Skip to content

Commit a55a720

Browse files
authored
[X86] Narrow BT/BTC/BTR/BTS compare + RMW patterns on very large integers (#165540)
This patch allows us to narrow single bit-test/twiddle operations for larger than legal scalar integers to efficiently operate just on the i32 sub-integer block actually affected. The BITOP(X,SHL(1,IDX)) patterns are split, with the IDX used to access the specific i32 block as well as specific bit within that block. BT comparisons are relatively simple, and builds on the truncated shifted loads fold from #165266. BTC/BTR/BTS bit twiddling patterns need to match the entire RMW pattern to safely confirm only one block is affected, but a similar approach is taken and creates codegen that should allow us to further merge with matching BT opcodes in a future patch (see #165291). The resulting codegen is notably more efficient than the heavily micro-coded memory folded variants of BT/BTC/BTR/BTS. There is still some work to improve the bit insert 'init' patterns included in bittest-big-integer.ll but I'm expecting this to be a straightforward future extension. Fixes #164225
1 parent 8c8bead commit a55a720

File tree

2 files changed

+1036
-6275
lines changed

2 files changed

+1036
-6275
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 112 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53481,6 +53481,80 @@ static SDValue combineMaskedStore(SDNode *N, SelectionDAG &DAG,
5348153481
return SDValue();
5348253482
}
5348353483

53484+
// Look for a RMW operation that only touches one bit of a larger than legal
53485+
// type and fold it to a BTC/BTR/BTS pattern acting on a single i32 sub value.
53486+
static SDValue narrowBitOpRMW(StoreSDNode *St, const SDLoc &DL,
53487+
SelectionDAG &DAG,
53488+
const X86Subtarget &Subtarget) {
53489+
using namespace SDPatternMatch;
53490+
53491+
// Only handle normal stores and its chain was a matching normal load.
53492+
auto *Ld = dyn_cast<LoadSDNode>(St->getChain());
53493+
if (!ISD::isNormalStore(St) || !St->isSimple() || !Ld ||
53494+
!ISD::isNormalLoad(Ld) || !Ld->isSimple() ||
53495+
Ld->getBasePtr() != St->getBasePtr() ||
53496+
Ld->getOffset() != St->getOffset())
53497+
return SDValue();
53498+
53499+
SDValue LoadVal(Ld, 0);
53500+
SDValue StoredVal = St->getValue();
53501+
EVT VT = StoredVal.getValueType();
53502+
53503+
// Only narrow larger than legal scalar integers.
53504+
if (!VT.isScalarInteger() ||
53505+
VT.getSizeInBits() <= (Subtarget.is64Bit() ? 64 : 32))
53506+
return SDValue();
53507+
53508+
// BTR: X & ~(1 << ShAmt)
53509+
// BTS: X | (1 << ShAmt)
53510+
// BTC: X ^ (1 << ShAmt)
53511+
SDValue ShAmt;
53512+
if (!StoredVal.hasOneUse() ||
53513+
!(sd_match(StoredVal, m_And(m_Specific(LoadVal),
53514+
m_Not(m_Shl(m_One(), m_Value(ShAmt))))) ||
53515+
sd_match(StoredVal,
53516+
m_Or(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt)))) ||
53517+
sd_match(StoredVal,
53518+
m_Xor(m_Specific(LoadVal), m_Shl(m_One(), m_Value(ShAmt))))))
53519+
return SDValue();
53520+
53521+
// Ensure the shift amount is in bounds.
53522+
KnownBits KnownAmt = DAG.computeKnownBits(ShAmt);
53523+
if (KnownAmt.getMaxValue().uge(VT.getSizeInBits()))
53524+
return SDValue();
53525+
53526+
// Split the shift into an alignment shift that moves the active i32 block to
53527+
// the bottom bits for truncation and a modulo shift that can act on the i32.
53528+
EVT AmtVT = ShAmt.getValueType();
53529+
SDValue AlignAmt = DAG.getNode(ISD::AND, DL, AmtVT, ShAmt,
53530+
DAG.getSignedConstant(-32LL, DL, AmtVT));
53531+
SDValue ModuloAmt =
53532+
DAG.getNode(ISD::AND, DL, AmtVT, ShAmt, DAG.getConstant(31, DL, AmtVT));
53533+
53534+
// Compute the byte offset for the i32 block that is changed by the RMW.
53535+
// combineTruncate will adjust the load for us in a similar way.
53536+
EVT PtrVT = St->getBasePtr().getValueType();
53537+
SDValue PtrBitOfs = DAG.getZExtOrTrunc(AlignAmt, DL, PtrVT);
53538+
SDValue PtrByteOfs = DAG.getNode(ISD::SRL, DL, PtrVT, PtrBitOfs,
53539+
DAG.getShiftAmountConstant(3, PtrVT, DL));
53540+
SDValue NewPtr = DAG.getMemBasePlusOffset(St->getBasePtr(), PtrByteOfs, DL,
53541+
SDNodeFlags::NoUnsignedWrap);
53542+
53543+
// Reconstruct the BTC/BTR/BTS pattern for the i32 block and store.
53544+
SDValue X = DAG.getNode(ISD::SRL, DL, VT, LoadVal, AlignAmt);
53545+
X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X);
53546+
53547+
SDValue Mask =
53548+
DAG.getNode(ISD::SHL, DL, MVT::i32, DAG.getConstant(1, DL, MVT::i32),
53549+
DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8));
53550+
if (StoredVal.getOpcode() == ISD::AND)
53551+
Mask = DAG.getNOT(DL, Mask, MVT::i32);
53552+
53553+
SDValue Res = DAG.getNode(StoredVal.getOpcode(), DL, MVT::i32, X, Mask);
53554+
return DAG.getStore(St->getChain(), DL, Res, NewPtr, St->getPointerInfo(),
53555+
Align(), St->getMemOperand()->getFlags());
53556+
}
53557+
5348453558
static SDValue combineStore(SDNode *N, SelectionDAG &DAG,
5348553559
TargetLowering::DAGCombinerInfo &DCI,
5348653560
const X86Subtarget &Subtarget) {
@@ -53707,6 +53781,9 @@ static SDValue combineStore(SDNode *N, SelectionDAG &DAG,
5370753781
}
5370853782
}
5370953783

