Skip to content

Commit e5857c5

Browse files
gchanansoumith
authored andcommitted
Implement Gather double backwards.
1 parent 7da77c4 commit e5857c5

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

test/test_autograd.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1891,7 +1891,6 @@ def unpack_variables(args):
18911891

18921892
gradgradcheck_exclude_classes = set((
18931893
'Cumprod',
1894-
'Gather',
18951894
'Norm',
18961895
'Prod',
18971896
))

torch/autograd/_functions/tensor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,10 +556,9 @@ def forward(ctx, input, dim, index):
556556
return input.gather(dim, index)
557557

558558
@staticmethod
559-
@once_differentiable
560559
def backward(ctx, grad_output):
561-
index, = ctx.saved_tensors
562-
grad_input = grad_output.new(ctx.input_size).zero_()
560+
index, = ctx.saved_variables
561+
grad_input = Variable(grad_output.data.new(ctx.input_size).zero_())
563562
return grad_input.scatter_add_(ctx.dim, index, grad_output), None, None
564563

565564

0 commit comments

Comments
 (0)