Skip to content

Commit 3a3a524

Browse files
authored
[RISCV][NFC] Avoid iteration and division while selecting SHXADD instructions (#158851)
Should improve compilation time.
1 parent 3e023f7 commit 3a3a524

File tree

3 files changed

+104
-78
lines changed

3 files changed

+104
-78
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 71 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -16498,43 +16498,60 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
1649816498
SDValue X = N->getOperand(0);
1649916499

1650016500
if (Subtarget.hasShlAdd(3)) {
16501-
for (uint64_t Divisor : {3, 5, 9}) {
16502-
if (MulAmt % Divisor != 0)
16503-
continue;
16504-
uint64_t MulAmt2 = MulAmt / Divisor;
16505-
// 3/5/9 * 2^N -> shl (shXadd X, X), N
16506-
if (isPowerOf2_64(MulAmt2)) {
16507-
SDLoc DL(N);
16508-
SDValue X = N->getOperand(0);
16509-
// Put the shift first if we can fold a zext into the
16510-
// shift forming a slli.uw.
16511-
if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
16512-
X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
16513-
SDValue Shl = DAG.getNode(ISD::SHL, DL, VT, X,
16514-
DAG.getConstant(Log2_64(MulAmt2), DL, VT));
16515-
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
16516-
DAG.getConstant(Log2_64(Divisor - 1), DL, VT),
16517-
Shl);
16518-
}
16519-
// Otherwise, put rhe shl second so that it can fold with following
16520-
// instructions (e.g. sext or add).
16521-
SDValue Mul359 =
16522-
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
16523-
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
16524-
return DAG.getNode(ISD::SHL, DL, VT, Mul359,
16525-
DAG.getConstant(Log2_64(MulAmt2), DL, VT));
16526-
}
16527-
16528-
// 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
16529-
if (MulAmt2 == 3 || MulAmt2 == 5 || MulAmt2 == 9) {
16530-
SDLoc DL(N);
16531-
SDValue Mul359 =
16532-
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
16533-
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
16534-
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
16535-
DAG.getConstant(Log2_64(MulAmt2 - 1), DL, VT),
16536-
Mul359);
16501+
int Shift;
16502+
if (int ShXAmount = isShifted359(MulAmt, Shift)) {
16503+
// 3/5/9 * 2^N -> shl (shXadd X, X), N
16504+
SDLoc DL(N);
16505+
SDValue X = N->getOperand(0);
16506+
// Put the shift first if we can fold a zext into the shift forming
16507+
// a slli.uw.
16508+
if (X.getOpcode() == ISD::AND && isa<ConstantSDNode>(X.getOperand(1)) &&
16509+
X.getConstantOperandVal(1) == UINT64_C(0xffffffff)) {
16510+
SDValue Shl =
16511+
DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(Shift, DL, VT));
16512+
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Shl,
16513+
DAG.getConstant(ShXAmount, DL, VT), Shl);
1653716514
}
16515+
// Otherwise, put the shl second so that it can fold with following
16516+
// instructions (e.g. sext or add).
16517+
SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
16518+
DAG.getConstant(ShXAmount, DL, VT), X);
16519+
return DAG.getNode(ISD::SHL, DL, VT, Mul359,
16520+
DAG.getConstant(Shift, DL, VT));
16521+
}
16522+
16523+
// 3/5/9 * 3/5/9 -> shXadd (shYadd X, X), (shYadd X, X)
16524+
int ShX;
16525+
int ShY;
16526+
switch (MulAmt) {
16527+
case 3 * 5:
16528+
ShY = 1;
16529+
ShX = 2;
16530+
break;
16531+
case 3 * 9:
16532+
ShY = 1;
16533+
ShX = 3;
16534+
break;
16535+
case 5 * 5:
16536+
ShX = ShY = 2;
16537+
break;
16538+
case 5 * 9:
16539+
ShY = 2;
16540+
ShX = 3;
16541+
break;
16542+
case 9 * 9:
16543+
ShX = ShY = 3;
16544+
break;
16545+
default:
16546+
ShX = ShY = 0;
16547+
break;
16548+
}
16549+
if (ShX) {
16550+
SDLoc DL(N);
16551+
SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
16552+
DAG.getConstant(ShY, DL, VT), X);
16553+
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
16554+
DAG.getConstant(ShX, DL, VT), Mul359);
1653816555
}
1653916556

