Skip to content

Commit 5f5caed

Browse files
msaroufimpytorchmergebot
authored andcommitted
do not cast all inputs in benchmarks (pytorch#108456)
Fixes why stable diffusion is not showing up in inference dashboard even though it shows up in training dashboard The reason is stable diffusion in torchbench has a line like `input_tensor = input_tensor.long().to(self.device)` and if you cast this to a bfloat16 the inference will fail <img width="1705" alt="Screenshot 2023-09-01 at 4 37 49 PM" src="https://github.com/pytorch/pytorch/assets/3282513/ada0d381-1af0-4378-8e8b-2375b39c3713"> Pull Request resolved: pytorch#108456 Approved by: https://github.com/cpuhrsch
1 parent b8af8ac commit 5f5caed

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

benchmarks/dynamo/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,8 @@ class CI(NamedTuple):
305305
"dlrm",
306306
}
307307

308+
DO_NOT_CAST_INPUTS = {"stable_diffusion"}
309+
308310

309311
def model_specified_by_path(path_and_class_str):
310312
return ":" in path_and_class_str
@@ -3482,8 +3484,11 @@ def detect_and_mark_batch(t):
34823484
torch.cuda.set_per_process_memory_fraction(
34833485
args.per_process_memory_fraction
34843486
)
3487+
if model_name in DO_NOT_CAST_INPUTS:
3488+
model, _ = runner.cast_based_on_args(model, example_inputs)
34853489

3486-
model, example_inputs = runner.cast_based_on_args(model, example_inputs)
3490+
else:
3491+
model, example_inputs = runner.cast_based_on_args(model, example_inputs)
34873492
runner.run_one_model(
34883493
name,
34893494
model,

0 commit comments

Comments
 (0)