Skip to content

Conversation

@MaheshRavishankar
Copy link
Contributor

The existing scf::tileAndFuseConsumerOfSlices takes a list of slices (and loops they are part of), tries to find the consumer of these slices (all slices are expected to be the same consumer), and then tiles the consumer into the loop nest using the TilingInterface. A more natural way of doing consumer fusion is to just start from the consumer, look for operands that are produced by the loop nest passed in as loops (presumably these loops are generated by tiling, but that is not a requirement for consumer fusion). Using the consumer you can find the slices of the operands that are accessed within the loop which you can then use to tile and fuse the consumer (using TilingInterface). This handles more naturally the case where multiple operands of the consumer come from the loop nest.

The scf::tileAndFuseConsumerOfSlices was implemented as a mirror of scf::tileAndFuseProducerOfSlice. For the latter, the slice has a single producer for the source of the slice, which makes it a natural way of specifying producer fusion. But for consumers, the result might have multiple users, resulting in multiple candidates for fusion, as well as a fusion candidate using multiple results from the tiled loop nest. This means using slices
(tensor.insert_slice/tensor.parallel_insert_slice) as a hook for consumer fusion turns out to be quite hard to navigate. The use of the consumer directly avoids all those pain points. In time the scf::tileAndFuseConsumerOfSlices should be deprecated in favor of scf::tileAndFuseConsumer. There is a lot of tech-debt that has accumulated in scf::tileAndFuseConsumerOfSlices that needs to be cleanedup. So while that gets cleaned up, and required functionality is moved to scf::tileAndFuseConsumer, the old path is still maintained.

The test for scf::tileAndFuseConsumerUsingSlices is copied to tile-and-fuse-consumer.mlir to
tile-and-fuse-consumer-using-slices.mlir. All the tests that were there in this file are now using the tileAndFuseConsumer method. The test op test.tile_and_fuse_consumer is modified to call scf::tileAndFuseConsumer, while a new op
test.tile_and_fuse_consumer_of_slice is used to keep the old path tested while it is deprecated.

… a given tiled loop nest. The existing `scf::tileAndFuseConsumerOfSlices` takes a list of slices (and loops they are part of), tries to find the consumer of these slices (all slices are expected to be the same consumer), and then tiles the consumer into the loop nest using the `TilingInterface`. A more natural way of doing consumer fusion is to just start from the consumer, look for operands that are produced by the loop nest passed in as `loops` (presumably these loops are generated by tiling, but that is not a requirement for consumer fusion). Using the consumer you can find the slices of the operands that are accessed within the loop which you can then use to tile and fuse the consumer (using `TilingInterface`). This handles more naturally the case where multiple operands of the consumer come from the loop nest. The `scf::tileAndFuseConsumerOfSlices` was implemented as a mirror of `scf::tileAndFuseProducerOfSlice`. For the latter, the slice has a single producer for the source of the slice, which makes it a natural way of specifying producer fusion. But for consumers, the result might have multiple users, resulting in multiple candidates for fusion, as well as a fusion candidate using multiple results from the tiled loop nest. This means using slices (`tensor.insert_slice`/`tensor.parallel_insert_slice`) as a hook for consumer fusion turns out to be quite hard to navigate. The use of the consumer directly avoids all those pain points. In time the `scf::tileAndFuseConsumerOfSlices` should be deprecated in favor of `scf::tileAndFuseConsumer`. There is a lot of tech-debt that has accumulated in `scf::tileAndFuseConsumerOfSlices` that needs to be cleanedup. So while that gets cleaned up, and required functionality is moved to `scf::tileAndFuseConsumer`, the old path is still maintained. The test for `scf::tileAndFuseConsumerUsingSlices` is copied to `tile-and-fuse-consumer.mlir` to `tile-and-fuse-consumer-using-slices.mlir`. All the tests that were there in this file are now using the `tileAndFuseConsumer` method. The test op `test.tile_and_fuse_consumer` is modified to call `scf::tileAndFuseConsumer`, while a new op `test.tile_and_fuse_consumer_of_slice` is used to keep the old path tested while it is deprecated. Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Author: None (MaheshRavishankar)

Changes

