Skip to content

Commit 7c15c3c

Browse files
shi27fengrusty1s
andauthored
Add function for the addition of two matrices (rusty1s#177)
* Create spadd.py Hi, Maybe it's trivial to have this function, but I still think it'll be helpful and it looks neat when applying matrix addition, i.e., C = A + B. Thanks * update * update * fix jit Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
1 parent 28f1295 commit 7c15c3c

File tree

5 files changed

+107
-19
lines changed

5 files changed

+107
-19
lines changed

test/test_add.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from itertools import product
2+
3+
import pytest
4+
import torch
5+
from torch_sparse import SparseTensor, add
6+
7+
from .utils import dtypes, devices, tensor
8+
9+
10+
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
11+
def test_add(dtype, device):
12+
rowA = torch.tensor([0, 0, 1, 2, 2], device=device)
13+
colA = torch.tensor([0, 2, 1, 0, 1], device=device)
14+
valueA = tensor([1, 2, 4, 1, 3], dtype, device)
15+
A = SparseTensor(row=rowA, col=colA, value=valueA)
16+
17+
rowB = torch.tensor([0, 0, 1, 2, 2], device=device)
18+
colB = torch.tensor([1, 2, 2, 1, 2], device=device)
19+
valueB = tensor([2, 3, 1, 2, 4], dtype, device)
20+
B = SparseTensor(row=rowB, col=colB, value=valueB)
21+
22+
C = A + B
23+
rowC, colC, valueC = C.coo()
24+
25+
assert rowC.tolist() == [0, 0, 0, 1, 1, 2, 2, 2]
26+
assert colC.tolist() == [0, 1, 2, 1, 2, 0, 1, 2]
27+
assert valueC.tolist() == [1, 2, 5, 4, 1, 1, 5, 4]
28+
29+
@torch.jit.script
30+
def jit_add(A: SparseTensor, B: SparseTensor) -> SparseTensor:
31+
return add(A, B)
32+
33+
jit_add(A, B)

torch_sparse/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from .eye import eye # noqa
6666
from .spmm import spmm # noqa
6767
from .spspmm import spspmm # noqa
68+
from .spadd import spadd # noqa
6869

6970
__all__ = [
7071
'SparseStorage',
@@ -111,5 +112,6 @@
111112
'eye',
112113
'spmm',
113114
'spspmm',
115+
'spadd',
114116
'__version__',
115117
]

torch_sparse/add.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,69 @@
11
from typing import Optional
22

33
import torch
4+
from torch import Tensor
45
from torch_scatter import gather_csr
56
from torch_sparse.tensor import SparseTensor
67

78

8-
def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
9-
rowptr, col, value = src.csr()
10-
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
11-
other = gather_csr(other.squeeze(1), rowptr)
12-
pass
13-
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
14-
other = other.squeeze(0)[col]
15-
else:
16-
raise ValueError(
17-
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
18-
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
19-
if value is not None:
20-
value = other.to(value.dtype).add_(value)
9+
@torch.jit._overload # noqa: F811
10+
def add(src, other): # noqa: F811
11+
# type: (SparseTensor, Tensor) -> SparseTensor
12+
pass
13+
14+
15+
@torch.jit._overload # noqa: F811
16+
def add(src, other): # noqa: F811
17+
# type: (SparseTensor, SparseTensor) -> SparseTensor
18+
pass
19+
20+
21+
def add(src, other): # noqa: F811
22+
if isinstance(other, Tensor):
23+
rowptr, col, value = src.csr()
24+
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise.
25+
other = gather_csr(other.squeeze(1), rowptr)
26+
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise.
27+
other = other.squeeze(0)[col]
28+
else:
29+
raise ValueError(
30+
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
31+
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
32+
if value is not None:
33+
value = other.to(value.dtype).add_(value)
34+
else:
35+
value = other.add_(1)
36+
return src.set_value(value, layout='coo')
37+
38+
elif isinstance(other, SparseTensor):
39+
rowA, colA, valueA = src.coo()
40+
rowB, colB, valueB = other.coo()
41+
42+
row = torch.cat([rowA, rowB], dim=0)
43+
col = torch.cat([colA, colB], dim=0)
44+
45+
value: Optional[Tensor] = None
46+
if valueA is not None and valueB is not None:
47+
value = torch.cat([valueA, valueB], dim=0)
48+
49+
M = max(src.size(0), other.size(0))
50+
N = max(src.size(1), other.size(1))
51+
sparse_sizes = (M, N)
52+
53+
out = SparseTensor(row=row, col=col, value=value,
54+
sparse_sizes=sparse_sizes)
55+
out = out.coalesce(reduce='sum')
56+
return out
57+
2158
else:
22-
value = other.add_(1)
23-
return src.set_value(value, layout='coo')
59+
raise NotImplementedError
2460

2561

2662
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
2763
rowptr, col, value = src.csr()
28-
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
64+
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise.
2965
other = gather_csr(other.squeeze(1), rowptr)
30-
pass
31-
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
66+
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise.
3267
other = other.squeeze(0)[col]
3368
else:
3469
raise ValueError(

torch_sparse/spadd.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
from torch_sparse import coalesce
3+
4+
5+
def spadd(indexA, valueA, indexB, valueB, m, n):
6+
"""Matrix addition of two sparse matrices.
7+
8+
Args:
9+
indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
10+
valueA (:class:`Tensor`): The value tensor of first sparse matrix.
11+
indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
12+
valueB (:class:`Tensor`): The value tensor of second sparse matrix.
13+
m (int): The first dimension of the sparse matrices.
14+
n (int): The second dimension of the sparse matrices.
15+
"""
16+
index = torch.cat([indexA, indexB], dim=-1)
17+
value = torch.cat([valueA, valueB], dim=0)
18+
return coalesce(index=index, value=value, m=m, n=n, op='add')

torch_sparse/storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ def sparse_reshape(self, num_rows: int, num_cols: int):
292292

293293
idx = self.sparse_size(1) * self.row() + self.col()
294294

295-
row = idx // num_cols
295+
row = torch.div(idx, num_cols, rounding_mode='floor')
296296
col = idx % num_cols
297297
assert row.dtype == torch.long and col.dtype == torch.long
298298

0 commit comments

Comments
 (0)