Skip to content

Commit 9c1b5c4

Browse files
authored
[0-size Tensor Job2 No.97] Add 0-size Tensor support for paddle.Tensor.set_[fluid_ops] (#74200)
* Fix * Fix * Fix
1 parent 4da463e commit 9c1b5c4

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

paddle/phi/kernels/set_kernel.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// limitations under the License.
1414
#include "paddle/phi/kernels/set_kernel.h"
1515
#include "paddle/phi/core/kernel_registry.h"
16-
16+
#include "paddle/phi/kernels/full_kernel.h"
1717
namespace phi {
1818

1919
template <typename T, typename Context>
@@ -28,6 +28,17 @@ void SetKernel(const Context& dev_ctx,
2828
meta.dims = DDim(dims.data(), static_cast<int>(dims.size()));
2929
meta.strides = DDim(stride.data(), static_cast<int>(stride.size()));
3030
meta.offset = offset;
31+
if (x.numel() == 0 || source.numel() == 0) {
32+
if (source.numel() != 0) {
33+
out->clear();
34+
*out = DenseTensor{source.Holder(), meta};
35+
} else if (x.numel() == 0) {
36+
phi::Full<T, Context>(
37+
dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out);
38+
}
39+
out->ShareInplaceVersionCounterWith(x);
40+
return;
41+
}
3142
if (x.IsSharedWith(source)) {
3243
out->set_meta(meta);
3344
} else {

test/legacy_test/test_inplace.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import unittest
1717

1818
import numpy as np
19+
from op_test import get_places
1920

2021
import paddle
2122

@@ -2491,5 +2492,16 @@ def test_inplace_api(self):
24912492
self.assertTrue(id(x) == id(inplace_x2))
24922493

24932494

2495+
class TestSet_API_ZeroSize(unittest.TestCase):
2496+
def setUp(self):
2497+
self.places = get_places()
2498+
2499+
def test_set_api(self):
2500+
for place in self.places:
2501+
with paddle.base.dygraph.guard(place):
2502+
out = paddle.randn([20]).set_(paddle.randn([0, 3]), [20], [2])
2503+
np.testing.assert_allclose(out.shape, [20])
2504+
2505+
24942506
if __name__ == '__main__':
24952507
unittest.main()

0 commit comments

Comments
 (0)