Skip to content

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Oct 17, 2025

This PR implements unrolling for vector.shape_cast operations by decomposing them into smaller tiles processed element-by-element. For each element in a result tile, it converts the result position to a linear index, then maps that linear index back to the corresponding source coordinates for extraction.

@llvmbot
Copy link
Member

llvmbot commented Oct 17, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Nishant Patel (nbpatel)

Changes

This PR implements unrolling for vector.shape_cast operations by decomposing them into smaller tiles processed element-by-element. For each element in a result tile, it converts the result position to a linear index, then maps that linear index back to the corresponding source coordinates for extraction.


Full diff: https://github.com/llvm/llvm-project/pull/164010.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+4)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+149-2)
  • (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+92)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+2-2)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 6e79085afac9f..39097368b1e71 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2408,6 +2408,7 @@ def Vector_CompressStoreOp : def Vector_ShapeCastOp : Vector_Op<"shape_cast", [Pure, + DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]> ]>, Arguments<(ins AnyVectorOfAnyRank:$source)>, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 58256b0ade9f6..dff66a6e829a9 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6233,6 +6233,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, setResultRanges(getResult(), argRanges.front()); } +std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() { + return llvm::to_vector<4>(getResultVectorType().getShape()); +} + LogicalResult ShapeCastOp::verify() { VectorType sourceType = getSourceVectorType(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index fbae0989bed26..8a969b6c6be6b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1003,6 +1003,153 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> { vector::UnrollVectorOptions options; }; +/// This pattern unrolls `vector.shape_cast` operations according to the +/// provided target unroll shape. It decomposes a large shape_cast operation +/// into smaller tiles and reconstructs each tile by extracting individual +/// elements from the source vector and placing them at the correct positions. +/// +/// Since shape_cast performs linear element reindexing, the pattern uses +/// linear indexing as a bridge to map between source and result coordinates. +/// For each element in a result tile, it calculates the corresponding source +/// position and extracts that element. +/// +/// Example: +/// Given a shape_cast operation: +/// %0 = vector.shape_cast %src : vector<2x8xf32> to vector<4x4xf32> +/// +/// and a target unroll shape of <2x2>, the pattern produces: +/// +/// %zero = arith.constant dense<0.0> : vector<4x4xf32> +/// %tile_zero = arith.constant dense<0.0> : vector<2x2xf32> +/// +/// // First tile [0,0]: elements at result positions +/// (0,0),(0,1),(1,0),(1,1) +/// %e0 = vector.extract %src[0, 0] : f32 from vector<2x8xf32> +/// %t0 = vector.insert %e0, %tile_zero [0, 0] : f32 into vector<2x2xf32> +/// %e1 = vector.extract %src[0, 1] : f32 from vector<2x8xf32> +/// %t1 = vector.insert %e1, %t0 [0, 1] : f32 into vector<2x2xf32> +/// %e2 = vector.extract %src[0, 4] : f32 from vector<2x8xf32> +/// %t2 = vector.insert %e2, %t1 [1, 0] : f32 into vector<2x2xf32> +/// %e3 = vector.extract %src[0, 5] : f32 from vector<2x8xf32> +/// %t3 = vector.insert %e3, %t2 [1, 1] : f32 into vector<2x2xf32> +/// %r0 = vector.insert_strided_slice %t3, %zero +/// {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into +/// vector<4x4xf32> +/// +struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> { + UnrollShapeCastPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern<vector::ShapeCastOp>(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, + PatternRewriter &rewriter) const override { + auto targetShape = getTargetShape(options, shapeCastOp); + if (!targetShape) + return failure(); + + Location loc = shapeCastOp.getLoc(); + VectorType sourceType = shapeCastOp.getSourceVectorType(); + VectorType resultType = shapeCastOp.getResultVectorType(); + + ArrayRef<int64_t> resultShape = resultType.getShape(); + ArrayRef<int64_t> sourceShape = sourceType.getShape(); + + SmallVector<int64_t> strides(targetShape->size(), 1); + Value result = rewriter.create<arith::ConstantOp>( + loc, resultType, rewriter.getZeroAttr(resultType)); + + // For each unrolled tile in the result + for (SmallVector<int64_t> tileOffsets : + StaticTileOffsetRange(resultShape, *targetShape)) { + + // Create the target tile type + VectorType tileType = + VectorType::get(*targetShape, resultType.getElementType()); + + // Build the tile by extracting individual elements + Value tile = createTileFromElements( + rewriter, loc, shapeCastOp.getSource(), sourceShape, resultShape, + tileOffsets, *targetShape, tileType); + + // Insert the tile into the result + result = rewriter.create<vector::InsertStridedSliceOp>( + loc, tile, result, tileOffsets, strides); + } + + rewriter.replaceOp(shapeCastOp, result); + return success(); + } + +private: + /// Creates a result tile by extracting individual elements from the source + /// and inserting them at the correct positions in the tile. + Value createTileFromElements(PatternRewriter &rewriter, Location loc, + Value source, ArrayRef<int64_t> sourceShape, + ArrayRef<int64_t> resultShape, + ArrayRef<int64_t> tileOffsets, + ArrayRef<int64_t> tileShape, + VectorType tileType) const { + + // Initialize tile with zeros + Value tile = rewriter.create<arith::ConstantOp>( + loc, tileType, rewriter.getZeroAttr(tileType)); + + // Calculate strides for both source and result shapes + SmallVector<int64_t> sourceStrides = computeStrides(sourceShape); + SmallVector<int64_t> resultStrides = computeStrides(resultShape); + + // Iterate over all positions in the tile using linear indexing + for (int64_t linearTileIdx = 0; linearTileIdx < computeProduct(tileShape); + ++linearTileIdx) { + // Convert linear tile index to multi-dimensional tile position + SmallVector<int64_t> tilePosition = + linearIndexToMultiDim(linearTileIdx, tileShape); + + // Calculate the global position in the result + SmallVector<int64_t> globalResultPos; + globalResultPos.reserve(tileOffsets.size()); + for (auto [offset, pos] : llvm::zip(tileOffsets, tilePosition)) { + globalResultPos.push_back(offset + pos); + } + + // Convert result position to linear index + int64_t linearIndex = linearize(globalResultPos, resultStrides); + + // Convert linear index to source position + SmallVector<int64_t> sourcePos = + linearIndexToMultiDim(linearIndex, sourceShape); + + // Extract element from source + Value element = + rewriter.create<vector::ExtractOp>(loc, source, sourcePos); + + // Insert element into tile + tile = + rewriter.create<vector::InsertOp>(loc, element, tile, tilePosition); + } + + return tile; + } + + /// Converts a linear index to multi-dimensional position within a given + /// shape. Used for both tile iteration and source coordinate computation. + SmallVector<int64_t> linearIndexToMultiDim(int64_t linearIndex, + ArrayRef<int64_t> shape) const { + SmallVector<int64_t> position(shape.size()); + + for (int64_t i = shape.size() - 1; i >= 0; --i) { + position[i] = linearIndex % shape[i]; + linearIndex /= shape[i]; + } + + return position; + } + + vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( @@ -1013,8 +1160,8 @@ void mlir::vector::populateVectorUnrollPatterns( UnrollReductionPattern, UnrollMultiReductionPattern, UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, - UnrollToElements, UnrollStepPattern>(patterns.getContext(), - options, benefit); + UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>( + patterns.getContext(), options, benefit); } void mlir::vector::populateVectorToElementsUnrollPatterns( diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index e5a98b5c67f33..7a7129e9027a0 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -496,3 +496,95 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3 // CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32> // CHECK-NOT: arith.addf // CHECK: return + +//CHECK-LABEL: func @shape_cast_1D_to_2D +// CHECK-SAME: (%[[ARG0:.*]]: vector<16xf32>) -> vector<4x4xf32> +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32> +// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> +// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0] : f32 from vector<16xf32> +// CHECK: %[[INS0:.*]] = vector.insert %[[E0]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<16xf32> +// CHECK: %[[INS1:.*]] = vector.insert %[[E1]], %[[INS0]] [0, 1] : f32 into vector<2x2xf32> +// CHECK: %[[E2:.*]] = vector.extract %[[ARG0]][4] : f32 from vector<16xf32> +// CHECK: %[[INS2:.*]] = vector.insert %[[E2]], %[[INS1]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E3:.*]] = vector.extract %[[ARG0]][5] : f32 from vector<16xf32> +// CHECK: %[[V0:.*]] = vector.insert %[[E3]], %[[INS2]] [1, 1] : f32 into vector<2x2xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: %[[E4:.*]] = vector.extract %[[ARG0]][2] : f32 from vector<16xf32> +// CHECK: %[[INS3:.*]] = vector.insert %[[E4]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E5:.*]] = vector.extract %[[ARG0]][3] : f32 from vector<16xf32> +// CHECK: %[[INS4:.*]] = vector.insert %[[E5]], %[[INS3]] [0, 1] : f32 into vector<2x2xf32> +// CHECK: %[[E6:.*]] = vector.extract %[[ARG0]][6] : f32 from vector<16xf32> +// CHECK: %[[INS5:.*]] = vector.insert %[[E6]], %[[INS4]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E7:.*]] = vector.extract %[[ARG0]][7] : f32 from vector<16xf32> +// CHECK: %[[V1:.*]] = vector.insert %[[E7]], %[[INS5]] [1, 1] : f32 into vector<2x2xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[V1]], %[[I0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: %[[E8:.*]] = vector.extract %[[ARG0]][8] : f32 from vector<16xf32> +// CHECK: %[[INS6:.*]] = vector.insert %[[E8]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E9:.*]] = vector.extract %[[ARG0]][9] : f32 from vector<16xf32> +// CHECK: %[[INS7:.*]] = vector.insert %[[E9]], %[[INS6]] [0, 1] : f32 into vector<2x2xf32> +// CHECK: %[[E10:.*]] = vector.extract %[[ARG0]][12] : f32 from vector<16xf32> +// CHECK: %[[INS8:.*]] = vector.insert %[[E10]], %[[INS7]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E11:.*]] = vector.extract %[[ARG0]][13] : f32 from vector<16xf32> +// CHECK: %[[V2:.*]] = vector.insert %[[E11]], %[[INS8]] [1, 1] : f32 into vector<2x2xf32> +// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[V2]], %[[I1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: %[[E12:.*]] = vector.extract %[[ARG0]][10] : f32 from vector<16xf32> +// CHECK: %[[INS9:.*]] = vector.insert %[[E12]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E13:.*]] = vector.extract %[[ARG0]][11] : f32 from vector<16xf32> +// CHECK: %[[INS10:.*]] = vector.insert %[[E13]], %[[INS9]] [0, 1] : f32 into vector<2x2xf32> +// CHECK: %[[E14:.*]] = vector.extract %[[ARG0]][14] : f32 from vector<16xf32> +// CHECK: %[[INS11:.*]] = vector.insert %[[E14]], %[[INS10]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E15:.*]] = vector.extract %[[ARG0]][15] : f32 from vector<16xf32> +// CHECK: %[[V3:.*]] = vector.insert %[[E15]], %[[INS11]] [1, 1] : f32 into vector<2x2xf32> +// CHECK: %[[I3:.*]] = vector.insert_strided_slice %[[V3]], %[[I2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: return %[[I3]] : vector<4x4xf32> +func.func @shape_cast_1D_to_2D(%v: vector<16xf32>) -> vector<4x4xf32> { + %0 = vector.shape_cast %v : vector<16xf32> to vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +//CHECK-LABEL: func @shape_cast_2D +// CHECK-SAME: (%[[ARG0:.*]]: vector<2x8xf32>) -> vector<4x4xf32> +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32> +// CHECK: %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> +// CHECK: %[[E0:.*]] = vector.extract %[[ARG0]][0, 0] : f32 from vector<2x8xf32> +// CHECK: %[[INS0:.*]] = vector.insert %[[E0]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E1:.*]] = vector.extract %[[ARG0]][0, 1] : f32 from vector<2x8xf32> +// CHECK: %[[INS1:.*]] = vector.insert %[[E1]], %[[INS0]] [0, 1] : f32 into vector<2x2xf32> +// CHECK: %[[E2:.*]] = vector.extract %[[ARG0]][0, 4] : f32 from vector<2x8xf32> +// CHECK: %[[INS2:.*]] = vector.insert %[[E2]], %[[INS1]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E3:.*]] = vector.extract %[[ARG0]][0, 5] : f32 from vector<2x8xf32> +// CHECK: %[[V0:.*]] = vector.insert %[[E3]], %[[INS2]] [1, 1] : f32 into vector<2x2xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: %[[E4:.*]] = vector.extract %[[ARG0]][0, 2] : f32 from vector<2x8xf32> +// CHECK: %[[INS3:.*]] = vector.insert %[[E4]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E5:.*]] = vector.extract %[[ARG0]][0, 3] : f32 from vector<2x8xf32> +// CHECK: %[[INS4:.*]] = vector.insert %[[E5]], %[[INS3]] [0, 1] : f32 into vector<2x2xf32> +// CHECK: %[[E6:.*]] = vector.extract %[[ARG0]][0, 6] : f32 from vector<2x8xf32> +// CHECK: %[[INS5:.*]] = vector.insert %[[E6]], %[[INS4]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E7:.*]] = vector.extract %[[ARG0]][0, 7] : f32 from vector<2x8xf32> +// CHECK: %[[V1:.*]] = vector.insert %[[E7]], %[[INS5]] [1, 1] : f32 into vector<2x2xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[V1]], %[[I0]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: %[[E8:.*]] = vector.extract %[[ARG0]][1, 0] : f32 from vector<2x8xf32> +// CHECK: %[[INS6:.*]] = vector.insert %[[E8]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E9:.*]] = vector.extract %[[ARG0]][1, 1] : f32 from vector<2x8xf32> +// CHECK: %[[INS7:.*]] = vector.insert %[[E9]], %[[INS6]] [0, 1] : f32 into vector<2x2xf32> +// CHECK: %[[E10:.*]] = vector.extract %[[ARG0]][1, 4] : f32 from vector<2x8xf32> +// CHECK: %[[INS8:.*]] = vector.insert %[[E10]], %[[INS7]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E11:.*]] = vector.extract %[[ARG0]][1, 5] : f32 from vector<2x8xf32> +// CHECK: %[[V2:.*]] = vector.insert %[[E11]], %[[INS8]] [1, 1] : f32 into vector<2x2xf32> +// CHECK: %[[I2:.*]] = vector.insert_strided_slice %[[V2]], %[[I1]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: %[[E12:.*]] = vector.extract %[[ARG0]][1, 2] : f32 from vector<2x8xf32> +// CHECK: %[[INS9:.*]] = vector.insert %[[E12]], %[[CST_0]] [0, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E13:.*]] = vector.extract %[[ARG0]][1, 3] : f32 from vector<2x8xf32> +// CHECK: %[[INS10:.*]] = vector.insert %[[E13]], %[[INS9]] [0, 1] : f32 into vector<2x2xf32> +// CHECK: %[[E14:.*]] = vector.extract %[[ARG0]][1, 6] : f32 from vector<2x8xf32> +// CHECK: %[[INS11:.*]] = vector.insert %[[E14]], %[[INS10]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[E15:.*]] = vector.extract %[[ARG0]][1, 7] : f32 from vector<2x8xf32> +// CHECK: %[[V3:.*]] = vector.insert %[[E15]], %[[INS11]] [1, 1] : f32 into vector<2x2xf32> +// CHECK: %[[I3:.*]] = vector.insert_strided_slice %[[V3]], %[[I2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> +// CHECK: return %[[I3]] : vector<4x4xf32> +func.func @shape_cast_2D(%v: vector<2x8xf32>) -> vector<4x4xf32> { + %0 = vector.shape_cast %v : vector<2x8xf32> to vector<4x4xf32> + return %0 : vector<4x4xf32> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 79bfc9bbcda71..0a54f06f5d6b6 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -163,8 +163,8 @@ struct TestVectorUnrollingPatterns .setFilterConstraint([](Operation *op) { return success( isa<arith::AddFOp, vector::FMAOp, vector::MultiDimReductionOp, - vector::BroadcastOp, vector::LoadOp, vector::StoreOp>( - op)); + vector::BroadcastOp, vector::LoadOp, vector::StoreOp, + vector::ShapeCastOp>(op)); })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() 
@nbpatel nbpatel requested a review from newling October 17, 2025 20:32
@nbpatel
Copy link
Contributor Author

nbpatel commented Oct 17, 2025

@newling @dcaballe @banach-space Please let me know if this approach is good or is there a better way to unroll it.

@nbpatel
Copy link
Contributor Author

nbpatel commented Oct 18, 2025

@kuhar Thanks for the feedback. I addressed the comments

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment