Skip to content

Commit 86ee75f

Browse files
committed
Fix for Long and Byte tensor indexing of Variables
1 parent 3194191 commit 86ee75f

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

test/test_autograd.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,34 @@ def test_volatile(self):
194194

195195
def test_indexing(self):
196196
x = torch.range(1, 16).resize_(4, 4)
197-
y = Variable(x)
198-
self.assertEqual(x[1], y[1].data)
199-
self.assertEqual(x[1, 1], y[1, 1].data[0])
200-
self.assertEqual(x[1:], y[1:].data)
201-
self.assertEqual(x[:2], y[:2].data)
202-
self.assertEqual(x[:2, 2], y[:2, 2].data)
203-
self.assertEqual(x[1:2, 2], y[1:2, 2].data)
204-
self.assertEqual(x[1, 2:], y[1, 2:].data)
197+
y = Variable(x, requires_grad=True)
198+
199+
def check_index(idx):
200+
y.grad.data.zero_()
201+
indexed_tensor = x[idx]
202+
indexed_var = y[idx]
203+
204+
indexed_var_t = indexed_var.data
205+
if not torch.is_tensor(indexed_tensor):
206+
indexed_var_t = indexed_var_t[0]
207+
self.assertEqual(indexed_tensor, indexed_var)
208+
209+
indexed_var.sum().backward()
210+
expected_grad = torch.zeros(4, 4)
211+
expected_grad[idx] = 1
212+
self.assertEqual(y.grad.data, expected_grad)
213+
214+
check_index(1)
215+
check_index((1, 1))
216+
check_index(slice(1, None))
217+
check_index(slice(None, 2))
218+
check_index((slice(None, 2), 2))
219+
check_index((slice(1, 2), 2))
220+
check_index((1, slice(2, None)))
221+
check_index((slice(None, None), slice(2, None)))
222+
check_index(torch.LongTensor([0, 2]))
223+
check_index(torch.rand(4, 4).bernoulli().byte())
224+
check_index((Ellipsis, slice(2, None)))
205225

206226
def test_requires_grad(self):
207227
x = Variable(torch.randn(5, 5))

torch/autograd/_functions/tensor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ def forward(self, i):
1818
return result
1919

2020
def backward(self, grad_output):
21-
# TODO: this won't have to be zeroed
2221
grad_input = grad_output.new(self.input_size).zero_()
23-
grad_input.index(self.index).copy_(grad_output)
22+
grad_input._set_index(self.index, grad_output)
2423
return grad_input
2524

2625

0 commit comments

Comments
 (0)