There was an error while loading. Please reload this page.
1 parent 89facbe commit 5c41070Copy full SHA for 5c41070
mnist_hogwild/main.py
@@ -50,7 +50,7 @@ def forward(self, x):
50
torch.manual_seed(args.seed)
51
52
model = Net()
53
- model.share_memory()
+ model.share_memory() # gradients are allocated lazily, so they are not shared here
54
55
processes = []
56
for rank in range(args.num_processes):
mnist_hogwild/train.py
@@ -7,10 +7,6 @@
7
8
def train(rank, args, model):
9
torch.manual_seed(args.seed + rank)
10
- for param in model.parameters():
11
- # Break gradient sharing
12
- if param.grad is not None:
13
- param.grad.data = param.grad.data.clone()
14
15
train_loader = torch.utils.data.DataLoader(
16
datasets.MNIST('../data', train=True, download=True,
0 commit comments