@@ -345,15 +345,29 @@ def train(model,
345345 "iter_{}" .format (iter ))
346346 if not os .path .isdir (current_save_dir ):
347347 os .makedirs (current_save_dir )
348+ states_dict = {
349+ 'mIoU' : mean_iou ,
350+ 'Acc' : acc ,
351+ 'iter' : iter
352+ }
348353 paddle .save (model .state_dict (),
349354 os .path .join (current_save_dir , 'model.pdparams' ))
350355 paddle .save (optimizer .state_dict (),
351356 os .path .join (current_save_dir , 'model.pdopt' ))
357+ paddle .save (states_dict ,
358+ os .path .join (current_save_dir , 'model.pdstates' ))
352359
353360 if use_ema :
361+ ema_states_dict = {
362+ 'mIoU' : ema_mean_iou ,
363+ 'Acc' : ema_acc ,
364+ 'iter' : iter
365+ }
354366 paddle .save (
355367 ema_model .state_dict (),
356368 os .path .join (current_save_dir , 'ema_model.pdparams' ))
369+ paddle .save (ema_states_dict ,
370+ os .path .join (current_save_dir , 'ema_model.pdstates' ))
357371
358372 save_models .append (current_save_dir )
359373 if len (save_models ) > keep_checkpoint_max > 0 :
@@ -369,6 +383,8 @@ def train(model,
369383 paddle .save (
370384 model .state_dict (),
371385 os .path .join (best_model_dir , 'model.pdparams' ))
386+ paddle .save (states_dict ,
387+ os .path .join (best_model_dir , 'model.pdstates' ))
372388 elif mean_iou < best_mean_iou :
373389 stop_count += 1
374390
@@ -391,6 +407,8 @@ def train(model,
391407 paddle .save (ema_model .state_dict (),
392408 os .path .join (best_ema_model_dir ,
393409 'ema_model.pdparams' ))
410+ paddle .save (ema_states_dict ,
411+ os .path .join (best_ema_model_dir , 'ema_model.pdstates' ))
394412 logger .info (
395413 '[EVAL] The EMA model with the best validation mIoU ({:.4f}) was saved at iter {}.'
396414 .format (best_ema_mean_iou , best_ema_model_iter ))
0 commit comments