1717import paddle .v2 as paddle
1818import sys
1919import numpy
20+ import unittest
2021
2122
2223def parse_arg ():
@@ -74,18 +75,18 @@ def conv_net(img, label):
7475 return loss_net (conv_pool_2 , label )
7576
7677
77- def train (args , save_dirname = None ):
78- print ( "recognize digits with args: {0}" . format ( " " . join ( sys . argv [ 1 :])))
79-
78+ def train (nn_type , use_cuda , parallel , save_dirname ):
79+ if use_cuda and not fluid . core . is_compiled_with_cuda ():
80+ return
8081 img = fluid .layers .data (name = 'img' , shape = [1 , 28 , 28 ], dtype = 'float32' )
8182 label = fluid .layers .data (name = 'label' , shape = [1 ], dtype = 'int64' )
8283
83- if args . nn_type == 'mlp' :
84+ if nn_type == 'mlp' :
8485 net_conf = mlp
8586 else :
8687 net_conf = conv_net
8788
88- if args . parallel :
89+ if parallel :
8990 places = fluid .layers .get_places ()
9091 pd = fluid .layers .ParallelDo (places )
9192 with pd .do ():
@@ -107,7 +108,7 @@ def train(args, save_dirname=None):
107108 optimizer = fluid .optimizer .Adam (learning_rate = 0.001 )
108109 optimizer .minimize (avg_loss )
109110
110- place = fluid .CUDAPlace (0 ) if args . use_cuda else fluid .CPUPlace ()
111+ place = fluid .CUDAPlace (0 ) if use_cuda else fluid .CPUPlace ()
111112
112113 exe = fluid .Executor (place )
113114 exe .run (fluid .default_startup_program ())
@@ -147,13 +148,14 @@ def train(args, save_dirname=None):
147148 'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}' .
148149 format (pass_id , batch_id + 1 ,
149150 float (avg_loss_val ), float (acc_val )))
151+ raise AssertionError ("Loss of recognize digits is too large" )
150152
151153
152- def infer (args , save_dirname = None ):
154+ def infer (use_cuda , save_dirname = None ):
153155 if save_dirname is None :
154156 return
155157
156- place = fluid .CUDAPlace (0 ) if args . use_cuda else fluid .CPUPlace ()
158+ place = fluid .CUDAPlace (0 ) if use_cuda else fluid .CPUPlace ()
157159 exe = fluid .Executor (place )
158160
159161 # Use fluid.io.load_inference_model to obtain the inference program desc,
@@ -174,11 +176,48 @@ def infer(args, save_dirname=None):
174176 print ("infer results: " , results [0 ])
175177
176178
177- if __name__ == '__main__' :
178- args = parse_arg ()
179- if not args .use_cuda and not args .parallel :
180- save_dirname = "recognize_digits_" + args .nn_type + ".inference.model"
179+ def main (use_cuda , parallel , nn_type ):
180+ if not use_cuda and not parallel :
181+ save_dirname = "recognize_digits_" + nn_type + ".inference.model"
181182 else :
182183 save_dirname = None
183- train (args , save_dirname )
184- infer (args , save_dirname )
184+
185+ train (
186+ nn_type = nn_type ,
187+ use_cuda = use_cuda ,
188+ parallel = parallel ,
189+ save_dirname = save_dirname )
190+ infer (use_cuda = use_cuda , save_dirname = save_dirname )
191+
192+
193+ class TestRecognizeDigits (unittest .TestCase ):
194+ pass
195+
196+
197+ def inject_test_method (use_cuda , parallel , nn_type ):
198+ def __impl__ (self ):
199+ prog = fluid .Program ()
200+ startup_prog = fluid .Program ()
201+ scope = fluid .core .Scope ()
202+ with fluid .scope_guard (scope ):
203+ with fluid .program_guard (prog , startup_prog ):
204+ main (use_cuda , parallel , nn_type )
205+
206+ fn = 'test_{0}_{1}_{2}' .format (nn_type , 'cuda'
207+ if use_cuda else 'cpu' , 'parallel'
208+ if parallel else 'normal' )
209+
210+ setattr (TestRecognizeDigits , fn , __impl__ )
211+
212+
213+ def inject_all_tests ():
214+ for use_cuda in (False , True ):
215+ for parallel in (False , True ):
216+ for nn_type in ('mlp' , 'conv' ):
217+ inject_test_method (use_cuda , parallel , nn_type )
218+
219+
220+ inject_all_tests ()
221+
222+ if __name__ == '__main__' :
223+ unittest .main ()
0 commit comments