Skip to content

Commit 988ee51

Browse files
bmccannsoumith
authored andcommitted
state_dicts for translation and optimizer
1 parent 4af62f9 commit 988ee51

File tree

3 files changed

+40
-15
lines changed

3 files changed

+40
-15
lines changed

OpenNMT/onmt/Models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def forward(self, input):
149149
self._fix_enc_hidden(enc_hidden[1]))
150150

151151
out, dec_hidden, _attn = self.decoder(tgt, enc_hidden, context, init_output)
152-
if hasattr(self, 'generate') and self.generate:
152+
if hasattr(self, 'generator') and self.generate:
153153
out = self.generator(out)
154154

155155
return out

OpenNMT/onmt/Translator.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import onmt
2+
import torch.nn as nn
23
import torch
34
from torch.autograd import Variable
45

@@ -9,17 +10,34 @@ def __init__(self, opt):
910
self.tt = torch.cuda if opt.cuda else torch
1011

1112
checkpoint = torch.load(opt.model)
12-
self.model = checkpoint['model']
1313

14-
self.model.eval()
14+
model_opt = checkpoint['opt']
15+
self.src_dict = checkpoint['dicts']['src']
16+
self.tgt_dict = checkpoint['dicts']['tgt']
17+
18+
encoder = onmt.Models.Encoder(model_opt, self.src_dict)
19+
decoder = onmt.Models.Decoder(model_opt, self.tgt_dict)
20+
model = onmt.Models.NMTModel(encoder, decoder)
21+
22+
generator = nn.Sequential(
23+
nn.Linear(model_opt.rnn_size, self.tgt_dict.size()),
24+
nn.LogSoftmax())
25+
26+
model.load_state_dict(checkpoint['model'])
27+
generator.load_state_dict(checkpoint['generator'])
1528

1629
if opt.cuda:
17-
self.model.cuda()
30+
model.cuda()
31+
generator.cuda()
1832
else:
19-
self.model.cpu()
33+
model.cpu()
34+
generator.cpu()
35+
36+
model.generator = generator
37+
38+
self.model = model
39+
self.model.eval()
2040

21-
self.src_dict = checkpoint['dicts']['src']
22-
self.tgt_dict = checkpoint['dicts']['tgt']
2341

2442
def buildData(self, srcBatch, goldBatch):
2543
srcData = [self.src_dict.convertToIdx(b,

OpenNMT/train.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@
100100

101101
# GPU
102102
parser.add_argument('-gpus', default=[], nargs='+', type=int,
103-
help="Use CUDA")
103+
help="Use CUDA on the listed devices.")
104104

105105
parser.add_argument('-log_interval', type=int, default=50,
106106
help="Print stats at this interval.")
@@ -255,7 +255,8 @@ def trainEpoch(epoch):
255255
'dicts': dataset['dicts'],
256256
'opt': opt,
257257
'epoch': epoch,
258-
'optim': optim,
258+
'optimizer': optim.optimizer.state_dict(),
259+
'last_ppl': optim.last_ppl,
259260
}
260261
torch.save(checkpoint,
261262
'%s_acc_%.2f_ppl_%.2f_e%d.pt' % (opt.save_model, 100*valid_acc, valid_ppl, epoch))
@@ -299,12 +300,14 @@ def main():
299300
print('Loading model from checkpoint at %s' % opt.train_from)
300301
model.load_state_dict(checkpoint['model'])
301302
generator.load_state_dict(checkpoint['generator'])
302-
optim = checkpoint['optim']
303303
opt.start_epoch = checkpoint['epoch'] + 1
304304

305305
if len(opt.gpus) >= 1:
306306
model.cuda()
307307
generator.cuda()
308+
else:
309+
model.cpu()
310+
generator.cpu()
308311

309312
if len(opt.gpus) > 1:
310313
model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
@@ -316,11 +319,15 @@ def main():
316319
for p in model.parameters():
317320
p.data.uniform_(-opt.param_init, opt.param_init)
318321

319-
optim = onmt.Optim(
320-
model.parameters(), opt.optim, opt.learning_rate, opt.max_grad_norm,
321-
lr_decay=opt.learning_rate_decay,
322-
start_decay_at=opt.start_decay_at
323-
)
322+
optim = onmt.Optim(
323+
model.parameters(), opt.optim, opt.learning_rate, opt.max_grad_norm,
324+
lr_decay=opt.learning_rate_decay,
325+
start_decay_at=opt.start_decay_at
326+
)
327+
328+
if opt.train_from:
329+
optim.last_ppl = checkpoint['last_ppl']
330+
optim.optimizer.load_state_dict(checkpoint['optimizer'])
324331

325332
nParams = sum([p.nelement() for p in model.parameters()])
326333
print('* number of parameters: %d' % nParams)

0 commit comments

Comments
 (0)