Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,11 @@ def ForallOp : SCF_Op<"forall", [
getNumDynamicControlOperands() + getRank());
}

BlockArgument getTiedBlockArgument(OpResult opResult) {
assert(opResult.getDefiningOp() == getOperation() && "invalid OpResult");
return getBody()->getArgument(getRank() + opResult.getResultNumber());
}

::mlir::Value getInductionVar(int64_t idx) {
return getInductionVars()[idx];
}
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,10 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
/// tiled in a manner that is consistent for all the passed slices. Note that
/// the method replaces the uses of `candidateSlices` with the tiled and fused
/// consumer value but does not delete the slice operations.
/// TODO(MaheshRavishankar): A more natural way of exposing the consumer fusion
/// is to take the consumer operation, and find the slices to use for fusion
/// by walking its operands to the `loops` and then into the body to get the
/// slices used for fusion.
struct SCFFuseConsumerOfSliceResult {
// Original untiled consumer operands.
SmallVector<OpOperand *> origConsumerOperands;
Expand All @@ -427,6 +431,14 @@ tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
ArrayRef<Operation *> candidateSlices,
MutableArrayRef<LoopLikeOpInterface> loops);

/// Fuse the `consumer` operation into the loop nest provided by `loops`.
/// The transformation looks for operands in the `consumer` that are defined
/// by the outermost loop of the loop nest in `loops`. The nested loop is
/// expected to have the structure of the loops generated through tiling.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer,
MutableArrayRef<LoopLikeOpInterface> loops);

