Skip to content

Commit b95b4c3

Browse files
adamlerersoumith
authored andcommitted
Remove language features leftovers
1 parent d88d377 commit b95b4c3

File tree

7 files changed

+60
-241
lines changed

7 files changed

+60
-241
lines changed

OpenNMT/onmt/Dataset.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,16 @@
33

44

55
class Dataset(object):
6-
# FIXME: randomize
6+
77
def __init__(self, srcData, tgtData, batchSize, cuda):
8-
self.src = srcData['words']
8+
self.src = srcData
99
if tgtData:
10-
self.tgt = tgtData['words']
10+
self.tgt = tgtData
1111
assert(len(self.src) == len(self.tgt))
1212
else:
1313
self.tgt = None
1414
self.cuda = cuda
15-
# FIXME
16-
# self.srcFeatures = srcData.features
17-
# self.tgtFeatures = tgtData.features
15+
1816
self.batchSize = batchSize
1917
self.numBatches = len(self.src) // batchSize
2018

OpenNMT/onmt/Models.py

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,6 @@
33
from torch.autograd import Variable
44
import onmt.modules
55

6-
7-
def _makeFeatEmbedder(opt, dicts):
8-
return onmt.FeaturesEmbedding(dicts['features'],
9-
opt.feat_vec_exponent,
10-
opt.feat_vec_size,
11-
opt.feat_merge)
12-
13-
146
class Encoder(nn.Module):
157

168
def __init__(self, opt, dicts):
@@ -19,40 +11,25 @@ def __init__(self, opt, dicts):
1911
assert opt.rnn_size % self.num_directions == 0
2012
self.hidden_size = opt.rnn_size // self.num_directions
2113
inputSize = opt.word_vec_size
22-
feat_lut = None
23-
# Sequences with features.
24-
if len(dicts['features']) > 0:
25-
feat_lut = _makeFeatEmbedder(opt, dicts)
26-
inputSize = inputSize + feat_lut.outputSize
2714

2815
super(Encoder, self).__init__()
29-
self.word_lut = nn.Embedding(dicts['words'].size(),
16+
self.word_lut = nn.Embedding(dicts.size(),
3017
opt.word_vec_size,
3118
padding_idx=onmt.Constants.PAD)
3219
self.rnn = nn.LSTM(inputSize, self.hidden_size,
3320
num_layers=opt.layers,
3421
dropout=opt.dropout,
3522
bidirectional=opt.brnn)
3623

37-
3824
# self.rnn.bias_ih_l0.data.div_(2)
3925
# self.rnn.bias_hh_l0.data.copy_(self.rnn.bias_ih_l0.data)
4026

4127
if opt.pre_word_vecs_enc is not None:
4228
pretrained = torch.load(opt.pre_word_vecs_enc)
4329
self.word_lut.weight.copy_(pretrained)
4430

45-
self.has_features = feat_lut is not None
46-
if self.has_features:
47-
self.add_module('feat_lut', feat_lut)
48-
4931
def forward(self, input, hidden=None):
50-
if self.has_features:
51-
word_emb = self.word_lut(input[0])
52-
feat_emb = self.feat_lut(input[1])
53-
emb = torch.cat([word_emb, feat_emb], 1)
54-
else:
55-
emb = self.word_lut(input)
32+
emb = self.word_lut(input)
5633

5734
if hidden is None:
5835
batch_size = emb.size(1)
@@ -70,7 +47,6 @@ def __init__(self, num_layers, input_size, rnn_size, dropout):
7047
super(StackedLSTM, self).__init__()
7148
self.dropout = nn.Dropout(dropout)
7249

73-
7450
self.layers = []
7551
for i in range(num_layers):
7652
layer = nn.LSTMCell(input_size, rnn_size)
@@ -104,14 +80,8 @@ def __init__(self, opt, dicts):
10480
if self.input_feed:
10581
input_size += opt.rnn_size
10682

107-
feat_lut = None
108-
# Sequences with features.
109-
if len(dicts['features']) > 0:
110-
feat_lut = _makeFeatEmbedder(opt, dicts)
111-
input_size = input_size + feat_lut.outputSize
112-
11383
super(Decoder, self).__init__()
114-
self.word_lut = nn.Embedding(dicts['words'].size(),
84+
self.word_lut = nn.Embedding(dicts.size(),
11585
opt.word_vec_size,
11686
padding_idx=onmt.Constants.PAD)
11787
self.rnn = StackedLSTM(opt.layers, input_size, opt.rnn_size, opt.dropout)
@@ -127,17 +97,9 @@ def __init__(self, opt, dicts):
12797
pretrained = torch.load(opt.pre_word_vecs_dec)
12898
self.word_lut.weight.copy_(pretrained)
12999

130-
self.has_features = feat_lut is not None
131-
if self.has_features:
132-
self.add_module('feat_lut', feat_lut)
133100

134101
def forward(self, input, hidden, context, init_output):
135-
if self.has_features:
136-
word_emb = self.word_lut(input[0])
137-
feat_emb = self.feat_lut(input[1])
138-
emb = torch.cat([word_emb, feat_emb], 1)
139-
else:
140-
emb = self.word_lut(input)
102+
emb = self.word_lut(input)
141103

142104
batch_size = input.size(1)
143105

OpenNMT/onmt/Translator.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@ def __init__(self, opt):
1818
else:
1919
self.model.cpu()
2020

21-
self.src_dict = checkpoint['dicts']['src']['words']
22-
self.tgt_dict = checkpoint['dicts']['tgt']['words']
23-
24-
# if opt.phrase_table.len() > 0:
25-
# phraseTable = onmt.translate.PhraseTable.new(opt.phrase_table)
21+
self.src_dict = checkpoint['dicts']['src']
22+
self.tgt_dict = checkpoint['dicts']['tgt']
2623

2724
def buildData(self, srcBatch, goldBatch):
2825
srcData = [self.src_dict.convertToIdx(b,
@@ -34,9 +31,7 @@ def buildData(self, srcBatch, goldBatch):
3431
onmt.Constants.BOS_WORD,
3532
onmt.Constants.EOS_WORD) for b in goldBatch]
3633

37-
return onmt.Dataset(
38-
{'words': srcData},
39-
{'words': tgtData} if tgtData else None,
34+
return onmt.Dataset(srcData, tgtData,
4035
self.opt.batch_size, self.opt.cuda)
4136

4237
def buildTargetTokens(self, pred, src, attn):
@@ -53,7 +48,6 @@ def buildTargetTokens(self, pred, src, attn):
5348

5449
def translateBatch(self, batch):
5550
srcBatch, tgtBatch = batch
56-
sourceLength = srcBatch.size(0)
5751
batchSize = srcBatch.size(1)
5852
beamSize = self.opt.beam_size
5953

@@ -179,7 +173,6 @@ def updateActive(t):
179173

180174
def translate(self, srcBatch, goldBatch):
181175
dataset = self.buildData(srcBatch, goldBatch)
182-
assert(len(dataset) == 1) # FIXME
183176
batch = dataset[0]
184177

185178
pred, predScore, attn, goldScore = self.translateBatch(batch)

OpenNMT/onmt/modules/GlobalAttention.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch.nn as nn
2525
import math
2626

27+
2728
class GlobalAttention(nn.Module):
2829
def __init__(self, dim):
2930
super(GlobalAttention, self).__init__()

0 commit comments

Comments
 (0)