1313# limitations under the License.
1414from __future__ import annotations
1515
16+ import concurrent .futures
1617import contextlib
1718import copy
1819import gc
@@ -319,6 +320,65 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype:
319320 return last_dtype
320321
321322
323+ def _split_keys_evenly (keys : list , n : int ) -> list :
324+ """Split a list into n lists with an equal number of elements.
325+
326+ Args:
327+ keys (list): the list to be split
328+ n (int): number of splits
329+
330+ Returns:
331+ result: list of lists
332+ """
333+
334+ total_len = len (keys )
335+ base_size = total_len // n
336+ extra = total_len % n
337+
338+ result = []
339+ index = 0
340+ for _ in range (n ):
341+ part_size = base_size + 1 if extra > 0 else base_size
342+ extra -= 1
343+ result .append (keys [index : index + part_size ])
344+ index += part_size
345+
346+ return result
347+
348+
349+ def _load_part_state_dict (
350+ keys , checkpoint_file : Union [str , os .PathLike ], tensor_parallel_split_mapping , fliter_dict_keys , device
351+ ):
352+ """load part state dict from checkpoint file.
353+
354+ Args:
355+ keys (list): the keys of part state dict
356+ checkpoint_file (str): the path of checkpoint file
357+ tensor_parallel_split_mapping (dict): mapping from key to function
358+ fliter_dict_keys (list): filter keys in state dict
359+
360+ Returns:
361+ part_state_dict (dict): the part state dict
362+
363+ """
364+ part_state_dict = {}
365+ with safe_open (checkpoint_file , framework = "np" ) as f :
366+ for key in keys :
367+ if fliter_dict_keys is not None and key not in fliter_dict_keys :
368+ continue
369+ py_safe_slice_ = f .get_slice (key )
370+ if key in tensor_parallel_split_mapping :
371+ weight = tensor_parallel_split_mapping [key ](py_safe_slice_ )
372+ else :
373+ weight = py_safe_slice_ [:]
374+ if device == "expected" :
375+ with device_guard ():
376+ weight = paddle .Tensor (weight , zero_copy = True )
377+ weight = weight ._copy_to (paddle .framework ._current_expected_place (), False )
378+ part_state_dict [key ] = weight
379+ return part_state_dict
380+
381+
322382def load_state_dict (
323383 checkpoint_file : Union [str , os .PathLike ], tensor_parallel_split_mapping = None , fliter_dict_keys = None , device = "cpu"
324384):
@@ -343,21 +403,36 @@ def load_state_dict(
343403 if metadata .get ("format" , "np" ) == "pd" :
344404 raise ValueError ("Currently unsupport paddle weights file, use numpy instead." )
345405 if metadata .get ("format" , "np" ) == "np" :
406+ thread_num = int (os .environ .get ("LOAD_STATE_DICT_THREAD_NUM" , "1" ))
346407 state_dict = {}
347- with safe_open (checkpoint_file , framework = "np" ) as f :
348- for key in f .keys ():
349- if fliter_dict_keys is not None and key not in fliter_dict_keys :
350- continue
351- py_safe_slice_ = f .get_slice (key )
352- if key in tensor_parallel_split_mapping :
353- weight = tensor_parallel_split_mapping [key ](py_safe_slice_ )
354- else :
355- weight = py_safe_slice_ [:]
356- if device == "expected" :
357- with device_guard ():
358- weight = paddle .Tensor (weight , zero_copy = True )
359- weight = weight ._copy_to (paddle .framework ._current_expected_place (), False )
360- state_dict [key ] = weight
408+ if thread_num <= 1 :
409+ with safe_open (checkpoint_file , framework = "np" ) as f :
410+ state_dict = _load_part_state_dict (
411+ list (f .keys ()),
412+ checkpoint_file ,
413+ tensor_parallel_split_mapping ,
414+ fliter_dict_keys ,
415+ device ,
416+ )
417+ else :
418+ # Load state dict in multi-thread to speed up loading
419+ with safe_open (checkpoint_file , framework = "np" ) as f :
420+ keys_groups = _split_keys_evenly (list (f .keys ()), thread_num )
421+ with concurrent .futures .ThreadPoolExecutor (max_workers = thread_num ) as executor :
422+ future_to_key = {
423+ executor .submit (
424+ _load_part_state_dict ,
425+ keys ,
426+ checkpoint_file ,
427+ tensor_parallel_split_mapping ,
428+ fliter_dict_keys ,
429+ device ,
430+ ): keys
431+ for keys in keys_groups
432+ }
433+ for future in concurrent .futures .as_completed (future_to_key ):
434+ result = future .result ()
435+ state_dict .update (result )
361436
362437 if device == "cpu" :
363438 for k in list (state_dict .keys ()):
@@ -1963,7 +2038,6 @@ def _fuse_or_split_keys(
19632038
19642039 if config .quantization_config .is_weight_quantize ():
19652040 filter_dict_keys = None
1966-
19672041 state_dict = load_state_dict (
19682042 shard_file , tp_actions if pre_tensor_parallel_split else None , filter_dict_keys
19692043 )
@@ -2279,7 +2353,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
22792353 else :
22802354 raise ValueError (f"Unexpected file: { resolved_archive_file } for weight conversion." )
22812355 # load pt weights early so that we know which dtype to init the model under
2282-
22832356 if not is_sharded and state_dict is None :
22842357 # 4. loading non-sharded ckpt from the state dict
22852358 if config .tensor_parallel_degree > 1 and resolved_archive_file .endswith ("model_state.pdparams" ):
0 commit comments