Skip to content

Commit 9d0aa9b

Browse files
committed
[CINN] Fixed gather_nd incorrect logic for negative inputs.
1 parent ec6f870 commit 9d0aa9b

File tree

1 file changed

+6
-106
lines changed

1 file changed

+6
-106
lines changed

paddle/cinn/hlir/op/contrib/gather_nd.cc

Lines changed: 6 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -42,48 +42,6 @@ namespace op {
4242
using cinn::common::CINNValue;
4343
using cinn::common::CINNValuePack;
4444

45-
ir::Tensor GatherNd(const ir::Tensor &x,
46-
const ir::Tensor &index,
47-
const std::string &name) {
48-
std::vector<Expr> x_shape = x->shape;
49-
std::vector<Expr> index_shape = index->shape;
50-
size_t x_shape_size = x_shape.size();
51-
size_t index_shape_size = index_shape.size();
52-
std::vector<Expr> out_shape;
53-
out_shape.insert(out_shape.end(), index_shape.begin(), index_shape.end() - 1);
54-
out_shape.insert(out_shape.end(),
55-
x_shape.begin() + index_shape.back().as_int32(),
56-
x_shape.end());
57-
auto res = Compute(
58-
out_shape,
59-
[=](const std::vector<Expr> &indices) {
60-
std::vector<Expr> indices_position;
61-
for (size_t i = 0; i < index_shape_size - 1; ++i) {
62-
indices_position.push_back(
63-
ir::Cast::Make(cinn::common::Int(32), indices[i]));
64-
}
65-
indices_position.push_back(
66-
ir::Cast::Make(cinn::common::Int(32), Expr(0)));
67-
size_t indices_position_size = indices_position.size();
68-
std::vector<Expr> real_indices;
69-
for (size_t i = 0; i < index_shape.back().as_int32(); ++i) {
70-
indices_position[indices_position_size - 1] =
71-
ir::Cast::Make(cinn::common::Int(32), Expr(i));
72-
real_indices.push_back(
73-
ir::Cast::Make(cinn::common::Int(32), index(indices_position)));
74-
}
75-
if (real_indices.size() == x_shape_size) {
76-
return x(real_indices);
77-
}
78-
for (size_t i = index_shape_size - 1; i < indices.size(); ++i) {
79-
real_indices.push_back(indices[i]);
80-
}
81-
return x(real_indices);
82-
},
83-
name);
84-
return res;
85-
}
86-
8745
ir::Tensor GatherNdSymbolic(const ir::Tensor &x,
8846
const ir::Tensor &index,
8947
const std::string &name) {
@@ -111,8 +69,13 @@ ir::Tensor GatherNdSymbolic(const ir::Tensor &x,
11169
for (size_t i = 0; i < index_shape.back().as_int64(); ++i) {
11270
indices_position[indices_position_size - 1] =
11371
ir::Cast::Make(cinn::common::Int(64), Expr(i));
72+
// support negative indices
73+
auto idx_expr =
74+
ir::Cast::Make(cinn::common::Int(64), index(indices_position));
75+
auto real_expr = ir::Select::Make(
76+
ir::GE::Make(idx_expr, Expr(0)), idx_expr, x_shape[i] + idx_expr);
11477
real_indices.push_back(
115-
ir::Cast::Make(cinn::common::Int(64), index(indices_position)));
78+
ir::Cast::Make(cinn::common::Int(64), real_expr));
11679
}
11780
if (real_indices.size() == x_shape_size) {
11881
return x(real_indices);
@@ -126,67 +89,6 @@ ir::Tensor GatherNdSymbolic(const ir::Tensor &x,
12689
return res;
12790
}
12891

129-
std::shared_ptr<framework::OpStrategy> StrategyForGatherNd(
130-
const framework::NodeAttr &attrs,
131-
const std::vector<ir::Tensor> &inputs,
132-
const std::vector<Type> &out_type,
133-
const std::vector<std::vector<int>> &output_shapes,
134-
const Target &target) {
135-
std::string op_name("gather_nd");
136-
137-
framework::CINNCompute gather_nd_compute([=](lang::Args args,
138-
lang::RetValue *ret) {
139-
PADDLE_ENFORCE_NE(
140-
args.empty(),
141-
true,
142-
::common::errors::InvalidArgument(
143-
"The input argument of %s compute is empty! Please check.",
144-
op_name));
145-
CINNValuePack pack_args = args[0];
146-
PADDLE_ENFORCE_GE(
147-
pack_args.size(),
148-
2U,
149-
::common::errors::InvalidArgument("2 input tensors for compute\n"));
150-
Expr x = pack_args[0];
151-
Expr index = pack_args[1];
152-
PADDLE_ENFORCE_NOT_NULL(x.as_tensor(),
153-
::common::errors::InvalidArgument(
154-
"Required x must be a tensor. Please check."));
155-
PADDLE_ENFORCE_NOT_NULL(
156-
index.as_tensor(),
157-
::common::errors::InvalidArgument(
158-
"Required index must be a tensor. Please check."));
159-
PADDLE_ENFORCE_NE(
160-
output_shapes.empty(),
161-
true,
162-
::common::errors::InvalidArgument(
163-
"The output shape of gather_nd is empty! Please check."));
164-
auto tensor_x = x.as_tensor_ref();
165-
auto tensor_index = index.as_tensor_ref();
166-
VLOG(3) << "x shape: " << utils::Join(tensor_x->shape, ", ")
167-
<< ", index shape: " << utils::Join(tensor_index->shape, ", ")
168-
<< ", output_shapes: " << utils::Join(output_shapes[0], ", ");
169-
PADDLE_ENFORCE_EQ(pack_args.size(),
170-
3U,
171-
::common::errors::InvalidArgument(
172-
"The size of pack_args should be 3\n"));
173-
std::string tensor_name = pack_args[2].operator std::string();
174-
ir::Tensor out = GatherNd(tensor_x, tensor_index, tensor_name);
175-
std::vector<CINNValue> res;
176-
res.push_back(CINNValue(out));
177-
PADDLE_ENFORCE_NE(
178-
out_type.empty(),
179-
true,
180-
::common::errors::InvalidArgument(
181-
"The output type of gather_nd is empty! Please check."));
182-
*ret = CINNValuePack{res};
183-
});
184-
185-
auto strategy = std::make_shared<framework::OpStrategy>();
186-
strategy->AddImpl(gather_nd_compute, "strategy.gather_nd.x86", 1);
187-
return strategy;
188-
}
189-
19092
std::shared_ptr<framework::OpStrategy> StrategyForGatherNdSymbolic(
19193
const framework::NodeAttr &attrs,
19294
const std::vector<ir::Tensor> &inputs,
@@ -259,8 +161,6 @@ CINN_REGISTER_HELPER(gather_nd_ops) {
259161
.set_num_outputs(1)
260162
.set_attr<cinn::hlir::framework::StrategyFunctionSymbolic>(
261163
"CINNStrategySymbolic", cinn::hlir::op::StrategyForGatherNdSymbolic)
262-
.set_attr<cinn::hlir::framework::StrategyFunction>(
263-
"CINNStrategy", cinn::hlir::op::StrategyForGatherNd)
264164
.set_attr<cinn::hlir::framework::OpPatternKind>(
265165
"OpPattern", cinn::hlir::framework::OpPatternKind::kInjective)
266166
.set_support_level(4);

0 commit comments

Comments
 (0)