Skip to content

Commit 5c41070

Browse files
pfrendlapaszke
authored andcommitted
mnist_hogwild manual breaking of gradient sharing removed (pytorch#138)
1 parent 89facbe commit 5c41070

File tree

2 files changed

+1
-5
lines changed

2 files changed

+1
-5
lines changed

mnist_hogwild/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def forward(self, x):
5050
torch.manual_seed(args.seed)
5151

5252
model = Net()
53-
model.share_memory()
53+
model.share_memory() # gradients are allocated lazily, so they are not shared here
5454

5555
processes = []
5656
for rank in range(args.num_processes):

mnist_hogwild/train.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77

88
def train(rank, args, model):
99
torch.manual_seed(args.seed + rank)
10-
for param in model.parameters():
11-
# Break gradient sharing
12-
if param.grad is not None:
13-
param.grad.data = param.grad.data.clone()
1410

1511
train_loader = torch.utils.data.DataLoader(
1612
datasets.MNIST('../data', train=True, download=True,

0 commit comments

Comments
 (0)