Skip to content

Commit 11854bc

Browse files
Martin Yuanfacebook-github-bot
authored andcommitted
Add test to torch.jit.export_opnames, make the _C function private
Summary: Pull Request resolved: pytorch#31446 Test Plan: Imported from OSS Differential Revision: D19172851 Pulled By: iseeyuan fbshipit-source-id: f06d8766ed73c9abe4ebf41c402ee64880d745be
1 parent 81329c9 commit 11854bc

File tree

4 files changed

+34
-4
lines changed

4 files changed

+34
-4
lines changed

docs/source/torch.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,4 +377,3 @@ Utilities
377377
.. autofunction:: result_type
378378
.. autofunction:: can_cast
379379
.. autofunction:: promote_types
380-
.. autofunction:: export_opnames

test/test_jit.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3693,6 +3693,38 @@ def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
36933693
self.assertTrue(type(block.paramNode()) == torch._C.Node)
36943694
self.assertTrue(tested_blocks)
36953695

3696+
def test_export_opnames(self):
3697+
class Foo(torch.jit.ScriptModule):
3698+
def __init__(self):
3699+
super(Foo, self).__init__()
3700+
3701+
def one(self, x, y):
3702+
# type: (Tensor, Tensor) -> Tensor
3703+
return x + y
3704+
3705+
def two(self, x):
3706+
# type: (Tensor) -> Tensor
3707+
return 2 * x
3708+
3709+
@torch.jit.script_method
3710+
def forward(self, x):
3711+
# type: (Tensor) -> Tensor
3712+
return self.one(self.two(x), x)
3713+
3714+
class Bar(torch.jit.ScriptModule):
3715+
def __init__(self):
3716+
super(Bar, self).__init__()
3717+
self.sub = Foo()
3718+
3719+
def forward(self, x):
3720+
# type: (Tensor) -> Tensor
3721+
return self.sub.forward(x)
3722+
3723+
bar = Bar()
3724+
ops = torch.jit.export_opnames(bar)
3725+
expected = ['aten::add.Tensor', 'aten::mul.Scalar', 'prim::Constant']
3726+
self.assertEqual(ops, expected)
3727+
36963728
def test_pytorch_jit_env_off(self):
36973729
import subprocess
36983730
env = os.environ.copy()
@@ -3735,7 +3767,6 @@ def forward(self):
37353767
model_loaded = torch.jit.load(buffer)
37363768
self.assertEqual(model_loaded(), model())
37373769

3738-
37393770
class TestFrontend(JitTestCase):
37403771

37413772
def test_instancing_error(self):

torch/csrc/jit/script/init.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1175,7 +1175,7 @@ void initJitScriptBindings(PyObject* module) {
11751175
return Module(get_python_cu(), type);
11761176
});
11771177

1178-
m.def("export_opnames",
1178+
m.def("_export_opnames",
11791179
[](script::Module& sm) {return debugMakeList(torch::jit::export_opnames(sm));});
11801180

11811181
py::class_<ConcreteModuleTypeBuilder, std::shared_ptr<ConcreteModuleTypeBuilder>>(

torch/jit/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def export_opnames(m):
240240
r"""
241241
Returns a list of operator names of a script module and its submodules
242242
"""
243-
return torch._C.export_opnames(m._c)
243+
return torch._C._export_opnames(m._c)
244244

245245
def _get_trace_graph(f, args=(), kwargs=None, _force_outplace=False,
246246
return_inputs=False, _return_inputs_states=False):

0 commit comments

Comments
 (0)