@@ -26,13 +26,17 @@ def __init__(self,
2626 theta_shape = (20 , 2 , 3 ),
2727 output_shape = [20 , 2 , 5 , 7 ],
2828 align_corners = True ,
29-  dtype = "float32" ):
29+  dtype = "float32" ,
30+  invalid_theta = False ,
31+  variable_output_shape = False ):
3032 super (AffineGridTestCase , self ).__init__ (methodName )
3133
3234 self .theta_shape  =  theta_shape 
3335 self .output_shape  =  output_shape 
3436 self .align_corners  =  align_corners 
3537 self .dtype  =  dtype 
38+  self .invalid_theta  =  invalid_theta 
39+  self .variable_output_shape  =  variable_output_shape 
3640
3741 def  setUp (self ):
3842 self .theta  =  np .random .randn (* (self .theta_shape )).astype (self .dtype )
@@ -70,9 +74,12 @@ def functional(self, place):
7074 return  y_np 
7175
7276 def  paddle_dygraph_layer (self ):
73-  theta_var  =  dg .to_variable (self .theta )
77+  theta_var  =  dg .to_variable (
78+  self .theta ) if  not  self .invalid_theta  else  "invalid" 
79+  output_shape  =  dg .to_variable (
80+  self .output_shape ) if  variable_output_shape  else  self .output_shape 
7481 y_var  =  F .affine_grid (
75-  theta_var , self . output_shape , align_corners = self .align_corners )
82+  theta_var , output_shape , align_corners = self .align_corners )
7683 y_np  =  y_var .numpy ()
7784 return  y_np 
7885
@@ -108,6 +115,9 @@ def add_cases(suite):
108115 suite .addTest (AffineGridTestCase (methodName = 'runTest' , align_corners = True ))
109116
110117 suite .addTest (AffineGridTestCase (methodName = 'runTest' , align_corners = False ))
118+  suite .addTest (
119+  AffineGridTestCase (
120+  methodName = 'runTest' , variable_output_shape = True ))
111121
112122 suite .addTest (
113123 AffineGridTestCase (
@@ -121,6 +131,10 @@ def add_error_cases(suite):
121131 suite .addTest (
122132 AffineGridErrorTestCase (
123133 methodName = 'runTest' , output_shape = "not_valid" ))
134+  suite .addTest (
135+  AffineGridErrorTestCase (
136+  methodName = 'runTest' ,
137+  invalid_theta = True )) # to test theta not variable error checking 
124138
125139
126140def  load_tests (loader , standard_tests , pattern ):
0 commit comments