Skip to content

Conversation

@waliwali777
Copy link
Contributor

@waliwali777 waliwali777 commented Jul 1, 2025

PR Category

Auto Parallel

PR Types

Performance

Description

动半pp支持参数共享优化实现

一、代码实现:

  1. 需要用户在模型组网中创建 shared_parameters 其为 list[dict[str, list[EagerParamBase]]] 类型,传入 PipelineStage 参与 stage 的构建
    目前是需要用户在模型中找到要共享的两个参数,构建shared_parameters ,其可支持多个共享参数描述,需要严格满足如下格式:

    // 所有共享参数对描述信息 shared_parameters = [ // 第1对共享参数描述 { "params": [param_1, param_2], // 必需的,分别位于 2 个stage上的要共享的参数 }, // 第2对共享参数描述 { "params": [param_3, param_4], // 必需的,分别位于 2 个stage上的要共享的参数 }, ]

    **注意:**在 VPP 模式下,当前 rank 上 multiple stage 都必须使用 同一个 shared_parameters 参数。第一次stage创建时,共享参数通信 group 信息可以保存在每对共享参数的 dict 中,之后stage的创建可以避免产生冗余 group。用法可见 pipeline_sync_shared_parameters_unittest.py 中 VPP 共享参数测试。

  2. PipelineStage 类里面添加 shared_parameters 参数,并添加4个成员函数:

    • _validate_shared_parameter_pair:因为目前是共享参数信息要依赖于用户输入,所以这里对 shared_parameters 参数进行严格的类型、个数等检查,并提供错误信息 log

    • _init_shared_group:在当前 rank 上对每对参数 dict 创建 "shared_param" 和 “group”,

      • shared_param :当前 rank 上实际要进行共享的参数,一般为 "params" 中两个参数之一;如果当前 rank 上没有要共享的参数("params"中两个参数都不位于当前rank 上),默认为 None
      • group:当前 rank 上参数共享所需的通讯组
    • _sync_shared_param:对当前rank 上 shared_param 进行 broadcast 同步

    • _sync_shared_param_grads:对当前rank 上 shared_param 梯度进行 allreduce 同步

  3. 在stage init 时先调用 _validate_shared_parameter_pair 对参数进行检查,再调用 _init_shared_group _sync_shared_param 完成共享参数同步

  4. 在 schedule 的每次调度 step 结束时调用 sync_shared_params 完成共享参数梯度同步

二、测试:

(朴素 pp +参数共享)应该分别与 (3 种 pp 调度( FThenB、 1F1B、 VPP)+参数共享优化)实现loss精度 diff < 1e-5

三、验证:

在 GPT-13B 上进行 3000 step 收敛验证结果如下:
image
loss diff在0 上下波动,收敛一致

PCard-91691

@paddle-bot
Copy link

