Skip to content

Commit e1842d4

Browse files
authored
[CINN] Fixed gather_nd incorrect logic for negative inputs (#73940)
* [CINN] Fixed gather_nd symbolic shape compile bug * [CINN] Add unittest for gather_nd op
1 parent 74bdde0 commit e1842d4

File tree

2 files changed

+149
-104
lines changed

2 files changed

+149
-104
lines changed

paddle/cinn/hlir/op/contrib/gather_nd.cc

Lines changed: 23 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -42,46 +42,21 @@ namespace op {
4242
using cinn::common::CINNValue;
4343
using cinn::common::CINNValuePack;
4444

45-
ir::Tensor GatherNd(const ir::Tensor &x,
46-
const ir::Tensor &index,
47-
const std::string &name) {
48-
std::vector<Expr> x_shape = x->shape;
49-
std::vector<Expr> index_shape = index->shape;
50-
size_t x_shape_size = x_shape.size();
51-
size_t index_shape_size = index_shape.size();
52-
std::vector<Expr> out_shape;
53-
out_shape.insert(out_shape.end(), index_shape.begin(), index_shape.end() - 1);
54-
out_shape.insert(out_shape.end(),
55-
x_shape.begin() + index_shape.back().as_int32(),
56-
x_shape.end());
57-
auto res = Compute(
58-
out_shape,
59-
[=](const std::vector<Expr> &indices) {
60-
std::vector<Expr> indices_position;
61-
for (size_t i = 0; i < index_shape_size - 1; ++i) {
62-
indices_position.push_back(
63-
ir::Cast::Make(cinn::common::Int(32), indices[i]));
64-
}
65-
indices_position.push_back(
66-
ir::Cast::Make(cinn::common::Int(32), Expr(0)));
67-
size_t indices_position_size = indices_position.size();
68-
std::vector<Expr> real_indices;
69-
for (size_t i = 0; i < index_shape.back().as_int32(); ++i) {
70-
indices_position[indices_position_size - 1] =
71-
ir::Cast::Make(cinn::common::Int(32), Expr(i));
72-
real_indices.push_back(
73-
ir::Cast::Make(cinn::common::Int(32), index(indices_position)));
74-
}
75-
if (real_indices.size() == x_shape_size) {
76-
return x(real_indices);
77-
}
78-
for (size_t i = index_shape_size - 1; i < indices.size(); ++i) {
79-
real_indices.push_back(indices[i]);
80-
}
81-
return x(real_indices);
82-
},
83-
name);
84-
return res;
45+
Expr CastShapeElementType(Expr shape_element) {
46+
// constant values should already been cast to int64_t for shape calc
47+
// make sure shape expr does not contain non-matched int type for min and max
48+
if (shape_element.is_constant()) return shape_element;
49+
if (auto max_elem = shape_element.As<ir::Max>()) {
50+
max_elem->a() = CastShapeElementType(max_elem->a());
51+
max_elem->b() = CastShapeElementType(max_elem->b());
52+
} else if (auto min_elem = shape_element.As<ir::Min>()) {
53+
min_elem->a() = CastShapeElementType(min_elem->a());
54+
min_elem->b() = CastShapeElementType(min_elem->b());
55+
} else {
56+
shape_element = ir::Call::Make(
57+
ir::Int(64), "int64_t", {shape_element}, {}, ir::CallType::Intrinsic);
58+
}
59+
return shape_element;
8560
}
8661

8762
ir::Tensor GatherNdSymbolic(const ir::Tensor &x,
@@ -111,8 +86,15 @@ ir::Tensor GatherNdSymbolic(const ir::Tensor &x,
11186
for (size_t i = 0; i < index_shape.back().as_int64(); ++i) {
11287
indices_position[indices_position_size - 1] =
11388
ir::Cast::Make(cinn::common::Int(64), Expr(i));
89+
// support negative indices
90+
auto idx_expr =
91+
ir::Cast::Make(cinn::common::Int(64), index(indices_position));
92+
auto real_expr =
93+
ir::Select::Make(ir::GE::Make(idx_expr, Expr(0)),
94+
idx_expr,
95+
CastShapeElementType(x_shape[i]) + idx_expr);
11496
real_indices.push_back(
115-
ir::Cast::Make(cinn::common::Int(64), index(indices_position)));
97+
ir::Cast::Make(cinn::common::Int(64), real_expr));
11698
}
11799
if (real_indices.size() == x_shape_size) {
118100
return x(real_indices);
@@ -126,67 +108,6 @@ ir::Tensor GatherNdSymbolic(const ir::Tensor &x,
126108
return res;
127109
}
128110

129-
std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
130-
const framework::NodeAttr &attrs,
131-
const std::vector<ir::Tensor> &inputs,
132-
const std::vector<Type> &out_type,
133-
const std::vector<std::vector<int>> &output_shapes,
134-
const Target &target) {
135-
std::string op_name("gather_nd");
136-
137-
framework::CINNCompute gather_nd_compute([=](lang::Args args,
138-
lang::RetValue *ret) {
139-
PADDLE_ENFORCE_NE(
140-
args.empty(),
141-
true,
142-
::common::errors::InvalidArgument(
143-
"The input argument of %s compute is empty! Please check.",
144-
op_name));
145-
CINNValuePack pack_args = args[0];
146-
PADDLE_ENFORCE_GE(
147-
pack_args.size(),
148-
2U,
149-
::common::errors::InvalidArgument("2 input tensors for compute\n"));
150-
Expr x = pack_args[0];
151-
Expr index = pack_args[1];
152-
PADDLE_ENFORCE_NOT_NULL(x.as_tensor(),
153-
::common::errors::InvalidArgument(
154-
"Required x must be a tensor. Please check."));
155-
PADDLE_ENFORCE_NOT_NULL(
156-
index.as_tensor(),
157-
::common::errors::InvalidArgument(
158-
"Required index must be a tensor. Please check."));
159-
PADDLE_ENFORCE_NE(
160-
output_shapes.empty(),
161-
true,
162-
::common::errors::InvalidArgument(
163-
"The output shape of gather_nd is empty! Please check."));
164-
auto tensor_x = x.as_tensor_ref();
165-
auto tensor_index = index.as_tensor_ref();
166-
VLOG(3) << "x shape: " << utils::Join(tensor_x->shape, ", ")
167-
<< ", index shape: " << utils::Join(tensor_index->shape, ", ")
168-
<< ", output_shapes: " << utils::Join(output_shapes[0], ", ");
169-
PADDLE_ENFORCE_EQ(pack_args.size(),
170-
3U,
171-
::common::errors::InvalidArgument(
172-
"The size of pack_args should be 3\n"));
173-
std::string tensor_name = pack_args[2].operator std::string();
174-
ir::Tensor out = GatherNd(tensor_x, tensor_index, tensor_name);
175-
std::vector<CINNValue> res;
176-
res.push_back(CINNValue(out));
177-
PADDLE_ENFORCE_NE(
178-
out_type.empty(),
179-
true,
180-
::common::errors::InvalidArgument(
181-
"The output type of gather_nd is empty! Please check."));
182-
*ret = CINNValuePack{res};
183-
});
184-
185-
auto strategy = std::make_shared<framework::OpStrategy>();
186-
strategy->AddImpl(gather_nd_compute, "strategy.gather_nd.x86", 1);
187-
return strategy;
188-
}
189-
190111
std::shared_ptr<framework::OpStrategy> StrategyForGatherNdSymbolic(
191112
const framework::NodeAttr &attrs,
192113
const std::vector<ir::Tensor> &inputs,
@@ -259,8 +180,6 @@ CINN_REGISTER_HELPER(gather_nd_ops) {
259180
.set_num_outputs(1)
260181
.set_attr<cinn::hlir::framework::StrategyFunctionSymbolic>(
261182
"CINNStrategySymbolic", cinn::hlir::op::StrategyForGatherNdSymbolic)
262-
.set_attr<cinn::hlir::framework::StrategyFunction>(
263-
"CINNStrategy", cinn::hlir::op::StrategyForGatherNd)
264183
.set_attr<cinn::hlir::framework::OpPatternKind>(
265184
"OpPattern", cinn::hlir::framework::OpPatternKind::kInjective)
266185
.set_support_level(4);
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy
18+
import utils
19+
20+
import paddle
21+
22+
23+
class TestGatherNd(unittest.TestCase):
24+
# Note that GatherNd is also used in index_put, so we can test it by using index_put.
25+
def eval(self, dy_compute, inputs, input_spec=None):
26+
dy_out = dy_compute(*inputs)
27+
28+
static_compute = utils.apply_to_static(
29+
dy_compute, use_cinn=True, input_spec=None
30+
)
31+
st_out = static_compute(*inputs)
32+
33+
for a, b in zip(
34+
paddle.utils.flatten(dy_out), paddle.utils.flatten(st_out)
35+
):
36+
numpy.testing.assert_allclose(a, b, atol=1e-6, rtol=1e-6)
37+
38+
@staticmethod
39+
def get_input(x_shape, indices_shape, value_shape, has_negative_index=True):
40+
n_indices = indices_shape[0]
41+
index_dim_size = indices_shape[1] if len(indices_shape) > 1 else 1
42+
43+
x_pd = paddle.randn(x_shape)
44+
x_pd.stop_gradient = False
45+
46+
indices_pd = tuple(
47+
[
48+
paddle.randint(
49+
-x_shape[i] if has_negative_index else 0,
50+
x_shape[i],
51+
[n_indices],
52+
)
53+
for i in range(max(index_dim_size, 1))
54+
]
55+
)
56+
value_pd = paddle.randn(value_shape)
57+
value_pd.stop_gradient = False
58+
59+
dout_pd = paddle.randn(x_shape)
60+
dout_pd.stop_gradient = False
61+
return x_pd, indices_pd, value_pd, dout_pd
62+
63+
@staticmethod
64+
def get_input_spec(indice_dim):
65+
return [
66+
paddle.static.InputSpec(shape=[-1, -1], dtype="float32"),
67+
tuple(
68+
paddle.static.InputSpec(shape=[-1], dtype="int64")
69+
for _ in range(indice_dim)
70+
),
71+
paddle.static.InputSpec(shape=[-1, -1], dtype="float32"),
72+
paddle.static.InputSpec(shape=[-1, -1], dtype="float32"),
73+
]
74+
75+
@staticmethod
76+
def index_put_grad(x, indices, v, dy):
77+
y = paddle.index_put(x, indices, v, True)
78+
return paddle.grad(y, [x, v], dy)
79+
80+
def test_index_put_grad_non_negative_index(self):
81+
x_pd, indices_pd, value_pd, dout_pd = self.get_input(
82+
[12, 13, 14], [88, 2], [88, 14], False
83+
)
84+
85+
self.eval(
86+
TestGatherNd.index_put_grad,
87+
[x_pd, indices_pd, value_pd, dout_pd],
88+
input_spec=self.get_input_spec(2),
89+
)
90+
91+
def test_index_put_grad_negative_index_1(self):
92+
x_pd, indices_pd, value_pd, dout_pd = self.get_input(
93+
[12, 13, 14], [88, 1], [88, 13, 14]
94+
)
95+
96+
self.eval(
97+
TestGatherNd.index_put_grad,
98+
[x_pd, indices_pd, value_pd, dout_pd],
99+
input_spec=self.get_input_spec(1),
100+
)
101+
102+
def test_index_put_grad_negative_index_2(self):
103+
x_pd, indices_pd, value_pd, dout_pd = self.get_input(
104+
[16, 16], [20, 2], [20]
105+
)
106+
107+
self.eval(
108+
TestGatherNd.index_put_grad,
109+
[x_pd, indices_pd, value_pd, dout_pd],
110+
input_spec=self.get_input_spec(2),
111+
)
112+
113+
def test_gather_nd_fusion(self):
114+
x_pd = paddle.randn([256, 128])
115+
y_pd = paddle.randn_like(x_pd)
116+
z_pd = paddle.randn([100])
117+
indices_pd = paddle.randint(-128, 128, [100, 2])
118+
119+
def func(x, y, z, indices):
120+
return paddle.gather_nd(x * y, indices) + z
121+
122+
self.eval(func, [x_pd, y_pd, z_pd, indices_pd])
123+
124+
125+
if __name__ == "__main__":
126+
unittest.main()

0 commit comments

Comments
 (0)