Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 113 additions & 115 deletions paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,52 +20,52 @@

namespace phi {

template <typename T>
template <typename T, typename IndexT>
inline void ModulatedDeformableCol2imCPUKernel(
const int num_kernels,
const IndexT num_kernels,
const T* data_col,
const T* data_offset,
const T* data_mask,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int deformable_group,
const int height_col,
const int width_col,
const IndexT channels,
const IndexT height,
const IndexT width,
const IndexT kernel_h,
const IndexT kernel_w,
const IndexT pad_h,
const IndexT pad_w,
const IndexT stride_h,
const IndexT stride_w,
const IndexT dilation_h,
const IndexT dilation_w,
const IndexT channel_per_deformable_group,
const IndexT batch_size,
const IndexT deformable_group,
const IndexT height_col,
const IndexT width_col,
T* grad_im) {
for (int thread = 0; thread < num_kernels; thread++) {
const int j = (thread / width_col / height_col / batch_size) % kernel_w;
const int i =
for (IndexT thread = 0; thread < num_kernels; thread++) {
const IndexT j = (thread / width_col / height_col / batch_size) % kernel_w;
const IndexT i =
(thread / width_col / height_col / batch_size / kernel_w) % kernel_h;
const int c =
const IndexT c =
thread / width_col / height_col / batch_size / kernel_w / kernel_h;

const int deformable_group_index = c / channel_per_deformable_group;
const IndexT deformable_group_index = c / channel_per_deformable_group;

int w_out = thread % width_col;
int h_out = (thread / width_col) % height_col;
int b = (thread / width_col / height_col) % batch_size;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
IndexT w_out = thread % width_col;
IndexT h_out = (thread / width_col) % height_col;
IndexT b = (thread / width_col / height_col) % batch_size;
IndexT w_in = w_out * stride_w - pad_w;
IndexT h_in = h_out * stride_h - pad_h;

const T* data_offset_ptr =
data_offset + (b * deformable_group + deformable_group_index) * 2 *
kernel_h * kernel_w * height_col * width_col;
const int data_offset_h_ptr =
const IndexT data_offset_h_ptr =
((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
const int data_offset_w_ptr =
const IndexT data_offset_w_ptr =
((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
const int data_mask_hw_ptr =
const IndexT data_mask_hw_ptr =
((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
const T offset_h = data_offset_ptr[data_offset_h_ptr];
const T offset_w = data_offset_ptr[data_offset_w_ptr];
Expand All @@ -80,14 +80,14 @@ inline void ModulatedDeformableCol2imCPUKernel(
const T mask = data_mask_ptr[data_mask_hw_ptr];
cur_top_grad *= mask;
}
const int cur_h = static_cast<int>(cur_inv_h_data);
const int cur_w = static_cast<int>(cur_inv_w_data);
for (int dy = -2; dy <= 2; dy++) {
for (int dx = -2; dx <= 2; dx++) {
const IndexT cur_h = static_cast<IndexT>(cur_inv_h_data);
const IndexT cur_w = static_cast<IndexT>(cur_inv_w_data);
for (IndexT dy = -2; dy <= 2; dy++) {
for (IndexT dx = -2; dx <= 2; dx++) {
if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 &&
cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
abs(cur_inv_w_data - (cur_w + dx)) < 1) {
int cur_bottom_grad_pos =
IndexT cur_bottom_grad_pos =
((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
T weight = DmcnGetGradientWeight(cur_inv_h_data,
cur_inv_w_data,
Expand All @@ -104,7 +104,7 @@ inline void ModulatedDeformableCol2imCPUKernel(
}
}

template <typename T, typename Context>
template <typename T, typename Context, typename IndexT>
void ModulatedDeformableCol2im(const Context& dev_ctx,
const T* data_col,
const T* data_offset,
Expand All @@ -117,70 +117,69 @@ void ModulatedDeformableCol2im(const Context& dev_ctx,
const std::vector<int>& dilation,
const int deformable_group,
T* grad_im) {
int channel_per_deformable_group =
static_cast<int>(im_shape[0] / deformable_group);
int num_kernels = static_cast<int>(col_shape[0] * col_shape[1] *
col_shape[2] * col_shape[3]);
int64_t channel_per_deformable_group = im_shape[0] / deformable_group;
int64_t num_kernels =
col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3];

ModulatedDeformableCol2imCPUKernel(num_kernels,
data_col,
data_offset,
data_mask,
im_shape[0],
im_shape[1],
im_shape[2],
kernel_shape[2],
kernel_shape[3],
pad[0],
pad[1],
stride[0],
stride[1],
dilation[0],
dilation[1],
channel_per_deformable_group,
col_shape[1],
deformable_group,
col_shape[2],
col_shape[3],
grad_im);
ModulatedDeformableCol2imCPUKernel<T, IndexT>(num_kernels,
data_col,
data_offset,
data_mask,
im_shape[0],
im_shape[1],
im_shape[2],
kernel_shape[2],
kernel_shape[3],
pad[0],
pad[1],
stride[0],
stride[1],
dilation[0],
dilation[1],
channel_per_deformable_group,
col_shape[1],
deformable_group,
col_shape[2],
col_shape[3],
grad_im);
}

template <typename T>
template <typename T, typename IndexT>
void ModulatedDeformableCol2imCoordCPUKernel(
const int num_kernels,
const IndexT num_kernels,
const T* data_col,
const T* data_im,
const T* data_offset,
const T* data_mask,
const int channels,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
const int channel_per_deformable_group,
const int batch_size,
const int offset_channels,
const int deformable_group,
const int height_col,
const int width_col,
const IndexT channels,
const IndexT height,
const IndexT width,
const IndexT kernel_h,
const IndexT kernel_w,
const IndexT pad_h,
const IndexT pad_w,
const IndexT stride_h,
const IndexT stride_w,
const IndexT dilation_h,
const IndexT dilation_w,
const IndexT channel_per_deformable_group,
const IndexT batch_size,
const IndexT offset_channels,
const IndexT deformable_group,
const IndexT height_col,
const IndexT width_col,
T* grad_offset,
T* grad_mask) {
for (int i = 0; i < num_kernels; i++) {
for (IndexT i = 0; i < num_kernels; i++) {
T val = 0, mval = 0;
const int w = i % width_col;
const int h = (i / width_col) % height_col;
const int c = (i / width_col / height_col) % offset_channels;
const int b = (i / width_col / height_col) / offset_channels;
const IndexT w = i % width_col;
const IndexT h = (i / width_col) % height_col;
const IndexT c = (i / width_col / height_col) % offset_channels;
const IndexT b = (i / width_col / height_col) / offset_channels;

const int deformable_group_index = c / (2 * kernel_h * kernel_w);
const int col_step = kernel_h * kernel_w;
int cnt = 0;
const IndexT deformable_group_index = c / (2 * kernel_h * kernel_w);
const IndexT col_step = kernel_h * kernel_w;
IndexT cnt = 0;
const T* data_col_ptr = data_col + deformable_group_index *
channel_per_deformable_group *
batch_size * width_col * height_col;
Expand All @@ -197,24 +196,25 @@ void ModulatedDeformableCol2imCoordCPUKernel(
kernel_h * kernel_w * height_col * width_col
: nullptr;

const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
const IndexT offset_c =
c - deformable_group_index * 2 * kernel_h * kernel_w;

for (int col_c = offset_c / 2; col_c < channel_per_deformable_group;
for (IndexT col_c = offset_c / 2; col_c < channel_per_deformable_group;
col_c += col_step) {
const int col_pos =
const IndexT col_pos =
(((col_c * batch_size + b) * height_col) + h) * width_col + w;
const int bp_dir = offset_c % 2;
const IndexT bp_dir = offset_c % 2;

int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
int i =
IndexT j = (col_pos / width_col / height_col / batch_size) % kernel_w;
IndexT i =
(col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
int w_out = col_pos % width_col;
int h_out = (col_pos / width_col) % height_col;
int w_in = w_out * stride_w - pad_w;
int h_in = h_out * stride_h - pad_h;
const int data_offset_h_ptr =
IndexT w_out = col_pos % width_col;
IndexT h_out = (col_pos / width_col) % height_col;
IndexT w_in = w_out * stride_w - pad_w;
IndexT h_in = h_out * stride_h - pad_h;
const IndexT data_offset_h_ptr =
(((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
const int data_offset_w_ptr =
const IndexT data_offset_w_ptr =
(((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col +
w_out);
const T offset_h = data_offset_ptr[data_offset_h_ptr];
Expand All @@ -241,7 +241,7 @@ void ModulatedDeformableCol2imCoordCPUKernel(
width,
bp_dir);
if (data_mask_ptr) {
const int data_mask_hw_ptr =
const IndexT data_mask_hw_ptr =
(((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
const T mask = data_mask_ptr[data_mask_hw_ptr];
val += weight * data_col_ptr[col_pos] * mask;
Expand All @@ -262,7 +262,7 @@ void ModulatedDeformableCol2imCoordCPUKernel(
}
}

template <typename T, typename Context>
template <typename T, typename Context, typename IndexT>
void ModulatedDeformableCol2imCoord(const Context& dev_ctx,
const T* data_col,
const T* data_im,
Expand All @@ -277,13 +277,11 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx,
const int deformable_groups,
T* grad_offset,
T* grad_mask) {
int num_kernels =
static_cast<int>(2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] *
col_shape[2] * col_shape[3] * deformable_groups);
int channel_per_deformable_group =
static_cast<int>(col_shape[0] / deformable_groups);
int64_t num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] *
col_shape[2] * col_shape[3] * deformable_groups;
int64_t channel_per_deformable_group = col_shape[0] / deformable_groups;

ModulatedDeformableCol2imCoordCPUKernel(
ModulatedDeformableCol2imCoordCPUKernel<T, IndexT>(
num_kernels,
data_col,
data_im,
Expand All @@ -310,15 +308,15 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx,
grad_mask);
}

template <typename T, typename Context>
template <typename T, typename Context, typename IndexT>
void FilterGradAddup(const Context& dev_ctx,
const int nthreads,
const int n,
const int height,
const int width,
const int64_t nthreads,
const int64_t n,
const int64_t height,
const int64_t width,
const T* dweight_3d,
T* filter_grad) {
for (int i = 0; i < nthreads; i++) {
for (IndexT i = 0; i < nthreads; i++) {
filter_grad[i] = filter_grad[i] + dweight_3d[i];
}
}
Expand Down
Loading