Skip to content

Commit 70ce915

Browse files
authored
[XPU]Fix Option to Enable/Disable xpufft Build (#72734)
* fix_cmake * fix * ci * ci
1 parent d8cb30c commit 70ce915

File tree

6 files changed

+26
-40
lines changed

6 files changed

+26
-40
lines changed

cmake/phi.cmake

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,6 @@ function(kernel_declare TARGET_LIST)
110110
set(first_registry "")
111111
endif()
112112
endif()
113-
# The kernel related to xpufft must have WITH_XPU_FFT enabled.
114-
if(WITH_XPU AND NOT WITH_XPU_FFT)
115-
string(FIND "${first_registry}" "xpufft" pos)
116-
if(pos GREATER 1)
117-
set(first_registry "")
118-
endif()
119-
endif()
120113

121114
if(NOT first_registry STREQUAL "")
122115
string(
@@ -148,7 +141,6 @@ function(kernel_declare TARGET_LIST)
148141
string(REPLACE "," ";" kernel_msg "${kernel_msg}")
149142
string(REGEX REPLACE "[ \\\t\r\n]+" "" kernel_msg "${kernel_msg}")
150143
string(REGEX REPLACE "//cuda_only" "" kernel_msg "${kernel_msg}")
151-
string(REGEX REPLACE "//xpufft" "" kernel_msg "${kernel_msg}")
152144

153145
list(GET kernel_msg 0 kernel_name)
154146
if(NOT is_all_backend STREQUAL "")

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1824,7 +1824,15 @@ XPUOpMap& get_kl3_ops() {
18241824
phi::DataType::INT64,
18251825
phi::DataType::INT32})},
18261826
#ifdef PADDLE_WITH_XPU_FFT
1827-
{"conj", XPUKernelSet({phi::DataType::COMPLEX64})},
1827+
{"conj",
1828+
XPUKernelSet({phi::DataType::FLOAT32,
1829+
phi::DataType::FLOAT16,
1830+
phi::DataType::BFLOAT16,
1831+
phi::DataType::FLOAT64,
1832+
phi::DataType::BOOL,
1833+
phi::DataType::INT64,
1834+
phi::DataType::INT32,
1835+
phi::DataType::COMPLEX64})},
18281836
{"real", XPUKernelSet({phi::DataType::COMPLEX64})},
18291837
{"real_grad", XPUKernelSet({phi::DataType::COMPLEX64})},
18301838
{"imag", XPUKernelSet({phi::DataType::COMPLEX64})},

paddle/phi/kernels/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,14 @@ if(WITH_XPU)
325325
file(GLOB_RECURSE kernel_xpu_cc "${CMAKE_CURRENT_BINARY_DIR}/*.cc")
326326
collect_generated_srcs(kernels_srcs SRCS ${kernel_xpu_cc})
327327
set(kernel_cc "")
328+
endif()
328329

330+
# The kernel related to xpufft must have WITH_XPU_FFT enabled.
331+
if(NOT WITH_XPU_FFT)
332+
list(REMOVE_ITEM kernel_xpu "xpu/complex_kernel.cc"
333+
"xpu/complex_grad_kernel.cc")
329334
endif()
335+
330336
collect_srcs(kernels_srcs SRCS ${kernel_xpu})
331337
kernel_declare("${kernel_xpu}")
332338
kernel_declare("${kernel_xpu_kps}")

paddle/phi/kernels/xpu/complex_grad_kernel.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,27 +117,24 @@ void ComplexGradKernel(const Context& dev_ctx,
117117
}
118118
} // namespace phi
119119

120-
PD_REGISTER_KERNEL(imag_grad, // xpufft
120+
PD_REGISTER_KERNEL(imag_grad,
121121
XPU,
122122
ALL_LAYOUT,
123123
phi::ImagGradKernel,
124124
phi::dtype::complex<float>) {
125125
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
126126
}
127127

128-
PD_REGISTER_KERNEL(real_grad, // xpufft
128+
PD_REGISTER_KERNEL(real_grad,
129129
XPU,
130130
ALL_LAYOUT,
131131
phi::RealGradKernel,
132132
phi::dtype::complex<float>) {
133133
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
134134
}
135135

136-
PD_REGISTER_KERNEL(complex_grad, // xpufft
137-
XPU,
138-
ALL_LAYOUT,
139-
phi::ComplexGradKernel,
140-
float) {
136+
PD_REGISTER_KERNEL(
137+
complex_grad, XPU, ALL_LAYOUT, phi::ComplexGradKernel, float) {
141138
kernel->InputAt(2).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
142139
}
143140
#endif // PADDLE_WITH_XPU_FFT

paddle/phi/kernels/xpu/complex_kernel.cc

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ void ConjKernel(const Context& dev_ctx,
3636
const DenseTensor& x,
3737
DenseTensor* out) {
3838
dev_ctx.template Alloc<T>(out);
39-
if (std::is_same<T, phi::dtype::complex<float>>::value) {
39+
if (std::is_same_v<T, phi::dtype::complex<float>>) {
4040
int r = xfft_internal::xpu::Conj(
4141
x.numel(),
4242
reinterpret_cast<cuFloatComplex*>(const_cast<T*>(x.data<T>())),
@@ -132,7 +132,7 @@ void ComplexKernel(const Context& dev_ctx,
132132
}
133133
} // namespace phi
134134

135-
PD_REGISTER_KERNEL(conj, // xpufft
135+
PD_REGISTER_KERNEL(conj,
136136
XPU,
137137
ALL_LAYOUT,
138138
phi::ConjKernel,
@@ -145,27 +145,17 @@ PD_REGISTER_KERNEL(conj, // xpufft
145145
phi::dtype::bfloat16,
146146
phi::dtype::complex<float>) {}
147147

148-
PD_REGISTER_KERNEL(real, // xpufft
149-
XPU,
150-
ALL_LAYOUT,
151-
phi::RealKernel,
152-
phi::dtype::complex<float>) {
148+
PD_REGISTER_KERNEL(
149+
real, XPU, ALL_LAYOUT, phi::RealKernel, phi::dtype::complex<float>) {
153150
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
154151
}
155152

156-
PD_REGISTER_KERNEL(imag, // xpufft
157-
XPU,
158-
ALL_LAYOUT,
159-
phi::ImagKernel,
160-
phi::dtype::complex<float>) {
153+
PD_REGISTER_KERNEL(
154+
imag, XPU, ALL_LAYOUT, phi::ImagKernel, phi::dtype::complex<float>) {
161155
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
162156
}
163157

164-
PD_REGISTER_KERNEL(complex, // xpufft
165-
XPU,
166-
ALL_LAYOUT,
167-
phi::ComplexKernel,
168-
float) {
158+
PD_REGISTER_KERNEL(complex, XPU, ALL_LAYOUT, phi::ComplexKernel, float) {
169159
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
170160
}
171161
#endif // PADDLE_WITH_XPU_FFT

test/xpu/xpu_op_test

Lines changed: 0 additions & 7 deletions
This file was deleted.

0 commit comments

Comments
 (0)