1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15- #include " paddle/fluid/operators/isfinite_v2_op.h"
16-
1715#include < string>
1816
17+ #include " paddle/fluid/framework/infershape_utils.h"
18+ #include " paddle/fluid/framework/op_registry.h"
1919#include " paddle/fluid/operators/common_infer_shape_functions.h"
20+ #include " paddle/phi/core/infermeta_utils.h"
21+ #include " paddle/phi/infermeta/unary.h"
2022
2123namespace paddle {
2224namespace framework {
@@ -49,11 +51,6 @@ class OverflowV2Op : public framework::OperatorWithKernel {
4951 const framework::VariableNameMap &outputs,
5052 const framework::AttributeMap &attrs)
5153 : OperatorWithKernel(type, inputs, outputs, attrs) {}
52- void InferShape (framework::InferShapeContext *ctx) const override {
53- OP_INOUT_CHECK (ctx->HasInput (" X" ), " Input" , " X" , " isfinitev2" );
54- OP_INOUT_CHECK (ctx->HasOutput (" Out" ), " Output" , " Out" , " isfinitev2" );
55- UnaryOpUnchangedInferShape (ctx);
56- }
5754
5855 protected:
5956 framework::OpKernelType GetExpectedKernelType (
@@ -104,6 +101,14 @@ element of X as a tensor.
104101} // namespace paddle
105102
106103namespace ops = paddle::operators;
104+ DECLARE_INFER_SHAPE_FUNCTOR (isinf_v2, IsinfInferShapeFunctor,
105+ PD_INFER_META (phi::IsfiniteInferMeta));
106+
107+ DECLARE_INFER_SHAPE_FUNCTOR (isnan_v2, IsnanInferShapeFunctor,
108+ PD_INFER_META (phi::IsfiniteInferMeta));
109+
110+ DECLARE_INFER_SHAPE_FUNCTOR (isfinite_v2, IsfiniteInferShapeFunctor,
111+ PD_INFER_META (phi::IsfiniteInferMeta));
107112
108113#define REGISTER_V2OP_MAKER (op_type, comment ) \
109114 namespace paddle { \
@@ -124,50 +129,17 @@ REGISTER_V2OP_MAKER(isfinite_v2, "isfinitev2(X)");
124129REGISTER_OPERATOR (
125130 isinf_v2, ops::OverflowV2Op, ops::_isinf_v2OverflowV2OpMaker,
126131 paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
127- paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
132+ paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
133+ IsinfInferShapeFunctor);
128134
129135REGISTER_OPERATOR (
130136 isnan_v2, ops::OverflowV2Op, ops::_isnan_v2OverflowV2OpMaker,
131137 paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
132- paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
138+ paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
139+ IsnanInferShapeFunctor);
133140
134141REGISTER_OPERATOR (
135142 isfinite_v2, ops::OverflowV2Op, ops::_isfinite_v2OverflowV2OpMaker,
136143 paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
137- paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
138-
139- REGISTER_OP_CPU_KERNEL (isnan_v2,
140- ops::OverflowKernel<paddle::platform::CPUDeviceContext,
141- int , ops::NANV2Functor>,
142- ops::OverflowKernel<paddle::platform::CPUDeviceContext,
143- int64_t , ops::NANV2Functor>,
144- ops::OverflowKernel<paddle::platform::CPUDeviceContext,
145- float , ops::NANV2Functor>,
146- ops::OverflowKernel<paddle::platform::CPUDeviceContext,
147- double , ops::NANV2Functor>,
148- ops::OverflowKernel<paddle::platform::CPUDeviceContext,
149- plat::float16, ops::NANV2Functor>);
150-
151- REGISTER_OP_CPU_KERNEL (
152- isinf_v2, ops::OverflowKernel<paddle::platform::CPUDeviceContext, int ,
153- ops::InfinityV2Functor>,
154- ops::OverflowKernel<paddle::platform::CPUDeviceContext, int64_t ,
155- ops::InfinityV2Functor>,
156- ops::OverflowKernel<paddle::platform::CPUDeviceContext, float ,
157- ops::InfinityV2Functor>,
158- ops::OverflowKernel<paddle::platform::CPUDeviceContext, double ,
159- ops::InfinityV2Functor>,
160- ops::OverflowKernel<paddle::platform::CPUDeviceContext, plat::float16,
161- ops::InfinityV2Functor>);
162-
163- REGISTER_OP_CPU_KERNEL (
164- isfinite_v2, ops::OverflowKernel<paddle::platform::CPUDeviceContext, int ,
165- ops::IsfiniteV2Functor>,
166- ops::OverflowKernel<paddle::platform::CPUDeviceContext, int64_t ,
167- ops::IsfiniteV2Functor>,
168- ops::OverflowKernel<paddle::platform::CPUDeviceContext, float ,
169- ops::IsfiniteV2Functor>,
170- ops::OverflowKernel<paddle::platform::CPUDeviceContext, double ,
171- ops::IsfiniteV2Functor>,
172- ops::OverflowKernel<paddle::platform::CPUDeviceContext, plat::float16,
173- ops::IsfiniteV2Functor>);
144+ paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
145+ IsfiniteInferShapeFunctor);
0 commit comments