@@ -2158,31 +2158,6 @@ def test_cache_tuple_outputs(self):
21582158 extract_tensor_metadata (b ),
21592159 )
21602160
2161-
2162- def test_cache_aten_index (self ):
2163- with FakeTensorMode ():
2164- x = torch .randn (4 , 4 , 4 )
2165- idx_tensor1 = torch .tensor ([0 , 2 , 3 ])
2166- idx_tensor2 = torch .tensor ([0 , 1 , 2 ])
2167-
2168- FakeTensorMode .cache_clear ()
2169- self .assertHitsMisses (0 , 0 )
2170-
2171- ref = torch .ops .aten .index (x , [None , idx_tensor1 , idx_tensor2 ])
2172- self .assertHitsMisses (0 , 3 )
2173-
2174- res = torch .ops .aten .index (x , [None , idx_tensor1 , idx_tensor2 ])
2175- self .assertHitsMisses (1 , 3 )
2176- self .assertEqual (extract_tensor_metadata (ref ), extract_tensor_metadata (res ))
2177-
2178- with FakeTensorMode ():
2179- x = torch .randn (4 , 4 , 4 )
2180- idx_tensor1 = torch .tensor ([True , True , False , True ])
2181- self .assertRaises (DynamicOutputShapeException , lambda : torch .ops .aten .index (x , [None , idx_tensor1 ]))
2182-
2183- idx_tensor1 = torch .tensor ([1 , - 2 , 3 , - 4 ], dtype = torch .int8 )
2184- self .assertRaises (DynamicOutputShapeException , lambda : torch .ops .aten .index (x , [None , idx_tensor1 ]))
2185-
21862161 @skipIfTorchDynamo ("cache hit/miss changes with invoke_subgraph caching" )
21872162 def test_invoke_subgraph (self ):
21882163 """
0 commit comments