Skip to content

Commit 77fe494

Browse files
committed
Updated mnist script.
1 parent 5d27d08 commit 77fe494

File tree

1 file changed

+56
-12
lines changed

1 file changed

+56
-12
lines changed

scripts/mnist.sc

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
val dataSet = MNISTLoader.load(Paths.get(tempdir.toString()), MNISTLoader.MNIST)
1313

14-
val dtf_cifar_data = dtfdata.tf_dataset(
14+
val dtf_mnist_data = dtfdata.tf_dataset(
1515
dtfdata.supervised_dataset(
1616
dataSet.trainImages.unstack(axis = 0),
1717
dataSet.trainLabels.castTo[Long].unstack(axis = -1)),
@@ -47,16 +47,13 @@
4747
loss
4848
)
4949

50-
val data_ops = dtflearn.model.data_ops(
50+
val data_ops = dtflearn.model.data_ops[(Output[UByte], Output[Long])](
5151
shuffleBuffer = 5000,
5252
batchSize = 128,
53-
prefetchSize = 10,
54-
concatOpI = Some(dtfpipe.EagerStack[UByte]()),
55-
concatOpT = Some(dtfpipe.EagerStack[Long]()),
56-
concatOpO = Some(dtfpipe.EagerConcatenate[Float]())
53+
prefetchSize = 10
5754
)
5855

59-
val config = dtflearn.model.trainConfig(
56+
val train_config = dtflearn.model.trainConfig(
6057
tempdir/"mnist_summaries",
6158
data_ops,
6259
optimizer,
@@ -69,9 +66,43 @@
6966
checkPointFreq = 100)
7067
))
7168

69+
val pattern_to_tensor =
70+
DataPipe[Seq[(Tensor[UByte], Tensor[Long])], (Tensor[UByte], Tensor[Long])](
71+
ds => {
72+
val (xs, ys) = ds.unzip
73+
74+
(
75+
dtfpipe.EagerStack[UByte](axis = 0).run(xs),
76+
dtfpipe.EagerStack[Long](axis = 0).run(ys)
77+
)
78+
}
79+
)
80+
81+
val data_handle_ops = dtflearn.model.tf_data_handle_ops[
82+
(Tensor[UByte], Tensor[Long]),
83+
(Tensor[UByte], Tensor[Long]),
84+
Tensor[Float],
85+
(Output[UByte], Output[Long])
86+
](
87+
bufferSize = 500,
88+
patternToTensor = Some(pattern_to_tensor),
89+
concatOpO = Some(dtfpipe.EagerConcatenate[Float]())
90+
)
7291

73-
74-
mnist_model.train(dtf_cifar_data.training_dataset, config)
92+
val data_handle_ops_infer =
93+
dtflearn.model.tf_data_handle_ops[Tensor[UByte], Tensor[UByte], Tensor[
94+
Float
95+
], Output[UByte]](
96+
bufferSize = 1000,
97+
patternToTensor = Some(dtfpipe.EagerStack[UByte](axis = 0)),
98+
concatOpO = Some(dtfpipe.EagerConcatenate[Float]())
99+
)
100+
101+
mnist_model.train(
102+
dtf_mnist_data.training_dataset,
103+
train_config,
104+
data_handle_ops
105+
)
75106

76107
def accuracy(predictions: Tensor[Long], labels: Tensor[Long]): Float =
77108
tfi.equal(predictions.argmax[Long](1), labels)
@@ -81,13 +112,26 @@
81112
.asInstanceOf[Float]
82113

83114
val (trainingPreds, testPreds): (Tensor[Float], Tensor[Float]) = (
84-
mnist_model.infer_batch(dtf_cifar_data.training_dataset.map(p => p._1), data_ops).left.get,
85-
mnist_model.infer_batch(dtf_cifar_data.test_dataset.map(p => p._1), data_ops).left.get
115+
mnist_model
116+
.infer_batch(
117+
dtf_mnist_data.training_dataset.map(p => p._1),
118+
data_handle_ops_infer
119+
)
120+
.left
121+
.get,
122+
mnist_model
123+
.infer_batch(
124+
dtf_mnist_data.test_dataset.map(p => p._1),
125+
data_handle_ops_infer
126+
)
127+
.left
128+
.get
86129
)
87130

88131
val (trainAccuracy, testAccuracy) = (
89132
accuracy(trainingPreds.castTo[Long], dataSet.trainLabels.castTo[Long]),
90-
accuracy(testPreds.castTo[Long], dataSet.testLabels.castTo[Long]))
133+
accuracy(testPreds.castTo[Long], dataSet.testLabels.castTo[Long])
134+
)
91135

92136
print("Train accuracy = ")
93137
pprint.pprintln(trainAccuracy)

0 commit comments

Comments
 (0)