Skip to content

Commit 790f573

Browse files
msaroufimpytorchmergebot
authored andcommitted
Fix Graph Break on builtin comparison on NNModule (pytorch#103176)
Fixes pytorch#102338 Pull Request resolved: pytorch#103176 Approved by: https://github.com/anijain2305
1 parent 95fced4 commit 790f573

File tree

3 files changed

+36
-0
lines changed

3 files changed

+36
-0
lines changed

test/dynamo/test_modules.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1925,6 +1925,33 @@ def fn(x):
19251925
res = opt_fn(x)
19261926
self.assertEqual(ref, res)
19271927

1928+
def test_no_graphbreak_builtin_equal(self):
1929+
class MyModule(torch.nn.Module):
1930+
def __init__(self):
1931+
super().__init__()
1932+
self.layer0 = torch.nn.Linear(10, 10)
1933+
self.layer1 = torch.nn.Linear(10, 10)
1934+
self.layer2 = torch.nn.Linear(10, 10)
1935+
1936+
@property
1937+
def encoder_layers(self):
1938+
return [self.layer0, self.layer1, self.layer2]
1939+
1940+
def forward(self, x):
1941+
for layer in self.encoder_layers:
1942+
output = layer(x)
1943+
if layer == self.layer0:
1944+
output = F.relu6(output)
1945+
else:
1946+
output = F.relu(output)
1947+
return output
1948+
1949+
x = torch.randn(10, 10)
1950+
1951+
m = MyModule()
1952+
1953+
opt_m = torch.compile(backend="eager", fullgraph=True)(m)
1954+
19281955

19291956
if __name__ == "__main__":
19301957
from torch._dynamo.test_case import run_tests

torch/_dynamo/symbolic_convert.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,7 @@ def COMPARE_OP(self, inst):
10631063
supported_const_comparison_ops[op](object(), right.value), **options
10641064
)
10651065
)
1066+
10661067
elif (
10671068
left.is_python_constant()
10681069
and right.is_python_constant()

torch/_dynamo/variables/builtin.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,7 @@ def _comparison(self, tx, left, right):
12151215
from . import (
12161216
BaseListVariable,
12171217
ConstantVariable,
1218+
NNModuleVariable,
12181219
TensorVariable,
12191220
UserFunctionVariable,
12201221
)
@@ -1229,6 +1230,13 @@ def _comparison(self, tx, left, right):
12291230
def _unimplemented():
12301231
unimplemented(f"comparison {typestr(left)} {op} {typestr(right)}")
12311232

1233+
if (
1234+
isinstance(left, NNModuleVariable)
1235+
and isinstance(right, NNModuleVariable)
1236+
and op in supported_const_comparison_ops
1237+
):
1238+
self.push(ConstantVariable(op(left, right)))
1239+
12321240
if isinstance(left, UserFunctionVariable):
12331241
if op not in supported_const_comparison_ops.values():
12341242
_unimplemented()

0 commit comments

Comments
 (0)