Skip to content

Conversation

ashermancinelli
Copy link
Contributor

Previously, an FPowI operation would invert the base before performing a sequence of multiplications, but this led to discrepancies between LLVM pow intrinsic folding and that coming from the math dialect.

See compiler-rt's version, which does the inverse at the end of the calculation: compiler-rt/lib/builtins/powidf2.c

Previously, an FPowI operation would invert the base *before* performing a sequence of multiplications, but this led to discrepancies between LLVM pow intrinsic folding and that coming from the math dialect. See compiler-rt's version, which does the inverse at the end of the calculation: compiler-rt/lib/builtins/powidf2.c
@llvmbot
Copy link
Member

llvmbot commented Apr 15, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-math

Author: Asher Mancinelli (ashermancinelli)

Changes

Previously, an FPowI operation would invert the base before performing a sequence of multiplications, but this led to discrepancies between LLVM pow intrinsic folding and that coming from the math dialect.

See compiler-rt's version, which does the inverse at the end of the calculation: compiler-rt/lib/builtins/powidf2.c


Full diff: https://github.com/llvm/llvm-project/pull/135735.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp (+5-5)
  • (modified) mlir/test/Dialect/Math/algebraic-simplification.mlir (+24-24)
diff --git a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp index dcace489673f0..13e2a4b5541b2 100644 --- a/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp +++ b/mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp @@ -197,11 +197,6 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite( if (exponentValue > exponentThreshold) return failure(); - // Inverse the base for negative exponent, i.e. for - // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`. - if (exponentIsNegative) - base = rewriter.create<DivOpTy>(loc, bcast(one), base); - Value result = base; // Transform to naive sequence of multiplications: // * For positive exponent case replace: @@ -215,6 +210,11 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite( for (unsigned i = 1; i < exponentValue; ++i) result = rewriter.create<MulOpTy>(loc, result, base); + // Inverse the base for negative exponent, i.e. for + // `[fi]powi(x, negative_exponent)` set `x` to `1 / x`. + if (exponentIsNegative) + result = rewriter.create<DivOpTy>(loc, bcast(one), result); + rewriter.replaceOp(op, result); return success(); } diff --git a/mlir/test/Dialect/Math/algebraic-simplification.mlir b/mlir/test/Dialect/Math/algebraic-simplification.mlir index a97ecc52a17e9..e0e2b9853a2a1 100644 --- a/mlir/test/Dialect/Math/algebraic-simplification.mlir +++ b/mlir/test/Dialect/Math/algebraic-simplification.mlir @@ -135,11 +135,11 @@ func.func @ipowi_exp_two(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi32 // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1> : vector<4xi32> // CHECK: %[[SCALAR0:.*]] = arith.muli %[[ARG0]], %[[ARG0]] // CHECK: %[[VECTOR0:.*]] = arith.muli %[[ARG1]], %[[ARG1]] - // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]] - // CHECK: %[[SMUL:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]] - // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]] - // CHECK: %[[VMUL:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]] - // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]] + // CHECK: %[[SMUL:.*]] = arith.muli %[[ARG0]], %[[ARG0]] + // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[SMUL]] + // CHECK: %[[VMUL:.*]] = arith.muli %[[ARG1]], %[[ARG1]] + // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[VMUL]] + // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]] %c1 = arith.constant 2 : i32 %v1 = arith.constant dense <2> : vector<4xi32> %0 = math.ipowi %arg0, %c1 : i32 @@ -162,13 +162,13 @@ func.func @ipowi_exp_three(%arg0: i32, %arg1: vector<4xi32>) -> (i32, vector<4xi // CHECK: %[[SCALAR0:.*]] = arith.muli %[[SMUL0]], %[[ARG0]] // CHECK: %[[VMUL0:.*]] = arith.muli %[[ARG1]], %[[ARG1]] // CHECK: %[[VECTOR0:.*]] = arith.muli %[[VMUL0]], %[[ARG1]] - // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[ARG0]] - // CHECK: %[[SMUL1:.*]] = arith.muli %[[SCALAR1]], %[[SCALAR1]] - // CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[SCALAR1]] - // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[ARG1]] - // CHECK: %[[VMUL1:.*]] = arith.muli %[[VECTOR1]], %[[VECTOR1]] - // CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[VECTOR1]] - // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]] + // CHECK: %[[SMUL1:.*]] = arith.muli %[[ARG0]], %[[ARG0]] + // CHECK: %[[SMUL2:.*]] = arith.muli %[[SMUL1]], %[[ARG0]] + // CHECK: %[[SCALAR1:.*]] = arith.divsi %[[CST_S]], %[[SMUL2]] + // CHECK: %[[VMUL1:.*]] = arith.muli %[[ARG1]], %[[ARG1]] + // CHECK: %[[VMUL2:.*]] = arith.muli %[[VMUL1]], %[[ARG1]] + // CHECK: %[[VECTOR1:.*]] = arith.divsi %[[CST_V]], %[[VMUL2]] + // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]] %c1 = arith.constant 3 : i32 %v1 = arith.constant dense <3> : vector<4xi32> %0 = math.ipowi %arg0, %c1 : i32 @@ -225,11 +225,11 @@ func.func @fpowi_exp_two(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf32 // CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32> // CHECK: %[[SCALAR0:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] // CHECK: %[[VECTOR0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]] - // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]] - // CHECK: %[[SMUL:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]] - // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]] - // CHECK: %[[VMUL:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]] - // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL]], %[[VMUL]] + // CHECK: %[[SMUL:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] + // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[SMUL]] + // CHECK: %[[VMUL:.*]] = arith.mulf %[[ARG1]], %[[ARG1]] + // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[VMUL]] + // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]] %c1 = arith.constant 2 : i32 %v1 = arith.constant dense <2> : vector<4xi32> %0 = math.fpowi %arg0, %c1 : f32, i32 @@ -252,13 +252,13 @@ func.func @fpowi_exp_three(%arg0: f32, %arg1: vector<4xf32>) -> (f32, vector<4xf // CHECK: %[[SCALAR0:.*]] = arith.mulf %[[SMUL0]], %[[ARG0]] // CHECK: %[[VMUL0:.*]] = arith.mulf %[[ARG1]], %[[ARG1]] // CHECK: %[[VECTOR0:.*]] = arith.mulf %[[VMUL0]], %[[ARG1]] - // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[ARG0]] - // CHECK: %[[SMUL1:.*]] = arith.mulf %[[SCALAR1]], %[[SCALAR1]] - // CHECK: %[[SMUL2:.*]] = arith.mulf %[[SMUL1]], %[[SCALAR1]] - // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[ARG1]] - // CHECK: %[[VMUL1:.*]] = arith.mulf %[[VECTOR1]], %[[VECTOR1]] - // CHECK: %[[VMUL2:.*]] = arith.mulf %[[VMUL1]], %[[VECTOR1]] - // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SMUL2]], %[[VMUL2]] + // CHECK: %[[SMUL1:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] + // CHECK: %[[SMUL2:.*]] = arith.mulf %[[SMUL1]], %[[ARG0]] + // CHECK: %[[SCALAR1:.*]] = arith.divf %[[CST_S]], %[[SMUL2]] + // CHECK: %[[VMUL1:.*]] = arith.mulf %[[ARG1]], %[[ARG1]] + // CHECK: %[[VMUL2:.*]] = arith.mulf %[[VMUL1]], %[[ARG1]] + // CHECK: %[[VECTOR1:.*]] = arith.divf %[[CST_V]], %[[VMUL2]] + // CHECK: return %[[SCALAR0]], %[[VECTOR0]], %[[SCALAR1]], %[[VECTOR1]] %c1 = arith.constant 3 : i32 %v1 = arith.constant dense <3> : vector<4xi32> %0 = math.fpowi %arg0, %c1 : f32, i32 
Copy link
Contributor

@vzakhari vzakhari left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Asher!

@ashermancinelli ashermancinelli merged commit 9ab2dea into llvm:main Apr 15, 2025
13 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

3 participants