- Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][SCF] Add scf::tileAndFuseConsumer that tiles a consumer into a given tiled loop nest. #167634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][SCF] Add scf::tileAndFuseConsumer that tiles a consumer into a given tiled loop nest. #167634
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| | @@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest( | |||||
| for (auto [outerLoop, innerLoop] : | ||||||
| llvm::zip_equal(loops.drop_back(), loops.drop_front())) { | ||||||
| // Again assume that all the outer loops are scf.for operations. | ||||||
| auto outerForLoop = cast<scf::ForOp>(outerLoop); | ||||||
| auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation()); | ||||||
| auto outerLoopYield = | ||||||
| cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator()); | ||||||
| SmallVector<Value> newYields = | ||||||
| | @@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter, | |||||
| return clonedSlices; | ||||||
| } | ||||||
| | ||||||
| /// Implementation of fusing consumer of a single slice by computing the | ||||||
| /// slice of the consumer in-place for scf loop. | ||||||
| FailureOr<scf::SCFFuseConsumerOfSliceResult> | ||||||
| mlir::scf::tileAndFuseConsumerOfSlices( | ||||||
| RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, | ||||||
| MutableArrayRef<LoopLikeOpInterface> loops) { | ||||||
| if (candidateSlices.empty()) { | ||||||
| return rewriter.notifyMatchFailure( | ||||||
| rewriter.getUnknownLoc(), | ||||||
| "no candidate slices provided for consumer fusion"); | ||||||
| } | ||||||
| // Return if `loops` is empty, return an error for now. Caller is expected | ||||||
| // to handle this case. | ||||||
| if (loops.empty()) { | ||||||
| return rewriter.notifyMatchFailure( | ||||||
| candidateSlices.front(), | ||||||
| "cannot call tile and fuse consumer with an empty loop nest"); | ||||||
| } | ||||||
| static FailureOr<scf::SCFFuseConsumerOfSliceResult> | ||||||
| tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp, | ||||||
| ArrayRef<OpOperand *> consumerOpOperands, | ||||||
| ArrayRef<Operation *> candidateSlices, | ||||||
| MutableArrayRef<LoopLikeOpInterface> loops) { | ||||||
| assert(!loops.empty() && "expected loops to be not empty"); | ||||||
| | ||||||
| if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || | ||||||
| llvm::all_of(candidateSlices, | ||||||
| llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { | ||||||
| // 1. Check assumption for loop with `reorderOperations` disabled. | ||||||
| if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) { | ||||||
| return rewriter.notifyMatchFailure( | ||||||
| candidateSlices.front(), | ||||||
| "candidates slices need to be all `tensor.extract_slice`s or " | ||||||
| "`tensor.parallel_insert_slice`s"); | ||||||
| } | ||||||
| | ||||||
| // 1. Get the consumer of scf.for for the result yielded by | ||||||
| // tensor.insert_slice/parallel_insert_slice. | ||||||
| SmallVector<OpOperand *> consumerOpOperands; | ||||||
| Operation *consumerOp; | ||||||
| { | ||||||
| FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand = | ||||||
| getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); | ||||||
| if (failed(maybeConsumerOpOperand)) { | ||||||
| return rewriter.notifyMatchFailure(candidateSlices.front(), | ||||||
| "could not fetch consumer to fuse"); | ||||||
| } | ||||||
| std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); | ||||||
| consumerOp = consumerOpOperands.front()->getOwner(); | ||||||
| loops.front(), "the first user of loop should not dominate any define " | ||||||
| "of consumer operand(s)"); | ||||||
| } | ||||||
| | ||||||
| LoopLikeOpInterface outerMostLoop = loops.front(); | ||||||
| LoopLikeOpInterface innerMostLoop = loops.back(); | ||||||
| | ||||||
| // Check assumption for loop with `reorderOperations` disabled. | ||||||
| if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) { | ||||||
| return rewriter.notifyMatchFailure( | ||||||
| outerMostLoop, "the first user of loop should not dominate any define " | ||||||
| "of consumer operand(s)"); | ||||||
| } | ||||||
| | ||||||
| OpBuilder::InsertionGuard g(rewriter); | ||||||
| | ||||||
| // 2. Check consumer is not using scf loop's output as init. | ||||||
| auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp); | ||||||
| if (!dstOp) | ||||||
| | @@ -2428,11 +2391,173 @@ mlir::scf::tileAndFuseConsumerOfSlices( | |||||
| llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) { | ||||||
| return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum); | ||||||
| }); | ||||||
| auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands); | ||||||
| return scf::SCFFuseConsumerOfSliceResult{ | ||||||
| std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands), | ||||||
| std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands), | ||||||
| std::move(tileAndFuseResult->tiledOps)}; | ||||||
| } | ||||||
| | ||||||
| /// Implementation of fusing consumer of a single slice by computing the | ||||||
| /// slice of the consumer in-place for scf loop. | ||||||
| FailureOr<scf::SCFFuseConsumerOfSliceResult> | ||||||
| mlir::scf::tileAndFuseConsumerOfSlices( | ||||||
| RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices, | ||||||
| MutableArrayRef<LoopLikeOpInterface> loops) { | ||||||
| if (candidateSlices.empty()) { | ||||||
| return rewriter.notifyMatchFailure( | ||||||
| rewriter.getUnknownLoc(), | ||||||
| "no candidate slices provided for consumer fusion"); | ||||||
| } | ||||||
| // Return if `loops` is empty, return an error for now. Caller is expected | ||||||
| // to handle this case. | ||||||
| if (loops.empty()) { | ||||||
| return rewriter.notifyMatchFailure( | ||||||
| candidateSlices.front(), | ||||||
| "cannot call tile and fuse consumer with an empty loop nest"); | ||||||
| } | ||||||
| | ||||||
| if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) || | ||||||
| llvm::all_of(candidateSlices, | ||||||
| llvm::IsaPred<tensor::ParallelInsertSliceOp>))) { | ||||||
| return rewriter.notifyMatchFailure( | ||||||
| candidateSlices.front(), | ||||||
| "candidates slices need to be all `tensor.extract_slice`s or " | ||||||
| "`tensor.parallel_insert_slice`s"); | ||||||
| } | ||||||
| | ||||||
| // Get the consumer of scf.for for the result yielded by | ||||||
| // tensor.insert_slice/parallel_insert_slice. | ||||||
| SmallVector<OpOperand *> consumerOpOperands; | ||||||
| Operation *consumerOp; | ||||||
| { | ||||||
| FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand = | ||||||
| getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops); | ||||||
| if (failed(maybeConsumerOpOperand)) { | ||||||
| return rewriter.notifyMatchFailure(candidateSlices.front(), | ||||||
| "could not fetch consumer to fuse"); | ||||||
| } | ||||||
| std::swap(consumerOpOperands, maybeConsumerOpOperand.value()); | ||||||
| consumerOp = consumerOpOperands.front()->getOwner(); | ||||||
| } | ||||||
| Comment on lines +2432 to +2441 Contributor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd remove the | ||||||
| | ||||||
| return tileAndFuseConsumerOfSlicesImpl( | ||||||
| rewriter, consumerOp, consumerOpOperands, candidateSlices, loops); | ||||||
| } | ||||||
| | ||||||
| /// For a given `result` of a `forallOp` return the | ||||||
| /// `tensor.parallel_insert_slice` op (or combining op) that is used to | ||||||
| /// construct this result. | ||||||
| static std::optional<Operation *> | ||||||
| getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) { | ||||||
| if (result.getOwner() != forallOp) | ||||||
| return std::nullopt; | ||||||
| BlockArgument bbArg = forallOp.getTiedBlockArgument(result); | ||||||
| SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg); | ||||||
| // If the number of combining ops is not 1, then this is unexpected. Return | ||||||
| // nullopt. | ||||||
| if (combiningOps.size() != 1) { | ||||||
| return std::nullopt; | ||||||
| } | ||||||
| Comment on lines +2458 to +2460 Contributor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style nit: drop braces. | ||||||
| return combiningOps[0]; | ||||||
| } | ||||||
| | ||||||
| /// For a given result of the loop nest that is a tiled loop nest, return the | ||||||
| /// insert slice-like op that is used for consumer fusion | ||||||
| std::optional<Operation *> | ||||||
| Contributor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggested change
| ||||||
| getProducingInsertSliceLikeOp(OpResult result, | ||||||
| ArrayRef<LoopLikeOpInterface> loops) { | ||||||
| assert(!loops.empty() && "Expected loops to be not empty"); | ||||||
| LoopLikeOpInterface outermostLoop = loops.front(); | ||||||
| | ||||||
| if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation())) { | ||||||
| Comment on lines +2470 to +2472 Contributor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In any case, I'd drop the blank line because they belong to the same code snippet -- you declare outermostLoop variable and check it immediately. | ||||||
| assert(loops.size() == 1 && | ||||||
| "expected only a single loop when tiling using scf.forall"); | ||||||
| return getProducingParallelInsertSlice(forallOp, result); | ||||||
| } | ||||||
| // Assume that the loop nest is a nested `scf.for` that is created through | ||||||
| // tiling and retrieve the `tensor.insert_slice` operation used to construct | ||||||
| // the result. | ||||||
| while (loops.size() != 1) { | ||||||
| LoopLikeOpInterface loop = loops.front(); | ||||||
| if (result.getOwner() != loop) | ||||||
| return std::nullopt; | ||||||
| auto forOp = dyn_cast<scf::ForOp>(loop.getOperation()); | ||||||
| if (!forOp) | ||||||
| return std::nullopt; | ||||||
| auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); | ||||||
| OpResult innerForResult = | ||||||
| dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber())); | ||||||
| Comment on lines +2488 to +2489 Contributor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style nit: auto, because | ||||||
| if (!innerForResult) | ||||||
| return std::nullopt; | ||||||
| result = innerForResult; | ||||||
| loops = loops.drop_front(); | ||||||
| } | ||||||
| LoopLikeOpInterface loop = loops.front(); | ||||||
| if (result.getOwner() != loop) | ||||||
| return std::nullopt; | ||||||
| auto forOp = dyn_cast<scf::ForOp>(loop.getOperation()); | ||||||
| if (!forOp) | ||||||
| return std::nullopt; | ||||||
| auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); | ||||||
| auto insertSliceOp = yieldOp.getOperand(result.getResultNumber()) | ||||||
| .getDefiningOp<tensor::InsertSliceOp>(); | ||||||
| if (!insertSliceOp) | ||||||
| return std::nullopt; | ||||||
| return insertSliceOp; | ||||||
| } | ||||||
| | ||||||
| FailureOr<scf::SCFFuseConsumerOfSliceResult> | ||||||
| mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *user, | ||||||
| MutableArrayRef<LoopLikeOpInterface> loops) { | ||||||
| Comment on lines +2509 to +2511 Contributor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This mismatches the declaration that uses | ||||||
| // Only handle users that implement the `TilingInterface`. | ||||||
| Contributor There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: remove the comment because below error message already spells it out. | ||||||
| if (!isa<TilingInterface>(user)) { | ||||||
| return rewriter.notifyMatchFailure( | ||||||
| user, "unhandled user that does not implement TilingInterface"); | ||||||
| } | ||||||
| | ||||||
| // Return if `loops` is empty, return an error for now. Caller is expected | ||||||
| // to handle this case. | ||||||
| if (loops.empty()) { | ||||||
| return rewriter.notifyMatchFailure( | ||||||
| user, "cannot call tile and fuse consumer with an empty loop nest"); | ||||||
| } | ||||||
| | ||||||
| LoopLikeOpInterface outermostLoop = loops.front(); | ||||||
| | ||||||
| // Collect the operands of the user that come from the outermost loop of the | ||||||
| // loop nest. | ||||||
| SmallVector<OpOperand *> consumerFusableOperands; | ||||||
| for (OpOperand &opOperand : user->getOpOperands()) { | ||||||
| if (opOperand.get().getDefiningOp() == outermostLoop) { | ||||||
| consumerFusableOperands.push_back(&opOperand); | ||||||
| } | ||||||
| } | ||||||
| | ||||||
| // Nothing to fuse. Just return an empty set. | ||||||
| if (consumerFusableOperands.empty()) { | ||||||
| return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands, | ||||||
| SmallVector<OpOperand *>{}, | ||||||
| SmallVector<Operation *>{}}; | ||||||
| } | ||||||
| | ||||||
| // Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices | ||||||
| // for fusion. | ||||||
| SmallVector<Operation *> candidateSlices; | ||||||
| candidateSlices.reserve(consumerFusableOperands.size()); | ||||||
| for (OpOperand *opOperand : consumerFusableOperands) { | ||||||
| std::optional<Operation *> slice = | ||||||
| getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops); | ||||||
| if (!slice) { | ||||||
| return rewriter.notifyMatchFailure( | ||||||
| user, | ||||||
| "couldnt find producing insert-slice like operation for operand"); | ||||||
| } | ||||||
| candidateSlices.push_back(slice.value()); | ||||||
| } | ||||||
| return tileAndFuseConsumerOfSlicesImpl( | ||||||
| rewriter, user, consumerFusableOperands, candidateSlices, loops); | ||||||
| } | ||||||
| | ||||||
| //===----------------------------------------------------------------------===// | ||||||
| // lowerToLoopsUsingSCFForOp implementation. | ||||||
| //===----------------------------------------------------------------------===// | ||||||
| | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks not related to the PR itself, we can consider reverting it.