Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions paddle/cinn/common/integer_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ cas_intervals_t CollectVarIntervalsOfExprs(const std::vector<ir::Expr>& exprs,
lower_bound = ir::Expr(1);
}
var_intervals.insert(
{var->name, CasInterval(lower_bound, upper_bound)});
{var->name,
CasInterval(lower_bound, NormalizeUpperBound(upper_bound))});
}
return false;
});
Expand Down Expand Up @@ -572,14 +573,21 @@ class BoundReplacer : public ir::IRMutator<> {
ir::Expr SymbolicExprAnalyzer::LowerBound(const ir::Expr& expr) const {
BoundReplacer bound_replacer(var_intervals_, true);
ir::Expr bound = ir::ir_utils::IRCopy(expr);
if (bound.is_index()) {
bound = bound.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3);
}
bound_replacer(&bound);
return optim::ArithSimplify(bound);
}

ir::Expr SymbolicExprAnalyzer::UpperBound(const ir::Expr& expr) const {
BoundReplacer bound_replacer(var_intervals_, false);
ir::Expr bound = ir::ir_utils::IRCopy(expr);
if (bound.is_index()) {
bound = bound.as_index().Normalize(ir::IndexExpr::OptLevel::kLevel3);
}
bound_replacer(&bound);

return optim::ArithSimplify(bound);
}

Expand Down Expand Up @@ -709,7 +717,8 @@ SingleIntervalIntSet::SingleIntervalIntSet(const ir::Expr& min,
? x->as_var()->upper_bound
: SymbolicExprLimit::positive_inf;
var_intervals_.insert(
{x->as_var()->name, CasInterval(lower_bound, upper_bound)});
{x->as_var()->name,
CasInterval(lower_bound, NormalizeUpperBound(upper_bound))});
}
return false;
};
Expand Down
10 changes: 10 additions & 0 deletions paddle/cinn/common/ir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,16 @@ bool is_zero(Expr v) {
return false;
}

Expr NormalizeUpperBound(Expr upper_bound, bool minus_one /* = true */) {
if (upper_bound == SymbolicExprLimit::positive_inf) {
return upper_bound;
}
if (minus_one) {
return upper_bound - ir::Expr(1); // [lower, upper) to [lower, upper]
}
return upper_bound + ir::Expr(1); // (lower, upper] to [lower, upper)
}

Expr CastIfNeeded(Expr body, Type type) {
if (body.type() == type) return body;
return ir::Cast::Make(type, body);
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/common/ir_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ std::vector<std::string> GatherItersToTensorProducer(

bool is_zero(Expr v);

Expr NormalizeUpperBound(Expr upper_bound, bool minus_one = true);

bool MathEqual(const Expr &a, const Expr &b);

//! helper function to get a ir::Select node.
Expand Down
10 changes: 7 additions & 3 deletions paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ IntSet Evaluate(Expr expr,
const std::unordered_map<ir::Var, IntSet>& var_domain) {
Expr copy_for_upper_bound = ir::ir_utils::IRCopy(expr);
Expr copy_for_lower_bound = ir::ir_utils::IRCopy(expr);
common::cas_intervals_t var_intervals;
common::cas_intervals_t
var_intervals; // variable name -> CasIntervals[lower_bound, upper_bound]
std::vector<ir::Expr> var_vec = ir::ir_utils::CollectIRNodesWithoutTensor(
expr, [](const ir::Expr* x) { return x->as_var(); });
for (Expr var_expr : var_vec) {
Expand All @@ -150,7 +151,9 @@ IntSet Evaluate(Expr expr,
const ir::Var& fixed_var = fixed.at(var);
var_intervals.emplace(
fixed_var->name,
common::CasInterval(fixed_var->lower_bound, fixed_var->upper_bound));
common::CasInterval(
fixed_var->lower_bound,
cinn::common::NormalizeUpperBound(fixed_var->upper_bound)));
optim::ReplaceVarWithExpr(&copy_for_lower_bound, var, Expr(fixed_var));
optim::ReplaceVarWithExpr(&copy_for_upper_bound, var, Expr(fixed_var));
} else if (var_domain.count(var) != 0) {
Expand All @@ -172,7 +175,8 @@ IntSet Evaluate(Expr expr,
::common::errors::InvalidArgument(
"The 'upper_bound' of the variable must be defined."));
optim::ReplaceVarWithExpr(&copy_for_lower_bound, var, var->lower_bound);
optim::ReplaceVarWithExpr(&copy_for_upper_bound, var, var->upper_bound);
optim::ReplaceVarWithExpr(
&copy_for_upper_bound, var, NormalizeUpperBound(var->upper_bound));
}
}
ir::Expr lower_bound = optim::ArithSimplify(copy_for_lower_bound);
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ struct _Var_ : public ExprNode<_Var_> {
};

//! A named variable.
// i ∈ [lower_bound, upper_bound)
struct Var : public IrNodeRef {
Var() = default;
explicit Var(IrNode* n) : IrNodeRef(n) {}
Expand Down Expand Up @@ -846,6 +847,7 @@ struct For : public ExprNode<For>, public ForBase {
//! The minimum value of the iteration.
Expr min;
//! The extent of the iteration.
// loop_var ∈ [min, min + extent)
Expr extent;

Expr body;
Expand Down
12 changes: 8 additions & 4 deletions paddle/cinn/ir/ir_analyzer/ir_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,8 @@ std::vector<ir::Var> IndicesToVars(const std::vector<ir::Expr>& indices) {
if (e.is_constant()) {
std::string var_name =
cinn::UniqName("constant" + static_cast<int>(e.get_constant()));
result.emplace_back(e, e, var_name, /* is_reduce = */ false);
result.emplace_back(
e, NormalizeUpperBound(e, false), var_name, /* is_reduce = */ false);
} else if (e.As<ir::_Var_>() != nullptr) {
ir::Expr copy_e = ir::ir_utils::IRCopy(e);
ir::_Var_* var_ref = copy_e.As<ir::_Var_>();
Expand All @@ -635,14 +636,17 @@ std::vector<ir::Var> IndicesToVars(const std::vector<ir::Expr>& indices) {
ir::Var var = x->as_var_ref();
var_intervals.insert(
{var->name,
common::CasInterval{var->lower_bound, var->upper_bound}});
common::CasInterval{var->lower_bound,
NormalizeUpperBound(var->upper_bound)}});
if (var->is_reduce_axis) is_reduce = true;
}
return false;
});
common::SymbolicExprAnalyzer analyzer(var_intervals);
result.emplace_back(
analyzer.LowerBound(e), analyzer.UpperBound(e), var_name, is_reduce);
result.emplace_back(analyzer.LowerBound(e),
NormalizeUpperBound(analyzer.UpperBound(e), false),
var_name,
is_reduce);
}
}
return result;
Expand Down
Loading