Skip to content

Conversation

@Enigmatisms
Copy link
Contributor

@Enigmatisms Enigmatisms commented Jul 25, 2025

PR Category

CINN

PR Types

Bug fixes

Description

longlong2int 是一个 int类型计算的 optimization pass,包括了:

  • OpDataTypePromote (显式在 CodeGenAndJit中调用)
  • TryElevateIntxxtoIntxx(32->64或者反过来都有):这个函数在很多类型、函数中都隐式调用了。包括 ir::BinaryNode 构造时

特别是第二点,导致type cast问题涉及的范围很广(甚至在 IrPrinter 生成代码的环节都能出现 type cast,比如 load/store node调用 index 方法时)。问题的主要表现是:当被编译操作中含有 ir::Min / ir::Max 两种 node 时,可能会因为动态 shape 的存在而导致 operand type 不匹配。究其原因:

  • CINN 根据 predicate 对动态shape同一个 kernel 编译几个版本:比如 shape 大时,动态 shape 输入将会使用 int64,反之用 int32,kernel 在 host 端通过 branches 进行选择。
  • 输入参数的动态 shape 表达式(ir::LoweredFunc)的 ir::Argument 与在 kernel 内部的动态 shape 表达式(ir::_Var_)不是强关联的,导致同时有 int32 / int64 版本ir::Argument 的 LoweredFunc,但 func body 内部类型并没有正确设置。具体表现如下:
// func body IR: i32 version var[some_index] = cinn_min(S0, 0); // func body IR: i64 version var[some_index] = cinn_min(S0, 0ll); // CUDA code: // 32bit ir::Argument __global__ foo_predicate_le_int_max_kernel(int S0) { var[threadIdx.x] = cinn_min(S0, 0ll); // 报错! } // 64bit ir::Argument __global__ foo_predicate_gt_int_max_kernel(int64_t S0) { var[threadIdx.x] = cinn_min(S0, 0ll); // 不报错 }

这一问题的解决,可以说目前我的方案还比较暴力(但可能很难找到更简单的方法了):

  • codegen 开始阶段(GPU only),通过哈希表记录变量名到动态shape symbol 类型的关系
  • 在codegen阶段,IrPrinter 打印 ir::Min / ir::Max node 时,对于左右operands进行 ir-visiting:当左右operand都是int类型输入,并且都包含ir::Argument中已经记录的symbol名称时,确定 最大的 int类型 bit 数(32/64)。
  • 根据最大bit数,将左右 operand 进行 cast,IrPrinter print cast 后(unified type)的 operands。

举个例子:cinn_min((S0 + 1), 1024ll)ir::Argument 中的 S0 是 int32类型的 ir::_Var_,其left operand 包含动态shape symbol,并且查哈希表知,最大 bit 是32,而 right operand 不包含动态 shape symbol,故为0,则这个 min node 应该被整体 cast 为 int32 类型的(否则在 codegen 带入时,S0 是 int,左边整体是int)。

之所以说目前没有更简单的方法,就是因为 longlong2int 相关的改动设计的API太多了,从 PostProcess 到 CodeGen 的 string code printing 过程都有int类型的相互变换。

感兴趣可以测试一下本 PR 提供的单测:test_unifying_minmax_type.py。本 PR 前此单测没有一个可以通过,并且原始的 bug 也影响了一些 CINN 算子的支持(比如 gather_nd 我尝试将min/max 进行简单的 type unify,统一到int64,引入了一定的性能问题,见 #73940),后续会对对应算子支持进行进一步简化。

Pcard-89620

@paddle-bot
Copy link

paddle-bot bot commented Jul 25, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@lshpku lshpku merged commit 8830a97 into PaddlePaddle:develop Jul 26, 2025
66 of 68 checks passed
@Enigmatisms Enigmatisms deleted the longlong2int_fix branch August 29, 2025 05:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants