Skip to content

Commit 92f49a6

Browse files
[Prim] Add stack_double_grad (#63161)
* add stack_double_grad composite API * add TestStackDoubleGradCheck
1 parent c5f73f6 commit 92f49a6

File tree

4 files changed

+48
-2
lines changed

4 files changed

+48
-2
lines changed

paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
"conv3d_double_grad",
5858
"depthwise_conv2d_grad_grad",
5959
"concat_double_grad",
60+
"stack_double_grad",
6061
"expand_grad",
6162
"argsort_grad",
6263
"eigh_grad",

paddle/phi/api/yaml/backward.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2375,6 +2375,12 @@
23752375
inplace : (out_grad -> x_grad)
23762376
backward: squeeze_double_grad
23772377

2378+
- backward_op : stack_double_grad
2379+
forward : stack_grad (Tensor[] x, Tensor grad_out, int axis=0) -> Tensor[](grad_x)
2380+
args : (Tensor[] grad_x_grad, int axis = 0)
2381+
output : Tensor(grad_out_grad)
2382+
invoke : stack(grad_x_grad, axis)
2383+
23782384
- backward_op : stack_grad
23792385
forward : stack (Tensor[] x, int axis) -> Tensor(out)
23802386
args : (Tensor[] x, Tensor out_grad, int axis)
@@ -2389,6 +2395,7 @@
23892395
data_type : out_grad
23902396
no_need_buffer : x
23912397
composite : stack_grad(x, out_grad, axis, x_grad)
2398+
backward: stack_double_grad
23922399

23932400
- backward_op : stanh_grad
23942401
forward : stanh(Tensor x, float scale_a, float scale_b) -> Tensor(out)

paddle/phi/api/yaml/op_compat.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3225,7 +3225,7 @@
32253225
outputs : [xshape]
32263226

32273227
- op : stack
3228-
backward : stack_grad
3228+
backward : stack_grad, stack_double_grad
32293229
inputs :
32303230
x : X
32313231
outputs :

test/legacy_test/test_nn_grad.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,6 @@ def concat_wrapper(self, x):
405405
@prog_scope()
406406
def func(self, place):
407407
x_shape = [2, 3, 4, 5]
408-
pad = [1, 1, 1, 1]
409408
dtype = np.float64
410409

411410
x1 = paddle.static.data('x', x_shape, dtype)
@@ -437,6 +436,45 @@ def test_grad(self):
437436
self.func(p)
438437

439438

439+
class TestStackDoubleGradCheck(unittest.TestCase):
440+
def stack_wrapper(self, x):
441+
return paddle.stack(x, axis=1)
442+
443+
@test_with_pir_api
444+
@prog_scope()
445+
def func(self, place):
446+
x_shape = [2, 3, 4, 5]
447+
dtype = np.float64
448+
449+
x1 = paddle.static.data('x', x_shape, dtype)
450+
x2 = paddle.static.data('x', x_shape, dtype)
451+
x1.persistable = True
452+
x1.stop_gradient = False
453+
x2.persistable = True
454+
x2.stop_gradient = False
455+
out = paddle.stack([x1, x2], axis=0)
456+
x2_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)
457+
x1_arr = np.random.uniform(-1, 1, x_shape).astype(dtype)
458+
459+
gradient_checker.double_grad_check(
460+
[x1, x2], out, x_init=[x1_arr, x2_arr], place=place
461+
)
462+
gradient_checker.double_grad_check_for_dygraph(
463+
self.stack_wrapper,
464+
[x1, x2],
465+
out,
466+
x_init=[x1_arr, x2_arr],
467+
place=place,
468+
)
469+
470+
def test_grad(self):
471+
places = [base.CPUPlace()]
472+
if core.is_compiled_with_cuda():
473+
places.append(base.CUDAPlace(0))
474+
for p in places:
475+
self.func(p)
476+
477+
440478
class TestAvgPool2DDoubleGradCheckCase1(unittest.TestCase):
441479
@test_with_pir_api
442480
@prog_scope()

0 commit comments

Comments
 (0)