Skip to content

Commit 2694a29

Browse files
authored
[XPU] support take_along_axis_grad in XPU (#71779)
* xpu: support take_along_axis_grad in XPU * fix test cases for take_along_axis
1 parent 0ff150d commit 2694a29

File tree

3 files changed

+105
-12
lines changed

3 files changed

+105
-12
lines changed

paddle/phi/backends/xpu/xpu3_op_list.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1497,6 +1497,10 @@ XPUOpMap& get_kl3_ops() {
14971497
XPUKernelSet({phi::DataType::FLOAT32,
14981498
phi::DataType::FLOAT16,
14991499
phi::DataType::BFLOAT16})},
1500+
{"take_along_axis_grad",
1501+
XPUKernelSet({phi::DataType::FLOAT32,
1502+
phi::DataType::FLOAT16,
1503+
phi::DataType::BFLOAT16})},
15001504
{"tanh_grad",
15011505
XPUKernelSet({phi::DataType::FLOAT32,
15021506
phi::DataType::FLOAT16,
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/take_along_axis_grad_kernel.h"
16+
17+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
18+
#include "paddle/phi/core/kernel_registry.h"
19+
20+
namespace phi {
21+
22+
template <typename T, typename Context>
23+
void TakeAlongAxisGradKernel(const Context& dev_ctx,
24+
const DenseTensor& x,
25+
const DenseTensor& index,
26+
const DenseTensor& out_grad,
27+
int axis,
28+
DenseTensor* x_grad) {
29+
using XPUType = typename XPUTypeTrait<T>::Type;
30+
dev_ctx.template Alloc<T>(x_grad);
31+
32+
const auto& index_dtype = index.dtype();
33+
bool index_dtype_match =
34+
index_dtype == DataType::INT32 || index_dtype == DataType::INT64;
35+
PADDLE_ENFORCE_EQ(index_dtype_match,
36+
true,
37+
errors::InvalidArgument(
38+
"Input(Index) holds the wrong type, it holds %s, but "
39+
"desires to be %s or %s",
40+
DataTypeToString(index_dtype),
41+
DataTypeToString(DataType::INT32),
42+
DataTypeToString(DataType::INT64)));
43+
44+
int r = xpu::constant(dev_ctx.x_context(),
45+
reinterpret_cast<XPUType*>(x_grad->data<T>()),
46+
x_grad->numel(),
47+
XPUType(0));
48+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
49+
50+
auto x_shape = common::vectorize<int64_t>(x.dims());
51+
auto out_grad_shape = common::vectorize<int64_t>(out_grad.dims());
52+
auto index_shape = common::vectorize<int64_t>(index.dims());
53+
54+
if (index_dtype == DataType::INT32) {
55+
r = xpu::paddle_put_along_axis<XPUType, int>(
56+
dev_ctx.x_context(),
57+
reinterpret_cast<const XPUType*>(x_grad->data<T>()),
58+
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
59+
reinterpret_cast<const int*>(index.data<int>()),
60+
reinterpret_cast<XPUType*>(x_grad->data<T>()),
61+
x_shape,
62+
out_grad_shape,
63+
index_shape,
64+
axis,
65+
1,
66+
false);
67+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_put_along_axis");
68+
} else {
69+
r = xpu::paddle_put_along_axis<XPUType, int64_t>(
70+
dev_ctx.x_context(),
71+
reinterpret_cast<const XPUType*>(x_grad->data<T>()),
72+
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
73+
reinterpret_cast<const int64_t*>(index.data<int64_t>()),
74+
reinterpret_cast<XPUType*>(x_grad->data<T>()),
75+
x_shape,
76+
out_grad_shape,
77+
index_shape,
78+
axis,
79+
1,
80+
false);
81+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "paddle_put_along_axis");
82+
}
83+
}
84+
} // namespace phi
85+
86+
PD_REGISTER_KERNEL(take_along_axis_grad,
87+
XPU,
88+
ALL_LAYOUT,
89+
phi::TakeAlongAxisGradKernel,
90+
float,
91+
phi::dtype::float16,
92+
phi::dtype::bfloat16) {}

test/xpu/test_take_along_axis_op_xpu.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def setUp(self):
6161
self.outputs = {'Result': self.target}
6262

6363
def init_config(self):
64-
self.in_type = np.float32
6564
self.x_shape = (1, 4, 10)
6665
self.index_type = np.int32
6766
self.index = np.array([[[0, 1, 3, 5, 6]]]).astype(self.index_type)
@@ -77,39 +76,38 @@ def test_check_grad(self):
7776

7877
class TestCase1(TestXPUTakeAlongAxisOp):
7978
def init_config(self):
80-
self.in_type = np.float32
8179
self.x_shape = (1, 10, 100)
8280
self.index_type = np.int32
8381
self.index = np.array([[[0, 1, 3, 5, 13]]]).astype(self.index_type)
8482
self.axis = 2
8583

8684
class TestCase2(TestXPUTakeAlongAxisOp):
8785
def init_config(self):
88-
self.in_type = np.float32
8986
self.x_shape = (1, 10, 100)
9087
self.index_type = np.int64
9188
self.index = np.array([[[0, 1, 3, 5, 13]]]).astype(self.index_type)
9289
self.axis = 2
9390

9491
class TestCase3(TestXPUTakeAlongAxisOp):
9592
def init_config(self):
96-
self.in_type = np.float16
9793
self.x_shape = (1, 10, 100)
9894
self.index_type = np.int32
99-
self.index = np.array([[[0, 1, 3, 5, 13]]]).astype(self.index_type)
100-
self.axis = 2
95+
self.index = np.array([[[0], [1], [3], [5]]]).astype(
96+
self.index_type
97+
)
98+
self.axis = 1
10199

102100
class TestCase4(TestXPUTakeAlongAxisOp):
103101
def init_config(self):
104-
self.in_type = np.float16
105102
self.x_shape = (1, 10, 100)
106103
self.index_type = np.int64
107-
self.index = np.array([[[0, 1, 3, 5, 13]]]).astype(self.index_type)
108-
self.axis = 2
104+
self.index = np.array([[[0], [1], [3], [5]]]).astype(
105+
self.index_type
106+
)
107+
self.axis = 1
109108

110109
class TestCase5(TestXPUTakeAlongAxisOp):
111110
def init_config(self):
112-
self.in_type = np.float32
113111
self.x_shape = (1, 10, 100)
114112
self.index_type = np.int32
115113
self.index = np.array([[[0], [1], [3], [5], [8]]]).astype(
@@ -119,8 +117,7 @@ def init_config(self):
119117

120118
class TestCase6(TestXPUTakeAlongAxisOp):
121119
def init_config(self):
122-
self.in_type = np.uint16
123-
self.x_shape = (1, 10, 100)
120+
self.x_shape = (2, 30, 50)
124121
self.index_type = np.int64
125122
self.index = np.array([[[0, 1, 3, 5, 13]]]).astype(self.index_type)
126123
self.axis = 2

0 commit comments

Comments
 (0)