Skip to content

Commit 0dee89a

Browse files
bmccannsoumith
authored andcommitted
allowing the option of single device
1 parent f485f7b commit 0dee89a

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

OpenNMT/train.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
See README for specific formatting instructions.""")
8888

8989
# GPU
90-
parser.add_argument('-gpu', default=[], nargs='+', type=int,
90+
parser.add_argument('-gpus', default=[], nargs='+', type=int,
9191
help="Use CUDA")
9292

9393
parser.add_argument('-log_interval', type=int, default=50,
@@ -96,15 +96,15 @@
9696
# help="Seed for random initialization")
9797

9898
opt = parser.parse_args()
99-
opt.cuda = len(opt.gpu)
99+
opt.cuda = len(opt.gpus)
100100

101101
print(opt)
102102

103103
if torch.cuda.is_available() and not opt.cuda:
104104
print("WARNING: You have a CUDA device, so you should probably run with -cuda")
105105

106106
if opt.cuda:
107-
cuda.set_device(opt.gpu[0])
107+
cuda.set_device(opt.gpus[0])
108108

109109
def NMTCriterion(vocabSize):
110110
weight = torch.ones(vocabSize)
@@ -118,7 +118,7 @@ def NMTCriterion(vocabSize):
118118
def memoryEfficientLoss(outputs, targets, generator, crit, eval=False):
119119
# compute generations one piece at a time
120120
loss = 0
121-
outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval)
121+
outputs = Variable(outputs.data, requires_grad=(not eval), volatile=eval).contiguous()
122122

123123
batch_size = outputs.size(1)
124124
outputs_split = torch.split(outputs, opt.max_generator_batches)
@@ -257,9 +257,11 @@ def main():
257257
generator = nn.Sequential(
258258
nn.Linear(opt.rnn_size, dicts['tgt'].size()),
259259
nn.LogSoftmax())
260-
generator = nn.DataParallel(generator, device_ids=opt.gpu)
260+
if opt.cuda > 1:
261+
generator = nn.DataParallel(generator, device_ids=opt.gpus)
261262
model = onmt.Models.NMTModel(encoder, decoder, generator)
262-
model = nn.DataParallel(model, device_ids=opt.gpu)
263+
if opt.cuda > 1:
264+
model = nn.DataParallel(model, device_ids=opt.gpus)
263265
if opt.cuda:
264266
model.cuda()
265267
else:

0 commit comments

Comments
 (0)