Skip to content

Commit 77e44ab

Browse files
authored
Merge pull request #640 from zhxfl/fix-639
Augmentation should compute frame_dim
2 parents b1c3796 + 000f685 commit 77e44ab

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

fluid/DeepASR/data_utils/data_reader.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,6 @@ class DataReader(object):
121121
corresponding description file.
122122
label_file_list (str): File containing paths of label data file and
123123
corresponding description file.
124-
frame_dim (int): The final feature dimension of one frame after all
125-
augmentation applied.
126124
drop_frame_len (int): Samples whose label length above the value will be
127125
dropped.
128126
process_num (int): Number of processes for processing data.
@@ -137,21 +135,18 @@ class DataReader(object):
137135
random_seed (int): Random seed.
138136
"""
139137

140-
def __init__(
141-
self,
142-
feature_file_list,
143-
label_file_list,
144-
frame_dim=120 * 11, # @TODO augmentor is responsible for the value
145-
drop_frame_len=512,
146-
process_num=10,
147-
sample_buffer_size=1024,
148-
sample_info_buffer_size=1024,
149-
batch_buffer_size=1024,
150-
shuffle_block_num=1,
151-
random_seed=0):
138+
def __init__(self,
139+
feature_file_list,
140+
label_file_list,
141+
drop_frame_len=512,
142+
process_num=10,
143+
sample_buffer_size=1024,
144+
sample_info_buffer_size=1024,
145+
batch_buffer_size=1024,
146+
shuffle_block_num=1,
147+
random_seed=0):
152148
self._feature_file_list = feature_file_list
153149
self._label_file_list = label_file_list
154-
self._frame_dim = frame_dim
155150
self._drop_frame_len = drop_frame_len
156151
self._shuffle_block_num = shuffle_block_num
157152
self._block_info_list = None
@@ -300,8 +295,9 @@ def read_bytes(fpath, start, size):
300295

301296
def batch_iterator(self, batch_size, minimum_batch_size):
302297
def batch_to_ndarray(batch_samples, lod):
303-
batch_feature = np.zeros(
304-
(lod[-1], self._frame_dim), dtype="float32")
298+
assert len(batch_samples)
299+
frame_dim = batch_samples[0][0].shape[1]
300+
batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
305301
batch_label = np.zeros((lod[-1], 1), dtype="int64")
306302
start = 0
307303
for sample in batch_samples:

0 commit comments

Comments
 (0)