Skip to content

Commit b262375

Browse files
Fixed two bugs when importing MetaGraphDefs that contain ResourceVariables.
1) In the ResourceVariable implementation, pass import_scope when creating the SaveSliceInfo. This is present in the implementation of plain variables, and was likely a copy-and-paste omission. 2) When importing a MetaGraphDef, restoring the GLOBAL_VARIABLES and TRAINABLE_VARIABLES collections will add ops to the graph for ResourceVariables. Made graph construction deterministic by fixing the order in which collections are restored. PiperOrigin-RevId: 177144138
1 parent 64e1459 commit b262375

File tree

3 files changed

+45
-8
lines changed

3 files changed

+45
-8
lines changed

tensorflow/python/framework/meta_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ def import_scoped_meta_graph(meta_graph_or_file,
663663
[part for part in [graph.get_name_scope(), import_scope] if part])
664664

665665
# Restores all the other collections.
666-
for key, col_def in meta_graph_def.collection_def.items():
666+
for key, col_def in sorted(meta_graph_def.collection_def.items()):
667667
# Don't add unbound_inputs to the new graph.
668668
if key == unbound_inputs_col_name:
669669
continue

tensorflow/python/framework/meta_graph_test.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -662,22 +662,36 @@ def _enqueue_vector(sess, queue, values, shape=None):
662662
class ExportImportAcrossScopesTest(test.TestCase):
663663

664664
def testPartionedVariables(self):
665-
def make_graph_with_partitioned_variables():
665+
666+
def make_graph_with_partitioned_variables(use_resource):
666667
variable_scope.get_variable(
667668
name="weights",
668669
partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0),
669-
initializer=random_ops.truncated_normal([100, 10]))
670-
self._testExportImportAcrossScopes(make_graph_with_partitioned_variables)
670+
initializer=random_ops.truncated_normal([100, 10]),
671+
use_resource=use_resource)
672+
# The next variable illustrates the necessity of restoring collections
673+
# in a deterministic fashion when using ResourceVariables.
674+
variable_scope.get_variable(
675+
name="another",
676+
shape=[],
677+
collections=["a", "b", "z", "f", "e", "d", "g"],
678+
use_resource=use_resource)
679+
680+
self._testExportImportAcrossScopes(
681+
make_graph_with_partitioned_variables, use_resource=False)
682+
self._testExportImportAcrossScopes(
683+
make_graph_with_partitioned_variables, use_resource=True)
671684

672-
def _testExportImportAcrossScopes(self, graph_fn):
685+
def _testExportImportAcrossScopes(self, graph_fn, use_resource):
673686
"""Tests export and importing a graph across scopes.
674687
675688
Args:
676689
graph_fn: A closure that creates a graph on the current scope.
690+
use_resource: A bool indicating whether or not to use ResourceVariables.
677691
"""
678692
with ops.Graph().as_default() as original_graph:
679693
with variable_scope.variable_scope("dropA/dropB/keepA"):
680-
graph_fn()
694+
graph_fn(use_resource=use_resource)
681695
exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
682696
graph=original_graph,
683697
export_scope="dropA/dropB")[0]
@@ -689,10 +703,32 @@ def _testExportImportAcrossScopes(self, graph_fn):
689703

690704
with ops.Graph().as_default() as expected_graph:
691705
with variable_scope.variable_scope("importA/keepA"):
692-
graph_fn()
706+
graph_fn(use_resource=use_resource)
707+
708+
if use_resource:
709+
# Bringing in a collection that contains ResourceVariables adds ops
710+
# to the graph, so mimic the same behavior.
711+
for collection_key in sorted([
712+
ops.GraphKeys.GLOBAL_VARIABLES,
713+
ops.GraphKeys.TRAINABLE_VARIABLES,
714+
]):
715+
for var in expected_graph.get_collection(collection_key):
716+
var._read_variable_op()
693717

694718
result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
695719
expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]
720+
721+
if use_resource:
722+
# Clear all shared_name attributes before comparing, since they are
723+
# supposed to be orthogonal to scopes.
724+
for meta_graph_def in [result, expected]:
725+
for node in meta_graph_def.graph_def.node:
726+
shared_name_attr = "shared_name"
727+
shared_name_value = node.attr.get(shared_name_attr, None)
728+
if shared_name_value and shared_name_value.HasField("s"):
729+
if shared_name_value.s:
730+
node.attr[shared_name_attr].s = b""
731+
696732
self.assertProtoEquals(expected, result)
697733

698734

tensorflow/python/ops/resource_variable_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,8 @@ def _init_from_proto(self, variable_def, import_scope=None):
513513
self._cached_value = None
514514
if variable_def.HasField("save_slice_info_def"):
515515
self._save_slice_info = variables.Variable.SaveSliceInfo(
516-
save_slice_info_def=variable_def.save_slice_info_def)
516+
save_slice_info_def=variable_def.save_slice_info_def,
517+
import_scope=import_scope)
517518
else:
518519
self._save_slice_info = None
519520
self._caching_device = None

0 commit comments

Comments
 (0)