Skip to content

Commit 834b931

Browse files
authored
Merge pull request tensorflow#2 from tensorflow/mnist
Add Mnist examples (Simple and CNN)
2 parents 0ef05d4 + 171114e commit 834b931

File tree

11 files changed

+741
-0
lines changed

11 files changed

+741
-0
lines changed

tensorflow-examples/pom.xml

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
<project>
2+
<modelVersion>4.0.0</modelVersion>
3+
<groupId>org.tensorflow.model</groupId>
4+
<artifactId>tensorflow-examples</artifactId>
5+
<version>0.1.0-SNAPSHOT</version>
6+
7+
<name>TensorFlow Examples</name>
8+
<description>A suite of executable examples using TensorFlow Java</description>
9+
10+
<properties>
11+
<!-- The sample code requires at least JDK 1.8. -->
12+
<!-- The maven compiler plugin defaults to a lower version -->
13+
<maven.compiler.source>1.8</maven.compiler.source>
14+
<maven.compiler.target>1.8</maven.compiler.target>
15+
</properties>
16+
<dependencies>
17+
<dependency>
18+
<groupId>org.tensorflow</groupId>
19+
<artifactId>tensorflow-core-platform</artifactId>
20+
<version>0.1.0-SNAPSHOT</version>
21+
</dependency>
22+
<dependency>
23+
<groupId>org.tensorflow</groupId>
24+
<artifactId>tensorflow-training</artifactId>
25+
<version>0.1.0-SNAPSHOT</version>
26+
</dependency>
27+
</dependencies>
28+
<build>
29+
<plugins>
30+
<plugin>
31+
<groupId>org.apache.maven.plugins</groupId>
32+
<artifactId>maven-assembly-plugin</artifactId>
33+
<executions>
34+
<execution>
35+
<phase>package</phase>
36+
<goals>
37+
<goal>single</goal>
38+
</goals>
39+
<configuration>
40+
<archive>
41+
<manifest>
42+
<mainClass>
43+
org.tensorflow.model.examples.mnist.SimpleMnist
44+
</mainClass>
45+
</manifest>
46+
</archive>
47+
<descriptorRefs>
48+
<descriptorRef>jar-with-dependencies</descriptorRef>
49+
</descriptorRefs>
50+
</configuration>
51+
</execution>
52+
</executions>
53+
</plugin>
54+
</plugins>
55+
</build>
56+
</project>
Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
/*
2+
* Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.tensorflow.model.examples.mnist;
17+
18+
import java.util.Arrays;
19+
import java.util.logging.Level;
20+
import java.util.logging.Logger;
21+
import org.tensorflow.Graph;
22+
import org.tensorflow.Operand;
23+
import org.tensorflow.Session;
24+
import org.tensorflow.Tensor;
25+
import org.tensorflow.model.examples.mnist.data.ImageBatch;
26+
import org.tensorflow.model.examples.mnist.data.MnistDataset;
27+
import org.tensorflow.op.Op;
28+
import org.tensorflow.op.Ops;
29+
import org.tensorflow.op.core.Constant;
30+
import org.tensorflow.op.core.OneHot;
31+
import org.tensorflow.op.core.Placeholder;
32+
import org.tensorflow.op.core.Reshape;
33+
import org.tensorflow.op.core.Variable;
34+
import org.tensorflow.op.math.Add;
35+
import org.tensorflow.op.math.Mean;
36+
import org.tensorflow.op.nn.Conv2d;
37+
import org.tensorflow.op.nn.MaxPool;
38+
import org.tensorflow.op.nn.Relu;
39+
import org.tensorflow.op.nn.Softmax;
40+
import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits;
41+
import org.tensorflow.op.random.TruncatedNormal;
42+
import org.tensorflow.tools.Shape;
43+
import org.tensorflow.tools.ndarray.ByteNdArray;
44+
import org.tensorflow.tools.ndarray.FloatNdArray;
45+
import org.tensorflow.tools.ndarray.index.Indices;
46+
import org.tensorflow.training.optimizers.AdaDelta;
47+
import org.tensorflow.training.optimizers.AdaGrad;
48+
import org.tensorflow.training.optimizers.AdaGradDA;
49+
import org.tensorflow.training.optimizers.Adam;
50+
import org.tensorflow.training.optimizers.GradientDescent;
51+
import org.tensorflow.training.optimizers.Momentum;
52+
import org.tensorflow.training.optimizers.Optimizer;
53+
import org.tensorflow.training.optimizers.RMSProp;
54+
import org.tensorflow.types.TFloat32;
55+
import org.tensorflow.types.TUint8;
56+
57+
/**
58+
* Builds a LeNet-5 style CNN for MNIST.
59+
*/
60+
public class CnnMnist {
61+
62+
private static final Logger logger = Logger.getLogger(CnnMnist.class.getName());
63+
64+
private static final int PIXEL_DEPTH = 255;
65+
private static final int NUM_CHANNELS = 1;
66+
private static final int IMAGE_SIZE = 28;
67+
private static final int NUM_LABELS = MnistDataset.NUM_CLASSES;
68+
private static final long SEED = 123456789L;
69+
70+
private static final String PADDING_TYPE = "SAME";
71+
72+
public static final String INPUT_NAME = "input";
73+
public static final String OUTPUT_NAME = "output";
74+
public static final String TARGET = "target";
75+
public static final String TRAIN = "train";
76+
public static final String TRAINING_LOSS = "training_loss";
77+
public static final String INIT = "init";
78+
79+
public static Graph build(String optimizerName) {
80+
Graph graph = new Graph();
81+
82+
Ops tf = Ops.create(graph);
83+
84+
// Inputs
85+
Placeholder<TUint8> input = tf.withName(INPUT_NAME).placeholder(TUint8.DTYPE,
86+
Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE)));
87+
Reshape<TUint8> input_reshaped = tf
88+
.reshape(input, tf.array(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS));
89+
Placeholder<TUint8> labels = tf.withName(TARGET).placeholder(TUint8.DTYPE);
90+
91+
// Scaling the features
92+
Constant<TFloat32> centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f);
93+
Constant<TFloat32> scalingFactor = tf.constant((float) PIXEL_DEPTH);
94+
Operand<TFloat32> scaledInput = tf.math
95+
.div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.DTYPE), centeringFactor),
96+
scalingFactor);
97+
98+
// First conv layer
99+
Variable<TFloat32> conv1Weights = tf.variable(tf.math.mul(tf.random
100+
.truncatedNormal(tf.array(5, 5, NUM_CHANNELS, 32), TFloat32.DTYPE,
101+
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
102+
Conv2d<TFloat32> conv1 = tf.nn
103+
.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE);
104+
Variable<TFloat32> conv1Biases = tf
105+
.variable(tf.fill(tf.array(new int[]{32}), tf.constant(0.0f)));
106+
Relu<TFloat32> relu1 = tf.nn.relu(tf.nn.biasAdd(conv1, conv1Biases));
107+
108+
// First pooling layer
109+
MaxPool<TFloat32> pool1 = tf.nn
110+
.maxPool(relu1, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1),
111+
PADDING_TYPE);
112+
113+
// Second conv layer
114+
Variable<TFloat32> conv2Weights = tf.variable(tf.math.mul(tf.random
115+
.truncatedNormal(tf.array(5, 5, 32, 64), TFloat32.DTYPE,
116+
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
117+
Conv2d<TFloat32> conv2 = tf.nn
118+
.conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE);
119+
Variable<TFloat32> conv2Biases = tf
120+
.variable(tf.fill(tf.array(new int[]{64}), tf.constant(0.1f)));
121+
Relu<TFloat32> relu2 = tf.nn.relu(tf.nn.biasAdd(conv2, conv2Biases));
122+
123+
// Second pooling layer
124+
MaxPool<TFloat32> pool2 = tf.nn
125+
.maxPool(relu2, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1),
126+
PADDING_TYPE);
127+
128+
// Flatten inputs
129+
Reshape<TFloat32> flatten = tf.reshape(pool2, tf.concat(Arrays
130+
.asList(tf.slice(tf.shape(pool2), tf.array(new int[]{0}), tf.array(new int[]{1})),
131+
tf.array(new int[]{-1})), tf.constant(0)));
132+
133+
// Fully connected layer
134+
Variable<TFloat32> fc1Weights = tf.variable(tf.math.mul(tf.random
135+
.truncatedNormal(tf.array(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.DTYPE,
136+
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
137+
Variable<TFloat32> fc1Biases = tf
138+
.variable(tf.fill(tf.array(new int[]{512}), tf.constant(0.1f)));
139+
Relu<TFloat32> relu3 = tf.nn
140+
.relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases));
141+
142+
// Softmax layer
143+
Variable<TFloat32> fc2Weights = tf.variable(tf.math.mul(tf.random
144+
.truncatedNormal(tf.array(512, NUM_LABELS), TFloat32.DTYPE,
145+
TruncatedNormal.seed(SEED)), tf.constant(0.1f)));
146+
Variable<TFloat32> fc2Biases = tf
147+
.variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.constant(0.1f)));
148+
149+
Add<TFloat32> logits = tf.math.add(tf.linalg.matMul(relu3, fc2Weights), fc2Biases);
150+
151+
// Predicted outputs
152+
Softmax<TFloat32> prediction = tf.withName(OUTPUT_NAME).nn.softmax(logits);
153+
154+
// Loss function & regularization
155+
OneHot<TFloat32> oneHot = tf
156+
.oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f));
157+
SoftmaxCrossEntropyWithLogits<TFloat32> batchLoss = tf.nn
158+
.softmaxCrossEntropyWithLogits(logits, oneHot);
159+
Mean<TFloat32> labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0));
160+
Add<TFloat32> regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math
161+
.add(tf.nn.l2Loss(fc1Biases),
162+
tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases))));
163+
Add<TFloat32> loss = tf.withName(TRAINING_LOSS).math
164+
.add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f)));
165+
166+
String lcOptimizerName = optimizerName.toLowerCase();
167+
// Optimizer
168+
Optimizer optimizer;
169+
switch (lcOptimizerName) {
170+
case "adadelta":
171+
optimizer = new AdaDelta(graph, 1f, 0.95f, 1e-8f);
172+
break;
173+
case "adagradda":
174+
optimizer = new AdaGradDA(graph, 0.01f);
175+
break;
176+
case "adagrad":
177+
optimizer = new AdaGrad(graph, 0.01f);
178+
break;
179+
case "adam":
180+
optimizer = new Adam(graph, 0.001f, 0.9f, 0.999f, 1e-8f);
181+
break;
182+
case "sgd":
183+
optimizer = new GradientDescent(graph, 0.01f);
184+
break;
185+
case "momentum":
186+
optimizer = new Momentum(graph, 0.01f, 0.9f, false);
187+
break;
188+
case "rmsprop":
189+
optimizer = new RMSProp(graph, 0.01f, 0.9f, 0.0f, 1e-10f, false);
190+
break;
191+
default:
192+
throw new IllegalArgumentException("Unknown optimizer " + optimizerName);
193+
}
194+
logger.info("Optimizer = " + optimizer.toString());
195+
Op minimize = optimizer.minimize(loss, TRAIN);
196+
197+
tf.init();
198+
199+
return graph;
200+
}
201+
202+
public static void train(Session session, int epochs, int minibatchSize, MnistDataset dataset) {
203+
// Initialises the parameters.
204+
session.runner().addTarget(INIT).run();
205+
logger.info("Initialised the model parameters");
206+
207+
int interval = 0;
208+
// Train the model
209+
for (int i = 0; i < epochs; i++) {
210+
for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) {
211+
try (Tensor<TUint8> batchImages = TUint8.tensorOf(trainingBatch.images());
212+
Tensor<TUint8> batchLabels = TUint8.tensorOf(trainingBatch.labels());
213+
Tensor<TFloat32> loss = session.runner()
214+
.feed(TARGET, batchLabels)
215+
.feed(INPUT_NAME, batchImages)
216+
.addTarget(TRAIN)
217+
.fetch(TRAINING_LOSS)
218+
.run().get(0).expect(TFloat32.DTYPE)) {
219+
if (interval % 100 == 0) {
220+
logger.log(Level.INFO,
221+
"Iteration = " + interval + ", training loss = " + loss.data().getFloat());
222+
}
223+
}
224+
interval++;
225+
}
226+
}
227+
}
228+
229+
public static void test(Session session, int minibatchSize, MnistDataset dataset) {
230+
int correctCount = 0;
231+
int[][] confusionMatrix = new int[10][10];
232+
233+
for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) {
234+
try (Tensor<TUint8> transformedInput = TUint8.tensorOf(trainingBatch.images());
235+
Tensor<TFloat32> outputTensor = session.runner()
236+
.feed(INPUT_NAME, transformedInput)
237+
.fetch(OUTPUT_NAME).run().get(0).expect(TFloat32.DTYPE)) {
238+
239+
ByteNdArray labelBatch = trainingBatch.labels();
240+
for (int k = 0; k < labelBatch.shape().size(0); k++) {
241+
byte trueLabel = labelBatch.getByte(k);
242+
int predLabel;
243+
244+
predLabel = argmax(outputTensor.data().slice(Indices.at(k), Indices.all()));
245+
if (predLabel == trueLabel) {
246+
correctCount++;
247+
}
248+
249+
confusionMatrix[trueLabel][predLabel]++;
250+
}
251+
}
252+
}
253+
254+
logger.info("Final accuracy = " + ((float) correctCount) / dataset.numTestingExamples());
255+
256+
StringBuilder sb = new StringBuilder();
257+
sb.append("Label");
258+
for (int i = 0; i < confusionMatrix.length; i++) {
259+
sb.append(String.format("%1$5s", "" + i));
260+
}
261+
sb.append("\n");
262+
263+
for (int i = 0; i < confusionMatrix.length; i++) {
264+
sb.append(String.format("%1$5s", "" + i));
265+
for (int j = 0; j < confusionMatrix[i].length; j++) {
266+
sb.append(String.format("%1$5s", "" + confusionMatrix[i][j]));
267+
}
268+
sb.append("\n");
269+
}
270+
271+
System.out.println(sb.toString());
272+
}
273+
274+
/**
275+
* Find the maximum probability and return it's index.
276+
*
277+
* @param probabilities The probabilites.
278+
* @return The index of the max.
279+
*/
280+
public static int argmax(FloatNdArray probabilities) {
281+
float maxVal = Float.NEGATIVE_INFINITY;
282+
int idx = 0;
283+
for (int i = 0; i < probabilities.shape().size(0); i++) {
284+
float curVal = probabilities.getFloat(i);
285+
if (curVal > maxVal) {
286+
maxVal = curVal;
287+
idx = i;
288+
}
289+
}
290+
return idx;
291+
}
292+
293+
public static void main(String[] args) {
294+
logger.info(
295+
"Usage: MNISTTest <num-epochs> <minibatch-size> <optimizer-name>");
296+
297+
MnistDataset dataset = MnistDataset.create(0);
298+
299+
logger.info("Loaded data.");
300+
301+
int epochs = Integer.parseInt(args[0]);
302+
int minibatchSize = Integer.parseInt(args[1]);
303+
304+
try (Graph graph = build(args[2]);
305+
Session session = new Session(graph)) {
306+
train(session, epochs, minibatchSize, dataset);
307+
308+
logger.info("Trained model");
309+
310+
test(session, minibatchSize, dataset);
311+
}
312+
}
313+
}

0 commit comments

Comments
 (0)