Skip to content

Commit a4d07bb

Browse files
authored
[AMP] Add multi_precision for sgd (#38231)
1 parent 08941ed commit a4d07bb

File tree

6 files changed

+408
-36
lines changed

6 files changed

+408
-36
lines changed

paddle/fluid/operators/optimizers/sgd_op.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,24 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
126126
AddInput("Param", "(Tensor or SelectedRows) Input parameter");
127127
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
128128
AddInput("Grad", "(Tensor or SelectedRows) Input gradient");
129+
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
129130
AddOutput("ParamOut",
130131
"(Tensor or SelectedRows, same with Param) "
131132
"Output parameter, should share the same memory with Param");
133+
AddOutput("MasterParamOut",
134+
"The updated FP32 master weight for AMP. "
135+
"It shared memory with Input(MasterParam).")
136+
.AsDispensable();
137+
132138
AddAttr<bool>(
133139
"use_mkldnn",
134140
"(bool, default false) Indicates if MKL-DNN kernel will be used")
135141
.SetDefault(false);
142+
AddAttr<bool>("multi_precision",
143+
"(bool, default false) "
144+
"Whether to use multi-precision during weight updating.")
145+
.SetDefault(false);
146+
136147
AddComment(R"DOC(
137148
138149
SGD operator

paddle/fluid/operators/optimizers/sgd_op.cu

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include <algorithm>
16+
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
1617
#include "paddle/fluid/operators/optimizers/sgd_op.h"
1718
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
1819

@@ -21,14 +22,19 @@ namespace operators {
2122

2223
namespace {
2324

24-
template <typename T>
25-
__global__ void SGDKernel(const T* g, const T* p, const T* learning_rate,
26-
const int num, T* p_out) {
27-
T lr = learning_rate[0];
25+
template <typename T, typename MT>
26+
__global__ void SGDKernelMT(const T* param, const T* grad,
27+
const T* learning_rate, const int num, T* param_out,
28+
const MT* master_param, MT* master_param_out) {
29+
MT lr = static_cast<MT>(learning_rate[0]);
2830
CUDA_KERNEL_LOOP(i, num) {
29-
T g_data = g[i];
30-
T p_data = p[i];
31-
p_out[i] = p_data - lr * g_data;
31+
MT p_data = master_param ? master_param[i] : static_cast<MT>(param[i]);
32+
MT g_data = static_cast<MT>(grad[i]);
33+
p_data = p_data - lr * g_data;
34+
param_out[i] = static_cast<T>(p_data);
35+
if (master_param_out) {
36+
master_param_out[i] = p_data;
37+
}
3238
}
3339
}
3440

@@ -63,30 +69,48 @@ class SGDOpKernel<platform::CUDADeviceContext, T>
6369
"but the received is %s",
6470
ctx.InputNames("Param").front(),
6571
paddle::framework::ToTypeName(param_var->Type())));
72+
using paddle::framework::Tensor;
73+
using MPDType = typename details::MPTypeTrait<T>::Type;
6674

6775
auto* param = ctx.Input<framework::Tensor>("Param");
6876
auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
6977
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
7078

7179
auto* grad_var = ctx.InputVar("Grad");
80+
81+
const bool multi_precision = ctx.Attr<bool>("multi_precision");
82+
const Tensor* master_param = nullptr;
83+
Tensor* master_param_out = nullptr;
84+
if (multi_precision) {
85+
bool has_master =
86+
ctx.HasInput("MasterParam") && ctx.HasOutput("MasterParamOut");
87+
PADDLE_ENFORCE_EQ(has_master, true,
88+
platform::errors::InvalidArgument(
89+
"The Input(MasterParam) and Output(MasterParamOut) "
90+
"should not be null when "
91+
"the attr `multi_precision` is true"));
92+
master_param = ctx.Input<framework::Tensor>("MasterParam");
93+
master_param_out = ctx.Output<framework::Tensor>("MasterParamOut");
94+
}
95+
const MPDType* master_in_data =
96+
multi_precision ? master_param->data<MPDType>() : nullptr;
97+
MPDType* master_out_data =
98+
multi_precision
99+
? master_param_out->mutable_data<MPDType>(ctx.GetPlace())
100+
: nullptr;
101+
72102
// Actually, all tensors are LoDTensor except SelectedRows.
73103
if (grad_var->IsType<framework::LoDTensor>()) {
74-
param_out->mutable_data<T>(ctx.GetPlace());
75104
auto* grad = ctx.Input<framework::Tensor>("Grad");
76-
// LOG(ERROR) << "grad";
77-
// LOG(ERROR) << ctx.op().Input("Grad");
78-
auto* grad_data = grad->data<T>();
79-
// LOG(ERROR) << "param";
80-
auto* param_data = param->data<T>();
81-
// LOG(ERROR) << "fin";
82-
auto* param_out_data = param_out->data<T>();
83105

84106
int block = 512;
85107
int grid = (param->numel() + block - 1) / block;
86108

87-
SGDKernel<T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
88-
grad_data, param_data, learning_rate->data<T>(), param->numel(),
89-
param_out_data);
109+
SGDKernelMT<
110+
T, MPDType><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
111+
param->data<T>(), grad->data<T>(), learning_rate->data<T>(),
112+
param->numel(), param_out->mutable_data<T>(ctx.GetPlace()),
113+
master_in_data, master_out_data);
90114

91115
} else if (grad_var->IsType<framework::SelectedRows>()) {
92116
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.

paddle/fluid/pybind/op_function_generator.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
7979
"Beta2Pow", "MasterParam"}},
8080
{"sparse_attention",
8181
{"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}},
82+
{"sgd", {"Param", "LearningRate", "Grad", "MasterParam"}},
8283
};
8384

8485
// NOTE(zhiqiu): Like op_ins_map.
@@ -125,6 +126,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
125126
{"adamw",
126127
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
127128
"MasterParamOut"}},
129+
{"sgd", {"ParamOut", "MasterParamOut"}},
128130
{"lamb",
129131
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
130132
"MasterParamOut"}},
@@ -142,7 +144,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
142144
// especially in declarative mode.
143145
// For those OPs, we need to manually specify the outs need to pass in this map.
144146
std::map<std::string, std::set<std::string>> op_passing_outs_map = {
145-
{"sgd", {"ParamOut"}},
147+
{"sgd", {"ParamOut", "MasterParamOut"}},
146148
{"adam",
147149
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut",
148150
"MasterParamOut"}},

