|  | 
|  | 1 | +"""Trains and Evaluates the MNIST network using a feed dictionary. | 
|  | 2 | +
 | 
|  | 3 | +TensorFlow install instructions: | 
|  | 4 | +https://tensorflow.org/get_started/os_setup.html | 
|  | 5 | +
 | 
|  | 6 | +MNIST tutorial: | 
|  | 7 | +https://tensorflow.org/tutorials/mnist/tf/index.html | 
|  | 8 | +
 | 
|  | 9 | +""" | 
|  | 10 | +from __future__ import print_function | 
|  | 11 | +# pylint: disable=missing-docstring | 
|  | 12 | +import os.path | 
|  | 13 | +import time | 
|  | 14 | + | 
|  | 15 | +import tensorflow.python.platform | 
|  | 16 | +import numpy | 
|  | 17 | +import tensorflow as tf | 
|  | 18 | + | 
|  | 19 | +import input_data | 
|  | 20 | +import mnist | 
|  | 21 | + | 
|  | 22 | + | 
|  | 23 | +# Basic model parameters as external flags. | 
|  | 24 | +flags = tf.app.flags | 
|  | 25 | +FLAGS = flags.FLAGS | 
|  | 26 | +flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') | 
|  | 27 | +flags.DEFINE_integer('max_steps', 2000, 'Number of steps to run trainer.') | 
|  | 28 | +flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.') | 
|  | 29 | +flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.') | 
|  | 30 | +flags.DEFINE_integer('batch_size', 100, 'Batch size. ' | 
|  | 31 | + 'Must divide evenly into the dataset sizes.') | 
|  | 32 | +flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.') | 
|  | 33 | +flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data ' | 
|  | 34 | + 'for unit testing.') | 
|  | 35 | + | 
|  | 36 | + | 
|  | 37 | +def placeholder_inputs(batch_size): | 
|  | 38 | + """Generate placeholder variables to represent the the input tensors. | 
|  | 39 | +
 | 
|  | 40 | + These placeholders are used as inputs by the rest of the model building | 
|  | 41 | + code and will be fed from the downloaded data in the .run() loop, below. | 
|  | 42 | +
 | 
|  | 43 | + Args: | 
|  | 44 | + batch_size: The batch size will be baked into both placeholders. | 
|  | 45 | +
 | 
|  | 46 | + Returns: | 
|  | 47 | + images_placeholder: Images placeholder. | 
|  | 48 | + labels_placeholder: Labels placeholder. | 
|  | 49 | + """ | 
|  | 50 | + # Note that the shapes of the placeholders match the shapes of the full | 
|  | 51 | + # image and label tensors, except the first dimension is now batch_size | 
|  | 52 | + # rather than the full size of the train or test data sets. | 
|  | 53 | + images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, | 
|  | 54 | + mnist.IMAGE_PIXELS)) | 
|  | 55 | + labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size)) | 
|  | 56 | + return images_placeholder, labels_placeholder | 
|  | 57 | + | 
|  | 58 | + | 
|  | 59 | +def fill_feed_dict(data_set, images_pl, labels_pl): | 
|  | 60 | + """Fills the feed_dict for training the given step. | 
|  | 61 | +
 | 
|  | 62 | + A feed_dict takes the form of: | 
|  | 63 | + feed_dict = { | 
|  | 64 | + <placeholder>: <tensor of values to be passed for placeholder>, | 
|  | 65 | + .... | 
|  | 66 | + } | 
|  | 67 | +
 | 
|  | 68 | + Args: | 
|  | 69 | + data_set: The set of images and labels, from input_data.read_data_sets() | 
|  | 70 | + images_pl: The images placeholder, from placeholder_inputs(). | 
|  | 71 | + labels_pl: The labels placeholder, from placeholder_inputs(). | 
|  | 72 | +
 | 
|  | 73 | + Returns: | 
|  | 74 | + feed_dict: The feed dictionary mapping from placeholders to values. | 
|  | 75 | + """ | 
|  | 76 | + # Create the feed_dict for the placeholders filled with the next | 
|  | 77 | + # `batch size ` examples. | 
|  | 78 | + images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size, | 
|  | 79 | + FLAGS.fake_data) | 
|  | 80 | + feed_dict = { | 
|  | 81 | + images_pl: images_feed, | 
|  | 82 | + labels_pl: labels_feed, | 
|  | 83 | + } | 
|  | 84 | + return feed_dict | 
|  | 85 | + | 
|  | 86 | + | 
|  | 87 | +def do_eval(sess, | 
|  | 88 | + eval_correct, | 
|  | 89 | + images_placeholder, | 
|  | 90 | + labels_placeholder, | 
|  | 91 | + data_set): | 
|  | 92 | + """Runs one evaluation against the full epoch of data. | 
|  | 93 | +
 | 
|  | 94 | + Args: | 
|  | 95 | + sess: The session in which the model has been trained. | 
|  | 96 | + eval_correct: The Tensor that returns the number of correct predictions. | 
|  | 97 | + images_placeholder: The images placeholder. | 
|  | 98 | + labels_placeholder: The labels placeholder. | 
|  | 99 | + data_set: The set of images and labels to evaluate, from | 
|  | 100 | + input_data.read_data_sets(). | 
|  | 101 | + """ | 
|  | 102 | + # And run one epoch of eval. | 
|  | 103 | + true_count = 0 # Counts the number of correct predictions. | 
|  | 104 | + steps_per_epoch = int(data_set.num_examples / FLAGS.batch_size) | 
|  | 105 | + num_examples = steps_per_epoch * FLAGS.batch_size | 
|  | 106 | + for step in xrange(steps_per_epoch): | 
|  | 107 | + feed_dict = fill_feed_dict(data_set, | 
|  | 108 | + images_placeholder, | 
|  | 109 | + labels_placeholder) | 
|  | 110 | + true_count += sess.run(eval_correct, feed_dict=feed_dict) | 
|  | 111 | + precision = float(true_count) / float(num_examples) | 
|  | 112 | + print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' % | 
|  | 113 | + (num_examples, true_count, precision)) | 
|  | 114 | + | 
|  | 115 | + | 
|  | 116 | +def run_training(): | 
|  | 117 | + """Train MNIST for a number of steps.""" | 
|  | 118 | + # Get the sets of images and labels for training, validation, and | 
|  | 119 | + # test on MNIST. | 
|  | 120 | + data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data) | 
|  | 121 | + | 
|  | 122 | + # Tell TensorFlow that the model will be built into the default Graph. | 
|  | 123 | + with tf.Graph().as_default(): | 
|  | 124 | + # Generate placeholders for the images and labels. | 
|  | 125 | + images_placeholder, labels_placeholder = placeholder_inputs( | 
|  | 126 | + FLAGS.batch_size) | 
|  | 127 | + | 
|  | 128 | + # Build a Graph that computes predictions from the inference model. | 
|  | 129 | + logits = mnist.inference(images_placeholder, | 
|  | 130 | + FLAGS.hidden1, | 
|  | 131 | + FLAGS.hidden2) | 
|  | 132 | + | 
|  | 133 | + # Add to the Graph the Ops for loss calculation. | 
|  | 134 | + loss = mnist.loss(logits, labels_placeholder) | 
|  | 135 | + | 
|  | 136 | + # Add to the Graph the Ops that calculate and apply gradients. | 
|  | 137 | + train_op = mnist.training(loss, FLAGS.learning_rate) | 
|  | 138 | + | 
|  | 139 | + # Add the Op to compare the logits to the labels during evaluation. | 
|  | 140 | + eval_correct = mnist.evaluation(logits, labels_placeholder) | 
|  | 141 | + | 
|  | 142 | + # Build the summary operation based on the TF collection of Summaries. | 
|  | 143 | + summary_op = tf.merge_all_summaries() | 
|  | 144 | + | 
|  | 145 | + # Create a saver for writing training checkpoints. | 
|  | 146 | + saver = tf.train.Saver() | 
|  | 147 | + | 
|  | 148 | + # Create a session for running Ops on the Graph. | 
|  | 149 | + sess = tf.Session() | 
|  | 150 | + | 
|  | 151 | + # Run the Op to initialize the variables. | 
|  | 152 | + init = tf.initialize_all_variables() | 
|  | 153 | + sess.run(init) | 
|  | 154 | + | 
|  | 155 | + # Instantiate a SummaryWriter to output summaries and the Graph. | 
|  | 156 | + summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, | 
|  | 157 | + graph_def=sess.graph_def) | 
|  | 158 | + | 
|  | 159 | + # And then after everything is built, start the training loop. | 
|  | 160 | + for step in xrange(FLAGS.max_steps): | 
|  | 161 | + start_time = time.time() | 
|  | 162 | + | 
|  | 163 | + # Fill a feed dictionary with the actual set of images and labels | 
|  | 164 | + # for this particular training step. | 
|  | 165 | + feed_dict = fill_feed_dict(data_sets.train, | 
|  | 166 | + images_placeholder, | 
|  | 167 | + labels_placeholder) | 
|  | 168 | + | 
|  | 169 | + # Run one step of the model. The return values are the activations | 
|  | 170 | + # from the `train_op` (which is discarded) and the `loss` Op. To | 
|  | 171 | + # inspect the values of your Ops or variables, you may include them | 
|  | 172 | + # in the list passed to sess.run() and the value tensors will be | 
|  | 173 | + # returned in the tuple from the call. | 
|  | 174 | + _, loss_value = sess.run([train_op, loss], | 
|  | 175 | + feed_dict=feed_dict) | 
|  | 176 | + | 
|  | 177 | + duration = time.time() - start_time | 
|  | 178 | + | 
|  | 179 | + # Write the summaries and print an overview fairly often. | 
|  | 180 | + if step % 100 == 0: | 
|  | 181 | + # Print status to stdout. | 
|  | 182 | + print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) | 
|  | 183 | + # Update the events file. | 
|  | 184 | + summary_str = sess.run(summary_op, feed_dict=feed_dict) | 
|  | 185 | + summary_writer.add_summary(summary_str, step) | 
|  | 186 | + | 
|  | 187 | + # Save a checkpoint and evaluate the model periodically. | 
|  | 188 | + if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps: | 
|  | 189 | + saver.save(sess, FLAGS.train_dir, global_step=step) | 
|  | 190 | + # Evaluate against the training set. | 
|  | 191 | + print('Training Data Eval:') | 
|  | 192 | + do_eval(sess, | 
|  | 193 | + eval_correct, | 
|  | 194 | + images_placeholder, | 
|  | 195 | + labels_placeholder, | 
|  | 196 | + data_sets.train) | 
|  | 197 | + # Evaluate against the validation set. | 
|  | 198 | + print('Validation Data Eval:') | 
|  | 199 | + do_eval(sess, | 
|  | 200 | + eval_correct, | 
|  | 201 | + images_placeholder, | 
|  | 202 | + labels_placeholder, | 
|  | 203 | + data_sets.validation) | 
|  | 204 | + # Evaluate against the test set. | 
|  | 205 | + print('Test Data Eval:') | 
|  | 206 | + do_eval(sess, | 
|  | 207 | + eval_correct, | 
|  | 208 | + images_placeholder, | 
|  | 209 | + labels_placeholder, | 
|  | 210 | + data_sets.test) | 
|  | 211 | + | 
|  | 212 | + | 
|  | 213 | +def main(_): | 
|  | 214 | + run_training() | 
|  | 215 | + | 
|  | 216 | + | 
|  | 217 | +if __name__ == '__main__': | 
|  | 218 | + tf.app.run() | 
0 commit comments