@@ -345,36 +345,30 @@ def train(model,
345345 "iter_{}" .format (iter ))
346346 if not os .path .isdir (current_save_dir ):
347347 os .makedirs (current_save_dir )
348- states_dict = {
349- 'mIoU' : mean_iou ,
350- 'Acc' : acc ,
351- 'iter' : iter
352- }
353348 paddle .save (model .state_dict (),
354349 os .path .join (current_save_dir , 'model.pdparams' ))
355350 paddle .save (optimizer .state_dict (),
356351 os .path .join (current_save_dir , 'model.pdopt' ))
357- paddle .save (states_dict ,
358- os .path .join (current_save_dir , 'model.pdstates' ))
359352
360353 if use_ema :
361- ema_states_dict = {
362- 'mIoU' : ema_mean_iou ,
363- 'Acc' : ema_acc ,
364- 'iter' : iter
365- }
366354 paddle .save (
367355 ema_model .state_dict (),
368356 os .path .join (current_save_dir , 'ema_model.pdparams' ))
369- paddle .save (ema_states_dict ,
370- os .path .join (current_save_dir , 'ema_model.pdstates' ))
371357
372358 save_models .append (current_save_dir )
373359 if len (save_models ) > keep_checkpoint_max > 0 :
374360 model_to_remove = save_models .popleft ()
375361 shutil .rmtree (model_to_remove )
376362
377363 if val_dataset is not None :
364+ states_dict = {
365+ 'mIoU' : mean_iou ,
366+ 'Acc' : acc ,
367+ 'iter' : iter
368+ }
369+ paddle .save (states_dict ,
370+ os .path .join (current_save_dir , 'model.pdstates' ))
371+
378372 if mean_iou > best_mean_iou :
379373 stop_count = 0
380374 best_mean_iou = mean_iou
@@ -399,6 +393,14 @@ def train(model,
399393 .format (best_mean_iou , best_model_iter ))
400394
401395 if use_ema :
396+ ema_states_dict = {
397+ 'mIoU' : ema_mean_iou ,
398+ 'Acc' : ema_acc ,
399+ 'iter' : iter
400+ }
401+ paddle .save (ema_states_dict ,
402+ os .path .join (current_save_dir , 'ema_model.pdstates' ))
403+
402404 if ema_mean_iou > best_ema_mean_iou :
403405 best_ema_mean_iou = ema_mean_iou
404406 best_ema_model_iter = iter
0 commit comments