@@ -82,22 +82,32 @@ bool AddNOpInferSymbolicShape(pir::Operation *op,
8282 " should be larger than 0. But received X's dimensions %d." ,
8383 inputs_shape.size ()));
8484 symbol::TensorShapeOrDataDimExprs candidate_shape = inputs_shape.front ();
85+ std::vector<symbol::DimExpr> candidate_shape_vec = candidate_shape.shape ();
8586 for (size_t i = 1 ; i < inputs_shape.size (); ++i) {
8687 // 0D tensor
8788 if (inputs_shape[i].shape ().size () == 0 ) {
8889 continue ;
8990 }
9091 if (candidate_shape.shape ().size () == 0 ) {
9192 candidate_shape = inputs_shape[i];
93+ candidate_shape_vec = candidate_shape.shape ();
9294 continue ;
9395 }
94- for (size_t j = 0 ; j < candidate_shape.shape ().size (); ++j) {
95- infer_context->AddEqualCstr (candidate_shape.shape ()[j],
96- inputs_shape[i].shape ()[j]);
96+ for (size_t j = 0 ; j < candidate_shape_vec.size (); ++j) {
97+ if (candidate_shape_vec[j] != 0 ) {
98+ if (inputs_shape[i].shape ()[j] != 0 ) {
99+ infer_context->AddEqualCstr (candidate_shape_vec[j],
100+ inputs_shape[i].shape ()[j]);
101+ } else {
102+ candidate_shape_vec[j] = symbol::DimExpr{0 };
103+ }
104+ }
97105 }
98106 }
99107 infer_context->SetShapeOrDataForValue (
100- op->result (0 ), symbol::ShapeOrDataDimExprs{candidate_shape});
108+ op->result (0 ),
109+ symbol::ShapeOrDataDimExprs{
110+ symbol::TensorShapeOrDataDimExprs (candidate_shape_vec)});
101111
102112 return true ;
103113}
0 commit comments