11"""
2- Working code of a simple Fully Connected (FC) network training on MNIST dataset.
3- The code is intended to show how to create a FC network as well
4- as how to initialize loss, optimizer, etc. in a simple way to get
5- training to work with function that checks accuracy as well.
2+ A simple walkthrough of how to code a fully connected neural network
3+ using the PyTorch library. For demonstration we train it on the very
4+ common MNIST dataset of handwritten digits. In this code we go through
5+ how to create the network as well as initialize a loss function, optimizer,
6+ check accuracy and more.
67
7- Video explanation: https://youtu.be/Jy4wM2X21u0
8- Got any questions leave a comment on youtube :)
9-
10- Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
11- * 2020-04-08 Initial coding
8+ Programmed by Aladdin Persson
9+ * 2020-04-08: Initial coding
10+ * 2021-03-24: Added more detailed comments also removed part of
11+ check_accuracy which would only work specifically on MNIST.
1212
1313"""
1414
1515# Imports
1616import torch
17- import torchvision
18- import torch .nn as nn # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
19- import torch .optim as optim # For all Optimization algorithms, SGD, Adam, etc.
20- import torch .nn .functional as F # All functions that don't have any parameters
21- from torch .utils .data import (
22- DataLoader ,
23- ) # Gives easier dataset managment and creates mini batches
24- import torchvision .datasets as datasets # Has standard datasets we can import in a nice way
25- import torchvision .transforms as transforms # Transformations we can perform on our dataset
26-
27- # Create Fully Connected Network
17+ import torchvision # torch package for vision related things
18+ import torch .nn .functional as F # Parameterless functions, like (some) activation functions
19+ import torchvision .datasets as datasets # Standard datasets
20+ import torchvision .transforms as transforms # Transformations we can perform on our dataset for augmentation
21+ from torch import optim # For optimizers like SGD, Adam, etc.
22+ from torch import nn # All neural network modules
23+ from torch .utils .data import DataLoader # Gives easier dataset managment by creating mini batches etc.
24+ from tqdm import tqdm # For nice progress bar!
25+
26+ # Here we create our simple neural network. For more details here we are subclassing and
27+ # inheriting from nn.Module, this is the most general way to create your networks and
28+ # allows for more flexibility. I encourage you to also check out nn.Sequential which
29+ # would be easier to use in this scenario but I wanted to show you something that
30+ # "always" works.
2831class NN (nn .Module ):
2932 def __init__ (self , input_size , num_classes ):
3033 super (NN , self ).__init__ ()
34+ # Our first linear layer take input_size, in this case 784 nodes to 50
35+ # and our second linear layer takes 50 to the num_classes we have, in
36+ # this case 10.
3137 self .fc1 = nn .Linear (input_size , 50 )
3238 self .fc2 = nn .Linear (50 , num_classes )
3339
3440 def forward (self , x ):
41+ """
42+ x here is the mnist images and we run it through fc1, fc2 that we created above.
43+ we also add a ReLU activation function in between and for that (since it has no parameters)
44+ I recommend using nn.functional (F)
45+ """
46+
3547 x = F .relu (self .fc1 (x ))
3648 x = self .fc2 (x )
3749 return x
3850
3951
40- # Set device
52+ # Set device cuda for GPU if it's available otherwise run on the CPU
4153device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
4254
43- # Hyperparameters
55+ # Hyperparameters of our neural network which depends on the dataset, and
56+ # also just experimenting to see what works well (learning rate for example).
4457input_size = 784
4558num_classes = 10
4659learning_rate = 0.001
4760batch_size = 64
48- num_epochs = 1
61+ num_epochs = 3
4962
50- # Load Data
51- train_dataset = datasets .MNIST (
52- root = "dataset/" , train = True , transform = transforms .ToTensor (), download = True
53- )
63+ # Load Training and Test data
64+ train_dataset = datasets .MNIST (root = "dataset/" , train = True , transform = transforms .ToTensor (), download = True )
65+ test_dataset = datasets .MNIST (root = "dataset/" , train = False , transform = transforms .ToTensor (), download = True )
5466train_loader = DataLoader (dataset = train_dataset , batch_size = batch_size , shuffle = True )
55- test_dataset = datasets .MNIST (
56- root = "dataset/" , train = False , transform = transforms .ToTensor (), download = True
57- )
5867test_loader = DataLoader (dataset = test_dataset , batch_size = batch_size , shuffle = True )
5968
6069# Initialize network
@@ -66,7 +75,7 @@ def forward(self, x):
6675
6776# Train Network
6877for epoch in range (num_epochs ):
69- for batch_idx , (data , targets ) in enumerate (train_loader ):
78+ for batch_idx , (data , targets ) in enumerate (tqdm ( train_loader ) ):
7079 # Get data to cuda if possible
7180 data = data .to (device = device )
7281 targets = targets .to (device = device )
@@ -85,15 +94,9 @@ def forward(self, x):
8594 # gradient descent or adam step
8695 optimizer .step ()
8796
88- # Check accuracy on training & test to see how good our model
89-
9097
98+ # Check accuracy on training & test to see how good our model
9199def check_accuracy (loader , model ):
92- if loader .dataset .train :
93- print ("Checking accuracy on training data" )
94- else :
95- print ("Checking accuracy on test data" )
96-
97100 num_correct = 0
98101 num_samples = 0
99102 model .eval ()
@@ -109,12 +112,9 @@ def check_accuracy(loader, model):
109112 num_correct += (predictions == y ).sum ()
110113 num_samples += predictions .size (0 )
111114
112- print (
113- f"Got { num_correct } / { num_samples } with accuracy { float (num_correct )/ float (num_samples )* 100 :.2f} "
114- )
115-
116115 model .train ()
116+ return num_correct / num_samples
117117
118118
119- check_accuracy (train_loader , model )
120- check_accuracy (test_loader , model )
119+ print ( f"Accuracy on training set: { check_accuracy (train_loader , model ) * 100 :.2f } " )
120+ print ( f"Accuracy on test set: { check_accuracy (test_loader , model ) * 100 :.2f } " )
0 commit comments