Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions torchprime/hf_models/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ train_script:

# If minibatch is False, this should be set to the global batch size.
# If minibatch is True, this should be set to the per host batch size.
per_device_train_batch_size: 256

per_device_train_batch_size: 1024
do_train: true
output_dir: "test-clm"
overwrite_output_dir: true
Expand All @@ -35,5 +34,7 @@ train_script:
torch_dtype: "bfloat16"
dataloader_drop_last: true
flash_attention: true
max_steps: 50
max_steps: 500
seed: 42

ignore_data_skip: true
3 changes: 2 additions & 1 deletion torchprime/launcher/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ RUN if [ "$USE_TRANSFORMERS" = "true" ] && [ -d "local_transformers" ]; then \

# Only install transformers if USE_TRANSFORMERS is true
RUN if [ "$USE_TRANSFORMERS" = "true" ]; then \
pip install -e /workspaces/torchprime/local_transformers evaluate; \
pip install --no-deps -e /workspaces/torchprime/local_transformers; \
pip install --no-deps evaluate; \
fi

ENV LIBTPU_INIT_ARGS "--xla_tpu_scoped_vmem_limit_kib=98304 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true"
32 changes: 31 additions & 1 deletion torchprime/sharding/shard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,33 @@

import torch.nn


@torch.library.custom_op("xla::aot_mark_sharding", mutates_args=())
def aot_mark_sharding(t: torch.Tensor, partition_spec: str) -> torch.Tensor:
import torch_xla
if t is None:
return None
import ast
mesh = torch_xla.distributed.spmd.get_global_mesh()
partition_spec_eval = ast.literal_eval(partition_spec)
torch_xla.distributed.spmd.mark_sharding(
t, mesh, partition_spec_eval)
return t.clone()

@aot_mark_sharding.register_fake
def aot_mark_sharding_fake(t: torch.Tensor, partition_spec: str) -> torch.Tensor:
if t is None:
return None
return torch.empty_like(t)


def aot_mark_sharding_backward(ctx, grad):
return grad, None


aot_mark_sharding.register_autograd(aot_mark_sharding_backward)