1654016557
// If this is a power 2 + 2/4/8, we can use a shift followed by a single
@@ -16557,26 +16574,22 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
1655716574
// variants we could implement. e.g.
1655816575
// (2^(1,2,3) * 3,5,9 + 1) << C2
1655916576
// 2^(C1>3) * 3,5,9 +/- 1
16560-
for (uint64_t Divisor : {3, 5, 9}) {
16561-
uint64_t C = MulAmt - 1;
16562-
if (C <= Divisor)
16563-
continue;
16564-
unsigned TZ = llvm::countr_zero(C);
16565-
if ((C >> TZ) == Divisor && (TZ == 1 || TZ == 2 || TZ == 3)) {
16577+
if (int ShXAmount = isShifted359(MulAmt - 1, Shift)) {
16578+
assert(Shift != 0 && "MulAmt=4,6,10 handled before");
16579+
if (Shift <= 3) {
1656616580
SDLoc DL(N);
16567-
SDValue Mul359 =
16568-
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
16569-
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
16581+
SDValue Mul359 = DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
16582+
DAG.getConstant(ShXAmount, DL, VT), X);
1657016583
return DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359,
16571-
DAG.getConstant(TZ, DL, VT), X);
16584+
DAG.getConstant(Shift, DL, VT), X);
1657216585
}
1657316586
}
1657416587

