1- from typing import Tuple , Optional
1+ from typing import Optional , Tuple
22
33import torch
4- from torch_sparse .tensor import SparseTensor
4+ from torch import Tensor
5+
56from torch_sparse .permute import permute
7+ from torch_sparse .tensor import SparseTensor
68
79
8- def weight2metis (weight : torch . Tensor ) -> Optional [torch . Tensor ]:
10+ def weight2metis (weight : Tensor ) -> Optional [Tensor ]:
911 sorted_weight = weight .sort ()[0 ]
1012 diff = sorted_weight [1 :] - sorted_weight [:- 1 ]
1113 if diff .sum () == 0 :
@@ -20,16 +22,24 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
2022
2123
2224def partition (
23- src : SparseTensor , num_parts : int , recursive : bool = False ,
24- weighted : bool = False , node_weight : Optional [torch .Tensor ] = None
25- ) -> Tuple [SparseTensor , torch .Tensor , torch .Tensor ]:
25+ src : SparseTensor ,
26+ num_parts : int ,
27+ recursive : bool = False ,
28+ weighted : bool = False ,
29+ node_weight : Optional [Tensor ] = None ,
30+ balance_edge : bool = False ,
31+ ) -> Tuple [SparseTensor , Tensor , Tensor ]:
2632
2733 assert num_parts >= 1
2834 if num_parts == 1 :
2935 partptr = torch .tensor ([0 , src .size (0 )], device = src .device ())
3036 perm = torch .arange (src .size (0 ), device = src .device ())
3137 return src , partptr , perm
3238
39+ if balance_edge and node_weight :
40+ raise ValueError ("Cannot set 'balance_edge' and 'node_weight' at the "
41+ "same time in 'torch_sparse.partition'" )
42+
3343 rowptr , col , value = src .csr ()
3444 rowptr , col = rowptr .cpu (), col .cpu ()
3545
@@ -41,6 +51,10 @@ def partition(
4151 else :
4252 value = None
4353
54+ if balance_edge :
55+ node_weight = col .new_zeros (col .size (0 ))
56+ node_weight .scatter_add_ (0 , col , torch .ones_like (col ))
57+
4458 if node_weight is not None :
4559 assert node_weight .numel () == rowptr .numel () - 1
4660 node_weight = node_weight .view (- 1 ).detach ().cpu ()
0 commit comments