7171 help = 'seed for initializing training. ' )
7272parser .add_argument ('--gpu' , default = None , type = int ,
7373 help = 'GPU id to use.' )
74+ parser .add_argument ('--no-accel' , action = 'store_true' ,
75+ help = 'disables accelerator' )
7476parser .add_argument ('--multiprocessing-distributed' , action = 'store_true' ,
7577 help = 'Use multi-processing distributed training to launch '
7678 'N processes per node, which has N GPUs. This is the '
@@ -104,8 +106,17 @@ def main():
104106
105107 args .distributed = args .world_size > 1 or args .multiprocessing_distributed
106108
107- if torch .cuda .is_available ():
108- ngpus_per_node = torch .cuda .device_count ()
109+ use_accel = not args .no_accel and torch .accelerator .is_available ()
110+
111+ if use_accel :
112+ device = torch .accelerator .current_accelerator ()
113+ else :
114+ device = torch .device ("cpu" )
115+
116+ print (f"Using device: { device } " )
117+
118+ if device .type == 'cuda' :
119+ ngpus_per_node = torch .accelerator .device_count ()
109120 if ngpus_per_node == 1 and args .dist_backend == "nccl" :
110121 warnings .warn ("nccl backend >=2.5 requires GPU count>1, see https://github.com/NVIDIA/nccl/issues/103 perhaps use 'gloo'" )
111122 else :
@@ -127,8 +138,15 @@ def main_worker(gpu, ngpus_per_node, args):
127138 global best_acc1
128139 args .gpu = gpu
129140
130- if args .gpu is not None :
131- print ("Use GPU: {} for training" .format (args .gpu ))
141+ use_accel = not args .no_accel and torch .accelerator .is_available ()
142+
143+ if use_accel :
144+ if args .gpu is not None :
145+ torch .accelerator .set_device_index (args .gpu )
146+ print ("Use GPU: {} for training" .format (args .gpu ))
147+ device = torch .accelerator .current_accelerator ()
148+ else :
149+ device = torch .device ("cpu" )
132150
133151 if args .distributed :
134152 if args .dist_url == "env://" and args .rank == - 1 :
@@ -147,16 +165,16 @@ def main_worker(gpu, ngpus_per_node, args):
147165 print ("=> creating model '{}'" .format (args .arch ))
148166 model = models .__dict__ [args .arch ]()
149167
150- if not torch . cuda . is_available () and not torch . backends . mps . is_available () :
168+ if not use_accel :
151169 print ('using CPU, this will be slow' )
152170 elif args .distributed :
153171 # For multiprocessing distributed, DistributedDataParallel constructor
154172 # should always set the single device scope, otherwise,
155173 # DistributedDataParallel will use all available devices.
156- if torch . cuda . is_available () :
174+ if device . type == ' cuda' :
157175 if args .gpu is not None :
158176 torch .cuda .set_device (args .gpu )
159- model .cuda (args . gpu )
177+ model .cuda (device )
160178 # When using a single GPU per process and per
161179 # DistributedDataParallel, we need to divide the batch size
162180 # ourselves based on the total number of GPUs of the current node.
@@ -168,29 +186,17 @@ def main_worker(gpu, ngpus_per_node, args):
168186 # DistributedDataParallel will divide and allocate batch_size to all
169187 # available GPUs if device_ids are not set
170188 model = torch .nn .parallel .DistributedDataParallel (model )
171- elif args .gpu is not None and torch .cuda .is_available ():
172- torch .cuda .set_device (args .gpu )
173- model = model .cuda (args .gpu )
174- elif torch .backends .mps .is_available ():
175- device = torch .device ("mps" )
176- model = model .to (device )
177- else :
189+ elif device .type == 'cuda' :
178190 # DataParallel will divide and allocate batch_size to all available GPUs
179191 if args .arch .startswith ('alexnet' ) or args .arch .startswith ('vgg' ):
180192 model .features = torch .nn .DataParallel (model .features )
181193 model .cuda ()
182194 else :
183195 model = torch .nn .DataParallel (model ).cuda ()
184-
185- if torch .cuda .is_available ():
186- if args .gpu :
187- device = torch .device ('cuda:{}' .format (args .gpu ))
188- else :
189- device = torch .device ("cuda" )
190- elif torch .backends .mps .is_available ():
191- device = torch .device ("mps" )
192196 else :
193- device = torch .device ("cpu" )
197+ model .to (device )
198+
199+
194200 # define loss function (criterion), optimizer, and learning rate scheduler
195201 criterion = nn .CrossEntropyLoss ().to (device )
196202
@@ -207,9 +213,9 @@ def main_worker(gpu, ngpus_per_node, args):
207213 print ("=> loading checkpoint '{}'" .format (args .resume ))
208214 if args .gpu is None :
209215 checkpoint = torch .load (args .resume )
210- elif torch . cuda . is_available () :
216+ else :
211217 # Map model to be loaded to specified single gpu.
212- loc = 'cuda:{}' . format ( args .gpu )
218+ loc = f' { device . type } : { args .gpu } '
213219 checkpoint = torch .load (args .resume , map_location = loc )
214220 args .start_epoch = checkpoint ['epoch' ]
215221 best_acc1 = checkpoint ['best_acc1' ]
@@ -302,11 +308,14 @@ def main_worker(gpu, ngpus_per_node, args):
302308
303309
304310def train (train_loader , model , criterion , optimizer , epoch , device , args ):
305- batch_time = AverageMeter ('Time' , ':6.3f' )
306- data_time = AverageMeter ('Data' , ':6.3f' )
307- losses = AverageMeter ('Loss' , ':.4e' )
308- top1 = AverageMeter ('Acc@1' , ':6.2f' )
309- top5 = AverageMeter ('Acc@5' , ':6.2f' )
311+
312+ use_accel = not args .no_accel and torch .accelerator .is_available ()
313+
314+ batch_time = AverageMeter ('Time' , use_accel , ':6.3f' , Summary .NONE )
315+ data_time = AverageMeter ('Data' , use_accel , ':6.3f' , Summary .NONE )
316+ losses = AverageMeter ('Loss' , use_accel , ':.4e' , Summary .NONE )
317+ top1 = AverageMeter ('Acc@1' , use_accel , ':6.2f' , Summary .NONE )
318+ top5 = AverageMeter ('Acc@5' , use_accel , ':6.2f' , Summary .NONE )
310319 progress = ProgressMeter (
311320 len (train_loader ),
312321 [batch_time , data_time , losses , top1 , top5 ],
@@ -349,18 +358,27 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args):
349358
350359def validate (val_loader , model , criterion , args ):
351360
361+ use_accel = not args .no_accel and torch .accelerator .is_available ()
362+
352363 def run_validate (loader , base_progress = 0 ):
364+
365+ if use_accel :
366+ device = torch .accelerator .current_accelerator ()
367+ else :
368+ device = torch .device ("cpu" )
369+
353370 with torch .no_grad ():
354371 end = time .time ()
355372 for i , (images , target ) in enumerate (loader ):
356373 i = base_progress + i
357- if args .gpu is not None and torch .cuda .is_available ():
358- images = images .cuda (args .gpu , non_blocking = True )
359- if torch .backends .mps .is_available ():
360- images = images .to ('mps' )
361- target = target .to ('mps' )
362- if torch .cuda .is_available ():
363- target = target .cuda (args .gpu , non_blocking = True )
374+ if use_accel :
375+ if args .gpu is not None and device .type == 'cuda' :
376+ torch .accelerator .set_device_index (argps .gpu )
377+ images = images .cuda (args .gpu , non_blocking = True )
378+ target = target .cuda (args .gpu , non_blocking = True )
379+ else :
380+ images = images .to (device )
381+ target = target .to (device )
364382
365383 # compute output
366384 output = model (images )
@@ -379,10 +397,10 @@ def run_validate(loader, base_progress=0):
379397 if i % args .print_freq == 0 :
380398 progress .display (i + 1 )
381399
382- batch_time = AverageMeter ('Time' , ':6.3f' , Summary .NONE )
383- losses = AverageMeter ('Loss' , ':.4e' , Summary .NONE )
384- top1 = AverageMeter ('Acc@1' , ':6.2f' , Summary .AVERAGE )
385- top5 = AverageMeter ('Acc@5' , ':6.2f' , Summary .AVERAGE )
400+ batch_time = AverageMeter ('Time' , use_accel , ':6.3f' , Summary .NONE )
401+ losses = AverageMeter ('Loss' , use_accel , ':.4e' , Summary .NONE )
402+ top1 = AverageMeter ('Acc@1' , use_accel , ':6.2f' , Summary .AVERAGE )
403+ top5 = AverageMeter ('Acc@5' , use_accel , ':6.2f' , Summary .AVERAGE )
386404 progress = ProgressMeter (
387405 len (val_loader ) + (args .distributed and (len (val_loader .sampler ) * args .world_size < len (val_loader .dataset ))),
388406 [batch_time , losses , top1 , top5 ],
@@ -422,8 +440,9 @@ class Summary(Enum):
422440
423441class AverageMeter (object ):
424442 """Computes and stores the average and current value"""
425- def __init__ (self , name , fmt = ':f' , summary_type = Summary .AVERAGE ):
443+ def __init__ (self , name , use_accel , fmt = ':f' , summary_type = Summary .AVERAGE ):
426444 self .name = name
445+ self .use_accel = use_accel
427446 self .fmt = fmt
428447 self .summary_type = summary_type
429448 self .reset ()
@@ -440,11 +459,9 @@ def update(self, val, n=1):
440459 self .count += n
441460 self .avg = self .sum / self .count
442461
443- def all_reduce (self ):
444- if torch .cuda .is_available ():
445- device = torch .device ("cuda" )
446- elif torch .backends .mps .is_available ():
447- device = torch .device ("mps" )
462+ def all_reduce (self ):
463+ if use_accel :
464+ device = torch .accelerator .current_accelerator ()
448465 else :
449466 device = torch .device ("cpu" )
450467 total = torch .tensor ([self .sum , self .count ], dtype = torch .float32 , device = device )
0 commit comments