@@ -48,6 +48,10 @@ def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_datase
4848 # set torchrun variables 
4949 self .local_rank  =  int (os .environ ["LOCAL_RANK" ])
5050 self .global_rank  =  int (os .environ ["RANK" ]) 
51+  # set device 
52+  self .acc  =  torch .accelerator .current_accelerator ()
53+  self .device : torch .device  =  torch .device (f"{ self .acc } { self .local_rank }  )
54+  self .device_type  =  self .device .type 
5155 # data stuff 
5256 self .train_dataset  =  train_dataset 
5357 self .train_loader  =  self ._prepare_dataloader (train_dataset )
@@ -58,7 +62,7 @@ def __init__(self, trainer_config: TrainerConfig, model, optimizer, train_datase
5862 self .optimizer  =  optimizer  
5963 self .save_every  =  self .config .save_every 
6064 if  self .config .use_amp :
61-  self .scaler  =  torch .cuda . amp .GradScaler ()
65+  self .scaler  =  torch .amp .GradScaler (self . device_type )
6266 # load snapshot if available. only necessary on the first node. 
6367 if  self .config .snapshot_path  is  None :
6468 self .config .snapshot_path  =  "snapshot.pt" 
@@ -93,7 +97,7 @@ def _load_snapshot(self):
9397
9498
9599 def  _run_batch (self , source , targets , train : bool  =  True ) ->  float :
96-  with  torch .set_grad_enabled (train ), torch .amp .autocast (device_type = "cuda" , dtype = torch .float16 , enabled = (self .config .use_amp )):
100+  with  torch .set_grad_enabled (train ), torch .amp .autocast (device_type = self . device_type , dtype = torch .float16 , enabled = (self .config .use_amp )):
97101 _ , loss  =  self .model (source , targets )
98102
99103 if  train :
@@ -119,7 +123,7 @@ def _run_epoch(self, epoch: int, dataloader: DataLoader, train: bool = True):
119123 targets  =  targets .to (self .local_rank )
120124 batch_loss  =  self ._run_batch (source , targets , train )
121125 if  iter  %  100  ==  0 :
122-  print (f"[GPU { self .global_rank } { epoch } { iter } { step_type } { batch_loss :.5f}  )
126+  print (f"[RANK { self .global_rank } { epoch } { iter } { step_type } { batch_loss :.5f}  )
123127
124128 def  _save_snapshot (self , epoch ):
125129 # capture snapshot 
0 commit comments