File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -383,8 +383,16 @@ def conversion_dtype(self):
383383 return  torch .bfloat16 
384384
385385 def  _get_autocast_with_kwargs (self ):
386+  kwargs  =  {}
387+ 
388+  # Set the default data-type based on the accelerator. 
389+  if  self .benchmark_experiment .accelerator  ==  "cuda" :
390+  kwargs ["dtype" ] =  torch .float16 
391+  else :
392+  # Both CPU and TPU autocast mode defaults to bfloat16. 
393+  kwargs ["dtype" ] =  torch .bfloat16 
394+ 
386395 if  self .use_amp ():
387-  kwargs  =  {"dtype" : torch .bfloat16 }
388396 if  self .benchmark_experiment .xla :
389397 # Should call device specific autocast implementations. 
390398 # PyTorch/XLA autocast does not run with dynamo, though: 
@@ -394,7 +402,6 @@ def _get_autocast_with_kwargs(self):
394402 else :
395403 autocast  =  torch .cuda .amp .autocast 
396404 else :
397-  kwargs  =  {}
398405 autocast  =  contextlib .nullcontext 
399406 return  (autocast , kwargs )
400407
                         You can’t perform that action at this time. 
           
                  
0 commit comments