python/paddle/fluid/optimizer.py

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,7 @@ def __init__(self,
12961296
parameter_list=None,
12971297
regularization=None,
12981298
grad_clip=None,
1299+
multi_precision=False,
12991300
name=None):
13001301
assert learning_rate is not None
13011302
super(SGDOptimizer, self).__init__(
@@ -1306,26 +1307,86 @@ def __init__(self,
13061307
name=name)
13071308
self.type = "sgd"
13081309
self._use_mkldnn = False
1310+
self._multi_precision = multi_precision
1311+
self._master_weights = {}
1312+
1313+
def _create_master_weight(self, param):
1314+
if param.name in self._master_weights:
1315+
var = self._master_weights[param.name]
1316+
else:
1317+
assert isinstance(self.helper, LayerHelper)
1318+
1319+
var_name = param.name + "_fp32_master"
1320+
var_name = unique_name.generate(var_name)
1321+
var = layers.create_global_var(
1322+
name=var_name,
1323+
shape=param.shape,
1324+
value=0,
1325+
dtype='float32',
1326+
persistable=True)
1327+
block = self.helper.startup_program.global_block()
1328+
block.append_op(
1329+
type="cast",
1330+
inputs={"X": [param]},
1331+
outputs={"Out": [var]},
1332+
attrs={
1333+
"in_dtype": param.dtype,
1334+
"out_dtype": core.VarDesc.VarType.FP32
1335+
})
1336+
self._master_weights[param.name] = var
1337+
return var
1338+
1339+
def _create_accumulators(self, block, parameters):
1340+
assert isinstance(block, framework.Block)
1341+
if isinstance(parameters, dict):
1342+
parameters = self._update_param_group(parameters)
1343+
1344+
# Create accumulator tensors for first and second moments
1345+
for p in parameters:
1346+
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
1347+
master_p = self._create_master_weight(p)
1348+
continue
1349+
if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
1350+
warnings.warn(
1351+
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
1352+
"Consider using multi_precision=True option of the Adam optimizer."
1353+
)
13091354

13101355
@no_grad
13111356
def _append_optimize_op(self, block, param_and_grad):
1357+
1358+
find_master = self._multi_precision and param_and_grad[
1359+
0].dtype == core.VarDesc.VarType.FP16
1360+
master_weight = (self._master_weights[param_and_grad[0].name]
1361+
if find_master else None)
1362+
13121363
lr = self._create_param_lr(param_and_grad)
13131364
if framework.in_dygraph_mode():
1314-
_C_ops.sgd(param_and_grad[0], lr, param_and_grad[1],
1315-
param_and_grad[0])
1365+
_C_ops.sgd(param_and_grad[0], lr, param_and_grad[1], master_weight,
1366+
param_and_grad[0], master_weight)
13161367
return None
13171368

13181369
assert isinstance(block, framework.Block)
13191370
# create the optimize op
1371+
inputs = {
1372+
"Param": param_and_grad[0],
1373+
"Grad": param_and_grad[1],
1374+
"LearningRate": lr
1375+
}
1376+
1377+
outputs = {"ParamOut": param_and_grad[0]}
1378+
1379+
attrs = {"multi_precision": find_master}
1380+
1381+
if find_master:
1382+
inputs["MasterParam"] = master_weight
1383+
outputs["MasterParamOut"] = master_weight
1384+
13201385
sgd_op = block.append_op(
13211386
type=self.type,
1322-
inputs={
1323-
"Param": param_and_grad[0],
1324-
"Grad": param_and_grad[1],
1325-
"LearningRate": lr
1326-
},
1327-
attrs={"use_mkldnn": self._use_mkldnn},
1328-
outputs={"ParamOut": param_and_grad[0]},
1387+
inputs=inputs,
1388+
outputs=outputs,
1389+
attrs=attrs,
13291390
stop_gradient=True)
13301391

13311392
return sgd_op

0 commit comments

Comments
 (0)