Skip to content

Commit f578946

Browse files
committed
fix
1 parent 30df8b6 commit f578946

File tree

2 files changed

+171
-21
lines changed

2 files changed

+171
-21
lines changed

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

Lines changed: 150 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,20 @@
2020
import paddle
2121
import paddle.distributed as dist
2222
from paddle.distributed import fleet
23+
from safetensors import safe_open
2324
from tqdm.auto import tqdm
2425

2526
from paddlenlp.peft import LoRAModel, PrefixModelForCausalLM
26-
from paddlenlp.transformers.model_utils import load_state_dict, unwrap_model
27+
from paddlenlp.transformers.model_utils import (
28+
_add_variant,
29+
load_state_dict,
30+
unwrap_model,
31+
)
2732
from paddlenlp.utils.env import (
2833
SAFE_MASTER_WEIGHTS_INDEX_NAME,
34+
SAFE_MASTER_WEIGHTS_NAME,
2935
SAFE_OPTIMIZER_INDEX_NAME,
36+
SAFE_OPTIMIZER_NAME,
3037
)
3138
from paddlenlp.utils.nested import nested_copy
3239

@@ -175,6 +182,26 @@ def gather_splited_param_for_optimizer(optimizer, ckpt_quant_stage="O0"):
175182
return optim_state_dict, master_weights
176183

177184

185+
def get_params_info(comm_buffer_list):
186+
expected_keys = []
187+
param_slice_info = {}
188+
param_shape_info = {}
189+
190+
for buffer in comm_buffer_list:
191+
for key in buffer._sharding_param_grad_view.keys():
192+
begin = buffer._sharding_param_grad_view[key]._param_begin
193+
end = buffer._sharding_param_grad_view[key]._param_end
194+
if end > begin:
195+
expected_keys.append(key)
196+
shape = buffer._sharding_param_grad_view[key]._param.shape
197+
numel = buffer._sharding_param_grad_view[key]._param.numel().item()
198+
index = buffer._sharding_param_grad_view[key]._index
199+
padded_size = buffer._sharding_param_grad_view[key]._padded_size
200+
param_slice_info[key] = (begin, end)
201+
param_shape_info[key] = (shape, numel, index, padded_size)
202+
return expected_keys, param_slice_info, param_shape_info
203+
204+
178205
def load_unified_optimizer_split_param(args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
179206
returned_optim_state_dict = nested_copy(optimizer.state_dict())
180207

@@ -196,28 +223,12 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
196223
static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} # get optimizer param mappings
197224
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()}
198225

199-
expected_keys = []
200-
param_slice_info = {}
201-
param_shape_info = {}
202-
203226
comm_buffer_list = optimizer._inner_opt._comm_buffer_list
204227
if hasattr(args, "enable_sharding_comm_overlap") and args.enable_sharding_comm_overlap:
205228
comm_buffer_list = list(chain(*model._chunk_2_comm_buffers.values()))
206229
model = unwrap_model(model)
207230

