|
| 1 | +torch.func interaction with torch.compile |
| 2 | +============================================== |
| 3 | + |
| 4 | +So you want to use a `torch.func` ("functorch") transform (like `vmap`, `grad`, `jacrev`, etc) with `torch.compile`. Here's a guide to what works today, what doesn't, and how to work around it. |
| 5 | + |
| 6 | +Applying a `torch.func` transform to a `torch.compile`'d function |
| 7 | +----------------------------------------------------------------- |
| 8 | + |
| 9 | +This doesn't work and is being tracked by `https://github.com/pytorch/pytorch/issues/100320`. |
| 10 | + |
| 11 | +.. code:: python |
| 12 | +
|
| 13 | + import torch |
| 14 | +
|
| 15 | + @torch.compile |
| 16 | + def f(x): |
| 17 | + return torch.sin(x) |
| 18 | +
|
| 19 | + def g(x): |
| 20 | + return torch.grad(f)(x) |
| 21 | +
|
| 22 | + x = torch.randn(2, 3) |
| 23 | + g(x) |
| 24 | +
|
| 25 | +As a workaround, please put the `torch.compile` outside of the `torch.func` transform: |
| 26 | + |
| 27 | +.. code:: python |
| 28 | +
|
| 29 | + import torch |
| 30 | +
|
| 31 | + def f(x): |
| 32 | + return torch.sin(x) |
| 33 | +
|
| 34 | + @torch.compile |
| 35 | + def g(x): |
| 36 | + return torch.vmap(f)(x) |
| 37 | +
|
| 38 | + x = torch.randn(2, 3) |
| 39 | + g(x) |
| 40 | +
|
| 41 | +Doesn't work (PT 2.0): calling a `torch.func` transform inside of a `torch.compile`'ed function |
| 42 | +------------------------------------------------------------------------------------------------ |
| 43 | + |
| 44 | +.. code:: python |
| 45 | +
|
| 46 | + import torch |
| 47 | +
|
| 48 | + @torch.compile |
| 49 | + def f(x): |
| 50 | + return torch.vmap(torch.sum)(x) |
| 51 | +
|
| 52 | + x = torch.randn(2, 3) |
| 53 | + f(x) |
| 54 | +
|
| 55 | +This doesn't work yet. Please see the workaround (the next section). |
| 56 | + |
| 57 | +Workaround: use `torch._dynamo.allow_in_graph` |
| 58 | +---------------------------------------------- |
| 59 | + |
| 60 | +`allow_in_graph` is an escape hatch. If your code does not work with `torch.compile`, which introspects Python bytecode, but you believe it will work via a symbolic tracing approach (like `jax.jit`), then use `allow_in_graph`. |
| 61 | + |
| 62 | +By using `allow_in_graph` to annotate a function, you promise PyTorch a couple of things that we are unable to completely verify: |
| 63 | +- Your function is pure. That is, all outputs only depend on the inputs and do not depend on any captured Tensors. |
| 64 | +- Your function is functional. That is, it does not mutate any state. This may be relaxed; we actually support functions that appear to be functional from the outside: they may have in-place PyTorch operations, but may not mutate global state or inputs to the function. |
| 65 | +- Your function does not raise data-dependent errors. |
| 66 | + |
| 67 | +.. code:: python |
| 68 | +
|
| 69 | + import torch |
| 70 | +
|
| 71 | + @torch.compile |
| 72 | + def f(x): |
| 73 | + return torch._dynamo.allow_in_graph(torch.vmap(torch.sum))(x) |
| 74 | +
|
| 75 | + x = torch.randn(2, 3) |
| 76 | + f(x) |
| 77 | +
|
| 78 | +A common pitfall is using `allow_in_graph` to annotate a function that invokes an `nn.Module`. This is because the outputs now depend on the parameters of the `nn.Module`. To actually get this to work, use `torch.func.functional_call` to extract the module state. |
0 commit comments