- Notifications
You must be signed in to change notification settings - Fork 5.9k
Add nms op and batched_nms api #40962
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
heavengate merged 23 commits into PaddlePaddle:develop from RichardWooSJTU:add_nms_op_and_batched_nms Apr 5, 2022
Merged
Changes from all commits
Commits
Show all changes
23 commits Select commit Hold shift + click to select a range
0f7c4bf add nms op and batched_nms api
RichardWooSJTU bd9918f modify description of nms op
RichardWooSJTU 7ea82ed fix error msg of PADDLE_ENFORCE
RichardWooSJTU b3b48de delete debug info
RichardWooSJTU 8eeef56 modify HOSTDEVICE keyword
RichardWooSJTU 8d9a9ba accelerate test
RichardWooSJTU 042eb6b Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
RichardWooSJTU a17bfb1 fix rocm compile error
RichardWooSJTU 918cd3f modify api doc and fix shape bug
RichardWooSJTU 264ab25 fix topk error when compile time
RichardWooSJTU a4537c7 add api to __all__
RichardWooSJTU 085ef8a fix doc string
RichardWooSJTU 407b831 fix doc string
RichardWooSJTU b2e7ed5 fix doc example and math error
RichardWooSJTU 7b52f99 fix doc example and math error
RichardWooSJTU 905acf8 fix doc math error
RichardWooSJTU e421ca1 fix doc math error
RichardWooSJTU 611a872 delete duplicated code
RichardWooSJTU 6186d8c try to fix CI-Windows-Inference memory error
RichardWooSJTU 9f92dcd merge nms and batched_nms
RichardWooSJTU 3052238 fix coverage
RichardWooSJTU dfb5f20 Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
RichardWooSJTU 615aff7 skip test_ops_nms
RichardWooSJTU File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
| | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
| | ||
| #include "paddle/fluid/operators/detection/nms_op.h" | ||
| #include <vector> | ||
| | ||
| namespace paddle { | ||
| namespace operators { | ||
| | ||
| using framework::Tensor; | ||
| | ||
| class NMSOpMaker : public framework::OpProtoAndCheckerMaker { | ||
| public: | ||
| void Make() override { | ||
| AddInput("Boxes", | ||
| "(Tensor) " | ||
| "Boxes is a Tensor with shape [N, 4] " | ||
| "N is the number of boxes " | ||
| "in last dimension in format [x1, x2, y1, y2] " | ||
| "the relation should be ``0 <= x1 < x2 && 0 <= y1 < y2``."); | ||
| | ||
| AddOutput("KeepBoxesIdxs", | ||
| "(Tensor) " | ||
| "KeepBoxesIdxs is a Tensor with shape [N] "); | ||
| AddAttr<float>( | ||
| "iou_threshold", | ||
| "iou_threshold is a threshold value used to compress similar boxes " | ||
| "boxes with IoU > iou_threshold will be considered as overlapping " | ||
| "and just one of them can be kept.") | ||
| .SetDefault(1.0f) | ||
| .AddCustomChecker([](const float& iou_threshold) { | ||
| PADDLE_ENFORCE_LE(iou_threshold, 1.0f, | ||
| platform::errors::InvalidArgument( | ||
| "iou_threshold should less equal than 1.0 " | ||
| "but got %f", | ||
| iou_threshold)); | ||
| PADDLE_ENFORCE_GE(iou_threshold, 0.0f, | ||
| platform::errors::InvalidArgument( | ||
| "iou_threshold should greater equal than 0.0 " | ||
| "but got %f", | ||
| iou_threshold)); | ||
| }); | ||
| AddComment(R"DOC( | ||
| NMS Operator. | ||
| This Operator is used to perform Non-Maximum Compress for input boxes. | ||
| Indices of boxes kept by NMS will be sorted by scores and output. | ||
| )DOC"); | ||
| } | ||
| }; | ||
| | ||
| class NMSOp : public framework::OperatorWithKernel { | ||
| public: | ||
| using framework::OperatorWithKernel::OperatorWithKernel; | ||
| void InferShape(framework::InferShapeContext* ctx) const override { | ||
| OP_INOUT_CHECK(ctx->HasInput("Boxes"), "Input", "Boxes", "NMS"); | ||
| OP_INOUT_CHECK(ctx->HasOutput("KeepBoxesIdxs"), "Output", "KeepBoxesIdxs", | ||
| "NMS"); | ||
| | ||
| auto boxes_dim = ctx->GetInputDim("Boxes"); | ||
| PADDLE_ENFORCE_EQ(boxes_dim.size(), 2, | ||
| platform::errors::InvalidArgument( | ||
| "The Input Boxes must be 2-dimention " | ||
| "whose shape must be [N, 4] " | ||
| "N is the number of boxes " | ||
| "in last dimension in format [x1, x2, y1, y2]. ")); | ||
| auto num_boxes = boxes_dim[0]; | ||
| | ||
| ctx->SetOutputDim("KeepBoxesIdxs", {num_boxes}); | ||
| } | ||
| | ||
| protected: | ||
| framework::OpKernelType GetExpectedKernelType( | ||
| const framework::ExecutionContext& ctx) const override { | ||
| return framework::OpKernelType( | ||
| OperatorWithKernel::IndicateVarDataType(ctx, "Boxes"), ctx.GetPlace()); | ||
| } | ||
| }; | ||
| | ||
| template <typename T> | ||
| static void NMS(const T* boxes_data, int64_t* output_data, float threshold, | ||
| int64_t num_boxes) { | ||
| auto num_masks = CeilDivide(num_boxes, 64); | ||
| std::vector<uint64_t> masks(num_masks, 0); | ||
| | ||
| for (int64_t i = 0; i < num_boxes; ++i) { | ||
| if (masks[i / 64] & 1ULL << (i % 64)) continue; | ||
| T box_1[4]; | ||
| for (int k = 0; k < 4; ++k) { | ||
| box_1[k] = boxes_data[i * 4 + k]; | ||
| } | ||
| for (int64_t j = i + 1; j < num_boxes; ++j) { | ||
| if (masks[j / 64] & 1ULL << (j % 64)) continue; | ||
| T box_2[4]; | ||
| for (int k = 0; k < 4; ++k) { | ||
| box_2[k] = boxes_data[j * 4 + k]; | ||
| } | ||
| bool is_overlap = CalculateIoU<T>(box_1, box_2, threshold); | ||
| if (is_overlap) { | ||
| masks[j / 64] |= 1ULL << (j % 64); | ||
| } | ||
| } | ||
| } | ||
| | ||
| int64_t output_data_idx = 0; | ||
| for (int64_t i = 0; i < num_boxes; ++i) { | ||
| if (masks[i / 64] & 1ULL << (i % 64)) continue; | ||
| output_data[output_data_idx++] = i; | ||
| } | ||
| | ||
| for (; output_data_idx < num_boxes; ++output_data_idx) { | ||
| output_data[output_data_idx] = 0; | ||
| } | ||
| } | ||
| | ||
| template <typename T> | ||
| class NMSKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& context) const override { | ||
| const Tensor* boxes = context.Input<Tensor>("Boxes"); | ||
| Tensor* output = context.Output<Tensor>("KeepBoxesIdxs"); | ||
| int64_t* output_data = output->mutable_data<int64_t>(context.GetPlace()); | ||
| auto threshold = context.template Attr<float>("iou_threshold"); | ||
| NMS<T>(boxes->data<T>(), output_data, threshold, boxes->dims()[0]); | ||
| } | ||
| }; | ||
| | ||
| } // namespace operators | ||
| } // namespace paddle | ||
| | ||
| namespace ops = paddle::operators; | ||
| | ||
| REGISTER_OPERATOR( | ||
| nms, ops::NMSOp, ops::NMSOpMaker, | ||
| paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, | ||
| paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); | ||
| REGISTER_OP_CPU_KERNEL(nms, ops::NMSKernel<float>, ops::NMSKernel<double>); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
| | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
| | ||
| #include <vector> | ||
| #include "paddle/fluid/operators/detection/nms_op.h" | ||
| #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" | ||
| | ||
| static const int64_t threadsPerBlock = sizeof(int64_t) * 8; | ||
| | ||
| namespace paddle { | ||
| namespace operators { | ||
| | ||
| using framework::Tensor; | ||
| | ||
| template <typename T> | ||
| static __global__ void NMS(const T* boxes_data, float threshold, | ||
| int64_t num_boxes, uint64_t* masks) { | ||
| auto raw_start = blockIdx.y; | ||
| auto col_start = blockIdx.x; | ||
| if (raw_start > col_start) return; | ||
| | ||
| const int raw_last_storage = | ||
| min(num_boxes - raw_start * threadsPerBlock, threadsPerBlock); | ||
| const int col_last_storage = | ||
| min(num_boxes - col_start * threadsPerBlock, threadsPerBlock); | ||
| | ||
| if (threadIdx.x < raw_last_storage) { | ||
| uint64_t mask = 0; | ||
| auto current_box_idx = raw_start * threadsPerBlock + threadIdx.x; | ||
| const T* current_box = boxes_data + current_box_idx * 4; | ||
| for (int i = 0; i < col_last_storage; ++i) { | ||
| const T* target_box = boxes_data + (col_start * threadsPerBlock + i) * 4; | ||
| if (CalculateIoU<T>(current_box, target_box, threshold)) { | ||
| mask |= 1ULL << i; | ||
| } | ||
| } | ||
| const int blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); | ||
| masks[current_box_idx * blocks_per_line + col_start] = mask; | ||
| } | ||
| } | ||
| | ||
| template <typename T> | ||
| class NMSCudaKernel : public framework::OpKernel<T> { | ||
| public: | ||
| void Compute(const framework::ExecutionContext& context) const override { | ||
| const Tensor* boxes = context.Input<Tensor>("Boxes"); | ||
| Tensor* output = context.Output<Tensor>("KeepBoxesIdxs"); | ||
| auto* output_data = output->mutable_data<int64_t>(context.GetPlace()); | ||
| | ||
| auto threshold = context.template Attr<float>("iou_threshold"); | ||
| const int64_t num_boxes = boxes->dims()[0]; | ||
| const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock); | ||
| | ||
| dim3 block(threadsPerBlock); | ||
| dim3 grid(blocks_per_line, blocks_per_line); | ||
| | ||
| auto mask_data = | ||
| memory::Alloc(context.cuda_device_context(), | ||
| num_boxes * blocks_per_line * sizeof(uint64_t)); | ||
| uint64_t* mask_dev = reinterpret_cast<uint64_t*>(mask_data->ptr()); | ||
| NMS<T><<<grid, block, 0, context.cuda_device_context().stream()>>>( | ||
| boxes->data<T>(), threshold, num_boxes, mask_dev); | ||
| | ||
| std::vector<uint64_t> mask_host(num_boxes * blocks_per_line); | ||
| memory::Copy(platform::CPUPlace(), mask_host.data(), context.GetPlace(), | ||
| mask_dev, num_boxes * blocks_per_line * sizeof(uint64_t), | ||
| context.cuda_device_context().stream()); | ||
| | ||
| std::vector<int64_t> remv(blocks_per_line); | ||
| | ||
| std::vector<int64_t> keep_boxes_idxs(num_boxes); | ||
| int64_t* output_host = keep_boxes_idxs.data(); | ||
| | ||
| int64_t last_box_num = 0; | ||
| for (int64_t i = 0; i < num_boxes; ++i) { | ||
| auto remv_element_id = i / threadsPerBlock; | ||
| auto remv_bit_id = i % threadsPerBlock; | ||
| if (!(remv[remv_element_id] & 1ULL << remv_bit_id)) { | ||
| output_host[last_box_num++] = i; | ||
| uint64_t* current_mask = mask_host.data() + i * blocks_per_line; | ||
| for (auto j = remv_element_id; j < blocks_per_line; ++j) { | ||
| remv[j] |= current_mask[j]; | ||
| } | ||
| } | ||
| } | ||
| memory::Copy(context.GetPlace(), output_data, platform::CPUPlace(), | ||
| output_host, sizeof(int64_t) * num_boxes, | ||
| context.cuda_device_context().stream()); | ||
| } | ||
| }; | ||
| | ||
| } // namespace operators | ||
| } // namespace paddle | ||
| | ||
| namespace ops = paddle::operators; | ||
| REGISTER_OP_CUDA_KERNEL(nms, ops::NMSCudaKernel<float>, | ||
| ops::NMSCudaKernel<double>); | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
| | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
| | ||
| #pragma once | ||
| | ||
| #include "paddle/fluid/framework/op_registry.h" | ||
| #include "paddle/fluid/framework/operator.h" | ||
| | ||
| namespace paddle { | ||
| namespace operators { | ||
| | ||
| HOSTDEVICE static inline int64_t CeilDivide(int64_t n, int64_t m) { | ||
| return (n + m - 1) / m; | ||
| } | ||
| | ||
| template <typename T> | ||
| HOSTDEVICE inline bool CalculateIoU(const T* const box_1, const T* const box_2, | ||
| const float threshold) { | ||
| auto box_1_x0 = box_1[0], box_1_y0 = box_1[1]; | ||
| auto box_1_x1 = box_1[2], box_1_y1 = box_1[3]; | ||
| auto box_2_x0 = box_2[0], box_2_y0 = box_2[1]; | ||
| auto box_2_x1 = box_2[2], box_2_y1 = box_2[3]; | ||
| | ||
| auto inter_box_x0 = box_1_x0 > box_2_x0 ? box_1_x0 : box_2_x0; | ||
| auto inter_box_y0 = box_1_y0 > box_2_y0 ? box_1_y0 : box_2_y0; | ||
| auto inter_box_x1 = box_1_x1 < box_2_x1 ? box_1_x1 : box_2_x1; | ||
| auto inter_box_y1 = box_1_y1 < box_2_y1 ? box_1_y1 : box_2_y1; | ||
| | ||
| auto inter_width = | ||
| inter_box_x1 - inter_box_x0 > 0 ? inter_box_x1 - inter_box_x0 : 0; | ||
| auto inter_height = | ||
| inter_box_y1 - inter_box_y0 > 0 ? inter_box_y1 - inter_box_y0 : 0; | ||
| auto inter_area = inter_width * inter_height; | ||
| auto union_area = (box_1_x1 - box_1_x0) * (box_1_y1 - box_1_y0) + | ||
| (box_2_x1 - box_2_x0) * (box_2_y1 - box_2_y0) - inter_area; | ||
| return inter_area / union_area > threshold; | ||
| } | ||
| | ||
| } // namespace operators | ||
| } // namespace paddle |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GPU内容拷回CPU后,需要同步,不然后面用到的
mask_host极有可能是脏数据。There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已线下沟通 下个PR修改