@@ -122,20 +122,23 @@ def _build_model(self):
122122
123123 if self .use_terminal_symbol :
124124 # 0 index indicates terminal
125- first_decoder_input = tf .expand_dims (trainable_initial_state (
125+ self . first_decoder_input = tf .expand_dims (trainable_initial_state (
126126 batch_size , self .hidden_dim , name = "first_decoder_input" ), 1 )
127127 self .enc_outputs = tf .concat_v2 (
128- [first_decoder_input , self .enc_outputs ], axis = 1 )
128+ [self . first_decoder_input , self .enc_outputs ], axis = 1 )
129129
130130 with tf .variable_scope ("dencoder" ):
131+ self .idx_pairs = index_matrix_to_pairs (self .dec_targets )
132+ self .embeded_dec_inputs = tf .stop_gradient (
133+ tf .gather_nd (self .enc_outputs , self .idx_pairs ))
134+
131135 if self .use_terminal_symbol :
132136 tiled_zero_idxs = tf .tile (tf .zeros (
133137 [1 , 1 ], dtype = tf .int32 ), [batch_size , 1 ], name = "tiled_zero_idxs" )
134138 self .dec_targets = tf .concat_v2 ([self .dec_targets , tiled_zero_idxs ], axis = 1 )
135139
136- self .idx_pairs = index_matrix_to_pairs (self .dec_targets )
137- self .embeded_dec_inputs = tf .stop_gradient (
138- tf .gather_nd (self .enc_outputs , self .idx_pairs ))
140+ self .embeded_dec_inputs = tf .concat_v2 (
141+ [self .first_decoder_input , self .embeded_dec_inputs ], axis = 1 )
139142
140143 self .dec_cell = LSTMCell (
141144 self .hidden_dim ,
@@ -148,41 +151,51 @@ def _build_model(self):
148151 self .dec_pred_logits , _ , _ = decoder_rnn (
149152 self .dec_cell , self .embeded_dec_inputs ,
150153 self .enc_outputs , self .enc_final_states ,
151- self .dec_seq_length , self .hidden_dim , self . num_glimpse ,
152- self .max_dec_length , batch_size , is_train = True ,
154+ self .dec_seq_length , self .hidden_dim ,
155+ self .num_glimpse , batch_size , is_train = True ,
153156 initializer = self .initializer )
154157 self .dec_pred_prob = tf .nn .softmax (
155158 self .dec_pred_logits , 2 , name = "dec_pred_prob" )
156159 self .dec_pred = tf .argmax (
157160 self .dec_pred_logits , 2 , name = "dec_pred" )
158161
159162 with tf .variable_scope ("dencoder" , reuse = True ):
160- self .dec_inference_outputs , _ , self . dec_inference = decoder_rnn (
161- self .dec_cell , first_decoder_input ,
163+ self .dec_inference_logits , _ , _ = decoder_rnn (
164+ self .dec_cell , None ,
162165 self .enc_outputs , self .enc_final_states ,
163- self .dec_seq_length , self .hidden_dim , self . num_glimpse ,
164- self .max_dec_length , batch_size , is_train = False ,
165- initializer = self .initializer )
166+ self .dec_seq_length , self .hidden_dim ,
167+ self .num_glimpse , batch_size , is_train = False ,
168+ initializer = self .initializer , first_decoder_input = self . first_decoder_input )
166169 self .dec_inference_prob = tf .nn .softmax (
167- self .dec_inference_outputs , 2 , name = "dec_inference_prob" )
170+ self .dec_inference_logits , 2 , name = "dec_inference_logits" )
171+ self .dec_inference = tf .argmax (
172+ self .dec_inference_logits , 2 , name = "dec_inference" )
168173
169174 def _build_optim (self ):
170175 losses = tf .nn .sparse_softmax_cross_entropy_with_logits (
171176 labels = self .dec_targets , logits = self .dec_pred_logits )
177+ inference_losses = tf .nn .sparse_softmax_cross_entropy_with_logits (
178+ labels = self .dec_targets , logits = self .dec_inference_logits )
172179
173180 def apply_mask (op ):
174181 length = tf .cast (op [:1 ], tf .int32 )
175182 loss = op [1 :]
176183 return tf .multiply (loss , tf .ones (length , dtype = tf .float32 ))
177184
178- batch_loss = tf .div (tf .reduce_sum (tf .multiply (losses , self .mask )),
179- tf .reduce_sum (self .mask ), name = "batch_loss" )
185+ batch_loss = tf .div (
186+ tf .reduce_sum (tf .multiply (losses , self .mask )),
187+ tf .reduce_sum (self .mask ), name = "batch_loss" )
188+
189+ batch_inference_loss = tf .div (
190+ tf .reduce_sum (tf .multiply (losses , self .mask )),
191+ tf .reduce_sum (self .mask ), name = "batch_inference_loss" )
180192
181193 tf .losses .add_loss (batch_loss )
182194 total_loss = tf .losses .get_total_loss ()
183195
184196 self .total_loss = total_loss
185197 self .target_cross_entropy_losses = losses
198+ self .total_inference_loss = batch_inference_loss
186199
187200 self .lr = tf .train .exponential_decay (
188201 self .lr_start , self .global_step , self .lr_decay_step ,
0 commit comments