The existing scf::tileAndFuseConsumerOfSlices takes a list of slices (and loops they are part of), tries to find the consumer of these slices (all slices are expected to be the same consumer), and then tiles the consumer into the loop nest using the TilingInterface. A more natural way of doing consumer fusion is to just start from the consumer, look for operands that are produced by the loop nest passed in as loops (presumably these loops are generated by tiling, but that is not a requirement for consumer fusion). Using the consumer you can find the slices of the operands that are accessed within the loop which you can then use to tile and fuse the consumer (using TilingInterface). This handles more naturally the case where multiple operands of the consumer come from the loop nest.

The scf::tileAndFuseConsumerOfSlices was implemented as a mirror of scf::tileAndFuseProducerOfSlice. For the latter, the slice has a single producer for the source of the slice, which makes it a natural way of specifying producer fusion. But for consumers, the result might have multiple users, resulting in multiple candidates for fusion, as well as a fusion candidate using multiple results from the tiled loop nest. This means using slices
(tensor.insert_slice/tensor.parallel_insert_slice) as a hook for consumer fusion turns out to be quite hard to navigate. The use of the consumer directly avoids all those pain points. In time the scf::tileAndFuseConsumerOfSlices should be deprecated in favor of scf::tileAndFuseConsumer. There is a lot of tech-debt that has accumulated in scf::tileAndFuseConsumerOfSlices that needs to be cleanedup. So while that gets cleaned up, and required functionality is moved to scf::tileAndFuseConsumer, the old path is still maintained.

The test for scf::tileAndFuseConsumerUsingSlices is copied to tile-and-fuse-consumer.mlir to
tile-and-fuse-consumer-using-slices.mlir. All the tests that were there in this file are now using the tileAndFuseConsumer method. The test op test.tile_and_fuse_consumer is modified to call scf::tileAndFuseConsumer, while a new op
test.tile_and_fuse_consumer_of_slice is used to keep the old path tested while it is deprecated.


