Skip to content

Commit ac5b745

Browse files
committed
fix dcgan
1 parent f89a371 commit ac5b745

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

dcgan/main.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,11 @@ def __init__(self, ngpu):
128128
)
129129

130130
def forward(self, input):
131-
gpu_ids = None
132-
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu >= 1:
133-
gpu_ids = range(self.ngpu)
134-
return nn.parallel.data_parallel(self.main, input, gpu_ids)
131+
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
132+
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
133+
else:
134+
output = self.main(input)
135+
return output
135136

136137

137138
netG = _netG(ngpu)
@@ -167,10 +168,11 @@ def __init__(self, ngpu):
167168
)
168169

169170
def forward(self, input):
170-
gpu_ids = None
171-
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > =1:
172-
gpu_ids = range(self.ngpu)
173-
output = nn.parallel.data_parallel(self.main, input, gpu_ids)
171+
if isinstance(input.data, torch.cuda.FloatTensor) and self.ngpu > 1:
172+
output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
173+
else:
174+
output = self.main(input)
175+
174176
return output.view(-1, 1)
175177

176178

0 commit comments

Comments
 (0)