33case with a (very) small and simple Feedforward Network training on MNIST
44dataset with a learning rate scheduler. In this case ReduceLROnPlateau
55scheduler is used, but can easily be changed to any of the other schedulers
6- available.
7-
8- Video explanation: https://youtu.be/P31hB37g4Ak
9- Got any questions leave a comment on youtube :)
6+ available. I think simply reducing LR by 1/10 or so, when loss plateaus is
7+ a good default.
108
119Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
1210* 2020-04-10 Initial programming
11+ * 2022-12-19 Updated comments, made sure it works with latest PyTorch
1312
1413"""
1514
2827
2928# Hyperparameters
3029num_classes = 10
31- learning_rate = 0.1
30+ learning_rate = (
31+ 0.1 # way too high learning rate, but we want to see the scheduler in action
32+ )
3233batch_size = 128
3334num_epochs = 100
3435
4748
4849# Define Scheduler
4950scheduler = torch .optim .lr_scheduler .ReduceLROnPlateau (
50- optimizer , factor = 0.1 , patience = 5 , verbose = True
51+ optimizer , factor = 0.1 , patience = 10 , verbose = True
5152)
5253
5354# Train Network
6768 losses .append (loss .item ())
6869
6970 # backward
71+ optimizer .zero_grad ()
7072 loss .backward ()
71-
72- # gradient descent or adam step
73- # scheduler.step(loss)
7473 optimizer .step ()
75- optimizer .zero_grad ()
7674
7775 mean_loss = sum (losses ) / len (losses )
76+ mean_loss = round (mean_loss , 2 ) # we should see difference in loss at 2 decimals
7877
7978 # After each epoch do scheduler.step, note in this scheduler we need to send
80- # in loss for that epoch!
79+ # in loss for that epoch! This can also be set using validation loss, and also
80+ # in the forward loop we can do on our batch but then we might need to modify
81+ # the patience parameter
8182 scheduler .step (mean_loss )
82- print (f"Cost at epoch { epoch } is { mean_loss } " )
83+ print (f"Average loss for epoch { epoch } was { mean_loss } " )
8384
8485# Check accuracy on training & test to see how good our model
8586def check_accuracy (loader , model ):
@@ -90,6 +91,7 @@ def check_accuracy(loader, model):
9091 with torch .no_grad ():
9192 for x , y in loader :
9293 x = x .to (device = device )
94+ x = x .reshape (x .shape [0 ], - 1 )
9395 y = y .to (device = device )
9496
9597 scores = model (x )
0 commit comments