Skip to content

Conversation

@HydrogenSulfate
Copy link
Contributor

@HydrogenSulfate HydrogenSulfate commented Jul 2, 2025

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

Pcard-75624

修复index_put_double_grad仅给定ddv 的情况下,ddout输出不对的问题,验证代码如下

Code

def test_index_put_fwd_bwd_double_bwd(): import paddle from paddle.framework import core import numpy as np import torch for place in ["gpu", "cpu"]: for accumulate in [True, False]: for x_shape, indices_shape, value_shape in [ ([16], [5], [5]), ([16, 16], [15, 2], [15]), ([12, 13, 14], [27, 1], [27, 13, 14]), ([12, 13, 14], [27, 2], [27, 14]), ([12, 13, 14], [27, 3], [27]), ([12, 13, 14], [19, 3], [19]), ]: paddle.device.set_device(place) try: core.set_prim_eager_enabled(True) n_indices = indices_shape[0] index_dim_size = indices_shape[1] if len(indices_shape) > 1 else 1 x_np = np.random.randn(*x_shape) indices_np = tuple( [ np.random.randint(0, x_shape[i], [n_indices]) for i in range(max(index_dim_size, 1)) ] ) value_np = np.random.randn(*value_shape).astype("float32") # run paddle x_pd = paddle.to_tensor( x_np.copy(), "float32", stop_gradient=False, place=place ) indices_pd = [ paddle.to_tensor( indice.copy(), "int64", stop_gradient=True, place=place ) for indice in indices_np ] value_pd = paddle.to_tensor( value_np.copy(), "float32", stop_gradient=False, place=place ) out_pd = paddle.index_put( x_pd, indices_pd, value_pd, accumulate=accumulate ) # out_pd = paddle.tanh(out_pd) # dout_np = np.random.randn(*out_pd.shape) dout_pd = paddle.to_tensor( dout_np.copy(), "float32", stop_gradient=False, place=place ) dout_pd.stop_gradient = False dx_pd = paddle.grad(out_pd, x_pd, dout_pd, create_graph=True)[0] ddx_np = np.random.randn(*dx_pd.shape) dvalue_pd = paddle.grad(out_pd, value_pd, dout_pd, create_graph=True)[0] ddvalue_np = np.random.randn(*dvalue_pd.shape) ddx_pd = paddle.to_tensor( ddx_np.copy(), "float32", stop_gradient=False, place=place ) ddvalue_pd = paddle.to_tensor( ddvalue_np.copy(), "float32", stop_gradient=False, place=place ) ddout1_pd = paddle.grad(dx_pd, dout_pd, ddx_pd, create_graph=True)[0] ddout2_pd = paddle.grad( dvalue_pd, dout_pd, ddvalue_pd, create_graph=True )[0] ddout3_pd = paddle.grad( [dvalue_pd, dx_pd], dout_pd, [ddvalue_pd, ddx_pd], create_graph=True )[0] # run torch x_pt = torch.as_tensor( x_np, dtype=torch.float32, device="cuda" if place == "gpu" else place, ).requires_grad_(True) indices_pt = [ torch.as_tensor( indice, dtype=torch.int64, device="cuda" if place == "gpu" else place, ).requires_grad_(False) for indice in indices_np ] value_pt = torch.as_tensor( value_np, dtype=torch.float32, device="cuda" if place == "gpu" else place, ).requires_grad_(True) out_pt = torch.index_put( x_pt, indices_pt, value_pt, accumulate=accumulate ) # out_pt = torch.tanh(out_pt) dout_pt = torch.as_tensor( dout_np, dtype=torch.float32, device="cuda" if place == "gpu" else place, ).requires_grad_(True) dout_pt.stop_gradient = False dx_pt = torch.autograd.grad(out_pt, x_pt, dout_pt, create_graph=True)[0] dvalue_pt = torch.autograd.grad( out_pt, value_pt, dout_pt, create_graph=True )[0] ddx_pt = torch.as_tensor( ddx_np.copy(), dtype=torch.float32, device="cuda" if place == "gpu" else place, ).requires_grad_(True) ddvalue_pt = torch.as_tensor( ddvalue_np.copy(), dtype=torch.float32, device="cuda" if place == "gpu" else place, ).requires_grad_(True) ddout1_pt = torch.autograd.grad( dx_pt, dout_pt, ddx_pt, create_graph=True )[0] ddout2_pt = torch.autograd.grad( dvalue_pt, dout_pt, ddvalue_pt, create_graph=True )[0] ddout3_pt = torch.autograd.grad( [dvalue_pt, dx_pt], dout_pt, [ddvalue_pt, ddx_pt], create_graph=True )[0] # compare result ## output if accumulate: np.testing.assert_allclose( out_pt.detach().cpu().numpy(), out_pd.numpy(), 1.3e-6, 1e-5, err_msg=f"accumulate={accumulate}\nx_np:\n{x_np.shape}\nindices_np:\n{[q.shape for q in indices_np]}\nvalue_np:\n{value_np.shape}", ) ## 1-order grad np.testing.assert_allclose( dx_pt.detach().cpu().numpy(), dx_pd.numpy(), 1.3e-6, 1e-5, err_msg=f"accumulate={accumulate}\nx_np:\n{x_np.shape}\nindices_np:\n{[q.shape for q in indices_np]}\nvalue_np:\n{value_np.shape}", ) # if accumulate: np.testing.assert_allclose( dvalue_pt.detach().cpu().numpy(), dvalue_pd.numpy(), 1.3e-6, 1e-5, err_msg=f"accumulate={accumulate}\nx_np:\n{x_np.shape}\nindices_np:\n{[q.shape for q in indices_np]}\nvalue_np:\n{value_np.shape}", ) ## 2-order grad np.testing.assert_allclose( ddout1_pt.detach().cpu().numpy(), ddout1_pd.numpy(), 1.3e-6, 1e-5, err_msg=f"accumulate={accumulate}\nx_np:\n{x_np.shape}\nindices_np:\n{[q.shape for q in indices_np]}\nvalue_np:\n{value_np.shape}", ) if accumulate: np.testing.assert_allclose( ddout2_pt.detach().cpu().numpy(), ddout2_pd.numpy(), 1.3e-6, 1e-5, err_msg=f"accumulate={accumulate}\nx_np:\n{x_np.shape}\nindices_np:\n{[q for q in indices_np]}\nvalue_np:\n{value_np.shape}\nddout:{dout_np}\nddvalue:{ddvalue_np}\n", ) np.testing.assert_allclose( ddout3_pt.detach().cpu().numpy(), ddout3_pd.numpy(), 1.3e-6, 1e-5, err_msg=f"accumulate={accumulate}\nx_np:\n{x_np.shape}\nindices_np:\n{[q.shape for q in indices_np]}\nvalue_np:\n{value_np.shape}", ) finally: core.set_prim_eager_enabled(False)

@paddle-bot
Copy link

paddle-bot bot commented Jul 3, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@HydrogenSulfate HydrogenSulfate merged commit 57e8475 into PaddlePaddle:develop Jul 3, 2025
77 of 78 checks passed
@HydrogenSulfate HydrogenSulfate deleted the fix_index_put_double_grad branch July 3, 2025 05:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants