Skip to content

Conversation

newling
Copy link
Contributor

@newling newling commented Oct 16, 2025

The PR #162167 removed a pattern to linearize vector.splat, without adding the equivalent pattern for vector.broadcast. This PR adds such a pattern, hopefully brining vector.broadcast up to full parity with vector.splat that has now been removed.

@llvmbot
Copy link
Member

llvmbot commented Oct 16, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

The PR #162167 removed a pattern to linearize vector.splat, without adding the equivalent pattern for vector.broadcast. This PR adds such a pattern, hopefully brining vector.broadcast up to full parity with vector.splat that has now been removed.


Full diff: https://github.com/llvm/llvm-project/pull/163845.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+46-2)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+41)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 1b656d82f3201..9b2f88b7bbe9d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -817,6 +817,50 @@ struct LinearizeVectorToElements final } }; +/// Convert broadcasts from scalars or 1-element vectors, such as +/// +/// ```mlir +/// vector.broadcast %value : f32 to vector<4x4xf32> +/// ``` +/// +/// to broadcasts to rank-1 vectors, with shape_casts before/after as needed. +/// The above becomes, +/// +/// ```mlir +/// %out_1d = vector.splat %value : f32 to vector<16xf32> +/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> +/// ``` +struct LinearizeVectorBroadcast final + : public OpConversionPattern<vector::BroadcastOp> { + using Base::Base; + + LinearizeVectorBroadcast(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + int numElements = 1; + Type sourceType = broadcastOp.getSourceType(); + if (auto vecType = dyn_cast<VectorType>(sourceType)) { + numElements = vecType.getNumElements(); + } + + if (numElements != 1) { + return rewriter.notifyMatchFailure( + broadcastOp, "only broadcasts of single elements can be linearized."); + } + + auto dstTy = getTypeConverter()->convertType(broadcastOp.getType()); + rewriter.replaceOpWithNewOp<vector::BroadcastOp>(broadcastOp, dstTy, + adaptor.getSource()); + + return success(); + } +}; + } // namespace /// This method defines the set of operations that are linearizable, and hence @@ -909,8 +953,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns( patterns .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast, LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore, - LinearizeVectorFromElements, LinearizeVectorToElements>( - typeConverter, patterns.getContext()); + LinearizeVectorBroadcast, LinearizeVectorFromElements, + LinearizeVectorToElements>(typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index ee5cfbcda5c19..cbbc833d7a51d 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -428,6 +428,47 @@ func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> { // ----- +// CHECK-LABEL: linearize_vector_broadcast_scalar_source +// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32> +func.func @linearize_vector_broadcast_scalar_source(%arg0: i32) -> vector<4x2xi32> { + + // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<8xi32> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32> + // CHECK: return %[[CAST]] : vector<4x2xi32> + %0 = vector.broadcast %arg0 : i32 to vector<4x2xi32> + return %0 : vector<4x2xi32> +} + +// ----- + +// CHECK-LABEL: linearize_vector_broadcast_rank_two_source +// CHECK-SAME: (%[[ARG:.*]]: vector<1x1xi32>) -> vector<4x2xi32> +func.func @linearize_vector_broadcast_rank_two_source(%arg0: vector<1x1xi32>) -> vector<4x2xi32> { + + // CHECK: %[[CAST0:.*]] = vector.shape_cast %[[ARG]] : vector<1x1xi32> to vector<1xi32> + // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST0]] : vector<1xi32> to vector<8xi32> + // CHECK: %[[CAST1:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32> + // CHECK: return %[[CAST1]] : vector<4x2xi32> + %0 = vector.broadcast %arg0 : vector<1x1xi32> to vector<4x2xi32> + return %0 : vector<4x2xi32> +} + +// ----- + +// CHECK-LABEL: linearize_scalable_vector_broadcast +// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32> +func.func @linearize_scalable_vector_broadcast(%arg0: i32) -> vector<4x[2]xi32> { + + // CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<[8]xi32> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<[8]xi32> to vector<4x[2]xi32> + // CHECK: return %[[CAST]] : vector<4x[2]xi32> + %0 = vector.broadcast %arg0 : i32 to vector<4x[2]xi32> + return %0 : vector<4x[2]xi32> + +} + +// ----- + // CHECK-LABEL: linearize_create_mask // CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1> func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> { 
@Garra1980
Copy link

Thanks a lot!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sending this so quickly! Just some minor comments. Otherwise LGTM

