Skip to content

Commit a117497

Browse files
authored
Optimize bilinear interpolation foward (#39243)
* bilinear_fw init * optimize code * pre-compute linear_interp input index
1 parent c86765e commit a117497

File tree

1 file changed

+118
-90
lines changed

1 file changed

+118
-90
lines changed

paddle/fluid/operators/interpolate_v2_op.cu

Lines changed: 118 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,17 @@ inline platform::GpuLaunchConfig GetGpuLaunchConfig3D(
5959
return config;
6060
}
6161

62+
template <typename T>
63+
__forceinline__ __device__ void PreCalculatorForLinearInterpInputIndex(
64+
int* in_img_idx, int* w_id, T* w1lambda, T* w2lambda, T src_w,
65+
const int in_img_w) {
66+
src_w = (src_w > 0) ? src_w : 0.f;
67+
*in_img_idx = static_cast<int>(src_w);
68+
*w_id = (*in_img_idx < in_img_w - 1) ? 1 : 0;
69+
*w1lambda = src_w - *in_img_idx;
70+
*w2lambda = 1.f - *w1lambda;
71+
}
72+
6273
struct FastDivModForInterpolate {
6374
public:
6475
FastDivMod channels_div;
@@ -417,96 +428,93 @@ __global__ void KeLinearInterpBw(T* in, const size_t in_img_w,
417428
}
418429

419430
template <typename T>
420-
__global__ void KeBilinearInterpFw(
421-
const T* in, const size_t in_img_h, const size_t in_img_w,
422-
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
423-
const size_t out_img_w, const size_t output_h, const size_t output_w,
424-
const size_t num_channels, const float ratio_h, const float ratio_w,
425-
const bool align_corners, const int align_mode,
426-
const DataLayout data_layout) {
427-
int nthreads = output_h * output_w;
428-
int tid = blockIdx.x * blockDim.x + threadIdx.x;
429-
int stride = blockDim.x * gridDim.x;
430-
bool align_flag = (align_mode == 0 && !align_corners);
431-
for (; tid < nthreads; tid += stride) {
432-
int out_id_h = tid / output_w;
433-
int out_id_w = tid % output_w;
434-
int in_img_size = input_w / num_channels;
435-
int out_img_size = output_w / num_channels;
431+
__global__ void KeBilinearInterpNCHWFw(const T* in, const size_t in_img_h,
432+
const size_t in_img_w, T* out,
433+
const size_t out_img_h,
434+
const size_t out_img_w, const size_t nc,
435+
const float ratio_h, const float ratio_w,
436+
const T align_type_value) {
437+
int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x;
438+
int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y;
439+
int nc_id = threadIdx.z + blockIdx.z * blockDim.z;
440+
int nc_stride = blockDim.z * gridDim.z;
436441

437-
int channel_id, out_img_idy, out_img_idx;
438-
if (data_layout == DataLayout::kNCHW) {
439-
channel_id = out_id_w / out_img_size;
440-
out_img_idy = (out_id_w % out_img_size) / out_img_w;
441-
out_img_idx = tid % out_img_w;
442-
} else {
443-
out_img_idy = out_id_w / (out_img_w * num_channels);
444-
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
445-
channel_id = tid % num_channels;
446-
}
442+
int in_img_idx, in_img_idy, h_id, w_id;
443+
T h1lambda, w1lambda, h2lambda, w2lambda;
444+
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
445+
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
447446

448-
int in_img_idy = align_flag
449-
? static_cast<int>(ratio_h * (out_img_idy + 0.5) - 0.5)
450-
: static_cast<int>(ratio_h * out_img_idy);
451-
in_img_idy = (in_img_idy > 0) ? in_img_idy : 0;
452-
int h_id = (in_img_idy < in_img_h - 1) ? 1 : 0;
453-
T src_h = ratio_h * (out_img_idy + 0.5) - 0.5;
454-
src_h = (src_h > 0) ? src_h : 0;
455-
T h1lambda =
456-
align_flag ? src_h - in_img_idy : ratio_h * out_img_idy - in_img_idy;
457-
T h2lambda = 1.f - h1lambda;
447+
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
448+
&w2lambda, src_w, in_img_w);
449+
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
450+
&h2lambda, src_h, in_img_h);
458451

459-
int in_img_idx = align_flag
460-
? static_cast<int>(ratio_w * (out_img_idx + 0.5) - 0.5)
461-
: static_cast<int>(ratio_w * out_img_idx);
462-
in_img_idx = (in_img_idx > 0) ? in_img_idx : 0;
463-
int w_id = (in_img_idx < in_img_w - 1) ? 1 : 0;
464-
T src_w = ratio_w * (out_img_idx + 0.5) - 0.5;
465-
src_w = (src_w > 0) ? src_w : 0;
466-
T w1lambda =
467-
align_flag ? src_w - in_img_idx : ratio_w * out_img_idx - in_img_idx;
468-
T w2lambda = 1.f - w1lambda;
452+
int in_index = (nc_id * in_img_h + in_img_idy) * in_img_w + in_img_idx;
453+
int in_index_stride = nc_stride * in_img_h * in_img_w;
469454

470-
if (data_layout == DataLayout::kNCHW) {
471-
const T* in_pos = &in[out_id_h * input_w + channel_id * in_img_size +
472-
in_img_idy * in_img_w + in_img_idx];
455+
int out_index = (nc_id * out_img_h + out_img_idy) * out_img_w + out_img_idx;
456+
int out_index_stride = nc_stride * out_img_h * out_img_w;
473457

474-
// bilinear interpolation
475-
out[out_id_h * output_w + out_id_w] =
458+
// prevent from multiple threads writing
459+
if (out_img_idx < out_img_w && out_img_idy < out_img_h) {
460+
while (nc_id < nc) {
461+
const T* in_pos = &in[in_index];
462+
out[out_index] =
476463
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
477464
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
478465
w1lambda * in_pos[h_id * in_img_w + w_id]);
479-
} else {
480-
const T* in_pos =
481-
&in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
482-
in_img_idx * num_channels + channel_id];
483466

484-
// bilinear interpolation
485-
out[out_id_h * output_w + out_id_w] =
486-
h2lambda *
487-
(w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]) +
488-
h1lambda * (w2lambda * in_pos[h_id * in_img_w * num_channels] +
489-
w1lambda * in_pos[h_id * in_img_w * num_channels +
490-
w_id * num_channels]);
467+
in_index += in_index_stride;
468+
out_index += out_index_stride;
469+
nc_id += nc_stride;
491470
}
492471
}
493472
}
494473

