@@ -48,29 +48,32 @@ def buildTargetTokens(self, pred, src, attn):
4848
4949 def translateBatch (self , batch ):
5050 srcBatch , tgtBatch = batch
51- batchSize = srcBatch .size (1 )
51+ batchSize = srcBatch .size (0 )
5252 beamSize = self .opt .beam_size
5353
5454 # (1) run the encoder on the src
5555
5656 # have to execute the encoder manually to deal with padding
5757 encStates = None
5858 context = []
59- for srcBatch_t in srcBatch .chunk (srcBatch .size (0 ) ):
59+ for srcBatch_t in srcBatch .chunk (srcBatch .size (1 ), dim = 1 ):
6060 encStates , context_t = self .model .encoder (srcBatch_t , hidden = encStates )
61- batchPadIdx = srcBatch_t .data .squeeze (0 ).eq (onmt .Constants .PAD ).nonzero ()
61+ batchPadIdx = srcBatch_t .data .squeeze (1 ).eq (onmt .Constants .PAD ).nonzero ()
6262 if batchPadIdx .nelement () > 0 :
6363 batchPadIdx = batchPadIdx .squeeze (1 )
6464 encStates [0 ].data .index_fill_ (1 , batchPadIdx , 0 )
6565 encStates [1 ].data .index_fill_ (1 , batchPadIdx , 0 )
6666 context += [context_t ]
6767
68+ encStates = (self .model ._fix_enc_hidden (encStates [0 ]),
69+ self .model ._fix_enc_hidden (encStates [1 ]))
70+
6871 context = torch .cat (context )
6972 rnnSize = context .size (2 )
7073
7174 # This mask is applied to the attention model inside the decoder
7275 # so that the attention ignores source padding
73- padMask = srcBatch .data .eq (onmt .Constants .PAD ). t ()
76+ padMask = srcBatch .data .eq (onmt .Constants .PAD )
7477 def applyContextMask (m ):
7578 if isinstance (m , onmt .modules .GlobalAttention ):
7679 m .applyMask (padMask )
@@ -85,8 +88,8 @@ def applyContextMask(m):
8588 initOutput = self .model .make_init_decoder_output (context )
8689
8790 decOut , decStates , attn = self .model .decoder (
88- tgtBatch [:- 1 ], decStates , context , initOutput )
89- for dec_t , tgt_t in zip (decOut , tgtBatch [1 :].data ):
91+ tgtBatch [:, :- 1 ], decStates , context , initOutput )
92+ for dec_t , tgt_t in zip (decOut . transpose ( 0 , 1 ), tgtBatch . transpose ( 0 , 1 ) [1 :].data ):
9093 gen_t = self .model .generator .forward (dec_t )
9194 tgt_t = tgt_t .unsqueeze (1 )
9295 scores = gen_t .data .gather (1 , tgt_t )
@@ -104,7 +107,7 @@ def applyContextMask(m):
104107
105108 decOut = self .model .make_init_decoder_output (context )
106109
107- padMask = srcBatch .data .eq (onmt .Constants .PAD ).t (). unsqueeze (0 ).repeat (beamSize , 1 , 1 )
110+ padMask = srcBatch .data .eq (onmt .Constants .PAD ).unsqueeze (0 ).repeat (beamSize , 1 , 1 )
108111
109112 batchIdx = list (range (batchSize ))
110113 remainingSents = batchSize
@@ -117,9 +120,9 @@ def applyContextMask(m):
117120 if not b .done ]).t ().contiguous ().view (1 , - 1 )
118121
119122 decOut , decStates , attn = self .model .decoder (
120- Variable (input ), decStates , context , decOut )
123+ Variable (input ). transpose ( 0 , 1 ) , decStates , context , decOut )
121124 # decOut: 1 x (beam*batch) x numWords
122- decOut = decOut .squeeze (0 )
125+ decOut = decOut .transpose ( 0 , 1 ). squeeze (0 )
123126 out = self .model .generator .forward (decOut )
124127
125128 # batch x beam x numWords
@@ -174,7 +177,7 @@ def updateActive(t):
174177 scores , ks = beam [b ].sortBest ()
175178
176179 allScores += [scores [:n_best ]]
177- valid_attn = srcBatch .data [:, b ].ne (onmt .Constants .PAD ).nonzero ().squeeze (1 )
180+ valid_attn = srcBatch .transpose ( 0 , 1 ). data [:, b ].ne (onmt .Constants .PAD ).nonzero ().squeeze (1 )
178181 hyps , attn = zip (* [beam [b ].getHyp (k ) for k in ks [:n_best ]])
179182 attn = [a .index_select (1 , valid_attn ) for a in attn ]
180183 allHyp += [hyps ]
@@ -186,13 +189,14 @@ def translate(self, srcBatch, goldBatch):
186189 # (1) convert words to indexes
187190 dataset = self .buildData (srcBatch , goldBatch )
188191 batch = dataset [0 ]
192+ batch = [x .transpose (0 , 1 ) for x in batch ]
189193
190194 # (2) translate
191195 pred , predScore , attn , goldScore = self .translateBatch (batch )
192196
193197 # (3) convert indexes to words
194198 predBatch = []
195- for b in range (batch [0 ].size (1 )):
199+ for b in range (batch [0 ].size (0 )):
196200 predBatch .append (
197201 [self .buildTargetTokens (pred [b ][n ], srcBatch [b ], attn [b ][n ])
198202 for n in range (self .opt .n_best )]
0 commit comments