Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions distributed/FSDP2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
## FSDP2
To run FSDP2 on transformer model:
```
cd distributed/FSDP2
torchrun --nproc_per_node 2 train.py
```
* For 1st time, it creates a "checkpoints" folder and save state dicts there
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save -> saves

* For 2nd time, it loads from previous checkpoints

To enable explicit prefetching
```
torchrun --nproc_per_node 2 train.py --explicit-prefetch
```

To enable mixed precision
```
torchrun --nproc_per_node 2 train.py --mixed-precision
```

To showcse DCP API
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

```
torchrun --nproc_per_node 2 train.py --dcp-api
```

## Ensure you are running a recent version of PyTorch:
see https://pytorch.org/get-started/locally/ to install at least 2.5 and ideally a current nightly build.
209 changes: 209 additions & 0 deletions distributed/FSDP2/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import os
import time

import torch
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import (
_init_optim_state,
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.fsdp import FSDPModule
from torch.distributed.tensor import distribute_tensor, DTensor


MODEL_CHECKPOINT = "model_state_dict.pt"
OPTIM_CHECKPOINT = "optim_state_dict.pt"
PARAMS = "params"


def get_latest_checkpoint_folder(path):
max_num = None
if not os.path.exists(path):
return max_num
for name in os.listdir(path):
folder_path = os.path.join(path, name)
if os.path.isdir(folder_path):
try:
num = int(name)
if max_num is None or num > max_num:
max_num = num
except ValueError:
pass # Skip non-numeric folder names
return max_num


class Checkpointer:
def __init__(self, folder: str, dcp_api: bool):
self.folder = folder
self.dcp_api = dcp_api
self.last_training_time = get_latest_checkpoint_folder(
f"{folder}/{'dcp_api' if dcp_api else 'dtensor_api'}"
)

def is_empty(self):
return self.last_training_time is None

def load_model(self, model: FSDPModule):
last_model_checkpoint = (
f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}"
f"/{self.last_training_time}/{MODEL_CHECKPOINT}"
)
full_sd = torch.load(
last_model_checkpoint, mmap=True, weights_only=True, map_location="cpu"
)
if self.dcp_api:
set_model_state_dict(
model=model,
model_state_dict=full_sd,
options=StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
),
)
return
meta_sharded_sd = model.state_dict()
sharded_sd = {}
for param_name, full_tensor in full_sd.items():
sharded_meta_param = meta_sharded_sd.get(param_name)
sharded_tensor = distribute_tensor(
full_tensor,
sharded_meta_param.device_mesh,
sharded_meta_param.placements,
)
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
# choose `assign=True` since we cannot call `copy_` on meta tensor
model.load_state_dict(sharded_sd, strict=False, assign=True)

def load_optim(self, model: FSDPModule, opt: torch.optim.Optimizer):
last_optim_checkpoint = (
f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}"
f"/{self.last_training_time}/{OPTIM_CHECKPOINT}"
)
full_sd = torch.load(
last_optim_checkpoint, mmap=True, weights_only=True, map_location="cpu"
)
if self.dcp_api:
set_optimizer_state_dict(
model=model,
optimizers=opt,
optim_state_dict=full_sd,
options=StateDictOptions(
full_state_dict=True,
broadcast_from_rank0=True,
),
)
return
_init_optim_state(opt)
param_groups = opt.state_dict()["param_groups"]
state = opt.state_dict()["state"]

full_param_groups = full_sd["param_groups"]
full_state = full_sd["state"]

for param_group, full_param_group in zip(param_groups, full_param_groups):
for key, value in full_param_group.items():
if key == PARAMS:
continue
param_group[key] = value
for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]):
if pid not in state:
continue
param_state = state[pid]
full_param_state = full_state[full_pid]
for attr, full_tensor in full_param_state.items():
sharded_tensor = param_state[attr]
if isinstance(sharded_tensor, DTensor):
# exp_avg is DTensor
param_state[attr] = distribute_tensor(
full_tensor,
sharded_tensor.device_mesh,
sharded_tensor.placements,
)
else:
# step is plain tensor
param_state[attr] = full_tensor
opt.load_state_dict(
{
"param_groups": param_groups,
"state": state,
}
)

def _get_full_model_state_dict(self, model: FSDPModule):
if self.dcp_api:
return get_model_state_dict(
model=model,
options=StateDictOptions(
full_state_dict=True,
cpu_offload=True,
),
)

sharded_sd = model.state_dict()
cpu_state_dict = {}
for param_name, sharded_param in sharded_sd.items():
full_param = sharded_param.full_tensor()
if torch.distributed.get_rank() == 0:
cpu_state_dict[param_name] = full_param.cpu()
else:
del full_param
return cpu_state_dict

