|
15 | 15 |
|
16 | 16 |
|
17 | 17 | parser = argparse.ArgumentParser() |
18 | | -parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw ') |
| 18 | +parser.add_argument('--dataset', required=True, help='cifar10 | lsun | imagenet | folder | lfw | fake') |
19 | 19 | parser.add_argument('--dataroot', required=True, help='path to dataset') |
20 | 20 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) |
21 | 21 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') |
|
77 | 77 | transforms.Scale(opt.imageSize), |
78 | 78 | transforms.ToTensor(), |
79 | 79 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
80 | | - ]) |
81 | | - ) |
| 80 | + ])) |
| 81 | +elif opt.dataset == 'fake': |
| 82 | + dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize), |
| 83 | + transform=transforms.ToTensor()) |
82 | 84 | assert dataset |
83 | 85 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, |
84 | 86 | shuffle=True, num_workers=int(opt.workers)) |
@@ -173,7 +175,7 @@ def forward(self, input): |
173 | 175 | else: |
174 | 176 | output = self.main(input) |
175 | 177 |
|
176 | | - return output.view(-1, 1) |
| 178 | + return output.view(-1, 1).squeeze(1) |
177 | 179 |
|
178 | 180 |
|
179 | 181 | netD = _netD(ngpu) |
|
0 commit comments