File tree Expand file tree Collapse file tree 1 file changed +24
-11
lines changed Expand file tree Collapse file tree 1 file changed +24
-11
lines changed Original file line number Diff line number Diff line change 22import torch
33from torch import * # noqa: F401, F403
44
5- for n in dir (torch ):
5+ from .._internal import _get_all_public_members
6+
7+
8+ def filter_ (name ):
69 if (
7- n .startswith ("_" )
8- or n .endswith ("_" )
9- or "cuda" in n
10- or "cpu" in n
11- or "backward" in n
10+ name .startswith ("_" )
11+ or name .endswith ("_" )
12+ or "cuda" in name
13+ or "cpu" in name
14+ or "backward" in name
1215 ):
13- continue
14- exec (n + " = torch." + n )
16+ return False
17+ return True
18+
19+
20+ _torch_all = _get_all_public_members (torch , filter_ = filter_ )
1521
22+ for _name in _torch_all :
23+ globals ()[_name ] = getattr (torch , _name )
1624
17- from ..common ._helpers import (
25+
26+ from ..common ._helpers import ( # noqa: E402
1827 array_namespace ,
1928 device ,
2029 get_namespace ,
2433)
2534
2635# These imports may overwrite names from the import * above.
27- from ._aliases import (
36+ from ._aliases import ( # noqa: E402
2837 add ,
2938 all ,
3039 any ,
92101 zeros ,
93102)
94103
95- __all__ = [
104+ __all__ = []
105+
106+ __all__ += _torch_all
107+
108+ __all__ += [
96109 "is_array_api_obj" ,
97110 "array_namespace" ,
98111 "get_namespace" ,
You can’t perform that action at this time.
0 commit comments