8787 See README for specific formatting instructions.""" )
8888
8989# GPU
90- parser .add_argument ('-gpu ' , default = [], nargs = '+' , type = int ,
90+ parser .add_argument ('-gpus ' , default = [], nargs = '+' , type = int ,
9191 help = "Use CUDA" )
9292
9393parser .add_argument ('-log_interval' , type = int , default = 50 ,
9696# help="Seed for random initialization")
9797
9898opt = parser .parse_args ()
99- opt .cuda = len (opt .gpu )
99+ opt .cuda = len (opt .gpus )
100100
101101print (opt )
102102
103103if torch .cuda .is_available () and not opt .cuda :
104104 print ("WARNING: You have a CUDA device, so you should probably run with -cuda" )
105105
106106if opt .cuda :
107- cuda .set_device (opt .gpu [0 ])
107+ cuda .set_device (opt .gpus [0 ])
108108
109109def NMTCriterion (vocabSize ):
110110 weight = torch .ones (vocabSize )
@@ -118,7 +118,7 @@ def NMTCriterion(vocabSize):
118118def memoryEfficientLoss (outputs , targets , generator , crit , eval = False ):
119119 # compute generations one piece at a time
120120 loss = 0
121- outputs = Variable (outputs .data , requires_grad = (not eval ), volatile = eval )
121+ outputs = Variable (outputs .data , requires_grad = (not eval ), volatile = eval ). contiguous ()
122122
123123 batch_size = outputs .size (1 )
124124 outputs_split = torch .split (outputs , opt .max_generator_batches )
@@ -257,9 +257,11 @@ def main():
257257 generator = nn .Sequential (
258258 nn .Linear (opt .rnn_size , dicts ['tgt' ].size ()),
259259 nn .LogSoftmax ())
260- generator = nn .DataParallel (generator , device_ids = opt .gpu )
260+ if opt .cuda > 1 :
261+ generator = nn .DataParallel (generator , device_ids = opt .gpus )
261262 model = onmt .Models .NMTModel (encoder , decoder , generator )
262- model = nn .DataParallel (model , device_ids = opt .gpu )
263+ if opt .cuda > 1 :
264+ model = nn .DataParallel (model , device_ids = opt .gpus )
263265 if opt .cuda :
264266 model .cuda ()
265267 else :
0 commit comments