Skip to content

Conversation

@HydrogenSulfate
Copy link
Contributor

@HydrogenSulfate HydrogenSulfate commented Jul 21, 2025

PR Category

Operator Mechanism

PR Types

New features

Description

Pcard-75624

支持gather_double_grad动态图二阶组合反向,并删除原有的一阶组合

精度测试代码如下

Details

import paddle import torch import numpy as np for i in range(50): for axis in [-1, -2, 0, 1]: x_np = np.random.randn(50, 50) index_np = np.random.randint(0, x_np.shape[axis], size=[2048]) # paddle x_pd = paddle.to_tensor(x_np, paddle.float32) x_pd.stop_gradient = False index_pd = paddle.to_tensor(index_np, paddle.int64) out_pd = paddle.gather(x_pd, index_pd, axis=axis) dout_pd = paddle.randn_like(out_pd) dout_pd.stop_gradient = False dx_pd, = paddle.grad(out_pd, x_pd, dout_pd, create_graph=True) ddx_pd = paddle.randn_like(dx_pd) ddout_pd, = paddle.grad(dx_pd, dout_pd, ddx_pd, create_graph=True) # torch x_pt = torch.from_dlpack(x_pd.detach()).requires_grad_(True) index_pt = torch.from_dlpack(index_pd.detach()) out_pt = torch.index_select(x_pt, dim=axis, index=index_pt) np.testing.assert_allclose(out_pd.numpy(), out_pt.detach().cpu().numpy()) dout_pt = torch.from_dlpack(dout_pd.detach()).requires_grad_(True) dx_pt, = torch.autograd.grad(out_pt, x_pt, dout_pt, create_graph=True) np.testing.assert_allclose(dx_pd.numpy(), dx_pt.detach().cpu().numpy(), 1e-5, 1e-5) ddx_pt = torch.from_dlpack(ddx_pd.detach()).requires_grad_(True) ddout_pt, = torch.autograd.grad(dx_pt, dout_pt, ddx_pt, create_graph=True) np.testing.assert_allclose(ddout_pd.numpy(), ddout_pt.detach().cpu().numpy(), 1e-5, 1e-5)

@paddle-bot
Copy link

paddle-bot bot commented Jul 21, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@HydrogenSulfate HydrogenSulfate force-pushed the add_gather_double_grad branch from dc09c83 to 0786688 Compare July 23, 2025 02:30
@HydrogenSulfate HydrogenSulfate merged commit f756146 into PaddlePaddle:develop Jul 23, 2025
54 of 55 checks passed
@HydrogenSulfate HydrogenSulfate deleted the add_gather_double_grad branch July 23, 2025 09:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants