Skip to content

Conversation

@aquagull
Copy link
Contributor

@aquagull aquagull commented Apr 9, 2025

PR Category

User Experience

PR Types

Bug fixes

Description

当输入都为int32时,输出会是int64。

import paddle import torch x1 = paddle.zeros(shape=[0, 1, 1], dtype='int32') x2 = paddle.zeros(shape=[1,1], dtype='int32') res = paddle.tensordot(x1, x2) print(res.dtype) 
@paddle-bot
Copy link

paddle-bot bot commented Apr 9, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Apr 9, 2025
Comment on lines +6013 to +6016
y = y.sum(dim_y, dtype=y.dtype).reshape(shape_y)
elif sy == 1:
shape_x[dim_x] = 1
x = x.sum(dim_x).reshape(shape_x)
x = x.sum(dim_x, dtype=x.dtype).reshape(shape_x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的逻辑应该是中间结果用int64存,然后最终结果转成int32,还是保持中间结果也是int32?如果输入中含有接近int32最大值的场景,哪一种比较合适呢?可以跟numpy或者torch对比一下这种场景下的数值溢出情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个修改的意图是明确指定 sum 操作的输出数据类型与其输入数据类型保持一致。溢出这一块应该是由开发者处理,框架应该保证正确性(未指定dtype,输入int32,输出int32)

Copy link
Contributor

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HydrogenSulfate HydrogenSulfate merged commit 2925993 into PaddlePaddle:develop Apr 17, 2025
33 of 34 checks passed
YqGe585 pushed a commit to YqGe585/Paddle that referenced this pull request May 7, 2025
@aquagull aquagull deleted the tensordot branch August 8, 2025 08:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

2 participants