@@ -114,5 +114,164 @@ def view_api_processing(self, var):
114114 return paddle .flatten (var )
115115
116116
117+ # NOTE(chenfeiyu): TestCases for operators with complex view mechanism.
118+ # a kind of view that shares storage between complex and real tensors
119+ # though the storage retains, the shape and data type of the two tensors differ
120+ class TestDygraphViewAsComplexReuseAllocation (unittest .TestCase ):
121+ def setUp (self ):
122+ self .init_shape ()
123+ self .init_dtype ()
124+
125+ def init_shape (self ):
126+ self .input_shape = [2 , 3 , 2 ]
127+ self .output_shape = [2 , 3 ]
128+
129+ def init_dtype (self ):
130+ self .input_dtype = paddle .float32
131+ self .output_dtype = paddle .complex64
132+
133+ def view_api_processing (self , var ):
134+ return paddle .view_as_complex (var )
135+
136+ def test_view_api (self ):
137+ var = paddle .rand (self .input_shape , dtype = self .input_dtype )
138+ view_var = self .view_api_processing (var )
139+
140+ self .assertEqual (var .shape , self .input_shape )
141+ self .assertEqual (view_var .shape , self .output_shape )
142+
143+ view_var .add_ (paddle .to_tensor ([1.0 ]))
144+ var_numpy = var .numpy ()
145+ view_var_numpy = view_var .numpy ()
146+ self .assertTrue (
147+ np .array_equal (
148+ np .take (
149+ var_numpy , 0 , axis = - 1 ), view_var_numpy .real ))
150+ self .assertTrue (
151+ np .array_equal (
152+ np .take (
153+ var_numpy , 1 , axis = - 1 ), view_var_numpy .imag ))
154+
155+ def test_forward_version (self ):
156+ var = paddle .rand (self .input_shape , dtype = self .input_dtype )
157+ self .assertEqual (var .inplace_version , 0 )
158+ view_var = self .view_api_processing (var )
159+ self .assertEqual (view_var .inplace_version , 0 )
160+
161+ view_var .add_ (paddle .to_tensor ([1.0 ]))
162+ self .assertEqual (var .inplace_version , 1 )
163+ self .assertEqual (view_var .inplace_version , 1 )
164+
165+ view_var_2 = self .view_api_processing (var )
166+ self .assertEqual (view_var_2 .inplace_version , 1 )
167+
168+ var .add_ (paddle .to_tensor ([1.0 ]))
169+ self .assertEqual (view_var .inplace_version , 2 )
170+ self .assertEqual (view_var_2 .inplace_version , 2 )
171+
172+ def test_backward_error (self ):
173+ # It raises an error because the inplace operator will result
174+ # in incorrect gradient computation.
175+ with paddle .fluid .dygraph .guard ():
176+ var_a = paddle .ones (shape = self .input_shape , dtype = self .input_dtype )
177+ var_a .stop_gradient = False
178+
179+ var_b = var_a ** 2
180+
181+ # Here, the gradient computation will use the value of var_b
182+ var_c = var_b ** 2
183+ view_var_b = self .view_api_processing (var_b )
184+ view_var_b .add_ (paddle .to_tensor (
185+ [1.0 ])) # var_b is modified inplace
186+
187+ loss = paddle .nn .functional .relu (var_c )
188+ with self .assertRaisesRegexp (
189+ RuntimeError ,
190+ "received tensor_version:{} != wrapper_version_snapshot:{}" .
191+ format (1 , 0 )):
192+ loss .backward ()
193+
194+
195+ class TestDygraphViewAsRealReuseAllocation (unittest .TestCase ):
196+ def setUp (self ):
197+ self .init_shape ()
198+ self .init_dtype ()
199+
200+ def init_shape (self ):
201+ self .input_shape = [2 , 3 ]
202+ self .output_shape = [2 , 3 , 2 ]
203+
204+ def init_dtype (self ):
205+ self .input_dtype = paddle .complex64
206+ self .output_dtype = paddle .float32
207+
208+ def view_api_processing (self , var ):
209+ return paddle .view_as_real (var )
210+
211+ def test_view_api (self ):
212+ var = (
213+ paddle .rand (self .input_shape ) + 1j * paddle .rand (self .input_shape )
214+ ).astype (self .input_dtype )
215+ view_var = self .view_api_processing (var )
216+
217+ self .assertEqual (var .shape , self .input_shape )
218+ self .assertEqual (view_var .shape , self .output_shape )
219+
220+ view_var [0 , 0 , 0 ] = 2.0
221+ var_numpy = var .numpy ()
222+ view_var_numpy = view_var .numpy ()
223+ self .assertTrue (
224+ np .array_equal (
225+ np .take (
226+ view_var_numpy , 0 , axis = - 1 ), var_numpy .real ))
227+ self .assertTrue (
228+ np .array_equal (
229+ np .take (
230+ view_var_numpy , 1 , axis = - 1 ), var_numpy .imag ))
231+
232+ def test_forward_version (self ):
233+ var = (
234+ paddle .rand (self .input_shape ) + 1j * paddle .rand (self .input_shape )
235+ ).astype (self .input_dtype )
236+ self .assertEqual (var .inplace_version , 0 )
237+
238+ view_var = self .view_api_processing (var )
239+ self .assertEqual (view_var .inplace_version , 0 )
240+
241+ view_var [0 , 0 , 0 ] = 2.0
242+ self .assertEqual (var .inplace_version , 1 )
243+ self .assertEqual (view_var .inplace_version , 1 )
244+
245+ view_var_2 = self .view_api_processing (var )
246+ self .assertEqual (view_var_2 .inplace_version , 1 )
247+
248+ var .add_ (paddle .to_tensor ([1.0 ]))
249+ self .assertEqual (view_var .inplace_version , 2 )
250+ self .assertEqual (view_var_2 .inplace_version , 2 )
251+
252+ def test_backward_error (self ):
253+ # It raises an error because the inplace operator will result
254+ # in incorrect gradient computation.
255+ with paddle .fluid .dygraph .guard ():
256+ var_a = (paddle .ones (shape = self .input_shape ) + 1j * paddle .ones (
257+ shape = self .input_shape )).astype (self .input_dtype )
258+ var_a .stop_gradient = False
259+
260+ var_b = paddle .conj (var_a )
261+
262+ # Here, the gradient computation will use the value of var_b
263+ var_c = var_b * var_b
264+ view_var_b = self .view_api_processing (var_b )
265+ view_var_b .add_ (paddle .to_tensor (
266+ [1.0 ])) # var_b is modified inplace
267+
268+ loss = paddle .abs (var_c ).sum ()
269+ with self .assertRaisesRegexp (
270+ RuntimeError ,
271+ "received tensor_version:{} != wrapper_version_snapshot:{}" .
272+ format (1 , 0 )):
273+ loss .backward ()
274+
275+
117276if __name__ == "__main__" :
118277 unittest .main ()
0 commit comments