@@ -268,6 +268,38 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
268268 return  Result;
269269}
270270
271+ // / Given a function, if it represents the entry point of a target kernel, this
272+ // / returns the execution mode flags associated to that kernel.
273+ static  std::optional<omp::OMPTgtExecModeFlags>
274+ getTargetKernelExecMode (Function &Kernel) {
275+  CallInst *TargetInitCall = nullptr ;
276+  for  (Instruction &Inst : Kernel.getEntryBlock ()) {
277+  if  (auto  *Call = dyn_cast<CallInst>(&Inst)) {
278+  if  (Call->getCalledFunction ()->getName () == " __kmpc_target_init" 
279+  TargetInitCall = Call;
280+  break ;
281+  }
282+  }
283+  }
284+ 
285+  if  (!TargetInitCall)
286+  return  std::nullopt ;
287+ 
288+  //  Get the kernel mode information from the global variable associated to the
289+  //  first argument to the call to __kmpc_target_init. Refer to
290+  //  createTargetInit() to see how this is initialized.
291+  Value *InitOperand = TargetInitCall->getArgOperand (0 );
292+  GlobalVariable *KernelEnv = nullptr ;
293+  if  (auto  *Cast = dyn_cast<ConstantExpr>(InitOperand))
294+  KernelEnv = cast<GlobalVariable>(Cast->getOperand (0 ));
295+  else 
296+  KernelEnv = cast<GlobalVariable>(InitOperand);
297+  auto  *KernelEnvInit = cast<ConstantStruct>(KernelEnv->getInitializer ());
298+  auto  *ConfigEnv = cast<ConstantStruct>(KernelEnvInit->getOperand (0 ));
299+  auto  *KernelMode = cast<ConstantInt>(ConfigEnv->getOperand (2 ));
300+  return  static_cast <OMPTgtExecModeFlags>(KernelMode->getZExtValue ());
301+ }
302+ 
271303// / Make \p Source branch to \p Target.
272304// /
273305// / Handles two situations:
@@ -702,15 +734,19 @@ void OpenMPIRBuilder::finalize(Function *Fn) {
702734 //  CodeExtractor generates correct code for extracted functions
703735 //  which are used by OpenMP runtime.
704736 bool  ArgsInZeroAddressSpace = Config.isTargetDevice ();
705-  CodeExtractor Extractor (Blocks, /*  DominatorTree */ nullptr ,
706-  /*  AggregateArgs */ true ,
707-  /*  BlockFrequencyInfo */ nullptr ,
708-  /*  BranchProbabilityInfo */ nullptr ,
709-  /*  AssumptionCache */ nullptr ,
710-  /*  AllowVarArgs */ true ,
711-  /*  AllowAlloca */ true ,
712-  /*  AllocaBlock*/ OuterAllocaBB ,
713-  /*  Suffix */ " .omp_par" 
737+  CodeExtractor Extractor (
738+  Blocks, /*  DominatorTree */ nullptr ,
739+  /*  AggregateArgs */ true ,
740+  /*  BlockFrequencyInfo */ nullptr ,
741+  /*  BranchProbabilityInfo */ nullptr ,
742+  /*  AssumptionCache */ nullptr ,
743+  /*  AllowVarArgs */ true ,
744+  /*  AllowAlloca */ true ,
745+  /*  AllocaBlock*/ OuterAllocaBB ,
746+  /*  Suffix */ " .omp_par" 
747+  OI.CustomArgAllocatorCB  ? &OI.CustomArgAllocatorCB  : nullptr ,
748+  /*  DeallocationBlock */ ExitBB ,
749+  OI.CustomArgDeallocatorCB  ? &OI.CustomArgDeallocatorCB  : nullptr );
714750
715751 LLVM_DEBUG (dbgs () << " Before outlining: " " \n " 
716752 LLVM_DEBUG (dbgs () << " Entry " EntryBB ->getName ()
@@ -1614,6 +1650,50 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createParallel(
16141650 IfCondition, NumThreads, PrivTID, PrivTIDAddr,
16151651 ThreadID, ToBeDeletedVec);
16161652 };
1653+ 
1654+  std::optional<omp::OMPTgtExecModeFlags> ExecMode =
1655+  getTargetKernelExecMode (*OuterFn);
1656+ 
1657+  //  If OuterFn is not a Generic kernel, skip custom allocation. This causes
1658+  //  the CodeExtractor to follow its default behavior. Otherwise, we need to
1659+  //  use device shared memory to allocate argument structures.
1660+  if  (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) {
1661+  OI.CustomArgAllocatorCB  = [this ,
1662+  EntryBB](BasicBlock *, BasicBlock::iterator,
1663+  Type *ArgTy, const  Twine &Name) {
1664+  //  Instead of using the insertion point provided by the CodeExtractor,
1665+  //  here we need to use the block that eventually calls the outlined
1666+  //  function for the `parallel` construct.
1667+  // 
1668+  //  The reason is that the explicit deallocation call will be inserted
1669+  //  within the outlined function, whereas the alloca insertion point
1670+  //  might actually be located somewhere else in the caller. This becomes
1671+  //  a problem when e.g. `parallel` is inside of a `distribute` construct,
1672+  //  because the deallocation would be executed multiple times and the
1673+  //  allocation just once (outside of the loop).
1674+  // 
1675+  //  TODO: Ideally, we'd want to do the allocation and deallocation
1676+  //  outside of the `parallel` outlined function, hence using here the
1677+  //  insertion point provided by the CodeExtractor. We can't do this at
1678+  //  the moment because there is currently no way of passing an eligible
1679+  //  insertion point for the explicit deallocation to the CodeExtractor,
1680+  //  as that block is created (at least when nested inside of
1681+  //  `distribute`) sometime after createParallel() completed, so it can't
1682+  //  be stored in the OutlineInfo structure here.
1683+  // 
1684+  //  The current approach results in an explicit allocation and
1685+  //  deallocation pair for each `distribute` loop iteration in that case,
1686+  //  which is suboptimal.
1687+  return  createOMPAllocShared (
1688+  InsertPointTy (EntryBB, EntryBB->getFirstInsertionPt ()), ArgTy,
1689+  Name);
1690+  };
1691+  OI.CustomArgDeallocatorCB  =
1692+  [this ](BasicBlock *BB, BasicBlock::iterator AllocIP, Value *Arg,
1693+  Type *ArgTy) -> Instruction * {
1694+  return  createOMPFreeShared (InsertPointTy (BB, AllocIP), Arg, ArgTy);
1695+  };
1696+  }
16171697 } else  {
16181698 //  Generate OpenMP host runtime call
16191699 OI.PostOutlineCB  = [=, ToBeDeletedVec =
0 commit comments