Skip to content

Commit e5b4dd7

Browse files
authored
[heterps] add fuse_allreduce (#35131)
* heterps:add fuse_allreduce op; test=develop * add program_mode in minimize for pslib mode;test=develop
1 parent 339cb19 commit e5b4dd7

File tree

3 files changed

+284
-9
lines changed

3 files changed

+284
-9
lines changed

python/paddle/distributed/fleet/utils/fs.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,10 +468,17 @@ def __init__(
468468
self._bd_err_re = re.compile(
469469
r'\s?responseErrorMsg\s?\:.*, errorCode\:\s?[0-9]+, path\:')
470470

471-
def _run_cmd(self, cmd, redirect_stderr=False):
471+
def _run_cmd(self, cmd, redirect_stderr=False, retry_times=5):
472472
exe_cmd = "{} -{}".format(self._base_cmd, cmd)
473-
ret, output = core.shell_execute_cmd(exe_cmd, 0, 0, redirect_stderr)
474-
ret = int(ret)
473+
ret = 0
474+
output = None
475+
retry_sleep_second = 3
476+
for x in range(retry_times + 1):
477+
ret, output = core.shell_execute_cmd(exe_cmd, 0, 0, redirect_stderr)
478+
ret = int(ret)
479+
if ret == 0:
480+
break
481+
time.sleep(retry_sleep_second)
475482
if ret == 134:
476483
raise FSShellCmdAborted(cmd)
477484
return ret, output.splitlines()

python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,7 +1091,8 @@ def minimize(self,
10911091
scopes=None,
10921092
startup_programs=None,
10931093
parameter_list=None,
1094-
no_grad_set=None):
1094+
no_grad_set=None,
1095+
program_mode="all_reduce"):
10951096
"""
10961097
minimize a program through loss, loss can be a list in DistributedOptimizer.
10971098
Note that in parameter server mode, a worker will not get anything about optimize_os
@@ -1105,6 +1106,7 @@ def minimize(self,
11051106
in `parameter_list`.
11061107
parameter_list (list): list of Variables to update.
11071108
no_grad_set (set|None): set of Variables should be ignored.
1109+
program_mode (str|"all_reduce"): grad action for grogram when use_ps_gpu.
11081110
Returns:
11091111
tuple: (optimize_ops, params_grads) which are, list of operators appended;
11101112
and list of (param, grad) Variables pair for optimization.
@@ -1139,12 +1141,17 @@ def minimize(self,
11391141
if opt_info["use_ps_gpu"]:
11401142
from paddle.fluid.transpiler.collective import MultiThread
11411143
# check start program
1142-
1144+
if program_mode not in [
1145+
"all_reduce", "fuse_all_reduce", "all_gather"
1146+
]:
1147+
raise ValueError("You should set program_mode in [ all_reduce, \
1148+
fuse_all_reduce, all_gather ]")
11431149
env = self.get_dist_env()
11441150
if not isinstance(losses, list):
11451151
startup_programs = [startup_programs]
11461152
for i in range(0, len(startup_programs)):
1147-
t = MultiThread()
1153+
1154+
t = MultiThread(trans_mode=program_mode)
11481155
start_program = startup_programs[i]
11491156
main_program = programs[i]
11501157
t.transpile(

python/paddle/fluid/transpiler/collective.py

Lines changed: 264 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def transpile(self, startup_program, main_program, rank, endpoints,
6565
self.main_program = default_main_program()
6666

6767
self.nranks = len(endpoints)
68-
if self.nranks == 1 and self.mode != "single_process_multi_thread":
68+
if self.nranks == 1 and self.mode != "single_process_multi_thread" and self.mode != "box":
6969
raise ValueError('the number of endpoints must > 1')
7070

7171
if rank < 0:
@@ -441,9 +441,14 @@ class MultiThread(GradAllReduce):
441441
'''
442442
'''
443443

444-
def __init__(self, nrings=1):
444+
def __init__(self, nrings=1, trans_mode="all_reduce"):
445445
GradAllReduce.__init__(self, nrings)
446-
self.mode = "single_process_multi_thread"
446+
self.mode = "box"
447+
self.trans_mode = trans_mode
448+
self.fuse_grad_size_in_num = 128
449+
gpu_nums = os.getenv("FLAGS_selected_gpus",
450+
"0,1,2,3,4,5,6,7,8").split(",")
451+
self.gpu_num = len(gpu_nums)
447452

448453
def _transpile_startup_program(self):
449454
if len(self.endpoints) > 1:
@@ -460,3 +465,259 @@ def _transpile_startup_program(self):
460465
print("begin to _transpile_startup_program for single-node")
461466
block = self.startup_program.global_block()
462467
block.append_op(type='c_comm_init_all', attrs={'ring_id': 0})
468+
469+
def _transpile_main_program(self):
470+
self._insert_scale_loss_grad_ops()
471+
if self.trans_mode == "all_gather":
472+
print("begin to transpile in all-gather mode")
473+
self.allgather_ranks = self.nranks * self.gpu_num
474+
self._insert_allgather_ops()
475+
self._update_adam_ops()
476+
elif self.trans_mode == "fuse_all_reduce":
477+
print("begin to transpile in fuse all-reduce mode")
478+
self._insert_fuse_allreduce_ops()
479+
else:
480+
print("begin to transpile in all-reduce mode")
481+
self._insert_allreduce_ops()
482+
483+
def _insert_allgather_ops(self):
484+
"""
485+
insert allgather op to the main_program
486+
"""
487+
block = self.main_program.global_block()
488+
ring_id = -1
489+
grad = None
490+
for idx, op in reversed(list(enumerate(block.ops))):
491+
if self._is_backward_op(op) and \
492+
self.op_role_var_key in op.attr_names:
493+
op_role_var = op.all_attrs()[self.op_role_var_key]
494+
if len(op_role_var) == 0:
495+
continue
496+
assert len(op_role_var) % 2 == 0
497+
498+
offset = idx
499+
for i in range(0, len(op_role_var), 2):
500+
param = block.vars[op_role_var[i]]
501+
new_grad_var = block.create_var(
502+
name=op_role_var[i] + "_allgather",
503+
shape=[self.allgather_ranks] + list(param.shape),
504+
persistable=False,
505+
dtype=core.VarDesc.VarType.FP32,
506+
stop_gradient=True)
507+
grad = block.vars[op_role_var[i + 1]]
508+
if param.is_distributed: # no need to care: used in PLSC
509+
continue
510+
511+
if offset == idx:
512+
offset += 1
513+
block._insert_op(
514+
offset,
515+
type='c_sync_calc_stream',
516+
inputs={'X': grad},
517+
outputs={'Out': grad},
518+
attrs={self.op_role_key: OpRole.Backward})
519+
offset += 1
520+
521+
# As we search ops reversedly, we should insert c_allgather
522+
# op in the same way to keep the ring_id alternate
523+
ring_id = (ring_id + 1) % self.nrings
524+
block._insert_op(
525+
offset,
526+
type='c_allgather',
527+
inputs={'X': grad},
528+
outputs={'Out': new_grad_var},
529+
attrs={
530+
'nranks': self.allgather_ranks,
531+
'ring_id': ring_id,
532+
self.op_role_key: OpRole.Backward
533+
})
534+
535+
if grad is None:
536+
return
537+
538+
for idx, op in enumerate(block.ops):
539+
if self._is_optimizer_op(op):
540+
for ring_id in range(self.nrings):
541+
block._insert_op(
542+
idx + ring_id,
543+
type='c_sync_comm_stream',
544+
inputs={'X': grad},
545+
outputs={'Out': grad},
546+
attrs={
547+
'ring_id': ring_id,
548+
self.op_role_key: OpRole.Backward
549+
})
550+
break
551+
552+
def _update_adam_ops(self):
553+
"""
554+
remove the original adam op, and add new adam ops
555+
"""
556+
block = self.main_program.global_block()
557+
558+
for idx, op in reversed(list(enumerate(block.ops))):
559+
if self._is_optimizer_op(op):
560+
offset = idx
561+
if op.type != 'adam' and op.type != 'lamb': # filter out scale op
562+
continue
563+
param_name = op.input("Param")[0]
564+
inputs = {
565+
"Param": block.vars[op.input("Param")[0]],
566+
"LearningRate": block.vars[op.input("LearningRate")[0]],
567+
"Moment1": block.vars[op.input("Moment1")[0]],
568+
"Moment2": block.vars[op.input("Moment2")[0]],
569+
"Beta1Pow": block.vars[op.input("Beta1Pow")[0]],
570+
"Beta2Pow": block.vars[op.input("Beta2Pow")[0]]
571+
}
572+
outputs = {
573+
"ParamOut": block.vars[op.output("ParamOut")[0]],
574+
"Moment1Out": block.vars[op.output("Moment1Out")[0]],
575+
"Moment2Out": block.vars[op.output("Moment2Out")[0]],
576+
"Beta1PowOut": block.vars[op.output("Beta1PowOut")[0]],
577+
"Beta2PowOut": block.vars[op.output("Beta2PowOut")[0]]
578+
}
579+
attrs = {
580+
"epsilon": op.attr('epsilon'),
581+
"beta1": op.attr('beta1'),
582+
"beta2": op.attr('beta2'),
583+
"lazy_mode": op.attr('lazy_mode'),
584+
"min_row_size_to_use_multithread":
585+
op.attr('min_row_size_to_use_multithread')
586+
}
587+
split_vars = [
588+
block.create_var(
589+
name=param_name + "_" + str(i),
590+
shape=block.vars[op.input("Param")[0]].shape,
591+
persistable=False,
592+
dtype=core.VarDesc.VarType.FP32,
593+
stop_gradient=True) for i in range(self.allgather_ranks)
594+
]
595+
block._insert_op(
596+
offset,
597+
type="split",
598+
inputs={
599+
'X': block.vars[op.input("Param")[0] + "_allgather"]
600+
},
601+
outputs={'Out': split_vars},
602+
attrs={'num': self.allgather_ranks,
603+
'axis': 0})
604+
offset += 1
605+
606+
for i in range(self.allgather_ranks):
607+
inputs["Grad"] = split_vars[i]
608+
block._insert_op(
609+
offset,
610+
type=op.type,
611+
inputs=inputs,
612+
outputs=outputs,
613+
attrs=attrs)
614+
offset += 1
615+
# remove the original adam op
616+
block._remove_op(offset)
617+
618+
def _insert_fuse_allreduce_ops(self):
619+
"""
620+
insert coalesce_tensor and all reduce ops
621+
"""
622+
block = self.main_program.global_block()
623+
ring_id = 0 % self.nrings
624+
grad = None
625+
param_grads = []
626+
# find all grad params
627+
for op in reversed(block.ops):
628+
if self._is_backward_op(op) and \
629+
self.op_role_var_key in op.attr_names:
630+
op_role_var = op.all_attrs()[self.op_role_var_key]
631+
if len(op_role_var) == 0:
632+
continue
633+
assert len(op_role_var) % 2 == 0, "vars need to be one param var followed by one grad var, " \
634+
"but got odd number of vars"
635+
for i in range(0, len(op_role_var), 2):
636+
param_name = op_role_var[i]
637+
param = block.var(param_name)
638+
grad_name = op_role_var[i + 1]
639+
grad = block.var(grad_name)
640+
if param.is_distributed:
641+
continue
642+
param_grads.append(grad)
643+
if grad is None:
644+
return
645+
646+
segments = []
647+
last_dtype = None
648+
# split the grad based on dtype and fused size
649+
for var in param_grads:
650+
if len(segments) == 0 \
651+
or len(segments[-1]) == self.fuse_grad_size_in_num \
652+
or var.dtype != last_dtype:
653+
segments.append([var])
654+
last_dtype = var.dtype
655+
else:
656+
segments[-1].append(var)
657+
658+
fused_vars = []
659+
for idx, op in enumerate(block.ops):
660+
if self._is_optimizer_op(op):
661+
for segment in segments:
662+
# insert coalesce tensor
663+
tmp_var = block.create_var(
664+
name=unique_name.generate('FusedOutput_{}'.format(
665+
segment[0].name)),
666+
dtype=segment[0].dtype,
667+
persistable=False,
668+
stop_gradient=True)
669+
fused_vars.append(tmp_var)
670+
block._insert_op(
671+
idx,
672+
type="coalesce_tensor",
673+
inputs={"Input": segment},
674+
outputs={"Output": segment,
675+
"FusedOutput": tmp_var},
676+
attrs={
677+
"copy_data": True,
678+
"use_align": True,
679+
"dtype": segment[0].dtype,
680+
self.op_role_key: OpRole.Backward
681+
})
682+
break
683+
684+
# insert the allreduce_sum op
685+
for idx, op in enumerate(block.ops):
686+
if self._is_optimizer_op(op):
687+
for fused_var in fused_vars:
688+
block._insert_op(
689+
idx,
690+
type='c_allreduce_sum',
691+
inputs={'X': fused_var},
692+
outputs={'Out': fused_var},
693+
attrs={
694+
'ring_id': ring_id,
695+
'use_calc_stream': False,
696+
self.op_role_key: OpRole.Backward
697+
})
698+
block._insert_op(
699+
idx,
700+
type='c_sync_calc_stream',
701+
inputs={'X': fused_var},
702+
outputs={'Out': fused_var},
703+
attrs={self.op_role_key: OpRole.Backward})
704+
break
705+
706+
if len(fused_vars) == 0:
707+
block._sync_with_cpp()
708+
return
709+
710+
# insert the sync comm op
711+
for idx, op in enumerate(block.ops):
712+
if self._is_optimizer_op(op):
713+
block._insert_op(
714+
idx,
715+
type='c_sync_comm_stream',
716+
inputs={'X': fused_vars[0]},
717+
outputs={'Out': fused_vars[0]},
718+
attrs={
719+
'ring_id': ring_id,
720+
self.op_role_key: OpRole.Backward
721+
})
722+
break
723+
block._sync_with_cpp()

0 commit comments

Comments
 (0)