@@ -186,9 +186,8 @@ def trainEpoch(epoch):
186186 # shuffle mini batch order 
187187 batchOrder  =  torch .randperm (len (trainData ))
188188
189-  total_loss , report_loss  =  0 , 0 
190-  total_words , report_tgt_words , report_src_words  =  0 , 0 , 0 
191-  total_num_correct  =  0 
189+  total_loss , total_words , total_num_correct  =  0 
190+  report_loss , report_tgt_words , report_src_words , report_num_correct  =  0 
192191 start  =  time .time ()
193192 for  i  in  range (len (trainData )):
194193
@@ -206,23 +205,24 @@ def trainEpoch(epoch):
206205 # update the parameters 
207206 optim .step ()
208207
209-  report_loss  +=  loss 
210-  total_num_correct  +=  num_correct 
211-  total_loss  +=  loss 
212208 num_words  =  targets .data .ne (onmt .Constants .PAD ).sum ()
213-  total_words  +=  num_words 
209+  report_loss  +=  loss 
210+  report_num_correct  +=  num_correct 
214211 report_tgt_words  +=  num_words 
215212 report_src_words  +=  batch [0 ].data .ne (onmt .Constants .PAD ).sum ()
213+  total_loss  +=  loss 
214+  total_num_correct  +=  num_correct 
215+  total_words  +=  num_words 
216216 if  i  %  opt .log_interval  ==  - 1  %  opt .log_interval :
217217 print ("Epoch %2d, %5d/%5d; acc: %6.2f; ppl: %6.2f; %3.0f src tok/s; %3.0f tgt tok/s; %6.0f s elapsed"  % 
218218 (epoch , i , len (trainData ),
219-  num_correct  /  num_words  *  100 ,
219+  report_num_correct  /  report_tgt_words  *  100 ,
220220 math .exp (report_loss  /  report_tgt_words ),
221221 report_src_words / (time .time ()- start ),
222222 report_tgt_words / (time .time ()- start ),
223223 time .time ()- start_time ))
224224
225-  report_loss  =  report_tgt_words  =  report_src_words  =  0 
225+  report_loss  =  report_tgt_words  =  report_src_words  =  report_num_correct   =   0 
226226 start  =  time .time ()
227227
228228 return  total_loss  /  total_words , total_num_correct  /  total_words 
0 commit comments