1- from paddle .v2 .framework .layer_helper import LayerHelper
1+ from paddle .v2 .framework .layer_helper import LayerHelper , unique_name
22import paddle .v2 .framework .core as core
3- from paddle .v2 .framework .framework import OpProtoHolder , Variable
3+ from paddle .v2 .framework .framework import OpProtoHolder , Variable , Program
44import re
55
66__all__ = [
7- 'fc' , 'data' , 'cross_entropy' , 'conv2d' , 'pool2d' , 'embedding' , 'concat'
7+ 'fc' , 'data' , 'cross_entropy' , 'conv2d' , 'pool2d' , 'embedding' , 'concat' ,
8+ 'StaticRNN'
89]
910
1011
@@ -26,7 +27,9 @@ def fc(input,
2627 mul_results = []
2728 for input_var , param_attr in helper .iter_inputs_and_params ():
2829 input_shape = input_var .shape
29- param_shape = list (input_shape [num_flatten_dims :]) + [size ]
30+ param_shape = [
31+ reduce (lambda a , b : a * b , input_shape [num_flatten_dims :], 1 )
32+ ] + [size ]
3033
3134 w = helper .create_parameter (
3235 attr = param_attr , shape = param_shape , dtype = dtype )
@@ -38,10 +41,8 @@ def fc(input,
3841 "Y" : w ,
3942 },
4043 outputs = {"Out" : tmp },
41- attrs = {
42- 'x_num_col_dims' : num_flatten_dims ,
43- 'y_num_col_dims' : len (input_shape ) - num_flatten_dims
44- })
44+ attrs = {'x_num_col_dims' : num_flatten_dims ,
45+ 'y_num_col_dims' : 1 })
4546 mul_results .append (tmp )
4647
4748 # sum
@@ -273,3 +274,170 @@ def pool2d(input,
273274 })
274275
275276 return pool_out
277+
278+
279+ class BlockGuard (object ):
280+ """
281+ BlockGuard used to create sub-block in program by using Python `with`
282+ keyword.
283+ """
284+
285+ def __init__ (self , program ):
286+ if not isinstance (program , Program ):
287+ raise TypeError ("BlockGuard takes a program" )
288+ self .program = program
289+
290+ def __enter__ (self ):
291+ self .program .create_block ()
292+
293+ def __exit__ (self , exc_type , exc_val , exc_tb ):
294+ self .program .rollback ()
295+ if exc_type is not None :
296+ return False # re-raise exception
297+ return True
298+
299+
300+ class StaticRNNGuard (BlockGuard ):
301+ def __init__ (self , rnn ):
302+ if not isinstance (rnn , StaticRNN ):
303+ raise TypeError ("StaticRNNGuard takes an StaticRNN" )
304+ super (StaticRNNGuard , self ).__init__ (rnn .helper .program )
305+ self .rnn = rnn
306+
307+ def __enter__ (self ):
308+ self .rnn .status = StaticRNN .IN_RNN_BLOCK
309+ return super (StaticRNNGuard , self ).__enter__ ()
310+
311+ def __exit__ (self , exc_type , exc_val , exc_tb ):
312+ self .rnn .status = StaticRNN .AFTER_RNN_BLOCK
313+ self .rnn .complete_rnn_op ()
314+ return super (StaticRNNGuard , self ).__exit__ (exc_type , exc_val , exc_tb )
315+
316+
317+ class StaticRNNMemoryLink (object ):
318+ """
319+ :param init: the initial variable for Memory
320+ :type init: Variable
321+ :param pre_mem: the memory variable in previous time step
322+ :type pre_mem: Variable
323+ :param mem: the memory variable in current time step
324+ :type mem: Variable
325+ """
326+
327+ def __init__ (self , init , pre_mem , mem = None ):
328+ self .init = init
329+ self .pre_mem = pre_mem
330+ self .mem = mem
331+
332+
333+ class StaticRNN (object ):
334+ BEFORE_RNN_BLOCK = 0
335+ IN_RNN_BLOCK = 1
336+ AFTER_RNN_BLOCK = 2
337+
338+ def __init__ (self , name = None , program = None ):
339+ self .helper = LayerHelper ("static_rnn" , name = name , program = program )
340+ self .memories = {} # memory map, from pre_mem.name --> MemoryLink
341+ self .inputs = [] # input variable list in current block
342+ self .outputs = [] # output variable list in parent block
343+ self .status = StaticRNN .BEFORE_RNN_BLOCK # status flag.
344+ # sequence length, since it is a static RNN, sequence length are fixed.
345+ self .seq_len = None
346+
347+ def step (self ):
348+ return StaticRNNGuard (self )
349+
350+ def _assert_in_rnn_block_ (self , method ):
351+ if self .status != StaticRNN .IN_RNN_BLOCK :
352+ raise ValueError ("You must invoke {0} in rnn block" .format (method ))
353+
354+ def memory (self , init = None , shape = None , dtype = None , init_value = 0 ):
355+ self ._assert_in_rnn_block_ ('memory' )
356+ if init is None :
357+ if shape is None or dtype is None :
358+ raise ValueError (
359+ "if init is None, memory at least need shape and dtype" )
360+ parent_block = self .parent_block ()
361+ var_name = unique_name ("@" .join ([self .helper .name , "memory_boot" ]))
362+ boot_var = parent_block .create_var (
363+ name = var_name , shape = shape , dtype = dtype , persistable = False )
364+
365+ parent_block .append_op (
366+ type = "fill_constant" ,
367+ inputs = {},
368+ outputs = {'Out' : [boot_var ]},
369+ attrs = {
370+ 'value' : init_value ,
371+ 'shape' : boot_var .shape ,
372+ 'data_type' : boot_var .data_type
373+ })
374+
375+ return self .memory (init = boot_var )
376+ else :
377+ pre_mem = self .helper .create_variable (
378+ name = unique_name ("@" .join ([self .helper .name , "mem" ])),
379+ dtype = init .data_type ,
380+ shape = init .shape )
381+ self .memories [pre_mem .name ] = StaticRNNMemoryLink (
382+ init = init , pre_mem = pre_mem )
383+ return pre_mem
384+
385+ def step_input (self , x ):
386+ self ._assert_in_rnn_block_ ('step_input' )
387+ if not isinstance (x , Variable ):
388+ raise TypeError ("step input takes a Variable" )
389+ if self .seq_len is None :
390+ self .seq_len = x .shape [1 ]
391+ elif self .seq_len != x .shape [1 ]:
392+ raise ValueError ("Static RNN only take fix seq_len input" )
393+
394+ ipt = self .helper .create_variable (
395+ name = x .name ,
396+ dtype = x .data_type ,
397+ shape = [- 1 ] + list (x .shape [2 :]),
398+ type = x .type )
399+ self .inputs .append (ipt )
400+ return ipt
401+
402+ def step_output (self , o ):
403+ self ._assert_in_rnn_block_ ('step_output' )
404+ if not isinstance (o , Variable ):
405+ raise TypeError ("step output takes a Variable" )
406+
407+ out_var = self .parent_block ().create_var (
408+ name = o .name ,
409+ shape = [- 1 , self .seq_len ] + list (o .shape [1 :]),
410+ dtype = o .data_type )
411+
412+ self .outputs .append (out_var )
413+
414+ def output (self , * outputs ):
415+ for each in outputs :
416+ self .step_output (each )
417+
418+ def update_memory (self , mem , var ):
419+ if not isinstance (mem , Variable ) or not isinstance (var , Variable ):
420+ raise TypeError ("update memory should take variables" )
421+ self .memories [mem .name ].mem = var
422+
423+ def parent_block (self ):
424+ prog = self .helper .program
425+ parent_idx = prog .current_block ().parent_idx
426+ assert parent_idx >= 0
427+ parent_block = prog .block (parent_idx )
428+ return parent_block
429+
430+ def __call__ (self , * args , ** kwargs ):
431+ if self .status != StaticRNN .AFTER_RNN_BLOCK :
432+ raise ValueError ("RNN output can only be retrieved after rnn block" )
433+ if len (self .outputs ) == 0 :
434+ raise ValueError ("RNN has no output" )
435+ elif len (self .outputs ) == 1 :
436+ return self .outputs [0 ]
437+ else :
438+ return self .outputs
439+
440+ def complete_rnn_op (self ):
441+ # TODO(yuyang18): Create RNN Op here.
442+ # Implement this method after RNN op complete.
443+ pass
0 commit comments