@@ -24,6 +24,7 @@ def translate_batch(exe,
2424 n_best ,
2525 batch_size ,
2626 n_head ,
27+  d_model ,
2728 src_pad_idx ,
2829 trg_pad_idx ,
2930 bos_idx ,
@@ -43,6 +44,11 @@ def translate_batch(exe,
4344 return_pos = True ,
4445 return_attn_bias = True ,
4546 return_max_len = False )
47+  # Append the data shape input to reshape the output of embedding layer. 
48+  enc_in_data  =  enc_in_data  +  [
49+  np .array (
50+  [- 1 , enc_in_data [2 ].shape [- 1 ], d_model ], dtype = "int32" )
51+  ]
4652 # Append the shape inputs to reshape before and after softmax in encoder 
4753 # self attention. 
4854 enc_in_data  =  enc_in_data  +  [
@@ -59,9 +65,14 @@ def translate_batch(exe,
5965 scores  =  np .zeros ((batch_size , beam_size ), dtype = "float32" )
6066 prev_branchs  =  [[] for  i  in  range (batch_size )]
6167 next_ids  =  [[] for  i  in  range (batch_size )]
62-  # Use beam_map  to map the instance idx in batch to beam idx , since the 
68+  # Use beam_inst_map  to map beam idx to  the instance idx in batch, since the 
6369 # size of feeded batch is changing. 
64-  beam_map  =  range (batch_size )
70+  beam_inst_map  =  {
71+  beam_idx : inst_idx 
72+  for  inst_idx , beam_idx  in  enumerate (range (batch_size ))
73+  }
74+  # Use active_beams to recode the alive. 
75+  active_beams  =  range (batch_size )
6576
6677 def  beam_backtrace (prev_branchs , next_ids , n_best = beam_size ):
6778 """ 
@@ -98,8 +109,14 @@ def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output):
98109 [- 1e9 ]).astype ("float32" )
99110 # This is used to remove attention on the paddings of source sequences. 
100111 trg_src_attn_bias  =  np .tile (
101-  src_slf_attn_bias [:, :, ::src_max_length , :],
102-  [beam_size , 1 , trg_max_len , 1 ])
112+  src_slf_attn_bias [:, :, ::src_max_length , :][:, np .newaxis ],
113+  [1 , beam_size , 1 , trg_max_len , 1 ]).reshape ([
114+  - 1 , src_slf_attn_bias .shape [1 ], trg_max_len ,
115+  src_slf_attn_bias .shape [- 1 ]
116+  ])
117+  # Append the shape input to reshape the output of embedding layer. 
118+  trg_data_shape  =  np .array (
119+  [batch_size  *  beam_size , trg_max_len , d_model ], dtype = "int32" )
103120 # Append the shape inputs to reshape before and after softmax in 
104121 # decoder self attention. 
105122 trg_slf_attn_pre_softmax_shape  =  np .array (
@@ -112,22 +129,24 @@ def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output):
112129 [- 1 , trg_src_attn_bias .shape [- 1 ]], dtype = "int32" )
113130 trg_src_attn_post_softmax_shape  =  np .array (
114131 trg_src_attn_bias .shape , dtype = "int32" )
115-  enc_output  =  np .tile (enc_output , [beam_size , 1 , 1 ])
132+  enc_output  =  np .tile (
133+  enc_output [:, np .newaxis ], [1 , beam_size , 1 , 1 ]).reshape (
134+  [- 1 , enc_output .shape [- 2 ], enc_output .shape [- 1 ]])
116135 return  trg_words , trg_pos , trg_slf_attn_bias , trg_src_attn_bias , \
117-  trg_slf_attn_pre_softmax_shape ,  trg_slf_attn_post_softmax_shape , \
118-  trg_src_attn_pre_softmax_shape ,  trg_src_attn_post_softmax_shape , \
119-  enc_output 
136+  trg_data_shape ,  trg_slf_attn_pre_softmax_shape , \
137+  trg_slf_attn_post_softmax_shape ,  trg_src_attn_pre_softmax_shape , \
138+  trg_src_attn_post_softmax_shape ,  enc_output 
120139
121-  def  update_dec_in_data (dec_in_data , next_ids , active_beams ):
140+  def  update_dec_in_data (dec_in_data , next_ids , active_beams ,  beam_inst_map ):
122141 """ 
123142 Update the input data of decoder mainly by slicing from the previous 
124143 input data and dropping the finished instance beams. 
125144 """ 
126145 trg_words , trg_pos , trg_slf_attn_bias , trg_src_attn_bias , \
127-  trg_slf_attn_pre_softmax_shape ,  trg_slf_attn_post_softmax_shape , \
128-  trg_src_attn_pre_softmax_shape ,  trg_src_attn_post_softmax_shape , \
129-  enc_output  =  dec_in_data 
130-  trg_cur_len  =  len ( next_ids [ 0 ])  +  1   # include the <bos> 
146+  trg_data_shape ,  trg_slf_attn_pre_softmax_shape , \
147+  trg_slf_attn_post_softmax_shape ,  trg_src_attn_pre_softmax_shape , \
148+  trg_src_attn_post_softmax_shape ,  enc_output  =  dec_in_data 
149+  trg_cur_len  =  trg_slf_attn_bias . shape [ - 1 ]  +  1 
131150 trg_words  =  np .array (
132151 [
133152 beam_backtrace (prev_branchs [beam_idx ], next_ids [beam_idx ])
@@ -138,6 +157,7 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
138157 trg_pos  =  np .array (
139158 [range (1 , trg_cur_len  +  1 )] *  len (active_beams ) *  beam_size ,
140159 dtype = "int64" ).reshape ([- 1 , 1 ])
160+  active_beams  =  [beam_inst_map [beam_idx ] for  beam_idx  in  active_beams ]
141161 active_beams_indice  =  (
142162 (np .array (active_beams ) *  beam_size )[:, np .newaxis ] + 
143163 np .array (range (beam_size ))[np .newaxis , :]).flatten ()
@@ -152,6 +172,10 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
152172 trg_src_attn_bias  =  np .tile (trg_src_attn_bias [
153173 active_beams_indice , :, ::trg_src_attn_bias .shape [2 ], :],
154174 [1 , 1 , trg_cur_len , 1 ])
175+  # Append the shape input to reshape the output of embedding layer. 
176+  trg_data_shape  =  np .array (
177+  [len (active_beams ) *  beam_size , trg_cur_len , d_model ],
178+  dtype = "int32" )
155179 # Append the shape inputs to reshape before and after softmax in 
156180 # decoder self attention. 
157181 trg_slf_attn_pre_softmax_shape  =  np .array (
@@ -166,9 +190,9 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
166190 trg_src_attn_bias .shape , dtype = "int32" )
167191 enc_output  =  enc_output [active_beams_indice , :, :]
168192 return  trg_words , trg_pos , trg_slf_attn_bias , trg_src_attn_bias , \
169-  trg_slf_attn_pre_softmax_shape ,  trg_slf_attn_post_softmax_shape , \
170-  trg_src_attn_pre_softmax_shape ,  trg_src_attn_post_softmax_shape , \
171-  enc_output 
193+  trg_data_shape ,  trg_slf_attn_pre_softmax_shape , \
194+  trg_slf_attn_post_softmax_shape ,  trg_src_attn_pre_softmax_shape , \
195+  trg_src_attn_post_softmax_shape ,  enc_output 
172196
173197 dec_in_data  =  init_dec_in_data (batch_size , beam_size , enc_in_data ,
174198 enc_output )
@@ -177,15 +201,18 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
177201 feed = dict (zip (dec_in_names , dec_in_data )),
178202 fetch_list = dec_out_names )[0 ]
179203 predict_all  =  np .log (
180-  predict_all .reshape ([len (beam_map ) *  beam_size , i  +  1 , - 1 ])[:, 
181-     - 1 , :])
182-  predict_all  =  (predict_all  +  scores [beam_map ].reshape (
183-  [len (beam_map ) *  beam_size , - 1 ])).reshape (
184-  [len (beam_map ), beam_size , - 1 ])
204+  predict_all .reshape ([len (beam_inst_map ) *  beam_size , i  +  1 , - 1 ])
205+  [:,  - 1 , :])
206+  predict_all  =  (predict_all  +  scores [active_beams ].reshape (
207+  [len (beam_inst_map ) *  beam_size , - 1 ])).reshape (
208+  [len (beam_inst_map ), beam_size , - 1 ])
185209 if  not  output_unk : # To exclude the <unk> token. 
186210 predict_all [:, :, unk_idx ] =  - 1e9 
187211 active_beams  =  []
188-  for  inst_idx , beam_idx  in  enumerate (beam_map ):
212+  for  beam_idx  in  range (batch_size ):
213+  if  not  beam_inst_map .has_key (beam_idx ):
214+  continue 
215+  inst_idx  =  beam_inst_map [beam_idx ]
189216 predict  =  (predict_all [inst_idx , :, :]
190217 if  i  !=  0  else  predict_all [inst_idx , 0 , :]).flatten ()
191218 top_k_indice  =  np .argpartition (predict , - beam_size )[- beam_size :]
@@ -198,10 +225,14 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
198225 next_ids [beam_idx ].append (top_scores_ids  %  predict_all .shape [- 1 ])
199226 if  next_ids [beam_idx ][- 1 ][0 ] !=  eos_idx :
200227 active_beams .append (beam_idx )
201-  beam_map  =  active_beams 
202-  if  len (beam_map ) ==  0 :
228+  if  len (active_beams ) ==  0 :
203229 break 
204-  dec_in_data  =  update_dec_in_data (dec_in_data , next_ids , active_beams )
230+  dec_in_data  =  update_dec_in_data (dec_in_data , next_ids , active_beams ,
231+  beam_inst_map )
232+  beam_inst_map  =  {
233+  beam_idx : inst_idx 
234+  for  inst_idx , beam_idx  in  enumerate (active_beams )
235+  }
205236
206237 # Decode beams and select n_best sequences for each instance by backtrace. 
207238 seqs  =  [
@@ -215,10 +246,8 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
215246def  main ():
216247 place  =  fluid .CUDAPlace (0 ) if  InferTaskConfig .use_gpu  else  fluid .CPUPlace ()
217248 exe  =  fluid .Executor (place )
218-  # The current program desc is coupled with batch_size and the only 
219-  # supported batch size is 1 currently. 
249+ 
220250 encoder_program  =  fluid .Program ()
221-  model .batch_size  =  InferTaskConfig .batch_size 
222251 with  fluid .program_guard (main_program = encoder_program ):
223252 enc_output  =  encoder (
224253 ModelHyperParams .src_vocab_size  +  1 ,
@@ -228,7 +257,6 @@ def main():
228257 ModelHyperParams .d_inner_hid , ModelHyperParams .dropout ,
229258 ModelHyperParams .src_pad_idx , ModelHyperParams .pos_pad_idx )
230259
231-  model .batch_size  =  InferTaskConfig .batch_size  *  InferTaskConfig .beam_size 
232260 decoder_program  =  fluid .Program ()
233261 with  fluid .program_guard (main_program = decoder_program ):
234262 predict  =  decoder (
@@ -273,6 +301,9 @@ def main():
273301
274302 trg_idx2word  =  paddle .dataset .wmt16 .get_dict (
275303 "de" , dict_size = ModelHyperParams .trg_vocab_size , reverse = True )
304+  # Append the <pad> token since the dict provided by dataset.wmt16 does 
305+  # not include it. 
306+  trg_idx2word [ModelHyperParams .trg_pad_idx ] =  "<pad>" 
276307
277308 def  post_process_seq (seq ,
278309 bos_idx = ModelHyperParams .bos_idx ,
@@ -306,6 +337,7 @@ def post_process_seq(seq,
306337 InferTaskConfig .n_best ,
307338 len (data ),
308339 ModelHyperParams .n_head ,
340+  ModelHyperParams .d_model ,
309341 ModelHyperParams .src_pad_idx ,
310342 ModelHyperParams .trg_pad_idx ,
311343 ModelHyperParams .bos_idx ,
0 commit comments