Skip to content

Conversation

@python3kgae
Copy link
Contributor

@python3kgae python3kgae commented Sep 30, 2025

Fold mulf(x, 0) -> 0 when (nnan | ninf)

Fold `mulf(x, 0) -> 0`. Updated the yield_constant_loop test in mlir/test/Dialect/SCF/loop-pipelining.mlir to workaround [TODO](https://github.com/llvm/llvm-project/blob/main/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp#L163) in TestSCFUtils.cpp
@llvmbot
Copy link
Member

llvmbot commented Sep 30, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir-arith

Author: Xiang Li (python3kgae)

Changes

Fold mulf(x, 0) -> 0.

Updated the yield_constant_loop test in mlir/test/Dialect/SCF/loop-pipelining.mlir
to workaround TODO in TestSCFUtils.cpp


Full diff: https://github.com/llvm/llvm-project/pull/161395.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+3)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+10)
  • (modified) mlir/test/Dialect/SCF/loop-pipelining.mlir (+6-6)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 7cfd6d3a98df8..676297f56ac0f 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1281,6 +1281,9 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) { // mulf(x, 1) -> x if (matchPattern(adaptor.getRhs(), m_OneFloat())) return getLhs(); + // mulf(x, 0) -> 0 + if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat())) + return getRhs(); return constFoldBinaryOp<FloatAttr>( adaptor.getOperands(), diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index ca3de3a2d7703..4c72a1bb27b01 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2216,6 +2216,16 @@ func.func @test_mulf1(%arg0 : f32, %arg1 : f32) -> (f32) { return %2 : f32 } +// CHECK-LABEL: @test_mulf2( +func.func @test_mulf2(%arg0 : f32, %arg1 : f32) -> (f32, f32) { + // CHECK-NEXT: %[[C0:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK-NEXT: return %[[C0]], %[[C0]] + %c0 = arith.constant 0.0 : f32 + %0 = arith.mulf %arg0, %c0 : f32 + %1 = arith.mulf %c0, %arg1 : f32 + return %0, %1 : f32, f32 +} + // ----- // CHECK-LABEL: @test_divf( diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir index 86af637fc05d7..11dc55c7ebb17 100644 --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -930,7 +930,7 @@ func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: i // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[CST10:.*]] = arith.constant 1.000000e+01 : f32 // CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32 // Prologue: // CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32> @@ -938,15 +938,15 @@ func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: i // CHECK-NEXT: %[[L1:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]] // CHECK-SAME: step %[[C1]] iter_args(%[[ARG0:.*]] = %[[CST2]], %[[ARG1:.*]] = %[[L0]]) -> (f32, f32) { // CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG1]], %[[ARG0]] : f32 -// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CST0]] : f32 +// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CST10]] : f32 // CHECK-NEXT: memref.store %[[MUL0]], %[[A]][%[[IV]]] : memref<?xf32> // CHECK-NEXT: %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index // CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32> -// CHECK-NEXT: scf.yield %[[CST0]], %[[L2]] : f32 +// CHECK-NEXT: scf.yield %[[CST10]], %[[L2]] : f32 // CHECK-NEXT: } // Epilogue: -// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[L1]]#1, %[[CST0]] : f32 -// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CST0]] : f32 +// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[L1]]#1, %[[CST10]] : f32 +// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CST10]] : f32 // CHECK-NEXT: memref.store %[[MUL1]], %[[A]][%[[C3]]] : memref<?xf32> // CHECK-NEXT: return %[[L1]]#0 : f32 @@ -954,7 +954,7 @@ func.func @yield_constant_loop(%A: memref<?xf32>) -> f32 { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index - %cf0 = arith.constant 0.0 : f32 + %cf0 = arith.constant 10.0 : f32 %cf2 = arith.constant 2.0 : f32 %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf2) -> f32 { %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref<?xf32> 
Comment on lines 2219 to 2227
// CHECK-LABEL: @test_mulf2(
func.func @test_mulf2(%arg0 : f32, %arg1 : f32) -> (f32, f32) {
// CHECK-NEXT: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: return %[[C0]], %[[C0]]
%c0 = arith.constant 0.0 : f32
%0 = arith.mulf %arg0, %c0 : f32
%1 = arith.mulf %c0, %arg1 : f32
return %0, %1 : f32, f32
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's not correct for Nan

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also does not preserve the sign of the operand.

We could do all this with fast-math flags.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed by fold NaN before 0.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's still not correct, the value may dynamically be Nan and not be a constant Nan

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.
Updated with fast-math flags.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is nnan/ninf enough to lose the sign of the input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so :( .
Will FastMathFlags::nsz cover it or we'll have to go FastMathFlags::fast?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added FastMathFlags::nsz.

@python3kgae python3kgae changed the title [mlir][arith] Add mulf(x, 0) -> 0 to mulf folder [mlir][arith] Add more patterns to mulf folder Sep 30, 2025
@kuhar kuhar self-requested a review September 30, 2025 17:45
@python3kgae python3kgae changed the title [mlir][arith] Add more patterns to mulf folder [mlir][arith] Add mulf(x, 0) -> 0 to mulf folder Sep 30, 2025
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you match anyZeroFloat, maybe also add a test case with -0.0? Looks good otherwise.

@python3kgae
Copy link
Contributor Author

Since you match anyZeroFloat, maybe also add a test case with -0.0? Looks good otherwise.

Done.

@python3kgae python3kgae merged commit 2d06374 into llvm:main Oct 2, 2025
9 checks passed
@python3kgae python3kgae deleted the fold_mulf_0 branch October 2, 2025 02:47
Comment on lines +1285 to +1286
if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
arith::FastMathFlags::nsz)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't it need also ninf? inf * 0 -> Nan

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to check this with Alive: https://alive2.llvm.org/ce/z/wvNkdy

Copy link
Member

@kuhar kuhar Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because nnan applies to the result as well:

nnan
No NaNs - Allow optimizations to assume the arguments and result are not NaN.

mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Oct 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment