@@ -48,17 +48,17 @@ def buildTargetTokens(self, pred, src, attn):
4848
4949 def translateBatch (self , batch ):
5050 srcBatch , tgtBatch = batch
51- batchSize = srcBatch .size (0 )
51+ batchSize = srcBatch .size (1 )
5252 beamSize = self .opt .beam_size
5353
5454 # (1) run the encoder on the src
5555
5656 # have to execute the encoder manually to deal with padding
5757 encStates = None
5858 context = []
59- for srcBatch_t in srcBatch .chunk ( srcBatch . size ( 1 ), dim = 1 ):
59+ for srcBatch_t in srcBatch .split ( 1 ):
6060 encStates , context_t = self .model .encoder (srcBatch_t , hidden = encStates )
61- batchPadIdx = srcBatch_t .data .squeeze (1 ).eq (onmt .Constants .PAD ).nonzero ()
61+ batchPadIdx = srcBatch_t .data .squeeze (0 ).eq (onmt .Constants .PAD ).nonzero ()
6262 if batchPadIdx .nelement () > 0 :
6363 batchPadIdx = batchPadIdx .squeeze (1 )
6464 encStates [0 ].data .index_fill_ (1 , batchPadIdx , 0 )
@@ -73,7 +73,7 @@ def translateBatch(self, batch):
7373
7474 # This mask is applied to the attention model inside the decoder
7575 # so that the attention ignores source padding
76- padMask = srcBatch .data .eq (onmt .Constants .PAD )
76+ padMask = srcBatch .data .eq (onmt .Constants .PAD ). t ()
7777 def applyContextMask (m ):
7878 if isinstance (m , onmt .modules .GlobalAttention ):
7979 m .applyMask (padMask )
@@ -88,8 +88,8 @@ def applyContextMask(m):
8888 initOutput = self .model .make_init_decoder_output (context )
8989
9090 decOut , decStates , attn = self .model .decoder (
91- tgtBatch [:, :- 1 ], decStates , context , initOutput )
92- for dec_t , tgt_t in zip (decOut . transpose ( 0 , 1 ), tgtBatch . transpose ( 0 , 1 ) [1 :].data ):
91+ tgtBatch [:- 1 ], decStates , context , initOutput )
92+ for dec_t , tgt_t in zip (decOut , tgtBatch [1 :].data ):
9393 gen_t = self .model .generator .forward (dec_t )
9494 tgt_t = tgt_t .unsqueeze (1 )
9595 scores = gen_t .data .gather (1 , tgt_t )
@@ -107,7 +107,7 @@ def applyContextMask(m):
107107
108108 decOut = self .model .make_init_decoder_output (context )
109109
110- padMask = srcBatch .data .eq (onmt .Constants .PAD ).unsqueeze (0 ).repeat (beamSize , 1 , 1 )
110+ padMask = srcBatch .data .eq (onmt .Constants .PAD ).t (). unsqueeze (0 ).repeat (beamSize , 1 , 1 )
111111
112112 batchIdx = list (range (batchSize ))
113113 remainingSents = batchSize
@@ -120,9 +120,9 @@ def applyContextMask(m):
120120 if not b .done ]).t ().contiguous ().view (1 , - 1 )
121121
122122 decOut , decStates , attn = self .model .decoder (
123- Variable (input , volatile = True ). transpose ( 0 , 1 ) , decStates , context , decOut )
123+ Variable (input , volatile = True ), decStates , context , decOut )
124124 # decOut: 1 x (beam*batch) x numWords
125- decOut = decOut .transpose ( 0 , 1 ). squeeze (0 )
125+ decOut = decOut .squeeze (0 )
126126 out = self .model .generator .forward (decOut )
127127
128128 # batch x beam x numWords
@@ -177,7 +177,7 @@ def updateActive(t):
177177 scores , ks = beam [b ].sortBest ()
178178
179179 allScores += [scores [:n_best ]]
180- valid_attn = srcBatch .transpose ( 0 , 1 ). data [:, b ].ne (onmt .Constants .PAD ).nonzero ().squeeze (1 )
180+ valid_attn = srcBatch .data [:, b ].ne (onmt .Constants .PAD ).nonzero ().squeeze (1 )
181181 hyps , attn = zip (* [beam [b ].getHyp (k ) for k in ks [:n_best ]])
182182 attn = [a .index_select (1 , valid_attn ) for a in attn ]
183183 allHyp += [hyps ]
@@ -189,14 +189,13 @@ def translate(self, srcBatch, goldBatch):
189189 # (1) convert words to indexes
190190 dataset = self .buildData (srcBatch , goldBatch )
191191 batch = dataset [0 ]
192- batch = [x .transpose (0 , 1 ) for x in batch ]
193192
194193 # (2) translate
195194 pred , predScore , attn , goldScore = self .translateBatch (batch )
196195
197196 # (3) convert indexes to words
198197 predBatch = []
199- for b in range (batch [0 ].size (0 )):
198+ for b in range (batch [0 ].size (1 )):
200199 predBatch .append (
201200 [self .buildTargetTokens (pred [b ][n ], srcBatch [b ], attn [b ][n ])
202201 for n in range (self .opt .n_best )]
0 commit comments