Skip to content

Commit c0c7c78

Browse files
committed
fix
1 parent 30df8b6 commit c0c7c78

File tree

2 files changed

+193
-21
lines changed

2 files changed

+193
-21
lines changed

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

Lines changed: 172 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,27 @@
1515

1616
import gc
1717
import os
18+
import re
1819
from itertools import chain
1920

2021
import paddle
2122
import paddle.distributed as dist
2223
from paddle.distributed import fleet
24+
from safetensors import safe_open
2325
from tqdm.auto import tqdm
2426

2527
from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
26-
from paddlenlp.transformers.model_utils import load_state_dict, unwrap_model
28+
from paddlenlp.transformers.model_utils import (
29+
_add_variant,
30+
load_state_dict,
31+
unwrap_model,
32+
)
33+
from paddlenlp.transformers.utils import device_guard
2734
from paddlenlp.utils.env import (
2835
SAFE_MASTER_WEIGHTS_INDEX_NAME,
36+
SAFE_MASTER_WEIGHTS_NAME,
2937
SAFE_OPTIMIZER_INDEX_NAME,
38+
SAFE_OPTIMIZER_NAME,
3039
)
3140
from paddlenlp.utils.nested import nested_copy
3241

@@ -175,6 +184,26 @@ def gather_splited_param_for_optimizer(optimizer, ckpt_quant_stage="O0"):
175184
return optim_state_dict, master_weights
176185

177186

187+
def get_params_info(comm_buffer_list):
188+
expected_keys = []
189+
param_slice_info = {}
190+
param_shape_info = {}
191+
192+
for buffer in comm_buffer_list:
193+
for key in buffer._sharding_param_grad_view.keys():
194+
begin = buffer._sharding_param_grad_view[key]._param_begin
195+
end = buffer._sharding_param_grad_view[key]._param_end
196+
if end > begin:
197+
expected_keys.append(key)
198+
shape = buffer._sharding_param_grad_view[key]._param.shape
199+
numel = buffer._sharding_param_grad_view[key]._param.numel().item()
200+
index = buffer._sharding_param_grad_view[key]._index
201+
padded_size = buffer._sharding_param_grad_view[key]._padded_size
202+
param_slice_info[key] = (begin, end)
203+
param_shape_info[key] = (shape, numel, index, padded_size)
204+
return expected_keys, param_slice_info, param_shape_info
205+
206+
178207
def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
179208
returned_optim_state_dict = nested_copy(optimizer.state_dict())
180209

@@ -196,28 +225,12 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
196225
static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} # get optimizer param mappings
197226
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()}
198227

199-
expected_keys = []
200-
param_slice_info = {}
201-
param_shape_info = {}
202-
203228
comm_buffer_list = optimizer._inner_opt._comm_buffer_list
204229
if hasattr(args, "enable_sharding_comm_overlap") and args.enable_sharding_comm_overlap:
205230
comm_buffer_list = list(chain(*model._chunk_2_comm_buffers.values()))
206231
model = unwrap_model(model)
207232

