Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 3 additions & 37 deletions data_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ class DataGenerator(object):
be passed forward directly without
converting to index sequence.
:type keep_transcription_text: bool
:param num_conv_layers: The number of convolution layer, used to compute
the sequence length.
:type num_conv_layers: int
"""

def __init__(self,
Expand All @@ -78,8 +75,7 @@ def __init__(self,
use_dB_normalization=True,
num_threads=multiprocessing.cpu_count() // 2,
random_seed=0,
keep_transcription_text=False,
num_conv_layers=2):
keep_transcription_text=False):
self._max_duration = max_duration
self._min_duration = min_duration
self._normalizer = FeatureNormalizer(mean_std_filepath)
Expand All @@ -100,7 +96,6 @@ def __init__(self,
self._local_data = local()
self._local_data.tar2info = {}
self._local_data.tar2object = {}
self._num_conv_layers = num_conv_layers

def process_utterance(self, filename, transcript):
"""Load, augment, featurize and normalize for speech data.
Expand Down Expand Up @@ -219,14 +214,7 @@ def feeding(self):
:return: Data feeding dict.
:rtype: dict
"""
feeding_dict = {
"audio_spectrogram": 0,
"transcript_text": 1,
"sequence_offset": 2,
"sequence_length": 3
}
for i in xrange(self._num_conv_layers):
feeding_dict["conv%d_index_range" % i] = len(feeding_dict)
feeding_dict = {"audio_spectrogram": 0, "transcript_text": 1}
return feeding_dict

@property
Expand Down Expand Up @@ -322,29 +310,7 @@ def _padding_batch(self, batch, padding_to=-1, flatten=False):
padded_audio[:, :audio.shape[1]] = audio
if flatten:
padded_audio = padded_audio.flatten()

# Stride size for conv0 is (3, 2)
# Stride size for conv1 to convN is (1, 2)
# Same as the network, hard-coded here
padded_instance = [padded_audio, text]
padded_conv0_h = (padded_audio.shape[0] - 1) // 2 + 1
padded_conv0_w = (padded_audio.shape[1] - 1) // 3 + 1
valid_w = (audio.shape[1] - 1) // 3 + 1
padded_instance += [
[0], # sequence offset, always 0
[valid_w], # valid sequence length
# Index ranges for channel, height and width
# Please refer scale_sub_region layer to see details
[1, 32, 1, padded_conv0_h, valid_w + 1, padded_conv0_w]
]
pre_padded_h = padded_conv0_h
for i in xrange(self._num_conv_layers - 1):
padded_h = (pre_padded_h - 1) // 2 + 1
pre_padded_h = padded_h
padded_instance += [
[1, 32, 1, padded_h, valid_w + 1, padded_conv0_w]
]

padded_instance = [padded_audio, text, audio.shape[1]]
new_batch.append(padded_instance)
return new_batch

Expand Down
16 changes: 2 additions & 14 deletions deploy/demo_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ def start_server():
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=1,
keep_transcription_text=True,
num_conv_layers=args.num_conv_layers)
keep_transcription_text=True)
# prepare ASR model
ds2_model = DeepSpeech2Model(
vocab_size=data_generator.vocab_size,
Expand All @@ -164,20 +163,9 @@ def start_server():
# prepare ASR inference handler
def file_to_transcript(filename):
feature = data_generator.process_utterance(filename, "")
ins = []
conv0_h = (feature[0].shape[0] - 1) // 2 + 1
conv0_w = (feature[0].shape[1] - 1) // 3 + 1
ins += [feature[0], feature[1],
[0], [conv0_w],
[1, 32, 1, conv0_h, conv0_w + 1, conv0_w]]
pre_h = conv0_h
for i in xrange(args.num_conv_layers - 1):
h = (pre_h - 1) // 2 + 1
pre_h = h
ins += [[1, 32, 1, h, conv0_w + 1, conv0_w]]

result_transcript = ds2_model.infer_batch(
infer_data=[ins],
infer_data=[feature],
decoding_method=args.decoding_method,
beam_alpha=args.alpha,
beam_beta=args.beta,
Expand Down
3 changes: 1 addition & 2 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def infer():
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=1,
keep_transcription_text=True,
num_conv_layers=args.num_conv_layers)
keep_transcription_text=True)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.infer_manifest,
batch_size=args.num_samples,
Expand Down
122 changes: 114 additions & 8 deletions model_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import time
import logging
import gzip
import copy
import inspect
from distutils.dir_util import mkpath
import paddle.v2 as paddle
from decoders.swig_wrapper import Scorer
Expand Down Expand Up @@ -48,6 +50,7 @@ def __init__(self, vocab_size, num_conv_layers, num_rnn_layers,
self._inferer = None
self._loss_inferer = None
self._ext_scorer = None
self._num_conv_layers = num_conv_layers
self.logger = logging.getLogger("")
self.logger.setLevel(level=logging.INFO)

Expand Down Expand Up @@ -91,6 +94,11 @@ def train(self,
if not os.path.exists(output_model_dir):
mkpath(output_model_dir)

# adapt the feeding dict and reader according to the network
adapted_feeding_dict = self._adapt_feeding_dict(feeding_dict)
adapted_train_batch_reader = self._adapt_data(train_batch_reader)
adapted_dev_batch_reader = self._adapt_data(dev_batch_reader)

# prepare optimizer and trainer
optimizer = paddle.optimizer.Adam(
learning_rate=learning_rate,
Expand Down Expand Up @@ -128,7 +136,8 @@ def event_handler(event):
(time.time() - start_time, event.pass_id))
else:
result = trainer.test(
reader=dev_batch_reader, feeding=feeding_dict)
reader=adapted_dev_batch_reader,
feeding=adapted_feeding_dict)
print(
"\n------- Time: %d sec, Pass: %d, "
"ValidationCost: %s" %
Expand All @@ -140,11 +149,12 @@ def event_handler(event):

# run train
trainer.train(
reader=train_batch_reader,
reader=adapted_train_batch_reader,
event_handler=event_handler,
num_passes=num_passes,
feeding=feeding_dict)
feeding=adapted_feeding_dict)

# TODO(@pkuyym) merge this function into infer_batch
def infer_loss_batch(self, infer_data):
"""Model inference. Infer the ctc loss for a batch of speech
utterances.
Expand Down Expand Up @@ -205,15 +215,17 @@ def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta,
if self._inferer == None:
self._inferer = paddle.inference.Inference(
output_layer=self._log_probs, parameters=self._parameters)
adapted_feeding_dict = self._adapt_feeding_dict(feeding_dict)
adapted_infer_data = self._adapt_data(infer_data)
# run inference
infer_results = self._inferer.infer(
input=infer_data, feeding=feeding_dict)
start_pos = [0] * (len(infer_data) + 1)
for i in xrange(len(infer_data)):
start_pos[i + 1] = start_pos[i] + infer_data[i][3][0]
input=adapted_infer_data, feeding=adapted_feeding_dict)
start_pos = [0] * (len(adapted_infer_data) + 1)
for i in xrange(len(adapted_infer_data)):
start_pos[i + 1] = start_pos[i] + adapted_infer_data[i][3][0]
probs_split = [
infer_results[start_pos[i]:start_pos[i + 1]]
for i in xrange(0, len(infer_data))
for i in xrange(0, len(adapted_infer_data))
]
# run decoder
results = []
Expand Down Expand Up @@ -260,6 +272,100 @@ def infer_batch(self, infer_data, decoding_method, beam_alpha, beam_beta,
decoding_method)
return results

def _adapt_feeding_dict(self, feeding_dict):
"""Adapt feeding dict according to network struct.

To remove impacts from padding part, we add scale_sub_region layer and
sub_seq layer. For sub_seq layer, 'sequence_offset' and
'sequence_length' fields are appended. For each scale_sub_region layer
'convN_index_range' field is appended.

:param feeding_dict: Feeding is a map of field name and tuple index
of the data that reader returns.
:type feeding_dict: dict|list
:return: Adapted feeding dict.
:rtype: dict|list
"""
adapted_feeding_dict = copy.deepcopy(feeding_dict)
if isinstance(feeding_dict, dict):
adapted_feeding_dict["sequence_offset"] = len(adapted_feeding_dict)
adapted_feeding_dict["sequence_length"] = len(adapted_feeding_dict)
for i in xrange(self._num_conv_layers):
adapted_feeding_dict["conv%d_index_range" %i] = \
len(adapted_feeding_dict)
elif isinstance(feeding_dict, list):
adapted_feeding_dict.append("sequence_offset")
adapted_feeding_dict.append("sequence_length")
for i in xrange(self._num_conv_layers):
adapted_feeding_dict.append("conv%d_index_range" % i)
else:
raise ValueError("Type of feeding_dict is %s, not supported." %
type(feeding_dict))

