Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 60 additions & 50 deletions paddle/fluid/operators/optimizers/merged_momentum_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ struct MergedMomentumKernelParam
T *PADDLE_RESTRICT params[N];
const T *PADDLE_RESTRICT grads[N];
MT *PADDLE_RESTRICT velocitys[N];
const MT *PADDLE_RESTRICT lr;
const MultiPrecisionType<MT> *PADDLE_RESTRICT lr;
MT mu;
MT rescale_grad;
uint32_t param_num;

HOSTDEVICE void operator()(size_t i) const {
const auto lr_val = *lr;
const MT lr_val = static_cast<MT>(*lr);
for (uint32_t idx = 0; idx < param_num; ++idx) {
auto size = sizes[idx];
if (i >= size) continue;
Expand All @@ -81,8 +81,22 @@ struct MergedMomentumKernelParam

template <typename DeviceContext, typename T>
class MergedMomentumOpKernel : public framework::OpKernel<T> {
using MPType = typename operators::details::MPTypeTrait<T>::Type;

public:
void Compute(const framework::ExecutionContext &ctx) const override {
const bool multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
InnerCompute<MPType>(ctx, multi_precision);
} else {
InnerCompute<T>(ctx, multi_precision);
}
}

private:
template <typename MT>
void InnerCompute(const framework::ExecutionContext &ctx,
const bool multi_precision) const {
auto params = ctx.MultiInput<framework::Tensor>("Param");
auto params_out = ctx.MultiOutput<framework::Tensor>("ParamOut");
size_t n = params.size();
Expand Down Expand Up @@ -133,7 +147,6 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
auto master_params = ctx.MultiInput<framework::Tensor>("MasterParam");
auto master_params_out =
ctx.MultiOutput<framework::Tensor>("MasterParamOut");
auto multi_precision = ctx.Attr<bool>("multi_precision");
if (multi_precision) {
PADDLE_ENFORCE_EQ(
n, master_params.size(),
Expand Down Expand Up @@ -206,39 +219,37 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
<< ", regularization_coeffs.size(): "
<< regularization_coeffs.size();

using MPType = typename operators::details::MPTypeTrait<T>::Type;

auto &dev_ctx = ctx.template device_context<DeviceContext>();

if (lrs.size() == 1 && use_nesterov == false &&
regularization_methods.size() == 0) {
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MPType, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \
kernel_params.mu = static_cast<MPType>(mu); \
kernel_params.rescale_grad = static_cast<MPType>(rescale_grad); \
kernel_params.lr = lrs[0]->data<MPType>(); \
for (size_t i = 0; i < kernel_num; ++i) { \
size_t start = i * kMaxMergedNum; \
size_t end = std::min((i + 1) * kMaxMergedNum, n); \
kernel_params.param_num = static_cast<uint32_t>(end - start); \
size_t max_size = 0; \
for (size_t j = 0; j < kernel_params.param_num; ++j) { \
auto size = static_cast<size_t>(params_out[j + start]->numel()); \
max_size = std::max(max_size, size); \
kernel_params.sizes[j] = size; \
kernel_params.params[j] = params_out[j + start]->data<T>(); \
kernel_params.grads[j] = grads[j + start]->data<T>(); \
kernel_params.velocitys[j] = velocitys_out[j + start]->data<MPType>(); \
kernel_params.SetMasterParam( \
j, kMultiPrecision ? master_params_out[j + start]->data<MPType>() \
: nullptr); \
} \
platform::ForRange<DeviceContext> for_range(dev_ctx, max_size); \
for_range(kernel_params); \
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
#define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(kMultiPrecision) \
MergedMomentumKernelParam<T, MT, kMultiPrecision> kernel_params; \
constexpr auto kMaxMergedNum = decltype(kernel_params)::N; \
size_t kernel_num = (n + kMaxMergedNum - 1) / kMaxMergedNum; \
kernel_params.mu = static_cast<MT>(mu); \
kernel_params.rescale_grad = static_cast<MT>(rescale_grad); \
kernel_params.lr = lrs[0]->data<MPType>(); \
for (size_t i = 0; i < kernel_num; ++i) { \
size_t start = i * kMaxMergedNum; \
size_t end = std::min((i + 1) * kMaxMergedNum, n); \
kernel_params.param_num = static_cast<uint32_t>(end - start); \
size_t max_size = 0; \
for (size_t j = 0; j < kernel_params.param_num; ++j) { \
auto size = static_cast<size_t>(params_out[j + start]->numel()); \
max_size = std::max(max_size, size); \
kernel_params.sizes[j] = size; \
kernel_params.params[j] = params_out[j + start]->data<T>(); \
kernel_params.grads[j] = grads[j + start]->data<T>(); \
kernel_params.velocitys[j] = velocitys_out[j + start]->data<MT>(); \
kernel_params.SetMasterParam( \
j, kMultiPrecision ? master_params_out[j + start]->data<MT>() \
: nullptr); \
} \
platform::ForRange<DeviceContext> for_range(dev_ctx, max_size); \
for_range(kernel_params); \
VLOG(10) << "Launch MergedMomentum kernel " << i << " " \
<< kernel_params.param_num; \
}
if (multi_precision) {
PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL(true);
Expand All @@ -254,34 +265,33 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
? RegularizationType::kL2DECAY
: RegularizationType::kNONE;

MPType regularization_coeff = static_cast<MPType>(0.0);
MT regularization_coeff = static_cast<MT>(0.0);
if (regularization_coeffs.size() != 0) {
regularization_coeff =
static_cast<MPType>(regularization_coeffs[idx]);
regularization_coeff = static_cast<MT>(regularization_coeffs[idx]);
}
auto lr_temp = lrs.size() > 1 ? lrs[idx] : lrs[0];

const MPType *master_in_data =
multi_precision ? master_params[idx]->data<MPType>() : nullptr;
MPType *master_out_data =
multi_precision ? master_params_out[idx]->data<MPType>() : nullptr;
const MT *master_in_data =
multi_precision ? master_params[idx]->data<MT>() : nullptr;
MT *master_out_data =
multi_precision ? master_params_out[idx]->data<MT>() : nullptr;
if (platform::is_cpu_place(ctx.GetPlace())) {
CPUDenseMomentumFunctor<MPType> functor;
functor(params[idx], grads[idx], velocitys[idx], lr_temp, mu,
use_nesterov, regularization_flag, regularization_coeff,
params_out[idx], velocitys_out[idx]);
CPUDenseMomentumFunctor<MT> functor;
functor(params[idx], grads[idx], velocitys[idx], lr_temp,
static_cast<MT>(mu), use_nesterov, regularization_flag,
regularization_coeff, params_out[idx], velocitys_out[idx]);
VLOG(10) << "Launch MergedMomentum cpu kernel.";
} else if (platform::is_gpu_place(ctx.GetPlace())) {
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext &>(ctx.device_context()),
params[idx]->numel());
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MPType, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), grads[idx]->data<T>(), \
velocitys[idx]->data<MPType>(), lr_temp->data<MPType>(), master_in_data, \
mu, rescale_grad, params[idx]->numel(), regularization_coeff, \
params_out[idx]->data<T>(), velocitys_out[idx]->data<MPType>(), \
master_out_data); \
#define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL(__nesterov, __reg_type) \
DenseMomentumFunctor<T, MT, __reg_type, __nesterov> functor( \
params[idx]->data<T>(), grads[idx]->data<T>(), \
velocitys[idx]->data<MT>(), lr_temp->data<MPType>(), master_in_data, \
static_cast<MT>(mu), static_cast<MT>(rescale_grad), \
params[idx]->numel(), regularization_coeff, params_out[idx]->data<T>(), \
velocitys_out[idx]->data<MT>(), master_out_data); \
for_range(functor);
if (use_nesterov) {
if (regularization_flag == RegularizationType::kL2DECAY) {
Expand Down
9 changes: 4 additions & 5 deletions python/paddle/optimizer/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,7 @@ def _append_optimize_multi_tensor_op(self, target_block,
multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor']
for key in multi_tensor_list:
if len(self._param_dict[key]) > 0:
if key == 'FP32_LODTensor':
self._multi_precision = False
find_master = self._multi_precision and key == 'FP16_LODTensor'

_beta1 = self._beta1 if not isinstance(
self._beta1, Variable) else self._beta1.numpy().item(0)
Expand All @@ -571,7 +570,7 @@ def _append_optimize_multi_tensor_op(self, target_block,
self._beta2_pow_acc_dict[key],
self._master_weight_dict[key], 'epsilon', self._epsilon,
'beta1', _beta1, 'beta2', _beta2, 'multi_precision',
self._multi_precision)
find_master)
else:
inputs = {
"Param": self._param_dict[key],
Expand All @@ -594,11 +593,11 @@ def _append_optimize_multi_tensor_op(self, target_block,
"beta1": _beta1,
"beta2": _beta2
}
if self._multi_precision:
if find_master:
inputs["MasterParam"] = self._master_weight_dict[key]
outputs["MasterParamOut"] = self._master_weight_dict[
key]
attrs["multi_precision"] = self._multi_precision
attrs["multi_precision"] = find_master
target_block.append_op(
type="merged_adam",
inputs=inputs,
Expand Down
9 changes: 4 additions & 5 deletions python/paddle/optimizer/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,7 @@ def _append_optimize_multi_tensor_op(self, target_block,
multi_tensor_list = ['FP32_LODTensor', 'FP16_LODTensor']
for key in multi_tensor_list:
if len(self._param_dict[key]) > 0:
if key == 'FP32_LODTensor':
self._multi_precision = False
find_master = self._multi_precision and key == 'FP16_LODTensor'

if framework.in_dygraph_mode():
_, _, _ = _C_ops.merged_momentum(
Expand All @@ -478,7 +477,7 @@ def _append_optimize_multi_tensor_op(self, target_block,
self._regularization_method_dict[key],
'regularization_coeff',
self._regularization_coeff_dict[key], 'multi_precision',
self._multi_precision)
find_master)
else:
inputs = {
"Param": self._param_dict[key],
Expand All @@ -498,11 +497,11 @@ def _append_optimize_multi_tensor_op(self, target_block,
"regularization_coeff":
self._regularization_coeff_dict[key],
}
if self._multi_precision:
if find_master:
inputs["MasterParam"] = self._master_weight_dict[key]
outputs["MasterParamOut"] = self._master_weight_dict[
key]
attrs["multi_precision"] = self._multi_precision
attrs["multi_precision"] = find_master
target_block.append_op(
type="merged_momentum",
inputs=inputs,
Expand Down