Skip to content

Commit 849ad1c

Browse files
committed
get rid of dynamic rnn... hew
1 parent 8a848e8 commit 849ad1c

File tree

5 files changed

+70
-56
lines changed

5 files changed

+70
-56
lines changed

config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def add_argument_group(name):
4747

4848
# Misc
4949
misc_arg = add_argument_group('Misc')
50-
misc_arg.add_argument('--log_step', type=int, default=20, help='')
51-
misc_arg.add_argument('--num_log_samples', type=int, default=2, help='')
50+
misc_arg.add_argument('--log_step', type=int, default=50, help='')
51+
misc_arg.add_argument('--num_log_samples', type=int, default=3, help='')
5252
misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN'], help='')
5353
misc_arg.add_argument('--log_dir', type=str, default='logs')
5454
misc_arg.add_argument('--data_dir', type=str, default='data')

data_loader.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ def generate_one_example(n_nodes, rng):
3535
solutions = solve_tsp_dynamic(nodes)
3636
return nodes, solutions
3737

38+
def pad(x, max_length):
39+
shape = x.shape
40+
pad_length = max_length - shape[0]
41+
shape[0] = pad_length
42+
return np.concatenate([x, x.np.zeros(shape)])
43+
3844
class TSPDataLoader(object):
3945
def __init__(self, config, rng=None):
4046
self.config = config
@@ -78,10 +84,10 @@ def _create_input_queue(self, queue_capacity_factor=16):
7884
min_after_dequeue = 1000
7985
capacity = min_after_dequeue + 3 * self.batch_size
8086

81-
self.queue_ops[name] = tf.PaddingFIFOQueue(
87+
self.queue_ops[name] = tf.FIFOQueue(
8288
capacity=capacity,
8389
dtypes=[tf.float32, tf.int32],
84-
shapes=[[None, 2,], [None]],
90+
shapes=[[self.max_length, 2,], [self.max_length]],
8591
name="fifo_{}".format(name))
8692
self.enqueue_ops[name] = \
8793
self.queue_ops[name].enqueue([self.input_ops[name], self.target_ops[name]])
@@ -136,12 +142,13 @@ def _maybe_generate_and_save(self):
136142
if not os.path.exists(path):
137143
tf.logging.info("Creating {} for [{}]".format(path, self.task))
138144

139-
x, y = [], []
145+
x = np.zeros([num, self.max_length, 2], dtype=np.float32)
146+
y = np.zeros([num, self.max_length], dtype=np.int32)
140147
for i in trange(num, desc="Create {} data".format(name)):
141148
n_nodes = self.rng.randint(self.min_length, self.max_length+ 1)
142149
nodes, res = generate_one_example(n_nodes, self.rng)
143-
x.append(nodes)
144-
y.append(res)
150+
x[i,:len(nodes)] = nodes
151+
y[i,:len(res)] = res
145152

146153
np.savez(path, x=x, y=y)
147154
self.data[name] = TSP(x=x, y=y, name=name)

layers.py

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111

1212
def decoder_rnn(cell, inputs,
1313
enc_outputs, enc_final_states,
14-
seq_length, hidden_dim, num_glimpse,
15-
max_dec_length, batch_size, is_train,
16-
end_of_sequence_id=0, initializer=None):
14+
seq_length, hidden_dim,
15+
num_glimpse, batch_size, is_train,
16+
end_of_sequence_id=0, initializer=None,
17+
first_decoder_input=None):
1718
with tf.variable_scope("decoder_rnn") as scope:
1819
def attention(ref, query, with_softmax, scope="attention"):
1920
with tf.variable_scope(scope):
@@ -41,40 +42,32 @@ def glimpse(ref, query, scope="glimpse"):
4142
return tf.reduce_sum(alignments * ref, [1])
4243