return adapted_feeding_dict

def _adapt_data(self, data):
"""Adapt data according to network struct.

For each convolution layer in the conv_group, to remove impacts from
padding data, we can multiply zero to the padding part of the outputs
of each batch normalization layer. We add a scale_sub_region layer after
each batch normalization layer to reset the padding data.
For rnn layers, to remove impacts from padding data, we can truncate the
padding part before output data feeded into the first rnn layer. We use
sub_seq layer to achieve this.

:param data: Data from data_provider.
:type data: list|function
:return: Adapted data.
:rtype: list|function
"""

def adapt_instance(instance):
if len(instance) < 2 or len(instance) > 3:
raise ValueError("Size of instance should be 2 or 3.")
padded_audio = instance[0]
text = instance[1]
# no padding part
if len(instance) == 2:
audio_len = padded_audio.shape[1]
else:
audio_len = instance[2]
adapted_instance = [padded_audio, text]
# Stride size for conv0 is (3, 2)
# Stride size for conv1 to convN is (1, 2)
# Same as the network, hard-coded here
padded_conv0_h = (padded_audio.shape[0] - 1) // 2 + 1
padded_conv0_w = (padded_audio.shape[1] - 1) // 3 + 1
valid_w = (audio_len - 1) // 3 + 1
adapted_instance += [
[0], # sequence offset, always 0
[valid_w], # valid sequence length
# Index ranges for channel, height and width
# Please refer scale_sub_region layer to see details
[1, 32, 1, padded_conv0_h, valid_w + 1, padded_conv0_w]
]
pre_padded_h = padded_conv0_h
for i in xrange(self._num_conv_layers - 1):
padded_h = (pre_padded_h - 1) // 2 + 1
pre_padded_h = padded_h
adapted_instance += [
[1, 32, 1, padded_h, valid_w + 1, padded_conv0_w]
]
return adapted_instance

if isinstance(data, list):
return map(adapt_instance, data)
elif inspect.isgeneratorfunction(data):

def adapted_reader():
for instance in data():
yield map(adapt_instance, instance)

return adapted_reader
else:
raise ValueError("Type of data is %s, not supported." % type(data))

def _create_parameters(self, model_path=None):
"""Load or create model parameters."""
if model_path is None:
Expand Down
3 changes: 1 addition & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def evaluate():
augmentation_config='{}',
specgram_type=args.specgram_type,
num_threads=args.num_proc_data,
keep_transcription_text=True,
num_conv_layers=args.num_conv_layers)
keep_transcription_text=True)
batch_reader = data_generator.batch_reader_creator(
manifest_path=args.test_manifest,
batch_size=args.batch_size,
Expand Down
6 changes: 2 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,13 @@ def train():
max_duration=args.max_duration,
min_duration=args.min_duration,
specgram_type=args.specgram_type,
num_threads=args.num_proc_data,
num_conv_layers=args.num_conv_layers)
num_threads=args.num_proc_data)
dev_generator = DataGenerator(
vocab_filepath=args.vocab_path,
mean_std_filepath=args.mean_std_path,
augmentation_config="{}",
specgram_type=args.specgram_type,
num_threads=args.num_proc_data,
num_conv_layers=args.num_conv_layers)
num_threads=args.num_proc_data)
train_batch_reader = train_generator.batch_reader_creator(
manifest_path=args.train_manifest,
batch_size=args.batch_size,
Expand Down