-
Couldn't load subscription status.
- Fork 560
Add alternative dynamo backend #8893
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits Select commit Hold shift + click to select a range
9e0999a Add alternative dynamo backend
qihqi 570b8e1 Dynamo backend update
qihqi 83454a6 update messages
qihqi 97811d5 yapf
qihqi 33a6004 comments
qihqi 889b3a0 yafp
qihqi 4ca0a4c fix tests
qihqi 426e1b2 nits
qihqi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| import functools | ||
| from typing import Any | ||
| import torch | ||
| from torch.utils import _pytree as pytree | ||
| from torch_xla.core import xla_builder as xb | ||
| import torch_xla | ||
| | ||
| from torch._dynamo.backends.common import aot_autograd | ||
| from functorch.compile import make_boxed_func | ||
| | ||
| | ||
| def _dynamo_backend(model: torch.fx.GraphModule, sample_args: Any): | ||
| """A dynamo backend that compiles a FX graph to HLO using JAX and torchax. | ||
| | ||
| It takes FX graph as input and returns a compiled PyTorch function. The FX graph | ||
| is traced into a JAX function using torchax, and the JAX function is lowered to HLO. | ||
| | ||
| Args: | ||
| model: the graph to be compiled | ||
| sample_args: a tuple or list of sample inputs. I.e. model(*sample_args) produces | ||
| the model output | ||
| | ||
| Returns: | ||
| Another callable f such that f(*sample_inputs) computes the same thing as model. | ||
| """ | ||
| | ||
| try: | ||
| import torchax.interop | ||
| from torchax.export import JaxInterpreter | ||
| import jax | ||
| except ImportError: | ||
| print('To use this dynamo backend, please install torchax') | ||
| raise | ||
| | ||
| jax.config.update("jax_enable_x64", True) | ||
| env = torchax.default_env() | ||
| xla_device = torch_xla.device() | ||
| | ||
| def run_jax(*args, initial_rng_key): | ||
| args_t = torchax.interop.torch_view(args) | ||
| env.manual_seed(initial_rng_key) | ||
| with env: | ||
| res = model(*args_t) | ||
| return torchax.interop.jax_view(res) | ||
| | ||
| initial_rng_key = torch.tensor(0, device=xla_device, dtype=torch.uint32) | ||
| computation = xb.jax_func_to_xla_computation( | ||
| run_jax, sample_args, {'initial_rng_key': initial_rng_key}, 'dynamo_jax') | ||
| | ||
| def equivalent(*args, **kwargs): | ||
| kwargs['initial_rng_key'] = torch.randint( | ||
qihqi marked this conversation as resolved. Show resolved Hide resolved | ||
| 0, 2**32, (), dtype=torch.uint32, device=xla_device) | ||
| flattened, _ = pytree.tree_flatten((args, kwargs)) | ||
| res = computation(flattened) | ||
| if not isinstance(res, (list, tuple)): | ||
| return (res,) | ||
| return res | ||
| | ||
| return make_boxed_func(equivalent) | ||
| | ||
| | ||
| def dynamo_backend(fx, args): | ||
| from functorch.compile import aot_function | ||
| return aot_function(fx, fw_compiler=_dynamo_backend) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.