208-
for buffer in comm_buffer_list:
209-
for key in buffer._sharding_param_grad_view.keys():
210-
begin = buffer._sharding_param_grad_view[key]._param_begin
211-
end = buffer._sharding_param_grad_view[key]._param_end
212-
if end > begin:
213-
expected_keys.append(key)
214-
shape = buffer._sharding_param_grad_view[key]._param.shape
215-
numel = buffer._sharding_param_grad_view[key]._param.numel().item()
216-
index = buffer._sharding_param_grad_view[key]._index
217-
padded_size = buffer._sharding_param_grad_view[key]._padded_size
218-
param_slice_info[key] = (begin, end)
219-
param_shape_info[key] = (shape, numel, index, padded_size)
220-
231+
expected_keys, param_slice_info, param_shape_info = get_params_info(comm_buffer_list)
221232
expected_keys = set([static2struct_name_mappings.get(name, None) for name in expected_keys])
222233
expected_keys_optim = []
223234
for key in expected_keys:
@@ -291,7 +302,7 @@ def load_resolved_archive_file(
291302

292303
if int(state_dict_optim[key].numel()) > 1:
293304
begin, end = param_slice_info[static_name]
294-
shape, numel, index, padded_size = param_shape_info[static_name]
305+
_, numel, index, padded_size = param_shape_info[static_name]
295306
state_dict_optim[key] = state_dict_optim[key].reshape([-1])
296307
state_dict_optim[key] = state_dict_optim[key][begin - index : end - index]
297308

@@ -330,7 +341,7 @@ def load_resolved_archive_file(
330341
static_name = struct2static_name_mappings.get(key, None)
331342
if int(state_dict_master_weight[key].numel()) > 1:
332343
begin, end = param_slice_info[static_name]
333-
shape, numel, index, padded_size = param_shape_info[static_name]
344+
_, numel, index, padded_size = param_shape_info[static_name]
334345
state_dict_master_weight[key] = state_dict_master_weight[key].reshape([-1])
335346
state_dict_master_weight[key] = state_dict_master_weight[key][begin - index : end - index]
336347

@@ -357,3 +368,122 @@ def load_resolved_archive_file(
357368
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])
358369

359370
return returned_optim_state_dict
371+
372+
373+
def load_non_merge_optimizer_with_split_param(args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
374+
returned_optim_state_dict = nested_copy(optimizer.state_dict())
375+
376+
optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, args.optimizer_name_suffix)
377+
master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, args.optimizer_name_suffix)
378+
optimizer_path = os.path.join(resume_from_checkpoint, optimizer_name)
379+
master_weights_path = os.path.join(resume_from_checkpoint, master_weights_name)
380+
381+
# no quantization & no master weight represent O1 AMP strategy.
382+
is_amp_o1 = args.fp16_opt_level == "O1"
383+
384+
model_state_dict = get_expected_state_dict(model)
385+
static2struct_name_mappings = {v.name: k for k, v in model_state_dict.items()} # get optimizer param mappings
386+
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()}
387+
388+
comm_buffer_list = optimizer._inner_opt._comm_buffer_list
389+
if hasattr(args, "enable_sharding_comm_overlap") and args.enable_sharding_comm_overlap:
390+
comm_buffer_list = list(chain(*model._chunk_2_comm_buffers.values()))
391+
392+
expected_keys, param_slice_info, param_shape_info = get_params_info(comm_buffer_list)
393+
expected_keys = set([static2struct_name_mappings.get(name, None) for name in expected_keys])
394+
expected_keys_optim = []
395+
typename_set = set()
396+
with safe_open(optimizer_path, framework="numpy") as f:
397+
optim_keys = f.keys()
398+
for key in optim_keys:
399+
_, typename = key.split("/")
400+
typename_set.add(typename)
401+
for key in expected_keys:
402+
for typename in typename_set:
403+
expected_keys_optim.append(f"{key}/{typename}")
404+
expected_keys_optim = set(expected_keys_optim)
405+
406+
optimizer_state_dict = load_state_dict(
407+
optimizer_path, None, None, device="expected", ckpt_quant_stage=ckpt_quant_stage
408+
)
409+
master_weights = {}
410+
# normal AMP O2
411+
if not is_amp_o1 and os.path.isfile(master_weights_path):
412+
master_weights = load_state_dict(master_weights_path, None, None, device="expected")
413+
414+
# Get other param slice which maybe in other shard files.
415+
unfound_keys = expected_keys_optim - optimizer_state_dict.keys()
416+
if len(unfound_keys) > 0:
417+
backup_files = []
418+
files = os.listdir(resume_from_checkpoint)
419+
for f in files:
420+
if f.startswith("optimizer") and f.endswith("safetensors"):
421+
backup_files.append(f)
422+
print(backup_files)
423+
raise ValueError
424+
425+
for key in list(optimizer_state_dict.keys()):
426+
key_name = key.split("/")
427+
static_name = struct2static_name_mappings.get(key_name[0], None)
428+
429+
if int(optimizer_state_dict[key].numel()) > 1:
430+
begin, end = param_slice_info[static_name]
431+
_, numel, index, padded_size = param_shape_info[static_name]
432+
optimizer_state_dict[key] = optimizer_state_dict[key].reshape([-1])
433+
optimizer_state_dict[key] = optimizer_state_dict[key][begin - index : end - index]
434+
435+
padding_start = max(begin, index + numel)
436+
padding_end = min(end, index + padded_size)
437+
if padding_start < padding_end:
438+
optimizer_state_dict[key] = paddle.concat(
439+
(
440+
optimizer_state_dict[key],
441+
paddle.zeros([padding_end - padding_start], dtype=optimizer_state_dict[key].dtype),
442+
)
443+
)
444+
445+
# rename and move to paddle.Tensor
446+
for key in list(optimizer_state_dict.keys()):
447+
key_name = key.split("/")
448+
model_weight_key = key_name[0]
449+
static_name = struct2static_name_mappings[key_name[0]]
450+
if not is_amp_o1:
451+
if model_state_dict[key_name[0]].dtype != paddle.float32:
452+
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
453+
else:
454+
key_name = "_".join([static_name, key_name[1]])
455+
else:
456+
key_name = "_".join([static_name, key_name[1]])
457+
returned_optim_state_dict[key_name] = optimizer_state_dict.pop(key)
458+
returned_optim_state_dict[key_name].name = key_name
459+
460+
# master weight cast (only in AMP O2 + remove_master_weight)
461+
if not is_amp_o1 and not os.path.isfile(master_weights_path):
462+
master_weights[model_weight_key] = paddle.cast(model_state_dict[model_weight_key], dtype=paddle.float32)
463+
464+
if not is_amp_o1:
465+
for key in list(master_weights.keys()):
466+
static_name = struct2static_name_mappings.get(key, None)
467+
if int(master_weights[key].numel()) > 1:
468+
begin, end = param_slice_info[static_name]
469+
_, numel, index, padded_size = param_shape_info[static_name]
470+
master_weights[key] = master_weights[key].reshape([-1])
471+
master_weights[key] = master_weights[key][begin - index : end - index]
472+
473+
padding_start = max(begin, index + numel)
474+
padding_end = min(end, index + padded_size)
475+
if padding_start < padding_end:
476+
master_weights[key] = paddle.concat(
477+
(
478+
master_weights[key],
479+
paddle.zeros([padding_end - padding_start], dtype=master_weights[key].dtype),
480+
)
481+
)
482+
483+
returned_optim_state_dict["master_weights"] = {}
484+
for key in list(master_weights.keys()):
485+
static_name = struct2static_name_mappings[key]
486+
returned_optim_state_dict["master_weights"][static_name] = master_weights.pop(key)
487+
returned_optim_state_dict["master_weights"][static_name].name = "_".join([static_name, FP32_MASTER])
488+
489+
return returned_optim_state_dict

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@
6262
save_single_card_checkpoint,
6363
save_single_card_optimizer,
6464
)
65-
from .sharding_split_param_utils import gather_splited_param_for_optimizer
65+
from .sharding_split_param_utils import (
66+
gather_splited_param_for_optimizer,
67+
load_non_merge_optimizer_with_split_param,
68+
)
6669
from .utils import (
6770
FP32_MASTER,
6871
UnifiedCheckpointOption,
@@ -263,6 +266,23 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
263266
)
264267

265268
def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint, ckpt_quant_stage="O0"):
269+
"""load non merge optimizer
270+
271+
Args:
272+
model (PretrainedModel): model used to get key mapping.
273+
optimizer (Optimizer): optimizer to load
274+
resume_from_checkpoint (str): path of the checkpoint to load
275+
ckpt_quant_stage (str): ckpt quant stage
276+
277+
Returns:
278+
dict: optimizer state dict
279+
"""
280+
281+
if is_sharding_split_param_mode(self.args):
282+
return load_non_merge_optimizer_with_split_param(
283+
self.args, model, optimizer, resume_from_checkpoint, ckpt_quant_stage
284+
)
285+
266286
# init and get optimizer LR_Scheduler
267287
returned_optim_state_dict = nested_copy(optimizer.state_dict())
268288

0 commit comments

Comments
 (0)