Skip to content

Conversation

@cangtianhuang
Copy link
Contributor

@cangtianhuang cangtianhuang commented Jul 16, 2025

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

错误定位:

  1. ThrustCumsumKernel 自身的精度误差就极大
  2. fp32 下大 tensor 拥有累计误差

解决方案:

  1. 直接删除 ThrustCumsumKernel 分支处理,进入后续 CUDA cub 计算
  2. BlockPrefixCallbackOp 采用 Kahan 算法,参考:https://en.wikipedia.org/wiki/Kahan_summation_algorithm
  3. BlockPrefixCallbackOp 对于 LogAddExp 算子特例,采用 Kahan + Online Scale,使其数值更稳定

其他修改:

  1. 优化 np_logcumsumexp_grad ,避免直接计算 np.log(-dout) (dout > 0) 报错

测试:

修复后,测试用例大致分为三类:

  1. 累积维度小,直接 Pass
  2. 累积维度过大,精度误差超 1e-2:
image

考虑到超大张量累积误差本身就更大,将 atol、rtol 改为 1 后基本通过测试:
image

  1. 数据类型为 fp16,torch 计算错误,paddle 使用 MPType,符合理论值:
image

将其添加至 torch_error_skip 跳过精度测试

  1. torch 直接报告 CUDA 700:
image

将其添加至 torch_error_skip 跳过精度测试

补充测试:

  1. paddle_only 测试全部通过:
image
  1. paddle 与 torch 在固定数值下的理论值分析,对于以下测试用例:
paddle.cumsum(x=Tensor([4294967297],"float32"), ) paddle.logcumsumexp(x=Tensor([4294967297],"float32"), ) 

张量通过 full 填充为 0.01,则 cumsum 期望值约为 42949672.97logcumsumexp 期望值约为 22.1907
image
image

发现 cumsum 均与理论值有差异,但数值接近,用时相同;logcumsumexp 的 torch 更接近理论值,paddle 仍有差异,且远远慢于 torch,等待算法进一步修复

Pcard-85711

@paddle-bot
Copy link

paddle-bot bot commented Jul 16, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@cangtianhuang cangtianhuang marked this pull request as ready for review July 17, 2025 07:07
Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM Kernel改动比较大,让书豪帮忙把把关

@lshpku lshpku merged commit 0ca88c4 into PaddlePaddle:develop Jul 17, 2025
92 of 94 checks passed
@cangtianhuang cangtianhuang deleted the fix-cumsum branch July 26, 2025 14:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants