Skip to content

Commit f17cfe4

Browse files
martinraisonapaszke
authored andcommitted
sparse tensor operations (pytorch#735)
1 parent c93c884 commit f17cfe4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+2556
-271
lines changed

test/common.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,18 @@ def assertEqual(self, x, y, prec=None, message=''):
118118
y = y.data
119119

120120
if torch.is_tensor(x) and torch.is_tensor(y):
121-
max_err = 0
122-
super(TestCase, self).assertEqual(x.size(), y.size())
123-
for index in iter_indices(x):
124-
max_err = max(max_err, abs(x[index] - y[index]))
125-
self.assertLessEqual(max_err, prec, message)
121+
def assertTensorsEqual(a, b):
122+
max_err = 0
123+
super(TestCase, self).assertEqual(a.size(), b.size())
124+
for index in iter_indices(a):
125+
max_err = max(max_err, abs(a[index] - b[index]))
126+
self.assertLessEqual(max_err, prec, message)
127+
self.assertEqual(x.is_sparse, y.is_sparse, message)
128+
if x.is_sparse:
129+
assertTensorsEqual(x.indices(), y.indices())
130+
assertTensorsEqual(x.values(), y.values())
131+
else:
132+
assertTensorsEqual(x, y)
126133
elif type(x) == str and type(y) == str:
127134
super(TestCase, self).assertEqual(x, y)
128135
elif is_iterable(x) and is_iterable(y):

test/common_nn.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,15 +337,18 @@ def _jacobian(self, input, num_out):
337337

338338
def _flatten_tensors(self, x):
339339
if torch.is_tensor(x):
340-
return x.view(-1)
340+
if x.is_sparse:
341+
return x.to_dense().view(-1)
342+
else:
343+
return x.view(-1)
341344
elif isinstance(x, Variable):
342-
return x.data.view(-1)
345+
return self._flatten_tensors(x.data)
343346
else:
344347
return tuple(self._flatten_tensors(a) for a in x)
345348

346349
def _zero_grad_input(self, input):
347350
if isinstance(input, Variable):
348-
if input.requires_grad:
351+
if input.requires_grad and input.grad is not None:
349352
input.grad.data.zero_()
350353
elif torch.is_tensor(input):
351354
return

test/test_autograd.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,49 @@ def _test_backward(self):
128128
def test_backward(self):
129129
self._test_backward()
130130

131+
def test_sparse_backward(self):
132+
class FixedGradientFunction(Function):
133+
134+
def __init__(self, grad):
135+
self.grad = grad
136+
137+
def forward(self, x):
138+
return x
139+
140+
def backward(self, grad_x):
141+
return self.grad
142+
143+
size = torch.Size([6, 3, 2])
144+
i1 = torch.LongTensor([
145+
[0, 3, 4],
146+
[0, 2, 2],
147+
])
148+
v1 = torch.DoubleTensor([[1, 2], [4, 5], [7, 8]])
149+
sparse_grad1 = torch.sparse.DoubleTensor(i1, v1, size)
150+
i2 = torch.LongTensor([
151+
[0, 1, 3, 4],
152+
[0, 1, 2, 2],
153+
])
154+
v2 = torch.DoubleTensor([[1, 2], [4, 3], [4, 5], [7, 8]])
155+
sparse_grad2 = torch.sparse.DoubleTensor(i2, v2, size)
156+
dense_grad = torch.rand(size).double()
157+
sparse_fn1 = FixedGradientFunction(sparse_grad1)
158+
sparse_fn2 = FixedGradientFunction(sparse_grad2)
159+
dense_fn = FixedGradientFunction(dense_grad)
160+
161+
# sparse first
162+
x = Variable(torch.randn(5, 5), requires_grad=True)
163+
(sparse_fn1(x) + dense_fn(x) + sparse_fn2(x)).sum().backward()
164+
self.assertEqual(x.grad.data, dense_grad + sparse_grad1 + sparse_grad2)
165+
# dense first
166+
x = Variable(torch.randn(5, 5), requires_grad=True)
167+
(dense_fn(x) + sparse_fn1(x) + sparse_fn2(x)).sum().backward()
168+
self.assertEqual(x.grad.data, dense_grad + sparse_grad1 + sparse_grad2)
169+
# sparse only
170+
x = Variable(torch.randn(5, 5), requires_grad=True)
171+
(sparse_fn1(x) + sparse_fn2(x)).sum().backward()
172+
self.assertEqual(x.grad.data, sparse_grad1 + sparse_grad2)
173+
131174
@unittest.skip("BasicEngine is out of date")
132175
def test_backward_basic_engine(self):
133176
with backward_engine(torch.autograd.engine.BasicEngine):
@@ -197,7 +240,8 @@ def test_indexing(self):
197240
y = Variable(x, requires_grad=True)
198241

199242
def check_index(idx):
200-
y.grad.data.zero_()
243+
if y.grad is not None:
244+
y.grad.data.zero_()
201245
indexed_tensor = x[idx]
202246
indexed_var = y[idx]
203247

test/test_multiprocessing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def autograd_sharing(queue, ready, master_modified):
8080
is_ok = var.data.equal(expected_var)
8181
var.data[:] = torch.ones(5, 5)
8282

83-
is_ok &= var.grad.data.equal(torch.zeros(5, 5))
84-
var.grad.data[:] = torch.ones(5, 5)
83+
is_ok &= var.grad is None
84+
var._grad = Variable(torch.ones(5, 5), requires_grad=False)
8585

8686
queue.put(is_ok)
8787

@@ -358,7 +358,7 @@ def _test_autograd_sharing(self, var):
358358
queue = mp.Queue()
359359
p = mp.Process(target=autograd_sharing, args=(queue, ready, master_modified))
360360
p.start()
361-
var.grad.data.zero_()
361+
var._grad = Variable(torch.zeros(5, 5), requires_grad=False)
362362
queue.put(var)
363363

364364
ready.wait()

test/test_nn.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def _forward_criterion(self, criterion, input, target):
196196
def _backward_criterion(self, criterion, input, target):
197197
input_tuple = input if isinstance(input, tuple) else (input,)
198198
for i in input_tuple:
199-
i.grad.data.zero_()
199+
if i.grad is not None:
200+
i.grad.data.zero_()
200201
args = input_tuple + (target,)
201202
criterion(*args).backward()
202203
if isinstance(input, tuple):
@@ -206,18 +207,24 @@ def _backward_criterion(self, criterion, input, target):
206207

207208
def _zero_grad_parameters(self, module):
208209
if hasattr(module, 'weight') and module.weight is not None:
209-
module.weight.grad.data.zero_()
210+
if module.weight.grad is not None:
211+
module.weight.grad.data.zero_()
210212
if hasattr(module, 'bias') and module.bias is not None:
211-
module.bias.grad.data.zero_()
213+
if module.bias.grad is not None:
214+
module.bias.grad.data.zero_()
212215

213216
def _get_parameters(self, module):
214217
params = []
215218
d_params = []
216219
if hasattr(module, 'weight') and module.weight is not None:
217220
params += [module.weight.data]
221+
if module.weight.grad is None:
222+
module.weight._grad = Variable(module.weight.data.clone().zero_())
218223
d_params += [module.weight.grad.data]
219224
if hasattr(module, 'bias') and module.bias is not None:
220225
params += [module.bias.data]
226+
if module.bias.grad is None:
227+
module.bias._grad = Variable(module.bias.data.clone().zero_())
221228
d_params += [module.bias.grad.data]
222229
return params, d_params
223230

@@ -356,13 +363,13 @@ def test_zero_grad(self):
356363
module.zero_grad()
357364

358365
module.weight.requires_grad = True
359-
module.weight.grad.data.fill_(1)
366+
module.weight._grad = Variable(module.weight.data.clone().fill_(1))
360367
module.zero_grad()
361368
self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
362369

363370
module.bias.requires_grad = True
364-
module.weight.grad.data.fill_(1)
365-
module.bias.grad.data.fill_(1)
371+
module.weight._grad = Variable(module.weight.data.clone().fill_(1))
372+
module.bias._grad = Variable(module.bias.data.clone().fill_(1))
366373
module.zero_grad()
367374
self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
368375
self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
@@ -586,7 +593,7 @@ def compare_scaling(grads):
586593
grads = torch.range(1, 100), torch.ones(10).div(1000)
587594
for norm_type in [0.5, 1.5, 2, 4, 'inf']:
588595
for p, g in zip(l.parameters(), grads):
589-
p.grad.data.copy_(g)
596+
p._grad = Variable(g.clone())
590597
norm_before = compute_norm(norm_type)
591598
clip_grad_norm(l.parameters(), max_norm, norm_type=norm_type)
592599
norm_after = compute_norm(norm_type)
@@ -1167,7 +1174,8 @@ def pad(tensor, length):
11671174
self.assertEqual(unpacked_len, lengths)
11681175

11691176
# check grad
1170-
padded.grad.data.zero_()
1177+
if padded.grad is not None:
1178+
padded.grad.data.zero_()
11711179
grad_output = unpacked.data.clone().normal_()
11721180
unpacked.backward(grad_output)
11731181
if batch_first:
@@ -1185,13 +1193,15 @@ def pad(var, length):
11851193

11861194
lengths = [10, 10, 6, 2, 2, 1, 1]
11871195
max_length = lengths[0]
1188-
x = Variable(torch.randn(max_length, len(lengths), 3), requires_grad=True)
1196+
x_leaf = Variable(torch.randn(max_length, len(lengths), 3), requires_grad=True)
11891197
lstm = nn.LSTM(3, 4, bidirectional=True, num_layers=2)
11901198
lstm2 = deepcopy(lstm)
11911199
if cuda:
1192-
x = x.cuda()
1200+
x = x_leaf.cuda()
11931201
lstm.cuda()
11941202
lstm2.cuda()
1203+
else:
1204+
x = x_leaf
11951205

11961206
# Compute sequences separately
11971207
seq_outs = []
@@ -1216,11 +1226,11 @@ def pad(var, length):
12161226

12171227
# Check backward
12181228
seq_out.sum().backward()
1219-
grad_x = x.grad.data.clone()
1220-
x.grad.data.zero_()
1229+
grad_x = x_leaf.grad.data.clone()
1230+
x_leaf.grad.data.zero_()
12211231
unpacked.sum().backward()
12221232

1223-
self.assertEqual(x.grad.data, grad_x)
1233+
self.assertEqual(x_leaf.grad.data, grad_x)
12241234
for p1, p2 in zip(lstm.parameters(), lstm2.parameters()):
12251235
self.assertEqual(p1.grad, p2.grad)
12261236

@@ -1576,11 +1586,12 @@ def test_noncontig_conv_grad(self):
15761586
grad = torch.randn(2, 2, 5, 10, 10).cuda()[:, 1]
15771587
assert not grad.is_contiguous()
15781588
output.backward(grad, retain_variables=True)
1579-
result = output.grad.data.clone()
1580-
output.grad.data.zero_()
1589+
self.assertIsNotNone(input.grad)
1590+
result = input.grad.data.clone()
1591+
input.grad.data.zero_()
15811592

15821593
output.backward(grad.contiguous())
1583-
self.assertEqual(result, output.grad.data)
1594+
self.assertEqual(result, input.grad.data)
15841595

15851596
def test_pixel_shuffle(self):
15861597
batch_size = random.randint(1, 3)
@@ -1613,7 +1624,8 @@ def test_batchnorm_eval(self):
16131624
grad1 = data.grad.data.clone()
16141625

16151626
# 2nd pass
1616-
data.grad.data.zero_()
1627+
if data.grad is not None:
1628+
data.grad.data.zero_()
16171629

16181630
res2 = module(data)
16191631
res2.backward(grad)

0 commit comments

Comments
 (0)