@@ -128,10 +128,11 @@ def __init__(self, ngpu):
128128 )
129129
130130 def forward (self , input ):
131- gpu_ids = None
132- if isinstance (input .data , torch .cuda .FloatTensor ) and self .ngpu >= 1 :
133- gpu_ids = range (self .ngpu )
134- return nn .parallel .data_parallel (self .main , input , gpu_ids )
131+ if isinstance (input .data , torch .cuda .FloatTensor ) and self .ngpu > 1 :
132+ output = nn .parallel .data_parallel (self .main , input , range (self .ngpu ))
133+ else :
134+ output = self .main (input )
135+ return output
135136
136137
137138netG = _netG (ngpu )
@@ -167,10 +168,11 @@ def __init__(self, ngpu):
167168 )
168169
169170 def forward (self , input ):
170- gpu_ids = None
171- if isinstance (input .data , torch .cuda .FloatTensor ) and self .ngpu > = 1 :
172- gpu_ids = range (self .ngpu )
173- output = nn .parallel .data_parallel (self .main , input , gpu_ids )
171+ if isinstance (input .data , torch .cuda .FloatTensor ) and self .ngpu > 1 :
172+ output = nn .parallel .data_parallel (self .main , input , range (self .ngpu ))
173+ else :
174+ output = self .main (input )
175+
174176 return output .view (- 1 , 1 )
175177
176178
0 commit comments