Skip to content

Commit 2738ca1

Browse files
author
Yibing Liu
authored
Merge pull request #636 from kuke/refactor_model
Refactor model conf: add profiling, parallel running, model saving etc
2 parents 6759125 + d6d819c commit 2738ca1

File tree

7 files changed

+547
-229
lines changed

7 files changed

+547
-229
lines changed

fluid/DeepASR/data_utils/data_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ def ordered_feeding_task(sample_info_queue):
225225
@suppress_complaints(verbose=self._verbose)
226226
def ordered_processing_task(sample_info_queue, sample_queue, out_order):
227227
if self._verbose == 0:
228-
signal.signal(signal.SIGTERM, suppress_signal())
229-
signal.signal(signal.SIGINT, suppress_signal())
228+
signal.signal(signal.SIGTERM, suppress_signal)
229+
signal.signal(signal.SIGINT, suppress_signal)
230230

231231
def read_bytes(fpath, start, size):
232232
f = open(fpath, 'r')

fluid/DeepASR/data_utils/util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from six import reraise
66
from tblib import Traceback
77

8+
import numpy as np
9+
810

911
def to_lodtensor(data, place):
1012
"""convert tensor to lodtensor

fluid/DeepASR/model_utils/model.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
5+
import paddle.v2 as paddle
6+
import paddle.v2.fluid as fluid
7+
8+
9+
def stacked_lstmp_model(hidden_dim,
10+
proj_dim,
11+
stacked_num,
12+
class_num,
13+
parallel=False,
14+
is_train=True):
15+
""" The model for DeepASR. The main structure is composed of stacked
16+
identical LSTMP (LSTM with recurrent projection) layers.
17+
18+
When running in training and validation phase, the feeding dictionary
19+
is {'feature', 'label'}, fed by the LodTensor for feature data and
20+
label data respectively. And in inference, only `feature` is needed.
21+
22+
Args:
23+
hidden_dim(int): The hidden state's dimension of the LSTMP layer.
24+
proj_dim(int): The projection size of the LSTMP layer.
25+
stacked_num(int): The number of stacked LSTMP layers.
26+
parallel(bool): Run in parallel or not, default `False`.
27+
is_train(bool): Run in training phase or not, default `True`.
28+
class_dim(int): The number of output classes.
29+
"""
30+
31+
# network configuration
32+
def _net_conf(feature, label):
33+
seq_conv1 = fluid.layers.sequence_conv(
34+
input=feature,
35+
num_filters=1024,
36+
filter_size=3,
37+
filter_stride=1,
38+
bias_attr=True)
39+
bn1 = fluid.layers.batch_norm(
40+
input=seq_conv1,
41+
act="sigmoid",
42+
is_test=not is_train,
43+
momentum=0.9,
44+
epsilon=1e-05,
45+
data_layout='NCHW')
46+
47+
stack_input = bn1
48+
for i in range(stacked_num):
49+
fc = fluid.layers.fc(input=stack_input,
50+
size=hidden_dim * 4,
51+
bias_attr=True)
52+
proj, cell = fluid.layers.dynamic_lstmp(
53+
input=fc,
54+
size=hidden_dim * 4,
55+
proj_size=proj_dim,
56+
bias_attr=True,
57+
use_peepholes=True,
58+
is_reverse=False,
59+
cell_activation="tanh",
60+
proj_activation="tanh")
61+
bn = fluid.layers.batch_norm(
62+
input=proj,
63+
act="sigmoid",
64+
is_test=not is_train,
65+
momentum=0.9,
66+
epsilon=1e-05,
67+
data_layout='NCHW')
68+
stack_input = bn
69+
70+
prediction = fluid.layers.fc(input=stack_input,
71+
size=class_num,
72+
act='softmax')
73+
74+
cost = fluid.layers.cross_entropy(input=prediction, label=label)
75+
avg_cost = fluid.layers.mean(x=cost)
76+
acc = fluid.layers.accuracy(input=prediction, label=label)
77+
return prediction, avg_cost, acc
78+
79+
# data feeder
80+
feature = fluid.layers.data(
81+
name="feature", shape=[-1, 120 * 11], dtype="float32", lod_level=1)
82+
label = fluid.layers.data(
83+
name="label", shape=[-1, 1], dtype="int64", lod_level=1)
84+
85+
if parallel:
86+
# When the execution place is specified to CUDAPlace, the program will
87+
# run on all $CUDA_VISIBLE_DEVICES GPUs. Otherwise the program will
88+
# run on all CPU devices.
89+
places = fluid.layers.get_places()
90+
pd = fluid.layers.ParallelDo(places)
91+
with pd.do():
92+
feat_ = pd.read_input(feature)
93+
label_ = pd.read_input(label)
94+
prediction, avg_cost, acc = _net_conf(feat_, label_)
95+
for out in [avg_cost, acc]:
96+
pd.write_output(out)
97+
98+
# get mean loss and acc through every devices.
99+
avg_cost, acc = pd()
100+
avg_cost = fluid.layers.mean(x=avg_cost)
101+
acc = fluid.layers.mean(x=acc)
102+
else:
103+
prediction, avg_cost, acc = _net_conf(feature, label)
104+
105+
return prediction, avg_cost, acc

fluid/DeepASR/stacked_dynamic_lstm.py

Lines changed: 0 additions & 227 deletions
This file was deleted.

fluid/DeepASR/tools/_init_paths.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""Add the parent directory to $PYTHONPATH"""
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
from __future__ import print_function
5+
6+
import os.path
7+
import sys
8+
9+
10+
def add_path(path):
11+
if path not in sys.path:
12+
sys.path.insert(0, path)
13+
14+
15+
this_dir = os.path.dirname(__file__)
16+
17+
# Add project path to PYTHONPATH
18+
proj_path = os.path.join(this_dir, '..')
19+
add_path(proj_path)

0 commit comments

Comments
 (0)