Skip to content

Commit 7c498a4

Browse files
committed
Update to TensorFlow 1.1, and misc bugfixes.
1 parent d27ea80 commit 7c498a4

File tree

14 files changed

+325
-49
lines changed

14 files changed

+325
-49
lines changed

WORKSPACE

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,18 @@ local_repository(
77
path = "tensorflow",
88
)
99

10+
# TensorFlow depends on "io_bazel_rules_closure" so we need this here.
11+
# Needs to be kept in sync with the same target in TensorFlow's WORKSPACE file.
12+
http_archive(
13+
name = "io_bazel_rules_closure",
14+
sha256 = "60fc6977908f999b23ca65698c2bb70213403824a84f7904310b6000d78be9ce",
15+
strip_prefix = "rules_closure-5ca1dab6df9ad02050f7ba4e816407f88690cf7d",
16+
urls = [
17+
"http://bazel-mirror.storage.googleapis.com/github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz", # 2017-02-03
18+
"https://github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz",
19+
],
20+
)
21+
1022
# Import all of the tensorflow dependencies.
1123
load('//tensorflow_fold:workspace.bzl', 'tf_fold_workspace')
1224
tf_fold_workspace()

tensorflow

Submodule tensorflow updated 4089 files

tensorflow_fold/blocks/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ fold_py_library(
140140
":util",
141141
# numpy",
142142
"@org_tensorflow//tensorflow:tensorflow_py",
143+
"@org_tensorflow//tensorflow/python/debug:debug_py",
143144
],
144145
)
145146

tensorflow_fold/blocks/block_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def init_loom(self, max_depth=None, loom_input_tensor=None, input_tensor=None,
408408
'loom_input_tensor for the same compiler')
409409
self._input_tensor = input_tensor
410410
loom_input_tensor = tf.py_func(
411-
self.build_loom_input_batched,
411+
lambda inp: np.asarray(self.build_loom_input_batched(inp), np.object),
412412
[self._input_tensor], [tf.string], name='Scheduler')
413413
passthrough_types = (self._dry.tagging_op.passthrough_types
414414
if self._dry.tagging_op else None)

tensorflow_fold/blocks/blocks.py

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def __init__(self, children=None, input_type=None, output_type=None,
8282
self._children = []
8383
self._constructor_name = None
8484
self._constructor_args = None
85+
self._constructor_kwargs = None
8586
if children is not None:
8687
for b in children:
8788
self._add_child(b)
@@ -90,7 +91,7 @@ def __init__(self, children=None, input_type=None, output_type=None,
9091
self._propagate_types_from_child(child)
9192

9293
def __repr__(self):
93-
strs = ['td.%s' % (self._constructor_name or type(self).__name__,)]
94+
strs = [self._constructor_name or ('td.%s' % type(self).__name__)]
9495
if self._name: strs.append('%r' % self._name)
9596
for k, v in sorted(six.iteritems(self._repr_kwargs())):
9697
if isinstance(v, functools.partial): v = v.func
@@ -117,6 +118,24 @@ def set_constructor_args(self, *constructor_args):
117118
self._constructor_args = constructor_args
118119
return self
119120

121+
def set_constructor_name_args(self, name, args, kwargs):
122+
"""Sets the constructor args used to pretty-print this layer.
123+
124+
Should be called by derived classes in __init__.
125+
126+
Args:
127+
name: the fully qualified name of the constructor
128+
args: a list of constructor arguments
129+
kwargs: a list of (key,value,default) triples for keyword arguments
130+
131+
Returns:
132+
self
133+
"""
134+
self._constructor_name = name
135+
self._constructor_args = args
136+
self._constructor_kwargs = kwargs if kwargs is not None else []
137+
return self
138+
120139
def _add_child(self, b):
121140
assert isinstance(b, Block) # internal consistency check
122141
if b.parent is not None:
@@ -282,9 +301,9 @@ def __init__(self, shape, dtype='float32', name=None):
282301

283302
def _repr_kwargs(self):
284303
kwargs = {'dtype': self.output_type.dtype}
285-
if self._constructor_name == 'Vector':
304+
if self._constructor_name == 'td.Vector':
286305
kwargs['size'] = self.output_type.shape[0]
287-
elif self._constructor_name != 'Scalar':
306+
elif self._constructor_name != 'td.Scalar':
288307
kwargs['shape'] = self.output_type.shape
289308
return kwargs
290309

@@ -299,13 +318,13 @@ def _evaluate(self, eval_ctx, x):
299318
def Scalar(dtype='float32', name=None): # pylint: disable=invalid-name
300319
"""A block that converts its input to a scalar."""
301320
return Tensor(shape=[], dtype=dtype, name=name).set_constructor_name(
302-
'Scalar')
321+
'td.Scalar')
303322

304323

305324
def Vector(size, dtype='float32', name=None): # pylint: disable=invalid-name
306325
"""A block that converts its input to a vector."""
307326
return Tensor(shape=[size], dtype=dtype, name=name).set_constructor_name(
308-
'Vector')
327+
'td.Vector')
309328

