Skip to content

Commit 4c990dc

Browse files
committed
Add a trained seq2seq model to generate Chinese poetry.
1 parent e72c105 commit 4c990dc

File tree

11 files changed

+423
-3
lines changed

11 files changed

+423
-3
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#!/usr/bin/env python
2+
#coding=utf-8
3+
4+
import os
5+
import sys
6+
import gzip
7+
import logging
8+
import numpy as np
9+
10+
import reader
11+
import paddle.v2 as paddle
12+
from paddle.v2.layer import parse_network
13+
from network_conf import encoder_decoder_network
14+
15+
logger = logging.getLogger("paddle")
16+
logger.setLevel(logging.WARNING)
17+
18+
19+
def infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout):
20+
beam_result = inferer.infer(input=test_batch, field=["prob", "id"])
21+
gen_sen_idx = np.where(beam_result[1] == -1)[0]
22+
assert len(gen_sen_idx) == len(test_batch) * beam_size, ("%d vs. %d" % (
23+
len(gen_sen_idx), len(test_batch) * beam_size))
24+
25+
start_pos, end_pos = 1, 0
26+
for i, sample in enumerate(test_batch):
27+
fout.write("%s\n" % (
28+
" ".join([id_to_text[w] for w in sample[0][1:-1]])
29+
)) # skip the start and ending mark when print the source sentence
30+
for j in xrange(beam_size):
31+
end_pos = gen_sen_idx[i * beam_size + j]
32+
fout.write("%s\n" % ("%.4f\t%s" % (beam_result[0][i][j], " ".join(
33+
id_to_text[w] for w in beam_result[1][start_pos:end_pos]))))
34+
start_pos = end_pos + 2
35+
fout.write("\n")
36+
fout.flush
37+
38+
39+
def generate(model_path, word_dict_path, test_data_path, batch_size, beam_size,
40+
save_file, use_gpu):
41+
assert os.path.exists(model_path), "trained model does not exist."
42+
43+
paddle.init(use_gpu=use_gpu, trainer_count=1)
44+
with gzip.open(model_path, "r") as f:
45+
parameters = paddle.parameters.Parameters.from_tar(f)
46+
47+
id_to_text = {}
48+
with open(word_dict_path, "r") as f:
49+
for i, line in enumerate(f):
50+
id_to_text[i] = line.strip().split("\t")[0]
51+
52+
beam_gen = encoder_decoder_network(
53+
word_count=len(id_to_text),
54+
emb_dim=512,
55+
encoder_depth=3,
56+
encoder_hidden_dim=512,
57+
decoder_depth=3,
58+
decoder_hidden_dim=512,
59+
is_generating=True,
60+
beam_size=beam_size,
61+
max_length=10)
62+
63+
inferer = paddle.inference.Inference(
64+
output_layer=beam_gen, parameters=parameters)
65+
66+
test_batch = []
67+
with open(save_file, "w") as fout:
68+
for idx, item in enumerate(
69+
reader.gen_reader(test_data_path, word_dict_path)()):
70+
test_batch.append([item])
71+
if len(test_batch) == batch_size:
72+
infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout)
73+
test_batch = []
74+
75+
if len(test_batch):
76+
infer_a_batch(inferer, test_batch, beam_size, id_to_text, fout)
77+
test_batch = []
78+
79+
80+
if __name__ == "__main__":
81+
generate(
82+
model_path="models/pass_00025.tar.gz",
83+
word_dict_path="data/word_dict.txt",
84+
test_data_path="data/input.txt",
85+
save_file="gen_result.txt",
86+
batch_size=4,
87+
beam_size=5,
88+
use_gpu=True)
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#!/usr/bin/env python
2+
#coding=utf-8
3+
4+
import paddle.v2 as paddle
5+
from paddle.v2.layer import parse_network
6+
7+
__all__ = ["encoder_decoder_network"]
8+
9+
10+
def _bidirect_lstm_encoder(input, hidden_dim, depth):
11+
lstm_last = []
12+
for dirt in ["fwd", "bwd"]:
13+
for i in range(depth):
14+
input_proj = paddle.layer.mixed(
15+
name="__in_proj_%0d_%s__" % (i, dirt),
16+
size=hidden_dim * 4,
17+
bias_attr=True,
18+
input=[
19+
paddle.layer.full_matrix_projection(input_proj),
20+
paddle.layer.full_matrix_projection(
21+
lstm, param_attr=paddle.attr.Param(initial_std=5e-4)),
22+
] if i else [paddle.layer.full_matrix_projection(input)])
23+
lstm = paddle.layer.lstmemory(
24+
input=input_proj,
25+
bias_attr=paddle.attr.Param(initial_std=0.),
26+
param_attr=paddle.attr.Param(initial_std=5e-4),
27+
reverse=i % 2 if dirt == "fwd" else not i % 2)
28+
lstm_last.append(lstm)
29+
return paddle.layer.concat(input=lstm_last)
30+
31+
32+
def _attended_decoder_step(word_count, enc_out, enc_out_proj,
33+
decoder_hidden_dim, depth, trg_emb):
34+
decoder_memory = paddle.layer.memory(
35+
name="__decoder_0__", size=decoder_hidden_dim, boot_layer=None)
36+
37+
context = paddle.networks.simple_attention(
38+
encoded_sequence=enc_out,
39+
encoded_proj=enc_out_proj,
40+
decoder_state=decoder_memory)
41+
42+
for i in range(depth):
43+
input_proj = paddle.layer.mixed(
44+
act=paddle.activation.Linear(),
45+
size=decoder_hidden_dim * 4,
46+
bias_attr=False,
47+
input=[
48+
paddle.layer.full_matrix_projection(input_proj),
49+
paddle.layer.full_matrix_projection(lstm)
50+
] if i else [
51+
paddle.layer.full_matrix_projection(context),
52+
paddle.layer.full_matrix_projection(trg_emb)
53+
])
54+
lstm = paddle.networks.lstmemory_unit(
55+
input=input_proj,
56+
input_proj_layer_attr=paddle.attr.ExtraLayerAttribute(
57+
error_clipping_threshold=25.),
58+
out_memory=decoder_memory if not i else None,
59+
name="__decoder_%d__" % (i),
60+
size=decoder_hidden_dim,
61+
act=paddle.activation.Tanh(),
62+
gate_act=paddle.activation.Sigmoid(),
63+
state_act=paddle.activation.Tanh())
64+
65+
next_word = paddle.layer.fc(
66+
size=word_count,
67+
bias_attr=True,
68+
act=paddle.activation.Softmax(),
69+
input=lstm)
70+
return next_word
71+
72+
73+
def encoder_decoder_network(word_count,
74+
emb_dim,
75+
encoder_depth,
76+
encoder_hidden_dim,
77+
decoder_depth,
78+
decoder_hidden_dim,
79+
beam_size=10,
80+
max_length=15,
81+
is_generating=False):
82+
src_emb = paddle.layer.embedding(
83+
input=paddle.layer.data(
84+
name="src_word_id",
85+
type=paddle.data_type.integer_value_sequence(word_count)),
86+
size=emb_dim,
87+
param_attr=paddle.attr.ParamAttr(name="__embedding__"))
88+
enc_out = _bidirect_lstm_encoder(
89+
input=src_emb, hidden_dim=encoder_hidden_dim, depth=encoder_depth)
90+
enc_out_proj = paddle.layer.fc(
91+
act=paddle.activation.Linear(),
92+
size=encoder_hidden_dim,
93+
bias_attr=False,
94+
input=enc_out)
95+
96+
decoder_group_name = "decoder_group"
97+
group_inputs = [
98+
word_count, paddle.layer.StaticInput(input=enc_out),
99+
paddle.layer.StaticInput(input=enc_out_proj), decoder_hidden_dim,
100+
decoder_depth
101+
]
102+
103+
if is_generating:
104+
gen_trg_emb = paddle.layer.GeneratedInput(
105+
size=word_count,
106+
embedding_name="__embedding__",
107+
embedding_size=emb_dim)
108+
return paddle.layer.beam_search(
109+
name=decoder_group_name,
110+
step=_attended_decoder_step,
111+
input=group_inputs + [gen_trg_emb],
112+
bos_id=0,
113+
eos_id=1,
114+
beam_size=beam_size,
115+
max_length=max_length)
116+
117+
else:
118+
trg_emb = paddle.layer.embedding(
119+
input=paddle.layer.data(
120+
name="trg_word_id",
121+
type=paddle.data_type.integer_value_sequence(word_count)),
122+
size=emb_dim,
123+
param_attr=paddle.attr.ParamAttr(name="__embedding__"))
124+
lbl = paddle.layer.data(
125+
name="trg_next_word",
126+
type=paddle.data_type.integer_value_sequence(word_count))
127+
next_word = paddle.layer.recurrent_group(
128+
name=decoder_group_name,
129+
step=_attended_decoder_step,
130+
input=group_inputs + [trg_emb])
131+
return paddle.layer.classification_cost(input=next_word, label=lbl)

