Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion python/paddle/nn/utils/transform_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,34 @@ def vector_to_parameters(
Tensor(shape=[], dtype=bool, place=Place(cpu), stop_gradient=True,
True)
"""
assert len(vec.shape) == 1
origin_shapes = []
sections = []
total_elements = 0
for param in parameters:
shape = param.shape
origin_shapes.append(shape)
numel = reduce(lambda x, y: x * y, shape, 1)
total_elements += numel
sections.append(numel)

if len(sections) == 1:
sections.append(0)

if in_dygraph_mode():
with paddle.base.dygraph.no_grad():
res = _C_ops.split(vec, sections, 0)
res = []
if total_elements == vec.shape[0]:
res = _C_ops.split(vec, sections, 0)
elif total_elements < vec.shape[0]:
pointer = 0
for section in sections:
res.append(vec[pointer : pointer + section])
pointer += section
else:
raise ValueError(
"The total_elements of vec should be equal to or larger than the number of elements in parameters."
)
for i in range(0, len(parameters)):
res[i]._share_underline_tensor_to(parameters[i])
else:
Expand Down
42 changes: 42 additions & 0 deletions test/legacy_test/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,47 @@ def test_parambase_to_vector_zero(self):
self.assertEqual(vec.shape, [15])


class TestVectorToParam(unittest.TestCase):
def test_vector_to_param_zerosize(self):
# test the case that the parameters contains zero size tensor
with guard():
vec = paddle.randn([18], dtype='float32')
param1 = paddle.empty([5], dtype='float32')
param2 = paddle.empty([5], dtype='float32')
param3 = paddle.empty([8], dtype='float32')
param4 = paddle.empty([0], dtype='float32')
params = [param1, param2, param3, param4]
paddle.nn.utils.vector_to_parameters(vec, params)
# concat the parameters and get the original vector
vec_ = paddle.concat(params, axis=0)
np.testing.assert_array_equal(vec_.numpy(), vec.numpy())

def test_vector_to_param1(self):
# test the case that the sum of parameter's elements less than vector elements
with guard():
vec = paddle.randn([18], dtype='float32')
param1 = paddle.empty([5], dtype='float32')
param2 = paddle.empty([5], dtype='float32')
param3 = paddle.empty([7], dtype='float32')
params = [param1, param2, param3]
paddle.nn.utils.vector_to_parameters(vec, params)
# concat the parameters and get the original vector
vec_ = paddle.concat(params, axis=0)
np.testing.assert_array_equal(vec_.numpy(), vec[:17].numpy())

def test_vector_to_param2(self):
# test the case that the sum of parameter's elements grater than vector elements
def _test_vector_to_param():
with guard():
vec = paddle.randn([18], dtype='float32')
param1 = paddle.empty([5], dtype='float32')
param2 = paddle.empty([5], dtype='float32')
param3 = paddle.empty([9], dtype='float32')
params = [param1, param2, param3]
paddle.nn.utils.vector_to_parameters(vec, params)

self.assertRaises(ValueError, _test_vector_to_param)


if __name__ == '__main__':
unittest.main()
Loading