Skip to content

Conversation

@sdasgup3
Copy link
Collaborator

@sdasgup3 sdasgup3 commented Dec 20, 2023

Background

State of implicit broadcasting support in XLA and PyTorch/XLA codebases.

Recently, HLO is equipped with the capability to express unbounded dynamic shaped ref. The Pytorch/XLA bridge added the machinery to propagate unbounded dynamic dimensions from torch to XLA.

HLO, in its current form, can handle implicit broadcasting for static shapes and bounded dynamic shapes which is currently leveraged by the PyTorch/XLA bridge as a single source of truth.
However, there is no support in XLA/HLO for implicit broadcasting with unbounded dynamic shapes.

Relevance of shape assertion to support implicit broadcasting with unbounded dynamic shapes

With static and bounded dynamic shapes it is feasible to check at compile time if the broadcasting rules are met. With unbounded dynamic shapes, we need to rely on runtime guards (which we refer as shape assertions) to ensure the participating shapes in broadcasting are valid. With that said, typical code generation for implicit broadcasting with unbounded shapes consists of two parts: (A) shape assertions, and (B) broadcasting sequence, which actuallly does the broadcasting assuming all the shape assertions hold good.

For example, CHLO dialect supports implicit broadcasting with unbounded dynamic shape ref. The support relies on shape dialect ops with shape assertions embedded in the code to check broadcasting rules are met. Please refer to the Appendix for how the lowered mhlo code (chlo ops → shape dialect ops → mhlo ops) would look like. Also, note how the shape assertions and broadcasting sequence look like.

Similarly, Jax supports experimental lowering of polymorphic shape specification to StableHLO with shape assertions to validates the the specification is valid at runtime.

PyTorch symbolic shape specification

Per ref and ref, PyTorch allows constraints over the dynamic dimensions. The shape constraints are currently represented in the FX graph and can be converted to assertions.

Proposal

Support implicit broadcasting with unbounded dynamic shapes at PT/XLA level.

With the shape constraint specification provided at the framework (PyTorch) level, it would make sense for PyTorch/XLA to leverage that information while doing implicit broadcasting.

Another option was to support implicit broadcasting at the XLA level. This is discouraged because HLO does not have any notion of shape constraint specification, hence the support would require to propagate those information via PyTorch/XLA bridge. Any change in specification format/semantics would require changes at the HLO client APIs.

Current PR

In the current PR, we are just implementing the just broadcast sequence (refer to B above) assuming that the participating shapes met the broadcasting rules at runtime. There is tracking issue #6232 to make sure that the shape constraints, provided in terms of PyTorch shape specification, in the FX graph are converted to shape assertions.

Example

With the proposed change, the following mini-model

print("(X,?) * c") a = torch.randn(()).to(device=device) b = torch.randn((10,5)).to(device=device) torch_xla._XLAC._xla_mark_dynamic(b, 1) c = a * b print(xm.get_stablehlo([c]))

can be exported to following StableHLO code

module @IrToHlo.18 attributes {mhlo.cross_program_prefetches = [], mhlo.is_dynamic = false, mhlo.use_auto_spmd_partitioning = false} { func.func @main(%arg0: tensor<10x?xf32>, %arg1: tensor<f32>) -> tensor<10x?xf32> { %0 = stablehlo.constant dense<10> : tensor<1xi32> %1 = stablehlo.constant dense<1> : tensor<2xi32> %2 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor<10x?xf32>) -> tensor<i32> %3 = stablehlo.reshape %2 : (tensor<i32>) -> tensor<1xi32> %4 = stablehlo.concatenate %0, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> %5 = stablehlo.maximum %4, %1 : tensor<2xi32> %6 = stablehlo.dynamic_broadcast_in_dim %arg1, %5, dims = [] : (tensor<f32>, tensor<2xi32>) -> tensor<10x?xf32> %7 = stablehlo.dynamic_broadcast_in_dim %arg0, %5, dims = [0, 1] : (tensor<10x?xf32>, tensor<2xi32>) -> tensor<10x?xf32> %8 = stablehlo.multiply %6, %7 : tensor<10x?xf32> return %8 : tensor<10x?xf32> } } 

Appendix

