Skip to content

Commit 50e664a

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into transpose_conv1d
test=develop
2 parents f4e4c4d + 6b28456 commit 50e664a

File tree

89 files changed

+6552
-1586
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+6552
-1586
lines changed

paddle/fluid/operators/activation_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -781,8 +781,8 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
781781
}
782782
};
783783

784-
// leaky_relu Grad: dx=dy if y>=0 else alpha * dy
785-
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
784+
// leaky_relu Grad: dx=dy if x>=0 else alpha * dy
785+
// leaky_relu GradGrad: ddy=ddx if x>=0 else alpha * ddx
786786
template <typename T>
787787
class LeakyReluDoubleGradMaker
788788
: public ::paddle::framework::SingleGradOpMaker<T> {
@@ -792,8 +792,8 @@ class LeakyReluDoubleGradMaker
792792
protected:
793793
void Apply(GradOpPtr<T> op) const override {
794794
op->SetType("leaky_relu_grad_grad");
795-
// input1: Out
796-
op->SetInput("Out", this->Input("Out"));
795+
// input1: X
796+
op->SetInput("X", this->Input("X"));
797797
// X@GRAD@GRAD: ddx
798798
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
799799
op->SetAttrMap(this->Attrs());

paddle/fluid/operators/activation_op.h

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,7 +1084,11 @@ struct LeakyReluFunctor : public BaseActivationFunctor<T> {
10841084

10851085
template <typename Device, typename X, typename Out>
10861086
void operator()(Device d, X x, Out out) const {
1087-
out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
1087+
if (alpha < 1.f) {
1088+
out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
1089+
} else {
1090+
out.device(d) = x.cwiseMin(static_cast<T>(alpha) * x);
1091+
}
10881092
}
10891093
};
10901094

@@ -1098,12 +1102,12 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
10981102
typename dX>
10991103
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
11001104
auto temp1 =
1101-
static_cast<T>(alpha) * (out <= static_cast<T>(0)).template cast<T>();
1102-
auto temp2 = (out > static_cast<T>(0)).template cast<T>();
1105+
static_cast<T>(alpha) * (x < static_cast<T>(0)).template cast<T>();
1106+
auto temp2 = (x >= static_cast<T>(0)).template cast<T>();
11031107
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
11041108
}
11051109

1106-
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
1110+
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
11071111
};
11081112

11091113
template <typename T>
@@ -1451,18 +1455,18 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
14511455
auto* d = dev.eigen_device();
14521456
auto ddx = framework::EigenVector<T>::Flatten(
14531457
GET_DATA_SAFELY(ddX, "Input", "DDX", "LeakyReluGradGrad"));
1454-
auto out = framework::EigenVector<T>::Flatten(
1455-
GET_DATA_SAFELY(Out, "Output", "Out", "LeakyReluGradGrad"));
1458+
auto x = framework::EigenVector<T>::Flatten(
1459+
GET_DATA_SAFELY(X, "Input", "X", "LeakyReluGradGrad"));
14561460
auto ddout = framework::EigenVector<T>::Flatten(
14571461
GET_DATA_SAFELY(ddOut, "Output", "DOut", "LeakyReluGradGrad"));
1458-
ddout.device(*d) = ddx *
1459-
((out > static_cast<T>(0)).template cast<T>() +
1460-
static_cast<T>(alpha) *
1461-
(out <= static_cast<T>(0)).template cast<T>())
1462-
.template cast<T>();
1462+
ddout.device(*d) =
1463+
ddx *
1464+
((x > static_cast<T>(0)).template cast<T>() +
1465+
static_cast<T>(alpha) * (x <= static_cast<T>(0)).template cast<T>())
1466+
.template cast<T>();
14631467
}
14641468
}
1465-
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
1469+
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
14661470
};
14671471

14681472
template <typename T>

paddle/fluid/operators/arg_min_max_op_base.cu.h

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ using Tensor = framework::Tensor;
5353
FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__);
5454

5555
template <typename T, typename IndType, class Reducer, size_t BlockDim>
56-
__global__ void ArgCUDAKernel(const IndType height, // n * h
57-
const IndType width, // c
58-
const IndType post_size, // h
56+
__global__ void ArgCUDAKernel(const int64_t height, // n * h
57+
const int64_t width, // c
58+
const int64_t post_size, // h
5959
const Reducer reducer, const T init, const T* in,
6060
IndType* out) {
6161
typedef cub::BlockReduce<KeyValuePair<int, T>, BlockDim> BlockReduce;
@@ -79,10 +79,10 @@ __global__ void ArgCUDAKernel(const IndType height, // n * h
7979

8080
template <typename T, typename IndType, class Reducer>
8181
void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
82-
Tensor* indices, const IndType pre, const IndType post,
83-
const IndType n) {
82+
Tensor* indices, const int64_t pre, const int64_t post,
83+
const int64_t n) {
8484
auto cu_stream = ctx.stream();
85-
auto ComputeBlockSize = [](IndType col) {
85+
auto ComputeBlockSize = [](int64_t col) {
8686
if (col > 512)
8787
return 1024;
8888
else if (col > 256)
@@ -101,10 +101,10 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
101101
return 8;
102102
};
103103

104-
int max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x;
105-
int height = pre * post;
106-
int width = n;
107-
int grid_size = height < max_grid_dimx ? height : max_grid_dimx;
104+
int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize().x;
105+
int64_t height = pre * post;
106+
int64_t width = n;
107+
int64_t grid_size = height < max_grid_dimx ? height : max_grid_dimx;
108108

109109
const T* in_data = input.data<T>();
110110
IndType* out_data = indices->mutable_data<IndType>(ctx.GetPlace());
@@ -129,31 +129,60 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
129129
}
130130

131131
template <typename T, class Reducer>
132-
class ArgMinMaxOpCUDAKernel : public framework::OpKernel<T> {
133-
public:
134-
void Compute(const framework::ExecutionContext& ctx) const override {
132+
struct VisitDataCudaArgMinMaxFunctor {
133+
const framework::ExecutionContext& ctx;
134+
135+
explicit VisitDataCudaArgMinMaxFunctor(const framework::ExecutionContext& ctx)
136+
: ctx(ctx) {}
137+
template <typename IndType>
138+
void apply() const {
135139
auto* input = ctx.Input<Tensor>("X");
136140
auto* output = ctx.Output<Tensor>("Out");
137141
int axis = ctx.Attr<int64_t>("axis");
138-
auto in_dims = input->dims();
139-
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
142+
const bool& flatten = ctx.Attr<bool>("flatten");
143+
144+
framework::DDim input_dims;
145+
if (flatten) {
146+
input_dims = framework::make_ddim({input->numel()});
147+
// if flatten, the axis just as 0
148+
axis = 0;
149+
} else {
150+
input_dims = input->dims();
151+
if (axis < 0) axis += input->dims().size();
152+
}
140153

141154
int64_t numel = input->numel();
142-
int64_t groups = numel / in_dims[axis];
155+
int64_t groups = numel / input_dims[axis];
143156
int64_t pre = 1;
144157
int64_t post = 1;
145-
int64_t n = in_dims[axis];
158+
int64_t n = input_dims[axis];
146159

147160
for (int i = 0; i < axis; i++) {
148-
pre *= in_dims[i];
161+
pre *= input_dims[i];
149162
}
150163

151-
for (int i = axis + 1; i < in_dims.size(); i++) {
152-
post *= in_dims[i];
164+
for (int i = axis + 1; i < input_dims.size(); i++) {
165+
post *= input_dims[i];
153166
}
154167

155168
const auto& dev_ctx = ctx.cuda_device_context();
156-
ComputeFullArg<T, int64_t, Reducer>(dev_ctx, *input, output, pre, post, n);
169+
ComputeFullArg<T, IndType, Reducer>(dev_ctx, *input, output, pre, post, n);
170+
}
171+
};
172+
template <typename T, class Reducer>
173+
class ArgMinMaxOpCUDAKernel : public framework::OpKernel<T> {
174+
public:
175+
void Compute(const framework::ExecutionContext& ctx) const override {
176+
auto& dtype = ctx.Attr<int>("dtype");
177+
if (dtype < 0) {
178+
framework::VisitDataType(static_cast<framework::proto::VarType::Type>(
179+
framework::proto::VarType::INT64),
180+
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
181+
return;
182+
}
183+
framework::VisitDataType(
184+
static_cast<framework::proto::VarType::Type>(dtype),
185+
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
157186
}
158187
};
159188

paddle/fluid/operators/arg_min_max_op_base.h

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ struct ArgMinMaxFunctor {};
3838
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
3939
enum_argminmax_value> { \
4040
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
41-
framework::LoDTensor* out, int64_t axis, bool keepdims) { \
42-
auto in_eigen = framework::EigenTensor<T, Rank>::From(in); \
41+
framework::LoDTensor* out, framework::DDim x_dims, \
42+
int64_t axis, bool keepdims) { \
43+
auto in_eigen = framework::EigenTensor<T, Rank>::From(in, x_dims); \
4344
if (keepdims) { \
4445
auto out_eigen = framework::EigenTensor<Tout, Rank>::From(*out); \
4546
out_eigen.device(*(ctx.eigen_device())) = \
@@ -68,16 +69,26 @@ struct VisitDataArgMinMaxFunctor {
6869
out.template mutable_data<Tout>(ctx.GetPlace());
6970
auto axis = ctx.Attr<int64_t>("axis");
7071
auto keepdims = ctx.Attr<bool>("keepdims");
71-
auto x_rank = x.dims().size();
72-
if (axis < 0) axis += x_rank;
72+
const bool& flatten = ctx.Attr<bool>("flatten");
73+
74+
// if flatten, will construct the new dims for the cacluate
75+
framework::DDim x_dims;
76+
if (flatten) {
77+
x_dims = framework::make_ddim({x.numel()});
78+
// if flatten, the axis just as 0
79+
axis = 0;
80+
} else {
81+
x_dims = x.dims();
82+
if (axis < 0) axis += x_dims.size();
83+
}
7384
auto& dev_ctx = ctx.template device_context<DeviceContext>();
7485

7586
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
7687
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
7788
functor##rank; \
78-
functor##rank(dev_ctx, x, &out, axis, keepdims)
89+
functor##rank(dev_ctx, x, &out, x_dims, axis, keepdims)
7990

80-
switch (x.dims().size()) {
91+
switch (x_dims.size()) {
8192
case 1:
8293
CALL_ARG_MINMAX_FUNCTOR(1);
8394
break;
@@ -141,6 +152,7 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
141152
const auto& x_dims = ctx->GetInputDim("X");
142153
int64_t axis = ctx->Attrs().Get<int64_t>("axis");
143154
bool keepdims = ctx->Attrs().Get<bool>("keepdims");
155+
const bool& flatten = ctx->Attrs().Get<bool>("flatten");
144156

145157
PADDLE_ENFORCE_GE(axis, -x_dims.size(),
146158
platform::errors::InvalidArgument(
@@ -152,14 +164,21 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
152164
platform::errors::InvalidArgument(
153165
"'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size()));
154166

155-
auto x_rank = x_dims.size();
156-
if (axis < 0) axis += x_rank;
157167
std::vector<int64_t> vec;
158-
for (int64_t i = 0; i < axis; i++) vec.push_back(x_dims[i]);
159-
if (keepdims) {
160-
vec.push_back(static_cast<int64_t>(1));
168+
if (flatten) {
169+
// if is flatten, will return the only on element
170+
if (keepdims) {
171+
vec.emplace_back(static_cast<int64_t>(1));
172+
}
173+
} else {
174+
auto x_rank = x_dims.size();
175+
if (axis < 0) axis += x_rank;
176+
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]);
177+
if (keepdims) {
178+
vec.emplace_back(static_cast<int64_t>(1));
179+
}
180+
for (int64_t i = axis + 1; i < x_rank; i++) vec.emplace_back(x_dims[i]);
161181
}
162-
for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]);
163182
ctx->SetOutputDim("Out", framework::make_ddim(vec));
164183
}
165184
};
@@ -176,6 +195,9 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
176195
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
177196
AddAttr<bool>("keepdims", "Keep the dim that to reduce.").SetDefault(false);
178197
AddAttr<int>("dtype", "Keep the dim that to reduce.").SetDefault(-1);
198+
AddAttr<bool>("flatten",
199+
"Flatten the input value, and search the min or max indices")
200+
.SetDefault(false);
179201
AddComment(string::Sprintf(R"DOC(
180202
%s Operator.
181203
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#include "paddle/fluid/operators/bernoulli_op.h"
15+
16+
#include <algorithm>
17+
#include <string>
18+
19+
#include "paddle/fluid/framework/generator.h"
20+
#include "paddle/fluid/framework/op_registry.h"
21+
#include "paddle/fluid/framework/operator.h"
22+
#include "paddle/fluid/operators/common_infer_shape_functions.h"
23+
24+
namespace paddle {
25+
namespace operators {
26+
27+
class BernoulliOpMaker : public framework::OpProtoAndCheckerMaker {
28+
public:
29+
void Make() override {
30+
AddInput("X",
31+
"A tensor with probabilities for generating the random binary "
32+
"number");
33+
AddOutput("Out", "A Tensor filled with random binary number");
34+
AddComment(R"DOC(
35+
This OP returns a Tensor filled with random binary(0 or 1) number from a Bernoulli distribution.
36+
37+
Out ~ Bernoulli(X)
38+
39+
)DOC");
40+
}
41+
};
42+
43+
class BernoulliOp : public framework::OperatorWithKernel {
44+
public:
45+
using framework::OperatorWithKernel::OperatorWithKernel;
46+
47+
void InferShape(framework::InferShapeContext *ctx) const override {
48+
return UnaryOpUnchangedInferShape(ctx);
49+
}
50+
};
51+
52+
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
53+
// Use std::random and thrust::random(thrust is a std library in CUDA) to
54+
// implement uniform random.
55+
template <typename T>
56+
class BernoulliOpKernel<platform::CPUDeviceContext, T>
57+
: public framework::OpKernel<T> {
58+
public:
59+
void Compute(const framework::ExecutionContext &ctx) const override {
60+
const auto x = ctx.Input<framework::Tensor>("X");
61+
auto out = ctx.Output<framework::Tensor>("Out");
62+
auto *in_data = x->data<T>();
63+
auto *out_data = out->mutable_data<T>(ctx.GetPlace());
64+
65+
int64_t size = x->numel();
66+
std::uniform_real_distribution<T> dist(0.0, 1.0);
67+
auto gen_ptr = framework::Generator::GetInstance();
68+
std::mt19937_64 &gen_engine = gen_ptr->GetCPUEngine();
69+
70+
for (int64_t i = 0; i < size; ++i) {
71+
out_data[i] = BernoulliFunctor(in_data[i], dist(gen_engine));
72+
}
73+
}
74+
}; // namespace operators
75+
76+
} // namespace operators
77+
} // namespace paddle
78+
79+
namespace ops = paddle::operators;
80+
namespace plat = paddle::platform;
81+
REGISTER_OPERATOR(
82+
bernoulli, ops::BernoulliOp, ops::BernoulliOpMaker,
83+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
84+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
85+
86+
REGISTER_OP_CPU_KERNEL(bernoulli,
87+
ops::BernoulliOpKernel<plat::CPUDeviceContext, float>,
88+
ops::BernoulliOpKernel<plat::CPUDeviceContext, double>);

0 commit comments

Comments
 (0)