4344
def output_fn(ref, query, num_glimpse):
44-
for idx in range(num_glimpse):
45-
query = glimpse(ref, query, "glimpse_{}".format(idx))
46-
return attention(ref, query, with_softmax=False, scope="attention")
47-
48-
maximum_length = tf.convert_to_tensor(max_dec_length, tf.int32)
49-
def decoder_fn_inference(
50-
time, cell_state, cell_input, cell_output, context_state):
51-
if context_state is None:
52-
context_state = tf.zeros([batch_size,], dtype=tf.int32)
53-
54-
if cell_output is None:
55-
# time == 0
56-
cell_state = enc_final_states
57-
done = tf.zeros([batch_size,], dtype=tf.bool)
45+
if query is None:
46+
return tf.zeros([11], tf.float32) # only used for shape inference
5847
else:
59-
output_logit = output_fn(enc_outputs, cell_output, num_glimpse)
48+
for idx in range(num_glimpse):
49+
query = glimpse(ref, query, "glimpse_{}".format(idx))
50+
return attention(ref, query, with_softmax=False, scope="attention")
6051

61-
sampled_idx = tf.squeeze(
62-
tf.cast(tf.multinomial(output_logit, 1), tf.int32), -1)
63-
done = tf.equal(sampled_idx, end_of_sequence_id)
64-
65-
cell_input = tf.stop_gradient(
66-
tf.gather_nd(enc_outputs, index_matrix_to_pairs(sampled_idx)))
67-
conext_state = tf.stack([context_state, sampled_idx], 0)
68-
69-
done = tf.cond(tf.greater(time, maximum_length),
70-
lambda: tf.ones([batch_size,], dtype=tf.bool),
71-
lambda: done)
72-
return (done, cell_state, cell_input, cell_output, context_state)
52+
def input_fn(sampled_idx):
53+
return tf.stop_gradient(
54+
tf.gather_nd(enc_outputs, index_matrix_to_pairs(sampled_idx)))
7355

7456
if is_train:
7557
decoder_fn = simple_decoder_fn_train(enc_final_states)
7658
else:
77-
decoder_fn = decoder_fn_inference
59+
def decoder_fn(time, cell_state, cell_input, cell_output, context_state):
60+
cell_output = output_fn(enc_outputs, cell_output, num_glimpse)
61+
62+
if cell_state is None:
63+
cell_state = enc_final_states
64+
next_input = tf.squeeze(first_decoder_input, 1)
65+
done = tf.zeros([batch_size,], dtype=tf.bool)
66+
else:
67+
sampled_idx = tf.cast(tf.argmax(cell_output, 1), tf.int32)
68+
next_input = input_fn(sampled_idx)
69+
done = tf.equal(sampled_idx, end_of_sequence_id)
70+
return (done, cell_state, next_input, cell_output, context_state)
7871

