Skip to content

Commit c5e857d

Browse files
authored
elementwise_mul refactor (#37471)
* elementwise_mul refactor * perfect code in test * delete redundant code * fix bugs when run test_multiply * adjust the location of macro * fix bugs when run ci
1 parent 0f24de8 commit c5e857d

File tree

16 files changed

+395
-123
lines changed

16 files changed

+395
-123
lines changed

paddle/fluid/operators/elementwise/elementwise_mul_op.cu

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ limitations under the License. */
1717
#include "paddle/fluid/platform/complex.h"
1818
#include "paddle/fluid/platform/float16.h"
1919

20+
// only can include the headers in paddle/top/api dirs
21+
#include "paddle/pten/api/lib/utils/tensor_utils.h"
22+
#include "paddle/pten/include/core.h"
23+
#include "paddle/pten/include/math.h"
2024
namespace ops = paddle::operators;
2125
namespace plat = paddle::platform;
2226

@@ -28,15 +32,39 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
2832
: public framework::OpKernel<T> {
2933
public:
3034
void Compute(const framework::ExecutionContext& ctx) const override {
31-
framework::Tensor x_for_selectedrows;
32-
std::vector<const framework::Tensor*> ins;
33-
std::vector<framework::Tensor*> outs;
35+
auto x_var = ctx.InputVar("X");
36+
PADDLE_ENFORCE_EQ(x_var != nullptr, true,
37+
platform::errors::InvalidArgument(
38+
"Cannot get input Variable X, Variable name = %s.",
39+
ctx.InputName("X")));
3440
const auto& cuda_ctx =
3541
ctx.template device_context<platform::CUDADeviceContext>();
36-
37-
int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs, &x_for_selectedrows);
38-
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
39-
cuda_ctx, ins, &outs, axis, MulFunctor<T>());
42+
if (x_var->IsType<framework::SelectedRows>()) {
43+
framework::Tensor x_for_selectedrows;
44+
std::vector<const framework::Tensor*> ins;
45+
std::vector<framework::Tensor*> outs;
46+
int axis =
47+
PackTensorsIntoVector<T>(ctx, &ins, &outs, &x_for_selectedrows);
48+
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
49+
cuda_ctx, ins, &outs, axis, MulFunctor<T>());
50+
} else if (x_var->IsType<framework::LoDTensor>()) {
51+
auto* x_lod = ctx.Input<framework::LoDTensor>("X");
52+
auto* y_lod = ctx.Input<framework::LoDTensor>("Y");
53+
auto* z_lod = ctx.Output<framework::LoDTensor>("Out");
54+
z_lod->mutable_data<T>(ctx.GetPlace());
55+
56+
int axis = ctx.Attr<int>("axis");
57+
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod);
58+
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y_lod);
59+
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod);
60+
pten::ElementwiseMul<T>(cuda_ctx, *pt_x.get(), *pt_y.get(), axis,
61+
pt_z.get());
62+
} else {
63+
PADDLE_THROW(platform::errors::InvalidArgument(
64+
"X's type[%s] is not supported by elementwise_op. X's type should be "
65+
"LoDTensor or SelectedRows.",
66+
framework::ToTypeName(x_var->Type())));
67+
}
4068
}
4169
};
4270

paddle/fluid/operators/elementwise/elementwise_mul_op.h

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,16 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include <string>
18+
#include "paddle/fluid/framework/pten_utils.h"
1819
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
1920
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
2021
#include "paddle/fluid/operators/math/blas.h"
2122
#include "paddle/fluid/platform/cpu_info.h"
2223

