Skip to content

Commit 99735df

Browse files
Graph partition based on balance_edge (rusty1s#309)
* add balance_edge for graph partition * add Tensor defition in metis.py * update --------- Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
1 parent a980efd commit 99735df

File tree

1 file changed

+20
-6
lines changed

1 file changed

+20
-6
lines changed

torch_sparse/metis.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from typing import Tuple, Optional
1+
from typing import Optional, Tuple
22

33
import torch
4-
from torch_sparse.tensor import SparseTensor
4+
from torch import Tensor
5+
56
from torch_sparse.permute import permute
7+
from torch_sparse.tensor import SparseTensor
68

79

8-
def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
10+
def weight2metis(weight: Tensor) -> Optional[Tensor]:
911
sorted_weight = weight.sort()[0]
1012
diff = sorted_weight[1:] - sorted_weight[:-1]
1113
if diff.sum() == 0:
@@ -20,16 +22,24 @@ def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
2022

2123

2224
def partition(
23-
src: SparseTensor, num_parts: int, recursive: bool = False,
24-
weighted: bool = False, node_weight: Optional[torch.Tensor] = None
25-
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
25+
src: SparseTensor,
26+
num_parts: int,
27+
recursive: bool = False,
28+
weighted: bool = False,
29+
node_weight: Optional[Tensor] = None,
30+
balance_edge: bool = False,
31+
) -> Tuple[SparseTensor, Tensor, Tensor]:
2632

2733
assert num_parts >= 1
2834
if num_parts == 1:
2935
partptr = torch.tensor([0, src.size(0)], device=src.device())
3036
perm = torch.arange(src.size(0), device=src.device())
3137
return src, partptr, perm
3238

39+
if balance_edge and node_weight:
40+
raise ValueError("Cannot set 'balance_edge' and 'node_weight' at the "
41+
"same time in 'torch_sparse.partition'")
42+
3343
rowptr, col, value = src.csr()
3444
rowptr, col = rowptr.cpu(), col.cpu()
3545

@@ -41,6 +51,10 @@ def partition(
4151
else:
4252
value = None
4353

54+
if balance_edge:
55+
node_weight = col.new_zeros(col.size(0))
56+
node_weight.scatter_add_(0, col, torch.ones_like(col))
57+
4458
if node_weight is not None:
4559
assert node_weight.numel() == rowptr.numel() - 1
4660
node_weight = node_weight.view(-1).detach().cpu()

0 commit comments

Comments
 (0)