495474
template <typename T>
496-
__forceinline__ __device__ void PreCalculatorForInputIndex(
497-
int* in_img_idx, int* in_img_idy, int* w_id, int* h_id, T* w1lambda,
498-
T* h1lambda, T* w2lambda, T* h2lambda, T src_w, T src_h, const int in_img_w,
499-
const int in_img_h) {
500-
src_w = (src_w > 0) ? src_w : 0.f;
501-
src_h = (src_h > 0) ? src_h : 0.f;
502-
*in_img_idx = static_cast<int>(src_w);
503-
*in_img_idy = static_cast<int>(src_h);
504-
*w_id = (*in_img_idx < in_img_w - 1) ? 1 : 0;
505-
*h_id = (*in_img_idy < in_img_h - 1) ? 1 : 0;
506-
*w1lambda = src_w - *in_img_idx;
507-
*h1lambda = src_h - *in_img_idy;
508-
*w2lambda = 1.f - *w1lambda;
509-
*h2lambda = 1.f - *h1lambda;
475+
__global__ void KeBilinearInterpFw(
476+
const T* in, const size_t in_img_h, const size_t in_img_w,
477+
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
478+
const size_t out_img_w, const size_t output_h, const size_t output_w,
479+
const size_t num_channels, const float ratio_h, const float ratio_w,
480+
const T align_type_value, FastDivModForInterpolate divmods) {
481+
int nthreads = output_h * output_w;
482+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
483+
int stride = blockDim.x * gridDim.x;
484+
485+
for (; tid < nthreads; tid += stride) {
486+
auto out_id_divmod = divmods.output_w_div.Divmod(tid);
487+
int out_id_h = out_id_divmod.val[0];
488+
int out_id_w = out_id_divmod.val[1];
489+
490+
int channel_id = divmods.channels_div.Divmod(tid).val[1];
491+
auto outimg_id_divmod = divmods.output_wc_div.Divmod(out_id_w);
492+
int out_img_idy = outimg_id_divmod.val[0];
493+
int out_img_idx =
494+
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
495+
496+
int in_img_idx, in_img_idy, h_id, w_id;
497+
T h1lambda, w1lambda, h2lambda, w2lambda;
498+
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
499+
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
500+
501+
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
502+
&w2lambda, src_w, in_img_w);
503+
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
504+
&h2lambda, src_h, in_img_h);
505+
506+
// bilinear interpolation
507+
const T* in_pos =
508+
&in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
509+
in_img_idx * num_channels + channel_id];
510+
out[tid] =
511+
h2lambda *
512+
(w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]) +
513+
h1lambda *
514+
(w2lambda * in_pos[h_id * in_img_w * num_channels] +
515+
w1lambda *
516+
in_pos[h_id * in_img_w * num_channels + w_id * num_channels]);
517+
}
510518
}
511519