310329

311330
class FromTensor(Block):
@@ -493,7 +512,7 @@ def SerializedMessageToTree(message_type_name): # pylint: disable=invalid-name
493512
return InputTransform(functools.partial(
494513
proto_tools.serialized_message_to_tree, message_type_name),
495514
name=message_type_name).set_constructor_name(
496-
'SerializedMessageToTree')
515+
'td.SerializedMessageToTree')
497516

498517

499518
class GetItem(Block):
@@ -581,7 +600,8 @@ def Slice(*args, **kwargs): # pylint: disable=invalid-name
581600
raise TypeError('Slice does not accept positional arguments; allowed '
582601
'keyword arguments are start, stop, and step')
583602
name = kwargs.pop('name', None)
584-
return GetItem(_get_slice(**kwargs), name=name).set_constructor_name('Slice')
603+
return GetItem(_get_slice(**kwargs), name=name).set_constructor_name(
604+
'td.Slice')
585605

586606

587607
def _get_slice(start=None, stop=None, step=None):
@@ -630,7 +650,7 @@ def __init__(self, input_type, output_type, name):
630650
self._target_block = None
631651
super(_ForwardDeclarationRef, self).__init__(
632652
input_type=input_type, output_type=output_type, name=name)
633-
self.set_constructor_name('ForwardDeclaration()')
653+
self.set_constructor_name('td.ForwardDeclaration()')
634654

635655
def _update_input_type(self):
636656
if self._target_block is not None:
@@ -677,7 +697,7 @@ def __init__(self, parent, parent_name, is_input):
677697
placeholder_name = 'output'
678698
self._update_parent = parent.set_output_type
679699
super(_ComposeIO, self).__init__(name=parent_name)
680-
self.set_constructor_name('Composition.%s' % placeholder_name)
700+
self.set_constructor_name('td.Composition.%s' % placeholder_name)
681701
self._parent = parent
682702

