Skip to content
44 changes: 44 additions & 0 deletions paddle/cuda/include/hl_cnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -370,4 +370,48 @@ extern void hl_maxout_backward(real* inGrad,
size_t featLen,
size_t groups);

/**
* @brief Upsample forward.
* @param[in] inputData input data.
* @param[out] maskData the mask data from MaxPoolWithMaskLayer.
* @param[out] batchSize the batch size of the input.
* @param[in] imgSizeH image height.
* @param[in] imgSizeW image width.
* @param[in] channels the input channels.
* @param[in] outputH the output height.
* @param[in] outputW the output widht.
* @param[out] outputData output data.
*/
extern void hl_upsample_forward(real* inputData,
real* maskData,
size_t batchSize,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW,
real* outputData);

/**
* @brief Upsample backward.
* @param[in] outputGradData the output grad data.
* @param[out] maskData the mask data from MaxPoolWithMaskLayer.
* @param[out] batchSize the batch size of the input.
* @param[in] imgSizeH image height.
* @param[in] imgSizeW image width.
* @param[in] channels the input channels.
* @param[in] outputH the output height.
* @param[in] outputW the output widht.
* @param[out] inputGradData the input grad data.
*/
extern void hl_upsample_backward(real* outputGradData,
real* maskData,
size_t batchSize,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW,
real* inputGradData);

#endif // HL_CNN_H_
20 changes: 20 additions & 0 deletions paddle/cuda/include/stub/hl_cnn_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,4 +224,24 @@ inline void hl_maxout_backward(real* inGrad,
size_t featLen,
size_t group) {}

inline void hl_upsample_forward(real* inputData,
real* maskData,
size_t batchSize,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW,
real* outputData) {}

inline void hl_upsample_backward(real* outputGradData,
real* maskData,
size_t batchSize,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW,
real* inputGradData) {}

#endif // HL_CNN_STUB_H_
76 changes: 76 additions & 0 deletions paddle/cuda/src/hl_cuda_cnn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1028,3 +1028,79 @@ void hl_maxout_backward(real* inGrad,
num_kernels, inGrad, outGrad, idData, size, featLen, groups);
CHECK_SYNC("hl_maxout_backward failed");
}

__global__ void upsampleForwardCompute(real* input_data,
real* mask_data,
size_t nthreads,
size_t in_h,
size_t in_w,
size_t out_h,
size_t out_w,
real* output_data) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int offset = index / (in_w * in_h) * out_h * out_w;
int upsample_idx = static_cast<int>(mask_data[index]);
output_data[offset + upsample_idx] = input_data[index];
}
}

__global__ void upsampleBackwardCompute(real* out_grad,
real* mask_data,
size_t nthreads,
size_t in_h,
size_t in_w,
size_t out_h,
size_t out_w,
real* input_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int offset = index / (in_w * in_h) * out_h * out_w;
int upsample_idx = static_cast<int>(mask_data[index]);
input_grad[index] = out_grad[offset + upsample_idx];
}
}

void hl_upsample_forward(real* inputData,
real* maskData,
size_t batchSize,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW,
real* outputData) {
int num_kernels = batchSize * imgSizeH * imgSizeW * channels;
int blocks = (num_kernels + 1024 - 1) / 1024;
upsampleForwardCompute<<<blocks, 1024, 0, STREAM_DEFAULT>>>(inputData,
maskData,
num_kernels,
imgSizeH,
imgSizeW,
outputH,
outputW,
outputData);
CHECK_SYNC("hl_upsample_forward failed");
}

void hl_upsample_backward(real* outputGradData,
real* maskData,
size_t batchSize,
size_t imgSizeH,
size_t imgSizeW,
size_t channels,
size_t outputH,
size_t outputW,
real* inputGradData) {
int num_kernels = batchSize * imgSizeH * imgSizeW * channels;
int blocks = (num_kernels + 1024 - 1) / 1024;
upsampleBackwardCompute<<<blocks, 1024, 0, STREAM_DEFAULT>>>(outputGradData,
maskData,
num_kernels,
imgSizeH,
imgSizeW,
outputH,
outputW,
inputGradData);
CHECK_SYNC("hl_upsample_backward failed");
}
108 changes: 108 additions & 0 deletions paddle/gserver/layers/UpsampleLayer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "UpsampleLayer.h"
#include "iostream"

