Skip to content

Commit a5bab28

Browse files
authored
[NFC][SPIRV] Move common SPIRV::LinkageType deduction code to a helper in SPIRVUtils (#164248)
There was some repeated code that was used to deduce the SPIRV::LinkageType from a GlobalVariable/Function. At several related parts of the code we also had functions taking 2 parameters: a 'hasLinkage' bool, and a 'LinkageType'. This is error-prone since the later parameter's meaning depends on the first. This patch also merges these two options into a single `std::optional<SPIRV::LinkageType>`.
1 parent e25e43a commit a5bab28

File tree

7 files changed

+46
-44
lines changed

7 files changed

+46
-44
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,9 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
507507
static Register buildBuiltinVariableLoad(
508508
MachineIRBuilder &MIRBuilder, SPIRVType *VariableType,
509509
SPIRVGlobalRegistry *GR, SPIRV::BuiltIn::BuiltIn BuiltinValue, LLT LLType,
510-
Register Reg = Register(0), bool isConst = true, bool hasLinkageTy = true) {
510+
Register Reg = Register(0), bool isConst = true,
511+
const std::optional<SPIRV::LinkageType::LinkageType> &LinkageTy = {
512+
SPIRV::LinkageType::Import}) {
511513
Register NewRegister =
512514
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::pIDRegClass);
513515
MIRBuilder.getMRI()->setType(
@@ -521,9 +523,8 @@ static Register buildBuiltinVariableLoad(
521523
// Set up the global OpVariable with the necessary builtin decorations.
522524
Register Variable = GR->buildGlobalVariable(
523525
NewRegister, PtrType, getLinkStringForBuiltIn(BuiltinValue), nullptr,
524-
SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst,
525-
/* HasLinkageTy */ hasLinkageTy, SPIRV::LinkageType::Import, MIRBuilder,
526-
false);
526+
SPIRV::StorageClass::Input, nullptr, /* isConst= */ isConst, LinkageTy,
527+
MIRBuilder, false);
527528

528529
// Load the value from the global variable.
529530
Register LoadedRegister =
@@ -1851,7 +1852,7 @@ static bool generateWaveInst(const SPIRV::IncomingCall *Call,
18511852

18521853
return buildBuiltinVariableLoad(
18531854
MIRBuilder, Call->ReturnType, GR, Value, LLType, Call->ReturnRegister,
1854-
/* isConst= */ false, /* hasLinkageTy= */ false);
1855+
/* isConst= */ false, /* LinkageType= */ std::nullopt);
18551856
}
18561857

18571858
// We expect a builtin

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -479,18 +479,9 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
479479
.addImm(static_cast<uint32_t>(getExecutionModel(*ST, F)))
480480
.addUse(FuncVReg);
481481
addStringImm(F.getName(), MIB);
482-
} else if (!F.hasLocalLinkage() &&
483-
F.getVisibility() != GlobalValue::HiddenVisibility) {
484-
SPIRV::LinkageType::LinkageType LnkTy =
485-
F.isDeclaration()
486-
? SPIRV::LinkageType::Import
487-
: (F.getLinkage() == GlobalValue::LinkOnceODRLinkage &&
488-
ST->canUseExtension(
489-
SPIRV::Extension::SPV_KHR_linkonce_odr)
490-
? SPIRV::LinkageType::LinkOnceODR
491-
: SPIRV::LinkageType::Export);
482+
} else if (const auto LnkTy = getSpirvLinkageTypeFor(*ST, F)) {
492483
buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
493-
{static_cast<uint32_t>(LnkTy)}, F.getName());
484+
{static_cast<uint32_t>(*LnkTy)}, F.getName());
494485
}
495486

496487
// Handle function pointers decoration

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -712,9 +712,9 @@ SPIRVGlobalRegistry::buildConstantSampler(Register ResReg, unsigned AddrMode,
712712
Register SPIRVGlobalRegistry::buildGlobalVariable(
713713
Register ResVReg, SPIRVType *BaseType, StringRef Name,
714714
const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,
715-
const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
716-
SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
717-
bool IsInstSelector) {
715+
const MachineInstr *Init, bool IsConst,
716+
const std::optional<SPIRV::LinkageType::LinkageType> &LinkageType,
717+
MachineIRBuilder &MIRBuilder, bool IsInstSelector) {
718718
const GlobalVariable *GVar = nullptr;
719719
if (GV) {
720720
GVar = cast<const GlobalVariable>(GV);
@@ -792,9 +792,9 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
792792
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});
793793
}
794794

795-
if (HasLinkageTy)
795+
if (LinkageType)
796796
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
797-
{static_cast<uint32_t>(LinkageType)}, Name);
797+
{static_cast<uint32_t>(*LinkageType)}, Name);
798798

799799
SPIRV::BuiltIn::BuiltIn BuiltInId;
800800
if (getSpirvBuiltInIdByName(Name, BuiltInId))
@@ -821,8 +821,8 @@ Register SPIRVGlobalRegistry::getOrCreateGlobalVariableWithBinding(
821821
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
822822

823823
buildGlobalVariable(VarReg, VarType, Name, nullptr,
824-
getPointerStorageClass(VarType), nullptr, false, false,
825-
SPIRV::LinkageType::Import, MIRBuilder, false);
824+
getPointerStorageClass(VarType), nullptr, false,
825+
std::nullopt, MIRBuilder, false);
826826

