Skip to content

Commit 40688d2

Browse files
committed
refine im2col (up_pad,down_pad)
1 parent 1e60c9b commit 40688d2

File tree

4 files changed

+135
-78
lines changed

4 files changed

+135
-78
lines changed

paddle/operators/math/im2col.cc

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,8 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
140140
public:
141141
void operator()(const platform::DeviceContext& context,
142142
const framework::Tensor& im, framework::Tensor& col,
143-
int stride, int pad, int row_begin, int row_end) {
144-
int stride_height = stride;
145-
int stride_width = 0;
146-
int padding_height = pad;
147-
int padding_width = 0;
143+
int stride_height, int stride_width, int up_pad,
144+
int down_pad) {
148145
PADDLE_ENFORCE(im.dims().size() == 3);
149146
PADDLE_ENFORCE(col.dims().size() == 5);
150147
int input_channels = im.dims()[0];
@@ -155,6 +152,18 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
155152
// int output_height = col.dims()[0];
156153
int output_width = col.dims()[1];
157154

155+
int row_begin, row_end;
156+
int padding_height = std::max(up_pad, down_pad);
157+
int padding_width = 0;
158+
if (up_pad >= down_pad) {
159+
row_begin = 0;
160+
} else {
161+
row_begin = down_pad - up_pad;
162+
}
163+
row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
164+
stride_height +
165+
1);
166+
158167
const T* im_data = im.data<T>();
159168
T* col_data = col.data<T>();
160169

@@ -204,12 +213,8 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
204213
platform::CPUPlace, T> {
205214
public:
206215
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
207-
const framework::Tensor& col, int stride, int pad,
208-
int row_start, int row_end) {
209-
int stride_height = stride;
210-
int stride_width = 0;
211-
int padding_height = pad;
212-
int padding_width = 0;
216+
const framework::Tensor& col, int stride_height,
217+
int stride_width, int up_pad, int down_pad) {
213218
PADDLE_ENFORCE(im.dims().size() == 3);
214219
PADDLE_ENFORCE(col.dims().size() == 5);
215220
int input_channels = im.dims()[0];
@@ -220,10 +225,22 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
220225
// int output_height = col.dims()[0];
221226
int output_width = col.dims()[1];
222227

228+
int row_begin, row_end;
229+
int padding_height = std::max(up_pad, down_pad);
230+
int padding_width = 0;
231+
if (up_pad >= down_pad) {
232+
row_begin = 0;
233+
} else {
234+
row_begin = down_pad - up_pad;
235+
}
236+
row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
237+
stride_height +
238+
1);
239+
223240
T* im_data = im.data<T>();
224241
const T* col_data = col.data<T>();
225242