def _get_full_optimizer_state_dict(
self,
model: FSDPModule,
opt: torch.optim.Optimizer,
):
if self.dcp_api:
return get_optimizer_state_dict(
model=model,
optimizers=opt,
options=StateDictOptions(
full_state_dict=True,
cpu_offload=True,
),
)
is_rank_zero = torch.distributed.get_rank() == 0
sharded_sd = opt.state_dict()
sharded_state = sharded_sd["state"]
full_state = {}
for group_id, sharded_group in sharded_state.items():
group_state = {}
for attr, sharded_tensor in sharded_group.items():
if isinstance(sharded_tensor, DTensor):
# "exp_avg" in AdamW is `DTensor`
full_tensor = sharded_tensor.full_tensor()
else:
# "step" in AdamW is plain tensor
full_tensor = sharded_tensor
if is_rank_zero:
group_state[attr] = full_tensor.cpu()
else:
del full_tensor
if is_rank_zero:
full_state[group_id] = group_state
else:
del group_state
if is_rank_zero:
return {
"param_groups": sharded_sd["param_groups"],
"state": full_state,
}
else:
return {}

def save(self, model: FSDPModule, optim: torch.optim.Optimizer):
model_state_dict = self._get_full_model_state_dict(model)
optim_state_dict = self._get_full_optimizer_state_dict(model, optim)
if torch.distributed.get_rank() == 0:
new_training_time = int(time.time() * 1000)
new_checkpoint_folder = f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}/{new_training_time}"
new_model_checkpoint = f"{new_checkpoint_folder}/{MODEL_CHECKPOINT}"
new_optim_checkpoint = f"{new_checkpoint_folder}/{OPTIM_CHECKPOINT}"
os.makedirs(new_checkpoint_folder, exist_ok=True)
torch.save(model_state_dict, new_model_checkpoint)
torch.save(optim_state_dict, new_optim_checkpoint)
134 changes: 134 additions & 0 deletions distributed/FSDP2/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class ModelArgs:
n_layers: int = 2
vocab_size: int = 8
max_seq_len: int = 16
dim: int = 16
n_heads: int = 4
dropout_p: float = 0.1


class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.dim % args.n_heads == 0
self.head_dim = args.dim // args.n_heads
self.n_heads = args.n_heads
self.dropout_p = args.dropout_p
self.resid_dropout = nn.Dropout(args.dropout_p)

self.wq = nn.Linear(args.dim, args.dim, bias=False)
self.wk = nn.Linear(args.dim, args.dim, bias=False)
self.wv = nn.Linear(args.dim, args.dim, bias=False)
self.wo = nn.Linear(args.dim, args.dim, bias=False)

def forward(self, x):
bsz, seq_len, _ = x.size()
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim)
keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim)
values = values.view(bsz, seq_len, self.n_heads, self.head_dim)

queries = queries.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
keys = keys.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
values = values.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)

output = F.scaled_dot_product_attention(
queries,
keys,
values,
None,
self.dropout_p if self.training else 0,
)
output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
return self.resid_dropout(self.wo(output))

def reset_parameters(self):
self.wq.reset_parameters()
self.wk.reset_parameters()
self.wv.reset_parameters()
self.wo.reset_parameters()


class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout_p):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim)
self.gelu = nn.GELU()
self.w2 = nn.Linear(hidden_dim, dim)
self.resid_dropout = nn.Dropout(dropout_p)

def forward(self, x):
return self.resid_dropout(self.w2(self.gelu(self.w1(x))))

def reset_parameters(self):
self.w1.reset_parameters()
self.w2.reset_parameters()


class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.attention_norm = nn.LayerNorm(args.dim)
self.attention = Attention(args)
self.ffn_norm = nn.LayerNorm(args.dim)
self.feed_forward = FeedForward(
args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p
)

def forward(self, x):
h = x + self.attention(self.attention_norm(x))
out = h + self.feed_forward(self.ffn_norm(h))
return out

def reset_parameters(self):
self.attention_norm.reset_parameters()
self.attention.reset_parameters()
self.ffn_norm.reset_parameters()
self.feed_forward.reset_parameters()


# A toy transformer model, partly inspired by the nanoGPT model:
# https://github.com/karpathy/nanoGPT.
class Transformer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
assert args.vocab_size is not None
assert args.max_seq_len is not None
self.model_args = args
self.max_seq_len = args.max_seq_len
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim)
self.dropout = nn.Dropout(args.dropout_p)
self.layers = nn.ModuleList()
for _ in range(args.n_layers):
self.layers.append(TransformerBlock(args))
self.norm = nn.LayerNorm(args.dim)
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)

def forward(self, tokens):
_bsz, seq_len = tokens.size()
assert seq_len <= self.max_seq_len
h = self.tok_embeddings(tokens)
pos = torch.arange(0, seq_len, device=tokens.device)
p = self.pos_embeddings(pos) # positional embeddings of shape (seq_len, dim)
h = h + p
h = self.dropout(h)
for layer in self.layers:
h = layer(h)
h = self.norm(h)
output = self.output(h).float()
return output

def reset_parameters(self):
self.tok_embeddings.reset_parameters()
self.pos_embeddings.reset_parameters()
self.norm.reset_parameters()
self.output.reset_parameters()
Loading