Skip to content

Commit b1123db

Browse files
authored
Run decomp before processing (#5713)
1 parent bccbb5a commit b1123db

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

test/stablehlo/test_exports.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import unittest
2+
import torch
3+
import torch.nn.functional as F
4+
from torch_xla.stablehlo import exported_program_to_stablehlo
5+
6+
7+
class Interpolate(torch.nn.Module):
8+
9+
def forward(self, masks: torch.Tensor) -> torch.Tensor:
10+
masks = F.interpolate(
11+
masks,
12+
size=(500, 500),
13+
mode="bilinear",
14+
align_corners=False,
15+
)
16+
return masks
17+
18+
19+
class ExportTest(unittest.TestCase):
20+
21+
def test_interpolate(self):
22+
23+
arg = (torch.randn(3, 3, 200, 200),)
24+
model = Interpolate()
25+
26+
ans = model(*arg)
27+
28+
with torch.no_grad():
29+
exported = torch._export.export(model, arg)
30+
shlo = exported_program_to_stablehlo(exported)
31+
ans2 = shlo(*arg).cpu().to(torch.float32)
32+
self.assertTrue(torch.allclose(ans, ans2, atol=1e-5))

torch_xla/stablehlo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _exported_program_to_stablehlo_bundle(exported_model,
240240
options) -> StableHLOModelBundle:
241241
if options is None:
242242
options = StableHLOExportOptions()
243-
243+
exported_model = exported_model.run_decompositions()
244244
input_args = _extract_input_args(exported_model, options)
245245

246246
device = xm.xla_device()

0 commit comments

Comments
 (0)