@@ -340,51 +340,51 @@ void RoiAlignInferMeta(const MetaTensor& x,
340340 PADDLE_ENFORCE_EQ (
341341 boxes_num_dims.size (),
342342 1 ,
343- phi::errors::InvalidArgument (" The size of RoisNum should be 1"
343+ phi::errors::InvalidArgument (" The size of boxes_num should be 1"
344344 " , but received size = %d" ,
345345 boxes_num_dims.size ()));
346346 }
347347 PADDLE_ENFORCE_EQ (input_dims.size (),
348348 4 ,
349349 phi::errors::InvalidArgument (
350- " The format of Input(X ) in"
351- " RoIAlignOp is NCHW. And the rank of input must be 4. "
350+ " The format of Input(x ) in"
351+ " RoiAlignOp is NCHW. And the rank of input must be 4. "
352352 " But received rank = %d" ,
353353 input_dims.size ()));
354354 PADDLE_ENFORCE_EQ (boxes_dims.size (),
355355 2 ,
356- phi::errors::InvalidArgument (" The rank of Input(ROIs ) "
357- " in RoIAlignOp should be 2. "
358- " But the rank of RoIs is %d" ,
356+ phi::errors::InvalidArgument (" The rank of Input(boxes ) "
357+ " in RoiAlignOp should be 2. "
358+ " But the rank of boxes is %d" ,
359359 boxes_dims.size ()));
360360 if (config.is_runtime ) {
361361 PADDLE_ENFORCE_EQ (boxes_dims[1 ],
362362 4 ,
363363 phi::errors::InvalidArgument (
364364 " The second dimension "
365- " of Input(ROIs ) should be 4. But received the "
365+ " of Input(boxes ) should be 4. But received the "
366366 " dimension = %d" ,
367367 boxes_dims[1 ]));
368368 }
369369
370370 PADDLE_ENFORCE_GT (pooled_height,
371371 0 ,
372372 phi::errors::InvalidArgument (
373- " The 'pooled_height' attribute in RoIAlignOp is "
373+ " The 'pooled_height' attribute in RoiAlignOp is "
374374 " invalid. The height must be greater than 0. But "
375375 " received 'pooled_height' = %d" ,
376376 pooled_height));
377377 PADDLE_ENFORCE_GT (pooled_width,
378378 0 ,
379379 phi::errors::InvalidArgument (
380- " The 'pooled_width' attribute in RoIAlignOp is "
380+ " The 'pooled_width' attribute in RoiAlignOp is "
381381 " invalid. The width must be greater than 0. But "
382382 " received 'pooled_width' = %d" ,
383383 pooled_width));
384384 PADDLE_ENFORCE_GT (spatial_scale,
385385 0 .0f ,
386386 phi::errors::InvalidArgument (
387- " The 'spatial_scale' attribute in RoIAlignOp is "
387+ " The 'spatial_scale' attribute in RoiAlignOp is "
388388 " invalid. The scale must be greater than 0. But "
389389 " received 'spatial_scale' = %f" ,
390390 spatial_scale));
@@ -399,6 +399,81 @@ void RoiAlignInferMeta(const MetaTensor& x,
399399 out->set_dtype (x.dtype ());
400400}
401401
402+ void RoiPoolInferMeta (const MetaTensor& x,
403+ const MetaTensor& boxes,
404+ paddle::optional<const MetaTensor&> boxes_num,
405+ int pooled_height,
406+ int pooled_width,
407+ float spatial_scale,
408+ MetaTensor* out,
409+ MetaTensor* arg_max) {
410+ auto input_dims = x.dims ();
411+ auto boxes_dims = boxes.dims ();
412+
413+ if (boxes_num) {
414+ auto boxes_num_dims = boxes_num->dims ();
415+ PADDLE_ENFORCE_EQ (
416+ boxes_num_dims.size (),
417+ 1 ,
418+ phi::errors::InvalidArgument (" The second dimension of boxes_num should "
419+ " be 1, but received dimension is %d" ,
420+ boxes_num_dims.size ()));
421+ }
422+ PADDLE_ENFORCE_EQ (input_dims.size (),
423+ 4 ,
424+ phi::errors::InvalidArgument (
425+ " The input data should be a four-dimensional "
426+ " tensor with [N,C,H,W], but received input data with "
427+ " %d dimension" ,
428+ input_dims.size ()));
429+ PADDLE_ENFORCE_EQ (
430+ boxes_dims.size (),
431+ 2 ,
432+ phi::errors::InvalidArgument (
433+ " boxes should be a 2-D LoDTensor with shape (num_boxes, 4)"
434+ " given as [[x1, y1, x2, y2], ...], but received boxes is "
435+ " %d-dimensional LoDTensor" ,
436+ boxes_dims.size ()));
437+ PADDLE_ENFORCE_EQ (
438+ boxes_dims[1 ],
439+ 4 ,
440+ phi::errors::InvalidArgument (
441+ " boxes should be a 2-D LoDTensor with shape (num_boxes, 4)"
442+ " given as [[x1, y1, x2, y2], ...]. But the second dimension of "
443+ " the received data is %d" ,
444+ boxes_dims[1 ]));
445+
446+ PADDLE_ENFORCE_GT (
447+ pooled_height,
448+ 0 ,
449+ phi::errors::OutOfRange (" The pooled output height must be greater than 0"
450+ " but received height is %d" ,
451+ pooled_height));
452+ PADDLE_ENFORCE_GT (
453+ pooled_width,
454+ 0 ,
455+ phi::errors::OutOfRange (" The pooled output width must be greater than 0"
456+ " but received width is %d" ,
457+ pooled_width));
458+ PADDLE_ENFORCE_GT (
459+ spatial_scale,
460+ 0 .0f ,
461+ phi::errors::OutOfRange (" The spatial scale must be greater than 0, "
462+ " but received spatial scale is %f" ,
463+ spatial_scale));
464+
465+ auto out_dims = input_dims;
466+ out_dims[0 ] = boxes_dims[0 ];
467+ out_dims[1 ] = input_dims[1 ];
468+ out_dims[2 ] = pooled_height;
469+ out_dims[3 ] = pooled_width;
470+
471+ out->set_dims (out_dims);
472+ out->set_dtype (x.dtype ());
473+ arg_max->set_dims (out_dims);
474+ arg_max->set_dtype (DataType::INT64);
475+ }
476+
402477void ScatterInferMeta (const MetaTensor& x,
403478 const MetaTensor& index,
404479 const MetaTensor& updates,
0 commit comments