120120
121121def  NMTCriterion (vocabSize ):
122122 weight  =  torch .ones (vocabSize )
123-  weightonmt .Constants .PAD ] =  0 
123+  weight [ onmt .Constants .PAD ] =  0 
124124 crit  =  nn .NLLLoss (weight , size_average = False )
125125 if  opt .gpus :
126126 crit .cuda ()
@@ -129,40 +129,45 @@ def NMTCriterion(vocabSize):
129129
130130def  memoryEfficientLoss (outputs , targets , generator , crit , eval = False ):
131131 # compute generations one piece at a time 
132-  loss  =  0 
132+  num_correct ,  loss  =   0 ,  0 
133133 outputs  =  Variable (outputs .data , requires_grad = (not  eval ), volatile = eval )
134134
135135 batch_size  =  outputs .size (1 )
136136 outputs_split  =  torch .split (outputs , opt .max_generator_batches )
137137 targets_split  =  torch .split (targets , opt .max_generator_batches )
138-  for  out_t , targ_t  in  zip (outputs_split , targets_split ):
138+  for  i , ( out_t , targ_t )  in  enumerate ( zip (outputs_split , targets_split ) ):
139139 out_t  =  out_t .view (- 1 , out_t .size (2 ))
140-  pred_t  =  generator (out_t )
141-  loss_t  =  crit (pred_t , targ_t .view (- 1 ))
140+  scores_t  =  generator (out_t )
141+  loss_t  =  crit (scores_t , targ_t .view (- 1 ))
142+  pred_t  =  scores_t .max (1 )[1 ]
143+  num_correct_t  =  pred_t .data .eq (targ_t .data ).masked_select (targ_t .ne (onmt .Constants .PAD ).data ).sum ()
144+  num_correct  +=  num_correct_t 
142145 loss  +=  loss_t .data [0 ]
143146 if  not  eval :
144147 loss_t .div (batch_size ).backward ()
145148
146149 grad_output  =  None  if  outputs .grad  is  None  else  outputs .grad .data 
147-  return  loss , grad_output 
150+  return  loss , grad_output ,  num_correct 
148151
149152
150153def  eval (model , criterion , data ):
151154 total_loss  =  0 
152155 total_words  =  0 
156+  total_num_correct  =  0 
153157
154158 model .eval ()
155159 for  i  in  range (len (data )):
156160 batch  =  data [i ]
157161 outputs  =  model (batch ) # FIXME volatile 
158162 targets  =  batch [1 ][1 :] # exclude <s> from targets 
159-  loss , _  =  memoryEfficientLoss (
163+  loss , _ ,  num_correct  =  memoryEfficientLoss (
160164 outputs , targets , model .generator , criterion , eval = True )
161165 total_loss  +=  loss 
166+  total_num_correct  +=  num_correct 
162167 total_words  +=  targets .data .ne (onmt .Constants .PAD ).sum ()
163168
164169 model .train ()
165-  return  total_loss  /  total_words 
170+  return  total_loss  /  total_words ,  total_num_correct   /   total_words 
166171
167172
168173def  trainModel (model , trainData , validData , dataset , optim ):
@@ -183,6 +188,7 @@ def trainEpoch(epoch):
183188
184189 total_loss , report_loss  =  0 , 0 
185190 total_words , report_tgt_words , report_src_words  =  0 , 0 , 0 
191+  total_num_correct  =  0 
186192 start  =  time .time ()
187193 for  i  in  range (len (trainData )):
188194
@@ -192,7 +198,7 @@ def trainEpoch(epoch):
192198 model .zero_grad ()
193199 outputs  =  model (batch )
194200 targets  =  batch [1 ][1 :] # exclude <s> from targets 
195-  loss , gradOutput  =  memoryEfficientLoss (
201+  loss , gradOutput ,  num_correct  =  memoryEfficientLoss (
196202 outputs , targets , model .generator , criterion )
197203
198204 outputs .backward (gradOutput )
@@ -201,36 +207,40 @@ def trainEpoch(epoch):
201207 optim .step ()
202208
203209 report_loss  +=  loss 
210+  total_num_correct  +=  num_correct 
204211 total_loss  +=  loss 
205212 num_words  =  targets .data .ne (onmt .Constants .PAD ).sum ()
206213 total_words  +=  num_words 
207214 report_tgt_words  +=  num_words 
208215 report_src_words  +=  batch [0 ].data .ne (onmt .Constants .PAD ).sum ()
209216 if  i  %  opt .log_interval  ==  0  and  i  >  0 :
210-  print ("Epoch %2d, %5d/%5d batches; perplexity : %6.2f; %3.0f source tokens /s; %3.0f target tokens /s; %6.0f s elapsed"  % 
217+  print ("Epoch %2d, %5d/%5d; ppl: %6.2f; acc : %6.2f; %3.0f src tok /s; %3.0f tgt tok /s; %6.0f s elapsed"  % 
211218 (epoch , i , len (trainData ),
212219 math .exp (report_loss  /  report_tgt_words ),
220+  num_correct  /  num_words  *  100 ,
213221 report_src_words / (time .time ()- start ),
214222 report_tgt_words / (time .time ()- start ),
215223 time .time ()- start_time ))
216224
217225 report_loss  =  report_tgt_words  =  report_src_words  =  0 
218226 start  =  time .time ()
219227
220-  return  total_loss  /  total_words 
228+  return  total_loss  /  total_words ,  total_num_correct   /   total_words 
221229
222230 for  epoch  in  range (opt .start_epoch , opt .epochs  +  1 ):
223231 print ('' )
224232
225233 # (1) train for one epoch on the training set 
226-  train_loss  =  trainEpoch (epoch )
234+  train_loss ,  train_acc  =  trainEpoch (epoch )
227235 train_ppl  =  math .exp (min (train_loss , 100 ))
228236 print ('Train perplexity: %g'  %  train_ppl )
237+  print ('Train accuracy: %g'  %  train_acc )
229238
230239 # (2) evaluate on the validation set 
231-  valid_loss  =  eval (model , criterion , validData )
240+  valid_loss ,  valid_acc  =  eval (model , criterion , validData )
232241 valid_ppl  =  math .exp (min (valid_loss , 100 ))
233242 print ('Validation perplexity: %g'  %  valid_ppl )
243+  print ('Validation accuracy: %g'  %  valid_acc )
234244
235245 # (3) maybe update the learning rate 
236246 if  opt .update_learning_rate :
@@ -245,7 +255,7 @@ def trainEpoch(epoch):
245255 'optim' : optim ,
246256 }
247257 torch .save (checkpoint ,
248-  '%s_val_%.2f_train_ %.2f_e%d.pt'  %  (opt .save_model , valid_ppl ,  train_ppl , epoch ))
258+  '%s_acc_%.2f_ppl_ %.2f_e%d.pt'  %  (opt .save_model , valid_acc ,  valid_ppl , epoch ))
249259
250260def  main ():
251261
0 commit comments