Skip to content

Commit 1168003

Browse files
authored
Adding the Adam Optimizer operator (#4733)
* add adam op moment1_out = beta1 * moment1 + (1 − beta1) * grad moment2_out = beta2 * moment2 + (1 − beta2) * grad * grad moment1_hat = moment1_out / (1 - beta1^t) moment2_hat = moment2_out / (1 - beta2^t) param_out = param - learning_rate * moment1_hat / (sqrt(moment2_hat) + epsilon) * fix moment 2 * Adding the Adam optimization operator * Adding more tests for Adam op
1 parent 7460958 commit 1168003

File tree

4 files changed

+432
-0
lines changed

4 files changed

+432
-0
lines changed

paddle/operators/adam_op.cc

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/adam_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class AdamOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
protected:
25+
void InferShape(framework::InferShapeContext *ctx) const override {
26+
PADDLE_ENFORCE(ctx->HasInput("Param"),
27+
"Input(Param) of AdamOp should not be null.");
28+
PADDLE_ENFORCE(ctx->HasInput("Grad"),
29+
"Input(Grad) of AdamOp should not be null.");
30+
PADDLE_ENFORCE(ctx->HasInput("Moment1"),
31+
"Input(Moment1) of AdamOp should not be null.");
32+
PADDLE_ENFORCE(ctx->HasInput("Moment2"),
33+
"Input(Moment2) of AdamOp should not be null.");
34+
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
35+
"Input(LearningRate) of AdamOp should not be null.");
36+
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
37+
"Input(Beta1Pow) of AdamOp should not be null.");
38+
PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"),
39+
"Input(Beta2Pow) of AdamOp should not be null.");
40+
41+
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
42+
"Output(ParamOut) of AdamOp should not be null.");
43+
PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"),
44+
"Output(Moment1Out) of AdamOp should not be null.");
45+
PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"),
46+
"Output(Moment2Out) of AdamOp should not be null.");
47+
PADDLE_ENFORCE(ctx->HasOutput("Beta1PowOut"),
48+
"Output(Beta1PowOut) of AdamOp should not be null.");
49+
PADDLE_ENFORCE(ctx->HasOutput("Beta2PowOut"),
50+
"Output(Beta2PowOut) of AdamOp should not be null.");
51+
52+
auto lr_dims = ctx->GetInputDim("LearningRate");
53+
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
54+
"Learning rate should have 1 dimension");
55+
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
56+
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
57+
"Beta1 power accumulator should have 1 dimension");
58+
auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow");
59+
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
60+
"Beta1 power accumulator should have 1 dimension");
61+
62+
auto param_dims = ctx->GetInputDim("Param");
63+
PADDLE_ENFORCE_EQ(
64+
param_dims, ctx->GetInputDim("Grad"),
65+
"Param and Grad input of AdamOp should have same dimension");
66+
PADDLE_ENFORCE_EQ(
67+
param_dims, ctx->GetInputDim("Moment1"),
68+
"Param and Moment input of AdamOp should have same dimension");
69+
PADDLE_ENFORCE_EQ(
70+
param_dims, ctx->GetInputDim("Moment2"),
71+
"Param and InfNorm input of AdamOp should have same dimension");
72+
73+
ctx->SetOutputDim("ParamOut", param_dims);
74+
ctx->SetOutputDim("Moment1Out", param_dims);
75+
ctx->SetOutputDim("Moment2Out", param_dims);
76+
ctx->SetOutputDim("Beta1PowOut", beta1_pow_dims);
77+
ctx->SetOutputDim("Beta2PowOut", beta2_pow_dims);
78+
}
79+
};
80+
81+
class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
82+
public:
83+
AdamOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
84+
: OpProtoAndCheckerMaker(proto, op_checker) {
85+
AddInput("Param", "(Tensor) Input parameter");
86+
AddInput("Grad", "(Tensor) Input gradient");
87+
AddInput("LearningRate", "(Tensor) Learning rate");
88+
AddInput("Moment1", "(Tensor) Input first moment");
89+
AddInput("Moment2", "(Tensor) Input second moment");
90+
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator");
91+
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator");
92+
93+
AddOutput("ParamOut", "(Tensor) Output parameter");
94+
AddOutput("Moment1Out", "(Tensor) Output first moment");
95+
AddOutput("Moment2Out", "(Tensor) Output second moment");
96+
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator");
97+
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator");
98+
99+
AddAttr<float>("beta1",
100+
"(float, default 0.9) "
101+
"Exponential decay rate for the "
102+
"first moment estimates.")
103+
.SetDefault(0.9f);
104+
AddAttr<float>("beta2",
105+
"(float, default 0.999) "
106+
"exponential decay rate for the "
107+
"second moment estimates.")
108+
.SetDefault(0.999f);
109+
AddAttr<float>("epsilon",
110+
"(float, default 1.0e-8) "
111+
"Constant for numerical stability")
112+
.SetDefault(1.0e-8f);
113+
114+
AddComment(R"DOC(
115+
Adam Updates Operator.
116+
117+
This implements the Adam optimizer from Section 2 of the Adam
118+
paper[1]. Adam is a first-order gradient-based optimization
119+
method based on adaptive estimates of lower-order moments.
120+
121+
Adam updates:
122+
123+
moment1_out = beta1 * moment1 + (1 − beta1) * grad
124+
moment2_out = beta2 * moment2 + (1 − beta2) * grad * grad
125+
beta1_pow_out = beta1_pow * beta1
126+
beta2_pow_out = beta2_pow * beta2
127+
learning_rate_t = learning_rate_t *
128+
sqrt(1 - beta2_pow_out) / (1 - beta1_pow_out)
129+
param_out = param - learning_rate_t * moment1/ (sqrt(moment2) + epsilon)
130+
131+
References:
132+
[1] Adam: A Method for Stochastic Optimization
133+
(https://arxiv.org/abs/1412.6980)
134+
135+
)DOC");
136+
}
137+
};
138+
} // namespace operators
139+
} // namespace paddle
140+
141+
namespace ops = paddle::operators;
142+
REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker);
143+
REGISTER_OP_CPU_KERNEL(adam,
144+
ops::AdamOpKernel<paddle::platform::CPUPlace, float>);

