Skip to content

Commit 2dd33c7

Browse files
msaroufimpytorchmergebot
authored andcommitted
Docs for torchcompile and functorch (pytorch#101881)
<!-- copilot:poem --> ### <samp>🤖 Generated by Copilot at b5f48b6</samp> > _`torch.compile` docs_ > _Add a new section for `func`_ > _Winter of features_ Thanks @zou3519 Pull Request resolved: pytorch#101881 Approved by: https://github.com/eellison, https://github.com/zou3519
1 parent 81c181d commit 2dd33c7

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
torch.func interaction with torch.compile
2+
==============================================
3+
4+
So you want to use a `torch.func` ("functorch") transform (like `vmap`, `grad`, `jacrev`, etc) with `torch.compile`. Here's a guide to what works today, what doesn't, and how to work around it.
5+
6+
Applying a `torch.func` transform to a `torch.compile`'d function
7+
-----------------------------------------------------------------
8+
9+
This doesn't work and is being tracked by `https://github.com/pytorch/pytorch/issues/100320`.
10+
11+
.. code:: python
12+
13+
import torch
14+
15+
@torch.compile
16+
def f(x):
17+
return torch.sin(x)
18+
19+
def g(x):
20+
return torch.grad(f)(x)
21+
22+
x = torch.randn(2, 3)
23+
g(x)
24+
25+
As a workaround, please put the `torch.compile` outside of the `torch.func` transform:
26+
27+
.. code:: python
28+
29+
import torch
30+
31+
def f(x):
32+
return torch.sin(x)
33+
34+
@torch.compile
35+
def g(x):
36+
return torch.vmap(f)(x)
37+
38+
x = torch.randn(2, 3)
39+
g(x)
40+
41+
Doesn't work (PT 2.0): calling a `torch.func` transform inside of a `torch.compile`'ed function
42+
------------------------------------------------------------------------------------------------
43+
44+
.. code:: python
45+
46+
import torch
47+
48+
@torch.compile
49+
def f(x):
50+
return torch.vmap(torch.sum)(x)
51+
52+
x = torch.randn(2, 3)
53+
f(x)
54+
55+
This doesn't work yet. Please see the workaround (the next section).
56+
57+
Workaround: use `torch._dynamo.allow_in_graph`
58+
----------------------------------------------
59+
60+
`allow_in_graph` is an escape hatch. If your code does not work with `torch.compile`, which introspects Python bytecode, but you believe it will work via a symbolic tracing approach (like `jax.jit`), then use `allow_in_graph`.
61+
62+
By using `allow_in_graph` to annotate a function, you promise PyTorch a couple of things that we are unable to completely verify:
63+
- Your function is pure. That is, all outputs only depend on the inputs and do not depend on any captured Tensors.
64+
- Your function is functional. That is, it does not mutate any state. This may be relaxed; we actually support functions that appear to be functional from the outside: they may have in-place PyTorch operations, but may not mutate global state or inputs to the function.
65+
- Your function does not raise data-dependent errors.
66+
67+
.. code:: python
68+
69+
import torch
70+
71+
@torch.compile
72+
def f(x):
73+
return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x)
74+
75+
x = torch.randn(2, 3)
76+
f(x)
77+
78+
A common pitfall is using `allow_in_graph` to annotate a function that invokes an `nn.Module`. This is because the outputs now depend on the parameters of the `nn.Module`. To actually get this to work, use `torch.func.functional_call` to extract the module state.

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Features described in this documentation are classified by release status:
5656
compile/custom-backends
5757
compile/deep-dive
5858
compile/performance-dashboard
59+
compile/torchfunc-and-torchcompile
5960
ir
6061

6162
.. toctree::

0 commit comments

Comments
 (0)