Consider the following legalization of chlo.broadcast_add to MHLO ops via Shape dialects.

//chlo func.func @same_rank_bcast(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> { %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> func.return %0 : tensor<?xf32> } // mlir-hlo-opt -chlo-legalize-to-hlo="legalize-broadcasts=true" --shape-legalize-to-hlo=legalize-constraints=true -canonicalize module { func.func @same_rank_bcast(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> { %0 = mhlo.constant dense<1> : tensor<1xi32> %1 = "mhlo.get_dimension_size"(%arg0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32> %2 = mhlo.reshape %1 : (tensor<i32>) -> tensor<1xi32> %3 = "mhlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32> %4 = mhlo.reshape %3 : (tensor<i32>) -> tensor<1xi32> %5 = mhlo.compare EQ, %2, %0, NOTYPE : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> %6 = mhlo.compare EQ, %4, %0, NOTYPE : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> %7 = mhlo.or %5, %6 : tensor<1xi1> %8 = mhlo.compare EQ, %2, %4, NOTYPE : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> %9 = mhlo.or %7, %8 : tensor<1xi1> %10 = mhlo.reshape %9 : (tensor<1xi1>) -> tensor<i1> // shape assertion mhlo.custom_call @shape_assertion(%10) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<i1>) -> () // broadcast sequence %11 = "mhlo.get_dimension_size"(%arg0) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32> %12 = mhlo.reshape %11 : (tensor<i32>) -> tensor<1xi32> %13 = "mhlo.get_dimension_size"(%arg1) {dimension = 0 : i64} : (tensor<?xf32>) -> tensor<i32> %14 = mhlo.reshape %13 : (tensor<i32>) -> tensor<1xi32> %15 = mhlo.maximum %12, %14 : tensor<1xi32> %16 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %15) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32> %17 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %15) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32> %18 = mhlo.add %16, %17 : tensor<?xf32> return %18 : tensor<?xf32> } } 
@sdasgup3 sdasgup3 requested review from GleasonK and qihqi December 20, 2023 08:29
Copy link
Collaborator

@GleasonK GleasonK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great writeup on broadcasting rationale and the state of broadcasting in CHLO. Thanks for the thorough commit message.

I'm trying to figure out how to keep these methods as simple and maintainable as possible, left a few comments to that end.

@JackCaoG JackCaoG added the dynamism Dynamic Shape Features label Dec 20, 2023
@sdasgup3 sdasgup3 force-pushed the sdasgup3/unbounded-implicit-broadcasting branch 3 times, most recently from 1f65892 to 4b6bc0d Compare December 21, 2023 21:21
@sdasgup3 sdasgup3 requested a review from GleasonK December 21, 2023 21:37
Copy link
Collaborator

@GleasonK GleasonK left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. One open comment, we can resolve there. Otherwise LGTM

@sdasgup3 sdasgup3 force-pushed the sdasgup3/unbounded-implicit-broadcasting branch 2 times, most recently from aac915f to 7627e33 Compare December 22, 2023 03:21
@sdasgup3 sdasgup3 requested a review from GleasonK December 22, 2023 03:24
@sdasgup3 sdasgup3 force-pushed the sdasgup3/unbounded-implicit-broadcasting branch from 7627e33 to 1fc1370 Compare December 22, 2023 03:32
@sdasgup3 sdasgup3 requested a review from lsy323 December 22, 2023 03:35
@sdasgup3 sdasgup3 force-pushed the sdasgup3/unbounded-implicit-broadcasting branch from 1fc1370 to 4de4102 Compare December 22, 2023 03:41
@sdasgup3 sdasgup3 force-pushed the sdasgup3/unbounded-implicit-broadcasting branch 2 times, most recently from 14463bb to 284cb0b Compare January 8, 2024 23:57
@lsy323
Copy link
Collaborator

lsy323 commented Jan 10, 2024

LGTM, thanks! Let's rebase after the CI is green again on HEAD.

@sdasgup3 sdasgup3 force-pushed the sdasgup3/unbounded-implicit-broadcasting branch from 284cb0b to 2def4f1 Compare January 11, 2024 17:30
@lsy323 lsy323 merged commit 896de17 into master Jan 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dynamism Dynamic Shape Features

6 participants