@@ -82,6 +82,7 @@ def __init__(self, children=None, input_type=None, output_type=None,
8282 self ._children = []
8383 self ._constructor_name = None
8484 self ._constructor_args = None
85+ self ._constructor_kwargs = None
8586 if children is not None :
8687 for b in children :
8788 self ._add_child (b )
@@ -90,7 +91,7 @@ def __init__(self, children=None, input_type=None, output_type=None,
9091 self ._propagate_types_from_child (child )
9192
9293 def __repr__ (self ):
93- strs = ['td.%s' % ( self . _constructor_name or type (self ).__name__ , )]
94+ strs = [self . _constructor_name or ( 'td.%s' % type (self ).__name__ )]
9495 if self ._name : strs .append ('%r' % self ._name )
9596 for k , v in sorted (six .iteritems (self ._repr_kwargs ())):
9697 if isinstance (v , functools .partial ): v = v .func
@@ -117,6 +118,24 @@ def set_constructor_args(self, *constructor_args):
117118 self ._constructor_args = constructor_args
118119 return self
119120
121+ def set_constructor_name_args (self , name , args , kwargs ):
122+ """Sets the constructor args used to pretty-print this layer.
123+
124+ Should be called by derived classes in __init__.
125+
126+ Args:
127+ name: the fully qualified name of the constructor
128+ args: a list of constructor arguments
129+ kwargs: a list of (key,value,default) triples for keyword arguments
130+
131+ Returns:
132+ self
133+ """
134+ self ._constructor_name = name
135+ self ._constructor_args = args
136+ self ._constructor_kwargs = kwargs if kwargs is not None else []
137+ return self
138+
120139 def _add_child (self , b ):
121140 assert isinstance (b , Block ) # internal consistency check
122141 if b .parent is not None :
@@ -282,9 +301,9 @@ def __init__(self, shape, dtype='float32', name=None):
282301
283302 def _repr_kwargs (self ):
284303 kwargs = {'dtype' : self .output_type .dtype }
285- if self ._constructor_name == 'Vector' :
304+ if self ._constructor_name == 'td. Vector' :
286305 kwargs ['size' ] = self .output_type .shape [0 ]
287- elif self ._constructor_name != 'Scalar' :
306+ elif self ._constructor_name != 'td. Scalar' :
288307 kwargs ['shape' ] = self .output_type .shape
289308 return kwargs
290309
@@ -299,13 +318,13 @@ def _evaluate(self, eval_ctx, x):
299318def Scalar (dtype = 'float32' , name = None ): # pylint: disable=invalid-name
300319 """A block that converts its input to a scalar."""
301320 return Tensor (shape = [], dtype = dtype , name = name ).set_constructor_name (
302- 'Scalar' )
321+ 'td. Scalar' )
303322
304323
305324def Vector (size , dtype = 'float32' , name = None ): # pylint: disable=invalid-name
306325 """A block that converts its input to a vector."""
307326 return Tensor (shape = [size ], dtype = dtype , name = name ).set_constructor_name (
308- 'Vector' )
327+ 'td. Vector' )
309328
310329
311330class FromTensor (Block ):
@@ -493,7 +512,7 @@ def SerializedMessageToTree(message_type_name): # pylint: disable=invalid-name
493512 return InputTransform (functools .partial (
494513 proto_tools .serialized_message_to_tree , message_type_name ),
495514 name = message_type_name ).set_constructor_name (
496- 'SerializedMessageToTree' )
515+ 'td. SerializedMessageToTree' )
497516
498517
499518class GetItem (Block ):
@@ -581,7 +600,8 @@ def Slice(*args, **kwargs): # pylint: disable=invalid-name
581600 raise TypeError ('Slice does not accept positional arguments; allowed '
582601 'keyword arguments are start, stop, and step' )
583602 name = kwargs .pop ('name' , None )
584- return GetItem (_get_slice (** kwargs ), name = name ).set_constructor_name ('Slice' )
603+ return GetItem (_get_slice (** kwargs ), name = name ).set_constructor_name (
604+ 'td.Slice' )
585605
586606
587607def _get_slice (start = None , stop = None , step = None ):
@@ -630,7 +650,7 @@ def __init__(self, input_type, output_type, name):
630650 self ._target_block = None
631651 super (_ForwardDeclarationRef , self ).__init__ (
632652 input_type = input_type , output_type = output_type , name = name )
633- self .set_constructor_name ('ForwardDeclaration()' )
653+ self .set_constructor_name ('td. ForwardDeclaration()' )
634654
635655 def _update_input_type (self ):
636656 if self ._target_block is not None :
@@ -677,7 +697,7 @@ def __init__(self, parent, parent_name, is_input):
677697 placeholder_name = 'output'
678698 self ._update_parent = parent .set_output_type
679699 super (_ComposeIO , self ).__init__ (name = parent_name )
680- self .set_constructor_name ('Composition.%s' % placeholder_name )
700+ self .set_constructor_name ('td. Composition.%s' % placeholder_name )
681701 self ._parent = parent
682702
683703 def _update_input_type (self ):
@@ -989,29 +1009,29 @@ def Pipe(*blocks, **kwargs): # pylint: disable=invalid-name
9891009 ```
9901010
9911011 Args:
992- *blocks: A tuple of blocks.
993- **kwargs: `{'name': name_string}` or `{}` .
1012+ *blocks: A tuple of blocks.
1013+ **kwargs: Optional keyword arguments. Accepts name='block_name' .
9941014
9951015 Returns:
9961016 A block.
9971017 """
998- return _pipe ([convert_to_block (b ) for b in blocks ],
999- ** kwargs ).set_constructor_name ('Pipe' )
1018+ blocks = [convert_to_block (b ) for b in blocks ]
1019+ blocks = [b for b in blocks if not isinstance (b , Identity )]
1020+ if not blocks : return Identity (** kwargs )
1021+ if len (blocks ) == 1 : return blocks [0 ]
1022+ return _pipe (blocks , ** kwargs )
10001023
10011024
10021025def _pipe (blocks , name = None ):
10031026 """Internal implementation of Pipe."""
1004- if not blocks : return Identity (name = name )
1005- if len (blocks ) == 1 : return blocks [0 ]
1006-
10071027 c = Composition (blocks , name = name )
10081028 c .connect (c .input , blocks [0 ])
10091029 prev = blocks [0 ]
10101030 for b in blocks [1 :]:
10111031 c .connect (prev , b )
10121032 prev = b
10131033 c .connect (prev , c .output )
1014- return c
1034+ return c . set_constructor_name ( 'td.Pipe' )
10151035
10161036
10171037class Record (Block ):
@@ -1135,15 +1155,18 @@ def AllOf(*blocks, **kwargs): # pylint: disable=invalid-name
11351155 Returns:
11361156 See above.
11371157 """
1138- return _all_of ([convert_to_block (b ) for b in blocks ],
1139- ** kwargs ).set_constructor_name ('AllOf' )
1158+ blocks = [convert_to_block (b ) for b in blocks ]
1159+ c = _all_of (blocks , ** kwargs )
1160+ c .set_constructor_name ('td.AllOf' )
1161+ c .set_constructor_args (* blocks )
1162+ return c
11401163
11411164
11421165def _all_of (blocks , name = None ):
11431166 """Internal implementation of AllOf."""
11441167 if not blocks : return Void (name = name )
11451168 if len (blocks ) == 1 :
1146- # TODO(moshelooks ): fix composition to allow for tuple output.
1169+ # TODO(delesley ): fix composition to allow for tuple output.
11471170 return Pipe (blocks [0 ], AllOf (Identity (), Identity ()), Slice (stop = 1 ),
11481171 name = name )
11491172 c = Composition (blocks , name = name )
@@ -1291,7 +1314,6 @@ def __init__(self, rnn_cell_block, name=None):
12911314 """
12921315 self ._rnn_cell_block = convert_to_block (rnn_cell_block )
12931316 super (_RNN , self ).__init__ (children = [self ._rnn_cell_block ], name = name )
1294- self .set_constructor_name ('RNN' )
12951317
12961318 def _repr_kwargs (self ):
12971319 return dict (rnn_cell_block = self .rnn_cell_block )
@@ -1402,12 +1424,15 @@ def RNN(cell, initial_state=None, # pylint: disable=invalid-name
14021424 a block.
14031425 """
14041426 cell = convert_to_block (cell )
1427+ (args , kwargs ) = _get_local_args (RNN )
14051428
14061429 if initial_state_from_input :
14071430 if initial_state is not None :
14081431 raise ValueError ('Cannot specify initial_state if '
14091432 'initial_state_from_input is True.' )
1410- return _RNN (cell , name = name )
1433+ rnn = _RNN (cell , name = name )
1434+ rnn .set_constructor_name_args ('td.RNN' , args , kwargs )
1435+ return rnn
14111436
14121437 # Otherwise create a composition to wire in initial_state.
14131438 if initial_state is None :
@@ -1421,10 +1446,11 @@ def RNN(cell, initial_state=None, # pylint: disable=invalid-name
14211446 else :
14221447 initial_state = convert_to_block (initial_state )
14231448
1424- c = Composition (name = name ).set_constructor_name ('RNN' )
1449+ c = Composition (name = name ).set_constructor_name ('td. RNN' )
14251450 with c .scope ():
14261451 rnn = _RNN (cell , name = name ).reads (c .input , initial_state )
14271452 c .output .reads (rnn )
1453+ c .set_constructor_name_args ('td.RNN' , args , kwargs )
14281454 return c
14291455
14301456
@@ -1484,17 +1510,17 @@ def _evaluate(self, eval_ctx, x):
14841510
14851511def Sum (name = None ): # pylint: disable=invalid-name
14861512 """Sums its inputs."""
1487- return Reduce (Function (tf .add ), name = name ).set_constructor_name ('Sum' )
1513+ return Reduce (Function (tf .add ), name = name ).set_constructor_name ('td. Sum' )
14881514
14891515
14901516def Min (name = None ): # pylint: disable=invalid-name
14911517 """Takes the minimum of its inputs. Zero on no inputs."""
1492- return Reduce (Function (tf .minimum ), name = name ).set_constructor_name ('Min' )
1518+ return Reduce (Function (tf .minimum ), name = name ).set_constructor_name ('td. Min' )
14931519
14941520
14951521def Max (name = None ): # pylint: disable=invalid-name
14961522 """Takes the maximum of its inputs. Zero on no inputs."""
1497- return Reduce (Function (tf .maximum ), name = name ).set_constructor_name ('Max' )
1523+ return Reduce (Function (tf .maximum ), name = name ).set_constructor_name ('td. Max' )
14981524
14991525
15001526def _tf_safe_reciprocal (x ):
@@ -1517,7 +1543,7 @@ def Mean(name=None): # pylint: disable=invalid-name
15171543 with c .scope ():
15181544 c .output .reads (Function (_tf_batch_safe_scalar_division ).reads (
15191545 Sum ().reads (c .input ), Length ().reads (c .input )))
1520- return c .set_constructor_name ('Mean' )
1546+ return c .set_constructor_name ('td. Mean' )
15211547
15221548
15231549class OneOf (Block ):
@@ -1929,7 +1955,7 @@ def OneHotFromList(elements, dtype='float32', strict=True, name=None): # pylint
19291955 key_fn = lambda x : indices .get (x , - 1 )
19301956
19311957 return OneOf (key_fn , tensors , pre_block = Void (),
1932- name = name ).set_constructor_name ('OneHotFromList' )
1958+ name = name ).set_constructor_name ('td. OneHotFromList' )
19331959
19341960
19351961class Nth (Block ):
@@ -1999,13 +2025,13 @@ def Zeros(output_type, name=None): # pylint: disable=invalid-name
19992025 result = _EmptySequence (input_type = tdt .VoidType (), output_type = output_type ,
20002026 name = name )
20012027 result .set_constructor_args (pp_output_type )
2002- return result .set_constructor_name ('Zeros' )
2028+ return result .set_constructor_name ('td. Zeros' )
20032029
20042030
20052031def Void (name = None ): # pylint: disable=invalid-name
20062032 """A block with void output type that accepts any input type."""
20072033 return Composition (name = name ).set_output_type (
2008- tdt .VoidType ()).set_constructor_name ('Void' )
2034+ tdt .VoidType ()).set_constructor_name ('td. Void' ). set_constructor_args ( )
20092035
20102036
20112037def convert_to_block (block_like ):
@@ -2111,3 +2137,7 @@ def _evaluate(self, unused_eval_ctx, unused_x):
21112137def _is_layer (x ):
21122138 return isinstance (
21132139 x , tensorflow_fold .blocks .layers .Layer )
2140+
2141+
2142+ _get_local_args = (
2143+ tensorflow_fold .blocks .layers .get_local_arguments )
0 commit comments