Skip to content

Commit acb784e

Browse files
committed
hack nvlink
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 054ac0f commit acb784e

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

vllm/distributed/device_communicators/all2all.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def __init__(self, cpu_group, device):
262262
self.device = device
263263
self.handle_cache = Cache()
264264

265-
def get_handle(self, kwargs):
265+
def get_handle(self, kwargs, nvlink: bool = True):
266266
from rose.distributed.torch_group import TorchParallelGroup
267267
from rose.kernels.efa_all_to_all import EfaAllToAll
268268

@@ -285,11 +285,9 @@ def get_handle(self, kwargs):
285285
ranks=dp_group.ranks,
286286
)
287287

288-
kwargs["nvlink"] = True
289-
290288
kwargs["nets_per_gpu"] = _nets_per_gpu()
291289
kwargs["dp_group"] = tp_group
292-
kwargs["node_group"] = global_group if kwargs.get("nvlink", False) else None
290+
kwargs["node_group"] = global_group if nvlink else None
293291
kwargs["global_group"] = global_group
294292
kwargs["device"] = self.device
295293

0 commit comments

Comments
 (0)