@@ -1005,11 +1005,12 @@ struct FoldRedundantSymbolicBroadcast {
10051005 * Simplify Example:
10061006 * Broadcast(S0,S0,S1) => Broadcast(S0,S1)
10071007 */
1008- struct FoldRedundantBroadcast {
1009- using dim_expr_type = Broadcast<DimExpr>;
1008+ template <template <typename > class Op >
1009+ struct FoldRepetitiveSymbol {
1010+ using dim_expr_type = Op<DimExpr>;
10101011
10111012 DimExpr Rewrite (const DimExpr& expr) {
1012- const auto & [operands] = expr.Get <Broadcast <DimExpr>>();
1013+ const auto & [operands] = expr.Get <Op <DimExpr>>();
10131014 while (operands->size () > 1 ) {
10141015 int pos_index = SearchSameIndex (operands);
10151016 if (pos_index < 0 ) {
@@ -1020,7 +1021,7 @@ struct FoldRedundantBroadcast {
10201021 if (operands->size () == 1 ) {
10211022 return operands->at (0 );
10221023 } else {
1023- return Broadcast <DimExpr>{operands};
1024+ return Op <DimExpr>{operands};
10241025 }
10251026 PADDLE_THROW (common::errors::Fatal (" Dead code." ));
10261027 }
@@ -1040,6 +1041,197 @@ struct FoldRedundantBroadcast {
10401041 }
10411042};
10421043
1044+ DimExprCompareResult EasyCompareAddWithZero (const Add<DimExpr>& add) {
1045+ // Only return GT, GE, UNKNOWN.
1046+ List<DimExpr> operands = add.operands ;
1047+ for (const auto & operand : *operands) {
1048+ if (operand.isa <std::string>()) {
1049+ continue ;
1050+ }
1051+ if (operand.isa <std::int64_t >() && operand.dyn_cast <int64_t >() > 0 ) {
1052+ continue ;
1053+ }
1054+ return DimExprCompareResult::UNKNOWN;
1055+ }
1056+ return DimExprCompareResult::GT;
1057+ }
1058+
1059+ DimExprCompareResult EasyCompareMulWithOne (const Mul<DimExpr>& mul) {
1060+ // Only return GT, GE, UNKNOWN.
1061+ List<DimExpr> operands = mul.operands ;
1062+ int64_t const_result = 1 ;
1063+ for (const auto & operand : *operands) {
1064+ if (operand.isa <std::string>()) {
1065+ continue ;
1066+ }
1067+ if (operand.isa <std::int64_t >() && operand.dyn_cast <int64_t >() > 1 ) {
1068+ const_result = operand.dyn_cast <int64_t >();
1069+ continue ;
1070+ }
1071+
1072+ return DimExprCompareResult::UNKNOWN;
1073+ }
1074+ if (const_result == 1 ) {
1075+ return DimExprCompareResult::GE;
1076+ } else {
1077+ return DimExprCompareResult::GT;
1078+ }
1079+ }
1080+
1081+ bool EasyIsGtWithZero (const DimExpr& expr) {
1082+ auto ExprVisit = common::Overloaded{
1083+ [](const std::int64_t & expr) { return expr > 0 ; },
1084+ [](const std::string& expr) { return true ; },
1085+ [](const Mul<DimExpr>& expr) {
1086+ return EasyCompareMulWithOne (expr) != DimExprCompareResult::UNKNOWN;
1087+ },
1088+ [](const Add<DimExpr>& expr) {
1089+ return EasyCompareAddWithZero (expr) == DimExprCompareResult::GT;
1090+ },
1091+ [](const Broadcast<DimExpr>& expr) { return true ; },
1092+ [](const auto & expr) { return false ; }};
1093+ return std::visit (ExprVisit, expr.variant ());
1094+ }
1095+
1096+ DimExprCompareResult EasyCompareGtOrGe (const DimExpr& lhs,
1097+ const DimExpr& rhs,
1098+ bool is_broadcast = false ) {
1099+ // TODO(ooooo): not perfect but ensures accuracy now.Such as:
1100+ // S0 < Add(S0, Mul(S1, S2)), S2 also can be Add(S4, S5, -1)
1101+ // range info may be used.
1102+ auto CompareDivResult = common::Overloaded{
1103+ [](const std::int64_t & expr) {
1104+ // trick for Min(Mul(5, S0), Mul(3, S0))
1105+ return expr >= 1 ? DimExprCompareResult::GT
1106+ : DimExprCompareResult::UNKNOWN;
1107+ },
1108+ [](const std::string& expr) { return DimExprCompareResult::GE; },
1109+ [&](const Mul<DimExpr>& expr) { return EasyCompareMulWithOne (expr); },
1110+ [](const auto & expr) { return DimExprCompareResult::UNKNOWN; }};
1111+
1112+ auto CompareSubResult = common::Overloaded{
1113+ [](const std::int64_t & expr) {
1114+ return expr > 0 ? DimExprCompareResult::GT
1115+ : DimExprCompareResult::UNKNOWN;
1116+ },
1117+ [](const std::string& expr) { return DimExprCompareResult::GT; },
1118+ [&](const Add<DimExpr>& expr) { return EasyCompareAddWithZero (expr); },
1119+ [](const auto & expr) { return DimExprCompareResult::UNKNOWN; }};
1120+
1121+ if (lhs.isa <std::string>() && rhs.isa <std::string>()) {
1122+ return DimExprCompareResult::UNKNOWN;
1123+ }
1124+
1125+ auto IsAddOrMul = [](const DimExpr& expr) {
1126+ return expr.isa <Add<DimExpr>>() || expr.isa <Mul<DimExpr>>();
1127+ };
1128+ auto IsOneOrZero = [](const DimExpr& expr) {
1129+ return expr == DimExpr{1 } || expr == DimExpr{0 };
1130+ };
1131+
1132+ if (!IsAddOrMul (lhs) && !IsAddOrMul (rhs) && !IsOneOrZero (lhs) &&
1133+ !IsOneOrZero (rhs)) {
1134+ return DimExprCompareResult::UNKNOWN;
1135+ }
1136+
1137+ // check with Sub
1138+ DimExpr simplified_result_sub = SimplifyDimExpr (DimExpr{lhs} - DimExpr{rhs});
1139+ auto sub_compare =
1140+ std::visit (CompareSubResult, simplified_result_sub.variant ());
1141+ if (sub_compare != DimExprCompareResult::UNKNOWN) {
1142+ return sub_compare;
1143+ }
1144+ if (rhs != symbol::DimExpr{0 }) {
1145+ if (!is_broadcast && !EasyIsGtWithZero (rhs)) {
1146+ // assume operands in broadcast is always positive.
1147+ return DimExprCompareResult::UNKNOWN;
1148+ }
1149+ DimExpr simplified_result_div =
1150+ SimplifyDimExpr (DimExpr{lhs} / DimExpr{rhs});
1151+ auto div_compare =
1152+ std::visit (CompareDivResult, simplified_result_div.variant ());
1153+ return div_compare;
1154+ } else {
1155+ return DimExprCompareResult::UNKNOWN;
1156+ }
1157+ }
1158+
1159+ struct SimplifyMaxWithGE {
1160+ using dim_expr_type = Max<DimExpr>;
1161+ static List<DimExpr> SearchErasable (const List<DimExpr>& operands) {
1162+ List<DimExpr> simplified_operands{};
1163+ for (std::size_t i = 0 ; i < operands->size (); ++i) {
1164+ bool is_redundant = false ;
1165+ for (std::size_t j = 0 ; j < operands->size (); ++j) {
1166+ if (i == j) {
1167+ continue ;
1168+ }
1169+ auto compare_j_i = EasyCompareGtOrGe (operands->at (j), operands->at (i));
1170+ if (compare_j_i == DimExprCompareResult::GT ||
1171+ compare_j_i == DimExprCompareResult::GE) {
1172+ is_redundant = true ;
1173+ break ;
1174+ }
1175+ }
1176+ if (!is_redundant) {
1177+ simplified_operands->push_back (operands->at (i));
1178+ }
1179+ }
1180+ return simplified_operands;
1181+ }
1182+
1183+ DimExpr Rewrite (const DimExpr& expr) {
1184+ const auto [operands] = expr.Get <Max<DimExpr>>();
1185+ List<DimExpr> simplified_operands = SearchErasable (operands);
1186+
1187+ if (simplified_operands->size () == 1 ) {
1188+ return simplified_operands->at (0 );
1189+ } else {
1190+ return Max<DimExpr>{simplified_operands};
1191+ }
1192+ }
1193+ };
1194+
1195+ /*
1196+ * Simplify Example:
1197+ * Min(S0, Mul(S0, S1)) => S0
1198+ */
1199+ struct SimplifyMinWithGE {
1200+ using dim_expr_type = Min<DimExpr>;
1201+ static List<DimExpr> SearchErasable (const List<DimExpr>& operands) {
1202+ List<DimExpr> simplified_operands{};
1203+ for (std::size_t i = 0 ; i < operands->size (); ++i) {
1204+ bool is_redundant = false ;
1205+ for (std::size_t j = 0 ; j < operands->size (); ++j) {
1206+ if (i == j) {
1207+ continue ;
1208+ }
1209+ auto compare_i_j = EasyCompareGtOrGe (operands->at (i), operands->at (j));
1210+ if (compare_i_j == DimExprCompareResult::GT ||
1211+ compare_i_j == DimExprCompareResult::GE) {
1212+ is_redundant = true ;
1213+ break ;
1214+ }
1215+ }
1216+ if (!is_redundant) {
1217+ simplified_operands->push_back (operands->at (i));
1218+ }
1219+ }
1220+ return simplified_operands;
1221+ }
1222+
1223+ DimExpr Rewrite (const DimExpr& expr) {
1224+ const auto [operands] = expr.Get <Min<DimExpr>>();
1225+ List<DimExpr> simplified_operands = SearchErasable (operands);
1226+
1227+ if (simplified_operands->size () == 1 ) {
1228+ return simplified_operands->at (0 );
1229+ } else {
1230+ return Min<DimExpr>{simplified_operands};
1231+ }
1232+ }
1233+ };
1234+
10431235/*
10441236 * Simplify Example:
10451237 * Broadcast(S0, Mul(S0, S1)) => Mul(S0, S1)
@@ -1266,6 +1458,8 @@ DimExpr Simplify(const DimExpr& expr) {
12661458 DoPass<SimplifyOperands<Mul>>(&keep_rewrite, &ret);
12671459 DoPass<SimplifyOperands<Div>>(&keep_rewrite, &ret);
12681460 DoPass<SimplifyOperands<Broadcast>>(&keep_rewrite, &ret);
1461+ DoPass<SimplifyOperands<Min>>(&keep_rewrite, &ret);
1462+ DoPass<SimplifyOperands<Max>>(&keep_rewrite, &ret);
12691463 DoPass<SortOperands<Add>>(&keep_rewrite, &ret);
12701464 DoPass<SortOperands<Mul>>(&keep_rewrite, &ret);
12711465 DoPass<SortOperands<Broadcast>>(&keep_rewrite, &ret);
@@ -1283,9 +1477,13 @@ DimExpr Simplify(const DimExpr& expr) {
12831477 DoPass<FoldConstants<Min>>(&keep_rewrite, &ret);
12841478 DoPass<FoldConstants<Broadcast>>(&keep_rewrite, &ret);
12851479 DoPass<FoldInversedPairToUnit<Add>>(&keep_rewrite, &ret);
1286- DoPass<FoldRedundantBroadcast>(&keep_rewrite, &ret);
1480+ DoPass<FoldRepetitiveSymbol<Broadcast>>(&keep_rewrite, &ret);
1481+ DoPass<FoldRepetitiveSymbol<Min>>(&keep_rewrite, &ret);
1482+ DoPass<FoldRepetitiveSymbol<Max>>(&keep_rewrite, &ret);
12871483 DoPass<FoldRedundantSymbolicBroadcast>(&keep_rewrite, &ret);
12881484 DoPass<SimplifyBroadcast>(&keep_rewrite, &ret);
1485+ DoPass<SimplifyMinWithGE>(&keep_rewrite, &ret);
1486+ DoPass<SimplifyMaxWithGE>(&keep_rewrite, &ret);
12891487 DoPass<SimplifyDiv>(&keep_rewrite, &ret);
12901488 if (expr_before_run_pipeline == ret) break ;
12911489 }
0 commit comments