Skip to content

Commit 1491a61

Browse files
Revert "[hop] ban creating hop by directly instantiating HigherOrderOperator. (pytorch#133645)"
This reverts commit 696107e. Reverted pytorch#133645 on behalf of https://github.com/ydwu4 due to breaking ci. probably due to land race ([comment](pytorch#133645 (comment)))
1 parent 5fcfcce commit 1491a61

File tree

11 files changed

+11
-68
lines changed

11 files changed

+11
-68
lines changed

test/dynamo/test_higher_order_ops.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6299,11 +6299,7 @@ def fn(x):
62996299
self._validate(fn, backend, x)
63006300

63016301
def test_override_fallthrough_dispatch_key(self):
6302-
class _FallthroughTestOnly(torch._ops.HigherOrderOperator):
6303-
def __init__(self):
6304-
super().__init__("_fallthrough_test_only")
6305-
6306-
test_op = _FallthroughTestOnly()
6302+
test_op = torch._ops.HigherOrderOperator("_fallthrough_test_only")
63076303
default_keys = torch._ops._HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS
63086304
self.assertTrue(
63096305
not any(test_op.non_fallthrough_keys.has(key) for key in default_keys)

test/functorch/test_eager_transforms.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4989,11 +4989,7 @@ def forward(self, x_1):
49894989

49904990

49914991
def construct_sum_pyop():
4992-
class MySum(HigherOrderOperator):
4993-
def __init__(self):
4994-
super().__init__("mysum")
4995-
4996-
mysum = MySum()
4992+
mysum = HigherOrderOperator("mysum")
49974993

49984994
@mysum.py_impl(torch._C._functorch.TransformType.Vmap)
49994995
def mysum_batch_rule(interpreter, x, dim):

torch/_dynamo/_trace_wrapped_higher_order_op.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,8 @@ def trace_wrapped(*args, **kwargs):
4848
return _trace_wrapped_op(*args, **kwargs)
4949

5050

51-
class TraceWrapped(HigherOrderOperator):
52-
def __init__(self):
53-
super().__init__("trace_wrapped")
54-
55-
5651
# TODO(jansel): need to ensure this does not get DCEed
57-
_trace_wrapped_op = TraceWrapped()
52+
_trace_wrapped_op = HigherOrderOperator("trace_wrapped")
5853

5954

6055
def _assert_meta(grad, size, stride, dtype):

torch/_export/wrappers.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,7 @@
1212
from torch.utils import _pytree as pytree
1313

1414

15-
class ExportTracepoint(HigherOrderOperator):
16-
def __init__(self):
17-
super().__init__("export_tracepoint")
18-
19-
20-
_export_tracepoint = ExportTracepoint()
15+
_export_tracepoint = HigherOrderOperator("_export_tracepoint")
2116

2217

2318
@_export_tracepoint.py_impl(ProxyTorchDispatchMode)

torch/_higher_order_ops/executorch_call_delegate.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,7 @@
2525
from torch.utils._pytree import tree_flatten
2626

2727

28-
class ExecutorchCallDelegate(HigherOrderOperator):
29-
def __init__(self):
30-
super().__init__("executorch_call_delegate")
31-
32-
33-
executorch_call_delegate = ExecutorchCallDelegate()
28+
executorch_call_delegate = HigherOrderOperator("executorch_call_delegate")
3429
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonDispatcher)
3530
executorch_call_delegate.fallthrough(torch._C.DispatchKey.PythonTLSSnapshot)
3631
executorch_call_delegate.fallthrough(torch._C.DispatchKey.ADInplaceOrView)

torch/_higher_order_ops/map.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,8 @@ def __call__(self, xs, *args):
3636
return map_wrapper(xs, *args)
3737

3838

39-
class MapImpl(HigherOrderOperator):
40-
def __init__(self):
41-
super().__init__("map_impl")
42-
43-
4439
map = MapWrapper("map")
45-
46-
map_impl = MapImpl()
40+
map_impl = HigherOrderOperator("map_impl")
4741

4842
dummy_aot_config = AOTConfig(
4943
fw_compiler=None, # type: ignore[arg-type]

torch/_higher_order_ops/run_const_graph.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,7 @@
88
from torch.utils import _pytree as pytree
99

1010

11-
class RunConstGraph(HigherOrderOperator):
12-
def __init__(self):
13-
super().__init__("run_const_graph")
14-
15-
16-
run_const_graph = RunConstGraph()
11+
run_const_graph = HigherOrderOperator("run_const_graph")
1712

1813

1914
@run_const_graph.py_impl(ProxyTorchDispatchMode)

torch/_higher_order_ops/strict_mode.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,7 @@ def strict_mode(callable, operands):
2828
)
2929

3030

31-
class StrictMode(HigherOrderOperator):
32-
def __init__(self):
33-
super().__init__("strict_mode")
34-
35-
36-
strict_mode_op = StrictMode()
31+
strict_mode_op = HigherOrderOperator("strict_mode")
3732

3833

3934
@strict_mode_op.py_impl(DispatchKey.CompositeExplicitAutograd)

torch/_higher_order_ops/torchbind.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,12 @@
1616

1717
log = logging.getLogger(__name__)
1818

19-
2019
# The call_torchbind operator represents a method invocation on a torchbind
2120
# object. The calling convention is:
2221
# call_torchbind(self: ScriptObject, method_name: str, *method_args, **method_kwargs)
2322
# We do not expect users to write this operator directly. Instead it will be
2423
# emitted by Dynamo when tracing encounters a torchbind object.
25-
class CallTorchBind(HigherOrderOperator):
26-
def __init__(self):
27-
super().__init__("call_torchbind")
28-
29-
30-
call_torchbind = CallTorchBind()
24+
call_torchbind = HigherOrderOperator("call_torchbind")
3125

3226
# Register this operator as side-effectful with FX.
3327
# TODO: this is not really sufficient. While passes (hopefully) check

torch/_ops.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,6 @@ class HigherOrderOperator(OperatorBase):
246246
# practice due to name collisions.
247247
def __init__(self, name):
248248
super().__init__()
249-
if type(self) is HigherOrderOperator:
250-
raise RuntimeError(
251-
"Direct instantiation of HigherOrderOperator is not allowed. Please subclass it."
252-
)
253249
self._name = name
254250

255251
# Make _OPNamespace not scream, this whole name based association needs a good hard look

0 commit comments

Comments
 (0)