@@ -32,51 +32,7 @@ limitations under the License. */
3232namespace paddle {
3333namespace experimental {
3434
35- // TODO(chenweihang): the original sum grad op can support higher-level
36- // differentiation,
37- // but if we use this impl, it will not support. We need to be able to reuse
38- // the autograd API here, which is not yet implemented
39- // TODO(chenweihang): we should support call generated api in custom api impl
40- std::vector<Tensor> add_n_grad_impl (const std::vector<Tensor>& x,
41- const Tensor& out_grad) {
42- auto kernel_key_set = ParseKernelKeyByInputArgs (out_grad);
43- auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey ();
44-
45- Backend kernel_backend = kernel_key.backend ();
46- DataLayout kernel_layout = kernel_key.layout ();
47- DataType kernel_data_type = kernel_key.dtype ();
48-
49- auto kernel = phi::KernelFactory::Instance ().SelectKernelOrThrowError (
50- " scale" , {kernel_backend, kernel_layout, kernel_data_type});
51- VLOG (6 ) << " add_n_grad API kernel key: [" << kernel_backend << " , "
52- << kernel_layout << " , " << kernel_data_type << " ]" ;
53- VLOG (6 ) << " add_n_grad API kernel: " << kernel;
54-
55- auto * dev_ctx = GetDeviceContextByBackend (kernel_backend);
56-
57- auto dense_out_grad = PrepareData (out_grad, kernel.InputAt (0 ), {});
58-
59- size_t out_number = x.size ();
60- std::vector<Tensor> x_grad;
61- auto dense_x_grad = SetKernelOutput (out_number, kernel_backend, &x_grad);
62-
63- using kernel_signature = void (*)(const platform::DeviceContext&,
64- const phi::DenseTensor&,
65- const phi::Scalar&,
66- float ,
67- bool ,
68- phi::DenseTensor*);
69- auto * kernel_fn = kernel.GetVariadicKernelFn <kernel_signature>();
70-
71- for (auto * dense_x_grad_t : dense_x_grad) {
72- phi::MetaTensor meta_out (dense_x_grad_t );
73- phi::UnchangedInferMeta (MakeMetaTensor (*dense_out_grad), &meta_out);
74- (*kernel_fn)(
75- *dev_ctx, *dense_out_grad, phi::Scalar (1.0 ), 0.0 , true , dense_x_grad_t );
76- }
77-
78- return x_grad;
79- }
35+ // //////////////// Forward api impls //////////////////////
8036
8137Tensor copy_to_impl (const Tensor& x, Place place, bool blocking) {
8238 auto kernel_key_set = ParseKernelKeyByInputArgs (x);
@@ -310,6 +266,54 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
310266 return api_output;
311267}
312268
269+ // //////////////// Backward(grad) api impls //////////////////////
270+
271+ // TODO(chenweihang): the original sum grad op can support higher-level
272+ // differentiation,
273+ // but if we use this impl, it will not support. We need to be able to reuse
274+ // the autograd API here, which is not yet implemented
275+ // TODO(chenweihang): we should support call generated api in custom api impl
276+ std::vector<Tensor> add_n_grad_impl (const std::vector<Tensor>& x,
277+ const Tensor& out_grad) {
278+ auto kernel_key_set = ParseKernelKeyByInputArgs (out_grad);
279+ auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey ();
280+
281+ Backend kernel_backend = kernel_key.backend ();
282+ DataLayout kernel_layout = kernel_key.layout ();
283+ DataType kernel_data_type = kernel_key.dtype ();
284+
285+ auto kernel = phi::KernelFactory::Instance ().SelectKernelOrThrowError (
286+ " scale" , {kernel_backend, kernel_layout, kernel_data_type});
287+ VLOG (6 ) << " add_n_grad API kernel key: [" << kernel_backend << " , "
288+ << kernel_layout << " , " << kernel_data_type << " ]" ;
289+ VLOG (6 ) << " add_n_grad API kernel: " << kernel;
290+
291+ auto * dev_ctx = GetDeviceContextByBackend (kernel_backend);
292+
293+ auto dense_out_grad = PrepareData (out_grad, kernel.InputAt (0 ), {});
294+
295+ size_t out_number = x.size ();
296+ std::vector<Tensor> x_grad;
297+ auto dense_x_grad = SetKernelOutput (out_number, kernel_backend, &x_grad);
298+
299+ using kernel_signature = void (*)(const platform::DeviceContext&,
300+ const phi::DenseTensor&,
301+ const phi::Scalar&,
302+ float ,
303+ bool ,
304+ phi::DenseTensor*);
305+ auto * kernel_fn = kernel.GetVariadicKernelFn <kernel_signature>();
306+
307+ for (auto * dense_x_grad_t : dense_x_grad) {
308+ phi::MetaTensor meta_out (dense_x_grad_t );
309+ phi::UnchangedInferMeta (MakeMetaTensor (*dense_out_grad), &meta_out);
310+ (*kernel_fn)(
311+ *dev_ctx, *dense_out_grad, phi::Scalar (1.0 ), 0.0 , true , dense_x_grad_t );
312+ }
313+
314+ return x_grad;
315+ }
316+
313317std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl (
314318 const Tensor& x,
315319 const Tensor& scale,
@@ -504,5 +508,50 @@ std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
504508 return x_grad;
505509}
506510
511+ std::vector<Tensor> stack_grad_impl (const std::vector<Tensor>& x,
512+ const Tensor& out_grad,
513+ int axis) {
514+ auto kernel_key_set = ParseKernelKeyByInputArgs (out_grad);
515+ auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey ();
516+
517+ Backend kernel_backend = kernel_key.backend ();
518+ DataLayout kernel_layout = kernel_key.layout ();
519+ DataType kernel_data_type = kernel_key.dtype ();
520+
521+ auto kernel = phi::KernelFactory::Instance ().SelectKernelOrThrowError (
522+ " stack_grad" , {kernel_backend, kernel_layout, kernel_data_type});
523+ VLOG (6 ) << " stack_grad API kernel key: [" << kernel_backend << " , "
524+ << kernel_layout << " , " << kernel_data_type << " ]" ;
525+ VLOG (6 ) << " stack_grad API kernel: " << kernel;
526+
527+ auto * dev_ctx = GetDeviceContextByBackend (kernel_backend);
528+
529+ auto dense_out_grad = PrepareData (out_grad, kernel.InputAt (0 ), {});
530+
531+ size_t out_number = x.size ();
532+ std::vector<Tensor> x_grad;
533+ auto dense_x_grad = SetKernelOutput (out_number, kernel_backend, &x_grad);
534+ std::vector<phi::MetaTensor> meta_x_grad;
535+ meta_x_grad.reserve (out_number);
536+ std::vector<phi::MetaTensor*> meta_x_grad_ptrs;
537+ meta_x_grad_ptrs.reserve (out_number);
538+ for (size_t i = 0 ; i < out_number; ++i) {
539+ meta_x_grad.push_back (dense_x_grad[i]);
540+ meta_x_grad_ptrs.push_back (&meta_x_grad.back ());
541+ }
542+
543+ phi::StackGradInferMeta (
544+ MakeMetaTensor (*dense_out_grad), axis, meta_x_grad_ptrs);
545+
546+ using kernel_signature = void (*)(const platform::DeviceContext&,
547+ const phi::DenseTensor&,
548+ int axis,
549+ std::vector<phi::DenseTensor*>);
550+ auto * kernel_fn = kernel.GetVariadicKernelFn <kernel_signature>();
551+ (*kernel_fn)(*dev_ctx, *dense_out_grad, axis, dense_x_grad);
552+
553+ return x_grad;
554+ }
555+
507556} // namespace experimental
508557} // namespace paddle
0 commit comments