namespace paddle {

REGISTER_LAYER(upsample, UpsampleLayer);

size_t UpsampleLayer::getOutputSize() {
if (upsampleSize_ == 0) {
upsampleSize_ = imgSize_ * scale_ - static_cast<int>(padOutX_);
upsampleSizeY_ = imgSizeY_ * scaleY_ - static_cast<int>(padOutY_);
}
return upsampleSize_ * upsampleSizeY_ * channels_;
}

bool UpsampleLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
Layer::init(layerMap, parameterMap);

CHECK_EQ(inputLayers_.size(), 2U);
CHECK_EQ(config_.inputs_size(), 2);
const auto& conf = config_.inputs(0).upsample_conf();
const auto& img_conf = conf.image_conf();

imgSizeY_ =
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
imgSize_ = img_conf.img_size();
channels_ = img_conf.channels();

CHECK((conf.has_upsample_size()) || (conf.has_scale()))
<< "scale or upsample_size is required.";

if (conf.has_upsample_size()) {
upsampleSize_ = conf.upsample_size();
upsampleSizeY_ = upsampleSize_;
if (conf.has_upsample_size_y()) {
upsampleSizeY_ = conf.upsample_size_y();
}
} else {
if (!conf.has_scale_y()) {
scale_ = scaleY_ = conf.scale_y();
CHECK_GT(static_cast<int>(scale_), 1);
} else {
scale_ = conf.scale();
scaleY_ = conf.scale_y();
}
padOutX_ = conf.pad_out_x();
padOutY_ = conf.pad_out_y();
CHECK(!padOutX_ || scale_ == 2)
<< "Output height padding compensation requires scale_ == 2";
CHECK(!padOutY_ || scaleY_ == 2)
<< "Output width padding compensation requires scaleY_ == 2";
upsampleSize_ = upsampleSizeY_ = 0;
}
return true;
}

void UpsampleLayer::forward(PassType passType) {
Layer::forward(passType);

MatrixPtr input = getInputValue(0);
MatrixPtr mask = inputLayers_[1]->getOutput("mask").value;

size_t batchSize = input->getHeight();
size_t outSize = getOutputSize();

CHECK_EQ(input->getWidth(), mask->getWidth());
CHECK_EQ(mask->getHeight(), batchSize);
resetOutput(batchSize, outSize);

MatrixPtr output = getOutputValue();
output->upsampleForward(*input,
*mask,
imgSize_,
imgSizeY_,
channels_,
upsampleSize_,
upsampleSizeY_);
}

void UpsampleLayer::backward(const UpdateCallback& callback) {
MatrixPtr mask = inputLayers_[1]->getOutput("mask").value;
MatrixPtr inputGrad = getInputGrad(0);
MatrixPtr outputGrad = getOutputGrad();
inputGrad->upsampleBackward(*outputGrad,
*mask,
imgSize_,
imgSizeY_,
channels_,
upsampleSize_,
upsampleSizeY_);
}

} // namespace paddle
53 changes: 53 additions & 0 deletions paddle/gserver/layers/UpsampleLayer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <vector>
#include "Layer.h"
#include "paddle/math/Matrix.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"

namespace paddle {

/**
* This layer transpose the pooling process.
* It takes two input, the first input is the input data, and
* the second is the mask data from the max-pool-with-mask layer.
*
*/

class UpsampleLayer : public Layer {
public:
explicit UpsampleLayer(const LayerConfig& config) : Layer(config) {}
~UpsampleLayer() {}

bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;

void forward(PassType passType) override;
void backward(const UpdateCallback& callback) override;

size_t getOutputSize();

protected:
size_t scale_, scaleY_;
size_t upsampleSize_, upsampleSizeY_;
size_t padOutX_, padOutY_;
size_t imgSize_, imgSizeY_;
size_t channels_;
};

} // namespace paddle
1 change: 1 addition & 0 deletions paddle/gserver/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ gserver_test(test_BatchNorm)
gserver_test(test_KmaxSeqScore)
gserver_test(test_Expand)
gserver_test(test_MaxPoolingWithMaskOutput)
gserver_test(test_Upsample)

set(PYTHON_PATH
${PADDLE_SOURCE_DIR}/paddle/.set_python_path.sh -d
Expand Down
Loading