226-
for (int col_row_idx = row_start; col_row_idx < row_end; ++col_row_idx) {
243+
for (int col_row_idx = row_begin; col_row_idx < row_end; ++col_row_idx) {
227244
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) {
228245
for (int channel = 0; channel < input_channels; ++channel) {
229246
for (int filter_row_idx = 0; filter_row_idx < filter_height;
@@ -235,7 +252,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
235252
int im_col_offset =
236253
col_col_idx * stride_width + filter_col_idx - padding_width;
237254
int col_offset =
238-
((((col_row_idx - row_start) * output_width + col_col_idx) *
255+
((((col_row_idx - row_begin) * output_width + col_col_idx) *
239256
input_channels +
240257
channel) *
241258
filter_height +

paddle/operators/math/im2col.cu

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -240,19 +240,28 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
240240
public:
241241
void operator()(const platform::DeviceContext& context,
242242
const framework::Tensor& im, framework::Tensor& col,
243-
int stride, int pad, int row_begin, int row_end) {
244-
int stride_height = stride;
245-
int stride_width = 0;
246-
int padding_height = pad;
247-
int padding_width = 0;
248-
243+
int stride_height, int stride_width, int up_pad,
244+
int down_pad) {
249245
PADDLE_ENFORCE(im.dims().size() == 3);
250246
PADDLE_ENFORCE(col.dims().size() == 5);
251247
int input_channels = im.dims()[0];
252248
int input_height = im.dims()[1];
253249
int input_width = im.dims()[2];
254250
int filter_height = col.dims()[3];
255251
int filter_width = col.dims()[4];
252+
253+
int row_begin, row_end;
254+
int padding_height = std::max(up_pad, down_pad);
255+
int padding_width = 0;
256+
if (up_pad >= down_pad) {
257+
row_begin = 0;
258+
} else {
259+
row_begin = down_pad - up_pad;
260+
}
261+
row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
262+
stride_height +
263+
1);
264+
256265
int output_height = row_end - row_begin; // col.dims()[0];
257266
int output_width = col.dims()[1];
258267

@@ -295,7 +304,6 @@ __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
295304
int row_end) {
296305
int swid = blockIdx.x;
297306
int shid = blockIdx.y;
298-
// if (shid < row_begin || shid > row_end) return;
299307
for (int channelid = threadIdx.z; channelid < input_channels;
300308
channelid += blockDim.z) {
301309
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
@@ -331,19 +339,28 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
331339
platform::GPUPlace, T> {
332340
public:
333341
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
334-
const framework::Tensor& col, int stride, int pad,
335-
int row_begin, int row_end) {
336-
int stride_height = stride;
337-
int stride_width = 0;
338-
int padding_height = pad;
339-
int padding_width = 0;
342+
const framework::Tensor& col, int stride_height,
343+
int stride_width, int up_pad, int down_pad) {
340344
PADDLE_ENFORCE(im.dims().size() == 3);
341345
PADDLE_ENFORCE(col.dims().size() == 5);
342346
int input_channels = im.dims()[0];
343347
int input_height = im.dims()[1];
344348
int input_width = im.dims()[2];
345349
int filter_height = col.dims()[3];
346350
int filter_width = col.dims()[4];
351+
352+
int row_begin, row_end;
353+
int padding_height = std::max(up_pad, down_pad);
354+
int padding_width = 0;
355+
if (up_pad >= down_pad) {
356+
row_begin = 0;
357+
} else {
358+
row_begin = down_pad - up_pad;
359+
}
360+
row_end = row_begin + ((input_height + up_pad + down_pad - filter_height) /
361+
stride_height +
362+
1);
363+
347364
int output_height = row_end - row_begin; // col.dims()[0];
348365
int output_width = col.dims()[1];
349366

paddle/operators/math/im2col_test.cc

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ void testIm2col() {
3535
*
3636
* output_ocf = [0, 1, 3, 4
3737
* 1, 2, 4, 5]
38+
*
39+
* col2im_cfo = [0, 2, 2
40+
* 3, 4, 5]
41+
*
42+
* col2im_ocf = [0, 2, 2
43+
* 3, 4, 5]
3844
*/
3945
int input_height = 2;
4046
int input_width = 3;
@@ -59,7 +65,7 @@ void testIm2col() {
5965
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
6066
#else
6167
PADDLE_THROW("no GPU support");
62-
#endif // PADDLE_ONLY_CPU
68+
#endif // PADDLE_WITH_CUDA
6369
}
6470
if (paddle::platform::is_cpu_place(*place)) {
6571
input = input_tmp;
@@ -71,6 +77,7 @@ void testIm2col() {
7177
output_ocf.mutable_data<float>(
7278
{output_height, output_width, 1, filter_size, filter_size}, *place);
7379

80+
// Im2Col
7481
paddle::operators::math::Im2ColFunctor<
7582
paddle::operators::math::ColFormat::kCFO, Place, float>
7683
im2col;
@@ -79,8 +86,12 @@ void testIm2col() {
7986
im2col_ocf;
8087

8188
im2col(*context, input, output_cfo, stride, stride, padding, padding);
82-
im2col_ocf(*context, input, output_ocf, stride, padding, 0,
83-
output_height * output_width);
89+
im2col_ocf(*context, input, output_ocf, /*stride_height*/ stride,
90+
/*stride_width*/ stride, /*up_pad*/ padding,
91+
/*down_pad*/ padding);
92+
93+
float out_cfo_data[] = {0, 1, 1, 2, 3, 4, 4, 5};
94+
float out_ocf_data[] = {0, 1, 3, 4, 1, 2, 4, 5};
8495

8596
float* out_cfo_ptr;
8697
if (paddle::platform::is_cpu_place(*place)) {
@@ -90,14 +101,9 @@ void testIm2col() {
90101
*context);
91102
out_cfo_ptr = output_tmp.data<float>();
92103
}
93-
EXPECT_EQ(out_cfo_ptr[0], 0);
94-
EXPECT_EQ(out_cfo_ptr[1], 1);
95-
EXPECT_EQ(out_cfo_ptr[2], 1);
96-
EXPECT_EQ(out_cfo_ptr[3], 2);
97-
EXPECT_EQ(out_cfo_ptr[4], 3);
98-
EXPECT_EQ(out_cfo_ptr[5], 4);
99-
EXPECT_EQ(out_cfo_ptr[6], 4);
100-
EXPECT_EQ(out_cfo_ptr[7], 5);
104+
for (int i = 0; i < 6; ++i) {
105+
EXPECT_EQ(out_cfo_ptr[i], out_cfo_data[i]);
106+
}
101107

102108
float* out_ocf_ptr;
103109
if (paddle::platform::is_cpu_place(*place)) {
@@ -107,14 +113,60 @@ void testIm2col() {
107113
*context);
108114
out_ocf_ptr = output_tmp.data<float>();
109115
}
110-
EXPECT_EQ(out_ocf_ptr[0], 0);
111-
EXPECT_EQ(out_ocf_ptr[1], 1);
112-
EXPECT_EQ(out_ocf_ptr[2], 3);
113-
EXPECT_EQ(out_ocf_ptr[3], 4);
114-
EXPECT_EQ(out_ocf_ptr[4], 1);
115-
EXPECT_EQ(out_ocf_ptr[5], 2);
116-
EXPECT_EQ(out_ocf_ptr[6], 4);
117-
EXPECT_EQ(out_ocf_ptr[7], 5);
116+
for (int i = 0; i < 6; ++i) {
117+
EXPECT_EQ(out_ocf_ptr[i], out_ocf_data[i]);
118+
}
119+
120+
// Col2Im: kCFO
121+
paddle::operators::math::Col2ImFunctor<
122+
paddle::operators::math::ColFormat::kCFO, Place, float>
123+
col2im;
124+
paddle::operators::math::Col2ImFunctor<
125+
paddle::operators::math::ColFormat::kOCF, Place, float>
126+
col2im_ocf;
127+
float col2im_data[] = {0, 2, 2, 3, 8, 5};
128+
129+
memset(input_ptr, 0, 6 * sizeof(float));
130+
if (paddle::platform::is_cpu_place(*place)) {
131+
input = input_tmp;
132+
} else {
133+
input.CopyFrom<float>(input_tmp, *place, *context);
134+
}
135+
136+
col2im(*context, input, output_cfo, stride, stride, padding, padding);
137+
138+
float* in_ptr;
139+
if (paddle::platform::is_cpu_place(*place)) {
140+
in_ptr = input.data<float>();
141+
} else {
142+
input_tmp.CopyFrom<float>(input, paddle::platform::CPUPlace(), *context);
143+
in_ptr = input_tmp.data<float>();
144+
}
145+
for (int i = 0; i < 6; ++i) {
146+
EXPECT_EQ(in_ptr[i], col2im_data[i]);
147+
}
148+
149+
// Col2Im: kOCF
150+
memset(input_ptr, 0, 6 * sizeof(float));
151+
if (paddle::platform::is_cpu_place(*place)) {
152+
input = input_tmp;
153+
} else {
154+
input.CopyFrom<float>(input_tmp, *place, *context);
155+
}
156+
157+
col2im_ocf(*context, input, output_ocf, /*stride_height*/ stride,
158+
/*stride_width*/ stride, /*up_pad*/ padding,
159+
/*down_pad*/ padding);
160+
161+
if (paddle::platform::is_cpu_place(*place)) {
162+
in_ptr = input.data<float>();
163+
} else {
164+
input_tmp.CopyFrom<float>(input, paddle::platform::CPUPlace(), *context);
165+
in_ptr = input_tmp.data<float>();
166+
}
167+
for (int i = 0; i < 6; ++i) {
168+
EXPECT_EQ(in_ptr[i], col2im_data[i]);
169+
}
118170
}
119171

120172
TEST(math, im2col) {

paddle/operators/sequence_project_op.h

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -87,24 +87,9 @@ class SequenceProjectKernel : public framework::OpKernel<T> {
8787
sequence_width}); // input_channels, input_height, input_width
8888
in_t.Resize(framework::make_ddim(input_shape));
8989
for (int j = 0; j < context_length; ++j) {
90-
int pad;
91-
int row_start;
92-
93-
if (up_pad != 0) {
94-
pad = up_pad;
95-
row_start = 0;
96-
} else if (down_pad != 0) {
97-
pad = down_pad;
98-
row_start = down_pad;
99-
} else {
100-
pad = 0;
101-
row_start = 0;
102-
}
103-
10490
im2col_ocf(context.device_context(), in_t, out_t,
105-
/*stride*/ context_stride, /*pad*/ pad,
106-
/*row_start*/ row_start,
107-
/*row_end*/ row_start + sequence_height);
91+
/*stride_height*/ context_stride, /*stride_width*/ 0, up_pad,
92+
down_pad);
10893
if (padding_trainable) {
10994
// add up trainable data
11095
out_t.Resize(framework::make_ddim(
@@ -229,23 +214,9 @@ class SequenceProjectGradKernel : public framework::OpKernel<T> {
229214
out_g_t.Resize(framework::make_ddim(
230215
{sequence_height, 1, 1, context_length, sequence_width}));
231216

232-
int pad;
233-
int row_start;
234-
235-
if (up_pad != 0) {
236-
pad = up_pad;
237-
row_start = 0;
238-
} else if (down_pad != 0) {
239-
pad = down_pad;
240-
row_start = down_pad;
241-
} else {
242-
pad = 0;
243-
row_start = 0;
244-
}
245217
col2im_ocf(context.device_context(), in_g_t, out_g_t,
246-
/*stride*/ context_stride, /*pad*/ pad,
247-
/*row_start*/ row_start,
248-
/*row_end*/ row_start + sequence_height);
218+
/*stride_height*/ context_stride, /*stride_width*/ 0, up_pad,
219+
down_pad);
249220

250221
// out_g_t back to orign size
251222
}

0 commit comments

Comments
 (0)