@@ -334,7 +334,20 @@ def launch_ps(args, distribute_mode):
334334 return
335335
336336
337+ def infer_backend (args ):
338+ if args .backend != "auto" : return
339+ if fluid .core .is_compiled_with_cuda ():
340+ args .backend = 'nccl'
341+ elif fluid .core .is_compiled_with_npu ():
342+ args .backend = 'unknown'
343+ elif fluid .core .is_compiled_with_xpu ():
344+ args .backend = 'bkcl'
345+ else :
346+ args .backend = 'gloo'
347+
348+
337349def which_distributed_mode (args ):
350+ infer_backend (args ) # modify the args.backend
338351 if args .run_mode is not None :
339352 assert args .run_mode in ["collective" , "ps" , "ps-heter" ]
340353
@@ -368,12 +381,9 @@ def which_distributed_mode(args):
368381
369382 if fluid .core .is_compiled_with_cuda ():
370383 accelerators = fluid .core .get_cuda_device_count ()
371- args .backend = 'nccl'
372384 elif fluid .core .is_compiled_with_npu ():
373- args .backend = 'unknown'
374385 accelerators = fluid .core .get_npu_device_count ()
375386 elif fluid .core .is_compiled_with_xpu ():
376- args .backend = 'bkcl'
377387 accelerators = fluid .core .get_xpu_device_count ()
378388 else :
379389 accelerators = 0
@@ -400,7 +410,6 @@ def which_distributed_mode(args):
400410 But found args.servers not empty, default use ps mode" )
401411 return DistributeMode .PS
402412 else :
403- args .backend = "gloo"
404413 return DistributeMode .COLLECTIVE
405414 else :
406415 logger .warning (
@@ -583,20 +592,21 @@ def launch():
583592 _print_arguments (args )
584593
585594 if args .backend == 'auto' :
586- distribute_mode = which_distributed_mode (args )
587- assert args .backend in [
588- 'gloo' , 'nccl' , 'bkcl' , 'unknown'
589- ] # which_distributed_mode must modify args.backend
595+ distribute_mode = which_distributed_mode (
596+ args ) # which_distributed_mode must modify args.backend
590597 else :
591598 assert args .run_mode == 'collective' or args .run_mode == None , "When backend is not 'auto', run mode must be collective"
592599 check_backend (args .backend )
593600 distribute_mode = DistributeMode .COLLECTIVE
594601
595- block_windows_and_macos (
596- args . backend ) # raise error when using gloo on windows or macos
602+ assert args . backend in [ 'gloo' , 'nccl' , 'bkcl' , 'unknown' ]
603+
597604 if args .backend == 'gloo' :
598605 logger .warning ("launch start with CPUONLY mode" )
599606
607+ block_windows_and_macos (
608+ args .backend ) # raise error when using gloo on windows or macos
609+
600610 if enable_elastic (args , distribute_mode ):
601611 launch_elastic (args , distribute_mode )
602612 return
0 commit comments