Skip to content

Commit d2c89aa

Browse files
committed
Review compliance, start
1 parent 680d3da commit d2c89aa

File tree

1 file changed

+18
-22
lines changed

1 file changed

+18
-22
lines changed

benchmark_v2/framework/benchmark_config.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(
2222
self,
2323
warmup_iterations: int = 5,
2424
measurement_iterations: int = 20,
25-
gpu_monitoring: bool = True,
25+
gpu_monitoring: bool = False, # False by default because it slows down the benchmark by a lot
2626
batch_size: int = 1,
2727
sequence_length: int = 128,
2828
num_tokens_to_generate: int = 128,
@@ -49,6 +49,9 @@ def __init__(
4949
self.compile_mode = compile_mode
5050
self.compile_options = compile_options if compile_options is not None else {}
5151
self.kernelize = kernelize
52+
# Constant parameters
53+
self.dtype = "torch.bfloat16"
54+
self.device = "cuda"
5255

5356
self.check_validity(skip_validity_check)
5457
self.name = name if name is not None else self.infer_name()
@@ -63,14 +66,6 @@ def check_validity(self, skip_validity_check: bool = False) -> None:
6366
logger.warning("Flash attention does not support compile mode. Turning off compile mode.")
6467
self.compile_mode = None
6568

66-
@property
67-
def device(self) -> str:
68-
return "cuda"
69-
70-
@property
71-
def dtype(self) -> str:
72-
return "torch.bfloat16"
73-
7469
@property
7570
def hash(self) -> str:
7671
return hashlib.sha256(json.dumps(self.to_dict()).encode()).hexdigest()
@@ -87,7 +82,7 @@ def infer_name(self) -> str:
8782
"kernelized" if self.kernelize else "unkernelized",
8883
])
8984

90-
def to_dict(self) -> dict[str, Union[None, int, float, str]]:
85+
def to_dict(self) -> dict[str, Any]:
9186
return {
9287
"name": self.name,
9388
"warmup_iterations": self.warmup_iterations,
@@ -104,20 +99,21 @@ def to_dict(self) -> dict[str, Union[None, int, float, str]]:
10499
}
105100

106101
@classmethod
107-
def from_dict(cls, data: dict[str, Any]) -> "BenchmarkConfig":
102+
def from_dict(cls, data: dict[str, Any], skip_validity_check: bool = False) -> "BenchmarkConfig":
108103
return cls(
109-
warmup_iterations=data["warmup_iterations"],
110-
measurement_iterations=data["measurement_iterations"],
111-
gpu_monitoring=data["gpu_monitoring"],
112-
batch_size=data["batch_size"],
113-
sequence_length=data["sequence_length"],
114-
num_tokens_to_generate=data["num_tokens_to_generate"],
115-
attn_implementation=data["attn_implementation"],
116-
sdpa_backend=data["sdpa_backend"],
117-
compile_mode=data["compile_mode"],
118-
compile_options=data["compile_options"],
119-
kernelize=data["kernelize"],
104+
warmup_iterations=data.get("warmup_iterations", 5),
105+
measurement_iterations=data.get("measurement_iterations", 20),
106+
gpu_monitoring=data.get("gpu_monitoring", False),
107+
batch_size=data.get("batch_size", 1),
108+
sequence_length=data.get("sequence_length", 128),
109+
num_tokens_to_generate=data.get("num_tokens_to_generate", 128),
110+
attn_implementation=data.get("attn_implementation", "eager"),
111+
sdpa_backend=data.get("sdpa_backend"),
112+
compile_mode=data.get("compile_mode"),
113+
compile_options=data.get("compile_options"),
114+
kernelize=data.get("kernelize", False),
120115
name=data.get("name"),
116+
skip_validity_check=skip_validity_check,
121117
)
122118

123119

0 commit comments

Comments
 (0)