- Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][math] Add clampf and clean math ExpandOps API #151153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| @llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Fabian Mora (fabianmcg) ChangesThis patch adds the The reasoning behind adding this operation is that some GPU vendors offer specialized intrinsics for this operation, or subsets of this operation. For example, __saturatef in NVIDIA GPUs, or This patch also removes void populateExpansionPatterns(RewritePatternSet &patterns, ArrayRef<StringRef> opMnemonics = {});Patch is 20.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/151153.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td index 56370388dea87..cfd8c4b8f11f7 100644 --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -352,6 +352,37 @@ def Math_CeilOp : Math_FloatUnaryOp<"ceil"> { let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// ClampFOp +//===----------------------------------------------------------------------===// + +def Math_ClampFOp : Math_FloatTernaryOp<"clampf"> { + let summary = "floating point clamping operation"; + let description = [{ + The `clampf` operation takes three operands and returns one result, each of + these is required to be the same type. Operands must be of floating point type + (i.e., scalar, tensor or vector). + + The semantics of the operation are described by: + ``` + clampf(value, min, max) = maxf(minf(value, min), max) + ``` + + Example: + + ```mlir + %d = math.clampf %value to [%min, %max] : f64 + ``` + }]; + let arguments = (ins FloatLike:$value, FloatLike:$min, FloatLike:$max, + DefaultValuedAttr<Arith_FastMathAttr, + "::mlir::arith::FastMathFlags::none">:$fastmath); + let assemblyFormat = [{ + $value `to` ` ` `[` $min `,` $max `]` (`fastmath` `` $fastmath^)? + attr-dict `:` type($result) + }]; +} + //===----------------------------------------------------------------------===// // CopySignOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h index c0fe5d3be448a..b3abbf728a3c6 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -23,22 +23,16 @@ class ConversionTarget; class RewritePatternSet; class TypeConverter; -void populateExpandCtlzPattern(RewritePatternSet &patterns); -void populateExpandTanPattern(RewritePatternSet &patterns); -void populateExpandSinhPattern(RewritePatternSet &patterns); -void populateExpandCoshPattern(RewritePatternSet &patterns); -void populateExpandTanhPattern(RewritePatternSet &patterns); -void populateExpandAsinhPattern(RewritePatternSet &patterns); -void populateExpandAcoshPattern(RewritePatternSet &patterns); -void populateExpandAtanhPattern(RewritePatternSet &patterns); -void populateExpandFmaFPattern(RewritePatternSet &patterns); -void populateExpandCeilFPattern(RewritePatternSet &patterns); -void populateExpandExp2FPattern(RewritePatternSet &patterns); -void populateExpandPowFPattern(RewritePatternSet &patterns); -void populateExpandFPowIPattern(RewritePatternSet &patterns); -void populateExpandRoundFPattern(RewritePatternSet &patterns); -void populateExpandRoundEvenPattern(RewritePatternSet &patterns); -void populateExpandRsqrtPattern(RewritePatternSet &patterns); +namespace math { +/// Adds patterns to expand math operations into other more fundamental +/// operations. For example, hyperbolic functions are expanded into expressions +/// using `exp`. If `opMnemonics` is empty then all available patterns will be +/// added, otherwise only the patterns corresponding to ops in `opMnemonics` +/// will be added to the set. +void populateExpansionPatterns(RewritePatternSet &patterns, + ArrayRef<StringRef> opMnemonics = {}); +} // namespace math + void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); struct MathPolynomialApproximationOptions { diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td index a84c89020d4f3..4d415aeac8f58 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td @@ -44,4 +44,24 @@ def MathExtendToSupportedTypes : Pass<"math-extend-to-supported-types"> { let dependentDialects = ["math::MathDialect", "arith::ArithDialect"]; } +def MathExpandOpsPass : Pass<"math-expand-ops"> { + let summary = "Expand math operations."; + let description = [{ + Expands some math operations into more fundamental operations, allowing them + to be subsequently lowered through these. For example, hyperbolic functions + are transformed into their expanded form containing only `exp` functions. + + The `ops` parameter can be used to apply only a subset of all the + available expansions, these must correspond to the operation mnemonic. + For example, `ops=sinh,acosh` will expand only `math.sinh` and + `math.acosh` operations. If the list is empty, then all expansions are + applied. + }]; + let dependentDialects = ["arith::ArithDialect"]; + let options = [ + ListOption<"opMnemonics", "ops", "std::string", + "Operations to expand."> + ]; +} + #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt index e1c0c2410c126..d37a056e8e158 100644 --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -1,6 +1,6 @@ add_mlir_dialect_library(MLIRMathTransforms AlgebraicSimplification.cpp - ExpandPatterns.cpp + ExpandOps.cpp ExtendToSupportedTypes.cpp PolynomialApproximation.cpp UpliftToFMA.cpp diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp similarity index 89% rename from mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp rename to mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp index 4a40a3055ed62..cd68039d0d964 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp @@ -13,14 +13,18 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; +namespace mlir::math { +#define GEN_PASS_DEF_MATHEXPANDOPSPASS +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" +} // namespace mlir::math + /// Create a float constant. static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b) { @@ -661,66 +665,77 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op, return success(); } -void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { - patterns.add(convertCtlzOp); -} - -void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) { - patterns.add(convertSinhOp); -} - -void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) { - patterns.add(convertCoshOp); -} - -void mlir::populateExpandTanPattern(RewritePatternSet &patterns) { - patterns.add(convertTanOp); -} - -void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { - patterns.add(convertTanhOp); -} - -void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) { - patterns.add(convertAsinhOp); -} - -void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) { - patterns.add(convertAcoshOp); -} - -void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) { - patterns.add(convertAtanhOp); -} - -void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { - patterns.add(convertFmaFOp); -} - -void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) { - patterns.add(convertCeilOp); -} - -void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { - patterns.add(convertExp2fOp); -} - -void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { - patterns.add(convertPowfOp); -} - -void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) { - patterns.add(convertFPowIOp); -} - -void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { - patterns.add(convertRoundOp); +// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf` +static LogicalResult convertClampfOp(math::ClampFOp op, + PatternRewriter &rewriter) { + auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(), + op.getMin(), op.getFastmath()); + rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMax(), + op.getFastmath()); + return success(); } -void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) { - patterns.add(convertRoundEvenOp); +void mlir::math::populateExpansionPatterns(RewritePatternSet &patterns, + ArrayRef<StringRef> opMnemonics) { + auto filter = [&](StringRef name) { + // This should be a static assert and `consume_front` take a twine, but none + // is currently possible. TODO: augment `StringRef::consume_front` and make + // `getDialectNamespace` use `std::string_view`. + assert("math" == MathDialect::getDialectNamespace()); + name.consume_front("math."); + return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0); + }; + if (filter(CountLeadingZerosOp::getOperationName())) + patterns.add(convertCtlzOp); + if (filter(SinhOp::getOperationName())) + patterns.add(convertSinhOp); + if (filter(CoshOp::getOperationName())) + patterns.add(convertCoshOp); + if (filter(TanOp::getOperationName())) + patterns.add(convertTanOp); + if (filter(TanhOp::getOperationName())) + patterns.add(convertTanhOp); + if (filter(AsinhOp::getOperationName())) + patterns.add(convertAsinhOp); + if (filter(AcoshOp::getOperationName())) + patterns.add(convertAcoshOp); + if (filter(AtanhOp::getOperationName())) + patterns.add(convertAtanhOp); + if (filter(FmaOp::getOperationName())) + patterns.add(convertFmaFOp); + if (filter(CeilOp::getOperationName())) + patterns.add(convertCeilOp); + if (filter(Exp2Op::getOperationName())) + patterns.add(convertExp2fOp); + if (filter(PowFOp::getOperationName())) + patterns.add(convertPowfOp); + if (filter(FPowIOp::getOperationName())) + patterns.add(convertFPowIOp); + if (filter(RoundOp::getOperationName())) + patterns.add(convertRoundOp); + if (filter(RoundEvenOp::getOperationName())) + patterns.add(convertRoundEvenOp); + if (filter(RsqrtOp::getOperationName())) + patterns.add(convertRsqrtOp); + if (filter(ClampFOp::getOperationName())) + patterns.add(convertClampfOp); } -void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) { - patterns.add(convertRsqrtOp); -} +//===----------------------------------------------------------------------===// +// MathExpandOpsPass pass +//===----------------------------------------------------------------------===// +namespace { +struct MathExpandOpsPass final + : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> { + using MathExpandOpsPassBase::MathExpandOpsPassBase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + SmallVector<StringRef> mnemonics = + llvm::to_vector_of<StringRef>(opMnemonics); + math::populateExpansionPatterns(patterns, mnemonics); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 1420acaa40d35..615c607efc3c3 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -1,7 +1,9 @@ -// RUN: mlir-opt %s --split-input-file -test-expand-math | FileCheck %s +// RUN: mlir-opt %s --split-input-file -math-expand-ops | FileCheck %s +// RUN: mlir-opt %s --split-input-file -math-expand-ops=ops=tanh,tan | FileCheck %s --check-prefix=CHECK-FILTER // CHECK-LABEL: func @tanh func.func @tanh(%arg: f32) -> f32 { + // CHECK-FILTER-NOT: math.tanh %res = math.tanh %arg : f32 return %res : f32 } @@ -27,6 +29,7 @@ func.func @tanh(%arg: f32) -> f32 { // CHECK-LABEL: func @vector_tanh func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> { // CHECK-NOT: math.tanh + // CHECK-FILTER-NOT: math.tanh %res = math.tanh %arg : vector<4xf32> return %res : vector<4xf32> } @@ -35,6 +38,7 @@ func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> { // CHECK-LABEL: func @tan func.func @tan(%arg: f32) -> f32 { + // CHECK-FILTER-NOT: math.tan %res = math.tan %arg : f32 return %res : f32 } @@ -49,6 +53,7 @@ func.func @tan(%arg: f32) -> f32 { // CHECK-LABEL: func @vector_tan func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> { + // CHECK-FILTER-NOT: math.tan %res = math.tan %arg : vector<4xf32> return %res : vector<4xf32> } @@ -58,6 +63,7 @@ func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> { // ----- func.func @ctlz(%arg: i32) -> i32 { + // CHECK-FILTER: math.ctlz %res = math.ctlz %arg : i32 return %res : i32 } @@ -112,6 +118,7 @@ func.func @ctlz(%arg: i32) -> i32 { // ----- func.func @ctlz_vector(%arg: vector<4xi32>) -> vector<4xi32> { + // CHECK-FILTER: math.ctlz %res = math.ctlz %arg : vector<4xi32> return %res : vector<4xi32> } @@ -145,6 +152,7 @@ func.func @ceilf_func(%a: f64) -> f64 { // CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]] // CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]] // CHECK-NEXT: return [[ADDF]] + // CHECK-FILTER: math.ceil %ret = math.ceil %a : f64 return %ret : f64 } @@ -158,6 +166,7 @@ func.func @exp2f_func(%a: f64) -> f64 { // CHECK: [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]] // CHECK: [[EXP:%.+]] = math.exp [[MULF]] // CHECK: return [[EXP]] + // CHECK-FILTER: math.exp2 %ret = math.exp2 %a : f64 return %ret : f64 } @@ -813,3 +822,27 @@ func.func @unranked_rsqrt_op(%arg: tensor<*xf32>) -> tensor<*xf32>{ %a = math.rsqrt %arg : tensor<*xf32> return %a: tensor<*xf32> } + +// ----- + +// CHECK-LABEL: func.func @clampf_scalar_op +// CHECK-SAME: (%[[ARG:.*]]: f16, %[[MIN:.*]]: f16, %[[MAX:.*]]: f16) +// CHECK: %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] : f16 +// CHECK: %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] : f16 +// CHECK: return %[[V1]] : f16 + +func.func @clampf_scalar_op(%arg: f16, %min: f16, %max: f16) -> f16 { + %a = math.clampf %arg to [%min, %max] : f16 + return %a: f16 +} + +// CHECK-LABEL: func.func @clampf_vector_op +// CHECK-SAME: (%[[ARG:.*]]: vector<3x4xf32>, %[[MIN:.*]]: vector<3x4xf32>, %[[MAX:.*]]: vector<3x4xf32>) +// CHECK: %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] fastmath<fast> : vector<3x4xf32> +// CHECK: %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] fastmath<fast> : vector<3x4xf32> +// CHECK: return %[[V1]] : vector<3x4xf32> + +func.func @clampf_vector_op(%arg: vector<3x4xf32>, %min: vector<3x4xf32>, %max: vector<3x4xf32>) -> vector<3x4xf32>{ + %a = math.clampf %arg to [%min, %max] fastmath<fast> : vector<3x4xf32> + return %a: vector<3x4xf32> +} diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir index 8feadedd1860e..cb10fc4397ffc 100644 --- a/mlir/test/Dialect/Math/ops.mlir +++ b/mlir/test/Dialect/Math/ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s // RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s // CHECK-LABEL: func @atan( @@ -337,3 +337,16 @@ func.func @fpclassify(%f: f32, %d: f64, %v: vector<4xf32>, %t: tensor<4x?xf32>) math.isnormal %t : tensor<4x?xf32> return } + +// CHECK-LABEL: func @clampf( +func.func @clampf(%av: vector<3x4xf32>, %mv: vector<3x4xf32>, %Mv: vector<3x4xf32>, + %as: f32, %ms: f32, %Ms: f32, + %at: tensor<?xf80>, %mt: tensor<?xf80>, %Mt: tensor<?xf80>) { + // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] fastmath<fast> : vector<3x4xf32> + %rv = math.clampf %av to [%mv, %Mv] fastmath<fast> : vector<3x4xf32> + // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] : f32 + %rs = math.clampf %as to [%ms, %Ms] fastmath<none> : f32 + // CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] : tensor<?xf80> + %rt = math.clampf %at to [%mt, %Mt] : tensor<?xf80> + return +} diff --git a/mlir/test/lib/Dialect/Math/CMakeLists.txt b/mlir/test/lib/Dialect/Math/CMakeLists.txt index 91e70d1785369..900dff3b5e9f1 100644 --- a/mlir/test/lib/Dialect/Math/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Math/CMakeLists.txt @@ -1,7 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRMathTestPasses TestAlgebraicSimplification.cpp - TestExpandMath.cpp TestPolynomialApproximation.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp deleted file mode 100644 index efc1acf8bb6cd..0000000000000 --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ /dev/null @@ -1,62 +0,0 @@ -//===- TestExpandMath.cpp - Test expand math op into exp form -------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains test passes for expanding math operations. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; - -namespace { -struct TestExpandMathPass - : public PassWrapper<TestExpandMathPass, OperationPass<>> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestExpandMathPass) - - void runOnOperation() override; - StringRef getArgument() const final { return "test-expand-math"; } - void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert<arith::ArithDialect, scf::SCFDialect, vector::VectorDialect>(); - } - StringRef getDescription() const final { return "Test expanding math"; } -}; -} // namespace - -void TestExpandMathPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - populateExpandCtlzPattern(patterns); - populateExpandExp2FPattern(patterns); - populateExpandTanPattern(patterns); - populateExpandSinhPattern(patterns); - populateExpandCoshPattern(patterns); - populateExpandTanhPattern(patterns); - populateExpandAsinhPattern(patterns); - populateExpandAcoshPattern(patterns); - populateExpandAtanhPattern(patterns); - populateExpandFmaFPattern(patterns); - populateExpandCeilFPattern(patterns); - populateExpandPowFPattern(patterns); - populateExpandFPowIPattern(patterns); - populateExpandRoundFPattern(patterns); - populateExpandRoundEvenPattern(patterns); - populateExpandRsqrtPattern(patterns); - (void)applyPatternsGreedily(getOperation(), std::move(patterns)); -} - -namespace mlir { -namespace test { -void registerTestExpandMathPass() { PassRegistration<TestExpandMathPass>(); } -} // namespace test -} // namespace mlir diff --git a/mlir/test/mlir-runner/test-expand-math-approx.mlir b/mlir/test/mlir-runner/test-expand-math-approx.mlir index b599c9d8435d4..3f9d3f2125e1a 100644 --- a/mlir/test/mlir-runner/test-expand-math-approx.mlir +++ b/mlir/test/mlir-runner/test-expand-math-approx.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(test-expand-math),convert-vector-to-scf,convert-scf-to-cf,convert-vector-to-llvm,convert-to-llvm,reconcile-unrealized-casts)" \ +// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(math-expand-ops),convert-vector-to-scf,convert-scf-to-cf,convert-vector-to-llvm,convert-to-llvm,reconcile-unrealized-casts)" \ // RUN: | mlir-runner \ // RUN: -e main -entry-point-result=void -O0 ... [truncated] |
clampf and clean math ExpandOpsclampf and clean math ExpandOps API
nicolasvasilache left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
This patch adds the
clampfoperation to the math dialect. The semantics op are defined as:The reasoning behind adding this operation is that some GPU vendors offer specialized intrinsics for this operation, or subsets of this operation. For example, __saturatef in NVIDIA GPUs, or
__builtin_amdgcn_fmed3fin AMD GPUs.This patch also removes
test-expand-mathin favor ofmath-expand-ops.Finally, it removes individual expansion population API calls like
populateExpandCoshPatternin favor of: