33and modifies this to train on the CIFAR10 dataset. The same method generalizes
44well to other datasets, but the modifications to the network may need to be changed.
55
6- Video explanation: https://youtu.be/U4bHxEhMGNk
7- Got any questions leave a comment on youtube :)
8-
96Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
107* 2020-04-08 Initial coding
8+ * 2022-12-19 Updated comments, minor code changes, made sure it works with latest PyTorch
119
1210"""
1311
2220) # Gives easier dataset managment and creates mini batches
2321import torchvision .datasets as datasets # Has standard datasets we can import in a nice way
2422import torchvision .transforms as transforms # Transformations we can perform on our dataset
23+ from tqdm import tqdm
2524
26- # Set device
2725device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
2826
2927# Hyperparameters
3230batch_size = 1024
3331num_epochs = 5
3432
35- # Simple Identity class that let's input pass without changes
36- class Identity (nn .Module ):
37- def __init__ (self ):
38- super (Identity , self ).__init__ ()
39-
40- def forward (self , x ):
41- return x
42-
43-
4433# Load pretrain model & modify it
45- model = torchvision .models .vgg16 (pretrained = True )
34+ model = torchvision .models .vgg16 (weights = "DEFAULT" )
4635
4736# If you want to do finetuning then set requires_grad = False
4837# Remove these two lines if you want to train entire model,
4938# and only want to load the pretrain weights.
5039for param in model .parameters ():
5140 param .requires_grad = False
5241
53- model .avgpool = Identity ()
42+ model .avgpool = nn . Identity ()
5443model .classifier = nn .Sequential (
5544 nn .Linear (512 , 100 ), nn .ReLU (), nn .Linear (100 , num_classes )
5645)
@@ -71,7 +60,7 @@ def forward(self, x):
7160for epoch in range (num_epochs ):
7261 losses = []
7362
74- for batch_idx , (data , targets ) in enumerate (train_loader ):
63+ for batch_idx , (data , targets ) in enumerate (tqdm ( train_loader ) ):
7564 # Get data to cuda if possible
7665 data = data .to (device = device )
7766 targets = targets .to (device = device )
0 commit comments