@@ -12,12 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212See the License for the specific language governing permissions and
1313limitations under the License. */
1414
15- #include " paddle/fluid/operators/strided_slice_op.h"
1615#include < algorithm>
1716#include < memory>
1817#include < string>
1918#include < vector>
19+
20+ #include " paddle/fluid/framework/infershape_utils.h"
21+ #include " paddle/fluid/framework/op_registry.h"
2022#include " paddle/fluid/operators/slice_op.h"
23+ #include " paddle/phi/core/infermeta_utils.h"
24+ #include " paddle/phi/infermeta/backward.h"
25+ #include " paddle/phi/kernels/funcs/strided_slice.h"
2126
2227namespace paddle {
2328namespace operators {
@@ -28,149 +33,6 @@ class StridedSliceOp : public framework::OperatorWithKernel {
2833 public:
2934 using framework::OperatorWithKernel::OperatorWithKernel;
3035
31- void InferShape (framework::InferShapeContext *ctx) const override {
32- OP_INOUT_CHECK (ctx->HasInput (" Input" ), " Input" , " Input" , " StridedSlice" );
33- OP_INOUT_CHECK (ctx->HasOutput (" Out" ), " Output" , " Out" , " StridedSlice" );
34- auto input_var_type = ctx->GetInputsVarType (" Input" )[0 ];
35- if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
36- if (ctx->IsRuntime ()) {
37- // shape is determined by Runtime.
38- return ;
39- }
40- }
41- auto in_dims = ctx->GetInputDim (" Input" );
42- PADDLE_ENFORCE_LT (
43- in_dims.size (), 7 ,
44- platform::errors::InvalidArgument (
45- " The dimension of StridedSlice operator's input should be less "
46- " than 7, but received dimension is %d." ,
47- in_dims.size ()));
48-
49- auto starts_int = ctx->Attrs ().Get <std::vector<int >>(" starts" );
50- auto ends_int = ctx->Attrs ().Get <std::vector<int >>(" ends" );
51- auto strides_int = ctx->Attrs ().Get <std::vector<int >>(" strides" );
52-
53- std::vector<int64_t > starts (starts_int.begin (), starts_int.end ());
54- std::vector<int64_t > ends (ends_int.begin (), ends_int.end ());
55- std::vector<int64_t > strides (strides_int.begin (), strides_int.end ());
56-
57- auto axes = ctx->Attrs ().Get <std::vector<int >>(" axes" );
58- auto infer_flags = ctx->Attrs ().Get <std::vector<int >>(" infer_flags" );
59- auto decrease_axis = ctx->Attrs ().Get <std::vector<int >>(" decrease_axis" );
60-
61- auto starts_size = starts.size ();
62- auto ends_size = ends.size ();
63- auto strides_size = strides.size ();
64-
65- for (size_t i = 0 ; i < axes.size (); ++i) {
66- PADDLE_ENFORCE_GE (axes[i], 0 ,
67- platform::errors::InvalidArgument (
68- " The axis should be greater than or equal to 0."
69- " But received %d of axes[%d]" ,
70- axes[i], i));
71- PADDLE_ENFORCE_LT (
72- axes[i], in_dims.size (),
73- platform::errors::InvalidArgument (
74- " The axes should be less than or equal to input tensor's rank."
75- " But received %d of axes[%d], input tensor shape [%d]" ,
76- axes[i], i, in_dims.size ()));
77- }
78-
79- if (ctx->HasInputs (" StartsTensorList" )) {
80- auto StartsTensorList = ctx->Inputs (" StartsTensorList" );
81- PADDLE_ENFORCE_GT (
82- StartsTensorList.size (), 0 ,
83- platform::errors::InvalidArgument (
84- " StridedSlice operator's StartsTensorList is empty." ));
85- starts_size = StartsTensorList.size ();
86- }
87- if (ctx->HasInputs (" EndsTensorList" )) {
88- auto EndsTensorList = ctx->Inputs (" EndsTensorList" );
89- PADDLE_ENFORCE_GT (
90- EndsTensorList.size (), 0 ,
91- platform::errors::InvalidArgument (
92- " StridedSlice operator's EndsTensorList is empty." ));
93- ends_size = EndsTensorList.size ();
94- }
95- if (ctx->HasInputs (" StridesTensorList" )) {
96- auto StridesTensorList = ctx->Inputs (" StridesTensorList" );
97- PADDLE_ENFORCE_GT (
98- StridesTensorList.size (), 0 ,
99- platform::errors::InvalidArgument (
100- " StridedSlice operator's StridesTensorList is empty." ));
101- strides_size = StridesTensorList.size ();
102- }
103-
104- auto tensor_input = false ;
105- if (ctx->HasInput (" EndsTensor" ) || ctx->HasInput (" StartsTensor" ) ||
106- ctx->HasInput (" StridesTensor" )) {
107- tensor_input = true ;
108- }
109- if (!ctx->HasInput (" EndsTensor" )) {
110- PADDLE_ENFORCE_EQ (
111- ends_size, axes.size (),
112- platform::errors::InvalidArgument (
113- " The size of ends attribute in StridedSlice operator is not "
114- " equal to the size of axes attribute. The ends attribute's size "
115- " is %d, axes attribute's size is %d." ,
116- ends_size, axes.size ()));
117- }
118- if (!ctx->HasInput (" StartsTensor" )) {
119- PADDLE_ENFORCE_EQ (
120- starts_size, axes.size (),
121- platform::errors::InvalidArgument (
122- " The size of starts attribute in StridedSlice operator is not "
123- " equal to the size of axes attribute. The starts attribute's "
124- " size is %d, axes attribute's size is %d." ,
125- starts_size, axes.size ()));
126- }
127- if (!ctx->HasInput (" StridesTensor" )) {
128- PADDLE_ENFORCE_EQ (
129- strides_size, axes.size (),
130- platform::errors::InvalidArgument (
131- " The size of strides attribute in StridedSlice operator is not "
132- " equal to the size of axes attribute. The strides attribute's "
133- " size is %d, axes attribute's size is %d." ,
134- strides_size, axes.size ()));
135- }
136- // we need to analysis strided slice op is valid for
137- // the parameter that we get from python front
138- std::vector<int64_t > out_dims_vector (in_dims.size (), -1 );
139- if (!tensor_input) {
140- StridedSliceOutDims (starts, ends, strides, axes, infer_flags, in_dims,
141- decrease_axis, out_dims_vector.data (), axes.size (),
142- true );
143- }
144- framework::DDim out_dims (phi::make_ddim (out_dims_vector));
145- // generate new shape
146- if (decrease_axis.size () > 0 ) {
147- std::vector<int64_t > new_out_shape;
148- for (size_t i = 0 ; i < decrease_axis.size (); ++i) {
149- if (ctx->IsRuntime () && infer_flags[i] != -1 ) {
150- PADDLE_ENFORCE_EQ (out_dims[decrease_axis[i]], 1 ,
151- platform::errors::InvalidArgument (
152- " the size of decrease dimension should be 1, "
153- " but received %d." ,
154- out_dims[decrease_axis[i]]));
155- }
156- out_dims[decrease_axis[i]] = 0 ;
157- }
158-
159- for (int i = 0 ; i < out_dims.size (); ++i) {
160- if (out_dims[i] != 0 ) {
161- new_out_shape.push_back (out_dims[i]);
162- }
163- }
164- if (new_out_shape.size () == 0 ) {
165- new_out_shape.push_back (1 );
166- }
167-
168- out_dims = phi::make_ddim (new_out_shape);
169- }
170- ctx->SetOutputDim (" Out" , out_dims);
171- ctx->ShareLoD (" Input" , /* ->*/ " Out" );
172- }
173-
17436 protected:
17537 framework::OpKernelType GetExpectedKernelType (
17638 const framework::ExecutionContext &ctx) const override {
@@ -304,26 +166,6 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
304166 public:
305167 using framework::OperatorWithKernel::OperatorWithKernel;
306168
307- void InferShape (framework::InferShapeContext *ctx) const override {
308- OP_INOUT_CHECK (ctx->HasInput (" Input" ), " Input" , " Input" ,
309- " StridedSliceGrad" );
310- OP_INOUT_CHECK (ctx->HasInput (framework::GradVarName (" Out" )), " Input" ,
311- " Out@GRAD" , " StridedSliceGrad" );
312-
313- auto input_var_type = ctx->GetInputsVarType (" Input" )[0 ];
314- if (input_var_type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
315- if (ctx->IsRuntime ()) {
316- // shape is determined by Runtime
317- return ;
318- }
319- }
320- auto x_dims = ctx->GetInputDim (" Input" );
321- auto x_grad_name = framework::GradVarName (" Input" );
322- if (ctx->HasOutput (x_grad_name)) {
323- ctx->SetOutputDim (x_grad_name, x_dims);
324- }
325- }
326-
327169 framework::OpKernelType GetExpectedKernelType (
328170 const framework::ExecutionContext &ctx) const override {
329171 return framework::OpKernelType (OperatorWithKernel::IndicateVarDataType (
@@ -384,35 +226,19 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(StridedSliceOpGradNoNeedBufferVarsInferer,
384226} // namespace paddle
385227
386228namespace ops = paddle::operators;
229+
230+ DECLARE_INFER_SHAPE_FUNCTOR (strided_slice, StridedSliceInferShape,
231+ PD_INFER_META (phi::StridedSliceInferMeta));
232+
387233REGISTER_OPERATOR (strided_slice, ops::StridedSliceOp, ops::StridedSliceOpMaker,
388234 ops::StridedSliceOpGradMaker<paddle::framework::OpDesc>,
389235 ops::StridedSliceOpGradMaker<paddle::imperative::OpBase>,
390- ops::StridedSliceOpVarTypeInference);
236+ ops::StridedSliceOpVarTypeInference, StridedSliceInferShape);
237+
238+ DECLARE_INFER_SHAPE_FUNCTOR (strided_slice_grad, StridedSliceGradInferShape,
239+ PD_INFER_META (phi::GeneralUnaryGradInferMeta));
391240
392241REGISTER_OPERATOR (strided_slice_grad, ops::StridedSliceOpGrad,
393242 ops::StridedSliceOpGradNoNeedBufferVarsInferer,
394- ops::StridedSliceGradOpVarTypeInference);
395-
396- REGISTER_OP_CPU_KERNEL (
397- strided_slice,
398- ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, bool >,
399- ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int >,
400- ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int64_t >,
401- ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, float >,
402- ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, double >,
403- ops::StridedSliceKernel<paddle::platform::CPUDeviceContext,
404- paddle::platform::complex <float >>,
405- ops::StridedSliceKernel<paddle::platform::CPUDeviceContext,
406- paddle::platform::complex <double >>);
407-
408- REGISTER_OP_CPU_KERNEL (
409- strided_slice_grad,
410- ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, bool >,
411- ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int >,
412- ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int64_t >,
413- ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, float >,
414- ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, double >,
415- ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext,
416- paddle::platform::complex <float >>,
417- ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext,
418- paddle::platform::complex <double >>);
243+ ops::StridedSliceGradOpVarTypeInference,
244+ StridedSliceGradInferShape);
0 commit comments