Skip to content

Conversation

@fxyfxy777
Copy link
Contributor

@fxyfxy777 fxyfxy777 commented Aug 6, 2025

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

修复 paddle.distpaddle.nn.functional.normalize

两个算子共用一个kernel,所以放在一起分析

  1. 前向计算报错(精度误差)
    • Paddle 前向输出为 inf,而 Torch 输出正常数值;
    • 原因:在 FP16 模式下,reduce_sum(pow(...)) 的中间计算过程中发生溢出;
  • 解决方案:将中间计算过程提升到 float32,最后再 cast 回 float16 输出,避免溢出。
  1. 反向计算报错一般情况
    • 修改前向后,反向结果仍不一致,其中部分case出现以下情况
    • 现象:Paddle 可正常求导,Torch 输出为 infnan
    • 提高精度(如使用 float32)后,Paddle 与 Torch 输出一致;
    • 示例(float16 精度报错):
      [accuracy error] backward paddle.nn.functional.normalize(x=Tensor([4, 5, 6, 3579139], "float16"), p=4) Not equal to tolerance rtol=0.01, atol=0.01 Greatest absolute difference: nan
    • 示例(float32 精度一致):
      [Pass] paddle.nn.functional.normalize(x=Tensor([4, 5, 6, 3579139], "float32"), p=4)
  • 解决方案:这部分case写入到torch_error_skip。
  1. 反向计算报错特殊情况
    -现象:Paddle 和 Torch 都能反向计算出结果,但差距极大(最大超过 160)。
    -原因:查看torch源码,发现两者反向计算公式基本一致,但对特殊值 0 的处理不同。
  • Paddle:使用 1 / (norm + eps),导致较大数值。

  • Torch:使用 mask 将这些位置梯度直接置为 0。

  • 优化:参考 Torch 的处理方式,对 norm == 0 的位置直接 mask 为 0,误差从上百缩小至 0.2 以内。
    示例前后对比:

  • 修改前

    输入: Tensor([2281701, 1], dtype="float32"), axis=1 最大误差: 163.0 
  • 修改后

    输入: Tensor([2281701, 1], dtype="float32"), axis=1 最大误差: 0.0625 

4.最终修改内容
• 在 float16 模式下,将中间计算提至 float32 精度,避免溢出。
• 修复整数类型溢出问题,使用 int64_t 替代潜在溢出的 int32。
• 对于 Torch 输出 inf/nan 的反向 case,跳过精度对齐测试。
• 优化反向 norm == 0 情况的处理方式,使用 mask 替代除法以减少误差。
Pcard-92269

@paddle-bot
Copy link

paddle-bot bot commented Aug 6, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@fxyfxy777
Copy link
Contributor Author

/re-run all-failed

@fxyfxy777
Copy link
Contributor Author

/re-run all-failed

@fxyfxy777
Copy link
Contributor Author

/re-run all-failed

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wanghuancoder wanghuancoder merged commit 239dff3 into PaddlePaddle:develop Aug 7, 2025
65 of 66 checks passed
Enigmatisms pushed a commit to Enigmatisms/Paddle that referenced this pull request Aug 9, 2025
* fix: fix normalization logic in p_norm kernels * Code Formatting * zancun * fix_dist_normalize * back * huifu zhushi * pre-commit * Fix code style * Standardize code comments
@fxyfxy777 fxyfxy777 deleted the fix_normaliz branch September 9, 2025 06:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants