Skip to content

Commit 9012fae

Browse files
committed
fix for 0.2
1 parent 5c2b513 commit 9012fae

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

dcgan/main.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
parser = argparse.ArgumentParser()
18-
parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw ')
18+
parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw | fake')
1919
parser.add_argument('--dataroot', required=True, help='path to dataset')
2020
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
2121
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
@@ -77,8 +77,10 @@
7777
transforms.Scale(opt.imageSize),
7878
transforms.ToTensor(),
7979
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
80-
])
81-
)
80+
]))
81+
elif opt.dataset == 'fake':
82+
dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
83+
transform=transforms.ToTensor())
8284
assert dataset
8385
dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
8486
shuffle=True, num_workers=int(opt.workers))
@@ -173,7 +175,7 @@ def forward(self, input):
173175
else:
174176
output = self.main(input)
175177

176-
return output.view(-1, 1)
178+
return output.view(-1, 1).squeeze(1)
177179

178180

179181
netD = _netD(ngpu)

0 commit comments

Comments
 (0)