- Notifications
You must be signed in to change notification settings - Fork 15k
[X86] Narrow BT/BTC/BTR/BTS compare + RMW patterns on very large integers #165540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e3c0892 0d0f720 0903c72 e33cc86 6fedca8 78ee615 File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| | @@ -53480,6 +53480,80 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG, | |
| return SDValue(); | ||
| } | ||
| | ||
| // Look for a RMW operation that only touches one bit of a larger than legal | ||
| // type and fold it to a BTC/BTR/BTS pattern acting on a single i32 sub value. | ||
| static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL, | ||
| SelectionDAG &DAG, | ||
| const X86Subtarget &Subtarget) { | ||
| using namespace SDPatternMatch; | ||
| | ||
| // Only handle normal stores and its chain was a matching normal load. | ||
| auto *Ld = dyn_cast<LoadSDNode>(St->getChain()); | ||
| if (!ISD::isNormalStore(St) || !St->isSimple() || !Ld || | ||
| !ISD::isNormalLoad(Ld) || !Ld->isSimple() || | ||
| Ld->getBasePtr() != St->getBasePtr() || | ||
| Ld->getOffset() != St->getOffset()) | ||
| return SDValue(); | ||
| | ||
| SDValue LoadVal(Ld, 0); | ||
| SDValue StoredVal = St->getValue(); | ||
| EVT VT = StoredVal.getValueType(); | ||
| | ||
| // Only narrow larger than legal scalar integers. | ||
| if (!VT.isScalarInteger() || | ||
| VT.getSizeInBits() <= (Subtarget.is64Bit() ? 64 : 32)) | ||
| return SDValue(); | ||
| | ||
| // BTR: X & ~(1 << ShAmt) | ||
| // BTS: X | (1 << ShAmt) | ||
| // BTC: X ^ (1 << ShAmt) | ||
| SDValue ShAmt; | ||
| if (!StoredVal.hasOneUse() || | ||
| !(sd_match(StoredVal, m_And(m_Specific(LoadVal), | ||
| m_Not(m_Shl(m_One(), m_Value(ShAmt))))) || | ||
| sd_match(StoredVal, | ||
| m_Or(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))) || | ||
| sd_match(StoredVal, | ||
| m_Xor(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))))) | ||
| return SDValue(); | ||
| | ||
| // Ensure the shift amount is in bounds. | ||
| KnownBits KnownAmt = DAG.computeKnownBits(ShAmt); | ||
| if (KnownAmt.getMaxValue().uge(VT.getSizeInBits())) | ||
| return SDValue(); | ||
| | ||
| // Split the shift into an alignment shift that moves the active i32 block to | ||
| // the bottom bits for truncation and a modulo shift that can act on the i32. | ||
| EVT AmtVT = ShAmt.getValueType(); | ||
| SDValue AlignAmt = DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, | ||
| DAG.getSignedConstant(-32LL, DL, AmtVT)); | ||
| SDValue ModuloAmt = | ||
| DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, DAG.getConstant(31, DL, AmtVT)); | ||
| | ||
| // Compute the byte offset for the i32 block that is changed by the RMW. | ||
| // combineTruncate will adjust the load for us in a similar way. | ||
| EVT PtrVT = St->getBasePtr().getValueType(); | ||
| SDValue PtrBitOfs = DAG.getZExtOrTrunc(AlignAmt, DL, PtrVT); | ||
| SDValue PtrByteOfs = DAG.getNode(ISD::SRL, DL, PtrVT, PtrBitOfs, | ||
| DAG.getShiftAmountConstant(3, PtrVT, DL)); | ||
| SDValue NewPtr = DAG.getMemBasePlusOffset(St->getBasePtr(), PtrByteOfs, DL, | ||
| SDNodeFlags::NoUnsignedWrap); | ||
| | ||
| // Reconstruct the BTC/BTR/BTS pattern for the i32 block and store. | ||
| SDValue X = DAG.getNode(ISD::SRL, DL, VT, LoadVal, AlignAmt); | ||
| X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X); | ||
| | ||
| SDValue Mask = | ||
| DAG.getNode(ISD::SHL, DL, MVT::i32, DAG.getConstant(1, DL, MVT::i32), | ||
| DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8)); | ||
| if (StoredVal.getOpcode() == ISD::AND) | ||
| Mask = DAG.getNOT(DL, Mask, MVT::i32); | ||
| | ||
| SDValue Res = DAG.getNode(StoredVal.getOpcode(), DL, MVT::i32, X, Mask); | ||
| return DAG.getStore(St->getChain(), DL, Res, NewPtr, St->getPointerInfo(), | ||
| Align(), St->getMemOperand()->getFlags()); | ||
| } | ||
| | ||
| static SDValue combineStore(SDNode *N, SelectionDAG &DAG, | ||
| TargetLowering::DAGCombinerInfo &DCI, | ||
| const X86Subtarget &Subtarget) { | ||
| | @@ -53706,6 +53780,9 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG, | |
| } | ||
| } | ||
| | ||
| if (SDValue R = narrowBitOpRMW(St, dl, DAG, Subtarget)) | ||
| return R; | ||
| | ||
| // Convert store(cmov(load(p), x, CC), p) to cstore(x, p, CC) | ||
| // store(cmov(x, load(p), CC), p) to cstore(x, p, InvertCC) | ||
| if ((VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) && | ||
| | @@ -54660,8 +54737,9 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG, | |
| // truncation, see if we can convert the shift into a pointer offset instead. | ||
| // Limit this to normal (non-ext) scalar integer loads. | ||
| if (SrcVT.isScalarInteger() && Src.getOpcode() == ISD::SRL && | ||
| Src.hasOneUse() && Src.getOperand(0).hasOneUse() && | ||
| ISD::isNormalLoad(Src.getOperand(0).getNode())) { | ||
| Src.hasOneUse() && ISD::isNormalLoad(Src.getOperand(0).getNode()) && | ||
| (Src.getOperand(0).hasOneUse() || | ||
| !DAG.getTargetLoweringInfo().isOperationLegal(ISD::LOAD, SrcVT))) { | ||
| auto *Ld = cast<LoadSDNode>(Src.getOperand(0)); | ||
| if (Ld->isSimple() && VT.isByteSized() && | ||
| isPowerOf2_64(VT.getSizeInBits())) { | ||
| | @@ -56459,6 +56537,7 @@ static SDValue combineAVX512SetCCToKMOV(EVT VT, SDValue Op0, ISD::CondCode CC, | |
| static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG, | ||
| TargetLowering::DAGCombinerInfo &DCI, | ||
| const X86Subtarget &Subtarget) { | ||
| using namespace SDPatternMatch; | ||
| const ISD::CondCode CC = cast<CondCodeSDNode>(N->getOperand(2))->get(); | ||
| const SDValue LHS = N->getOperand(0); | ||
| const SDValue RHS = N->getOperand(1); | ||
| | @@ -56517,6 +56596,37 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG, | |
| if (SDValue AndN = MatchAndCmpEq(RHS, LHS)) | ||
| return DAG.getSetCC(DL, VT, AndN, DAG.getConstant(0, DL, OpVT), CC); | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it better to split to a different patch given the currect large effects in the tests? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm happy to split this into separate BT and BTC/BTR/BTS patches - however there are some codegen regressions, but given the size of the current codegen is that a problem? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, the concerns are the size of test case. It's good to not split if there are regressions. | ||
| | ||
| // If we're performing a bit test on a larger than legal type, attempt | ||
| // to (aligned) shift down the value to the bottom 32-bits and then | ||
| // perform the bittest on the i32 value. | ||
| // ICMP_ZERO(AND(X,SHL(1,IDX))) | ||
| // --> ICMP_ZERO(AND(TRUNC(SRL(X,AND(IDX,-32))),SHL(1,AND(IDX,31)))) | ||
| if (isNullConstant(RHS) && | ||
| OpVT.getScalarSizeInBits() > (Subtarget.is64Bit() ? 64 : 32)) { | ||
| SDValue X, ShAmt; | ||
| if (sd_match(LHS, m_OneUse(m_And(m_Value(X), | ||
| m_Shl(m_One(), m_Value(ShAmt)))))) { | ||
| // Only attempt this if the shift amount is known to be in bounds. | ||
| KnownBits KnownAmt = DAG.computeKnownBits(ShAmt); | ||
| if (KnownAmt.getMaxValue().ult(OpVT.getScalarSizeInBits())) { | ||
| EVT AmtVT = ShAmt.getValueType(); | ||
| SDValue AlignAmt = | ||
| DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, | ||
| DAG.getSignedConstant(-32LL, DL, AmtVT)); | ||
| SDValue ModuloAmt = DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, | ||
| DAG.getConstant(31, DL, AmtVT)); | ||
| SDValue Mask = DAG.getNode( | ||
| ISD::SHL, DL, MVT::i32, DAG.getConstant(1, DL, MVT::i32), | ||
| DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8)); | ||
| X = DAG.getNode(ISD::SRL, DL, OpVT, X, AlignAmt); | ||
| X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X); | ||
| X = DAG.getNode(ISD::AND, DL, MVT::i32, X, Mask); | ||
| return DAG.getSetCC(DL, VT, X, DAG.getConstant(0, DL, MVT::i32), | ||
| CC); | ||
| } | ||
| } | ||
| } | ||
| | ||
| // cmpeq(trunc(x),C) --> cmpeq(x,C) | ||
| // cmpne(trunc(x),C) --> cmpne(x,C) | ||
| // iff x upper bits are zero. | ||
| | ||
Uh oh!
There was an error while loading. Please reload this page.