paddle/operators/adam_op.cu

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
#include "paddle/operators/adam_op.h"
17+
18+
namespace ops = paddle::operators;
19+
REGISTER_OP_GPU_KERNEL(adam,
20+
ops::AdamOpKernel<paddle::platform::GPUPlace, float>);

paddle/operators/adam_op.h

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/eigen.h"
17+
#include "paddle/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
template <typename Place, typename T>
23+
class AdamOpKernel : public framework::OpKernel<T> {
24+
public:
25+
void Compute(const framework::ExecutionContext& ctx) const override {
26+
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
27+
auto moment1_out_tensor = ctx.Output<framework::Tensor>("Moment1Out");
28+
auto moment2_out_tensor = ctx.Output<framework::Tensor>("Moment2Out");
29+
auto beta1_pow_out_tensor = ctx.Output<framework::Tensor>("Beta1PowOut");
30+
auto beta2_pow_out_tensor = ctx.Output<framework::Tensor>("Beta2PowOut");
31+
32+
param_out_tensor->mutable_data<T>(ctx.GetPlace());
33+
moment1_out_tensor->mutable_data<T>(ctx.GetPlace());
34+
moment2_out_tensor->mutable_data<T>(ctx.GetPlace());
35+
beta1_pow_out_tensor->mutable_data<T>(ctx.GetPlace());
36+
beta2_pow_out_tensor->mutable_data<T>(ctx.GetPlace());
37+
38+
float beta1 = ctx.Attr<float>("beta1");
39+
float beta2 = ctx.Attr<float>("beta2");
40+
float epsilon = ctx.Attr<float>("epsilon");
41+
42+
auto param = framework::EigenVector<T>::Flatten(
43+
*ctx.Input<framework::Tensor>("Param"));
44+
auto grad = framework::EigenVector<T>::Flatten(
45+
*ctx.Input<framework::Tensor>("Grad"));
46+
auto moment1 = framework::EigenVector<T>::Flatten(
47+
*ctx.Input<framework::Tensor>("Moment1"));
48+
auto moment2 = framework::EigenVector<T>::Flatten(
49+
*ctx.Input<framework::Tensor>("Moment2"));
50+
auto lr = framework::EigenVector<T>::Flatten(
51+
*ctx.Input<framework::Tensor>("LearningRate"));
52+
auto beta1_pow = framework::EigenVector<T>::Flatten(
53+
*ctx.Input<framework::Tensor>("Beta1Pow"));
54+
auto beta2_pow = framework::EigenVector<T>::Flatten(
55+
*ctx.Input<framework::Tensor>("Beta2Pow"));
56+
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
57+
auto moment1_out = framework::EigenVector<T>::Flatten(*moment1_out_tensor);
58+
auto moment2_out = framework::EigenVector<T>::Flatten(*moment2_out_tensor);
59+
auto beta1_pow_out =
60+
framework::EigenVector<T>::Flatten(*beta1_pow_out_tensor);
61+
auto beta2_pow_out =
62+
framework::EigenVector<T>::Flatten(*beta2_pow_out_tensor);
63+
auto place = ctx.GetEigenDevice<Place>();
64+
65+
moment1_out.device(place) = beta1 * moment1 + (1 - beta1) * grad;
66+
moment2_out.device(place) = beta2 * moment2 + (1 - beta2) * grad.square();
67+
beta1_pow_out.device(place) = beta1_pow * beta1;
68+
beta2_pow_out.device(place) = beta2_pow * beta2;
69+
// All of these are tensors of 1 element
70+
auto lr_t = lr * (1 - beta2_pow_out).sqrt() / (1 - beta1_pow_out);
71+
// Eigen does not support automatic broadcast
72+
// Get dimensions of moment vector to broadcast lr_t
73+
Eigen::DSizes<int, 1> m_dsize(moment1_out_tensor->numel());
74+
param_out.device(place) =
75+
param -
76+
lr_t.broadcast(m_dsize) *
77+
(moment1_out / (moment2_out.sqrt() + epsilon));
78+
}
79+
};
80+
81+
} // namespace operators
82+
} // namespace paddle

0 commit comments

Comments
 (0)