11import warnings
2- from typing import Optional , List , Tuple
2+ from typing import List , Optional , Tuple
33
44import torch
5- from torch_scatter import segment_csr , scatter_add
6- from torch_sparse .utils import Final
5+ from torch_scatter import scatter_add , segment_csr
6+
7+ from torch_sparse .utils import Final , index_sort
78
89layouts : Final [List [str ]] = ['coo' , 'csr' , 'csc' ]
910
@@ -151,7 +152,8 @@ def __init__(
151152 idx [1 :] *= self ._sparse_sizes [1 ]
152153 idx [1 :] += self ._col
153154 if (idx [1 :] < idx [:- 1 ]).any ():
154- perm = idx [1 :].argsort ()
155+ max_value = self ._sparse_sizes [0 ] * self ._sparse_sizes [1 ]
156+ _ , perm = index_sort (idx [1 :], max_value )
155157 self ._row = self .row ()[perm ]
156158 self ._col = self ._col [perm ]
157159 if value is not None :
@@ -163,10 +165,20 @@ def __init__(
163165 def empty (self ):
164166 row = torch .tensor ([], dtype = torch .long )
165167 col = torch .tensor ([], dtype = torch .long )
166- return SparseStorage (row = row , rowptr = None , col = col , value = None ,
167- sparse_sizes = (0 , 0 ), rowcount = None , colptr = None ,
168- colcount = None , csr2csc = None , csc2csr = None ,
169- is_sorted = True , trust_data = True )
168+ return SparseStorage (
169+ row = row ,
170+ rowptr = None ,
171+ col = col ,
172+ value = None ,
173+ sparse_sizes = (0 , 0 ),
174+ rowcount = None ,
175+ colptr = None ,
176+ colcount = None ,
177+ csr2csc = None ,
178+ csc2csr = None ,
179+ is_sorted = True ,
180+ trust_data = True ,
181+ )
170182
171183 def has_row (self ) -> bool :
172184 return self ._row is not None
@@ -209,8 +221,11 @@ def has_value(self) -> bool:
209221 def value (self ) -> Optional [torch .Tensor ]:
210222 return self ._value
211223
212- def set_value_ (self , value : Optional [torch .Tensor ],
213- layout : Optional [str ] = None ):
224+ def set_value_ (
225+ self ,
226+ value : Optional [torch .Tensor ],
227+ layout : Optional [str ] = None ,
228+ ):
214229 if value is not None :
215230 if get_layout (layout ) == 'csc' :
216231 value = value [self .csc2csr ()]
@@ -221,8 +236,11 @@ def set_value_(self, value: Optional[torch.Tensor],
221236 self ._value = value
222237 return self
223238
224- def set_value (self , value : Optional [torch .Tensor ],
225- layout : Optional [str ] = None ):
239+ def set_value (
240+ self ,
241+ value : Optional [torch .Tensor ],
242+ layout : Optional [str ] = None ,
243+ ):
226244 if value is not None :
227245 if get_layout (layout ) == 'csc' :
228246 value = value [self .csc2csr ()]
@@ -375,8 +393,11 @@ def colcount(self) -> torch.Tensor:
375393 if colptr is not None :
376394 colcount = colptr [1 :] - colptr [:- 1 ]
377395 else :
378- colcount = scatter_add (torch .ones_like (self ._col ), self ._col ,
379- dim_size = self ._sparse_sizes [1 ])
396+ colcount = scatter_add (
397+ torch .ones_like (self ._col ),
398+ self ._col ,
399+ dim_size = self ._sparse_sizes [1 ],
400+ )
380401 self ._colcount = colcount
381402 return colcount
382403
@@ -389,7 +410,8 @@ def csr2csc(self) -> torch.Tensor:
389410 return csr2csc
390411
391412 idx = self ._sparse_sizes [0 ] * self ._col + self .row ()
392- csr2csc = idx .argsort ()
413+ max_value = self ._sparse_sizes [0 ] * self ._sparse_sizes [1 ]
414+ _ , csr2csc = index_sort (idx , max_value )
393415 self ._csr2csc = csr2csc
394416 return csr2csc
395417
@@ -401,7 +423,8 @@ def csc2csr(self) -> torch.Tensor:
401423 if csc2csr is not None :
402424 return csc2csr
403425
404- csc2csr = self .csr2csc ().argsort ()
426+ max_value = self ._sparse_sizes [0 ] * self ._sparse_sizes [1 ]
427+ _ , csc2csr = index_sort (self .csr2csc (), max_value )
405428 self ._csc2csr = csc2csr
406429 return csc2csr
407430
@@ -543,7 +566,8 @@ def type(self, dtype: torch.dtype, non_blocking: bool = False):
543566 else :
544567 return self .set_value (
545568 value .to (dtype = dtype , non_blocking = non_blocking ),
546- layout = 'coo' )
569+ layout = 'coo' ,
570+ )
547571 else :
548572 return self
549573
0 commit comments