Skip to content

Commit 756af9f

Browse files
authored
modify infershape of multiclass nms (#40059)
* modify infershape of multiclass nms
1 parent 831b69d commit 756af9f

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

paddle/fluid/operators/detection/multiclass_nms_op.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
9393
// Here the box_dims[0] is not the real dimension of output.
9494
// It will be rewritten in the computing kernel.
9595
if (score_size == 3) {
96-
ctx->SetOutputDim("Out", {box_dims[1], box_dims[2] + 2});
96+
ctx->SetOutputDim("Out", {-1, box_dims[2] + 2});
9797
} else {
9898
ctx->SetOutputDim("Out", {-1, box_dims[2] + 2});
9999
}
@@ -545,11 +545,10 @@ class MultiClassNMS2Op : public MultiClassNMSOp {
545545
void InferShape(framework::InferShapeContext* ctx) const override {
546546
MultiClassNMSOp::InferShape(ctx);
547547

548-
auto box_dims = ctx->GetInputDim("BBoxes");
549548
auto score_dims = ctx->GetInputDim("Scores");
550549
auto score_size = score_dims.size();
551550
if (score_size == 3) {
552-
ctx->SetOutputDim("Index", {box_dims[1], 1});
551+
ctx->SetOutputDim("Index", {-1, 1});
553552
} else {
554553
ctx->SetOutputDim("Index", {-1, 1});
555554
}

0 commit comments

Comments
 (0)