Skip to content
7 changes: 6 additions & 1 deletion tensorflow-examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
<description>A suite of executable examples using TensorFlow Java</description>

<properties>
<!-- The sample code requires at least JDK 1.7. -->
<!-- The sample code requires at least JDK 1.8. -->
<!-- The maven compiler plugin defaults to a lower version -->
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
Expand All @@ -25,6 +25,11 @@
<version>0.1.0-SNAPSHOT</version>
<classifier>macosx-x86_64</classifier>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-training</artifactId>
<version>0.1.0-SNAPSHOT</version>
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>proto</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
/*
* Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.tensorflow.model.examples.mnist;

import org.tensorflow.Graph;
import org.tensorflow.Operand;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.model.examples.mnist.data.ImageBatch;
import org.tensorflow.model.examples.mnist.data.MnistDataset;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.OneHot;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.core.Reshape;
import org.tensorflow.op.core.Variable;
import org.tensorflow.op.math.Add;
import org.tensorflow.op.math.Mean;
import org.tensorflow.op.nn.Conv2d;
import org.tensorflow.op.nn.MaxPool;
import org.tensorflow.op.nn.Relu;
import org.tensorflow.op.nn.Softmax;
import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits;
import org.tensorflow.op.random.TruncatedNormal;
import org.tensorflow.tools.ndarray.ByteNdArray;
import org.tensorflow.tools.ndarray.FloatNdArray;
import org.tensorflow.tools.ndarray.index.Indices;
import org.tensorflow.training.optimizers.AdaDelta;
import org.tensorflow.training.optimizers.AdaGrad;
import org.tensorflow.training.optimizers.AdaGradDA;
import org.tensorflow.training.optimizers.Adam;
import org.tensorflow.training.optimizers.GradientDescent;
import org.tensorflow.training.optimizers.Momentum;
import org.tensorflow.training.optimizers.Optimizer;
import org.tensorflow.training.optimizers.RMSProp;
import org.tensorflow.tools.Shape;
import org.tensorflow.types.TFloat32;

import java.io.IOException;
import java.util.Arrays;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tensorflow.types.TUint8;

/**
* Builds a LeNet-5 style CNN for MNIST.
*/
public class MNISTTest {

private static final Logger logger = Logger.getLogger(MNISTTest.class.getName());

private static final int PIXEL_DEPTH = 255;
private static final int NUM_CHANNELS = 1;
private static final int IMAGE_SIZE = 28;
private static final int NUM_LABELS = 10;
private static final long SEED = 123456789L;

private static final String PADDING_TYPE = "SAME";

public static final String INPUT_NAME = "input";
public static final String OUTPUT_NAME = "output";
public static final String TARGET = "target";
public static final String TRAIN = "train";
public static final String TRAINING_LOSS = "training_loss";
public static final String INIT = "init";

public static Graph build(String optimizerName) {
Graph graph = new Graph();

Ops tf = Ops.create(graph);

// Inputs
Placeholder<TUint8> input = tf.withName(INPUT_NAME).placeholder(TUint8.DTYPE,
Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)));
Placeholder<TUint8> labels = tf.withName(TARGET).placeholder(TUint8.DTYPE);

// Scaling the features
Constant<TFloat32> centeringFactor = tf.val(PIXEL_DEPTH / 2.0f);
Constant<TFloat32> scalingFactor = tf.val((float) PIXEL_DEPTH);
Operand<TFloat32> scaledInput = tf.math.div(tf.math.sub(tf.dtypes.cast(input,TFloat32.DTYPE), centeringFactor), scalingFactor);

// First conv layer
Variable<TFloat32> conv1Weights = tf.variable(tf.math.mul(tf.random
.truncatedNormal(tf.array(5, 5, NUM_CHANNELS, 32), TFloat32.DTYPE,
TruncatedNormal.seed(SEED)), tf.val(0.1f)));
Conv2d<TFloat32> conv1 = tf.nn
.conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE);
Variable<TFloat32> conv1Biases = tf
.variable(tf.fill(tf.array(new int[]{32}), tf.val(0.0f)));
Relu<TFloat32> relu1 = tf.nn.relu(tf.nn.biasAdd(conv1, conv1Biases));

// First pooling layer
MaxPool<TFloat32> pool1 = tf.nn
.maxPool(relu1, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1),
PADDING_TYPE);