generate_chinese_poetry/reader.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
from utils import load_dict
5+
6+
7+
def train_reader(data_file_path, word_dict_file):
8+
def reader():
9+
word_dict = load_dict(word_dict_file)
10+
11+
unk_id = word_dict[u"<unk>"]
12+
bos_id = word_dict[u"<s>"]
13+
eos_id = word_dict[u"<e>"]
14+
15+
with open(data_file_path, "r") as f:
16+
for line in f:
17+
line_split = line.strip().decode(
18+
"utf8", errors="ignore").split("\t")
19+
if len(line_split) < 3: continue
20+
21+
poetry = line_split[2].split(".")
22+
poetry_ids = []
23+
for sen in poetry:
24+
if sen:
25+
poetry_ids.append([bos_id] + [
26+
word_dict.get(word, unk_id)
27+
for word in "".join(sen.split())
28+
] + [eos_id])
29+
l = len(poetry_ids)
30+
if l < 2: continue
31+
for i in range(l - 1):
32+
yield poetry_ids[i], poetry_ids[i +
33+
1][:-1], poetry_ids[i +
34+
1][1:]
35+
36+
return reader
37+
38+
39+
def gen_reader(data_file_path, word_dict_file):
40+
def reader():
41+
word_dict = load_dict(word_dict_file)
42+
43+
unk_id = word_dict[u"<unk>"]
44+
bos_id = word_dict[u"<s>"]
45+
eos_id = word_dict[u"<e>"]
46+
47+
with open(data_file_path, "r") as f:
48+
for line in f:
49+
input_line = "".join(
50+
line.strip().decode("utf8", errors="ignore").split())
51+
yield [bos_id
52+
] + [word_dict.get(word, unk_id)
53+
for word in input_line] + [eos_id]
54+
55+
return reader

