Skip to content

Commit c90842c

Browse files
bmccannsoumith
authored andcommitted
reverting cudnn decoder to lstmcell
1 parent bf82a7b commit c90842c

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

OpenNMT/onmt/Models.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,34 @@ def forward(self, input, hidden=None):
3838
return hidden_t, outputs
3939

4040

41+
class StackedLSTM(nn.Module):
42+
def __init__(self, num_layers, input_size, rnn_size, dropout):
43+
super(StackedLSTM, self).__init__()
44+
self.dropout = nn.Dropout(dropout)
45+
self.num_layers = num_layers
46+
self.layers = nn.ModuleList()
47+
48+
for i in range(num_layers):
49+
self.layers.append(nn.LSTMCell(input_size, rnn_size))
50+
input_size = rnn_size
51+
52+
def forward(self, input, hidden):
53+
h_0, c_0 = hidden
54+
h_1, c_1 = [], []
55+
for i, layer in enumerate(self.layers):
56+
h_1_i, c_1_i = layer(input, (h_0[i], c_0[i]))
57+
input = h_1_i
58+
if i != self.num_layers:
59+
input = self.dropout(input)
60+
h_1 += [h_1_i]
61+
c_1 += [c_1_i]
62+
63+
h_1 = torch.stack(h_1)
64+
c_1 = torch.stack(c_1)
65+
66+
return input, (h_1, c_1)
67+
68+
4169
class Decoder(nn.Module):
4270

4371
def __init__(self, opt, dicts):
@@ -51,9 +79,7 @@ def __init__(self, opt, dicts):
5179
self.word_lut = nn.Embedding(dicts.size(),
5280
opt.word_vec_size,
5381
padding_idx=onmt.Constants.PAD)
54-
self.rnn = nn.LSTM(input_size, opt.rnn_size,
55-
num_layers=opt.layers,
56-
dropout=opt.dropout)
82+
self.rnn = StackedLSTM(opt.layers, input_size, opt.rnn_size, opt.dropout)
5783
self.attn = onmt.modules.GlobalAttention(opt.rnn_size)
5884
self.dropout = nn.Dropout(opt.dropout)
5985

@@ -77,16 +103,16 @@ def forward(self, input, hidden, context, init_output):
77103
outputs = []
78104
output = init_output
79105
for i, emb_t in enumerate(emb.split(1)):
80-
emb_t = emb_t
106+
emb_t = emb_t.squeeze(0)
81107
if self.input_feed:
82-
emb_t = torch.cat([emb_t, output], 2)
108+
emb_t = torch.cat([emb_t, output], 1)
83109

84110
output, hidden = self.rnn(emb_t, hidden)
85-
output, attn = self.attn(output.squeeze(0), context.t())
86-
output = self.dropout(output.unsqueeze(0))
111+
output, attn = self.attn(output, context.t())
112+
output = self.dropout(output)
87113
outputs += [output]
88114

89-
outputs = torch.cat(outputs, 0)
115+
outputs = torch.stack(outputs)
90116
return outputs.transpose(0, 1), hidden, attn
91117

92118

@@ -105,7 +131,7 @@ def set_generate(self, enabled):
105131
def make_init_decoder_output(self, context):
106132
batch_size = context.size(1)
107133
h_size = (batch_size, self.decoder.hidden_size)
108-
return Variable(context.data.new(1, *h_size).zero_(), requires_grad=False)
134+
return Variable(context.data.new(*h_size).zero_(), requires_grad=False)
109135

110136
def _fix_enc_hidden(self, h):
111137
# the encoder hidden is (layers*directions) x batch x dim

0 commit comments

Comments
 (0)