Skip to content

Commit 49e5e51

Browse files
author
yangyaming
committed
Add loading model function for train.py.
1 parent 5974ea9 commit 49e5e51

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

deep_speech_2/train.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from model import deep_speech2
1212
from audio_data_utils import DataGenerator
1313
import numpy as np
14+
import os
1415

1516
#TODO: add WER metric
1617

@@ -78,6 +79,11 @@
7879
default='data/eng_vocab.txt',
7980
type=str,
8081
help="Vocabulary filepath. (default: %(default)s)")
82+
parser.add_argument(
83+
"--init_model_path",
84+
default='models/params.tar.gz',
85+
type=str,
86+
help="Model path for initialization. (default: %(default)s)")
8187
args = parser.parse_args()
8288

8389

@@ -114,8 +120,13 @@ def train():
114120
rnn_size=args.rnn_layer_size,
115121
is_inference=False)
116122

117-
# create parameters and optimizer
118-
parameters = paddle.parameters.create(cost)
123+
# create/load parameters and optimizer
124+
if args.init_model_path is None:
125+
parameters = paddle.parameters.create(cost)
126+
else:
127+
assert os.path.isfile(args.init_model_path), "Invalid model."
128+
parameters = paddle.parameters.Parameters.from_tar(
129+
gzip.open(args.init_model_path))
119130
optimizer = paddle.optimizer.Adam(
120131
learning_rate=args.adam_learning_rate, gradient_clipping_threshold=400)
121132
trainer = paddle.trainer.SGD(

0 commit comments

Comments
 (0)