Skip to content

Commit bf68e16

Browse files
janselpytorchmergebot
authored andcommitted
[dynamo] Fix support for classmethod(property(...)) (pytorch#134968)
Fixes pytorch#134451 Pull Request resolved: pytorch#134968 Approved by: https://github.com/yanboliang
1 parent d732df7 commit bf68e16

File tree

3 files changed

+82
-1
lines changed

3 files changed

+82
-1
lines changed

test/dynamo/test_repros.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from abc import ABC
2121
from collections import namedtuple
2222
from copy import deepcopy
23-
from enum import Enum
23+
from enum import Enum, IntEnum
2424
from functools import wraps
2525
from typing import Any, Dict, Iterator, List, Tuple
2626
from unittest import mock
@@ -4546,6 +4546,81 @@ def f(*args):
45464546
f(*args)
45474547
self.assertEqual(num_compiles, 1)
45484548

4549+
def test_issue134451(self):
4550+
class BoundingBox2DIndex(IntEnum):
4551+
_X = 0
4552+
_Y = 1
4553+
_HEADING = 2
4554+
_LENGTH = 3
4555+
_WIDTH = 4
4556+
4557+
@classmethod
4558+
def size(cls):
4559+
return 5
4560+
4561+
@classmethod
4562+
@property
4563+
def X(cls):
4564+
return cls._X
4565+
4566+
@classmethod
4567+
@property
4568+
def Y(cls):
4569+
return cls._Y
4570+
4571+
@classmethod
4572+
@property
4573+
def HEADING(cls):
4574+
return cls._HEADING
4575+
4576+
@classmethod
4577+
@property
4578+
def LENGTH(cls):
4579+
return cls._LENGTH
4580+
4581+
@classmethod
4582+
@property
4583+
def WIDTH(cls):
4584+
return cls._WIDTH
4585+
4586+
@classmethod
4587+
@property
4588+
def POINT(cls):
4589+
# assumes X, Y have subsequent indices
4590+
return slice(cls._X, cls._Y + 1)
4591+
4592+
@classmethod
4593+
@property
4594+
def STATE_SE2(cls):
4595+
# assumes X, Y, HEADING have subsequent indices
4596+
return slice(cls._X, cls._HEADING + 1)
4597+
4598+
class SimpleModel(nn.Module):
4599+
def __init__(self):
4600+
super().__init__()
4601+
self._mlp_states = nn.Sequential(
4602+
nn.Linear(10, 20),
4603+
nn.ReLU(),
4604+
nn.Linear(20, BoundingBox2DIndex.size()),
4605+
)
4606+
4607+
def forward(self, x):
4608+
agent_states = self._mlp_states(x)
4609+
agent_states[..., BoundingBox2DIndex.POINT] = (
4610+
agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32
4611+
)
4612+
agent_states[..., BoundingBox2DIndex.HEADING] = (
4613+
agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi
4614+
)
4615+
return agent_states
4616+
4617+
model = SimpleModel().eval()
4618+
input_tensor = torch.randn(1, 10, dtype=torch.float32)
4619+
opt = torch.compile(model.eval(), backend="eager", fullgraph=True)
4620+
actual = opt(input_tensor)
4621+
expected = model(input_tensor)
4622+
self.assertEqual(actual, expected)
4623+
45494624
def test_invalid_seq_unpack(self):
45504625
def myfn(arg):
45514626
(a, b) = arg

torch/_dynamo/variables/constant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,8 @@ def create(cls, cls_type, value_vt, options):
222222
unimplemented("Enum variable is constructed with non constant values")
223223

224224
def as_proxy(self):
225+
if isinstance(self.value, int):
226+
return int(self.value) # convert IntEnum to a normal int
225227
return self.value
226228

227229
def __str__(self) -> str:

torch/_dynamo/variables/user_defined.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke
197197
else:
198198
return SourcelessBuilder.create(tx, func)
199199
elif isinstance(obj, classmethod):
200+
if isinstance(obj.__func__, property):
201+
return variables.UserFunctionVariable(obj.__func__.fget).call_function(
202+
tx, [self], {}
203+
)
200204
return variables.UserMethodVariable(obj.__func__, self, source=source)
201205
elif isinstance(obj, types.ClassMethodDescriptorType):
202206
# e.g.: inspect.getattr_static(dict, "fromkeys")

0 commit comments

Comments
 (0)