Skip to content

Commit 9344da8

Browse files
Revert "[fake tensor cache] Support index with non bool/int8 indices (pytorch#151477)"
This reverts commit bdb34f5. Reverted pytorch#151477 on behalf of https://github.com/wdvr due to reverting confusing ghstack state ([comment](pytorch#151477 (comment)))
1 parent 348272e commit 9344da8

File tree

2 files changed

+0
-40
lines changed

2 files changed

+0
-40
lines changed

test/test_fake_tensor.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,31 +2158,6 @@ def test_cache_tuple_outputs(self):
21582158
extract_tensor_metadata(b),
21592159
)
21602160

2161-
2162-
def test_cache_aten_index(self):
2163-
with FakeTensorMode():
2164-
x = torch.randn(4, 4, 4)
2165-
idx_tensor1 = torch.tensor([0, 2, 3])
2166-
idx_tensor2 = torch.tensor([0, 1, 2])
2167-
2168-
FakeTensorMode.cache_clear()
2169-
self.assertHitsMisses(0, 0)
2170-
2171-
ref = torch.ops.aten.index(x, [None, idx_tensor1, idx_tensor2])
2172-
self.assertHitsMisses(0, 3)
2173-
2174-
res = torch.ops.aten.index(x, [None, idx_tensor1, idx_tensor2])
2175-
self.assertHitsMisses(1, 3)
2176-
self.assertEqual(extract_tensor_metadata(ref), extract_tensor_metadata(res))
2177-
2178-
with FakeTensorMode():
2179-
x = torch.randn(4, 4, 4)
2180-
idx_tensor1 = torch.tensor([True, True, False, True])
2181-
self.assertRaises(DynamicOutputShapeException, lambda: torch.ops.aten.index(x, [None, idx_tensor1]))
2182-
2183-
idx_tensor1 = torch.tensor([1, -2, 3, -4], dtype=torch.int8)
2184-
self.assertRaises(DynamicOutputShapeException, lambda: torch.ops.aten.index(x, [None, idx_tensor1]))
2185-
21862161
@skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching")
21872162
def test_invoke_subgraph(self):
21882163
"""

torch/_subclasses/fake_tensor.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1539,21 +1539,6 @@ def _validate_cache_key(
15391539
raise _BypassDispatchCache("data dependent output")
15401540

15411541
if torch.Tag.dynamic_output_shape in func.tags:
1542-
if func is aten.index.Tensor:
1543-
_, new_kwargs = normalize_function( # type: ignore[misc]
1544-
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type]
1545-
)
1546-
for index in new_kwargs["indices"]:
1547-
# index calls nonzero for bool or int8 tensors, and
1548-
# therefore has a dynamic shape output. For other dtypes,
1549-
# the output shape depends on the input shape (and not data)
1550-
if isinstance(index, torch.Tensor) and index.dtype in (
1551-
torch.bool,
1552-
torch.int8,
1553-
):
1554-
raise _BypassDispatchCache("dynamic output shape")
1555-
return
1556-
15571542
raise _BypassDispatchCache("dynamic output shape")
15581543

15591544
if torch.Tag.inplace_view in func.tags:

0 commit comments

Comments
 (0)