Skip to content

Commit 1c6d9d2

Browse files
bartolsthoornsoumith
authored andcommitted
Change reusing of Variables (pytorch#150)
1 parent dc10cd8 commit 1c6d9d2

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

dcgan/main.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,6 @@ def forward(self, input):
198198
input, label = input.cuda(), label.cuda()
199199
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
200200

201-
input = Variable(input)
202-
label = Variable(label)
203-
noise = Variable(noise)
204201
fixed_noise = Variable(fixed_noise)
205202

206203
# setup optimizer
@@ -216,21 +213,25 @@ def forward(self, input):
216213
netD.zero_grad()
217214
real_cpu, _ = data
218215
batch_size = real_cpu.size(0)
219-
input.data.resize_(real_cpu.size()).copy_(real_cpu)
220-
label.data.resize_(batch_size).fill_(real_label)
221-
222-
output = netD(input)
223-
errD_real = criterion(output, label)
216+
if opt.cuda:
217+
real_cpu = real_cpu.cuda()
218+
input.resize_as_(real_cpu).copy_(real_cpu)
219+
label.resize_(batch_size).fill_(real_label)
220+
inputv = Variable(input)
221+
labelv = Variable(label)
222+
223+
output = netD(inputv)
224+
errD_real = criterion(output, labelv)
224225
errD_real.backward()
225226
D_x = output.data.mean()
226227

227228
# train with fake
228-
noise.data.resize_(batch_size, nz, 1, 1)
229-
noise.data.normal_(0, 1)
230-
fake = netG(noise)
231-
label.data.fill_(fake_label)
229+
noise.resize_(batch_size, nz, 1, 1).normal_(0, 1)
230+
noisev = Variable(noise)
231+
fake = netG(noisev)
232+
labelv = Variable(label.fill_(fake_label))
232233
output = netD(fake.detach())
233-
errD_fake = criterion(output, label)
234+
errD_fake = criterion(output, labelv)
234235
errD_fake.backward()
235236
D_G_z1 = output.data.mean()
236237
errD = errD_real + errD_fake
@@ -240,9 +241,9 @@ def forward(self, input):
240241
# (2) Update G network: maximize log(D(G(z)))
241242
###########################
242243
netG.zero_grad()
243-
label.data.fill_(real_label) # fake labels are real for generator cost
244+
labelv = Variable(label.fill_(real_label)) # fake labels are real for generator cost
244245
output = netD(fake)
245-
errG = criterion(output, label)
246+
errG = criterion(output, labelv)
246247
errG.backward()
247248
D_G_z2 = output.data.mean()
248249
optimizerG.step()

0 commit comments

Comments
 (0)