// Second conv layer
Variable<TFloat32> conv2Weights = tf.variable(tf.math.mul(tf.random
.truncatedNormal(tf.array(5, 5, 32, 64), TFloat32.DTYPE,
TruncatedNormal.seed(SEED)), tf.val(0.1f)));
Conv2d<TFloat32> conv2 = tf.nn
.conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE);
Variable<TFloat32> conv2Biases = tf
.variable(tf.fill(tf.array(new int[]{64}), tf.val(0.1f)));
Relu<TFloat32> relu2 = tf.nn.relu(tf.nn.biasAdd(conv2, conv2Biases));

// Second pooling layer
MaxPool<TFloat32> pool2 = tf.nn
.maxPool(relu2, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1),
PADDING_TYPE);

// Flatten inputs
Reshape<TFloat32> flatten = tf.reshape(pool2, tf.concat(Arrays
.asList(tf.slice(tf.shape(pool2), tf.array(new int[]{0}), tf.array(new int[]{1})),
tf.array(new int[]{-1})), tf.val(0)));

// Fully connected layer
Variable<TFloat32> fc1Weights = tf.variable(tf.math.mul(tf.random
.truncatedNormal(tf.array(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.DTYPE,
TruncatedNormal.seed(SEED)), tf.val(0.1f)));
Variable<TFloat32> fc1Biases = tf
.variable(tf.fill(tf.array(new int[]{512}), tf.val(0.1f)));
Relu<TFloat32> relu3 = tf.nn
.relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases));

// Softmax layer
Variable<TFloat32> fc2Weights = tf.variable(tf.math.mul(tf.random
.truncatedNormal(tf.array(512, NUM_LABELS), TFloat32.DTYPE,
TruncatedNormal.seed(SEED)), tf.val(0.1f)));
Variable<TFloat32> fc2Biases = tf
.variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.val(0.1f)));

Add<TFloat32> logits = tf.math.add(tf.linalg.matMul(relu3, fc2Weights), fc2Biases);

// Predicted outputs
Softmax<TFloat32> prediction = tf.withName(OUTPUT_NAME).nn.softmax(logits);

// Loss function & regularization
OneHot<TFloat32> oneHot = tf
.oneHot(labels, tf.val(10), tf.val(1.0f), tf.val(0.0f));
SoftmaxCrossEntropyWithLogits<TFloat32> batchLoss = tf.nn
.softmaxCrossEntropyWithLogits(logits, oneHot);
Mean<TFloat32> labelLoss = tf.math.mean(batchLoss.loss(), tf.val(0));
Add<TFloat32> regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math
.add(tf.nn.l2Loss(fc1Biases),
tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases))));
Add<TFloat32> loss = tf.withName(TRAINING_LOSS).math
.add(labelLoss, tf.math.mul(regularizers, tf.val(5e-4f)));

String lcOptimizerName = optimizerName.toLowerCase();
// Optimizer
Optimizer optimizer;
switch (lcOptimizerName) {
case "adadelta":
optimizer = new AdaDelta(graph, 1f, 0.95f, 1e-8f);
break;
case "adagradda":
optimizer = new AdaGradDA(graph, 0.01f);
break;
case "adagrad":
optimizer = new AdaGrad(graph, 0.01f);
break;
case "adam":
optimizer = new Adam(graph, 0.001f, 0.9f, 0.999f, 1e-8f);
break;
case "sgd":
optimizer = new GradientDescent(graph, 0.01f);
break;
case "momentum":
optimizer = new Momentum(graph, 0.01f, 0.9f, false);
break;
case "rmsprop":
optimizer = new RMSProp(graph, 0.01f, 0.9f, 0.0f, 1e-10f, false);
break;
default:
throw new IllegalArgumentException("Unknown optimizer " + optimizerName);
}
logger.info("Optimizer = " + optimizer.toString());
Op minimize = optimizer.minimize(loss, TRAIN);

Op init = graph.variablesInitializer();

return graph;
}

public static void train(Session session, int epochs, int minibatchSize, MnistDataset dataset) {
// Initialises the parameters.
session.runner().addTarget(INIT).run();
logger.info("Initialised the model parameters");

int interval = 0;
// Train the model
for (int i = 0; i < epochs; i++) {
for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) {
try (Tensor<TUint8> batchImages = TUint8.tensorOf(trainingBatch.images());
Tensor<TUint8> batchLabels = TUint8.tensorOf(trainingBatch.labels());
Tensor<?> loss = session.runner()
.feed(OUTPUT_NAME, batchLabels)
.feed(INPUT_NAME, batchImages)
.addTarget(TRAIN)
.fetch(TRAINING_LOSS)
.run().get(0)) {
if (interval % 100 == 0) {
logger.log(Level.INFO,
"Iteration = " + interval + ", training loss = " + loss.floatValue());
}
}
interval++;
}
}
}

