Skip to content

Commit 57f8c36

Browse files
authored
fix min max bug (#72267)
1 parent 3b7e089 commit 57f8c36

File tree

3 files changed

+20
-16
lines changed

3 files changed

+20
-16
lines changed

paddle/cinn/ir/ir_base.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,9 @@ void ElevateInt64ToInt32_(Expr &expr) { // NOLINT
734734
"Current only support convert int64_t "
735735
"to int32_t, but get type is: %s",
736736
expr->type()));
737+
738+
// althoughtype is Int(32), we also need to convert it indices to Int(32).
739+
if (expr->node_type() == IrNodeTy::Load) expr->convert_int64_to_int32();
737740
if (expr->type() == Int(64)) {
738741
expr->convert_int64_to_int32();
739742
if (expr->node_type() == IrNodeTy::Cast) {

paddle/cinn/optim/longlong2int_pass.cc

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,25 +138,28 @@ class CastLonglong2IntMutator : public ir::IRMutator<> {
138138
}
139139
void Visit(const ir::Min* op, Expr* expr) override {
140140
auto node = expr->As<ir::Min>();
141-
if (node->a().is_index() && node->b().is_index()) {
142-
if ((node->a().is_var() && node->a().as_var()->is_symbolic_constant) ||
143-
(node->b().is_var() && node->b().as_var()->is_symbolic_constant)) {
144-
ir::ElevateInt64ToInt32_((*expr)->operands);
145-
}
141+
// min(min(S0, 1ll), 1ll) ==> min(min(S0, 1), 1)
142+
// min(V[S0, S1], 1ll) ==> min(V[S0, S1], 1ll)
143+
// min(S0 + 1ll, 1ll) ==> max(S0 + 1, 1)
144+
// IndexType::kValid means expr only has +-*/%, Const, Symbol, Min, Max.
145+
// IsDynamic == true means expr has Symbol.
146+
if (optim::VerifyIndex(*expr) == ir::IndexExpr::IndexType::kValid &&
147+
expr->as_index().IsDynamic()) {
148+
ir::ElevateInt64ToInt32_((*expr)->operands);
149+
} else {
150+
ir::IRMutator<>::Visit(&node->a(), &node->a());
151+
ir::IRMutator<>::Visit(&node->b(), &node->b());
146152
}
147-
ir::IRMutator<>::Visit(&node->a(), &node->a());
148-
ir::IRMutator<>::Visit(&node->b(), &node->b());
149153
}
150154
void Visit(const ir::Max* op, Expr* expr) override {
151155
auto node = expr->As<ir::Max>();
152-
if (node->a().is_index() && node->b().is_index()) {
153-
if ((node->a().is_var() && node->a().as_var()->is_symbolic_constant) ||
154-
(node->b().is_var() && node->b().as_var()->is_symbolic_constant)) {
155-
ir::ElevateInt64ToInt32_((*expr)->operands);
156-
}
156+
if (optim::VerifyIndex(*expr) == ir::IndexExpr::IndexType::kValid &&
157+
expr->as_index().IsDynamic()) {
158+
ir::ElevateInt64ToInt32_((*expr)->operands);
159+
} else {
160+
ir::IRMutator<>::Visit(&node->a(), &node->a());
161+
ir::IRMutator<>::Visit(&node->b(), &node->b());
157162
}
158-
ir::IRMutator<>::Visit(&node->a(), &node->a());
159-
ir::IRMutator<>::Visit(&node->b(), &node->b());
160163
}
161164
};
162165

test/dygraph_to_static/test_word2vec.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from dygraph_to_static_utils import (
2121
Dy2StTestBase,
2222
enable_to_static_guard,
23-
test_phi_only,
2423
)
2524

2625
import paddle
@@ -318,7 +317,6 @@ def train():
318317

319318

320319
class TestWord2Vec(Dy2StTestBase):
321-
@test_phi_only
322320
def test_dygraph_static_same_loss(self):
323321
with enable_to_static_guard(False):
324322
dygraph_loss = train()

0 commit comments

Comments
 (0)