Skip to content

Conversation

@Enigmatisms
Copy link
Contributor

@Enigmatisms Enigmatisms commented Jul 9, 2025

PR Category

CINN

PR Types

Bug fixes

Description

本PR修复了 gather_nd 算子的CINN lowering逻辑。原始逻辑不会对负数输入进行处理,导致计算错误。本PR:

  • 增加了对负数的处理(基于ir::Select
  • 删除了无效的 non-symbolic strategy
  • 增加了对应的单测 test_cinn_gather_nd.py

性能测试结果

配置 原始 本PR
test_llama forward.py 约38.3us 约45.3us
test_llama inference.py 约38.8us 约45.6us
gather_nd(x * y, indices) + z (256, 128) + (100, 2) 3.35ms 3.32ms
gather_nd(x * y, indices) + z (4096, 2048) + (131072, 2) 41.3ms 41.5ms

注:前两个测试例子为 paddle 的单测,内部有融合了 gather_nd 算子的kernel,本实验统计的是与 gather_nd 相关kernel 的平均运行时间。

部分kernel的性能好像降了很多?经过NCU的分析发现:对应的kernel 实际上throughput提高了,且bottleneck环节delay大幅下降(比如lg throttle等等),但执行的SASS指令数量显著提升。

进一步对比了下面几个实现方法对应的速度:

  • select(本PR,单个kernel): ~17us
  • 位操作(64bit,具体:((-int(index > 0)) & shape) + index): ~17.5 us
  • 位操作(32bit): ~14.5us
  • mod操作((index + shape) % shape): ~12us,与修改前一致(甚至快一丢丢)。这个处理方法原本是非常理想的(使用 mod 操作实现 CINN 内部逻辑时,甚至不用对 shape 内可能的 min、max操作 operand 进行 recast),但现有的表达式简化逻辑暂不支持存在负数的情况((a+N)%N会被简化为a % N,在a为负数,N为常数时,这个是一个不成立的简化)。故本优化被暂时放弃,因为上述select实现的PR引入的性能下降反应在整个pass上很小(受影响kernel的时间占比小于1%)。

Pcard-89620

@Enigmatisms Enigmatisms force-pushed the index_put_grad branch 2 times, most recently from 9d0aa9b to 4fcb05f Compare July 11, 2025 07:32
@lshpku lshpku merged commit e1842d4 into PaddlePaddle:develop Jul 15, 2025
105 of 110 checks passed
@Enigmatisms Enigmatisms deleted the index_put_grad branch August 29, 2025 05:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants