Skip to content

Commit 99fb78d

Browse files
committed
update typename
1 parent 1e250f4 commit 99fb78d

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

paddlenlp/trainer/unified_checkpoint/sharding_split_param_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,12 +388,18 @@ def load_non_merge_optimizer_with_split_param(args, model, optimizer, resume_fro
388388
expected_keys, param_slice_info, param_shape_info = get_params_info(comm_buffer_list)
389389
expected_keys = set([static2struct_name_mappings.get(name, None) for name in expected_keys])
390390
expected_keys_optim = []
391-
typename_set = set()
391+
sharding_typename_set, typename_set = [], []
392392
with safe_open(optimizer_path, framework="numpy") as f:
393393
optim_keys = f.keys()
394394
for key in optim_keys:
395395
_, typename = key.split("/")
396-
typename_set.add(typename)
396+
typename_set.append(typename)
397+
398+
# To avoid incomplete typename in some shard files, communication is performed.
399+
hcg = fleet.get_hybrid_communicate_group()
400+
sharding_group = hcg.get_sharding_parallel_group()
401+
dist.all_gather_object(sharding_typename_set, typename_set, sharding_group)
402+
typename_set = set(chain(*sharding_typename_set))
397403
for key in expected_keys:
398404
for typename in typename_set:
399405
expected_keys_optim.append(f"{key}/{typename}")

0 commit comments

Comments
 (0)