7972
outputs, final_state, final_context_state = \
8073
dynamic_rnn_decoder(cell, decoder_fn, inputs=inputs,

model.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,20 +122,23 @@ def _build_model(self):
122122

123123
if self.use_terminal_symbol:
124124
# 0 index indicates terminal
125-
first_decoder_input = tf.expand_dims(trainable_initial_state(
125+
self.first_decoder_input = tf.expand_dims(trainable_initial_state(
126126
batch_size, self.hidden_dim, name="first_decoder_input"), 1)
127127
self.enc_outputs = tf.concat_v2(
128-
[first_decoder_input, self.enc_outputs], axis=1)
128+
[self.first_decoder_input, self.enc_outputs], axis=1)
129129

130130
with tf.variable_scope("dencoder"):
131+
self.idx_pairs = index_matrix_to_pairs(self.dec_targets)
132+
self.embeded_dec_inputs = tf.stop_gradient(
133+
tf.gather_nd(self.enc_outputs, self.idx_pairs))
134+
131135
if self.use_terminal_symbol:
132136
tiled_zero_idxs = tf.tile(tf.zeros(
133137
[1, 1], dtype=tf.int32), [batch_size, 1], name="tiled_zero_idxs")
134138
self.dec_targets = tf.concat_v2([self.dec_targets, tiled_zero_idxs], axis=1)
135139

136-
self.idx_pairs = index_matrix_to_pairs(self.dec_targets)
137-
self.embeded_dec_inputs = tf.stop_gradient(
138-
tf.gather_nd(self.enc_outputs, self.idx_pairs))
140+
self.embeded_dec_inputs = tf.concat_v2(
141+
[self.first_decoder_input, self.embeded_dec_inputs], axis=1)
139142

140143
self.dec_cell = LSTMCell(
141144
self.hidden_dim,
@@ -148,41 +151,51 @@ def _build_model(self):
148151
self.dec_pred_logits, _, _ = decoder_rnn(
149152
self.dec_cell, self.embeded_dec_inputs,
150153
self.enc_outputs, self.enc_final_states,
151-
self.dec_seq_length, self.hidden_dim, self.num_glimpse,
152-
self.max_dec_length, batch_size, is_train=True,
154+
self.dec_seq_length, self.hidden_dim,
155+
self.num_glimpse, batch_size, is_train=True,
153156
initializer=self.initializer)
154157
self.dec_pred_prob = tf.nn.softmax(
155158
self.dec_pred_logits, 2, name="dec_pred_prob")
156159
self.dec_pred = tf.argmax(
157160
self.dec_pred_logits, 2, name="dec_pred")
158161

159162
with tf.variable_scope("dencoder", reuse=True):
160-
self.dec_inference_outputs, _, self.dec_inference = decoder_rnn(
161-
self.dec_cell, first_decoder_input,
163+
self.dec_inference_logits, _, _ = decoder_rnn(
164+
self.dec_cell, None,
162165
self.enc_outputs, self.enc_final_states,
163-
self.dec_seq_length, self.hidden_dim, self.num_glimpse,
164-
self.max_dec_length, batch_size, is_train=False,
165-
initializer=self.initializer)
166+
self.dec_seq_length, self.hidden_dim,
167+
self.num_glimpse, batch_size, is_train=False,
168+
initializer=self.initializer, first_decoder_input=self.first_decoder_input)
166169
self.dec_inference_prob = tf.nn.softmax(
167-
self.dec_inference_outputs, 2, name="dec_inference_prob")
170+
self.dec_inference_logits, 2, name="dec_inference_logits")
171+
self.dec_inference = tf.argmax(
172+
self.dec_inference_logits, 2, name="dec_inference")
168173

169174
def _build_optim(self):
170175
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
171176
labels=self.dec_targets, logits=self.dec_pred_logits)
177+
inference_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
178+
labels=self.dec_targets, logits=self.dec_inference_logits)
172179

173180
def apply_mask(op):
174181
length = tf.cast(op[:1], tf.int32)
175182
loss = op[1:]
176183
return tf.multiply(loss, tf.ones(length, dtype=tf.float32))
177184

178-
batch_loss = tf.div(tf.reduce_sum(tf.multiply(losses, self.mask)),
179-
tf.reduce_sum(self.mask), name="batch_loss")
185+
batch_loss = tf.div(
186+
tf.reduce_sum(tf.multiply(losses, self.mask)),
187+
tf.reduce_sum(self.mask), name="batch_loss")
188+
189+
batch_inference_loss = tf.div(
190+
tf.reduce_sum(tf.multiply(losses, self.mask)),
191+
tf.reduce_sum(self.mask), name="batch_inference_loss")
180192

181193
tf.losses.add_loss(batch_loss)
182194
total_loss = tf.losses.get_total_loss()
183195

184196
self.total_loss = total_loss
185197
self.target_cross_entropy_losses = losses
198+
self.total_inference_loss = batch_inference_loss
186199

187200
self.lr = tf.train.exponential_decay(
188201
self.lr_start, self.global_step, self.lr_decay_step,

trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,13 @@ def train(self):
7676

7777
if result['step'] % self.log_step == 0:
7878
fetch = {
79-
'loss': self.model.total_loss,
80-
'pred': self.model.dec_inference,
81-
'targets': self.model.dec_targets,
79+
'test l': self.model.total_inference_loss,
80+
'test x': self.model.dec_inference,
81+
'test y': self.model.dec_targets,
8282
}
8383
result = self.model.test(self.sess, fetch, self.summary_writer)
8484

85+
tf.logging.info("")
8586
tf.logging.info("loss: {}".format(result['loss']))
8687
for idx in range(self.num_log_samples):
8788
tf.logging.info("preds: {}".format(result['preds'][idx]))

0 commit comments

Comments
 (0)