7171 help = 'seed for initializing training. ' )
7272parser .add_argument ('--gpu' , default = None , type = int ,
7373 help = 'GPU id to use.' )
74+ parser .add_argument ('--no-accel' , action = 'store_true' ,
75+  help = 'disables accelerator' )
7476parser .add_argument ('--multiprocessing-distributed' , action = 'store_true' ,
7577 help = 'Use multi-processing distributed training to launch ' 
7678 'N processes per node, which has N GPUs. This is the ' 
@@ -104,8 +106,17 @@ def main():
104106
105107 args .distributed  =  args .world_size  >  1  or  args .multiprocessing_distributed 
106108
107-  if  torch .cuda .is_available ():
108-  ngpus_per_node  =  torch .cuda .device_count ()
109+  use_accel  =  not  args .no_accel  and  torch .accelerator .is_available ()
110+ 
111+  if  use_accel :
112+  device  =  torch .accelerator .current_accelerator ()
113+  else :
114+  device  =  torch .device ("cpu" )
115+ 
116+  print (f"Using device: { device }  )
117+ 
118+  if  device .type  == 'cuda' :
119+  ngpus_per_node  =  torch .accelerator .device_count ()
109120 if  ngpus_per_node  ==  1  and  args .dist_backend  ==  "nccl" :
110121 warnings .warn ("nccl backend >=2.5 requires GPU count>1, see https://github.com/NVIDIA/nccl/issues/103 perhaps use 'gloo'" )
111122 else :
@@ -127,8 +138,15 @@ def main_worker(gpu, ngpus_per_node, args):
127138 global  best_acc1 
128139 args .gpu  =  gpu 
129140
130-  if  args .gpu  is  not None :
131-  print ("Use GPU: {} for training" .format (args .gpu ))
141+  use_accel  =  not  args .no_accel  and  torch .accelerator .is_available ()
142+ 
143+  if  use_accel :
144+  if  args .gpu  is  not None :
145+  torch .accelerator .set_device_index (args .gpu )
146+  print ("Use GPU: {} for training" .format (args .gpu ))
147+  device  =  torch .accelerator .current_accelerator ()
148+  else :
149+  device  =  torch .device ("cpu" )
132150
133151 if  args .distributed :
134152 if  args .dist_url  ==  "env://"  and  args .rank  ==  - 1 :
@@ -147,16 +165,16 @@ def main_worker(gpu, ngpus_per_node, args):
147165 print ("=> creating model '{}'" .format (args .arch ))
148166 model  =  models .__dict__ [args .arch ]()
149167
150-  if  not  torch . cuda . is_available ()  and   not   torch . backends . mps . is_available () :
168+  if  not  use_accel :
151169 print ('using CPU, this will be slow' )
152170 elif  args .distributed :
153171 # For multiprocessing distributed, DistributedDataParallel constructor 
154172 # should always set the single device scope, otherwise, 
155173 # DistributedDataParallel will use all available devices. 
156-  if  torch . cuda . is_available () :
174+  if  device . type   ==   ' cuda' 
157175 if  args .gpu  is  not None :
158176 torch .cuda .set_device (args .gpu )
159-  model .cuda (args . gpu )
177+  model .cuda (device )
160178 # When using a single GPU per process and per 
161179 # DistributedDataParallel, we need to divide the batch size 
162180 # ourselves based on the total number of GPUs of the current node. 
@@ -168,29 +186,17 @@ def main_worker(gpu, ngpus_per_node, args):
168186 # DistributedDataParallel will divide and allocate batch_size to all 
169187 # available GPUs if device_ids are not set 
170188 model  =  torch .nn .parallel .DistributedDataParallel (model )
171-  elif  args .gpu  is  not None  and  torch .cuda .is_available ():
172-  torch .cuda .set_device (args .gpu )
173-  model  =  model .cuda (args .gpu )
174-  elif  torch .backends .mps .is_available ():
175-  device  =  torch .device ("mps" )
176-  model  =  model .to (device )
177-  else :
189+  elif  device .type  ==  'cuda' :
178190 # DataParallel will divide and allocate batch_size to all available GPUs 
179191 if  args .arch .startswith ('alexnet' ) or  args .arch .startswith ('vgg' ):
180192 model .features  =  torch .nn .DataParallel (model .features )
181193 model .cuda ()
182194 else :
183195 model  =  torch .nn .DataParallel (model ).cuda ()
184- 
185-  if  torch .cuda .is_available ():
186-  if  args .gpu :
187-  device  =  torch .device ('cuda:{}' .format (args .gpu ))
188-  else :
189-  device  =  torch .device ("cuda" )
190-  elif  torch .backends .mps .is_available ():
191-  device  =  torch .device ("mps" )
192196 else :
193-  device  =  torch .device ("cpu" )
197+  model .to (device )
198+ 
199+ 
194200 # define loss function (criterion), optimizer, and learning rate scheduler 
195201 criterion  =  nn .CrossEntropyLoss ().to (device )
196202
@@ -207,9 +213,9 @@ def main_worker(gpu, ngpus_per_node, args):
207213 print ("=> loading checkpoint '{}'" .format (args .resume ))
208214 if  args .gpu  is  None :
209215 checkpoint  =  torch .load (args .resume )
210-  elif   torch . cuda . is_available () :
216+  else :
211217 # Map model to be loaded to specified single gpu. 
212-  loc  =  'cuda:{}' . format ( args .gpu ) 
218+  loc  =  f' { device . type } : { args .gpu } ' 
213219 checkpoint  =  torch .load (args .resume , map_location = loc )
214220 args .start_epoch  =  checkpoint ['epoch' ]
215221 best_acc1  =  checkpoint ['best_acc1' ]
@@ -302,11 +308,14 @@ def main_worker(gpu, ngpus_per_node, args):
302308
303309
304310def  train (train_loader , model , criterion , optimizer , epoch , device , args ):
305-  batch_time  =  AverageMeter ('Time' , ':6.3f' )
306-  data_time  =  AverageMeter ('Data' , ':6.3f' )
307-  losses  =  AverageMeter ('Loss' , ':.4e' )
308-  top1  =  AverageMeter ('Acc@1' , ':6.2f' )
309-  top5  =  AverageMeter ('Acc@5' , ':6.2f' )
311+  
312+  use_accel  =  not  args .no_accel  and  torch .accelerator .is_available ()
313+ 
314+  batch_time  =  AverageMeter ('Time' , use_accel , ':6.3f' , Summary .NONE )
315+  data_time  =  AverageMeter ('Data' , use_accel , ':6.3f' , Summary .NONE )
316+  losses  =  AverageMeter ('Loss' , use_accel , ':.4e' , Summary .NONE )
317+  top1  =  AverageMeter ('Acc@1' , use_accel , ':6.2f' , Summary .NONE )
318+  top5  =  AverageMeter ('Acc@5' , use_accel , ':6.2f' , Summary .NONE )
310319 progress  =  ProgressMeter (
311320 len (train_loader ),
312321 [batch_time , data_time , losses , top1 , top5 ],
@@ -349,18 +358,27 @@ def train(train_loader, model, criterion, optimizer, epoch, device, args):
349358
350359def  validate (val_loader , model , criterion , args ):
351360
361+  use_accel  =  not  args .no_accel  and  torch .accelerator .is_available ()
362+ 
352363 def  run_validate (loader , base_progress = 0 ):
364+ 
365+  if  use_accel :
366+  device  =  torch .accelerator .current_accelerator ()
367+  else :
368+  device  =  torch .device ("cpu" )
369+ 
353370 with  torch .no_grad ():
354371 end  =  time .time ()
355372 for  i , (images , target ) in  enumerate (loader ):
356373 i  =  base_progress  +  i 
357-  if  args .gpu  is  not None  and  torch .cuda .is_available ():
358-  images  =  images .cuda (args .gpu , non_blocking = True )
359-  if  torch .backends .mps .is_available ():
360-  images  =  images .to ('mps' )
361-  target  =  target .to ('mps' )
362-  if  torch .cuda .is_available ():
363-  target  =  target .cuda (args .gpu , non_blocking = True )
374+  if  use_accel :
375+  if  args .gpu  is  not None  and  device .type == 'cuda' :
376+  torch .accelerator .set_device_index (argps .gpu )
377+  images  =  images .cuda (args .gpu , non_blocking = True )
378+  target  =  target .cuda (args .gpu , non_blocking = True )
379+  else :
380+  images  =  images .to (device )
381+  target  =  target .to (device )
364382
365383 # compute output 
366384 output  =  model (images )
@@ -379,10 +397,10 @@ def run_validate(loader, base_progress=0):
379397 if  i  %  args .print_freq  ==  0 :
380398 progress .display (i  +  1 )
381399
382-  batch_time  =  AverageMeter ('Time' , ':6.3f' , Summary .NONE )
383-  losses  =  AverageMeter ('Loss' , ':.4e' , Summary .NONE )
384-  top1  =  AverageMeter ('Acc@1' , ':6.2f' , Summary .AVERAGE )
385-  top5  =  AverageMeter ('Acc@5' , ':6.2f' , Summary .AVERAGE )
400+  batch_time  =  AverageMeter ('Time' , use_accel ,  ':6.3f' , Summary .NONE )
401+  losses  =  AverageMeter ('Loss' , use_accel ,  ':.4e' , Summary .NONE )
402+  top1  =  AverageMeter ('Acc@1' , use_accel ,  ':6.2f' , Summary .AVERAGE )
403+  top5  =  AverageMeter ('Acc@5' , use_accel ,  ':6.2f' , Summary .AVERAGE )
386404 progress  =  ProgressMeter (
387405 len (val_loader ) +  (args .distributed  and  (len (val_loader .sampler ) *  args .world_size  <  len (val_loader .dataset ))),
388406 [batch_time , losses , top1 , top5 ],
@@ -422,8 +440,9 @@ class Summary(Enum):
422440
423441class  AverageMeter (object ):
424442 """Computes and stores the average and current value""" 
425-  def  __init__ (self , name , fmt = ':f' , summary_type = Summary .AVERAGE ):
443+  def  __init__ (self , name , use_accel ,  fmt = ':f' , summary_type = Summary .AVERAGE ):
426444 self .name  =  name 
445+  self .use_accel  =  use_accel 
427446 self .fmt  =  fmt 
428447 self .summary_type  =  summary_type 
429448 self .reset ()
@@ -440,11 +459,9 @@ def update(self, val, n=1):
440459 self .count  +=  n 
441460 self .avg  =  self .sum  /  self .count 
442461
443-  def  all_reduce (self ):
444-  if  torch .cuda .is_available ():
445-  device  =  torch .device ("cuda" )
446-  elif  torch .backends .mps .is_available ():
447-  device  =  torch .device ("mps" )
462+  def  all_reduce (self ): 
463+  if  use_accel :
464+  device  =  torch .accelerator .current_accelerator ()
448465 else :
449466 device  =  torch .device ("cpu" )
450467 total  =  torch .tensor ([self .sum , self .count ], dtype = torch .float32 , device = device )
0 commit comments