33from torch .autograd import Variable
44import onmt .modules
55
6-
7- def _makeFeatEmbedder (opt , dicts ):
8- return onmt .FeaturesEmbedding (dicts ['features' ],
9- opt .feat_vec_exponent ,
10- opt .feat_vec_size ,
11- opt .feat_merge )
12-
13-
146class Encoder (nn .Module ):
157
168 def __init__ (self , opt , dicts ):
@@ -19,40 +11,25 @@ def __init__(self, opt, dicts):
1911 assert opt .rnn_size % self .num_directions == 0
2012 self .hidden_size = opt .rnn_size // self .num_directions
2113 inputSize = opt .word_vec_size
22- feat_lut = None
23- # Sequences with features.
24- if len (dicts ['features' ]) > 0 :
25- feat_lut = _makeFeatEmbedder (opt , dicts )
26- inputSize = inputSize + feat_lut .outputSize
2714
2815 super (Encoder , self ).__init__ ()
29- self .word_lut = nn .Embedding (dicts [ 'words' ] .size (),
16+ self .word_lut = nn .Embedding (dicts .size (),
3017 opt .word_vec_size ,
3118 padding_idx = onmt .Constants .PAD )
3219 self .rnn = nn .LSTM (inputSize , self .hidden_size ,
3320 num_layers = opt .layers ,
3421 dropout = opt .dropout ,
3522 bidirectional = opt .brnn )
3623
37-
3824 # self.rnn.bias_ih_l0.data.div_(2)
3925 # self.rnn.bias_hh_l0.data.copy_(self.rnn.bias_ih_l0.data)
4026
4127 if opt .pre_word_vecs_enc is not None :
4228 pretrained = torch .load (opt .pre_word_vecs_enc )
4329 self .word_lut .weight .copy_ (pretrained )
4430
45- self .has_features = feat_lut is not None
46- if self .has_features :
47- self .add_module ('feat_lut' , feat_lut )
48-
4931 def forward (self , input , hidden = None ):
50- if self .has_features :
51- word_emb = self .word_lut (input [0 ])
52- feat_emb = self .feat_lut (input [1 ])
53- emb = torch .cat ([word_emb , feat_emb ], 1 )
54- else :
55- emb = self .word_lut (input )
32+ emb = self .word_lut (input )
5633
5734 if hidden is None :
5835 batch_size = emb .size (1 )
@@ -70,7 +47,6 @@ def __init__(self, num_layers, input_size, rnn_size, dropout):
7047 super (StackedLSTM , self ).__init__ ()
7148 self .dropout = nn .Dropout (dropout )
7249
73-
7450 self .layers = []
7551 for i in range (num_layers ):
7652 layer = nn .LSTMCell (input_size , rnn_size )
@@ -104,14 +80,8 @@ def __init__(self, opt, dicts):
10480 if self .input_feed :
10581 input_size += opt .rnn_size
10682
107- feat_lut = None
108- # Sequences with features.
109- if len (dicts ['features' ]) > 0 :
110- feat_lut = _makeFeatEmbedder (opt , dicts )
111- input_size = input_size + feat_lut .outputSize
112-
11383 super (Decoder , self ).__init__ ()
114- self .word_lut = nn .Embedding (dicts [ 'words' ] .size (),
84+ self .word_lut = nn .Embedding (dicts .size (),
11585 opt .word_vec_size ,
11686 padding_idx = onmt .Constants .PAD )
11787 self .rnn = StackedLSTM (opt .layers , input_size , opt .rnn_size , opt .dropout )
@@ -127,17 +97,9 @@ def __init__(self, opt, dicts):
12797 pretrained = torch .load (opt .pre_word_vecs_dec )
12898 self .word_lut .weight .copy_ (pretrained )
12999
130- self .has_features = feat_lut is not None
131- if self .has_features :
132- self .add_module ('feat_lut' , feat_lut )
133100
134101 def forward (self , input , hidden , context , init_output ):
135- if self .has_features :
136- word_emb = self .word_lut (input [0 ])
137- feat_emb = self .feat_lut (input [1 ])
138- emb = torch .cat ([word_emb , feat_emb ], 1 )
139- else :
140- emb = self .word_lut (input )
102+ emb = self .word_lut (input )
141103
142104 batch_size = input .size (1 )
143105
0 commit comments