-
Couldn't load subscription status.
- Fork 5.9k
[Auto-parallel] Fix sharding all_gather overlap in auto_dy #73717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Auto-parallel] Fix sharding all_gather overlap in auto_dy #73717
Conversation
| def fuse_all_gather_hook_func(param_storage, comm_group): | ||
| @paddle.autograd.no_grad() | ||
| def fuse_comm(*_): | ||
| shard_size = param_storage._numel() // comm_group.nranks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里,如果 param_storage._numel() 不能被整除,会怎么处理
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在 _build_fuse_param_view 中的 get_padded_size 确保了param_storage._numel() 是 comm_group.nranks 整数倍,故不会出现这种情况。
| task = paddle.distributed.all_gather( | ||
| param_storage, | ||
| slice_buffer, | ||
| group=self._sharding_group, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么 传了 comm_group 但实际用的 self._sharding_group?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已做修改,感谢!
| | ||
| def _set_sharding_overlap(self, enable_sharding_overlap, layers): | ||
| self.enable_sharding_overlap = enable_sharding_overlap | ||
| self._layers = layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1、 后续要用到 self._layers 做参数查找和注册 hook,这里需要对 layers 参数做检查,比如,类型是 paddle.nn.Layer
2、这个函数本身就是 enable_sharding_overlap 为 True 时才会调用吧,是有有必要再传这个参数?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1和2均已做修改,感谢!
| 'param' | ||
| ] | ||
| layer = _find_layer_containing_param(first_param) | ||
| layer.register_forward_pre_hook( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 这里每次调用 _find_layer_containing_param 都会遍历所有子layer,建议缓存 param2layer 的关系
- 考虑 layer 为 None 的情况
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改为用局部变量 param2layer = {} 缓存,已有 self._layers 为 None 时的报错提醒。
| ) | ||
| | ||
| def _set_tensor_fusion(self, enable_tensor_fusion): | ||
| self.enable_tensor_fusion = enable_tensor_fusion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数本身就是 enable_tensor_fusion 为 True,不需再传参数 enable_tensor_fusion 了。建议:
def _enable_tensor_fusion(self):
self.enable_tensor_fusion = True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已做修改,感谢!
| ) | ||
| for layer in self._layers.sublayers(): | ||
| for p in layer.parameters(include_sublayers=False): | ||
| if param.name == p.name: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里只能通过 name 来判断吗?是否参数名会被用户修改?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改为根据 param 的id判断
| sync_op=False, | ||
| ).wait() | ||
| | ||
| def _async_reduce_scatter(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如线下沟通,还有以下问题:
- 函数命名
- 增加注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已做相应修改,感谢!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@ ## develop #73717 +/- ## ========================================== Coverage ? 55.81% ========================================== Files ? 1 Lines ? 43 Branches ? 0 ========================================== Hits ? 24 Misses ? 19 Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
Auto Parallel
PR Types
Bug fixes
Description
Launching all
all_gatherat once blocks overlap with other sync/comm ops.Fix: Prefetch 1 buffer ahead by hook to enable overlap.
Ref: Same fix in dynamic_hand #73406
Pcard-70448