Skip to content

Conversation

@wanghuancoder
Copy link
Contributor

@wanghuancoder wanghuancoder commented Mar 28, 2025

PR Category

Execute Infrastructure

PR Types

Bug fixes

Description

reduce相关Kernel支持 big tensor,本PR可以使728个大Tensor失败用例中726个用例测试通过。具体测试用例见:
big tensor reduce error case.txt
不能通过的case为:

  1. paddle.sum(Tensor([3, 715827883, 2],"int32"), axis=1, keepdim=True, )
  2. paddle.sum(Tensor([1, 2281701379, 1],"float32"), 1, )
    报精度问题,本PR先合入,另外两个问题另提PR修复。

此外,还有大量API/Kernel依赖phi::funcs::ReduceKernel,应该都会有改善。

本PR涉及的改动包括:

  • 大Tensor相关
  1. 修改大量int32的使用,主要涉及:numel、reduce_num、stride、index等。
  2. 对于index的修改,参考Torch,如果int32可以不越界,则使用int32,如果int32越界则使用int64。因为int64的++更耗时。
  3. 由于cub版本过低,不能支持超过int32的大Tensor的运算,对于大Tensor不支持cub运算。
  4. 经测试,早期对SumRawKernel使用Eigen做的大Tensor支持存在hang、700、精度等问题。改用phi::funcs::ReduceKernel。
  • 精度相关
  1. 修改浮点数相等对比精度,从1e-8改为1e-15,否则一些运算精度与Torch比,存在错误。
  2. 修改max_grad、min_grad的运算规则,当axis=None且前向多个元素相等且为最大值时,反向梯度需要均分梯度值。

Paddle的max、min与Torch存在diff:
未修改前:

  • Paddle的max、min就是返回最大值、最小值,反向将out_grad给前向最大/最小值每个位置都copy一份。
  • Torch
    • 如果axis为None,则返回最大值、最小值,反向将out_grad给前向最大/最小值每个位置都均分一份。
    • 如果axis不为None,则返回每个维度的最大值、最小值,并返回每个维度其中一个最大值的索引,如果有多个最大值,只提供其中一个索引。反向将out_grad copy到提供索引的梯度位置。

Paddle不可能将max、min修改为提供索引,因为这属于不兼容升级,目前看没有这个必要。本次修改后行为如下:

  • 如果axis为None,则返回最大值、最小值,反向将out_grad给前向最大/最小值每个位置都均分一份。
  • 如果axis不为None,则返回每个维度的最大值、最小值,不返回索引,反向将out_grad给前向最大/最小值每个位置都copy一份。(也可以改成 反向将out_grad给前向最大/最小值每个位置都均分一份,待讨论

max_grad、min_grad,copy和均分的diff如下:
1744680895667
因此,均分的运算量更大,以paddle.max(Tensor([20010241024],"float32"))为例,max_grad耗时上升73.46%。

此外:

  1. 由于以上几个原因,修改了一些reduce的基础逻辑,导致编译蔓延,修改了很多Kernel代码。
  2. 由于修改精度问题,蔓延修改了组合算子逻辑。
  3. 蔓延修改了相应单测测试逻辑。

Pcard-67164

@paddle-bot
Copy link

paddle-bot bot commented Mar 28, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-ci-bot
Copy link

paddle-ci-bot bot commented Apr 21, 2025

Sorry to inform you that bddf7db's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wanghuancoder wanghuancoder merged commit 29c3bf3 into PaddlePaddle:develop Apr 23, 2025
36 of 39 checks passed
YqGe585 pushed a commit to YqGe585/Paddle that referenced this pull request May 7, 2025
wanghuancoder added a commit to wanghuancoder/Paddle that referenced this pull request May 27, 2025
wanghuancoder added a commit that referenced this pull request Jun 3, 2025
* refine forrange (#72360) * refine forrange * refine forrange * reduce support big tensor (#71970) * reduce support big tensor * [PHI] Fix gridDim limit for reduce kernel (#72507) * [API] isclose support bigtensor (#72516) * isclose support bigtensor * refine * [API] isnan isinf isfinite support bigtensor (#72517) * isnan isinf isfinite support bigtensor * refine * [PHI] Fix cum kernel for big tensor (#72562) * [PHI] Preliminary fix for elementwise broadcast int32 shape overflow (#72584) * [PHI] Align linalg.solve kernel with torch (#72608) * Update strided copy kernel (#72662) * [PHI] Fix grid sample kernel for big tensor (#72628) * [PHI] Fix argsort big tensor bug (#72712) * [PHI] Fixed argsort big tensor bug * [PHI] Fixed shape mismatch problem. * [PHI] Fix contiguous kernel for big tensor (#72705) * [PHI] Fix flatten and split kernel for big tensor (#72634) * [PHI] Fix out-of-bound issue of paddle.take_along_axis (#72757) * [PHI] fix paddle.diag with big tensor (#72638) * [API] fix paddle.cross with big tensor (#72652) * [PHI] Fix paddle.where api for big tensor (#72717) * [PHI] Fix bincount kernel for big tensor (#72706) * fix bincount kernel for big tensor * use HostAlloc to alloc memory * add cpu test case * [PHI] Fix full_like kernel for big tensor (#72831) * [API] Fix int overflow and float16 support for paddle.frac (#72815) * [PHI] Align paddle.inner with torch in matmul logic (#72843) * [PHI] Fix paddle.var & paddle.std float16 overflow (#72650) * [PHI] Fix logsumexp precision problem (#72681) * [PHI] Debug for logsumexp, bug source found * [PHI] Removed GetNumBlocks func to get correct logsumexp * [PHI] Removed redundant debug VLOG * [PHI] Elegant grid bounded solution * [Accuracy diff No.55-56、76-77] Fix accuracy diff for var&std API (#72879) * [Accuracy diff No.21] Fix accuracy diff for heaviside API (#72894) --------- Co-authored-by: Shuhao Liang <50269654+lshpku@users.noreply.github.com> Co-authored-by: Qianyue He <46109954+Enigmatisms@users.noreply.github.com> Co-authored-by: Lei Ding <69283446+Dmovic@users.noreply.github.com> Co-authored-by: ggggxm <66855582+ggggxm@users.noreply.github.com> Co-authored-by: xkkkkkk23 <xiekeke@baidu.com> Co-authored-by: Zx <zhangxiao35@baidu.com> Co-authored-by: huangjiyi <43315610+huangjiyi@users.noreply.github.com> Co-authored-by: ooo oo <106524776+ooooo-create@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

4 participants