-
Couldn't load subscription status.
- Fork 15k
[OpenMP][OMPIRBuilder] Support parallel in Generic kernels #150926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/skatrak/flang-generic-04-parallel-args
Are you sure you want to change the base?
[OpenMP][OMPIRBuilder] Support parallel in Generic kernels #150926
Conversation
| @llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-flang-openmp Author: Sergio Afonso (skatrak) ChangesThis patch introduces codegen logic to produce a wrapper function argument for the Full diff: https://github.com/llvm/llvm-project/pull/150926.diff 2 Files Affected:
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index a913958c0de9a..0005a72e86324 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -1323,6 +1323,86 @@ Error OpenMPIRBuilder::emitCancelationCheckImpl( return Error::success(); } +// Create wrapper function used to gather the outlined function's argument +// structure from a shared buffer and to forward them to it when running in +// Generic mode. +// +// The outlined function is expected to receive 2 integer arguments followed by +// an optional pointer argument to an argument structure holding the rest. +static Function *createTargetParallelWrapper(OpenMPIRBuilder *OMPIRBuilder, + Function &OutlinedFn) { + size_t NumArgs = OutlinedFn.arg_size(); + assert((NumArgs == 2 || NumArgs == 3) && + "expected a 2-3 argument parallel outlined function"); + bool UseArgStruct = NumArgs == 3; + + IRBuilder<> &Builder = OMPIRBuilder->Builder; + IRBuilder<>::InsertPointGuard IPG(Builder); + auto *FnTy = FunctionType::get(Builder.getVoidTy(), + {Builder.getInt16Ty(), Builder.getInt32Ty()}, + /*isVarArg=*/false); + auto *WrapperFn = + Function::Create(FnTy, GlobalValue::InternalLinkage, + OutlinedFn.getName() + ".wrapper", OMPIRBuilder->M); + + WrapperFn->addParamAttr(0, Attribute::NoUndef); + WrapperFn->addParamAttr(0, Attribute::ZExt); + WrapperFn->addParamAttr(1, Attribute::NoUndef); + + BasicBlock *EntryBB = + BasicBlock::Create(OMPIRBuilder->M.getContext(), "entry", WrapperFn); + Builder.SetInsertPoint(EntryBB); + + // Allocation. + Value *AddrAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), + /*ArraySize=*/nullptr, "addr"); + AddrAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast( + AddrAlloca, Builder.getPtrTy(/*AddrSpace=*/0), + AddrAlloca->getName() + ".ascast"); + + Value *ZeroAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), + /*ArraySize=*/nullptr, "zero"); + ZeroAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast( + ZeroAlloca, Builder.getPtrTy(/*AddrSpace=*/0), + ZeroAlloca->getName() + ".ascast"); + + Value *ArgsAlloca = nullptr; + if (UseArgStruct) { + ArgsAlloca = Builder.CreateAlloca(Builder.getPtrTy(), + /*ArraySize=*/nullptr, "global_args"); + ArgsAlloca = Builder.CreatePointerBitCastOrAddrSpaceCast( + ArgsAlloca, Builder.getPtrTy(/*AddrSpace=*/0), + ArgsAlloca->getName() + ".ascast"); + } + + // Initialization. + Builder.CreateStore(WrapperFn->getArg(1), AddrAlloca); + Builder.CreateStore(Builder.getInt32(0), ZeroAlloca); + if (UseArgStruct) { + Builder.CreateCall( + OMPIRBuilder->getOrCreateRuntimeFunctionPtr( + llvm::omp::RuntimeFunction::OMPRTL___kmpc_get_shared_variables), + {ArgsAlloca}); + } + + SmallVector<Value *, 3> Args{AddrAlloca, ZeroAlloca}; + + // Load structArg from global_args. + if (UseArgStruct) { + Value *StructArg = Builder.CreateLoad(Builder.getPtrTy(), ArgsAlloca); + StructArg = Builder.CreateInBoundsGEP(Builder.getPtrTy(), StructArg, + {Builder.getInt64(0)}); + StructArg = Builder.CreateLoad(Builder.getPtrTy(), StructArg, "structArg"); + Args.push_back(StructArg); + } + + // Call the outlined function holding the parallel body. + Builder.CreateCall(&OutlinedFn, Args); + Builder.CreateRetVoid(); + + return WrapperFn; +} + // Callback used to create OpenMP runtime calls to support // omp parallel clause for the device. // We need to use this callback to replace call to the OutlinedFn in OuterFn @@ -1332,6 +1412,10 @@ static void targetParallelCallback( BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition, Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr, Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) { + assert(OutlinedFn.arg_size() >= 2 && + "Expected at least tid and bounded tid as arguments"); + unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2; + // Add some known attributes. IRBuilder<> &Builder = OMPIRBuilder->Builder; OutlinedFn.addParamAttr(0, Attribute::NoAlias); @@ -1340,17 +1424,12 @@ static void targetParallelCallback( OutlinedFn.addParamAttr(1, Attribute::NoUndef); OutlinedFn.addFnAttr(Attribute::NoUnwind); - assert(OutlinedFn.arg_size() >= 2 && - "Expected at least tid and bounded tid as arguments"); - unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2; - CallInst *CI = cast<CallInst>(OutlinedFn.user_back()); assert(CI && "Expected call instruction to outlined function"); CI->getParent()->setName("omp_parallel"); Builder.SetInsertPoint(CI); Type *PtrTy = OMPIRBuilder->VoidPtr; - Value *NullPtrValue = Constant::getNullValue(PtrTy); // Add alloca for kernel args OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP(); @@ -1376,6 +1455,15 @@ static void targetParallelCallback( IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32) : Builder.getInt32(1); + // If this is not a Generic kernel, we can skip generating the wrapper. + std::optional<omp::OMPTgtExecModeFlags> ExecMode = + getTargetKernelExecMode(*OuterFn); + Value *WrapperFn; + if (ExecMode && *ExecMode & OMP_TGT_EXEC_MODE_GENERIC) + WrapperFn = createTargetParallelWrapper(OMPIRBuilder, OutlinedFn); + else + WrapperFn = Constant::getNullValue(PtrTy); + // Build kmpc_parallel_51 call Value *Parallel51CallArgs[] = { /* identifier*/ Ident, @@ -1384,7 +1472,7 @@ static void targetParallelCallback( /* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1), /* Proc bind */ Builder.getInt32(-1), /* outlined function */ &OutlinedFn, - /* wrapper function */ NullPtrValue, + /* wrapper function */ WrapperFn, /* arguments of the outlined funciton*/ Args, /* number of arguments */ Builder.getInt64(NumCapturedVars)}; diff --git a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir index 504e39c96f008..ca998b4672ba0 100644 --- a/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir +++ b/mlir/test/Target/LLVMIR/omptarget-parallel-llvm.mlir @@ -69,7 +69,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: store ptr %[[TMP6]], ptr %[[GEP_]], align 8 // CHECK: %[[TMP7:.*]] = getelementptr inbounds [1 x ptr], ptr %[[TMP2]], i64 0, i64 0 // CHECK: store ptr %[[STRUCTARG]], ptr %[[TMP7]], align 8 -// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr null, ptr %[[TMP2]], i64 1) +// CHECK: call void @__kmpc_parallel_51(ptr addrspacecast (ptr addrspace(1) @[[GLOB1]] to ptr), i32 %[[OMP_GLOBAL_THREAD_NUM]], i32 1, i32 -1, i32 -1, ptr @[[FUNC1:.*]], ptr @[[FUNC1_WRAPPER:.*]], ptr %[[TMP2]], i64 1) // CHECK: call void @__kmpc_free_shared(ptr %[[STRUCTARG]], i64 8) // CHECK: call void @__kmpc_target_deinit() @@ -84,7 +84,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: call void @__kmpc_parallel_51(ptr addrspacecast ( // CHECK-SAME: ptr addrspace(1) @[[NUM_THREADS_GLOB:[0-9]+]] to ptr), // CHECK-SAME: i32 [[NUM_THREADS_TMP0:%.*]], i32 1, i32 156, -// CHECK-SAME: i32 -1, ptr [[FUNC_NUM_THREADS1:@.*]], ptr null, ptr [[NUM_THREADS_TMP1:%.*]], i64 1) +// CHECK-SAME: i32 -1, ptr @[[FUNC_NUM_THREADS1:.*]], ptr @[[FUNC2_WRAPPER:.*]], ptr [[NUM_THREADS_TMP1:%.*]], i64 1) // One of the arguments of kmpc_parallel_51 function is responsible for handling if clause // of omp parallel construct for target region. If this argument is nonzero, @@ -105,4 +105,23 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo // CHECK: call void @__kmpc_parallel_51(ptr addrspacecast ( // CHECK-SAME: ptr addrspace(1) {{.*}} to ptr), // CHECK-SAME: i32 {{.*}}, i32 %[[IFCOND_TMP4]], i32 -1, -// CHECK-SAME: i32 -1, ptr {{.*}}, ptr null, ptr {{.*}}, i64 1) +// CHECK-SAME: i32 -1, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i64 1) + +// CHECK: define internal void @[[FUNC1_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %[[ADDR:.*]]) +// CHECK: %[[ADDR_ALLOCA:.*]] = alloca i32, align 4, addrspace(5) +// CHECK: %[[ADDR_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ADDR_ALLOCA]] to ptr +// CHECK: %[[ZERO_ALLOCA:.*]] = alloca i32, align 4, addrspace(5) +// CHECK: %[[ZERO_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ZERO_ALLOCA]] to ptr +// CHECK: %[[ARGS_ALLOCA:.*]] = alloca ptr, align 8, addrspace(5) +// CHECK: %[[ARGS_ASCAST:.*]] = addrspacecast ptr addrspace(5) %[[ARGS_ALLOCA]] to ptr +// CHECK: store i32 %[[ADDR]], ptr %[[ADDR_ASCAST]] +// CHECK: store i32 0, ptr %[[ZERO_ASCAST]] +// CHECK: call void @__kmpc_get_shared_variables(ptr %[[ARGS_ASCAST]]) +// CHECK: %[[LOAD_ARGS:.*]] = load ptr, ptr %[[ARGS_ASCAST]], align 8 +// CHECK: %[[FIRST_ARG:.*]] = getelementptr inbounds ptr, ptr %[[LOAD_ARGS]], i64 0 +// CHECK: %[[STRUCTARG:.*]] = load ptr, ptr %[[FIRST_ARG]], align 8 +// CHECK: call void @[[FUNC1]](ptr %[[ADDR_ASCAST]], ptr %[[ZERO_ASCAST]], ptr %[[STRUCTARG]]) + +// CHECK: define internal void @[[FUNC2_WRAPPER]](i16 noundef zeroext %{{.*}}, i32 noundef %{{.*}}) +// CHECK-NOT: define +// CHECK: call void @[[FUNC_NUM_THREADS1]]({{.*}}) |
6a97ff2 to 8b34402 Compare 8771c0f to b6e9849 Compare b6e9849 to 6a81001 Compare There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
0237b82 to 06a2570 Compare 4f751f7 to e57c671 Compare This patch introduces codegen logic to produce a wrapper function argument for the `__kmpc_parallel_51` DeviceRTL function needed to handle arguments passed using device shared memory in Generic mode.
be95567 to 6bcb74a Compare e57c671 to f027b86 Compare
This patch introduces codegen logic to produce a wrapper function argument for the
__kmpc_parallel_51DeviceRTL function needed to handle arguments passed using device shared memory in Generic mode.