Skip to content

Commit c9b41ee

Browse files
author
Yibing Liu
authored
Merge pull request #653 from kuke/add_infer
Add the demo script for inference
2 parents e826679 + b8fae8a commit c9b41ee

File tree

2 files changed

+121
-10
lines changed

2 files changed

+121
-10
lines changed

fluid/DeepASR/infer.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import os
6+
import argparse
7+
import paddle.v2.fluid as fluid
8+
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
9+
import data_utils.augmentor.trans_add_delta as trans_add_delta
10+
import data_utils.augmentor.trans_splice as trans_splice
11+
import data_utils.data_reader as reader
12+
from data_utils.util import lodtensor_to_ndarray
13+
14+
15+
def parse_args():
16+
parser = argparse.ArgumentParser("Inference for stacked LSTMP model.")
17+
parser.add_argument(
18+
'--batch_size',
19+
type=int,
20+
default=32,
21+
help='The sequence number of a batch data. (default: %(default)d)')
22+
parser.add_argument(
23+
'--device',
24+
type=str,
25+
default='GPU',
26+
choices=['CPU', 'GPU'],
27+
help='The device type. (default: %(default)s)')
28+
parser.add_argument(
29+
'--mean_var',
30+
type=str,
31+
default='data/global_mean_var_search26kHr',
32+
help="The path for feature's global mean and variance. "
33+
"(default: %(default)s)")
34+
parser.add_argument(
35+
'--infer_feature_lst',
36+
type=str,
37+
default='data/infer_feature.lst',
38+
help='The feature list path for inference. (default: %(default)s)')
39+
parser.add_argument(
40+
'--infer_label_lst',
41+
type=str,
42+
default='data/infer_label.lst',
43+
help='The label list path for inference. (default: %(default)s)')
44+
parser.add_argument(
45+
'--model_save_path',
46+
type=str,
47+
default='./checkpoints/deep_asr.pass_0.model/',
48+
help='The directory for saving model. (default: %(default)s)')
49+
args = parser.parse_args()
50+
return args
51+
52+
53+
def print_arguments(args):
54+
print('----------- Configuration Arguments -----------')
55+
for arg, value in sorted(vars(args).iteritems()):
56+
print('%s: %s' % (arg, value))
57+
print('------------------------------------------------')
58+
59+
60+
def split_infer_result(infer_seq, lod):
61+
infer_batch = []
62+
for i in xrange(0, len(lod[0]) - 1):
63+
infer_batch.append(infer_seq[lod[0][i]:lod[0][i + 1]])
64+
return infer_batch
65+
66+
67+
def infer(args):
68+
""" Gets one batch of feature data and predicts labels for each sample.
69+
"""
70+
71+
if not os.path.exists(args.model_save_path):
72+
raise IOError("Invalid model path!")
73+
74+
place = fluid.CUDAPlace(0) if args.device == 'GPU' else fluid.CPUPlace()
75+
exe = fluid.Executor(place)
76+
77+
# load model
78+
[infer_program, feed_dict,
79+
fetch_targets] = fluid.io.load_inference_model(args.model_save_path, exe)
80+
81+
ltrans = [
82+
trans_add_delta.TransAddDelta(2, 2),
83+
trans_mean_variance_norm.TransMeanVarianceNorm(args.mean_var),
84+
trans_splice.TransSplice()
85+
]
86+
87+
infer_data_reader = reader.DataReader(args.infer_feature_lst,
88+
args.infer_label_lst)
89+
infer_data_reader.set_transformers(ltrans)
90+
91+
feature_t = fluid.LoDTensor()
92+
one_batch = infer_data_reader.batch_iterator(args.batch_size, 1).next()
93+
(features, labels, lod) = one_batch
94+
feature_t.set(features, place)
95+
feature_t.set_lod([lod])
96+
97+
results = exe.run(infer_program,
98+
feed={feed_dict[0]: feature_t},
99+
fetch_list=fetch_targets,
100+
return_numpy=False)
101+
102+
probs, lod = lodtensor_to_ndarray(results[0])
103+
preds = probs.argmax(axis=1)
104+
infer_batch = split_infer_result(preds, lod)
105+
for index, sample in enumerate(infer_batch):
106+
print("result %d: " % index, sample, '\n')
107+
108+
109+
if __name__ == '__main__':
110+
args = parse_args()
111+
print_arguments(args)
112+
infer(args)

fluid/DeepASR/train.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,33 +72,34 @@ def parse_args():
7272
'--mean_var',
7373
type=str,
7474
default='data/global_mean_var_search26kHr',
75-
help='mean var path')
75+
help="The path for feature's global mean and variance. "
76+
"(default: %(default)s)")
7677
parser.add_argument(
7778
'--train_feature_lst',
7879
type=str,
7980
default='data/feature.lst',
80-
help='feature list path for training.')
81+
help='The feature list path for training. (default: %(default)s)')
8182
parser.add_argument(
8283
'--train_label_lst',
8384
type=str,
8485
default='data/label.lst',
85-
help='label list path for training.')
86+
help='The label list path for training. (default: %(default)s)')
8687
parser.add_argument(
8788
'--val_feature_lst',
8889
type=str,
8990
default='data/val_feature.lst',
90-
help='feature list path for validation.')
91+
help='The feature list path for validation. (default: %(default)s)')
9192
parser.add_argument(
9293
'--val_label_lst',
9394
type=str,
9495
default='data/val_label.lst',
95-
help='label list path for validation.')
96+
help='The label list path for validation. (default: %(default)s)')
9697
parser.add_argument(
9798
'--model_save_dir',
9899
type=str,
99100
default='./checkpoints',
100-
help='directory to save model. Do not save model if set to '
101-
'.')
101+
help="The directory for saving model. Do not save model if set to "
102+
"''. (default: %(default)s)")
102103
args = parser.parse_args()
103104
return args
104105

@@ -114,8 +115,6 @@ def train(args):
114115
"""train in loop.
115116
"""
116117

117-
# prediction, avg_cost, accuracy = stacked_lstmp_model(args.hidden_dim,
118-
# args.proj_dim, args.stacked_num, class_num=1749, args.parallel)
119118
prediction, avg_cost, accuracy = stacked_lstmp_model(
120119
hidden_dim=args.hidden_dim,
121120
proj_dim=args.proj_dim,
@@ -206,7 +205,7 @@ def test(exe):
206205
sys.stdout.flush()
207206
# run test
208207
val_cost, val_acc = test(exe)
209-
# save model
208+
# save model
210209
if args.model_save_dir != '':
211210
model_path = os.path.join(
212211
args.model_save_dir, "deep_asr.pass_" + str(pass_id) + ".model")

0 commit comments

Comments
 (0)