Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,32 @@ def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}

def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let summary = "Converts scf.forall into a nest of scf.for operations";
let description = [{
Converts the `scf.forall` operation pointed to by the given handle into an
`scf.parallel` operation.

The operand handle must be associated with exactly one payload operation.

Loops with outputs are not supported.

#### Return Modes

Consumes the operand handle. Produces a silenceable failure if the operand
is not associated with a single `scf.forall` payload operation.
Returns a handle to the new `scf.parallel` operation.
Produces a silenceable failure if another number of resulting handles is
requested.
}];
let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);

let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
}

def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForLoopRangeFoldingPass();
/// Creates a pass that converts SCF forall loops to SCF for loops.
std::unique_ptr<Pass> createForallToForLoopPass();

/// Creates a pass that converts SCF forall loops to SCF parallel loops.
std::unique_ptr<Pass> createForallToParallelLoopPass();

// Creates a pass which lowers for loops into while loops.
std::unique_ptr<Pass> createForToWhileLoopPass();

Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ def SCFForallToForLoop : Pass<"scf-forall-to-for"> {
let constructor = "mlir::createForallToForLoopPass()";
}

def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
let summary = "Convert SCF forall loops to SCF parallel loops";
let constructor = "mlir::createForallToParallelLoopPass()";
}

def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
let summary = "Convert SCF for loops to SCF while loops";
let constructor = "mlir::createForToWhileLoopPass()";
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class WhileOp;
LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
SmallVectorImpl<Operation *> *results = nullptr);

/// Try converting scf.forall into an scf.parallel loop.
/// The conversion is only supported for forall operations with no results.
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
ParallelOp *result = nullptr);

/// Fuses all adjacent scf.parallel operations with identical bounds and step
/// into one scf.parallel operations. Uses a naive aliasing and dependency
/// analysis.
Expand Down
29 changes: 2 additions & 27 deletions mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
Expand Down Expand Up @@ -688,33 +689,7 @@ IndexSwitchLowering::matchAndRewrite(IndexSwitchOp op,

LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
PatternRewriter &rewriter) const {
Location loc = forallOp.getLoc();
if (!forallOp.getOutputs().empty())
return rewriter.notifyMatchFailure(
forallOp,
"only fully bufferized scf.forall ops can be lowered to scf.parallel");

// Convert mixed bounds and steps to SSA values.
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedLowerBound());
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedUpperBound());
SmallVector<Value> steps =
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());

// Create empty scf.parallel op.
auto parallelOp = rewriter.create<ParallelOp>(loc, lbs, ubs, steps);
rewriter.eraseBlock(&parallelOp.getRegion().front());
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
parallelOp.getRegion().begin());
// Replace the terminator.
rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
parallelOp.getRegion().front().getTerminator());

// Erase the scf.forall op.
rewriter.replaceOp(forallOp, parallelOp);
return success();
return scf::forallToParallelLoop(rewriter, forallOp);
}

void mlir::populateSCFToControlFlowConversionPatterns(
Expand Down
44 changes: 44 additions & 0 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,50 @@ transform::ForallToForOp::apply(transform::TransformRewriter &rewriter,
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// ForallToForOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
auto payload = state.getPayloadOps(getTarget());
if (!llvm::hasSingleElement(payload))
return emitSilenceableError() << "expected a single payload op";

auto target = dyn_cast<scf::ForallOp>(*payload.begin());
if (!target) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "expected the payload to be scf.forall";
diag.attachNote((*payload.begin())->getLoc()) << "payload op";
return diag;
}

if (!target.getOutputs().empty()) {
return emitSilenceableError()
<< "unsupported shared outputs (didn't bufferize?)";
}

if (getNumResults() != 1) {
DiagnosedSilenceableFailure diag = emitSilenceableError()
<< "op expects one result, given "
<< getNumResults();
diag.attachNote(target.getLoc()) << "payload op";
return diag;
}

scf::ParallelOp opResult;
if (failed(scf::forallToParallelLoop(rewriter, target, &opResult))) {
DiagnosedSilenceableFailure diag =
emitSilenceableError() << "failed to convert forall into parallel";
return diag;
}

results.set(cast<OpResult>(getTransformed()[0]), {opResult});
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// LoopOutlineOp
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
ForallToFor.cpp
ForallToParallel.cpp
ForToWhile.cpp
LoopCanonicalization.cpp
LoopPipelining.cpp
Expand Down
82 changes: 82 additions & 0 deletions mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
//===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Transforms SCF.ForallOp's into SCF.ParallelOps's.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Passes.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir {
#define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir

using namespace mlir;

LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
scf::ForallOp forallOp,
scf::ParallelOp *result) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(forallOp);

