Skip to content

Commit 347d212

Browse files
authored
[Zero-Dim] reshape/reshape_/reverse 0D support (PaddlePaddle#49357)
* [Zero-Dim] reshape/reshape_/reverse 0D support * rm comment * change paddle.to_tensor to paddle.full * fix docs * update paddle.full
1 parent 021085e commit 347d212

File tree

5 files changed

+267
-15
lines changed

5 files changed

+267
-15
lines changed

paddle/fluid/operators/reshape_op.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
114114
return;
115115
}
116116

117-
PADDLE_ENFORCE_EQ(!shape.empty(),
118-
true,
119-
platform::errors::InvalidArgument(
120-
"The parameter 'shape' in ReshapeOp must be set. "
121-
"But received 'shape' is empty."));
122117
auto x_dims = ctx->GetInputDim("X");
123118
auto out_dims = ValidateShape(shape, x_dims);
124119
ctx->SetOutputDim("Out", out_dims);

python/paddle/fluid/tests/unittests/test_reshape_op.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,20 @@ def test_check_grad(self):
4949
class TestReshapeOp_ZeroDim1(OpTest):
5050
def init_data(self):
5151
self.ori_shape = ()
52-
self.new_shape = 1
53-
self.infered_shape = 1
52+
self.new_shape = (1,)
53+
self.infered_shape = (1,)
5454

5555

5656
class TestReshapeOp_ZeroDim2(OpTest):
5757
def init_data(self):
5858
self.ori_shape = ()
59-
self.new_shape = -1
60-
self.infered_shape = 1
59+
self.new_shape = (-1,)
60+
self.infered_shape = (1,)
6161

6262

6363
class TestReshapeOp_ZeroDim3(OpTest):
6464
def init_data(self):
65-
self.ori_shape = 1
65+
self.ori_shape = (1,)
6666
self.new_shape = ()
6767
self.infered_shape = ()
6868

python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,105 @@ def test_floor_divide(self):
756756
np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy())
757757
np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1))
758758

759+
def test_reshape_list(self):
760+
x = paddle.rand([])
761+
x.stop_gradient = False
762+
763+
out = paddle.reshape(x, [])
764+
out.backward()
765+
self.assertEqual(x.grad.shape, [])
766+
self.assertEqual(out.shape, [])
767+
self.assertEqual(out.grad.shape, [])
768+
769+
out = paddle.reshape(x, [1])
770+
out.backward()
771+
self.assertEqual(x.grad.shape, [])
772+
self.assertEqual(out.shape, [1])
773+
self.assertEqual(out.grad.shape, [1])
774+
775+
out = paddle.reshape(x, [-1])
776+
out.backward()
777+
self.assertEqual(x.grad.shape, [])
778+
self.assertEqual(out.shape, [1])
779+
self.assertEqual(out.grad.shape, [1])
780+
781+
out = paddle.reshape(x, [-1, 1])
782+
out.backward()
783+
self.assertEqual(x.grad.shape, [])
784+
self.assertEqual(out.shape, [1, 1])
785+
self.assertEqual(out.grad.shape, [1, 1])
786+
787+
def test_reshape_tensor(self):
788+
x = paddle.rand([1, 1])
789+
x.stop_gradient = False
790+
791+
out = paddle.reshape(x, [])
792+
out.backward()
793+
self.assertEqual(x.grad.shape, [1, 1])
794+
self.assertEqual(out.shape, [])
795+
self.assertEqual(out.grad.shape, [])
796+
797+
new_shape = paddle.full([1], 1, "int32")
798+
out = paddle.reshape(x, new_shape)
799+
out.backward()
800+
self.assertEqual(x.grad.shape, [1, 1])
801+
self.assertEqual(out.shape, [1])
802+
self.assertEqual(out.grad.shape, [1])
803+
804+
new_shape = paddle.full([1], -1, "int32")
805+
out = paddle.reshape(x, new_shape)
806+
out.backward()
807+
self.assertEqual(x.grad.shape, [1, 1])
808+
self.assertEqual(out.shape, [1])
809+
self.assertEqual(out.grad.shape, [1])
810+
811+
new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")]
812+
out = paddle.reshape(x, new_shape)
813+
out.backward()
814+
self.assertEqual(x.grad.shape, [1, 1])
815+
self.assertEqual(out.shape, [1, 1])
816+
self.assertEqual(out.grad.shape, [1, 1])
817+
818+
def test_reshape__list(self):
819+
x = paddle.rand([])
820+
out = paddle.reshape_(x, [])
821+
self.assertEqual(out.shape, [])
822+
823+
out = paddle.reshape_(x, [1])
824+
self.assertEqual(out.shape, [1])
825+
826+
out = paddle.reshape_(x, [-1])
827+
self.assertEqual(out.shape, [1])
828+
829+
out = paddle.reshape_(x, [-1, 1])
830+
self.assertEqual(out.shape, [1, 1])
831+
832+
def test_reshape__tensor(self):
833+
x = paddle.rand([1, 1])
834+
out = paddle.reshape_(x, [])
835+
self.assertEqual(out.shape, [])
836+
837+
new_shape = paddle.full([1], 1, "int32")
838+
out = paddle.reshape_(x, new_shape)
839+
self.assertEqual(out.shape, [1])
840+
841+
new_shape = paddle.full([1], -1, "int32")
842+
out = paddle.reshape_(x, new_shape)
843+
self.assertEqual(out.shape, [1])
844+
845+
new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")]
846+
out = paddle.reshape_(x, new_shape)
847+
self.assertEqual(out.shape, [1, 1])
848+
849+
def test_reverse(self):
850+
x = paddle.rand([])
851+
x.stop_gradient = False
852+
out = paddle.reverse(x, axis=[])
853+
out.backward()
854+
self.assertEqual(x.shape, [])
855+
self.assertEqual(out.shape, [])
856+
self.assertEqual(out.grad.shape, [])
857+
759858