208-
for buffer in comm_buffer_list:
209-
for key in buffer._sharding_param_grad_view.keys():
210-
begin = buffer._sharding_param_grad_view[key]._param_begin
211-
end = buffer._sharding_param_grad_view[key]._param_end
212-
if end > begin:
213-
expected_keys.append(key)
214-
shape = buffer._sharding_param_grad_view[key]._param.shape
215-
numel = buffer._sharding_param_grad_view[key]._param.numel().item()
216-
index = buffer._sharding_param_grad_view[key]._index
217-
padded_size = buffer._sharding_param_grad_view[key]._padded_size
218-
param_slice_info[key] = (begin, end)
219-
param_shape_info[key] = (shape, numel, index, padded_size)
220-
233+
expected_keys, param_slice_info, param_shape_info = get_params_info(comm_buffer_list)
221234
expected_keys = set([static2struct_name_mappings.get(name, None) for name in expected_keys])
222235
expected_keys_optim = []
223236
for key in expected_keys:
@@ -291,7 +304,7 @@ def load_resolved_archive_file(
291304

292305
if int(state_dict_optim[key].numel()) > 1:
293306
begin, end = param_slice_info[static_name]
294-
shape, numel, index, padded_size = param_shape_info[static_name]
307+
_, numel, index, padded_size = param_shape_info[static_name]
295308
state_dict_optim[key] = state_dict_optim[key].reshape([-1])
296309
state_dict_optim[key] = state_dict_optim[key][begin - index : end - index]
297310

@@ -330,7 +343,7 @@ def load_resolved_archive_file(
330343
static_name = struct2static_name_mappings.get(key, None)
331344
if int(state_dict_master_weight[key].numel()) > 1:
332345
begin, end = param_slice_info[static_name]
333-
shape, numel, index, padded_size = param_shape_info[static_name]
346+
_, numel, index, padded_size = param_shape_info[static_name]
334347
state_dict_master_weight[key] = state_dict_master_weight[key].reshape([-1])
335348
state_dict_master_weight[key] = state_dict_master_weight[key][begin - index : end - index]
336349

@@ -357,3 +370,142 @@ def load_resolved_archive_file(
357370
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])
358371

359372
return returned_optim_state_dict
373+
374+
375+
def load_non_merge_optimizer_with_split_param(args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
376+
returned_optim_state_dict = nested_copy(optimizer.state_dict())
377+
378+
optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, args.optimizer_name_suffix)
379+
master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, args.optimizer_name_suffix)
380+
optimizer_path = os.path.join(resume_from_checkpoint, optimizer_name)
381+
master_weights_path = os.path.join(resume_from_checkpoint, master_weights_name)
382+
383+
# no quantization & no master weight represent O1 AMP strategy.
384+
is_amp_o1 = args.fp16_opt_level == "O1"
385+
386+
model_state_dict = get_expected_state_dict(model)
387+
static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} # get optimizer param mappings
388+
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()}
389+
390+
comm_buffer_list = optimizer._inner_opt._comm_buffer_list
391+
if hasattr(args, "enable_sharding_comm_overlap") and args.enable_sharding_comm_overlap:
392+
comm_buffer_list = list(chain(*model._chunk_2_comm_buffers.values()))
393+
394+
expected_keys, param_slice_info, param_shape_info = get_params_info(comm_buffer_list)
395+
expected_keys = set([static2struct_name_mappings.get(name, None) for name in expected_keys])
396+
expected_keys_optim = []
397+
typename_set = set()
398+
with safe_open(optimizer_path, framework="numpy") as f:
399+
optim_keys = f.keys()
400+
for key in optim_keys:
401+
_, typename = key.split("/")
402+
typename_set.add(typename)
403+
for key in expected_keys:
404+
for typename in typename_set:
405+
expected_keys_optim.append(f"{key}/{typename}")
406+
expected_keys_optim = set(expected_keys_optim)
407+
408+
optimizer_state_dict = load_state_dict(
409+
optimizer_path, None, None, device="expected", ckpt_quant_stage=ckpt_quant_stage
410+
)
411+
master_weights = {}
412+
# normal AMP O2
413+
if not is_amp_o1 and os.path.isfile(master_weights_path):
414+
master_weights = load_state_dict(master_weights_path, None, None, device="expected")
415+
416+
def get_unfound_params(unfound_keys, state_dict, is_optimizer=True):
417+
if len(unfound_keys) > 0:
418+
backup_files = []
419+
files = os.listdir(resume_from_checkpoint)
420+
name = optimizer_name if is_optimizer else master_weights_name
421+
name_without_shard = re.sub(r"_?shard\d+_?", "", name)
422+
name_ = "optimizer" if is_optimizer else "master_weights"
423+
for f in files:
424+
if f.startswith(name_) and f.endswith("safetensors") and f != name:
425+
if re.sub(r"_?shard\d+_?", "", f) == name_without_shard:
426+
backup_files.append(f)
427+
for f in backup_files:
428+
new_path = os.path.join(resume_from_checkpoint, f)
429+
with safe_open(new_path, framework="numpy") as fin:
430+
keys = fin.keys()
431+
for key in unfound_keys:
432+
if key in keys:
433+
tensor = fin.get_tensor(key)
434+
with device_guard():
435+
tensor = paddle.Tensor(tensor, zero_copy=True)
436+
state_dict[key] = tensor._copy_to(paddle.framework._current_expected_place(), False)
437+
438+
# Get other optimizer paramsters which maybe in other shard files.
439+
unfound_keys = expected_keys_optim - optimizer_state_dict.keys()
440+
get_unfound_params(unfound_keys, optimizer_state_dict, True)
441+
442+
# Get other master weight parameters which maybe in other shard files.
443+
if master_weights != {}:
444+
unfound_keys = expected_keys - master_weights.keys()
445+
get_unfound_params(unfound_keys, master_weights, False)
446+
447+
for key in list(optimizer_state_dict.keys()):
448+
key_name = key.split("/")
449+
static_name = struct2static_name_mappings.get(key_name[0], None)
450+
451+
if int(optimizer_state_dict[key].numel()) > 1:
452+
begin, end = param_slice_info[static_name]
453+
_, numel, index, padded_size = param_shape_info[static_name]
454+
optimizer_state_dict[key] = optimizer_state_dict[key].reshape([-1])
455+
optimizer_state_dict[key] = optimizer_state_dict[key][begin - index : end - index]
456+
457+
padding_start = max(begin, index + numel)
458+
padding_end = min(end, index + padded_size)
459+
if padding_start < padding_end:
460+
optimizer_state_dict[key] = paddle.concat(
461+
(
462+
optimizer_state_dict[key],
463+
paddle.zeros([padding_end - padding_start], dtype=optimizer_state_dict[key].dtype),
464+
)
465+
)
466+
467+
# rename and move to paddle.Tensor
468+
for key in list(optimizer_state_dict.keys()):
469+
key_name = key.split("/")
470+
model_weight_key = key_name[0]
471+
static_name = struct2static_name_mappings[key_name[0]]
472+
if not is_amp_o1:
473+
if model_state_dict[key_name[0]].dtype != paddle.float32:
474+
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
475+
else:
476+
key_name = "_".join([static_name, key_name[1]])
477+
else:
478+
key_name = "_".join([static_name, key_name[1]])
479+
returned_optim_state_dict[key_name] = optimizer_state_dict.pop(key)
480+
returned_optim_state_dict[key_name].name = key_name
481+
482+
# master weight cast (only in AMP O2 + remove_master_weight)
483+
if not is_amp_o1 and not os.path.isfile(master_weights_path):
484+
master_weights[model_weight_key] = paddle.cast(model_state_dict[model_weight_key], dtype=paddle.float32)
485+
486+
if not is_amp_o1:
487+
for key in list(master_weights.keys()):
488+
static_name = struct2static_name_mappings.get(key, None)
489+
if int(master_weights[key].numel()) > 1:
490+
begin, end = param_slice_info[static_name]
491+
_, numel, index, padded_size = param_shape_info[static_name]
492+
master_weights[key] = master_weights[key].reshape([-1])
493+
master_weights[key] = master_weights[key][begin - index : end - index]
494+
495+
padding_start = max(begin, index + numel)
496+
padding_end = min(end, index + padded_size)
497+
if padding_start < padding_end:
498+
master_weights[key] = paddle.concat(
499+
(
500+
master_weights[key],
501+
paddle.zeros([padding_end - padding_start], dtype=master_weights[key].dtype),
502+
)
503+
)
504+
505+
returned_optim_state_dict["master_weights"] = {}
506+
for key in list(master_weights.keys()):
507+
static_name = struct2static_name_mappings[key]
508+
returned_optim_state_dict["master_weights"][static_name] = master_weights.pop(key)
509+
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])
510+
511+
return returned_optim_state_dict

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@
6262
save_single_card_checkpoint,
6363
save_single_card_optimizer,
6464
)
65-
from .sharding_split_param_utils import gather_splited_param_for_optimizer
65+
from .sharding_split_param_utils import (
66+
gather_splited_param_for_optimizer,
67+
load_non_merge_optimizer_with_split_param,
68+
)
6669
from .utils import (
6770
FP32_MASTER,
6871
UnifiedCheckpointOption,
@@ -263,6 +266,23 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
263266
)
264267

265268
def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
269+
"""load non merge optimizer
270+
271+
Args:
272+
model (PretrainedModel): model used to get key mapping.
273+
optimizer (Optimizer): optimizer to load
274+
resume_from_checkpoint (str): path of the checkpoint to load
275+
ckpt_quant_stage (str): ckpt quant stage
276+
277+
Returns:
278+
dict: optimizer state dict
279+
"""
280+
281+
if is_sharding_split_param_mode(self.args):
282+
return load_non_merge_optimizer_with_split_param(
283+
self.args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage
284+
)
285+
266286
# init and get optimizer LR_Scheduler
267287
returned_optim_state_dict = nested_copy(optimizer.state_dict())
268288

0 commit comments

Comments
 (0)