Skip to content

Commit 2beaa66

Browse files
author
agunapal
committed
Notebook for training ResNet with torch.compile
1 parent 1ef5680 commit 2beaa66

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

torch_compile/main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
##################################################
3737

3838
def train(args, model, device, train_loader, optimizer, criterion, epoch, profile):
39+
"""
40+
Train the model
41+
"""
3942
model.train()
4043

4144
if profile:
@@ -86,6 +89,9 @@ def train(args, model, device, train_loader, optimizer, criterion, epoch, profil
8689
))
8790

8891
def test(model, device, test_loader, criterion):
92+
"""
93+
Evaluate the model
94+
"""
8995
model.eval()
9096
test_loss = 0
9197
correct = 0

0 commit comments

Comments
 (0)