@@ -17,6 +17,10 @@ limitations under the License. */
1717#include " paddle/fluid/platform/complex.h"
1818#include " paddle/fluid/platform/float16.h"
1919
20+ // only can include the headers in paddle/top/api dirs
21+ #include " paddle/pten/api/lib/utils/tensor_utils.h"
22+ #include " paddle/pten/include/core.h"
23+ #include " paddle/pten/include/math.h"
2024namespace ops = paddle::operators;
2125namespace plat = paddle::platform;
2226
@@ -28,15 +32,39 @@ class ElementwiseMulKernel<platform::CUDADeviceContext, T>
2832 : public framework::OpKernel<T> {
2933 public:
3034 void Compute (const framework::ExecutionContext& ctx) const override {
31- framework::Tensor x_for_selectedrows;
32- std::vector<const framework::Tensor*> ins;
33- std::vector<framework::Tensor*> outs;
35+ auto x_var = ctx.InputVar (" X" );
36+ PADDLE_ENFORCE_EQ (x_var != nullptr , true ,
37+ platform::errors::InvalidArgument (
38+ " Cannot get input Variable X, Variable name = %s." ,
39+ ctx.InputName (" X" )));
3440 const auto & cuda_ctx =
3541 ctx.template device_context <platform::CUDADeviceContext>();
36-
37- int axis = PackTensorsIntoVector<T>(ctx, &ins, &outs, &x_for_selectedrows);
38- LaunchElementwiseCudaKernel<ElementwiseType::kBinary , T, T>(
39- cuda_ctx, ins, &outs, axis, MulFunctor<T>());
42+ if (x_var->IsType <framework::SelectedRows>()) {
43+ framework::Tensor x_for_selectedrows;
44+ std::vector<const framework::Tensor*> ins;
45+ std::vector<framework::Tensor*> outs;
46+ int axis =
47+ PackTensorsIntoVector<T>(ctx, &ins, &outs, &x_for_selectedrows);
48+ LaunchElementwiseCudaKernel<ElementwiseType::kBinary , T, T>(
49+ cuda_ctx, ins, &outs, axis, MulFunctor<T>());
50+ } else if (x_var->IsType <framework::LoDTensor>()) {
51+ auto * x_lod = ctx.Input <framework::LoDTensor>(" X" );
52+ auto * y_lod = ctx.Input <framework::LoDTensor>(" Y" );
53+ auto * z_lod = ctx.Output <framework::LoDTensor>(" Out" );
54+ z_lod->mutable_data <T>(ctx.GetPlace ());
55+
56+ int axis = ctx.Attr <int >(" axis" );
57+ auto pt_x = paddle::experimental::MakePtenDenseTensor (*x_lod);
58+ auto pt_y = paddle::experimental::MakePtenDenseTensor (*y_lod);
59+ auto pt_z = paddle::experimental::MakePtenDenseTensor (*z_lod);
60+ pten::ElementwiseMul<T>(cuda_ctx, *pt_x.get (), *pt_y.get (), axis,
61+ pt_z.get ());
62+ } else {
63+ PADDLE_THROW (platform::errors::InvalidArgument (
64+ " X's type[%s] is not supported by elementwise_op. X's type should be "
65+ " LoDTensor or SelectedRows." ,
66+ framework::ToTypeName (x_var->Type ())));
67+ }
4068 }
4169};
4270
0 commit comments