Skip to content

Commit e77ef45

Browse files
[AArch64][GlobalISel] Improve lowering of vector fp16 fptrunc and fpext
1 parent 45495b5 commit e77ef45

15 files changed

+592
-712
lines changed

llvm/lib/Target/AArch64/AArch64Combine.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,13 @@ def combine_mul_cmlt : GICombineRule<
333333
(apply [{ applyCombineMulCMLT(*${root}, MRI, B, ${matchinfo}); }])
334334
>;
335335

336+
def lower_fptrunc_fptrunc: GICombineRule<
337+
(defs root:$root),
338+
(match (wip_match_opcode G_FPTRUNC):$root,
339+
[{ return matchFpTruncFpTrunc(*${root}, MRI); }]),
340+
(apply [{ applyFpTruncFpTrunc(*${root}, MRI, B); }])
341+
>;
342+
336343
// Post-legalization combines which should happen at all optimization levels.
337344
// (E.g. ones that facilitate matching for the selector) For example, matching
338345
// pseudos.
@@ -341,7 +348,7 @@ def AArch64PostLegalizerLowering
341348
[shuffle_vector_lowering, vashr_vlshr_imm,
342349
icmp_lowering, build_vector_lowering,
343350
lower_vector_fcmp, form_truncstore, fconstant_to_constant,
344-
vector_sext_inreg_to_shift,
351+
vector_sext_inreg_to_shift, lower_fptrunc_fptrunc,
345352
unmerge_ext_to_unmerge, lower_mulv2s64,
346353
vector_unmerge_lowering, insertelt_nonconst,
347354
unmerge_duplanes]> {

llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
2222
#include "llvm/CodeGen/GlobalISel/Utils.h"
2323
#include "llvm/CodeGen/MachineInstr.h"
24+
#include "llvm/CodeGen/MachineInstrBuilder.h"
2425
#include "llvm/CodeGen/MachineRegisterInfo.h"
2526
#include "llvm/CodeGen/TargetOpcodes.h"
2627
#include "llvm/IR/DerivedTypes.h"
@@ -817,14 +818,31 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
817818
.legalFor(
818819
{{s16, s32}, {s16, s64}, {s32, s64}, {v4s16, v4s32}, {v2s32, v2s64}})
819820
.libcallFor({{s16, s128}, {s32, s128}, {s64, s128}})
820-
.clampNumElements(0, v4s16, v4s16)
821-
.clampNumElements(0, v2s32, v2s32)
821+
.moreElementsToNextPow2(1)
822+
.customIf([](const LegalityQuery &Q) {
823+
LLT DstTy = Q.Types[0];
824+
LLT SrcTy = Q.Types[1];
825+
return SrcTy.isFixedVector() && DstTy.isFixedVector() &&
826+
SrcTy.getScalarSizeInBits() == 64 &&
827+
DstTy.getScalarSizeInBits() == 16;
828+
})
829+
// Clamp based on input
830+
.clampNumElements(1, v4s32, v4s32)
831+
.clampNumElements(1, v2s64, v2s64)
822832
.scalarize(0);
823833

824834
getActionDefinitionsBuilder(G_FPEXT)
825835
.legalFor(
826836
{{s32, s16}, {s64, s16}, {s64, s32}, {v4s32, v4s16}, {v2s64, v2s32}})
827837
.libcallFor({{s128, s64}, {s128, s32}, {s128, s16}})
838+
.moreElementsToNextPow2(0)
839+
.customIf([](const LegalityQuery &Q) {
840+
LLT DstTy = Q.Types[0];
841+
LLT SrcTy = Q.Types[1];
842+
return SrcTy.isVector() && DstTy.isVector() &&
843+
SrcTy.getScalarSizeInBits() == 16 &&
844+
DstTy.getScalarSizeInBits() == 64;
845+
})
828846
.clampNumElements(0, v4s32, v4s32)
829847
.clampNumElements(0, v2s64, v2s64)
830848
.scalarize(0);
@@ -1472,6 +1490,12 @@ bool AArch64LegalizerInfo::legalizeCustom(
14721490
return legalizeICMP(MI, MRI, MIRBuilder);
14731491
case TargetOpcode::G_BITCAST:
14741492
return legalizeBitcast(MI, Helper);
1493+
case TargetOpcode::G_FPEXT:
1494+
// In order to vectorise f16 to f64 properly, we need to use f32 as an
1495+
// intermediary
1496+
return legalizeViaF32(MI, MIRBuilder, MRI, TargetOpcode::G_FPEXT);
1497+
case TargetOpcode::G_FPTRUNC:
1498+
return legalizeViaF32(MI, MIRBuilder, MRI, TargetOpcode::G_FPTRUNC);
14751499
}
14761500

14771501
llvm_unreachable("expected switch to return");
@@ -2396,3 +2420,37 @@ bool AArch64LegalizerInfo::legalizePrefetch(MachineInstr &MI,
23962420
MI.eraseFromParent();
23972421
return true;
23982422
}
2423+
2424+
bool AArch64LegalizerInfo::legalizeViaF32(MachineInstr &MI,
2425+
MachineIRBuilder &MIRBuilder,
2426+
MachineRegisterInfo &MRI,
2427+
unsigned Opcode) const {
2428+
Register Dst = MI.getOperand(0).getReg();
2429+
Register Src = MI.getOperand(1).getReg();
2430+
LLT DstTy = MRI.getType(Dst);
2431+
LLT SrcTy = MRI.getType(Src);
2432+
2433+
LLT MidTy = LLT::fixed_vector(SrcTy.getNumElements(), LLT::scalar(32));
2434+
2435+
MachineInstrBuilder Mid;
2436+
MachineInstrBuilder Fin;
2437+
MIRBuilder.setInstrAndDebugLoc(MI);
2438+
switch (Opcode) {
2439+
default:
2440+
return false;
2441+
case TargetOpcode::G_FPEXT: {
2442+
Mid = MIRBuilder.buildFPExt(MidTy, Src);
2443+
Fin = MIRBuilder.buildFPExt(DstTy, Mid.getReg(0));
2444+
break;
2445+
}
2446+
case TargetOpcode::G_FPTRUNC: {
2447+
Mid = MIRBuilder.buildFPTrunc(MidTy, Src);
2448+
Fin = MIRBuilder.buildFPTrunc(DstTy, Mid.getReg(0));
2449+
break;
2450+
}
2451+
}
2452+
2453+
MRI.replaceRegWith(Dst, Fin.getReg(0));
2454+
MI.eraseFromParent();
2455+
return true;
2456+
}

llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class AArch64LegalizerInfo : public LegalizerInfo {
6767
bool legalizeDynStackAlloc(MachineInstr &MI, LegalizerHelper &Helper) const;
6868
bool legalizePrefetch(MachineInstr &MI, LegalizerHelper &Helper) const;
6969
bool legalizeBitcast(MachineInstr &MI, LegalizerHelper &Helper) const;
70+
bool legalizeViaF32(MachineInstr &MI, MachineIRBuilder &MIRBuilder,
71+
MachineRegisterInfo &MRI, unsigned Opcode) const;
7072
const AArch64Subtarget *ST;
7173
};
7274
} // End llvm namespace.

llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,200 @@ unsigned getCmpOperandFoldingProfit(Register CmpOp, MachineRegisterInfo &MRI) {
901901
return 0;
902902
}
903903

904+
// Helper function for matchFpTruncFpTrunc.
905+
// Checks that the given definition belongs to an FPTRUNC and that the source is
906+
// not an integer, as no rounding is necessary due to the range of values
907+
bool checkTruncSrc(MachineRegisterInfo &MRI, MachineInstr *MaybeFpTrunc) {
908+
if (!MaybeFpTrunc || MaybeFpTrunc->getOpcode() != TargetOpcode::G_FPTRUNC)
909+
return false;
910+
911+
// Check the source is 64 bits as we only want to match a very specific
912+
// pattern
913+
Register FpTruncSrc = MaybeFpTrunc->getOperand(1).getReg();
914+
LLT SrcTy = MRI.getType(FpTruncSrc);
915+
if (SrcTy.getScalarSizeInBits() != 64)
916+
return false;
917+
918+
// Need to check the float didn't come from an int as no rounding is
919+
// neccessary
920+
MachineInstr *FpTruncSrcDef = getDefIgnoringCopies(FpTruncSrc, MRI);
921+
if (FpTruncSrcDef->getOpcode() == TargetOpcode::G_SITOFP ||
922+
FpTruncSrcDef->getOpcode() == TargetOpcode::G_UITOFP)
923+
return false;
924+
925+
return true;
926+
}
927+
928+
// To avoid double rounding issues we need to lower FPTRUNC(FPTRUNC) to an odd
929+
// rounding truncate and a normal truncate. When
930+
// truncating an FP that came from an integer this is not a problem as the range
931+
// of values is lower in the int
932+
bool matchFpTruncFpTrunc(MachineInstr &MI, MachineRegisterInfo &MRI) {
933+
if (MI.getOpcode() != TargetOpcode::G_FPTRUNC)
934+
return false;
935+
936+
// Check the destination is 16 bits as we only want to match a very specific
937+
// pattern
938+
Register Dst = MI.getOperand(0).getReg();
939+
LLT DstTy = MRI.getType(Dst);
940+
if (DstTy.getScalarSizeInBits() != 16)
941+
return false;
942+
943+
Register Src = MI.getOperand(1).getReg();
944+
945+
MachineInstr *ParentDef = getDefIgnoringCopies(Src, MRI);
946+
if (!ParentDef)
947+
return false;
948+
949+
MachineInstr *FpTruncDef;
950+
switch (ParentDef->getOpcode()) {
951+
default:
952+
return false;
953+
case TargetOpcode::G_CONCAT_VECTORS: {
954+
// Expecting exactly two FPTRUNCs
955+
if (ParentDef->getNumOperands() != 3)
956+
return false;
957+
958+
// All operands need to be FPTRUNC
959+
for (unsigned OpIdx = 1, NumOperands = ParentDef->getNumOperands();
960+
OpIdx != NumOperands; ++OpIdx) {
961+
Register FpTruncDst = ParentDef->getOperand(OpIdx).getReg();
962+
963+
FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
964+
965+
if (!checkTruncSrc(MRI, FpTruncDef))
966+
return false;
967+
}
968+
969+
return true;
970+
}
971+
// This is to match cases in which vectors are widened to a larger size
972+
case TargetOpcode::G_INSERT_VECTOR_ELT: {
973+
Register VecExtractDst = ParentDef->getOperand(2).getReg();
974+
MachineInstr *VecExtractDef = getDefIgnoringCopies(VecExtractDst, MRI);
975+
976+
Register FpTruncDst = VecExtractDef->getOperand(1).getReg();
977+
FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
978+
979+
if (!checkTruncSrc(MRI, FpTruncDef))
980+
return false;
981+
break;
982+
}
983+
case TargetOpcode::G_FPTRUNC: {
984+
Register FpTruncDst = ParentDef->getOperand(1).getReg();
985+
FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
986+
987+
if (!checkTruncSrc(MRI, FpTruncDef))
988+
return false;
989+
break;
990+
}
991+
}
992+
993+
return true;
994+
}
995+
996+
void applyFpTruncFpTrunc(MachineInstr &MI, MachineRegisterInfo &MRI,
997+
MachineIRBuilder &B) {
998+
Register Dst = MI.getOperand(0).getReg();
999+
Register Src = MI.getOperand(1).getReg();
1000+
1001+
LLT V2F32 = LLT::fixed_vector(2, LLT::scalar(32));
1002+
LLT V4F32 = LLT::fixed_vector(4, LLT::scalar(32));
1003+
LLT V4F16 = LLT::fixed_vector(4, LLT::scalar(16));
1004+
1005+
B.setInstrAndDebugLoc(MI);
1006+
1007+
MachineInstr *ParentDef = getDefIgnoringCopies(Src, MRI);
1008+
if (!ParentDef)
1009+
return;
1010+
1011+
switch (ParentDef->getOpcode()) {
1012+
default:
1013+
return;
1014+
case TargetOpcode::G_INSERT_VECTOR_ELT: {
1015+
Register VecExtractDst = ParentDef->getOperand(2).getReg();
1016+
MachineInstr *VecExtractDef = getDefIgnoringCopies(VecExtractDst, MRI);
1017+
1018+
Register FpTruncDst = VecExtractDef->getOperand(1).getReg();
1019+
MachineInstr *FpTruncDef = getDefIgnoringCopies(FpTruncDst, MRI);
1020+
1021+
Register FpTruncSrc = FpTruncDef->getOperand(1).getReg();
1022+
MRI.setRegClass(FpTruncSrc, &AArch64::FPR128RegClass);
1023+
1024+
Register Fp32 = MRI.createGenericVirtualRegister(V2F32);
1025+
MRI.setRegClass(Fp32, &AArch64::FPR64RegClass);
1026+
1027+
B.buildInstr(AArch64::FCVTXNv2f32, {Fp32}, {FpTruncSrc});
1028+
1029+
// Only 4f32 -> 4f16 is legal so we need to mimic that situation
1030+
Register Fp32Padding = B.buildUndef(V2F32).getReg(0);
1031+
MRI.setRegClass(Fp32Padding, &AArch64::FPR64RegClass);
1032+
1033+
Register Fp32Full = MRI.createGenericVirtualRegister(V4F32);
1034+
MRI.setRegClass(Fp32Full, &AArch64::FPR128RegClass);
1035+
B.buildConcatVectors(Fp32Full, {Fp32, Fp32Padding});
1036+
1037+
Register Fp16 = MRI.createGenericVirtualRegister(V4F16);
1038+
MRI.setRegClass(Fp16, &AArch64::FPR64RegClass);
1039+
B.buildFPTrunc(Fp16, Fp32Full);
1040+
1041+
MRI.replaceRegWith(Dst, Fp16);
1042+
MI.eraseFromParent();
1043+
break;
1044+
}
1045+
case TargetOpcode::G_CONCAT_VECTORS: {
1046+
// Get the two FP Truncs that are being concatenated
1047+
Register FpTrunc1Dst = ParentDef->getOperand(1).getReg();
1048+
Register FpTrunc2Dst = ParentDef->getOperand(2).getReg();
1049+
1050+
MachineInstr *FpTrunc1Def = getDefIgnoringCopies(FpTrunc1Dst, MRI);
1051+
MachineInstr *FpTrunc2Def = getDefIgnoringCopies(FpTrunc2Dst, MRI);
1052+
1053+
// Make the registers 128bit to store the 2 doubles
1054+
Register LoFp64 = FpTrunc1Def->getOperand(1).getReg();
1055+
MRI.setRegClass(LoFp64, &AArch64::FPR128RegClass);
1056+
Register HiFp64 = FpTrunc2Def->getOperand(1).getReg();
1057+
MRI.setRegClass(HiFp64, &AArch64::FPR128RegClass);
1058+
1059+
B.setInstrAndDebugLoc(MI);
1060+
1061+
// Convert the lower half
1062+
Register LoFp32 = MRI.createGenericVirtualRegister(V2F32);
1063+
MRI.setRegClass(LoFp32, &AArch64::FPR64RegClass);
1064+
B.buildInstr(AArch64::FCVTXNv2f32, {LoFp32}, {LoFp64});
1065+
1066+
// Create a register for the high half to use
1067+
Register AccUndef = MRI.createGenericVirtualRegister(V4F32);
1068+
MRI.setRegClass(AccUndef, &AArch64::FPR128RegClass);
1069+
B.buildUndef(AccUndef);
1070+
1071+
Register Acc = MRI.createGenericVirtualRegister(V4F32);
1072+
MRI.setRegClass(Acc, &AArch64::FPR128RegClass);
1073+
B.buildInstr(TargetOpcode::INSERT_SUBREG)
1074+
.addDef(Acc)
1075+
.addUse(AccUndef)
1076+
.addUse(LoFp32)
1077+
.addImm(AArch64::dsub);
1078+
1079+
// Convert the high half
1080+
Register AccOut = MRI.createGenericVirtualRegister(V4F32);
1081+
MRI.setRegClass(AccOut, &AArch64::FPR128RegClass);
1082+
B.buildInstr(AArch64::FCVTXNv4f32)
1083+
.addDef(AccOut)
1084+
.addUse(Acc)
1085+
.addUse(HiFp64);
1086+
1087+
Register Fp16 = MRI.createGenericVirtualRegister(V4F16);
1088+
MRI.setRegClass(Fp16, &AArch64::FPR64RegClass);
1089+
B.buildFPTrunc(Fp16, AccOut);
1090+
1091+
MRI.replaceRegWith(Dst, Fp16);
1092+
MI.eraseFromParent();
1093+
break;
1094+
}
1095+
}
1096+
}
1097+
9041098
/// \returns true if it would be profitable to swap the LHS and RHS of a G_ICMP
9051099
/// instruction \p MI.
9061100
bool trySwapICmpOperands(MachineInstr &MI, MachineRegisterInfo &MRI) {

llvm/test/CodeGen/AArch64/GlobalISel/legalizer-info-validation.mir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -555,11 +555,11 @@
555555
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
556556
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
557557
# DEBUG-NEXT: G_FPEXT (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
558-
# DEBUG-NEXT: .. the first uncovered type index: 2, OK
559-
# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
558+
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
559+
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
560560
# DEBUG-NEXT: G_FPTRUNC (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
561-
# DEBUG-NEXT: .. the first uncovered type index: 2, OK
562-
# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
561+
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
562+
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected
563563
# DEBUG-NEXT: G_FPTOSI (opcode {{[0-9]+}}): 2 type indices, 0 imm indices
564564
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
565565
# DEBUG-NEXT: .. imm index coverage check SKIPPED: user-defined predicate detected

llvm/test/CodeGen/AArch64/arm64-fp128.ll

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,30 +1197,22 @@ define <2 x half> @vec_round_f16(<2 x fp128> %val) {
11971197
;
11981198
; CHECK-GI-LABEL: vec_round_f16:
11991199
; CHECK-GI: // %bb.0:
1200-
; CHECK-GI-NEXT: sub sp, sp, #64
1201-
; CHECK-GI-NEXT: str x30, [sp, #48] // 8-byte Folded Spill
1202-
; CHECK-GI-NEXT: .cfi_def_cfa_offset 64
1200+
; CHECK-GI-NEXT: sub sp, sp, #48
1201+
; CHECK-GI-NEXT: str x30, [sp, #32] // 8-byte Folded Spill
1202+
; CHECK-GI-NEXT: .cfi_def_cfa_offset 48
12031203
; CHECK-GI-NEXT: .cfi_offset w30, -16
1204-
; CHECK-GI-NEXT: mov v2.d[0], x8
12051204
; CHECK-GI-NEXT: str q1, [sp] // 16-byte Folded Spill
1206-
; CHECK-GI-NEXT: mov v2.d[1], x8
1207-
; CHECK-GI-NEXT: str q2, [sp, #32] // 16-byte Folded Spill
12081205
; CHECK-GI-NEXT: bl __trunctfhf2
12091206
; CHECK-GI-NEXT: // kill: def $h0 killed $h0 def $q0
12101207
; CHECK-GI-NEXT: str q0, [sp, #16] // 16-byte Folded Spill
12111208
; CHECK-GI-NEXT: ldr q0, [sp] // 16-byte Folded Reload
12121209
; CHECK-GI-NEXT: bl __trunctfhf2
1210+
; CHECK-GI-NEXT: ldr q1, [sp, #16] // 16-byte Folded Reload
12131211
; CHECK-GI-NEXT: // kill: def $h0 killed $h0 def $q0
1214-
; CHECK-GI-NEXT: str q0, [sp] // 16-byte Folded Spill
1215-
; CHECK-GI-NEXT: ldr q0, [sp, #32] // 16-byte Folded Reload
1216-
; CHECK-GI-NEXT: bl __trunctfhf2
1217-
; CHECK-GI-NEXT: ldr q0, [sp, #32] // 16-byte Folded Reload
1218-
; CHECK-GI-NEXT: bl __trunctfhf2
1219-
; CHECK-GI-NEXT: ldp q1, q0, [sp] // 32-byte Folded Reload
1220-
; CHECK-GI-NEXT: ldr x30, [sp, #48] // 8-byte Folded Reload
1221-
; CHECK-GI-NEXT: mov v0.h[1], v1.h[0]
1222-
; CHECK-GI-NEXT: // kill: def $d0 killed $d0 killed $q0
1223-
; CHECK-GI-NEXT: add sp, sp, #64
1212+
; CHECK-GI-NEXT: ldr x30, [sp, #32] // 8-byte Folded Reload
1213+
; CHECK-GI-NEXT: mov v1.h[1], v0.h[0]
1214+
; CHECK-GI-NEXT: fmov d0, d1
1215+
; CHECK-GI-NEXT: add sp, sp, #48
12241216
; CHECK-GI-NEXT: ret
12251217
%dst = fptrunc <2 x fp128> %val to <2 x half>
12261218
ret <2 x half> %dst

0 commit comments

Comments
 (0)