/// Method to lower an `op` that implements the `TilingInterface` to
/// loops/scalars.
FailureOr<SmallVector<scf::ForOp>>
Expand Down
223 changes: 174 additions & 49 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ static LogicalResult addInitOperandsToLoopNest(
for (auto [outerLoop, innerLoop] :
llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
// Again assume that all the outer loops are scf.for operations.
auto outerForLoop = cast<scf::ForOp>(outerLoop);
auto outerForLoop = cast<scf::ForOp>(outerLoop.getOperation());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks not related to the PR itself, we can consider reverting it.

auto outerLoopYield =
cast<scf::YieldOp>(outerForLoop.getBody()->getTerminator());
SmallVector<Value> newYields =
Expand Down Expand Up @@ -2184,61 +2184,24 @@ cloneAsInsertSlices(RewriterBase &rewriter,
return clonedSlices;
}

/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
mlir::scf::tileAndFuseConsumerOfSlices(
RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
MutableArrayRef<LoopLikeOpInterface> loops) {
if (candidateSlices.empty()) {
return rewriter.notifyMatchFailure(
rewriter.getUnknownLoc(),
"no candidate slices provided for consumer fusion");
}
// Return if `loops` is empty, return an error for now. Caller is expected
// to handle this case.
if (loops.empty()) {
return rewriter.notifyMatchFailure(
candidateSlices.front(),
"cannot call tile and fuse consumer with an empty loop nest");
}
static FailureOr<scf::SCFFuseConsumerOfSliceResult>
tileAndFuseConsumerOfSlicesImpl(RewriterBase &rewriter, Operation *consumerOp,
ArrayRef<OpOperand *> consumerOpOperands,
ArrayRef<Operation *> candidateSlices,
MutableArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "expected loops to be not empty");

if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
llvm::all_of(candidateSlices,
llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
// 1. Check assumption for loop with `reorderOperations` disabled.
if (failed(checkAssumptionForLoop(loops.front(), consumerOp, false))) {
return rewriter.notifyMatchFailure(
candidateSlices.front(),
"candidates slices need to be all `tensor.extract_slice`s or "
"`tensor.parallel_insert_slice`s");
}

// 1. Get the consumer of scf.for for the result yielded by
// tensor.insert_slice/parallel_insert_slice.
SmallVector<OpOperand *> consumerOpOperands;
Operation *consumerOp;
{
FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSlices.front(),
"could not fetch consumer to fuse");
}
std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
consumerOp = consumerOpOperands.front()->getOwner();
loops.front(), "the first user of loop should not dominate any define "
"of consumer operand(s)");
}

LoopLikeOpInterface outerMostLoop = loops.front();
LoopLikeOpInterface innerMostLoop = loops.back();

// Check assumption for loop with `reorderOperations` disabled.
if (failed(checkAssumptionForLoop(outerMostLoop, consumerOp, false))) {
return rewriter.notifyMatchFailure(
outerMostLoop, "the first user of loop should not dominate any define "
"of consumer operand(s)");
}

OpBuilder::InsertionGuard g(rewriter);

// 2. Check consumer is not using scf loop's output as init.
auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
if (!dstOp)
Expand Down Expand Up @@ -2428,11 +2391,173 @@ mlir::scf::tileAndFuseConsumerOfSlices(
llvm::map_to_vector(operandNumbers, [&](unsigned operandNum) {
return &tileAndFuseResult->tiledOps[0]->getOpOperand(operandNum);
});
auto consumerOpOperandsVec = llvm::to_vector(consumerOpOperands);
return scf::SCFFuseConsumerOfSliceResult{
std::move(consumerOpOperands), std::move(tiledAndFusedOpOperands),
std::move(consumerOpOperandsVec), std::move(tiledAndFusedOpOperands),
std::move(tileAndFuseResult->tiledOps)};
}

/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
mlir::scf::tileAndFuseConsumerOfSlices(
RewriterBase &rewriter, ArrayRef<Operation *> candidateSlices,
MutableArrayRef<LoopLikeOpInterface> loops) {
if (candidateSlices.empty()) {
return rewriter.notifyMatchFailure(
rewriter.getUnknownLoc(),
"no candidate slices provided for consumer fusion");
}
// Return if `loops` is empty, return an error for now. Caller is expected
// to handle this case.
if (loops.empty()) {
return rewriter.notifyMatchFailure(
candidateSlices.front(),
"cannot call tile and fuse consumer with an empty loop nest");
}

if (!(llvm::all_of(candidateSlices, llvm::IsaPred<tensor::InsertSliceOp>) ||
llvm::all_of(candidateSlices,
llvm::IsaPred<tensor::ParallelInsertSliceOp>))) {
return rewriter.notifyMatchFailure(
candidateSlices.front(),
"candidates slices need to be all `tensor.extract_slice`s or "
"`tensor.parallel_insert_slice`s");
}

// Get the consumer of scf.for for the result yielded by
// tensor.insert_slice/parallel_insert_slice.
SmallVector<OpOperand *> consumerOpOperands;
Operation *consumerOp;
{
FailureOr<SmallVector<OpOperand *>> maybeConsumerOpOperand =
getUntiledConsumerOperandsFromSlices(rewriter, candidateSlices, loops);
if (failed(maybeConsumerOpOperand)) {
return rewriter.notifyMatchFailure(candidateSlices.front(),
"could not fetch consumer to fuse");
}
std::swap(consumerOpOperands, maybeConsumerOpOperand.value());
consumerOp = consumerOpOperands.front()->getOwner();
}
Comment on lines +2432 to +2441
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd remove the {} region. I usually use it when I need local variables that don't pollute other phases. Given that it is a small function and we don't have such need, let's drop {}?


return tileAndFuseConsumerOfSlicesImpl(
rewriter, consumerOp, consumerOpOperands, candidateSlices, loops);
}

/// For a given `result` of a `forallOp` return the
/// `tensor.parallel_insert_slice` op (or combining op) that is used to
/// construct this result.
static std::optional<Operation *>
getProducingParallelInsertSlice(scf::ForallOp forallOp, OpResult result) {
if (result.getOwner() != forallOp)
return std::nullopt;
BlockArgument bbArg = forallOp.getTiedBlockArgument(result);
SmallVector<Operation *> combiningOps = forallOp.getCombiningOps(bbArg);
// If the number of combining ops is not 1, then this is unexpected. Return
// nullopt.
if (combiningOps.size() != 1) {
return std::nullopt;
}
Comment on lines +2458 to +2460
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: drop braces.

return combiningOps[0];
}

/// For a given result of the loop nest that is a tiled loop nest, return the
/// insert slice-like op that is used for consumer fusion
std::optional<Operation *>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::optional<Operation *>
static std::optional<Operation *>
getProducingInsertSliceLikeOp(OpResult result,
ArrayRef<LoopLikeOpInterface> loops) {
assert(!loops.empty() && "Expected loops to be not empty");
LoopLikeOpInterface outermostLoop = loops.front();

if (auto forallOp = dyn_cast<scf::ForallOp>(outermostLoop.getOperation())) {
Comment on lines +2470 to +2472
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outermostLoop is only used by this check. I think we don't need the variable for readability. I could be biased, but it is straight-forward to assume the loops order is starting from outermost loop.

In any case, I'd drop the blank line because they belong to the same code snippet -- you declare outermostLoop variable and check it immediately.

assert(loops.size() == 1 &&
"expected only a single loop when tiling using scf.forall");
return getProducingParallelInsertSlice(forallOp, result);
}
// Assume that the loop nest is a nested `scf.for` that is created through
// tiling and retrieve the `tensor.insert_slice` operation used to construct
// the result.
while (loops.size() != 1) {
LoopLikeOpInterface loop = loops.front();
if (result.getOwner() != loop)
return std::nullopt;
auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
if (!forOp)
return std::nullopt;
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
OpResult innerForResult =
dyn_cast<OpResult>(yieldOp.getOperand(result.getResultNumber()));
Comment on lines +2488 to +2489
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: auto, because dyn_cast already spells the type.

if (!innerForResult)
return std::nullopt;
result = innerForResult;
loops = loops.drop_front();
}
LoopLikeOpInterface loop = loops.front();
if (result.getOwner() != loop)
return std::nullopt;
auto forOp = dyn_cast<scf::ForOp>(loop.getOperation());
if (!forOp)
return std::nullopt;
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
auto insertSliceOp = yieldOp.getOperand(result.getResultNumber())
.getDefiningOp<tensor::InsertSliceOp>();
if (!insertSliceOp)
return std::nullopt;
return insertSliceOp;
}

FailureOr<scf::SCFFuseConsumerOfSliceResult>
mlir::scf::tileAndFuseConsumer(RewriterBase &rewriter, Operation *user,
MutableArrayRef<LoopLikeOpInterface> loops) {
Comment on lines +2509 to +2511
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mismatches the declaration that uses consumer instead of user.

FailureOr<scf::SCFFuseConsumerOfSliceResult> tileAndFuseConsumer(RewriterBase &rewriter, Operation *consumer, MutableArrayRef<LoopLikeOpInterface> loops); 
// Only handle users that implement the `TilingInterface`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove the comment because below error message already spells it out.

if (!isa<TilingInterface>(user)) {
return rewriter.notifyMatchFailure(
user, "unhandled user that does not implement TilingInterface");
}

// Return if `loops` is empty, return an error for now. Caller is expected
// to handle this case.
if (loops.empty()) {
return rewriter.notifyMatchFailure(
user, "cannot call tile and fuse consumer with an empty loop nest");
}

LoopLikeOpInterface outermostLoop = loops.front();

// Collect the operands of the user that come from the outermost loop of the
// loop nest.
SmallVector<OpOperand *> consumerFusableOperands;
for (OpOperand &opOperand : user->getOpOperands()) {
if (opOperand.get().getDefiningOp() == outermostLoop) {
consumerFusableOperands.push_back(&opOperand);
}
}

// Nothing to fuse. Just return an empty set.
if (consumerFusableOperands.empty()) {
return mlir::scf::SCFFuseConsumerOfSliceResult{consumerFusableOperands,
SmallVector<OpOperand *>{},
SmallVector<Operation *>{}};
}

// Collect the relevant tensor.insert_slice/tensor.parallel_insert_slices
// for fusion.
SmallVector<Operation *> candidateSlices;
candidateSlices.reserve(consumerFusableOperands.size());
for (OpOperand *opOperand : consumerFusableOperands) {
std::optional<Operation *> slice =
getProducingInsertSliceLikeOp(cast<OpResult>(opOperand->get()), loops);
if (!slice) {
return rewriter.notifyMatchFailure(
user,
"couldnt find producing insert-slice like operation for operand");
}
candidateSlices.push_back(slice.value());
}
return tileAndFuseConsumerOfSlicesImpl(
rewriter, user, consumerFusableOperands, candidateSlices, loops);
}

//===----------------------------------------------------------------------===//
// lowerToLoopsUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ module {
// Fuse the consumer operation into the tiled loop.
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
transform.test.fuse_consumer %slice_op in (%forall_op)
transform.test.fuse_consumer_using_slice %slice_op in (%forall_op)
: (!transform.op<"tensor.parallel_insert_slice">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
Expand Down Expand Up @@ -231,7 +231,7 @@ module {
// Fuse the consumer operation into the tiled loop.
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
// Note that we cannot apply transform.test.fuse_consumer here because the extract_slice
// Note that we cannot apply transform.test.fuse_consumer_using_slice here because the extract_slice
// is not qualified consumer operation. Forcing this will yeild "could not fetch consumer
// to fuse" error.
transform.yield
Expand Down
Loading
Loading