1717import numpy as np
1818import paddle
1919import paddle .fluid as fluid
20+ from paddle .fluid .framework import _test_eager_guard
2021
2122from op_test import OpTest
2223
2324
25+ def graph_send_recv_wrapper (x ,
26+ src_index ,
27+ dst_index ,
28+ pool_type = "sum" ,
29+ out_size = None ,
30+ name = None ):
31+ return paddle .incubate .graph_send_recv (x , src_index , dst_index ,
32+ pool_type .lower (), out_size , name )
33+
34+
2435class TestGraphSendRecvMaxOp (OpTest ):
2536 def setUp (self ):
2637 paddle .enable_static ()
38+ self .python_api = graph_send_recv_wrapper
39+ self .python_out_sig = ["Out" ]
2740 self .op_type = "graph_send_recv"
2841 x = np .random .random ((10 , 20 )).astype ("float64" )
2942 index = np .random .randint (0 , 10 , (15 , 2 )).astype (np .int64 )
@@ -39,15 +52,18 @@ def setUp(self):
3952 self .outputs = {'Out' : out }
4053
4154 def test_check_output (self ):
42- self .check_output ()
55+ self .check_output (check_eager = True )
4356
4457 def test_check_grad (self ):
45- self .check_grad (['X' ], 'Out' , user_defined_grads = [self .gradient ])
58+ self .check_grad (
59+ ['X' ], 'Out' , user_defined_grads = [self .gradient ], check_eager = True )
4660
4761
4862class TestGraphSendRecvMinOp (OpTest ):
4963 def setUp (self ):
5064 paddle .enable_static ()
65+ self .python_api = graph_send_recv_wrapper
66+ self .python_out_sig = ["Out" ]
5167 self .op_type = "graph_send_recv"
5268 x = np .random .random ((10 , 20 )).astype ("float64" )
5369 index = np .random .randint (0 , 10 , (15 , 2 )).astype (np .int64 )
@@ -64,15 +80,18 @@ def setUp(self):
6480 self .outputs = {'Out' : out }
6581
6682 def test_check_output (self ):
67- self .check_output ()
83+ self .check_output (check_eager = True )
6884
6985 def test_check_grad (self ):
70- self .check_grad (['X' ], 'Out' , user_defined_grads = [self .gradient ])
86+ self .check_grad (
87+ ['X' ], 'Out' , user_defined_grads = [self .gradient ], check_eager = True )
7188
7289
7390class TestGraphSendRecvSumOp (OpTest ):
7491 def setUp (self ):
7592 paddle .enable_static ()
93+ self .python_api = graph_send_recv_wrapper
94+ self .python_out_sig = ["Out" ]
7695 self .op_type = "graph_send_recv"
7796 x = np .random .random ((10 , 20 )).astype ("float64" )
7897 index = np .random .randint (0 , 10 , (15 , 2 )).astype (np .int64 )
@@ -88,15 +107,17 @@ def setUp(self):
88107 self .outputs = {'Out' : out }
89108
90109 def test_check_output (self ):
91- self .check_output ()
110+ self .check_output (check_eager = True )
92111
93112 def test_check_grad (self ):
94- self .check_grad (['X' ], 'Out' )
113+ self .check_grad (['X' ], 'Out' , check_eager = True )
95114
96115
97116class TestGraphSendRecvMeanOp (OpTest ):
98117 def setUp (self ):
99118 paddle .enable_static ()
119+ self .python_api = graph_send_recv_wrapper
120+ self .python_out_sig = ["Out" ]
100121 self .op_type = "graph_send_recv"
101122 x = np .random .random ((10 , 20 )).astype ("float64" )
102123 index = np .random .randint (0 , 10 , (15 , 2 )).astype (np .int64 )
@@ -113,10 +134,10 @@ def setUp(self):
113134 self .outputs = {'Out' : out , 'Dst_count' : dst_count }
114135
115136 def test_check_output (self ):
116- self .check_output ()
137+ self .check_output (check_eager = True )
117138
118139 def test_check_grad (self ):
119- self .check_grad (['X' ], 'Out' )
140+ self .check_grad (['X' ], 'Out' , check_eager = True )
120141
121142
122143def compute_graph_send_recv_for_sum_mean (inputs , attributes ):
@@ -333,6 +354,12 @@ def test_set_outsize_gpu(self):
333354 {}\n {}, check diff!"
334355 .format (np_res_set_outsize , res_set_outsize ))
335356
357+ def test_api_eager_dygraph (self ):
358+ with _test_eager_guard ():
359+ self .test_dygraph ()
360+ self .test_int32_input ()
361+ self .test_set_outsize_gpu ()
362+
336363
337364if __name__ == '__main__' :
338365 unittest .main ()
0 commit comments