Skip to content

Commit 2784c34

Browse files
author
Yibing Liu
committed
improve parameters tuning
1 parent 38a1d26 commit 2784c34

File tree

3 files changed

+71
-47
lines changed

3 files changed

+71
-47
lines changed

deep_speech_2/evaluate.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from decoder import *
1313
from lm.lm_scorer import LmScorer
1414
from error_rate import wer
15+
import utils
1516

1617
parser = argparse.ArgumentParser(description=__doc__)
1718
parser.add_argument(
@@ -67,12 +68,12 @@
6768
help="Path for language model. (default: %(default)s)")
6869
parser.add_argument(
6970
"--alpha",
70-
default=0.26,
71+
default=0.34,
7172
type=float,
7273
help="Parameter associated with language model. (default: %(default)f)")
7374
parser.add_argument(
7475
"--beta",
75-
default=0.1,
76+
default=0.35,
7677
type=float,
7778
help="Parameter associated with word count. (default: %(default)f)")
7879
parser.add_argument(
@@ -192,11 +193,12 @@ def evaluate():
192193
else:
193194
raise ValueError("Decoding method [%s] is not supported." %
194195
decode_method)
195-
196+
print("Cur WER = %f" % (wer_sum / wer_counter))
196197
print("Final WER = %f" % (wer_sum / wer_counter))
197198

198199

199200
def main():
201+
utils.print_arguments(args)
200202
paddle.init(use_gpu=args.use_gpu, trainer_count=1)
201203
evaluate()
202204

deep_speech_2/infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@
9494
help="Path for language model. (default: %(default)s)")
9595
parser.add_argument(
9696
"--alpha",
97-
default=0.26,
97+
default=0.34,
9898
type=float,
9999
help="Parameter associated with language model. (default: %(default)f)")
100100
parser.add_argument(
101101
"--beta",
102-
default=0.1,
102+
default=0.35,
103103
type=float,
104104
help="Parameter associated with word count. (default: %(default)f)")
105105
parser.add_argument(

deep_speech_2/tune.py

Lines changed: 64 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
parser = argparse.ArgumentParser(description=__doc__)
1818
parser.add_argument(
19-
"--num_samples",
19+
"--batch_size",
2020
default=100,
2121
type=int,
22-
help="Number of samples for parameters tuning. (default: %(default)s)")
22+
help="Minibatch size for evaluation. (default: %(default)s)")
2323
parser.add_argument(
2424
"--num_conv_layers",
2525
default=2,
@@ -57,7 +57,7 @@
5757
help="Manifest path for normalizer. (default: %(default)s)")
5858
parser.add_argument(
5959
"--decode_manifest_path",
60-
default='datasets/manifest.test',
60+
default='datasets/manifest.dev',
6161
type=str,
6262
help="Manifest path for decoding. (default: %(default)s)")
6363
parser.add_argument(
@@ -82,17 +82,17 @@
8282
help="Path for language model. (default: %(default)s)")
8383
parser.add_argument(
8484
"--alpha_from",
85-
default=0.1,
85+
default=0.22,
8686
type=float,
8787
help="Where alpha starts from. (default: %(default)f)")
8888
parser.add_argument(
8989
"--num_alphas",
90-
default=14,
90+
default=10,
9191
type=int,
9292
help="Number of candidate alphas. (default: %(default)d)")
9393
parser.add_argument(
9494
"--alpha_to",
95-
default=0.36,
95+
default=0.40,
9696
type=float,
9797
help="Where alpha ends with. (default: %(default)f)")
9898
parser.add_argument(
@@ -102,12 +102,12 @@
102102
help="Where beta starts from. (default: %(default)f)")
103103
parser.add_argument(
104104
"--num_betas",
105-
default=20,
105+
default=7,
106106
type=float,
107107
help="Number of candidate betas. (default: %(default)d)")
108108
parser.add_argument(
109109
"--beta_to",
110-
default=1.0,
110+
default=0.35,
111111
type=float,
112112
help="Where beta ends with. (default: %(default)f)")
113113
parser.add_argument(
@@ -160,55 +160,77 @@ def tune():
160160
# prepare infer data
161161
batch_reader = data_generator.batch_reader_creator(
162162
manifest_path=args.decode_manifest_path,
163-
batch_size=args.num_samples,
163+
batch_size=args.batch_size,
164+
min_batch_size=1,
164165
sortagrad=False,
165166
shuffle_method=None)
166-
# get one batch data for tuning
167-
infer_data = batch_reader().next()
168-
169-
# run inference
170-
infer_results = paddle.infer(
171-
output_layer=output_probs, parameters=parameters, input=infer_data)
172-
num_steps = len(infer_results) // len(infer_data)
173-
probs_split = [
174-
infer_results[i * num_steps:(i + 1) * num_steps]
175-
for i in xrange(0, len(infer_data))
176-
]
167+
168+
# define inferer
169+
inferer = paddle.inference.Inference(
170+
output_layer=output_probs, parameters=parameters)
177171

178172
# create grid for search
179173
cand_alphas = np.linspace(args.alpha_from, args.alpha_to, args.num_alphas)
180174
cand_betas = np.linspace(args.beta_from, args.beta_to, args.num_betas)
181175
params_grid = [(alpha, beta) for alpha in cand_alphas
182176
for beta in cand_betas]
183177

178+
# external scorer
184179
ext_scorer = LmScorer(args.alpha_from, args.beta_from,
185180
args.language_model_path)
186-
## tune parameters in loop
187-
for alpha, beta in params_grid:
188-
wer_sum, wer_counter = 0, 0
189-
# reset scorer
190-
ext_scorer.reset_params(alpha, beta)
191-
# beam search using multiple processes
192-
beam_search_results = ctc_beam_search_decoder_batch(
193-
probs_split=probs_split,
194-
vocabulary=data_generator.vocab_list,
195-
beam_size=args.beam_size,
196-
cutoff_prob=args.cutoff_prob,
197-
blank_id=len(data_generator.vocab_list),
198-
num_processes=args.num_processes_beam_search,
199-
ext_scoring_func=ext_scorer, )
200-
for i, beam_search_result in enumerate(beam_search_results):
201-
target_transcription = ''.join([
202-
data_generator.vocab_list[index] for index in infer_data[i][1]
203-
])
204-
wer_sum += wer(target_transcription, beam_search_result[0][1])
205-
wer_counter += 1
206181

207-
print("alpha = %f\tbeta = %f\tWER = %f" %
208-
(alpha, beta, wer_sum / wer_counter))
182+
wer_sum = [0.0 for i in xrange(len(params_grid))]
183+
wer_counter = [0 for i in xrange(len(params_grid))]
184+
ave_wer = [0.0 for i in xrange(len(params_grid))]
185+
num_batches = 0
186+
187+
## incremental tuning batch by batch
188+
for infer_data in batch_reader():
189+
# run inference
190+
infer_results = inferer.infer(input=infer_data)
191+
num_steps = len(infer_results) // len(infer_data)
192+
probs_split = [
193+
infer_results[i * num_steps:(i + 1) * num_steps]
194+
for i in xrange(0, len(infer_data))
195+
]
196+
# target transcription
197+
target_transcription = [
198+
''.join([
199+
data_generator.vocab_list[index] for index in infer_data[i][1]
200+
]) for i, probs in enumerate(probs_split)
201+
]
202+
203+
# grid search on current batch
204+
for index, (alpha, beta) in enumerate(params_grid):
205+
# reset scorer
206+
ext_scorer.reset_params(alpha, beta)
207+
beam_search_results = ctc_beam_search_decoder_batch(
208+
probs_split=probs_split,
209+
vocabulary=data_generator.vocab_list,
210+
beam_size=args.beam_size,
211+
blank_id=len(data_generator.vocab_list),
212+
num_processes=args.num_processes_beam_search,
213+
ext_scoring_func=ext_scorer,
214+
cutoff_prob=args.cutoff_prob, )
215+
for i, beam_search_result in enumerate(beam_search_results):
216+
wer_sum[index] += wer(target_transcription[i],
217+
beam_search_result[0][1])
218+
wer_counter[index] += 1
219+
ave_wer[index] = wer_sum[index] / wer_counter[index]
220+
print("alpha = %f, beta = %f, WER = %f" %
221+
(alpha, beta, ave_wer[index]))
222+
223+
# output tuning result til current batch
224+
ave_wer_min = min(ave_wer)
225+
min_index = ave_wer.index(ave_wer_min)
226+
print("Finish batch %d, alpha_opt = %f, beta_opt = %f, WER_opt = %f\n" %
227+
(num_batches, params_grid[min_index][0],
228+
params_grid[min_index][1], ave_wer_min))
229+
num_batches += 1
209230

210231

211232
def main():
233+
utils.print_arguments(args)
212234
paddle.init(use_gpu=args.use_gpu, trainer_count=1)
213235
tune()
214236

0 commit comments

Comments
 (0)