Skip to content

Commit 13d7a2c

Browse files
AndreasBergmeisterAndreas Bergmeisterrusty1s
authored
sparse element-wise multiplication (#323)
* sparse element-wise multiplication * update * update * update * update --------- Co-authored-by: Andreas Bergmeister <andbergm@ethz.ch> Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
1 parent 2a73371 commit 13d7a2c

File tree

3 files changed

+136
-20
lines changed

3 files changed

+136
-20
lines changed

.github/workflows/testing.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ jobs:
4848
4949
- name: Install main package
5050
run: |
51-
pip install -e .[test]
51+
python setup.py develop
5252
env:
5353
WITH_METIS: 1
5454

5555
- name: Run test-suite
5656
run: |
57+
pip install pytest pytest-cov
5758
pytest --cov --cov-report=xml
5859
5960
- name: Upload coverage

test/test_mul.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from itertools import product
2+
3+
import pytest
4+
import torch
5+
6+
from torch_sparse import SparseTensor, mul
7+
from torch_sparse.testing import devices, dtypes, tensor
8+
9+
10+
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
11+
def test_sparse_sparse_mul(dtype, device):
12+
rowA = torch.tensor([0, 0, 1, 2, 2], device=device)
13+
colA = torch.tensor([0, 2, 1, 0, 1], device=device)
14+
valueA = tensor([1, 2, 4, 1, 3], dtype, device)
15+
A = SparseTensor(row=rowA, col=colA, value=valueA)
16+
17+
rowB = torch.tensor([0, 0, 1, 2, 2], device=device)
18+
colB = torch.tensor([1, 2, 2, 1, 2], device=device)
19+
valueB = tensor([2, 3, 1, 2, 4], dtype, device)
20+
B = SparseTensor(row=rowB, col=colB, value=valueB)
21+
22+
C = A * B
23+
rowC, colC, valueC = C.coo()
24+
25+
assert rowC.tolist() == [0, 2]
26+
assert colC.tolist() == [2, 1]
27+
assert valueC.tolist() == [6, 6]
28+
29+
@torch.jit.script
30+
def jit_mul(A: SparseTensor, B: SparseTensor) -> SparseTensor:
31+
return mul(A, B)
32+
33+
jit_mul(A, B)
34+
35+
36+
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
37+
def test_sparse_sparse_mul_empty(dtype, device):
38+
rowA = torch.tensor([0], device=device)
39+
colA = torch.tensor([1], device=device)
40+
valueA = tensor([1], dtype, device)
41+
A = SparseTensor(row=rowA, col=colA, value=valueA)
42+
43+
rowB = torch.tensor([1], device=device)
44+
colB = torch.tensor([0], device=device)
45+
valueB = tensor([2], dtype, device)
46+
B = SparseTensor(row=rowB, col=colB, value=valueB)
47+
48+
C = A * B
49+
rowC, colC, valueC = C.coo()
50+
51+
assert rowC.tolist() == []
52+
assert colC.tolist() == []
53+
assert valueC.tolist() == []

torch_sparse/mul.py

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,83 @@
11
from typing import Optional
22

33
import torch
4+
from torch import Tensor
45
from torch_scatter import gather_csr
6+
57
from torch_sparse.tensor import SparseTensor
68

79

8-
def mul(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
9-
rowptr, col, value = src.csr()
10-
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
11-
other = gather_csr(other.squeeze(1), rowptr)
12-
pass
13-
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
14-
other = other.squeeze(0)[col]
15-
else:
16-
raise ValueError(
17-
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
18-
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
10+
@torch.jit._overload # noqa: F811
11+
def mul(src, other): # noqa: F811
12+
# type: (SparseTensor, Tensor) -> SparseTensor
13+
pass
1914

20-
if value is not None:
21-
value = other.to(value.dtype).mul_(value)
15+
16+
@torch.jit._overload # noqa: F811
17+
def mul(src, other): # noqa: F811
18+
# type: (SparseTensor, SparseTensor) -> SparseTensor
19+
pass
20+
21+
22+
def mul(src, other): # noqa: F811
23+
if isinstance(other, Tensor):
24+
rowptr, col, value = src.csr()
25+
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
26+
other = gather_csr(other.squeeze(1), rowptr)
27+
pass
28+
# Col-wise...
29+
elif other.size(0) == 1 and other.size(1) == src.size(1):
30+
other = other.squeeze(0)[col]
31+
else:
32+
raise ValueError(
33+
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
34+
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
35+
36+
if value is not None:
37+
value = other.to(value.dtype).mul_(value)
38+
else:
39+
value = other
40+
return src.set_value(value, layout='coo')
41+
42+
assert isinstance(other, SparseTensor)
43+
44+
if not src.is_coalesced():
45+
raise ValueError("The `src` tensor is not coalesced")
46+
if not other.is_coalesced():
47+
raise ValueError("The `other` tensor is not coalesced")
48+
49+
rowA, colA, valueA = src.coo()
50+
rowB, colB, valueB = other.coo()
51+
52+
row = torch.cat([rowA, rowB], dim=0)
53+
col = torch.cat([colA, colB], dim=0)
54+
55+
if valueA is not None and valueB is not None:
56+
value = torch.cat([valueA, valueB], dim=0)
2257
else:
23-
value = other
24-
return src.set_value(value, layout='coo')
58+
raise ValueError('Both sparse tensors must contain values')
59+
60+
M = max(src.size(0), other.size(0))
61+
N = max(src.size(1), other.size(1))
62+
sparse_sizes = (M, N)
63+
64+
# Sort indices:
65+
idx = col.new_full((col.numel() + 1, ), -1)
66+
idx[1:] = row * sparse_sizes[1] + col
67+
perm = idx[1:].argsort()
68+
idx[1:] = idx[1:][perm]
69+
70+
row, col, value = row[perm], col[perm], value[perm]
71+
72+
valid_mask = idx[1:] == idx[:-1]
73+
valid_idx = valid_mask.nonzero().view(-1)
74+
75+
return SparseTensor(
76+
row=row[valid_mask],
77+
col=col[valid_mask],
78+
value=value[valid_idx - 1] * value[valid_idx],
79+
sparse_sizes=sparse_sizes,
80+
)
2581

2682

2783
def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
@@ -43,8 +99,11 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
4399
return src.set_value_(value, layout='coo')
44100

45101

46-
def mul_nnz(src: SparseTensor, other: torch.Tensor,
47-
layout: Optional[str] = None) -> SparseTensor:
102+
def mul_nnz(
103+
src: SparseTensor,
104+
other: torch.Tensor,
105+
layout: Optional[str] = None,
106+
) -> SparseTensor:
48107
value = src.storage.value()
49108
if value is not None:
50109
value = value.mul(other.to(value.dtype))
@@ -53,8 +112,11 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor,
53112
return src.set_value(value, layout=layout)
54113

55114

56-
def mul_nnz_(src: SparseTensor, other: torch.Tensor,
57-
layout: Optional[str] = None) -> SparseTensor:
115+
def mul_nnz_(
116+
src: SparseTensor,
117+
other: torch.Tensor,
118+
layout: Optional[str] = None,
119+
) -> SparseTensor:
58120
value = src.storage.value()
59121
if value is not None:
60122
value = value.mul_(other.to(value.dtype))

0 commit comments

Comments
 (0)