Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 166 additions & 49 deletions paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,27 @@

import gc
import os
import re
from itertools import chain

import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from safetensors import safe_open
from tqdm.auto import tqdm

from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
from paddlenlp.transformers.model_utils import load_state_dict, unwrap_model
from paddlenlp.transformers.model_utils import (
_add_variant,
load_state_dict,
unwrap_model,
)
from paddlenlp.transformers.utils import device_guard
from paddlenlp.utils.env import (
SAFE_MASTER_WEIGHTS_INDEX_NAME,
SAFE_MASTER_WEIGHTS_NAME,
SAFE_OPTIMIZER_INDEX_NAME,
SAFE_OPTIMIZER_NAME,
)
from paddlenlp.utils.nested import nested_copy

Expand Down Expand Up @@ -175,6 +184,49 @@
return optim_state_dict, master_weights


def get_params_info(comm_buffer_list):
expected_keys = []
param_slice_info = {}
param_shape_info = {}

Check warning on line 190 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L188-L190

Added lines #L188 - L190 were not covered by tests

for buffer in comm_buffer_list:
for key in buffer._sharding_param_grad_view.keys():
begin = buffer._sharding_param_grad_view[key]._param_begin
end = buffer._sharding_param_grad_view[key]._param_end
if end > begin:
expected_keys.append(key)
shape = buffer._sharding_param_grad_view[key]._param.shape
numel = buffer._sharding_param_grad_view[key]._param.numel().item()
index = buffer._sharding_param_grad_view[key]._index
padded_size = buffer._sharding_param_grad_view[key]._padded_size
param_slice_info[key] = (begin, end)
param_shape_info[key] = (shape, numel, index, padded_size)
return expected_keys, param_slice_info, param_shape_info

Check warning on line 204 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L192-L204

Added lines #L192 - L204 were not covered by tests


def reshape_params(state_dict, struct2static_name_mappings, param_shape_info, param_slice_info):
"""Reshape params to 1-D tensors"""
for key in list(state_dict.keys()):
key_name = key.split("/")[0]
static_name = struct2static_name_mappings.get(key_name, None)
if int(state_dict[key].numel()) > 1:
begin, end = param_slice_info[static_name]
_, numel, index, padded_size = param_shape_info[static_name]
state_dict[key] = state_dict[key].reshape([-1])
state_dict[key] = state_dict[key][begin - index : end - index]

Check warning on line 216 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L209-L216

Added lines #L209 - L216 were not covered by tests

padding_start = max(begin, index + numel)
padding_end = min(end, index + padded_size)
if padding_start < padding_end:
state_dict[key] = paddle.concat(

Check warning on line 221 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L218-L221

Added lines #L218 - L221 were not covered by tests
(
state_dict[key],
paddle.zeros([padding_end - padding_start], dtype=state_dict[key].dtype),
)
)
return state_dict

Check warning on line 227 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L227

Added line #L227 was not covered by tests


def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
returned_optim_state_dict = nested_copy(optimizer.state_dict())

Expand All @@ -196,28 +248,12 @@
static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} # get optimizer param mappings
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()}

expected_keys = []
param_slice_info = {}
param_shape_info = {}

comm_buffer_list = optimizer._inner_opt._comm_buffer_list
if hasattr(args, "enable_sharding_comm_overlap") and args.enable_sharding_comm_overlap:
comm_buffer_list = list(chain(*model._chunk_2_comm_buffers.values()))
model = unwrap_model(model)

for buffer in comm_buffer_list:
for key in buffer._sharding_param_grad_view.keys():
begin = buffer._sharding_param_grad_view[key]._param_begin
end = buffer._sharding_param_grad_view[key]._param_end
if end > begin:
expected_keys.append(key)
shape = buffer._sharding_param_grad_view[key]._param.shape
numel = buffer._sharding_param_grad_view[key]._param.numel().item()
index = buffer._sharding_param_grad_view[key]._index
padded_size = buffer._sharding_param_grad_view[key]._padded_size
param_slice_info[key] = (begin, end)
param_shape_info[key] = (shape, numel, index, padded_size)

expected_keys, param_slice_info, param_shape_info = get_params_info(comm_buffer_list)

Check warning on line 256 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L256

Added line #L256 was not covered by tests
expected_keys = set([static2struct_name_mappings.get(name, None) for name in expected_keys])
expected_keys_optim = []
for key in expected_keys:
Expand Down Expand Up @@ -285,25 +321,10 @@
)

# need to split param for different sharding rank, maybe need to deal with oom issue.
reshape_params(state_dict_optim, struct2static_name_mappings, param_shape_info, param_slice_info)

Check warning on line 324 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L324

Added line #L324 was not covered by tests
for key in list(state_dict_optim.keys()):
key_name = key.split("/")
static_name = struct2static_name_mappings.get(key_name[0], None)

