Skip to content

Commit b018213

Browse files
chuyang-dengyuanzhua
authored andcommitted
disable downloading data in mnist.py (aws#805)
1 parent 80333fd commit b018213

File tree

1 file changed

+2
-2
lines changed
  • sagemaker-python-sdk/pytorch_mnist

1 file changed

+2
-2
lines changed

sagemaker-python-sdk/pytorch_mnist/mnist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _get_train_data_loader(batch_size, training_dir, is_distributed, **kwargs):
4343
dataset = datasets.MNIST(training_dir, train=True, transform=transforms.Compose([
4444
transforms.ToTensor(),
4545
transforms.Normalize((0.1307,), (0.3081,))
46-
]), download=True)
46+
]))
4747
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None
4848
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=train_sampler is None,
4949
sampler=train_sampler, **kwargs)
@@ -55,7 +55,7 @@ def _get_test_data_loader(test_batch_size, training_dir, **kwargs):
5555
datasets.MNIST(training_dir, train=False, transform=transforms.Compose([
5656
transforms.ToTensor(),
5757
transforms.Normalize((0.1307,), (0.3081,))
58-
]), download=True),
58+
])),
5959
batch_size=test_batch_size, shuffle=True, **kwargs)
6060

6161

0 commit comments

Comments
 (0)