-
Couldn't load subscription status.
- Fork 560
Support of implicit broadcasting with unbounded dynamism #6219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great writeup on broadcasting rationale and the state of broadcasting in CHLO. Thanks for the thorough commit message.
I'm trying to figure out how to keep these methods as simple and maintainable as possible, left a few comments to that end.
1f65892 to 4b6bc0d Compare There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. One open comment, we can resolve there. Otherwise LGTM
aac915f to 7627e33 Compare 7627e33 to 1fc1370 Compare 1fc1370 to 4de4102 Compare 14463bb to 284cb0b Compare | LGTM, thanks! Let's rebase after the CI is green again on HEAD. |
284cb0b to 2def4f1 Compare
Background
State of implicit broadcasting support in XLA and PyTorch/XLA codebases.
Recently, HLO is equipped with the capability to express unbounded dynamic shaped ref. The Pytorch/XLA bridge added the machinery to propagate unbounded dynamic dimensions from torch to XLA.
HLO, in its current form, can handle implicit broadcasting for static shapes and bounded dynamic shapes which is currently leveraged by the PyTorch/XLA bridge as a single source of truth.
However, there is no support in XLA/HLO for implicit broadcasting with unbounded dynamic shapes.
Relevance of shape assertion to support implicit broadcasting with unbounded dynamic shapes
With static and bounded dynamic shapes it is feasible to check at compile time if the broadcasting rules are met. With unbounded dynamic shapes, we need to rely on runtime guards (which we refer as shape assertions) to ensure the participating shapes in broadcasting are valid. With that said, typical code generation for implicit broadcasting with unbounded shapes consists of two parts: (A) shape assertions, and (B) broadcasting sequence, which actuallly does the broadcasting assuming all the shape assertions hold good.
For example, CHLO dialect supports implicit broadcasting with unbounded dynamic shape ref. The support relies on shape dialect ops with shape assertions embedded in the code to check broadcasting rules are met. Please refer to the Appendix for how the lowered mhlo code (chlo ops → shape dialect ops → mhlo ops) would look like. Also, note how the shape assertions and broadcasting sequence look like.
Similarly, Jax supports experimental lowering of polymorphic shape specification to StableHLO with shape assertions to validates the the specification is valid at runtime.
PyTorch symbolic shape specification
Per ref and ref, PyTorch allows constraints over the dynamic dimensions. The shape constraints are currently represented in the FX graph and can be converted to assertions.
Proposal
Support implicit broadcasting with unbounded dynamic shapes at PT/XLA level.
With the shape constraint specification provided at the framework (PyTorch) level, it would make sense for PyTorch/XLA to leverage that information while doing implicit broadcasting.
Another option was to support implicit broadcasting at the XLA level. This is discouraged because HLO does not have any notion of shape constraint specification, hence the support would require to propagate those information via PyTorch/XLA bridge. Any change in specification format/semantics would require changes at the HLO client APIs.
Current PR
In the current PR, we are just implementing the just broadcast sequence (refer to B above) assuming that the participating shapes met the broadcasting rules at runtime. There is tracking issue #6232 to make sure that the shape constraints, provided in terms of PyTorch shape specification, in the FX graph are converted to shape assertions.
Example
With the proposed change, the following mini-model
can be exported to following StableHLO code
Appendix
Consider the following legalization of
chlo.broadcast_addto MHLO ops via Shape dialects.