Location loc = forallOp.getLoc();
if (!forallOp.getOutputs().empty())
return rewriter.notifyMatchFailure(
forallOp,
"only fully bufferized scf.forall ops can be lowered to scf.parallel");

// Convert mixed bounds and steps to SSA values.
SmallVector<Value> lbs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedLowerBound());
SmallVector<Value> ubs = getValueOrCreateConstantIndexOp(
rewriter, loc, forallOp.getMixedUpperBound());
SmallVector<Value> steps =
getValueOrCreateConstantIndexOp(rewriter, loc, forallOp.getMixedStep());

// Create empty scf.parallel op.
auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps);
rewriter.eraseBlock(&parallelOp.getRegion().front());
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
parallelOp.getRegion().begin());
// Replace the terminator.
rewriter.setInsertionPointToEnd(&parallelOp.getRegion().front());
rewriter.replaceOpWithNewOp<scf::ReduceOp>(
parallelOp.getRegion().front().getTerminator());

// Erase the scf.forall op.
rewriter.replaceOp(forallOp, parallelOp);

if (result)
*result = parallelOp;

return success();
}

namespace {
struct ForallToParallelLoop final
: public impl::SCFForallToParallelLoopBase<ForallToParallelLoop> {
void runOnOperation() override {
Operation *parentOp = getOperation();
IRRewriter rewriter(parentOp->getContext());

parentOp->walk([&](scf::ForallOp forallOp) {
if (failed(scf::forallToParallelLoop(rewriter, forallOp))) {
return signalPassFailure();
}
});
}
};
} // namespace

std::unique_ptr<Pass> mlir::createForallToParallelLoopPass() {
return std::make_unique<ForallToParallelLoop>();
}
62 changes: 62 additions & 0 deletions mlir/test/Dialect/SCF/forall-to-parallel.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-forall-to-parallel))' -split-input-file | FileCheck %s

func.func private @callee(%i: index, %j: index)

// CHECK-LABEL: @two_iters
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
func.func @two_iters(%ub1: index, %ub2: index) {
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}

// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
// CHECK: scf.reduce
return
}

// -----

func.func private @callee(%i: index, %j: index)

// CHECK-LABEL: @repeated
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
func.func @repeated(%ub1: index, %ub2: index) {
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}

// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
// CHECK: scf.reduce
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}

// CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
// CHECK: func.call @callee(%[[IV3]], %[[IV4]])
// CHECK: scf.reduce
return
}

// -----

func.func private @callee(%i: index, %j: index, %k: index, %l: index)

// CHECK-LABEL: @nested
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index, %[[UB3:.+]]: index, %[[UB4:.+]]: index
func.func @nested(%ub1: index, %ub2: index, %ub3: index, %ub4: index) {
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]]) step (%{{.*}}, %{{.*}}) {
// CHECK: scf.parallel (%[[IV3:.+]], %[[IV4:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB3]], %[[UB4]]) step (%{{.*}}, %{{.*}}) {
// CHECK: func.call @callee(%[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]])
// CHECK: scf.reduce
// CHECK: }
// CHECK: scf.reduce
// CHECK: }
scf.forall (%i, %j) in (%ub1, %ub2) {
scf.forall (%k, %l) in (%ub3, %ub4) {
func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
}
}
return
}
60 changes: 60 additions & 0 deletions mlir/test/Dialect/SCF/transform-op-forall-to-parallel.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s

func.func private @callee(%i: index, %j: index)

// CHECK-LABEL: @two_iters
// CHECK-SAME: %[[UB1:.+]]: index, %[[UB2:.+]]: index
func.func @two_iters(%ub1: index, %ub2: index) {
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
// CHECK: scf.parallel (%[[IV1:.+]], %[[IV2:.+]]) = (%{{.*}}, %{{.*}}) to (%[[UB1]], %[[UB2]])
// CHECK: func.call @callee(%[[IV1]], %[[IV2]]) : (index, index) -> ()
// CHECK: scf.reduce
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}

// -----

func.func private @callee(%i: index, %j: index)

func.func @repeated(%ub1: index, %ub2: index) {
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
scf.forall (%i, %j) in (%ub1, %ub2) {
func.call @callee(%i, %j) : (index, index) -> ()
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{expected a single payload op}}
transform.loop.forall_to_parallel %0 : (!transform.any_op) -> (!transform.any_op)
transform.yield
}
}

// -----

// expected-note @below {{payload op}}
func.func private @callee(%i: index, %j: index)

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{expected the payload to be scf.forall}}
transform.loop.forall_to_for %0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}