- Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][SCF] Add scf::tileAndFuseConsumer that tiles a consumer into a given tiled loop nest. #167634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][SCF] Add scf::tileAndFuseConsumer that tiles a consumer into a given tiled loop nest. #167634
Conversation
… a given tiled loop nest. The existing `scf::tileAndFuseConsumerOfSlices` takes a list of slices (and loops they are part of), tries to find the consumer of these slices (all slices are expected to be the same consumer), and then tiles the consumer into the loop nest using the `TilingInterface`. A more natural way of doing consumer fusion is to just start from the consumer, look for operands that are produced by the loop nest passed in as `loops` (presumably these loops are generated by tiling, but that is not a requirement for consumer fusion). Using the consumer you can find the slices of the operands that are accessed within the loop which you can then use to tile and fuse the consumer (using `TilingInterface`). This handles more naturally the case where multiple operands of the consumer come from the loop nest. The `scf::tileAndFuseConsumerOfSlices` was implemented as a mirror of `scf::tileAndFuseProducerOfSlice`. For the latter, the slice has a single producer for the source of the slice, which makes it a natural way of specifying producer fusion. But for consumers, the result might have multiple users, resulting in multiple candidates for fusion, as well as a fusion candidate using multiple results from the tiled loop nest. This means using slices (`tensor.insert_slice`/`tensor.parallel_insert_slice`) as a hook for consumer fusion turns out to be quite hard to navigate. The use of the consumer directly avoids all those pain points. In time the `scf::tileAndFuseConsumerOfSlices` should be deprecated in favor of `scf::tileAndFuseConsumer`. There is a lot of tech-debt that has accumulated in `scf::tileAndFuseConsumerOfSlices` that needs to be cleanedup. So while that gets cleaned up, and required functionality is moved to `scf::tileAndFuseConsumer`, the old path is still maintained. The test for `scf::tileAndFuseConsumerUsingSlices` is copied to `tile-and-fuse-consumer.mlir` to `tile-and-fuse-consumer-using-slices.mlir`. All the tests that were there in this file are now using the `tileAndFuseConsumer` method. The test op `test.tile_and_fuse_consumer` is modified to call `scf::tileAndFuseConsumer`, while a new op `test.tile_and_fuse_consumer_of_slice` is used to keep the old path tested while it is deprecated. Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
| @llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: None (MaheshRavishankar) ChangesThe existing The The test for Patch is 141.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167634.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index cd033c140a233..8bdf3e0b566ef 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -613,6 +613,11 @@ def ForallOp : SCF_Op<"forall", [ getNumDynamicControlOperands() + getRank()); } + BlockArgument getTiedBlockArgument(OpResult opResult) { + assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult"); + return getBody()->getArgument(getRank() + opResult.getResultNumber()); + } + ::mlir::Value getInductionVar(int64_t idx) { return getInductionVars()[idx]; } diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 7c735d825b445..0005fad3d5c01 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -415,6 +415,10 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, /// tiled in a manner that is consistent for all the passed slices. Note that /// the method replaces the uses of `candidateSlices` with the tiled and fused /// consumer value but does not delete the slice operations. +/// TODO(MaheshRavishankar): A more natural way of exposing the consumer fusion +/// is to take the consumer operation, and find the slices to use for fusion +/// by walking its operands to the `loops` and then into the body to get the +/// slices used for fusion. struct SCFFuseConsumerOfSliceResult { // Original untiled consumer operands. SmallVector<OpOperand *> origConsumerOperands; @@ -427,6 +431,14 @@ tileAndFuseConsumerOfSlices(RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, MutableArrayRef<LoopLikeOpInterface> loops); +/// Fuse the `consumer` operation into the loop nest provided by `loops`. +/// The transformation looks for operands in the `consumer` that are defined +/// by the outermost loop of the loop nest in `loops`. The nested loop is +/// expected to have the structure of the loops generated through tiling. +FailureOr<scf::SCFFuseConsumerOfSliceResult> +tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer, + MutableArrayRef<LoopLikeOpInterface> loops); + /// Method to lower an `op` that implements the `TilingInterface` to /// loops/scalars. FailureOr<SmallVector<scf::ForOp>> diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 29b770fb4b279..7e715ee189740 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest( for (auto [outerLoop, innerLoop] : llvm::zip_equal(loops.drop_back(), loops.drop_front())) { // Again assume that all the outer loops are scf.for operations. - auto outerForLoop = cast<scf::ForOp>(outerLoop); + auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation()); auto outerLoopYield = cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); SmallVector<Value> newYields = @@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter, return clonedSlices; } -/// Implementation of fusing consumer of a single slice by computing the -/// slice of the consumer in-place for scf loop. -FailureOr<scf::SCFFuseConsumerOfSliceResult> -mlir::scf::tileAndFuseConsumerOfSlices( - RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, - MutableArrayRef<LoopLikeOpInterface> loops) { - if (candidateSlices.empty()) { - return rewriter.notifyMatchFailure( - rewriter.getUnknownLoc(), - "no candidate slices provided for consumer fusion"); - } - // Return if `loops` is empty, return an error for now. Caller is expected - // to handle this case. - if (loops.empty()) { - return rewriter.notifyMatchFailure( - candidateSlices.front(), - "cannot call tile and fuse consumer with an empty loop nest"); - } +static FailureOr<scf::SCFFuseConsumerOfSliceResult> +tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp, + ArrayRef<OpOperand *> consumerOpOperands, + ArrayRef<Operation *> candidateSlices, + MutableArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "expected loops to be not empty"); - if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || - llvm::all_of(candidateSlices, - llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { + // 1. Check assumption for loop with `reorderOperations` disabled. + if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) { return rewriter.notifyMatchFailure( - candidateSlices.front(), - "candidates slices need to be all `tensor.extract_slice`s or " - "`tensor.parallel_insert_slice`s"); - } - - // 1. Get the consumer of scf.for for the result yielded by - // tensor.insert_slice/parallel_insert_slice. - SmallVector<OpOperand *> consumerOpOperands; - Operation *consumerOp; - { - FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand = - getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); - if (failed(maybeConsumerOpOperand)) { - return rewriter.notifyMatchFailure(candidateSlices.front(), - "could not fetch consumer to fuse"); - } - std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); - consumerOp = consumerOpOperands.front()->getOwner(); + loops.front(), "the first user of loop should not dominate any define " + "of consumer operand(s)"); } LoopLikeOpInterface outerMostLoop = loops.front(); LoopLikeOpInterface innerMostLoop = loops.back(); - // Check assumption for loop with `reorderOperations` disabled. - if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { - return rewriter.notifyMatchFailure( - outerMostLoop, "the first user of loop should not dominate any define " - "of consumer operand(s)"); - } - OpBuilder::InsertionGuard g(rewriter); - // 2. Check consumer is not using scf loop's output as init. auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp); if (!dstOp) @@ -2428,11 +2391,171 @@ mlir::scf::tileAndFuseConsumerOfSlices( llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) { return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum); }); + auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands); return scf::SCFFuseConsumerOfSliceResult{ - std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands), + std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands), std::move(tileAndFuseResult->tiledOps)}; } +/// Implementation of fusing consumer of a single slice by computing the +/// slice of the consumer in-place for scf loop. +FailureOr<scf::SCFFuseConsumerOfSliceResult> +mlir::scf::tileAndFuseConsumerOfSlices( + RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, + MutableArrayRef<LoopLikeOpInterface> loops) { + if (candidateSlices.empty()) { + return rewriter.notifyMatchFailure( + rewriter.getUnknownLoc(), + "no candidate slices provided for consumer fusion"); + } + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "cannot call tile and fuse consumer with an empty loop nest"); + } + + if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || + llvm::all_of(candidateSlices, + llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "candidates slices need to be all `tensor.extract_slice`s or " + "`tensor.parallel_insert_slice`s"); + } + + // Get the consumer of scf.for for the result yielded by + // tensor.insert_slice/parallel_insert_slice. + SmallVector<OpOperand *> consumerOpOperands; + Operation *consumerOp; + { + FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand = + getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); + if (failed(maybeConsumerOpOperand)) { + return rewriter.notifyMatchFailure(candidateSlices.front(), + "could not fetch consumer to fuse"); + } + std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); + consumerOp = consumerOpOperands.front()->getOwner(); + } + + return tileAndFuseConsumerOfSlicesImpl( + rewriter, consumerOp, consumerOpOperands, candidateSlices, loops); +} + +/// For a given `result` of a `forallOp` return the +/// `tensor.parallel_insert_slice` op (or combining op) that is used to +/// construct this result. +static std::optional<Operation *> +getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) { + if (result.getOwner() != forallOp) + return std::nullopt; + BlockArgument bbArg = forallOp.getTiedBlockArgument(result); + SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg); + // If the number of combining ops is not 1, then this is unexpected. Return + // nullopt. + if (combiningOps.size() != 1) { + return std::nullopt; + } + return combiningOps[0]; +} + +/// For a given result of the loop nest that is a tiled loop nest, return the +/// insert slice-like op that is used for consumer fusion +std::optional<Operation *> +getProducingInsertSliceLikeOp(OpResult result, + ArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "Expected loops to be not empty"); + LoopLikeOpInterface outermostLoop = loops.front(); + + if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation())) { + assert(loops.size() == 1 && + "expected only a single loop when tiling using scf.forall"); + return getProducingParallelInsertSlice(forallOp, result); + } + // Assume that the loop nest is a nested `scf.for` that is created through + // tiling and retrieve the `tensor.insert_slice` operation used to construct + // the result. + while (loops.size() != 1) { + if (result.getOwner() != loops.front()) + return std::nullopt; + auto forOp = dyn_cast<scf::ForOp>(loops.front()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); + OpResult innerForResult = + dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber())); + if (!innerForResult) + return std::nullopt; + result = innerForResult; + loops = loops.drop_front(); + } + if (result.getOwner() != loops.front()) + return std::nullopt; + auto forOp = dyn_cast<scf::ForOp>(loops.front()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); + auto insertSliceOp = yieldOp.getOperand(result.getResultNumber()) + .getDefiningOp<tensor::InsertSliceOp>(); + if (!insertSliceOp) + return std::nullopt; + return insertSliceOp; +} + +FailureOr<scf::SCFFuseConsumerOfSliceResult> +mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *user, + MutableArrayRef<LoopLikeOpInterface> loops) { + // Only handle users that implement the `TilingInterface`. + if (!isa<TilingInterface>(user)) { + return rewriter.notifyMatchFailure( + user, "unhandled user that does not implement TilingInterface"); + } + + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + user, "cannot call tile and fuse consumer with an empty loop nest"); + } + + LoopLikeOpInterface outermostLoop = loops.front(); + + // Collect the operands of the user that come from the outermost loop of the + // loop nest. + SmallVector<OpOperand *> consumerFusableOperands; + for (OpOperand &opOperand : user->getOpOperands()) { + if (opOperand.get().getDefiningOp() == outermostLoop) { + consumerFusableOperands.push_back(&opOperand); + } + } + + // Nothing to fuse. Just return an empty set. + if (consumerFusableOperands.empty()) { + return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands, + SmallVector<OpOperand *>{}, + SmallVector<Operation *>{}}; + } + + // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices + // for fusion. + SmallVector<Operation *> candidateSlices; + candidateSlices.reserve(consumerFusableOperands.size()); + for (OpOperand *opOperand : consumerFusableOperands) { + std::optional<Operation *> slice = + getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops); + if (!slice) { + return rewriter.notifyMatchFailure( + user, + "couldnt find producing insert-slice like operation for operand"); + } + candidateSlices.push_back(slice.value()); + } + return tileAndFuseConsumerOfSlicesImpl( + rewriter, user, consumerFusableOperands, candidateSlices, loops); +} + //===----------------------------------------------------------------------===// // lowerToLoopsUsingSCFForOp implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir index 185fb9b358055..d72ab080f3c5c 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir @@ -170,7 +170,7 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - transform.test.fuse_consumer %slice_op in (%forall_op) + transform.test.fuse_consumer_using_slice %slice_op in (%forall_op) : (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -231,7 +231,7 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice + // Note that we cannot apply transform.test.fuse_consumer_using_slice here because the extract_slice // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer // to fuse" error. transform.yield diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir new file mode 100644 index 0000000000000..62dd7faec4eb7 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir @@ -0,0 +1,1156 @@ +// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s + +#map = affine_map<(d0) -> (d0)> +module { + func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32xf32> + %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32> + scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32> + } + %in_operand_2 = tensor.empty() : tensor<64xf32> + %out_operand_3 = tensor.empty() : tensor<64xf32> + %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32> + return %2 : tensor<64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_for( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %0 = tensor.empty() : tensor<64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[MAT_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>) +// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[ELEM_OUT:.*]] = linalg.add +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT]] : +// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] : +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#2 : + +// ----- + +module { + func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + } + } + %in_ope... [truncated] |
| @llvm/pr-subscribers-mlir-linalg Author: None (MaheshRavishankar) ChangesThe existing The The test for Patch is 141.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167634.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index cd033c140a233..8bdf3e0b566ef 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -613,6 +613,11 @@ def ForallOp : SCF_Op<"forall", [ getNumDynamicControlOperands() + getRank()); } + BlockArgument getTiedBlockArgument(OpResult opResult) { + assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult"); + return getBody()->getArgument(getRank() + opResult.getResultNumber()); + } + ::mlir::Value getInductionVar(int64_t idx) { return getInductionVars()[idx]; } diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 7c735d825b445..0005fad3d5c01 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -415,6 +415,10 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter, /// tiled in a manner that is consistent for all the passed slices. Note that /// the method replaces the uses of `candidateSlices` with the tiled and fused /// consumer value but does not delete the slice operations. +/// TODO(MaheshRavishankar): A more natural way of exposing the consumer fusion +/// is to take the consumer operation, and find the slices to use for fusion +/// by walking its operands to the `loops` and then into the body to get the +/// slices used for fusion. struct SCFFuseConsumerOfSliceResult { // Original untiled consumer operands. SmallVector<OpOperand *> origConsumerOperands; @@ -427,6 +431,14 @@ tileAndFuseConsumerOfSlices(RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, MutableArrayRef<LoopLikeOpInterface> loops); +/// Fuse the `consumer` operation into the loop nest provided by `loops`. +/// The transformation looks for operands in the `consumer` that are defined +/// by the outermost loop of the loop nest in `loops`. The nested loop is +/// expected to have the structure of the loops generated through tiling. +FailureOr<scf::SCFFuseConsumerOfSliceResult> +tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer, + MutableArrayRef<LoopLikeOpInterface> loops); + /// Method to lower an `op` that implements the `TilingInterface` to /// loops/scalars. FailureOr<SmallVector<scf::ForOp>> diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 29b770fb4b279..7e715ee189740 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest( for (auto [outerLoop, innerLoop] : llvm::zip_equal(loops.drop_back(), loops.drop_front())) { // Again assume that all the outer loops are scf.for operations. - auto outerForLoop = cast<scf::ForOp>(outerLoop); + auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation()); auto outerLoopYield = cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); SmallVector<Value> newYields = @@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter, return clonedSlices; } -/// Implementation of fusing consumer of a single slice by computing the -/// slice of the consumer in-place for scf loop. -FailureOr<scf::SCFFuseConsumerOfSliceResult> -mlir::scf::tileAndFuseConsumerOfSlices( - RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, - MutableArrayRef<LoopLikeOpInterface> loops) { - if (candidateSlices.empty()) { - return rewriter.notifyMatchFailure( - rewriter.getUnknownLoc(), - "no candidate slices provided for consumer fusion"); - } - // Return if `loops` is empty, return an error for now. Caller is expected - // to handle this case. - if (loops.empty()) { - return rewriter.notifyMatchFailure( - candidateSlices.front(), - "cannot call tile and fuse consumer with an empty loop nest"); - } +static FailureOr<scf::SCFFuseConsumerOfSliceResult> +tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp, + ArrayRef<OpOperand *> consumerOpOperands, + ArrayRef<Operation *> candidateSlices, + MutableArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "expected loops to be not empty"); - if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || - llvm::all_of(candidateSlices, - llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { + // 1. Check assumption for loop with `reorderOperations` disabled. + if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) { return rewriter.notifyMatchFailure( - candidateSlices.front(), - "candidates slices need to be all `tensor.extract_slice`s or " - "`tensor.parallel_insert_slice`s"); - } - - // 1. Get the consumer of scf.for for the result yielded by - // tensor.insert_slice/parallel_insert_slice. - SmallVector<OpOperand *> consumerOpOperands; - Operation *consumerOp; - { - FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand = - getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); - if (failed(maybeConsumerOpOperand)) { - return rewriter.notifyMatchFailure(candidateSlices.front(), - "could not fetch consumer to fuse"); - } - std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); - consumerOp = consumerOpOperands.front()->getOwner(); + loops.front(), "the first user of loop should not dominate any define " + "of consumer operand(s)"); } LoopLikeOpInterface outerMostLoop = loops.front(); LoopLikeOpInterface innerMostLoop = loops.back(); - // Check assumption for loop with `reorderOperations` disabled. - if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { - return rewriter.notifyMatchFailure( - outerMostLoop, "the first user of loop should not dominate any define " - "of consumer operand(s)"); - } - OpBuilder::InsertionGuard g(rewriter); - // 2. Check consumer is not using scf loop's output as init. auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp); if (!dstOp) @@ -2428,11 +2391,171 @@ mlir::scf::tileAndFuseConsumerOfSlices( llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) { return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum); }); + auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands); return scf::SCFFuseConsumerOfSliceResult{ - std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands), + std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands), std::move(tileAndFuseResult->tiledOps)}; } +/// Implementation of fusing consumer of a single slice by computing the +/// slice of the consumer in-place for scf loop. +FailureOr<scf::SCFFuseConsumerOfSliceResult> +mlir::scf::tileAndFuseConsumerOfSlices( + RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, + MutableArrayRef<LoopLikeOpInterface> loops) { + if (candidateSlices.empty()) { + return rewriter.notifyMatchFailure( + rewriter.getUnknownLoc(), + "no candidate slices provided for consumer fusion"); + } + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "cannot call tile and fuse consumer with an empty loop nest"); + } + + if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || + llvm::all_of(candidateSlices, + llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { + return rewriter.notifyMatchFailure( + candidateSlices.front(), + "candidates slices need to be all `tensor.extract_slice`s or " + "`tensor.parallel_insert_slice`s"); + } + + // Get the consumer of scf.for for the result yielded by + // tensor.insert_slice/parallel_insert_slice. + SmallVector<OpOperand *> consumerOpOperands; + Operation *consumerOp; + { + FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand = + getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); + if (failed(maybeConsumerOpOperand)) { + return rewriter.notifyMatchFailure(candidateSlices.front(), + "could not fetch consumer to fuse"); + } + std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); + consumerOp = consumerOpOperands.front()->getOwner(); + } + + return tileAndFuseConsumerOfSlicesImpl( + rewriter, consumerOp, consumerOpOperands, candidateSlices, loops); +} + +/// For a given `result` of a `forallOp` return the +/// `tensor.parallel_insert_slice` op (or combining op) that is used to +/// construct this result. +static std::optional<Operation *> +getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) { + if (result.getOwner() != forallOp) + return std::nullopt; + BlockArgument bbArg = forallOp.getTiedBlockArgument(result); + SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg); + // If the number of combining ops is not 1, then this is unexpected. Return + // nullopt. + if (combiningOps.size() != 1) { + return std::nullopt; + } + return combiningOps[0]; +} + +/// For a given result of the loop nest that is a tiled loop nest, return the +/// insert slice-like op that is used for consumer fusion +std::optional<Operation *> +getProducingInsertSliceLikeOp(OpResult result, + ArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "Expected loops to be not empty"); + LoopLikeOpInterface outermostLoop = loops.front(); + + if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation())) { + assert(loops.size() == 1 && + "expected only a single loop when tiling using scf.forall"); + return getProducingParallelInsertSlice(forallOp, result); + } + // Assume that the loop nest is a nested `scf.for` that is created through + // tiling and retrieve the `tensor.insert_slice` operation used to construct + // the result. + while (loops.size() != 1) { + if (result.getOwner() != loops.front()) + return std::nullopt; + auto forOp = dyn_cast<scf::ForOp>(loops.front()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); + OpResult innerForResult = + dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber())); + if (!innerForResult) + return std::nullopt; + result = innerForResult; + loops = loops.drop_front(); + } + if (result.getOwner() != loops.front()) + return std::nullopt; + auto forOp = dyn_cast<scf::ForOp>(loops.front()); + if (!forOp) + return std::nullopt; + auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); + auto insertSliceOp = yieldOp.getOperand(result.getResultNumber()) + .getDefiningOp<tensor::InsertSliceOp>(); + if (!insertSliceOp) + return std::nullopt; + return insertSliceOp; +} + +FailureOr<scf::SCFFuseConsumerOfSliceResult> +mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *user, + MutableArrayRef<LoopLikeOpInterface> loops) { + // Only handle users that implement the `TilingInterface`. + if (!isa<TilingInterface>(user)) { + return rewriter.notifyMatchFailure( + user, "unhandled user that does not implement TilingInterface"); + } + + // Return if `loops` is empty, return an error for now. Caller is expected + // to handle this case. + if (loops.empty()) { + return rewriter.notifyMatchFailure( + user, "cannot call tile and fuse consumer with an empty loop nest"); + } + + LoopLikeOpInterface outermostLoop = loops.front(); + + // Collect the operands of the user that come from the outermost loop of the + // loop nest. + SmallVector<OpOperand *> consumerFusableOperands; + for (OpOperand &opOperand : user->getOpOperands()) { + if (opOperand.get().getDefiningOp() == outermostLoop) { + consumerFusableOperands.push_back(&opOperand); + } + } + + // Nothing to fuse. Just return an empty set. + if (consumerFusableOperands.empty()) { + return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands, + SmallVector<OpOperand *>{}, + SmallVector<Operation *>{}}; + } + + // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices + // for fusion. + SmallVector<Operation *> candidateSlices; + candidateSlices.reserve(consumerFusableOperands.size()); + for (OpOperand *opOperand : consumerFusableOperands) { + std::optional<Operation *> slice = + getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops); + if (!slice) { + return rewriter.notifyMatchFailure( + user, + "couldnt find producing insert-slice like operation for operand"); + } + candidateSlices.push_back(slice.value()); + } + return tileAndFuseConsumerOfSlicesImpl( + rewriter, user, consumerFusableOperands, candidateSlices, loops); +} + //===----------------------------------------------------------------------===// // lowerToLoopsUsingSCFForOp implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir index 185fb9b358055..d72ab080f3c5c 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-and-fuse-pack-unpack.mlir @@ -170,7 +170,7 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - transform.test.fuse_consumer %slice_op in (%forall_op) + transform.test.fuse_consumer_using_slice %slice_op in (%forall_op) : (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } @@ -231,7 +231,7 @@ module { // Fuse the consumer operation into the tiled loop. %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op : (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice"> - // Note that we cannot apply transform.test.fuse_consumer here because the extract_slice + // Note that we cannot apply transform.test.fuse_consumer_using_slice here because the extract_slice // is not qualified consumer operation. Forcing this will yeild "could not fetch consumer // to fuse" error. transform.yield diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir new file mode 100644 index 0000000000000..62dd7faec4eb7 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer-using-slices.mlir @@ -0,0 +1,1156 @@ +// RUN: mlir-opt --transform-interpreter --cse --split-input-file --verify-diagnostics %s | FileCheck %s + +#map = affine_map<(d0) -> (d0)> +module { + func.func @fuse_tileable_consumer_scf_for(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) { + %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32xf32> + %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32> + scf.yield %arg5, %4 : tensor<64xf32>, tensor<64xf32> + } + %in_operand_2 = tensor.empty() : tensor<64xf32> + %out_operand_3 = tensor.empty() : tensor<64xf32> + %2 = linalg.add ins(%1#1, %in_operand_2 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand_3 : tensor<64xf32>) -> tensor<64xf32> + return %2 : tensor<64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %loop = transform.structured.match ops{["scf.for"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %yield = transform.structured.match ops{["tensor.insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer_using_slice %yield in (%loop) + : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: func.func @fuse_tileable_consumer_scf_for( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64xf32>) +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %0 = tensor.empty() : tensor<64xf32> +// CHECK: %[[FINAL_RESULT:.*]]:3 = scf.for %[[IV:.*]] = %[[C0]] +// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[SECOND_OUT_ARG:.*]] = %[[ARG2]], %[[ELEM_OUT_ARG:.*]] = %0) +// CHECK-SAME: { +// CHECK: %[[MAT_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[MAT_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[MAT_OUT_SLICE]] : tensor<32xf32>) +// CHECK: %[[INSERT_MAT:.*]] = tensor.insert_slice %[[MAT_OUT]] into %[[FIRST_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OPERAND2:.*]] = tensor.extract_slice %0[%[[IV]]] [32] [1] +// CHECK: %[[SLICE_OUT:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: %[[ELEM_OUT:.*]] = linalg.add +// CHECK-SAME: ins(%[[MAT_OUT]], %[[SLICE_OPERAND2]] : +// CHECK-SAME: outs(%[[SLICE_OUT]] : +// CHECK: %[[INSERT_ELEM:.*]] = tensor.insert_slice %[[ELEM_OUT]] into %[[ELEM_OUT_ARG]][%[[IV]]] [32] [1] +// CHECK: scf.yield %[[SECOND_OUT_ARG]], %[[INSERT_MAT]], %[[INSERT_ELEM]] : +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#2 : + +// ----- + +module { + func.func @fuse_tileable_consumer_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %extracted_slice_1 = tensor.extract_slice %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32> + %3 = linalg.matmul ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + tensor.parallel_insert_slice %extracted_slice_1 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32> + } + } + %in_ope... [truncated] |
| ✅ With the latest revision this PR passed the C/C++ code formatter. |
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test for scf::tileAndFuseConsumerUsingSlices is copied to tile-and-fuse-consumer.mlir to
tile-and-fuse-consumer-using-slices.mlir. All the tests that were there in this file are now using the tileAndFuseConsumer method. The test op test.tile_and_fuse_consumer is modified to call scf::tileAndFuseConsumer, while a new op
test.tile_and_fuse_consumer_of_slice is used to keep the old path tested while it is deprecated.
I was curious about how the new op work with multiple consumers, but I don't see fuse_add_multiple_tilable_consumers in the tile-and-fuse-consumer.mlir file. Are we missing the test case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should revamp the description, so people won't be confused about why two transfrom ops have the same description. Your PR description is very helpful and may be reused here and below.
| SmallVector<Operation *> fusedConsumerOps; | ||
| | ||
| rewriter.setInsertionPoint(consumer); | ||
| | ||
| FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults = | ||
| scf::tileAndFuseConsumer(rewriter, consumer, loops); | ||
| | ||
| if (failed(fuseConsumerResults)) | ||
| return consumer->emitOpError("failed to fuse consumer of slice"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd remove blank lines, because it does not help readability and it makes the code snippet fit window size better.
| /// Apply fusing of consumer transformation to all payload ops and store both | ||
| /// the original consumer operation as well as the fused consumer operation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the comment should be updated. This was the old comment for the other method. This method only fuses one operation now right?
| | ||
| // ----- | ||
| | ||
| // Check that when the given operand tiles are inconsistent, tiling fails. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this no longer hold? I don't see changes on lit checks. Or it may belong to the below test case: multi_slice_fusion_with_broadcast, that has an expected_error check.
| @@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest( | |||
| for (auto [outerLoop, innerLoop] : | |||
| llvm::zip_equal(loops.drop_back(), loops.drop_front())) { | |||
| // Again assume that all the outer loops are scf.for operations. | |||
| auto outerForLoop = cast<scf::ForOp>(outerLoop); | |||
| auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation()); | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks not related to the PR itself, we can consider reverting it.
| | ||
| /// For a given result of the loop nest that is a tiled loop nest, return the | ||
| /// insert slice-like op that is used for consumer fusion | ||
| std::optional<Operation *> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| std::optional<Operation *> | |
| static std::optional<Operation *> |
| { | ||
| FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand = | ||
| getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); | ||
| if (failed(maybeConsumerOpOperand)) { | ||
| return rewriter.notifyMatchFailure(candidateSlices.front(), | ||
| "could not fetch consumer to fuse"); | ||
| } | ||
| std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); | ||
| consumerOp = consumerOpOperands.front()->getOwner(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd remove the {} region. I usually use it when I need local variables that don't pollute other phases. Given that it is a small function and we don't have such need, let's drop {}?
| if (combiningOps.size() != 1) { | ||
| return std::nullopt; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style nit: drop braces.
| LoopLikeOpInterface outermostLoop = loops.front(); | ||
| | ||
| if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
outermostLoop is only used by this check. I think we don't need the variable for readability. I could be biased, but it is straight-forward to assume the loops order is starting from outermost loop.
In any case, I'd drop the blank line because they belong to the same code snippet -- you declare outermostLoop variable and check it immediately.
| OpResult innerForResult = | ||
| dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style nit: auto, because dyn_cast already spells the type.
I think the expected use is for multiple consumers, you call the fusion method multiple times. I think this test case https://github.com/llvm/llvm-project/pull/167634/files#diff-154c2d387b69e0e07cde6b61865996435b248f8abac476ee42de6c3cfb12d715R736 shows this. |
The existing
scf::tileAndFuseConsumerOfSlicestakes a list of slices (and loops they are part of), tries to find the consumer of these slices (all slices are expected to be the same consumer), and then tiles the consumer into the loop nest using theTilingInterface. A more natural way of doing consumer fusion is to just start from the consumer, look for operands that are produced by the loop nest passed in asloops(presumably these loops are generated by tiling, but that is not a requirement for consumer fusion). Using the consumer you can find the slices of the operands that are accessed within the loop which you can then use to tile and fuse the consumer (usingTilingInterface). This handles more naturally the case where multiple operands of the consumer come from the loop nest.The
scf::tileAndFuseConsumerOfSliceswas implemented as a mirror ofscf::tileAndFuseProducerOfSlice. For the latter, the slice has a single producer for the source of the slice, which makes it a natural way of specifying producer fusion. But for consumers, the result might have multiple users, resulting in multiple candidates for fusion, as well as a fusion candidate using multiple results from the tiled loop nest. This means using slices(
tensor.insert_slice/tensor.parallel_insert_slice) as a hook for consumer fusion turns out to be quite hard to navigate. The use of the consumer directly avoids all those pain points. In time thescf::tileAndFuseConsumerOfSlicesshould be deprecated in favor ofscf::tileAndFuseConsumer. There is a lot of tech-debt that has accumulated inscf::tileAndFuseConsumerOfSlicesthat needs to be cleanedup. So while that gets cleaned up, and required functionality is moved toscf::tileAndFuseConsumer, the old path is still maintained.The test for
scf::tileAndFuseConsumerUsingSlicesis copied totile-and-fuse-consumer.mlirtotile-and-fuse-consumer-using-slices.mlir. All the tests that were there in this file are now using thetileAndFuseConsumermethod. The test optest.tile_and_fuse_consumeris modified to callscf::tileAndFuseConsumer, while a new optest.tile_and_fuse_consumer_of_sliceis used to keep the old path tested while it is deprecated.