Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion paddle/fluid/operators/elementwise/elementwise_div_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,26 @@ limitations under the License. */

#include <vector>
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这把sub放在div里合适吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前的代码当中,div引用了sub_op头文件中的这个函数,但是迁移后,这个函数在sub_op头文件中已经移除,且为了适应pten下的写法,pten下这个函数稍微修改的有点不一样了,于是干脆为了避免对div的修改,将sub的这个函数写入到了div当中。不过这个是临时写法,后续迁移div会删掉。

void default_elementwise_sub(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
SubFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseSubFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseSubFunctor<T>(), z);
}
}

template <typename DeviceContext, typename T>
void default_elementwise_div(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
Expand Down
97 changes: 0 additions & 97 deletions paddle/fluid/operators/elementwise/elementwise_sub_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,103 +17,6 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;

namespace paddle {
namespace operators {

template <typename T>
static __global__ void SimpleElemwiseSubGradCUDAKernel(const T* dout,
int64_t size, T* dx,
T* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;

while (col < size) {
if (dx != nullptr) {
dx[col] = dout[col];
}
dy[col] = -dout[col];
col += blockDim.x * gridDim.x;
}
}

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
auto* dout_data = dout->data<T>();
// dx
if (dx != nullptr) {
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
if (dx->dims() == dout->dims()) {
if (dx_data != dout_data) {
framework::TensorCopy(
*dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
}
} else {
// For inplace strategy, dx will be stored in addr of dout, which makes
// the result of dy wrong.
if (dx->IsSharedBufferWith(*dout)) {
dx->clear();
dx->mutable_data<T>(x->dims(), ctx.GetPlace());
}
std::vector<int> reduce_dims = GetReduceDim(x->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
*dout, dx, kps::IdentityFunctor<T>(), reduce_dims, stream);
}
}
// dy
if (dy != nullptr) {
auto* dy_data = dy->mutable_data<T>(ctx.GetPlace());
if (dy->dims() == dout->dims()) {
if (dy_data != dout_data) {
dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1);
auto size = dy->numel();
dim3 grid_size =
dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1);
SimpleElemwiseSubGradCUDAKernel<T><<<
grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
dout->data<T>(), size, nullptr,
dy->mutable_data<T>(ctx.GetPlace()));
}
} else {
std::vector<int> reduce_dims = GetReduceDim(y->dims(), out->dims(), axis);
gpuStream_t stream = ctx.cuda_device_context().stream();
TensorReduceFunctorImpl<T, T, kps::AddFunctor, kps::InverseFunctor<T>>(
*dout, dy, kps::InverseFunctor<T>(), reduce_dims, stream);
}
}
}

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, plat::CUDADeviceContext>::value>::type
elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1);
auto size = x->numel();
dim3 grid_size =
dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1);
SimpleElemwiseSubGradCUDAKernel<
T><<<grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
dout->data<T>(), size, dx->mutable_data<T>(ctx.GetPlace()),
dy->mutable_data<T>(ctx.GetPlace()));
}

} // namespace operators
} // namespace paddle

REGISTER_OP_CUDA_KERNEL(
elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, float>,
Expand Down
120 changes: 22 additions & 98 deletions paddle/fluid/operators/elementwise/elementwise_sub_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,11 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/place.h"

#include "paddle/pten/kernels/elementwise_grad_kernel.h"
#include "paddle/pten/kernels/math_kernel.h"
namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
void default_elementwise_sub(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y, framework::Tensor* z) {
int axis = ctx.Attr<int>("axis");
auto x_dims = x->dims();
auto y_dims = y->dims();
if (x_dims.size() >= y_dims.size()) {
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
SubFunctor<T>(), z);
} else {
ElementwiseComputeEx<InverseSubFunctor<T>, DeviceContext, T>(
ctx, x, y, axis, InverseSubFunctor<T>(), z);
}
}

template <typename DeviceContext, typename T>
class ElementwiseSubKernel : public framework::OpKernel<T> {
public:
Expand All @@ -48,76 +33,13 @@ class ElementwiseSubKernel : public framework::OpKernel<T> {

auto& dev_ctx = ctx.device_context<DeviceContext>();
int axis = ctx.Attr<int>("axis");
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x);
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z);
pten::SubtractRawKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*pt_x.get(), *pt_y.get(), axis, pt_z.get());
*x, *y, axis, z);
}
};

