Skip to content

Commit dd46209

Browse files
committed
Add more unitest for grid sample API
test=develop
1 parent de31410 commit dd46209

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

python/paddle/fluid/tests/unittests/test_affine_grid_function.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,17 @@ def __init__(self,
2626
theta_shape=(20, 2, 3),
2727
output_shape=[20, 2, 5, 7],
2828
align_corners=True,
29-
dtype="float32"):
29+
dtype="float32",
30+
invalid_theta=False,
31+
variable_output_shape=False):
3032
super(AffineGridTestCase, self).__init__(methodName)
3133

3234
self.theta_shape = theta_shape
3335
self.output_shape = output_shape
3436
self.align_corners = align_corners
3537
self.dtype = dtype
38+
self.invalid_theta = invalid_theta
39+
self.variable_output_shape = variable_output_shape
3640

3741
def setUp(self):
3842
self.theta = np.random.randn(*(self.theta_shape)).astype(self.dtype)
@@ -70,9 +74,12 @@ def functional(self, place):
7074
return y_np
7175

7276
def paddle_dygraph_layer(self):
73-
theta_var = dg.to_variable(self.theta)
77+
theta_var = dg.to_variable(
78+
self.theta) if not self.invalid_theta else "invalid"
79+
output_shape = dg.to_variable(
80+
self.output_shape) if variable_output_shape else self.output_shape
7481
y_var = F.affine_grid(
75-
theta_var, self.output_shape, align_corners=self.align_corners)
82+
theta_var, output_shape, align_corners=self.align_corners)
7683
y_np = y_var.numpy()
7784
return y_np
7885

@@ -108,6 +115,9 @@ def add_cases(suite):
108115
suite.addTest(AffineGridTestCase(methodName='runTest', align_corners=True))
109116

110117
suite.addTest(AffineGridTestCase(methodName='runTest', align_corners=False))
118+
suite.addTest(
119+
AffineGridTestCase(
120+
methodName='runTest', variable_output_shape=True))
111121

112122
suite.addTest(
113123
AffineGridTestCase(
@@ -121,6 +131,10 @@ def add_error_cases(suite):
121131
suite.addTest(
122132
AffineGridErrorTestCase(
123133
methodName='runTest', output_shape="not_valid"))
134+
suite.addTest(
135+
AffineGridErrorTestCase(
136+
methodName='runTest',
137+
invalid_theta=True)) # to test theta not variable error checking
124138

125139

126140
def load_tests(loader, standard_tests, pattern):

0 commit comments

Comments
 (0)