@@ -38,6 +38,34 @@ def forward(self, input, hidden=None):
3838 return  hidden_t , outputs 
3939
4040
41+ class  StackedLSTM (nn .Module ):
42+  def  __init__ (self , num_layers , input_size , rnn_size , dropout ):
43+  super (StackedLSTM , self ).__init__ ()
44+  self .dropout  =  nn .Dropout (dropout )
45+  self .num_layers  =  num_layers 
46+  self .layers  =  nn .ModuleList ()
47+ 
48+  for  i  in  range (num_layers ):
49+  self .layers .append (nn .LSTMCell (input_size , rnn_size ))
50+  input_size  =  rnn_size 
51+ 
52+  def  forward (self , input , hidden ):
53+  h_0 , c_0  =  hidden 
54+  h_1 , c_1  =  [], []
55+  for  i , layer  in  enumerate (self .layers ):
56+  h_1_i , c_1_i  =  layer (input , (h_0 [i ], c_0 [i ]))
57+  input  =  h_1_i 
58+  if  i  !=  self .num_layers :
59+  input  =  self .dropout (input )
60+  h_1  +=  [h_1_i ]
61+  c_1  +=  [c_1_i ]
62+ 
63+  h_1  =  torch .stack (h_1 )
64+  c_1  =  torch .stack (c_1 )
65+ 
66+  return  input , (h_1 , c_1 )
67+ 
68+ 
4169class  Decoder (nn .Module ):
4270
4371 def  __init__ (self , opt , dicts ):
@@ -51,9 +79,7 @@ def __init__(self, opt, dicts):
5179 self .word_lut  =  nn .Embedding (dicts .size (),
5280 opt .word_vec_size ,
5381 padding_idx = onmt .Constants .PAD )
54-  self .rnn  =  nn .LSTM (input_size , opt .rnn_size , 
55-  num_layers = opt .layers ,
56-  dropout = opt .dropout )
82+  self .rnn  =  StackedLSTM (opt .layers , input_size , opt .rnn_size , opt .dropout )
5783 self .attn  =  onmt .modules .GlobalAttention (opt .rnn_size )
5884 self .dropout  =  nn .Dropout (opt .dropout )
5985
@@ -77,16 +103,16 @@ def forward(self, input, hidden, context, init_output):
77103 outputs  =  []
78104 output  =  init_output 
79105 for  i , emb_t  in  enumerate (emb .split (1 )):
80-  emb_t  =  emb_t 
106+  emb_t  =  emb_t . squeeze ( 0 ) 
81107 if  self .input_feed :
82-  emb_t  =  torch .cat ([emb_t , output ], 2 )
108+  emb_t  =  torch .cat ([emb_t , output ], 1 )
83109
84110 output , hidden  =  self .rnn (emb_t , hidden )
85-  output , attn  =  self .attn (output . squeeze ( 0 ) , context .t ())
86-  output  =  self .dropout (output . unsqueeze ( 0 ) )
111+  output , attn  =  self .attn (output , context .t ())
112+  output  =  self .dropout (output )
87113 outputs  +=  [output ]
88114
89-  outputs  =  torch .cat (outputs ,  0 )
115+  outputs  =  torch .stack (outputs )
90116 return  outputs .transpose (0 , 1 ), hidden , attn 
91117
92118
@@ -105,7 +131,7 @@ def set_generate(self, enabled):
105131 def  make_init_decoder_output (self , context ):
106132 batch_size  =  context .size (1 )
107133 h_size  =  (batch_size , self .decoder .hidden_size )
108-  return  Variable (context .data .new (1 ,  * h_size ).zero_ (), requires_grad = False )
134+  return  Variable (context .data .new (* h_size ).zero_ (), requires_grad = False )
109135
110136 def  _fix_enc_hidden (self , h ):
111137 # the encoder hidden is (layers*directions) x batch x dim 
0 commit comments