22#include < ATen/cuda/CUDAContext.h>
33#include < c10/cuda/CUDAGuard.h>
44#include < torch/library.h>
5- #include < ATen/cuda/Atomic .cuh>
5+ #include < ATen/native/ cuda/KernelUtils .cuh>
66
77#include " cuda_helpers.h"
88
@@ -218,7 +218,8 @@ __global__ void roi_align_backward_kernel_impl(
218218 int n_stride,
219219 int c_stride,
220220 int h_stride,
221- int w_stride) {
221+ int w_stride,
222+ const int memory_span) {
222223 CUDA_1D_KERNEL_LOOP (index, nthreads) {
223224 // (n, c, ph, pw) is an element in the pooled output
224225 int pw = index % pooled_width;
@@ -247,12 +248,9 @@ __global__ void roi_align_backward_kernel_impl(
247248 T bin_size_h = static_cast <T>(roi_height) / static_cast <T>(pooled_height);
248249 T bin_size_w = static_cast <T>(roi_width) / static_cast <T>(pooled_width);
249250
250- T* offset_grad_input =
251- grad_input + ((roi_batch_ind * channels + c) * height * width);
252-
253251 // We need to index the gradient using the tensor strides to access the
254252 // correct values.
255- int output_offset = n * n_stride + c * c_stride;
253+ const int output_offset = n * n_stride + c * c_stride;
256254 const T* offset_grad_output = grad_output + output_offset;
257255 const T grad_output_this_bin =
258256 offset_grad_output[ph * h_stride + pw * w_stride];
@@ -267,6 +265,8 @@ __global__ void roi_align_backward_kernel_impl(
267265 // We do average (integral) pooling inside a bin
268266 const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
269267
268+ const int input_offset = (roi_batch_ind * channels + c) * height * width;
269+
270270 for (int iy = 0 ; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1
271271 {
272272 const T y = roi_start_h + ph * bin_size_h +
@@ -301,14 +301,30 @@ __global__ void roi_align_backward_kernel_impl(
301301 T g4 = grad_output_this_bin * w4 / count;
302302
303303 if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0 ) {
304- gpuAtomicAdd (
305- offset_grad_input + y_low * width + x_low, static_cast <T>(g1));
306- gpuAtomicAdd (
307- offset_grad_input + y_low * width + x_high, static_cast <T>(g2));
308- gpuAtomicAdd (
309- offset_grad_input + y_high * width + x_low, static_cast <T>(g3));
310- gpuAtomicAdd (
311- offset_grad_input + y_high * width + x_high, static_cast <T>(g4));
304+ at::native::fastAtomicAdd (
305+ grad_input,
306+ input_offset + y_low * width + x_low,
307+ memory_span,
308+ static_cast <T>(g1),
309+ true );
310+ at::native::fastAtomicAdd (
311+ grad_input,
312+ input_offset + y_low * width + x_high,
313+ memory_span,
314+ static_cast <T>(g2),
315+ true );
316+ at::native::fastAtomicAdd (
317+ grad_input,
318+ input_offset + y_high * width + x_low,
319+ memory_span,
320+ static_cast <T>(g3),
321+ true );
322+ at::native::fastAtomicAdd (
323+ grad_input,
324+ input_offset + y_high * width + x_high,
325+ memory_span,
326+ static_cast <T>(g4),
327+ true );
312328 } // if
313329 } // ix
314330 } // iy
@@ -442,7 +458,8 @@ at::Tensor roi_align_backward_kernel(
442458 n_stride,
443459 c_stride,
444460 h_stride,
445- w_stride);
461+ w_stride,
462+ grad_input.numel ());
446463 });
447464 AT_CUDA_CHECK (cudaGetLastError ());
448465 return grad_input;
0 commit comments