ShardWeightFn = Callable[[torch.Tensor, str], torch.Tensor]
"""
ShardWeightFn optionally transforms a weight tensor based on its name.
Expand Down Expand Up @@ -228,7 +255,10 @@ def shard_activation(tensor, spec: tuple[str, ...]):
def shard_param(tensor, spec: tuple[str, ...]):
the_mesh = mesh if mesh is not None else xs.get_global_mesh()
assert the_mesh is not None, "No mesh found"
return xs.mark_sharding(tensor, the_mesh, spec).global_tensor
# TODO(https://github.com/pytorch/xla/issues/8678): Shard the gradient too.
# Previously we use xs.mark_sharding(tensor, the_mesh, spec).global_tensor.
# However, this is not supported for AOT compilation.
return aot_mark_sharding(tensor, str(spec))

return shard_model_from_config(
model,
Expand Down
7 changes: 4 additions & 3 deletions torchprime/torch_xla_models/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
# override the earlier ones.
defaults:
- _self_ # refers to this config file
- model: llama-3-8b # refers to model/llama-3-8b.yaml
- model: llama-3.1-8b # refers to model/llama-3-8b.yaml

dataset_name: wikitext
dataset_config_name: wikitext-2-raw-v1
dataset_config_name: wikitext-103-raw-v1
global_batch_size: 8
logging_steps: 10
max_steps: 15
block_size: 8192
cache_dir: /tmp/
seed: 42
profile_step: -1
profile_step: 3

# This might be overwritten when using tp run to launch the run using XPK
profile_dir: profile
Expand Down Expand Up @@ -47,3 +47,4 @@ dcn_mesh:
fsdp: 1
tensor: 1
expert: 1

1 change: 1 addition & 0 deletions torchprime/torch_xla_models/configs/model/llama-3-8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ attention_dropout: false
attention_bias: false
flash_attention: true
rope_theta: 500000.0
scan_decoder_layers: true
Copy link
Collaborator Author

@zpcore zpcore Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to default yaml file

4 changes: 3 additions & 1 deletion torchprime/torch_xla_models/configs/model/llama-3.1-8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ rope_scaling:
factor: 8.0
low_freq_factor: 1.0
high_freq_factor: 4.0
original_context_len: 8192
original_context_len: 8192

scan_decoder_layers: true
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
activation_checkpoint_layers:
- LlamaDecoderLayer
# activation_checkpoint_layers:
# - LlamaDecoderLayer

# Refer to https://github.com/pytorch/xla/issues/6379 for backward optimization barrier info.
optimization_barrier_layers:
Expand All @@ -23,3 +23,4 @@ sharding:
# Activations
model.layers.*: [fsdp, null, null]
lm_head: [fsdp, null, null]

177 changes: 177 additions & 0 deletions torchprime/torch_xla_models/experimental/offloading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from typing import Sequence
import torch.fx as fx
import torch
import torch_xla
from torch.utils._pytree import tree_iter

from functorch.compile import aot_function, make_boxed_func # type:ignore
from .remat_all import remat_all_partition_fn


@torch.library.custom_op("xla::offload_name", mutates_args=())
def offload_name(t: torch.Tensor, name: str) -> torch.Tensor:
"""
`offload_name` is an identity function that associates the input
tensor with `name`. It is primarily useful in conjunction with
`remat_all_and_offload_these_inputs`, which will rematerialize
intermediate activations and also offload inputs with the specified
names to host memory, moving them back during the backward pass.
"""
if t is None:
return None
return t.clone()


@offload_name.register_fake
def _(t: torch.Tensor, name: str) -> torch.Tensor:
if t is None:
return None
return torch.empty_like(t)


def offload_name_backward(ctx, grad):
return grad, None


offload_name.register_autograd(offload_name_backward)


def remat_all_and_offload_these_inputs(
joint_module: fx.GraphModule,
_joint_inputs,
*,
num_fwd_outputs,
names_to_offload: Sequence[str],
):
"""
`remat_all_and_offload_these_inputs` will rematerialize (recompute) all
intermediate activations in `joint_module`, and offload inputs with the
specified names to host memory, moving them back during the backward pass.
It transforms the joint graph into separate forward and backward graphs.
"""
input_device = next(iter(tree_iter(_joint_inputs))).device
fwd, bwd = remat_all_partition_fn(
joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
with torch.device(input_device):
fw_example_args = make_arguments(fwd)
bw_example_args = make_arguments(bwd)

fw_name_in_output_indices = get_name_in_output_indices(fwd)
bw_name_in_input_names = get_name_in_input_names(bwd)

for name in names_to_offload:
print(f"Going to offload {name}")
assert name in fw_name_in_output_indices
assert name in bw_name_in_input_names

# print("fw_name_in_output_indices", fw_name_in_output_indices)
# print("bw_name_in_input_names", bw_name_in_input_names)

with torch.no_grad():

def forward(**kwargs):
import pdb
try:
out = fwd(**kwargs)
indices_to_offload = set(
[fw_name_in_output_indices[name] for name in names_to_offload])
return tuple(
torch.ops.xla.place_to_host(v) if i in # type:ignore
indices_to_offload else v for i, v in enumerate(out))
except Exception:
pdb.post_mortem()

def backward(**kwargs):
# print(f"Backward got {len(kwargs)} arguments:")
for k, v in kwargs.items():
print(f"Arg {k}: {v.shape if v is not None else 'None'}")
arguments_to_move_back = set(
[bw_name_in_input_names[name] for name in names_to_offload])
kwargs = {
k: torch.ops.xla.place_to_device(v) # type: ignore
if k in arguments_to_move_back else v for k, v in kwargs.items()
}
import pdb
try:
values = bwd(**kwargs)
print(f"Backward will return {len(values)} values:")
for i, v in enumerate(values):
print(f"Arg {i}: {v.shape if v is not None else 'None'}")
return values
except Exception:
pdb.post_mortem()

# Use AOTAutograd to retrace forward and backward, thus incorporating
# the offloading ops.
graph = [None]

def get_graph(g, _):
graph[0] = g
return make_boxed_func(g)

_ = aot_function(forward, fw_compiler=get_graph)(**fw_example_args)
aot_forward = graph[0]

_ = aot_function(backward, fw_compiler=get_graph)(**bw_example_args)
aot_backward = graph[0]

return aot_forward, aot_backward


def make_arguments(gm: fx.GraphModule):
"""
Given a graph module, `make_arguments` returns a dictionary of example inputs
that can be used as keyward arguments to call the graph module.
"""
example_args = {}
for node in gm.graph.nodes:
if node.op != 'placeholder':
continue
if 'tensor_meta' in node.meta:
tensor_meta = node.meta['tensor_meta']
tensor = torch.zeros(
tensor_meta.shape,
dtype=tensor_meta.dtype,
requires_grad=tensor_meta.requires_grad)
example_args[node.name] = tensor
return example_args


def get_named_nodes(gm: torch.fx.GraphModule):
named_nodes = {}

for node in gm.graph.nodes:
if node.op == "call_function":
if hasattr(node.target, "name"):
if node.target.name() == offload_name._qualname: # type: ignore
named_nodes[node.args[0]] = node.args[1]

return named_nodes


def get_name_in_output_indices(gm: torch.fx.GraphModule):
named_nodes = get_named_nodes(gm)
name_in_output_indices = {}

for node in gm.graph.nodes:
if node.op == "output":
assert len(node.args) <= 1
if len(node.args) == 0:
continue
for i, arg in enumerate(next(iter(node.args))): # type: ignore
if arg in named_nodes:
name_in_output_indices[named_nodes[arg]] = i

return name_in_output_indices


def get_name_in_input_names(gm: torch.fx.GraphModule):
named_nodes = get_named_nodes(gm)
name_in_input_names = {}

for node in gm.graph.nodes:
if node.op == "placeholder":
if node in named_nodes:
name_in_input_names[named_nodes[node]] = node.target

return name_in_input_names
75 changes: 75 additions & 0 deletions torchprime/torch_xla_models/experimental/remat_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch.fx
import torch._functorch.config
from functorch.compile import min_cut_rematerialization_partition

from contextlib import contextmanager


@contextmanager
def remat_all_config():
# Backup existing config values
backup = {
"activation_memory_budget":
torch._functorch.config.activation_memory_budget,
"aggressive_recomputation":
torch._functorch.config.aggressive_recomputation,
"recompute_views":
torch._functorch.config.recompute_views,
"ban_recompute_reductions":
torch._functorch.config.ban_recompute_reductions,
"ban_recompute_not_in_allowlist":
torch._functorch.config.ban_recompute_not_in_allowlist,
"ban_recompute_materialized_backward":
torch._functorch.config.ban_recompute_materialized_backward,
"ban_recompute_long_fusible_chains":
torch._functorch.config.ban_recompute_long_fusible_chains,
"ban_recompute_used_far_apart":
torch._functorch.config.ban_recompute_used_far_apart,
}

try:
# Set activation_memory_budget to zero to force the min cut partitioner
# to recompute instead of saving. Also don't ban the recomputing of any ops.
torch._functorch.config.activation_memory_budget = 0.0
torch._functorch.config.aggressive_recomputation = True
torch._functorch.config.recompute_views = True
torch._functorch.config.ban_recompute_reductions = False
torch._functorch.config.ban_recompute_not_in_allowlist = False
torch._functorch.config.ban_recompute_materialized_backward = False
torch._functorch.config.ban_recompute_long_fusible_chains = False
torch._functorch.config.ban_recompute_used_far_apart = False
yield

finally:
# Restore the original config values
torch._functorch.config.activation_memory_budget = backup[
"activation_memory_budget"]
torch._functorch.config.aggressive_recomputation = backup[
"aggressive_recomputation"]
torch._functorch.config.recompute_views = backup["recompute_views"]
torch._functorch.config.ban_recompute_reductions = backup[
"ban_recompute_reductions"]
torch._functorch.config.ban_recompute_not_in_allowlist = backup[
"ban_recompute_not_in_allowlist"]
torch._functorch.config.ban_recompute_materialized_backward = backup[
"ban_recompute_materialized_backward"]
torch._functorch.config.ban_recompute_long_fusible_chains = backup[
"ban_recompute_long_fusible_chains"]
torch._functorch.config.ban_recompute_used_far_apart = backup[
"ban_recompute_used_far_apart"]


def remat_all_partition_fn(
joint_module: torch.fx.GraphModule,
_joint_inputs,
*,
num_fwd_outputs,
):
"""
remat_all_partition_fn is a graph partition function that closely matches the
default behavior of `torch.utils.checkpoint`, which is to discard all intermediate
activations and recompute all of them during the backward pass.
"""
with remat_all_config():
return min_cut_rematerialization_partition(
joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
Loading
Loading