Skip to content

Conversation

@Enigmatisms
Copy link
Contributor

@Enigmatisms Enigmatisms commented May 6, 2025

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

修复了部分带有 broadcast 的 elementwise 操作中,shape inference使用 int32 导致大Tensor内出现溢出的问题,目前有如下几个算子得到了修复:

  • bitwise_leftshift & bitwise_rightshift
  • equal
  • fmin
  • hypot
  • logical_or
  • masked_fill
  • functional.normalize
  • functional.pairwise_distance

argsort 目前暂未测试,预计也可以完成修复,但现有设备显存不够,无法进行大Tensor测试。

此外,目前的fix只进行了最简单且与API正确性扫描结果相关的修正,实际发现库中还有很多使用int进行shape推导的地方。

示例

这里只给出几个例子,均是原来不正确,现在正确的单测:

bitwise_left_shift:

paddle.bitwise_left_shift(paddle.to_tensor([1], dtype = "uint8"), paddle.ones([4294967297], dtype = "uint8")) paddle.bitwise_left_shift(paddle.to_tensor([1], dtype = "int16"), paddle.ones([4294967297], dtype = "int16")) paddle.bitwise_left_shift(paddle.ones([4294967297], dtype = "uint8"), paddle.to_tensor([1], dtype = "uint8")) paddle.bitwise_left_shift(paddle.ones([4294967297], dtype = "int16"), paddle.to_tensor([1], dtype = "int16"))

fmin:

paddle.fmin(paddle.to_tensor([1], dtype = "int64"), paddle.ones([2281701379], dtype = "int64")) paddle.fmin(paddle.ones([2147483649], dtype = "int64"), paddle.to_tensor([1], dtype = "int64")) paddle.fmin(paddle.ones([2281701379], dtype = "int64"), paddle.to_tensor([1], dtype = "int64"))

masked_fill:

paddle.masked_fill(paddle.ones([4294967295], dtype = "float16"), paddle.ones([4294967295], dtype = "bool"), -0.7255859375)

性能分析

目前的改动都还没有到GPU kernel内部使用 int64 进行计算,故经过nsys profiling测试,没有发现显著性能差异。

Pcard-89620

@paddle-bot
Copy link

paddle-bot bot commented May 6, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Enigmatisms Enigmatisms force-pushed the broadcast_shape_overflow branch from 29eef2a to 7a5d4f6 Compare May 6, 2025 06:23
@Enigmatisms Enigmatisms marked this pull request as ready for review May 6, 2025 08:48
@lshpku lshpku merged commit 88243ff into PaddlePaddle:develop May 7, 2025
43 of 45 checks passed
wanghuancoder added a commit that referenced this pull request Jun 3, 2025
* refine forrange (#72360) * refine forrange * refine forrange * reduce support big tensor (#71970) * reduce support big tensor * [PHI] Fix gridDim limit for reduce kernel (#72507) * [API] isclose support bigtensor (#72516) * isclose support bigtensor * refine * [API] isnan isinf isfinite support bigtensor (#72517) * isnan isinf isfinite support bigtensor * refine * [PHI] Fix cum kernel for big tensor (#72562) * [PHI] Preliminary fix for elementwise broadcast int32 shape overflow (#72584) * [PHI] Align linalg.solve kernel with torch (#72608) * Update strided copy kernel (#72662) * [PHI] Fix grid sample kernel for big tensor (#72628) * [PHI] Fix argsort big tensor bug (#72712) * [PHI] Fixed argsort big tensor bug * [PHI] Fixed shape mismatch problem. * [PHI] Fix contiguous kernel for big tensor (#72705) * [PHI] Fix flatten and split kernel for big tensor (#72634) * [PHI] Fix out-of-bound issue of paddle.take_along_axis (#72757) * [PHI] fix paddle.diag with big tensor (#72638) * [API] fix paddle.cross with big tensor (#72652) * [PHI] Fix paddle.where api for big tensor (#72717) * [PHI] Fix bincount kernel for big tensor (#72706) * fix bincount kernel for big tensor * use HostAlloc to alloc memory * add cpu test case * [PHI] Fix full_like kernel for big tensor (#72831) * [API] Fix int overflow and float16 support for paddle.frac (#72815) * [PHI] Align paddle.inner with torch in matmul logic (#72843) * [PHI] Fix paddle.var & paddle.std float16 overflow (#72650) * [PHI] Fix logsumexp precision problem (#72681) * [PHI] Debug for logsumexp, bug source found * [PHI] Removed GetNumBlocks func to get correct logsumexp * [PHI] Removed redundant debug VLOG * [PHI] Elegant grid bounded solution * [Accuracy diff No.55-56、76-77] Fix accuracy diff for var&std API (#72879) * [Accuracy diff No.21] Fix accuracy diff for heaviside API (#72894) --------- Co-authored-by: Shuhao Liang <50269654+lshpku@users.noreply.github.com> Co-authored-by: Qianyue He <46109954+Enigmatisms@users.noreply.github.com> Co-authored-by: Lei Ding <69283446+Dmovic@users.noreply.github.com> Co-authored-by: ggggxm <66855582+ggggxm@users.noreply.github.com> Co-authored-by: xkkkkkk23 <xiekeke@baidu.com> Co-authored-by: Zx <zhangxiao35@baidu.com> Co-authored-by: huangjiyi <43315610+huangjiyi@users.noreply.github.com> Co-authored-by: ooo oo <106524776+ooooo-create@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants