Skip to content

Commit 1ec577a

Browse files
committed
update stack op
1 parent 740fd7d commit 1ec577a

File tree

4 files changed

+95
-4
lines changed

4 files changed

+95
-4
lines changed

paddle/phi/kernels/cpu/stack_grad_kernel.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ void StackGradKernel(const Context& dev_ctx,
3737
}
3838
}
3939
auto dy_data = out.data<T>();
40+
41+
// zero sized tensor case
42+
if (out.numel() == 0) {
43+
for (int i = 0; i < n; i++) {
44+
auto x_grad_dim = x_grad[i]->dims();
45+
x_grad[i]->Resize(x_grad_dim);
46+
}
47+
return;
48+
}
49+
4050
int pre = 1;
4151
for (int i = 0; i < axis; ++i) pre *= static_cast<int>(out.dims()[i]);
4252
int total_num = static_cast<int>(out.numel());

paddle/phi/kernels/cpu/stack_kernel.cc

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,18 @@ void StackKernel(const Context& dev_ctx,
2828

2929
auto x_dims = x[0]->dims();
3030
for (int i = 0; i < x_dims.size(); i++) {
31-
PADDLE_ENFORCE_GT(x_dims[i],
32-
0,
33-
phi::errors::InvalidArgument(
34-
"The dims of Input(X) should be greater than 0"));
31+
PADDLE_ENFORCE_GE(
32+
x_dims[i],
33+
0,
34+
phi::errors::InvalidArgument(
35+
"The dims of Input(X) should be greater than or equal to 0"));
36+
}
37+
// zero sized tensor case
38+
if (x[0]->numel() == 0) {
39+
dev_ctx.template Alloc<T>(out);
40+
auto out_dims = out->dims();
41+
out->Resize(out_dims);
42+
return;
3543
}
3644

3745
int n = static_cast<int>(x.size());

paddle/phi/kernels/funcs/stack_and_unstack.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ void StackRawKernel(const Context& ctx,
7777
if (axis < 0) axis += (x[0]->dims().size() + 1);
7878
int num = static_cast<int>(x.size());
7979

80+
// zero sized tensor case
81+
if (x[0]->numel() == 0) {
82+
ctx.template Alloc<T>(out);
83+
auto out_dims = out->dims();
84+
out->Resize(out_dims);
85+
return;
86+
}
8087
// Split x dim from axis to matrix of shape [x_row, x_col], and the output
8188
// tensor's shape is [x_row, out_col].
8289
int64_t x_row = 1, x_row_bak = 1;
@@ -251,6 +258,15 @@ void UnStackRawKernel(const Context& ctx,
251258
// Input tensor is splited to split_dim tensors along split_dim dimension.
252259
int64_t split_dim = x_dims[axis];
253260

261+
// zero sized tensor case
262+
if (x.numel() == 0) {
263+
for (int i = 0; i < split_dim; i++) {
264+
ctx.template Alloc<T>((*outs)[i]);
265+
auto x_grad_dim = (*outs)[i]->dims();
266+
(*outs)[i]->Resize(x_grad_dim);
267+
}
268+
return;
269+
}
254270
// Treat outs[i] as [out_row, out_col], and x as [out_row, split_dim,
255271
// out_col].
256272
int64_t out_row = 1;

test/legacy_test/test_stack_op.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,5 +452,62 @@ def test_stack_triple_grad(self):
452452
paddle.base.core.set_prim_eager_enabled(False)
453453

454454

455+
class TestStackAPI_ZeroSizedTensor(unittest.TestCase):
456+
def test_dygraph(self):
457+
places = [base.CPUPlace()]
458+
if base.is_compiled_with_cuda():
459+
places.append(base.CUDAPlace(0))
460+
461+
for place in places:
462+
with base.dygraph.guard():
463+
paddle.disable_static(place)
464+
465+
x1 = paddle.ones([1, 0])
466+
x2 = paddle.ones([1, 0])
467+
x1.stop_gradient = False
468+
x2.stop_gradient = False
469+
out = paddle.stack([x1, x2])
470+
out.retain_grads()
471+
out.backward()
472+
473+
np.testing.assert_equal(out.shape, [2, 1, 0])
474+
np.testing.assert_equal(x1.grad, None)
475+
np.testing.assert_equal(x2.grad, None)
476+
np.testing.assert_equal(out, np.ones([2, 1, 0]))
477+
478+
paddle.enable_static()
479+
480+
@test_with_pir_api
481+
def test_static(self):
482+
places = [paddle.CPUPlace()]
483+
if base.is_compiled_with_cuda():
484+
places.append(paddle.CUDAPlace(0))
485+
paddle.enable_static()
486+
for place in places:
487+
with paddle.static.program_guard(
488+
paddle.static.Program(), paddle.static.Program()
489+
):
490+
data1 = paddle.static.data(
491+
'data1', shape=[0, 2], dtype='float64'
492+
)
493+
data2 = paddle.static.data(
494+
'data2', shape=[0, 2], dtype='float64'
495+
)
496+
data3 = paddle.static.data(
497+
'data3', shape=[0, 2], dtype='float64'
498+
)
499+
result_stack = paddle.stack([data1, data2, data3], axis=0)
500+
exe = base.Executor(place)
501+
input1 = np.ones([0, 2]).astype('float64')
502+
input2 = np.ones([0, 2]).astype('float64')
503+
input3 = np.ones([0, 2]).astype('float64')
504+
(result,) = exe.run(
505+
feed={"data1": input1, "data2": input2, "data3": input3},
506+
fetch_list=[result_stack],
507+
)
508+
expected_result = np.stack([input1, input2, input3], axis=0)
509+
np.testing.assert_equal(expected_result, result)
510+
511+
455512
if __name__ == '__main__':
456513
unittest.main()

0 commit comments

Comments
 (0)