- Notifications
You must be signed in to change notification settings - Fork 15.1k
[DAGCombine] Propagate truncate to operands #98666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| @llvm/pr-subscribers-backend-amdgpu @llvm/pr-subscribers-backend-nvptx Author: Justin Fargnoli (justinfargnoli) ChangesFull diff: https://github.com/llvm/llvm-project/pull/98666.diff 2 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 476a532db0a37..26729c7adb020 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -725,7 +725,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // We have some custom DAG combine patterns for these nodes setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD, ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, - ISD::VSELECT}); + ISD::TRUNCATE, ISD::VSELECT}); // setcc for f16x2 and bf16x2 needs special handling to prevent // legalizer's attempt to scalarize it due to v2i1 not being legal. @@ -5541,6 +5541,53 @@ static SDValue PerformREMCombine(SDNode *N, return SDValue(); } +// truncate (logic_op x, y) --> logic_op (truncate x), (truncate y) +// This will reduce register pressure. +static SDValue PerformTruncCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + if (!DCI.isBeforeLegalizeOps()) + return SDValue(); + + SDValue LogicalOp = N->getOperand(0); + switch (LogicalOp.getOpcode()) { + default: + break; + case ISD::ADD: + case ISD::SUB: + case ISD::MUL: + case ISD::AND: + case ISD::OR: + case ISD::XOR: { + EVT VT = N->getValueType(0); + EVT LogicalVT = LogicalOp.getValueType(); + if (VT != MVT::i32 || LogicalVT != MVT::i64) + break; + const TargetLowering &TLI = DCI.DAG.getTargetLoweringInfo(); + if (!VT.isScalarInteger() && + !TLI.isOperationLegal(LogicalOp.getOpcode(), VT)) + break; + if (!all_of(LogicalOp.getNode()->uses(), [](SDNode *U) { + return U->isMachineOpcode() + ? U->getMachineOpcode() == NVPTX::CVT_u32_u64 + : U->getOpcode() == ISD::TRUNCATE; + })) + break; + + SDLoc DL(N); + SDValue CVTNone = + DCI.DAG.getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); + SDNode *NarrowL = DCI.DAG.getMachineNode(NVPTX::CVT_u32_u64, DL, VT, + LogicalOp.getOperand(0), CVTNone); + SDNode *NarrowR = DCI.DAG.getMachineNode(NVPTX::CVT_u32_u64, DL, VT, + LogicalOp.getOperand(1), CVTNone); + return DCI.DAG.getNode(LogicalOp.getOpcode(), DL, VT, SDValue(NarrowL, 0), + SDValue(NarrowR, 0)); + } + } + + return SDValue(); +} + enum OperandSignedness { Signed = 0, Unsigned, @@ -5957,6 +6004,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, case ISD::UREM: case ISD::SREM: return PerformREMCombine(N, DCI, OptLevel); + case ISD::TRUNCATE: + return PerformTruncCombine(N, DCI); case ISD::SETCC: return PerformSETCCCombine(N, DCI, STI.getSmVersion()); case ISD::LOAD: @@ -5974,6 +6023,10 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, case ISD::VSELECT: return PerformVSELECTCombine(N, DCI); } + + if (N->isMachineOpcode() && N->getMachineOpcode() == NVPTX::CVT_u32_u64) + return PerformTruncCombine(N, DCI); + return SDValue(); } diff --git a/llvm/test/CodeGen/NVPTX/combine-truncate.ll b/llvm/test/CodeGen/NVPTX/combine-truncate.ll new file mode 100644 index 0000000000000..30e415ebe9527 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/combine-truncate.ll @@ -0,0 +1,90 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -march=nvptx64 | FileCheck %s +; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %} + +target triple = "nvptx64-nvidia-cuda" + +define i32 @trunc(i64 %a, i64 %b) { +; CHECK-LABEL: trunc( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.u64 %rd1, [trunc_param_0]; +; CHECK-NEXT: ld.param.u64 %rd2, [trunc_param_1]; +; CHECK-NEXT: cvt.u32.u64 %r1, %rd2; +; CHECK-NEXT: cvt.u32.u64 %r2, %rd1; +; CHECK-NEXT: or.b32 %r3, %r2, %r1; +; CHECK-NEXT: st.param.b32 [func_retval0+0], %r3; +; CHECK-NEXT: ret; + %or = or i64 %a, %b + %trunc = trunc i64 %or to i32 + ret i32 %trunc +} + +define i32 @trunc_not(i64 %a, i64 %b) { +; CHECK-LABEL: trunc_not( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .b64 %rd<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.u64 %rd1, [trunc_not_param_0]; +; CHECK-NEXT: ld.param.u64 %rd2, [trunc_not_param_1]; +; CHECK-NEXT: or.b64 %rd3, %rd1, %rd2; +; CHECK-NEXT: cvt.u32.u64 %r1, %rd3; +; CHECK-NEXT: mov.u64 %rd4, 0; +; CHECK-NEXT: st.u64 [%rd4], %rd3; +; CHECK-NEXT: st.param.b32 [func_retval0+0], %r1; +; CHECK-NEXT: ret; + %or = or i64 %a, %b + %trunc = trunc i64 %or to i32 + store i64 %or, ptr null + ret i32 %trunc +} + +define i32 @trunc_cvt(i64 %a, i64 %b) { +; CHECK-LABEL: trunc_cvt( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<5>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.u64 %rd1, [trunc_cvt_param_0]; +; CHECK-NEXT: ld.param.u64 %rd2, [trunc_cvt_param_1]; +; CHECK-NEXT: cvt.u32.u64 %r1, %rd2; +; CHECK-NEXT: cvt.u32.u64 %r2, %rd1; +; CHECK-NEXT: add.s32 %r3, %r2, %r1; +; CHECK-NEXT: or.b32 %r4, %r3, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0+0], %r4; +; CHECK-NEXT: ret; + %add = add i64 %a, %b + %or = or i64 %add, %a + %trunc = trunc i64 %or to i32 + ret i32 %trunc +} + +define i32 @trunc_cvt_not(i64 %a, i64 %b) { +; CHECK-LABEL: trunc_cvt_not( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-NEXT: .reg .b64 %rd<5>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.u64 %rd1, [trunc_cvt_not_param_0]; +; CHECK-NEXT: ld.param.u64 %rd2, [trunc_cvt_not_param_1]; +; CHECK-NEXT: add.s64 %rd3, %rd1, %rd2; +; CHECK-NEXT: mov.u64 %rd4, 0; +; CHECK-NEXT: st.u64 [%rd4], %rd3; +; CHECK-NEXT: cvt.u32.u64 %r1, %rd3; +; CHECK-NEXT: cvt.u32.u64 %r2, %rd1; +; CHECK-NEXT: or.b32 %r3, %r1, %r2; +; CHECK-NEXT: st.param.b32 [func_retval0+0], %r3; +; CHECK-NEXT: ret; + %add = add i64 %a, %b + store i64 %add, ptr null + %or = or i64 %add, %a + %trunc = trunc i64 %or to i32 + ret i32 %trunc +} |
| The DAGCombiner has the same combine: |
Interesting, then the question is -- why it apparently doesn't it fire in NVPTX. That may be what needs fixing. |
It expects one of the parameters of the binop to be constant. This combine seems to be less restrictive. |
I see. I guess the assumption is that if one of the argument is a constant, then truncation for it is free, so it never increases the number of truncates and assumes that the logical op is the same cost, regardless of the size. I think adding a new target-specific check here would be appropriate. Considering that register pressure is a pretty common issue for NVPTX, something fairly generic like |
Your 128-bit mul could become a 32-bit mul after the combine. |
Mul is the odd one out on the list as it nominally has different number of valid bits in the operands and the result. I think the optimization for mul could be improved by checking if UMUL_LOHI is legal and using that with operands truncated to half the size. That's somewhat orthogonal to a simpler choice whether it's beneficial to use an additional truncate but save on the number of registers used. Let's gat this part sorted out first. |
| ✅ With the latest revision this PR passed the C/C++ code formatter. |
f987431 to 7fb23e5 Compare
Artem-B left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. LGTM with a few minor nits.
| @justinfargnoli reverse-ping? |
98cc543 to 7632ec0 Compare 5e4123e to 99da17d Compare | SDValue TruncatedOp = | ||
| DAG.getNode(N0.getOpcode(), DL, VT, NarrowL, NarrowR); | ||
| if (TLI.IsDesirableToPromoteOp(TruncatedOp, SrcVT)) | ||
| break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needing to speculatively create some nodes in case it's profitable is unfortunate, is there a way to avoid this
| case ISD::MUL: | ||
| case ISD::SETCC: | ||
| case ISD::SELECT: | ||
| if (DestVT.getScalarSizeInBits() == 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getScalarType() == i1
Propagate
ISD::TRUNCATE's to the operands of logical operations if a target is particularly sensitive to register pressure.