Skip to content

Commit 579173d

Browse files
authored
[Phi] Move infershape of roi_pool to phi (#40682)
* move infershape of roi_pool to phi * polish code
1 parent 7f93e2b commit 579173d

File tree

3 files changed

+102
-80
lines changed

3 files changed

+102
-80
lines changed

paddle/fluid/operators/roi_pool_op.cc

Lines changed: 8 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include <memory>
16+
#include "paddle/fluid/framework/infershape_utils.h"
1617
#include "paddle/fluid/framework/op_registry.h"
1718
#include "paddle/fluid/framework/op_version_registry.h"
18-
#include "paddle/phi/kernels/roi_pool_kernel.h"
19+
#include "paddle/phi/core/infermeta_utils.h"
20+
#include "paddle/phi/infermeta/ternary.h"
1921

2022
namespace paddle {
2123
namespace operators {
@@ -27,74 +29,6 @@ class ROIPoolOp : public framework::OperatorWithKernel {
2729
public:
2830
using framework::OperatorWithKernel::OperatorWithKernel;
2931

30-
void InferShape(framework::InferShapeContext* ctx) const override {
31-
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "roi_pool");
32-
OP_INOUT_CHECK(ctx->HasInput("ROIs"), "Input", "ROIs", "roi_pool");
33-
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "roi_pool");
34-
OP_INOUT_CHECK(ctx->HasOutput("Argmax"), "Output", "Argmax", "roi_pool");
35-
36-
auto input_dims = ctx->GetInputDim("X");
37-
auto rois_dims = ctx->GetInputDim("ROIs");
38-
39-
if (ctx->HasInput("RoisNum")) {
40-
auto rois_num_dims = ctx->GetInputDim("RoisNum");
41-
PADDLE_ENFORCE_EQ(rois_num_dims.size(), 1,
42-
platform::errors::InvalidArgument(
43-
"The second dimension of RoisNum should "
44-
"be 1, but received dimension is %d",
45-
rois_num_dims.size()));
46-
}
47-
PADDLE_ENFORCE_EQ(input_dims.size(), 4,
48-
platform::errors::InvalidArgument(
49-
"The input data should be a four-dimensional "
50-
"tensor with [N,C,H,W], but received input data with "
51-
" %d dimension",
52-
input_dims.size()));
53-
PADDLE_ENFORCE_EQ(
54-
rois_dims.size(), 2,
55-
platform::errors::InvalidArgument(
56-
"ROIs should be a 2-D LoDTensor with shape (num_rois, 4)"
57-
"given as [[x1, y1, x2, y2], ...], but received ROIs is "
58-
"%d-dimensional LoDTensor",
59-
rois_dims.size()));
60-
PADDLE_ENFORCE_EQ(
61-
rois_dims[1], phi::kROISize,
62-
platform::errors::InvalidArgument(
63-
"ROIs should be a 2-D LoDTensor with shape (num_rois, 4)"
64-
"given as [[x1, y1, x2, y2], ...]. But the second dimension of "
65-
"the received data is %d",
66-
rois_dims[1]));
67-
68-
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
69-
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
70-
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
71-
72-
PADDLE_ENFORCE_GT(pooled_height, 0,
73-
platform::errors::OutOfRange(
74-
"The pooled output height must be greater than 0"
75-
"but received height is %d",
76-
pooled_height));
77-
PADDLE_ENFORCE_GT(pooled_width, 0,
78-
platform::errors::OutOfRange(
79-
"The pooled output width must be greater than 0"
80-
"but received width is %d",
81-
pooled_width));
82-
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
83-
platform::errors::OutOfRange(
84-
"The spatial scale must be greater than 0, "
85-
"but received spatial scale is %f",
86-
spatial_scale));
87-
88-
auto out_dims = input_dims;
89-
out_dims[0] = rois_dims[0];
90-
out_dims[1] = input_dims[1];
91-
out_dims[2] = pooled_height;
92-
out_dims[3] = pooled_width;
93-
94-
ctx->SetOutputDim("Out", out_dims);
95-
ctx->SetOutputDim("Argmax", out_dims);
96-
}
97-
9832
protected:
9933
framework::OpKernelType GetExpectedKernelType(
10034
const framework::ExecutionContext& ctx) const override {
@@ -213,9 +147,13 @@ class ROIPoolGradMaker : public framework::SingleGradOpMaker<T> {
213147
} // namespace paddle
214148

215149
namespace ops = paddle::operators;
150+
DECLARE_INFER_SHAPE_FUNCTOR(roi_pool, RoiPoolInferShapeFunctor,
151+
PD_INFER_META(phi::RoiPoolInferMeta));
152+
216153
REGISTER_OPERATOR(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker,
217154
ops::ROIPoolGradMaker<paddle::framework::OpDesc>,
218-
ops::ROIPoolGradMaker<paddle::imperative::OpBase>);
155+
ops::ROIPoolGradMaker<paddle::imperative::OpBase>,
156+
RoiPoolInferShapeFunctor);
219157
REGISTER_OPERATOR(roi_pool_grad, ops::ROIPoolGradOp);
220158

221159
REGISTER_OP_VERSION(roi_pool)

paddle/phi/infermeta/ternary.cc

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -340,51 +340,51 @@ void RoiAlignInferMeta(const MetaTensor& x,
340340
PADDLE_ENFORCE_EQ(
341341
boxes_num_dims.size(),
342342
1,
343-
phi::errors::InvalidArgument("The size of RoisNum should be 1"
343+
phi::errors::InvalidArgument("The size of boxes_num should be 1"
344344
", but received size = %d",
345345
boxes_num_dims.size()));
346346
}
347347
PADDLE_ENFORCE_EQ(input_dims.size(),
348348
4,
349349
phi::errors::InvalidArgument(
350-
"The format of Input(X) in"
351-
"RoIAlignOp is NCHW. And the rank of input must be 4. "
350+
"The format of Input(x) in"
351+
"RoiAlignOp is NCHW. And the rank of input must be 4. "
352352
"But received rank = %d",
353353
input_dims.size()));
354354
PADDLE_ENFORCE_EQ(boxes_dims.size(),
355355
2,
356-
phi::errors::InvalidArgument("The rank of Input(ROIs) "
357-
"in RoIAlignOp should be 2. "
358-
"But the rank of RoIs is %d",
356+
phi::errors::InvalidArgument("The rank of Input(boxes) "
357+
"in RoiAlignOp should be 2. "
358+
"But the rank of boxes is %d",
359359
boxes_dims.size()));
360360
if (config.is_runtime) {
361361
PADDLE_ENFORCE_EQ(boxes_dims[1],
362362
4,
363363
phi::errors::InvalidArgument(
364364
"The second dimension "
365-
"of Input(ROIs) should be 4. But received the "
365+
"of Input(boxes) should be 4. But received the "
366366
"dimension = %d",
367367
boxes_dims[1]));
368368
}
369369

370370
PADDLE_ENFORCE_GT(pooled_height,
371371
0,
372372
phi::errors::InvalidArgument(
373-
"The 'pooled_height' attribute in RoIAlignOp is "
373+
"The 'pooled_height' attribute in RoiAlignOp is "
374374
"invalid. The height must be greater than 0. But "
375375
"received 'pooled_height' = %d",
376376
pooled_height));
377377
PADDLE_ENFORCE_GT(pooled_width,
378378
0,
379379
phi::errors::InvalidArgument(
380-
"The 'pooled_width' attribute in RoIAlignOp is "
380+
"The 'pooled_width' attribute in RoiAlignOp is "
381381
"invalid. The width must be greater than 0. But "
382382
"received 'pooled_width' = %d",
383383
pooled_width));
384384
PADDLE_ENFORCE_GT(spatial_scale,
385385
0.0f,
386386
phi::errors::InvalidArgument(
387-
"The 'spatial_scale' attribute in RoIAlignOp is "
387+
"The 'spatial_scale' attribute in RoiAlignOp is "
388388
"invalid. The scale must be greater than 0. But "
389389
"received 'spatial_scale' = %f",
390390
spatial_scale));
@@ -399,6 +399,81 @@ void RoiAlignInferMeta(const MetaTensor& x,
399399
out->set_dtype(x.dtype());
400400
}
401401

402+
void RoiPoolInferMeta(const MetaTensor& x,
403+
const MetaTensor& boxes,
404+
paddle::optional<const MetaTensor&> boxes_num,
405+
int pooled_height,
406+
int pooled_width,
407+
float spatial_scale,
408+
MetaTensor* out,
409+
MetaTensor* arg_max) {
410+
auto input_dims = x.dims();
411+
auto boxes_dims = boxes.dims();
412+
413+
if (boxes_num) {
414+
auto boxes_num_dims = boxes_num->dims();
415+
PADDLE_ENFORCE_EQ(
416+
boxes_num_dims.size(),
417+
1,
418+
phi::errors::InvalidArgument("The second dimension of boxes_num should "
419+
"be 1, but received dimension is %d",
420+
boxes_num_dims.size()));
421+
}
422+
PADDLE_ENFORCE_EQ(input_dims.size(),
423+
4,
424+
phi::errors::InvalidArgument(
425+
"The input data should be a four-dimensional "
426+
"tensor with [N,C,H,W], but received input data with "
427+
" %d dimension",
428+
input_dims.size()));
429+
PADDLE_ENFORCE_EQ(
430+
boxes_dims.size(),
431+
2,
432+
phi::errors::InvalidArgument(
433+
"boxes should be a 2-D LoDTensor with shape (num_boxes, 4)"
434+
"given as [[x1, y1, x2, y2], ...], but received boxes is "
435+
"%d-dimensional LoDTensor",
436+
boxes_dims.size()));
437+
PADDLE_ENFORCE_EQ(
438+
boxes_dims[1],
439+
4,
440+
phi::errors::InvalidArgument(
441+
"boxes should be a 2-D LoDTensor with shape (num_boxes, 4)"
442+
"given as [[x1, y1, x2, y2], ...]. But the second dimension of "
443+
"the received data is %d",
444+
boxes_dims[1]));
445+
446+
PADDLE_ENFORCE_GT(
447+
pooled_height,
448+
0,
449+
phi::errors::OutOfRange("The pooled output height must be greater than 0"
450+
"but received height is %d",
451+
pooled_height));
452+
PADDLE_ENFORCE_GT(
453+
pooled_width,
454+
0,
455+
phi::errors::OutOfRange("The pooled output width must be greater than 0"
456+
"but received width is %d",
457+
pooled_width));
458+
PADDLE_ENFORCE_GT(
459+
spatial_scale,
460+
0.0f,
461+
phi::errors::OutOfRange("The spatial scale must be greater than 0, "
462+
"but received spatial scale is %f",
463+
spatial_scale));
464+
465+
auto out_dims = input_dims;
466+
out_dims[0] = boxes_dims[0];
467+
out_dims[1] = input_dims[1];
468+
out_dims[2] = pooled_height;
469+
out_dims[3] = pooled_width;
470+
471+
out->set_dims(out_dims);
472+
out->set_dtype(x.dtype());
473+
arg_max->set_dims(out_dims);
474+
arg_max->set_dtype(DataType::INT64);
475+
}
476+
402477
void ScatterInferMeta(const MetaTensor& x,
403478
const MetaTensor& index,
404479
const MetaTensor& updates,

paddle/phi/infermeta/ternary.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ void RoiAlignInferMeta(const MetaTensor& x,
8484
MetaTensor* out,
8585
MetaConfig config = MetaConfig());
8686

87+
void RoiPoolInferMeta(const MetaTensor& x,
88+
const MetaTensor& boxes,
89+
paddle::optional<const MetaTensor&> boxes_num,
90+
int pooled_height,
91+
int pooled_width,
92+
float spatial_scale,
93+
MetaTensor* out,
94+
MetaTensor* arg_max);
95+
8796
void ScatterInferMeta(const MetaTensor& x,
8897
const MetaTensor& index,
8998
const MetaTensor& updates,

0 commit comments

Comments
 (0)