3030from paddlenlp .transformers .model_utils import (
3131 PretrainedModel ,
3232 _load_state_dict_into_model ,
33+ faster_set_state_dict ,
3334 get_parameter_dtype ,
3435 load_state_dict ,
3536 unwrap_model ,
6465from paddlenlp .utils .log import logger
6566
6667if is_safetensors_available ():
67- from safetensors import safe_open
68+ # from safetensors import safe_open
6869 from safetensors .numpy import save_file as safe_save_file
6970
71+ from paddlenlp .utils .safetensors import fast_safe_open as safe_open
7072
7173FP32_MASTER = "fp32_master_0"
7274optimizer_scalar_name = [
@@ -195,7 +197,6 @@ def load_unified_checkpoint(args, model, optimizer, resume_from_checkpoint: str,
195197 Returns:
196198 None
197199 """
198-
199200 if paddle .distributed .get_world_size () <= 1 :
200201 load_single_card_checkpoint (args , model , resume_from_checkpoint )
201202 return
@@ -221,7 +222,6 @@ def load_unified_checkpoint_locally(args, model, resume_from_checkpoint: str, sa
221222 pretrained_model_name_or_path = resume_from_checkpoint ,
222223 index_filename = os .path .join (resume_from_checkpoint , index_filename ),
223224 )
224-
225225 loaded_keys = sharded_metadata ["all_checkpoint_keys" ]
226226
227227 model_state_dict = get_expected_state_dict (model )
@@ -265,7 +265,9 @@ def _remove_unused_keys(
265265 else :
266266 tp_actions = model .get_tensor_parallel_convert_actions (model .config , loaded_keys , ignore_error = True )
267267 # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
268- state_dict = load_state_dict (shard_file , tp_actions if pre_tensor_parallel_split else None , expected_keys )
268+ state_dict = load_state_dict (
269+ shard_file , tp_actions if pre_tensor_parallel_split else None , expected_keys , device = "expected"
270+ )
269271
270272 if not pre_tensor_parallel_split :
271273 # Since we load all keys but we only need one of pipeline stages
@@ -278,11 +280,12 @@ def _remove_unused_keys(
278280 None , model .config , state_dict = state_dict , ignore_error = len (resolved_archive_file ) > 1
279281 )
280282
281- error_msgs += _load_state_dict_into_model (model , state_dict , "" )
283+ # error_msgs += _load_state_dict_into_model(model, state_dict, "")
284+ error_msgs += faster_set_state_dict (model , state_dict , strict_dtype = False )
282285
283286 # force memory release
284287 del state_dict
285- gc .collect ()
288+ # gc.collect()
286289
287290 if len (error_msgs ) > 0 :
288291 error_msg = "\n \t " .join (error_msgs )
@@ -336,6 +339,7 @@ def unified_checkpoint_into_shards(
336339 tp_actions = model_to_save .get_tensor_parallel_convert_actions (
337340 model_to_save .config , state_dict .keys (), is_split = False , ignore_error = True
338341 )
342+ logger .info ("Unified model tensor parallel weights in shards" )
339343 state_dict = merge_tensor_parallel_with_shard (state_dict , tp_actions , all_filter_keys )
340344
341345 # build index json file
@@ -489,6 +493,7 @@ def load_unified_optimizer_locally(args, model, optimizer, resume_from_checkpoin
489493 # This should always be a list but, just to be sure.
490494 if not isinstance (resolved_archive_file , list ):
491495 resolved_archive_file = [resolved_archive_file ]
496+
492497 if len (resolved_archive_file ) > 1 :
493498 resolved_archive_file = tqdm (resolved_archive_file , desc = "Loading optimizer shards" )
494499
@@ -536,10 +541,10 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
536541 tp_actions = mapping_optimizer_tp_actions (tp_actions , expected_keys )
537542
538543 # Here we use expected_keys to optimize weights loading for pipeline model. Only works for safetensors
539- state_dict = load_state_dict (shard_file , tp_actions , expected_keys )
544+ state_dict = load_state_dict (shard_file , tp_actions , expected_keys , device = "expected" )
540545 else :
541546 # for pipeline model, we don't need to use tp_actions
542- state_dict = load_state_dict (shard_file , None , expected_keys )
547+ state_dict = load_state_dict (shard_file , None , expected_keys , device = "expected" )
543548
544549 returned_state_dict .update (state_dict )
545550 # force memory release
@@ -552,7 +557,6 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
552557 state_dict_master_weight = load_resolved_archive_file (
553558 resolved_archive_file_mw , sharded_metadata_mw , expected_keys_mw , is_master_weights = True
554559 )
555-
556560 # rename optimizer param
557561 for key in list (state_dict_optim .keys ()):
558562 key_name = key .split ("/" )
@@ -561,13 +565,13 @@ def load_resolved_archive_file(resolved_archive_file, sharded_metadata, expected
561565 key_name = "_" .join ([static_name , FP32_MASTER , key_name [1 ]])
562566 else :
563567 key_name = "_" .join ([static_name , key_name [1 ]])
564- returned_optim_state_dict [key_name ] = state_dict_optim [ key ]
568+ returned_optim_state_dict [key_name ] = state_dict_optim . pop ( key )
565569 returned_optim_state_dict [key_name ].name = key_name
566570
567571 if has_master_weights :
568572 for key in list (state_dict_master_weight .keys ()):
569573 static_name = struct2static_name_mappings [key ]
570- returned_optim_state_dict ["master_weights" ][static_name ] = state_dict_master_weight [ key ]
574+ returned_optim_state_dict ["master_weights" ][static_name ] = state_dict_master_weight . pop ( key )
571575 returned_optim_state_dict ["master_weights" ][static_name ].name = "_" .join ([static_name , FP32_MASTER ])
572576
573577 returned_optim_state_dict = nested_copy_place (
@@ -639,6 +643,7 @@ def unified_optimizer_into_shards(
639643 tp_actions = model .get_tensor_parallel_convert_actions (
640644 model .config , model_keys , is_split = False , ignore_error = True
641645 )
646+ logger .info ("Unified optimizer tensor parallel in shards" )
642647 optim_state_dict = merge_tensor_parallel_for_optimizer (
643648 optim_state_dict ,
644649 tp_actions ,
@@ -647,6 +652,7 @@ def unified_optimizer_into_shards(
647652 paddle .device .cuda .empty_cache ()
648653
649654 if master_weights is not None :
655+ logger .info ("Unified master weight tensor parallel in shards" )
650656 master_weights = merge_tensor_parallel_for_optimizer (
651657 master_weights ,
652658 tp_actions ,
@@ -702,7 +708,6 @@ def unified_optimizer_into_shards(
702708def check_unified_checkpoint (args , model , resume_from_checkpoint , safe_serialization = False ):
703709 index_filename = select_model_weight_index (args , model , resume_from_checkpoint , safe_serialization , local = False )
704710 index_filename = os .path .join (resume_from_checkpoint , index_filename )
705-
706711 # Find index json file and distribute this file in global group.
707712 if distributed_isfile (index_filename ):
708713 distributed_file (index_filename )
@@ -1604,7 +1609,9 @@ def gather_sharded_object(index_file, total_size, is_optimizer=False):
16041609 tp_group = hcg .get_model_parallel_group ()
16051610 pp_group = hcg .get_pipe_parallel_group ()
16061611
1607- logger .info ("Unified checkpoint generating sharded_index json files." )
1612+ logger .info (
1613+ f"Unified checkpoint: generating sharded_index json files for { 'optimizer or master weight' if is_optimizer else 'model weight' } ."
1614+ )
16081615
16091616 if tp_group .nranks > 1 :
16101617 dist .all_gather_object (index_file_list , index_file , tp_group )
@@ -1713,8 +1720,6 @@ def filter_params(model_to_save, state_dict, is_optimizer=False):
17131720
17141721
17151722def merge_tensor_parallel_with_shard (state_dict , tp_actions , all_filter_keys ):
1716- logger .info ("Unified checkpoint merge tensor parallel in shards" )
1717-
17181723 hcg = fleet .get_hybrid_communicate_group ()
17191724 tp_group = hcg .get_model_parallel_group ()
17201725 tp_rank = tp_group .rank
@@ -1740,7 +1745,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17401745 action = tp_actions .pop (key )
17411746 tensor = action (ret ) if is_dst else None
17421747 else :
1743- tensor = tensor ._copy_to (paddle .CPUPlace (), False ) if is_dst else None
1748+ tensor = tensor ._copy_to (paddle .CUDAPinnedPlace (), False ) if is_dst else None
17441749
17451750 if is_dst :
17461751 state_dict_to_save [key ] = tensor
@@ -1753,8 +1758,7 @@ def merge_tensor_parallel_with_shard(state_dict, tp_actions, all_filter_keys):
17531758
17541759
17551760def merge_tensor_parallel_for_optimizer (state_dict , tp_actions , all_filter_keys ):
1756- logger .info ("Unified optimizer tensor parallel in shards" )
1757-
1761+ # Core function for UC
17581762 hcg = fleet .get_hybrid_communicate_group ()
17591763 tp_group = hcg .get_model_parallel_group ()
17601764 tp_rank = tp_group .rank
@@ -1773,14 +1777,14 @@ def merge_tensor_parallel_for_optimizer(state_dict, tp_actions, all_filter_keys)
17731777 # for example: beta1, beta2
17741778 if tensor .numel ().item () == 1 :
17751779 tensor = (
1776- tensor ._copy_to (paddle .CPUPlace (), False ) if is_dst else None
1780+ tensor ._copy_to (paddle .CUDAPinnedPlace (), False ) if is_dst else None
17771781 ) # Need broadcast when loaded
17781782 else :
17791783 ret = distributed_gather (tensor , dst = j , group = tp_group , offload = False )
17801784 action = tp_actions [model_key ]
17811785 tensor = action (ret ) if is_dst else None
17821786 else :
1783- tensor = tensor ._copy_to (paddle .CPUPlace (), False ) if is_dst else None
1787+ tensor = tensor ._copy_to (paddle .CUDAPinnedPlace (), False ) if is_dst else None
17841788
17851789 if is_dst :
17861790 state_dict_to_save [filter_keys [i ]] = tensor
@@ -1892,7 +1896,10 @@ def nested_copy_place(inputs, place=None, blocking=False):
18921896 outputs [key ] = nested_copy_place (inputs [key ], place , blocking )
18931897 return outputs
18941898 if isinstance (inputs , paddle .Tensor ):
1895- inputs = inputs if inputs .place == place else inputs ._copy_to (place , blocking )
1899+ if inputs .place ._equals (place ):
1900+ return inputs
1901+ else :
1902+ return inputs ._copy_to (place , blocking )
18961903 return inputs
18971904
18981905
0 commit comments