Skip to content

Commit d91dab0

Browse files
authored
Merge pull request #74 from qingqing01/ds2
Support variable input batch and SortaGrad.
2 parents d67d362 + cb6da07 commit d91dab0

File tree

2 files changed

+87
-76
lines changed

2 files changed

+87
-76
lines changed

deep_speech_2/audio_data_utils.py

Lines changed: 63 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import random
99
import soundfile
1010
import numpy as np
11+
import itertools
1112
import os
1213

1314
RANDOM_SEED = 0
@@ -62,6 +63,7 @@ def __init__(self,
6263
self.__stride_ms__ = stride_ms
6364
self.__window_ms__ = window_ms
6465
self.__max_frequency__ = max_frequency
66+
self.__epoc__ = 0
6567
self.__random__ = random.Random(RANDOM_SEED)
6668
# load vocabulary (dictionary)
6769
self.__vocab_dict__, self.__vocab_list__ = \
@@ -245,43 +247,56 @@ def __padding_batch__(self, batch, padding_to=-1, flatten=False):
245247
new_batch.append((padded_audio, text))
246248
return new_batch
247249

248-
def instance_reader_creator(self,
249-
manifest_path,
250-
sort_by_duration=True,
251-
shuffle=False):
250+
def __batch_shuffle__(self, manifest, batch_size):
251+
"""
252+
The instances have different lengths and they cannot be
253+
combined into a single matrix multiplication. It usually
254+
sorts the training examples by length and combines only
255+
similarly-sized instances into minibatches, pads with
256+
silence when necessary so that all instances in a batch
257+
have the same length. This batch shuffle fuction is used
258+
to make similarly-sized instances into minibatches and
259+
make a batch-wise shuffle.
260+
261+
1. Sort the audio clips by duration.
262+
2. Generate a random number `k`, k in [0, batch_size).
263+
3. Randomly remove `k` instances in order to make different mini-batches,
264+
then make minibatches and each minibatch size is batch_size.
265+
4. Shuffle the minibatches.
266+
267+
:param manifest: manifest file.
268+
:type manifest: list
269+
:param batch_size: Batch size. This size is also used for generate
270+
a random number for batch shuffle.
271+
:type batch_size: int
272+
:return: batch shuffled mainifest.
273+
:rtype: list
274+
"""
275+
manifest.sort(key=lambda x: x["duration"])
276+
shift_len = self.__random__.randint(0, batch_size - 1)
277+
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
278+
self.__random__.shuffle(batch_manifest)
279+
batch_manifest = list(sum(batch_manifest, ()))
280+
res_len = len(manifest) - shift_len - len(batch_manifest)
281+
batch_manifest.extend(manifest[-res_len:])
282+
batch_manifest.extend(manifest[0:shift_len])
283+
return batch_manifest
284+
285+
def instance_reader_creator(self, manifest):
252286
"""
253287
Instance reader creator for audio data. Creat a callable function to
254288
produce instances of data.
255289
256290
Instance: a tuple of a numpy ndarray of audio spectrogram and a list of
257291
tokenized and indexed transcription text.
258292
259-
:param manifest_path: Filepath of manifest for audio clip files.
260-
:type manifest_path: basestring
261-
:param sort_by_duration: Sort the audio clips by duration if set True
262-
(for SortaGrad).
263-
:type sort_by_duration: bool
264-
:param shuffle: Shuffle the audio clips if set True.
265-
:type shuffle: bool
293+
:param manifest: Filepath of manifest for audio clip files.
294+
:type manifest: basestring
266295
:return: Data reader function.
267296
:rtype: callable
268297
"""
269-
if sort_by_duration and shuffle:
270-
sort_by_duration = False
271-
logger.warn("When shuffle set to true, "
272-
"sort_by_duration is forced to set False.")
273298

274299
def reader():
275-
# read manifest
276-
manifest = self.__read_manifest__(
277-
manifest_path=manifest_path,
278-
max_duration=self.__max_duration__,
279-
min_duration=self.__min_duration__)
280-
# sort (by duration) or shuffle manifest
281-
if sort_by_duration:
282-
manifest.sort(key=lambda x: x["duration"])
283-
if shuffle:
284-
self.__random__.shuffle(manifest)
285300
# extract spectrogram feature
286301
for instance in manifest:
287302
spectrogram = self.__audio_featurize__(
@@ -296,8 +311,8 @@ def batch_reader_creator(self,
296311
batch_size,
297312
padding_to=-1,
298313
flatten=False,
299-
sort_by_duration=True,
300-
shuffle=False):
314+
sortagrad=False,
315+
batch_shuffle=False):
301316
"""
302317
Batch data reader creator for audio data. Creat a callable function to
303318
produce batches of data.
@@ -317,20 +332,32 @@ def batch_reader_creator(self,
317332
:param flatten: If set True, audio data will be flatten to be a 1-dim
318333
ndarray. Otherwise, 2-dim ndarray. Default is False.
319334
:type flatten: bool
320-
:param sort_by_duration: Sort the audio clips by duration if set True
321-
(for SortaGrad).
322-
:type sort_by_duration: bool
323-
:param shuffle: Shuffle the audio clips if set True.
324-
:type shuffle: bool
335+
:param sortagrad: Sort the audio clips by duration in the first epoc
336+
if set True.
337+
:type sortagrad: bool
338+
:param batch_shuffle: Shuffle the audio clips if set True. It is
339+
not a thorough instance-wise shuffle, but a
340+
specific batch-wise shuffle. For more details,
341+
please see `__batch_shuffle__` function.
342+
:type batch_shuffle: bool
325343
:return: Batch reader function, producing batches of data when called.
326344
:rtype: callable
327345
"""
328346

329347
def batch_reader():
330-
instance_reader = self.instance_reader_creator(
348+
# read manifest
349+
manifest = self.__read_manifest__(
331350
manifest_path=manifest_path,
332-
sort_by_duration=sort_by_duration,
333-
shuffle=shuffle)
351+
max_duration=self.__max_duration__,
352+
min_duration=self.__min_duration__)
353+
354+
# sort (by duration) or shuffle manifest
355+
if self.__epoc__ == 0 and sortagrad:
356+
manifest.sort(key=lambda x: x["duration"])
357+
elif batch_shuffle:
358+
manifest = self.__batch_shuffle__(manifest, batch_size)
359+
360+
instance_reader = self.instance_reader_creator(manifest)
334361
batch = []
335362
for instance in instance_reader():
336363
batch.append(instance)
@@ -339,6 +366,7 @@ def batch_reader():
339366
batch = []
340367
if len(batch) > 0:
341368
yield self.__padding_batch__(batch, padding_to, flatten)
369+
self.__epoc__ += 1
342370

343371
return batch_reader
344372

deep_speech_2/train.py

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -93,23 +93,27 @@ def train():
9393
"""
9494
DeepSpeech2 training.
9595
"""
96+
9697
# initialize data generator
97-
data_generator = DataGenerator(
98-
vocab_filepath=args.vocab_filepath,
99-
normalizer_manifest_path=args.normalizer_manifest_path,
100-
normalizer_num_samples=200,
101-
max_duration=20.0,
102-
min_duration=0.0,
103-
stride_ms=10,
104-
window_ms=20)
98+
def data_generator():
99+
return DataGenerator(
100+
vocab_filepath=args.vocab_filepath,
101+
normalizer_manifest_path=args.normalizer_manifest_path,
102+
normalizer_num_samples=200,
103+
max_duration=20.0,
104+
min_duration=0.0,
105+
stride_ms=10,
106+
window_ms=20)
105107

108+
train_generator = data_generator()
109+
test_generator = data_generator()
106110
# create network config
107-
dict_size = data_generator.vocabulary_size()
111+
dict_size = train_generator.vocabulary_size()
112+
# paddle.data_type.dense_array is used for variable batch input.
113+
# the size 161 * 161 is only an placeholder value and the real shape
114+
# of input batch data will be set at each batch.
108115
audio_data = paddle.layer.data(
109-
name="audio_spectrogram",
110-
height=161,
111-
width=2000,
112-
type=paddle.data_type.dense_vector(322000))
116+
name="audio_spectrogram", type=paddle.data_type.dense_array(161 * 161))
113117
text_data = paddle.layer.data(
114118
name="transcript_text",
115119
type=paddle.data_type.integer_value_sequence(dict_size))
@@ -136,28 +140,16 @@ def train():
136140
cost=cost, parameters=parameters, update_equation=optimizer)
137141