827827
buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::DescriptorSet, {Set});
828828
buildOpDecorate(VarReg, MIRBuilder, SPIRV::Decoration::Binding, {Binding});

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -548,14 +548,12 @@ class SPIRVGlobalRegistry : public SPIRVIRMapping {
548548
MachineIRBuilder &MIRBuilder);
549549
Register getOrCreateUndef(MachineInstr &I, SPIRVType *SpvType,
550550
const SPIRVInstrInfo &TII);
551-
Register buildGlobalVariable(Register Reg, SPIRVType *BaseType,
552-
StringRef Name, const GlobalValue *GV,
553-
SPIRV::StorageClass::StorageClass Storage,
554-
const MachineInstr *Init, bool IsConst,
555-
bool HasLinkageTy,
556-
SPIRV::LinkageType::LinkageType LinkageType,
557-
MachineIRBuilder &MIRBuilder,
558-
bool IsInstSelector);
551+
Register buildGlobalVariable(
552+
Register Reg, SPIRVType *BaseType, StringRef Name, const GlobalValue *GV,
553+
SPIRV::StorageClass::StorageClass Storage, const MachineInstr *Init,
554+
bool IsConst,
555+
const std::optional<SPIRV::LinkageType::LinkageType> &LinkageType,
556+
MachineIRBuilder &MIRBuilder, bool IsInstSelector);
559557
Register getOrCreateGlobalVariableWithBinding(const SPIRVType *VarType,
560558
uint32_t Set, uint32_t Binding,
561559
StringRef Name,

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4350,22 +4350,16 @@ bool SPIRVInstructionSelector::selectGlobalValue(
43504350
if (hasInitializer(GlobalVar) && !Init)
43514351
return true;
43524352

4353-
bool HasLnkTy = !GV->hasLocalLinkage() && !GV->hasHiddenVisibility();
4354-
SPIRV::LinkageType::LinkageType LnkType =
4355-
GV->isDeclarationForLinker()
4356-
? SPIRV::LinkageType::Import
4357-
: (GV->hasLinkOnceODRLinkage() &&
4358-
STI.canUseExtension(SPIRV::Extension::SPV_KHR_linkonce_odr)
4359-
? SPIRV::LinkageType::LinkOnceODR
4360-
: SPIRV::LinkageType::Export);
4353+
const std::optional<SPIRV::LinkageType::LinkageType> LnkType =
4354+
getSpirvLinkageTypeFor(STI, *GV);
43614355

43624356
const unsigned AddrSpace = GV->getAddressSpace();
43634357
SPIRV::StorageClass::StorageClass StorageClass =
43644358
addressSpaceToStorageClass(AddrSpace, STI);
43654359
SPIRVType *ResType = GR.getOrCreateSPIRVPointerType(GVType, I, StorageClass);
43664360
Register Reg = GR.buildGlobalVariable(
43674361
ResVReg, ResType, GlobalIdent, GV, StorageClass, Init,
4368-
GlobalVar->isConstant(), HasLnkTy, LnkType, MIRBuilder, true);
4362+
GlobalVar->isConstant(), LnkType, MIRBuilder, true);
43694363
return Reg.isValid();
43704364
}
43714365

@@ -4516,8 +4510,8 @@ bool SPIRVInstructionSelector::loadVec3BuiltinInputID(
45164510
// builtin variable.
45174511
Register Variable = GR.buildGlobalVariable(
45184512
NewRegister, PtrType, getLinkStringForBuiltIn(BuiltInValue), nullptr,
4519-
SPIRV::StorageClass::Input, nullptr, true, false,
4520-
SPIRV::LinkageType::Import, MIRBuilder, false);
4513+
SPIRV::StorageClass::Input, nullptr, true, std::nullopt, MIRBuilder,
4514+
false);
45214515

45224516
// Create new register for loading value.
45234517
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
@@ -4569,8 +4563,8 @@ bool SPIRVInstructionSelector::loadBuiltinInputID(
45694563
// builtin variable.
45704564
Register Variable = GR.buildGlobalVariable(
45714565
NewRegister, PtrType, getLinkStringForBuiltIn(BuiltInValue), nullptr,
4572-
SPIRV::StorageClass::Input, nullptr, true, false,
4573-
SPIRV::LinkageType::Import, MIRBuilder, false);
4566+
SPIRV::StorageClass::Input, nullptr, true, std::nullopt, MIRBuilder,
4567+
false);
45744568

45754569
// Load uint value from the global variable.
45764570
auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpLoad))

llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,4 +1040,19 @@ getFirstValidInstructionInsertPoint(MachineBasicBlock &BB) {
10401040
: VarPos;
10411041
}
10421042

1043+
std::optional<SPIRV::LinkageType::LinkageType>
1044+
getSpirvLinkageTypeFor(const SPIRVSubtarget &ST, const GlobalValue &GV) {
1045+
if (GV.hasLocalLinkage() || GV.hasHiddenVisibility())
1046+
return std::nullopt;
1047+
1048+
if (GV.isDeclarationForLinker())
1049+
return SPIRV::LinkageType::Import;
1050+
1051+
if (GV.hasLinkOnceODRLinkage() &&
1052+
ST.canUseExtension(SPIRV::Extension::SPV_KHR_linkonce_odr))
1053+
return SPIRV::LinkageType::LinkOnceODR;
1054+
1055+
return SPIRV::LinkageType::Export;
1056+
}
1057+
10431058
} // namespace llvm

llvm/lib/Target/SPIRV/SPIRVUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,5 +559,8 @@ unsigned getArrayComponentCount(const MachineRegisterInfo *MRI,
559559
const MachineInstr *ResType);
560560
MachineBasicBlock::iterator
561561
getFirstValidInstructionInsertPoint(MachineBasicBlock &BB);
562+
563+
std::optional<SPIRV::LinkageType::LinkageType>
564+
getSpirvLinkageTypeFor(const SPIRVSubtarget &ST, const GlobalValue &GV);
562565
} // namespace llvm
563566
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H

0 commit comments

Comments
 (0)