if int(state_dict_optim[key].numel()) > 1:
begin, end = param_slice_info[static_name]
shape, numel, index, padded_size = param_shape_info[static_name]
state_dict_optim[key] = state_dict_optim[key].reshape([-1])
state_dict_optim[key] = state_dict_optim[key][begin - index : end - index]

padding_start = max(begin, index + numel)
padding_end = min(end, index + padded_size)
if padding_start < padding_end:
state_dict_optim[key] = paddle.concat(
(
state_dict_optim[key],
paddle.zeros([padding_end - padding_start], dtype=state_dict_optim[key].dtype),
)
)
if has_master_weights:
if model_state_dict[key_name[0]].dtype != paddle.float32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
Expand All @@ -325,24 +346,10 @@
expected_keys,
is_master_weights=True,
)
reshape_params(state_dict_master_weight, struct2static_name_mappings, param_shape_info, param_slice_info)

Check warning on line 349 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L349

Added line #L349 was not covered by tests

for key in list(state_dict_master_weight.keys()):
static_name = struct2static_name_mappings.get(key, None)
if int(state_dict_master_weight[key].numel()) > 1:
begin, end = param_slice_info[static_name]
shape, numel, index, padded_size = param_shape_info[static_name]
state_dict_master_weight[key] = state_dict_master_weight[key].reshape([-1])
state_dict_master_weight[key] = state_dict_master_weight[key][begin - index : end - index]

padding_start = max(begin, index + numel)
padding_end = min(end, index + padded_size)
if padding_start < padding_end:
state_dict_master_weight[key] = paddle.concat(
(
state_dict_master_weight[key],
paddle.zeros([padding_end - padding_start], dtype=state_dict_master_weight[key].dtype),
)
)
state_dict_master_weight[key] = state_dict_master_weight[key]._copy_to(
paddle.framework._current_expected_place(), False
)
Expand All @@ -357,3 +364,113 @@
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])

return returned_optim_state_dict


def load_non_merge_optimizer_with_split_param(args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
returned_optim_state_dict = nested_copy(optimizer.state_dict())

Check warning on line 370 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L370

Added line #L370 was not covered by tests

optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, args.optimizer_name_suffix)
master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, args.optimizer_name_suffix)
optimizer_path = os.path.join(resume_from_checkpoint, optimizer_name)
master_weights_path = os.path.join(resume_from_checkpoint, master_weights_name)

Check warning on line 375 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L372-L375

Added lines #L372 - L375 were not covered by tests

# no quantization & no master weight represent O1 AMP strategy.
is_amp_o1 = args.fp16_opt_level == "O1"

Check warning on line 378 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L378

Added line #L378 was not covered by tests

model_state_dict = get_expected_state_dict(model)
static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} # get optimizer param mappings
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()}

Check warning on line 382 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L380-L382

Added lines #L380 - L382 were not covered by tests

comm_buffer_list = optimizer._inner_opt._comm_buffer_list
if hasattr(args, "enable_sharding_comm_overlap") and args.enable_sharding_comm_overlap:
comm_buffer_list = list(chain(*model._chunk_2_comm_buffers.values()))

Check warning on line 386 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L384-L386

Added lines #L384 - L386 were not covered by tests

expected_keys, param_slice_info, param_shape_info = get_params_info(comm_buffer_list)
expected_keys = set([static2struct_name_mappings.get(name, None) for name in expected_keys])
expected_keys_optim = []
sharding_typename_set, typename_set = [], []
with safe_open(optimizer_path, framework="numpy") as f:
optim_keys = f.keys()
for key in optim_keys:
_, typename = key.split("/")
typename_set.append(typename)

Check warning on line 396 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L388-L396

Added lines #L388 - L396 were not covered by tests

# To avoid incomplete typename in some shard files, communication is performed.
hcg = fleet.get_hybrid_communicate_group()
sharding_group = hcg.get_sharding_parallel_group()
dist.all_gather_object(sharding_typename_set, typename_set, sharding_group)
typename_set = set(chain(*sharding_typename_set))
for key in expected_keys:
for typename in typename_set:
expected_keys_optim.append(f"{key}/{typename}")
expected_keys_optim = set(expected_keys_optim)

Check warning on line 406 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L399-L406

Added lines #L399 - L406 were not covered by tests

optimizer_state_dict = load_state_dict(

Check warning on line 408 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L408

Added line #L408 was not covered by tests
optimizer_path, None, None, device="expected", ckpt_quant_stage=ckpt_quant_stage
)
master_weights = {}

Check warning on line 411 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L411

Added line #L411 was not covered by tests
# normal AMP O2
if not is_amp_o1 and os.path.isfile(master_weights_path):
master_weights = load_state_dict(master_weights_path, None, None, device="expected")