paddle-bot bot commented Jul 1, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@liym27 liym27 requested review from liym27 and xuxinyi389 July 4, 2025 03:12
@waliwali777 waliwali777 force-pushed the dynamic_sync_param branch from 8c57a1c to fab7b8f Compare July 7, 2025 01:22
sync_process_ids = sync_param.process_mesh.process_ids
cur_group = _build_current_sync_commm_group(
ori_process_ids, sync_process_ids, get_group_from_ranks
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

共享参数通信组,现在必须用户自己创建吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,现在只能依赖用户创建
因为 动半 pp 的设计,从构建 PipelineStage 开始,各个 rank 已经开始互相不感知,
为了保证 rank 上构建通信组的一致性,现在是需要用户创建在 stage 构建前创建通信组的,并通过 shared_param_map 传递进来

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

更新:
创建通信组转移到 stage 的 init_shared_group 中,用户只需传入共享的 param 和 mesh 信息即可,无需自行创建

// 每组共享参数描述信息:stage j 上 sync_param 将与 stage i 上 ori_param 使用同样的参数信息 shared_param_map[key] = { // key 共享参数的描述代号 "ori_param": ori_param, // stage i 上的原始共享参数 "sync_param": sync_param, // stage j 上的要被同步的共享参数 "ori_mesh": ori_mesh, // stage i 的mesh信息 "sync_mesh": sync_mesh, // stage i 的mesh信息 }
input_args (TensorMeta|tuple[TensorMeta, ...]|None): The input arguments for the layer.
output_args (TensorMeta|tuple[TensorMeta, ...]|None): The output arguments for the layer.
group (Group, None): The process group for distributed training. If None, default group.
shared_param_map (dict[str, dict[str, paddle.Tensor | paddle.distributed.collective.Group]] | None): A description
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict[str, dict[str, paddle.Tensor | paddle.distributed.collective.Group]] | None
这里,两个 str 分别表示什么,有什么作用? 为什么还要嵌套一个 dict

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shared_param_map 支持多个共享概念参数,每个成员的 key 是共享参数的描述代号,每个成员的 value 里面包含了共享参数 param 和 该参数所需通信组 group 的描述
shared_param_map 举例如下:
{
'gpt_shared_weight_1': { 'param': paddle.Tensor, 'group': rpaddle.distributed.collective.Group, },
'gpt_shared_weight_2': { 'param': paddle.Tensor, 'group': rpaddle.distributed.collective.Group, }
}
这里之后将更新注释,令描述更清楚

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 key 如 'gpt_shared_weight_1',具体什么作用呢?看代码逻辑里,并没有用到,这是扩展预留的吗?如果不需要,是不是用 list[dict[str, paddle.Tensor | ...]] 更简洁?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key 是用户指定的共享参数的描述代号,现增加 debug log,在参数同步时可以输出 key 共享参数的状态信息

@codecov-commenter
Copy link

codecov-commenter commented Jul 7, 2025

Codecov Report

Attention: Patch coverage is 16.21622% with 62 lines in your changes missing coverage. Please review.

Please upload report for BASE (develop@cfab2c4). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...ddle/distributed/auto_parallel/pipelining/stage.py 16.21% 62 Missing ⚠️

❌ Your patch status has failed because the patch coverage (16.21%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@ Coverage Diff @@ ## develop #73733 +/- ## ========================================== Coverage ? 16.21% ========================================== Files ? 1 Lines ? 74 Branches ? 0 ========================================== Hits ? 12 Misses ? 62 Partials ? 0 

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
@waliwali777
Copy link
Contributor Author

/re-run all-failed

1 similar comment
@waliwali777
Copy link
Contributor Author

/re-run all-failed

@waliwali777
Copy link
Contributor Author

/re-run inference build

@waliwali777
Copy link
Contributor Author

/re-run all-failed

1 similar comment
@waliwali777
Copy link
Contributor Author

/re-run all-failed

input_args (TensorMeta|tuple[TensorMeta, ...]|None): The input arguments for the layer.
output_args (TensorMeta|tuple[TensorMeta, ...]|None): The output arguments for the layer.
group (Group, None): The process group for distributed training. If None, default group.
shared_param_map (dict[str, dict[str, paddle.Tensor | paddle.distributed.collective.Group]] | None): A description
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 key 如 'gpt_shared_weight_1',具体什么作用呢?看代码逻辑里,并没有用到,这是扩展预留的吗?如果不需要,是不是用 list[dict[str, paddle.Tensor | ...]] 更简洁?

group=sync_group,
)

def sync_shared_params(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个应该是 sync_shared_param_grads?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@waliwali777 waliwali777 force-pushed the dynamic_sync_param branch from e4bda45 to 4db2e37 Compare July 8, 2025 03:41
@waliwali777
Copy link
Contributor Author

/re-run all-failed

input_args (TensorMeta|tuple[TensorMeta, ...]|None): The input arguments for the layer.
output_args (TensorMeta|tuple[TensorMeta, ...]|None): The output arguments for the layer.
group (Group, None): The process group for distributed training. If None, default group.
shared_param_map (dict[str, dict[str, paddle.Tensor | paddle.distributed.collective.Group]] | None): A nested
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处类型注释是不是还没同步更新

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, thx

@xuxinyi389
Copy link
Contributor

LGTM

# Build identical communication groups for each rank within the mesh.
assert (
"ori_mesh" in a_map and "sync_mesh" in a_map
), "Missing 'ori_mesh' or `sync_mesh` key in `shared_param_map` entry"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议 assert message 打一下 a_map 信息

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done,增加'check_a_shared_param_map'函数对用户输入进行严格检查

get_group_from_ranks = {}
for key, a_map in self.shared_param_map.items():
if "param" in a_map and "group" in a_map:
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 “param” 和 “group” 是什么?为什么又被跳过呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

“param” 是指当前 rank 上要共享的参数,现已经更名为"shared_param"
“group” 是指共享参数通信使用的通信组
如果“param”不存在,意味着当前rank上没有要同步的共享参数。现更新为,如果"shared_param" 为None,意味着当前rank上没有要同步的共享参数
更新可见最新 comment,也已经增加更多注释信息

), "Missing 'ori_mesh' or `sync_mesh` key in `shared_param_map` entry"
ori_ranks = a_map["ori_mesh"].process_ids
sync_ranks = a_map["sync_mesh"].process_ids
assert len(ori_ranks) == len(sync_ranks)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议增加报错信息

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, 已经增加在 函数 'check_a_shared_param_map' 中

assert len(ori_ranks) == len(sync_ranks)
assert (
ori_ranks != sync_ranks
), "Shared parameters must be on different stage meshes."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议增加报错信息,打一下 ori_ranks 和 sync_ranks
其他地方同样的问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, 已经增加在 函数 'check_a_shared_param_map' 中

# on the shared parameters.
for key, a_map in self.shared_param_map.items():
if "param" not in a_map or "group" not in a_map:
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"param" not in a_map or "group" not in a_map 这种情况是符合预期的吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

符合预期,vpp 上每个 rank 包含多个 PipelineStage,每个 PipelineStage 都需要执行 group 创建,为了避免冗余 group 创建,这里针对已经创建过的情况进行跳过
这里检查不够严格,之后修改

@waliwali777
Copy link
Contributor Author

waliwali777 commented Jul 8, 2025

更新:
参数更新为 shared_parameters,其为 list[dict[str, list[EagerParamBase]]] 类型,需要严格满足如下格式:

 // 所有共享参数对描述信息 shared_parameters = [ // 每对共享参数描述 { "params": [param_1, param_2], // 必需的,2 个stage上的要共享的参数 }, ]

因为目前是依赖用户输入指定共享参数,这里增加函数 'check_a_shared_param_map' 对用户输入的每对共享参数进行严格检查,并完善报错信息,希望之后可以扩展更好的接口,减少对用户输入的依赖
构建group 将转移到stage初始化中自动完成

注意:
在 VPP 模式下,同一个 rank 上需要构建多个 stage,这里要求多个 stage 使用同一个 shared_param_map ,可以避免冗余冗余 group。
创建的 group 和实际共享的参数 将被共享保存在 shared_param_map 中:

 // 所有共享参数对描述信息 shared_parameters = [ // 每对共享参数描述 { "params": [param_1, param_2], // 必需的,2 个stage上的要共享的参数 "shared_param": param // 自动创建,为 None 或者 "params" 中参数之一,当前 rank 上要被共享的参数;为 None 时,意味共享参数不存在当前 rank 上 "group": group, // 自动创建,共享参数所需的通信组,当 "shared_param" 不为None时,该参数才有意义 }, ]
@waliwali777 waliwali777 force-pushed the dynamic_sync_param branch from 297c6f4 to 6b2b132 Compare July 9, 2025 07:37
Copy link
Contributor

@liym27 liym27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@waliwali777
Copy link
Contributor Author

/re-run all-failed

2 similar comments
@waliwali777
Copy link
Contributor Author

/re-run all-failed

@waliwali777
Copy link
Contributor Author

/re-run all-failed

@xuxinyi389
Copy link
Contributor

LGTM

@liym27 liym27 merged commit 9f969d4 into PaddlePaddle:develop Jul 10, 2025
77 of 82 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

5 participants