1- #!/usr/bin/env python 
2- #coding=utf-8 
3- 
4- import  gzip 
51import  os 
2+ import  gzip 
63import  logging 
4+ import  click 
75
86import  paddle .v2  as  paddle 
97import  reader 
@@ -24,24 +22,59 @@ def load_initial_model(model_path, parameters):
2422 parameters .init_from_tar (f )
2523
2624
27- def  main (num_passes ,
28-  batch_size ,
29-  use_gpu ,
30-  trainer_count ,
31-  save_dir_path ,
32-  encoder_depth ,
33-  decoder_depth ,
34-  word_dict_path ,
35-  train_data_path ,
36-  init_model_path = "" ):
25+ @click .command ("train" ) 
26+ @click .option ( 
27+  "--num_passes" , default = 10 , help = "Number of passes for the training task." ) 
28+ @click .option ( 
29+  "--batch_size" , 
30+  default = 16 , 
31+  help = "The number of training examples in one forward/backward pass." ) 
32+ @click .option ( 
33+  "--use_gpu" , default = False , help = "Whether to use gpu to train the model." ) 
34+ @click .option ( 
35+  "--trainer_count" , default = 1 , help = "The thread number used in training." ) 
36+ @click .option ( 
37+  "--save_dir_path" , 
38+  default = "models" , 
39+  help = "The path to saved the trained models." ) 
40+ @click .option ( 
41+  "--encoder_depth" , 
42+  default = 3 , 
43+  help = "The number of stacked LSTM layers in encoder." ) 
44+ @click .option ( 
45+  "--decoder_depth" , 
46+  default = 3 , 
47+  help = "The number of stacked LSTM layers in encoder." ) 
48+ @click .option ( 
49+  "--train_data_path" , required = True , help = "The path of trainning data." ) 
50+ @click .option ( 
51+  "--word_dict_path" , required = True , help = "The path of word dictionary." ) 
52+ @click .option ( 
53+  "--init_model_path" , 
54+  default = "" , 
55+  help = ("The path of a trained model used to initialized all "  
56+  "the model parameters." )) 
57+ def  train (num_passes ,
58+  batch_size ,
59+  use_gpu ,
60+  trainer_count ,
61+  save_dir_path ,
62+  encoder_depth ,
63+  decoder_depth ,
64+  train_data_path ,
65+  word_dict_path ,
66+  init_model_path = "" ):
3767 if  not  os .path .exists (save_dir_path ):
3868 os .mkdir (save_dir_path )
69+  assert  os .path .exists (
70+  word_dict_path ), "The given word dictionary does not exist." 
71+  assert  os .path .exists (
72+  train_data_path ), "The given training data does not exist." 
3973
4074 # initialize PaddlePaddle 
41-  paddle .init (use_gpu = use_gpu , trainer_count = trainer_count ,  parallel_nn = 1 )
75+  paddle .init (use_gpu = use_gpu , trainer_count = trainer_count )
4276
4377 # define optimization method and the trainer instance 
44-  # optimizer = paddle.optimizer.Adam( 
4578 optimizer  =  paddle .optimizer .AdaDelta (
4679 learning_rate = 1e-3 ,
4780 gradient_clipping_threshold = 25.0 ,
@@ -74,7 +107,7 @@ def main(num_passes,
74107 # define the event_handler callback 
75108 def  event_handler (event ):
76109 if  isinstance (event , paddle .event .EndIteration ):
77-  if  (not  event .batch_id  %  2000 ) and  event .batch_id :
110+  if  (not  event .batch_id  %  1000 ) and  event .batch_id :
78111 save_path  =  os .path .join (save_dir_path ,
79112 "pass_%05d_batch_%05d.tar.gz"  % 
80113 (event .pass_id , event .batch_id ))
@@ -94,15 +127,5 @@ def event_handler(event):
94127 reader = train_reader , event_handler = event_handler , num_passes = num_passes )
95128
96129
97- if  __name__  ==  '__main__' :
98-  main (
99-  num_passes = 500 ,
100-  batch_size = 4  *  500 ,
101-  use_gpu = True ,
102-  trainer_count = 4 ,
103-  encoder_depth = 3 ,
104-  decoder_depth = 3 ,
105-  save_dir_path = "models" ,
106-  word_dict_path = "data/word_dict.txt" ,
107-  train_data_path = "data/song.poet.txt" ,
108-  init_model_path = "" )
130+ if  __name__  ==  "__main__" :
131+  train ()
0 commit comments