Skip to content

Commit 7554f42

Browse files
Add nms op and batched_nms api (#40962)
* add nms op and batched_nms api
1 parent 510347f commit 7554f42

File tree

9 files changed

+740
-0
lines changed

9 files changed

+740
-0
lines changed

paddle/fluid/operators/detection/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ detection_library(yolo_box_op SRCS yolo_box_op.cc)
6666
detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu)
6767
detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu)
6868
detection_library(retinanet_detection_output_op SRCS retinanet_detection_output_op.cc)
69+
detection_library(nms_op SRCS nms_op.cc nms_op.cu)
6970

7071
if(WITH_GPU OR WITH_ROCM)
7172
set(TMPDEPS memory)
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/detection/nms_op.h"
16+
#include <vector>
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using framework::Tensor;
22+
23+
class NMSOpMaker : public framework::OpProtoAndCheckerMaker {
24+
public:
25+
void Make() override {
26+
AddInput("Boxes",
27+
"(Tensor) "
28+
"Boxes is a Tensor with shape [N, 4] "
29+
"N is the number of boxes "
30+
"in last dimension in format [x1, x2, y1, y2] "
31+
"the relation should be ``0 <= x1 < x2 && 0 <= y1 < y2``.");
32+
33+
AddOutput("KeepBoxesIdxs",
34+
"(Tensor) "
35+
"KeepBoxesIdxs is a Tensor with shape [N] ");
36+
AddAttr<float>(
37+
"iou_threshold",
38+
"iou_threshold is a threshold value used to compress similar boxes "
39+
"boxes with IoU > iou_threshold will be considered as overlapping "
40+
"and just one of them can be kept.")
41+
.SetDefault(1.0f)
42+
.AddCustomChecker([](const float& iou_threshold) {
43+
PADDLE_ENFORCE_LE(iou_threshold, 1.0f,
44+
platform::errors::InvalidArgument(
45+
"iou_threshold should less equal than 1.0 "
46+
"but got %f",
47+
iou_threshold));
48+
PADDLE_ENFORCE_GE(iou_threshold, 0.0f,
49+
platform::errors::InvalidArgument(
50+
"iou_threshold should greater equal than 0.0 "
51+
"but got %f",
52+
iou_threshold));
53+
});
54+
AddComment(R"DOC(
55+
NMS Operator.
56+
This Operator is used to perform Non-Maximum Compress for input boxes.
57+
Indices of boxes kept by NMS will be sorted by scores and output.
58+
)DOC");
59+
}
60+
};
61+
62+
class NMSOp : public framework::OperatorWithKernel {
63+
public:
64+
using framework::OperatorWithKernel::OperatorWithKernel;
65+
void InferShape(framework::InferShapeContext* ctx) const override {
66+
OP_INOUT_CHECK(ctx->HasInput("Boxes"), "Input", "Boxes", "NMS");
67+
OP_INOUT_CHECK(ctx->HasOutput("KeepBoxesIdxs"), "Output", "KeepBoxesIdxs",
68+
"NMS");
69+
70+
auto boxes_dim = ctx->GetInputDim("Boxes");
71+
PADDLE_ENFORCE_EQ(boxes_dim.size(), 2,
72+
platform::errors::InvalidArgument(
73+
"The Input Boxes must be 2-dimention "
74+
"whose shape must be [N, 4] "
75+
"N is the number of boxes "
76+
"in last dimension in format [x1, x2, y1, y2]. "));
77+
auto num_boxes = boxes_dim[0];
78+
79+
ctx->SetOutputDim("KeepBoxesIdxs", {num_boxes});
80+
}
81+
82+
protected:
83+
framework::OpKernelType GetExpectedKernelType(
84+
const framework::ExecutionContext& ctx) const override {
85+
return framework::OpKernelType(
86+
OperatorWithKernel::IndicateVarDataType(ctx, "Boxes"), ctx.GetPlace());
87+
}
88+
};
89+
90+
template <typename T>
91+
static void NMS(const T* boxes_data, int64_t* output_data, float threshold,
92+
int64_t num_boxes) {
93+
auto num_masks = CeilDivide(num_boxes, 64);
94+
std::vector<uint64_t> masks(num_masks, 0);
95+
96+
for (int64_t i = 0; i < num_boxes; ++i) {
97+
if (masks[i / 64] & 1ULL << (i % 64)) continue;
98+
T box_1[4];
99+
for (int k = 0; k < 4; ++k) {
100+
box_1[k] = boxes_data[i * 4 + k];
101+
}
102+
for (int64_t j = i + 1; j < num_boxes; ++j) {
103+
if (masks[j / 64] & 1ULL << (j % 64)) continue;
104+
T box_2[4];
105+
for (int k = 0; k < 4; ++k) {
106+
box_2[k] = boxes_data[j * 4 + k];
107+
}
108+
bool is_overlap = CalculateIoU<T>(box_1, box_2, threshold);
109+
if (is_overlap) {
110+
masks[j / 64] |= 1ULL << (j % 64);
111+
}
112+
}
113+
}
114+
115+
int64_t output_data_idx = 0;
116+
for (int64_t i = 0; i < num_boxes; ++i) {
117+
if (masks[i / 64] & 1ULL << (i % 64)) continue;
118+
output_data[output_data_idx++] = i;
119+
}
120+
121+
for (; output_data_idx < num_boxes; ++output_data_idx) {
122+
output_data[output_data_idx] = 0;
123+
}
124+
}
125+
126+
template <typename T>
127+
class NMSKernel : public framework::OpKernel<T> {
128+
public:
129+
void Compute(const framework::ExecutionContext& context) const override {
130+
const Tensor* boxes = context.Input<Tensor>("Boxes");
131+
Tensor* output = context.Output<Tensor>("KeepBoxesIdxs");
132+
int64_t* output_data = output->mutable_data<int64_t>(context.GetPlace());
133+
auto threshold = context.template Attr<float>("iou_threshold");
134+
NMS<T>(boxes->data<T>(), output_data, threshold, boxes->dims()[0]);
135+
}
136+
};
137+
138+
} // namespace operators
139+
} // namespace paddle
140+
141+
namespace ops = paddle::operators;
142+
143+
REGISTER_OPERATOR(
144+
nms, ops::NMSOp, ops::NMSOpMaker,
145+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
146+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
147+
REGISTER_OP_CPU_KERNEL(nms, ops::NMSKernel<float>, ops::NMSKernel<double>);
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <vector>
16+
#include "paddle/fluid/operators/detection/nms_op.h"
17+
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
18+
19+
static const int64_t threadsPerBlock = sizeof(int64_t) * 8;
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using framework::Tensor;
25+
26+
template <typename T>
27+
static __global__ void NMS(const T* boxes_data, float threshold,
28+
int64_t num_boxes, uint64_t* masks) {
29+
auto raw_start = blockIdx.y;
30+
auto col_start = blockIdx.x;
31+
if (raw_start > col_start) return;
32+
33+
const int raw_last_storage =
34+
min(num_boxes - raw_start * threadsPerBlock, threadsPerBlock);
35+
const int col_last_storage =
36+
min(num_boxes - col_start * threadsPerBlock, threadsPerBlock);
37+
38+
if (threadIdx.x < raw_last_storage) {
39+
uint64_t mask = 0;
40+
auto current_box_idx = raw_start * threadsPerBlock + threadIdx.x;
41+
const T* current_box = boxes_data + current_box_idx * 4;
42+
for (int i = 0; i < col_last_storage; ++i) {
43+
const T* target_box = boxes_data + (col_start * threadsPerBlock + i) * 4;
44+
if (CalculateIoU<T>(current_box, target_box, threshold)) {
45+
mask |= 1ULL << i;
46+
}
47+
}
48+
const int blocks_per_line = CeilDivide(num_boxes, threadsPerBlock);
49+
masks[current_box_idx * blocks_per_line + col_start] = mask;
50+
}
51+
}
52+
53+
template <typename T>
54+
class NMSCudaKernel : public framework::OpKernel<T> {
55+
public:
56+
void Compute(const framework::ExecutionContext& context) const override {
57+
const Tensor* boxes = context.Input<Tensor>("Boxes");
58+
Tensor* output = context.Output<Tensor>("KeepBoxesIdxs");
59+
auto* output_data = output->mutable_data<int64_t>(context.GetPlace());
60+
61+
auto threshold = context.template Attr<float>("iou_threshold");
62+
const int64_t num_boxes = boxes->dims()[0];
63+
const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock);
64+
65+
dim3 block(threadsPerBlock);
66+
dim3 grid(blocks_per_line, blocks_per_line);
67+
68+
auto mask_data =
69+
memory::Alloc(context.cuda_device_context(),
70+
num_boxes * blocks_per_line * sizeof(uint64_t));
71+
uint64_t* mask_dev = reinterpret_cast<uint64_t*>(mask_data->ptr());
72+
NMS<T><<<grid, block, 0, context.cuda_device_context().stream()>>>(
73+
boxes->data<T>(), threshold, num_boxes, mask_dev);
74+
75+
std::vector<uint64_t> mask_host(num_boxes * blocks_per_line);
76+
memory::Copy(platform::CPUPlace(), mask_host.data(), context.GetPlace(),
77+
mask_dev, num_boxes * blocks_per_line * sizeof(uint64_t),
78+
context.cuda_device_context().stream());
79+
80+
std::vector<int64_t> remv(blocks_per_line);
81+
82+
std::vector<int64_t> keep_boxes_idxs(num_boxes);
83+
int64_t* output_host = keep_boxes_idxs.data();
84+
85+
int64_t last_box_num = 0;
86+
for (int64_t i = 0; i < num_boxes; ++i) {
87+
auto remv_element_id = i / threadsPerBlock;
88+
auto remv_bit_id = i % threadsPerBlock;
89+
if (!(remv[remv_element_id] & 1ULL << remv_bit_id)) {
90+
output_host[last_box_num++] = i;
91+
uint64_t* current_mask = mask_host.data() + i * blocks_per_line;
92+
for (auto j = remv_element_id; j < blocks_per_line; ++j) {
93+
remv[j] |= current_mask[j];
94+
}
95+
}
96+
}
97+
memory::Copy(context.GetPlace(), output_data, platform::CPUPlace(),
98+
output_host, sizeof(int64_t) * num_boxes,
99+
context.cuda_device_context().stream());
100+
}
101+
};
102+
103+
} // namespace operators
104+
} // namespace paddle
105+
106+
namespace ops = paddle::operators;
107+
REGISTER_OP_CUDA_KERNEL(nms, ops::NMSCudaKernel<float>,
108+
ops::NMSCudaKernel<double>);
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/framework/operator.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
HOSTDEVICE static inline int64_t CeilDivide(int64_t n, int64_t m) {
24+
return (n + m - 1) / m;
25+
}
26+
27+
template <typename T>
28+
HOSTDEVICE inline bool CalculateIoU(const T* const box_1, const T* const box_2,
29+
const float threshold) {
30+
auto box_1_x0 = box_1[0], box_1_y0 = box_1[1];
31+
auto box_1_x1 = box_1[2], box_1_y1 = box_1[3];
32+
auto box_2_x0 = box_2[0], box_2_y0 = box_2[1];
33+
auto box_2_x1 = box_2[2], box_2_y1 = box_2[3];
34+
35+
auto inter_box_x0 = box_1_x0 > box_2_x0 ? box_1_x0 : box_2_x0;
36+
auto inter_box_y0 = box_1_y0 > box_2_y0 ? box_1_y0 : box_2_y0;
37+
auto inter_box_x1 = box_1_x1 < box_2_x1 ? box_1_x1 : box_2_x1;
38+
auto inter_box_y1 = box_1_y1 < box_2_y1 ? box_1_y1 : box_2_y1;
39+
40+
auto inter_width =
41+
inter_box_x1 - inter_box_x0 > 0 ? inter_box_x1 - inter_box_x0 : 0;
42+
auto inter_height =
43+
inter_box_y1 - inter_box_y0 > 0 ? inter_box_y1 - inter_box_y0 : 0;
44+
auto inter_area = inter_width * inter_height;
45+
auto union_area = (box_1_x1 - box_1_x0) * (box_1_y1 - box_1_y0) +
46+
(box_2_x1 - box_2_x0) * (box_2_y1 - box_2_y0) - inter_area;
47+
return inter_area / union_area > threshold;
48+
}
49+
50+
} // namespace operators
51+
} // namespace paddle

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ endif()
234234

235235
if(WIN32)
236236
LIST(REMOVE_ITEM TEST_OPS test_complex_matmul)
237+
LIST(REMOVE_ITEM TEST_OPS test_ops_nms)
237238
endif()
238239

239240
LIST(REMOVE_ITEM TEST_OPS test_fleet_checkpoint)

0 commit comments

Comments
 (0)