Skip to content

Commit 7ececef

Browse files
bmccannsoumith
authored andcommitted
allowing validation data to volatile
1 parent 1649caf commit 7ececef

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

OpenNMT/onmt/Dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
class Dataset(object):
88

9-
def __init__(self, srcData, tgtData, batchSize, cuda):
9+
def __init__(self, srcData, tgtData, batchSize, cuda, volatile=False):
1010
self.src = srcData
1111
if tgtData:
1212
self.tgt = tgtData
@@ -16,7 +16,8 @@ def __init__(self, srcData, tgtData, batchSize, cuda):
1616
self.cuda = cuda
1717

1818
self.batchSize = batchSize
19-
self.numBatches = (len(self.src) + batchSize - 1) // batchSize
19+
self.numBatches = len(self.src) // batchSize
20+
self.volatile = volatile
2021

2122
def _batchify(self, data, align_right=False):
2223
max_length = max(x.size(0) for x in data)
@@ -30,7 +31,7 @@ def _batchify(self, data, align_right=False):
3031
if self.cuda:
3132
out = out.cuda()
3233

33-
v = Variable(out)
34+
v = Variable(out, volatile=self.volatile)
3435
return v
3536

3637
def __getitem__(self, index):

OpenNMT/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,8 @@ def main():
269269
trainData = onmt.Dataset(dataset['train']['src'],
270270
dataset['train']['tgt'], opt.batch_size, opt.gpus)
271271
validData = onmt.Dataset(dataset['valid']['src'],
272-
dataset['valid']['tgt'], opt.batch_size, opt.gpus)
272+
dataset['valid']['tgt'], opt.batch_size, opt.gpus,
273+
volatile=True)
273274

274275
dicts = dataset['dicts']
275276
print(' * vocabulary size. source = %d; target = %d' %

0 commit comments

Comments
 (0)