Skip to content

Commit e625315

Browse files
authored
bugfix: only check backend when mode == Collecive (#36758)
* bugfix: only check backend when mode == Collecive * fix bug
1 parent d5245a3 commit e625315

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

python/paddle/distributed/fleet/launch.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,20 @@ def launch_ps(args, distribute_mode):
334334
return
335335

336336

337+
def infer_backend(args):
338+
if args.backend != "auto": return
339+
if fluid.core.is_compiled_with_cuda():
340+
args.backend = 'nccl'
341+
elif fluid.core.is_compiled_with_npu():
342+
args.backend = 'unknown'
343+
elif fluid.core.is_compiled_with_xpu():
344+
args.backend = 'bkcl'
345+
else:
346+
args.backend = 'gloo'
347+
348+
337349
def which_distributed_mode(args):
350+
infer_backend(args) # modify the args.backend
338351
if args.run_mode is not None:
339352
assert args.run_mode in ["collective", "ps", "ps-heter"]
340353

@@ -368,12 +381,9 @@ def which_distributed_mode(args):
368381

369382
if fluid.core.is_compiled_with_cuda():
370383
accelerators = fluid.core.get_cuda_device_count()
371-
args.backend = 'nccl'
372384
elif fluid.core.is_compiled_with_npu():
373-
args.backend = 'unknown'
374385
accelerators = fluid.core.get_npu_device_count()
375386
elif fluid.core.is_compiled_with_xpu():
376-
args.backend = 'bkcl'
377387
accelerators = fluid.core.get_xpu_device_count()
378388
else:
379389
accelerators = 0
@@ -400,7 +410,6 @@ def which_distributed_mode(args):
400410
But found args.servers not empty, default use ps mode")
401411
return DistributeMode.PS
402412
else:
403-
args.backend = "gloo"
404413
return DistributeMode.COLLECTIVE
405414
else:
406415
logger.warning(
@@ -583,20 +592,21 @@ def launch():
583592
_print_arguments(args)
584593

585594
if args.backend == 'auto':
586-
distribute_mode = which_distributed_mode(args)
587-
assert args.backend in [
588-
'gloo', 'nccl', 'bkcl', 'unknown'
589-
] # which_distributed_mode must modify args.backend
595+
distribute_mode = which_distributed_mode(
596+
args) # which_distributed_mode must modify args.backend
590597
else:
591598
assert args.run_mode == 'collective' or args.run_mode == None, "When backend is not 'auto', run mode must be collective"
592599
check_backend(args.backend)
593600
distribute_mode = DistributeMode.COLLECTIVE
594601

595-
block_windows_and_macos(
596-
args.backend) # raise error when using gloo on windows or macos
602+
assert args.backend in ['gloo', 'nccl', 'bkcl', 'unknown']
603+
597604
if args.backend == 'gloo':
598605
logger.warning("launch start with CPUONLY mode")
599606

607+
block_windows_and_macos(
608+
args.backend) # raise error when using gloo on windows or macos
609+
600610
if enable_elastic(args, distribute_mode):
601611
launch_elastic(args, distribute_mode)
602612
return

0 commit comments

Comments
 (0)