@@ -410,5 +410,153 @@ std::vector<Tensor> stack_grad_impl(const std::vector<Tensor>& x,
410410 return x_grad;
411411}
412412
413+ std::vector<Tensor> meshgrid_impl (const std::vector<Tensor>& inputs) {
414+ Backend kernel_backend = Backend::UNDEFINED;
415+ DataLayout kernel_layout = DataLayout::UNDEFINED;
416+ DataType kernel_data_type = DataType::UNDEFINED;
417+
418+ if (kernel_backend == Backend::UNDEFINED ||
419+ kernel_layout == DataLayout::UNDEFINED ||
420+ kernel_data_type == DataType::UNDEFINED) {
421+ auto kernel_key_set = ParseKernelKeyByInputArgs (inputs);
422+ auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey ();
423+ if (kernel_backend == Backend::UNDEFINED) {
424+ kernel_backend = kernel_key.backend ();
425+ }
426+ if (kernel_layout == DataLayout::UNDEFINED) {
427+ kernel_layout = kernel_key.layout ();
428+ }
429+ if (kernel_data_type == DataType::UNDEFINED) {
430+ kernel_data_type = kernel_key.dtype ();
431+ }
432+ }
433+
434+ const auto & kernel = phi::KernelFactory::Instance ().SelectKernelOrThrowError (
435+ " meshgrid" , {kernel_backend, kernel_layout, kernel_data_type});
436+ VLOG (6 ) << " meshgrid API kernel key: [" << kernel_backend << " , "
437+ << kernel_layout << " , " << kernel_data_type << " ]" ;
438+ VLOG (6 ) << " meshgrid API kernel: " << kernel;
439+
440+ auto * dev_ctx = GetDeviceContextByBackend (kernel_backend);
441+
442+ auto input_inputs_vec = PrepareData (inputs, kernel.InputAt (0 ), {});
443+ std::vector<const phi::DenseTensor*> input_inputs (input_inputs_vec->size ());
444+ for (size_t i = 0 ; i < input_inputs.size (); ++i) {
445+ input_inputs[i] = &input_inputs_vec->at (i);
446+ }
447+
448+ auto x_meta_vec = MakeMetaTensor (input_inputs);
449+ std::vector<phi::MetaTensor*> inputs_metas (x_meta_vec.size ());
450+ for (size_t i = 0 ; i < x_meta_vec.size (); ++i) {
451+ inputs_metas[i] = &x_meta_vec[i];
452+ }
453+
454+ // Calculate the number of out tensors
455+ size_t out_number = inputs.size ();
456+
457+ std::vector<Tensor> out;
458+ auto dense_outs = SetKernelOutput (out_number, kernel_backend, &out);
459+
460+ std::vector<phi::MetaTensor> meta_outs;
461+ meta_outs.reserve (out_number);
462+ std::vector<phi::MetaTensor*> meta_out_ptrs;
463+ meta_out_ptrs.reserve (out_number);
464+ for (size_t i = 0 ; i < out_number; ++i) {
465+ meta_outs.push_back (dense_outs[i]);
466+ meta_out_ptrs.push_back (&meta_outs.back ());
467+ }
468+ phi::MeshgridInferMeta (inputs_metas, meta_out_ptrs);
469+
470+ using kernel_signature = void (*)(const platform::DeviceContext&,
471+ const std::vector<const phi::DenseTensor*>&,
472+ std::vector<phi::DenseTensor*>&);
473+ auto * kernel_fn = kernel.GetVariadicKernelFn <kernel_signature>();
474+ (*kernel_fn)(*dev_ctx, input_inputs, dense_outs);
475+
476+ return out;
477+ }
478+
479+ std::vector<Tensor> meshgrid_grad_impl (
480+ const std::vector<Tensor>& inputs,
481+ const std::vector<Tensor>& outputs_grad) {
482+ Backend kernel_backend = Backend::UNDEFINED;
483+ DataLayout kernel_layout = DataLayout::UNDEFINED;
484+ DataType kernel_data_type = DataType::UNDEFINED;
485+
486+ if (kernel_backend == Backend::UNDEFINED ||
487+ kernel_layout == DataLayout::UNDEFINED ||
488+ kernel_data_type == DataType::UNDEFINED) {
489+ auto kernel_key_set = ParseKernelKeyByInputArgs (inputs, outputs_grad);
490+ auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey ();
491+ if (kernel_backend == Backend::UNDEFINED) {
492+ kernel_backend = kernel_key.backend ();
493+ }
494+ if (kernel_layout == DataLayout::UNDEFINED) {
495+ kernel_layout = kernel_key.layout ();
496+ }
497+ if (kernel_data_type == DataType::UNDEFINED) {
498+ kernel_data_type = kernel_key.dtype ();
499+ }
500+ }
501+
502+ const auto & kernel = phi::KernelFactory::Instance ().SelectKernelOrThrowError (
503+ " meshgrid_grad" , {kernel_backend, kernel_layout, kernel_data_type});
504+ VLOG (6 ) << " meshgrid_grad API kernel key: [" << kernel_backend << " , "
505+ << kernel_layout << " , " << kernel_data_type << " ]" ;
506+ VLOG (6 ) << " meshgrid_grad API kernel: " << kernel;
507+
508+ auto * dev_ctx = GetDeviceContextByBackend (kernel_backend);
509+
510+ auto input_inputs_vec = PrepareData (inputs, kernel.InputAt (0 ), {});
511+ std::vector<const phi::DenseTensor*> input_inputs (input_inputs_vec->size ());
512+ for (size_t i = 0 ; i < input_inputs.size (); ++i) {
513+ input_inputs[i] = &input_inputs_vec->at (i);
514+ }
515+ auto input_outputs_grad_vec =
516+ PrepareData (outputs_grad, kernel.InputAt (1 ), {});
517+ std::vector<const phi::DenseTensor*> input_outputs_grad (
518+ input_outputs_grad_vec->size ());
519+ for (size_t i = 0 ; i < input_outputs_grad.size (); ++i) {
520+ input_outputs_grad[i] = &input_outputs_grad_vec->at (i);
521+ }
522+
523+ size_t out_number = inputs.size ();
524+ std::vector<Tensor> api_output;
525+ auto kernel_out = SetKernelOutput (out_number, kernel_backend, &api_output);
526+
527+ auto inputs_meta_vec = MakeMetaTensor (input_inputs);
528+ std::vector<phi::MetaTensor*> inputs_metas (inputs_meta_vec.size ());
529+ for (size_t i = 0 ; i < inputs_meta_vec.size (); ++i) {
530+ inputs_metas[i] = &inputs_meta_vec[i];
531+ }
532+
533+ auto outputs_grad_meta_vec = MakeMetaTensor (input_outputs_grad);
534+ std::vector<phi::MetaTensor*> outputs_grad_metas (
535+ outputs_grad_meta_vec.size ());
536+ for (size_t i = 0 ; i < outputs_grad_meta_vec.size (); ++i) {
537+ outputs_grad_metas[i] = &outputs_grad_meta_vec[i];
538+ }
539+
540+ std::vector<phi::MetaTensor> meta_outs;
541+ meta_outs.reserve (out_number);
542+ std::vector<phi::MetaTensor*> meta_out_ptrs;
543+ meta_out_ptrs.reserve (out_number);
544+ for (size_t i = 0 ; i < out_number; ++i) {
545+ meta_outs.push_back (kernel_out[i]);
546+ meta_out_ptrs.push_back (&meta_outs.back ());
547+ }
548+
549+ phi::MeshgridGradInferMeta (inputs_metas, outputs_grad_metas, meta_out_ptrs);
550+
551+ using kernel_signature = void (*)(const platform::DeviceContext&,
552+ const std::vector<const phi::DenseTensor*>&,
553+ const std::vector<const phi::DenseTensor*>&,
554+ std::vector<phi::DenseTensor*>&);
555+ auto * kernel_fn = kernel.GetVariadicKernelFn <kernel_signature>();
556+ (*kernel_fn)(*dev_ctx, input_inputs, input_outputs_grad, kernel_out);
557+
558+ return api_output;
559+ }
560+
413561} // namespace experimental
414562} // namespace paddle
0 commit comments