@@ -17,6 +17,7 @@ limitations under the License. */
1717#include " paddle/fluid/operators/grid_sampler_op.h"
1818#include " paddle/fluid/platform/device/gpu/gpu_device_function.h"
1919#include " paddle/fluid/platform/device/gpu/gpu_info.h"
20+ #include " paddle/fluid/platform/device/gpu/gpu_launch_config.h"
2021#include " paddle/fluid/platform/device/gpu/gpu_primitives.h"
2122
2223namespace paddle {
@@ -292,15 +293,12 @@ class GridSampleOpCUDAKernel : public framework::OpKernel<T> {
292293 auto * output_data = output->mutable_data <T>(ctx.GetPlace ());
293294 VLOG (3 ) << " out dims: " << output->dims ()[0 ] << " ; " << output->dims ()[1 ]
294295 << " ; " << output->dims ()[2 ] << " ; " << output->dims ()[3 ];
295- phi::funcs::SetConstant<paddle::platform::CUDADeviceContext, T>()(
296- dev_ctx, output, static_cast <T>(0 ));
297296 int count = static_cast <int >(n * out_h * out_w);
298297 auto cu_stream = dev_ctx.stream ();
299- int block_size = 512 ;
300- int grid_size = (count + block_size - 1 ) / block_size;
301- VLOG (3 ) << " cuda launch - grid dims: " << grid_size << " ; block dims"
302- << block_size;
303- grid_sample_cuda_kernel<T><<<grid_size, block_size, 0 , cu_stream>>> (
298+ platform::GpuLaunchConfig config =
299+ platform::GetGpuLaunchConfig1D (dev_ctx, count);
300+ grid_sample_cuda_kernel<
301+ T><<<config.block_per_grid, config.thread_per_block, 0 , cu_stream>>> (
304302 count, n, c, out_h, out_w, in_h, in_w, input->data <T>(),
305303 grid->data <T>(), output_data, mode, padding_mode, align_corners);
306304 }
@@ -467,19 +465,14 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
467465 if (ctx.HasOutput (framework::GradVarName (" Grid" ))) {
468466 auto * grid_grad = ctx.Output <Tensor>(framework::GradVarName (" Grid" ));
469467 grid_grad_data = grid_grad->mutable_data <T>(ctx.GetPlace ());
470- phi::funcs::SetConstant<paddle::platform::CUDADeviceContext, T>()(
471- ctx.template device_context <paddle::platform::CUDADeviceContext>(),
472- grid_grad, static_cast <T>(0 ));
473468 }
474469
475470 int count = static_cast <int >(n * out_h * out_w);
476471 auto cu_stream = dev_ctx.stream ();
477- int block_size = 512 ;
478- int grid_size = (count + block_size - 1 ) / block_size;
479- VLOG (3 ) << " cuda launch grad kernel - grid dims: " << grid_size
480- << " ; block dims" << block_size << " ; count: " << count;
472+ platform::GpuLaunchConfig config =
473+ platform::GetGpuLaunchConfig1D (dev_ctx, count);
481474 grid_sampler_cuda_backward_kernel<
482- T><<<grid_size, block_size , 0 , cu_stream>>> (
475+ T><<<config.block_per_grid, config.thread_per_block , 0 , cu_stream>>> (
483476 count, output_grad->data <T>(), input->data <T>(), grid->data <T>(), n, c,
484477 out_h, out_w, in_h, in_w, input_grad->data <T>(), grid_grad_data, mode,
485478 padding_mode, align_corners);
0 commit comments