generate_chinese_poetry/train.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#!/usr/bin/env python
2+
#coding=utf-8
3+
4+
import gzip
5+
import os
6+
import logging
7+
8+
import paddle.v2 as paddle
9+
import reader
10+
from paddle.v2.layer import parse_network
11+
from network_conf import encoder_decoder_network
12+
13+
logger = logging.getLogger("paddle")
14+
logger.setLevel(logging.INFO)
15+
16+
17+
def save_model(save_path, parameters):
18+
with gzip.open(save_path, "w") as f:
19+
parameters.to_tar(f)
20+
21+
22+
def load_initial_model(model_path, parameters):
23+
with gzip.open(model_path, "rb") as f:
24+
parameters.init_from_tar(f)
25+
26+
27+
def main(num_passes,
28+
batch_size,
29+
use_gpu,
30+
trainer_count,
31+
save_dir_path,
32+
encoder_depth,
33+
decoder_depth,
34+
word_dict_path,
35+
train_data_path,
36+
init_model_path=""):
37+
if not os.path.exists(save_dir_path):
38+
os.mkdir(save_dir_path)
39+
40+
# initialize PaddlePaddle
41+
paddle.init(use_gpu=use_gpu, trainer_count=trainer_count, parallel_nn=1)
42+
43+
# define optimization method and the trainer instance
44+
# optimizer = paddle.optimizer.Adam(
45+
optimizer = paddle.optimizer.AdaDelta(
46+
learning_rate=1e-3,
47+
gradient_clipping_threshold=25.0,
48+
regularization=paddle.optimizer.L2Regularization(rate=8e-4),
49+
model_average=paddle.optimizer.ModelAverage(
50+
average_window=0.5, max_average_window=2500))
51+
52+
cost = encoder_decoder_network(
53+
word_count=len(open(word_dict_path, "r").readlines()),
54+
emb_dim=512,
55+
encoder_depth=encoder_depth,
56+
encoder_hidden_dim=512,
57+
decoder_depth=decoder_depth,
58+
decoder_hidden_dim=512)
59+
60+
parameters = paddle.parameters.create(cost)
61+
if init_model_path:
62+
load_initial_model(init_model_path, parameters)
63+
64+
trainer = paddle.trainer.SGD(
65+
cost=cost, parameters=parameters, update_equation=optimizer)
66+
67+
# define data reader
68+
train_reader = paddle.batch(
69+
paddle.reader.shuffle(
70+
reader.train_reader(train_data_path, word_dict_path),
71+
buf_size=1024000),
72+
batch_size=batch_size)
73+
74+
# define the event_handler callback
75+
def event_handler(event):
76+
if isinstance(event, paddle.event.EndIteration):
77+
if (not event.batch_id % 2000) and event.batch_id:
78+
save_path = os.path.join(save_dir_path,
79+
"pass_%05d_batch_%05d.tar.gz" %
80+
(event.pass_id, event.batch_id))
81+
save_model(save_path, parameters)
82+
83+
if not event.batch_id % 5:
84+
logger.info("Pass %d, Batch %d, Cost %f, %s" % (
85+
event.pass_id, event.batch_id, event.cost, event.metrics))
86+
87+
if isinstance(event, paddle.event.EndPass):
88+
save_path = os.path.join(save_dir_path,
89+
"pass_%05d.tar.gz" % event.pass_id)
90+
save_model(save_path, parameters)
91+
92+
# start training
93+
trainer.train(
94+
reader=train_reader, event_handler=event_handler, num_passes=num_passes)
95+
96+
97+
if __name__ == '__main__':
98+
main(
99+
num_passes=500,
100+
batch_size=4 * 500,
101+
use_gpu=True,
102+
trainer_count=4,
103+
encoder_depth=3,
104+
decoder_depth=3,
105+
save_dir_path="models",
106+
word_dict_path="data/word_dict.txt",
107+
train_data_path="data/song.poet.txt",
108+
init_model_path="")

0 commit comments

Comments
 (0)