760859
class TestSundryAPIStatic(unittest.TestCase):
761860
def setUp(self):
@@ -1011,6 +1110,78 @@ def test_floor_divide(self):
10111110
np.testing.assert_array_equal(out3_1, out3_2)
10121111
np.testing.assert_array_equal(out3_2, np.asarray(1))
10131112

1113+
@prog_scope()
1114+
def test_reshape_list(self):
1115+
x1 = paddle.rand([])
1116+
x2 = paddle.rand([])
1117+
x3 = paddle.rand([])
1118+
x4 = paddle.rand([])
1119+
x1.stop_gradient = False
1120+
x2.stop_gradient = False
1121+
x3.stop_gradient = False
1122+
x4.stop_gradient = False
1123+
1124+
out1 = paddle.reshape(x1, [])
1125+
paddle.static.append_backward(out1)
1126+
1127+
out2 = paddle.reshape(x2, [1])
1128+
paddle.static.append_backward(out2)
1129+
1130+
out3 = paddle.reshape(x3, [-1])
1131+
paddle.static.append_backward(out3)
1132+
1133+
out4 = paddle.reshape(x4, [-1, 1])
1134+
paddle.static.append_backward(out4)
1135+
1136+
program = paddle.static.default_main_program()
1137+
res1, res2, res3, res4 = self.exe.run(
1138+
program, fetch_list=[out1, out2, out3, out4]
1139+
)
1140+
self.assertEqual(res1.shape, ())
1141+
self.assertEqual(res2.shape, (1,))
1142+
self.assertEqual(res3.shape, (1,))
1143+
self.assertEqual(res4.shape, (1, 1))
1144+
1145+
@prog_scope()
1146+
def test_reshape_tensor(self):
1147+
x1 = paddle.rand([])
1148+
x2 = paddle.rand([])
1149+
x3 = paddle.rand([])
1150+
x1.stop_gradient = False
1151+
x2.stop_gradient = False
1152+
x3.stop_gradient = False
1153+
1154+
new_shape = paddle.full([1], 1, "int32")
1155+
out1 = paddle.reshape(x1, new_shape)
1156+
paddle.static.append_backward(out1)
1157+
1158+
new_shape = paddle.full([1], -1, "int32")
1159+
out2 = paddle.reshape(x2, new_shape)
1160+
paddle.static.append_backward(out2)
1161+
1162+
new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")]
1163+
out3 = paddle.reshape(x3, new_shape)
1164+
paddle.static.append_backward(out3)
1165+
1166+
program = paddle.static.default_main_program()
1167+
res1, res2, res3 = self.exe.run(program, fetch_list=[out1, out2, out3])
1168+
self.assertEqual(res1.shape, (1,))
1169+
self.assertEqual(res2.shape, (1,))
1170+
self.assertEqual(res3.shape, (1, 1))
1171+
1172+
@prog_scope()
1173+
def test_reverse(self):
1174+
x = paddle.rand([])
1175+
x.stop_gradient = False
1176+
1177+
out = paddle.reverse(x, axis=[])
1178+
paddle.static.append_backward(out)
1179+
1180+
program = paddle.static.default_main_program()
1181+
res1, res2 = self.exe.run(program, fetch_list=[x, out])
1182+
self.assertEqual(res1.shape, ())
1183+
self.assertEqual(res2.shape, ())
1184+
10141185

10151186
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
10161187
class TestNoBackwardAPI(unittest.TestCase):

python/paddle/fluid/tests/unittests/xpu/test_zero_dim_tensor_xpu.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,96 @@ def test_floor_divide(self):
556556
np.testing.assert_array_equal(out3_1.numpy(), out3_2.numpy())
557557
np.testing.assert_array_equal(out3_2.numpy(), np.asarray(1))
558558

