1
1
from typing import Optional
2
2
3
3
import torch
4
+ from torch import Tensor
4
5
from torch_scatter import gather_csr
6
+
5
7
from torch_sparse .tensor import SparseTensor
6
8
7
9
8
- def mul (src : SparseTensor , other : torch .Tensor ) -> SparseTensor :
9
- rowptr , col , value = src .csr ()
10
- if other .size (0 ) == src .size (0 ) and other .size (1 ) == 1 : # Row-wise...
11
- other = gather_csr (other .squeeze (1 ), rowptr )
12
- pass
13
- elif other .size (0 ) == 1 and other .size (1 ) == src .size (1 ): # Col-wise...
14
- other = other .squeeze (0 )[col ]
15
- else :
16
- raise ValueError (
17
- f'Size mismatch: Expected size ({ src .size (0 )} , 1, ...) or '
18
- f'(1, { src .size (1 )} , ...), but got size { other .size ()} .' )
10
+ @torch .jit ._overload # noqa: F811
11
+ def mul (src , other ): # noqa: F811
12
+ # type: (SparseTensor, Tensor) -> SparseTensor
13
+ pass
19
14
20
- if value is not None :
21
- value = other .to (value .dtype ).mul_ (value )
15
+
16
+ @torch .jit ._overload # noqa: F811
17
+ def mul (src , other ): # noqa: F811
18
+ # type: (SparseTensor, SparseTensor) -> SparseTensor
19
+ pass
20
+
21
+
22
+ def mul (src , other ): # noqa: F811
23
+ if isinstance (other , Tensor ):
24
+ rowptr , col , value = src .csr ()
25
+ if other .size (0 ) == src .size (0 ) and other .size (1 ) == 1 : # Row-wise...
26
+ other = gather_csr (other .squeeze (1 ), rowptr )
27
+ pass
28
+ # Col-wise...
29
+ elif other .size (0 ) == 1 and other .size (1 ) == src .size (1 ):
30
+ other = other .squeeze (0 )[col ]
31
+ else :
32
+ raise ValueError (
33
+ f'Size mismatch: Expected size ({ src .size (0 )} , 1, ...) or '
34
+ f'(1, { src .size (1 )} , ...), but got size { other .size ()} .' )
35
+
36
+ if value is not None :
37
+ value = other .to (value .dtype ).mul_ (value )
38
+ else :
39
+ value = other
40
+ return src .set_value (value , layout = 'coo' )
41
+
42
+ assert isinstance (other , SparseTensor )
43
+
44
+ if not src .is_coalesced ():
45
+ raise ValueError ("The `src` tensor is not coalesced" )
46
+ if not other .is_coalesced ():
47
+ raise ValueError ("The `other` tensor is not coalesced" )
48
+
49
+ rowA , colA , valueA = src .coo ()
50
+ rowB , colB , valueB = other .coo ()
51
+
52
+ row = torch .cat ([rowA , rowB ], dim = 0 )
53
+ col = torch .cat ([colA , colB ], dim = 0 )
54
+
55
+ if valueA is not None and valueB is not None :
56
+ value = torch .cat ([valueA , valueB ], dim = 0 )
22
57
else :
23
- value = other
24
- return src .set_value (value , layout = 'coo' )
58
+ raise ValueError ('Both sparse tensors must contain values' )
59
+
60
+ M = max (src .size (0 ), other .size (0 ))
61
+ N = max (src .size (1 ), other .size (1 ))
62
+ sparse_sizes = (M , N )
63
+
64
+ # Sort indices:
65
+ idx = col .new_full ((col .numel () + 1 , ), - 1 )
66
+ idx [1 :] = row * sparse_sizes [1 ] + col
67
+ perm = idx [1 :].argsort ()
68
+ idx [1 :] = idx [1 :][perm ]
69
+
70
+ row , col , value = row [perm ], col [perm ], value [perm ]
71
+
72
+ valid_mask = idx [1 :] == idx [:- 1 ]
73
+ valid_idx = valid_mask .nonzero ().view (- 1 )
74
+
75
+ return SparseTensor (
76
+ row = row [valid_mask ],
77
+ col = col [valid_mask ],
78
+ value = value [valid_idx - 1 ] * value [valid_idx ],
79
+ sparse_sizes = sparse_sizes ,
80
+ )
25
81
26
82
27
83
def mul_ (src : SparseTensor , other : torch .Tensor ) -> SparseTensor :
@@ -43,8 +99,11 @@ def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
43
99
return src .set_value_ (value , layout = 'coo' )
44
100
45
101
46
- def mul_nnz (src : SparseTensor , other : torch .Tensor ,
47
- layout : Optional [str ] = None ) -> SparseTensor :
102
+ def mul_nnz (
103
+ src : SparseTensor ,
104
+ other : torch .Tensor ,
105
+ layout : Optional [str ] = None ,
106
+ ) -> SparseTensor :
48
107
value = src .storage .value ()
49
108
if value is not None :
50
109
value = value .mul (other .to (value .dtype ))
@@ -53,8 +112,11 @@ def mul_nnz(src: SparseTensor, other: torch.Tensor,
53
112
return src .set_value (value , layout = layout )
54
113
55
114
56
- def mul_nnz_ (src : SparseTensor , other : torch .Tensor ,
57
- layout : Optional [str ] = None ) -> SparseTensor :
115
+ def mul_nnz_ (
116
+ src : SparseTensor ,
117
+ other : torch .Tensor ,
118
+ layout : Optional [str ] = None ,
119
+ ) -> SparseTensor :
58
120
value = src .storage .value ()
59
121
if value is not None :
60
122
value = value .mul_ (other .to (value .dtype ))
0 commit comments