@@ -42,48 +42,6 @@ namespace op {
4242using  cinn::common::CINNValue;
4343using  cinn::common::CINNValuePack;
4444
45- ir::Tensor GatherNd (const  ir::Tensor &x,
46-  const  ir::Tensor &index,
47-  const  std::string &name) {
48-  std::vector<Expr> x_shape = x->shape ;
49-  std::vector<Expr> index_shape = index->shape ;
50-  size_t  x_shape_size = x_shape.size ();
51-  size_t  index_shape_size = index_shape.size ();
52-  std::vector<Expr> out_shape;
53-  out_shape.insert (out_shape.end (), index_shape.begin (), index_shape.end () - 1 );
54-  out_shape.insert (out_shape.end (),
55-  x_shape.begin () + index_shape.back ().as_int32 (),
56-  x_shape.end ());
57-  auto  res = Compute (
58-  out_shape,
59-  [=](const  std::vector<Expr> &indices) {
60-  std::vector<Expr> indices_position;
61-  for  (size_t  i = 0 ; i < index_shape_size - 1 ; ++i) {
62-  indices_position.push_back (
63-  ir::Cast::Make (cinn::common::Int (32 ), indices[i]));
64-  }
65-  indices_position.push_back (
66-  ir::Cast::Make (cinn::common::Int (32 ), Expr (0 )));
67-  size_t  indices_position_size = indices_position.size ();
68-  std::vector<Expr> real_indices;
69-  for  (size_t  i = 0 ; i < index_shape.back ().as_int32 (); ++i) {
70-  indices_position[indices_position_size - 1 ] =
71-  ir::Cast::Make (cinn::common::Int (32 ), Expr (i));
72-  real_indices.push_back (
73-  ir::Cast::Make (cinn::common::Int (32 ), index (indices_position)));
74-  }
75-  if  (real_indices.size () == x_shape_size) {
76-  return  x (real_indices);
77-  }
78-  for  (size_t  i = index_shape_size - 1 ; i < indices.size (); ++i) {
79-  real_indices.push_back (indices[i]);
80-  }
81-  return  x (real_indices);
82-  },
83-  name);
84-  return  res;
85- }
86- 
8745ir::Tensor GatherNdSymbolic (const  ir::Tensor &x,
8846 const  ir::Tensor &index,
8947 const  std::string &name) {
@@ -111,8 +69,13 @@ ir::Tensor GatherNdSymbolic(const ir::Tensor &x,
11169 for  (size_t  i = 0 ; i < index_shape.back ().as_int64 (); ++i) {
11270 indices_position[indices_position_size - 1 ] =
11371 ir::Cast::Make (cinn::common::Int (64 ), Expr (i));
72+  //  support negative indices
73+  auto  idx_expr =
74+  ir::Cast::Make (cinn::common::Int (64 ), index (indices_position));
75+  auto  real_expr = ir::Select::Make (
76+  ir::GE::Make (idx_expr, Expr (0 )), idx_expr, x_shape[i] + idx_expr);
11477 real_indices.push_back (
115-  ir::Cast::Make (cinn::common::Int (64 ), index (indices_position) ));
78+  ir::Cast::Make (cinn::common::Int (64 ), real_expr ));
11679 }
11780 if  (real_indices.size () == x_shape_size) {
11881 return  x (real_indices);
@@ -126,67 +89,6 @@ ir::Tensor GatherNdSymbolic(const ir::Tensor &x,
12689 return  res;
12790}
12891
129- std::shared_ptr<framework::OpStrategy> StrategyForGatherNd (
130-  const  framework::NodeAttr &attrs,
131-  const  std::vector<ir::Tensor> &inputs,
132-  const  std::vector<Type> &out_type,
133-  const  std::vector<std::vector<int >> &output_shapes,
134-  const  Target &target) {
135-  std::string op_name (" gather_nd" 
136- 
137-  framework::CINNCompute gather_nd_compute ([=](lang::Args args,
138-  lang::RetValue *ret) {
139-  PADDLE_ENFORCE_NE (
140-  args.empty (),
141-  true ,
142-  ::common::errors::InvalidArgument (
143-  " The input argument of %s compute is empty! Please check." 
144-  op_name));
145-  CINNValuePack pack_args = args[0 ];
146-  PADDLE_ENFORCE_GE (
147-  pack_args.size (),
148-  2U ,
149-  ::common::errors::InvalidArgument (" 2 input tensors for compute\n " 
150-  Expr x = pack_args[0 ];
151-  Expr index = pack_args[1 ];
152-  PADDLE_ENFORCE_NOT_NULL (x.as_tensor (),
153-  ::common::errors::InvalidArgument (
154-  " Required x must be a tensor. Please check." 
155-  PADDLE_ENFORCE_NOT_NULL (
156-  index.as_tensor (),
157-  ::common::errors::InvalidArgument (
158-  " Required index must be a tensor. Please check." 
159-  PADDLE_ENFORCE_NE (
160-  output_shapes.empty (),
161-  true ,
162-  ::common::errors::InvalidArgument (
163-  " The output shape of gather_nd is empty! Please check." 
164-  auto  tensor_x = x.as_tensor_ref ();
165-  auto  tensor_index = index.as_tensor_ref ();
166-  VLOG (3 ) << " x shape: " utils::Join (tensor_x->shape , " , " 
167-  << " , index shape: " utils::Join (tensor_index->shape , " , " 
168-  << " , output_shapes: " utils::Join (output_shapes[0 ], " , " 
169-  PADDLE_ENFORCE_EQ (pack_args.size (),
170-  3U ,
171-  ::common::errors::InvalidArgument (
172-  " The size of pack_args should be 3\n " 
173-  std::string tensor_name = pack_args[2 ].operator  std::string ();
174-  ir::Tensor out = GatherNd (tensor_x, tensor_index, tensor_name);
175-  std::vector<CINNValue> res;
176-  res.push_back (CINNValue (out));
177-  PADDLE_ENFORCE_NE (
178-  out_type.empty (),
179-  true ,
180-  ::common::errors::InvalidArgument (
181-  " The output type of gather_nd is empty! Please check." 
182-  *ret = CINNValuePack{res};
183-  });
184- 
185-  auto  strategy = std::make_shared<framework::OpStrategy>();
186-  strategy->AddImpl (gather_nd_compute, " strategy.gather_nd.x86" 1 );
187-  return  strategy;
188- }
189- 
19092std::shared_ptr<framework::OpStrategy> StrategyForGatherNdSymbolic (
19193 const  framework::NodeAttr &attrs,
19294 const  std::vector<ir::Tensor> &inputs,
@@ -259,8 +161,6 @@ CINN_REGISTER_HELPER(gather_nd_ops) {
259161 .set_num_outputs (1 )
260162 .set_attr <cinn::hlir::framework::StrategyFunctionSymbolic>(
261163 " CINNStrategySymbolic" 
262-  .set_attr <cinn::hlir::framework::StrategyFunction>(
263-  " CINNStrategy" 
264164 .set_attr <cinn::hlir::framework::OpPatternKind>(
265165 " OpPattern" kInjective )
266166 .set_support_level (4 );
0 commit comments