Skip to content

Commit 887c350

Browse files
committed
Add loading vanilla Tensorflow/Keras models
Add vanilla example Add logic to validate we have a genuine TFKG model in model.LoadMoodel Fix bug where inaccurate dataset len or an error were causing the callback.EventEnd not to fire on train/val/test
1 parent 1225959 commit 887c350

File tree

13 files changed

+977
-53
lines changed

13 files changed

+977
-53
lines changed

Makefile

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,20 @@ examples-transfer-raw:
101101
go generate ./...
102102
cd examples/transfer_learning && go run main.go
103103

104+
examples-vanilla:
105+
go generate ./...
106+
docker-compose up -d tf-jupyter-golang
107+
docker-compose exec tf-jupyter-golang sh -c "cd /go/src/tfkg/examples/vanilla && python generate_vanilla_model.py && go run main.go"
108+
109+
examples-vanilla-gpu:
110+
go generate ./...
111+
docker-compose up -d tf-jupyter-golang-gpu
112+
docker-compose exec tf-jupyter-golang-gpu sh -c "cd /go/src/tfkg/examples/vanilla && go run main.go"
113+
114+
examples-vanilla-raw:
115+
go generate ./...
116+
cd examples/vanilla && go run main.go
117+
104118
test-python:
105119
docker-compose up -d tf-jupyter-golang
106120
docker-compose exec tf-jupyter-golang sh -c "cd /go/src/tfkg && python test.py"

examples/iris/main.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,11 @@ func main() {
168168
},
169169
)
170170

171+
m, e = model.LoadModel(errorHandler, logger, saveDir)
172+
if e != nil {
173+
return
174+
}
175+
171176
logger.InfoF("main", "Finished training")
172177

173178
// Create an inference provider, with a processor which will accept our input of [][]float32 and turn it into a tensor

