@@ -1996,7 +1996,8 @@ bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
19961996 UndefValue::get (Int8Ty), F->getName () + " .ID" );
19971997
19981998 for (Use *U : ToBeReplacedStateMachineUses)
1999- U->set (ConstantExpr::getBitCast (ID, U->get ()->getType ()));
1999+ U->set (ConstantExpr::getPointerBitCastOrAddrSpaceCast (
2000+ ID, U->get ()->getType ()));
20002001
20012002 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
20022003
@@ -3183,10 +3184,14 @@ struct AAKernelInfoFunction : AAKernelInfo {
31833184 IsWorker->setDebugLoc (DLoc);
31843185 BranchInst::Create (StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB);
31853186
3187+ Module &M = *Kernel->getParent ();
3188+
31863189 // Create local storage for the work function pointer.
3190+ const DataLayout &DL = M.getDataLayout ();
31873191 Type *VoidPtrTy = Type::getInt8PtrTy (Ctx);
3188- AllocaInst *WorkFnAI = new AllocaInst (VoidPtrTy, 0 , " worker.work_fn.addr" ,
3189- &Kernel->getEntryBlock ().front ());
3192+ Instruction *WorkFnAI =
3193+ new AllocaInst (VoidPtrTy, DL.getAllocaAddrSpace (), nullptr ,
3194+ " worker.work_fn.addr" , &Kernel->getEntryBlock ().front ());
31903195 WorkFnAI->setDebugLoc (DLoc);
31913196
31923197 auto &OMPInfoCache = static_cast <OMPInformationCache &>(A.getInfoCache ());
@@ -3199,13 +3204,23 @@ struct AAKernelInfoFunction : AAKernelInfo {
31993204 Value *Ident = KernelInitCB->getArgOperand (0 );
32003205 Value *GTid = KernelInitCB;
32013206
3202- Module &M = *Kernel->getParent ();
32033207 FunctionCallee BarrierFn =
32043208 OMPInfoCache.OMPBuilder .getOrCreateRuntimeFunction (
32053209 M, OMPRTL___kmpc_barrier_simple_spmd);
32063210 CallInst::Create (BarrierFn, {Ident, GTid}, " " , StateMachineBeginBB)
32073211 ->setDebugLoc (DLoc);
32083212
3213+ if (WorkFnAI->getType ()->getPointerAddressSpace () !=
3214+ (unsigned int )AddressSpace::Generic) {
3215+ WorkFnAI = new AddrSpaceCastInst (
3216+ WorkFnAI,
3217+ PointerType::getWithSamePointeeType (
3218+ cast<PointerType>(WorkFnAI->getType ()),
3219+ (unsigned int )AddressSpace::Generic),
3220+ WorkFnAI->getName () + " .generic" , StateMachineBeginBB);
3221+ WorkFnAI->setDebugLoc (DLoc);
3222+ }
3223+
32093224 FunctionCallee KernelParallelFn =
32103225 OMPInfoCache.OMPBuilder .getOrCreateRuntimeFunction (
32113226 M, OMPRTL___kmpc_kernel_parallel);
0 commit comments