Skip to content

Commit 8532d71

Browse files
authored
[CINN] Add SimplifyWithObviousGreaterThan (#71341)
* Add SimplifyWithObviousGreaterThan * refine Min(S0, Add(S0,S1), Mul(Add(S1,S2),S2)) => S0 * fix typo * Add more tests * fix * fix bugs * log * apply review * clean * fix bugs * fix * make sure accuracy * fix typo * may be * Add a test * finally test * deal zero, one * fix bug for if condition
1 parent 90f3f4c commit 8532d71

File tree

3 files changed

+254
-6
lines changed

3 files changed

+254
-6
lines changed

paddle/pir/include/dialect/shape/utils/dim_expr_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ IR_API PriorityComparisonStatus CompareDimExprPriority(const DimExpr& lhs,
3939

4040
enum class DimExprCompareResult {
4141
GT, // lhs is greater than rhs
42+
GE, // lhs is greater than or equal to rhs
4243
EQ, // lhs and rhs is equal
4344
LT, // lhs is less than rhs
45+
LE, // lhs is less than or equal to rhs
4446
UNKNOWN, // lhs and rhs is not comparable
4547
};
4648
IR_API DimExprCompareResult Compare(const DimExpr& lhs, const DimExpr& rhs);

paddle/pir/src/dialect/shape/utils/dim_expr_util.cc

Lines changed: 203 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,11 +1005,12 @@ struct FoldRedundantSymbolicBroadcast {
10051005
* Simplify Example:
10061006
* Broadcast(S0,S0,S1) => Broadcast(S0,S1)
10071007
*/
1008-
struct FoldRedundantBroadcast {
1009-
using dim_expr_type = Broadcast<DimExpr>;
1008+
template <template <typename> class Op>
1009+
struct FoldRepetitiveSymbol {
1010+
using dim_expr_type = Op<DimExpr>;
10101011

10111012
DimExpr Rewrite(const DimExpr& expr) {
1012-
const auto& [operands] = expr.Get<Broadcast<DimExpr>>();
1013+
const auto& [operands] = expr.Get<Op<DimExpr>>();
10131014
while (operands->size() > 1) {
10141015
int pos_index = SearchSameIndex(operands);
10151016
if (pos_index < 0) {
@@ -1020,7 +1021,7 @@ struct FoldRedundantBroadcast {
10201021
if (operands->size() == 1) {
10211022
return operands->at(0);
10221023
} else {
1023-
return Broadcast<DimExpr>{operands};
1024+
return Op<DimExpr>{operands};
10241025
}
10251026
PADDLE_THROW(common::errors::Fatal("Dead code."));
10261027
}
@@ -1040,6 +1041,197 @@ struct FoldRedundantBroadcast {
10401041
}
10411042
};
10421043

1044+
DimExprCompareResult EasyCompareAddWithZero(const Add<DimExpr>& add) {
1045+
// Only return GT, GE, UNKNOWN.
1046+
List<DimExpr> operands = add.operands;
1047+
for (const auto& operand : *operands) {
1048+
if (operand.isa<std::string>()) {
1049+
continue;
1050+
}
1051+
if (operand.isa<std::int64_t>() && operand.dyn_cast<int64_t>() > 0) {
1052+
continue;
1053+
}
1054+
return DimExprCompareResult::UNKNOWN;
1055+
}
1056+
return DimExprCompareResult::GT;
1057+
}
1058+
1059+
DimExprCompareResult EasyCompareMulWithOne(const Mul<DimExpr>& mul) {
1060+
// Only return GT, GE, UNKNOWN.
1061+
List<DimExpr> operands = mul.operands;
1062+
int64_t const_result = 1;
1063+
for (const auto& operand : *operands) {
1064+
if (operand.isa<std::string>()) {
1065+
continue;
1066+
}
1067+
if (operand.isa<std::int64_t>() && operand.dyn_cast<int64_t>() > 1) {
1068+
const_result = operand.dyn_cast<int64_t>();
1069+
continue;
1070+
}
1071+
1072+
return DimExprCompareResult::UNKNOWN;
1073+
}
1074+
if (const_result == 1) {
1075+
return DimExprCompareResult::GE;
1076+
} else {
1077+
return DimExprCompareResult::GT;
1078+
}
1079+
}
1080+
1081+
bool EasyIsGtWithZero(const DimExpr& expr) {
1082+
auto ExprVisit = common::Overloaded{
1083+
[](const std::int64_t& expr) { return expr > 0; },
1084+
[](const std::string& expr) { return true; },
1085+
[](const Mul<DimExpr>& expr) {
1086+
return EasyCompareMulWithOne(expr) != DimExprCompareResult::UNKNOWN;
1087+
},
1088+
[](const Add<DimExpr>& expr) {
1089+
return EasyCompareAddWithZero(expr) == DimExprCompareResult::GT;
1090+
},
1091+
[](const Broadcast<DimExpr>& expr) { return true; },
1092+
[](const auto& expr) { return false; }};
1093+
return std::visit(ExprVisit, expr.variant());
1094+
}
1095+
1096+
DimExprCompareResult EasyCompareGtOrGe(const DimExpr& lhs,
1097+
const DimExpr& rhs,
1098+
bool is_broadcast = false) {
1099+
// TODO(ooooo): not perfect but ensures accuracy now.Such as:
1100+
// S0 < Add(S0, Mul(S1, S2)), S2 also can be Add(S4, S5, -1)
1101+
// range info may be used.
1102+
auto CompareDivResult = common::Overloaded{
1103+
[](const std::int64_t& expr) {
1104+
// trick for Min(Mul(5, S0), Mul(3, S0))
1105+
return expr >= 1 ? DimExprCompareResult::GT
1106+
: DimExprCompareResult::UNKNOWN;
1107+
},
1108+
[](const std::string& expr) { return DimExprCompareResult::GE; },
1109+
[&](const Mul<DimExpr>& expr) { return EasyCompareMulWithOne(expr); },
1110+
[](const auto& expr) { return DimExprCompareResult::UNKNOWN; }};
1111+
1112+
auto CompareSubResult = common::Overloaded{
1113+
[](const std::int64_t& expr) {
1114+
return expr > 0 ? DimExprCompareResult::GT
1115+
: DimExprCompareResult::UNKNOWN;
1116+
},
1117+
[](const std::string& expr) { return DimExprCompareResult::GT; },
1118+
[&](const Add<DimExpr>& expr) { return EasyCompareAddWithZero(expr); },
1119+
[](const auto& expr) { return DimExprCompareResult::UNKNOWN; }};
1120+
1121+
if (lhs.isa<std::string>() && rhs.isa<std::string>()) {
1122+
return DimExprCompareResult::UNKNOWN;
1123+
}
1124+
1125+
auto IsAddOrMul = [](const DimExpr& expr) {
1126+
return expr.isa<Add<DimExpr>>() || expr.isa<Mul<DimExpr>>();
1127+
};
1128+
auto IsOneOrZero = [](const DimExpr& expr) {
1129+
return expr == DimExpr{1} || expr == DimExpr{0};
1130+
};
1131+
1132+
if (!IsAddOrMul(lhs) && !IsAddOrMul(rhs) && !IsOneOrZero(lhs) &&
1133+
!IsOneOrZero(rhs)) {
1134+
return DimExprCompareResult::UNKNOWN;
1135+
}
1136+
1137+
// check with Sub
1138+
DimExpr simplified_result_sub = SimplifyDimExpr(DimExpr{lhs} - DimExpr{rhs});
1139+
auto sub_compare =
1140+
std::visit(CompareSubResult, simplified_result_sub.variant());
1141+
if (sub_compare != DimExprCompareResult::UNKNOWN) {
1142+
return sub_compare;
1143+
}
1144+
if (rhs != symbol::DimExpr{0}) {
1145+
if (!is_broadcast && !EasyIsGtWithZero(rhs)) {
1146+
// assume operands in broadcast is always positive.
1147+
return DimExprCompareResult::UNKNOWN;
1148+
}
1149+
DimExpr simplified_result_div =
1150+
SimplifyDimExpr(DimExpr{lhs} / DimExpr{rhs});
1151+
auto div_compare =
1152+
std::visit(CompareDivResult, simplified_result_div.variant());
1153+
return div_compare;
1154+
} else {
1155+
return DimExprCompareResult::UNKNOWN;
1156+
}
1157+
}
1158+
1159+
struct SimplifyMaxWithGE {
1160+
using dim_expr_type = Max<DimExpr>;
1161+
static List<DimExpr> SearchErasable(const List<DimExpr>& operands) {
1162+
List<DimExpr> simplified_operands{};
1163+
for (std::size_t i = 0; i < operands->size(); ++i) {
1164+
bool is_redundant = false;
1165+
for (std::size_t j = 0; j < operands->size(); ++j) {
1166+
if (i == j) {
1167+
continue;
1168+
}
1169+
auto compare_j_i = EasyCompareGtOrGe(operands->at(j), operands->at(i));
1170+
if (compare_j_i == DimExprCompareResult::GT ||
1171+
compare_j_i == DimExprCompareResult::GE) {
1172+
is_redundant = true;
1173+
break;
1174+
}
1175+
}
1176+
if (!is_redundant) {
1177+
simplified_operands->push_back(operands->at(i));
1178+
}
1179+
}
1180+
return simplified_operands;
1181+
}
1182+
1183+
DimExpr Rewrite(const DimExpr& expr) {
1184+
const auto [operands] = expr.Get<Max<DimExpr>>();
1185+
List<DimExpr> simplified_operands = SearchErasable(operands);
1186+
1187+
if (simplified_operands->size() == 1) {
1188+
return simplified_operands->at(0);
1189+
} else {
1190+
return Max<DimExpr>{simplified_operands};
1191+
}
1192+
}
1193+
};
1194+
1195+
/*
1196+
* Simplify Example:
1197+
* Min(S0, Mul(S0, S1)) => S0
1198+
*/
1199+
struct SimplifyMinWithGE {
1200+
using dim_expr_type = Min<DimExpr>;
1201+
static List<DimExpr> SearchErasable(const List<DimExpr>& operands) {
1202+
List<DimExpr> simplified_operands{};
1203+
for (std::size_t i = 0; i < operands->size(); ++i) {
1204+
bool is_redundant = false;
1205+
for (std::size_t j = 0; j < operands->size(); ++j) {
1206+
if (i == j) {
1207+
continue;
1208+
}
1209+
auto compare_i_j = EasyCompareGtOrGe(operands->at(i), operands->at(j));
1210+
if (compare_i_j == DimExprCompareResult::GT ||
1211+
compare_i_j == DimExprCompareResult::GE) {
1212+
is_redundant = true;
1213+
break;
1214+
}
1215+
}
1216+
if (!is_redundant) {
1217+
simplified_operands->push_back(operands->at(i));
1218+
}
1219+
}
1220+
return simplified_operands;
1221+
}
1222+
1223+
DimExpr Rewrite(const DimExpr& expr) {
1224+
const auto [operands] = expr.Get<Min<DimExpr>>();
1225+
List<DimExpr> simplified_operands = SearchErasable(operands);
1226+
1227+
if (simplified_operands->size() == 1) {
1228+
return simplified_operands->at(0);
1229+
} else {
1230+
return Min<DimExpr>{simplified_operands};
1231+
}
1232+
}
1233+
};
1234+
10431235
/*
10441236
* Simplify Example:
10451237
* Broadcast(S0, Mul(S0, S1)) => Mul(S0, S1)
@@ -1266,6 +1458,8 @@ DimExpr Simplify(const DimExpr& expr) {
12661458
DoPass<SimplifyOperands<Mul>>(&keep_rewrite, &ret);
12671459
DoPass<SimplifyOperands<Div>>(&keep_rewrite, &ret);
12681460
DoPass<SimplifyOperands<Broadcast>>(&keep_rewrite, &ret);
1461+
DoPass<SimplifyOperands<Min>>(&keep_rewrite, &ret);
1462+
DoPass<SimplifyOperands<Max>>(&keep_rewrite, &ret);
12691463
DoPass<SortOperands<Add>>(&keep_rewrite, &ret);
12701464
DoPass<SortOperands<Mul>>(&keep_rewrite, &ret);
12711465
DoPass<SortOperands<Broadcast>>(&keep_rewrite, &ret);
@@ -1283,9 +1477,13 @@ DimExpr Simplify(const DimExpr& expr) {
12831477
DoPass<FoldConstants<Min>>(&keep_rewrite, &ret);
12841478
DoPass<FoldConstants<Broadcast>>(&keep_rewrite, &ret);
12851479
DoPass<FoldInversedPairToUnit<Add>>(&keep_rewrite, &ret);
1286-
DoPass<FoldRedundantBroadcast>(&keep_rewrite, &ret);
1480+
DoPass<FoldRepetitiveSymbol<Broadcast>>(&keep_rewrite, &ret);
1481+
DoPass<FoldRepetitiveSymbol<Min>>(&keep_rewrite, &ret);
1482+
DoPass<FoldRepetitiveSymbol<Max>>(&keep_rewrite, &ret);
12871483
DoPass<FoldRedundantSymbolicBroadcast>(&keep_rewrite, &ret);
12881484
DoPass<SimplifyBroadcast>(&keep_rewrite, &ret);
1485+
DoPass<SimplifyMinWithGE>(&keep_rewrite, &ret);
1486+
DoPass<SimplifyMaxWithGE>(&keep_rewrite, &ret);
12891487
DoPass<SimplifyDiv>(&keep_rewrite, &ret);
12901488
if (expr_before_run_pipeline == ret) break;
12911489
}

test/cpp/pir/shape_dialect/simplify_dim_expr_test.cc

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,60 @@ TEST(Simplify, FoldBroadcast) {
213213
ASSERT_TRUE(simplify_broadcast3 == add);
214214
}
215215

216-
TEST(Simplify, FoldRedundantBroadcast) {
216+
TEST(Simplify, FoldRepetitiveSymbol) {
217217
DimExpr S0{"S0"};
218218
DimExpr S1{"S1"};
219219
DimExpr bc{Broadcast<DimExpr>{{S0, S0, S1, S1}}};
220+
DimExpr max{Max<DimExpr>{{S0, S0, S1, S1}}};
221+
DimExpr min{Min<DimExpr>{{S0, S0, S1, S1}}};
220222
DimExpr simplify_bc = SimplifyDimExpr(bc);
223+
DimExpr simplify_max = SimplifyDimExpr(max);
224+
DimExpr simplify_min = SimplifyDimExpr(min);
221225
ASSERT_TRUE((simplify_bc == Broadcast<DimExpr>{{S0, S1}}));
226+
ASSERT_TRUE((simplify_max == Max<DimExpr>{{S0, S1}}));
227+
ASSERT_TRUE((simplify_min == Min<DimExpr>{{S0, S1}}));
228+
}
229+
230+
TEST(Simplify, SimplifyMaxAndMinWithGE) {
231+
DimExpr S0{"S0"};
232+
DimExpr S1{"S1"};
233+
DimExpr S2{"S2"};
234+
DimExpr add1{Add<DimExpr>{{S0, S1}}};
235+
DimExpr max1{Max<DimExpr>{{S0, add1, S2}}};
236+
DimExpr min1{Min<DimExpr>{{S0, add1, S2}}};
237+
// Min(S0, Add(S0, S1), S2) => Min(S0, S2)
238+
ASSERT_TRUE((SimplifyDimExpr(max1) == Max<DimExpr>{{add1, S2}}));
239+
ASSERT_TRUE((SimplifyDimExpr(min1) == Min<DimExpr>{{S0, S2}}));
240+
241+
// Min(S0, Add(S0,S1), Mul(Add(S1, S2), S2)) => S0
242+
DimExpr mul{Mul<DimExpr>{{add1, S2}}};
243+
DimExpr max2{Max<DimExpr>{{S0, add1, mul}}};
244+
DimExpr min2{Min<DimExpr>{{S0, add1, mul}}};
245+
ASSERT_TRUE((SimplifyDimExpr(max2) == mul));
246+
ASSERT_TRUE((SimplifyDimExpr(min2) == S0));
247+
248+
// Min(S0, Add(S0, -1)) => Add(S0, -1)
249+
DimExpr add2{Add<DimExpr>{{S0, Negative<DimExpr>{DimExpr(1)}}}};
250+
ASSERT_TRUE(
251+
(SimplifyDimExpr(Min<DimExpr>{{S0, add2}}) == Add<DimExpr>{{S0, -1}}));
252+
253+
// Min(S0, Add(S0, -S1)) => Add(S0, -S1)
254+
DimExpr add3{Add<DimExpr>{{S0, Negative<DimExpr>{S1}}}};
255+
ASSERT_TRUE((SimplifyDimExpr(Min<DimExpr>{{S0, add3}}) == add3));
256+
257+
// Min(S0, 0) => 0, Max(S0, 0) => S0
258+
ASSERT_TRUE((SimplifyDimExpr(Min<DimExpr>{{S0, DimExpr(0)}}) == DimExpr(0)));
259+
ASSERT_TRUE((SimplifyDimExpr(Max<DimExpr>{{S0, DimExpr(0)}}) == S0));
260+
261+
// Min(S0, 1) => 1, Max(S0, 1) => S0
262+
ASSERT_TRUE((SimplifyDimExpr(Min<DimExpr>{{S0, DimExpr(1)}}) == DimExpr(1)));
263+
ASSERT_TRUE((SimplifyDimExpr(Max<DimExpr>{{S0, DimExpr(1)}}) == S0));
264+
265+
// Min(Mul(S0, S1), 0) => Min(Mul(S0, S1), 0)
266+
// Now simplify ability is limited.
267+
DimExpr mul2{Mul<DimExpr>{{S0, S1}}};
268+
ASSERT_TRUE((SimplifyDimExpr(Min<DimExpr>{{mul2, DimExpr(0)}}) ==
269+
Min<DimExpr>{{mul2, DimExpr(0)}}));
222270
}
223271

224272
TEST(Simplify, SimplifyDoubleNegForMulAndDiv) {

0 commit comments

Comments
 (0)