-
Couldn't load subscription status.
- Fork 560
[Distributed] Move the cached all reduce token to C++ #4912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
b96e1ac to 5497245 Compare | Okay, GPU CI is happy without test_zero1.py. Let's skip that and I will follow up next week. |
39c0092 to 36361ed Compare torch_xla/core/xla_model.py Outdated
| result = torch_xla._XLAC._xla_all_gather(value, token, dim, shard_count, | ||
| groups or [], pin_layout) | ||
| devctx.all_reduce_token = result[1] | ||
| torch_xla._XLAC._set_all_reduce_token(result[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, I guess in the long term if we don't even set the token in python, there is no need to return the token to the python layer? We can look into this later, maybe it is still cleaner this way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, I'm just too lazy to make the changes to all cc ops to remove the token in the python layer.
| std::mutex lock; | ||
| }; | ||
| | ||
| AllReduceToken g_all_reduce_token; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we have one token per device? Under the PJRT for v3 cases, each process will have 2 thread and 1 device per thread. In that case those two thread should not share the same token.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Completely forgot the V3 case...
| [](at::Tensor& self, const at::Tensor& source) -> at::Tensor& { | ||
| return XLANativeFunctions::set_(self, source); | ||
| }); | ||
| m.def("_get_all_reduce_token", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to call it _get_cc_token or _get_xla_token, although it is currently only for all_reduce. We can also do this after we convert second cc op to use cpp token
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just following the traditional in the python layer where it's named as all_reduce_token. haha.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| Thanks Jack for approving the change. |
| Thanks for the feedback everyone. |
Summary:
For all the cc ops, we use a token to introduce control dependencies among them such that they will be executed in order. This token is cached in the Python layer and this pull request moves it to C++ given the upcoming pytorch/pytorch#93173 won't carry the token from Python to C++.
Test Plan:
CI.