@@ -21,9 +21,9 @@ namespace math {
2121
2222template <typename T>
2323__global__ void KernelMaxOut (const int nthreads, const T* input_data,
24- const int channels,
25- const int input_height, const int input_width ,
26- int groups, T* output_data ) {
24+ const int channels, const int input_height ,
25+ const int input_width, int groups ,
26+ T* output_data) {
2727 const int size = input_height * input_width * channels / groups;
2828 const int feat_len = input_height * input_width;
2929 int index = blockIdx .x * blockDim .x + threadIdx .x ;
@@ -34,7 +34,7 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
3434 int channel_idx = batch_offset / feat_len;
3535 int feat_idx = batch_offset % feat_len;
3636 int data_idx =
37- (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
37+ (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
3838 T ele = static_cast <T>(-FLT_MAX);
3939 for (int g = 0 ; g < groups; ++g) {
4040 T x = input_data[data_idx + g * feat_len];
@@ -44,34 +44,35 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
4444 }
4545}
4646template <typename T>
47- __global__ void KernelMaxoutGrad (
48- const int nthreads, const T* input_data, const T* output_data,
49- const T* output_grad, T* input_grad, const int channels,
50- const int input_height, const int input_width, int groups) {
51- const int size = input_height * input_width * channels / groups;
52- const int feat_len = input_height * input_width;
53- int index = blockIdx .x * blockDim .x + threadIdx .x ;
54- int offset = blockDim .x * gridDim .x ;
55- for (int i = index; i < nthreads; i += offset) {
56- int batch_idx = i / size;
57- int batch_offset = i % size;
58- int channel_idx = batch_offset / feat_len;
59- int feat_idx = batch_offset % feat_len;
60- int data_idx =
47+ __global__ void KernelMaxoutGrad (const int nthreads, const T* input_data,
48+ const T* output_data, const T* output_grad,
49+ T* input_grad, const int channels,
50+ const int input_height, const int input_width,
51+ int groups) {
52+ const int size = input_height * input_width * channels / groups;
53+ const int feat_len = input_height * input_width;
54+ int index = blockIdx .x * blockDim .x + threadIdx .x ;
55+ int offset = blockDim .x * gridDim .x ;
56+ for (int i = index; i < nthreads; i += offset) {
57+ int batch_idx = i / size;
58+ int batch_offset = i % size;
59+ int channel_idx = batch_offset / feat_len;
60+ int feat_idx = batch_offset % feat_len;
61+ int data_idx =
6162 (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
62- int max_index = -1 ;
63- bool continue_match = true ;
64- for (int g = 0 ; g < groups && continue_match; ++g) {
65- if (input_data[data_idx + g * feat_len] == output_data[i]) {
66- max_index = data_idx + g * feat_len;
67- continue_match = false ;
68- break ;
69- }
70- }
71- if (max_index != -1 ) {
72- input_grad[max_index] += output_grad[index];
63+ int max_index = -1 ;
64+ bool continue_match = true ;
65+ for (int g = 0 ; g < groups && continue_match; ++g) {
66+ if (input_data[data_idx + g * feat_len] == output_data[i]) {
67+ max_index = data_idx + g * feat_len;
68+ continue_match = false ;
69+ break ;
7370 }
7471 }
72+ if (max_index != -1 ) {
73+ input_grad[max_index] += output_grad[index];
74+ }
75+ }
7576}
7677/*
7778 * All tensors are in NCHW format.
@@ -80,7 +81,7 @@ template <typename T>
8081class MaxOutFunctor <platform::GPUPlace, T> {
8182 public:
8283 void operator ()(const platform::DeviceContext& context,
83- const framework::Tensor& input, framework::Tensor * output,
84+ const framework::Tensor& input, framework::Tensor* output,
8485 int groups) {
8586 const int batch_size = input.dims ()[0 ];
8687 const int input_channels = input.dims ()[1 ];
@@ -92,7 +93,7 @@ class MaxOutFunctor<platform::GPUPlace, T> {
9293
9394 const T* input_data = input.data <T>();
9495 T* output_data = output->mutable_data <T>(context.GetPlace ());
95- int nthreads = output->numel ();
96+ int nthreads = output->numel ();
9697 int blocks = (nthreads + 1024 - 1 ) / 1024 ;
9798 dim3 threads (1024 , 1 );
9899 dim3 grid (blocks, 1 );
@@ -101,8 +102,7 @@ class MaxOutFunctor<platform::GPUPlace, T> {
101102 T><<<grid, threads, 0 ,
102103 reinterpret_cast <const platform::CUDADeviceContext&>(context)
103104 .stream()>>> (nthreads, input_data, input_channels,
104- input_height, input_width, groups,
105- output_data);
105+ input_height, input_width, groups, output_data);
106106 }
107107};
108108/*
@@ -112,11 +112,9 @@ template <typename T>
112112class MaxOutGradFunctor <platform::GPUPlace, T> {
113113 public:
114114 void operator ()(const platform::DeviceContext& context,
115- const framework::Tensor& input,
116- framework::Tensor * input_grad,
115+ const framework::Tensor& input, framework::Tensor* input_grad,
117116 const framework::Tensor& output,
118- const framework::Tensor& output_grad,
119- int groups) {
117+ const framework::Tensor& output_grad, int groups) {
120118 const int batch_size = input.dims ()[0 ];
121119 const int input_channels = input.dims ()[1 ];
122120 const int input_height = input.dims ()[2 ];
@@ -129,17 +127,17 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
129127 const T* output_data = output.data <T>();
130128 const T* output_grad_data = output_grad.data <T>();
131129 T* input_grad_data = input_grad->mutable_data <T>(context.GetPlace ());
132- int nthreads = output.numel ();
130+ int nthreads = output.numel ();
133131 int blocks = (nthreads + 1024 - 1 ) / 1024 ;
134132 dim3 threads (1024 , 1 );
135133 dim3 grid (blocks, 1 );
136134
137135 KernelMaxoutGrad<
138136 T><<<grid, threads, 0 ,
139137 reinterpret_cast <const platform::CUDADeviceContext&>(context)
140- .stream()>>> (
141- nthreads, input_data, output_data, output_grad_data, input_grad_data,
142- input_channels, input_height, input_width, groups);
138+ .stream()>>> (nthreads, input_data, output_data,
139+ output_grad_data, input_grad_data, input_channels ,
140+ input_height, input_width, groups);
143141 }
144142};
145143
0 commit comments