1616#include < string>
1717#include < unordered_map>
1818#include < vector>
19+ #include " paddle/fluid/framework/infershape_utils.h"
1920#include " paddle/fluid/framework/op_registry.h"
21+ #include " paddle/phi/core/infermeta_utils.h"
22+ #include " paddle/phi/infermeta/unary.h"
2023#ifdef PADDLE_WITH_MKLDNN
2124#include " paddle/fluid/platform/mkldnn_helper.h"
2225#endif
@@ -27,16 +30,6 @@ namespace operators {
2730class AbsOp : public framework ::OperatorWithKernel {
2831 public:
2932 using framework::OperatorWithKernel::OperatorWithKernel;
30-
31- void InferShape (framework::InferShapeContext* ctx) const override {
32- OP_INOUT_CHECK (ctx->HasInput (" X" ), " Input" , " X" , " abs" );
33- OP_INOUT_CHECK (ctx->HasOutput (" Out" ), " Output" , " Out" , " abs" );
34-
35- auto in_dims = ctx->GetInputDim (" X" );
36-
37- ctx->SetOutputDim (" Out" , in_dims);
38- ctx->ShareLoD (" X" , /* ->*/ " Out" );
39- }
4033};
4134
4235class AbsOpMaker : public framework ::OpProtoAndCheckerMaker {
@@ -148,11 +141,15 @@ class AbsDoubleGradOp : public framework::OperatorWithKernel {
148141} // namespace operators
149142} // namespace paddle
150143
144+ DELCARE_INFER_SHAPE_FUNCTOR (abs, AbsInferShapeFunctor,
145+ PT_INFER_META (phi::UnchangedInferMeta));
146+
151147namespace ops = paddle::operators;
152148
153149REGISTER_OPERATOR (abs, ops::AbsOp, ops::AbsOpMaker,
154150 ops::AbsGradMaker<paddle::framework::OpDesc>,
155- ops::AbsGradMaker<paddle::imperative::OpBase>);
151+ ops::AbsGradMaker<paddle::imperative::OpBase>,
152+ AbsInferShapeFunctor);
156153
157154REGISTER_OPERATOR (abs_grad, ops::AbsGradOp,
158155 ops::AbsDoubleGradMaker<paddle::framework::OpDesc>,
0 commit comments