|
| 1 | +import os |
| 2 | +import re |
| 3 | +import sys |
| 4 | +import unittest |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import torch |
| 8 | +import torch.utils._pytree as pytree |
| 9 | +import torch_xla |
| 10 | +import torch_xla.core.xla_model as xm |
| 11 | +from torch.export import Dim, export |
| 12 | +from torch_xla.experimental.unbounded_dynamism_export import * |
| 13 | +from torch_xla.stablehlo import exported_program_to_stablehlo |
| 14 | +from torch_xla.utils.stablehlo_test_utils import wrap_func_as_nn_module |
| 15 | + |
| 16 | + |
| 17 | +class ExportFxPassTest(unittest.TestCase): |
| 18 | + |
| 19 | + def test_decompose_dynamic_shape_select(self): |
| 20 | + args = (torch.rand((10, 197, 768)), 1, 0) |
| 21 | + dynamic_shapes = ([{0: Dim("bs")}, None, None],) |
| 22 | + m = wrap_func_as_nn_module(torch.ops.aten.select.int) |
| 23 | + ep = export(m, args, dynamic_shapes=dynamic_shapes) |
| 24 | + out1 = ep.module()(*args) |
| 25 | + decompose_dynamic_shape_select(ep.graph_module) |
| 26 | + ep.graph_module.recompile() |
| 27 | + self.assertTrue('aten.view' in ep.graph_module.code) |
| 28 | + replace_dynamic_view_with_xla_op(ep.graph_module) |
| 29 | + ep.graph_module.recompile() |
| 30 | + self.assertTrue('aten.view' not in ep.graph_module.code) |
| 31 | + self.assertTrue('xla.dynamic_view' in ep.graph_module.code) |
| 32 | + out2 = ep.module()(*args) |
| 33 | + self.assertTrue(torch.allclose(out1, out2)) |
| 34 | + |
| 35 | + def test_no_op_slice_removal(self): |
| 36 | + |
| 37 | + class M(torch.nn.Module): |
| 38 | + |
| 39 | + def forward(self, x): |
| 40 | + x = x * 2 |
| 41 | + return torch.ops.aten.slice(x, 1, 0, 9223372036854775807) |
| 42 | + |
| 43 | + m = M() |
| 44 | + args = (torch.rand((10, 197, 768)),) |
| 45 | + dynamic_shapes = ({0: Dim("bs")},) |
| 46 | + ep = export(m, args, dynamic_shapes=dynamic_shapes) |
| 47 | + out1 = ep.module()(*args) |
| 48 | + self.assertTrue('aten.slice' in ep.graph_module.code) |
| 49 | + remove_no_op_slice(ep.graph_module) |
| 50 | + ep.graph_module.recompile() |
| 51 | + self.assertTrue('aten.slice' not in ep.graph_module.code) |
| 52 | + out2 = ep.module()(*args) |
| 53 | + self.assertTrue(torch.allclose(out1, out2)) |
| 54 | + |
| 55 | + def test_dynamic_view(self): |
| 56 | + |
| 57 | + class M(torch.nn.Module): |
| 58 | + |
| 59 | + def __init__(self): |
| 60 | + super().__init__() |
| 61 | + self.conv = torch.nn.Conv2d(3, 5, [16, 16]) |
| 62 | + |
| 63 | + def forward(self, x): |
| 64 | + x = self.conv(x) |
| 65 | + return x.view(x.shape[0], x.shape[1], -1) |
| 66 | + |
| 67 | + m = M() |
| 68 | + args = (torch.rand((10, 3, 224, 224)),) |
| 69 | + dynamic_shapes = ({0: Dim("bs")},) |
| 70 | + ep = export(m, args, dynamic_shapes=dynamic_shapes) |
| 71 | + out1 = ep.module()(*args) |
| 72 | + replace_dynamic_view_with_xla_op(ep.graph_module) |
| 73 | + ep.graph_module.recompile() |
| 74 | + self.assertTrue('xla.dynamic_view' in ep.graph_module.code) |
| 75 | + out2 = ep.module()(*args) |
| 76 | + self.assertTrue(torch.allclose(out1, out2)) |
| 77 | + |
| 78 | + def test_dynamic_view_non_bs(self): |
| 79 | + |
| 80 | + class M(torch.nn.Module): |
| 81 | + |
| 82 | + def forward(self, x): |
| 83 | + return x.view(x.shape[0], x.shape[1] * x.shape[2], x.shape[3]) |
| 84 | + |
| 85 | + m = M() |
| 86 | + args = (torch.rand((1, 3, 2, 16)),) |
| 87 | + dynamic_shapes = ({1: Dim("bs")},) |
| 88 | + ep = export(m, args, dynamic_shapes=dynamic_shapes) |
| 89 | + out1 = ep.module()(*args) |
| 90 | + replace_dynamic_view_with_xla_op(ep.graph_module) |
| 91 | + ep.graph_module.recompile() |
| 92 | + self.assertTrue('xla.dynamic_view' in ep.graph_module.code) |
| 93 | + out2 = ep.module()(*args) |
| 94 | + self.assertTrue(torch.allclose(out1, out2)) |
| 95 | + |
| 96 | + def test_dynamic_view_multiplier(self): |
| 97 | + |
| 98 | + class M(torch.nn.Module): |
| 99 | + |
| 100 | + def __init__(self): |
| 101 | + super().__init__() |
| 102 | + self.conv = torch.nn.Conv2d(3, 5, [16, 16]) |
| 103 | + |
| 104 | + def forward(self, x): |
| 105 | + x = self.conv(x) |
| 106 | + return x.view(x.shape[0] * x.shape[1], -1) |
| 107 | + |
| 108 | + m = M() |
| 109 | + args = (torch.rand((10, 3, 224, 224)),) |
| 110 | + dynamic_shapes = ({0: Dim("bs")},) |
| 111 | + ep = export(m, args, dynamic_shapes=dynamic_shapes) |
| 112 | + out1 = ep.module()(*args) |
| 113 | + replace_dynamic_view_with_xla_op(ep.graph_module) |
| 114 | + print(ep) |
| 115 | + ep.graph_module.recompile() |
| 116 | + print(ep.graph_module.code) |
| 117 | + self.assertTrue('xla.dynamic_view' in ep.graph_module.code) |
| 118 | + out2 = ep.module()(*args) |
| 119 | + self.assertTrue(torch.allclose(out1, out2)) |
| 120 | + |
| 121 | + def test_dynamic_expand(self): |
| 122 | + |
| 123 | + class M(torch.nn.Module): |
| 124 | + |
| 125 | + def forward(self, x, image): |
| 126 | + return x.expand([image.shape[0], -1, -1]) |
| 127 | + |
| 128 | + m = M() |
| 129 | + args = (torch.rand((1, 1, 5)), torch.rand((3, 4))) |
| 130 | + dynamic_shapes = ( |
| 131 | + None, |
| 132 | + { |
| 133 | + 0: Dim("bs") |
| 134 | + }, |
| 135 | + ) |
| 136 | + ep = export(m, args, dynamic_shapes=dynamic_shapes) |
| 137 | + out1 = ep.module()(*args) |
| 138 | + replace_dynamic_expand_with_xla_op(ep.graph_module) |
| 139 | + ep.graph_module.recompile() |
| 140 | + self.assertTrue('xla.dynamic_expand' in ep.graph_module.code) |
| 141 | + out2 = ep.module()(*args) |
| 142 | + self.assertTrue(torch.allclose(out1, out2)) |
| 143 | + |
| 144 | + def test_dynamic_expand_2(self): |
| 145 | + |
| 146 | + class M(torch.nn.Module): |
| 147 | + |
| 148 | + def forward(self, x, range): |
| 149 | + return x.expand(1, 1, 8, range.shape[0], 256) |
| 150 | + |
| 151 | + m = M() |
| 152 | + args = (torch.rand((1, 1, 1, 3, 256)), torch.arange(3)) |
| 153 | + dynamic_shapes = ({3: Dim("bs")}, {0: Dim("bs")}) |
| 154 | + ep = export(m, args, dynamic_shapes=dynamic_shapes) |
| 155 | + out1 = ep.module()(*args) |
| 156 | + print(ep) |
| 157 | + replace_dynamic_expand_with_xla_op(ep.graph_module) |
| 158 | + print(ep) |
| 159 | + ep.graph_module.recompile() |
| 160 | + self.assertTrue('xla.dynamic_expand' in ep.graph_module.code) |
| 161 | + out2 = ep.module()(*args) |
| 162 | + self.assertTrue(torch.allclose(out1, out2)) |
| 163 | + |
| 164 | + def test_layer_norm_decomp(self): |
| 165 | + |
| 166 | + class M(torch.nn.Module): |
| 167 | + |
| 168 | + def forward(self, x, dim, weight, bias, eps): |
| 169 | + return torch.ops.aten.native_layer_norm.default(x, dim, weight, bias, |
| 170 | + eps)[0] |
| 171 | + |
| 172 | + args = (torch.rand(10, 197, |
| 173 | + 768), [768], torch.rand(768), torch.rand(768), 1e-12) |
| 174 | + dynamic_shapes = ({0: Dim("bs")}, [None], None, None, None) |
| 175 | + m = M().eval() |
| 176 | + before_decomp_out = m(*args) |
| 177 | + after_decomp_out = native_layer_norm_impl(*args) |
| 178 | + self.assertTrue( |
| 179 | + torch.allclose(before_decomp_out, after_decomp_out, atol=1e-6)) |
| 180 | + ep = export(m, args, dynamic_shapes=dynamic_shapes) |
| 181 | + decompose_dynamic_native_layer_norm(ep.graph_module) |
| 182 | + ep.graph_module.recompile() |
| 183 | + self.assertFalse('aten.native_layer_norm' in ep.graph_module.code) |
| 184 | + after_decomp_out_2 = ep.module()(*args) |
| 185 | + self.assertTrue( |
| 186 | + torch.allclose(before_decomp_out, after_decomp_out_2, atol=1e-6)) |
| 187 | + |
| 188 | + def test_group_norm_to_layer_norm(self): |
| 189 | + |
| 190 | + class M(torch.nn.Module): |
| 191 | + |
| 192 | + def forward(self, x, weight, bias, N, C, HxW, group, eps): |
| 193 | + return torch.ops.aten.native_group_norm.default(x, weight, bias, N, C, |
| 194 | + HxW, group, eps)[0] |
| 195 | + |
| 196 | + class M2(torch.nn.Module): |
| 197 | + |
| 198 | + def __init__(self): |
| 199 | + super().__init__() |
| 200 | + # self.conv = torch.nn.Conv1d(1, 512, 10, stride=5) |
| 201 | + self.layer_norm = torch.nn.GroupNorm( |
| 202 | + num_groups=512, num_channels=512, affine=True) |
| 203 | + |
| 204 | + def forward(self, x): |
| 205 | + return self.layer_norm(x)[0] |
| 206 | + |
| 207 | + args = (torch.rand(10, 512, 159), torch.rand(512), torch.rand(512), 10, 512, |
| 208 | + 159, 512, 1e-12) |
| 209 | + export_args = (torch.rand(10, 512, 159),) |
| 210 | + dynamic_shapes = ({0: Dim("bs")},) |
| 211 | + m = M().eval() |
| 212 | + before_decomp_out = m(*args) |
| 213 | + after_decomp_out = native_group_norm_impl(*args) |
| 214 | + self.assertTrue( |
| 215 | + torch.allclose(before_decomp_out, after_decomp_out, atol=1e-6)) |
| 216 | + # Test export path with a different to workaround an export issue. |
| 217 | + m2 = M2().eval() |
| 218 | + ep = export(m2, export_args, dynamic_shapes=dynamic_shapes) |
| 219 | + before_decomp_ep_out = m2(*export_args) |
| 220 | + decompose_dynamic_native_group_norm(ep.graph_module) |
| 221 | + ep.graph_module.recompile() |
| 222 | + self.assertFalse('aten.native_group_norm' in ep.graph_module.code) |
| 223 | + after_decomp_ep_out = ep.module()(*export_args) |
| 224 | + # print(before_decomp_ep_out - after_decomp_ep_out) |
| 225 | + self.assertTrue( |
| 226 | + torch.allclose(before_decomp_ep_out, after_decomp_ep_out, atol=1e-6)) |
| 227 | + |
| 228 | + def test_dynamic_unsqueeze_to_view(self): |
| 229 | + |
| 230 | + class M(torch.nn.Module): |
| 231 | + |
| 232 | + def forward(self, x): |
| 233 | + return torch.ops.aten.unsqueeze.default(x, 2) |
| 234 | + |
| 235 | + args = (torch.rand((1, 1, 3, 256)),) |
| 236 | + dynamic_shapes = ({2: Dim("dim")},) |
| 237 | + m = M().eval() |
| 238 | + ep = export(m, args, dynamic_shapes=dynamic_shapes) |
| 239 | + out1 = ep.module()(*args) |
| 240 | + dynamic_unsqueeze_to_view(ep.graph_module) |
| 241 | + ep.graph_module.recompile() |
| 242 | + self.assertFalse('aten.unsqueeze' in ep.graph_module.code) |
| 243 | + self.assertTrue('aten.view' in ep.graph_module.code) |
| 244 | + out2 = ep.module()(*args) |
| 245 | + self.assertTrue(torch.allclose(out1, out2)) |
| 246 | + |
| 247 | + |
| 248 | +if __name__ == "__main__": |
| 249 | + test = unittest.main() |
| 250 | + sys.exit(0 if test.result.wasSuccessful() else 1) |
0 commit comments