| 
 | 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.  | 
 | 2 | +//  | 
 | 3 | +// Licensed under the Apache License, Version 2.0 (the "License");  | 
 | 4 | +// you may not use this file except in compliance with the License.  | 
 | 5 | +// You may obtain a copy of the License at  | 
 | 6 | +//  | 
 | 7 | +// http://www.apache.org/licenses/LICENSE-2.0  | 
 | 8 | +//  | 
 | 9 | +// Unless required by applicable law or agreed to in writing, software  | 
 | 10 | +// distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 12 | +// See the License for the specific language governing permissions and  | 
 | 13 | +// limitations under the License.  | 
 | 14 | + | 
 | 15 | +#include "paddle/phi/kernels/take_along_axis_grad_kernel.h"  | 
 | 16 | + | 
 | 17 | +#include "paddle/phi/backends/xpu/enforce_xpu.h"  | 
 | 18 | +#include "paddle/phi/core/kernel_registry.h"  | 
 | 19 | + | 
 | 20 | +namespace phi {  | 
 | 21 | + | 
 | 22 | +template <typename T, typename Context>  | 
 | 23 | +void TakeAlongAxisGradKernel(const Context& dev_ctx,  | 
 | 24 | + const DenseTensor& x,  | 
 | 25 | + const DenseTensor& index,  | 
 | 26 | + const DenseTensor& out_grad,  | 
 | 27 | + int axis,  | 
 | 28 | + DenseTensor* x_grad) {  | 
 | 29 | + using XPUType = typename XPUTypeTrait<T>::Type;  | 
 | 30 | + dev_ctx.template Alloc<T>(x_grad);  | 
 | 31 | + | 
 | 32 | + const auto& index_dtype = index.dtype();  | 
 | 33 | + bool index_dtype_match =  | 
 | 34 | + index_dtype == DataType::INT32 || index_dtype == DataType::INT64;  | 
 | 35 | + PADDLE_ENFORCE_EQ(index_dtype_match,  | 
 | 36 | + true,  | 
 | 37 | + errors::InvalidArgument(  | 
 | 38 | + "Input(Index) holds the wrong type, it holds %s, but "  | 
 | 39 | + "desires to be %s or %s",  | 
 | 40 | + DataTypeToString(index_dtype),  | 
 | 41 | + DataTypeToString(DataType::INT32),  | 
 | 42 | + DataTypeToString(DataType::INT64)));  | 
 | 43 | + | 
 | 44 | + int r = xpu::constant(dev_ctx.x_context(),  | 
 | 45 | + reinterpret_cast<XPUType*>(x_grad->data<T>()),  | 
 | 46 | + x_grad->numel(),  | 
 | 47 | + XPUType(0));  | 
 | 48 | + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");  | 
 | 49 | + | 
 | 50 | + auto x_shape = common::vectorize<int64_t>(x.dims());  | 
 | 51 | + auto out_grad_shape = common::vectorize<int64_t>(out_grad.dims());  | 
 | 52 | + auto index_shape = common::vectorize<int64_t>(index.dims());  | 
 | 53 | + | 
 | 54 | + if (index_dtype == DataType::INT32) {  | 
 | 55 | + r = xpu::paddle_put_along_axis<XPUType, int>(  | 
 | 56 | + dev_ctx.x_context(),  | 
 | 57 | + reinterpret_cast<const XPUType*>(x_grad->data<T>()),  | 
 | 58 | + reinterpret_cast<const XPUType*>(out_grad.data<T>()),  | 
 | 59 | + reinterpret_cast<const int*>(index.data<int>()),  | 
 | 60 | + reinterpret_cast<XPUType*>(x_grad->data<T>()),  | 
 | 61 | + x_shape,  | 
 | 62 | + out_grad_shape,  | 
 | 63 | + index_shape,  | 
 | 64 | + axis,  | 
 | 65 | + 1,  | 
 | 66 | + false);  | 
 | 67 | + PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_put_along_axis");  | 
 | 68 | + } else {  | 
 | 69 | + r = xpu::paddle_put_along_axis<XPUType, int64_t>(  | 
 | 70 | + dev_ctx.x_context(),  | 
 | 71 | + reinterpret_cast<const XPUType*>(x_grad->data<T>()),  | 
 | 72 | + reinterpret_cast<const XPUType*>(out_grad.data<T>()),  | 
 | 73 | + reinterpret_cast<const int64_t*>(index.data<int64_t>()),  | 
 | 74 | + reinterpret_cast<XPUType*>(x_grad->data<T>()),  | 
 | 75 | + x_shape,  | 
 | 76 | + out_grad_shape,  | 
 | 77 | + index_shape,  | 
 | 78 | + axis,  | 
 | 79 | + 1,  | 
 | 80 | + false);  | 
 | 81 | + PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_put_along_axis");  | 
 | 82 | + }  | 
 | 83 | +}  | 
 | 84 | +} // namespace phi  | 
 | 85 | + | 
 | 86 | +PD_REGISTER_KERNEL(take_along_axis_grad,  | 
 | 87 | + XPU,  | 
 | 88 | + ALL_LAYOUT,  | 
 | 89 | + phi::TakeAlongAxisGradKernel,  | 
 | 90 | + float,  | 
 | 91 | + phi::dtype::float16,  | 
 | 92 | + phi::dtype::bfloat16) {}  | 
0 commit comments