1515
1616import gc
1717import os
18+ import re
1819from itertools import chain
1920
2021import paddle
2122import paddle .distributed as dist
2223from paddle .distributed import fleet
24+ from safetensors import safe_open
2325from tqdm .auto import tqdm
2426
2527from paddlenlp .peft import LoRAModel , PrefixModelForCausalLM
26- from paddlenlp .transformers .model_utils import load_state_dict , unwrap_model
28+ from paddlenlp .transformers .model_utils import (
29+ _add_variant ,
30+ load_state_dict ,
31+ unwrap_model ,
32+ )
33+ from paddlenlp .transformers .utils import device_guard
2734from paddlenlp .utils .env import (
2835 SAFE_MASTER_WEIGHTS_INDEX_NAME ,
36+ SAFE_MASTER_WEIGHTS_NAME ,
2937 SAFE_OPTIMIZER_INDEX_NAME ,
38+ SAFE_OPTIMIZER_NAME ,
3039)
3140from paddlenlp .utils .nested import nested_copy
3241
@@ -175,6 +184,49 @@ def gather_splited_param_for_optimizer(optimizer, ckpt_quant_stage="O0"):
175184 return optim_state_dict , master_weights
176185
177186
187+ def get_params_info (comm_buffer_list ):
188+ expected_keys = []
189+ param_slice_info = {}
190+ param_shape_info = {}
191+
192+ for buffer in comm_buffer_list :
193+ for key in buffer ._sharding_param_grad_view .keys ():
194+ begin = buffer ._sharding_param_grad_view [key ]._param_begin
195+ end = buffer ._sharding_param_grad_view [key ]._param_end
196+ if end > begin :
197+ expected_keys .append (key )
198+ shape = buffer ._sharding_param_grad_view [key ]._param .shape
199+ numel = buffer ._sharding_param_grad_view [key ]._param .numel ().item ()
200+ index = buffer ._sharding_param_grad_view [key ]._index
201+ padded_size = buffer ._sharding_param_grad_view [key ]._padded_size
202+ param_slice_info [key ] = (begin , end )
203+ param_shape_info [key ] = (shape , numel , index , padded_size )
204+ return expected_keys , param_slice_info , param_shape_info
205+
206+
207+ def reshape_params (state_dict , struct2static_name_mappings , param_shape_info , param_slice_info ):
208+ """Reshape params to 1-D tensors"""
209+ for key in list (state_dict .keys ()):
210+ key_name = key .split ("/" )[0 ]
211+ static_name = struct2static_name_mappings .get (key_name , None )
212+ if int (state_dict [key ].numel ()) > 1 :
213+ begin , end = param_slice_info [static_name ]
214+ _ , numel , index , padded_size = param_shape_info [static_name ]
215+ state_dict [key ] = state_dict [key ].reshape ([- 1 ])
216+ state_dict [key ] = state_dict [key ][begin - index : end - index ]
217+
218+ padding_start = max (begin , index + numel )
219+ padding_end = min (end , index + padded_size )
220+ if padding_start < padding_end :
221+ state_dict [key ] = paddle .concat (
222+ (
223+ state_dict [key ],
224+ paddle .zeros ([padding_end - padding_start ], dtype = state_dict [key ].dtype ),
225+ )
226+ )
227+ return state_dict
228+
229+
178230def load_unified_optimizer_split_param (args , model , optimizer , resume_from_checkpoint , ckpt_quant_stage = "O0" ):
179231 returned_optim_state_dict = nested_copy (optimizer .state_dict ())
180232
@@ -196,28 +248,12 @@ def load_unified_optimizer_split_param(args, model, optimizer, resume_from_check
196248 static2struct_name_mappings = {v .name : k for k , v in model_state_dict .items ()} # get optimizer param mappings
197249 struct2static_name_mappings = {k : v .name for k , v in model_state_dict .items ()}
198250
199- expected_keys = []
200- param_slice_info = {}
201- param_shape_info = {}
202-
203251 comm_buffer_list = optimizer ._inner_opt ._comm_buffer_list
204252 if hasattr (args , "enable_sharding_comm_overlap" ) and args .enable_sharding_comm_overlap :
205253 comm_buffer_list = list (chain (* model ._chunk_2_comm_buffers .values ()))
206254 model = unwrap_model (model )
207255
208- for buffer in comm_buffer_list :
209- for key in buffer ._sharding_param_grad_view .keys ():
210- begin = buffer ._sharding_param_grad_view [key ]._param_begin
211- end = buffer ._sharding_param_grad_view [key ]._param_end
212- if end > begin :
213- expected_keys .append (key )
214- shape = buffer ._sharding_param_grad_view [key ]._param .shape
215- numel = buffer ._sharding_param_grad_view [key ]._param .numel ().item ()
216- index = buffer ._sharding_param_grad_view [key ]._index
217- padded_size = buffer ._sharding_param_grad_view [key ]._padded_size
218- param_slice_info [key ] = (begin , end )
219- param_shape_info [key ] = (shape , numel , index , padded_size )
220-
256+ expected_keys , param_slice_info , param_shape_info = get_params_info (comm_buffer_list )
221257 expected_keys = set ([static2struct_name_mappings .get (name , None ) for name in expected_keys ])
222258 expected_keys_optim = []
223259 for key in expected_keys :
@@ -285,25 +321,10 @@ def load_resolved_archive_file(
285321 )
286322
287323 # need to split param for different sharding rank, maybe need to deal with oom issue.
324+ reshape_params (state_dict_optim , struct2static_name_mappings , param_shape_info , param_slice_info )
288325 for key in list (state_dict_optim .keys ()):
289326 key_name = key .split ("/" )
290327 static_name = struct2static_name_mappings .get (key_name [0 ], None )
291-
292- if int (state_dict_optim [key ].numel ()) > 1 :
293- begin , end = param_slice_info [static_name ]
294- shape , numel , index , padded_size = param_shape_info [static_name ]
295- state_dict_optim [key ] = state_dict_optim [key ].reshape ([- 1 ])
296- state_dict_optim [key ] = state_dict_optim [key ][begin - index : end - index ]
297-
298- padding_start = max (begin , index + numel )
299- padding_end = min (end , index + padded_size )
300- if padding_start < padding_end :
301- state_dict_optim [key ] = paddle .concat (
302- (
303- state_dict_optim [key ],
304- paddle .zeros ([padding_end - padding_start ], dtype = state_dict_optim [key ].dtype ),
305- )
306- )
307328 if has_master_weights :
308329 if model_state_dict [key_name [0 ]].dtype != paddle .float32 :
309330 key_name = "_" .join ([static_name , FP32_MASTER , key_name [1 ]])
@@ -325,24 +346,10 @@ def load_resolved_archive_file(
325346 expected_keys ,
326347 is_master_weights = True ,
327348 )
349+ reshape_params (state_dict_master_weight , struct2static_name_mappings , param_shape_info , param_slice_info )
328350
329351 for key in list (state_dict_master_weight .keys ()):
330352 static_name = struct2static_name_mappings .get (key , None )
331- if int (state_dict_master_weight [key ].numel ()) > 1 :
332- begin , end = param_slice_info [static_name ]
333- shape , numel , index , padded_size = param_shape_info [static_name ]
334- state_dict_master_weight [key ] = state_dict_master_weight [key ].reshape ([- 1 ])
335- state_dict_master_weight [key ] = state_dict_master_weight [key ][begin - index : end - index ]
336-
337- padding_start = max (begin , index + numel )
338- padding_end = min (end , index + padded_size )
339- if padding_start < padding_end :
340- state_dict_master_weight [key ] = paddle .concat (
341- (
342- state_dict_master_weight [key ],
343- paddle .zeros ([padding_end - padding_start ], dtype = state_dict_master_weight [key ].dtype ),
344- )
345- )
346353 state_dict_master_weight [key ] = state_dict_master_weight [key ]._copy_to (
347354 paddle .framework ._current_expected_place (), False
348355 )
@@ -357,3 +364,113 @@ def load_resolved_archive_file(
357364 returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
358365
359366 return returned_optim_state_dict
367+
368+
369+ def load_non_merge_optimizer_with_split_param (args , model , optimizer , resume_from_checkpoint , ckpt_quant_stage = "O0" ):
370+ returned_optim_state_dict = nested_copy (optimizer .state_dict ())
371+
372+ optimizer_name = _add_variant (SAFE_OPTIMIZER_NAME , args .optimizer_name_suffix )
373+ master_weights_name = _add_variant (SAFE_MASTER_WEIGHTS_NAME , args .optimizer_name_suffix )
374+ optimizer_path = os .path .join (resume_from_checkpoint , optimizer_name )
375+ master_weights_path = os .path .join (resume_from_checkpoint , master_weights_name )
376+
377+ # no quantization & no master weight represent O1 AMP strategy.
378+ is_amp_o1 = args .fp16_opt_level == "O1"
379+
380+ model_state_dict = get_expected_state_dict (model )
381+ static2struct_name_mappings = {v .name : k for k , v in model_state_dict .items ()} # get optimizer param mappings
382+ struct2static_name_mappings = {k : v .name for k , v in model_state_dict .items ()}
383+
384+ comm_buffer_list = optimizer ._inner_opt ._comm_buffer_list
385+ if hasattr (args , "enable_sharding_comm_overlap" ) and args .enable_sharding_comm_overlap :
386+ comm_buffer_list = list (chain (* model ._chunk_2_comm_buffers .values ()))
387+
388+ expected_keys , param_slice_info , param_shape_info = get_params_info (comm_buffer_list )
389+ expected_keys = set ([static2struct_name_mappings .get (name , None ) for name in expected_keys ])
390+ expected_keys_optim = []
391+ sharding_typename_set , typename_set = [], []
392+ with safe_open (optimizer_path , framework = "numpy" ) as f :
393+ optim_keys = f .keys ()
394+ for key in optim_keys :
395+ _ , typename = key .split ("/" )
396+ typename_set .append (typename )
397+
398+ # To avoid incomplete typename in some shard files, communication is performed.
399+ hcg = fleet .get_hybrid_communicate_group ()
400+ sharding_group = hcg .get_sharding_parallel_group ()
401+ dist .all_gather_object (sharding_typename_set , typename_set , sharding_group )
402+ typename_set = set (chain (* sharding_typename_set ))
403+ for key in expected_keys :
404+ for typename in typename_set :
405+ expected_keys_optim .append (f"{ key } /{ typename } " )
406+ expected_keys_optim = set (expected_keys_optim )
407+
408+ optimizer_state_dict = load_state_dict (
409+ optimizer_path , None , None , device = "expected" , ckpt_quant_stage = ckpt_quant_stage
410+ )
411+ master_weights = {}
412+ # normal AMP O2
413+ if not is_amp_o1 and os .path .isfile (master_weights_path ):
414+ master_weights = load_state_dict (master_weights_path , None , None , device = "expected" )
415+
416+ def get_unfound_params (unfound_keys , state_dict , is_optimizer = True ):
417+ if len (unfound_keys ) > 0 :
418+ backup_files = []
419+ files = os .listdir (resume_from_checkpoint )
420+ name = optimizer_name if is_optimizer else master_weights_name
421+ name_without_shard = re .sub (r"_?shard\d+_?" , "" , name )
422+ name_ = "optimizer" if is_optimizer else "master_weights"
423+ for f in files :
424+ if f .startswith (name_ ) and f .endswith ("safetensors" ) and f != name :
425+ if re .sub (r"_?shard\d+_?" , "" , f ) == name_without_shard :
426+ backup_files .append (f )
427+ for f in backup_files :
428+ new_path = os .path .join (resume_from_checkpoint , f )
429+ with safe_open (new_path , framework = "numpy" ) as fin :
430+ keys = fin .keys ()
431+ for key in unfound_keys :
432+ if key in keys :
433+ tensor = fin .get_tensor (key )
434+ with device_guard ():
435+ tensor = paddle .Tensor (tensor , zero_copy = True )
436+ state_dict [key ] = tensor ._copy_to (paddle .framework ._current_expected_place (), False )
437+
438+ # Get other optimizer paramsters which maybe in other shard files.
439+ unfound_keys = expected_keys_optim - optimizer_state_dict .keys ()
440+ get_unfound_params (unfound_keys , optimizer_state_dict , True )
441+
442+ # Get other master weight parameters which maybe in other shard files.
443+ if master_weights != {}:
444+ unfound_keys = expected_keys - master_weights .keys ()
445+ get_unfound_params (unfound_keys , master_weights , False )
446+ reshape_params (optimizer_state_dict , struct2static_name_mappings , param_shape_info , param_slice_info )
447+
448+ # rename and move to paddle.Tensor
449+ for key in list (optimizer_state_dict .keys ()):
450+ key_name = key .split ("/" )
451+ model_weight_key = key_name [0 ]
452+ static_name = struct2static_name_mappings [key_name [0 ]]
453+ if not is_amp_o1 :
454+ if model_state_dict [key_name [0 ]].dtype != paddle .float32 :
455+ key_name = "_" .join ([static_name , FP32_MASTER , key_name [1 ]])
456+ else :
457+ key_name = "_" .join ([static_name , key_name [1 ]])
458+ else :
459+ key_name = "_" .join ([static_name , key_name [1 ]])
460+ returned_optim_state_dict [key_name ] = optimizer_state_dict .pop (key )
461+ returned_optim_state_dict [key_name ].name = key_name
462+
463+ # master weight cast (only in AMP O2 + remove_master_weight)
464+ if not is_amp_o1 and not os .path .isfile (master_weights_path ):
465+ master_weights [model_weight_key ] = paddle .cast (model_state_dict [model_weight_key ], dtype = paddle .float32 )
466+
467+ if not is_amp_o1 :
468+ reshape_params (master_weights , struct2static_name_mappings , param_shape_info , param_slice_info )
469+
470+ returned_optim_state_dict ["master_weights" ] = {}
471+ for key in list (master_weights .keys ()):
472+ static_name = struct2static_name_mappings [key ]
473+ returned_optim_state_dict ["master_weights" ][static_name ] = master_weights .pop (key )
474+ returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
475+
476+ return returned_optim_state_dict
0 commit comments