- Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][Vector] Pattern to linearize broadcast #163845
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: James Newling (newling) ChangesThe PR #162167 removed a pattern to linearize vector.splat, without adding the equivalent pattern for vector.broadcast. This PR adds such a pattern, hopefully brining vector.broadcast up to full parity with vector.splat that has now been removed. Full diff: https://github.com/llvm/llvm-project/pull/163845.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 1b656d82f3201..9b2f88b7bbe9d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -817,6 +817,50 @@ struct LinearizeVectorToElements final } }; +/// Convert broadcasts from scalars or 1-element vectors, such as +/// +/// ```mlir +/// vector.broadcast %value : f32 to vector<4x4xf32> +/// ``` +/// +/// to broadcasts to rank-1 vectors, with shape_casts before/after as needed. +/// The above becomes, +/// +/// ```mlir +/// %out_1d = vector.splat %value : f32 to vector<16xf32> +/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> +/// ``` +struct LinearizeVectorBroadcast final + : public OpConversionPattern<vector::BroadcastOp> { + using Base::Base; + + LinearizeVectorBroadcast(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + int numElements = 1; + Type sourceType = broadcastOp.getSourceType(); + if (auto vecType = dyn_cast<VectorType>(sourceType)) { + numElements = vecType.getNumElements(); + } + + if (numElements != 1) { + return rewriter.notifyMatchFailure( + broadcastOp, "only broadcasts of single elements can be linearized."); + } + + auto dstTy = getTypeConverter()->convertType(broadcastOp.getType()); + rewriter.replaceOpWithNewOp<vector::BroadcastOp>(broadcastOp, dstTy, + adaptor.getSource()); + + return success(); + } +}; + } // namespace /// This method defines the set of operations that are linearizable, and hence @@ -909,8 +953,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns( patterns .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast, LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore, - LinearizeVectorFromElements, LinearizeVectorToElements>( - typeConverter, patterns.getContext()); + LinearizeVectorBroadcast, LinearizeVectorFromElements, + LinearizeVectorToElements>(typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index ee5cfbcda5c19..cbbc833d7a51d 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -428,6 +428,47 @@ func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> { // ----- +// CHECK-LABEL: linearize_vector_broadcast_scalar_source +// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32> +func.func @linearize_vector_broadcast_scalar_source(%arg0: i32) -> vector<4x2xi32> { + + // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<8xi32> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32> + // CHECK: return %[[CAST]] : vector<4x2xi32> + %0 = vector.broadcast %arg0 : i32 to vector<4x2xi32> + return %0 : vector<4x2xi32> +} + +// ----- + +// CHECK-LABEL: linearize_vector_broadcast_rank_two_source +// CHECK-SAME: (%[[ARG:.*]]: vector<1x1xi32>) -> vector<4x2xi32> +func.func @linearize_vector_broadcast_rank_two_source(%arg0: vector<1x1xi32>) -> vector<4x2xi32> { + + // CHECK: %[[CAST0:.*]] = vector.shape_cast %[[ARG]] : vector<1x1xi32> to vector<1xi32> + // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST0]] : vector<1xi32> to vector<8xi32> + // CHECK: %[[CAST1:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32> + // CHECK: return %[[CAST1]] : vector<4x2xi32> + %0 = vector.broadcast %arg0 : vector<1x1xi32> to vector<4x2xi32> + return %0 : vector<4x2xi32> +} + +// ----- + +// CHECK-LABEL: linearize_scalable_vector_broadcast +// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32> +func.func @linearize_scalable_vector_broadcast(%arg0: i32) -> vector<4x[2]xi32> { + + // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<[8]xi32> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<[8]xi32> to vector<4x[2]xi32> + // CHECK: return %[[CAST]] : vector<4x[2]xi32> + %0 = vector.broadcast %arg0 : i32 to vector<4x[2]xi32> + return %0 : vector<4x[2]xi32> + +} + +// ----- + // CHECK-LABEL: linearize_create_mask // CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1> func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> { |
Thanks a lot! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for sending this so quickly! Just some minor comments. Otherwise LGTM
func.func @linearize_vector_broadcast_scalar_source(%arg0: i32) -> vector<4x2xi32> { | ||
| ||
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<8xi32> | ||
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32> | ||
// CHECK: return %[[CAST]] : vector<4x2xi32> | ||
%0 = vector.broadcast %arg0 : i32 to vector<4x2xi32> | ||
return %0 : vector<4x2xi32> | ||
} | ||
| ||
// ----- | ||
| ||
// CHECK-LABEL: linearize_vector_broadcast_rank_two_source | ||
// CHECK-SAME: (%[[ARG:.*]]: vector<1x1xi32>) -> vector<4x2xi32> | ||
func.func @linearize_vector_broadcast_rank_two_source(%arg0: vector<1x1xi32>) -> vector<4x2xi32> { | ||
| ||
// CHECK: %[[CAST0:.*]] = vector.shape_cast %[[ARG]] : vector<1x1xi32> to vector<1xi32> | ||
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST0]] : vector<1xi32> to vector<8xi32> | ||
// CHECK: %[[CAST1:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32> | ||
// CHECK: return %[[CAST1]] : vector<4x2xi32> | ||
%0 = vector.broadcast %arg0 : vector<1x1xi32> to vector<4x2xi32> | ||
return %0 : vector<4x2xi32> | ||
} | ||
| ||
// ----- | ||
| ||
// CHECK-LABEL: linearize_scalable_vector_broadcast | ||
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32> | ||
func.func @linearize_scalable_vector_broadcast(%arg0: i32) -> vector<4x[2]xi32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] I think that you can safely skip "linearize" and "vector" from these names (that info is already "encoded" in the path). Naming is hard 🤷🏻
func.func @linearize_vector_broadcast_scalar_source(%arg0: i32) -> vector<4x2xi32> { | |
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<8xi32> | |
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32> | |
// CHECK: return %[[CAST]] : vector<4x2xi32> | |
%0 = vector.broadcast %arg0 : i32 to vector<4x2xi32> | |
return %0 : vector<4x2xi32> | |
} | |
// ----- | |
// CHECK-LABEL: linearize_vector_broadcast_rank_two_source | |
// CHECK-SAME: (%[[ARG:.*]]: vector<1x1xi32>) -> vector<4x2xi32> | |
func.func @linearize_vector_broadcast_rank_two_source(%arg0: vector<1x1xi32>) -> vector<4x2xi32> { | |
// CHECK: %[[CAST0:.*]] = vector.shape_cast %[[ARG]] : vector<1x1xi32> to vector<1xi32> | |
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST0]] : vector<1xi32> to vector<8xi32> | |
// CHECK: %[[CAST1:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32> | |
// CHECK: return %[[CAST1]] : vector<4x2xi32> | |
%0 = vector.broadcast %arg0 : vector<1x1xi32> to vector<4x2xi32> | |
return %0 : vector<4x2xi32> | |
} | |
// ----- | |
// CHECK-LABEL: linearize_scalable_vector_broadcast | |
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32> | |
func.func @linearize_scalable_vector_broadcast(%arg0: i32) -> vector<4x[2]xi32> { | |
func.func @broadcast_scalar_source(%arg0: i32) -> vector<4x2xi32> { | |
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<8xi32> | |
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32> | |
// CHECK: return %[[CAST]] : vector<4x2xi32> | |
%0 = vector.broadcast %arg0 : i32 to vector<4x2xi32> | |
return %0 : vector<4x2xi32> | |
} | |
// ----- | |
// CHECK-LABEL: linearize_vector_broadcast_rank_two_source | |
// CHECK-SAME: (%[[ARG:.*]]: vector<1x1xi32>) -> vector<4x2xi32> | |
func.func @broadcast_rank_two_source(%arg0: vector<1x1xi32>) -> vector<4x2xi32> { | |
// CHECK: %[[CAST0:.*]] = vector.shape_cast %[[ARG]] : vector<1x1xi32> to vector<1xi32> | |
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST0]] : vector<1xi32> to vector<8xi32> | |
// CHECK: %[[CAST1:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32> | |
// CHECK: return %[[CAST1]] : vector<4x2xi32> | |
%0 = vector.broadcast %arg0 : vector<1x1xi32> to vector<4x2xi32> | |
return %0 : vector<4x2xi32> | |
} | |
// ----- | |
// CHECK-LABEL: linearize_scalable_vector_broadcast | |
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32> | |
func.func @broadcast_rank_two_source_scalable(%arg0: i32) -> vector<4x[2]xi32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can post a follow-up with updates to the function names, here I was just being consistent with what is currently in this file!
Signed-off-by: James Newling <james.newling@gmail.com>
Thanks for the quick fix! |
The PR #162167 removed a pattern to linearize vector.splat, without adding the equivalent pattern for vector.broadcast. This PR adds such a pattern, hopefully brining vector.broadcast up to full parity with vector.splat that has now been removed.