Skip to content

Commit ec99ec0

Browse files
author
Ryo Miyajima
committed
init, copy mnist beginners tutorial
0 parents commit ec99ec0

File tree

6 files changed

+581
-0
lines changed

6 files changed

+581
-0
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
data
2+
data-*
3+
*.pyc
4+
checkpoint
5+
download

__init__.py

Whitespace-only changes.

fully_connected_feed.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
"""Trains and Evaluates the MNIST network using a feed dictionary.
2+
3+
TensorFlow install instructions:
4+
https://tensorflow.org/get_started/os_setup.html
5+
6+
MNIST tutorial:
7+
https://tensorflow.org/tutorials/mnist/tf/index.html
8+
9+
"""
10+
from __future__ import print_function
11+
# pylint: disable=missing-docstring
12+
import os.path
13+
import time
14+
15+
import tensorflow.python.platform
16+
import numpy
17+
import tensorflow as tf
18+
19+
import input_data
20+
import mnist
21+
22+
23+
# Basic model parameters as external flags.
24+
flags = tf.app.flags
25+
FLAGS = flags.FLAGS
26+
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
27+
flags.DEFINE_integer('max_steps', 2000, 'Number of steps to run trainer.')
28+
flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
29+
flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
30+
flags.DEFINE_integer('batch_size', 100, 'Batch size. '
31+
'Must divide evenly into the dataset sizes.')
32+
flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
33+
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
34+
'for unit testing.')
35+
36+
37+
def placeholder_inputs(batch_size):
38+
"""Generate placeholder variables to represent the the input tensors.
39+
40+
These placeholders are used as inputs by the rest of the model building
41+
code and will be fed from the downloaded data in the .run() loop, below.
42+
43+
Args:
44+
batch_size: The batch size will be baked into both placeholders.
45+
46+
Returns:
47+
images_placeholder: Images placeholder.
48+
labels_placeholder: Labels placeholder.
49+
"""
50+
# Note that the shapes of the placeholders match the shapes of the full
51+
# image and label tensors, except the first dimension is now batch_size
52+
# rather than the full size of the train or test data sets.
53+
images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
54+
mnist.IMAGE_PIXELS))
55+
labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
56+
return images_placeholder, labels_placeholder
57+
58+
59+
def fill_feed_dict(data_set, images_pl, labels_pl):
60+
"""Fills the feed_dict for training the given step.
61+
62+
A feed_dict takes the form of:
63+
feed_dict = {
64+
<placeholder>: <tensor of values to be passed for placeholder>,
65+
....
66+
}
67+
68+
Args:
69+
data_set: The set of images and labels, from input_data.read_data_sets()
70+
images_pl: The images placeholder, from placeholder_inputs().
71+
labels_pl: The labels placeholder, from placeholder_inputs().
72+
73+
Returns:
74+
feed_dict: The feed dictionary mapping from placeholders to values.
75+
"""
76+
# Create the feed_dict for the placeholders filled with the next
77+
# `batch size ` examples.
78+
images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
79+
FLAGS.fake_data)
80+
feed_dict = {
81+
images_pl: images_feed,
82+
labels_pl: labels_feed,
83+
}
84+
return feed_dict
85+
86+
87+
def do_eval(sess,
88+
eval_correct,
89+
images_placeholder,
90+
labels_placeholder,
91+
data_set):
92+
"""Runs one evaluation against the full epoch of data.
93+
94+
Args:
95+
sess: The session in which the model has been trained.
96+
eval_correct: The Tensor that returns the number of correct predictions.
97+
images_placeholder: The images placeholder.
98+
labels_placeholder: The labels placeholder.
99+
data_set: The set of images and labels to evaluate, from
100+
input_data.read_data_sets().
101+
"""
102+
# And run one epoch of eval.
103+
true_count = 0 # Counts the number of correct predictions.
104+
steps_per_epoch = int(data_set.num_examples / FLAGS.batch_size)
105+
num_examples = steps_per_epoch * FLAGS.batch_size
106+
for step in xrange(steps_per_epoch):
107+
feed_dict = fill_feed_dict(data_set,
108+
images_placeholder,
109+
labels_placeholder)
110+
true_count += sess.run(eval_correct, feed_dict=feed_dict)
111+
precision = float(true_count) / float(num_examples)
112+
print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' %
113+
(num_examples, true_count, precision))
114+
115+
116+
def run_training():
117+
"""Train MNIST for a number of steps."""
118+
# Get the sets of images and labels for training, validation, and
119+
# test on MNIST.
120+
data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
121+
122+
# Tell TensorFlow that the model will be built into the default Graph.
123+
with tf.Graph().as_default():
124+
# Generate placeholders for the images and labels.
125+
images_placeholder, labels_placeholder = placeholder_inputs(
126+
FLAGS.batch_size)
127+
128+
# Build a Graph that computes predictions from the inference model.
129+
logits = mnist.inference(images_placeholder,
130+
FLAGS.hidden1,
131+
FLAGS.hidden2)
132+
133+
# Add to the Graph the Ops for loss calculation.
134+
loss = mnist.loss(logits, labels_placeholder)
135+
136+
# Add to the Graph the Ops that calculate and apply gradients.
137+
train_op = mnist.training(loss, FLAGS.learning_rate)
138+
139+
# Add the Op to compare the logits to the labels during evaluation.
140+
eval_correct = mnist.evaluation(logits, labels_placeholder)
141+
142+
# Build the summary operation based on the TF collection of Summaries.
143+
summary_op = tf.merge_all_summaries()
144+
145+
# Create a saver for writing training checkpoints.
146+
saver = tf.train.Saver()
147+
148+
# Create a session for running Ops on the Graph.
149+
sess = tf.Session()
150+
151+
# Run the Op to initialize the variables.
152+
init = tf.initialize_all_variables()
153+
sess.run(init)
154+
155+
# Instantiate a SummaryWriter to output summaries and the Graph.
156+
summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
157+
graph_def=sess.graph_def)
158+
159+
# And then after everything is built, start the training loop.
160+
for step in xrange(FLAGS.max_steps):
161+
start_time = time.time()
162+
163+
# Fill a feed dictionary with the actual set of images and labels
164+
# for this particular training step.
165+
feed_dict = fill_feed_dict(data_sets.train,
166+
images_placeholder,
167+
labels_placeholder)
168+
169+
# Run one step of the model. The return values are the activations
170+
# from the `train_op` (which is discarded) and the `loss` Op. To
171+
# inspect the values of your Ops or variables, you may include them
172+
# in the list passed to sess.run() and the value tensors will be
173+
# returned in the tuple from the call.
174+
_, loss_value = sess.run([train_op, loss],
175+
feed_dict=feed_dict)
176+
177+
duration = time.time() - start_time
178+
179+
# Write the summaries and print an overview fairly often.
180+
if step % 100 == 0:
181+
# Print status to stdout.
182+
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
183+
# Update the events file.
184+
summary_str = sess.run(summary_op, feed_dict=feed_dict)
185+
summary_writer.add_summary(summary_str, step)
186+
187+
# Save a checkpoint and evaluate the model periodically.
188+
if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
189+
saver.save(sess, FLAGS.train_dir, global_step=step)
190+
# Evaluate against the training set.
191+
print('Training Data Eval:')
192+
do_eval(sess,
193+
eval_correct,
194+
images_placeholder,
195+
labels_placeholder,
196+
data_sets.train)
197+
# Evaluate against the validation set.
198+
print('Validation Data Eval:')
199+
do_eval(sess,
200+
eval_correct,
201+
images_placeholder,
202+
labels_placeholder,
203+
data_sets.validation)
204+
# Evaluate against the test set.
205+
print('Test Data Eval:')
206+
do_eval(sess,
207+
eval_correct,
208+
images_placeholder,
209+
labels_placeholder,
210+
data_sets.test)
211+
212+
213+
def main(_):
214+
run_training()
215+
216+
217+
if __name__ == '__main__':
218+
tf.app.run()

0 commit comments

Comments
 (0)