Skip to content

Commit 47ad351

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
[DRAFT] INitial version of sticky export (pytorch#151047)
Summary: This is to make torchnative demos and benchmarking real models more simple by not requiring ppl to find example inputs first. Test Plan: CI Differential Revision: D72815584 Pull Request resolved: pytorch#151047 Approved by: https://github.com/zhxchen17
1 parent bd19173 commit 47ad351

File tree

2 files changed

+145
-1
lines changed

2 files changed

+145
-1
lines changed

test/export/test_experimental.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["oncall: export"]
22
# flake8: noqa
3+
import types
34
import unittest
45
from typing import Dict, List, Tuple
56

@@ -9,7 +10,7 @@
910
from torch._functorch.aot_autograd import aot_export_module
1011
from torch.export import export, export_for_training
1112
from torch.export._trace import _convert_ts_to_export_experimental
12-
from torch.export.experimental import _export_forward_backward
13+
from torch.export.experimental import _export_forward_backward, _sticky_export
1314
from torch.export.graph_signature import OutputKind
1415
from torch.testing import FileCheck
1516

@@ -333,6 +334,111 @@ def forward(self, x, label):
333334
OutputKind.LOSS_OUTPUT,
334335
)
335336

337+
def test_sticky_export(self):
338+
class Model(torch.nn.Module):
339+
def __init__(self):
340+
super().__init__()
341+
self.linear = torch.nn.Linear(4, 4)
342+
343+
def forward(self, x):
344+
return self.linear(x)
345+
346+
class Pipeline:
347+
def __init__(self, model):
348+
self.model = model
349+
350+
def generate(self, *args, **kwargs):
351+
return self.model(*args, **kwargs)
352+
353+
inp = torch.randn(4, 4)
354+
355+
p = Pipeline(Model())
356+
orig_forward = p.model.forward
357+
p.model.forward = _sticky_export(p.model.forward)
358+
res = p.generate(inp)
359+
360+
p.model.forward = orig_forward
361+
res2 = p.generate(inp)
362+
self.assertTrue(torch.allclose(res, res2))
363+
364+
def test_sticky_export_dynamic(self):
365+
class Model(torch.nn.Module):
366+
def __init__(self):
367+
super().__init__()
368+
self.linear = torch.nn.Linear(4, 4)
369+
370+
def forward(self, x):
371+
if x.shape[0] < 5:
372+
return self.linear(x)
373+
return x.sin()
374+
375+
class Pipeline:
376+
def __init__(self, model):
377+
self.model = model
378+
379+
def generate(self, *args, **kwargs):
380+
return self.model(*args, **kwargs)
381+
382+
inp = torch.randn(4, 4)
383+
384+
def callback(*args, **kwargs):
385+
# I think it is bit weird to use the forward arg name here, so
386+
# lets just use ShapeCollections
387+
388+
flat_args, _ = torch.utils._pytree.tree_flatten((args, kwargs))
389+
collections = torch.export.ShapesCollection()
390+
for arg in flat_args:
391+
if isinstance(arg, torch.Tensor):
392+
collections[arg] = {
393+
i: torch.export.Dim.AUTO for i in range(len(arg.shape))
394+
}
395+
return collections
396+
397+
p = Pipeline(Model())
398+
p.model.forward = _sticky_export(
399+
p.model.forward, dynamic_shapes_callback=callback
400+
)
401+
_ = p.generate(inp)
402+
self.assertExpectedInline(
403+
str(p.model.forward._exported_artifact.code).strip(),
404+
"""\
405+
def forward(self, x):
406+
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
407+
linear_weight = self.linear.weight
408+
linear_bias = self.linear.bias
409+
linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None
410+
return pytree.tree_unflatten((linear,), self._out_spec)""",
411+
)
412+
413+
def test_sticky_export_nested_inp(self):
414+
class Model(torch.nn.Module):
415+
def __init__(self):
416+
super().__init__()
417+
self.linear = torch.nn.Linear(4, 4)
418+
419+
def forward(self, *, inputs):
420+
return self.linear(inputs[0]) + self.linear(inputs[1])
421+
422+
class Pipeline:
423+
def __init__(self, model):
424+
self.model = model
425+
426+
def generate(self, *, input_tensor, input_tensor2):
427+
inputs = [input_tensor, input_tensor2]
428+
return self.model(inputs=inputs)
429+
430+
inp = torch.randn(4, 4)
431+
inp2 = torch.randn(4, 4)
432+
433+
p = Pipeline(Model())
434+
orig_forward = p.model.forward
435+
p.model.forward = _sticky_export(p.model.forward)
436+
res = p.generate(input_tensor=inp, input_tensor2=inp2)
437+
438+
p.model.forward = orig_forward
439+
res2 = p.generate(input_tensor=inp, input_tensor2=inp2)
440+
self.assertTrue(torch.allclose(res, res2))
441+
336442

337443
if __name__ == "__main__":
338444
run_tests()

torch/export/experimental/__init__.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import copy
2+
import functools
3+
import types
24
import typing
35

46
import torch
@@ -67,3 +69,39 @@ def _export_forward_backward(
6769
_remove_detach_pass(gm, new_graph_signature)
6870

6971
return ep._update(gm, new_graph_signature)
72+
73+
74+
@typing.no_type_check
75+
def _sticky_export(forward_func, dynamic_shapes_callback=None):
76+
"""
77+
Lazily export the model on first forward call.
78+
Usage:
79+
model.forward = _sticky_export(model.forward, dynamic_shapes_callback=callback)
80+
"""
81+
model = forward_func.__self__
82+
original_forward = forward_func.__func__
83+
84+
@functools.wraps(forward_func)
85+
def wrapper(*args, **kwargs):
86+
# Unpatch forward to avoid recursion during export
87+
model.forward = types.MethodType(original_forward, model)
88+
89+
dynamic_shapes_spec = None
90+
if dynamic_shapes_callback:
91+
dynamic_shapes_spec = dynamic_shapes_callback(*args, **kwargs)
92+
93+
try:
94+
exported = torch.export.export(
95+
model,
96+
args,
97+
kwargs,
98+
dynamic_shapes=dynamic_shapes_spec,
99+
).module()
100+
wrapper._exported_artifact = exported
101+
finally:
102+
# Restore the wrapper after export
103+
model.forward = wrapper
104+
105+
return exported(*args, **kwargs)
106+
107+
return wrapper

0 commit comments

Comments
 (0)