1657516588
// 2^n + 2/4/8 + 1 -> (add (shl X, C1), (shXadd X, X))
1657616589
if (MulAmt > 2 && isPowerOf2_64((MulAmt - 1) & (MulAmt - 2))) {
1657716590
unsigned ScaleShift = llvm::countr_zero(MulAmt - 1);
1657816591
if (ScaleShift >= 1 && ScaleShift < 4) {
16579-
unsigned ShiftAmt = Log2_64(((MulAmt - 1) & (MulAmt - 2)));
16592+
unsigned ShiftAmt = llvm::countr_zero((MulAmt - 1) & (MulAmt - 2));
1658016593
SDLoc DL(N);
1658116594
SDValue Shift1 =
1658216595
DAG.getNode(ISD::SHL, DL, VT, X, DAG.getConstant(ShiftAmt, DL, VT));
@@ -16589,7 +16602,7 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
1658916602
// 2^N - 3/5/9 --> (sub (shl X, C1), (shXadd X, x))
1659016603
for (uint64_t Offset : {3, 5, 9}) {
1659116604
if (isPowerOf2_64(MulAmt + Offset)) {
16592-
unsigned ShAmt = Log2_64(MulAmt + Offset);
16605+
unsigned ShAmt = llvm::countr_zero(MulAmt + Offset);
1659316606
if (ShAmt >= VT.getSizeInBits())
1659416607
continue;
1659516608
SDLoc DL(N);
@@ -16608,21 +16621,16 @@ static SDValue expandMul(SDNode *N, SelectionDAG &DAG,
1660816621
uint64_t MulAmt2 = MulAmt / Divisor;
1660916622
// 3/5/9 * 3/5/9 * 2^N - In particular, this covers multiples
1661016623
// of 25 which happen to be quite common.
16611-
for (uint64_t Divisor2 : {3, 5, 9}) {
16612-
if (MulAmt2 % Divisor2 != 0)
16613-
continue;
16614-
uint64_t MulAmt3 = MulAmt2 / Divisor2;
16615-
if (isPowerOf2_64(MulAmt3)) {
16616-
SDLoc DL(N);
16617-
SDValue Mul359A =
16618-
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
16619-
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
16620-
SDValue Mul359B = DAG.getNode(
16621-
RISCVISD::SHL_ADD, DL, VT, Mul359A,
16622-
DAG.getConstant(Log2_64(Divisor2 - 1), DL, VT), Mul359A);
16623-
return DAG.getNode(ISD::SHL, DL, VT, Mul359B,
16624-
DAG.getConstant(Log2_64(MulAmt3), DL, VT));
16625-
}
16624+
if (int ShBAmount = isShifted359(MulAmt2, Shift)) {
16625+
SDLoc DL(N);
16626+
SDValue Mul359A =
16627+
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, X,
16628+
DAG.getConstant(Log2_64(Divisor - 1), DL, VT), X);
16629+
SDValue Mul359B =
16630+
DAG.getNode(RISCVISD::SHL_ADD, DL, VT, Mul359A,
16631+
DAG.getConstant(ShBAmount, DL, VT), Mul359A);
16632+
return DAG.getNode(ISD::SHL, DL, VT, Mul359B,
16633+
DAG.getConstant(Shift, DL, VT));
1662616634
}
1662716635
}
1662816636
}

llvm/lib/Target/RISCV/RISCVInstrInfo.cpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4586,24 +4586,23 @@ void RISCVInstrInfo::mulImm(MachineFunction &MF, MachineBasicBlock &MBB,
45864586
.addReg(DestReg, RegState::Kill)
45874587
.addImm(ShiftAmount)
45884588
.setMIFlag(Flag);
4589-
} else if (STI.hasShlAdd(3) &&
4590-
((Amount % 3 == 0 && isPowerOf2_64(Amount / 3)) ||
4591-
(Amount % 5 == 0 && isPowerOf2_64(Amount / 5)) ||
4592-
(Amount % 9 == 0 && isPowerOf2_64(Amount / 9)))) {
4589+
} else if (int ShXAmount, ShiftAmount;
4590+
STI.hasShlAdd(3) &&
4591+
(ShXAmount = isShifted359(Amount, ShiftAmount)) != 0) {
45934592
// We can use Zba SHXADD+SLLI instructions for multiply in some cases.
45944593
unsigned Opc;
4595-
uint32_t ShiftAmount;
4596-
if (Amount % 9 == 0) {
4597-
Opc = RISCV::SH3ADD;
4598-
ShiftAmount = Log2_64(Amount / 9);
4599-
} else if (Amount % 5 == 0) {
4600-
Opc = RISCV::SH2ADD;
4601-
ShiftAmount = Log2_64(Amount / 5);
4602-
} else if (Amount % 3 == 0) {
4594+
switch (ShXAmount) {
4595+
case 1:
46034596
Opc = RISCV::SH1ADD;
4604-
ShiftAmount = Log2_64(Amount / 3);
4605-
} else {
4606-
llvm_unreachable("implied by if-clause");
4597+
break;
4598+
case 2:
4599+
Opc = RISCV::SH2ADD;
4600+
break;
4601+
case 3:
4602+
Opc = RISCV::SH3ADD;
4603+
break;
4604+
default:
4605+
llvm_unreachable("unexpected result of isShifted359");
46074606
}
46084607
if (ShiftAmount)
46094608
BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)

llvm/lib/Target/RISCV/RISCVInstrInfo.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,25 @@
2525

2626
namespace llvm {
2727

28+
// If Value is of the form C1<<C2, where C1 = 3, 5 or 9,
29+
// returns log2(C1 - 1) and assigns Shift = C2.
30+
// Otherwise, returns 0.
31+
template <typename T> int isShifted359(T Value, int &Shift) {
32+
if (Value == 0)
33+
return 0;
34+
Shift = llvm::countr_zero(Value);
35+
switch (Value >> Shift) {
36+
case 3:
37+
return 1;
38+
case 5:
39+
return 2;
40+
case 9:
41+
return 3;
42+
default:
43+
return 0;
44+
}
45+
}
46+
2847
class RISCVSubtarget;
2948

3049
static const MachineMemOperand::Flags MONontemporalBit0 =

0 commit comments

Comments
 (0)