Skip to content

Commit cb6da07

Browse files
committed
add more comments and update train.py
1 parent 7fb1fdd commit cb6da07

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

deep_speech_2/audio_data_utils.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -247,25 +247,34 @@ def __padding_batch__(self, batch, padding_to=-1, flatten=False):
247247
new_batch.append((padded_audio, text))
248248
return new_batch
249249

250-
def __batch_shuffle__(self, manifest, batch_shuffle_size):
250+
def __batch_shuffle__(self, manifest, batch_size):
251251
"""
252+
The instances have different lengths and they cannot be
253+
combined into a single matrix multiplication. It usually
254+
sorts the training examples by length and combines only
255+
similarly-sized instances into minibatches, pads with
256+
silence when necessary so that all instances in a batch
257+
have the same length. This batch shuffle fuction is used
258+
to make similarly-sized instances into minibatches and
259+
make a batch-wise shuffle.
260+
252261
1. Sort the audio clips by duration.
253-
2. Generate a random number `k`, k in [0, batch_shuffle_size).
262+
2. Generate a random number `k`, k in [0, batch_size).
254263
3. Randomly remove `k` instances in order to make different mini-batches,
255-
then make minibatches and each minibatch size is batch_shuffle_size.
264+
then make minibatches and each minibatch size is batch_size.
256265
4. Shuffle the minibatches.
257266
258267
:param manifest: manifest file.
259268
:type manifest: list
260-
:param batch_shuffle_size: This size is uesed to generate a random number,
261-
it usually equals to batch size.
262-
:type batch_shuffle_size: int
269+
:param batch_size: Batch size. This size is also used for generate
270+
a random number for batch shuffle.
271+
:type batch_size: int
263272
:return: batch shuffled mainifest.
264273
:rtype: list
265274
"""
266275
manifest.sort(key=lambda x: x["duration"])
267-
shift_len = self.__random__.randint(0, batch_shuffle_size - 1)
268-
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_shuffle_size)
276+
shift_len = self.__random__.randint(0, batch_size - 1)
277+
batch_manifest = zip(*[iter(manifest[shift_len:])] * batch_size)
269278
self.__random__.shuffle(batch_manifest)
270279
batch_manifest = list(sum(batch_manifest, ()))
271280
res_len = len(manifest) - shift_len - len(batch_manifest)
@@ -327,8 +336,9 @@ def batch_reader_creator(self,
327336
if set True.
328337
:type sortagrad: bool
329338
:param batch_shuffle: Shuffle the audio clips if set True. It is
330-
not a thorough instance-wise shuffle,
331-
but a specific batch-wise shuffle.
339+
not a thorough instance-wise shuffle, but a
340+
specific batch-wise shuffle. For more details,
341+
please see `__batch_shuffle__` function.
332342
:type batch_shuffle: bool
333343
:return: Batch reader function, producing batches of data when called.
334344
:rtype: callable

deep_speech_2/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,12 @@ def data_generator():
143143
train_batch_reader = train_generator.batch_reader_creator(
144144
manifest_path=args.train_manifest_path,
145145
batch_size=args.batch_size,
146-
sortagrad=True,
147-
shuffle=True)
146+
sortagrad=True if args.init_model_path is None else False,
147+
batch_shuffle=True)
148148
test_batch_reader = test_generator.batch_reader_creator(
149149
manifest_path=args.dev_manifest_path,
150150
batch_size=args.batch_size,
151-
shuffle=False)
151+
batch_shuffle=False)
152152
feeding = train_generator.data_name_feeding()
153153

154154
# create event handler

0 commit comments

Comments
 (0)