1515from functools import reduce
1616
1717import paddle
18- from paddle .fluid .framework import Variable , in_dygraph_mode
19- from paddle .fluid .layer_helper import LayerHelper
20- from paddle .fluid .data_feeder import check_variable_and_dtype
18+ from paddle .fluid .framework import dygraph_only , _dygraph_tracer , _varbase_creator
2119from paddle import _C_ops
2220
2321
22+ #input==output, inplace strategy of reshape has no cost almostly
23+ def _inplace_reshape_dygraph (x , shape ):
24+ x_shape = _varbase_creator (dtype = x .dtype )
25+ _dygraph_tracer ().trace_op (
26+ type = "reshape2" ,
27+ inputs = {'X' : x },
28+ outputs = {'Out' : x ,
29+ 'XShape' : x_shape },
30+ attrs = {'shape' : shape },
31+ stop_gradient = True )
32+
33+
34+ @dygraph_only
2435def parameters_to_vector (parameters , name = None ):
2536 """
2637 Flatten parameters to a 1-D Tensor.
@@ -44,44 +55,25 @@ def parameters_to_vector(parameters, name=None):
4455 # 1-D Tensor: [165]
4556
4657 """
47- vec_list = []
48- if in_dygraph_mode ():
49- for param in parameters :
50- vec , _ = _C_ops .reshape2 (param , None , 'shape' , [- 1 ])
51- vec_list .append (vec )
52- return _C_ops .concat (vec_list , 'axis' , 0 )
53-
54- helper = LayerHelper ("parameters_to_vector" , ** locals ())
55- param_dtype = parameters [0 ].dtype
56- for id , param in enumerate (parameters ):
57- check_variable_and_dtype (
58- param , 'parameters[{}]' .format (id ),
59- ['bool' , 'float16' , 'float32' , 'float64' , 'int32' , 'int64' ],
60- "parameters_to_vector" )
61- if param .dtype != param_dtype :
62- raise TypeError (
63- "All the Tensors in the parameters must have the same data type."
64- )
65- vec = helper .create_variable_for_type_inference (dtype = param_dtype )
66- x_shape = helper .create_variable_for_type_inference (dtype = param_dtype )
67- # use View strategy that don't have Tensor Copy
68- helper .append_op (
69- type = 'reshape2' ,
70- inputs = {'X' : param },
71- outputs = {'Out' : vec ,
72- 'XShape' : x_shape },
73- attrs = {'shape' : [- 1 ]})
74- vec_list .append (vec )
75-
76- param_vec = helper .create_variable_for_type_inference (dtype = param_dtype )
77- helper .append_op (
58+ dtype = parameters [0 ].dtype
59+ origin_shapes = []
60+ for param in parameters :
61+ origin_shapes .append (param .shape )
62+ _inplace_reshape_dygraph (param , [- 1 ])
63+
64+ out = _varbase_creator (dtype = dtype )
65+ _dygraph_tracer ().trace_op (
7866 type = 'concat' ,
79- inputs = {'X' : vec_list },
80- outputs = {'Out' : param_vec },
81- attrs = {'axis' : 0 })
82- return param_vec
67+ inputs = {'X' : parameters },
68+ outputs = {'Out' : [out ]},
69+ attrs = {'axis' : 0 },
70+ stop_gradient = True )
71+ for i , param in enumerate (parameters ):
72+ _inplace_reshape_dygraph (param , origin_shapes [i ])
73+ return out
8374
8475
76+ @dygraph_only
8577def vector_to_parameters (vec , parameters , name = None ):
8678 """
8779 Transform a Tensor with 1-D shape to the parameters.
@@ -109,61 +101,22 @@ def vector_to_parameters(vec, parameters, name=None):
109101 # [..., ..., ...],
110102 # [3. , ..., 3. ]])
111103 """
112- start = 0
113- helper = LayerHelper ("vector_to_parameters" , ** locals ())
114- if in_dygraph_mode ():
115- with paddle .no_grad ():
116- for param in parameters :
117- shape = param .shape
118- numel = reduce (lambda x , y : x * y , shape )
119- end = start + numel
120- slice_data = _C_ops .slice (vec , None , None , 'axes' , [0 ],
121- 'infer_flags' , [1 ], 'starts' ,
122- [start ], 'ends' , [end ])
123- _C_ops .reshape2_ (slice_data , None , 'shape' , shape )
124- helper .append_op (
125- type = 'assign' ,
126- inputs = {'X' : slice_data },
127- outputs = {'Out' : param })
128- start += numel
129- return
130-
131- check_variable_and_dtype (
132- vec , 'x' , ['bool' , 'float16' , 'float32' , 'float64' , 'int32' , 'int64' ],
133- "vector_to_parameters" )
134- assert len (vec .shape ) == 1 , "'vec' must be a Tensor with 1-D shape."
135-
104+ origin_shapes = []
105+ sections = []
136106 for param in parameters :
137107 shape = param .shape
108+ origin_shapes .append (shape )
138109 numel = reduce (lambda x , y : x * y , shape )
139- end = start + numel
140-
141- slice_data = helper .create_variable_for_type_inference (
142- dtype = param .dtype )
143- helper .append_op (
144- type = 'slice' ,
145- inputs = {'Input' : vec },
146- outputs = {'Out' : slice_data },
147- attrs = {
148- 'axes' : [0 ],
149- 'infer_flags' : [1 ],
150- 'starts' : [start ],
151- 'ends' : [end ]
152- })
153-
154- # avoid backward for parameters
155- slice_data .stop_gradient = True
156- x_shape = helper .create_variable_for_type_inference (dtype = param .dtype )
157- out = helper .create_variable_for_type_inference (dtype = param .dtype )
158-
159- # use Inplace strategy that don't have Tensor Copy
160- helper .append_op (
161- type = 'reshape2' ,
162- inputs = {'X' : slice_data },
163- outputs = {'Out' : slice_data ,
164- 'XShape' : x_shape },
165- attrs = {'shape' : shape })
166-
167- helper .append_op (
168- type = 'assign' , inputs = {'X' : slice_data }, outputs = {'Out' : param })
169- start += numel
110+ sections .append (numel )
111+
112+ _dygraph_tracer ().trace_op (
113+ type = 'split' ,
114+ inputs = {'X' : [vec ]},
115+ outputs = {'Out' : parameters },
116+ attrs = {'axis' : 0 ,
117+ 'sections' : sections },
118+ stop_gradient = True )
119+
120+ for i , param in enumerate (parameters ):
121+ _inplace_reshape_dygraph (param , origin_shapes [i ])
122+ return
0 commit comments