1- #!/usr/bin/env python
2- #coding=utf-8
3-
4- import gzip
51import os
2+ import gzip
63import logging
4+ import click
75
86import paddle .v2 as paddle
97import reader
@@ -24,24 +22,59 @@ def load_initial_model(model_path, parameters):
2422 parameters .init_from_tar (f )
2523
2624
27- def main (num_passes ,
28- batch_size ,
29- use_gpu ,
30- trainer_count ,
31- save_dir_path ,
32- encoder_depth ,
33- decoder_depth ,
34- word_dict_path ,
35- train_data_path ,
36- init_model_path = "" ):
25+ @click .command ("train" )
26+ @click .option (
27+ "--num_passes" , default = 10 , help = "Number of passes for the training task." )
28+ @click .option (
29+ "--batch_size" ,
30+ default = 16 ,
31+ help = "The number of training examples in one forward/backward pass." )
32+ @click .option (
33+ "--use_gpu" , default = False , help = "Whether to use gpu to train the model." )
34+ @click .option (
35+ "--trainer_count" , default = 1 , help = "The thread number used in training." )
36+ @click .option (
37+ "--save_dir_path" ,
38+ default = "models" ,
39+ help = "The path to saved the trained models." )
40+ @click .option (
41+ "--encoder_depth" ,
42+ default = 3 ,
43+ help = "The number of stacked LSTM layers in encoder." )
44+ @click .option (
45+ "--decoder_depth" ,
46+ default = 3 ,
47+ help = "The number of stacked LSTM layers in encoder." )
48+ @click .option (
49+ "--train_data_path" , required = True , help = "The path of trainning data." )
50+ @click .option (
51+ "--word_dict_path" , required = True , help = "The path of word dictionary." )
52+ @click .option (
53+ "--init_model_path" ,
54+ default = "" ,
55+ help = ("The path of a trained model used to initialized all "
56+ "the model parameters." ))
57+ def train (num_passes ,
58+ batch_size ,
59+ use_gpu ,
60+ trainer_count ,
61+ save_dir_path ,
62+ encoder_depth ,
63+ decoder_depth ,
64+ train_data_path ,
65+ word_dict_path ,
66+ init_model_path = "" ):
3767 if not os .path .exists (save_dir_path ):
3868 os .mkdir (save_dir_path )
69+ assert os .path .exists (
70+ word_dict_path ), "The given word dictionary does not exist."
71+ assert os .path .exists (
72+ train_data_path ), "The given training data does not exist."
3973
4074 # initialize PaddlePaddle
41- paddle .init (use_gpu = use_gpu , trainer_count = trainer_count , parallel_nn = 1 )
75+ paddle .init (use_gpu = use_gpu , trainer_count = trainer_count )
4276
4377 # define optimization method and the trainer instance
44- # optimizer = paddle.optimizer.Adam(
4578 optimizer = paddle .optimizer .AdaDelta (
4679 learning_rate = 1e-3 ,
4780 gradient_clipping_threshold = 25.0 ,
@@ -74,7 +107,7 @@ def main(num_passes,
74107 # define the event_handler callback
75108 def event_handler (event ):
76109 if isinstance (event , paddle .event .EndIteration ):
77- if (not event .batch_id % 2000 ) and event .batch_id :
110+ if (not event .batch_id % 1000 ) and event .batch_id :
78111 save_path = os .path .join (save_dir_path ,
79112 "pass_%05d_batch_%05d.tar.gz" %
80113 (event .pass_id , event .batch_id ))
@@ -94,15 +127,5 @@ def event_handler(event):
94127 reader = train_reader , event_handler = event_handler , num_passes = num_passes )
95128
96129
97- if __name__ == '__main__' :
98- main (
99- num_passes = 500 ,
100- batch_size = 4 * 500 ,
101- use_gpu = True ,
102- trainer_count = 4 ,
103- encoder_depth = 3 ,
104- decoder_depth = 3 ,
105- save_dir_path = "models" ,
106- word_dict_path = "data/word_dict.txt" ,
107- train_data_path = "data/song.poet.txt" ,
108- init_model_path = "" )
130+ if __name__ == "__main__" :
131+ train ()
0 commit comments