@@ -16,8 +16,11 @@ limitations under the License. */
1616#include < string>
1717#include < unordered_map>
1818
19- #include " paddle/fluid/operators/erf_op.h"
19+ #include " paddle/fluid/framework/infershape_utils.h"
20+ #include " paddle/fluid/framework/op_registry.h"
2021#include " paddle/fluid/platform/float16.h"
22+ #include " paddle/phi/core/infermeta_utils.h"
23+ #include " paddle/phi/infermeta/unary.h"
2124
2225namespace paddle {
2326namespace operators {
@@ -29,18 +32,6 @@ class ErfOp : public framework::OperatorWithKernel {
2932 const framework::AttributeMap &attrs)
3033 : OperatorWithKernel(type, inputs, outputs, attrs) {}
3134
32- void InferShape (framework::InferShapeContext *ctx) const override {
33- PADDLE_ENFORCE_EQ (ctx->HasInput (" X" ), true ,
34- platform::errors::InvalidArgument (
35- " Input(%s) of ErfOp should not be null." , " X" ));
36- PADDLE_ENFORCE_EQ (ctx->HasOutput (" Out" ), true ,
37- platform::errors::InvalidArgument (
38- " Output(%s) of ErfOp should not be null." , " Out" ));
39-
40- ctx->ShareDim (" X" , /* ->*/ " Out" );
41- ctx->ShareLoD (" X" , /* ->*/ " Out" );
42- }
43-
4435 protected:
4536 framework::OpKernelType GetExpectedKernelType (
4637 const framework::ExecutionContext &ctx) const override {
@@ -116,28 +107,10 @@ class ErfGradOpMaker : public framework::SingleGradOpMaker<T> {
116107
117108namespace ops = paddle::operators;
118109
110+ DECLARE_INFER_SHAPE_FUNCTOR (erf, ErfInferShapeFunctor,
111+ PD_INFER_META (phi::ErfInferMeta));
119112REGISTER_OPERATOR (erf, ops::ErfOp, ops::ErfOpMaker,
120113 ops::ErfGradOpMaker<paddle::framework::OpDesc>,
121- ops::ErfGradOpMaker<paddle::imperative::OpBase>);
114+ ops::ErfGradOpMaker<paddle::imperative::OpBase>,
115+ ErfInferShapeFunctor);
122116REGISTER_OPERATOR (erf_grad, ops::ErfGradOp);
123- REGISTER_OP_CPU_KERNEL (
124- erf, ops::ErfKernel<paddle::platform::CPUDeviceContext, float >,
125- ops::ErfKernel<paddle::platform::CPUDeviceContext, double >,
126- ops::ErfKernel<paddle::platform::CPUDeviceContext,
127- paddle::platform::float16>);
128- REGISTER_OP_CPU_KERNEL (
129- erf_grad, ops::ErfGradKernel<paddle::platform::CPUDeviceContext, float >,
130- ops::ErfGradKernel<paddle::platform::CPUDeviceContext, double >,
131- ops::ErfGradKernel<paddle::platform::CPUDeviceContext,
132- paddle::platform::float16>);
133-
134- REGISTER_OP_CUDA_KERNEL (
135- erf, ops::ErfKernel<paddle::platform::CUDADeviceContext, float >,
136- ops::ErfKernel<paddle::platform::CUDADeviceContext, double >,
137- ops::ErfKernel<paddle::platform::CUDADeviceContext,
138- paddle::platform::float16>);
139- REGISTER_OP_CUDA_KERNEL (
140- erf_grad, ops::ErfGradKernel<paddle::platform::CUDADeviceContext, float >,
141- ops::ErfGradKernel<paddle::platform::CUDADeviceContext, double >,
142- ops::ErfGradKernel<paddle::platform::CUDADeviceContext,
143- paddle::platform::float16>);
0 commit comments