559+
def test_reshape_list(self):
560+
x = paddle.rand([])
561+
x.stop_gradient = False
562+
563+
out = paddle.reshape(x, [])
564+
out.backward()
565+
self.assertEqual(x.grad.shape, [])
566+
self.assertEqual(out.shape, [])
567+
self.assertEqual(out.grad.shape, [])
568+
569+
out = paddle.reshape(x, [1])
570+
out.backward()
571+
self.assertEqual(x.grad.shape, [])
572+
self.assertEqual(out.shape, [1])
573+
self.assertEqual(out.grad.shape, [1])
574+
575+
out = paddle.reshape(x, [-1])
576+
out.backward()
577+
self.assertEqual(x.grad.shape, [])
578+
self.assertEqual(out.shape, [1])
579+
self.assertEqual(out.grad.shape, [1])
580+
581+
out = paddle.reshape(x, [-1, 1])
582+
out.backward()
583+
self.assertEqual(x.grad.shape, [])
584+
self.assertEqual(out.shape, [1, 1])
585+
self.assertEqual(out.grad.shape, [1, 1])
586+
587+
def test_reshape_tensor(self):
588+
x = paddle.rand([1, 1])
589+
x.stop_gradient = False
590+
591+
out = paddle.reshape(x, [])
592+
out.backward()
593+
self.assertEqual(x.grad.shape, [1, 1])
594+
self.assertEqual(out.shape, [])
595+
self.assertEqual(out.grad.shape, [])
596+
597+
new_shape = paddle.full([], 1, "int32")
598+
out = paddle.reshape(x, new_shape)
599+
out.backward()
600+
self.assertEqual(x.grad.shape, [1, 1])
601+
self.assertEqual(out.shape, [1])
602+
self.assertEqual(out.grad.shape, [1])
603+
604+
new_shape = paddle.full([], -1, "int32")
605+
out = paddle.reshape(x, new_shape)
606+
out.backward()
607+
self.assertEqual(x.grad.shape, [1, 1])
608+
self.assertEqual(out.shape, [1])
609+
self.assertEqual(out.grad.shape, [1])
610+
611+
new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")]
612+
out = paddle.reshape(x, new_shape)
613+
out.backward()
614+
self.assertEqual(x.grad.shape, [1, 1])
615+
self.assertEqual(out.shape, [1, 1])
616+
self.assertEqual(out.grad.shape, [1, 1])
617+
618+
def test_reshape__list(self):
619+
x = paddle.rand([])
620+
out = paddle.reshape_(x, [])
621+
self.assertEqual(out.shape, [])
622+
623+
out = paddle.reshape_(x, [1])
624+
self.assertEqual(out.shape, [1])
625+
626+
out = paddle.reshape_(x, [-1])
627+
self.assertEqual(out.shape, [1])
628+
629+
out = paddle.reshape_(x, [-1, 1])
630+
self.assertEqual(out.shape, [1, 1])
631+
632+
def test_reshape__tensor(self):
633+
x = paddle.rand([1, 1])
634+
out = paddle.reshape_(x, [])
635+
self.assertEqual(out.shape, [])
636+
637+
new_shape = paddle.full([1], 1, "int32")
638+
out = paddle.reshape_(x, new_shape)
639+
self.assertEqual(out.shape, [1])
640+
641+
new_shape = paddle.full([1], -1, "int32")
642+
out = paddle.reshape_(x, new_shape)
643+
self.assertEqual(out.shape, [1])
644+
645+
new_shape = [paddle.full([], -1, "int32"), paddle.full([], 1, "int32")]
646+
out = paddle.reshape_(x, new_shape)
647+
self.assertEqual(out.shape, [1, 1])
648+
559649

560650
# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
561651
class TestNoBackwardAPI(unittest.TestCase):

python/paddle/tensor/manipulation.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3450,7 +3450,7 @@ def reshape(x, shape, name=None):
34503450
Args:
34513451
x (Tensor): An N-D Tensor. The data type is ``float32``, ``float64``, ``int32``, ``int64`` or ``bool``
34523452
shape (list|tuple|Tensor): Define the target shape. At most one dimension of the target shape can be -1.
3453-
The data type is ``int32`` . If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
3453+
The data type is ``int32`` . If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [].
34543454
If ``shape`` is an Tensor, it should be an 1-D Tensor .
34553455
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
34563456
@@ -3574,10 +3574,6 @@ def get_attr_shape(list_shape):
35743574
shape.stop_gradient = True
35753575
inputs["Shape"] = shape
35763576
elif isinstance(shape, (list, tuple)):
3577-
assert len(shape) > 0, (
3578-
"The size of 'shape' in reshape can't be zero, "
3579-
"but received %s." % len(shape)
3580-
)
35813577
attrs["shape"] = get_attr_shape(shape)
35823578
if utils._contain_var(shape):
35833579
inputs['ShapeTensor'] = utils._convert_to_tensor_list(shape)

0 commit comments

Comments
 (0)