Skip to content

Commit b3580ec

Browse files
authored
Merge pull request #8000 from reyoung/feature/make_recognize_digits_normal_unittest
Make recognize digits as a normal python unittest
2 parents c1ac5b6 + 8fb53b7 commit b3580ec

File tree

4 files changed

+55
-41
lines changed

4 files changed

+55
-41
lines changed

paddle/inference/tests/book/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ cc_test(test_inference_recognize_digits_mlp
44
DEPS ARCHIVE_START paddle_fluid ARCHIVE_END
55
ARGS --dirname=${PYTHON_TESTS_DIR}/book/recognize_digits_mlp.inference.model)
66
set_tests_properties(test_inference_recognize_digits_mlp
7-
PROPERTIES DEPENDS test_recognize_digits_mlp_cpu)
7+
PROPERTIES DEPENDS test_recognize_digits)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
recognize_digits_*.inference.model

python/paddle/v2/fluid/tests/book/CMakeLists.txt

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,6 @@
11
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
22
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
33

4-
list(REMOVE_ITEM TEST_OPS test_recognize_digits)
5-
py_test(test_recognize_digits_mlp_cpu
6-
SRCS test_recognize_digits.py
7-
ARGS mlp)
8-
py_test(test_recognize_digits_mlp_cuda
9-
SRCS test_recognize_digits.py
10-
ARGS mlp --use_cuda)
11-
py_test(test_recognize_digits_conv_cpu
12-
SRCS test_recognize_digits.py
13-
ARGS conv)
14-
py_test(test_recognize_digits_conv_cuda
15-
SRCS test_recognize_digits.py
16-
ARGS conv --use_cuda)
17-
py_test(test_recognize_digits_mlp_cpu_parallel
18-
SRCS test_recognize_digits.py
19-
ARGS mlp --parallel)
20-
py_test(test_recognize_digits_mlp_cuda_parallel
21-
SRCS test_recognize_digits.py
22-
ARGS mlp --use_cuda --parallel)
23-
py_test(test_recognize_digits_conv_cpu_parallel
24-
SRCS test_recognize_digits.py
25-
ARGS conv --parallel)
26-
py_test(test_recognize_digits_conv_cuda_parallel
27-
SRCS test_recognize_digits.py
28-
ARGS conv --use_cuda --parallel)
29-
304
# default test
315
foreach(src ${TEST_OPS})
326
py_test(${src} SRCS ${src}.py)

python/paddle/v2/fluid/tests/book/test_recognize_digits.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import paddle.v2 as paddle
1818
import sys
1919
import numpy
20+
import unittest
2021

2122

2223
def parse_arg():
@@ -74,18 +75,18 @@ def conv_net(img, label):
7475
return loss_net(conv_pool_2, label)
7576

7677

77-
def train(args, save_dirname=None):
78-
print("recognize digits with args: {0}".format(" ".join(sys.argv[1:])))
79-
78+
def train(nn_type, use_cuda, parallel, save_dirname):
79+
if use_cuda and not fluid.core.is_compiled_with_cuda():
80+
return
8081
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
8182
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
8283

83-
if args.nn_type == 'mlp':
84+
if nn_type == 'mlp':
8485
net_conf = mlp
8586
else:
8687
net_conf = conv_net
8788

88-
if args.parallel:
89+
if parallel:
8990
places = fluid.layers.get_places()
9091
pd = fluid.layers.ParallelDo(places)
9192
with pd.do():
@@ -107,7 +108,7 @@ def train(args, save_dirname=None):
107108
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
108109
optimizer.minimize(avg_loss)
109110

110-
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
111+
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
111112

112113
exe = fluid.Executor(place)
113114
exe.run(fluid.default_startup_program())
@@ -147,13 +148,14 @@ def train(args, save_dirname=None):
147148
'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'.
148149
format(pass_id, batch_id + 1,
149150
float(avg_loss_val), float(acc_val)))
151+
raise AssertionError("Loss of recognize digits is too large")
150152

151153

152-
def infer(args, save_dirname=None):
154+
def infer(use_cuda, save_dirname=None):
153155
if save_dirname is None:
154156
return
155157

156-
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
158+
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
157159
exe = fluid.Executor(place)
158160

159161
# Use fluid.io.load_inference_model to obtain the inference program desc,
@@ -174,11 +176,48 @@ def infer(args, save_dirname=None):
174176
print("infer results: ", results[0])
175177

176178

177-
if __name__ == '__main__':
178-
args = parse_arg()
179-
if not args.use_cuda and not args.parallel:
180-
save_dirname = "recognize_digits_" + args.nn_type + ".inference.model"
179+
def main(use_cuda, parallel, nn_type):
180+
if not use_cuda and not parallel:
181+
save_dirname = "recognize_digits_" + nn_type + ".inference.model"
181182
else:
182183
save_dirname = None
183-
train(args, save_dirname)
184-
infer(args, save_dirname)
184+
185+
train(
186+
nn_type=nn_type,
187+
use_cuda=use_cuda,
188+
parallel=parallel,
189+
save_dirname=save_dirname)
190+
infer(use_cuda=use_cuda, save_dirname=save_dirname)
191+
192+
193+
class TestRecognizeDigits(unittest.TestCase):
194+
pass
195+
196+
197+
def inject_test_method(use_cuda, parallel, nn_type):
198+
def __impl__(self):
199+
prog = fluid.Program()
200+
startup_prog = fluid.Program()
201+
scope = fluid.core.Scope()
202+
with fluid.scope_guard(scope):
203+
with fluid.program_guard(prog, startup_prog):
204+
main(use_cuda, parallel, nn_type)
205+
206+
fn = 'test_{0}_{1}_{2}'.format(nn_type, 'cuda'
207+
if use_cuda else 'cpu', 'parallel'
208+
if parallel else 'normal')
209+
210+
setattr(TestRecognizeDigits, fn, __impl__)
211+
212+
213+
def inject_all_tests():
214+
for use_cuda in (False, True):
215+
for parallel in (False, True):
216+
for nn_type in ('mlp', 'conv'):
217+
inject_test_method(use_cuda, parallel, nn_type)
218+
219+
220+
inject_all_tests()
221+
222+
if __name__ == '__main__':
223+
unittest.main()

0 commit comments

Comments
 (0)