|
20 | 20 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') |
21 | 21 | parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') |
22 | 22 | parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') |
23 | | -parser.add_argument('--ngf', type=int, default=64) |
24 | | -parser.add_argument('--ndf', type=int, default=64) |
| 23 | +parser.add_argument('--ngf', type=int, default=64, help='number of generator filters') |
| 24 | +parser.add_argument('--ndf', type=int, default=64, help='number of discriminator filters') |
25 | 25 | parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') |
26 | 26 | parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') |
27 | 27 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') |
28 | | -parser.add_argument('--cuda', action='store_true', default=False, help='enables cuda') |
29 | 28 | parser.add_argument('--dry-run', action='store_true', help='check a single training cycle works') |
30 | 29 | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') |
31 | 30 | parser.add_argument('--netG', default='', help="path to netG (to continue training)") |
32 | 31 | parser.add_argument('--netD', default='', help="path to netD (to continue training)") |
33 | 32 | parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints') |
34 | 33 | parser.add_argument('--manualSeed', type=int, help='manual seed') |
35 | 34 | parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set') |
36 | | -parser.add_argument('--mps', action='store_true', default=False, help='enables macOS GPU training') |
| 35 | +parser.add_argument('--accel', action='store_true', default=False, help='enables accelerator') |
37 | 36 |
|
38 | 37 | opt = parser.parse_args() |
39 | 38 | print(opt) |
|
51 | 50 |
|
52 | 51 | cudnn.benchmark = True |
53 | 52 |
|
54 | | -if torch.cuda.is_available() and not opt.cuda: |
55 | | - print("WARNING: You have a CUDA device, so you should probably run with --cuda") |
| 53 | +if opt.accel and torch.accelerator.is_available(): |
| 54 | + device = torch.accelerator.current_accelerator() |
| 55 | +else: |
| 56 | + device = torch.device("cpu") |
| 57 | + |
| 58 | +print(f"Using device: {device}") |
56 | 59 |
|
57 | | -if torch.backends.mps.is_available() and not opt.mps: |
58 | | - print("WARNING: You have mps device, to enable macOS GPU run with --mps") |
59 | | - |
60 | 60 | if opt.dataroot is None and str(opt.dataset).lower() != 'fake': |
61 | 61 | raise ValueError("`dataroot` parameter is required for dataset \"%s\"" % opt.dataset) |
62 | 62 |
|
|
106 | 106 | assert dataset |
107 | 107 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, |
108 | 108 | shuffle=True, num_workers=int(opt.workers)) |
109 | | -use_mps = opt.mps and torch.backends.mps.is_available() |
110 | | -if opt.cuda: |
111 | | - device = torch.device("cuda:0") |
112 | | -elif use_mps: |
113 | | - device = torch.device("mps") |
114 | | -else: |
115 | | - device = torch.device("cpu") |
116 | 109 |
|
117 | 110 | ngpu = int(opt.ngpu) |
118 | 111 | nz = int(opt.nz) |
@@ -158,7 +151,8 @@ def __init__(self, ngpu): |
158 | 151 | ) |
159 | 152 |
|
160 | 153 | def forward(self, input): |
161 | | - if input.is_cuda and self.ngpu > 1: |
| 154 | + |
| 155 | + if (input.is_cuda or input.is_xpu) and self.ngpu > 1: |
162 | 156 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) |
163 | 157 | else: |
164 | 158 | output = self.main(input) |
@@ -198,7 +192,7 @@ def __init__(self, ngpu): |
198 | 192 | ) |
199 | 193 |
|
200 | 194 | def forward(self, input): |
201 | | - if input.is_cuda and self.ngpu > 1: |
| 195 | + if (input.is_cuda or input.is_xpu) and self.ngpu > 1: |
202 | 196 | output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) |
203 | 197 | else: |
204 | 198 | output = self.main(input) |
|
0 commit comments