Skip to content

Commit a576356

Browse files
authored
[PIR][DynamicShape] Remove redundant code for shapeAnalysis and shapedTypeInterface (#60744)
att, remove redundant code for shapeAnalysis and shapedTypeInterface
1 parent bcd5e37 commit a576356

File tree

12 files changed

+69
-100
lines changed

12 files changed

+69
-100
lines changed

paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,8 @@ bool ProcessOp(paddle::dialect::ExpandOp op, pir::PatternRewriter* rewriter) {
139139
pir::ShapeConstraintIRAnalysis& shape_analysis =
140140
pir::ShapeAnalysisManager::Instance().Get(
141141
op.x().defining_op()->GetParentProgram());
142-
CHECK(shape_analysis.value_id_to_shapeordata_.find(GetValueId(&value)) !=
143-
shape_analysis.value_id_to_shapeordata_.end());
144-
return shape_analysis.value_id_to_shapeordata_.at(GetValueId(&value));
142+
143+
return shape_analysis.GetShapeOrDataForValue(value);
145144
};
146145
std::optional<pir::Value> opt_generated_shape =
147146
GetOutOfRewritedGenerateShapeOp(

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape.cc

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
1717
#include "paddle/pir/core/builtin_attribute.h"
1818
#include "paddle/pir/core/builtin_type.h"
19+
#include "paddle/pir/core/builtin_type_interfaces.h"
1920
#include "paddle/pir/dialect/shape/ir/shape_attribute.h"
2021

2122
namespace paddle::dialect {
@@ -33,27 +34,25 @@ bool SameOperandsAndResultShape(
3334
pir::Value operand_source = op->operand_source(0);
3435

3536
symbol::ShapeOrDataDimExprs operand_shape_or_data =
36-
shape_analysis->value_to_shape_or_data_[operand_source];
37+
shape_analysis->GetShapeOrDataForValue(operand_source);
3738

3839
op->set_attribute("symbolic_shape",
3940
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(),
4041
operand_shape_or_data));
4142
pir::OpResult res = op->result(0);
42-
shape_analysis->value_to_shape_or_data_[res] = operand_shape_or_data;
43+
shape_analysis->SetShapeOrDataForValue(res, operand_shape_or_data);
4344
return true;
4445
}
4546

4647
bool InferSymbolicShapeElementWiseBinary(
4748
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
4849
pir::Value operand_source_0 = op->operand_source(0);
49-
std::string operand_source_0_id = pir::GetValueId(&operand_source_0);
5050
std::vector<symbol::DimExpr> shape_0{
51-
shape_analysis->value_id_to_shapeordata_[operand_source_0_id].shape()};
51+
shape_analysis->GetShapeOrDataForValue(operand_source_0).shape()};
5252

5353
pir::Value operand_source_1 = op->operand_source(1);
54-
std::string operand_source_1_id = pir::GetValueId(&operand_source_1);
5554
std::vector<symbol::DimExpr> shape_1{
56-
shape_analysis->value_id_to_shapeordata_[operand_source_1_id].shape()};
55+
shape_analysis->GetShapeOrDataForValue(operand_source_1).shape()};
5756

5857
if (shape_0.size() > shape_1.size()) {
5958
for (size_t i = 0; i < shape_0.size() - shape_1.size(); i++) {
@@ -75,9 +74,11 @@ bool InferSymbolicShapeElementWiseBinary(
7574
std::vector<symbol::DimExpr> data;
7675

7776
pir::OpResult res = op->result(0);
78-
std::string res_id = pir::GetValueId(&res);
7977
symbol::ShapeOrDataDimExprs shape_data{shapes, data};
80-
shape_analysis->value_id_to_shapeordata_[res_id] = shape_data;
78+
shape_analysis->SetShapeOrDataForValue(res, shape_data);
79+
op->set_attribute(
80+
"symbolic_shape",
81+
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
8182
return true;
8283
}
8384

@@ -104,7 +105,7 @@ bool DataOpInferSymbolicShape(pir::Operation *op,
104105
std::vector<symbol::DimExpr> sym_dims;
105106
for (auto dim : dims) {
106107
symbol::DimExpr dim_expr;
107-
if (dim == -1) {
108+
if (dim == pir::ShapedTypeInterface::kDynamic) {
108109
symbol::DimExpr symbolic_dim_expr(shape_analysis->GetNextSymName());
109110
dim_expr = symbolic_dim_expr;
110111
} else {
@@ -120,7 +121,7 @@ bool DataOpInferSymbolicShape(pir::Operation *op,
120121
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
121122

122123
pir::OpResult res = op->result(0);
123-
shape_analysis->value_to_shape_or_data_[res] = shape_data;
124+
shape_analysis->SetShapeOrDataForValue(res, shape_data);
124125

125126
return true;
126127
}
@@ -171,13 +172,13 @@ bool ShapeOpInferSymbolicShape(pir::Operation *op,
171172
pir::OpResult res = op->result(0);
172173

173174
symbol::ShapeOrDataDimExprs operand_shape_or_data =
174-
shape_analysis->value_to_shape_or_data_[operand_source];
175+
shape_analysis->GetShapeOrDataForValue(operand_source);
175176

176177
symbol::ShapeOrDataDimExprs extend_shape_or_data =
177178
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(
178179
operand_shape_or_data);
179180

180-
shape_analysis->value_to_shape_or_data_[res] = extend_shape_or_data;
181+
shape_analysis->SetShapeOrDataForValue(res, extend_shape_or_data);
181182
op->set_attribute("symbolic_shape",
182183
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(),
183184
extend_shape_or_data));
@@ -193,7 +194,7 @@ bool StackOpInferSymbolicShape(pir::Operation *op,
193194
pir::ShapeConstraintIRAnalysis *shape_analysis) {
194195
pir::Value operand_source = op->operand_source(0);
195196
symbol::ShapeOrDataDimExprs operand_shape_or_data =
196-
shape_analysis->value_to_shape_or_data_[operand_source];
197+
shape_analysis->GetShapeOrDataForValue(operand_source);
197198

198199
std::vector<symbol::DimExpr> out_dims;
199200
if (operand_shape_or_data.data().has_value()) {
@@ -213,7 +214,7 @@ bool StackOpInferSymbolicShape(pir::Operation *op,
213214
"symbolic_shape",
214215
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
215216
pir::OpResult res = op->result(0);
216-
shape_analysis->value_to_shape_or_data_[res] = shape_data;
217+
shape_analysis->SetShapeOrDataForValue(res, shape_data);
217218
return true;
218219
}
219220

@@ -222,7 +223,7 @@ bool ReshapeOpInferSymbolicShape(
222223
pir::Value operand_source_shape = op->operand_source(1);
223224

224225
symbol::ShapeOrDataDimExprs operand_shape_or_data =
225-
shape_analysis->value_to_shape_or_data_[operand_source_shape];
226+
shape_analysis->GetShapeOrDataForValue(operand_source_shape);
226227

227228
std::vector<symbol::DimExpr> out_dims;
228229
if (operand_shape_or_data.data().has_value()) {
@@ -236,9 +237,9 @@ bool ReshapeOpInferSymbolicShape(
236237

237238
pir::OpResult res0 = op->result(0);
238239
pir::OpResult res1 = op->result(1);
239-
shape_analysis->value_to_shape_or_data_[res0] = shape_data;
240-
shape_analysis->value_to_shape_or_data_[res1] =
241-
shape_analysis->value_to_shape_or_data_[operand_source_shape];
240+
shape_analysis->SetShapeOrDataForValue(res0, shape_data);
241+
shape_analysis->SetShapeOrDataForValue(
242+
res1, shape_analysis->GetShapeOrDataForValue(operand_source_shape));
242243
return true;
243244
}
244245

@@ -267,7 +268,7 @@ bool FullIntArrayOpInferSymbolicShape(
267268
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
268269

269270
pir::OpResult res = op->result(0);
270-
shape_analysis->value_to_shape_or_data_[res] = shape_data;
271+
shape_analysis->SetShapeOrDataForValue(res, shape_data);
271272
return true;
272273
}
273274

@@ -286,7 +287,7 @@ bool SliceOpInferSymbolicShape(pir::Operation *op,
286287
// dialect.
287288
pir::Value operand_source = op->operand_source(0);
288289
symbol::ShapeOrDataDimExprs operand_shape_or_data =
289-
shape_analysis->value_to_shape_or_data_[operand_source];
290+
shape_analysis->GetShapeOrDataForValue(operand_source);
290291
pir::AttributeMap attributes = op->attributes();
291292

292293
std::vector<pir::Attribute> attr_starts =
@@ -309,7 +310,7 @@ bool SliceOpInferSymbolicShape(pir::Operation *op,
309310
pir::shape::SymbolAttribute::get(pir::IrContext::Instance(), shape_data));
310311

311312
pir::OpResult res = op->result(0);
312-
shape_analysis->value_to_shape_or_data_[res] = shape_data;
313+
shape_analysis->SetShapeOrDataForValue(res, shape_data);
313314
return true;
314315
}
315316

paddle/fluid/pir/dialect/operator/ir/manual_op.cc

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3157,15 +3157,9 @@ bool ShapeBroadcastOp::InferSymbolicShape(
31573157
pir::ShapeConstraintIRAnalysis *shape_analysis) {
31583158
pir::Value x = operand_source(0);
31593159
pir::Value y = operand_source(1);
3160-
std::string x_id = pir::GetValueId(&x);
3161-
std::string y_id = pir::GetValueId(&y);
3162-
3163-
IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(x_id) > 0,
3164-
"x_id does not exist.");
3165-
IR_ENFORCE(shape_analysis->value_id_to_shapeordata_.count(y_id) > 0,
3166-
"y_id does not exist.");
3167-
const auto &x_data_shape = shape_analysis->value_id_to_shapeordata_.at(x_id);
3168-
const auto &y_data_shape = shape_analysis->value_id_to_shapeordata_.at(y_id);
3160+
3161+
const auto &x_data_shape = shape_analysis->GetShapeOrDataForValue(x);
3162+
const auto &y_data_shape = shape_analysis->GetShapeOrDataForValue(y);
31693163
IR_ENFORCE(x_data_shape.data().has_value(),
31703164
"Value x comes from ShapeOp, it must have data");
31713165
IR_ENFORCE(y_data_shape.data().has_value(),
@@ -3180,10 +3174,9 @@ bool ShapeBroadcastOp::InferSymbolicShape(
31803174
}
31813175

31823176
pir::OpResult res = result(0);
3183-
std::string res_id = pir::GetValueId(&res);
31843177
symbol::ShapeOrDataDimExprs output_data_shape =
31853178
symbol::ShapeOrDataDimExprs::MakeConsistentShapeOrData(output_data);
3186-
shape_analysis->value_id_to_shapeordata_[res_id] = output_data_shape;
3179+
shape_analysis->SetShapeOrDataForValue(res, output_data_shape);
31873180
return true;
31883181
}
31893182

paddle/fluid/pir/dialect/operator/ir/op_dialect.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ struct CombineOpInferSymbolicShapeInterfaceModel
6262
}
6363

6464
auto operand_source_1st_data =
65-
shape_analysis->value_to_shape_or_data_[op->operand_source(0)].data();
65+
shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).data();
6666
if (operand_source_1st_data.has_value()) {
6767
for (auto operand_source : op->operands_source()) {
6868
auto source_data =
69-
shape_analysis->value_to_shape_or_data_[operand_source]
69+
shape_analysis->GetShapeOrDataForValue(operand_source)
7070
.data()
7171
.value();
7272
out_dims.push_back(source_data[0]);
@@ -83,7 +83,7 @@ struct CombineOpInferSymbolicShapeInterfaceModel
8383
pir::shape::SymbolAttribute::get(
8484
pir::IrContext::Instance(), shape_data));
8585
auto res = op->result(0);
86-
shape_analysis->value_to_shape_or_data_[res] = shape_data;
86+
shape_analysis->SetShapeOrDataForValue(res, shape_data);
8787
return true;
8888
}
8989

paddle/fluid/pir/transforms/shape_optimization_pass.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ void DebugPrintOpInfo(
5151
<< "ShapeOrData: ";
5252

5353
if (shape_analysis != nullptr) {
54-
auto shape_data = shape_analysis->value_to_shape_or_data_[res];
54+
auto shape_data = shape_analysis->GetShapeOrDataForValue(res);
5555
print_stream << "shape: [";
5656

5757
for (size_t i = 0; i < shape_data.shape().size(); ++i) {
@@ -94,7 +94,9 @@ void InferSymExprForAllValues(ModuleOp module_op) {
9494
if (infer_symbolic_shape_interface) {
9595
VLOG(3) << op.name() << " has InferSymbolicShapeInterface.";
9696
PADDLE_ENFORCE(infer_symbolic_shape_interface.InferSymbolicShape(
97-
&shape_analysis));
97+
&shape_analysis),
98+
"InferSymbolicShape for %s failed.",
99+
op.name());
98100
}
99101
DebugPrintOpInfo(&op, &shape_analysis);
100102
}

paddle/pir/core/builtin_type_interfaces.cc

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,8 @@ Type ShapedTypeInterface::GetElementType() const {
2121
return impl_->get_element_type(*this);
2222
}
2323

24-
std::vector<int64_t> ShapedTypeInterface::GetDyShape() const {
25-
if (dy_shape_.size() == 0) {
26-
auto ddim_vec = common::vectorize(impl_->get_shape(*this));
27-
dy_shape_ = ddim_vec;
28-
std::replace(dy_shape_.begin(),
29-
dy_shape_.end(),
30-
(int64_t)-1,
31-
ShapedTypeInterface::kDynamic);
32-
}
33-
return dy_shape_;
24+
pir::DDim ShapedTypeInterface::GetShape() const {
25+
return impl_->get_shape(*this);
3426
}
3527

3628
} // namespace pir

paddle/pir/core/builtin_type_interfaces.h

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class IR_API ShapedTypeInterface
5656
///
5757
/// \brief kDynamic
5858
///
59-
static constexpr int64_t kDynamic = std::numeric_limits<int64_t>::min();
59+
static constexpr int64_t kDynamic = std::int64_t(-1);
6060

6161
ShapedTypeInterface(Type type, Concept *impl)
6262
: TypeInterfaceBase<ShapedTypeInterface>(type), impl_(impl) {}
@@ -69,7 +69,7 @@ class IR_API ShapedTypeInterface
6969
///
7070
/// \brief Get the shape of this type.
7171
///
72-
std::vector<int64_t> GetDyShape() const;
72+
pir::DDim GetShape() const;
7373

7474
///
7575
/// \brief Check whether this type is ranked, currently return true.
@@ -81,7 +81,7 @@ class IR_API ShapedTypeInterface
8181
///
8282
int64_t GetRank() const {
8383
IR_ENFORCE((*this).HasRank(), "Cannot query rank of unranked shaped type.");
84-
return (*this).GetDyShape().size();
84+
return (*this).GetShape().size();
8585
}
8686

8787
///
@@ -94,11 +94,10 @@ class IR_API ShapedTypeInterface
9494
/// dimension.
9595
///
9696
bool IsDynamicShape() const {
97-
auto size_vec = (*this).GetDyShape();
98-
return std::any_of(
99-
size_vec.begin(), size_vec.end(), [](int64_t size_value) {
100-
return IsDynamic(size_value);
101-
});
97+
auto size_vec = common::vectorize(impl_->get_shape(*this));
98+
return std::any_of(size_vec.begin(), size_vec.end(), [](int64_t size_val) {
99+
return IsDynamic(size_val);
100+
});
102101
}
103102

104103
///
@@ -112,15 +111,15 @@ class IR_API ShapedTypeInterface
112111
///
113112
bool IsDynamicDim(unsigned idx) const {
114113
IR_ENFORCE(idx < GetRank(), "Invalid index for shaped type.");
115-
return ShapedTypeInterface::IsDynamic((*this).GetDyShape()[idx]);
114+
return ShapedTypeInterface::IsDynamic((*this).GetShape()[idx]);
116115
}
117116

118117
///
119118
/// \brief Get the number of dimensions with dynamic size for a ranked type.
120119
/// Aborts for unranked types.
121120
///
122121
int64_t GetNumDynamicDims() const {
123-
auto shape_vec = (*this).GetDyShape();
122+
auto shape_vec = vectorize((*this).GetShape());
124123
return std::count_if(
125124
shape_vec.begin(), shape_vec.end(), ShapedTypeInterface::IsDynamic);
126125
}
@@ -131,12 +130,11 @@ class IR_API ShapedTypeInterface
131130
///
132131
int64_t GetDimSize(unsigned idx) const {
133132
IR_ENFORCE(idx < GetRank(), "Invalid index for shaped type.");
134-
return (*this).GetDyShape()[idx];
133+
return (*this).GetShape()[idx];
135134
}
136135

137136
private:
138137
Concept *impl_;
139-
mutable std::vector<int64_t> dy_shape_;
140138
};
141139

142140
} // namespace pir

paddle/pir/core/type_util.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ Type GetElementTypeOrSelf(Type type) {
2323
return type;
2424
}
2525

26-
bool VerifyCompatibleShape(const std::vector<int64_t> &lhs_shape,
27-
const std::vector<int64_t> &rhs_shape) {
26+
bool VerifyCompatibleShape(const pir::DDim &lhs_shape,
27+
const pir::DDim &rhs_shape) {
2828
if (lhs_shape.size() != rhs_shape.size()) return false;
2929

30-
for (auto dim1 : lhs_shape) {
31-
for (auto dim2 : rhs_shape) {
30+
for (auto dim1 : common::vectorize(lhs_shape)) {
31+
for (auto dim2 : common::vectorize(rhs_shape)) {
3232
if (!ShapedTypeInterface::IsDynamic(dim1) &&
3333
!ShapedTypeInterface::IsDynamic(dim2) && dim1 != dim2)
3434
return false;
@@ -47,8 +47,8 @@ bool VerifyCompatibleShape(Type lhs_type, Type rhs_type) {
4747

4848
if (!lhs_shaped_type.HasRank() || !rhs_shaped_type.HasRank()) return true;
4949

50-
return VerifyCompatibleShape(lhs_shaped_type.GetDyShape(),
51-
rhs_shaped_type.GetDyShape());
50+
return VerifyCompatibleShape(lhs_shaped_type.GetShape(),
51+
rhs_shaped_type.GetShape());
5252
}
5353

5454
bool VerifyCompatibleDims(const std::vector<int64_t> &dims) {

paddle/pir/dialect/shape/utils/shape_optimization_utils.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,9 @@ std::vector<SymbolicDimOp> SymbolicDimMgr::CreateSymbolicDimsForRankedValue(
201201
std::vector<SymbolicDimOp> symbols;
202202
auto dims = value.type().dyn_cast<pir::DenseTensorType>().dims();
203203
for (int idx = 0; idx < dims.size(); ++idx) {
204-
symbols.push_back(
205-
(dims[idx] == ShapedTypeInterface::kDynamic || dims[idx] == -1)
206-
? NewSymbolicDim()
207-
: NewConstantSymbolicDim(dims[idx]));
204+
symbols.push_back(dims[idx] == ShapedTypeInterface::kDynamic
205+
? NewSymbolicDim()
206+
: NewConstantSymbolicDim(dims[idx]));
208207
}
209208
return symbols;
210209
}

0 commit comments

Comments
 (0)