Skip to content

Commit 2b74b73

Browse files
authored
[NPU] add merged_momentum (PaddlePaddle#40875)
* [NPU] add merged_momentum * fix * fix device
1 parent 139a30e commit 2b74b73

File tree

2 files changed

+540
-0
lines changed

2 files changed

+540
-0
lines changed
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/optimizers/merged_momentum_op.h"
16+
17+
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
template <typename T>
23+
class NPUMergedMomentumOpKernel : public framework::OpKernel<T> {
24+
public:
25+
void Compute(const framework::ExecutionContext& ctx) const override {
26+
auto params = ctx.MultiInput<framework::Tensor>("Param");
27+
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
28+
size_t n = params.size();
29+
PADDLE_ENFORCE_EQ(n, params_out.size(),
30+
platform::errors::InvalidArgument(
31+
"The size of Output(ParamOut) must be equal to "
32+
"Input(Param), but got the size of Output(ParamOut) "
33+
"is %d, the size of Input(Param) is %d.",
34+
params_out.size(), n));
35+
for (size_t i = 0; i < n; ++i) {
36+
PADDLE_ENFORCE_EQ(params[i], params_out[i],
37+
platform::errors::InvalidArgument(
38+
"The size of Input(Param) and Output(ParamOut) "
39+
"must be the same Tensors."));
40+
}
41+
42+
auto grads = ctx.MultiInput<framework::Tensor>("Grad");
43+
PADDLE_ENFORCE_EQ(
44+
n, grads.size(),
45+
platform::errors::InvalidArgument(
46+
"The size of Input(Grad) must be equal to Input(Param), but got "
47+
"the size of Input(Grad) is %d, the size of Input(Param) is %d.",
48+
grads.size(), n));
49+
50+
auto velocitys = ctx.MultiInput<framework::Tensor>("Velocity");
51+
PADDLE_ENFORCE_EQ(n, velocitys.size(),
52+
platform::errors::InvalidArgument(
53+
"The size of Input(Velocity) must be equal to "
54+
"Input(Param), but got the size of Input(Velocity) "
55+
"is %d, the size of Input(Param) is %d.",
56+
velocitys.size(), n));
57+
58+
auto velocitys_out = ctx.MultiOutput<framework::Tensor>("VelocityOut");
59+
PADDLE_ENFORCE_EQ(
60+
n, velocitys_out.size(),
61+
platform::errors::InvalidArgument(
62+
"The size of Output(VelocityOut) must be "
63+
"equal to Input(Param), but got the size of Output(VelocityOut) is "
64+
"%d, the size of Input(Param) is %d.",
65+
velocitys_out.size(), n));
66+
for (size_t i = 0; i < n; ++i) {
67+
PADDLE_ENFORCE_EQ(velocitys[i], velocitys_out[i],
68+
platform::errors::InvalidArgument(
69+
"Input(Velocity) and Output(VelocityOut) must be "
70+
"the same Tensors."));
71+
}
72+
73+
T mu = static_cast<T>(ctx.Attr<float>("mu"));
74+
auto lrs = ctx.MultiInput<framework::Tensor>("LearningRate");
75+
if (lrs.size() != 1) {
76+
PADDLE_ENFORCE_EQ(
77+
n, lrs.size(),
78+
platform::errors::InvalidArgument(
79+
"If the size of Input(LearningRate) is not 1, the size of "
80+
"Input(LearningRate) must be "
81+
"equal to Input(Param), but got the size of Input(LearningRate) "
82+
"is %d, the size of Input(Param) is %d.",
83+
lrs.size(), n));
84+
}
85+
auto use_nesterov = ctx.Attr<bool>("use_nesterov");
86+
auto regularization_methods =
87+
ctx.Attr<std::vector<std::string>>("regularization_method");
88+
auto regularization_coeffs =
89+
ctx.Attr<std::vector<float>>("regularization_coeff");
90+
if (regularization_methods.size() != 0) {
91+
PADDLE_ENFORCE_EQ(
92+
n, regularization_methods.size(),
93+
platform::errors::InvalidArgument(
94+
"The size of Attr(regularization_method) must be equal "
95+
"to Input(Param), but got the size of "
96+
"Attr(regularization_method) is %d, the size of Input(Param) is "
97+
"%d.",
98+
regularization_methods.size(), n));
99+
PADDLE_ENFORCE_EQ(
100+
n, regularization_coeffs.size(),
101+
platform::errors::InvalidArgument(
102+
"The size of Attr(regularization_coeff) must be equal "
103+
"to Input(Param), but got the size of Attr(regularization_coeff) "
104+
"is %d, the size of Input(Param) is %d.",
105+
regularization_coeffs.size(), n));
106+
}
107+
108+
VLOG(5) << "use_nesterov: " << use_nesterov
109+
<< ", regularization_methods.size(): "
110+
<< regularization_methods.size()
111+
<< ", regularization_coeffs.size(): "
112+
<< regularization_coeffs.size();
113+
114+
auto& dev_ctx = ctx.template device_context<platform::NPUDeviceContext>();
115+
116+
Tensor mu_tensor;
117+
mu_tensor.mutable_data<T>(phi::make_ddim({1}), ctx.GetPlace());
118+
FillNpuTensorWithConstant<T>(&mu_tensor, mu);
119+
120+
for (size_t idx = 0; idx < n; ++idx) {
121+
RegularizationType regularization_flag =
122+
regularization_methods.size() > 0 &&
123+
regularization_methods[idx] == "l2_decay"
124+
? RegularizationType::kL2DECAY
125+
: RegularizationType::kNONE;
126+
float regularization_coeff = 0.0;
127+
if (regularization_coeffs.size() != 0) {
128+
regularization_coeff = regularization_coeffs[idx];
129+
}
130+
131+
auto learning_rate = lrs.size() > 1 ? lrs[idx] : lrs[0];
132+
auto param = params[idx];
133+
auto param_out = params_out[idx];
134+
auto velocity = velocitys[idx];
135+
auto velocity_out = velocitys_out[idx];
136+
137+
auto grad = grads[idx];
138+
Tensor regularized_grad;
139+
if (regularization_flag == RegularizationType::kL2DECAY) {
140+
regularized_grad.mutable_data<T>(grad->dims(), ctx.GetPlace());
141+
const auto& runner1 = NpuOpRunner("Muls", {*param}, {regularized_grad},
142+
{{"value", regularization_coeff}});
143+
runner1.Run(dev_ctx.stream());
144+
const auto& runner2 = NpuOpRunner("Add", {regularized_grad, *grad},
145+
{regularized_grad}, {});
146+
runner2.Run(dev_ctx.stream());
147+
} else {
148+
regularized_grad.ShareDataWith(*grad);
149+
}
150+
framework::TensorCopy(*param, ctx.GetPlace(), dev_ctx, param_out);
151+
framework::TensorCopy(*velocity, ctx.GetPlace(), dev_ctx, velocity_out);
152+
// NOTE: ApplyMomentum will change the input
153+
const auto& runner = NpuOpRunner(
154+
"ApplyMomentum", {*param_out, *velocity_out, *learning_rate,
155+
regularized_grad, mu_tensor},
156+
{*param_out}, {{"use_nesterov", use_nesterov}});
157+
runner.Run(dev_ctx.stream());
158+
}
159+
}
160+
};
161+
} // namespace operators
162+
} // namespace paddle
163+
164+
namespace ops = paddle::operators;
165+
namespace plat = paddle::platform;
166+
REGISTER_OP_NPU_KERNEL(merged_momentum, ops::NPUMergedMomentumOpKernel<float>,
167+
ops::NPUMergedMomentumOpKernel<plat::float16>);

0 commit comments

Comments
 (0)