@@ -61,13 +61,13 @@ inline platform::GpuLaunchConfig GetGpuLaunchConfig3D(
6161
6262template <typename T>
6363__forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex (
64- int * in_img_idx, int * w_id , T* w1lambda , T* w2lambda , T src_w ,
65- const int in_img_w ) {
66- src_w = (src_w > 0 ) ? src_w : 0 .f ;
67- *in_img_idx = static_cast <int >(src_w );
68- *w_id = (*in_img_idx < in_img_w - 1 ) ? 1 : 0 ;
69- *w1lambda = src_w - *in_img_idx;
70- *w2lambda = 1 .f - *w1lambda ;
64+ int * in_img_idx, int * x_id , T* lambda1 , T* lambda2 , T src_x ,
65+ const int in_img_x ) {
66+ src_x = (src_x > 0 ) ? src_x : 0 .f ;
67+ *in_img_idx = static_cast <int >(src_x );
68+ *x_id = (*in_img_idx < in_img_x - 1 ) ? 1 : 0 ;
69+ *lambda1 = src_x - *in_img_idx;
70+ *lambda2 = 1 .f - *lambda1 ;
7171}
7272
7373struct FastDivModForInterpolate {
@@ -670,83 +670,102 @@ __global__ void KeBilinearInterpBwShareMemory(
670670 }
671671}
672672
673+ __device__ __forceinline__ int GetInputIndex (const size_t nc, const int height,
674+ const int width, const int h,
675+ const int w) {
676+ return (nc * height + h) * width + w;
677+ }
678+
679+ template <typename T>
680+ __global__ void KeBilinearInterpNCHWBw (T* in, const int in_h, const int in_w,
681+ const int out_h, const int out_w,
682+ const int n, const int num_channels,
683+ float ratio_h, float ratio_w,
684+ const T* __restrict__ out,
685+ const T align_type_value) {
686+ int index = threadIdx .x + blockDim .x * blockIdx .x ;
687+ int stride = blockDim .x * gridDim .x ;
688+ int num_out = n * num_channels * out_h * out_w;
689+ int num_in = n * num_channels * in_h * in_w;
690+
691+ for (; index < num_out; index += stride) {
692+ int index_tmp = index;
693+ int w2 = index_tmp % out_w;
694+ index_tmp /= out_w;
695+ int h2 = index_tmp % out_h;
696+ int nc = index_tmp / out_h;
697+
698+ int h1, y_id;
699+ T h1lambda, h0lambda;
700+ T src_y = ratio_h * (h2 + align_type_value) - align_type_value;
701+
702+ PreCalculatorForLinearInterpInputIndex (&h1, &y_id, &h1lambda, &h0lambda,
703+ src_y, in_h);
704+ int w1, x_id;
705+ T w1lambda, w0lambda;
706+ T src_x = ratio_w * (w2 + align_type_value) - align_type_value;
707+ PreCalculatorForLinearInterpInputIndex (&w1, &x_id, &w1lambda, &w0lambda,
708+ src_x, in_w);
709+
710+ T d2val = out[index];
711+
712+ platform::CudaAtomicAdd (in + GetInputIndex (nc, in_h, in_w, h1, w1),
713+ h0lambda * w0lambda * d2val);
714+ platform::CudaAtomicAdd (in + GetInputIndex (nc, in_h, in_w, h1, w1 + x_id),
715+ h0lambda * w1lambda * d2val);
716+ platform::CudaAtomicAdd (in + GetInputIndex (nc, in_h, in_w, h1 + y_id, w1),
717+ h1lambda * w0lambda * d2val);
718+ platform::CudaAtomicAdd (
719+ in + GetInputIndex (nc, in_h, in_w, h1 + y_id, w1 + x_id),
720+ h1lambda * w1lambda * d2val);
721+ }
722+ }
723+
673724template <typename T>
674725__global__ void KeBilinearInterpBw (T* in, const int in_h, const int in_w,
675726 const T* __restrict__ out, const int out_h,
676727 const int out_w, const int n,
677- const int num_channels, float ratio_h,
678- float ratio_w, const T align_type_value,
679- bool is_nchw) {
728+ const int out_chw, const int num_channels,
729+ float ratio_h, float ratio_w,
730+ const T align_type_value,
731+ FastDivModForInterpolate divmods) {
680732 int tid = blockIdx .x * blockDim .x + threadIdx .x ;
681733 int stride = blockDim .x * gridDim .x ;
682734 int in_chw = in_h * in_w * num_channels;
683- int out_chw = num_channels * out_h * out_w;
684735 int nthreads = n * out_chw;
685736
686- if (is_nchw) {
687- for (; tid < nthreads; tid += stride) {
688- int out_id_h = tid / out_chw;
689- int out_id_w = tid % out_chw;
690- const int in_img_size = in_h * in_w;
691- const int out_img_size = out_h * out_w;
692- T value = out[out_id_h * out_chw + out_id_w];
693-
694- int channel_id = out_id_w / out_img_size;
695- int out_img_idy = (out_id_w % out_img_size) / out_w;
696- int out_img_idx = tid % out_w;
697- int in_img_idx, in_img_idy, w_id, h_id;
698- T w1lambda, h1lambda, w2lambda, h2lambda;
699-
700- T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
701- T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
702-
703- PreCalculatorForLinearInterpInputIndex (&in_img_idx, &w_id, &w1lambda,
704- &w2lambda, src_w, in_w);
705- PreCalculatorForLinearInterpInputIndex (&in_img_idy, &h_id, &h1lambda,
706- &h2lambda, src_h, in_h);
707-
708- T* in_pos = &in[out_id_h * in_chw + channel_id * in_img_size +
709- in_img_idy * in_w + in_img_idx];
710- platform::CudaAtomicAdd (&in_pos[0 ], h2lambda * w2lambda * value);
711- platform::CudaAtomicAdd (&in_pos[w_id], h2lambda * w1lambda * value);
712- platform::CudaAtomicAdd (&in_pos[h_id * in_w],
713- h1lambda * w2lambda * value);
714- platform::CudaAtomicAdd (&in_pos[h_id * in_w + w_id],
715- h1lambda * w1lambda * value);
716- }
717- } else {
718- for (; tid < nthreads; tid += stride) {
719- int out_id_h = tid / out_chw;
720- int out_id_w = tid % out_chw;
721- const int in_img_size = in_h * in_w;
722- const int out_img_size = out_h * out_w;
723- T value = out[out_id_h * out_chw + out_id_w];
724-
725- int out_img_idy = out_id_w / (out_w * num_channels);
726- int out_img_idx = out_id_w % (out_w * num_channels) / num_channels;
727- int channel_id = tid % num_channels;
728-
729- int in_img_idx, in_img_idy, w_id, h_id;
730- T w1lambda, h1lambda, w2lambda, h2lambda;
731- T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
732- T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
733-
734- PreCalculatorForLinearInterpInputIndex (&in_img_idx, &w_id, &w1lambda,
735- &w2lambda, src_w, in_w);
736- PreCalculatorForLinearInterpInputIndex (&in_img_idy, &h_id, &h1lambda,
737- &h2lambda, src_h, in_h);
738-
739- T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels +
740- in_img_idx * num_channels + channel_id];
741- platform::CudaAtomicAdd (&in_pos[0 ], h2lambda * w2lambda * value);
742- platform::CudaAtomicAdd (&in_pos[w_id * num_channels],
743- h2lambda * w1lambda * value);
744- platform::CudaAtomicAdd (&in_pos[h_id * in_w * num_channels],
745- h1lambda * w2lambda * value);
746- platform::CudaAtomicAdd (
747- &in_pos[h_id * in_w * num_channels + w_id * num_channels],
748- h1lambda * w1lambda * value);
749- }
737+ for (; tid < nthreads; tid += stride) {
738+ auto out_id_divmod = divmods.output_w_div .Divmod (tid);
739+ int out_id_h = out_id_divmod.val [0 ];
740+ int out_id_w = out_id_divmod.val [1 ];
741+
742+ int channel_id = divmods.channels_div .Divmod (tid).val [1 ];
743+ auto outimg_id_divmod = divmods.output_wc_div .Divmod (out_id_w);
744+ int out_img_idy = outimg_id_divmod.val [0 ];
745+ int out_img_idx =
746+ divmods.channels_div .Divmod (outimg_id_divmod.val [1 ]).val [0 ];
747+
748+ int in_img_idx, in_img_idy, w_id, h_id;
749+ T w1lambda, h1lambda, w2lambda, h2lambda;
750+ T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
751+ T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
752+
753+ PreCalculatorForLinearInterpInputIndex (&in_img_idx, &w_id, &w1lambda,
754+ &w2lambda, src_w, in_w);
755+ PreCalculatorForLinearInterpInputIndex (&in_img_idy, &h_id, &h1lambda,
756+ &h2lambda, src_h, in_h);
757+
758+ T value = out[tid];
759+ T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels +
760+ in_img_idx * num_channels + channel_id];
761+ platform::CudaAtomicAdd (&in_pos[0 ], h2lambda * w2lambda * value);
762+ platform::CudaAtomicAdd (&in_pos[w_id * num_channels],
763+ h2lambda * w1lambda * value);
764+ platform::CudaAtomicAdd (&in_pos[h_id * in_w * num_channels],
765+ h1lambda * w2lambda * value);
766+ platform::CudaAtomicAdd (
767+ &in_pos[h_id * in_w * num_channels + w_id * num_channels],
768+ h1lambda * w1lambda * value);
750769 }
751770}
752771
@@ -1907,11 +1926,23 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
19071926 ctx.cuda_device_context().stream()>>> (
19081927 input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n, c,
19091928 ratio_h, ratio_w, align_type_value, is_nchw);
1929+ } else if (!optimize_flag & is_nchw) {
1930+ //
1931+ const int num_kernels = n * c * out_h * out_w;
1932+ const int num_threads =
1933+ std::min (ctx.cuda_device_context ().GetMaxThreadsPerBlock (), 1024 );
1934+ KeBilinearInterpNCHWBw<
1935+ T><<<platform::DivUp(num_kernels, num_threads), num_threads, 0 ,
1936+ ctx.cuda_device_context().stream()>>> (
1937+ input_grad_data, in_h, in_w, out_h, out_w, n, c, ratio_h, ratio_w,
1938+ output_grad_data, align_type_value);
19101939 } else {
1940+ int64_t cw = c * out_w;
1941+ auto interp_divmods = FastDivModForInterpolate (c, out_chw, cw);
19111942 KeBilinearInterpBw<T><<<config.block_per_grid, config.thread_per_block, 0 ,
19121943 ctx.cuda_device_context().stream()>>> (
1913- input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n, c,
1914- ratio_h, ratio_w, align_type_value, is_nchw );
1944+ input_grad_data, in_h, in_w, output_grad_data, out_h, out_w, n,
1945+ out_chw, c, ratio_h, ratio_w, align_type_value, interp_divmods );
19151946 }
19161947 } else if (" bicubic" == interp_method) {
19171948#ifdef __HIPCC__
0 commit comments