Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions paddle/cinn/backends/codegen_gpu_dev.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,33 @@ void CodeGenGpuDev::VisitStmt(const ir::stmt::Alloc &stmt) {
PrintTempBufferCreation(stmt->destination().as_buffer_ref());
}

inline void ProcessMinMaxOperand(ir::Expr *a,
ir::Expr *b,
int unify_bit,
bool both_dyn) {
if (unify_bit > 0) {
std::string type_func = "int" + std::to_string(unify_bit) + "_t";
if (both_dyn) {
// if both contains dynamic symbol, like: min(S0, S1), it it likely that
// S0 is int and S1 is int64_t. So we need to enforce the type cast by
// ir::Call
*a = ir::Call::Make(
common::Int(unify_bit), type_func, {*a}, {}, ir::CallType::Intrinsic);
*b = ir::Call::Make(
common::Int(unify_bit), type_func, {*b}, {}, ir::CallType::Intrinsic);
} else {
*a = ir::Cast::Make(common::Int(unify_bit), *a);
*b = ir::Cast::Make(common::Int(unify_bit), *b);
}
}
}

void CodeGenGpuDev::Visit(const ir::Min *op) {
str_ += "min(";
ir::Expr a = op->a(), b = op->b();
int unify_bit = common::UnifiedOperandTypeBits(&dynamic_shape_map_, op);
if (unify_bit > 0) {
a = ir::Cast::Make(common::Int(unify_bit), a);
b = ir::Cast::Make(common::Int(unify_bit), b);
}
auto [unify_bit, both_dyn] =
common::UnifiedOperandTypeBits(&dynamic_shape_map_, op);
ProcessMinMaxOperand(&a, &b, unify_bit, both_dyn);
IrPrinter::Visit(a);
str_ += ", ";
IrPrinter::Visit(b);
Expand All @@ -234,11 +253,9 @@ void CodeGenGpuDev::Visit(const ir::Min *op) {
void CodeGenGpuDev::Visit(const ir::Max *op) {
str_ += "max(";
ir::Expr a = op->a(), b = op->b();
int unify_bit = common::UnifiedOperandTypeBits(&dynamic_shape_map_, op);
if (unify_bit > 0) {
a = ir::Cast::Make(common::Int(unify_bit), a);
b = ir::Cast::Make(common::Int(unify_bit), b);
}
auto [unify_bit, both_dyn] =
common::UnifiedOperandTypeBits(&dynamic_shape_map_, op);
ProcessMinMaxOperand(&a, &b, unify_bit, both_dyn);
IrPrinter::Visit(a);
str_ += ", ";
IrPrinter::Visit(b);
Expand Down
30 changes: 15 additions & 15 deletions paddle/cinn/common/ir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -614,21 +614,21 @@ struct DynamicSymbolExprBitTracker : public ir::IRVisitor {
int dyn_symbol_bit = 0;
};

#define VISIT_OP(NodeType) \
int UnifiedOperandTypeBits( \
const std::unordered_map<std::string, common::Type> *search_map, \
const ir::NodeType *node) { \
if (search_map->empty()) return 0; \
if (!node->a().type().is_int() || !node->b().type().is_int()) return 0; \
int node_a_bits = node->a().type().bits(); \
int node_b_bits = node->b().type().bits(); \
if (node_a_bits < 32 || node_b_bits < 32) return 0; \
DynamicSymbolExprBitTracker tracker; \
tracker(search_map, &node->a()); \
int target_bit = tracker(search_map, &node->b()); \
if (target_bit > 0) { \
} \
return target_bit; \
#define VISIT_OP(NodeType) \
std::pair<int, bool> UnifiedOperandTypeBits( \
const std::unordered_map<std::string, common::Type> *search_map, \
const ir::NodeType *node) { \
if (search_map->empty()) return {0, false}; \
if (!node->a().type().is_int() || !node->b().type().is_int()) \
return {0, false}; \
int node_a_bits = node->a().type().bits(); \
int node_b_bits = node->b().type().bits(); \
if (node_a_bits < 32 || node_b_bits < 32) return {0, false}; \
DynamicSymbolExprBitTracker tracker; \
int b1 = tracker(search_map, &node->a()); \
tracker.dyn_symbol_bit = 0; \
int b2 = tracker(search_map, &node->b()); \
return std::make_pair(std::max(b1, b2), b1 > 0 && b2 > 0); \
}

VISIT_OP(Min)
Expand Down
9 changes: 6 additions & 3 deletions paddle/cinn/common/ir_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,14 @@ void OpDataTypePromote(ir::LoweredFunc *func);

// only process ir::Min and ir::Max where the operands 1. contains dynamic shape
// symbols. 2. the operands are both int types and both are 32/64 bits. Returns
// the number of bits for unifying operands (by casting)
int UnifiedOperandTypeBits(
// the number of bits for unifying operands (by casting). The bool flag
// indicates whether both sides has different dynamic shape symbols, since if
// true (like min(S0, S1))), we should not make a ir::Cast but a ir::Call
// (coercion)
std::pair<int, bool> UnifiedOperandTypeBits(
const std::unordered_map<std::string, common::Type> *search_map,
const ir::Min *op);
int UnifiedOperandTypeBits(
std::pair<int, bool> UnifiedOperandTypeBits(
const std::unordered_map<std::string, common::Type> *search_map,
const ir::Max *op);
} // namespace common
Expand Down
Loading