Skip to content

Conversation

@alanwaketan
Copy link
Collaborator

Summary:
For all the cc ops, we use a token to introduce control dependencies among them such that they will be executed in order. This token is cached in the Python layer and this pull request moves it to C++ given the upcoming pytorch/pytorch#93173 won't carry the token from Python to C++.

Test Plan:
CI.

@alanwaketan alanwaketan self-assigned this Apr 20, 2023
@alanwaketan alanwaketan marked this pull request as draft April 20, 2023 01:08
@alanwaketan alanwaketan changed the base branch from master to alanwaketan/cc April 20, 2023 01:08
@alanwaketan alanwaketan changed the base branch from alanwaketan/cc to master April 20, 2023 01:09
@alanwaketan alanwaketan marked this pull request as ready for review April 20, 2023 07:04
@alanwaketan alanwaketan requested a review from JackCaoG April 20, 2023 07:04
@alanwaketan
Copy link
Collaborator Author

Okay, GPU CI is happy without test_zero1.py. Let's skip that and I will follow up next week.

result = torch_xla._XLAC._xla_all_gather(value, token, dim, shard_count,
groups or [], pin_layout)
devctx.all_reduce_token = result[1]
torch_xla._XLAC._set_all_reduce_token(result[1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I guess in the long term if we don't even set the token in python, there is no need to return the token to the python layer? We can look into this later, maybe it is still cleaner this way.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I'm just too lazy to make the changes to all cc ops to remove the token in the python layer.

std::mutex lock;
};

AllReduceToken g_all_reduce_token;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we have one token per device? Under the PJRT for v3 cases, each process will have 2 thread and 1 device per thread. In that case those two thread should not share the same token.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Completely forgot the V3 case...

[](at::Tensor& self, const at::Tensor& source) -> at::Tensor& {
return XLANativeFunctions::set_(self, source);
});
m.def("_get_all_reduce_token",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to call it _get_cc_token or _get_xla_token, although it is currently only for all_reduce. We can also do this after we convert second cc op to use cpp token

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just following the traditional in the python layer where it's named as all_reduce_token. haha.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly lgtm, @pratnali @amithrm FYI we are moving the token to cpp so it can be traced by the dynamo.

@alanwaketan
Copy link
Collaborator Author

Thanks Jack for approving the change.

@alanwaketan alanwaketan merged commit 44c2fa0 into master Apr 21, 2023
@pratnali
Copy link

Thanks for the feedback everyone.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

4 participants