Skip to content

Commit bccbb5a

Browse files
authored
Add support for _unsafe_index. (#5707)
* Add support for `_unsafe_index`. * Fix lint issues. * Add tests.
1 parent 6ea9947 commit bccbb5a

File tree

4 files changed

+51
-0
lines changed

4 files changed

+51
-0
lines changed

codegen/xla_native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ supported:
130130
- _to_cpu
131131
- _to_copy
132132
- _unsafe_view
133+
- _unsafe_index.Tensor
133134
- adaptive_max_pool2d
134135
- adaptive_max_pool2d_backward
135136
- add.Scalar

test/cpp/test_aten_xla_tensor_1.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2033,6 +2033,34 @@ TEST_F(AtenXlaTensorTest, TestScatterReduceMaxInPlace) {
20332033
ExpectCounterChanged("xla::scatter_reduce", cpp_test::GetIgnoredCounters());
20342034
}
20352035

2036+
TEST_F(AtenXlaTensorTest, TestUnsafeIndex) {
2037+
for (torch::ScalarType scalar_type :
2038+
{torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt,
2039+
torch::kLong}) {
2040+
torch::Tensor a =
2041+
isFloatingType(scalar_type)
2042+
? torch::rand({3, 4}, torch::TensorOptions(scalar_type))
2043+
: torch::randint(100, {3, 4}, torch::TensorOptions(scalar_type));
2044+
for (torch::ScalarType index_scalar_type : {torch::kInt, torch::kLong}) {
2045+
torch::List<torch::optional<torch::Tensor>> indices{
2046+
torch::tensor({0, 1}, torch::TensorOptions(index_scalar_type)),
2047+
torch::tensor({2, 3}, torch::TensorOptions(index_scalar_type))};
2048+
torch::Tensor c0 = torch::_unsafe_index(a, indices);
2049+
ForEachDevice([&](const torch::Device& device) {
2050+
torch::Tensor xla_a = CopyToDevice(a, device);
2051+
torch::List<torch::optional<torch::Tensor>> xla_indices{
2052+
CopyToDevice(*indices.get(0), device),
2053+
CopyToDevice(*indices.get(1), device)};
2054+
torch::Tensor xla_c0 = torch::_unsafe_index(xla_a, xla_indices);
2055+
AllEqual(c0, xla_c0);
2056+
});
2057+
}
2058+
}
2059+
ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
2060+
ExpectCounterChanged("xla::index", cpp_test::GetIgnoredCounters());
2061+
ExpectCounterChanged("xla::_unsafe_index", cpp_test::GetIgnoredCounters());
2062+
}
2063+
20362064
TEST_F(AtenXlaTensorTest, TestIndexSelect) {
20372065
for (torch::ScalarType scalar_type :
20382066
{torch::kFloat, torch::kByte, torch::kChar, torch::kShort, torch::kInt,

test/dynamo/test_bridge.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,19 @@ def get_random_inputs(self):
7878
return (torch.randn(10), torch.randn(10))
7979

8080

81+
class UpsampleModule(nn.Module):
82+
83+
def __init__(self):
84+
super().__init__()
85+
self.upsample = nn.Upsample(scale_factor=2)
86+
87+
def forward(self, x):
88+
return self.upsample(x)
89+
90+
def get_random_inputs(self):
91+
return (torch.randn((1, 1, 5)),)
92+
93+
8194
def allclose(expected, actual):
8295

8396
def unwrap(cont):
@@ -179,6 +192,7 @@ def test_wrapper(self):
179192

180193
model = model.to(device=xla_dev)
181194
inputs = tuple(inp.to(device=xla_dev) for inp in inputs)
195+
inputs = tuple(inp.requires_grad_() for inp in inputs)
182196

183197
# do baseline
184198
baseline_model = copy.deepcopy(model)
@@ -206,6 +220,7 @@ class TorchXLAReuseGraphTest(torch._dynamo.test_case.TestCase):
206220

207221
test_training_linear = make_training_test(LinearModule)
208222
test_training_maxpool = make_training_test(MaxPoolModule)
223+
test_training_upsample = make_training_test(UpsampleModule)
209224

210225
def test_non_tensor_args_for_partition(self):
211226

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,13 @@ at::Tensor XLANativeFunctions::_unsafe_view(const at::Tensor& self,
626626
return view_copy_symint(self, c10::fromIntArrayRefSlow(size));
627627
}
628628

629+
at::Tensor XLANativeFunctions::_unsafe_index(
630+
const at::Tensor& self,
631+
const c10::List<c10::optional<at::Tensor>>& indices) {
632+
TORCH_LAZY_FN_COUNTER("xla::");
633+
return index(self, indices);
634+
}
635+
629636
at::Tensor XLANativeFunctions::add(const at::Tensor& self,
630637
const at::Tensor& other,
631638
const at::Scalar& alpha) {

0 commit comments

Comments
 (0)