|
20 | 20 | from abc import ABC |
21 | 21 | from collections import namedtuple |
22 | 22 | from copy import deepcopy |
23 | | -from enum import Enum |
| 23 | +from enum import Enum, IntEnum |
24 | 24 | from functools import wraps |
25 | 25 | from typing import Any, Dict, Iterator, List, Tuple |
26 | 26 | from unittest import mock |
@@ -4546,6 +4546,81 @@ def f(*args): |
4546 | 4546 | f(*args) |
4547 | 4547 | self.assertEqual(num_compiles, 1) |
4548 | 4548 |
|
| 4549 | + def test_issue134451(self): |
| 4550 | + class BoundingBox2DIndex(IntEnum): |
| 4551 | + _X = 0 |
| 4552 | + _Y = 1 |
| 4553 | + _HEADING = 2 |
| 4554 | + _LENGTH = 3 |
| 4555 | + _WIDTH = 4 |
| 4556 | + |
| 4557 | + @classmethod |
| 4558 | + def size(cls): |
| 4559 | + return 5 |
| 4560 | + |
| 4561 | + @classmethod |
| 4562 | + @property |
| 4563 | + def X(cls): |
| 4564 | + return cls._X |
| 4565 | + |
| 4566 | + @classmethod |
| 4567 | + @property |
| 4568 | + def Y(cls): |
| 4569 | + return cls._Y |
| 4570 | + |
| 4571 | + @classmethod |
| 4572 | + @property |
| 4573 | + def HEADING(cls): |
| 4574 | + return cls._HEADING |
| 4575 | + |
| 4576 | + @classmethod |
| 4577 | + @property |
| 4578 | + def LENGTH(cls): |
| 4579 | + return cls._LENGTH |
| 4580 | + |
| 4581 | + @classmethod |
| 4582 | + @property |
| 4583 | + def WIDTH(cls): |
| 4584 | + return cls._WIDTH |
| 4585 | + |
| 4586 | + @classmethod |
| 4587 | + @property |
| 4588 | + def POINT(cls): |
| 4589 | + # assumes X, Y have subsequent indices |
| 4590 | + return slice(cls._X, cls._Y + 1) |
| 4591 | + |
| 4592 | + @classmethod |
| 4593 | + @property |
| 4594 | + def STATE_SE2(cls): |
| 4595 | + # assumes X, Y, HEADING have subsequent indices |
| 4596 | + return slice(cls._X, cls._HEADING + 1) |
| 4597 | + |
| 4598 | + class SimpleModel(nn.Module): |
| 4599 | + def __init__(self): |
| 4600 | + super().__init__() |
| 4601 | + self._mlp_states = nn.Sequential( |
| 4602 | + nn.Linear(10, 20), |
| 4603 | + nn.ReLU(), |
| 4604 | + nn.Linear(20, BoundingBox2DIndex.size()), |
| 4605 | + ) |
| 4606 | + |
| 4607 | + def forward(self, x): |
| 4608 | + agent_states = self._mlp_states(x) |
| 4609 | + agent_states[..., BoundingBox2DIndex.POINT] = ( |
| 4610 | + agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32 |
| 4611 | + ) |
| 4612 | + agent_states[..., BoundingBox2DIndex.HEADING] = ( |
| 4613 | + agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi |
| 4614 | + ) |
| 4615 | + return agent_states |
| 4616 | + |
| 4617 | + model = SimpleModel().eval() |
| 4618 | + input_tensor = torch.randn(1, 10, dtype=torch.float32) |
| 4619 | + opt = torch.compile(model.eval(), backend="eager", fullgraph=True) |
| 4620 | + actual = opt(input_tensor) |
| 4621 | + expected = model(input_tensor) |
| 4622 | + self.assertEqual(actual, expected) |
| 4623 | + |
4549 | 4624 | def test_invalid_seq_unpack(self): |
4550 | 4625 | def myfn(arg): |
4551 | 4626 | (a, b) = arg |
|
0 commit comments