public static void test(Session session, int minibatchSize, MnistDataset dataset) {
TFloat32 prediction;

int correctCount = 0;
int[][] confusionMatrix = new int[10][10];

int j = 0;
for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) {
try (Tensor<TUint8> transformedInput = TUint8.tensorOf(trainingBatch.images());
Tensor<?> outputTensor = session.runner()
.feed(INPUT_NAME, transformedInput)
.fetch(OUTPUT_NAME).run().get(0)) {
prediction = (TFloat32) outputTensor.data();
}

ByteNdArray labelBatch = trainingBatch.labels();
for (int k = 0; k < labelBatch.shape().size(0); k++) {
byte trueLabel = labelBatch.getByte(k);
int predLabel;

predLabel = argmax(prediction.slice(Indices.at(k),Indices.all()));
if (predLabel == trueLabel) {
correctCount++;
}

confusionMatrix[trueLabel][predLabel]++;
}

if (j % 1000 == 0) {
logger.log(Level.INFO, "Cur accuracy = " + ((float) correctCount) / (j + minibatchSize));
}
j += minibatchSize;
}

logger.info("Final accuracy = " + ((float) correctCount) / dataset.numTestingExamples());

StringBuilder sb = new StringBuilder();
sb.append("Label");
for (int i = 0; i < confusionMatrix.length; i++) {
sb.append(String.format("%1$5s", "" + i));
}
sb.append("\n");

for (int i = 0; i < confusionMatrix.length; i++) {
sb.append(String.format("%1$5s", "" + i));
for (j = 0; j < confusionMatrix[i].length; j++) {
sb.append(String.format("%1$5s", "" + confusionMatrix[i][j]));
}
sb.append("\n");
}

System.out.println(sb.toString());
}

/**
* Find the maximum probability and return it's index.
*
* @param probabilities The probabilites.
* @return The index of the max.
*/
public static int argmax(FloatNdArray probabilities) {
float maxVal = Float.NEGATIVE_INFINITY;
int idx = 0;
for (int i = 0; i < probabilities.shape().size(i); i++) {
float curVal = probabilities.getFloat(i);
if (curVal > maxVal) {
maxVal = curVal;
idx = i;
}
}
return idx;
}

public static void main(String[] args) throws IOException, ClassNotFoundException {
logger.info(
"Usage: MNISTTest <num-epochs> <minibatch-size> <optimizer-name>");

MnistDataset dataset = MnistDataset.create(0);

logger.info("Loaded data.");

int epochs = Integer.parseInt(args[0]);
int minibatchSize = Integer.parseInt(args[1]);

try (Graph graph = build(args[2]);
Session session = new Session(graph)) {
train(session, epochs, minibatchSize, dataset);

logger.info("Trained model");

test(session, minibatchSize, dataset);

}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ public Iterable<ImageBatch> validationBatches(int batchSize) {
return () -> new ImageBatchIterator(batchSize, validationImages, validationLabels);
}

public Iterable<ImageBatch> testBatches(int batchSize) {
return () -> new ImageBatchIterator(batchSize, testImages, testLabels);
}

public ImageBatch testBatch() {
return new ImageBatch(testImages, testLabels);
}
Expand All @@ -72,6 +76,18 @@ public long imageSize() {
return imageSize;
}

public long numTrainingExamples() {
return trainingLabels.shape().size(0);
}

public long numTestingExamples() {
return testLabels.shape().size(0);
}

public long numValidationExamples() {
return validationLabels.shape().size(0);
}

private static final String TRAINING_IMAGES_ARCHIVE = "train-images-idx3-ubyte.gz";
private static final String TRAINING_LABELS_ARCHIVE = "train-labels-idx1-ubyte.gz";
private static final String TEST_IMAGES_ARCHIVE = "t10k-images-idx3-ubyte.gz";
Expand Down Expand Up @@ -121,6 +137,6 @@ private static ByteNdArray readArchive(String archiveName) throws IOException {
}
byte[] bytes = new byte[size];
archiveStream.readFully(bytes);
return NdArrays.wrap(DataBuffers.of(bytes, true, false), Shape.of(dimSizes));
return NdArrays.wrap(DataBuffers.from(bytes, true, false), Shape.of(dimSizes));
}
}