@@ -12,9 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License. */
1414
15- #include " paddle/fluid/operators/linspace_op.h"
1615#include < string>
16+
17+ #include " paddle/fluid/framework/infershape_utils.h"
18+ #include " paddle/fluid/framework/op_registry.h"
1719#include " paddle/fluid/framework/op_version_registry.h"
20+ #include " paddle/phi/core/infermeta_utils.h"
21+ #include " paddle/phi/infermeta/ternary.h"
1822
1923namespace paddle {
2024namespace operators {
@@ -23,33 +27,6 @@ class LinspaceOp : public framework::OperatorWithKernel {
2327 public:
2428 using framework::OperatorWithKernel::OperatorWithKernel;
2529
26- void InferShape (framework::InferShapeContext *ctx) const override {
27- OP_INOUT_CHECK (ctx->HasInput (" Start" ), " Input" , " Start" , " linspace" );
28- OP_INOUT_CHECK (ctx->HasInput (" Stop" ), " Input" , " Stop" , " linspace" );
29- OP_INOUT_CHECK (ctx->HasInput (" Num" ), " Input" , " Num" , " linspace" );
30- OP_INOUT_CHECK (ctx->HasOutput (" Out" ), " Output" , " Out" , " linspace" );
31-
32- auto s_dims = ctx->GetInputDim (" Start" );
33- PADDLE_ENFORCE_EQ ((s_dims.size () == 1 ) && (s_dims[0 ] == 1 ), true ,
34- platform::errors::InvalidArgument (
35- " The shape of Input(Start) must be [1],"
36- " but received input shape is [%s]." ,
37- s_dims));
38- auto e_dims = ctx->GetInputDim (" Stop" );
39- PADDLE_ENFORCE_EQ ((e_dims.size () == 1 ) && (e_dims[0 ] == 1 ), true ,
40- platform::errors::InvalidArgument (
41- " The shape of Input(Stop) must be [1],"
42- " but received input shape is [%s]." ,
43- e_dims));
44- auto step_dims = ctx->GetInputDim (" Num" );
45- PADDLE_ENFORCE_EQ (
46- (step_dims.size () == 1 ) && (step_dims[0 ] == 1 ), true ,
47- platform::errors::InvalidArgument (" The shape of Input(Num) must be [1],"
48- " but received input shape is [%s]." ,
49- step_dims));
50- ctx->SetOutputDim (" Out" , {-1 });
51- }
52-
5330 protected:
5431 framework::OpKernelType GetExpectedKernelType (
5532 const framework::ExecutionContext &ctx) const override {
@@ -88,11 +65,13 @@ class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker {
8865} // namespace paddle
8966
9067namespace ops = paddle::operators;
91- REGISTER_OP_WITHOUT_GRADIENT (linspace, ops::LinspaceOp, ops::LinspaceOpMaker);
92- REGISTER_OP_CPU_KERNEL (linspace, ops::CPULinspaceKernel<float >,
93- ops::CPULinspaceKernel<int32_t >,
94- ops::CPULinspaceKernel<int64_t >,
95- ops::CPULinspaceKernel<double >);
68+ DECLARE_INFER_SHAPE_FUNCTOR (linspace, LinspaceInferShapeFunctor,
69+ PD_INFER_META (phi::LinspaceInferMeta));
70+ REGISTER_OPERATOR (
71+ linspace, ops::LinspaceOp, ops::LinspaceOpMaker,
72+ paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
73+ paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
74+ LinspaceInferShapeFunctor);
9675
9776REGISTER_OP_VERSION (linspace)
9877 .AddCheckpoint(
0 commit comments