- Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][arith] Add mulf(x, 0) -> 0 to mulf folder #161395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Fold `mulf(x, 0) -> 0`. Updated the yield_constant_loop test in mlir/test/Dialect/SCF/loop-pipelining.mlir to workaround [TODO](https://github.com/llvm/llvm-project/blob/main/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp#L163) in TestSCFUtils.cpp
| @llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-arith Author: Xiang Li (python3kgae) ChangesFold Updated the yield_constant_loop test in mlir/test/Dialect/SCF/loop-pipelining.mlir Full diff: https://github.com/llvm/llvm-project/pull/161395.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 7cfd6d3a98df8..676297f56ac0f 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1281,6 +1281,9 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) { // mulf(x, 1) -> x if (matchPattern(adaptor.getRhs(), m_OneFloat())) return getLhs(); + // mulf(x, 0) -> 0 + if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat())) + return getRhs(); return constFoldBinaryOp<FloatAttr>( adaptor.getOperands(), diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index ca3de3a2d7703..4c72a1bb27b01 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2216,6 +2216,16 @@ func.func @test_mulf1(%arg0 : f32, %arg1 : f32) -> (f32) { return %2 : f32 } +// CHECK-LABEL: @test_mulf2( +func.func @test_mulf2(%arg0 : f32, %arg1 : f32) -> (f32, f32) { + // CHECK-NEXT: %[[C0:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK-NEXT: return %[[C0]], %[[C0]] + %c0 = arith.constant 0.0 : f32 + %0 = arith.mulf %arg0, %c0 : f32 + %1 = arith.mulf %c0, %arg1 : f32 + return %0, %1 : f32, f32 +} + // ----- // CHECK-LABEL: @test_divf( diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir index 86af637fc05d7..11dc55c7ebb17 100644 --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -930,7 +930,7 @@ func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: i // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[CST10:.*]] = arith.constant 1.000000e+01 : f32 // CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32 // Prologue: // CHECK: %[[L0:.*]] = memref.load %[[A]][%[[C0]]] : memref<?xf32> @@ -938,15 +938,15 @@ func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: i // CHECK-NEXT: %[[L1:.*]]:2 = scf.for %[[IV:.*]] = %[[C0]] to %[[C3]] // CHECK-SAME: step %[[C1]] iter_args(%[[ARG0:.*]] = %[[CST2]], %[[ARG1:.*]] = %[[L0]]) -> (f32, f32) { // CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[ARG1]], %[[ARG0]] : f32 -// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CST0]] : f32 +// CHECK-NEXT: %[[MUL0:.*]] = arith.mulf %[[ADD0]], %[[CST10]] : f32 // CHECK-NEXT: memref.store %[[MUL0]], %[[A]][%[[IV]]] : memref<?xf32> // CHECK-NEXT: %[[IV1:.*]] = arith.addi %[[IV]], %[[C1]] : index // CHECK-NEXT: %[[L2:.*]] = memref.load %[[A]][%[[IV1]]] : memref<?xf32> -// CHECK-NEXT: scf.yield %[[CST0]], %[[L2]] : f32 +// CHECK-NEXT: scf.yield %[[CST10]], %[[L2]] : f32 // CHECK-NEXT: } // Epilogue: -// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[L1]]#1, %[[CST0]] : f32 -// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CST0]] : f32 +// CHECK-NEXT: %[[ADD1:.*]] = arith.addf %[[L1]]#1, %[[CST10]] : f32 +// CHECK-NEXT: %[[MUL1:.*]] = arith.mulf %[[ADD1]], %[[CST10]] : f32 // CHECK-NEXT: memref.store %[[MUL1]], %[[A]][%[[C3]]] : memref<?xf32> // CHECK-NEXT: return %[[L1]]#0 : f32 @@ -954,7 +954,7 @@ func.func @yield_constant_loop(%A: memref<?xf32>) -> f32 { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index - %cf0 = arith.constant 0.0 : f32 + %cf0 = arith.constant 10.0 : f32 %cf2 = arith.constant 2.0 : f32 %r = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%arg0 = %cf2) -> f32 { %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 3 } : memref<?xf32> |
| // CHECK-LABEL: @test_mulf2( | ||
| func.func @test_mulf2(%arg0 : f32, %arg1 : f32) -> (f32, f32) { | ||
| // CHECK-NEXT: %[[C0:.+]] = arith.constant 0.000000e+00 : f32 | ||
| // CHECK-NEXT: return %[[C0]], %[[C0]] | ||
| %c0 = arith.constant 0.0 : f32 | ||
| %0 = arith.mulf %arg0, %c0 : f32 | ||
| %1 = arith.mulf %c0, %arg1 : f32 | ||
| return %0, %1 : f32, f32 | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's not correct for Nan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also does not preserve the sign of the operand.
We could do all this with fast-math flags.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed by fold NaN before 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's still not correct, the value may dynamically be Nan and not be a constant Nan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see.
Updated with fast-math flags.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is nnan/ninf enough to lose the sign of the input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so :( .
Will FastMathFlags::nsz cover it or we'll have to go FastMathFlags::fast?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added FastMathFlags::nsz.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you match anyZeroFloat, maybe also add a test case with -0.0? Looks good otherwise.
Done. |
| if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan | | ||
| arith::FastMathFlags::nsz)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesn't it need also ninf? inf * 0 -> Nan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to check this with Alive: https://alive2.llvm.org/ce/z/wvNkdy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because nnan applies to the result as well:
nnan
No NaNs - Allow optimizations to assume the arguments and result are not NaN.
Fold `mulf(x, 0) -> 0` when (nnan | nsz)
Fold
mulf(x, 0) -> 0when (nnan | ninf)