138142
# prepare data reader
139-
train_batch_reader_sortagrad = data_generator.batch_reader_creator(
140-
manifest_path=args.train_manifest_path,
141-
batch_size=args.batch_size,
142-
padding_to=2000,
143-
flatten=True,
144-
sort_by_duration=True,
145-
shuffle=False)
146-
train_batch_reader_nosortagrad = data_generator.batch_reader_creator(
143+
train_batch_reader = train_generator.batch_reader_creator(
147144
manifest_path=args.train_manifest_path,
148145
batch_size=args.batch_size,
149-
padding_to=2000,
150-
flatten=True,
151-
sort_by_duration=False,
152-
shuffle=True)
153-
test_batch_reader = data_generator.batch_reader_creator(
146+
sortagrad=True if args.init_model_path is None else False,
147+
batch_shuffle=True)
148+
test_batch_reader = test_generator.batch_reader_creator(
154149
manifest_path=args.dev_manifest_path,
155150
batch_size=args.batch_size,
156-
padding_to=2000,
157-
flatten=True,
158-
sort_by_duration=False,
159-
shuffle=False)
160-
feeding = data_generator.data_name_feeding()
151+
batch_shuffle=False)
152+
feeding = train_generator.data_name_feeding()
161153

162154
# create event handler
163155
def event_handler(event):
@@ -183,17 +175,8 @@ def event_handler(event):
183175
time.time() - start_time, event.pass_id, result.cost)
184176

185177
# run train
186-
# first pass with sortagrad
187-
if args.use_sortagrad:
188-
trainer.train(
189-
reader=train_batch_reader_sortagrad,
190-
event_handler=event_handler,
191-
num_passes=1,
192-
feeding=feeding)
193-
args.num_passes -= 1
194-
# other passes without sortagrad
195178
trainer.train(
196-
reader=train_batch_reader_nosortagrad,
179+
reader=train_batch_reader,
197180
event_handler=event_handler,
198181
num_passes=args.num_passes,
199182
feeding=feeding)

0 commit comments

Comments
 (0)