Skip to content

Commit e588c7f

Browse files
authored
[X86] Attempt to fold trunc(srl(load(p),amt) -> load(p+amt/8) (#165266)
As reported on #164853 - we only attempt to reduce shifted loads for constant shift amounts, but we could do more with non-constant values if value tracking can confirm basic alignments. This patch determines if a truncated shifted load of scalar integer shifts by a byte aligned amount and replaces the non-constant shift amount with a pointer offset instead. I had hoped to make this a generic DAG fold, but reduceLoadWidth isn't ready to be converted to a KnownBits value tracking mechanism, and other targets don't have complex address math like X86. Fixes #164853
1 parent 4678f16 commit e588c7f

File tree

5 files changed

+177
-1618
lines changed

5 files changed

+177
-1618
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54634,6 +54634,7 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
5463454634
const X86Subtarget &Subtarget) {
5463554635
EVT VT = N->getValueType(0);
5463654636
SDValue Src = N->getOperand(0);
54637+
EVT SrcVT = Src.getValueType();
5463754638
SDLoc DL(N);
5463854639

5463954640
// Attempt to pre-truncate inputs to arithmetic ops instead.
@@ -54652,6 +54653,39 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
5465254653
if (SDValue V = combinePMULH(Src, VT, DL, DAG, Subtarget))
5465354654
return V;
5465454655

54656+
// Fold trunc(srl(load(p),amt)) -> load(p+amt/8)
54657+
// If we're shifting down byte aligned bit chunks from a larger load for
54658+
// truncation, see if we can convert the shift into a pointer offset instead.
54659+
// Limit this to normal (non-ext) scalar integer loads.
54660+
if (SrcVT.isScalarInteger() && Src.getOpcode() == ISD::SRL &&
54661+
Src.hasOneUse() && Src.getOperand(0).hasOneUse() &&
54662+
ISD::isNormalLoad(Src.getOperand(0).getNode())) {
54663+
auto *Ld = cast<LoadSDNode>(Src.getOperand(0));
54664+
if (Ld->isSimple() && VT.isByteSized() &&
54665+
isPowerOf2_64(VT.getSizeInBits())) {
54666+
SDValue ShAmt = Src.getOperand(1);
54667+
KnownBits KnownAmt = DAG.computeKnownBits(ShAmt);
54668+
// Check the shift amount is byte aligned.
54669+
// Check the truncation doesn't use any shifted in (zero) top bits.
54670+
if (KnownAmt.countMinTrailingZeros() >= 3 &&
54671+
KnownAmt.getMaxValue().ule(SrcVT.getSizeInBits() -
54672+
VT.getSizeInBits())) {
54673+
EVT PtrVT = Ld->getBasePtr().getValueType();
54674+
SDValue PtrBitOfs = DAG.getZExtOrTrunc(ShAmt, DL, PtrVT);
54675+
SDValue PtrByteOfs =
54676+
DAG.getNode(ISD::SRL, DL, PtrVT, PtrBitOfs,
54677+
DAG.getShiftAmountConstant(3, PtrVT, DL));
54678+
SDValue NewPtr = DAG.getMemBasePlusOffset(
54679+
Ld->getBasePtr(), PtrByteOfs, DL, SDNodeFlags::NoUnsignedWrap);
54680+
SDValue NewLoad =
54681+
DAG.getLoad(VT, DL, Ld->getChain(), NewPtr, Ld->getMemOperand());
54682+
DAG.ReplaceAllUsesOfValueWith(Src.getOperand(0).getValue(1),
54683+
NewLoad.getValue(1));
54684+
return NewLoad;
54685+
}
54686+
}
54687+
}
54688+
5465554689
// The bitcast source is a direct mmx result.
5465654690
// Detect bitcasts between i32 to x86mmx
5465754691
if (Src.getOpcode() == ISD::BITCAST && VT == MVT::i32) {

llvm/test/CodeGen/X86/bfloat-calling-conv.ll

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -660,8 +660,7 @@ define <3 x bfloat> @call_ret_v3bf16(ptr %ptr) #0 {
660660
; SSE2-LABEL: call_ret_v3bf16:
661661
; SSE2: # %bb.0:
662662
; SSE2-NEXT: pushq %rax
663-
; SSE2-NEXT: movl 4(%rdi), %eax
664-
; SSE2-NEXT: pinsrw $0, %eax, %xmm1
663+
; SSE2-NEXT: pinsrw $0, 4(%rdi), %xmm1
665664
; SSE2-NEXT: movd {{.*#+}} xmm0 = mem[0],zero,zero,zero
666665
; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
667666
; SSE2-NEXT: callq returns_v3bf16@PLT
@@ -725,8 +724,7 @@ define <3 x bfloat> @call_ret_v3bf16(ptr %ptr) #0 {
725724
; AVXNECONVERT-LABEL: call_ret_v3bf16:
726725
; AVXNECONVERT: # %bb.0:
727726
; AVXNECONVERT-NEXT: pushq %rax
728-
; AVXNECONVERT-NEXT: movl 4(%rdi), %eax
729-
; AVXNECONVERT-NEXT: vpinsrw $0, %eax, %xmm0, %xmm0
727+
; AVXNECONVERT-NEXT: vpinsrw $0, 4(%rdi), %xmm0, %xmm0
730728
; AVXNECONVERT-NEXT: vmovss {{.*#+}} xmm1 = mem[0],zero,zero,zero
731729
; AVXNECONVERT-NEXT: vinsertps {{.*#+}} xmm0 = xmm1[0],xmm0[0],zero,zero
732730
; AVXNECONVERT-NEXT: callq returns_v3bf16@PLT

0 commit comments

Comments
 (0)