Skip to content

Commit 3bbb4e5

Browse files
committed
First draft for a MNIST example
1 parent 0ef05d4 commit 3bbb4e5

File tree

10 files changed

+421
-0
lines changed

10 files changed

+421
-0
lines changed

tensorflow-examples/pom.xml

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
<project>
2+
<modelVersion>4.0.0</modelVersion>
3+
<groupId>org.tensorflow.model</groupId>
4+
<artifactId>tensorflow-examples</artifactId>
5+
<version>0.1.0-SNAPSHOT</version>
6+
7+
<name>TensorFlow Examples</name>
8+
<description>A suite of executable examples using TensorFlow Java</description>
9+
10+
<properties>
11+
<!-- The sample code requires at least JDK 1.7. -->
12+
<!-- The maven compiler plugin defaults to a lower version -->
13+
<maven.compiler.source>1.8</maven.compiler.source>
14+
<maven.compiler.target>1.8</maven.compiler.target>
15+
</properties>
16+
<dependencies>
17+
<dependency>
18+
<groupId>org.tensorflow</groupId>
19+
<artifactId>tensorflow-core-api</artifactId>
20+
<version>0.1.0-SNAPSHOT</version>
21+
</dependency>
22+
<dependency>
23+
<groupId>org.tensorflow</groupId>
24+
<artifactId>tensorflow-core-api</artifactId>
25+
<version>0.1.0-SNAPSHOT</version>
26+
<classifier>macosx-x86_64</classifier>
27+
</dependency>
28+
<dependency>
29+
<groupId>org.tensorflow</groupId>
30+
<artifactId>proto</artifactId>
31+
<version>1.15.0</version>
32+
</dependency>
33+
</dependencies>
34+
<build>
35+
<plugins>
36+
<plugin>
37+
<groupId>org.apache.maven.plugins</groupId>
38+
<artifactId>maven-assembly-plugin</artifactId>
39+
<executions>
40+
<execution>
41+
<phase>package</phase>
42+
<goals>
43+
<goal>single</goal>
44+
</goals>
45+
<configuration>
46+
<archive>
47+
<manifest>
48+
<mainClass>
49+
org.tensorflow.model.examples.mnist.SimpleMnist
50+
</mainClass>
51+
</manifest>
52+
</archive>
53+
<descriptorRefs>
54+
<descriptorRef>jar-with-dependencies</descriptorRef>
55+
</descriptorRefs>
56+
</configuration>
57+
</execution>
58+
</executions>
59+
</plugin>
60+
</plugins>
61+
</build>
62+
</project>
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
package org.tensorflow.model.examples.mnist;
2+
3+
import java.util.Arrays;
4+
import org.tensorflow.Graph;
5+
import org.tensorflow.Operand;
6+
import org.tensorflow.Session;
7+
import org.tensorflow.Tensor;
8+
import org.tensorflow.model.examples.mnist.data.ImageBatch;
9+
import org.tensorflow.model.examples.mnist.data.MnistDataset;
10+
import org.tensorflow.op.Ops;
11+
import org.tensorflow.op.core.Assign;
12+
import org.tensorflow.op.core.Constant;
13+
import org.tensorflow.op.core.Gradients;
14+
import org.tensorflow.op.core.Placeholder;
15+
import org.tensorflow.op.core.Variable;
16+
import org.tensorflow.op.math.Mean;
17+
import org.tensorflow.op.nn.Softmax;
18+
import org.tensorflow.op.train.ApplyGradientDescent;
19+
import org.tensorflow.tools.Shape;
20+
import org.tensorflow.tools.ndarray.ByteNdArray;
21+
import org.tensorflow.types.TFloat32;
22+
import org.tensorflow.types.TInt64;
23+
24+
public class SimpleMnist implements Runnable {
25+
26+
public static void main(String[] args) {
27+
MnistDataset dataset = MnistDataset.create(VALIDATION_SIZE);
28+
try (Graph graph = new Graph()) {
29+
SimpleMnist mnist = new SimpleMnist(graph, dataset);
30+
mnist.run();
31+
}
32+
}
33+
34+
@Override
35+
public void run() {
36+
Ops tf = Ops.create(graph);
37+
38+
// Create placeholders and variables, which should fit batches of an unknown number of images
39+
Placeholder<TFloat32> images = tf.placeholder(TFloat32.DTYPE);
40+
Placeholder<TFloat32> labels = tf.placeholder(TFloat32.DTYPE);
41+
42+
// Create weights with an initial value of 0
43+
Shape weightShape = Shape.of(dataset.imageSize(), MnistDataset.NUM_CLASSES);
44+
Variable<TFloat32> weights = tf.variable(weightShape, TFloat32.DTYPE);
45+
Assign<TFloat32> weightsInit = tf.assign(weights, tf.zerosLike(weights));
46+
47+
// Create biases with an initial value of 0
48+
Shape biasShape = Shape.of(MnistDataset.NUM_CLASSES);
49+
Variable<TFloat32> biases = tf.variable(biasShape, TFloat32.DTYPE);
50+
Assign<TFloat32> biasesInit = tf.assign(biases, tf.zerosLike(biases));
51+
52+
// Predict the class of each image in the batch and compute the loss
53+
Softmax<TFloat32> softmax =
54+
tf.nn.softmax(
55+
tf.math.add(
56+
tf.linalg.matMul(images, weights),
57+
biases
58+
)
59+
);
60+
Mean<TFloat32> crossEntropy =
61+
tf.math.mean(
62+
tf.math.neg(
63+
tf.reduceSum(
64+
tf.math.mul(labels, tf.math.log(softmax)),
65+
tf.array(1)
66+
)
67+
),
68+
tf.array(0)
69+
);
70+
71+
// Back-propagate gradients to variables for training
72+
Gradients gradients = tf.gradients(crossEntropy, Arrays.asList(weights, biases));
73+
Constant<TFloat32> alpha = tf.val(LEARNING_RATE);
74+
ApplyGradientDescent<TFloat32> weightGradientDescent = tf.train.applyGradientDescent(weights, alpha, gradients.dy(0));
75+
ApplyGradientDescent<TFloat32> biasGradientDescent = tf.train.applyGradientDescent(biases, alpha, gradients.dy(1));
76+
77+
// Compute the accuracy of the model
78+
Operand<TInt64> predicted = tf.math.argMax(softmax, tf.val(1));
79+
Operand<TInt64> expected = tf.math.argMax(labels, tf.val(1));
80+
Operand<TFloat32> accuracy = tf.math.mean(tf.dtypes.cast(tf.math.equal(predicted, expected), TFloat32.DTYPE), tf.array(0));
81+
82+
// Run the graph
83+
try (Session session = new Session(graph)) {
84+
85+
// Initialize variables
86+
session.runner()
87+
.addTarget(weightsInit)
88+
.addTarget(biasesInit)
89+
.run();
90+
91+
// Train the model
92+
for (ImageBatch trainingBatch : dataset.trainingBatches(TRAINING_BATCH_SIZE)) {
93+
try (Tensor<TFloat32> batchImages = preprocessImages(trainingBatch.images());
94+
Tensor<TFloat32> batchLabels = preprocessLabels(trainingBatch.labels())) {
95+
session.runner()
96+
.addTarget(weightGradientDescent)
97+
.addTarget(biasGradientDescent)
98+
.feed(images.asOutput(), batchImages)
99+
.feed(labels.asOutput(), batchLabels)
100+
.run();
101+
}
102+
}
103+
104+
// Test the model
105+
ImageBatch testBatch = dataset.testBatch();
106+
try (Tensor<TFloat32> testImages = preprocessImages(testBatch.images());
107+
Tensor<TFloat32> testLabels = preprocessLabels(testBatch.labels());
108+
Tensor<TFloat32> accuracyValue = session.runner()
109+
.fetch(accuracy)
110+
.feed(images.asOutput(), testImages)
111+
.feed(labels.asOutput(), testLabels)
112+
.run()
113+
.get(0)
114+
.expect(TFloat32.DTYPE)) {
115+
System.out.println("Accuracy: " + accuracyValue.data().getFloat());
116+
}
117+
}
118+
}
119+
120+
private static final int VALIDATION_SIZE = 0;
121+
private static final int TRAINING_BATCH_SIZE = 100;
122+
private static final float LEARNING_RATE = 0.2f;
123+
124+
private static Tensor<TFloat32> preprocessImages(ByteNdArray rawImages) {
125+
Ops tf = Ops.create();
126+
127+
// Flatten images in a single dimension and normalize their pixels as floats.
128+
long imageSize = rawImages.get(0).shape().size();
129+
return tf.math.div(
130+
tf.reshape(
131+
tf.dtypes.cast(tf.val(rawImages), TFloat32.DTYPE),
132+
tf.array(-1L, imageSize)
133+
),
134+
tf.val(255.0f)
135+
).asTensor();
136+
}
137+
138+
private static Tensor<TFloat32> preprocessLabels(ByteNdArray rawLabels) {
139+
Ops tf = Ops.create();
140+
141+
// Map labels to one hot vectors where only the expected predictions as a value of 1.0
142+
return tf.oneHot(
143+
tf.val(rawLabels),
144+
tf.val(MnistDataset.NUM_CLASSES),
145+
tf.val(1.0f),
146+
tf.val(0.0f)
147+
).asTensor();
148+
}
149+
150+
private Graph graph;
151+
private MnistDataset dataset;
152+
153+
private SimpleMnist(Graph graph, MnistDataset dataset) {
154+
this.graph = graph;
155+
this.dataset = dataset;
156+
}
157+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package org.tensorflow.model.examples.mnist.data;
2+
3+
import org.tensorflow.tools.ndarray.ByteNdArray;
4+
5+
public class ImageBatch {
6+
7+
public ByteNdArray images() {
8+
return images;
9+
}
10+
11+
public ByteNdArray labels() {
12+
return labels;
13+
}
14+
15+
ImageBatch(ByteNdArray images, ByteNdArray labels) {
16+
this.images = images;
17+
this.labels = labels;
18+
}
19+
20+
private final ByteNdArray images;
21+
private final ByteNdArray labels;
22+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =======================================================================
16+
*/
17+
18+
package org.tensorflow.model.examples.mnist.data;
19+
20+
import static org.tensorflow.tools.ndarray.index.Indices.range;
21+
22+
import java.util.Iterator;
23+
import org.tensorflow.tools.ndarray.ByteNdArray;
24+
import org.tensorflow.tools.ndarray.index.Index;
25+
26+
class ImageBatchIterator implements Iterator<ImageBatch> {
27+
28+
@Override
29+
public boolean hasNext() {
30+
return batchStart < numImages;
31+
}
32+
33+
@Override
34+
public ImageBatch next() {
35+
long nextBatchSize = Math.min(batchSize, numImages - batchStart);
36+
Index range = range(batchStart, batchStart + nextBatchSize);
37+
batchStart += nextBatchSize;
38+
return new ImageBatch(images.slice(range), labels.slice(range));
39+
}
40+
41+
ImageBatchIterator(int batchSize, ByteNdArray images, ByteNdArray labels) {
42+
this.batchSize = batchSize;
43+
this.images = images;
44+
this.labels = labels;
45+
this.numImages = images != null ? images.shape().size(0) : 0;
46+
this.batchStart = 0;
47+
}
48+
49+
private final int batchSize;
50+
private final ByteNdArray images;
51+
private final ByteNdArray labels;
52+
private final long numImages;
53+
private int batchStart;
54+
}

0 commit comments

Comments
 (0)