Skip to content
143 changes: 143 additions & 0 deletions paddle/phi/api/lib/api_custom_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,149 @@ std::vector<Tensor> split_impl(const Tensor& x,
return out;
}

std::tuple<Tensor, Tensor, Tensor> momentum_impl(
const Tensor& param,
const Tensor& grad,
const Tensor& velocity,
const Tensor& learning_rate,
paddle::optional<const Tensor&> master_param,
float mu,
bool use_nesterov,
const std::string& regularization_method,
float regularization_coeff,
bool multi_precision,
float rescale_grad) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(param);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
std::string kernel_name = "momentum";
if (grad.is_selected_rows()) {
kernel_name = "momentum_dense_param_sparse_grad";
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << kernel_name << " API kernel: " << kernel;

auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);

auto input_param = PrepareData(param, kernel.InputAt(0), {});
auto input_grad = PrepareData(grad, kernel.InputAt(1), {});
auto input_velocity = PrepareData(velocity, kernel.InputAt(2), {});
auto input_learning_rate = PrepareData(learning_rate, kernel.InputAt(3), {});
paddle::optional<const phi::DenseTensor&> input_master_param(paddle::none);
auto input_master_param_ptr =
PrepareData(master_param, kernel.InputAt(4), {});

std::tuple<Tensor, Tensor, Tensor> api_output;
auto kernel_out_0 = input_param.get();
auto kernel_out_1 = input_velocity.get();
phi::DenseTensor* kernel_out_2 = nullptr;
if (input_master_param_ptr) {
input_master_param =
paddle::make_optional<const phi::DenseTensor&>(*input_master_param_ptr);
kernel_out_2 =
paddle::make_optional<phi::DenseTensor&>(*input_master_param_ptr)
.get_ptr();
}

paddle::optional<const phi::MetaTensor&> input_meta_ref_master_param(
paddle::none);
phi::DenseTensor dt;
phi::MetaTensor input_meta_tmp_master_param(dt);
if (input_master_param_ptr) {
input_meta_tmp_master_param.set_dtype(input_master_param_ptr->dtype());
input_meta_tmp_master_param.set_dims(input_master_param_ptr->dims());
input_meta_tmp_master_param.set_layout(input_master_param_ptr->layout());
input_meta_ref_master_param = input_meta_tmp_master_param;
}
phi::MetaTensor meta_out_0(kernel_out_0);
phi::MetaTensor meta_out_1(kernel_out_1);
if (kernel_out_2) {
phi::MetaTensor meta_out_2(kernel_out_2);
phi::MomentumInferMeta(MakeMetaTensor(*input_param),
MakeMetaTensor(*input_grad),
MakeMetaTensor(*input_velocity),
MakeMetaTensor(*input_learning_rate),
input_meta_ref_master_param,
mu,
use_nesterov,
regularization_method,
regularization_coeff,
multi_precision,
rescale_grad,
&meta_out_0,
&meta_out_1,
&meta_out_2);
} else {
phi::MomentumInferMeta(MakeMetaTensor(*input_param),
MakeMetaTensor(*input_grad),
MakeMetaTensor(*input_velocity),
MakeMetaTensor(*input_learning_rate),
input_meta_ref_master_param,
mu,
use_nesterov,
regularization_method,
regularization_coeff,
multi_precision,
rescale_grad,
&meta_out_0,
&meta_out_1,
nullptr);
}

using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
paddle::optional<const phi::DenseTensor&>,
float,
bool,
const std::string&,
float,
bool,
float,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();

(*kernel_fn)(*dev_ctx,
*input_param,
*input_grad,
*input_velocity,
*input_learning_rate,
input_master_param,
mu,
use_nesterov,
regularization_method,
regularization_coeff,
multi_precision,
rescale_grad,
kernel_out_0,
kernel_out_1,
kernel_out_2);

return api_output;
}

////////////////// Backward(grad) api impls //////////////////////

// TODO(chenweihang): the original sum grad op can support higher-level
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/api/lib/api_custom_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/utils/optional.h"

namespace paddle {
namespace experimental {
Expand All @@ -33,6 +34,19 @@ std::vector<Tensor> split_impl(const Tensor& x,
const IntArray& num_or_sections,
const Scalar& axis);

std::tuple<Tensor, Tensor, Tensor> momentum_impl(
const Tensor& param,
const Tensor& grad,
const Tensor& velocity,
const Tensor& learning_rate,
paddle::optional<const Tensor&> master_param,
float mu,
bool use_nesterov,
const std::string& regularization_method,
float regularization_coeff,
bool multi_precision,
float rescale_grad);

////////////////// Backward(grad) api impls //////////////////////

std::vector<Tensor> add_n_grad_impl(const std::vector<Tensor>& x,
Expand Down
47 changes: 47 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,53 @@ void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs,
}
}

void MomentumInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& velocity,
const MetaTensor& learning_rate,
paddle::optional<const MetaTensor&> master_param,
float mu,
bool use_nesterov,
const std::string& regularization_method,
float regularization_coeff,
bool multi_precision,
float rescale_grad,
MetaTensor* param_out,
MetaTensor* velocity_out,
MetaTensor* master_param_out) {
PADDLE_ENFORCE_NE(
param_out,
nullptr,
errors::NotFound("Output(ParamOut) of Momentum should not be null."));
PADDLE_ENFORCE_NE(
velocity_out,
nullptr,
errors::NotFound("Output(VelocityOut) of Momentum should not be null."));

auto lr_dims = learning_rate.dims();
PADDLE_ENFORCE_NE(
phi::product(lr_dims),
0,
errors::InvalidArgument("Maybe the Input variable LearningRate has not "
"been initialized. You may need to confirm "
"if you put exe.run(startup_program) "
"after optimizer.minimize function."));
PADDLE_ENFORCE_EQ(
phi::product(lr_dims),
1,
errors::InvalidArgument("Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
phi::product(lr_dims)));

auto param_dim = param.dims();
param_out->set_dims(param_dim);
velocity_out->set_dims(param_dim);

if (master_param_out) {
master_param_out->set_dims(param_dim);
}
}

