1515
1616import gc
1717import os
18+ import re
1819from itertools import chain
1920
2021import paddle
2122import paddle .distributed as dist
2223from paddle .distributed import fleet
24+ from safetensors import safe_open
2325from tqdm .auto import tqdm
2426
2527from paddlenlp .peft import LoRAModel , PrefixModelForCausalLM
26- from paddlenlp .transformers .model_utils import load_state_dict , unwrap_model
28+ from paddlenlp .transformers .model_utils import (
29+ _add_variant ,
30+ load_state_dict ,
31+ unwrap_model ,
32+ )
33+ from paddlenlp .transformers .utils import device_guard
2734from paddlenlp .utils .env import (
2835 SAFE_MASTER_WEIGHTS_INDEX_NAME ,
36+ SAFE_MASTER_WEIGHTS_NAME ,
2937 SAFE_OPTIMIZER_INDEX_NAME ,
38+ SAFE_OPTIMIZER_NAME ,
3039)
3140from paddlenlp .utils .nested import nested_copy
3241
@@ -175,6 +184,26 @@ def gather_splited_param_for_optimizer(optimizer, ckpt_quant_stage="O0"):
175184 return optim_state_dict , master_weights
176185
177186
187+ def get_params_info (comm_buffer_list ):
188+ expected_keys = []
189+ param_slice_info = {}
190+ param_shape_info = {}
191+
192+ for buffer in comm_buffer_list :
193+ for key in buffer ._sharding_param_grad_view .keys ():
194+ begin = buffer ._sharding_param_grad_view [key ]._param_begin
195+ end = buffer ._sharding_param_grad_view [key ]._param_end
196+ if end > begin :
197+ expected_keys .append (key )
198+ shape = buffer ._sharding_param_grad_view [key ]._param .shape
199+ numel = buffer ._sharding_param_grad_view [key ]._param .numel ().item ()
200+ index = buffer ._sharding_param_grad_view [key ]._index
201+ padded_size = buffer ._sharding_param_grad_view [key ]._padded_size
202+ param_slice_info [key ] = (begin , end )
203+ param_shape_info [key ] = (shape , numel , index , padded_size )
204+ return expected_keys , param_slice_info , param_shape_info
205+
206+
178207def load_unified_optimizer_split_param (args , model , optimizer , resume_from_checkpoint , ckpt_quant_stage = "O0" ):
179208 returned_optim_state_dict = nested_copy (optimizer .state_dict ())
180209
@@ -196,28 +225,12 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
196225 static2struct_name_mappings = {v .name : k for k , v in model_state_dict .items ()} # get optimizer param mappings
197226 struct2static_name_mappings = {k : v .name for k , v in model_state_dict .items ()}
198227
199- expected_keys = []
200- param_slice_info = {}
201- param_shape_info = {}
202-
203228 comm_buffer_list = optimizer ._inner_opt ._comm_buffer_list
204229 if hasattr (args , "enable_sharding_comm_overlap" ) and args .enable_sharding_comm_overlap :
205230 comm_buffer_list = list (chain (* model ._chunk_2_comm_buffers .values ()))
206231 model = unwrap_model (model )
207232
208- for buffer in comm_buffer_list :
209- for key in buffer ._sharding_param_grad_view .keys ():
210- begin = buffer ._sharding_param_grad_view [key ]._param_begin
211- end = buffer ._sharding_param_grad_view [key ]._param_end
212- if end > begin :
213- expected_keys .append (key )
214- shape = buffer ._sharding_param_grad_view [key ]._param .shape
215- numel = buffer ._sharding_param_grad_view [key ]._param .numel ().item ()
216- index = buffer ._sharding_param_grad_view [key ]._index
217- padded_size = buffer ._sharding_param_grad_view [key ]._padded_size
218- param_slice_info [key ] = (begin , end )
219- param_shape_info [key ] = (shape , numel , index , padded_size )
220-
233+ expected_keys , param_slice_info , param_shape_info = get_params_info (comm_buffer_list )
221234 expected_keys = set ([static2struct_name_mappings .get (name , None ) for name in expected_keys ])
222235 expected_keys_optim = []
223236 for key in expected_keys :
@@ -291,7 +304,7 @@ def load_resolved_archive_file(
291304
292305 if int (state_dict_optim [key ].numel ()) > 1 :
293306 begin , end = param_slice_info [static_name ]
294- shape , numel , index , padded_size = param_shape_info [static_name ]
307+ _ , numel , index , padded_size = param_shape_info [static_name ]
295308 state_dict_optim [key ] = state_dict_optim [key ].reshape ([- 1 ])
296309 state_dict_optim [key ] = state_dict_optim [key ][begin - index : end - index ]
297310
@@ -330,7 +343,7 @@ def load_resolved_archive_file(
330343 static_name = struct2static_name_mappings .get (key , None )
331344 if int (state_dict_master_weight [key ].numel ()) > 1 :
332345 begin , end = param_slice_info [static_name ]
333- shape , numel , index , padded_size = param_shape_info [static_name ]
346+ _ , numel , index , padded_size = param_shape_info [static_name ]
334347 state_dict_master_weight [key ] = state_dict_master_weight [key ].reshape ([- 1 ])
335348 state_dict_master_weight [key ] = state_dict_master_weight [key ][begin - index : end - index ]
336349
@@ -357,3 +370,142 @@ def load_resolved_archive_file(
357370 returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
358371
359372 return returned_optim_state_dict
373+
374+
375+ def load_non_merge_optimizer_with_split_param (args , model , optimizer , resume_from_checkpoint , ckpt_quant_stage = "O0" ):
376+ returned_optim_state_dict = nested_copy (optimizer .state_dict ())
377+
378+ optimizer_name = _add_variant (SAFE_OPTIMIZER_NAME , args .optimizer_name_suffix )
379+ master_weights_name = _add_variant (SAFE_MASTER_WEIGHTS_NAME , args .optimizer_name_suffix )
380+ optimizer_path = os .path .join (resume_from_checkpoint , optimizer_name )
381+ master_weights_path = os .path .join (resume_from_checkpoint , master_weights_name )
382+
383+ # no quantization & no master weight represent O1 AMP strategy.
384+ is_amp_o1 = args .fp16_opt_level == "O1"
385+
386+ model_state_dict = get_expected_state_dict (model )
387+ static2struct_name_mappings = {v .name : k for k , v in model_state_dict .items ()} # get optimizer param mappings
388+ struct2static_name_mappings = {k : v .name for k , v in model_state_dict .items ()}
389+
390+ comm_buffer_list = optimizer ._inner_opt ._comm_buffer_list
391+ if hasattr (args , "enable_sharding_comm_overlap" ) and args .enable_sharding_comm_overlap :
392+ comm_buffer_list = list (chain (* model ._chunk_2_comm_buffers .values ()))
393+
394+ expected_keys , param_slice_info , param_shape_info = get_params_info (comm_buffer_list )
395+ expected_keys = set ([static2struct_name_mappings .get (name , None ) for name in expected_keys ])
396+ expected_keys_optim = []
397+ typename_set = set ()
398+ with safe_open (optimizer_path , framework = "numpy" ) as f :
399+ optim_keys = f .keys ()
400+ for key in optim_keys :
401+ _ , typename = key .split ("/" )
402+ typename_set .add (typename )
403+ for key in expected_keys :
404+ for typename in typename_set :
405+ expected_keys_optim .append (f"{ key } /{ typename } " )
406+ expected_keys_optim = set (expected_keys_optim )
407+
408+ optimizer_state_dict = load_state_dict (
409+ optimizer_path , None , None , device = "expected" , ckpt_quant_stage = ckpt_quant_stage
410+ )
411+ master_weights = {}
412+ # normal AMP O2
413+ if not is_amp_o1 and os .path .isfile (master_weights_path ):
414+ master_weights = load_state_dict (master_weights_path , None , None , device = "expected" )
415+
416+ def get_unfound_params (unfound_keys , state_dict , is_optimizer = True ):
417+ if len (unfound_keys ) > 0 :
418+ backup_files = []
419+ files = os .listdir (resume_from_checkpoint )
420+ name = optimizer_name if is_optimizer else master_weights_name
421+ name_without_shard = re .sub (r"_?shard\d+_?" , "" , name )
422+ name_ = "optimizer" if is_optimizer else "master_weights"
423+ for f in files :
424+ if f .startswith (name_ ) and f .endswith ("safetensors" ) and f != name :
425+ if re .sub (r"_?shard\d+_?" , "" , f ) == name_without_shard :
426+ backup_files .append (f )
427+ for f in backup_files :
428+ new_path = os .path .join (resume_from_checkpoint , f )
429+ with safe_open (new_path , framework = "numpy" ) as fin :
430+ keys = fin .keys ()
431+ for key in unfound_keys :
432+ if key in keys :
433+ tensor = fin .get_tensor (key )
434+ with device_guard ():
435+ tensor = paddle .Tensor (tensor , zero_copy = True )
436+ state_dict [key ] = tensor ._copy_to (paddle .framework ._current_expected_place (), False )
437+
438+ # Get other optimizer paramsters which maybe in other shard files.
439+ unfound_keys = expected_keys_optim - optimizer_state_dict .keys ()
440+ get_unfound_params (unfound_keys , optimizer_state_dict , True )
441+
442+ # Get other master weight parameters which maybe in other shard files.
443+ if master_weights != {}:
444+ unfound_keys = expected_keys - master_weights .keys ()
445+ get_unfound_params (unfound_keys , master_weights , False )
446+
447+ for key in list (optimizer_state_dict .keys ()):
448+ key_name = key .split ("/" )
449+ static_name = struct2static_name_mappings .get (key_name [0 ], None )
450+
451+ if int (optimizer_state_dict [key ].numel ()) > 1 :
452+ begin , end = param_slice_info [static_name ]
453+ _ , numel , index , padded_size = param_shape_info [static_name ]
454+ optimizer_state_dict [key ] = optimizer_state_dict [key ].reshape ([- 1 ])
455+ optimizer_state_dict [key ] = optimizer_state_dict [key ][begin - index : end - index ]
456+
457+ padding_start = max (begin , index + numel )
458+ padding_end = min (end , index + padded_size )
459+ if padding_start < padding_end :
460+ optimizer_state_dict [key ] = paddle .concat (
461+ (
462+ optimizer_state_dict [key ],
463+ paddle .zeros ([padding_end - padding_start ], dtype = optimizer_state_dict [key ].dtype ),
464+ )
465+ )
466+
467+ # rename and move to paddle.Tensor
468+ for key in list (optimizer_state_dict .keys ()):
469+ key_name = key .split ("/" )
470+ model_weight_key = key_name [0 ]
471+ static_name = struct2static_name_mappings [key_name [0 ]]
472+ if not is_amp_o1 :
473+ if model_state_dict [key_name [0 ]].dtype != paddle .float32 :
474+ key_name = "_" .join ([static_name , FP32_MASTER , key_name [1 ]])
475+ else :
476+ key_name = "_" .join ([static_name , key_name [1 ]])
477+ else :
478+ key_name = "_" .join ([static_name , key_name [1 ]])
479+ returned_optim_state_dict [key_name ] = optimizer_state_dict .pop (key )
480+ returned_optim_state_dict [key_name ].name = key_name
481+
482+ # master weight cast (only in AMP O2 + remove_master_weight)
483+ if not is_amp_o1 and not os .path .isfile (master_weights_path ):
484+ master_weights [model_weight_key ] = paddle .cast (model_state_dict [model_weight_key ], dtype = paddle .float32 )
485+
486+ if not is_amp_o1 :
487+ for key in list (master_weights .keys ()):
488+ static_name = struct2static_name_mappings .get (key , None )
489+ if int (master_weights [key ].numel ()) > 1 :
490+ begin , end = param_slice_info [static_name ]
491+ _ , numel , index , padded_size = param_shape_info [static_name ]
492+ master_weights [key ] = master_weights [key ].reshape ([- 1 ])
493+ master_weights [key ] = master_weights [key ][begin - index : end - index ]
494+
495+ padding_start = max (begin , index + numel )
496+ padding_end = min (end , index + padded_size )
497+ if padding_start < padding_end :
498+ master_weights [key ] = paddle .concat (
499+ (
500+ master_weights [key ],
501+ paddle .zeros ([padding_end - padding_start ], dtype = master_weights [key ].dtype ),
502+ )
503+ )
504+
505+ returned_optim_state_dict ["master_weights" ] = {}
506+ for key in list (master_weights .keys ()):
507+ static_name = struct2static_name_mappings [key ]
508+ returned_optim_state_dict ["master_weights" ][static_name ] = master_weights .pop (key )
509+ returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
510+
511+ return returned_optim_state_dict
0 commit comments