2929 core .VarDesc .VarType .BOOL : 1
3030}
3131
32- sub_block_ops = [
32+ SUB_BLOCK_OPS = [
3333 "while" , "while_grad" , "parallel_do" , "parallel_do_grad" ,
3434 "conditional_block" , "conditional_block_grad"
3535]
3636
37+ SUB_BLOCK_PAIR = [("while" , "while_grad" ), ("parallel_do" , "parallel_do_grad" ),
38+ ("conditional_block" , "conditional_block_grad" )]
39+
3740PRINT_LOG = False
3841
3942
4043class ControlFlowGraph (object ):
41- def __init__ (self , Program , ops , forward_num , skip_opt ):
42- self ._program = Program
44+ def __init__ (self , program , ops , forward_num , skip_opt ):
45+ self ._program = program
4346 self ._ops = ops
4447 self ._forward_num = forward_num
4548 self ._successors = defaultdict (set )
@@ -51,14 +54,19 @@ def __init__(self, Program, ops, forward_num, skip_opt):
5154 self ._skip_opt = skip_opt
5255
5356 def _add_connections (self , connections ):
57+ """Populates _successors and _presuccessors for two neighbor nodes."""
5458 for node1 , node2 in connections :
5559 self ._add (node1 , node2 )
5660
5761 def _add (self , node1 , node2 ):
5862 self ._successors [node1 ].add (node2 )
5963 self ._presuccessors [node2 ].add (node1 )
6064
65+ # TODO(panyx0718): We need to have a unified way of building intermediate
66+ # representation.
6167 def _build_graph (self ):
68+ """Build a graph based on op sequence.
69+ """
6270 self .op_size = len (self ._ops )
6371 op_node_connections = [(i , i + 1 ) for i in range (self .op_size - 1 )]
6472 self ._add_connections (op_node_connections )
@@ -82,22 +90,23 @@ def _update_graph(self, old_name, new_name, begin_idx=0):
8290 self ._live_out [i ].add (new_name )
8391
8492 def _reach_fixed_point (self , live_in , live_out ):
93+ """Check if the liveness set has stablized."""
8594 if len (live_in ) != len (self ._live_in ):
8695 return False
8796 if len (live_out ) != len (self ._live_out ):
8897 return False
8998 for i in range (self .op_size ):
90- if live_in [i ] != self ._live_in [i ]:
91- return False
92- for i in range (self .op_size ):
93- if live_out [i ] != self ._live_out [i ]:
99+ if (live_in [i ] != self ._live_in [i ] or
100+ live_out [i ] != self ._live_out [i ]):
94101 return False
95102 return True
96103
97104 def _dataflow_analyze (self ):
98105 self ._build_graph ()
99106 live_in = defaultdict (set )
100107 live_out = defaultdict (set )
108+ # Repeatedly apply liveness updates until the algorithm stablize
109+ # on a complete set live input vars and live output vars.
101110 while True :
102111 for i in range (self .op_size , 0 , - 1 ):
103112 live_in [i ] = set (self ._live_in [i ])
@@ -141,6 +150,8 @@ def _check_var_validity(self, block_desc, x, is_forward):
141150 return False
142151 return True
143152
153+ # TODO(panyx0718): This needs to be less hacky. It seems memory optimization
154+ # doesn't consider vars copied between cpu and gpu.
144155 def _update_skip_opt_set (self ):
145156 for i in range (self .op_size ):
146157 op = self ._ops [i ]
@@ -154,7 +165,7 @@ def release_memory(self):
154165 bwd_id = 0
155166 for i in range (self .op_size ):
156167 op = self ._ops [i ]
157- if op .type () in sub_block_ops :
168+ if op .type () in SUB_BLOCK_OPS :
158169 continue
159170 block_desc = op .block ()
160171 is_forward = i < self ._forward_num
@@ -177,24 +188,25 @@ def memory_optimize(self, level=0):
177188 def compare_shape (x_shape , cache_shape , opt_level ):
178189 if opt_level == 0 :
179190 return x_shape == cache_shape
180- if opt_level == 1 :
191+ elif opt_level == 1 :
181192 if (x_shape [0 ] == - 1 ) ^ (cache_shape [0 ] == - 1 ):
182193 return False
183194 x_size = abs (reduce (lambda x , y : x * y , x_shape ))
184195 cache_size = abs (reduce (lambda x , y : x * y , cache_shape ))
185196 if x_size <= cache_size :
186197 return True
198+ else :
199+ raise ValueError ("only support opt_level 0 or 1." )
187200 return False
188201
189202 self ._dataflow_analyze ()
190203 self ._update_skip_opt_set ()
191204 self .pool = []
192205 for i in range (self .op_size ):
193206 op = self ._ops [i ]
194- if op .type () in sub_block_ops :
207+ if op .type () in SUB_BLOCK_OPS :
195208 continue
196209 block_desc = op .block ()
197- self .current_block_desc = block_desc
198210 is_forward = i < self ._forward_num
199211 if self .pool :
200212 defs_can_optimize = filter (
@@ -211,37 +223,40 @@ def compare_shape(x_shape, cache_shape, opt_level):
211223 for index , cache_pair in enumerate (self .pool ):
212224 cache_var = cache_pair [0 ]
213225 cache_shape = cache_pair [1 ]
214- if compare_shape (x_shape , cache_shape , level ):
215- if self ._has_var (block_desc , cache_var , is_forward ):
216- x_dtype = self ._find_var (block_desc , x ,
217- is_forward ).dtype ()
218- cache_dtype = self ._find_var (
219- block_desc , cache_var , is_forward ).dtype ()
220- # TODO(qijun): actually, we should compare dtype_to_size[x_dtype]
221- # and dtype_to_size[cache_dtype]
222- if x_dtype == cache_dtype :
223- if PRINT_LOG :
224- print (
225- ("Hit Cache !!!! cache pool index "
226- "is %d, var name is %s, "
227- "cached var name is %s, "
228- "var shape is %s " ) %
229- (index , x , cache_var ,
230- str (cache_shape )))
231- self .pool .pop (index )
232- if x == cache_var :
233- break
234- _rename_arg_ (
235- self ._ops , x , cache_var , begin_idx = i )
236- self ._program .block (block_desc .id ).var (
237- str (x )).desc = self ._find_var (
238- block_desc , cache_var , is_forward )
239- self ._update_graph (
240- x , cache_var , begin_idx = i )
241- break
242-
243- in_diff , out_diff = self ._get_diff (self ._live_in [i ],
244- self ._live_out [i ])
226+ if not compare_shape (x_shape , cache_shape , level ):
227+ continue
228+
229+ if not self ._has_var (block_desc , cache_var , is_forward ):
230+ continue
231+
232+ x_dtype = self ._find_var (block_desc , x ,
233+ is_forward ).dtype ()
234+ cache_dtype = self ._find_var (block_desc , cache_var ,
235+ is_forward ).dtype ()
236+ # TODO(qijun): actually, we should compare
237+ # dtype_to_size[x_dtype] and dtype_to_size[cache_dtype]
238+ if x_dtype != cache_dtype :
239+ continue
240+
241+ if PRINT_LOG :
242+ print (("Hit Cache !!!! cache pool index "
243+ "is %d, var name is %s, "
244+ "cached var name is %s, "
245+ "var shape is %s " ) % (index , x , cache_var ,
246+ str (cache_shape )))
247+ self .pool .pop (index )
248+ if x == cache_var :
249+ break
250+ # Rename the var to the cache var already with
251+ # memory allocated in order to reuse the memory.
252+ _rename_arg_ (self ._ops , x , cache_var , begin_idx = i )
253+ self ._program .block (block_desc .id ).var (str (
254+ x )).desc = self ._find_var (block_desc , cache_var ,
255+ is_forward )
256+ self ._update_graph (x , cache_var , begin_idx = i )
257+ break
258+
259+ in_diff , _ = self ._get_diff (self ._live_in [i ], self ._live_out [i ])
245260 can_optimize = filter (
246261 lambda x : self ._check_var_validity (block_desc , x , is_forward ),
247262 in_diff )
@@ -252,6 +267,19 @@ def compare_shape(x_shape, cache_shape, opt_level):
252267
253268
254269def _process_sub_block_pair (pdesc , sub_block_pair ):
270+ """Creates a list of tuple each of which tracks info of a subblock.
271+
272+ Note: this function doesn't handle nested subblocks yet.
273+ TODO(panyx0718): assert if case nested subblocks happen.
274+
275+ :param pdesc: ProgramDesc.
276+ :param sub_block_pair: A list op pairs. Each op pair is the forward
277+ op and backward op. The ops in the list are special that they contain
278+ a subblock of ops.
279+ :return: A list of tuples, each tuple is (all ops in a subblock pair
280+ including forward and backward, number of forward ops,
281+ all output args names of the ops in the subblock pairs).
282+ """
255283 ops_list = []
256284 block_desc = pdesc .block (0 )
257285 op_size = block_desc .op_size ()
@@ -308,6 +336,11 @@ def _process_sub_block_pair(pdesc, sub_block_pair):
308336
309337
310338def _get_cfgs (input_program ):
339+ """Process each block and create ControlFlowGraph for each of them.
340+
341+ :param input_program: Program object.
342+ :return: A list of ControlFlowGraph, each corresponds to a block.
343+ """
311344 ops_list = []
312345 pdesc = input_program .get_desc ()
313346 block_desc = pdesc .block (0 )
@@ -316,11 +349,8 @@ def _get_cfgs(input_program):
316349 ops_list .append (
317350 ([block_desc .op (i ) for i in range (op_size )], op_size , set ()))
318351
319- sub_block_pair = [("while" , "while_grad" ), ("parallel_do" ,
320- "parallel_do_grad" ),
321- ("conditional_block" , "conditional_block_grad" )]
322-
323- ops_list .extend (_process_sub_block_pair (pdesc , sub_block_pair ))
352+ # Only process one level of nested subblock.
353+ ops_list .extend (_process_sub_block_pair (pdesc , SUB_BLOCK_PAIR ))
324354
325355 cfgs = [
326356 ControlFlowGraph (input_program , ops , forward_num , skip_opt )
@@ -330,6 +360,17 @@ def _get_cfgs(input_program):
330360
331361
332362def memory_optimize (input_program , print_log = False , level = 0 ):
363+ """Optimize memory by reusing var memory.
364+
365+ Note: it doesn't not support subblock nested in subblock.
366+
367+ :param input_program: Input Program
368+ :param print_log: whether to print debug log.
369+ :param level: If level=0, reuse if the shape is completely equal, o
370+ :return:
371+ """
372+ if level != 0 and level != 1 :
373+ raise ValueError ("only support opt_level 0 or 1." )
333374 global PRINT_LOG
334375 PRINT_LOG = print_log
335376 cfgs = _get_cfgs (input_program )
0 commit comments