Skip to content

Conversation

@alanwaketan
Copy link
Collaborator

Summary:
This is to add manual all-reduce support to SPMD and it currently only supports one input tensor. For array support, we can do that in python layer instead.

Test Plan:
python ./test/spmd/test_xla_sharding.py -v -k test_spmd_all_reduce

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approve to unblock, but I think we should fix the tensor method name

}
}

XLATensorPtr all_reduce(const XLATensorPtr& input, AllReduceType reduce_type,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you call it all_reduce _no_token, the only difference in signature is it does not take pin_layout but the main difference in the op is that it does not set token.. It is better to reflect that in the name.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I can follow up with that.

@JackCaoG
Copy link
Collaborator

for array support do you plan to call all_reduce multiple times? In our C++ implementation I think we group tensors by dtype and call all_rduce once per dtype.

@alanwaketan
Copy link
Collaborator Author

for array support do you plan to call all_reduce multiple times? In our C++ implementation I think we group tensors by dtype and call all_rduce once per dtype.

I don't think that's necessary. I'm thinking the compiler should be smart enough to fuse all-reduces if the fusion is necessary.

@alanwaketan
Copy link
Collaborator Author

Thanks Jack for approving.

@alanwaketan alanwaketan merged commit 0df5c29 into master Jun 26, 2024
@alanwaketan alanwaketan deleted the alanwaketan/spmd_all_reduce branch June 26, 2024 18:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

2 participants