Patch is 141.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167634.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+5)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+12)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+172-49)
  • (modified) mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir (+2-2)
  • (added) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir (+1156)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+189-191)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+71-8)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td (+23-1)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index cd033c140a233..8bdf3e0b566ef 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -613,6 +613,11 @@ def ForallOp : SCF_Op<"forall", [ getNumDynamicControlOperands() + getRank()); } + BlockArgument getTiedBlockArgument(OpResult opResult) { + assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult"); + return getBody()->getArgument(getRank() + opResult.getResultNumber()); + } + ::mlir::Value getInductionVar(int64_t idx) { return getInductionVars()[idx]; } diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 7c735d825b445..0005fad3d5c01 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -415,6 +415,10 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, /// tiled in a manner that is consistent for all the passed slices. Note that /// the method replaces the uses of `candidateSlices` with the tiled and fused /// consumer value but does not delete the slice operations. +/// TODO(MaheshRavishankar): A more natural way of exposing the consumer fusion +/// is to take the consumer operation, and find the slices to use for fusion +/// by walking its operands to the `loops` and then into the body to get the +/// slices used for fusion. struct SCFFuseConsumerOfSliceResult { // Original untiled consumer operands. SmallVector<OpOperand *> origConsumerOperands; @@ -427,6 +431,14 @@ tileAndFuseConsumerOfSlices(RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, MutableArrayRef<LoopLikeOpInterface> loops); +/// Fuse the `consumer` operation into the loop nest provided by `loops`. +/// The transformation looks for operands in the `consumer` that are defined +/// by the outermost loop of the loop nest in `loops`. The nested loop is +/// expected to have the structure of the loops generated through tiling. +FailureOr<scf::SCFFuseConsumerOfSliceResult> +tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer, + MutableArrayRef<LoopLikeOpInterface> loops); + /// Method to lower an `op` that implements the `TilingInterface` to /// loops/scalars. FailureOr<SmallVector<scf::ForOp>> diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 29b770fb4b279..7e715ee189740 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest( for (auto [outerLoop, innerLoop] : llvm::zip_equal(loops.drop_back(), loops.drop_front())) { // Again assume that all the outer loops are scf.for operations. - auto outerForLoop = cast<scf::ForOp>(outerLoop); + auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation()); auto outerLoopYield = cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); SmallVector<Value> newYields = @@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter, return clonedSlices; } -/// Implementation of fusing consumer of a single slice by computing the -/// slice of the consumer in-place for scf loop. -FailureOr<scf::SCFFuseConsumerOfSliceResult> -mlir::scf::tileAndFuseConsumerOfSlices( - RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, - MutableArrayRef<LoopLikeOpInterface> loops) { - if (candidateSlices.empty()) { - return rewriter.notifyMatchFailure( - rewriter.getUnknownLoc(), - "no candidate slices provided for consumer fusion"); - } - // Return if `loops` is empty, return an error for now. Caller is expected - // to handle this case. - if (loops.empty()) { - return rewriter.notifyMatchFailure( - candidateSlices.front(), - "cannot call tile and fuse consumer with an empty loop nest"); - } +static FailureOr<scf::SCFFuseConsumerOfSliceResult> +tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp, + ArrayRef<OpOperand *> consumerOpOperands, + ArrayRef<Operation *> candidateSlices, + MutableArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "expected loops to be not empty"); - if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || - llvm::all_of(candidateSlices, - llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { + // 1. Check assumption for loop with `reorderOperations` disabled. + if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) { return rewriter.notifyMatchFailure( - candidateSlices.front(), - "candidates slices need to be all `tensor.extract_slice`s or " - "`tensor.parallel_insert_slice`s"); - } - - // 1. Get the consumer of scf.for for the result yielded by - // tensor.insert_slice/parallel_insert_slice. - SmallVector<OpOperand *> consumerOpOperands; - Operation *consumerOp; - { - FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand = - getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); - if (failed(maybeConsumerOpOperand)) { - return rewriter.notifyMatchFailure(candidateSlices.front(), - "could not fetch consumer to fuse"); - } - std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); - consumerOp = consumerOpOperands.front()->getOwner(); + loops.front(), "the first user of loop should not dominate any define " + "of consumer operand(s)"); } LoopLikeOpInterface outerMostLoop = loops.front(); LoopLikeOpInterface innerMostLoop = loops.back(); - // Check assumption for loop with `reorderOperations` disabled. - if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { - return rewriter.notifyMatchFailure( - outerMostLoop, "the first user of loop should not dominate any define " - "of consumer operand(s)"); - } - OpBuilder::InsertionGuard g(rewriter); - // 2. Check consumer is not using scf loop's output as init. auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp); if (!dstOp) @@ -2428,11 +2391,171 @@ mlir::scf::tileAndFuseConsumerOfSlices( llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) { return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum); }); + auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands); return scf::SCFFuseConsumerOfSliceResult{ - std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands), + std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands), std::move(tileAndFuseResult->tiledOps)}; } +/// Implementation of fusing consumer of a single slice by computing the +/// slice of the consumer in-place for scf loop. +FailureOr<scf::SCFFuseConsumerOfSliceResult> +mlir::scf::tileAndFuseConsumerOfSlices( + RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, + MutableArrayRef<LoopLikeOpInterface> loops) { + if (candidateSlices.empty()) { + return rewriter.notifyMatchFailure( + rewriter.getUnknownLoc(), + "no candidate slices provided for consumer fusion"); + } + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "cannot call tile and fuse consumer with an empty loop nest"); + } + + if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || + llvm::all_of(candidateSlices, + llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "candidates slices need to be all `tensor.extract_slice`s or " + "`tensor.parallel_insert_slice`s"); + } + + // Get the consumer of scf.for for the result yielded by + // tensor.insert_slice/parallel_insert_slice. + SmallVector<OpOperand *> consumerOpOperands; + Operation *consumerOp; + { + FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand = + getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); + if (failed(maybeConsumerOpOperand)) { + return rewriter.notifyMatchFailure(candidateSlices.front(), + "could not fetch consumer to fuse"); + } + std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); + consumerOp = consumerOpOperands.front()->getOwner(); + } + + return tileAndFuseConsumerOfSlicesImpl( + rewriter, consumerOp, consumerOpOperands, candidateSlices, loops); +} + +/// For a given `result` of a `forallOp` return the +/// `tensor.parallel_insert_slice` op (or combining op) that is used to +/// construct this result. +static std::optional<Operation *> +getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) { + if (result.getOwner() != forallOp) + return std::nullopt; + BlockArgument bbArg = forallOp.getTiedBlockArgument(result); + SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg); + // If the number of combining ops is not 1, then this is unexpected. Return + // nullopt. + if (combiningOps.size() != 1) { + return std::nullopt; + } + return combiningOps[0]; +} + +/// For a given result of the loop nest that is a tiled loop nest, return the +/// insert slice-like op that is used for consumer fusion +std::optional<Operation *> +getProducingInsertSliceLikeOp(OpResult result, + ArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "Expected loops to be not empty"); + LoopLikeOpInterface outermostLoop = loops.front(); + + if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation())) { + assert(loops.size() == 1 && + "expected only a single loop when tiling using scf.forall"); + return getProducingParallelInsertSlice(forallOp, result); + } + // Assume that the loop nest is a nested `scf.for` that is created through + // tiling and retrieve the `tensor.insert_slice` operation used to construct + // the result. + while (loops.size() != 1) { + if (result.getOwner() != loops.front()) + return std::nullopt; + auto forOp = dyn_cast<scf::ForOp>(loops.front()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); + OpResult innerForResult = + dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber())); + if (!innerForResult) + return std::nullopt; + result = innerForResult; + loops = loops.drop_front(); + } + if (result.getOwner() != loops.front()) + return std::nullopt; + auto forOp = dyn_cast<scf::ForOp>(loops.front()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); + auto insertSliceOp = yieldOp.getOperand(result.getResultNumber()) + .getDefiningOp<tensor::InsertSliceOp>(); + if (!insertSliceOp) + return std::nullopt; + return insertSliceOp; +} + +FailureOr<scf::SCFFuseConsumerOfSliceResult> +mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *user, + MutableArrayRef<LoopLikeOpInterface> loops) { + // Only handle users that implement the `TilingInterface`. + if (!isa<TilingInterface>(user)) { + return rewriter.notifyMatchFailure( + user, "unhandled user that does not implement TilingInterface"); + } + + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + user, "cannot call tile and fuse consumer with an empty loop nest"); + } + + LoopLikeOpInterface outermostLoop = loops.front(); + + // Collect the operands of the user that come from the outermost loop of the + // loop nest. + SmallVector<OpOperand *> consumerFusableOperands; + for (OpOperand &opOperand : user->getOpOperands()) { + if (opOperand.get().getDefiningOp() == outermostLoop) { + consumerFusableOperands.push_back(&opOperand); + } + } + + // Nothing to fuse. Just return an empty set. + if (consumerFusableOperands.empty()) { + return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands, + SmallVector<OpOperand *>{}, + SmallVector<Operation *>{}}; + } + + // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices + // for fusion. + SmallVector<Operation *> candidateSlices; + candidateSlices.reserve(consumerFusableOperands.size()); + for (OpOperand *opOperand : consumerFusableOperands) { + std::optional<Operation *> slice = + getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops); + if (!slice) { + return rewriter.notifyMatchFailure( + user, + "couldnt find producing insert-slice like operation for operand"); + } + candidateSlices.push_back(slice.value()); + } + return tileAndFuseConsumerOfSlicesImpl( + rewriter, user, consumerFusableOperands, candidateSlices, loops); +} + //===----------------------------------------------------------------------===// // lowerToLoopsUsingSCFForOp implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir index 185fb9b358055..d72ab080f3c5c 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir @@ -170,7 +170,7 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - transform.test.fuse_consumer %slice_op in (%forall_op) + transform.test.fuse_consumer_using_slice %slice_op in (%forall_op) : (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -231,7 +231,7 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice + // Note that we cannot apply transform.test.fuse_consumer_using_slice here because the extract_slice // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer // to fuse" error. transform.yield diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir new file mode 100644 index 0000000000000..62dd7faec4eb7 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir @@ -0,0 +1,1156 @@ +// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s + +#map = affine_map<(d0) -> (d0)> +module { + func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32xf32> + %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32> + scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32> + } + %in_operand_2 = tensor.empty() : tensor<64xf32> + %out_operand_3 = tensor.empty() : tensor<64xf32> + %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32> + return %2 : tensor<64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_for( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %0 = tensor.empty() : tensor<64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[MAT_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>) +// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[ELEM_OUT:.*]] = linalg.add +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT]] : +// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] : +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#2 : + +// ----- + +module { + func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + } + } + %in_ope... [truncated] 
@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2025

@llvm/pr-subscribers-mlir-linalg

Author: None (MaheshRavishankar)

Changes

The existing scf::tileAndFuseConsumerOfSlices takes a list of slices (and loops they are part of), tries to find the consumer of these slices (all slices are expected to be the same consumer), and then tiles the consumer into the loop nest using the TilingInterface. A more natural way of doing consumer fusion is to just start from the consumer, look for operands that are produced by the loop nest passed in as loops (presumably these loops are generated by tiling, but that is not a requirement for consumer fusion). Using the consumer you can find the slices of the operands that are accessed within the loop which you can then use to tile and fuse the consumer (using TilingInterface). This handles more naturally the case where multiple operands of the consumer come from the loop nest.

The scf::tileAndFuseConsumerOfSlices was implemented as a mirror of scf::tileAndFuseProducerOfSlice. For the latter, the slice has a single producer for the source of the slice, which makes it a natural way of specifying producer fusion. But for consumers, the result might have multiple users, resulting in multiple candidates for fusion, as well as a fusion candidate using multiple results from the tiled loop nest. This means using slices
(tensor.insert_slice/tensor.parallel_insert_slice) as a hook for consumer fusion turns out to be quite hard to navigate. The use of the consumer directly avoids all those pain points. In time the scf::tileAndFuseConsumerOfSlices should be deprecated in favor of scf::tileAndFuseConsumer. There is a lot of tech-debt that has accumulated in scf::tileAndFuseConsumerOfSlices that needs to be cleanedup. So while that gets cleaned up, and required functionality is moved to scf::tileAndFuseConsumer, the old path is still maintained.

The test for scf::tileAndFuseConsumerUsingSlices is copied to tile-and-fuse-consumer.mlir to
tile-and-fuse-consumer-using-slices.mlir. All the tests that were there in this file are now using the tileAndFuseConsumer method. The test op test.tile_and_fuse_consumer is modified to call scf::tileAndFuseConsumer, while a new op
test.tile_and_fuse_consumer_of_slice is used to keep the old path tested while it is deprecated.


Patch is 141.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167634.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+5)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+12)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+172-49)
  • (modified) mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir (+2-2)
  • (added) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir (+1156)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+189-191)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+71-8)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td (+23-1)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index cd033c140a233..8bdf3e0b566ef 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -613,6 +613,11 @@ def ForallOp : SCF_Op<"forall", [ getNumDynamicControlOperands() + getRank()); } + BlockArgument getTiedBlockArgument(OpResult opResult) { + assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult"); + return getBody()->getArgument(getRank() + opResult.getResultNumber()); + } + ::mlir::Value getInductionVar(int64_t idx) { return getInductionVars()[idx]; } diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 7c735d825b445..0005fad3d5c01 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -415,6 +415,10 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, /// tiled in a manner that is consistent for all the passed slices. Note that /// the method replaces the uses of `candidateSlices` with the tiled and fused /// consumer value but does not delete the slice operations. +/// TODO(MaheshRavishankar): A more natural way of exposing the consumer fusion +/// is to take the consumer operation, and find the slices to use for fusion +/// by walking its operands to the `loops` and then into the body to get the +/// slices used for fusion. struct SCFFuseConsumerOfSliceResult { // Original untiled consumer operands. SmallVector<OpOperand *> origConsumerOperands; @@ -427,6 +431,14 @@ tileAndFuseConsumerOfSlices(RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, MutableArrayRef<LoopLikeOpInterface> loops); +/// Fuse the `consumer` operation into the loop nest provided by `loops`. +/// The transformation looks for operands in the `consumer` that are defined +/// by the outermost loop of the loop nest in `loops`. The nested loop is +/// expected to have the structure of the loops generated through tiling. +FailureOr<scf::SCFFuseConsumerOfSliceResult> +tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer, + MutableArrayRef<LoopLikeOpInterface> loops); + /// Method to lower an `op` that implements the `TilingInterface` to /// loops/scalars. FailureOr<SmallVector<scf::ForOp>> diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 29b770fb4b279..7e715ee189740 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest( for (auto [outerLoop, innerLoop] : llvm::zip_equal(loops.drop_back(), loops.drop_front())) { // Again assume that all the outer loops are scf.for operations. - auto outerForLoop = cast<scf::ForOp>(outerLoop); + auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation()); auto outerLoopYield = cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); SmallVector<Value> newYields = @@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter, return clonedSlices; } -/// Implementation of fusing consumer of a single slice by computing the -/// slice of the consumer in-place for scf loop. -FailureOr<scf::SCFFuseConsumerOfSliceResult> -mlir::scf::tileAndFuseConsumerOfSlices( - RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, - MutableArrayRef<LoopLikeOpInterface> loops) { - if (candidateSlices.empty()) { - return rewriter.notifyMatchFailure( - rewriter.getUnknownLoc(), - "no candidate slices provided for consumer fusion"); - } - // Return if `loops` is empty, return an error for now. Caller is expected - // to handle this case. - if (loops.empty()) { - return rewriter.notifyMatchFailure( - candidateSlices.front(), - "cannot call tile and fuse consumer with an empty loop nest"); - } +static FailureOr<scf::SCFFuseConsumerOfSliceResult> +tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp, + ArrayRef<OpOperand *> consumerOpOperands, + ArrayRef<Operation *> candidateSlices, + MutableArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "expected loops to be not empty"); - if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || - llvm::all_of(candidateSlices, - llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { + // 1. Check assumption for loop with `reorderOperations` disabled. + if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) { return rewriter.notifyMatchFailure( - candidateSlices.front(), - "candidates slices need to be all `tensor.extract_slice`s or " - "`tensor.parallel_insert_slice`s"); - } - - // 1. Get the consumer of scf.for for the result yielded by - // tensor.insert_slice/parallel_insert_slice. - SmallVector<OpOperand *> consumerOpOperands; - Operation *consumerOp; - { - FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand = - getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); - if (failed(maybeConsumerOpOperand)) { - return rewriter.notifyMatchFailure(candidateSlices.front(), - "could not fetch consumer to fuse"); - } - std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); - consumerOp = consumerOpOperands.front()->getOwner(); + loops.front(), "the first user of loop should not dominate any define " + "of consumer operand(s)"); } LoopLikeOpInterface outerMostLoop = loops.front(); LoopLikeOpInterface innerMostLoop = loops.back(); - // Check assumption for loop with `reorderOperations` disabled. - if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { - return rewriter.notifyMatchFailure( - outerMostLoop, "the first user of loop should not dominate any define " - "of consumer operand(s)"); - } - OpBuilder::InsertionGuard g(rewriter); - // 2. Check consumer is not using scf loop's output as init. auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp); if (!dstOp) @@ -2428,11 +2391,171 @@ mlir::scf::tileAndFuseConsumerOfSlices( llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) { return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum); }); + auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands); return scf::SCFFuseConsumerOfSliceResult{ - std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands), + std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands), std::move(tileAndFuseResult->tiledOps)}; } +/// Implementation of fusing consumer of a single slice by computing the +/// slice of the consumer in-place for scf loop. +FailureOr<scf::SCFFuseConsumerOfSliceResult> +mlir::scf::tileAndFuseConsumerOfSlices( + RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, + MutableArrayRef<LoopLikeOpInterface> loops) { + if (candidateSlices.empty()) { + return rewriter.notifyMatchFailure( + rewriter.getUnknownLoc(), + "no candidate slices provided for consumer fusion"); + } + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "cannot call tile and fuse consumer with an empty loop nest"); + } + + if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || + llvm::all_of(candidateSlices, + llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "candidates slices need to be all `tensor.extract_slice`s or " + "`tensor.parallel_insert_slice`s"); + } + + // Get the consumer of scf.for for the result yielded by + // tensor.insert_slice/parallel_insert_slice. + SmallVector<OpOperand *> consumerOpOperands; + Operation *consumerOp; + { + FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand = + getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); + if (failed(maybeConsumerOpOperand)) { + return rewriter.notifyMatchFailure(candidateSlices.front(), + "could not fetch consumer to fuse"); + } + std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); + consumerOp = consumerOpOperands.front()->getOwner(); + } + + return tileAndFuseConsumerOfSlicesImpl( + rewriter, consumerOp, consumerOpOperands, candidateSlices, loops); +} + +/// For a given `result` of a `forallOp` return the +/// `tensor.parallel_insert_slice` op (or combining op) that is used to +/// construct this result. +static std::optional<Operation *> +getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) { + if (result.getOwner() != forallOp) + return std::nullopt; + BlockArgument bbArg = forallOp.getTiedBlockArgument(result); + SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg); + // If the number of combining ops is not 1, then this is unexpected. Return + // nullopt. + if (combiningOps.size() != 1) { + return std::nullopt; + } + return combiningOps[0]; +} + +/// For a given result of the loop nest that is a tiled loop nest, return the +/// insert slice-like op that is used for consumer fusion +std::optional<Operation *> +getProducingInsertSliceLikeOp(OpResult result, + ArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "Expected loops to be not empty"); + LoopLikeOpInterface outermostLoop = loops.front(); + + if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation())) { + assert(loops.size() == 1 && + "expected only a single loop when tiling using scf.forall"); + return getProducingParallelInsertSlice(forallOp, result); + } + // Assume that the loop nest is a nested `scf.for` that is created through + // tiling and retrieve the `tensor.insert_slice` operation used to construct + // the result. + while (loops.size() != 1) { + if (result.getOwner() != loops.front()) + return std::nullopt; + auto forOp = dyn_cast<scf::ForOp>(loops.front()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); + OpResult innerForResult = + dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber())); + if (!innerForResult) + return std::nullopt; + result = innerForResult; + loops = loops.drop_front(); + } + if (result.getOwner() != loops.front()) + return std::nullopt; + auto forOp = dyn_cast<scf::ForOp>(loops.front()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); + auto insertSliceOp = yieldOp.getOperand(result.getResultNumber()) + .getDefiningOp<tensor::InsertSliceOp>(); + if (!insertSliceOp) + return std::nullopt; + return insertSliceOp; +} + +FailureOr<scf::SCFFuseConsumerOfSliceResult> +mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *user, + MutableArrayRef<LoopLikeOpInterface> loops) { + // Only handle users that implement the `TilingInterface`. + if (!isa<TilingInterface>(user)) { + return rewriter.notifyMatchFailure( + user, "unhandled user that does not implement TilingInterface"); + } + + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + user, "cannot call tile and fuse consumer with an empty loop nest"); + } + + LoopLikeOpInterface outermostLoop = loops.front(); + + // Collect the operands of the user that come from the outermost loop of the + // loop nest. + SmallVector<OpOperand *> consumerFusableOperands; + for (OpOperand &opOperand : user->getOpOperands()) { + if (opOperand.get().getDefiningOp() == outermostLoop) { + consumerFusableOperands.push_back(&opOperand); + } + } + + // Nothing to fuse. Just return an empty set. + if (consumerFusableOperands.empty()) { + return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands, + SmallVector<OpOperand *>{}, + SmallVector<Operation *>{}}; + } + + // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices + // for fusion. + SmallVector<Operation *> candidateSlices; + candidateSlices.reserve(consumerFusableOperands.size()); + for (OpOperand *opOperand : consumerFusableOperands) { + std::optional<Operation *> slice = + getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops); + if (!slice) { + return rewriter.notifyMatchFailure( + user, + "couldnt find producing insert-slice like operation for operand"); + } + candidateSlices.push_back(slice.value()); + } + return tileAndFuseConsumerOfSlicesImpl( + rewriter, user, consumerFusableOperands, candidateSlices, loops); +} + //===----------------------------------------------------------------------===// // lowerToLoopsUsingSCFForOp implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir index 185fb9b358055..d72ab080f3c5c 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir @@ -170,7 +170,7 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - transform.test.fuse_consumer %slice_op in (%forall_op) + transform.test.fuse_consumer_using_slice %slice_op in (%forall_op) : (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -231,7 +231,7 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice + // Note that we cannot apply transform.test.fuse_consumer_using_slice here because the extract_slice // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer // to fuse" error. transform.yield diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir new file mode 100644 index 0000000000000..62dd7faec4eb7 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir @@ -0,0 +1,1156 @@ +// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s + +#map = affine_map<(d0) -> (d0)> +module { + func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32xf32> + %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32> + scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32> + } + %in_operand_2 = tensor.empty() : tensor<64xf32> + %out_operand_3 = tensor.empty() : tensor<64xf32> + %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32> + return %2 : tensor<64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_for( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %0 = tensor.empty() : tensor<64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[MAT_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>) +// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[ELEM_OUT:.*]] = linalg.add +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT]] : +// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] : +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#2 : + +// ----- + +module { + func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + } + } + %in_ope... [truncated] 
@github-actions
Copy link

github-actions bot commented Nov 12, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Nov 12, 2025
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test for scf::tileAndFuseConsumerUsingSlices is copied to tile-and-fuse-consumer.mlir to
tile-and-fuse-consumer-using-slices.mlir. All the tests that were there in this file are now using the tileAndFuseConsumer method. The test op test.tile_and_fuse_consumer is modified to call scf::tileAndFuseConsumer, while a new op
test.tile_and_fuse_consumer_of_slice is used to keep the old path tested while it is deprecated.

I was curious about how the new op work with multiple consumers, but I don't see fuse_add_multiple_tilable_consumers in the tile-and-fuse-consumer.mlir file. Are we missing the test case?

Comment on lines 58 to 59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should revamp the description, so people won't be confused about why two transfrom ops have the same description. Your PR description is very helpful and may be reused here and below.

Comment on lines +180 to +188
SmallVector<Operation *> fusedConsumerOps;

rewriter.setInsertionPoint(consumer);

FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
scf::tileAndFuseConsumer(rewriter, consumer, loops);

if (failed(fuseConsumerResults))
return consumer->emitOpError("failed to fuse consumer of slice");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd remove blank lines, because it does not help readability and it makes the code snippet fit window size better.

Comment on lines 173 to 174
/// Apply fusing of consumer transformation to all payload ops and store both
/// the original consumer operation as well as the fused consumer operation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the comment should be updated. This was the old comment for the other method. This method only fuses one operation now right?


// -----

// Check that when the given operand tiles are inconsistent, tiling fails.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this no longer hold? I don't see changes on lit checks. Or it may belong to the below test case: multi_slice_fusion_with_broadcast, that has an expected_error check.

@@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest(
for (auto [outerLoop, innerLoop] :
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
// Again assume that all the outer loops are scf.for operations.
auto outerForLoop = cast<scf::ForOp>(outerLoop);
auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks not related to the PR itself, we can consider reverting it.


/// For a given result of the loop nest that is a tiled loop nest, return the
/// insert slice-like op that is used for consumer fusion
std::optional<Operation *>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::optional<Operation *>
static std::optional<Operation *>
Comment on lines +2432 to +2441
{
FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSlices.front(),
"could not fetch consumer to fuse");
}
std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
consumerOp = consumerOpOperands.front()->getOwner();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd remove the {} region. I usually use it when I need local variables that don't pollute other phases. Given that it is a small function and we don't have such need, let's drop {}?

Comment on lines +2458 to +2460
if (combiningOps.size() != 1) {
return std::nullopt;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: drop braces.

Comment on lines +2470 to +2472
LoopLikeOpInterface outermostLoop = loops.front();

if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outermostLoop is only used by this check. I think we don't need the variable for readability. I could be biased, but it is straight-forward to assume the loops order is starting from outermost loop.

In any case, I'd drop the blank line because they belong to the same code snippet -- you declare outermostLoop variable and check it immediately.

Comment on lines +2488 to +2489
OpResult innerForResult =
dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: auto, because dyn_cast already spells the type.

@MaheshRavishankar
Copy link
Contributor Author

The test for scf::tileAndFuseConsumerUsingSlices is copied to tile-and-fuse-consumer.mlir to
tile-and-fuse-consumer-using-slices.mlir. All the tests that were there in this file are now using the tileAndFuseConsumer method. The test op test.tile_and_fuse_consumer is modified to call scf::tileAndFuseConsumer, while a new op
test.tile_and_fuse_consumer_of_slice is used to keep the old path tested while it is deprecated.

I was curious about how the new op work with multiple consumers, but I don't see fuse_add_multiple_tilable_consumers in the tile-and-fuse-consumer.mlir file. Are we missing the test case?

I think the expected use is for multiple consumers, you call the fusion method multiple times. I think this test case https://github.com/llvm/llvm-project/pull/167634/files#diff-154c2d387b69e0e07cde6b61865996435b248f8abac476ee42de6c3cfb12d715R736 shows this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment