|
14 | 14 | # limitations under the License. |
15 | 15 | # ============================================================================== |
16 | 16 | """Visitor restricting traversal to only the public tensorflow API.""" |
17 | | -import inspect |
18 | 17 |
|
| 18 | +import ast |
| 19 | +import inspect |
| 20 | +import typing |
19 | 21 |
|
20 | 22 | from tensorflow_docs.api_generator import doc_controls |
21 | 23 |
|
22 | | -import typing |
23 | 24 | _TYPING = frozenset(id(value) for value in typing.__dict__.values()) |
24 | 25 |
|
25 | 26 |
|
@@ -121,6 +122,77 @@ def util_2 |
121 | 122 | return filtered_children |
122 | 123 |
|
123 | 124 |
|
| 125 | +def _get_imported_symbols(obj): |
| 126 | + """Returns a list of symbol names imported by the given `obj`.""" |
| 127 | + |
| 128 | + class ImportNodeVisitor(ast.NodeVisitor): |
| 129 | + """An `ast.Visitor` that collects the names of imported symbols.""" |
| 130 | + |
| 131 | + def __init__(self): |
| 132 | + self.imported_symbols = [] |
| 133 | + |
| 134 | + def _add_imported_symbol(self, node): |
| 135 | + self.imported_symbols.extend([alias.name for alias in node.names]) |
| 136 | + |
| 137 | + def visit_Import(self, node): # pylint: disable=invalid-name |
| 138 | + self._add_imported_symbol(node) |
| 139 | + |
| 140 | + def visit_ImportFrom(self, node): # pylint: disable=invalid-name |
| 141 | + self._add_imported_symbol(node) |
| 142 | + |
| 143 | + source = inspect.getsource(obj) |
| 144 | + tree = ast.parse(source) |
| 145 | + visitor = ImportNodeVisitor() |
| 146 | + visitor.visit(tree) |
| 147 | + return visitor.imported_symbols |
| 148 | + |
| 149 | + |
| 150 | +def explicit_package_contents_filter(path, parent, children): |
| 151 | + """Filter modules to only include explicit contents. |
| 152 | +
|
| 153 | + This function returns the children explicitly included by this module, meaning |
| 154 | + that it will exclude: |
| 155 | +
|
| 156 | + * Modules in a package not explicitly imported by the package (submodules |
| 157 | + are implicitly injected into their parent's namespace). |
| 158 | + * Modules imported by a module that is not a package. |
| 159 | +
|
| 160 | + This filter is useful if you explicitly define your API in the packages of |
| 161 | + your library, but do not expliticly define that API in the `__all__` variable |
| 162 | + of each module. The purpose is to make it easier to maintain that API. |
| 163 | +
|
| 164 | + Note: This filter does work with wildcard imports, however it is generally not |
| 165 | + recommended to use wildcard imports. |
| 166 | +
|
| 167 | + Args: |
| 168 | + path: A tuple of names forming the path to the object. |
| 169 | + parent: The parent object. |
| 170 | + children: A list of (name, value) tuples describing the attributes of the |
| 171 | + patent. |
| 172 | +
|
| 173 | + Returns: |
| 174 | + A filtered list of children `(name, value)` pairs. |
| 175 | + """ |
| 176 | + del path # Unused |
| 177 | + is_parent_module = inspect.ismodule(parent) |
| 178 | + is_parent_package = is_parent_module and hasattr(parent, '__path__') |
| 179 | + if is_parent_package: |
| 180 | + imported_symbols = _get_imported_symbols(parent) |
| 181 | + filtered_children = [] |
| 182 | + for child in children: |
| 183 | + name, obj = child |
| 184 | + if inspect.ismodule(obj): |
| 185 | + # Do not include modules in a package not explicitly imported by the |
| 186 | + # package. |
| 187 | + if is_parent_package and name not in imported_symbols: |
| 188 | + continue |
| 189 | + # Do not include modules imported by a module that is not a package. |
| 190 | + if is_parent_module and not is_parent_package: |
| 191 | + continue |
| 192 | + filtered_children.append(child) |
| 193 | + return filtered_children |
| 194 | + |
| 195 | + |
124 | 196 | class PublicAPIFilter(object): |
125 | 197 | """Visitor to use with `traverse` to filter just the public API.""" |
126 | 198 |
|
|
0 commit comments