Skip to content

Commit 2e2a674

Browse files
Merge pull request #2490 from wanghaoshuang/crop_layer
add crop layer
2 parents d529134 + 676b76d commit 2e2a674

File tree

14 files changed

+714
-3
lines changed

14 files changed

+714
-3
lines changed

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License
1414

1515
cmake_minimum_required(VERSION 3.0)
16-
1716
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
1817
set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
1918
set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR})

paddle/function/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ if(WITH_GPU)
3636
add_simple_unittest(MulOpTest)
3737
add_simple_unittest(CosSimOpTest)
3838
add_simple_unittest(RowConvOpTest)
39+
add_simple_unittest(CropOpTest)
3940
endif()
4041

4142
add_simple_unittest(ConvOpTest)

paddle/function/CropOp.cpp

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "CropOp.h"
16+
#include "paddle/function/TensorShape.h"
17+
#include "paddle/math/Vector.h"
18+
19+
namespace paddle {
20+
21+
template <>
22+
void Crop<DEVICE_TYPE_CPU>(real* outputs,
23+
const real* inputs,
24+
const TensorShape inShape,
25+
const TensorShape outShape,
26+
const FuncConfig& conf) {
27+
std::vector<uint32_t> crop_corner =
28+
conf.get<std::vector<uint32_t>>("crop_corner");
29+
int cCrop = crop_corner[1];
30+
int hCrop = crop_corner[2];
31+
int wCrop = crop_corner[3];
32+
33+
int num = inShape[0];
34+
int inC = inShape[1];
35+
int inH = inShape[2];
36+
int inW = inShape[3];
37+
38+
int outC = outShape[1];
39+
int outH = outShape[2];
40+
int outW = outShape[3];
41+
42+
for (int n = 0; n < num; n++) {
43+
for (int c = 0; c < outC; c++) {
44+
for (int h = 0; h < outH; h++) {
45+
int outoff = ((n * outC + c) * outH + h) * outW;
46+
int inoff = ((n * inC + c + cCrop) * inH + h + hCrop) * inW + wCrop;
47+
memcpy(outputs + outoff, inputs + inoff, outW * sizeof(real));
48+
}
49+
}
50+
}
51+
}
52+
53+
template <>
54+
void CropGrad<DEVICE_TYPE_CPU>(const real* inGrad,
55+
real* outGrad,
56+
const TensorShape inShape,
57+
const TensorShape outShape,
58+
const FuncConfig& conf) {
59+
std::vector<uint32_t> crop_corner =
60+
conf.get<std::vector<uint32_t>>("crop_corner");
61+
int cCrop = crop_corner[1];
62+
int hCrop = crop_corner[2];
63+
int wCrop = crop_corner[3];
64+
65+
int num = outShape[0];
66+
int outC = outShape[1];
67+
int outH = outShape[2];
68+
int outW = outShape[3];
69+
70+
int inC = inShape[1];
71+
int inH = inShape[2];
72+
int inW = inShape[3];
73+
74+
for (int n = 0; n < num; n++) {
75+
for (int c = 0; c < inC; c++) {
76+
for (int h = 0; h < inH; h++) {
77+
int outoff = ((n * outC + c + cCrop) * outH + h + hCrop) * outW + wCrop;
78+
int inoff = ((n * inC + c) * inH + h) * inW;
79+
CpuVector inG = CpuVector(inW, const_cast<real*>(inGrad + inoff));
80+
CpuVector outG = CpuVector(inW, outGrad + outoff);
81+
outG += inG;
82+
}
83+
}
84+
}
85+
}
86+
87+
/**
88+
* \brief Crop input according to the specify corner and shape.
89+
* The input and output is a 4D tensor. In CropFunc, we only
90+
* crop the 2nd to 4th dimension.
91+
*
92+
* Argument in this Function:
93+
* \param pad_ A struct object contains the cropping corner and shape.
94+
* \param inputs A 4D tensor, only one input.
95+
* \param outputs A 4D tensor, the output value after cropping.
96+
*
97+
* For example,
98+
* Input(2,2,2,3) = [
99+
* [ [[1,2,3], [3,4,5]],
100+
* [[2,3,5], [1,6,7]] ],
101+
* [ [[4,3,1], [1,8,7]],
102+
* [[3,8,9], [2,3,5]] ]
103+
* ] # the input shape is (2,2,2,3)
104+
*
105+
* pad_: if corner = (0,1,1) and crop_shape = (2,1,2)
106+
* Output(2,2,1,2) = [
107+
* [ [[4,5]],
108+
* [[6,7]] ],
109+
* [ [[8,7]],
110+
* [[3,5]] ]
111+
* ] # the input shape is (2,2,2,3)
112+
*/
113+
template <DeviceType Device>
114+
class CropFunc : public FunctionBase {
115+
public:
116+
void init(const FuncConfig& config) override { conf_ = config; }
117+
118+
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
119+
CHECK_EQ(1UL, inputs.size());
120+
CHECK_EQ(1UL, outputs.size());
121+
CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO);
122+
123+
TensorShape inShape = inputs[0].shape();
124+
TensorShape outShape = outputs[0].shape();
125+
126+
Crop<Device>(outputs[0].data<real>(),
127+
inputs[0].data<real>(),
128+
inShape,
129+
outShape,
130+
conf_);
131+
}
132+
133+
private:
134+
FuncConfig conf_;
135+
};
136+
137+
/**
138+
* \brief The backward propagation of cropping Function.
139+
*
140+
* Argument in this Function:
141+
* \param crop_ The same meaning as it in CropFunc.
142+
* \param inputs The gradient with respect to the output value of CropFunc.
143+
* \param outputs The gradient with respect to the input value of CropFunc.
144+
*/
145+
146+
template <DeviceType Device>
147+
class CropGradFunc : public FunctionBase {
148+
public:
149+
void init(const FuncConfig& config) override { conf_ = config; }
150+
151+
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
152+
CHECK_EQ(1UL, inputs.size());
153+
CHECK_EQ(1UL, outputs.size());
154+
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
155+
156+
TensorShape outShape = outputs[0].shape();
157+
TensorShape inShape = inputs[0].shape();
158+
159+
CropGrad<Device>(inputs[0].data<real>(),
160+
outputs[0].data<real>(),
161+
inShape,
162+
outShape,
163+
conf_);
164+
}
165+
166+
private:
167+
FuncConfig conf_;
168+
};
169+
170+
REGISTER_TYPED_FUNC(Crop, CPU, CropFunc);
171+
REGISTER_TYPED_FUNC(CropGrad, CPU, CropGradFunc);
172+
#ifndef PADDLE_ONLY_CPU
173+
REGISTER_TYPED_FUNC(Crop, GPU, CropFunc);
174+
REGISTER_TYPED_FUNC(CropGrad, GPU, CropGradFunc);
175+
#endif
176+
177+
} // namespace paddle

