@@ -42,46 +42,21 @@ namespace op {
4242using cinn::common::CINNValue;
4343using cinn::common::CINNValuePack;
4444
45- ir::Tensor GatherNd (const ir::Tensor &x,
46- const ir::Tensor &index,
47- const std::string &name) {
48- std::vector<Expr> x_shape = x->shape ;
49- std::vector<Expr> index_shape = index->shape ;
50- size_t x_shape_size = x_shape.size ();
51- size_t index_shape_size = index_shape.size ();
52- std::vector<Expr> out_shape;
53- out_shape.insert (out_shape.end (), index_shape.begin (), index_shape.end () - 1 );
54- out_shape.insert (out_shape.end (),
55- x_shape.begin () + index_shape.back ().as_int32 (),
56- x_shape.end ());
57- auto res = Compute (
58- out_shape,
59- [=](const std::vector<Expr> &indices) {
60- std::vector<Expr> indices_position;
61- for (size_t i = 0 ; i < index_shape_size - 1 ; ++i) {
62- indices_position.push_back (
63- ir::Cast::Make (cinn::common::Int (32 ), indices[i]));
64- }
65- indices_position.push_back (
66- ir::Cast::Make (cinn::common::Int (32 ), Expr (0 )));
67- size_t indices_position_size = indices_position.size ();
68- std::vector<Expr> real_indices;
69- for (size_t i = 0 ; i < index_shape.back ().as_int32 (); ++i) {
70- indices_position[indices_position_size - 1 ] =
71- ir::Cast::Make (cinn::common::Int (32 ), Expr (i));
72- real_indices.push_back (
73- ir::Cast::Make (cinn::common::Int (32 ), index (indices_position)));
74- }
75- if (real_indices.size () == x_shape_size) {
76- return x (real_indices);
77- }
78- for (size_t i = index_shape_size - 1 ; i < indices.size (); ++i) {
79- real_indices.push_back (indices[i]);
80- }
81- return x (real_indices);
82- },
83- name);
84- return res;
45+ Expr CastShapeElementType (Expr shape_element) {
46+ // constant values should already been cast to int64_t for shape calc
47+ // make sure shape expr does not contain non-matched int type for min and max
48+ if (shape_element.is_constant ()) return shape_element;
49+ if (auto max_elem = shape_element.As <ir::Max>()) {
50+ max_elem->a () = CastShapeElementType (max_elem->a ());
51+ max_elem->b () = CastShapeElementType (max_elem->b ());
52+ } else if (auto min_elem = shape_element.As <ir::Min>()) {
53+ min_elem->a () = CastShapeElementType (min_elem->a ());
54+ min_elem->b () = CastShapeElementType (min_elem->b ());
55+ } else {
56+ shape_element = ir::Call::Make (
57+ ir::Int (64 ), " int64_t" , {shape_element}, {}, ir::CallType::Intrinsic);
58+ }
59+ return shape_element;
8560}
8661
8762ir::Tensor GatherNdSymbolic (const ir::Tensor &x,
@@ -111,8 +86,15 @@ ir::Tensor GatherNdSymbolic(const ir::Tensor &x,
11186 for (size_t i = 0 ; i < index_shape.back ().as_int64 (); ++i) {
11287 indices_position[indices_position_size - 1 ] =
11388 ir::Cast::Make (cinn::common::Int (64 ), Expr (i));
89+ // support negative indices
90+ auto idx_expr =
91+ ir::Cast::Make (cinn::common::Int (64 ), index (indices_position));
92+ auto real_expr =
93+ ir::Select::Make (ir::GE::Make (idx_expr, Expr (0 )),
94+ idx_expr,
95+ CastShapeElementType (x_shape[i]) + idx_expr);
11496 real_indices.push_back (
115- ir::Cast::Make (cinn::common::Int (64 ), index (indices_position) ));
97+ ir::Cast::Make (cinn::common::Int (64 ), real_expr ));
11698 }
11799 if (real_indices.size () == x_shape_size) {
118100 return x (real_indices);
@@ -126,67 +108,6 @@ ir::Tensor GatherNdSymbolic(const ir::Tensor &x,
126108 return res;
127109}
128110
129- std::shared_ptr<framework::OpStrategy> StrategyForGatherNd (
130- const framework::NodeAttr &attrs,
131- const std::vector<ir::Tensor> &inputs,
132- const std::vector<Type> &out_type,
133- const std::vector<std::vector<int >> &output_shapes,
134- const Target &target) {
135- std::string op_name (" gather_nd" );
136-
137- framework::CINNCompute gather_nd_compute ([=](lang::Args args,
138- lang::RetValue *ret) {
139- PADDLE_ENFORCE_NE (
140- args.empty (),
141- true ,
142- ::common::errors::InvalidArgument (
143- " The input argument of %s compute is empty! Please check." ,
144- op_name));
145- CINNValuePack pack_args = args[0 ];
146- PADDLE_ENFORCE_GE (
147- pack_args.size (),
148- 2U ,
149- ::common::errors::InvalidArgument (" 2 input tensors for compute\n " ));
150- Expr x = pack_args[0 ];
151- Expr index = pack_args[1 ];
152- PADDLE_ENFORCE_NOT_NULL (x.as_tensor (),
153- ::common::errors::InvalidArgument (
154- " Required x must be a tensor. Please check." ));
155- PADDLE_ENFORCE_NOT_NULL (
156- index.as_tensor (),
157- ::common::errors::InvalidArgument (
158- " Required index must be a tensor. Please check." ));
159- PADDLE_ENFORCE_NE (
160- output_shapes.empty (),
161- true ,
162- ::common::errors::InvalidArgument (
163- " The output shape of gather_nd is empty! Please check." ));
164- auto tensor_x = x.as_tensor_ref ();
165- auto tensor_index = index.as_tensor_ref ();
166- VLOG (3 ) << " x shape: " << utils::Join (tensor_x->shape , " , " )
167- << " , index shape: " << utils::Join (tensor_index->shape , " , " )
168- << " , output_shapes: " << utils::Join (output_shapes[0 ], " , " );
169- PADDLE_ENFORCE_EQ (pack_args.size (),
170- 3U ,
171- ::common::errors::InvalidArgument (
172- " The size of pack_args should be 3\n " ));
173- std::string tensor_name = pack_args[2 ].operator std::string ();
174- ir::Tensor out = GatherNd (tensor_x, tensor_index, tensor_name);
175- std::vector<CINNValue> res;
176- res.push_back (CINNValue (out));
177- PADDLE_ENFORCE_NE (
178- out_type.empty (),
179- true ,
180- ::common::errors::InvalidArgument (
181- " The output type of gather_nd is empty! Please check." ));
182- *ret = CINNValuePack{res};
183- });
184-
185- auto strategy = std::make_shared<framework::OpStrategy>();
186- strategy->AddImpl (gather_nd_compute, " strategy.gather_nd.x86" , 1 );
187- return strategy;
188- }
189-
190111std::shared_ptr<framework::OpStrategy> StrategyForGatherNdSymbolic (
191112 const framework::NodeAttr &attrs,
192113 const std::vector<ir::Tensor> &inputs,
@@ -259,8 +180,6 @@ CINN_REGISTER_HELPER(gather_nd_ops) {
259180 .set_num_outputs (1 )
260181 .set_attr <cinn::hlir::framework::StrategyFunctionSymbolic>(
261182 " CINNStrategySymbolic" , cinn::hlir::op::StrategyForGatherNdSymbolic)
262- .set_attr <cinn::hlir::framework::StrategyFunction>(
263- " CINNStrategy" , cinn::hlir::op::StrategyForGatherNd)
264183 .set_attr <cinn::hlir::framework::OpPatternKind>(
265184 " OpPattern" , cinn::hlir::framework::OpPatternKind::kInjective )
266185 .set_support_level (4 );
0 commit comments