Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/operators/detection/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ detection_library(yolo_box_op SRCS yolo_box_op.cc)
detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc box_decoder_and_assign_op.cu)
detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu)
detection_library(retinanet_detection_output_op SRCS retinanet_detection_output_op.cc)
detection_library(nms_op SRCS nms_op.cc nms_op.cu)

if(WITH_GPU OR WITH_ROCM)
set(TMPDEPS memory)
Expand Down
147 changes: 147 additions & 0 deletions paddle/fluid/operators/detection/nms_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/detection/nms_op.h"
#include <vector>

namespace paddle {
namespace operators {

using framework::Tensor;

class NMSOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Boxes",
"(Tensor) "
"Boxes is a Tensor with shape [N, 4] "
"N is the number of boxes "
"in last dimension in format [x1, x2, y1, y2] "
"the relation should be ``0 <= x1 < x2 && 0 <= y1 < y2``.");

AddOutput("KeepBoxesIdxs",
"(Tensor) "
"KeepBoxesIdxs is a Tensor with shape [N] ");
AddAttr<float>(
"iou_threshold",
"iou_threshold is a threshold value used to compress similar boxes "
"boxes with IoU > iou_threshold will be considered as overlapping "
"and just one of them can be kept.")
.SetDefault(1.0f)
.AddCustomChecker([](const float& iou_threshold) {
PADDLE_ENFORCE_LE(iou_threshold, 1.0f,
platform::errors::InvalidArgument(
"iou_threshold should less equal than 1.0 "
"but got %f",
iou_threshold));
PADDLE_ENFORCE_GE(iou_threshold, 0.0f,
platform::errors::InvalidArgument(
"iou_threshold should greater equal than 0.0 "
"but got %f",
iou_threshold));
});
AddComment(R"DOC(
NMS Operator.
This Operator is used to perform Non-Maximum Compress for input boxes.
Indices of boxes kept by NMS will be sorted by scores and output.
)DOC");
}
};

class NMSOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Boxes"), "Input", "Boxes", "NMS");
OP_INOUT_CHECK(ctx->HasOutput("KeepBoxesIdxs"), "Output", "KeepBoxesIdxs",
"NMS");

auto boxes_dim = ctx->GetInputDim("Boxes");
PADDLE_ENFORCE_EQ(boxes_dim.size(), 2,
platform::errors::InvalidArgument(
"The Input Boxes must be 2-dimention "
"whose shape must be [N, 4] "
"N is the number of boxes "
"in last dimension in format [x1, x2, y1, y2]. "));
auto num_boxes = boxes_dim[0];

ctx->SetOutputDim("KeepBoxesIdxs", {num_boxes});
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Boxes"), ctx.GetPlace());
}
};

template <typename T>
static void NMS(const T* boxes_data, int64_t* output_data, float threshold,
int64_t num_boxes) {
auto num_masks = CeilDivide(num_boxes, 64);
std::vector<uint64_t> masks(num_masks, 0);

for (int64_t i = 0; i < num_boxes; ++i) {
if (masks[i / 64] & 1ULL << (i % 64)) continue;
T box_1[4];
for (int k = 0; k < 4; ++k) {
box_1[k] = boxes_data[i * 4 + k];
}
for (int64_t j = i + 1; j < num_boxes; ++j) {
if (masks[j / 64] & 1ULL << (j % 64)) continue;
T box_2[4];
for (int k = 0; k < 4; ++k) {
box_2[k] = boxes_data[j * 4 + k];
}
bool is_overlap = CalculateIoU<T>(box_1, box_2, threshold);
if (is_overlap) {
masks[j / 64] |= 1ULL << (j % 64);
}
}
}

int64_t output_data_idx = 0;
for (int64_t i = 0; i < num_boxes; ++i) {
if (masks[i / 64] & 1ULL << (i % 64)) continue;
output_data[output_data_idx++] = i;
}

for (; output_data_idx < num_boxes; ++output_data_idx) {
output_data[output_data_idx] = 0;
}
}

