@@ -198,9 +198,6 @@ def forward(self, input):
198198 input , label = input .cuda (), label .cuda ()
199199 noise , fixed_noise = noise .cuda (), fixed_noise .cuda ()
200200
201- input = Variable (input )
202- label = Variable (label )
203- noise = Variable (noise )
204201fixed_noise = Variable (fixed_noise )
205202
206203# setup optimizer
@@ -216,21 +213,25 @@ def forward(self, input):
216213 netD .zero_grad ()
217214 real_cpu , _ = data
218215 batch_size = real_cpu .size (0 )
219- input .data .resize_ (real_cpu .size ()).copy_ (real_cpu )
220- label .data .resize_ (batch_size ).fill_ (real_label )
221-
222- output = netD (input )
223- errD_real = criterion (output , label )
216+ if opt .cuda :
217+ real_cpu = real_cpu .cuda ()
218+ input .resize_as_ (real_cpu ).copy_ (real_cpu )
219+ label .resize_ (batch_size ).fill_ (real_label )
220+ inputv = Variable (input )
221+ labelv = Variable (label )
222+
223+ output = netD (inputv )
224+ errD_real = criterion (output , labelv )
224225 errD_real .backward ()
225226 D_x = output .data .mean ()
226227
227228 # train with fake
228- noise .data . resize_ (batch_size , nz , 1 , 1 )
229- noise . data . normal_ ( 0 , 1 )
230- fake = netG (noise )
231- label .data . fill_ (fake_label )
229+ noise .resize_ (batch_size , nz , 1 , 1 ). normal_ ( 0 , 1 )
230+ noisev = Variable ( noise )
231+ fake = netG (noisev )
232+ labelv = Variable ( label .fill_ (fake_label ) )
232233 output = netD (fake .detach ())
233- errD_fake = criterion (output , label )
234+ errD_fake = criterion (output , labelv )
234235 errD_fake .backward ()
235236 D_G_z1 = output .data .mean ()
236237 errD = errD_real + errD_fake
@@ -240,9 +241,9 @@ def forward(self, input):
240241 # (2) Update G network: maximize log(D(G(z)))
241242 ###########################
242243 netG .zero_grad ()
243- label .data . fill_ (real_label ) # fake labels are real for generator cost
244+ labelv = Variable ( label .fill_ (real_label ) ) # fake labels are real for generator cost
244245 output = netD (fake )
245- errG = criterion (output , label )
246+ errG = criterion (output , labelv )
246247 errG .backward ()
247248 D_G_z2 = output .data .mean ()
248249 optimizerG .step ()
0 commit comments