examples/vanilla/data/iris.data

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
5.1,3.5,1.4,0.2,0
2+
4.9,3.0,1.4,0.2,0
3+
4.7,3.2,1.3,0.2,0
4+
4.6,3.1,1.5,0.2,0
5+
5.0,3.6,1.4,0.2,0
6+
5.4,3.9,1.7,0.4,0
7+
4.6,3.4,1.4,0.3,0
8+
5.0,3.4,1.5,0.2,0
9+
4.4,2.9,1.4,0.2,0
10+
4.9,3.1,1.5,0.1,0
11+
5.4,3.7,1.5,0.2,0
12+
4.8,3.4,1.6,0.2,0
13+
4.8,3.0,1.4,0.1,0
14+
4.3,3.0,1.1,0.1,0
15+
5.8,4.0,1.2,0.2,0
16+
5.7,4.4,1.5,0.4,0
17+
5.4,3.9,1.3,0.4,0
18+
5.1,3.5,1.4,0.3,0
19+
5.7,3.8,1.7,0.3,0
20+
5.1,3.8,1.5,0.3,0
21+
5.4,3.4,1.7,0.2,0
22+
5.1,3.7,1.5,0.4,0
23+
4.6,3.6,1.0,0.2,0
24+
5.1,3.3,1.7,0.5,0
25+
4.8,3.4,1.9,0.2,0
26+
5.0,3.0,1.6,0.2,0
27+
5.0,3.4,1.6,0.4,0
28+
5.2,3.5,1.5,0.2,0
29+
5.2,3.4,1.4,0.2,0
30+
4.7,3.2,1.6,0.2,0
31+
4.8,3.1,1.6,0.2,0
32+
5.4,3.4,1.5,0.4,0
33+
5.2,4.1,1.5,0.1,0
34+
5.5,4.2,1.4,0.2,0
35+
4.9,3.1,1.5,0.1,0
36+
5.0,3.2,1.2,0.2,0
37+
5.5,3.5,1.3,0.2,0
38+
4.9,3.1,1.5,0.1,0
39+
4.4,3.0,1.3,0.2,0
40+
5.1,3.4,1.5,0.2,0
41+
5.0,3.5,1.3,0.3,0
42+
4.5,2.3,1.3,0.3,0
43+
4.4,3.2,1.3,0.2,0
44+
5.0,3.5,1.6,0.6,0
45+
5.1,3.8,1.9,0.4,0
46+
4.8,3.0,1.4,0.3,0
47+
5.1,3.8,1.6,0.2,0
48+
4.6,3.2,1.4,0.2,0
49+
5.3,3.7,1.5,0.2,0
50+
5.0,3.3,1.4,0.2,0
51+
7.0,3.2,4.7,1.4,1
52+
6.4,3.2,4.5,1.5,1
53+
6.9,3.1,4.9,1.5,1
54+
5.5,2.3,4.0,1.3,1
55+
6.5,2.8,4.6,1.5,1
56+
5.7,2.8,4.5,1.3,1
57+
6.3,3.3,4.7,1.6,1
58+
4.9,2.4,3.3,1.0,1
59+
6.6,2.9,4.6,1.3,1
60+
5.2,2.7,3.9,1.4,1
61+
5.0,2.0,3.5,1.0,1
62+
5.9,3.0,4.2,1.5,1
63+
6.0,2.2,4.0,1.0,1
64+
6.1,2.9,4.7,1.4,1
65+
5.6,2.9,3.6,1.3,1
66+
6.7,3.1,4.4,1.4,1
67+
5.6,3.0,4.5,1.5,1
68+
5.8,2.7,4.1,1.0,1
69+
6.2,2.2,4.5,1.5,1
70+
5.6,2.5,3.9,1.1,1
71+
5.9,3.2,4.8,1.8,1
72+
6.1,2.8,4.0,1.3,1
73+
6.3,2.5,4.9,1.5,1
74+
6.1,2.8,4.7,1.2,1
75+
6.4,2.9,4.3,1.3,1
76+
6.6,3.0,4.4,1.4,1
77+
6.8,2.8,4.8,1.4,1
78+
6.7,3.0,5.0,1.7,1
79+
6.0,2.9,4.5,1.5,1
80+
5.7,2.6,3.5,1.0,1
81+
5.5,2.4,3.8,1.1,1
82+
5.5,2.4,3.7,1.0,1
83+
5.8,2.7,3.9,1.2,1
84+
6.0,2.7,5.1,1.6,1
85+
5.4,3.0,4.5,1.5,1
86+
6.0,3.4,4.5,1.6,1
87+
6.7,3.1,4.7,1.5,1
88+
6.3,2.3,4.4,1.3,1
89+
5.6,3.0,4.1,1.3,1
90+
5.5,2.5,4.0,1.3,1
91+
5.5,2.6,4.4,1.2,1
92+
6.1,3.0,4.6,1.4,1
93+
5.8,2.6,4.0,1.2,1
94+
5.0,2.3,3.3,1.0,1
95+
5.6,2.7,4.2,1.3,1
96+
5.7,3.0,4.2,1.2,1
97+
5.7,2.9,4.2,1.3,1
98+
6.2,2.9,4.3,1.3,1
99+
5.1,2.5,3.0,1.1,1
100+
5.7,2.8,4.1,1.3,1
101+
6.3,3.3,6.0,2.5,2
102+
5.8,2.7,5.1,1.9,2
103+
7.1,3.0,5.9,2.1,2
104+
6.3,2.9,5.6,1.8,2
105+
6.5,3.0,5.8,2.2,2
106+
7.6,3.0,6.6,2.1,2
107+
4.9,2.5,4.5,1.7,2
108+
7.3,2.9,6.3,1.8,2
109+
6.7,2.5,5.8,1.8,2
110+
7.2,3.6,6.1,2.5,2
111+
6.5,3.2,5.1,2.0,2
112+
6.4,2.7,5.3,1.9,2
113+
6.8,3.0,5.5,2.1,2
114+
5.7,2.5,5.0,2.0,2
115+
5.8,2.8,5.1,2.4,2
116+
6.4,3.2,5.3,2.3,2
117+
6.5,3.0,5.5,1.8,2
118+
7.7,3.8,6.7,2.2,2
119+
7.7,2.6,6.9,2.3,2
120+
6.0,2.2,5.0,1.5,2
121+
6.9,3.2,5.7,2.3,2
122+
5.6,2.8,4.9,2.0,2
123+
7.7,2.8,6.7,2.0,2
124+
6.3,2.7,4.9,1.8,2
125+
6.7,3.3,5.7,2.1,2
126+
7.2,3.2,6.0,1.8,2
127+
6.2,2.8,4.8,1.8,2
128+
6.1,3.0,4.9,1.8,2
129+
6.4,2.8,5.6,2.1,2
130+
7.2,3.0,5.8,1.6,2
131+
7.4,2.8,6.1,1.9,2
132+
7.9,3.8,6.4,2.0,2
133+
6.4,2.8,5.6,2.2,2
134+
6.3,2.8,5.1,1.5,2
135+
6.1,2.6,5.6,1.4,2
136+
7.7,3.0,6.1,2.3,2
137+
6.3,3.4,5.6,2.4,2
138+
6.4,3.1,5.5,1.8,2
139+
6.0,3.0,4.8,1.8,2
140+
6.9,3.1,5.4,2.1,2
141+
6.7,3.1,5.6,2.4,2
142+
6.9,3.1,5.1,2.3,2
143+
5.8,2.7,5.1,1.9,2
144+
6.8,3.2,5.9,2.3,2
145+
6.7,3.3,5.7,2.5,2
146+
6.7,3.0,5.2,2.3,2
147+
6.3,2.5,5.0,1.9,2
148+
6.5,3.0,5.2,2.0,2
149+
6.2,3.4,5.4,2.3,2
150+
5.9,3.0,5.1,1.8,2
151+