paddle/function/CropOp.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "Function.h"
18+
19+
namespace paddle {
20+
21+
/**
22+
* \brief This funtion crops inputs according to the specify start point and
23+
*shape.
24+
*
25+
* \param[out] outputs save results.
26+
* \param[in] inputs input data.
27+
* \param[in] inShape the shape of input tensor.
28+
* \param[in] conf the cropping config
29+
*/
30+
template <DeviceType Device>
31+
void Crop(real* outputs,
32+
const real* inputs,
33+
const TensorShape inShape,
34+
const TensorShape outShape,
35+
const FuncConfig& conf);
36+
37+
/**
38+
* \brief Cropping operation backward.
39+
*
40+
* \param[out] inGrad gradients of previous layer
41+
* \param[in] outGrad output gradient
42+
* \param[in] inShape the shape of input tensor.
43+
* \param[in] conf the cropping config
44+
*/
45+
template <DeviceType Device>
46+
void CropGrad(const real* inGrad,
47+
real* outGrad,
48+
const TensorShape inShape,
49+
const TensorShape outShape,
50+
const FuncConfig& conf);
51+
} // namespace paddle

paddle/function/CropOpGpu.cu

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "hl_base.h"
16+
#include "CropOp.h"
17+
18+
namespace paddle {
19+
20+
__global__ void KeCrop(real* outputs, const real* inputs,
21+
int inC, int inH, int inW,
22+
int cropC, int cropH, int cropW,
23+
int outC, int outH, int outW, int nthreads) {
24+
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
25+
if (idx < nthreads) {
26+
const int w = idx % outW;
27+
const int h = (idx / outW) % outH;
28+
const int c = (idx / outW / outH) % outC;
29+
const int n = idx / outW / outH / outC;
30+
31+
const int off = ((n * inC + c + cropC) * inH + h + cropH) * inW + cropW + w;
32+
outputs[idx] = inputs[off];
33+
}
34+
}
35+
36+
template <>
37+
void Crop<DEVICE_TYPE_GPU>(real* outputs,
38+
const real* inputs,
39+
const TensorShape inShape,
40+
const TensorShape outShape,
41+
const FuncConfig& conf) {
42+
std::vector<uint32_t> crop_corner = conf.get<std::vector<uint32_t>>("crop_corner");
43+
int cropC = crop_corner[1];
44+
int cropH = crop_corner[2];
45+
int cropW = crop_corner[3];
46+
47+
int num = inShape[0];
48+
int inC = inShape[1];
49+
int inH = inShape[2];
50+
int inW = inShape[3];
51+
52+
int outC = outShape[1];
53+
int outH = outShape[2];
54+
int outW = outShape[3];
55+
56+
size_t nth = num * outC * outH * outW;
57+
int blockSize = 1024;
58+
int gridSize = (nth + blockSize - 1) / blockSize;
59+
60+
KeCrop<<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
61+
(outputs, inputs, inC, inH, inW, cropC, cropH, cropW,
62+
outC, outH, outW, nth);
63+
CHECK_SYNC("Crop");
64+
}
65+
66+
__global__ void KeCropDiff(const real* inGrad, real* outGrad,
67+
int inC, int inH, int inW,
68+
int cropC, int cropH, int cropW,
69+
int outC, int outH, int outW, int nthreads) {
70+
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
71+
if (idx < nthreads) {
72+
const int w = idx % inW;
73+
const int h = (idx / inW) % inH;
74+
const int c = (idx / inW / inH) % inC;
75+
const int n = idx / inW / inH / inC;
76+
77+
const int off = ((n * outC + c + cropC) * outH + h + cropH) * outW + cropW + w;
78+
79+
outGrad[off] += inGrad[idx];
80+
}
81+
}
82+
83+
template <>
84+
void CropGrad<DEVICE_TYPE_GPU>(const real* inGrad,
85+
real* outGrad,
86+
const TensorShape inShape,
87+
const TensorShape outShape,
88+
const FuncConfig& conf) {
89+
std::vector<uint32_t> crop_corner = conf.get<std::vector<uint32_t>>("crop_corner");
90+
int cropC = crop_corner[1];
91+
int cropH = crop_corner[2];
92+
int cropW = crop_corner[3];
93+
94+
int num = outShape[0];
95+
int outC = outShape[1];
96+
int outH = outShape[2];
97+
int outW = outShape[3];
98+
99+
int inC = inShape[1];
100+
int inH = inShape[2];
101+
int inW = inShape[3];
102+
103+
size_t nth = num * inC * inH * inW;
104+
int blockSize = 1024;
105+
int gridSize = (nth + blockSize - 1) / blockSize;
106+
107+
KeCropDiff <<<gridSize, blockSize, 0, STREAM_DEFAULT>>>
108+
(inGrad, outGrad, inC, inH, inW, cropC, cropH, cropW,
109+
outC, outH, outW, nth);
110+
CHECK_SYNC("CropGrad");
111+
}
112+
113+
} // namespace paddle

0 commit comments

Comments
 (0)