100100
101101# GPU
102102parser .add_argument ('-gpus' , default = [], nargs = '+' , type = int ,
103- help = "Use CUDA" )
103+ help = "Use CUDA on the listed devices. " )
104104
105105parser .add_argument ('-log_interval' , type = int , default = 50 ,
106106 help = "Print stats at this interval." )
@@ -255,7 +255,8 @@ def trainEpoch(epoch):
255255 'dicts' : dataset ['dicts' ],
256256 'opt' : opt ,
257257 'epoch' : epoch ,
258- 'optim' : optim ,
258+ 'optimizer' : optim .optimizer .state_dict (),
259+ 'last_ppl' : optim .last_ppl ,
259260 }
260261 torch .save (checkpoint ,
261262 '%s_acc_%.2f_ppl_%.2f_e%d.pt' % (opt .save_model , 100 * valid_acc , valid_ppl , epoch ))
@@ -299,12 +300,14 @@ def main():
299300 print ('Loading model from checkpoint at %s' % opt .train_from )
300301 model .load_state_dict (checkpoint ['model' ])
301302 generator .load_state_dict (checkpoint ['generator' ])
302- optim = checkpoint ['optim' ]
303303 opt .start_epoch = checkpoint ['epoch' ] + 1
304304
305305 if len (opt .gpus ) >= 1 :
306306 model .cuda ()
307307 generator .cuda ()
308+ else :
309+ model .cpu ()
310+ generator .cpu ()
308311
309312 if len (opt .gpus ) > 1 :
310313 model = nn .DataParallel (model , device_ids = opt .gpus , dim = 1 )
@@ -316,11 +319,15 @@ def main():
316319 for p in model .parameters ():
317320 p .data .uniform_ (- opt .param_init , opt .param_init )
318321
319- optim = onmt .Optim (
320- model .parameters (), opt .optim , opt .learning_rate , opt .max_grad_norm ,
321- lr_decay = opt .learning_rate_decay ,
322- start_decay_at = opt .start_decay_at
323- )
322+ optim = onmt .Optim (
323+ model .parameters (), opt .optim , opt .learning_rate , opt .max_grad_norm ,
324+ lr_decay = opt .learning_rate_decay ,
325+ start_decay_at = opt .start_decay_at
326+ )
327+
328+ if opt .train_from :
329+ optim .last_ppl = checkpoint ['last_ppl' ]
330+ optim .optimizer .load_state_dict (checkpoint ['optimizer' ])
324331
325332 nParams = sum ([p .nelement () for p in model .parameters ()])
326333 print ('* number of parameters: %d' % nParams )
0 commit comments