Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1194,23 +1194,26 @@ XPUOpMap& get_kl3_ops() {
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
phi::DataType::FLOAT32,
phi::DataType::COMPLEX64})},
{"reshape2",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
phi::DataType::FLOAT32,
phi::DataType::COMPLEX64})},
{"reshape",
XPUKernelSet({phi::DataType::FLOAT64,
phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::BFLOAT16,
phi::DataType::BOOL,
phi::DataType::FLOAT32})},
phi::DataType::FLOAT32,
phi::DataType::COMPLEX64})},
{"resnet_unit",
XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})},
{"resnet_unit_grad",
Expand Down
17 changes: 9 additions & 8 deletions paddle/phi/kernels/funcs/fft_fill_conj_xpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,36 +62,37 @@ void FFTFillConj(const DeviceContext& ctx,
for (const auto i : axes) {
_is_fft_axis[i] = true;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个空行是为了区分开参数的计算和数据的拷贝部分,不是误加的

xpu::ctx_guard RAII_GUARD(ctx.x_context());
int64_t* src_strides_ptr =
RAII_GUARD.alloc_l3_or_gm<int64_t>(src_strides_v.size());
PADDLE_ENFORCE_NOT_NULL(src_strides_ptr,
common::errors::External("XPU has no enough memory"));
xpu_memcpy(reinterpret_cast<void*>(src_strides_ptr),
reinterpret_cast<void*>(src_strides_v.data()),
xpu_memcpy(src_strides_ptr,
src_strides_v.data(),
src_strides_v.size() * sizeof(int64_t),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
int64_t* dst_strides_ptr =
RAII_GUARD.alloc_l3_or_gm<int64_t>(dst_strides_v.size());
PADDLE_ENFORCE_NOT_NULL(dst_strides_ptr,
common::errors::External("XPU has no enough memory"));
xpu_memcpy(reinterpret_cast<void*>(dst_strides_ptr),
reinterpret_cast<void*>(dst_strides_v.data()),
xpu_memcpy(dst_strides_ptr,
dst_strides_v.data(),
dst_strides_v.size() * sizeof(int64_t),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
int64_t* dst_shape_ptr =
RAII_GUARD.alloc_l3_or_gm<int64_t>(dst_shape_v.size());
PADDLE_ENFORCE_NOT_NULL(dst_shape_ptr,
common::errors::External("XPU has no enough memory"));
xpu_memcpy(reinterpret_cast<void*>(dst_shape_ptr),
reinterpret_cast<void*>(dst_shape_v.data()),
xpu_memcpy(dst_shape_ptr,
dst_shape_v.data(),
dst_shape_v.size() * sizeof(int64_t),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
bool* _is_fft_axis_ptr = RAII_GUARD.alloc_l3_or_gm<bool>(rank);
PADDLE_ENFORCE_NOT_NULL(_is_fft_axis_ptr,
common::errors::External("XPU has no enough memory"));
xpu_memcpy(reinterpret_cast<void*>(_is_fft_axis_ptr),
reinterpret_cast<void*>(_is_fft_axis.get()),
xpu_memcpy(_is_fft_axis_ptr,
_is_fft_axis.get(),
rank * sizeof(bool),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);

Expand Down
46 changes: 4 additions & 42 deletions paddle/phi/kernels/funcs/fft_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,39 +77,11 @@ void exec_normalization(const phi::XPUContext& ctx,
DenseTensor scale_tensor =
phi::Full<T, phi::XPUContext>(ctx, {1}, static_cast<T>(scale));
MultiplyKernel<T, phi::XPUContext>(ctx, in, scale_tensor, out);
// ScaleKernel<T, phi::XPUContext>(ctx, in, scale, 0, true, out);
} else {
AssignKernel<phi::XPUContext>(ctx, in, out);
}
}

bool has_large_prime_factor(int64_t n) {
constexpr int64_t first_large_prime = 11;
const std::array<int64_t, 4> prime_radices{{2, 3, 5, 7}};
for (auto prime : prime_radices) {
if (n < first_large_prime) {
return false;
}
while (n % prime == 0) {
n /= prime;
}
}
return n != 1;
}

inline bool use_cache(const int64_t* signal_size) {
bool using_cache = true;
int cufft_version;
phi::dynload::cufftGetVersion(&cufft_version);
if (10300 <= cufft_version && cufft_version <= 10400) {
using_cache = std::none_of(
signal_size + 1, signal_size + kMaxDataNdim, [](int64_t dim_size) {
return has_large_prime_factor(dim_size);
});
}
return using_cache;
}

// up to 3d unnormalized fft transform (c2r, r2c, c2c)
template <typename Ti, typename To>
void exec_fft(const phi::XPUContext& ctx,
Expand Down Expand Up @@ -171,22 +143,12 @@ void exec_fft(const phi::XPUContext& ctx,
collapsed_output.Resize(collapsed_output_shape);
ctx.Alloc<To>(&collapsed_output);

int64_t device_id = ctx.GetPlace().GetDeviceId();
FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
std::lock_guard<std::mutex> guard(plan_cache.mutex);
FFTConfigKey key =
create_fft_configkey(collapsed_input, collapsed_output, signal_ndim);
int64_t device_id = ctx.GetPlace().GetDeviceId();
FFTConfig* config = nullptr;
std::unique_ptr<FFTConfig> config_ = nullptr;
bool using_cache = use_cache(key.sizes_);

if (using_cache) {
FFTConfigCache& plan_cache = get_fft_plan_cache(device_id);
std::unique_lock<std::mutex> guard(plan_cache.mutex, std::defer_lock);
guard.lock();
config = &(plan_cache.lookup(key));
} else {
config_ = std::make_unique<FFTConfig>(key);
config = config_.get();
}
FFTConfig* config = &(plan_cache.lookup(key));

const int64_t workspace_size = static_cast<int64_t>(config->workspace_size());
DenseTensor workspace_tensor = Empty<uint8_t>(ctx, {workspace_size});
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/kernels/impl/fill_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ void FillKernel(const Context& dev_ctx,
" but received NaN"));

dev_ctx.template Alloc<T>(out);
if (out->numel() == 0) {
return;
}

phi::funcs::SetConstant<Context, T> functor;
functor(dev_ctx, out, value.to<T>());
Expand Down
Loading