Comment on lines +433 to +460
func.func @linearize_vector_broadcast_scalar_source(%arg0: i32) -> vector<4x2xi32> {

// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<8xi32>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32>
// CHECK: return %[[CAST]] : vector<4x2xi32>
%0 = vector.broadcast %arg0 : i32 to vector<4x2xi32>
return %0 : vector<4x2xi32>
}

// -----

// CHECK-LABEL: linearize_vector_broadcast_rank_two_source
// CHECK-SAME: (%[[ARG:.*]]: vector<1x1xi32>) -> vector<4x2xi32>
func.func @linearize_vector_broadcast_rank_two_source(%arg0: vector<1x1xi32>) -> vector<4x2xi32> {

// CHECK: %[[CAST0:.*]] = vector.shape_cast %[[ARG]] : vector<1x1xi32> to vector<1xi32>
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST0]] : vector<1xi32> to vector<8xi32>
// CHECK: %[[CAST1:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32>
// CHECK: return %[[CAST1]] : vector<4x2xi32>
%0 = vector.broadcast %arg0 : vector<1x1xi32> to vector<4x2xi32>
return %0 : vector<4x2xi32>
}

// -----

// CHECK-LABEL: linearize_scalable_vector_broadcast
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32>
func.func @linearize_scalable_vector_broadcast(%arg0: i32) -> vector<4x[2]xi32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] I think that you can safely skip "linearize" and "vector" from these names (that info is already "encoded" in the path). Naming is hard 🤷🏻

Suggested change
func.func @linearize_vector_broadcast_scalar_source(%arg0: i32) -> vector<4x2xi32> {
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<8xi32>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32>
// CHECK: return %[[CAST]] : vector<4x2xi32>
%0 = vector.broadcast %arg0 : i32 to vector<4x2xi32>
return %0 : vector<4x2xi32>
}
// -----
// CHECK-LABEL: linearize_vector_broadcast_rank_two_source
// CHECK-SAME: (%[[ARG:.*]]: vector<1x1xi32>) -> vector<4x2xi32>
func.func @linearize_vector_broadcast_rank_two_source(%arg0: vector<1x1xi32>) -> vector<4x2xi32> {
// CHECK: %[[CAST0:.*]] = vector.shape_cast %[[ARG]] : vector<1x1xi32> to vector<1xi32>
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST0]] : vector<1xi32> to vector<8xi32>
// CHECK: %[[CAST1:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32>
// CHECK: return %[[CAST1]] : vector<4x2xi32>
%0 = vector.broadcast %arg0 : vector<1x1xi32> to vector<4x2xi32>
return %0 : vector<4x2xi32>
}
// -----
// CHECK-LABEL: linearize_scalable_vector_broadcast
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32>
func.func @linearize_scalable_vector_broadcast(%arg0: i32) -> vector<4x[2]xi32> {
func.func @broadcast_scalar_source(%arg0: i32) -> vector<4x2xi32> {
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[ARG]] : i32 to vector<8xi32>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32>
// CHECK: return %[[CAST]] : vector<4x2xi32>
%0 = vector.broadcast %arg0 : i32 to vector<4x2xi32>
return %0 : vector<4x2xi32>
}
// -----
// CHECK-LABEL: linearize_vector_broadcast_rank_two_source
// CHECK-SAME: (%[[ARG:.*]]: vector<1x1xi32>) -> vector<4x2xi32>
func.func @broadcast_rank_two_source(%arg0: vector<1x1xi32>) -> vector<4x2xi32> {
// CHECK: %[[CAST0:.*]] = vector.shape_cast %[[ARG]] : vector<1x1xi32> to vector<1xi32>
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[CAST0]] : vector<1xi32> to vector<8xi32>
// CHECK: %[[CAST1:.*]] = vector.shape_cast %[[BROADCAST]] : vector<8xi32> to vector<4x2xi32>
// CHECK: return %[[CAST1]] : vector<4x2xi32>
%0 = vector.broadcast %arg0 : vector<1x1xi32> to vector<4x2xi32>
return %0 : vector<4x2xi32>
}
// -----
// CHECK-LABEL: linearize_scalable_vector_broadcast
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32>
func.func @broadcast_rank_two_source_scalable(%arg0: i32) -> vector<4x[2]xi32> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can post a follow-up with updates to the function names, here I was just being consistent with what is currently in this file!

@newling newling enabled auto-merge (squash) October 17, 2025 23:21
Signed-off-by: James Newling <james.newling@gmail.com>
@newling newling merged commit fe5b72a into llvm:main Oct 17, 2025
10 checks passed
@Garra1980
Copy link

Thanks for the quick fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment