Skip to content

Commit 02a9aaf

Browse files
michaelreneercopybara-github
authored andcommitted
Added a new tensorflow_docs public filter and removed the _allowed_symbols variables from the TFF packages.
PiperOrigin-RevId: 306446724
1 parent 4755927 commit 02a9aaf

File tree

2 files changed

+115
-2
lines changed

2 files changed

+115
-2
lines changed

tools/tensorflow_docs/api_generator/public_api.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616
"""Visitor restricting traversal to only the public tensorflow API."""
17-
import inspect
1817

18+
import ast
19+
import inspect
20+
import typing
1921

2022
from tensorflow_docs.api_generator import doc_controls
2123

22-
import typing
2324
_TYPING = frozenset(id(value) for value in typing.__dict__.values())
2425

2526

@@ -121,6 +122,77 @@ def util_2
121122
return filtered_children
122123

123124

125+
def _get_imported_symbols(obj):
126+
"""Returns a list of symbol names imported by the given `obj`."""
127+
128+
class ImportNodeVisitor(ast.NodeVisitor):
129+
"""An `ast.Visitor` that collects the names of imported symbols."""
130+
131+
def __init__(self):
132+
self.imported_symbols = []
133+
134+
def _add_imported_symbol(self, node):
135+
self.imported_symbols.extend([alias.name for alias in node.names])
136+
137+
def visit_Import(self, node): # pylint: disable=invalid-name
138+
self._add_imported_symbol(node)
139+
140+
def visit_ImportFrom(self, node): # pylint: disable=invalid-name
141+
self._add_imported_symbol(node)
142+
143+
source = inspect.getsource(obj)
144+
tree = ast.parse(source)
145+
visitor = ImportNodeVisitor()
146+
visitor.visit(tree)
147+
return visitor.imported_symbols
148+
149+
150+
def explicit_package_contents_filter(path, parent, children):
151+
"""Filter modules to only include explicit contents.
152+
153+
This function returns the children explicitly included by this module, meaning
154+
that it will exclude:
155+
156+
* Modules in a package not explicitly imported by the package (submodules
157+
are implicitly injected into their parent's namespace).
158+
* Modules imported by a module that is not a package.
159+
160+
This filter is useful if you explicitly define your API in the packages of
161+
your library, but do not expliticly define that API in the `__all__` variable
162+
of each module. The purpose is to make it easier to maintain that API.
163+
164+
Note: This filter does work with wildcard imports, however it is generally not
165+
recommended to use wildcard imports.
166+
167+
Args:
168+
path: A tuple of names forming the path to the object.
169+
parent: The parent object.
170+
children: A list of (name, value) tuples describing the attributes of the
171+
patent.
172+
173+
Returns:
174+
A filtered list of children `(name, value)` pairs.
175+
"""
176+
del path # Unused
177+
is_parent_module = inspect.ismodule(parent)
178+
is_parent_package = is_parent_module and hasattr(parent, '__path__')
179+
if is_parent_package:
180+
imported_symbols = _get_imported_symbols(parent)
181+
filtered_children = []
182+
for child in children:
183+
name, obj = child
184+
if inspect.ismodule(obj):
185+
# Do not include modules in a package not explicitly imported by the
186+
# package.
187+
if is_parent_package and name not in imported_symbols:
188+
continue
189+
# Do not include modules imported by a module that is not a package.
190+
if is_parent_module and not is_parent_package:
191+
continue
192+
filtered_children.append(child)
193+
return filtered_children
194+
195+
124196
class PublicAPIFilter(object):
125197
"""Visitor to use with `traverse` to filter just the public API."""
126198

tools/tensorflow_docs/api_generator/public_api_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import typing
2222

2323
from absl.testing import absltest
24+
# This import is using to test
25+
from tensorflow_docs import api_generator
2426
from tensorflow_docs.api_generator import public_api
2527

2628

@@ -150,6 +152,45 @@ def public_members(obj):
150152

151153
self.assertCountEqual([], filtered_names)
152154

155+
def test_explicit_package_contents_filter_removes_modules_not_explicitly_imported(
156+
self):
157+
path = ('tensorflow_docs', 'api_generator')
158+
parent = api_generator
159+
members = inspect.getmembers(parent)
160+
members.append(('inspect', inspect))
161+
162+
# Assert that parent is a module and is a package, and that the members of
163+
# parent include a module named `inspect`.
164+
self.assertTrue(inspect.ismodule(parent))
165+
self.assertTrue(hasattr(parent, '__path__'))
166+
self.assertIn('inspect', [name for name, _ in members])
167+
self.assertTrue(inspect.ismodule(inspect))
168+
169+
filtered_members = public_api.explicit_package_contents_filter(
170+
path, parent, members)
171+
172+
# Assert that the filtered_members do not include a module named `inspect`.
173+
self.assertNotIn('inspect', [name for name, _ in filtered_members])
174+
175+
def test_explicit_package_contents_filter_removes_modules_imported_by_modules(
176+
self):
177+
path = ('tensorflow_docs', 'api_generator', 'public_api')
178+
parent = public_api
179+
members = inspect.getmembers(parent)
180+
181+
# Assert that parent is a module and not a package, and that the members of
182+
# parent include a module named `inspect`.
183+
self.assertTrue(inspect.ismodule(parent))
184+
self.assertFalse(hasattr(parent, '__path__'))
185+
self.assertIn('inspect', [name for name, _ in members])
186+
self.assertTrue(inspect.ismodule(inspect))
187+
188+
filtered_members = public_api.explicit_package_contents_filter(
189+
path, parent, members)
190+
191+
# Assert that the filtered_members do not include a module named `inspect`.
192+
self.assertNotIn('inspect', [name for name, _ in filtered_members])
193+
153194
def test_ignore_typing(self):
154195
children_before = [('a', 1), ('b', 3), ('c', typing.List)]
155196
children_after = public_api.ignore_typing('ignored', 'ignored',

0 commit comments

Comments
 (0)