Skip to content

Conversation

@zhwesky2010
Copy link
Contributor

@zhwesky2010 zhwesky2010 commented Dec 9, 2021

PR types

New features

PR changes

APIs

Describe

1.New API: paddle.clone

Returns a copy of input Tensor. It will always have a Tensor copy.

Tn addition, the OP provides gradient propagation.

import paddle x = paddle.ones([2], stop_gradient=False) clone_x = paddle.clone(x) y = clone_x**3 y.backward() print(clone_x.grad) # [3] print(x.grad) # [3] 

infoflow 2021-12-14 16-01-21

2.New API: paddle.Tensor.element_size

Returns the size in bytes of an element in the Tensor.

import paddle x = paddle.static.data(name='x', shape=[3, 2], dtype='bool') x.element_size() # 1 x = paddle.static.data(name='x', shape=[3, 2], dtype='int16') x.element_size() # 2 x = paddle.static.data(name='x', shape=[3, 2], dtype='float16') x.element_size() # 2 x = paddle.static.data(name='x', shape=[3, 2], dtype='float32') x.element_size() # 4 x = paddle.static.data(name='x', shape=[3, 2], dtype='float64') x.element_size() # 8 

3.New API:paddle.nn.utils.parameters_to_vector

Flatten parameters to a 1-D Tensor.

import paddle linear = paddle.nn.Linear(10, 15) paddle.nn.utils.parameters_to_vector(linear.parameters()) # 1-D Tensor: [165] 

infoflow 2021-12-14 16-02-15

4.New API:paddle.nn.utils.vector_to_parameters

Transform a Tensor with 1-D shape to the parameters.

import paddle linear1 = paddle.nn.Linear(10, 15, paddle.ParamAttr(paddle.nn.initializer.Constant(3.))) vec = paddle.nn.utils.parameters_to_vector(linear1.parameters()) linear2 = paddle.nn.Linear(10, 15) # copy weight of linear1 to linear2 paddle.nn.utils.vector_to_parameters(vec, linear2.parameters()) # weight: Tensor(shape=[10, 15], dtype=float32, place=CUDAPlace(0), stop_gradient=False, # [[3. , ..., 3. ], # [..., ..., ...], # [3. , ..., 3. ]]) 

infoflow 2021-12-14 16-02-29

@paddle-bot-old
Copy link

paddle-bot-old bot commented Dec 9, 2021

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两句描述不够清楚,可以在改一下

Copy link
Contributor Author

@zhwesky2010 zhwesky2010 Dec 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@zhwesky2010 zhwesky2010 requested a review from Xreki December 15, 2021 03:00
@zhwesky2010 zhwesky2010 force-pushed the add_clone_elementsize branch from 572fcf6 to cd894ae Compare December 20, 2021 06:45
@zhwesky2010 zhwesky2010 force-pushed the add_clone_elementsize branch from cd894ae to 7582067 Compare December 20, 2021 15:23
@zhwesky2010 zhwesky2010 force-pushed the add_clone_elementsize branch 2 times, most recently from 0d28fde to 92be1b7 Compare December 21, 2021 03:38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tn addition -> In addition

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx

@zhwesky2010 zhwesky2010 force-pushed the add_clone_elementsize branch from 92be1b7 to 0d56823 Compare December 22, 2021 02:48
TCChenlong
TCChenlong previously approved these changes Dec 22, 2021
Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

shape = param.shape
numel = reduce(lambda x, y: x * y, shape)
end = start + numel
slice_data = _C_ops.slice(vec, None, None, 'axes', [0],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

能否用1个split代替多个slice

Copy link
Contributor Author

@zhwesky2010 zhwesky2010 Dec 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done,已去掉多个slice

slice_data = _C_ops.slice(vec, None, None, 'axes', [0],
'infer_flags', [1], 'starts',
[start], 'ends', [end])
_C_ops.reshape2_(slice_data, None, 'shape', shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么这里需要显式使用inplace接口呢?

Copy link
Contributor Author

@zhwesky2010 zhwesky2010 Dec 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

core.ops.reshape2_是inplace机制,core.ops.reshape2是view机制。使用view机制的OP是写死的,目前有4个:
infoflow 2021-12-22 14-54-41

inplace机制下:输入、输出是同一个var。
view机制下:需要创建一个输出var,但输出var会和输入var是share_buffer_with的,所以两者都没有成本。
这里指定用inplace版本的,写法会简单些,不需要创建一个临时输出var

helper.append_op(
type='assign',
inputs={'X': slice_data},
outputs={'Out': param})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vector_to_parameters可能会比较慢

Copy link
Contributor Author

@zhwesky2010 zhwesky2010 Dec 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已去掉assign,改成最简化的写法:

parameters_to_vector:使用1个concat OP + 完全inplace的reshape_ OP
vector_to_parameters:使用1个split OP + 完全inplace的reshape_ OP

目前全部使用inplace版的reshape_ OP,每个API实质上只有1个concat或split OP的耗时,性能可以得到保证了

Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhwesky2010 zhwesky2010 reopened this Dec 23, 2021
Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for CMakeLists.txt

@zhwesky2010 zhwesky2010 merged commit 0eb03ed into PaddlePaddle:develop Dec 23, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

5 participants