@@ -134,8 +134,8 @@ function(kernel_library TARGET)
134134 if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /kps/${TARGET} .cu)
135135 list (APPEND gpu_srcs ${CMAKE_CURRENT_SOURCE_DIR} /kps/${TARGET} .cu)
136136 endif ()
137- if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /gpudnn/${TARGET} _gpudnn .cu)
138- list (APPEND gpudnn_srcs ${CMAKE_CURRENT_SOURCE_DIR} /gpudnn/${TARGET} _gpudnn .cu)
137+ if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR} /gpudnn/${TARGET} .cu)
138+ list (APPEND gpudnn_srcs ${CMAKE_CURRENT_SOURCE_DIR} /gpudnn/${TARGET} .cu)
139139 endif ()
140140 endif ()
141141 if (WITH_XPU)
@@ -197,92 +197,88 @@ function(kernel_library TARGET)
197197
198198 # kernel source file level
199199 # level 1: base device kernel
200- # - cpu_srcs / gpu_srcs / xpu_srcs / kps_srcs
200+ # - cpu_srcs / gpu_srcs / xpu_srcs / gpudnn_srcs / kps_srcs
201201 # level 2: device-independent kernel
202202 # - common_srcs
203203 # level 3: Kernel implemented by reusing device-independent kernel
204204 # - selected_rows_srcs
205+ set (base_device_kernels)
206+ set (device_independent_kernel)
207+ set (high_level_kernels)
205208
206- # Build Target according different src organization
207- if ((${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR
208- ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0 OR ${kps_srcs_len} GREATER 0) AND
209- (${common_srcs_len} GREATER 0 OR ${selected_rows_srcs_len} GREATER 0))
210- # If the common_srcs/selected_rows_srcs depends on specific device srcs, build target using this rule.
209+ # 1. Base device kernel compile
210+ if (${cpu_srcs_len} GREATER 0)
211+ cc_library(${TARGET} _cpu SRCS ${cpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
212+ list (APPEND base_device_kernels ${TARGET} _cpu)
213+ endif ()
214+ if (${gpu_srcs_len} GREATER 0)
211215 if (WITH_GPU)
212- if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
213- nv_library(${TARGET} _part SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
214- nv_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET} _part)
215- endif ()
216+ nv_library(${TARGET} _gpu SRCS ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
216217 elseif (WITH_ROCM)
217- if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
218- hip_library(${TARGET} _part SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
219- hip_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET} _part)
220- endif ()
221- elseif (WITH_XPU_KP)
222- if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${kps_srcs_len} GREATER 0)
223- xpu_library(${TARGET} _part SRCS ${cpu_srcs} ${xpu_srcs} ${kps_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
224- xpu_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET} _part)
225- endif ()
226- else ()
227- if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
228- cc_library(${TARGET} _part SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
229- cc_library(${TARGET} SRCS ${common_srcs} ${selected_rows_srcs} DEPS ${TARGET} _part)
230- endif ()
218+ hip_library(${TARGET} _gpu SRCS ${gpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
231219 endif ()
232- # If there are only specific device srcs, build target using this rule.
233- elseif (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0 OR ${kps_srcs_len} GREATER 0)
220+ list (APPEND base_device_kernels ${TARGET} _gpu)
221+ endif ()
222+ if (${xpu_srcs_len} GREATER 0)
223+ cc_library(${TARGET} _xpu SRCS ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
224+ list (APPEND base_device_kernels ${TARGET} _xpu)
225+ endif ()
226+ if (${gpudnn_srcs_len} GREATER 0)
234227 if (WITH_GPU)
235- if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
236- nv_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
237- endif ()
228+ nv_library(${TARGET} _gpudnn SRCS ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
238229 elseif (WITH_ROCM)
239- if (${cpu_srcs_len} GREATER 0 OR ${gpu_srcs_len} GREATER 0 OR ${gpudnn_srcs_len} GREATER 0)
240- hip_library(${TARGET} SRCS ${cpu_srcs} ${gpu_srcs} ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
241- endif ()
242- elseif (WITH_XPU_KP)
243- if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0 OR ${kps_srcs_len} GREATER 0)
244- xpu_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} ${kps_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
245- endif ()
246- else ()
247- if (${cpu_srcs_len} GREATER 0 OR ${xpu_srcs_len} GREATER 0)
248- cc_library(${TARGET} SRCS ${cpu_srcs} ${xpu_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
249- endif ()
230+ hip_library(${TARGET} _gpudnn SRCS ${gpudnn_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
250231 endif ()
251- # If the selected_rows_srcs depends on common_srcs, build target using this rule.
252- elseif (${common_srcs_len} GREATER 0 AND ${selected_rows_srcs_len} GREATER 0)
232+ list (APPEND base_device_kernels ${TARGET} _gpudnn)
233+ endif ()
234+ if (${kps_srcs_len} GREATER 0)
235+ # only when WITH_XPU_KP, the kps_srcs_len can be > 0
236+ xpu_library(${TARGET} _kps SRCS ${kps_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
237+ list (APPEND base_device_kernels ${TARGET} _kps)
238+ endif ()
239+
240+ # 2. Device-independent kernel compile
241+ if (${common_srcs_len} GREATER 0)
253242 if (WITH_GPU)
254- nv_library(${TARGET} _part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
255- nv_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET} _part)
243+ nv_library(${TARGET} _common SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} )
256244 elseif (WITH_ROCM)
257- hip_library(${TARGET} _part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
258- hip_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET} _part)
245+ hip_library(${TARGET} _common SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} )
259246 elseif (WITH_XPU_KP)
260- xpu_library(${TARGET} _part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
261- xpu_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET} _part)
247+ xpu_library(${TARGET} _common SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} )
262248 else ()
263- cc_library(${TARGET} _part SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} )
264- cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${TARGET} _part)
249+ cc_library(${TARGET} _common SRCS ${common_srcs} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} )
265250 endif ()
266- # If there are only common_srcs or selected_rows_srcs, build target using below rules.
267- elseif (${common_srcs_len} GREATER 0)
251+ list (APPEND device_independent_kernel ${TARGET} _common)
252+ endif ()
253+
254+ # 3. Reusing kernel compile
255+ if (${selected_rows_srcs_len} GREATER 0)
268256 if (WITH_GPU)
269- nv_library(${TARGET} SRCS ${common_srcs } DEPS ${kernel_library_DEPS} ${kernel_deps} )
257+ nv_library(${TARGET} _sr SRCS ${selected_rows_srcs } DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel } )
270258 elseif (WITH_ROCM)
271- hip_library(${TARGET} SRCS ${common_srcs } DEPS ${kernel_library_DEPS} ${kernel_deps} )
259+ hip_library(${TARGET} _sr SRCS ${selected_rows_srcs } DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel } )
272260 elseif (WITH_XPU_KP)
273- xpu_library(${TARGET} SRCS ${common_srcs } DEPS ${kernel_library_DEPS} ${kernel_deps} )
261+ xpu_library(${TARGET} _sr SRCS ${selected_rows_srcs } DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel } )
274262 else ()
275- cc_library(${TARGET} SRCS ${common_srcs } DEPS ${kernel_library_DEPS} ${kernel_deps} )
263+ cc_library(${TARGET} _sr SRCS ${selected_rows_srcs } DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels} ${device_independent_kernel } )
276264 endif ()
277- elseif (${selected_rows_srcs_len} GREATER 0)
265+ list (APPEND high_level_kernels ${TARGET} _sr)
266+ endif ()
267+
268+ # 4. Unify target compile
269+ list (LENGTH base_device_kernels base_device_kernels_len)
270+ list (LENGTH device_independent_kernel device_independent_kernel_len)
271+ list (LENGTH high_level_kernels high_level_kernels_len)
272+ if (${base_device_kernels_len} GREATER 0 OR ${device_independent_kernel_len} GREATER 0 OR
273+ ${high_level_kernels_len} GREATER 0)
278274 if (WITH_GPU)
279- nv_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS } ${kernel_deps } )
275+ nv_library(${TARGET} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels } ${device_independent_kernel} ${high_level_kernels } )
280276 elseif (WITH_ROCM)
281- hip_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS } ${kernel_deps } )
277+ hip_library(${TARGET} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels } ${device_independent_kernel} ${high_level_kernels } )
282278 elseif (WITH_XPU_KP)
283- xpu_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS } ${kernel_deps } )
279+ xpu_library(${TARGET} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels } ${device_independent_kernel} ${high_level_kernels } )
284280 else ()
285- cc_library(${TARGET} SRCS ${selected_rows_srcs} DEPS ${kernel_library_DEPS } ${kernel_deps } )
281+ cc_library(${TARGET} DEPS ${kernel_library_DEPS} ${kernel_deps} ${base_device_kernels } ${device_independent_kernel} ${high_level_kernels } )
286282 endif ()
287283 else ()
288284 set (target_build_flag 0)
0 commit comments