@@ -22,7 +22,7 @@ def __init__(
2222 self ,
2323 warmup_iterations : int = 5 ,
2424 measurement_iterations : int = 20 ,
25- gpu_monitoring : bool = True ,
25+ gpu_monitoring : bool = False , # False by default because it slows down the benchmark by a lot
2626 batch_size : int = 1 ,
2727 sequence_length : int = 128 ,
2828 num_tokens_to_generate : int = 128 ,
@@ -49,6 +49,9 @@ def __init__(
4949 self .compile_mode = compile_mode
5050 self .compile_options = compile_options if compile_options is not None else {}
5151 self .kernelize = kernelize
52+ # Constant parameters
53+ self .dtype = "torch.bfloat16"
54+ self .device = "cuda"
5255
5356 self .check_validity (skip_validity_check )
5457 self .name = name if name is not None else self .infer_name ()
@@ -63,14 +66,6 @@ def check_validity(self, skip_validity_check: bool = False) -> None:
6366 logger .warning ("Flash attention does not support compile mode. Turning off compile mode." )
6467 self .compile_mode = None
6568
66- @property
67- def device (self ) -> str :
68- return "cuda"
69-
70- @property
71- def dtype (self ) -> str :
72- return "torch.bfloat16"
73-
7469 @property
7570 def hash (self ) -> str :
7671 return hashlib .sha256 (json .dumps (self .to_dict ()).encode ()).hexdigest ()
@@ -87,7 +82,7 @@ def infer_name(self) -> str:
8782 "kernelized" if self .kernelize else "unkernelized" ,
8883 ])
8984
90- def to_dict (self ) -> dict [str , Union [ None , int , float , str ] ]:
85+ def to_dict (self ) -> dict [str , Any ]:
9186 return {
9287 "name" : self .name ,
9388 "warmup_iterations" : self .warmup_iterations ,
@@ -104,20 +99,21 @@ def to_dict(self) -> dict[str, Union[None, int, float, str]]:
10499 }
105100
106101 @classmethod
107- def from_dict (cls , data : dict [str , Any ]) -> "BenchmarkConfig" :
102+ def from_dict (cls , data : dict [str , Any ], skip_validity_check : bool = False ) -> "BenchmarkConfig" :
108103 return cls (
109- warmup_iterations = data [ "warmup_iterations" ] ,
110- measurement_iterations = data [ "measurement_iterations" ] ,
111- gpu_monitoring = data [ "gpu_monitoring" ] ,
112- batch_size = data [ "batch_size" ] ,
113- sequence_length = data [ "sequence_length" ] ,
114- num_tokens_to_generate = data [ "num_tokens_to_generate" ] ,
115- attn_implementation = data [ "attn_implementation" ] ,
116- sdpa_backend = data [ "sdpa_backend" ] ,
117- compile_mode = data [ "compile_mode" ] ,
118- compile_options = data [ "compile_options" ] ,
119- kernelize = data [ "kernelize" ] ,
104+ warmup_iterations = data . get ( "warmup_iterations" , 5 ) ,
105+ measurement_iterations = data . get ( "measurement_iterations" , 20 ) ,
106+ gpu_monitoring = data . get ( "gpu_monitoring" , False ) ,
107+ batch_size = data . get ( "batch_size" , 1 ) ,
108+ sequence_length = data . get ( "sequence_length" , 128 ) ,
109+ num_tokens_to_generate = data . get ( "num_tokens_to_generate" , 128 ) ,
110+ attn_implementation = data . get ( "attn_implementation" , "eager" ) ,
111+ sdpa_backend = data . get ( "sdpa_backend" ) ,
112+ compile_mode = data . get ( "compile_mode" ) ,
113+ compile_options = data . get ( "compile_options" ) ,
114+ kernelize = data . get ( "kernelize" , False ) ,
120115 name = data .get ("name" ),
116+ skip_validity_check = skip_validity_check ,
121117 )
122118
123119
0 commit comments