@@ -93,30 +93,33 @@ def same_or_split_var(p_name, var_name):
9393 return p_name == var_name or p_name .startswith (var_name + ".block" )
9494
9595
96- def split_dense_variable (var_list ,
97- pserver_count ,
98- min_block_size = 1024 ,
99- max_block_size = 1048576 ):
96+ def split_dense_variable (var_list , service_count , min_block_size = 8192 ):
10097 """
101- We may need to split dense tensor to one or more blocks and put
102- them equally onto parameter server. One block is a sub-tensor
103- aligned by dim[0] of the tensor.
104-
105- We need to have a minimal block size so that the calculations in
106- the parameter server side can gain better performance. By default
107- minimum block size is 1024. The max block size is used to prevent
108- very large blocks that may cause send error.
109- :return: A list of VarBlocks. Each VarBlock specifies a shard of
110- the var.
98+ We may need to split dense tensor to one or more blocks and put
99+ them equally onto parameter server. One block is a sub-tensor
100+ aligned by dim[0] of the tensor.
101+
102+ We need to have a minimal block size so that the calculations in
103+ the parameter server side can gain better performance. By default
104+ minimum block size 8K elements (maybe 16bit or 32bit or 64bit).
105+
106+ Args:
107+ var_list (list): List of variables.
108+ service_count (int): Numel of pserver services. A pserver may have two
109+ or more listening ports.
110+ min_block_size (int): Minimum splitted block size.
111+ Returns:
112+ blocks (list[(varname, block_id, current_block_size)]): A list
113+ of VarBlocks. Each VarBlock specifies a shard of the var.
111114 """
112115 blocks = []
113116 for var in var_list :
114- split_count = pserver_count
117+ split_count = service_count
115118 var_numel = reduce (lambda x , y : x * y , var .shape )
116119 max_pserver_count = int (math .floor (var_numel / float (min_block_size )))
117120 if max_pserver_count == 0 :
118121 max_pserver_count = 1
119- if max_pserver_count < pserver_count :
122+ if max_pserver_count < service_count :
120123 split_count = max_pserver_count
121124 block_size = int (math .ceil (var_numel / float (split_count )))
122125
@@ -270,16 +273,19 @@ def transpile(self,
270273 grad_var_mapping = self ._append_split_op (program , grad_blocks )
271274 param_var_mapping = self ._create_vars_from_blocklist (program ,
272275 param_blocks )
276+
273277 # step3: Add gradients as send op inputs and parameters as send
274278 # op outputs.
275279 send_inputs = []
276280 send_outputs = []
277281 for b in grad_blocks : # append by order
278282 varname , block_id , _ = b .split (":" )
279283 send_inputs .append (grad_var_mapping [varname ][int (block_id )])
284+
280285 for b in param_blocks :
281286 varname , block_id , _ = b .split (":" )
282287 send_outputs .append (param_var_mapping [varname ][int (block_id )])
288+
283289 # let send_op know which endpoint to send which var to, eplist has the same
284290 # order as send_inputs.
285291 eplist = split_method (send_inputs , pserver_endpoints )
@@ -751,9 +757,18 @@ def _create_vars_from_blocklist(self,
751757 Create vars for each split.
752758 NOTE: only grads need to be named for different trainers, use
753759 add_trainer_suffix to rename the grad vars.
754- :return: A dict mapping from original var name to each var split.
760+ Args:
761+ program (ProgramDesc): ProgramDesc which gradients blong.
762+ block_list (list[(varname, block_id, block_size)]): List of gradient blocks.
763+ add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True.
764+ Returns:
765+ var_mapping (dict(varname->[new_varname_variable])):A dict mapping
766+ from original var name to each var split.
755767 """
768+
769+ # varname->[(block_id, current_block_size)]
756770 block_map = dict ()
771+
757772 var_mapping = dict ()
758773 for block_str in block_list :
759774 varname , offset , size = block_str .split (":" )
@@ -824,7 +839,16 @@ def _clone_var(self, block, var, persistable=True):
824839 persistable = persistable )
825840
826841 def _append_split_op (self , program , gradblocks ):
827- # Split variables that need to be split and append respective ops
842+ """
843+ Split variables that need to be split and append respective ops
844+ Args:
845+ program (ProgramDesc): ProgramDesc that gradients blong.
846+ gradblocks (list[(varname, block_id, block_size)]): List of gradient blocks.
847+ Returns:
848+ var_mapping (dict(varname->[new_splitted_variable])):A dict mapping
849+ from original var name to each var split.
850+ """
851+
828852 add_suffix = False
829853 if self .trainer_num > 1 :
830854 add_suffix = True
@@ -1148,6 +1172,12 @@ def _get_lr_ops(self):
11481172 return lr_ops
11491173
11501174 def _get_optimize_pass (self ):
1175+ """
1176+ Get optimizer operators, paramters and gradients from origin_program
1177+ Returns:
1178+ opt_ops (list): optimize operators.
1179+ params_grads (dict): paramter->gradient.
1180+ """
11511181 block = self .origin_program .global_block ()
11521182 opt_ops = []
11531183 params_grads = []
0 commit comments