24+
// only can include the headers in paddle/pten/include dirs
25+
#include "paddle/pten/api/lib/utils/tensor_utils.h"
26+
#include "paddle/pten/include/core.h"
27+
#include "paddle/pten/include/math.h"
2328
namespace paddle {
2429
namespace operators {
2530

@@ -106,24 +111,32 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
106111
out_sele->mutable_value()->Resize(x_sele.value().dims());
107112
out_sele->mutable_value()->mutable_data(ctx.GetPlace(), x.type());
108113
z = ctx.Output<framework::SelectedRows>("Out")->mutable_value();
114+
z->mutable_data<T>(ctx.GetPlace());
115+
auto dims_equal = x.dims() == y->dims();
116+
if (dims_equal) {
117+
SameDimsElemwiseMul<DeviceContext, T> same_dims_mul;
118+
same_dims_mul(ctx, &x, y, z);
119+
} else {
120+
default_elementwise_mul<DeviceContext, T>(ctx, &x, y, z);
121+
}
109122
} else if (x_var->IsType<framework::LoDTensor>()) {
110-
x = x_var->Get<framework::LoDTensor>();
111-
z = ctx.Output<framework::LoDTensor>("Out");
123+
auto* x_lod = ctx.Input<framework::LoDTensor>("X");
124+
auto* z_lod = ctx.Output<framework::LoDTensor>("Out");
125+
z_lod->mutable_data<T>(ctx.GetPlace());
126+
127+
auto& dev_ctx = ctx.device_context<DeviceContext>();
128+
int axis = ctx.Attr<int>("axis");
129+
auto pt_x = paddle::experimental::MakePtenDenseTensor(*x_lod);
130+
auto pt_y = paddle::experimental::MakePtenDenseTensor(*y);
131+
auto pt_z = paddle::experimental::MakePtenDenseTensor(*z_lod);
132+
pten::ElementwiseMul<T>(dev_ctx, *pt_x.get(), *pt_y.get(), axis,
133+
pt_z.get());
112134
} else {
113135
PADDLE_THROW(platform::errors::InvalidArgument(
114136
"X's type[%s] is not supported by elementwise_op. X's type should be "
115137
"LoDTensor or SelectedRows.",
116138
framework::ToTypeName(x_var->Type())));
117139
}
118-
119-
z->mutable_data<T>(ctx.GetPlace());
120-
auto dims_equal = x.dims() == y->dims();
121-
if (dims_equal) {
122-
SameDimsElemwiseMul<DeviceContext, T> same_dims_mul;
123-
same_dims_mul(ctx, &x, y, z);
124-
} else {
125-
default_elementwise_mul<DeviceContext, T>(ctx, &x, y, z);
126-
}
127140
}
128141
};
129142
template <typename T>

paddle/fluid/operators/elementwise/elementwise_op.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,12 @@ class ElementwiseOp : public framework::OperatorWithKernel {
160160
{"axis"}, {"Out"});
161161
}
162162
}
163+
if (Type() == "elementwise_mul") {
164+
if (ctx.InputVar("X")->IsType<framework::LoDTensor>()) {
165+
return framework::KernelSignature("elementwise_mul", {"X", "Y"},
166+
{"axis"}, {"Out"});
167+
}
168+
}
163169
return framework::KernelSignature("None", {"X"}, {}, {"Out"});
164170
}
165171
};

paddle/pten/api/include/math.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,8 @@ PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y);
2828
PD_DLL_DECL Tensor subtract(const Tensor& x, const Tensor& y);
2929

3030
PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y);
31+
32+
PD_DLL_DECL Tensor multiply(const Tensor& x, const Tensor& y);
33+
3134
} // namespace experimental
3235
} // namespace paddle

paddle/pten/api/lib/math.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,41 @@ PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y) {
172172

173173
return out;
174174
}
175+
176+
PD_DLL_DECL Tensor multiply(const Tensor& x, const Tensor& y) {
177+
// 1. Get kernel signature and kernel
178+
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
179+
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
180+
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
181+
"elementwise_mul", kernel_key);
182+
183+
// 2. Get Device Context
184+
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
185+
auto kernel_context = pten::KernelContext(dev_ctx);
186+
187+
// 3. Auto data transform
188+
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
189+
kernel_context.EmplaceBackInput(dense_x);
190+
auto dense_y = std::dynamic_pointer_cast<pten::DenseTensor>(y.impl());
191+
kernel_context.EmplaceBackInput(dense_y);
192+
kernel_context.EmplaceBackAttr(-1);
193+
194+
// 4. InferShape
195+
auto out_meta = ElementwiseInferShape(dense_x->meta(), dense_y->meta(), -1);
196+
197+
// 5. Prepare outputs
198+
Tensor out;
199+
const auto allocator = std::make_shared<DefaultAllocator>(
200+
pten::TransToFluidPlace(kernel_key.backend()));
201+
auto dense_out = std::make_shared<pten::DenseTensor>(allocator, out_meta);
202+
kernel_context.EmplaceBackOutput(dense_out);
203+
out.set_impl(dense_out);
204+
205+
// 6. Call kernel
206+
kernel(&kernel_context);
207+
208+
return out;
209+
}
175210
} // namespace experimental
176211
} // namespace paddle
177212