Check warning on line 414 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L413-L414

Added lines #L413 - L414 were not covered by tests

def get_unfound_params(unfound_keys, state_dict, is_optimizer=True):
if len(unfound_keys) > 0:
backup_files = []
files = os.listdir(resume_from_checkpoint)
name = optimizer_name if is_optimizer else master_weights_name
name_without_shard = re.sub(r"_?shard\d+_?", "", name)
name_ = "optimizer" if is_optimizer else "master_weights"
for f in files:
if f.startswith(name_) and f.endswith("safetensors") and f != name:
if re.sub(r"_?shard\d+_?", "", f) == name_without_shard:
backup_files.append(f)
for f in backup_files:
new_path = os.path.join(resume_from_checkpoint, f)
with safe_open(new_path, framework="numpy") as fin:
keys = fin.keys()
for key in unfound_keys:
if key in keys:
tensor = fin.get_tensor(key)
with device_guard():
tensor = paddle.Tensor(tensor, zero_copy=True)
state_dict[key] = tensor._copy_to(paddle.framework._current_expected_place(), False)

Check warning on line 436 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L416-L436

Added lines #L416 - L436 were not covered by tests

# Get other optimizer paramsters which maybe in other shard files.
unfound_keys = expected_keys_optim - optimizer_state_dict.keys()
get_unfound_params(unfound_keys, optimizer_state_dict, True)

Check warning on line 440 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L439-L440

Added lines #L439 - L440 were not covered by tests

# Get other master weight parameters which maybe in other shard files.
if master_weights != {}:
unfound_keys = expected_keys - master_weights.keys()
get_unfound_params(unfound_keys, master_weights, False)
reshape_params(optimizer_state_dict, struct2static_name_mappings, param_shape_info, param_slice_info)

Check warning on line 446 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L443-L446

Added lines #L443 - L446 were not covered by tests

# rename and move to paddle.Tensor
for key in list(optimizer_state_dict.keys()):
key_name = key.split("/")
model_weight_key = key_name[0]
static_name = struct2static_name_mappings[key_name[0]]
if not is_amp_o1:
if model_state_dict[key_name[0]].dtype != paddle.float32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])

Check warning on line 455 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L449-L455

Added lines #L449 - L455 were not covered by tests
else:
key_name = "_".join([static_name, key_name[1]])

Check warning on line 457 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L457

Added line #L457 was not covered by tests
else:
key_name = "_".join([static_name, key_name[1]])
returned_optim_state_dict[key_name] = optimizer_state_dict.pop(key)
returned_optim_state_dict[key_name].name = key_name

Check warning on line 461 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L459-L461

Added lines #L459 - L461 were not covered by tests

# master weight cast (only in AMP O2 + remove_master_weight)
if not is_amp_o1 and not os.path.isfile(master_weights_path):
master_weights[model_weight_key] = paddle.cast(model_state_dict[model_weight_key], dtype=paddle.float32)

Check warning on line 465 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L464-L465

Added lines #L464 - L465 were not covered by tests

if not is_amp_o1:
reshape_params(master_weights, struct2static_name_mappings, param_shape_info, param_slice_info)

Check warning on line 468 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L467-L468

Added lines #L467 - L468 were not covered by tests

returned_optim_state_dict["master_weights"] = {}
for key in list(master_weights.keys()):
static_name = struct2static_name_mappings[key]
returned_optim_state_dict["master_weights"][static_name] = master_weights.pop(key)
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])

Check warning on line 474 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L470-L474

Added lines #L470 - L474 were not covered by tests

return returned_optim_state_dict

Check warning on line 476 in paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py#L476

Added line #L476 was not covered by tests
22 changes: 21 additions & 1 deletion paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@
save_single_card_checkpoint,
save_single_card_optimizer,
)
from .sharding_split_param_utils import gather_splited_param_for_optimizer
from .sharding_split_param_utils import (
gather_splited_param_for_optimizer,
load_non_merge_optimizer_with_split_param,
)
from .utils import (
FP32_MASTER,
UnifiedCheckpointOption,
Expand Down Expand Up @@ -263,6 +266,23 @@
)

def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
"""load non merge optimizer

Args:
model (PretrainedModel): model used to get key mapping.
optimizer (Optimizer): optimizer to load
resume_from_checkpoint (str): path of the checkpoint to load
ckpt_quant_stage (str): ckpt quant stage

Returns:
dict: optimizer state dict
"""

if is_sharding_split_param_mode(self.args):
return load_non_merge_optimizer_with_split_param(

Check warning on line 282 in paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py#L281-L282

Added lines #L281 - L282 were not covered by tests
self.args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage
)

# init and get optimizer LR_Scheduler
returned_optim_state_dict = nested_copy(optimizer.state_dict())

Expand Down
Loading