@@ -123,6 +123,149 @@ std::vector<Tensor> split_impl(const Tensor& x,
123123 return out;
124124}
125125
126+ std::tuple<Tensor, Tensor, Tensor> momentum_impl (
127+ const Tensor& param,
128+ const Tensor& grad,
129+ const Tensor& velocity,
130+ const Tensor& learning_rate,
131+ paddle::optional<const Tensor&> master_param,
132+ float mu,
133+ bool use_nesterov,
134+ const std::string& regularization_method,
135+ float regularization_coeff,
136+ bool multi_precision,
137+ float rescale_grad) {
138+ Backend kernel_backend = Backend::UNDEFINED;
139+ DataLayout kernel_layout = DataLayout::UNDEFINED;
140+ DataType kernel_data_type = DataType::UNDEFINED;
141+ if (kernel_backend == Backend::UNDEFINED ||
142+ kernel_layout == DataLayout::UNDEFINED ||
143+ kernel_data_type == DataType::UNDEFINED) {
144+ auto kernel_key_set = ParseKernelKeyByInputArgs (param);
145+ auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey ();
146+ if (kernel_backend == Backend::UNDEFINED) {
147+ kernel_backend = kernel_key.backend ();
148+ }
149+ if (kernel_layout == DataLayout::UNDEFINED) {
150+ kernel_layout = kernel_key.layout ();
151+ }
152+ if (kernel_data_type == DataType::UNDEFINED) {
153+ kernel_data_type = kernel_key.dtype ();
154+ }
155+ }
156+ std::string kernel_name = " momentum" ;
157+ if (grad.is_selected_rows ()) {
158+ kernel_name = " momentum_dense_param_sparse_grad" ;
159+ }
160+ const auto & kernel = phi::KernelFactory::Instance ().SelectKernelOrThrowError (
161+ kernel_name, {kernel_backend, kernel_layout, kernel_data_type});
162+ VLOG (6 ) << kernel_name << " API kernel key: [" << kernel_backend << " , "
163+ << kernel_layout << " , " << kernel_data_type << " ]" ;
164+ VLOG (6 ) << kernel_name << " API kernel: " << kernel;
165+
166+ auto * dev_ctx = GetDeviceContextByBackend (kernel_backend);
167+
168+ auto input_param = PrepareData (param, kernel.InputAt (0 ), {});
169+ auto input_grad = PrepareData (grad, kernel.InputAt (1 ), {});
170+ auto input_velocity = PrepareData (velocity, kernel.InputAt (2 ), {});
171+ auto input_learning_rate = PrepareData (learning_rate, kernel.InputAt (3 ), {});
172+ paddle::optional<const phi::DenseTensor&> input_master_param (paddle::none);
173+ auto input_master_param_ptr =
174+ PrepareData (master_param, kernel.InputAt (4 ), {});
175+
176+ std::tuple<Tensor, Tensor, Tensor> api_output;
177+ auto kernel_out_0 = input_param.get ();
178+ auto kernel_out_1 = input_velocity.get ();
179+ phi::DenseTensor* kernel_out_2 = nullptr ;
180+ if (input_master_param_ptr) {
181+ input_master_param =
182+ paddle::make_optional<const phi::DenseTensor&>(*input_master_param_ptr);
183+ kernel_out_2 =
184+ paddle::make_optional<phi::DenseTensor&>(*input_master_param_ptr)
185+ .get_ptr ();
186+ }
187+
188+ paddle::optional<const phi::MetaTensor&> input_meta_ref_master_param (
189+ paddle::none);
190+ phi::DenseTensor dt;
191+ phi::MetaTensor input_meta_tmp_master_param (dt);
192+ if (input_master_param_ptr) {
193+ input_meta_tmp_master_param.set_dtype (input_master_param_ptr->dtype ());
194+ input_meta_tmp_master_param.set_dims (input_master_param_ptr->dims ());
195+ input_meta_tmp_master_param.set_layout (input_master_param_ptr->layout ());
196+ input_meta_ref_master_param = input_meta_tmp_master_param;
197+ }
198+ phi::MetaTensor meta_out_0 (kernel_out_0);
199+ phi::MetaTensor meta_out_1 (kernel_out_1);
200+ if (kernel_out_2) {
201+ phi::MetaTensor meta_out_2 (kernel_out_2);
202+ phi::MomentumInferMeta (MakeMetaTensor (*input_param),
203+ MakeMetaTensor (*input_grad),
204+ MakeMetaTensor (*input_velocity),
205+ MakeMetaTensor (*input_learning_rate),
206+ input_meta_ref_master_param,
207+ mu,
208+ use_nesterov,
209+ regularization_method,
210+ regularization_coeff,
211+ multi_precision,
212+ rescale_grad,
213+ &meta_out_0,
214+ &meta_out_1,
215+ &meta_out_2);
216+ } else {
217+ phi::MomentumInferMeta (MakeMetaTensor (*input_param),
218+ MakeMetaTensor (*input_grad),
219+ MakeMetaTensor (*input_velocity),
220+ MakeMetaTensor (*input_learning_rate),
221+ input_meta_ref_master_param,
222+ mu,
223+ use_nesterov,
224+ regularization_method,
225+ regularization_coeff,
226+ multi_precision,
227+ rescale_grad,
228+ &meta_out_0,
229+ &meta_out_1,
230+ nullptr );
231+ }
232+
233+ using kernel_signature = void (*)(const platform::DeviceContext&,
234+ const phi::DenseTensor&,
235+ const phi::DenseTensor&,
236+ const phi::DenseTensor&,
237+ const phi::DenseTensor&,
238+ paddle::optional<const phi::DenseTensor&>,
239+ float ,
240+ bool ,
241+ const std::string&,
242+ float ,
243+ bool ,
244+ float ,
245+ phi::DenseTensor*,
246+ phi::DenseTensor*,
247+ phi::DenseTensor*);
248+ auto * kernel_fn = kernel.GetVariadicKernelFn <kernel_signature>();
249+
250+ (*kernel_fn)(*dev_ctx,
251+ *input_param,
252+ *input_grad,
253+ *input_velocity,
254+ *input_learning_rate,
255+ input_master_param,
256+ mu,
257+ use_nesterov,
258+ regularization_method,
259+ regularization_coeff,
260+ multi_precision,
261+ rescale_grad,
262+ kernel_out_0,
263+ kernel_out_1,
264+ kernel_out_2);
265+
266+ return api_output;
267+ }
268+
126269// //////////////// Backward(grad) api impls //////////////////////
127270
128271// TODO(chenweihang): the original sum grad op can support higher-level
0 commit comments