paddle/pten/api/lib/utils/tensor_utils.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,14 @@ void ReMakePtenDenseTensorFromVar(const framework::Variable& variable,
234234
const pten::TensorArgDef& arg_def,
235235
pten::DenseTensor* dst) {
236236
auto expected_place = pten::TransToFluidPlace(arg_def.backend);
237-
238237
if (variable.IsType<framework::LoDTensor>()) {
239238
const auto& tensor = variable.Get<framework::LoDTensor>();
239+
// check input dtype before ReMakePtenDenseTensor
240+
PADDLE_ENFORCE(
241+
(arg_def.dtype == pten::TransToPtenDataType(tensor.type())),
242+
paddle::platform::errors::InvalidArgument(
243+
"The type of input data is diffrent from the type of the "
244+
"argument's definition in kernel."));
240245
if (!platform::is_same_place(tensor.place(), expected_place)) {
241246
framework::LoDTensor tmp_tensor;
242247
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
@@ -248,6 +253,11 @@ void ReMakePtenDenseTensorFromVar(const framework::Variable& variable,
248253
// TODO(chenweihang): now we don't deal with row and height
249254
// by xiaowei's advice
250255
const auto& tensor = variable.Get<framework::SelectedRows>();
256+
PADDLE_ENFORCE(
257+
(arg_def.dtype == pten::TransToPtenDataType(tensor.value().type())),
258+
paddle::platform::errors::InvalidArgument(
259+
"The type of input data is diffrent from the type of the "
260+
"argument's definition in kernel."));
251261
if (!platform::is_same_place(tensor.value().place(), expected_place)) {
252262
framework::Tensor tmp_tensor;
253263
TensorCopySync(tensor.value(), expected_place, &tmp_tensor);

paddle/pten/include/math.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,18 @@ DenseTensor Divide(const ContextT& dev_ctx,
115115
ElementwiseDiv<T>(dev_ctx, x, y, axis, &dense_out);
116116
return dense_out;
117117
}
118+
119+
template <typename T, typename ContextT>
120+
DenseTensor Multiply(const ContextT& dev_ctx,
121+
const DenseTensor& x,
122+
const DenseTensor& y,
123+
int axis) {
124+
auto out_meta = ElementwiseInferShape(x.meta(), y.meta(), axis);
125+
const auto allocator =
126+
std::make_shared<paddle::experimental::DefaultAllocator>(
127+
dev_ctx.GetPlace());
128+
pten::DenseTensor dense_out(allocator, out_meta);
129+
ElementwiseMul<T>(dev_ctx, x, y, axis, &dense_out);
130+
return dense_out;
131+
}
118132
} // namespace pten

paddle/pten/kernels/cpu/math.cc

Lines changed: 20 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -64,56 +64,6 @@ void ScaleHost(const CPUContext& dev_ctx,
6464
out);
6565
}
6666

67-
template <typename T>
68-
void ElementwiseAdd(const CPUContext& dev_ctx,
69-
const DenseTensor& x,
70-
const DenseTensor& y,
71-
int axis,
72-
DenseTensor* out) {
73-
// allocate memory for out
74-
out->mutable_data<T>();
75-
76-
if (x.dims() == y.dims()) {
77-
SameDimsElementwiseCompute<general::SameDimsAddFunctor<CPUContext, T>>()(
78-
dev_ctx, x, y, out);
79-
} else {
80-
auto x_dims = x.dims();
81-
auto y_dims = y.dims();
82-
if (x_dims.size() >= y_dims.size()) {
83-
ElementwiseCompute<general::AddFunctor<T>, T>(
84-
dev_ctx, x, y, axis, general::AddFunctor<T>(), out);
85-
} else {
86-
ElementwiseCompute<general::InverseAddFunctor<T>, T>(
87-
dev_ctx, x, y, axis, general::InverseAddFunctor<T>(), out);
88-
}
89-
}
90-
}
91-
92-
template <typename T>
93-
void ElementwiseSub(const CPUContext& dev_ctx,
94-
const DenseTensor& x,
95-
const DenseTensor& y,
96-
int axis,
97-
DenseTensor* out) {
98-
// allocate memory for out
99-
out->mutable_data<T>();
100-
101-
if (x.dims() == y.dims()) {
102-
SameDimsElementwiseCompute<general::SameDimsSubFunctor<CPUContext, T>>()(
103-
dev_ctx, x, y, out);
104-
} else {
105-
auto x_dims = x.dims();
106-
auto y_dims = y.dims();
107-
if (x_dims.size() >= y_dims.size()) {
108-
ElementwiseCompute<general::SubFunctor<T>, T>(
109-
dev_ctx, x, y, axis, general::SubFunctor<T>(), out);
110-
} else {
111-
ElementwiseCompute<general::InverseSubFunctor<T>, T>(
112-
dev_ctx, x, y, axis, general::InverseSubFunctor<T>(), out);
113-
}
114-
}
115-
}
116-
11767
template <typename T>
11868
void ElementwiseDiv(const CPUContext& dev_ctx,
11969
const DenseTensor& x,
@@ -138,6 +88,15 @@ void ElementwiseDiv(const CPUContext& dev_ctx,
13888
}
13989
}
14090

91+
// Create the definition of ElementwiseAdd
92+
DEFINE_CPU_ELEMENTWISE_OP(Add)
93+
94+
// Create the definition of ElementwiseSub
95+
DEFINE_CPU_ELEMENTWISE_OP(Sub)
96+
97+
// Create the definition of ElementwiseMul
98+
DEFINE_CPU_ELEMENTWISE_OP(Mul)
99+
141100
} // namespace pten
142101

143102
// TODO(chenweihang): replace by better impl
@@ -208,3 +167,14 @@ PT_REGISTER_KERNEL("elementwise_div",
208167
int64_t,
209168
complex64,
210169
complex128) {}
170+
PT_REGISTER_KERNEL("elementwise_mul",
171+
CPU,
172+
ANY,
173+
pten::ElementwiseMul,
174+
float,
175+
double,
176+
int,
177+
int64_t,
178+
bool,
179+
complex64,
180+
complex128) {}

paddle/pten/kernels/cpu/math.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,36 @@ void ElementwiseDiv(const CPUContext& dev_ctx,
6666
const DenseTensor& y,
6767
int axis,
6868
DenseTensor* out);
69+
70+
template <typename T>
71+
void ElementwiseMul(const CPUContext& dev_ctx,
72+
const DenseTensor& x,
73+
const DenseTensor& y,
74+
int axis,
75+
DenseTensor* out);
6976
} // namespace pten
77+
78+
#define DEFINE_CPU_ELEMENTWISE_OP(name) \
79+
template <typename T> \
80+
void Elementwise##name(const CPUContext& dev_ctx, \
81+
const DenseTensor& x, \
82+
const DenseTensor& y, \
83+
int axis, \
84+
DenseTensor* out) { \
85+
out->mutable_data<T>(); \
86+
if (x.dims() == y.dims()) { \
87+
SameDimsElementwiseCompute< \
88+
general::SameDims##name##Functor<CPUContext, T>>()( \
89+
dev_ctx, x, y, out); \
90+
} else { \
91+
auto x_dims = x.dims(); \
92+
auto y_dims = y.dims(); \
93+
if (x_dims.size() >= y_dims.size()) { \
94+
ElementwiseCompute<general::name##Functor<T>, T>( \
95+
dev_ctx, x, y, axis, general::name##Functor<T>(), out); \
96+
} else { \
97+
ElementwiseCompute<general::Inverse##name##Functor<T>, T>( \
98+
dev_ctx, x, y, axis, general::Inverse##name##Functor<T>(), out); \
99+
} \
100+
} \
101+
}

0 commit comments

Comments
 (0)