|
11 | 11 |
|
12 | 12 | val dataSet = MNISTLoader.load(Paths.get(tempdir.toString()), MNISTLoader.MNIST) |
13 | 13 |
|
14 | | - val dtf_cifar_data = dtfdata.tf_dataset( |
| 14 | + val dtf_mnist_data = dtfdata.tf_dataset( |
15 | 15 | dtfdata.supervised_dataset( |
16 | 16 | dataSet.trainImages.unstack(axis = 0), |
17 | 17 | dataSet.trainLabels.castTo[Long].unstack(axis = -1)), |
|
47 | 47 | loss |
48 | 48 | ) |
49 | 49 |
|
50 | | - val data_ops = dtflearn.model.data_ops( |
| 50 | + val data_ops = dtflearn.model.data_ops[(Output[UByte], Output[Long])]( |
51 | 51 | shuffleBuffer = 5000, |
52 | 52 | batchSize = 128, |
53 | | - prefetchSize = 10, |
54 | | - concatOpI = Some(dtfpipe.EagerStack[UByte]()), |
55 | | - concatOpT = Some(dtfpipe.EagerStack[Long]()), |
56 | | - concatOpO = Some(dtfpipe.EagerConcatenate[Float]()) |
| 53 | + prefetchSize = 10 |
57 | 54 | ) |
58 | 55 |
|
59 | | - val config = dtflearn.model.trainConfig( |
| 56 | + val train_config = dtflearn.model.trainConfig( |
60 | 57 | tempdir/"mnist_summaries", |
61 | 58 | data_ops, |
62 | 59 | optimizer, |
|
69 | 66 | checkPointFreq = 100) |
70 | 67 | )) |
71 | 68 |
|
| 69 | + val pattern_to_tensor = |
| 70 | + DataPipe[Seq[(Tensor[UByte], Tensor[Long])], (Tensor[UByte], Tensor[Long])]( |
| 71 | + ds => { |
| 72 | + val (xs, ys) = ds.unzip |
| 73 | + |
| 74 | + ( |
| 75 | + dtfpipe.EagerStack[UByte](axis = 0).run(xs), |
| 76 | + dtfpipe.EagerStack[Long](axis = 0).run(ys) |
| 77 | + ) |
| 78 | + } |
| 79 | + ) |
| 80 | + |
| 81 | + val data_handle_ops = dtflearn.model.tf_data_handle_ops[ |
| 82 | + (Tensor[UByte], Tensor[Long]), |
| 83 | + (Tensor[UByte], Tensor[Long]), |
| 84 | + Tensor[Float], |
| 85 | + (Output[UByte], Output[Long]) |
| 86 | + ]( |
| 87 | + bufferSize = 500, |
| 88 | + patternToTensor = Some(pattern_to_tensor), |
| 89 | + concatOpO = Some(dtfpipe.EagerConcatenate[Float]()) |
| 90 | + ) |
72 | 91 |
|
73 | | - |
74 | | - mnist_model.train(dtf_cifar_data.training_dataset, config) |
| 92 | + val data_handle_ops_infer = |
| 93 | + dtflearn.model.tf_data_handle_ops[Tensor[UByte], Tensor[UByte], Tensor[ |
| 94 | + Float |
| 95 | + ], Output[UByte]]( |
| 96 | + bufferSize = 1000, |
| 97 | + patternToTensor = Some(dtfpipe.EagerStack[UByte](axis = 0)), |
| 98 | + concatOpO = Some(dtfpipe.EagerConcatenate[Float]()) |
| 99 | + ) |
| 100 | + |
| 101 | + mnist_model.train( |
| 102 | + dtf_mnist_data.training_dataset, |
| 103 | + train_config, |
| 104 | + data_handle_ops |
| 105 | + ) |
75 | 106 |
|
76 | 107 | def accuracy(predictions: Tensor[Long], labels: Tensor[Long]): Float = |
77 | 108 | tfi.equal(predictions.argmax[Long](1), labels) |
|
81 | 112 | .asInstanceOf[Float] |
82 | 113 |
|
83 | 114 | val (trainingPreds, testPreds): (Tensor[Float], Tensor[Float]) = ( |
84 | | - mnist_model.infer_batch(dtf_cifar_data.training_dataset.map(p => p._1), data_ops).left.get, |
85 | | - mnist_model.infer_batch(dtf_cifar_data.test_dataset.map(p => p._1), data_ops).left.get |
| 115 | + mnist_model |
| 116 | + .infer_batch( |
| 117 | + dtf_mnist_data.training_dataset.map(p => p._1), |
| 118 | + data_handle_ops_infer |
| 119 | + ) |
| 120 | + .left |
| 121 | + .get, |
| 122 | + mnist_model |
| 123 | + .infer_batch( |
| 124 | + dtf_mnist_data.test_dataset.map(p => p._1), |
| 125 | + data_handle_ops_infer |
| 126 | + ) |
| 127 | + .left |
| 128 | + .get |
86 | 129 | ) |
87 | 130 |
|
88 | 131 | val (trainAccuracy, testAccuracy) = ( |
89 | 132 | accuracy(trainingPreds.castTo[Long], dataSet.trainLabels.castTo[Long]), |
90 | | - accuracy(testPreds.castTo[Long], dataSet.testLabels.castTo[Long])) |
| 133 | + accuracy(testPreds.castTo[Long], dataSet.testLabels.castTo[Long]) |
| 134 | + ) |
91 | 135 |
|
92 | 136 | print("Train accuracy = ") |
93 | 137 | pprint.pprintln(trainAccuracy) |
|
0 commit comments