examples/vanilla/data/iris.names

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
1. Title: Iris Plants Database
2+
Updated Sept 21 by C.Blake - Added discrepency information
3+
4+
2. Sources:
5+
(a) Creator: R.A. Fisher
6+
(b) Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
7+
(c) Date: July, 1988
8+
9+
3. Past Usage:
10+
- Publications: too many to mention!!! Here are a few.
11+
1. Fisher,R.A. "The use of multiple measurements in taxonomic problems"
12+
Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions
13+
to Mathematical Statistics" (John Wiley, NY, 1950).
14+
2. Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.
15+
(Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.
16+
3. Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
17+
Structure and Classification Rule for Recognition in Partially Exposed
18+
Environments". IEEE Transactions on Pattern Analysis and Machine
19+
Intelligence, Vol. PAMI-2, No. 1, 67-71.
20+
-- Results:
21+
-- very low misclassification rates (0% for the setosa class)
22+
4. Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule". IEEE
23+
Transactions on Information Theory, May 1972, 431-433.
24+
-- Results:
25+
-- very low misclassification rates again
26+
5. See also: 1988 MLC Proceedings, 54-64. Cheeseman et al's AUTOCLASS II
27+
conceptual clustering system finds 3 classes in the data.
28+
29+
4. Relevant Information:
30+
--- This is perhaps the best known database to be found in the pattern
31+
recognition literature. Fisher's paper is a classic in the field
32+
and is referenced frequently to this day. (See Duda & Hart, for
33+
example.) The data set contains 3 classes of 50 instances each,
34+
where each class refers to a type of iris plant. One class is
35+
linearly separable from the other 2; the latter are NOT linearly
36+
separable from each other.
37+
--- Predicted attribute: class of iris plant.
38+
--- This is an exceedingly simple domain.
39+
--- This data differs from the data presented in Fishers article
40+
(identified by Steve Chadwick, spchadwick@espeedaz.net )
41+
The 35th sample should be: 4.9,3.1,1.5,0.2,"Iris-setosa"
42+
where the error is in the fourth feature.
43+
The 38th sample: 4.9,3.6,1.4,0.1,"Iris-setosa"
44+
where the errors are in the second and third features.
45+
46+
5. Number of Instances: 150 (50 in each of three classes)
47+
48+
6. Number of Attributes: 4 numeric, predictive attributes and the class
49+
50+
7. Attribute Information:
51+
1. sepal length in cm
52+
2. sepal width in cm
53+
3. petal length in cm
54+
4. petal width in cm
55+
5. class:
56+
-- Iris Setosa
57+
-- Iris Versicolour
58+
-- Iris Virginica
59+
60+
8. Missing Attribute Values: None
61+
62+
Summary Statistics:
63+
Min Max Mean SD Class Correlation
64+
sepal length: 4.3 7.9 5.84 0.83 0.7826
65+
sepal width: 2.0 4.4 3.05 0.43 -0.4194
66+
petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)
67+
petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)
68+
69+
9. Class Distribution: 33.3% for each of 3 classes.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import os
2+
import tensorflow as tf
3+
import numpy as np
4+
5+
csv_data = np.loadtxt('data/iris.data', delimiter=',')
6+
target_all = csv_data[:, -1]
7+
8+
csv_data = csv_data[:, 0:-1]
9+
10+
shuffled_indices = np.arange(csv_data.shape[0])
11+
np.random.shuffle(shuffled_indices)
12+
13+
shuffled_inputs = csv_data[shuffled_indices]
14+
shuffled_targets = target_all[shuffled_indices]
15+
16+
train_inputs = shuffled_inputs
17+
train_targets = shuffled_targets
18+
19+
model = tf.keras.Sequential([
20+
tf.keras.layers.Input(shape=(4,), dtype=tf.float32),
21+
tf.keras.layers.Dense(10, activation="swish"),
22+
tf.keras.layers.Dense(10, activation="swish"),
23+
tf.keras.layers.Dense(3, activation="softmax")
24+
])
25+
26+
model.compile(
27+
loss="sparse_categorical_crossentropy",
28+
optimizer="adam",
29+
metrics=['accuracy']
30+
)
31+
32+
model.fit(
33+
train_inputs,
34+
train_targets,
35+
batch_size=3,
36+
epochs=10,
37+
verbose=0,
38+
)
39+
40+
model.evaluate(
41+
train_inputs,
42+
train_targets,
43+
batch_size=3,
44+
verbose=1,
45+
)
46+
47+
model.save("vanilla_model")

0 commit comments

Comments
 (0)