4747 help = 'the number of heads in the encoder/decoder of the transformer model' )
4848parser .add_argument ('--dry-run' , action = 'store_true' ,
4949 help = 'verify the code and the model' )
50- parser .add_argument ('--accel' , action = 'store_true' ,help = 'Enables accelerated training' )
50+ parser .add_argument ('--accel' , action = 'store_true' ,
51+  help = 'Enables accelerated training' )
52+ parser .add_argument ('--use-optimizer' , action = 'store_true' ,
53+  help = 'Uses AdamW optimizer for gradient updating' )
5154args  =  parser .parse_args ()
5255
5356# Set the random seed manually for reproducibility. 
@@ -104,6 +107,8 @@ def batchify(data, bsz):
104107 model  =  RNNModel (args .model , ntokens , args .emsize , args .nhid , args .nlayers , args .dropout , args .tied ).to (device )
105108
106109criterion  =  nn .NLLLoss ()
110+ if  args .use_optimizer :
111+  optimizer  =  torch .optim .AdamW (model .parameters (), lr = args .lr )
107112
108113############################################################################### 
109114# Training code 
@@ -167,7 +172,10 @@ def train():
167172 data , targets  =  get_batch (train_data , i )
168173 # Starting each batch, we detach the hidden state from how it was previously produced. 
169174 # If we didn't, the model would try backpropagating all the way to start of the dataset. 
170-  model .zero_grad ()
175+  if  args .use_optimizer :
176+  optimizer .zero_grad ()
177+  else :
178+  model .zero_grad ()
171179 if  args .model  ==  'Transformer' :
172180 output  =  model (data )
173181 output  =  output .view (- 1 , ntokens )
@@ -179,8 +187,11 @@ def train():
179187
180188 # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 
181189 torch .nn .utils .clip_grad_norm_ (model .parameters (), args .clip )
182-  for  p  in  model .parameters ():
183-  p .data .add_ (p .grad , alpha = - lr )
190+  if  args .use_optimizer :
191+  optimizer .step ()
192+  else :
193+  for  p  in  model .parameters ():
194+  p .data .add_ (p .grad , alpha = - lr )
184195
185196 total_loss  +=  loss .item ()
186197
0 commit comments