@@ -59,6 +59,17 @@ inline platform::GpuLaunchConfig GetGpuLaunchConfig3D(
5959 return config;
6060}
6161
62+ template <typename T>
63+ __forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex (
64+ int * in_img_idx, int * w_id, T* w1lambda, T* w2lambda, T src_w,
65+ const int in_img_w) {
66+ src_w = (src_w > 0 ) ? src_w : 0 .f ;
67+ *in_img_idx = static_cast <int >(src_w);
68+ *w_id = (*in_img_idx < in_img_w - 1 ) ? 1 : 0 ;
69+ *w1lambda = src_w - *in_img_idx;
70+ *w2lambda = 1 .f - *w1lambda;
71+ }
72+
6273struct FastDivModForInterpolate {
6374 public:
6475 FastDivMod channels_div;
@@ -417,96 +428,93 @@ __global__ void KeLinearInterpBw(T* in, const size_t in_img_w,
417428}
418429
419430template <typename T>
420- __global__ void KeBilinearInterpFw (
421- const T* in, const size_t in_img_h, const size_t in_img_w,
422- const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
423- const size_t out_img_w, const size_t output_h, const size_t output_w,
424- const size_t num_channels, const float ratio_h, const float ratio_w,
425- const bool align_corners, const int align_mode,
426- const DataLayout data_layout) {
427- int nthreads = output_h * output_w;
428- int tid = blockIdx .x * blockDim .x + threadIdx .x ;
429- int stride = blockDim .x * gridDim .x ;
430- bool align_flag = (align_mode == 0 && !align_corners);
431- for (; tid < nthreads; tid += stride) {
432- int out_id_h = tid / output_w;
433- int out_id_w = tid % output_w;
434- int in_img_size = input_w / num_channels;
435- int out_img_size = output_w / num_channels;
431+ __global__ void KeBilinearInterpNCHWFw (const T* in, const size_t in_img_h,
432+ const size_t in_img_w, T* out,
433+ const size_t out_img_h,
434+ const size_t out_img_w, const size_t nc,
435+ const float ratio_h, const float ratio_w,
436+ const T align_type_value) {
437+ int out_img_idx = threadIdx .x + blockIdx .x * blockDim .x ;
438+ int out_img_idy = threadIdx .y + blockIdx .y * blockDim .y ;
439+ int nc_id = threadIdx .z + blockIdx .z * blockDim .z ;
440+ int nc_stride = blockDim .z * gridDim .z ;
436441
437- int channel_id, out_img_idy, out_img_idx;
438- if (data_layout == DataLayout::kNCHW ) {
439- channel_id = out_id_w / out_img_size;
440- out_img_idy = (out_id_w % out_img_size) / out_img_w;
441- out_img_idx = tid % out_img_w;
442- } else {
443- out_img_idy = out_id_w / (out_img_w * num_channels);
444- out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
445- channel_id = tid % num_channels;
446- }
442+ int in_img_idx, in_img_idy, h_id, w_id;
443+ T h1lambda, w1lambda, h2lambda, w2lambda;
444+ T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
445+ T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
447446
448- int in_img_idy = align_flag
449- ? static_cast <int >(ratio_h * (out_img_idy + 0.5 ) - 0.5 )
450- : static_cast <int >(ratio_h * out_img_idy);
451- in_img_idy = (in_img_idy > 0 ) ? in_img_idy : 0 ;
452- int h_id = (in_img_idy < in_img_h - 1 ) ? 1 : 0 ;
453- T src_h = ratio_h * (out_img_idy + 0.5 ) - 0.5 ;
454- src_h = (src_h > 0 ) ? src_h : 0 ;
455- T h1lambda =
456- align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
457- T h2lambda = 1 .f - h1lambda;
447+ PreCalculatorForLinearInterpInputIndex (&in_img_idx, &w_id, &w1lambda,
448+ &w2lambda, src_w, in_img_w);
449+ PreCalculatorForLinearInterpInputIndex (&in_img_idy, &h_id, &h1lambda,
450+ &h2lambda, src_h, in_img_h);
458451
459- int in_img_idx = align_flag
460- ? static_cast <int >(ratio_w * (out_img_idx + 0.5 ) - 0.5 )
461- : static_cast <int >(ratio_w * out_img_idx);
462- in_img_idx = (in_img_idx > 0 ) ? in_img_idx : 0 ;
463- int w_id = (in_img_idx < in_img_w - 1 ) ? 1 : 0 ;
464- T src_w = ratio_w * (out_img_idx + 0.5 ) - 0.5 ;
465- src_w = (src_w > 0 ) ? src_w : 0 ;
466- T w1lambda =
467- align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
468- T w2lambda = 1 .f - w1lambda;
452+ int in_index = (nc_id * in_img_h + in_img_idy) * in_img_w + in_img_idx;
453+ int in_index_stride = nc_stride * in_img_h * in_img_w;
469454
470- if (data_layout == DataLayout::kNCHW ) {
471- const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
472- in_img_idy * in_img_w + in_img_idx];
455+ int out_index = (nc_id * out_img_h + out_img_idy) * out_img_w + out_img_idx;
456+ int out_index_stride = nc_stride * out_img_h * out_img_w;
473457
474- // bilinear interpolation
475- out[out_id_h * output_w + out_id_w] =
458+ // prevent from multiple threads writing
459+ if (out_img_idx < out_img_w && out_img_idy < out_img_h) {
460+ while (nc_id < nc) {
461+ const T* in_pos = &in[in_index];
462+ out[out_index] =
476463 h2lambda * (w2lambda * in_pos[0 ] + w1lambda * in_pos[w_id]) +
477464 h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
478465 w1lambda * in_pos[h_id * in_img_w + w_id]);
479- } else {
480- const T* in_pos =
481- &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
482- in_img_idx * num_channels + channel_id];
483466
484- // bilinear interpolation
485- out[out_id_h * output_w + out_id_w] =
486- h2lambda *
487- (w2lambda * in_pos[0 ] + w1lambda * in_pos[w_id * num_channels]) +
488- h1lambda * (w2lambda * in_pos[h_id * in_img_w * num_channels] +
489- w1lambda * in_pos[h_id * in_img_w * num_channels +
490- w_id * num_channels]);
467+ in_index += in_index_stride;
468+ out_index += out_index_stride;
469+ nc_id += nc_stride;
491470 }
492471 }
493472}
494473
495474template <typename T>
496- __forceinline__ __device__ void PreCalculatorForInputIndex (
497- int * in_img_idx, int * in_img_idy, int * w_id, int * h_id, T* w1lambda,
498- T* h1lambda, T* w2lambda, T* h2lambda, T src_w, T src_h, const int in_img_w,
499- const int in_img_h) {
500- src_w = (src_w > 0 ) ? src_w : 0 .f ;
501- src_h = (src_h > 0 ) ? src_h : 0 .f ;
502- *in_img_idx = static_cast <int >(src_w);
503- *in_img_idy = static_cast <int >(src_h);
504- *w_id = (*in_img_idx < in_img_w - 1 ) ? 1 : 0 ;
505- *h_id = (*in_img_idy < in_img_h - 1 ) ? 1 : 0 ;
506- *w1lambda = src_w - *in_img_idx;
507- *h1lambda = src_h - *in_img_idy;
508- *w2lambda = 1 .f - *w1lambda;
509- *h2lambda = 1 .f - *h1lambda;
475+ __global__ void KeBilinearInterpFw (
476+ const T* in, const size_t in_img_h, const size_t in_img_w,
477+ const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
478+ const size_t out_img_w, const size_t output_h, const size_t output_w,
479+ const size_t num_channels, const float ratio_h, const float ratio_w,
480+ const T align_type_value, FastDivModForInterpolate divmods) {
481+ int nthreads = output_h * output_w;
482+ int tid = blockIdx .x * blockDim .x + threadIdx .x ;
483+ int stride = blockDim .x * gridDim .x ;
484+
485+ for (; tid < nthreads; tid += stride) {
486+ auto out_id_divmod = divmods.output_w_div .Divmod (tid);
487+ int out_id_h = out_id_divmod.val [0 ];
488+ int out_id_w = out_id_divmod.val [1 ];
489+
490+ int channel_id = divmods.channels_div .Divmod (tid).val [1 ];
491+ auto outimg_id_divmod = divmods.output_wc_div .Divmod (out_id_w);
492+ int out_img_idy = outimg_id_divmod.val [0 ];
493+ int out_img_idx =
494+ divmods.channels_div .Divmod (outimg_id_divmod.val [1 ]).val [0 ];
495+
496+ int in_img_idx, in_img_idy, h_id, w_id;
497+ T h1lambda, w1lambda, h2lambda, w2lambda;
498+ T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
499+ T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
500+
501+ PreCalculatorForLinearInterpInputIndex (&in_img_idx, &w_id, &w1lambda,
502+ &w2lambda, src_w, in_img_w);
503+ PreCalculatorForLinearInterpInputIndex (&in_img_idy, &h_id, &h1lambda,
504+ &h2lambda, src_h, in_img_h);
505+
506+ // bilinear interpolation
507+ const T* in_pos =
508+ &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
509+ in_img_idx * num_channels + channel_id];
510+ out[tid] =
511+ h2lambda *
512+ (w2lambda * in_pos[0 ] + w1lambda * in_pos[w_id * num_channels]) +
513+ h1lambda *
514+ (w2lambda * in_pos[h_id * in_img_w * num_channels] +
515+ w1lambda *
516+ in_pos[h_id * in_img_w * num_channels + w_id * num_channels]);
517+ }
510518}
511519
512520/* Calculate the minimum of partial elements in a block */
@@ -574,9 +582,11 @@ __global__ void KeBilinearInterpBwShareMemory(
574582 T w1lambda, h1lambda, w2lambda, h2lambda;
575583 T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
576584 T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
577- PreCalculatorForInputIndex (&in_img_idx, &in_img_idy, &w_id, &h_id,
578- &w1lambda, &h1lambda, &w2lambda, &h2lambda,
579- src_w, src_h, in_w, in_h);
585+
586+ PreCalculatorForLinearInterpInputIndex (&in_img_idx, &w_id, &w1lambda,
587+ &w2lambda, src_w, in_w);
588+ PreCalculatorForLinearInterpInputIndex (&in_img_idy, &h_id, &h1lambda,
589+ &h2lambda, src_h, in_h);
580590
581591 // top_left_index is just input_index.
582592 int input_index = out_id_h * in_chw + channel_id * in_img_size +
@@ -661,9 +671,11 @@ __global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w,
661671
662672 T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
663673 T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
664- PreCalculatorForInputIndex (&in_img_idx, &in_img_idy, &w_id, &h_id,
665- &w1lambda, &h1lambda, &w2lambda, &h2lambda,
666- src_w, src_h, in_w, in_h);
674+
675+ PreCalculatorForLinearInterpInputIndex (&in_img_idx, &w_id, &w1lambda,
676+ &w2lambda, src_w, in_w);
677+ PreCalculatorForLinearInterpInputIndex (&in_img_idy, &h_id, &h1lambda,
678+ &h2lambda, src_h, in_h);
667679
668680 T* in_pos = &in[out_id_h * in_chw + channel_id * in_img_size +
669681 in_img_idy * in_w + in_img_idx];
@@ -690,9 +702,11 @@ __global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w,
690702 T w1lambda, h1lambda, w2lambda, h2lambda;
691703 T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
692704 T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
693- PreCalculatorForInputIndex (&in_img_idx, &in_img_idy, &w_id, &h_id,
694- &w1lambda, &h1lambda, &w2lambda, &h2lambda,
695- src_w, src_h, in_w, in_h);
705+
706+ PreCalculatorForLinearInterpInputIndex (&in_img_idx, &w_id, &w1lambda,
707+ &w2lambda, src_w, in_w);
708+ PreCalculatorForLinearInterpInputIndex (&in_img_idy, &h_id, &h1lambda,
709+ &h2lambda, src_h, in_h);
696710
697711 T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels +
698712 in_img_idx * num_channels + channel_id];
@@ -1398,11 +1412,25 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
13981412 thread_num = 512 ;
13991413 }
14001414#endif
1401-
1402- KeBilinearInterpFw<T><<<config.block_per_grid, thread_num, 0 ,
1403- ctx.cuda_device_context().stream()>>> (
1404- input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
1405- out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout);
1415+ const T align_type_value = (align_mode == 0 && !align_corners) ? 0 .5f : 0 ;
1416+ if (data_layout == DataLayout::kNCHW ) {
1417+ // get launch 3D config
1418+ int nc = n * c;
1419+ platform::GpuLaunchConfig config_3d =
1420+ GetGpuLaunchConfig3D (ctx.cuda_device_context (), nc, out_h, out_w);
1421+ KeBilinearInterpNCHWFw<
1422+ T><<<config_3d.block_per_grid, config_3d.thread_per_block, 0 ,
1423+ ctx.cuda_device_context().stream()>>> (
1424+ input_data, in_h, in_w, output_data, out_h, out_w, nc, ratio_h,
1425+ ratio_w, align_type_value);
1426+ } else {
1427+ int64_t cw = c * out_w;
1428+ auto interp_divmods = FastDivModForInterpolate (c, out_chw, cw);
1429+ KeBilinearInterpFw<T><<<config.block_per_grid, thread_num, 0 ,
1430+ ctx.cuda_device_context().stream()>>> (
1431+ input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
1432+ out_chw, c, ratio_h, ratio_w, align_type_value, interp_divmods);
1433+ }
14061434 } else if (" bicubic" == interp_method) {
14071435#ifdef __HIPCC__
14081436 constexpr int thread_per_block = 256 ;
0 commit comments