512520
/* Calculate the minimum of partial elements in a block */
@@ -574,9 +582,11 @@ __global__ void KeBilinearInterpBwShareMemory(
574582
T w1lambda, h1lambda, w2lambda, h2lambda;
575583
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
576584
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
577-
PreCalculatorForInputIndex(&in_img_idx, &in_img_idy, &w_id, &h_id,
578-
&w1lambda, &h1lambda, &w2lambda, &h2lambda,
579-
src_w, src_h, in_w, in_h);
585+
586+
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
587+
&w2lambda, src_w, in_w);
588+
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
589+
&h2lambda, src_h, in_h);
580590

581591
// top_left_index is just input_index.
582592
int input_index = out_id_h * in_chw + channel_id * in_img_size +
@@ -661,9 +671,11 @@ __global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w,
661671

662672
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
663673
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
664-
PreCalculatorForInputIndex(&in_img_idx, &in_img_idy, &w_id, &h_id,
665-
&w1lambda, &h1lambda, &w2lambda, &h2lambda,
666-
src_w, src_h, in_w, in_h);
674+
675+
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
676+
&w2lambda, src_w, in_w);
677+
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
678+
&h2lambda, src_h, in_h);
667679

668680
T* in_pos = &in[out_id_h * in_chw + channel_id * in_img_size +
669681
in_img_idy * in_w + in_img_idx];
@@ -690,9 +702,11 @@ __global__ void KeBilinearInterpBw(T* in, const int in_h, const int in_w,
690702
T w1lambda, h1lambda, w2lambda, h2lambda;
691703
T src_w = ratio_w * (out_img_idx + align_type_value) - align_type_value;
692704
T src_h = ratio_h * (out_img_idy + align_type_value) - align_type_value;
693-
PreCalculatorForInputIndex(&in_img_idx, &in_img_idy, &w_id, &h_id,
694-
&w1lambda, &h1lambda, &w2lambda, &h2lambda,
695-
src_w, src_h, in_w, in_h);
705+
706+
PreCalculatorForLinearInterpInputIndex(&in_img_idx, &w_id, &w1lambda,
707+
&w2lambda, src_w, in_w);
708+
PreCalculatorForLinearInterpInputIndex(&in_img_idy, &h_id, &h1lambda,
709+
&h2lambda, src_h, in_h);
696710

697711
T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels +
698712
in_img_idx * num_channels + channel_id];
@@ -1398,11 +1412,25 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
13981412
thread_num = 512;
13991413
}
14001414
#endif
1401-
1402-
KeBilinearInterpFw<T><<<config.block_per_grid, thread_num, 0,
1403-
ctx.cuda_device_context().stream()>>>(
1404-
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
1405-
out_chw, c, ratio_h, ratio_w, align_corners, align_mode, data_layout);
1415+
const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0;
1416+
if (data_layout == DataLayout::kNCHW) {
1417+
// get launch 3D config
1418+
int nc = n * c;
1419+
platform::GpuLaunchConfig config_3d =
1420+
GetGpuLaunchConfig3D(ctx.cuda_device_context(), nc, out_h, out_w);
1421+
KeBilinearInterpNCHWFw<
1422+
T><<<config_3d.block_per_grid, config_3d.thread_per_block, 0,
1423+
ctx.cuda_device_context().stream()>>>(
1424+
input_data, in_h, in_w, output_data, out_h, out_w, nc, ratio_h,
1425+
ratio_w, align_type_value);
1426+
} else {
1427+
int64_t cw = c * out_w;
1428+
auto interp_divmods = FastDivModForInterpolate(c, out_chw, cw);
1429+
KeBilinearInterpFw<T><<<config.block_per_grid, thread_num, 0,
1430+
ctx.cuda_device_context().stream()>>>(
1431+
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
1432+
out_chw, c, ratio_h, ratio_w, align_type_value, interp_divmods);
1433+
}
14061434
} else if ("bicubic" == interp_method) {
14071435
#ifdef __HIPCC__
14081436
constexpr int thread_per_block = 256;

0 commit comments

Comments
 (0)