683703
def _update_input_type(self):
@@ -989,29 +1009,29 @@ def Pipe(*blocks, **kwargs): # pylint: disable=invalid-name
9891009
```
9901010
9911011
Args:
992-
*blocks: A tuple of blocks.
993-
**kwargs: `{'name': name_string}` or `{}`.
1012+
*blocks: A tuple of blocks.
1013+
**kwargs: Optional keyword arguments. Accepts name='block_name'.
9941014
9951015
Returns:
9961016
A block.
9971017
"""
998-
return _pipe([convert_to_block(b) for b in blocks],
999-
**kwargs).set_constructor_name('Pipe')
1018+
blocks = [convert_to_block(b) for b in blocks]
1019+
blocks = [b for b in blocks if not isinstance(b, Identity)]
1020+
if not blocks: return Identity(**kwargs)
1021+
if len(blocks) == 1: return blocks[0]
1022+
return _pipe(blocks, **kwargs)
10001023

10011024

10021025
def _pipe(blocks, name=None):
10031026
"""Internal implementation of Pipe."""
1004-
if not blocks: return Identity(name=name)
1005-
if len(blocks) == 1: return blocks[0]
1006-
10071027
c = Composition(blocks, name=name)
10081028
c.connect(c.input, blocks[0])
10091029
prev = blocks[0]
10101030
for b in blocks[1:]:
10111031
c.connect(prev, b)
10121032
prev = b
10131033
c.connect(prev, c.output)
1014-
return c
1034+
return c.set_constructor_name('td.Pipe')
10151035

10161036

10171037
class Record(Block):
@@ -1135,15 +1155,18 @@ def AllOf(*blocks, **kwargs): # pylint: disable=invalid-name
11351155
Returns:
11361156
See above.
11371157
"""
1138-
return _all_of([convert_to_block(b) for b in blocks],
1139-
**kwargs).set_constructor_name('AllOf')
1158+
blocks = [convert_to_block(b) for b in blocks]
1159+
c = _all_of(blocks, **kwargs)
1160+
c.set_constructor_name('td.AllOf')
1161+
c.set_constructor_args(*blocks)
1162+
return c
11401163

11411164

11421165
def _all_of(blocks, name=None):
11431166
"""Internal implementation of AllOf."""
11441167
if not blocks: return Void(name=name)
11451168
if len(blocks) == 1:
1146-
# TODO(moshelooks): fix composition to allow for tuple output.
1169+
# TODO(delesley): fix composition to allow for tuple output.
11471170
return Pipe(blocks[0], AllOf(Identity(), Identity()), Slice(stop=1),
11481171
name=name)
11491172
c = Composition(blocks, name=name)
@@ -1291,7 +1314,6 @@ def __init__(self, rnn_cell_block, name=None):
12911314
"""
12921315
self._rnn_cell_block = convert_to_block(rnn_cell_block)
12931316
super(_RNN, self).__init__(children=[self._rnn_cell_block], name=name)
1294-
self.set_constructor_name('RNN')
12951317

12961318
def _repr_kwargs(self):
12971319
return dict(rnn_cell_block=self.rnn_cell_block)
@@ -1402,12 +1424,15 @@ def RNN(cell, initial_state=None, # pylint: disable=invalid-name
14021424
a block.
14031425
"""
14041426
cell = convert_to_block(cell)
1427+
(args, kwargs) = _get_local_args(RNN)
14051428

14061429
if initial_state_from_input:
14071430
if initial_state is not None:
14081431
raise ValueError('Cannot specify initial_state if '
14091432
'initial_state_from_input is True.')
1410-
return _RNN(cell, name=name)
1433+
rnn = _RNN(cell, name=name)
1434+
rnn.set_constructor_name_args('td.RNN', args, kwargs)
1435+
return rnn
14111436

14121437
# Otherwise create a composition to wire in initial_state.
14131438
if initial_state is None:
@@ -1421,10 +1446,11 @@ def RNN(cell, initial_state=None, # pylint: disable=invalid-name
14211446
else:
14221447
initial_state = convert_to_block(initial_state)
14231448

1424-
c = Composition(name=name).set_constructor_name('RNN')
1449+
c = Composition(name=name).set_constructor_name('td.RNN')
14251450
with c.scope():
14261451
rnn = _RNN(cell, name=name).reads(c.input, initial_state)
14271452
c.output.reads(rnn)
1453+
c.set_constructor_name_args('td.RNN', args, kwargs)
14281454
return c
14291455

14301456

@@ -1484,17 +1510,17 @@ def _evaluate(self, eval_ctx, x):
14841510

14851511
def Sum(name=None): # pylint: disable=invalid-name
14861512
"""Sums its inputs."""
1487-
return Reduce(Function(tf.add), name=name).set_constructor_name('Sum')
1513+
return Reduce(Function(tf.add), name=name).set_constructor_name('td.Sum')
14881514

14891515

14901516
def Min(name=None): # pylint: disable=invalid-name
14911517
"""Takes the minimum of its inputs. Zero on no inputs."""
1492-
return Reduce(Function(tf.minimum), name=name).set_constructor_name('Min')
1518+
return Reduce(Function(tf.minimum), name=name).set_constructor_name('td.Min')
14931519

14941520

14951521
def Max(name=None): # pylint: disable=invalid-name
14961522
"""Takes the maximum of its inputs. Zero on no inputs."""
1497-
return Reduce(Function(tf.maximum), name=name).set_constructor_name('Max')
1523+
return Reduce(Function(tf.maximum), name=name).set_constructor_name('td.Max')
14981524

14991525

15001526
def _tf_safe_reciprocal(x):
@@ -1517,7 +1543,7 @@ def Mean(name=None): # pylint: disable=invalid-name
15171543
with c.scope():
15181544
c.output.reads(Function(_tf_batch_safe_scalar_division).reads(
15191545
Sum().reads(c.input), Length().reads(c.input)))
1520-
return c.set_constructor_name('Mean')
1546+
return c.set_constructor_name('td.Mean')
15211547

15221548

15231549
class OneOf(Block):
@@ -1929,7 +1955,7 @@ def OneHotFromList(elements, dtype='float32', strict=True, name=None): # pylint
19291955
key_fn = lambda x: indices.get(x, -1)
19301956

19311957
return OneOf(key_fn, tensors, pre_block=Void(),
1932-
name=name).set_constructor_name('OneHotFromList')
1958+
name=name).set_constructor_name('td.OneHotFromList')
19331959

19341960

19351961
class Nth(Block):
@@ -1999,13 +2025,13 @@ def Zeros(output_type, name=None): # pylint: disable=invalid-name
19992025
result = _EmptySequence(input_type=tdt.VoidType(), output_type=output_type,
20002026
name=name)
20012027
result.set_constructor_args(pp_output_type)
2002-
return result.set_constructor_name('Zeros')
2028+
return result.set_constructor_name('td.Zeros')
20032029

20042030

20052031
def Void(name=None): # pylint: disable=invalid-name
20062032
"""A block with void output type that accepts any input type."""
20072033
return Composition(name=name).set_output_type(
2008-
tdt.VoidType()).set_constructor_name('Void')
2034+
tdt.VoidType()).set_constructor_name('td.Void').set_constructor_args()
20092035

20102036

20112037
def convert_to_block(block_like):
@@ -2111,3 +2137,7 @@ def _evaluate(self, unused_eval_ctx, unused_x):
21112137
def _is_layer(x):
21122138
return isinstance(
21132139
x, tensorflow_fold.blocks.layers.Layer)
2140+
2141+
2142+
_get_local_args = (
2143+
tensorflow_fold.blocks.layers.get_local_arguments)

tensorflow_fold/blocks/blocks_test.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,13 @@ def assertBuilds(self, desired, block, inp, max_depth=1, feed_dict=None):
7171
self.assertSameStructure(desired_out, out[0].tolist())
7272

7373
def assertBuildsConst(self, desired, block, inp):
74-
# TODO(moshelooks): actually test constness
74+
# In this context 'const' means a block that can be evaluated
75+
# entirely in python, without running TF at all. For example,
76+
# (td.Map(td.Scalar()) >> td.Sum()) is const.
77+
#
78+
# There currently are no optimizations in Fold for const
79+
# blocks. If/when these are implemented, code should go here for
80+
# testing that 'block' is in fact const.
7581
self.assertBuilds(desired, block, inp, max_depth=None)
7682

7783
def test_scalar(self):
@@ -301,13 +307,14 @@ def test_composition_slice(self):
301307
self.assertBuilds(4, c1, None, max_depth=1)
302308

303309
def test_composition_backward_type_inference(self):
304-
b = tdb.Map(tdb.Identity()) >> tdb.Identity() >> tdb.Identity()
310+
b = tdb._pipe([tdb.Map(tdb.Identity()), tdb.Identity(), tdb.Identity()])
305311
six.assertRaisesRegex(
306312
self, TypeError, 'bad output type VoidType',
307313
b.output.set_output_type, tdt.VoidType())
308314

309315
def test_composition_forward_type_inference(self):
310-
b = tdb.Identity() >> tdb.Identity() >> tdb.Map(tdb.Function(tf.negative))
316+
b = tdb._pipe([tdb.Identity(), tdb.Identity(),
317+
tdb.Map(tdb.Function(tf.negative))])
311318
six.assertRaisesRegex(
312319
self, TypeError, 'bad input type PyObjectType',
313320
b.input.set_input_type, tdt.PyObjectType())
@@ -1090,8 +1097,9 @@ def test_repr(self):
10901097
tdb.Composition(name='x').output: '<td.Composition.output \'x\'>',
10911098
tdb.Composition(name='x'): '<td.Composition \'x\'>',
10921099

1093-
tdb.Pipe(): '<td.Pipe>',
1094-
tdb.Pipe(tdb.Scalar(), tdb.Identity()): '<td.Pipe>',
1100+
tdb.Pipe(): '<td.Identity>',
1101+
tdb.Pipe(tdb.Scalar(), tdb.Identity()): '<td.Scalar dtype=\'float32\'>',
1102+
tdb.Pipe(tdb.InputTransform(ord), tdb.Scalar('int32')): '<td.Pipe>',
10951103

10961104
tdb.Record({}, name='x'): '<td.Record \'x\' ordered=False>',
10971105
tdb.Record((), name='x'): '<td.Record \'x\' ordered=True>',

0 commit comments

Comments
 (0)