1919#include < string>
2020#include < tuple>
2121#include < vector>
22-
22+ # ifdef CINN_WITH_CUDA
2323#include " paddle/cinn/backends/codegen_cuda_dev.h"
24+ #endif
2425#include " paddle/cinn/cinn.h"
2526#include " paddle/cinn/ir/ir.h"
2627#include " paddle/cinn/ir/ir_mutator.h"
2728#include " paddle/cinn/ir/utils/ir_copy.h"
29+ #include " paddle/cinn/runtime/flags.h"
2830
2931namespace cinn {
3032namespace backends {
@@ -43,7 +45,7 @@ namespace backends {
4345 * - replace the original kernel function with a Call node and add it to the
4446 * first module, add a device kernel function to the second module.
4547 */
46- std::tuple<ir::Module, ir::Module> SplitCudaAndHostModule (ir::Module module );
48+ std::tuple<ir::Module, ir::Module> SplitDeviceAndHostModule (ir::Module module );
4749
4850namespace detail {
4951
@@ -52,7 +54,7 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
5254 : host_module_builder(module_name + " _host" ,
5355 cinn::common::DefaultHostTarget ()),
5456 device_module_builder(module_name + " _gpu_device" ,
55- cinn::common::DefaultNVGPUTarget ()) {}
57+ cinn::common::DefaultDeviceTarget ()) {}
5658
5759 std::tuple<ir::Module, ir::Module> operator ()(Expr* expr) {
5860 ir::IRMutator<>::Visit (expr, expr);
@@ -109,9 +111,18 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
109111 // shared_mem_bytes Can be calculated after codegen_cuda_dev buffer creation
110112 // however, this make CodeGenCUDA_Dev before spliting the host and device
111113 // module Maybe we could reorder the process.
112- CodeGenCUDA_Dev codegen_dev (cinn::common::DefaultNVGPUTarget ());
113- codegen_dev.Compile (ir::LoweredFunc (func));
114- Expr shared_mem_bytes = codegen_dev.GetDynSharedMemOffset ();
114+ std::optional<Expr> shared_mem_bytes;
115+ cinn::common::DefaultDeviceTarget ().arch .Match (
116+ [&](std::variant<common::UnknownArch,
117+ common::X86Arch,
118+ common::ARMArch>) { CINN_NOT_IMPLEMENTED; },
119+ [&](common::NVGPUArch) {
120+ #ifdef CINN_WITH_CUDA
121+ CodeGenCUDA_Dev codegen_dev (cinn::common::DefaultNVGPUTarget ());
122+ codegen_dev.Compile (ir::LoweredFunc (func));
123+ shared_mem_bytes = codegen_dev.GetDynSharedMemOffset ();
124+ #endif
125+ });
115126
116127 VLOG (6 ) << " Add a call node for func->name " << func->name << " \n "
117128 << " grid_dim: (" << func->cuda_axis_info .grid_dim (0 ) << " , "
@@ -120,10 +131,20 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
120131 << " block_dim: (" << func->cuda_axis_info .block_dim (0 ) << " , "
121132 << func->cuda_axis_info .block_dim (1 ) << " , "
122133 << func->cuda_axis_info .block_dim (2 ) << " ), "
123- << " shared_mem: " << shared_mem_bytes;
134+ << " shared_mem: " << shared_mem_bytes.value ();
135+
136+ std::optional<const char *> call_kernel;
137+ cinn::common::DefaultDeviceTarget ().arch .Match (
138+ [&](std::variant<common::UnknownArch,
139+ common::X86Arch,
140+ common::ARMArch>) { CINN_NOT_IMPLEMENTED; },
141+ [&](common::NVGPUArch) {
142+ call_kernel = runtime::intrinsic::call_cuda_kernel;
143+ });
144+
124145 auto call_extern_api =
125146 ir::Call::Make (Void (),
126- runtime::intrinsic::call_cuda_kernel ,
147+ call_kernel. value () ,
127148 {kernel_ptr,
128149 kernel_args,
129150 kernel_args_num,
@@ -133,7 +154,7 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
133154 func->cuda_axis_info .block_dim (0 ), // block_x
134155 func->cuda_axis_info .block_dim (1 ), // block_y
135156 func->cuda_axis_info .block_dim (2 ), // block_z
136- shared_mem_bytes,
157+ shared_mem_bytes. value () ,
137158 kernel_stream},
138159 {},
139160 ir::CallType::Extern,
0 commit comments