@@ -51,22 +51,28 @@ function(generate_unify_header DIR_NAME)
5151 endforeach ()
5252 # append header into extension.h
5353 string (REPLACE "${PADDLE_SOURCE_DIR} \/ " "" header_file "${header_file} " )
54- file (APPEND ${pten_extension_header_file } "#include \" ${header_file} \"\n " )
54+ file (APPEND ${phi_extension_header_file } "#include \" ${header_file} \"\n " )
5555endfunction ()
5656
5757# call kernel_declare need to make sure whether the target of input exists
5858function (kernel_declare TARGET_LIST)
5959 foreach (kernel_path ${TARGET_LIST} )
6060 file (READ ${kernel_path} kernel_impl)
61- # TODO(chenweihang): rename PD_REGISTER_KERNEL to PD_REGISTER_KERNEL
62- # NOTE(chenweihang): now we don't recommend to use digit in kernel name
63- string (REGEX MATCH "(PD_REGISTER_KERNEL|PD_REGISTER_GENERAL_KERNEL)\\ ([ \t\r\n ]*[a-z0-9_]*," first_registry "${kernel_impl} " )
61+ string (REGEX MATCH "(PD_REGISTER_KERNEL|PD_REGISTER_GENERAL_KERNEL)\\ ([ \t\r\n ]*[a-z0-9_]*,[ \t\r\n\/ ]*[a-z0-9_]*" first_registry "${kernel_impl} " )
6462 if (NOT first_registry STREQUAL "" )
63+ # some gpu kernel only can run on cuda, not support rocm, so we add this branch
64+ if (WITH_ROCM)
65+ string (FIND "${first_registry} " "cuda_only" pos)
66+ if (pos GREATER 1)
67+ continue ()
68+ endif ()
69+ endif ()
6570 # parse the first kernel name
6671 string (REPLACE "PD_REGISTER_KERNEL(" "" kernel_name "${first_registry} " )
6772 string (REPLACE "PD_REGISTER_GENERAL_KERNEL(" "" kernel_name "${kernel_name} " )
6873 string (REPLACE "," "" kernel_name "${kernel_name} " )
6974 string (REGEX REPLACE "[ \t\r\n ]+" "" kernel_name "${kernel_name} " )
75+ string (REGEX REPLACE "//cuda_only" "" kernel_name "${kernel_name} " )
7076 # append kernel declare into declarations.h
7177 # TODO(chenweihang): default declare ALL_LAYOUT for each kernel
7278 if (${kernel_path} MATCHES "./cpu\/ " )
@@ -75,6 +81,8 @@ function(kernel_declare TARGET_LIST)
7581 file (APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name} , GPU, ALL_LAYOUT);\n " )
7682 elseif (${kernel_path} MATCHES "./xpu\/ " )
7783 file (APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name} , XPU, ALL_LAYOUT);\n " )
84+ elseif (${kernel_path} MATCHES "./gpudnn\/ " )
85+ file (APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name} , GPUDNN, ALL_LAYOUT);\n " )
7886 else ()
7987 # deal with device independent kernel, now we use CPU temporaary
8088 file (APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name} , CPU, ALL_LAYOUT);\n " )
@@ -88,13 +96,16 @@ function(kernel_library TARGET)
8896 set (cpu_srcs)
8997 set (gpu_srcs)
9098 set (xpu_srcs)
99+ set (gpudnn_srcs)
91100 set (selected_rows_srcs)
92101 # parse and save the deps kerenl targets
93102 set (all_srcs)
94103 set (kernel_deps)
95104
96105 set (oneValueArgs SUB_DIR)
97106 set (multiValueArgs SRCS DEPS)
107+ set (target_build_flag 1)
108+
98109 cmake_parse_arguments (kernel_library "${options} " "${oneValueArgs} "
99110 "${multiValueArgs} " ${ARGN} )
100111
@@ -117,6 +128,9 @@ function(kernel_library TARGET)
117128 if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /gpu/${TARGET} .cu.cc)
118129 list (APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR} /gpu/${TARGET} .cu.cc)
119130 endif ()
131+ if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /gpudnn/${TARGET} _gpudnn.cu)
132+ list (APPEND gpudnn_srcs ${CMAKE_CURRENT_SOURCE_DIR} /gpudnn/${TARGET} _gpudnn.cu)
133+ endif ()
120134 endif ()
121135 if (WITH_XPU)
122136 if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /xpu/${TARGET} .cc)
@@ -135,6 +149,7 @@ function(kernel_library TARGET)
135149 list (APPEND all_srcs ${cpu_srcs} )
136150 list (APPEND all_srcs ${gpu_srcs} )
137151 list (APPEND all_srcs ${xpu_srcs} )
152+ list (APPEND all_srcs ${gpudnn_srcs} )
138153 foreach (src ${all_srcs} )
139154 file (READ ${src} target_content)
140155 string (REGEX MATCHALL "#include \" paddle\/ phi\/ kernels\/ [a-z0-9_]+_kernel.h\" " include_kernels ${target_content} )
@@ -160,21 +175,22 @@ function(kernel_library TARGET)
160175 list (LENGTH cpu_srcs cpu_srcs_len)
161176 list (LENGTH gpu_srcs gpu_srcs_len)
162177 list (LENGTH xpu_srcs xpu_srcs_len)
178+ list (LENGTH gpudnn_srcs gpudnn_srcs_len)
163179 list (LENGTH selected_rows_srcs selected_rows_srcs_len)
164180
165181 # Build Target according different src organization
166182 if ((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR
167- ${xpu_srcs_len} GREATER 0) AND ( ${common_srcs_len } GREATER 0 OR
168- ${selected_rows_srcs_len} GREATER 0))
183+ ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len } GREATER 0) AND
184+ ( ${common_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0))
169185 # If the common_srcs/selected_rows_srcs depends on specific device srcs, build target using this rule.
170186 if (WITH_GPU)
171- if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
172- nv_library(${TARGET} _part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
187+ if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0 )
188+ nv_library(${TARGET} _part SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
173189 nv_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET} _part)
174190 endif ()
175191 elseif (WITH_ROCM)
176- if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
177- hip_library(${TARGET} _part SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
192+ if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0 )
193+ hip_library(${TARGET} _part SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
178194 hip_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET} _part)
179195 endif ()
180196 else ()
@@ -184,14 +200,14 @@ function(kernel_library TARGET)
184200 endif ()
185201 endif ()
186202 # If there are only specific device srcs, build target using this rule.
187- elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
203+ elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0 )
188204 if (WITH_GPU)
189- if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
190- nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
205+ if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0 )
206+ nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
191207 endif ()
192208 elseif (WITH_ROCM)
193- if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0)
194- hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
209+ if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0 )
210+ hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
195211 endif ()
196212 else ()
197213 if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
@@ -228,35 +244,40 @@ function(kernel_library TARGET)
228244 cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
229245 endif ()
230246 else ()
231- message (FATAL_ERROR "Cannot find any implementation for ${TARGET} " )
247+ set (target_build_flag 0 )
232248 endif ()
233249
234- if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR
235- ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR
236- ${selected_rows_srcs_len} GREATER 0)
237- # append target into PTEN_KERNELS property
238- get_property (pten_kernels GLOBAL PROPERTY PTEN_KERNELS)
239- set (pten_kernels ${pten_kernels} ${TARGET} )
240- set_property (GLOBAL PROPERTY PTEN_KERNELS ${pten_kernels} )
241- endif ()
250+ if (${target_build_flag} EQUAL 1)
251+ if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR
252+ ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR
253+ ${gpudnn_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0)
254+ # append target into PHI_KERNELS property
255+ get_property (phi_kernels GLOBAL PROPERTY PHI_KERNELS)
256+ set (phi_kernels ${phi_kernels} ${TARGET} )
257+ set_property (GLOBAL PROPERTY PHI_KERNELS ${phi_kernels} )
258+ endif ()
242259
243- # parse kernel name and auto generate kernel declaration
244- # here, we don't need to check WITH_XXX, because if not WITH_XXX, the
245- # xxx_srcs_len will be equal to 0
246- if (${common_srcs_len} GREATER 0)
247- kernel_declare(${common_srcs} )
248- endif ()
249- if (${cpu_srcs_len} GREATER 0)
250- kernel_declare(${cpu_srcs} )
251- endif ()
252- if (${gpu_srcs_len} GREATER 0)
253- kernel_declare(${gpu_srcs} )
254- endif ()
255- if (${xpu_srcs_len} GREATER 0)
256- kernel_declare(${xpu_srcs} )
257- endif ()
258- if (${selected_rows_srcs_len} GREATER 0)
259- kernel_declare(${selected_rows_srcs} )
260+ # parse kernel name and auto generate kernel declaration
261+ # here, we don't need to check WITH_XXX, because if not WITH_XXX, the
262+ # xxx_srcs_len will be equal to 0
263+ if (${common_srcs_len} GREATER 0)
264+ kernel_declare(${common_srcs} )
265+ endif ()
266+ if (${cpu_srcs_len} GREATER 0)
267+ kernel_declare(${cpu_srcs} )
268+ endif ()
269+ if (${gpu_srcs_len} GREATER 0)
270+ kernel_declare(${gpu_srcs} )
271+ endif ()
272+ if (${xpu_srcs_len} GREATER 0)
273+ kernel_declare(${xpu_srcs} )
274+ endif ()
275+ if (${gpudnn_srcs_len} GREATER 0)
276+ kernel_declare(${gpudnn_srcs} )
277+ endif ()
278+ if (${selected_rows_srcs_len} GREATER 0)
279+ kernel_declare(${selected_rows_srcs} )
280+ endif ()
260281 endif ()
261282endfunction ()
262283
0 commit comments