Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 25 additions & 25 deletions paddle/phi/kernels/funcs/elementwise_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,14 @@ class TransformFunctor {
TransformFunctor(const DenseTensor &x,
const DenseTensor &y,
DenseTensor *z,
const DeviceContext &ctx,
const DeviceContext &dev_ctx,
Functor func,
const bool is_xsize_larger = true)
: x_(x.data<T>()),
y_(y.data<T>()),
z_(ctx.template Alloc<OutType>(z)),
z_(dev_ctx.template Alloc<OutType>(z)),
nx_(x.numel()),
ctx_(ctx),
dev_ctx_(dev_ctx),
func_(func),
is_xsize_larger_(is_xsize_larger) {
if (is_xsize_larger_ == false) {
Expand All @@ -219,20 +219,20 @@ class TransformFunctor {

inline void Run() const {
phi::Transform<DeviceContext> trans;
trans(ctx_, x_, x_ + nx_, y_, z_, func_);
trans(dev_ctx_, x_, x_ + nx_, y_, z_, func_);
}

inline void RunRowWise(int n) const {
phi::Transform<DeviceContext> trans;
if (is_xsize_larger_) {
trans(ctx_,
trans(dev_ctx_,
x_,
x_ + nx_,
RowwiseTransformIterator<T, DeviceContext>(y_, n),
z_,
func_);
} else {
trans(ctx_,
trans(dev_ctx_,
y_,
y_ + nx_,
RowwiseTransformIterator<T, DeviceContext>(x_, n),
Expand All @@ -244,14 +244,14 @@ class TransformFunctor {
inline void RunMidWise(int n, int post) const {
phi::Transform<DeviceContext> trans;
if (is_xsize_larger_) {
trans(ctx_,
trans(dev_ctx_,
x_,
x_ + nx_,
MidWiseTransformIterator<T, DeviceContext>(y_, n, post),
z_,
func_);
} else {
trans(ctx_,
trans(dev_ctx_,
y_,
y_ + nx_,
MidWiseTransformIterator<T, DeviceContext>(x_, n, post),
Expand All @@ -265,7 +265,7 @@ class TransformFunctor {
const T *y_;
OutType *z_;
int64_t nx_;
const DeviceContext &ctx_;
const DeviceContext &dev_ctx_;
Functor func_;
bool is_xsize_larger_;
};
Expand All @@ -278,17 +278,17 @@ void CommonForwardBroadcastCPU(const DenseTensor &x,
int64_t *y_dims_array,
int64_t *out_dims_array,
int max_dim,
const CPUContext &ctx,
const CPUContext &dev_ctx,
Functor func,
const bool is_xsize_larger = true) {
std::vector<int64_t> index_array(max_dim, 0);
const T *x_data = x.data<T>();
const T *y_data = y.data<T>();
if (z && z->numel() == 0) {
ctx.Alloc<OutType>(z);
dev_ctx.Alloc<OutType>(z);
return;
}
OutType *out_data = ctx.Alloc<OutType>(z);
OutType *out_data = dev_ctx.Alloc<OutType>(z);

const int64_t out_size = std::accumulate(out_dims_array,
out_dims_array + max_dim,
Expand Down Expand Up @@ -731,7 +731,7 @@ __global__ void VectorizedElementwiseKernel(
}

template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize>
void LaunchElementwiseKernel(const KPDevice &ctx,
void LaunchElementwiseKernel(const KPDevice &dev_ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
Functor func) {
Expand All @@ -754,18 +754,18 @@ void LaunchElementwiseKernel(const KPDevice &ctx,
int block_size = 64;
int grid_size = 8;
int read_lens = kps::details::GetXpuReadLens(numel, block_size, grid_size);
auto stream = ctx.x_context()->xpu_stream;
auto stream = dev_ctx.x_context()->xpu_stream;
int64_t main_offset =
(numel / (read_lens * block_size)) * read_lens * block_size;
VectorizedElementwiseKernel<OutT, Functor, Arity, NumOuts, VecSize>
<<<grid_size, block_size, 0, stream>>>(
ins_data, outs_data, numel, main_offset, read_lens, func);
#else
auto gpu_config =
phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, VecSize);
int64_t main_offset = (numel / (VecSize * gpu_config.GetBlockSize())) *
VecSize * gpu_config.GetBlockSize();
auto stream = ctx.stream();
auto stream = dev_ctx.stream();
VectorizedElementwiseKernel<OutT, Functor, Arity, NumOuts, VecSize>
<<<gpu_config.block_per_grid, gpu_config.thread_per_block, 0, stream>>>(
ins_data, outs_data, numel, main_offset, VecSize, func);
Expand All @@ -775,18 +775,18 @@ void LaunchElementwiseKernel(const KPDevice &ctx,
template <typename OutT, typename Functor, int Arity, int NumOuts = 1>
typename std::enable_if<!NeedVectorized<OutT>::value, void>::type
ElementwiseKernelForDifferentVecSize(
const KPDevice &ctx,
const KPDevice &dev_ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
Functor func) {
LaunchElementwiseKernel<OutT, Functor, Arity, NumOuts, VecSizeS>(
ctx, ins, outs, func);
dev_ctx, ins, outs, func);
}

template <typename OutT, typename Functor, int Arity, int NumOuts = 1>
typename std::enable_if<NeedVectorized<OutT>::value, void>::type
ElementwiseKernelForDifferentVecSize(
const KPDevice &ctx,
const KPDevice &dev_ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
Functor func) {
Expand All @@ -795,15 +795,15 @@ ElementwiseKernelForDifferentVecSize(
switch (vec_size) {
case VecSizeL:
LaunchElementwiseKernel<OutT, Functor, Arity, NumOuts, VecSizeL>(
ctx, ins, outs, func);
dev_ctx, ins, outs, func);
break;
case VecSizeM:
LaunchElementwiseKernel<OutT, Functor, Arity, NumOuts, VecSizeM>(
ctx, ins, outs, func);
dev_ctx, ins, outs, func);
break;
case VecSizeS:
LaunchElementwiseKernel<OutT, Functor, Arity, NumOuts, VecSizeS>(
ctx, ins, outs, func);
dev_ctx, ins, outs, func);
break;
default: {
PADDLE_THROW(common::errors::Unimplemented(
Expand All @@ -814,7 +814,7 @@ ElementwiseKernelForDifferentVecSize(
}

template <typename OutT, typename Functor, int NumOuts = 1>
void ElementwiseKernel(const KPDevice &ctx,
void ElementwiseKernel(const KPDevice &dev_ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
Functor func) {
Expand Down Expand Up @@ -850,14 +850,14 @@ void ElementwiseKernel(const KPDevice &ctx,
"but %dth output tensor`s shape is not.",
i));
}
ctx.template Alloc<OutT>((*outs)[i]);
dev_ctx.template Alloc<OutT>((*outs)[i]);
}
if (have_0_size) {
return;
}

ElementwiseKernelForDifferentVecSize<OutT, Functor, kArity, NumOuts>(
ctx, ins, outs, func);
dev_ctx, ins, outs, func);
}

#endif
Expand Down
Loading