-
Couldn't load subscription status.
- Fork 5.9k
[AutoParallel] addd sync param dynamic #73733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| 你的PR提交成功,感谢你对开源项目的贡献! |
8c57a1c to fab7b8f Compare | sync_process_ids = sync_param.process_mesh.process_ids | ||
| cur_group = _build_current_sync_commm_group( | ||
| ori_process_ids, sync_process_ids, get_group_from_ranks | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
共享参数通信组,现在必须用户自己创建吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,现在只能依赖用户创建
因为 动半 pp 的设计,从构建 PipelineStage 开始,各个 rank 已经开始互相不感知,
为了保证 rank 上构建通信组的一致性,现在是需要用户创建在 stage 构建前创建通信组的,并通过 shared_param_map 传递进来
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
更新:
创建通信组转移到 stage 的 init_shared_group 中,用户只需传入共享的 param 和 mesh 信息即可,无需自行创建
// 每组共享参数描述信息:stage j 上 sync_param 将与 stage i 上 ori_param 使用同样的参数信息 shared_param_map[key] = { // key 共享参数的描述代号 "ori_param": ori_param, // stage i 上的原始共享参数 "sync_param": sync_param, // stage j 上的要被同步的共享参数 "ori_mesh": ori_mesh, // stage i 的mesh信息 "sync_mesh": sync_mesh, // stage i 的mesh信息 }| input_args (TensorMeta|tuple[TensorMeta, ...]|None): The input arguments for the layer. | ||
| output_args (TensorMeta|tuple[TensorMeta, ...]|None): The output arguments for the layer. | ||
| group (Group, None): The process group for distributed training. If None, default group. | ||
| shared_param_map (dict[str, dict[str, paddle.Tensor | paddle.distributed.collective.Group]] | None): A description |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dict[str, dict[str, paddle.Tensor | paddle.distributed.collective.Group]] | None
这里,两个 str 分别表示什么,有什么作用? 为什么还要嵌套一个 dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shared_param_map 支持多个共享概念参数,每个成员的 key 是共享参数的描述代号,每个成员的 value 里面包含了共享参数 param 和 该参数所需通信组 group 的描述
shared_param_map 举例如下:
{
'gpt_shared_weight_1': { 'param': paddle.Tensor, 'group': rpaddle.distributed.collective.Group, },
'gpt_shared_weight_2': { 'param': paddle.Tensor, 'group': rpaddle.distributed.collective.Group, }
}
这里之后将更新注释,令描述更清楚
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 key 如 'gpt_shared_weight_1',具体什么作用呢?看代码逻辑里,并没有用到,这是扩展预留的吗?如果不需要,是不是用 list[dict[str, paddle.Tensor | ...]] 更简洁?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
key 是用户指定的共享参数的描述代号,现增加 debug log,在参数同步时可以输出 key 共享参数的状态信息
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (16.21%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@ ## develop #73733 +/- ## ========================================== Coverage ? 16.21% ========================================== Files ? 1 Lines ? 74 Branches ? 0 ========================================== Hits ? 12 Misses ? 62 Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| /re-run all-failed |
1 similar comment
| /re-run all-failed |
| /re-run inference build |
| /re-run all-failed |
1 similar comment
| /re-run all-failed |
| input_args (TensorMeta|tuple[TensorMeta, ...]|None): The input arguments for the layer. | ||
| output_args (TensorMeta|tuple[TensorMeta, ...]|None): The output arguments for the layer. | ||
| group (Group, None): The process group for distributed training. If None, default group. | ||
| shared_param_map (dict[str, dict[str, paddle.Tensor | paddle.distributed.collective.Group]] | None): A description |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 key 如 'gpt_shared_weight_1',具体什么作用呢?看代码逻辑里,并没有用到,这是扩展预留的吗?如果不需要,是不是用 list[dict[str, paddle.Tensor | ...]] 更简洁?
| group=sync_group, | ||
| ) | ||
| | ||
| def sync_shared_params(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个应该是 sync_shared_param_grads?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
e4bda45 to 4db2e37 Compare | /re-run all-failed |
| input_args (TensorMeta|tuple[TensorMeta, ...]|None): The input arguments for the layer. | ||
| output_args (TensorMeta|tuple[TensorMeta, ...]|None): The output arguments for the layer. | ||
| group (Group, None): The process group for distributed training. If None, default group. | ||
| shared_param_map (dict[str, dict[str, paddle.Tensor | paddle.distributed.collective.Group]] | None): A nested |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处类型注释是不是还没同步更新
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, thx
| LGTM |
| # Build identical communication groups for each rank within the mesh. | ||
| assert ( | ||
| "ori_mesh" in a_map and "sync_mesh" in a_map | ||
| ), "Missing 'ori_mesh' or `sync_mesh` key in `shared_param_map` entry" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议 assert message 打一下 a_map 信息
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,增加'check_a_shared_param_map'函数对用户输入进行严格检查
| get_group_from_ranks = {} | ||
| for key, a_map in self.shared_param_map.items(): | ||
| if "param" in a_map and "group" in a_map: | ||
| continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 “param” 和 “group” 是什么?为什么又被跳过呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
“param” 是指当前 rank 上要共享的参数,现已经更名为"shared_param"
“group” 是指共享参数通信使用的通信组
如果“param”不存在,意味着当前rank上没有要同步的共享参数。现更新为,如果"shared_param" 为None,意味着当前rank上没有要同步的共享参数
更新可见最新 comment,也已经增加更多注释信息
| ), "Missing 'ori_mesh' or `sync_mesh` key in `shared_param_map` entry" | ||
| ori_ranks = a_map["ori_mesh"].process_ids | ||
| sync_ranks = a_map["sync_mesh"].process_ids | ||
| assert len(ori_ranks) == len(sync_ranks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议增加报错信息
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, 已经增加在 函数 'check_a_shared_param_map' 中
| assert len(ori_ranks) == len(sync_ranks) | ||
| assert ( | ||
| ori_ranks != sync_ranks | ||
| ), "Shared parameters must be on different stage meshes." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议增加报错信息,打一下 ori_ranks 和 sync_ranks
其他地方同样的问题
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, 已经增加在 函数 'check_a_shared_param_map' 中
| # on the shared parameters. | ||
| for key, a_map in self.shared_param_map.items(): | ||
| if "param" not in a_map or "group" not in a_map: | ||
| continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"param" not in a_map or "group" not in a_map 这种情况是符合预期的吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
符合预期,vpp 上每个 rank 包含多个 PipelineStage,每个 PipelineStage 都需要执行 group 创建,为了避免冗余 group 创建,这里针对已经创建过的情况进行跳过
这里检查不够严格,之后修改
| 更新: // 所有共享参数对描述信息 shared_parameters = [ // 每对共享参数描述 { "params": [param_1, param_2], // 必需的,2 个stage上的要共享的参数 }, ]因为目前是依赖用户输入指定共享参数,这里增加函数 'check_a_shared_param_map' 对用户输入的每对共享参数进行严格检查,并完善报错信息,希望之后可以扩展更好的接口,减少对用户输入的依赖 注意: // 所有共享参数对描述信息 shared_parameters = [ // 每对共享参数描述 { "params": [param_1, param_2], // 必需的,2 个stage上的要共享的参数 "shared_param": param // 自动创建,为 None 或者 "params" 中参数之一,当前 rank 上要被共享的参数;为 None 时,意味共享参数不存在当前 rank 上 "group": group, // 自动创建,共享参数所需的通信组,当 "shared_param" 不为None时,该参数才有意义 }, ] |
297c6f4 to 6b2b132 Compare There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| /re-run all-failed |
2 similar comments
| /re-run all-failed |
| /re-run all-failed |
| LGTM |
PR Category
Auto Parallel
PR Types
Performance
Description
动半pp支持参数共享优化实现
一、代码实现:
需要用户在模型组网中创建
shared_parameters其为list[dict[str, list[EagerParamBase]]]类型,传入PipelineStage参与 stage 的构建目前是需要用户在模型中找到要共享的两个参数,构建
shared_parameters,其可支持多个共享参数描述,需要严格满足如下格式:**注意:**在 VPP 模式下,当前 rank 上 multiple stage 都必须使用 同一个
shared_parameters参数。第一次stage创建时,共享参数通信group信息可以保存在每对共享参数的dict中,之后stage的创建可以避免产生冗余 group。用法可见pipeline_sync_shared_parameters_unittest.py中 VPP 共享参数测试。在
PipelineStage类里面添加shared_parameters参数,并添加4个成员函数:_validate_shared_parameter_pair:因为目前是共享参数信息要依赖于用户输入,所以这里对 shared_parameters 参数进行严格的类型、个数等检查,并提供错误信息 log_init_shared_group:在当前 rank 上对每对参数 dict 创建 "shared_param" 和 “group”,shared_param:当前 rank 上实际要进行共享的参数,一般为 "params" 中两个参数之一;如果当前 rank 上没有要共享的参数("params"中两个参数都不位于当前rank 上),默认为 Nonegroup:当前 rank 上参数共享所需的通讯组_sync_shared_param:对当前rank 上 shared_param 进行 broadcast 同步_sync_shared_param_grads:对当前rank 上 shared_param 梯度进行 allreduce 同步在stage init 时先调用
_validate_shared_parameter_pair对参数进行检查,再调用_init_shared_group和_sync_shared_param完成共享参数同步在 schedule 的每次调度 step 结束时调用
sync_shared_params完成共享参数梯度同步二、测试:
(朴素 pp +参数共享)应该分别与 (3 种 pp 调度( FThenB、 1F1B、 VPP)+参数共享优化)实现loss精度 diff < 1e-5
三、验证:
在 GPT-13B 上进行 3000 step 收敛验证结果如下:

loss diff在0 上下波动,收敛一致
PCard-91691