@@ -258,11 +258,11 @@ def build_training_graph(self, input_tensors):
258258 mode = 'FAN_OUT' ,
259259 uniform = True ))
260260
261-  weighted_average_contexts , _  =  self .calculate_weighted_contexts (words_vocab , paths_vocab , attention_param ,
261+  code_vectors , _  =  self .calculate_weighted_contexts (words_vocab , paths_vocab , attention_param ,
262262 source_input , path_input , target_input ,
263263 valid_mask )
264264
265-  logits  =  tf .matmul (weighted_average_contexts , target_words_vocab , transpose_b = True )
265+  logits  =  tf .matmul (code_vectors , target_words_vocab , transpose_b = True )
266266 batch_size  =  tf .to_float (tf .shape (words_input )[0 ])
267267 loss  =  tf .reduce_sum (tf .nn .sparse_softmax_cross_entropy_with_logits (
268268 labels = tf .reshape (words_input , [- 1 ]),
@@ -302,10 +302,10 @@ def calculate_weighted_contexts(self, words_vocab, paths_vocab, attention_param,
302302 attention_weights  =  tf .nn .softmax (batched_contexts_weights , axis = 1 ) # (batch, max_contexts, 1) 
303303
304304 batched_embed  =  tf .reshape (flat_embed , shape = [- 1 , max_contexts , self .config .EMBEDDINGS_SIZE  *  3 ])
305-  weighted_average_contexts  =  tf .reduce_sum (tf .multiply (batched_embed , attention_weights ),
305+  code_vectors  =  tf .reduce_sum (tf .multiply (batched_embed , attention_weights ),
306306 axis = 1 ) # (batch, dim * 3) 
307307
308-  return  weighted_average_contexts , attention_weights 
308+  return  code_vectors , attention_weights 
309309
310310 def  build_test_graph (self , input_tensors , normalize_scores = False ):
311311 with  tf .variable_scope ('model' , reuse = self .get_should_reuse_variables ()):
@@ -322,19 +322,19 @@ def build_test_graph(self, input_tensors, normalize_scores=False):
322322 shape = (self .path_vocab_size  +  1 , self .config .EMBEDDINGS_SIZE ),
323323 dtype = tf .float32 , trainable = False )
324324
325-  target_words_vocab  =  tf .transpose (target_words_vocab ) # (dim, word_vocab +1) 
325+  target_words_vocab  =  tf .transpose (target_words_vocab ) # (dim * 3, target_word_vocab +1) 
326326
327327 words_input , source_input , path_input , target_input , valid_mask , source_string , path_string , path_target_string  =  input_tensors  # (batch, 1), (batch, max_contexts) 
328328
329-  weighted_average_contexts , attention_weights  =  self .calculate_weighted_contexts (words_vocab , paths_vocab ,
329+  code_vectors , attention_weights  =  self .calculate_weighted_contexts (words_vocab , paths_vocab ,
330330 attention_param ,
331331 source_input , path_input ,
332332 target_input ,
333333 valid_mask , True )
334334
335-  cos  =  tf .matmul (weighted_average_contexts , target_words_vocab )
335+  scores  =  tf .matmul (code_vectors , target_words_vocab )  # (batch, target_word_vocab+1 )
336336
337-  topk_candidates  =  tf .nn .top_k (cos , k = tf .minimum (self .topk , self .target_word_vocab_size ))
337+  topk_candidates  =  tf .nn .top_k (scores , k = tf .minimum (self .topk , self .target_word_vocab_size ))
338338 top_indices  =  tf .to_int64 (topk_candidates .indices )
339339 top_words  =  self .index_to_target_word_table .lookup (top_indices )
340340 original_words  =  words_input 
0 commit comments