Skip to content

Commit 7e0d3a5

Browse files
lsy323Siyuan Liu
andauthored
Add fx passes to support exporting unbounded dynamism (#6653)
Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
1 parent 1ccd6a6 commit 7e0d3a5

22 files changed

+1925
-395
lines changed

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ function run_xla_op_tests2 {
212212
function run_xla_op_tests3 {
213213
# TODO(qihqi): this test require tensorflow to run. need to setup separate
214214
# CI with tf.
215+
run_test "$CDIR/stablehlo/test_export_fx_passes.py"
215216
run_test "$CDIR/stablehlo/test_implicit_broadcasting.py"
216217
run_test "$CDIR/stablehlo/test_mark_pattern.py"
217218
run_test "$CDIR/stablehlo/test_pt2e_qdq.py"
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
import os
2+
import re
3+
import sys
4+
import unittest
5+
6+
import numpy as np
7+
import torch
8+
import torch.utils._pytree as pytree
9+
import torch_xla
10+
import torch_xla.core.xla_model as xm
11+
from torch.export import Dim, export
12+
from torch_xla.experimental.unbounded_dynamism_export import *
13+
from torch_xla.stablehlo import exported_program_to_stablehlo
14+
from torch_xla.utils.stablehlo_test_utils import wrap_func_as_nn_module
15+
16+
17+
class ExportFxPassTest(unittest.TestCase):
18+
19+
def test_decompose_dynamic_shape_select(self):
20+
args = (torch.rand((10, 197, 768)), 1, 0)
21+
dynamic_shapes = ([{0: Dim("bs")}, None, None],)
22+
m = wrap_func_as_nn_module(torch.ops.aten.select.int)
23+
ep = export(m, args, dynamic_shapes=dynamic_shapes)
24+
out1 = ep.module()(*args)
25+
decompose_dynamic_shape_select(ep.graph_module)
26+
ep.graph_module.recompile()
27+
self.assertTrue('aten.view' in ep.graph_module.code)
28+
replace_dynamic_view_with_xla_op(ep.graph_module)
29+
ep.graph_module.recompile()
30+
self.assertTrue('aten.view' not in ep.graph_module.code)
31+
self.assertTrue('xla.dynamic_view' in ep.graph_module.code)
32+
out2 = ep.module()(*args)
33+
self.assertTrue(torch.allclose(out1, out2))
34+
35+
def test_no_op_slice_removal(self):
36+
37+
class M(torch.nn.Module):
38+
39+
def forward(self, x):
40+
x = x * 2
41+
return torch.ops.aten.slice(x, 1, 0, 9223372036854775807)
42+
43+
m = M()
44+
args = (torch.rand((10, 197, 768)),)
45+
dynamic_shapes = ({0: Dim("bs")},)
46+
ep = export(m, args, dynamic_shapes=dynamic_shapes)
47+
out1 = ep.module()(*args)
48+
self.assertTrue('aten.slice' in ep.graph_module.code)
49+
remove_no_op_slice(ep.graph_module)
50+
ep.graph_module.recompile()
51+
self.assertTrue('aten.slice' not in ep.graph_module.code)
52+
out2 = ep.module()(*args)
53+
self.assertTrue(torch.allclose(out1, out2))
54+
55+
def test_dynamic_view(self):
56+
57+
class M(torch.nn.Module):
58+
59+
def __init__(self):
60+
super().__init__()
61+
self.conv = torch.nn.Conv2d(3, 5, [16, 16])
62+
63+
def forward(self, x):
64+
x = self.conv(x)
65+
return x.view(x.shape[0], x.shape[1], -1)
66+
67+
m = M()
68+
args = (torch.rand((10, 3, 224, 224)),)
69+
dynamic_shapes = ({0: Dim("bs")},)
70+
ep = export(m, args, dynamic_shapes=dynamic_shapes)
71+
out1 = ep.module()(*args)
72+
replace_dynamic_view_with_xla_op(ep.graph_module)
73+
ep.graph_module.recompile()
74+
self.assertTrue('xla.dynamic_view' in ep.graph_module.code)
75+
out2 = ep.module()(*args)
76+
self.assertTrue(torch.allclose(out1, out2))
77+
78+
def test_dynamic_view_non_bs(self):
79+
80+
class M(torch.nn.Module):
81+
82+
def forward(self, x):
83+
return x.view(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
84+
85+
m = M()
86+
args = (torch.rand((1, 3, 2, 16)),)
87+
dynamic_shapes = ({1: Dim("bs")},)
88+
ep = export(m, args, dynamic_shapes=dynamic_shapes)
89+
out1 = ep.module()(*args)
90+
replace_dynamic_view_with_xla_op(ep.graph_module)
91+
ep.graph_module.recompile()
92+
self.assertTrue('xla.dynamic_view' in ep.graph_module.code)
93+
out2 = ep.module()(*args)
94+
self.assertTrue(torch.allclose(out1, out2))
95+
96+
def test_dynamic_view_multiplier(self):
97+
98+
class M(torch.nn.Module):
99+
100+
def __init__(self):
101+
super().__init__()
102+
self.conv = torch.nn.Conv2d(3, 5, [16, 16])
103+
104+
def forward(self, x):
105+
x = self.conv(x)
106+
return x.view(x.shape[0] * x.shape[1], -1)
107+
108+
m = M()
109+
args = (torch.rand((10, 3, 224, 224)),)
110+
dynamic_shapes = ({0: Dim("bs")},)
111+
ep = export(m, args, dynamic_shapes=dynamic_shapes)
112+
out1 = ep.module()(*args)
113+
replace_dynamic_view_with_xla_op(ep.graph_module)
114+
print(ep)
115+
ep.graph_module.recompile()
116+
print(ep.graph_module.code)
117+
self.assertTrue('xla.dynamic_view' in ep.graph_module.code)
118+
out2 = ep.module()(*args)
119+
self.assertTrue(torch.allclose(out1, out2))
120+
121+
def test_dynamic_expand(self):
122+
123+
class M(torch.nn.Module):
124+
125+
def forward(self, x, image):
126+
return x.expand([image.shape[0], -1, -1])
127+
128+
m = M()
129+
args = (torch.rand((1, 1, 5)), torch.rand((3, 4)))
130+
dynamic_shapes = (
131+
None,
132+
{
133+
0: Dim("bs")
134+
},
135+
)
136+
ep = export(m, args, dynamic_shapes=dynamic_shapes)
137+
out1 = ep.module()(*args)
138+
replace_dynamic_expand_with_xla_op(ep.graph_module)
139+
ep.graph_module.recompile()
140+
self.assertTrue('xla.dynamic_expand' in ep.graph_module.code)
141+
out2 = ep.module()(*args)
142+
self.assertTrue(torch.allclose(out1, out2))
143+
144+
def test_dynamic_expand_2(self):
145+
146+
class M(torch.nn.Module):
147+
148+
def forward(self, x, range):
149+
return x.expand(1, 1, 8, range.shape[0], 256)
150+
151+
m = M()
152+
args = (torch.rand((1, 1, 1, 3, 256)), torch.arange(3))
153+
dynamic_shapes = ({3: Dim("bs")}, {0: Dim("bs")})
154+
ep = export(m, args, dynamic_shapes=dynamic_shapes)
155+
out1 = ep.module()(*args)
156+
print(ep)
157+
replace_dynamic_expand_with_xla_op(ep.graph_module)
158+
print(ep)
159+
ep.graph_module.recompile()
160+
self.assertTrue('xla.dynamic_expand' in ep.graph_module.code)
161+
out2 = ep.module()(*args)
162+
self.assertTrue(torch.allclose(out1, out2))
163+
164+
def test_layer_norm_decomp(self):
165+
166+
class M(torch.nn.Module):
167+
168+
def forward(self, x, dim, weight, bias, eps):
169+
return torch.ops.aten.native_layer_norm.default(x, dim, weight, bias,
170+
eps)[0]
171+
172+
args = (torch.rand(10, 197,
173+
768), [768], torch.rand(768), torch.rand(768), 1e-12)
174+
dynamic_shapes = ({0: Dim("bs")}, [None], None, None, None)
175+
m = M().eval()
176+
before_decomp_out = m(*args)
177+
after_decomp_out = native_layer_norm_impl(*args)
178+
self.assertTrue(
179+
torch.allclose(before_decomp_out, after_decomp_out, atol=1e-6))
180+
ep = export(m, args, dynamic_shapes=dynamic_shapes)
181+
decompose_dynamic_native_layer_norm(ep.graph_module)
182+
ep.graph_module.recompile()
183+
self.assertFalse('aten.native_layer_norm' in ep.graph_module.code)
184+
after_decomp_out_2 = ep.module()(*args)
185+
self.assertTrue(
186+
torch.allclose(before_decomp_out, after_decomp_out_2, atol=1e-6))
187+
188+
def test_group_norm_to_layer_norm(self):
189+
190+
class M(torch.nn.Module):
191+
192+
def forward(self, x, weight, bias, N, C, HxW, group, eps):
193+
return torch.ops.aten.native_group_norm.default(x, weight, bias, N, C,
194+
HxW, group, eps)[0]
195+
196+
class M2(torch.nn.Module):
197+
198+
def __init__(self):
199+
super().__init__()
200+
# self.conv = torch.nn.Conv1d(1, 512, 10, stride=5)
201+
self.layer_norm = torch.nn.GroupNorm(
202+
num_groups=512, num_channels=512, affine=True)
203+
204+
def forward(self, x):
205+
return self.layer_norm(x)[0]
206+
207+
args = (torch.rand(10, 512, 159), torch.rand(512), torch.rand(512), 10, 512,
208+
159, 512, 1e-12)
209+
export_args = (torch.rand(10, 512, 159),)
210+
dynamic_shapes = ({0: Dim("bs")},)
211+
m = M().eval()
212+
before_decomp_out = m(*args)
213+
after_decomp_out = native_group_norm_impl(*args)
214+
self.assertTrue(
215+
torch.allclose(before_decomp_out, after_decomp_out, atol=1e-6))
216+
# Test export path with a different to workaround an export issue.
217+
m2 = M2().eval()
218+
ep = export(m2, export_args, dynamic_shapes=dynamic_shapes)
219+
before_decomp_ep_out = m2(*export_args)
220+
decompose_dynamic_native_group_norm(ep.graph_module)
221+
ep.graph_module.recompile()
222+
self.assertFalse('aten.native_group_norm' in ep.graph_module.code)
223+
after_decomp_ep_out = ep.module()(*export_args)
224+
# print(before_decomp_ep_out - after_decomp_ep_out)
225+
self.assertTrue(
226+
torch.allclose(before_decomp_ep_out, after_decomp_ep_out, atol=1e-6))
227+
228+
def test_dynamic_unsqueeze_to_view(self):
229+
230+
class M(torch.nn.Module):
231+
232+
def forward(self, x):
233+
return torch.ops.aten.unsqueeze.default(x, 2)
234+
235+
args = (torch.rand((1, 1, 3, 256)),)
236+
dynamic_shapes = ({2: Dim("dim")},)
237+
m = M().eval()
238+
ep = export(m, args, dynamic_shapes=dynamic_shapes)
239+
out1 = ep.module()(*args)
240+
dynamic_unsqueeze_to_view(ep.graph_module)
241+
ep.graph_module.recompile()
242+
self.assertFalse('aten.unsqueeze' in ep.graph_module.code)
243+
self.assertTrue('aten.view' in ep.graph_module.code)
244+
out2 = ep.module()(*args)
245+
self.assertTrue(torch.allclose(out1, out2))
246+
247+
248+
if __name__ == "__main__":
249+
test = unittest.main()
250+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/stablehlo/test_mark_pattern.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch_xla import stablehlo
1212
from torch_xla.experimental import xla_marker
1313
from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
14-
from utils import has_tf_package
14+
from torch_xla.utils.stablehlo_test_utils import has_tf_package
1515

1616
try:
1717
from torch_xla.tf_saved_model_integration import \

test/stablehlo/test_pt2e_qdq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
1212
XNNPACKQuantizer, get_symmetric_quantization_config)
1313
from torch_xla import stablehlo
14-
from utils import has_tf_package
14+
from torch_xla.utils.stablehlo_test_utils import has_tf_package
1515

1616
try:
1717
from torch_xla.tf_saved_model_integration import \

test/stablehlo/test_saved_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from torch_xla.tf_saved_model_integration import (
1515
make_tf_function, save_stablehlo_graph_as_tf,
1616
save_torch_module_as_tf_saved_model)
17-
from utils import (compare_exported_program_and_saved_model_result,
18-
has_tf_package, wrap_func_as_nn_module)
17+
from torch_xla.utils.stablehlo_test_utils import (
18+
compare_exported_program_and_saved_model_result, has_tf_package,
19+
wrap_func_as_nn_module)
1920

2021

2122
class StableHLOInferenceTest(unittest.TestCase):

0 commit comments

Comments
 (0)