-
Couldn't load subscription status.
- Fork 9
Use scan and hostoffloading for llama model #123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
7 commits Select commit Hold shift + click to select a range
157d30f Support run trainer locally
zpcore a25f2f0 nit
zpcore b4a92df update docker command
zpcore 3d06daa initial runnable version
zpcore 470553e support hostoffloading
zpcore 4630cbd update config useless
zpcore 22f20da clean up PR for quick test
zpcore File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,177 @@ | ||
| from typing import Sequence | ||
| import torch.fx as fx | ||
| import torch | ||
| import torch_xla | ||
| from torch.utils._pytree import tree_iter | ||
| | ||
| from functorch.compile import aot_function, make_boxed_func # type:ignore | ||
| from .remat_all import remat_all_partition_fn | ||
| | ||
| | ||
| @torch.library.custom_op("xla::offload_name", mutates_args=()) | ||
| def offload_name(t: torch.Tensor, name: str) -> torch.Tensor: | ||
| """ | ||
| `offload_name` is an identity function that associates the input | ||
| tensor with `name`. It is primarily useful in conjunction with | ||
| `remat_all_and_offload_these_inputs`, which will rematerialize | ||
| intermediate activations and also offload inputs with the specified | ||
| names to host memory, moving them back during the backward pass. | ||
| """ | ||
| if t is None: | ||
| return None | ||
| return t.clone() | ||
| | ||
| | ||
| @offload_name.register_fake | ||
| def _(t: torch.Tensor, name: str) -> torch.Tensor: | ||
| if t is None: | ||
| return None | ||
| return torch.empty_like(t) | ||
| | ||
| | ||
| def offload_name_backward(ctx, grad): | ||
| return grad, None | ||
| | ||
| | ||
| offload_name.register_autograd(offload_name_backward) | ||
| | ||
| | ||
| def remat_all_and_offload_these_inputs( | ||
| joint_module: fx.GraphModule, | ||
| _joint_inputs, | ||
| *, | ||
| num_fwd_outputs, | ||
| names_to_offload: Sequence[str], | ||
| ): | ||
| """ | ||
| `remat_all_and_offload_these_inputs` will rematerialize (recompute) all | ||
| intermediate activations in `joint_module`, and offload inputs with the | ||
| specified names to host memory, moving them back during the backward pass. | ||
| It transforms the joint graph into separate forward and backward graphs. | ||
| """ | ||
| input_device = next(iter(tree_iter(_joint_inputs))).device | ||
| fwd, bwd = remat_all_partition_fn( | ||
| joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs) | ||
| with torch.device(input_device): | ||
| fw_example_args = make_arguments(fwd) | ||
| bw_example_args = make_arguments(bwd) | ||
| | ||
| fw_name_in_output_indices = get_name_in_output_indices(fwd) | ||
| bw_name_in_input_names = get_name_in_input_names(bwd) | ||
| | ||
| for name in names_to_offload: | ||
| print(f"Going to offload {name}") | ||
| assert name in fw_name_in_output_indices | ||
| assert name in bw_name_in_input_names | ||
| | ||
| # print("fw_name_in_output_indices", fw_name_in_output_indices) | ||
| # print("bw_name_in_input_names", bw_name_in_input_names) | ||
| | ||
| with torch.no_grad(): | ||
| | ||
| def forward(**kwargs): | ||
| import pdb | ||
| try: | ||
| out = fwd(**kwargs) | ||
| indices_to_offload = set( | ||
| [fw_name_in_output_indices[name] for name in names_to_offload]) | ||
| return tuple( | ||
| torch.ops.xla.place_to_host(v) if i in # type:ignore | ||
| indices_to_offload else v for i, v in enumerate(out)) | ||
| except Exception: | ||
| pdb.post_mortem() | ||
| | ||
| def backward(**kwargs): | ||
| # print(f"Backward got {len(kwargs)} arguments:") | ||
| for k, v in kwargs.items(): | ||
| print(f"Arg {k}: {v.shape if v is not None else 'None'}") | ||
| arguments_to_move_back = set( | ||
| [bw_name_in_input_names[name] for name in names_to_offload]) | ||
| kwargs = { | ||
| k: torch.ops.xla.place_to_device(v) # type: ignore | ||
| if k in arguments_to_move_back else v for k, v in kwargs.items() | ||
| } | ||
| import pdb | ||
| try: | ||
| values = bwd(**kwargs) | ||
| print(f"Backward will return {len(values)} values:") | ||
| for i, v in enumerate(values): | ||
| print(f"Arg {i}: {v.shape if v is not None else 'None'}") | ||
| return values | ||
| except Exception: | ||
| pdb.post_mortem() | ||
| | ||
| # Use AOTAutograd to retrace forward and backward, thus incorporating | ||
| # the offloading ops. | ||
| graph = [None] | ||
| | ||
| def get_graph(g, _): | ||
| graph[0] = g | ||
| return make_boxed_func(g) | ||
| | ||
| _ = aot_function(forward, fw_compiler=get_graph)(**fw_example_args) | ||
| aot_forward = graph[0] | ||
| | ||
| _ = aot_function(backward, fw_compiler=get_graph)(**bw_example_args) | ||
| aot_backward = graph[0] | ||
| | ||
| return aot_forward, aot_backward | ||
| | ||
| | ||
| def make_arguments(gm: fx.GraphModule): | ||
| """ | ||
| Given a graph module, `make_arguments` returns a dictionary of example inputs | ||
| that can be used as keyward arguments to call the graph module. | ||
| """ | ||
| example_args = {} | ||
| for node in gm.graph.nodes: | ||
| if node.op != 'placeholder': | ||
| continue | ||
| if 'tensor_meta' in node.meta: | ||
| tensor_meta = node.meta['tensor_meta'] | ||
| tensor = torch.zeros( | ||
| tensor_meta.shape, | ||
| dtype=tensor_meta.dtype, | ||
| requires_grad=tensor_meta.requires_grad) | ||
| example_args[node.name] = tensor | ||
| return example_args | ||
| | ||
| | ||
| def get_named_nodes(gm: torch.fx.GraphModule): | ||
| named_nodes = {} | ||
| | ||
| for node in gm.graph.nodes: | ||
| if node.op == "call_function": | ||
| if hasattr(node.target, "name"): | ||
| if node.target.name() == offload_name._qualname: # type: ignore | ||
| named_nodes[node.args[0]] = node.args[1] | ||
| | ||
| return named_nodes | ||
| | ||
| | ||
| def get_name_in_output_indices(gm: torch.fx.GraphModule): | ||
| named_nodes = get_named_nodes(gm) | ||
| name_in_output_indices = {} | ||
| | ||
| for node in gm.graph.nodes: | ||
| if node.op == "output": | ||
| assert len(node.args) <= 1 | ||
| if len(node.args) == 0: | ||
| continue | ||
| for i, arg in enumerate(next(iter(node.args))): # type: ignore | ||
| if arg in named_nodes: | ||
| name_in_output_indices[named_nodes[arg]] = i | ||
| | ||
| return name_in_output_indices | ||
| | ||
| | ||
| def get_name_in_input_names(gm: torch.fx.GraphModule): | ||
| named_nodes = get_named_nodes(gm) | ||
| name_in_input_names = {} | ||
| | ||
| for node in gm.graph.nodes: | ||
| if node.op == "placeholder": | ||
| if node in named_nodes: | ||
| name_in_input_names[named_nodes[node]] = node.target | ||
| | ||
| return name_in_input_names |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| import torch.fx | ||
| import torch._functorch.config | ||
| from functorch.compile import min_cut_rematerialization_partition | ||
| | ||
| from contextlib import contextmanager | ||
| | ||
| | ||
| @contextmanager | ||
| def remat_all_config(): | ||
| # Backup existing config values | ||
| backup = { | ||
| "activation_memory_budget": | ||
| torch._functorch.config.activation_memory_budget, | ||
| "aggressive_recomputation": | ||
| torch._functorch.config.aggressive_recomputation, | ||
| "recompute_views": | ||
| torch._functorch.config.recompute_views, | ||
| "ban_recompute_reductions": | ||
| torch._functorch.config.ban_recompute_reductions, | ||
| "ban_recompute_not_in_allowlist": | ||
| torch._functorch.config.ban_recompute_not_in_allowlist, | ||
| "ban_recompute_materialized_backward": | ||
| torch._functorch.config.ban_recompute_materialized_backward, | ||
| "ban_recompute_long_fusible_chains": | ||
| torch._functorch.config.ban_recompute_long_fusible_chains, | ||
| "ban_recompute_used_far_apart": | ||
| torch._functorch.config.ban_recompute_used_far_apart, | ||
| } | ||
| | ||
| try: | ||
| # Set activation_memory_budget to zero to force the min cut partitioner | ||
| # to recompute instead of saving. Also don't ban the recomputing of any ops. | ||
| torch._functorch.config.activation_memory_budget = 0.0 | ||
| torch._functorch.config.aggressive_recomputation = True | ||
| torch._functorch.config.recompute_views = True | ||
| torch._functorch.config.ban_recompute_reductions = False | ||
| torch._functorch.config.ban_recompute_not_in_allowlist = False | ||
| torch._functorch.config.ban_recompute_materialized_backward = False | ||
| torch._functorch.config.ban_recompute_long_fusible_chains = False | ||
| torch._functorch.config.ban_recompute_used_far_apart = False | ||
| yield | ||
| | ||
| finally: | ||
| # Restore the original config values | ||
| torch._functorch.config.activation_memory_budget = backup[ | ||
| "activation_memory_budget"] | ||
| torch._functorch.config.aggressive_recomputation = backup[ | ||
| "aggressive_recomputation"] | ||
| torch._functorch.config.recompute_views = backup["recompute_views"] | ||
| torch._functorch.config.ban_recompute_reductions = backup[ | ||
| "ban_recompute_reductions"] | ||
| torch._functorch.config.ban_recompute_not_in_allowlist = backup[ | ||
| "ban_recompute_not_in_allowlist"] | ||
| torch._functorch.config.ban_recompute_materialized_backward = backup[ | ||
| "ban_recompute_materialized_backward"] | ||
| torch._functorch.config.ban_recompute_long_fusible_chains = backup[ | ||
| "ban_recompute_long_fusible_chains"] | ||
| torch._functorch.config.ban_recompute_used_far_apart = backup[ | ||
| "ban_recompute_used_far_apart"] | ||
| | ||
| | ||
| def remat_all_partition_fn( | ||
| joint_module: torch.fx.GraphModule, | ||
| _joint_inputs, | ||
| *, | ||
| num_fwd_outputs, | ||
| ): | ||
| """ | ||
| remat_all_partition_fn is a graph partition function that closely matches the | ||
| default behavior of `torch.utils.checkpoint`, which is to discard all intermediate | ||
| activations and recompute all of them during the backward pass. | ||
| """ | ||
| with remat_all_config(): | ||
| return min_cut_rematerialization_partition( | ||
| joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move to default yaml file