template <typename T>
class NMSKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* boxes = context.Input<Tensor>("Boxes");
Tensor* output = context.Output<Tensor>("KeepBoxesIdxs");
int64_t* output_data = output->mutable_data<int64_t>(context.GetPlace());
auto threshold = context.template Attr<float>("iou_threshold");
NMS<T>(boxes->data<T>(), output_data, threshold, boxes->dims()[0]);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(
nms, ops::NMSOp, ops::NMSOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(nms, ops::NMSKernel<float>, ops::NMSKernel<double>);
108 changes: 108 additions & 0 deletions paddle/fluid/operators/detection/nms_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <vector>
#include "paddle/fluid/operators/detection/nms_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"

static const int64_t threadsPerBlock = sizeof(int64_t) * 8;

namespace paddle {
namespace operators {

using framework::Tensor;

template <typename T>
static __global__ void NMS(const T* boxes_data, float threshold,
int64_t num_boxes, uint64_t* masks) {
auto raw_start = blockIdx.y;
auto col_start = blockIdx.x;
if (raw_start > col_start) return;

const int raw_last_storage =
min(num_boxes - raw_start * threadsPerBlock, threadsPerBlock);
const int col_last_storage =
min(num_boxes - col_start * threadsPerBlock, threadsPerBlock);

if (threadIdx.x < raw_last_storage) {
uint64_t mask = 0;
auto current_box_idx = raw_start * threadsPerBlock + threadIdx.x;
const T* current_box = boxes_data + current_box_idx * 4;
for (int i = 0; i < col_last_storage; ++i) {
const T* target_box = boxes_data + (col_start * threadsPerBlock + i) * 4;
if (CalculateIoU<T>(current_box, target_box, threshold)) {
mask |= 1ULL << i;
}
}
const int blocks_per_line = CeilDivide(num_boxes, threadsPerBlock);
masks[current_box_idx * blocks_per_line + col_start] = mask;
}
}

template <typename T>
class NMSCudaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* boxes = context.Input<Tensor>("Boxes");
Tensor* output = context.Output<Tensor>("KeepBoxesIdxs");
auto* output_data = output->mutable_data<int64_t>(context.GetPlace());

auto threshold = context.template Attr<float>("iou_threshold");
const int64_t num_boxes = boxes->dims()[0];
const auto blocks_per_line = CeilDivide(num_boxes, threadsPerBlock);

dim3 block(threadsPerBlock);
dim3 grid(blocks_per_line, blocks_per_line);

auto mask_data =
memory::Alloc(context.cuda_device_context(),
num_boxes * blocks_per_line * sizeof(uint64_t));
uint64_t* mask_dev = reinterpret_cast<uint64_t*>(mask_data->ptr());
NMS<T><<<grid, block, 0, context.cuda_device_context().stream()>>>(
boxes->data<T>(), threshold, num_boxes, mask_dev);

std::vector<uint64_t> mask_host(num_boxes * blocks_per_line);
memory::Copy(platform::CPUPlace(), mask_host.data(), context.GetPlace(),
mask_dev, num_boxes * blocks_per_line * sizeof(uint64_t),
context.cuda_device_context().stream());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU内容拷回CPU后,需要同步,不然后面用到的mask_host极有可能是脏数据。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已线下沟通 下个PR修改


std::vector<int64_t> remv(blocks_per_line);

std::vector<int64_t> keep_boxes_idxs(num_boxes);
int64_t* output_host = keep_boxes_idxs.data();

int64_t last_box_num = 0;
for (int64_t i = 0; i < num_boxes; ++i) {
auto remv_element_id = i / threadsPerBlock;
auto remv_bit_id = i % threadsPerBlock;
if (!(remv[remv_element_id] & 1ULL << remv_bit_id)) {
output_host[last_box_num++] = i;
uint64_t* current_mask = mask_host.data() + i * blocks_per_line;
for (auto j = remv_element_id; j < blocks_per_line; ++j) {
remv[j] |= current_mask[j];
}
}
}
memory::Copy(context.GetPlace(), output_data, platform::CPUPlace(),
output_host, sizeof(int64_t) * num_boxes,
context.cuda_device_context().stream());
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(nms, ops::NMSCudaKernel<float>,
ops::NMSCudaKernel<double>);
51 changes: 51 additions & 0 deletions paddle/fluid/operators/detection/nms_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace operators {

HOSTDEVICE static inline int64_t CeilDivide(int64_t n, int64_t m) {
return (n + m - 1) / m;
}

template <typename T>
HOSTDEVICE inline bool CalculateIoU(const T* const box_1, const T* const box_2,
const float threshold) {
auto box_1_x0 = box_1[0], box_1_y0 = box_1[1];
auto box_1_x1 = box_1[2], box_1_y1 = box_1[3];
auto box_2_x0 = box_2[0], box_2_y0 = box_2[1];
auto box_2_x1 = box_2[2], box_2_y1 = box_2[3];

auto inter_box_x0 = box_1_x0 > box_2_x0 ? box_1_x0 : box_2_x0;
auto inter_box_y0 = box_1_y0 > box_2_y0 ? box_1_y0 : box_2_y0;
auto inter_box_x1 = box_1_x1 < box_2_x1 ? box_1_x1 : box_2_x1;
auto inter_box_y1 = box_1_y1 < box_2_y1 ? box_1_y1 : box_2_y1;

auto inter_width =
inter_box_x1 - inter_box_x0 > 0 ? inter_box_x1 - inter_box_x0 : 0;
auto inter_height =
inter_box_y1 - inter_box_y0 > 0 ? inter_box_y1 - inter_box_y0 : 0;
auto inter_area = inter_width * inter_height;
auto union_area = (box_1_x1 - box_1_x0) * (box_1_y1 - box_1_y0) +
(box_2_x1 - box_2_x0) * (box_2_y1 - box_2_y0) - inter_area;
return inter_area / union_area > threshold;
}

} // namespace operators
} // namespace paddle
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ endif()

if(WIN32)
LIST(REMOVE_ITEM TEST_OPS test_complex_matmul)
LIST(REMOVE_ITEM TEST_OPS test_ops_nms)
endif()

LIST(REMOVE_ITEM TEST_OPS test_fleet_checkpoint)
Expand Down
Loading