@@ -65,7 +65,7 @@ def transpile(self, startup_program, main_program, rank, endpoints,
6565 self .main_program = default_main_program ()
6666
6767 self .nranks = len (endpoints )
68- if self .nranks == 1 and self .mode != "single_process_multi_thread" :
68+ if self .nranks == 1 and self .mode != "single_process_multi_thread" and self . mode != "box" :
6969 raise ValueError ('the number of endpoints must > 1' )
7070
7171 if rank < 0 :
@@ -441,9 +441,14 @@ class MultiThread(GradAllReduce):
441441 '''
442442 '''
443443
444- def __init__ (self , nrings = 1 ):
444+ def __init__ (self , nrings = 1 , trans_mode = "all_reduce" ):
445445 GradAllReduce .__init__ (self , nrings )
446- self .mode = "single_process_multi_thread"
446+ self .mode = "box"
447+ self .trans_mode = trans_mode
448+ self .fuse_grad_size_in_num = 128
449+ gpu_nums = os .getenv ("FLAGS_selected_gpus" ,
450+ "0,1,2,3,4,5,6,7,8" ).split ("," )
451+ self .gpu_num = len (gpu_nums )
447452
448453 def _transpile_startup_program (self ):
449454 if len (self .endpoints ) > 1 :
@@ -460,3 +465,259 @@ def _transpile_startup_program(self):
460465 print ("begin to _transpile_startup_program for single-node" )
461466 block = self .startup_program .global_block ()
462467 block .append_op (type = 'c_comm_init_all' , attrs = {'ring_id' : 0 })
468+
469+ def _transpile_main_program (self ):
470+ self ._insert_scale_loss_grad_ops ()
471+ if self .trans_mode == "all_gather" :
472+ print ("begin to transpile in all-gather mode" )
473+ self .allgather_ranks = self .nranks * self .gpu_num
474+ self ._insert_allgather_ops ()
475+ self ._update_adam_ops ()
476+ elif self .trans_mode == "fuse_all_reduce" :
477+ print ("begin to transpile in fuse all-reduce mode" )
478+ self ._insert_fuse_allreduce_ops ()
479+ else :
480+ print ("begin to transpile in all-reduce mode" )
481+ self ._insert_allreduce_ops ()
482+
483+ def _insert_allgather_ops (self ):
484+ """
485+ insert allgather op to the main_program
486+ """
487+ block = self .main_program .global_block ()
488+ ring_id = - 1
489+ grad = None
490+ for idx , op in reversed (list (enumerate (block .ops ))):
491+ if self ._is_backward_op (op ) and \
492+ self .op_role_var_key in op .attr_names :
493+ op_role_var = op .all_attrs ()[self .op_role_var_key ]
494+ if len (op_role_var ) == 0 :
495+ continue
496+ assert len (op_role_var ) % 2 == 0
497+
498+ offset = idx
499+ for i in range (0 , len (op_role_var ), 2 ):
500+ param = block .vars [op_role_var [i ]]
501+ new_grad_var = block .create_var (
502+ name = op_role_var [i ] + "_allgather" ,
503+ shape = [self .allgather_ranks ] + list (param .shape ),
504+ persistable = False ,
505+ dtype = core .VarDesc .VarType .FP32 ,
506+ stop_gradient = True )
507+ grad = block .vars [op_role_var [i + 1 ]]
508+ if param .is_distributed : # no need to care: used in PLSC
509+ continue
510+
511+ if offset == idx :
512+ offset += 1
513+ block ._insert_op (
514+ offset ,
515+ type = 'c_sync_calc_stream' ,
516+ inputs = {'X' : grad },
517+ outputs = {'Out' : grad },
518+ attrs = {self .op_role_key : OpRole .Backward })
519+ offset += 1
520+
521+ # As we search ops reversedly, we should insert c_allgather
522+ # op in the same way to keep the ring_id alternate
523+ ring_id = (ring_id + 1 ) % self .nrings
524+ block ._insert_op (
525+ offset ,
526+ type = 'c_allgather' ,
527+ inputs = {'X' : grad },
528+ outputs = {'Out' : new_grad_var },
529+ attrs = {
530+ 'nranks' : self .allgather_ranks ,
531+ 'ring_id' : ring_id ,
532+ self .op_role_key : OpRole .Backward
533+ })
534+
535+ if grad is None :
536+ return
537+
538+ for idx , op in enumerate (block .ops ):
539+ if self ._is_optimizer_op (op ):
540+ for ring_id in range (self .nrings ):
541+ block ._insert_op (
542+ idx + ring_id ,
543+ type = 'c_sync_comm_stream' ,
544+ inputs = {'X' : grad },
545+ outputs = {'Out' : grad },
546+ attrs = {
547+ 'ring_id' : ring_id ,
548+ self .op_role_key : OpRole .Backward
549+ })
550+ break
551+
552+ def _update_adam_ops (self ):
553+ """
554+ remove the original adam op, and add new adam ops
555+ """
556+ block = self .main_program .global_block ()
557+
558+ for idx , op in reversed (list (enumerate (block .ops ))):
559+ if self ._is_optimizer_op (op ):
560+ offset = idx
561+ if op .type != 'adam' and op .type != 'lamb' : # filter out scale op
562+ continue
563+ param_name = op .input ("Param" )[0 ]
564+ inputs = {
565+ "Param" : block .vars [op .input ("Param" )[0 ]],
566+ "LearningRate" : block .vars [op .input ("LearningRate" )[0 ]],
567+ "Moment1" : block .vars [op .input ("Moment1" )[0 ]],
568+ "Moment2" : block .vars [op .input ("Moment2" )[0 ]],
569+ "Beta1Pow" : block .vars [op .input ("Beta1Pow" )[0 ]],
570+ "Beta2Pow" : block .vars [op .input ("Beta2Pow" )[0 ]]
571+ }
572+ outputs = {
573+ "ParamOut" : block .vars [op .output ("ParamOut" )[0 ]],
574+ "Moment1Out" : block .vars [op .output ("Moment1Out" )[0 ]],
575+ "Moment2Out" : block .vars [op .output ("Moment2Out" )[0 ]],
576+ "Beta1PowOut" : block .vars [op .output ("Beta1PowOut" )[0 ]],
577+ "Beta2PowOut" : block .vars [op .output ("Beta2PowOut" )[0 ]]
578+ }
579+ attrs = {
580+ "epsilon" : op .attr ('epsilon' ),
581+ "beta1" : op .attr ('beta1' ),
582+ "beta2" : op .attr ('beta2' ),
583+ "lazy_mode" : op .attr ('lazy_mode' ),
584+ "min_row_size_to_use_multithread" :
585+ op .attr ('min_row_size_to_use_multithread' )
586+ }
587+ split_vars = [
588+ block .create_var (
589+ name = param_name + "_" + str (i ),
590+ shape = block .vars [op .input ("Param" )[0 ]].shape ,
591+ persistable = False ,
592+ dtype = core .VarDesc .VarType .FP32 ,
593+ stop_gradient = True ) for i in range (self .allgather_ranks )
594+ ]
595+ block ._insert_op (
596+ offset ,
597+ type = "split" ,
598+ inputs = {
599+ 'X' : block .vars [op .input ("Param" )[0 ] + "_allgather" ]
600+ },
601+ outputs = {'Out' : split_vars },
602+ attrs = {'num' : self .allgather_ranks ,
603+ 'axis' : 0 })
604+ offset += 1
605+
606+ for i in range (self .allgather_ranks ):
607+ inputs ["Grad" ] = split_vars [i ]
608+ block ._insert_op (
609+ offset ,
610+ type = op .type ,
611+ inputs = inputs ,
612+ outputs = outputs ,
613+ attrs = attrs )
614+ offset += 1
615+ # remove the original adam op
616+ block ._remove_op (offset )
617+
618+ def _insert_fuse_allreduce_ops (self ):
619+ """
620+ insert coalesce_tensor and all reduce ops
621+ """
622+ block = self .main_program .global_block ()
623+ ring_id = 0 % self .nrings
624+ grad = None
625+ param_grads = []
626+ # find all grad params
627+ for op in reversed (block .ops ):
628+ if self ._is_backward_op (op ) and \
629+ self .op_role_var_key in op .attr_names :
630+ op_role_var = op .all_attrs ()[self .op_role_var_key ]
631+ if len (op_role_var ) == 0 :
632+ continue
633+ assert len (op_role_var ) % 2 == 0 , "vars need to be one param var followed by one grad var, " \
634+ "but got odd number of vars"
635+ for i in range (0 , len (op_role_var ), 2 ):
636+ param_name = op_role_var [i ]
637+ param = block .var (param_name )
638+ grad_name = op_role_var [i + 1 ]
639+ grad = block .var (grad_name )
640+ if param .is_distributed :
641+ continue
642+ param_grads .append (grad )
643+ if grad is None :
644+ return
645+
646+ segments = []
647+ last_dtype = None
648+ # split the grad based on dtype and fused size
649+ for var in param_grads :
650+ if len (segments ) == 0 \
651+ or len (segments [- 1 ]) == self .fuse_grad_size_in_num \
652+ or var .dtype != last_dtype :
653+ segments .append ([var ])
654+ last_dtype = var .dtype
655+ else :
656+ segments [- 1 ].append (var )
657+
658+ fused_vars = []
659+ for idx , op in enumerate (block .ops ):
660+ if self ._is_optimizer_op (op ):
661+ for segment in segments :
662+ # insert coalesce tensor
663+ tmp_var = block .create_var (
664+ name = unique_name .generate ('FusedOutput_{}' .format (
665+ segment [0 ].name )),
666+ dtype = segment [0 ].dtype ,
667+ persistable = False ,
668+ stop_gradient = True )
669+ fused_vars .append (tmp_var )
670+ block ._insert_op (
671+ idx ,
672+ type = "coalesce_tensor" ,
673+ inputs = {"Input" : segment },
674+ outputs = {"Output" : segment ,
675+ "FusedOutput" : tmp_var },
676+ attrs = {
677+ "copy_data" : True ,
678+ "use_align" : True ,
679+ "dtype" : segment [0 ].dtype ,
680+ self .op_role_key : OpRole .Backward
681+ })
682+ break
683+
684+ # insert the allreduce_sum op
685+ for idx , op in enumerate (block .ops ):
686+ if self ._is_optimizer_op (op ):
687+ for fused_var in fused_vars :
688+ block ._insert_op (
689+ idx ,
690+ type = 'c_allreduce_sum' ,
691+ inputs = {'X' : fused_var },
692+ outputs = {'Out' : fused_var },
693+ attrs = {
694+ 'ring_id' : ring_id ,
695+ 'use_calc_stream' : False ,
696+ self .op_role_key : OpRole .Backward
697+ })
698+ block ._insert_op (
699+ idx ,
700+ type = 'c_sync_calc_stream' ,
701+ inputs = {'X' : fused_var },
702+ outputs = {'Out' : fused_var },
703+ attrs = {self .op_role_key : OpRole .Backward })
704+ break
705+
706+ if len (fused_vars ) == 0 :
707+ block ._sync_with_cpp ()
708+ return
709+
710+ # insert the sync comm op
711+ for idx , op in enumerate (block .ops ):
712+ if self ._is_optimizer_op (op ):
713+ block ._insert_op (
714+ idx ,
715+ type = 'c_sync_comm_stream' ,
716+ inputs = {'X' : fused_vars [0 ]},
717+ outputs = {'Out' : fused_vars [0 ]},
718+ attrs = {
719+ 'ring_id' : ring_id ,
720+ self .op_role_key : OpRole .Backward
721+ })
722+ break
723+ block ._sync_with_cpp ()
0 commit comments