|
1 | 1 | # Owner(s): ["oncall: export"] |
2 | 2 | # flake8: noqa |
| 3 | +import types |
3 | 4 | import unittest |
4 | 5 | from typing import Dict, List, Tuple |
5 | 6 |
|
|
9 | 10 | from torch._functorch.aot_autograd import aot_export_module |
10 | 11 | from torch.export import export, export_for_training |
11 | 12 | from torch.export._trace import _convert_ts_to_export_experimental |
12 | | -from torch.export.experimental import _export_forward_backward |
| 13 | +from torch.export.experimental import _export_forward_backward, _sticky_export |
13 | 14 | from torch.export.graph_signature import OutputKind |
14 | 15 | from torch.testing import FileCheck |
15 | 16 |
|
@@ -333,6 +334,111 @@ def forward(self, x, label): |
333 | 334 | OutputKind.LOSS_OUTPUT, |
334 | 335 | ) |
335 | 336 |
|
| 337 | + def test_sticky_export(self): |
| 338 | + class Model(torch.nn.Module): |
| 339 | + def __init__(self): |
| 340 | + super().__init__() |
| 341 | + self.linear = torch.nn.Linear(4, 4) |
| 342 | + |
| 343 | + def forward(self, x): |
| 344 | + return self.linear(x) |
| 345 | + |
| 346 | + class Pipeline: |
| 347 | + def __init__(self, model): |
| 348 | + self.model = model |
| 349 | + |
| 350 | + def generate(self, *args, **kwargs): |
| 351 | + return self.model(*args, **kwargs) |
| 352 | + |
| 353 | + inp = torch.randn(4, 4) |
| 354 | + |
| 355 | + p = Pipeline(Model()) |
| 356 | + orig_forward = p.model.forward |
| 357 | + p.model.forward = _sticky_export(p.model.forward) |
| 358 | + res = p.generate(inp) |
| 359 | + |
| 360 | + p.model.forward = orig_forward |
| 361 | + res2 = p.generate(inp) |
| 362 | + self.assertTrue(torch.allclose(res, res2)) |
| 363 | + |
| 364 | + def test_sticky_export_dynamic(self): |
| 365 | + class Model(torch.nn.Module): |
| 366 | + def __init__(self): |
| 367 | + super().__init__() |
| 368 | + self.linear = torch.nn.Linear(4, 4) |
| 369 | + |
| 370 | + def forward(self, x): |
| 371 | + if x.shape[0] < 5: |
| 372 | + return self.linear(x) |
| 373 | + return x.sin() |
| 374 | + |
| 375 | + class Pipeline: |
| 376 | + def __init__(self, model): |
| 377 | + self.model = model |
| 378 | + |
| 379 | + def generate(self, *args, **kwargs): |
| 380 | + return self.model(*args, **kwargs) |
| 381 | + |
| 382 | + inp = torch.randn(4, 4) |
| 383 | + |
| 384 | + def callback(*args, **kwargs): |
| 385 | + # I think it is bit weird to use the forward arg name here, so |
| 386 | + # lets just use ShapeCollections |
| 387 | + |
| 388 | + flat_args, _ = torch.utils._pytree.tree_flatten((args, kwargs)) |
| 389 | + collections = torch.export.ShapesCollection() |
| 390 | + for arg in flat_args: |
| 391 | + if isinstance(arg, torch.Tensor): |
| 392 | + collections[arg] = { |
| 393 | + i: torch.export.Dim.AUTO for i in range(len(arg.shape)) |
| 394 | + } |
| 395 | + return collections |
| 396 | + |
| 397 | + p = Pipeline(Model()) |
| 398 | + p.model.forward = _sticky_export( |
| 399 | + p.model.forward, dynamic_shapes_callback=callback |
| 400 | + ) |
| 401 | + _ = p.generate(inp) |
| 402 | + self.assertExpectedInline( |
| 403 | + str(p.model.forward._exported_artifact.code).strip(), |
| 404 | + """\ |
| 405 | +def forward(self, x): |
| 406 | + x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) |
| 407 | + linear_weight = self.linear.weight |
| 408 | + linear_bias = self.linear.bias |
| 409 | + linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None |
| 410 | + return pytree.tree_unflatten((linear,), self._out_spec)""", |
| 411 | + ) |
| 412 | + |
| 413 | + def test_sticky_export_nested_inp(self): |
| 414 | + class Model(torch.nn.Module): |
| 415 | + def __init__(self): |
| 416 | + super().__init__() |
| 417 | + self.linear = torch.nn.Linear(4, 4) |
| 418 | + |
| 419 | + def forward(self, *, inputs): |
| 420 | + return self.linear(inputs[0]) + self.linear(inputs[1]) |
| 421 | + |
| 422 | + class Pipeline: |
| 423 | + def __init__(self, model): |
| 424 | + self.model = model |
| 425 | + |
| 426 | + def generate(self, *, input_tensor, input_tensor2): |
| 427 | + inputs = [input_tensor, input_tensor2] |
| 428 | + return self.model(inputs=inputs) |
| 429 | + |
| 430 | + inp = torch.randn(4, 4) |
| 431 | + inp2 = torch.randn(4, 4) |
| 432 | + |
| 433 | + p = Pipeline(Model()) |
| 434 | + orig_forward = p.model.forward |
| 435 | + p.model.forward = _sticky_export(p.model.forward) |
| 436 | + res = p.generate(input_tensor=inp, input_tensor2=inp2) |
| 437 | + |
| 438 | + p.model.forward = orig_forward |
| 439 | + res2 = p.generate(input_tensor=inp, input_tensor2=inp2) |
| 440 | + self.assertTrue(torch.allclose(res, res2)) |
| 441 | + |
336 | 442 |
|
337 | 443 | if __name__ == "__main__": |
338 | 444 | run_tests() |
0 commit comments