@@ -71,14 +71,22 @@ def main():
7171 else :
7272 model = torch .nn .DataParallel (model ).cuda ()
7373
74- # optionally resume from a checkpoint
74+ # define loss function (criterion) and optimizer
75+ criterion = nn .CrossEntropyLoss ().cuda ()
76+
77+ optimizer = torch .optim .SGD (model .parameters (), args .lr ,
78+ momentum = args .momentum ,
79+ weight_decay = args .weight_decay )
80+
81+ # optionally resume from a checkpoint
7582 if args .resume :
7683 if os .path .isfile (args .resume ):
7784 print ("=> loading checkpoint '{}'" .format (args .resume ))
7885 checkpoint = torch .load (args .resume )
7986 args .start_epoch = checkpoint ['epoch' ]
8087 best_prec1 = checkpoint ['best_prec1' ]
8188 model .load_state_dict (checkpoint ['state_dict' ])
89+ optimizer .load_state_dict (checkpoint ['optimizer' ])
8290 print ("=> loaded checkpoint '{}' (epoch {})"
8391 .format (args .resume , checkpoint ['epoch' ]))
8492 else :
@@ -112,13 +120,6 @@ def main():
112120 batch_size = args .batch_size , shuffle = False ,
113121 num_workers = args .workers , pin_memory = True )
114122
115- # define loss function (criterion) and optimizer
116- criterion = nn .CrossEntropyLoss ().cuda ()
117-
118- optimizer = torch .optim .SGD (model .parameters (), args .lr ,
119- momentum = args .momentum ,
120- weight_decay = args .weight_decay )
121-
122123 if args .evaluate :
123124 validate (val_loader , model , criterion )
124125 return
@@ -140,6 +141,7 @@ def main():
140141 'arch' : args .arch ,
141142 'state_dict' : model .state_dict (),
142143 'best_prec1' : best_prec1 ,
144+ 'optimizer' : optimizer .state_dict (),
143145 }, is_best )
144146
145147
0 commit comments