Skip to content

Commit fac7fd4

Browse files
authored
[Phi]Add mean/momentum yaml (#41319)
* move yaml * add momentum yaml * delete code * delete some code * add meshgrid backward * delete code * fix compile bugs
1 parent a288fca commit fac7fd4

File tree

10 files changed

+272
-8
lines changed

10 files changed

+272
-8
lines changed

paddle/phi/api/lib/api_custom_impl.cc

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,149 @@ std::vector<Tensor> split_impl(const Tensor& x,
123123
return out;
124124
}
125125

126+
std::tuple<Tensor, Tensor, Tensor> momentum_impl(
127+
const Tensor& param,
128+
const Tensor& grad,
129+
const Tensor& velocity,
130+
const Tensor& learning_rate,
131+
paddle::optional<const Tensor&> master_param,
132+
float mu,
133+
bool use_nesterov,
134+
const std::string& regularization_method,
135+
float regularization_coeff,
136+
bool multi_precision,
137+
float rescale_grad) {
138+
Backend kernel_backend = Backend::UNDEFINED;
139+
DataLayout kernel_layout = DataLayout::UNDEFINED;
140+
DataType kernel_data_type = DataType::UNDEFINED;
141+
if (kernel_backend == Backend::UNDEFINED ||
142+
kernel_layout == DataLayout::UNDEFINED ||
143+
kernel_data_type == DataType::UNDEFINED) {
144+
auto kernel_key_set = ParseKernelKeyByInputArgs(param);
145+
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
146+
if (kernel_backend == Backend::UNDEFINED) {
147+
kernel_backend = kernel_key.backend();
148+
}
149+
if (kernel_layout == DataLayout::UNDEFINED) {
150+
kernel_layout = kernel_key.layout();
151+
}
152+
if (kernel_data_type == DataType::UNDEFINED) {
153+
kernel_data_type = kernel_key.dtype();
154+
}
155+
}
156+
std::string kernel_name = "momentum";
157+
if (grad.is_selected_rows()) {
158+
kernel_name = "momentum_dense_param_sparse_grad";
159+
}
160+
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
161+
kernel_name, {kernel_backend, kernel_layout, kernel_data_type});
162+
VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", "
163+
<< kernel_layout << ", " << kernel_data_type << "]";
164+
VLOG(6) << kernel_name << " API kernel: " << kernel;
165+
166+
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
167+
168+
auto input_param = PrepareData(param, kernel.InputAt(0), {});
169+
auto input_grad = PrepareData(grad, kernel.InputAt(1), {});
170+
auto input_velocity = PrepareData(velocity, kernel.InputAt(2), {});
171+
auto input_learning_rate = PrepareData(learning_rate, kernel.InputAt(3), {});
172+
paddle::optional<const phi::DenseTensor&> input_master_param(paddle::none);
173+
auto input_master_param_ptr =
174+
PrepareData(master_param, kernel.InputAt(4), {});
175+
176+
std::tuple<Tensor, Tensor, Tensor> api_output;
177+
auto kernel_out_0 = input_param.get();
178+
auto kernel_out_1 = input_velocity.get();
179+
phi::DenseTensor* kernel_out_2 = nullptr;
180+
if (input_master_param_ptr) {
181+
input_master_param =
182+
paddle::make_optional<const phi::DenseTensor&>(*input_master_param_ptr);
183+
kernel_out_2 =
184+
paddle::make_optional<phi::DenseTensor&>(*input_master_param_ptr)
185+
.get_ptr();
186+
}
187+
188+
paddle::optional<const phi::MetaTensor&> input_meta_ref_master_param(
189+
paddle::none);
190+
phi::DenseTensor dt;
191+
phi::MetaTensor input_meta_tmp_master_param(dt);
192+
if (input_master_param_ptr) {
193+
input_meta_tmp_master_param.set_dtype(input_master_param_ptr->dtype());
194+
input_meta_tmp_master_param.set_dims(input_master_param_ptr->dims());
195+
input_meta_tmp_master_param.set_layout(input_master_param_ptr->layout());
196+
input_meta_ref_master_param = input_meta_tmp_master_param;
197+
}
198+
phi::MetaTensor meta_out_0(kernel_out_0);
199+
phi::MetaTensor meta_out_1(kernel_out_1);
200+
if (kernel_out_2) {
201+
phi::MetaTensor meta_out_2(kernel_out_2);
202+
phi::MomentumInferMeta(MakeMetaTensor(*input_param),
203+
MakeMetaTensor(*input_grad),
204+
MakeMetaTensor(*input_velocity),
205+
MakeMetaTensor(*input_learning_rate),
206+
input_meta_ref_master_param,
207+
mu,
208+
use_nesterov,
209+
regularization_method,
210+
regularization_coeff,
211+
multi_precision,
212+
rescale_grad,
213+
&meta_out_0,
214+
&meta_out_1,
215+
&meta_out_2);
216+
} else {
217+
phi::MomentumInferMeta(MakeMetaTensor(*input_param),
218+
MakeMetaTensor(*input_grad),
219+
MakeMetaTensor(*input_velocity),
220+
MakeMetaTensor(*input_learning_rate),
221+
input_meta_ref_master_param,
222+
mu,
223+
use_nesterov,
224+
regularization_method,
225+
regularization_coeff,
226+
multi_precision,
227+
rescale_grad,
228+
&meta_out_0,
229+
&meta_out_1,
230+
nullptr);
231+
}
232+
233+
using kernel_signature = void (*)(const platform::DeviceContext&,
234+
const phi::DenseTensor&,
235+
const phi::DenseTensor&,
236+
const phi::DenseTensor&,
237+
const phi::DenseTensor&,
238+
paddle::optional<const phi::DenseTensor&>,
239+
float,
240+
bool,
241+
const std::string&,
242+
float,
243+
bool,
244+
float,
245+
phi::DenseTensor*,
246+
phi::DenseTensor*,
247+
phi::DenseTensor*);
248+
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
249+
250+
(*kernel_fn)(*dev_ctx,
251+
*input_param,
252+
*input_grad,
253+
*input_velocity,
254+
*input_learning_rate,
255+
input_master_param,
256+
mu,
257+
use_nesterov,
258+
regularization_method,
259+
regularization_coeff,
260+
multi_precision,
261+
rescale_grad,
262+
kernel_out_0,
263+
kernel_out_1,
264+
kernel_out_2);
265+
266+
return api_output;
267+
}
268+
126269
////////////////// Backward(grad) api impls //////////////////////
127270

