Skip to content

Commit 75b54ee

Browse files
ysiraichibhavya01
authored andcommitted
[benchmarks] Fix AMP data-type. (#6550)
1 parent 44d218c commit 75b54ee

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

benchmarks/torchbench_model.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,16 @@ def conversion_dtype(self):
383383
return torch.bfloat16
384384

385385
def _get_autocast_with_kwargs(self):
386+
kwargs = {}
387+
388+
# Set the default data-type based on the accelerator.
389+
if self.benchmark_experiment.accelerator == "cuda":
390+
kwargs["dtype"] = torch.float16
391+
else:
392+
# Both CPU and TPU autocast mode defaults to bfloat16.
393+
kwargs["dtype"] = torch.bfloat16
394+
386395
if self.use_amp():
387-
kwargs = {"dtype": torch.bfloat16}
388396
if self.benchmark_experiment.xla:
389397
# Should call device specific autocast implementations.
390398
# PyTorch/XLA autocast does not run with dynamo, though:
@@ -394,7 +402,6 @@ def _get_autocast_with_kwargs(self):
394402
else:
395403
autocast = torch.cuda.amp.autocast
396404
else:
397-
kwargs = {}
398405
autocast = contextlib.nullcontext
399406
return (autocast, kwargs)
400407

0 commit comments

Comments
 (0)