1616#include " paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
1717#include " paddle/pir/core/builtin_attribute.h"
1818#include " paddle/pir/core/builtin_type.h"
19+ #include " paddle/pir/core/builtin_type_interfaces.h"
1920#include " paddle/pir/dialect/shape/ir/shape_attribute.h"
2021
2122namespace paddle ::dialect {
@@ -33,27 +34,25 @@ bool SameOperandsAndResultShape(
3334 pir::Value operand_source = op->operand_source (0 );
3435
3536 symbol::ShapeOrDataDimExprs operand_shape_or_data =
36- shape_analysis->value_to_shape_or_data_ [ operand_source] ;
37+ shape_analysis->GetShapeOrDataForValue ( operand_source) ;
3738
3839 op->set_attribute (" symbolic_shape" ,
3940 pir::shape::SymbolAttribute::get (pir::IrContext::Instance (),
4041 operand_shape_or_data));
4142 pir::OpResult res = op->result (0 );
42- shape_analysis->value_to_shape_or_data_ [ res] = operand_shape_or_data;
43+ shape_analysis->SetShapeOrDataForValue ( res, operand_shape_or_data) ;
4344 return true ;
4445}
4546
4647bool InferSymbolicShapeElementWiseBinary (
4748 pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
4849 pir::Value operand_source_0 = op->operand_source (0 );
49- std::string operand_source_0_id = pir::GetValueId (&operand_source_0);
5050 std::vector<symbol::DimExpr> shape_0{
51- shape_analysis->value_id_to_shapeordata_ [operand_source_0_id] .shape ()};
51+ shape_analysis->GetShapeOrDataForValue (operand_source_0) .shape ()};
5252
5353 pir::Value operand_source_1 = op->operand_source (1 );
54- std::string operand_source_1_id = pir::GetValueId (&operand_source_1);
5554 std::vector<symbol::DimExpr> shape_1{
56- shape_analysis->value_id_to_shapeordata_ [operand_source_1_id] .shape ()};
55+ shape_analysis->GetShapeOrDataForValue (operand_source_1) .shape ()};
5756
5857 if (shape_0.size () > shape_1.size ()) {
5958 for (size_t i = 0 ; i < shape_0.size () - shape_1.size (); i++) {
@@ -75,9 +74,11 @@ bool InferSymbolicShapeElementWiseBinary(
7574 std::vector<symbol::DimExpr> data;
7675
7776 pir::OpResult res = op->result (0 );
78- std::string res_id = pir::GetValueId (&res);
7977 symbol::ShapeOrDataDimExprs shape_data{shapes, data};
80- shape_analysis->value_id_to_shapeordata_ [res_id] = shape_data;
78+ shape_analysis->SetShapeOrDataForValue (res, shape_data);
79+ op->set_attribute (
80+ " symbolic_shape" ,
81+ pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
8182 return true ;
8283}
8384
@@ -104,7 +105,7 @@ bool DataOpInferSymbolicShape(pir::Operation *op,
104105 std::vector<symbol::DimExpr> sym_dims;
105106 for (auto dim : dims) {
106107 symbol::DimExpr dim_expr;
107- if (dim == - 1 ) {
108+ if (dim == pir::ShapedTypeInterface:: kDynamic ) {
108109 symbol::DimExpr symbolic_dim_expr (shape_analysis->GetNextSymName ());
109110 dim_expr = symbolic_dim_expr;
110111 } else {
@@ -120,7 +121,7 @@ bool DataOpInferSymbolicShape(pir::Operation *op,
120121 pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
121122
122123 pir::OpResult res = op->result (0 );
123- shape_analysis->value_to_shape_or_data_ [ res] = shape_data;
124+ shape_analysis->SetShapeOrDataForValue ( res, shape_data) ;
124125
125126 return true ;
126127}
@@ -171,13 +172,13 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op,
171172 pir::OpResult res = op->result (0 );
172173
173174 symbol::ShapeOrDataDimExprs operand_shape_or_data =
174- shape_analysis->value_to_shape_or_data_ [ operand_source] ;
175+ shape_analysis->GetShapeOrDataForValue ( operand_source) ;
175176
176177 symbol::ShapeOrDataDimExprs extend_shape_or_data =
177178 symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData (
178179 operand_shape_or_data);
179180
180- shape_analysis->value_to_shape_or_data_ [ res] = extend_shape_or_data;
181+ shape_analysis->SetShapeOrDataForValue ( res, extend_shape_or_data) ;
181182 op->set_attribute (" symbolic_shape" ,
182183 pir::shape::SymbolAttribute::get (pir::IrContext::Instance (),
183184 extend_shape_or_data));
@@ -193,7 +194,7 @@ bool StackOpInferSymbolicShape(pir::Operation *op,
193194 pir::ShapeConstraintIRAnalysis *shape_analysis) {
194195 pir::Value operand_source = op->operand_source (0 );
195196 symbol::ShapeOrDataDimExprs operand_shape_or_data =
196- shape_analysis->value_to_shape_or_data_ [ operand_source] ;
197+ shape_analysis->GetShapeOrDataForValue ( operand_source) ;
197198
198199 std::vector<symbol::DimExpr> out_dims;
199200 if (operand_shape_or_data.data ().has_value ()) {
@@ -213,7 +214,7 @@ bool StackOpInferSymbolicShape(pir::Operation *op,
213214 " symbolic_shape" ,
214215 pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
215216 pir::OpResult res = op->result (0 );
216- shape_analysis->value_to_shape_or_data_ [ res] = shape_data;
217+ shape_analysis->SetShapeOrDataForValue ( res, shape_data) ;
217218 return true ;
218219}
219220
@@ -222,7 +223,7 @@ bool ReshapeOpInferSymbolicShape(
222223 pir::Value operand_source_shape = op->operand_source (1 );
223224
224225 symbol::ShapeOrDataDimExprs operand_shape_or_data =
225- shape_analysis->value_to_shape_or_data_ [ operand_source_shape] ;
226+ shape_analysis->GetShapeOrDataForValue ( operand_source_shape) ;
226227
227228 std::vector<symbol::DimExpr> out_dims;
228229 if (operand_shape_or_data.data ().has_value ()) {
@@ -236,9 +237,9 @@ bool ReshapeOpInferSymbolicShape(
236237
237238 pir::OpResult res0 = op->result (0 );
238239 pir::OpResult res1 = op->result (1 );
239- shape_analysis->value_to_shape_or_data_ [ res0] = shape_data;
240- shape_analysis->value_to_shape_or_data_ [res1] =
241- shape_analysis->value_to_shape_or_data_ [ operand_source_shape] ;
240+ shape_analysis->SetShapeOrDataForValue ( res0, shape_data) ;
241+ shape_analysis->SetShapeOrDataForValue (
242+ res1, shape_analysis->GetShapeOrDataForValue ( operand_source_shape)) ;
242243 return true ;
243244}
244245
@@ -267,7 +268,7 @@ bool FullIntArrayOpInferSymbolicShape(
267268 pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
268269
269270 pir::OpResult res = op->result (0 );
270- shape_analysis->value_to_shape_or_data_ [ res] = shape_data;
271+ shape_analysis->SetShapeOrDataForValue ( res, shape_data) ;
271272 return true ;
272273}
273274
@@ -286,7 +287,7 @@ bool SliceOpInferSymbolicShape(pir::Operation *op,
286287 // dialect.
287288 pir::Value operand_source = op->operand_source (0 );
288289 symbol::ShapeOrDataDimExprs operand_shape_or_data =
289- shape_analysis->value_to_shape_or_data_ [ operand_source] ;
290+ shape_analysis->GetShapeOrDataForValue ( operand_source) ;
290291 pir::AttributeMap attributes = op->attributes ();
291292
292293 std::vector<pir::Attribute> attr_starts =
@@ -309,7 +310,7 @@ bool SliceOpInferSymbolicShape(pir::Operation *op,
309310 pir::shape::SymbolAttribute::get (pir::IrContext::Instance (), shape_data));
310311
311312 pir::OpResult res = op->result (0 );
312- shape_analysis->value_to_shape_or_data_ [ res] = shape_data;
313+ shape_analysis->SetShapeOrDataForValue ( res, shape_data) ;
313314 return true ;
314315}
315316
0 commit comments