128271
// TODO(chenweihang): the original sum grad op can support higher-level

paddle/phi/api/lib/api_custom_impl.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include "paddle/phi/common/int_array.h"
1919
#include "paddle/phi/common/place.h"
2020
#include "paddle/phi/common/scalar.h"
21+
#include "paddle/utils/optional.h"
2122

2223
namespace paddle {
2324
namespace experimental {
@@ -33,6 +34,19 @@ std::vector<Tensor> split_impl(const Tensor& x,
3334
const IntArray& num_or_sections,
3435
const Scalar& axis);
3536

37+
std::tuple<Tensor, Tensor, Tensor> momentum_impl(
38+
const Tensor& param,
39+
const Tensor& grad,
40+
const Tensor& velocity,
41+
const Tensor& learning_rate,
42+
paddle::optional<const Tensor&> master_param,
43+
float mu,
44+
bool use_nesterov,
45+
const std::string& regularization_method,
46+
float regularization_coeff,
47+
bool multi_precision,
48+
float rescale_grad);
49+
3650
////////////////// Backward(grad) api impls //////////////////////
3751

3852
std::vector<Tensor> add_n_grad_impl(const std::vector<Tensor>& x,

paddle/phi/infermeta/multiary.cc

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,53 @@ void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs,
15041504
}
15051505
}
15061506

1507+
void MomentumInferMeta(const MetaTensor& param,
1508+
const MetaTensor& grad,
1509+
const MetaTensor& velocity,
1510+
const MetaTensor& learning_rate,
1511+
paddle::optional<const MetaTensor&> master_param,
1512+
float mu,
1513+
bool use_nesterov,
1514+
const std::string& regularization_method,
1515+
float regularization_coeff,
1516+
bool multi_precision,
1517+
float rescale_grad,
1518+
MetaTensor* param_out,
1519+
MetaTensor* velocity_out,
1520+
MetaTensor* master_param_out) {
1521+
PADDLE_ENFORCE_NE(
1522+
param_out,
1523+
nullptr,
1524+
errors::NotFound("Output(ParamOut) of Momentum should not be null."));
1525+
PADDLE_ENFORCE_NE(
1526+
velocity_out,
1527+
nullptr,
1528+
errors::NotFound("Output(VelocityOut) of Momentum should not be null."));
1529+
1530+
auto lr_dims = learning_rate.dims();
1531+
PADDLE_ENFORCE_NE(
1532+
phi::product(lr_dims),
1533+
0,
1534+
errors::InvalidArgument("Maybe the Input variable LearningRate has not "
1535+
"been initialized. You may need to confirm "
1536+
"if you put exe.run(startup_program) "
1537+
"after optimizer.minimize function."));
1538+
PADDLE_ENFORCE_EQ(
1539+
phi::product(lr_dims),
1540+
1,
1541+
errors::InvalidArgument("Learning_rate should be a scalar. But Received "
1542+
"LearningRate's dim [%s]",
1543+
phi::product(lr_dims)));
1544+
1545+
auto param_dim = param.dims();
1546+
param_out->set_dims(param_dim);
1547+
velocity_out->set_dims(param_dim);
1548+
1549+
if (master_param_out) {
1550+
master_param_out->set_dims(param_dim);
1551+
}
1552+
}
1553+
15071554
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out) {
15081555
auto inputs_dims = GetMetaTensorsDim(x);
15091556

paddle/phi/infermeta/multiary.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,21 @@ void InterpolateInferMeta(
230230
void MeshgridInferMeta(const std::vector<MetaTensor*>& inputs,
231231
std::vector<MetaTensor*> outputs);
232232

233+
void MomentumInferMeta(const MetaTensor& param,
234+
const MetaTensor& grad,
235+
const MetaTensor& velocity,
236+
const MetaTensor& learning_rate,
237+
paddle::optional<const MetaTensor&> master_param,
238+
float mu,
239+
bool use_nesterov,
240+
const std::string& regularization_method,
241+
float regularization_coeff,
242+
bool multi_precision,
243+
float rescale_grad,
244+
MetaTensor* param_out,
245+
MetaTensor* velocity_out,
246+
MetaTensor* master_param_out);
247+
233248
void MultiDotInferMeta(const std::vector<MetaTensor*>& x, MetaTensor* out);
234249

235250
void MultiplexInferMeta(const std::vector<MetaTensor*>& ins,

python/paddle/fluid/layers/nn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12806,8 +12806,10 @@ def mean(x, name=None):
1280612806
mean = fluid.layers.mean(input)
1280712807
"""
1280812808

12809-
if _non_static_mode():
12809+
if _in_legacy_dygraph():
1281012810
return _C_ops.mean(x)
12811+
if in_dygraph_mode():
12812+
return _C_ops.final_state_mean_all(x)
1281112813

1281212814
helper = LayerHelper("mean", **locals())
1281312815
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'mean')

python/paddle/fluid/tests/unittests/test_mean_op.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import paddle.fluid.core as core
2222
import paddle.fluid as fluid
2323
from paddle.fluid import Program, program_guard
24-
24+
from paddle.fluid.framework import _test_eager_guard
2525
np.random.seed(10)
2626

2727

@@ -40,7 +40,7 @@ def reduce_mean_wrapper(x, axis=0, keepdim=False, reduce_all=False):
4040
class TestMeanOp(OpTest):
4141
def setUp(self):
4242
self.op_type = "mean"
43-
self.python_api = mean_wrapper
43+
self.python_api = fluid.layers.mean
4444
self.dtype = np.float64
4545
self.init_dtype_type()
4646
self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)}
@@ -81,7 +81,7 @@ def init_dtype_type(self):
8181
def test_check_output(self):
8282
place = core.CUDAPlace(0)
8383
if core.is_float16_supported(place):
84-
self.check_output_with_place(place)
84+
self.check_output_with_place(place, check_eager=True)
8585

8686
def test_checkout_grad(self):
8787
place = core.CUDAPlace(0)
@@ -104,11 +104,11 @@ def init_dtype_type(self):
104104

105105
def test_check_output(self):
106106
paddle.enable_static()
107-
self.check_output_with_place(core.CPUPlace())
107+
self.check_output_with_place(core.CPUPlace(), check_eager=True)
108108

109109
def test_checkout_grad(self):
110110
place = core.CPUPlace()
111-
self.check_grad_with_place(place, ['X'], 'Out')
111+
self.check_grad_with_place(place, ['X'], 'Out', check_eager=True)
112112

113113

114114
def ref_reduce_mean(x, axis=None, keepdim=False, reduce_all=False):

python/paddle/fluid/tests/unittests/test_momentum_op.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import paddle
2323
import paddle.fluid as fluid
2424
import numpy
25+
from paddle.fluid.framework import _test_eager_guard
2526

2627

2728
def calculate_momentum_by_numpy(param,
@@ -528,6 +529,11 @@ def test_raise_error(self):
528529
ValueError, paddle.optimizer.Momentum, learning_rate=None)
529530
self.assertRaises(ValueError, paddle.optimizer.Momentum, momentum=None)
530531

532+
def test_api_eager_dygraph(self):
533+
with _test_eager_guard():
534+
self.test_momentum_dygraph()
535+
self.test_raise_error()
536+
531537

532538
class TestMomentumOpWithDecay(OpTest):
533539
def setUp(self):
@@ -921,6 +927,10 @@ def test_main(self):
921927
self._check_with_param_arrt(place, use_amp)
922928
self._check_with_param_group(place, use_amp)
923929

930+
def test_api_eager_dygraph(self):
931+
with _test_eager_guard():
932+
self.test_main()
933+
924934

925935
class TestMultiTensorMomentumStatic(unittest.TestCase):
926936
def _momentum_optimize_static(self,

python/paddle/optimizer/momentum.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from paddle.fluid.regularizer import L2DecayRegularizer
2626
from paddle import _C_ops
2727
import paddle
28+
from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph
2829

2930
__all__ = []
3031

@@ -313,7 +314,7 @@ def _append_optimize_op(self, block, param_and_grad):
313314
master_weight = (self._master_weights[param_and_grad[0].name]
314315
if find_master else None)
315316

316-
if framework._non_static_mode():
317+
if _in_legacy_dygraph():
317318
if isinstance(param_and_grad, dict):
318319
self._update_regularization(param_and_grad['weight_decay'])
319320
_, _, _ = _C_ops.momentum(
@@ -323,8 +324,15 @@ def _append_optimize_op(self, block, param_and_grad):
323324
'regularization_method', regularization_method,
324325
'regularization_coeff', regularization_coeff, 'multi_precision',
325326
find_master)
326-
327327
return None
328+
if in_dygraph_mode():
329+
if isinstance(param_and_grad, dict):
330+
self._update_regularization(param_and_grad['weight_decay'])
331+
return _C_ops.final_state_momentum(
332+
param_and_grad[0], param_and_grad[1], velocity_acc, lr,
333+
master_weight, self._momentum, self._use_nesterov,
334+
regularization_method, regularization_coeff, find_master,
335+
self._rescale_grad)
328336

329337
attrs = {
330338
"mu": self._momentum,

0 commit comments

Comments
 (0)