Skip to content

Commit 9ac0da0

Browse files
committed
add command line parser.
1 parent 65b355f commit 9ac0da0

File tree

5 files changed

+82
-55
lines changed

5 files changed

+82
-55
lines changed

generate_chinese_poetry/generate.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
#!/usr/bin/env python
2-
#coding=utf-8
3-
41
import os
52
import sys
63
import gzip
74
import logging
85
import numpy as np
6+
import click
97

108
import reader
119
import paddle.v2 as paddle
@@ -36,19 +34,45 @@ def infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout):
3634
fout.flush
3735

3836

37+
@click.command("generate")
38+
@click.option(
39+
"--model_path",
40+
default="",
41+
help="The path of the trained model for generation.")
42+
@click.option(
43+
"--word_dict_path", required=True, help="The path of word dictionary.")
44+
@click.option(
45+
"--test_data_path",
46+
required=True,
47+
help="The path of input data for generation.")
48+
@click.option(
49+
"--batch_size",
50+
default=1,
51+
help="The number of training examples in one forward pass in generation.")
52+
@click.option(
53+
"--beam_size", default=5, help="The beam expansion in beam search.")
54+
@click.option(
55+
"--save_file",
56+
required=True,
57+
help="The file path to save the generated results.")
58+
@click.option(
59+
"--use_gpu", default=False, help="Whether to use GPU in generation.")
3960
def generate(model_path, word_dict_path, test_data_path, batch_size, beam_size,
4061
save_file, use_gpu):
41-
assert os.path.exists(model_path), "trained model does not exist."
62+
assert os.path.exists(model_path), "The given model does not exist."
63+
assert os.path.exists(test_data_path), "The given test data does not exist."
4264

43-
paddle.init(use_gpu=use_gpu, trainer_count=1)
4465
with gzip.open(model_path, "r") as f:
4566
parameters = paddle.parameters.Parameters.from_tar(f)
4667

4768
id_to_text = {}
69+
assert os.path.exists(
70+
word_dict_path), "The given word dictionary path does not exist."
4871
with open(word_dict_path, "r") as f:
4972
for i, line in enumerate(f):
5073
id_to_text[i] = line.strip().split("\t")[0]
5174

75+
paddle.init(use_gpu=use_gpu, trainer_count=1)
5276
beam_gen = encoder_decoder_network(
5377
word_count=len(id_to_text),
5478
emb_dim=512,
@@ -78,11 +102,4 @@ def generate(model_path, word_dict_path, test_data_path, batch_size, beam_size,
78102

79103

80104
if __name__ == "__main__":
81-
generate(
82-
model_path="models/pass_00025.tar.gz",
83-
word_dict_path="data/word_dict.txt",
84-
test_data_path="data/input.txt",
85-
save_file="gen_result.txt",
86-
batch_size=4,
87-
beam_size=5,
88-
use_gpu=True)
105+
generate()

generate_chinese_poetry/network_conf.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
#!/usr/bin/env python
2-
#coding=utf-8
3-
41
import paddle.v2 as paddle
52
from paddle.v2.layer import parse_network
63

generate_chinese_poetry/reader.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
41
from utils import load_dict
52

63

generate_chinese_poetry/train.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
#!/usr/bin/env python
2-
#coding=utf-8
3-
4-
import gzip
51
import os
2+
import gzip
63
import logging
4+
import click
75

86
import paddle.v2 as paddle
97
import reader
@@ -24,24 +22,59 @@ def load_initial_model(model_path, parameters):
2422
parameters.init_from_tar(f)
2523

2624

27-
def main(num_passes,
28-
batch_size,
29-
use_gpu,
30-
trainer_count,
31-
save_dir_path,
32-
encoder_depth,
33-
decoder_depth,
34-
word_dict_path,
35-
train_data_path,
36-
init_model_path=""):
25+
@click.command("train")
26+
@click.option(
27+
"--num_passes", default=10, help="Number of passes for the training task.")
28+
@click.option(
29+
"--batch_size",
30+
default=16,
31+
help="The number of training examples in one forward/backward pass.")
32+
@click.option(
33+
"--use_gpu", default=False, help="Whether to use gpu to train the model.")
34+
@click.option(
35+
"--trainer_count", default=1, help="The thread number used in training.")
36+
@click.option(
37+
"--save_dir_path",
38+
default="models",
39+
help="The path to saved the trained models.")
40+
@click.option(
41+
"--encoder_depth",
42+
default=3,
43+
help="The number of stacked LSTM layers in encoder.")
44+
@click.option(
45+
"--decoder_depth",
46+
default=3,
47+
help="The number of stacked LSTM layers in encoder.")
48+
@click.option(
49+
"--train_data_path", required=True, help="The path of trainning data.")
50+
@click.option(
51+
"--word_dict_path", required=True, help="The path of word dictionary.")
52+
@click.option(
53+
"--init_model_path",
54+
default="",
55+
help=("The path of a trained model used to initialized all "
56+
"the model parameters."))
57+
def train(num_passes,
58+
batch_size,
59+
use_gpu,
60+
trainer_count,
61+
save_dir_path,
62+
encoder_depth,
63+
decoder_depth,
64+
train_data_path,
65+
word_dict_path,
66+
init_model_path=""):
3767
if not os.path.exists(save_dir_path):
3868
os.mkdir(save_dir_path)
69+
assert os.path.exists(
70+
word_dict_path), "The given word dictionary does not exist."
71+
assert os.path.exists(
72+
train_data_path), "The given training data does not exist."
3973

4074
# initialize PaddlePaddle
41-
paddle.init(use_gpu=use_gpu, trainer_count=trainer_count, parallel_nn=1)
75+
paddle.init(use_gpu=use_gpu, trainer_count=trainer_count)
4276

4377
# define optimization method and the trainer instance
44-
# optimizer = paddle.optimizer.Adam(
4578
optimizer = paddle.optimizer.AdaDelta(
4679
learning_rate=1e-3,
4780
gradient_clipping_threshold=25.0,
@@ -74,7 +107,7 @@ def main(num_passes,
74107
# define the event_handler callback
75108
def event_handler(event):
76109
if isinstance(event, paddle.event.EndIteration):
77-
if (not event.batch_id % 2000) and event.batch_id:
110+
if (not event.batch_id % 1000) and event.batch_id:
78111
save_path = os.path.join(save_dir_path,
79112
"pass_%05d_batch_%05d.tar.gz" %
80113
(event.pass_id, event.batch_id))
@@ -94,15 +127,5 @@ def event_handler(event):
94127
reader=train_reader, event_handler=event_handler, num_passes=num_passes)
95128

96129

97-
if __name__ == '__main__':
98-
main(
99-
num_passes=500,
100-
batch_size=4 * 500,
101-
use_gpu=True,
102-
trainer_count=4,
103-
encoder_depth=3,
104-
decoder_depth=3,
105-
save_dir_path="models",
106-
word_dict_path="data/word_dict.txt",
107-
train_data_path="data/song.poet.txt",
108-
init_model_path="")
130+
if __name__ == "__main__":
131+
train()

generate_chinese_poetry/utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
#!/usr/bin/env python
2-
#coding=utf-8
3-
41
import os
52
import sys
63
import re
@@ -30,7 +27,3 @@ def find_optiaml_pass(log_file):
3027
cost_info.iteritems(),
3128
key=lambda x: sum(x[1]) / (len(x[1])),
3229
reverse=False)[0][0])
33-
34-
35-
if __name__ == '__main__':
36-
find_optiaml_pass('trained_models/models_first_round/train.log')

0 commit comments

Comments
 (0)