void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
auto inputs_dims = GetMetaTensorsDim(x);

Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,21 @@ void InterpolateInferMeta(
void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs);

void MomentumInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& velocity,
const MetaTensor& learning_rate,
paddle::optional<const MetaTensor&> master_param,
float mu,
bool use_nesterov,
const std::string& regularization_method,
float regularization_coeff,
bool multi_precision,
float rescale_grad,
MetaTensor* param_out,
MetaTensor* velocity_out,
MetaTensor* master_param_out);

void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out);

void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12806,8 +12806,10 @@ def mean(x, name=None):
mean = fluid.layers.mean(input)
"""

if _non_static_mode():
if _in_legacy_dygraph():
return _C_ops.mean(x)
if in_dygraph_mode():
return _C_ops.final_state_mean_all(x)

helper = LayerHelper("mean", **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'mean')
Expand Down
10 changes: 5 additions & 5 deletions python/paddle/fluid/tests/unittests/test_mean_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard

from paddle.fluid.framework import _test_eager_guard
np.random.seed(10)


Expand All @@ -40,7 +40,7 @@ def reduce_mean_wrapper(x, axis=0, keepdim=False, reduce_all=False):
class TestMeanOp(OpTest):
def setUp(self):
self.op_type = "mean"
self.python_api = mean_wrapper
self.python_api = fluid.layers.mean
self.dtype = np.float64
self.init_dtype_type()
self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)}
Expand Down Expand Up @@ -81,7 +81,7 @@ def init_dtype_type(self):
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place)
self.check_output_with_place(place, check_eager=True)

def test_checkout_grad(self):
place = core.CUDAPlace(0)
Expand All @@ -104,11 +104,11 @@ def init_dtype_type(self):

def test_check_output(self):
paddle.enable_static()
self.check_output_with_place(core.CPUPlace())
self.check_output_with_place(core.CPUPlace(), check_eager=True)

def test_checkout_grad(self):
place = core.CPUPlace()
self.check_grad_with_place(place, ['X'], 'Out')
self.check_grad_with_place(place, ['X'], 'Out', check_eager=True)


def ref_reduce_mean(x, axis=None, keepdim=False, reduce_all=False):
Expand Down
10 changes: 10 additions & 0 deletions python/paddle/fluid/tests/unittests/test_momentum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import paddle
import paddle.fluid as fluid
import numpy
from paddle.fluid.framework import _test_eager_guard


def calculate_momentum_by_numpy(param,
Expand Down Expand Up @@ -528,6 +529,11 @@ def test_raise_error(self):
ValueError, paddle.optimizer.Momentum, learning_rate=None)
self.assertRaises(ValueError, paddle.optimizer.Momentum, momentum=None)

def test_api_eager_dygraph(self):
with _test_eager_guard():
self.test_momentum_dygraph()
self.test_raise_error()


class TestMomentumOpWithDecay(OpTest):
def setUp(self):
Expand Down Expand Up @@ -921,6 +927,10 @@ def test_main(self):
self._check_with_param_arrt(place, use_amp)
self._check_with_param_group(place, use_amp)

def test_api_eager_dygraph(self):
with _test_eager_guard():
self.test_main()


class TestMultiTensorMomentumStatic(unittest.TestCase):
def _momentum_optimize_static(self,
Expand Down
12 changes: 10 additions & 2 deletions python/paddle/optimizer/momentum.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from paddle.fluid.regularizer import L2DecayRegularizer
from paddle import _C_ops
import paddle
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph

__all__ = []

Expand Down Expand Up @@ -313,7 +314,7 @@ def _append_optimize_op(self, block, param_and_grad):
master_weight = (self._master_weights[param_and_grad[0].name]
if find_master else None)

if framework._non_static_mode():
if _in_legacy_dygraph():
if isinstance(param_and_grad, dict):
self._update_regularization(param_and_grad['weight_decay'])
_, _, _ = _C_ops.momentum(
Expand All @@ -323,8 +324,15 @@ def _append_optimize_op(self, block, param_and_grad):
'regularization_method', regularization_method,
'regularization_coeff', regularization_coeff, 'multi_precision',
find_master)

return None
if in_dygraph_mode():
if isinstance(param_and_grad, dict):
self._update_regularization(param_and_grad['weight_decay'])
return _C_ops.final_state_momentum(
param_and_grad[0], param_and_grad[1], velocity_acc, lr,
master_weight, self._momentum, self._use_nesterov,
regularization_method, regularization_coeff, find_master,
self._rescale_grad)

attrs = {
"mu": self._momentum,
Expand Down
Loading