|
| 1 | +/* |
| 2 | + * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | +package org.tensorflow.model.examples.mnist; |
| 17 | + |
| 18 | +import java.util.Arrays; |
| 19 | +import java.util.logging.Level; |
| 20 | +import java.util.logging.Logger; |
| 21 | +import org.tensorflow.Graph; |
| 22 | +import org.tensorflow.Operand; |
| 23 | +import org.tensorflow.Session; |
| 24 | +import org.tensorflow.Tensor; |
| 25 | +import org.tensorflow.model.examples.mnist.data.ImageBatch; |
| 26 | +import org.tensorflow.model.examples.mnist.data.MnistDataset; |
| 27 | +import org.tensorflow.op.Op; |
| 28 | +import org.tensorflow.op.Ops; |
| 29 | +import org.tensorflow.op.core.Constant; |
| 30 | +import org.tensorflow.op.core.OneHot; |
| 31 | +import org.tensorflow.op.core.Placeholder; |
| 32 | +import org.tensorflow.op.core.Reshape; |
| 33 | +import org.tensorflow.op.core.Variable; |
| 34 | +import org.tensorflow.op.math.Add; |
| 35 | +import org.tensorflow.op.math.Mean; |
| 36 | +import org.tensorflow.op.nn.Conv2d; |
| 37 | +import org.tensorflow.op.nn.MaxPool; |
| 38 | +import org.tensorflow.op.nn.Relu; |
| 39 | +import org.tensorflow.op.nn.Softmax; |
| 40 | +import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; |
| 41 | +import org.tensorflow.op.random.TruncatedNormal; |
| 42 | +import org.tensorflow.tools.Shape; |
| 43 | +import org.tensorflow.tools.ndarray.ByteNdArray; |
| 44 | +import org.tensorflow.tools.ndarray.FloatNdArray; |
| 45 | +import org.tensorflow.tools.ndarray.index.Indices; |
| 46 | +import org.tensorflow.training.optimizers.AdaDelta; |
| 47 | +import org.tensorflow.training.optimizers.AdaGrad; |
| 48 | +import org.tensorflow.training.optimizers.AdaGradDA; |
| 49 | +import org.tensorflow.training.optimizers.Adam; |
| 50 | +import org.tensorflow.training.optimizers.GradientDescent; |
| 51 | +import org.tensorflow.training.optimizers.Momentum; |
| 52 | +import org.tensorflow.training.optimizers.Optimizer; |
| 53 | +import org.tensorflow.training.optimizers.RMSProp; |
| 54 | +import org.tensorflow.types.TFloat32; |
| 55 | +import org.tensorflow.types.TUint8; |
| 56 | + |
| 57 | +/** |
| 58 | + * Builds a LeNet-5 style CNN for MNIST. |
| 59 | + */ |
| 60 | +public class CnnMnist { |
| 61 | + |
| 62 | + private static final Logger logger = Logger.getLogger(CnnMnist.class.getName()); |
| 63 | + |
| 64 | + private static final int PIXEL_DEPTH = 255; |
| 65 | + private static final int NUM_CHANNELS = 1; |
| 66 | + private static final int IMAGE_SIZE = 28; |
| 67 | + private static final int NUM_LABELS = MnistDataset.NUM_CLASSES; |
| 68 | + private static final long SEED = 123456789L; |
| 69 | + |
| 70 | + private static final String PADDING_TYPE = "SAME"; |
| 71 | + |
| 72 | + public static final String INPUT_NAME = "input"; |
| 73 | + public static final String OUTPUT_NAME = "output"; |
| 74 | + public static final String TARGET = "target"; |
| 75 | + public static final String TRAIN = "train"; |
| 76 | + public static final String TRAINING_LOSS = "training_loss"; |
| 77 | + public static final String INIT = "init"; |
| 78 | + |
| 79 | + public static Graph build(String optimizerName) { |
| 80 | + Graph graph = new Graph(); |
| 81 | + |
| 82 | + Ops tf = Ops.create(graph); |
| 83 | + |
| 84 | + // Inputs |
| 85 | + Placeholder<TUint8> input = tf.withName(INPUT_NAME).placeholder(TUint8.DTYPE, |
| 86 | + Placeholder.shape(Shape.of(-1, IMAGE_SIZE, IMAGE_SIZE))); |
| 87 | + Reshape<TUint8> input_reshaped = tf |
| 88 | + .reshape(input, tf.array(-1, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)); |
| 89 | + Placeholder<TUint8> labels = tf.withName(TARGET).placeholder(TUint8.DTYPE); |
| 90 | + |
| 91 | + // Scaling the features |
| 92 | + Constant<TFloat32> centeringFactor = tf.constant(PIXEL_DEPTH / 2.0f); |
| 93 | + Constant<TFloat32> scalingFactor = tf.constant((float) PIXEL_DEPTH); |
| 94 | + Operand<TFloat32> scaledInput = tf.math |
| 95 | + .div(tf.math.sub(tf.dtypes.cast(input_reshaped, TFloat32.DTYPE), centeringFactor), |
| 96 | + scalingFactor); |
| 97 | + |
| 98 | + // First conv layer |
| 99 | + Variable<TFloat32> conv1Weights = tf.variable(tf.math.mul(tf.random |
| 100 | + .truncatedNormal(tf.array(5, 5, NUM_CHANNELS, 32), TFloat32.DTYPE, |
| 101 | + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); |
| 102 | + Conv2d<TFloat32> conv1 = tf.nn |
| 103 | + .conv2d(scaledInput, conv1Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); |
| 104 | + Variable<TFloat32> conv1Biases = tf |
| 105 | + .variable(tf.fill(tf.array(new int[]{32}), tf.constant(0.0f))); |
| 106 | + Relu<TFloat32> relu1 = tf.nn.relu(tf.nn.biasAdd(conv1, conv1Biases)); |
| 107 | + |
| 108 | + // First pooling layer |
| 109 | + MaxPool<TFloat32> pool1 = tf.nn |
| 110 | + .maxPool(relu1, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1), |
| 111 | + PADDING_TYPE); |
| 112 | + |
| 113 | + // Second conv layer |
| 114 | + Variable<TFloat32> conv2Weights = tf.variable(tf.math.mul(tf.random |
| 115 | + .truncatedNormal(tf.array(5, 5, 32, 64), TFloat32.DTYPE, |
| 116 | + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); |
| 117 | + Conv2d<TFloat32> conv2 = tf.nn |
| 118 | + .conv2d(pool1, conv2Weights, Arrays.asList(1L, 1L, 1L, 1L), PADDING_TYPE); |
| 119 | + Variable<TFloat32> conv2Biases = tf |
| 120 | + .variable(tf.fill(tf.array(new int[]{64}), tf.constant(0.1f))); |
| 121 | + Relu<TFloat32> relu2 = tf.nn.relu(tf.nn.biasAdd(conv2, conv2Biases)); |
| 122 | + |
| 123 | + // Second pooling layer |
| 124 | + MaxPool<TFloat32> pool2 = tf.nn |
| 125 | + .maxPool(relu2, tf.array(1, 2, 2, 1), tf.array(1, 2, 2, 1), |
| 126 | + PADDING_TYPE); |
| 127 | + |
| 128 | + // Flatten inputs |
| 129 | + Reshape<TFloat32> flatten = tf.reshape(pool2, tf.concat(Arrays |
| 130 | + .asList(tf.slice(tf.shape(pool2), tf.array(new int[]{0}), tf.array(new int[]{1})), |
| 131 | + tf.array(new int[]{-1})), tf.constant(0))); |
| 132 | + |
| 133 | + // Fully connected layer |
| 134 | + Variable<TFloat32> fc1Weights = tf.variable(tf.math.mul(tf.random |
| 135 | + .truncatedNormal(tf.array(IMAGE_SIZE * IMAGE_SIZE * 4, 512), TFloat32.DTYPE, |
| 136 | + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); |
| 137 | + Variable<TFloat32> fc1Biases = tf |
| 138 | + .variable(tf.fill(tf.array(new int[]{512}), tf.constant(0.1f))); |
| 139 | + Relu<TFloat32> relu3 = tf.nn |
| 140 | + .relu(tf.math.add(tf.linalg.matMul(flatten, fc1Weights), fc1Biases)); |
| 141 | + |
| 142 | + // Softmax layer |
| 143 | + Variable<TFloat32> fc2Weights = tf.variable(tf.math.mul(tf.random |
| 144 | + .truncatedNormal(tf.array(512, NUM_LABELS), TFloat32.DTYPE, |
| 145 | + TruncatedNormal.seed(SEED)), tf.constant(0.1f))); |
| 146 | + Variable<TFloat32> fc2Biases = tf |
| 147 | + .variable(tf.fill(tf.array(new int[]{NUM_LABELS}), tf.constant(0.1f))); |
| 148 | + |
| 149 | + Add<TFloat32> logits = tf.math.add(tf.linalg.matMul(relu3, fc2Weights), fc2Biases); |
| 150 | + |
| 151 | + // Predicted outputs |
| 152 | + Softmax<TFloat32> prediction = tf.withName(OUTPUT_NAME).nn.softmax(logits); |
| 153 | + |
| 154 | + // Loss function & regularization |
| 155 | + OneHot<TFloat32> oneHot = tf |
| 156 | + .oneHot(labels, tf.constant(10), tf.constant(1.0f), tf.constant(0.0f)); |
| 157 | + SoftmaxCrossEntropyWithLogits<TFloat32> batchLoss = tf.nn |
| 158 | + .softmaxCrossEntropyWithLogits(logits, oneHot); |
| 159 | + Mean<TFloat32> labelLoss = tf.math.mean(batchLoss.loss(), tf.constant(0)); |
| 160 | + Add<TFloat32> regularizers = tf.math.add(tf.nn.l2Loss(fc1Weights), tf.math |
| 161 | + .add(tf.nn.l2Loss(fc1Biases), |
| 162 | + tf.math.add(tf.nn.l2Loss(fc2Weights), tf.nn.l2Loss(fc2Biases)))); |
| 163 | + Add<TFloat32> loss = tf.withName(TRAINING_LOSS).math |
| 164 | + .add(labelLoss, tf.math.mul(regularizers, tf.constant(5e-4f))); |
| 165 | + |
| 166 | + String lcOptimizerName = optimizerName.toLowerCase(); |
| 167 | + // Optimizer |
| 168 | + Optimizer optimizer; |
| 169 | + switch (lcOptimizerName) { |
| 170 | + case "adadelta": |
| 171 | + optimizer = new AdaDelta(graph, 1f, 0.95f, 1e-8f); |
| 172 | + break; |
| 173 | + case "adagradda": |
| 174 | + optimizer = new AdaGradDA(graph, 0.01f); |
| 175 | + break; |
| 176 | + case "adagrad": |
| 177 | + optimizer = new AdaGrad(graph, 0.01f); |
| 178 | + break; |
| 179 | + case "adam": |
| 180 | + optimizer = new Adam(graph, 0.001f, 0.9f, 0.999f, 1e-8f); |
| 181 | + break; |
| 182 | + case "sgd": |
| 183 | + optimizer = new GradientDescent(graph, 0.01f); |
| 184 | + break; |
| 185 | + case "momentum": |
| 186 | + optimizer = new Momentum(graph, 0.01f, 0.9f, false); |
| 187 | + break; |
| 188 | + case "rmsprop": |
| 189 | + optimizer = new RMSProp(graph, 0.01f, 0.9f, 0.0f, 1e-10f, false); |
| 190 | + break; |
| 191 | + default: |
| 192 | + throw new IllegalArgumentException("Unknown optimizer " + optimizerName); |
| 193 | + } |
| 194 | + logger.info("Optimizer = " + optimizer.toString()); |
| 195 | + Op minimize = optimizer.minimize(loss, TRAIN); |
| 196 | + |
| 197 | + tf.init(); |
| 198 | + |
| 199 | + return graph; |
| 200 | + } |
| 201 | + |
| 202 | + public static void train(Session session, int epochs, int minibatchSize, MnistDataset dataset) { |
| 203 | + // Initialises the parameters. |
| 204 | + session.runner().addTarget(INIT).run(); |
| 205 | + logger.info("Initialised the model parameters"); |
| 206 | + |
| 207 | + int interval = 0; |
| 208 | + // Train the model |
| 209 | + for (int i = 0; i < epochs; i++) { |
| 210 | + for (ImageBatch trainingBatch : dataset.trainingBatches(minibatchSize)) { |
| 211 | + try (Tensor<TUint8> batchImages = TUint8.tensorOf(trainingBatch.images()); |
| 212 | + Tensor<TUint8> batchLabels = TUint8.tensorOf(trainingBatch.labels()); |
| 213 | + Tensor<TFloat32> loss = session.runner() |
| 214 | + .feed(TARGET, batchLabels) |
| 215 | + .feed(INPUT_NAME, batchImages) |
| 216 | + .addTarget(TRAIN) |
| 217 | + .fetch(TRAINING_LOSS) |
| 218 | + .run().get(0).expect(TFloat32.DTYPE)) { |
| 219 | + if (interval % 100 == 0) { |
| 220 | + logger.log(Level.INFO, |
| 221 | + "Iteration = " + interval + ", training loss = " + loss.data().getFloat()); |
| 222 | + } |
| 223 | + } |
| 224 | + interval++; |
| 225 | + } |
| 226 | + } |
| 227 | + } |
| 228 | + |
| 229 | + public static void test(Session session, int minibatchSize, MnistDataset dataset) { |
| 230 | + int correctCount = 0; |
| 231 | + int[][] confusionMatrix = new int[10][10]; |
| 232 | + |
| 233 | + for (ImageBatch trainingBatch : dataset.testBatches(minibatchSize)) { |
| 234 | + try (Tensor<TUint8> transformedInput = TUint8.tensorOf(trainingBatch.images()); |
| 235 | + Tensor<TFloat32> outputTensor = session.runner() |
| 236 | + .feed(INPUT_NAME, transformedInput) |
| 237 | + .fetch(OUTPUT_NAME).run().get(0).expect(TFloat32.DTYPE)) { |
| 238 | + |
| 239 | + ByteNdArray labelBatch = trainingBatch.labels(); |
| 240 | + for (int k = 0; k < labelBatch.shape().size(0); k++) { |
| 241 | + byte trueLabel = labelBatch.getByte(k); |
| 242 | + int predLabel; |
| 243 | + |
| 244 | + predLabel = argmax(outputTensor.data().slice(Indices.at(k), Indices.all())); |
| 245 | + if (predLabel == trueLabel) { |
| 246 | + correctCount++; |
| 247 | + } |
| 248 | + |
| 249 | + confusionMatrix[trueLabel][predLabel]++; |
| 250 | + } |
| 251 | + } |
| 252 | + } |
| 253 | + |
| 254 | + logger.info("Final accuracy = " + ((float) correctCount) / dataset.numTestingExamples()); |
| 255 | + |
| 256 | + StringBuilder sb = new StringBuilder(); |
| 257 | + sb.append("Label"); |
| 258 | + for (int i = 0; i < confusionMatrix.length; i++) { |
| 259 | + sb.append(String.format("%1$5s", "" + i)); |
| 260 | + } |
| 261 | + sb.append("\n"); |
| 262 | + |
| 263 | + for (int i = 0; i < confusionMatrix.length; i++) { |
| 264 | + sb.append(String.format("%1$5s", "" + i)); |
| 265 | + for (int j = 0; j < confusionMatrix[i].length; j++) { |
| 266 | + sb.append(String.format("%1$5s", "" + confusionMatrix[i][j])); |
| 267 | + } |
| 268 | + sb.append("\n"); |
| 269 | + } |
| 270 | + |
| 271 | + System.out.println(sb.toString()); |
| 272 | + } |
| 273 | + |
| 274 | + /** |
| 275 | + * Find the maximum probability and return it's index. |
| 276 | + * |
| 277 | + * @param probabilities The probabilites. |
| 278 | + * @return The index of the max. |
| 279 | + */ |
| 280 | + public static int argmax(FloatNdArray probabilities) { |
| 281 | + float maxVal = Float.NEGATIVE_INFINITY; |
| 282 | + int idx = 0; |
| 283 | + for (int i = 0; i < probabilities.shape().size(0); i++) { |
| 284 | + float curVal = probabilities.getFloat(i); |
| 285 | + if (curVal > maxVal) { |
| 286 | + maxVal = curVal; |
| 287 | + idx = i; |
| 288 | + } |
| 289 | + } |
| 290 | + return idx; |
| 291 | + } |
| 292 | + |
| 293 | + public static void main(String[] args) { |
| 294 | + logger.info( |
| 295 | + "Usage: MNISTTest <num-epochs> <minibatch-size> <optimizer-name>"); |
| 296 | + |
| 297 | + MnistDataset dataset = MnistDataset.create(0); |
| 298 | + |
| 299 | + logger.info("Loaded data."); |
| 300 | + |
| 301 | + int epochs = Integer.parseInt(args[0]); |
| 302 | + int minibatchSize = Integer.parseInt(args[1]); |
| 303 | + |
| 304 | + try (Graph graph = build(args[2]); |
| 305 | + Session session = new Session(graph)) { |
| 306 | + train(session, epochs, minibatchSize, dataset); |
| 307 | + |
| 308 | + logger.info("Trained model"); |
| 309 | + |
| 310 | + test(session, minibatchSize, dataset); |
| 311 | + } |
| 312 | + } |
| 313 | +} |
0 commit comments