template <typename T>
struct SubGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
};

template <typename T>
struct SubGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; }
};

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy) {
int axis = ctx.Attr<int>("axis");
const auto& dev_ctx =
ctx.template device_context<platform::CPUDeviceContext>();
pten::ElemwiseExplicitGradCompute<T, SubGradDX<T>, SubGradDY<T>>(
dev_ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX<T>(),
SubGradDY<T>());
}

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
default_elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
// cuda definition
template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
default_elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x,
const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout,
framework::Tensor* dx, framework::Tensor* dy);

template <typename DeviceContext, typename T>
typename std::enable_if<
std::is_same<DeviceContext, platform::CUDADeviceContext>::value>::type
elementwise_sub_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy);
#endif

template <typename DeviceContext, typename T>
class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
public:
Expand All @@ -130,14 +52,13 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
// skip out
auto* out = dout;
if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) {
elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx, dy);
} else {
default_elementwise_sub_grad<DeviceContext, T>(ctx, x, y, out, dout, dx,
dy);
}
int axis = ctx.Attr<int>("axis");
auto& dev_ctx = ctx.device_context<DeviceContext>();

pten::SubtractGradKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*x, *y, *dout, axis, dx, dy);
}
};

Expand All @@ -153,18 +74,21 @@ class ElementwiseSubDoubleGradKernel : public framework::OpKernel<T> {
auto* ddy = ctx.Input<Tensor>("DDY");

auto* ddout = ctx.Output<Tensor>("DDOut");
int axis = ctx.Attr<int>("axis");
auto& dev_ctx = ctx.device_context<DeviceContext>();

// DDOut = ddx - ddy
if (ddout) {
Tensor ddx_safe, ddy_safe;
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, dout, ddx, &ddx_safe);
GetDoubleGradSafeTensor<DeviceContext, T>(ctx, y, ddy, &ddy_safe);

ddout->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &ddx_safe, &ddy_safe, axis, SubFunctor<T>(), ddout);
paddle::optional<const pten::DenseTensor&> ddx_optional = paddle::none;
paddle::optional<const pten::DenseTensor&> ddy_optional = paddle::none;
if (ddx != nullptr) {
ddx_optional = *ddx;
}
if (ddy != nullptr) {
ddy_optional = *ddy;
}
pten::SubtractDoubleGradKernel<T>(
static_cast<const typename framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*y, ddx_optional, ddy_optional, *dout, axis, ddout);
}
};

Expand Down
1 change: 1 addition & 0 deletions paddle/pten/core/kernel_alias_name.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"elementwise_div", "divide_raw"},
{"elementwise_mul", "muliply_raw"},
{"elementwise_sub", "subtract_raw"},
{"elementwise_sub_grad", "subtract_grad"},
{"fill_any_like", "full_like"},
{"fill_constant", "full"},
{"flatten_contiguous_range", "flatten"},
Expand Down
36 changes: 34 additions & 2 deletions paddle/pten/kernels/cpu/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -753,8 +753,11 @@ void ElemwiseExplicitGradCompute(const CPUContext& dev_ctx,
}
}

// Add Grad

/*
******************************
Add Grad
******************************
*/
template <typename T>
struct IdentityGrad {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
Expand Down Expand Up @@ -796,4 +799,33 @@ elementwise_add_grad(const CPUContext& ctx,
ctx, x, y, out, dout, axis, dx, dy, IdentityGrad<T>(), IdentityGrad<T>());
}

/*
******************************
Sub Grad
******************************
*/

template <typename T>
struct SubGradDX {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; }
};

template <typename T>
struct SubGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; }
};

template <typename T>
void elementwise_sub_grad(const CPUContext& ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dout,
DenseTensor* dx,
DenseTensor* dy,
int axis = -1) {
ElemwiseExplicitGradCompute<T, SubGradDX<T>, SubGradDY<T>>(
ctx, x, y, out, dout, axis, dx, dy, SubGradDX<T>(), SubGradDY<T>());
}

} // namespace pten
Loading