@@ -83,6 +83,8 @@ function(kernel_declare TARGET_LIST)
8383 file (APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name} , XPU, ALL_LAYOUT);\n " )
8484 elseif (${kernel_path} MATCHES "./gpudnn\/ " )
8585 file (APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name} , GPUDNN, ALL_LAYOUT);\n " )
86+ elseif (${kernel_path} MATCHES "./kps\/ " )
87+ file (APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name} , KPS, ALL_LAYOUT);\n " )
8688 else ()
8789 # deal with device independent kernel, now we use CPU temporaary
8890 file (APPEND ${kernel_declare_file} "PD_DECLARE_KERNEL(${kernel_name} , CPU, ALL_LAYOUT);\n " )
@@ -97,6 +99,7 @@ function(kernel_library TARGET)
9799 set (gpu_srcs)
98100 set (xpu_srcs)
99101 set (gpudnn_srcs)
102+ set (kps_srcs)
100103 set (selected_rows_srcs)
101104 # parse and save the deps kerenl targets
102105 set (all_srcs)
@@ -128,6 +131,9 @@ function(kernel_library TARGET)
128131 if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /gpu/${TARGET} .cu.cc)
129132 list (APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR} /gpu/${TARGET} .cu.cc)
130133 endif ()
134+ if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /kps/${TARGET} .cu)
135+ list (APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR} /kps/${TARGET} .cu)
136+ endif ()
131137 if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /gpudnn/${TARGET} _gpudnn.cu)
132138 list (APPEND gpudnn_srcs ${CMAKE_CURRENT_SOURCE_DIR} /gpudnn/${TARGET} _gpudnn.cu)
133139 endif ()
@@ -137,6 +143,15 @@ function(kernel_library TARGET)
137143 list (APPEND xpu_srcs ${CMAKE_CURRENT_SOURCE_DIR} /xpu/${TARGET} .cc)
138144 endif ()
139145 endif ()
146+ if (WITH_XPU_KP)
147+ if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /kps/${TARGET} .cu)
148+ # Change XPU2 file suffix
149+ # NOTE(chenweihang): If we can be sure that the *.kps suffix is no longer used, it can be copied directly to *.xpu
150+ file (COPY ${CMAKE_CURRENT_SOURCE_DIR} /kps/${TARGET} .cu DESTINATION ${CMAKE_CURRENT_BINARY_DIR} /kps)
151+ file (RENAME ${CMAKE_CURRENT_BINARY_DIR} /kps/${TARGET} .cu ${CMAKE_CURRENT_BINARY_DIR} /kps/${TARGET} .kps)
152+ list (APPEND kps_srcs ${CMAKE_CURRENT_BINARY_DIR} /kps/${TARGET} .kps)
153+ endif ()
154+ endif ()
140155 else ()
141156 # TODO(chenweihang): impl compile by source later
142157 endif ()
@@ -150,6 +165,7 @@ function(kernel_library TARGET)
150165 list (APPEND all_srcs ${gpu_srcs} )
151166 list (APPEND all_srcs ${xpu_srcs} )
152167 list (APPEND all_srcs ${gpudnn_srcs} )
168+ list (APPEND all_srcs ${kps_srcs} )
153169 foreach (src ${all_srcs} )
154170 file (READ ${src} target_content)
155171 string (REGEX MATCHALL "#include \" paddle\/ phi\/ kernels\/ [a-z0-9_]+_kernel.h\" " include_kernels ${target_content} )
@@ -159,11 +175,11 @@ function(kernel_library TARGET)
159175 string (REGEX MATCHALL "#include \" paddle\/ phi\/ kernels\/ ${kernel_library_SUB_DIR} \/ [a-z0-9_]+_kernel.h\" " include_kernels ${target_content} )
160176 endif ()
161177 foreach (include_kernel ${include_kernels} )
162- if ("${kernel_library_SUB_DIR} " STREQUAL "" )
163- string (REGEX REPLACE "#include \" paddle\/ phi\/ kernels\/ " "" kernel_name ${include_kernel} )
164- else ()
165- string (REGEX REPLACE "#include \" paddle\/ phi\/ kernels\/ ${kernel_library_SUB_DIR} \/ " "" kernel_name ${include_kernel} )
166- endif ()
178+ if ("${kernel_library_SUB_DIR} " STREQUAL "" )
179+ string (REGEX REPLACE "#include \" paddle\/ phi\/ kernels\/ " "" kernel_name ${include_kernel} )
180+ else ()
181+ string (REGEX REPLACE "#include \" paddle\/ phi\/ kernels\/ ${kernel_library_SUB_DIR} \/ " "" kernel_name ${include_kernel} )
182+ endif ()
167183 string (REGEX REPLACE ".h\" " "" kernel_name ${kernel_name} )
168184 list (APPEND kernel_deps ${kernel_name} )
169185 endforeach ()
@@ -176,11 +192,20 @@ function(kernel_library TARGET)
176192 list (LENGTH gpu_srcs gpu_srcs_len)
177193 list (LENGTH xpu_srcs xpu_srcs_len)
178194 list (LENGTH gpudnn_srcs gpudnn_srcs_len)
195+ list (LENGTH kps_srcs kps_srcs_len)
179196 list (LENGTH selected_rows_srcs selected_rows_srcs_len)
180197
198+ # kernel source file level
199+ # level 1: base device kernel
200+ # - cpu_srcs / gpu_srcs / xpu_srcs / kps_srcs
201+ # level 2: device-independent kernel
202+ # - common_srcs
203+ # level 3: Kernel implemented by reusing device-independent kernel
204+ # - selected_rows_srcs
205+
181206 # Build Target according different src organization
182207 if ((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR
183- ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0) AND
208+ ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0 OR ${kps_srcs_len} GREATER 0 ) AND
184209 (${common_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0))
185210 # If the common_srcs/selected_rows_srcs depends on specific device srcs, build target using this rule.
186211 if (WITH_GPU)
@@ -193,14 +218,19 @@ function(kernel_library TARGET)
193218 hip_library(${TARGET} _part SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
194219 hip_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET} _part)
195220 endif ()
221+ elseif (WITH_XPU_KP)
222+ if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${kps_srcs_len} GREATER 0)
223+ xpu_library(${TARGET} _part SRCS ${cpu_srcs} ${xpu_srcs} ${kps_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
224+ xpu_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET} _part)
225+ endif ()
196226 else ()
197227 if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
198228 cc_library(${TARGET} _part SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
199229 cc_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET} _part)
200230 endif ()
201231 endif ()
202232 # If there are only specific device srcs, build target using this rule.
203- elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
233+ elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0 OR ${kps_srcs_len} GREATER 0 )
204234 if (WITH_GPU)
205235 if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
206236 nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
@@ -209,6 +239,10 @@ function(kernel_library TARGET)
209239 if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
210240 hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
211241 endif ()
242+ elseif (WITH_XPU_KP)
243+ if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${kps_srcs_len} GREATER 0)
244+ xpu_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} ${kps_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
245+ endif ()
212246 else ()
213247 if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
214248 cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
@@ -222,6 +256,9 @@ function(kernel_library TARGET)
222256 elseif (WITH_ROCM)
223257 hip_library(${TARGET} _part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
224258 hip_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET} _part)
259+ elseif (WITH_XPU_KP)
260+ xpu_library(${TARGET} _part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
261+ xpu_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET} _part)
225262 else ()
226263 cc_library(${TARGET} _part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
227264 cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET} _part)
@@ -232,6 +269,8 @@ function(kernel_library TARGET)
232269 nv_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
233270 elseif (WITH_ROCM)
234271 hip_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
272+ elseif (WITH_XPU_KP)
273+ xpu_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
235274 else ()
236275 cc_library(${TARGET} SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
237276 endif ()
@@ -240,6 +279,8 @@ function(kernel_library TARGET)
240279 nv_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
241280 elseif (WITH_ROCM)
242281 hip_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
282+ elseif (WITH_XPU_KP)
283+ xpu_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
243284 else ()
244285 cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
245286 endif ()
@@ -249,7 +290,7 @@ function(kernel_library TARGET)
249290
250291 if (${target_build_flag} EQUAL 1)
251292 if (${common_srcs_len} GREATER 0 OR ${cpu_srcs_len} GREATER 0 OR
252- ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR
293+ ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${kps_srcs_len} GREATER 0 OR
253294 ${gpudnn_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0)
254295 # append target into PHI_KERNELS property
255296 get_property (phi_kernels GLOBAL PROPERTY PHI_KERNELS)
@@ -275,6 +316,9 @@ function(kernel_library TARGET)
275316 if (${gpudnn_srcs_len} GREATER 0)
276317 kernel_declare(${gpudnn_srcs} )
277318 endif ()
319+ if (${kps_srcs_len} GREATER 0)
320+ kernel_declare(${kps_srcs} )
321+ endif ()
278322 if (${selected_rows_srcs_len} GREATER 0)
279323 kernel_declare(${selected_rows_srcs} )
280324 endif ()
0 commit comments