Skip to content

Commit 7f1133e

Browse files
authored
Adding notebook for MNIST training using PyTorch and StepFunctions (aws#4599)
* Update training_pipeline_pytorch_mnist.ipynb * Update mnist.py # Set a fixed random seed for reproducibility SEED = 42 torch.manual_seed(SEED) np.random.seed(SEED) random.seed(SEED) * Update mnist.py The main change is replacing with torch.no_grad(): with with torch.inference_mode():. * Added CI badge in notebook * Added CI badge in notebook * Reformatted the code * Reformatted the code
1 parent 9c156a8 commit 7f1133e

File tree

3 files changed

+941
-0
lines changed

3 files changed

+941
-0
lines changed
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
import argparse
2+
import json
3+
import logging
4+
import os
5+
import sys
6+
import random
7+
import sagemaker_containers
8+
import torch
9+
import torch.distributed as dist
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
import torch.optim as optim
13+
import torch.utils.data
14+
import torch.utils.data.distributed
15+
import numpy as np
16+
from torchvision import datasets, transforms
17+
18+
# Set a fixed random seed for reproducibility
19+
20+
SEED = 42
21+
torch.manual_seed(SEED)
22+
np.random.seed(SEED)
23+
random.seed(SEED)
24+
25+
logger = logging.getLogger(__name__)
26+
logger.setLevel(logging.DEBUG)
27+
logger.addHandler(logging.StreamHandler(sys.stdout))
28+
29+
30+
# Based on https://github.com/pytorch/examples/blob/master/mnist/main.py
31+
class Net(nn.Module):
32+
def __init__(self):
33+
super(Net, self).__init__()
34+
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
35+
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
36+
self.conv2_drop = nn.Dropout2d()
37+
self.fc1 = nn.Linear(320, 50)
38+
self.fc2 = nn.Linear(50, 10)
39+
40+
def forward(self, x):
41+
x = F.relu(F.max_pool2d(self.conv1(x), 2))
42+
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
43+
x = x.view(-1, 320)
44+
x = F.relu(self.fc1(x))
45+
x = F.dropout(x, training=self.training)
46+
x = self.fc2(x)
47+
return F.log_softmax(x, dim=1)
48+
49+
50+
def _get_train_data_loader(batch_size, training_dir, is_distributed, **kwargs):
51+
logger.info("Printing the Training Dir path")
52+
logger.info(training_dir)
53+
logger.info(os.listdir(training_dir + '/MNIST'))
54+
logger.info("Get train data loader")
55+
dataset = datasets.MNIST(
56+
training_dir,
57+
train=True,
58+
transform=transforms.Compose(
59+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
60+
),
61+
)
62+
train_sampler = (
63+
torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None
64+
)
65+
return torch.utils.data.DataLoader(
66+
dataset,
67+
batch_size=batch_size,
68+
shuffle=train_sampler is None,
69+
sampler=train_sampler,
70+
**kwargs
71+
)
72+
73+
74+
def _get_test_data_loader(test_batch_size, training_dir, **kwargs):
75+
logger.info("Get test data loader")
76+
return torch.utils.data.DataLoader(
77+
datasets.MNIST(
78+
training_dir,
79+
train=False,
80+
transform=transforms.Compose(
81+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
82+
),
83+
),
84+
batch_size=test_batch_size,
85+
shuffle=True,
86+
**kwargs
87+
)
88+
89+
90+
def _average_gradients(model):
91+
# Gradient averaging.
92+
size = float(dist.get_world_size())
93+
for param in model.parameters():
94+
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
95+
param.grad.data /= size
96+
97+
98+
def train(args):
99+
is_distributed = len(args.hosts) > 1 and args.backend is not None
100+
logger.debug("Distributed training - {}".format(is_distributed))
101+
use_cuda = args.num_gpus > 0
102+
logger.debug("Number of gpus available - {}".format(args.num_gpus))
103+
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
104+
device = torch.device("cuda" if use_cuda else "cpu")
105+
106+
if is_distributed:
107+
# Initialize the distributed environment.
108+
world_size = len(args.hosts)
109+
os.environ["WORLD_SIZE"] = str(world_size)
110+
host_rank = args.hosts.index(args.current_host)
111+
os.environ["RANK"] = str(host_rank)
112+
dist.init_process_group(backend=args.backend, rank=host_rank, world_size=world_size)
113+
logger.info(
114+
"Initialized the distributed environment: '{}' backend on {} nodes. ".format(
115+
args.backend, dist.get_world_size()
116+
)
117+
+ "Current host rank is {}. Number of gpus: {}".format(dist.get_rank(), args.num_gpus)
118+
)
119+
120+
# set the seed for generating random numbers
121+
torch.manual_seed(args.seed)
122+
if use_cuda:
123+
torch.cuda.manual_seed(args.seed)
124+
125+
train_loader = _get_train_data_loader(args.batch_size, args.data_dir, is_distributed, **kwargs)
126+
test_loader = _get_test_data_loader(args.test_batch_size, args.data_dir, **kwargs)
127+
128+
logger.debug(
129+
"Processes {}/{} ({:.0f}%) of train data".format(
130+
len(train_loader.sampler),
131+
len(train_loader.dataset),
132+
100.0 * len(train_loader.sampler) / len(train_loader.dataset),
133+
)
134+
)
135+
136+
logger.debug(
137+
"Processes {}/{} ({:.0f}%) of test data".format(
138+
len(test_loader.sampler),
139+
len(test_loader.dataset),
140+
100.0 * len(test_loader.sampler) / len(test_loader.dataset),
141+
)
142+
)
143+
144+
model = Net().to(device)
145+
if is_distributed and use_cuda:
146+
# multi-machine multi-gpu case
147+
model = torch.nn.parallel.DistributedDataParallel(model)
148+
else:
149+
# single-machine multi-gpu case or single-machine or multi-machine cpu case
150+
model = torch.nn.DataParallel(model)
151+
152+
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
153+
154+
for epoch in range(1, args.epochs + 1):
155+
model.train()
156+
for batch_idx, (data, target) in enumerate(train_loader, 1):
157+
data, target = data.to(device), target.to(device)
158+
optimizer.zero_grad()
159+
output = model(data)
160+
loss = F.nll_loss(output, target)
161+
loss.backward()
162+
if is_distributed and not use_cuda:
163+
# average gradients manually for multi-machine cpu case only
164+
_average_gradients(model)
165+
optimizer.step()
166+
if batch_idx % args.log_interval == 0:
167+
logger.info(
168+
"Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}".format(
169+
epoch,
170+
batch_idx * len(data),
171+
len(train_loader.sampler),
172+
100.0 * batch_idx / len(train_loader),
173+
loss.item(),
174+
)
175+
)
176+
test(model, test_loader, device)
177+
save_model(model, args.model_dir)
178+
179+
180+
def test(model, test_loader, device):
181+
model.eval()
182+
test_loss = 0
183+
correct = 0
184+
with torch.inference_mode():
185+
for data, target in test_loader:
186+
data, target = data.to(device), target.to(device)
187+
output = model(data)
188+
test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
189+
pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
190+
correct += pred.eq(target.view_as(pred)).sum().item()
191+
192+
test_loss /= len(test_loader.dataset)
193+
logger.info(
194+
"Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
195+
test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset)
196+
)
197+
)
198+
199+
200+
def model_fn(model_dir):
201+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
202+
model = torch.nn.DataParallel(Net())
203+
with open(os.path.join(model_dir, "model.pth"), "rb") as f:
204+
model.load_state_dict(torch.load(f))
205+
return model.to(device)
206+
207+
208+
def save_model(model, model_dir):
209+
logger.info("Saving the model.")
210+
path = os.path.join(model_dir, "model.pth")
211+
scripted_module = torch.jit.trace(model, torch.randn((1, 1, 28, 28)))
212+
torch.jit.save(scripted_module, path)
213+
214+
215+
if __name__ == "__main__":
216+
parser = argparse.ArgumentParser()
217+
218+
# Data and model checkpoints directories
219+
parser.add_argument(
220+
"--batch-size",
221+
type=int,
222+
default=64,
223+
metavar="N",
224+
help="input batch size for training (default: 64)",
225+
)
226+
parser.add_argument(
227+
"--test-batch-size",
228+
type=int,
229+
default=1000,
230+
metavar="N",
231+
help="input batch size for testing (default: 1000)",
232+
)
233+
parser.add_argument(
234+
"--epochs",
235+
type=int,
236+
default=10,
237+
metavar="N",
238+
help="number of epochs to train (default: 10)",
239+
)
240+
parser.add_argument(
241+
"--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)"
242+
)
243+
parser.add_argument(
244+
"--momentum", type=float, default=0.5, metavar="M", help="SGD momentum (default: 0.5)"
245+
)
246+
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
247+
parser.add_argument(
248+
"--log-interval",
249+
type=int,
250+
default=100,
251+
metavar="N",
252+
help="how many batches to wait before logging training status",
253+
)
254+
parser.add_argument(
255+
"--backend",
256+
type=str,
257+
default=None,
258+
help="backend for distributed training (tcp, gloo on cpu and gloo, nccl on gpu)",
259+
)
260+
261+
# Container environment
262+
parser.add_argument("--hosts", type=list, default=json.loads(os.environ["SM_HOSTS"]))
263+
parser.add_argument("--current-host", type=str, default=os.environ["SM_CURRENT_HOST"])
264+
parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
265+
parser.add_argument("--data-dir", type=str, default=os.environ["SM_CHANNEL_TRAINING"])
266+
parser.add_argument("--num-gpus", type=int, default=os.environ["SM_NUM_GPUS"])
267+
268+
train(parser.parse_args())
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
sagemaker_containers

0 commit comments

Comments
 (0)