2020import paddle
2121import paddle .distributed as dist
2222from paddle .distributed import fleet
23+ from safetensors import safe_open
2324from tqdm .auto import tqdm
2425
2526from paddlenlp .peft import LoRAModel , PrefixModelForCausalLM
26- from paddlenlp .transformers .model_utils import load_state_dict , unwrap_model
27+ from paddlenlp .transformers .model_utils import (
28+ _add_variant ,
29+ load_state_dict ,
30+ unwrap_model ,
31+ )
2732from paddlenlp .utils .env import (
2833 SAFE_MASTER_WEIGHTS_INDEX_NAME ,
34+ SAFE_MASTER_WEIGHTS_NAME ,
2935 SAFE_OPTIMIZER_INDEX_NAME ,
36+ SAFE_OPTIMIZER_NAME ,
3037)
3138from paddlenlp .utils .nested import nested_copy
3239
@@ -175,6 +182,26 @@ def gather_splited_param_for_optimizer(optimizer, ckpt_quant_stage="O0"):
175182 return optim_state_dict , master_weights
176183
177184
185+ def get_params_info (comm_buffer_list ):
186+ expected_keys = []
187+ param_slice_info = {}
188+ param_shape_info = {}
189+
190+ for buffer in comm_buffer_list :
191+ for key in buffer ._sharding_param_grad_view .keys ():
192+ begin = buffer ._sharding_param_grad_view [key ]._param_begin
193+ end = buffer ._sharding_param_grad_view [key ]._param_end
194+ if end > begin :
195+ expected_keys .append (key )
196+ shape = buffer ._sharding_param_grad_view [key ]._param .shape
197+ numel = buffer ._sharding_param_grad_view [key ]._param .numel ().item ()
198+ index = buffer ._sharding_param_grad_view [key ]._index
199+ padded_size = buffer ._sharding_param_grad_view [key ]._padded_size
200+ param_slice_info [key ] = (begin , end )
201+ param_shape_info [key ] = (shape , numel , index , padded_size )
202+ return expected_keys , param_slice_info , param_shape_info
203+
204+
178205def load_unified_optimizer_split_param (args , model , optimizer , resume_from_checkpoint , ckpt_quant_stage = "O0" ):
179206 returned_optim_state_dict = nested_copy (optimizer .state_dict ())
180207
@@ -196,28 +223,12 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
196223 static2struct_name_mappings = {v .name : k for k , v in model_state_dict .items ()} # get optimizer param mappings
197224 struct2static_name_mappings = {k : v .name for k , v in model_state_dict .items ()}
198225
199- expected_keys = []
200- param_slice_info = {}
201- param_shape_info = {}
202-
203226 comm_buffer_list = optimizer ._inner_opt ._comm_buffer_list
204227 if hasattr (args , "enable_sharding_comm_overlap" ) and args .enable_sharding_comm_overlap :
205228 comm_buffer_list = list (chain (* model ._chunk_2_comm_buffers .values ()))
206229 model = unwrap_model (model )
207230
208- for buffer in comm_buffer_list :
209- for key in buffer ._sharding_param_grad_view .keys ():
210- begin = buffer ._sharding_param_grad_view [key ]._param_begin
211- end = buffer ._sharding_param_grad_view [key ]._param_end
212- if end > begin :
213- expected_keys .append (key )
214- shape = buffer ._sharding_param_grad_view [key ]._param .shape
215- numel = buffer ._sharding_param_grad_view [key ]._param .numel ().item ()
216- index = buffer ._sharding_param_grad_view [key ]._index
217- padded_size = buffer ._sharding_param_grad_view [key ]._padded_size
218- param_slice_info [key ] = (begin , end )
219- param_shape_info [key ] = (shape , numel , index , padded_size )
220-
231+ expected_keys , param_slice_info , param_shape_info = get_params_info (comm_buffer_list )
221232 expected_keys = set ([static2struct_name_mappings .get (name , None ) for name in expected_keys ])
222233 expected_keys_optim = []
223234 for key in expected_keys :
@@ -291,7 +302,7 @@ def load_resolved_archive_file(
291302
292303 if int (state_dict_optim [key ].numel ()) > 1 :
293304 begin , end = param_slice_info [static_name ]
294- shape , numel , index , padded_size = param_shape_info [static_name ]
305+ _ , numel , index , padded_size = param_shape_info [static_name ]
295306 state_dict_optim [key ] = state_dict_optim [key ].reshape ([- 1 ])
296307 state_dict_optim [key ] = state_dict_optim [key ][begin - index : end - index ]
297308
@@ -330,7 +341,7 @@ def load_resolved_archive_file(
330341 static_name = struct2static_name_mappings .get (key , None )
331342 if int (state_dict_master_weight [key ].numel ()) > 1 :
332343 begin , end = param_slice_info [static_name ]
333- shape , numel , index , padded_size = param_shape_info [static_name ]
344+ _ , numel , index , padded_size = param_shape_info [static_name ]
334345 state_dict_master_weight [key ] = state_dict_master_weight [key ].reshape ([- 1 ])
335346 state_dict_master_weight [key ] = state_dict_master_weight [key ][begin - index : end - index ]
336347
@@ -357,3 +368,122 @@ def load_resolved_archive_file(
357368 returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
358369
359370 return returned_optim_state_dict
371+
372+
373+ def load_non_merge_optimizer_with_split_param (args , model , optimizer , resume_from_checkpoint , ckpt_quant_stage = "O0" ):
374+ returned_optim_state_dict = nested_copy (optimizer .state_dict ())
375+
376+ optimizer_name = _add_variant (SAFE_OPTIMIZER_NAME , args .optimizer_name_suffix )
377+ master_weights_name = _add_variant (SAFE_MASTER_WEIGHTS_NAME , args .optimizer_name_suffix )
378+ optimizer_path = os .path .join (resume_from_checkpoint , optimizer_name )
379+ master_weights_path = os .path .join (resume_from_checkpoint , master_weights_name )
380+
381+ # no quantization & no master weight represent O1 AMP strategy.
382+ is_amp_o1 = args .fp16_opt_level == "O1"
383+
384+ model_state_dict = get_expected_state_dict (model )
385+ static2struct_name_mappings = {v .name : k for k , v in model_state_dict .items ()} # get optimizer param mappings
386+ struct2static_name_mappings = {k : v .name for k , v in model_state_dict .items ()}
387+
388+ comm_buffer_list = optimizer ._inner_opt ._comm_buffer_list
389+ if hasattr (args , "enable_sharding_comm_overlap" ) and args .enable_sharding_comm_overlap :
390+ comm_buffer_list = list (chain (* model ._chunk_2_comm_buffers .values ()))
391+
392+ expected_keys , param_slice_info , param_shape_info = get_params_info (comm_buffer_list )
393+ expected_keys = set ([static2struct_name_mappings .get (name , None ) for name in expected_keys ])
394+ expected_keys_optim = []
395+ typename_set = set ()
396+ with safe_open (optimizer_path , framework = "numpy" ) as f :
397+ optim_keys = f .keys ()
398+ for key in optim_keys :
399+ _ , typename = key .split ("/" )
400+ typename_set .add (typename )
401+ for key in expected_keys :
402+ for typename in typename_set :
403+ expected_keys_optim .append (f"{ key } /{ typename } " )
404+ expected_keys_optim = set (expected_keys_optim )
405+
406+ optimizer_state_dict = load_state_dict (
407+ optimizer_path , None , None , device = "expected" , ckpt_quant_stage = ckpt_quant_stage
408+ )
409+ master_weights = {}
410+ # normal AMP O2
411+ if not is_amp_o1 and os .path .isfile (master_weights_path ):
412+ master_weights = load_state_dict (master_weights_path , None , None , device = "expected" )
413+
414+ # Get other param slice which maybe in other shard files.
415+ unfound_keys = expected_keys_optim - optimizer_state_dict .keys ()
416+ if len (unfound_keys ) > 0 :
417+ backup_files = []
418+ files = os .listdir (resume_from_checkpoint )
419+ for f in files :
420+ if f .startswith ("optimizer" ) and f .endswith ("safetensors" ):
421+ backup_files .append (f )
422+ print (backup_files )
423+ raise ValueError
424+
425+ for key in list (optimizer_state_dict .keys ()):
426+ key_name = key .split ("/" )
427+ static_name = struct2static_name_mappings .get (key_name [0 ], None )
428+
429+ if int (optimizer_state_dict [key ].numel ()) > 1 :
430+ begin , end = param_slice_info [static_name ]
431+ _ , numel , index , padded_size = param_shape_info [static_name ]
432+ optimizer_state_dict [key ] = optimizer_state_dict [key ].reshape ([- 1 ])
433+ optimizer_state_dict [key ] = optimizer_state_dict [key ][begin - index : end - index ]
434+
435+ padding_start = max (begin , index + numel )
436+ padding_end = min (end , index + padded_size )
437+ if padding_start < padding_end :
438+ optimizer_state_dict [key ] = paddle .concat (
439+ (
440+ optimizer_state_dict [key ],
441+ paddle .zeros ([padding_end - padding_start ], dtype = optimizer_state_dict [key ].dtype ),
442+ )
443+ )
444+
445+ # rename and move to paddle.Tensor
446+ for key in list (optimizer_state_dict .keys ()):
447+ key_name = key .split ("/" )
448+ model_weight_key = key_name [0 ]
449+ static_name = struct2static_name_mappings [key_name [0 ]]
450+ if not is_amp_o1 :
451+ if model_state_dict [key_name [0 ]].dtype != paddle .float32 :
452+ key_name = "_" .join ([static_name , FP32_MASTER , key_name [1 ]])
453+ else :
454+ key_name = "_" .join ([static_name , key_name [1 ]])
455+ else :
456+ key_name = "_" .join ([static_name , key_name [1 ]])
457+ returned_optim_state_dict [key_name ] = optimizer_state_dict .pop (key )
458+ returned_optim_state_dict [key_name ].name = key_name
459+
460+ # master weight cast (only in AMP O2 + remove_master_weight)
461+ if not is_amp_o1 and not os .path .isfile (master_weights_path ):
462+ master_weights [model_weight_key ] = paddle .cast (model_state_dict [model_weight_key ], dtype = paddle .float32 )
463+
464+ if not is_amp_o1 :
465+ for key in list (master_weights .keys ()):
466+ static_name = struct2static_name_mappings .get (key , None )
467+ if int (master_weights [key ].numel ()) > 1 :
468+ begin , end = param_slice_info [static_name ]
469+ _ , numel , index , padded_size = param_shape_info [static_name ]
470+ master_weights [key ] = master_weights [key ].reshape ([- 1 ])
471+ master_weights [key ] = master_weights [key ][begin - index : end - index ]
472+
473+ padding_start = max (begin , index + numel )
474+ padding_end = min (end , index + padded_size )
475+ if padding_start < padding_end :
476+ master_weights [key ] = paddle .concat (
477+ (
478+ master_weights [key ],
479+ paddle .zeros ([padding_end - padding_start ], dtype = master_weights [key ].dtype ),
480+ )
481+ )
482+
483+ returned_optim_state_dict ["master_weights" ] = {}
484+ for key in list (master_weights .keys ()):
485+ static_name = struct2static_name_mappings [key ]
486+ returned_optim_state_dict ["master_weights" ][static_name ] = master_weights .pop (key )
487+ returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
488+
489+ return returned_optim_state_dict
0 commit comments