53784+
if (SDValue R = narrowBitOpRMW(St, dl, DAG, Subtarget))
53785+
return R;
53786+
5371053787
// Convert store(cmov(load(p), x, CC), p) to cstore(x, p, CC)
5371153788
// store(cmov(x, load(p), CC), p) to cstore(x, p, InvertCC)
5371253789
if ((VT == MVT::i16 || VT == MVT::i32 || VT == MVT::i64) &&
@@ -54661,8 +54738,9 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
5466154738
// truncation, see if we can convert the shift into a pointer offset instead.
5466254739
// Limit this to normal (non-ext) scalar integer loads.
5466354740
if (SrcVT.isScalarInteger() && Src.getOpcode() == ISD::SRL &&
54664-
Src.hasOneUse() && Src.getOperand(0).hasOneUse() &&
54665-
ISD::isNormalLoad(Src.getOperand(0).getNode())) {
54741+
Src.hasOneUse() && ISD::isNormalLoad(Src.getOperand(0).getNode()) &&
54742+
(Src.getOperand(0).hasOneUse() ||
54743+
!DAG.getTargetLoweringInfo().isOperationLegal(ISD::LOAD, SrcVT))) {
5466654744
auto *Ld = cast<LoadSDNode>(Src.getOperand(0));
5466754745
if (Ld->isSimple() && VT.isByteSized() &&
5466854746
isPowerOf2_64(VT.getSizeInBits())) {
@@ -56460,6 +56538,7 @@ static SDValue combineAVX512SetCCToKMOV(EVT VT, SDValue Op0, ISD::CondCode CC,
5646056538
static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
5646156539
TargetLowering::DAGCombinerInfo &DCI,
5646256540
const X86Subtarget &Subtarget) {
56541+
using namespace SDPatternMatch;
5646356542
const ISD::CondCode CC = cast<CondCodeSDNode>(N->getOperand(2))->get();
5646456543
const SDValue LHS = N->getOperand(0);
5646556544
const SDValue RHS = N->getOperand(1);
@@ -56518,6 +56597,37 @@ static SDValue combineSetCC(SDNode *N, SelectionDAG &DAG,
5651856597
if (SDValue AndN = MatchAndCmpEq(RHS, LHS))
5651956598
return DAG.getSetCC(DL, VT, AndN, DAG.getConstant(0, DL, OpVT), CC);
5652056599

56600+
// If we're performing a bit test on a larger than legal type, attempt
56601+
// to (aligned) shift down the value to the bottom 32-bits and then
56602+
// perform the bittest on the i32 value.
56603+
// ICMP_ZERO(AND(X,SHL(1,IDX)))
56604+
// --> ICMP_ZERO(AND(TRUNC(SRL(X,AND(IDX,-32))),SHL(1,AND(IDX,31))))
56605+
if (isNullConstant(RHS) &&
56606+
OpVT.getScalarSizeInBits() > (Subtarget.is64Bit() ? 64 : 32)) {
56607+
SDValue X, ShAmt;
56608+
if (sd_match(LHS, m_OneUse(m_And(m_Value(X),
56609+
m_Shl(m_One(), m_Value(ShAmt)))))) {
56610+
// Only attempt this if the shift amount is known to be in bounds.
56611+
KnownBits KnownAmt = DAG.computeKnownBits(ShAmt);
56612+
if (KnownAmt.getMaxValue().ult(OpVT.getScalarSizeInBits())) {
56613+
EVT AmtVT = ShAmt.getValueType();
56614+
SDValue AlignAmt =
56615+
DAG.getNode(ISD::AND, DL, AmtVT, ShAmt,
56616+
DAG.getSignedConstant(-32LL, DL, AmtVT));
56617+
SDValue ModuloAmt = DAG.getNode(ISD::AND, DL, AmtVT, ShAmt,
56618+
DAG.getConstant(31, DL, AmtVT));
56619+
SDValue Mask = DAG.getNode(
56620+
ISD::SHL, DL, MVT::i32, DAG.getConstant(1, DL, MVT::i32),
56621+
DAG.getZExtOrTrunc(ModuloAmt, DL, MVT::i8));
56622+
X = DAG.getNode(ISD::SRL, DL, OpVT, X, AlignAmt);
56623+
X = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, X);
56624+
X = DAG.getNode(ISD::AND, DL, MVT::i32, X, Mask);
56625+
return DAG.getSetCC(DL, VT, X, DAG.getConstant(0, DL, MVT::i32),
56626+
CC);
56627+
}
56628+
}
56629+
}
56630+
5652156631
// cmpeq(trunc(x),C) --> cmpeq(x,C)
5652256632
// cmpne(trunc(x),C) --> cmpne(x,C)
5652356633
// iff x upper bits are zero.

0 commit comments

Comments
 (0)