Skip to content

Commit 644166e

Browse files
authored
Add Warmup Loop for JIT ops (#30)
* Fix num_runs loop for DCN collective tests * Adding warmup logic * Added warmp_up param as example to sample config
1 parent 84a73b6 commit 644166e

File tree

4 files changed

+37
-6
lines changed

4 files changed

+37
-6
lines changed

configs/sample_benchmark_collectives.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,27 @@ benchmarks:
22
- benchmark_name: ppermute
33
benchmark_sweep_params:
44
- {matrix_dim_range: {start: 1024, end: 20000, increase_by: 1024}, dtype: "bfloat16", dcn_size_range: 1, ici_size_range: 4}
5+
warmup_tries: 10
56
trace_dir: "/tmp/microbenchmarks/collectives"
67
csv_path: "/tmp/microbenchmarks/collectives"
78
xla_dump_dir: "/tmp/microbenchmarks/collective/hlo_graphs"
89
- benchmark_name: all_gather
910
benchmark_sweep_params:
1011
- {matrix_dim_range: {start: 1024, end: 20000, increase_by: 1024}, dtype: "bfloat16", dcn_size_range: 1, ici_size_range: 4}
12+
warmup_tries: 10
1113
trace_dir: "/tmp/microbenchmarks/collectives"
1214
csv_path: "/tmp/microbenchmarks/collectives"
1315
xla_dump_dir: "/tmp/microbenchmarks/collective/hlo_graphs"
1416
- benchmark_name: psum
1517
benchmark_sweep_params:
1618
- {matrix_dim_range: {start: 1024, end: 20000, increase_by: 1024}, dtype: "bfloat16", dcn_size_range: 1, ici_size_range: 4}
19+
warmup_tries: 10
1720
trace_dir: "/tmp/microbenchmarks/collectives"
1821
csv_path: "/tmp/microbenchmarks/collectives"
1922
xla_dump_dir: "/tmp/microbenchmarks/collective/hlo_graphs"
2023
- benchmark_name: psum_scatter
2124
benchmark_sweep_params:
2225
- {matrix_dim_range: {start: 1024, end: 20000, increase_by: 1024}, dtype: "bfloat16", dcn_size_range: 1, ici_size_range: 4}
26+
warmup_tries: 10
2327
trace_dir: "/tmp/microbenchmarks/collectives"
2428
csv_path: "/tmp/microbenchmarks/collectives"

src/benchmark_collectives.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def psum_benchmark(
5757
ici_size: int,
5858
num_runs: int = 1,
5959
trace_dir: str = None,
60+
warmup_tries: int = 10,
6061
) -> Dict[str, Any]:
6162
"""Benchmarks the psum collective operation.
6263
@@ -90,6 +91,7 @@ def f(x):
9091
jitted_op,
9192
sharded_matrix,
9293
matrix_dim=matrix_dim,
94+
warmup_tries=warmup_tries,
9395
tries=num_runs,
9496
task="psum_dcn_op",
9597
trace_dir=trace_dir,
@@ -110,6 +112,7 @@ def f(x):
110112
jitted_op,
111113
sharded_matrix,
112114
matrix_dim=matrix_dim,
115+
warmup_tries=warmup_tries,
113116
tries=num_runs,
114117
task="psum_ici_op",
115118
trace_dir=trace_dir,
@@ -189,6 +192,7 @@ def psum_scatter_benchmark(
189192
ici_size: int,
190193
num_runs: int = 1,
191194
trace_dir: str = None,
195+
warmup_tries: int = 10,
192196
) -> Dict[str, Any]:
193197
"""Benchmarks the psum_scatter collective operation.
194198
@@ -224,6 +228,7 @@ def f(x):
224228
jitted_op,
225229
sharded_matrix,
226230
matrix_dim=matrix_dim,
231+
warmup_tries=warmup_tries,
227232
tries=num_runs,
228233
task="psum_scatter_dcn_op",
229234
trace_dir=trace_dir,
@@ -244,6 +249,7 @@ def f(x):
244249
jitted_op,
245250
sharded_matrix,
246251
matrix_dim=matrix_dim,
252+
warmup_tries=warmup_tries,
247253
tries=num_runs,
248254
task="psum_scatter_ici_op",
249255
trace_dir=trace_dir,
@@ -322,6 +328,7 @@ def all_gather_benchmark(
322328
dtype: jnp.dtype,
323329
dcn_size: int,
324330
ici_size: int,
331+
warmup_tries: int = 10,
325332
num_runs: int = 1,
326333
trace_dir: str = None,
327334
) -> Dict[str, Any]:
@@ -364,6 +371,7 @@ def f(x):
364371
jitted_op,
365372
sharded_matrix,
366373
matrix_dim=matrix_dim,
374+
warmup_tries=warmup_tries,
367375
tries=num_runs,
368376
task="all_gather_dcn_op",
369377
trace_dir=trace_dir,
@@ -390,6 +398,7 @@ def f(x):
390398
jitted_op,
391399
sharded_matrix,
392400
matrix_dim=matrix_dim,
401+
warmup_tries=warmup_tries,
393402
tries=num_runs,
394403
task="all_gather_ici_op",
395404
trace_dir=trace_dir,
@@ -469,6 +478,7 @@ def ppermute_benchmark(
469478
ici_size: int,
470479
num_runs: int = 1,
471480
trace_dir: str = None,
481+
warmup_tries: int = 10,
472482
) -> Dict[str, Any]:
473483
"""Benchmarks the ppermute collective operation.
474484
@@ -506,6 +516,7 @@ def f(x):
506516
jitted_op,
507517
sharded_matrix,
508518
matrix_dim=matrix_dim,
519+
warmup_tries=warmup_tries,
509520
tries=num_runs,
510521
task="ppermute_dcn_op",
511522
trace_dir=trace_dir,
@@ -527,6 +538,7 @@ def f(x):
527538
jitted_op,
528539
sharded_matrix,
529540
matrix_dim=matrix_dim,
541+
warmup_tries=warmup_tries,
530542
tries=num_runs,
531543
task="ppermute_ici_op",
532544
trace_dir=trace_dir,
@@ -598,6 +610,7 @@ def all_to_all_benchmark(
598610
ici_size: int,
599611
num_runs: int = 1,
600612
trace_dir: str = None,
613+
warmup_tries: int = 10,
601614
) -> Dict[str, Any]:
602615
"""Benchmarks the all_to_all collective operation.
603616
@@ -634,6 +647,7 @@ def f(x):
634647
jitted_op,
635648
sharded_matrix,
636649
matrix_dim=matrix_dim,
650+
warmup_tries=warmup_tries,
637651
tries=num_runs,
638652
task="all_to_all_dcn_op",
639653
trace_dir=trace_dir,
@@ -660,6 +674,7 @@ def f(x):
660674
jitted_op,
661675
sharded_matrix,
662676
matrix_dim=matrix_dim,
677+
warmup_tries=warmup_tries,
663678
tries=num_runs,
664679
task="all_to_all_ici_op",
665680
trace_dir=trace_dir,

src/benchmark_utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,19 @@
1919
import shutil
2020

2121

22-
def simple_timeit(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None) -> float:
22+
def simple_timeit(f, *args, matrix_dim=None, warmup_tries = 10, tries=10, task=None, trace_dir=None) -> float:
2323
"""Simple utility to time a function for multiple runs."""
2424
assert task is not None
2525

2626
if trace_dir:
2727
return timeit_from_trace(f, *args, matrix_dim=matrix_dim, tries=tries, task=task, trace_dir=trace_dir)
2828

29+
# warmup loop
30+
print(f"Running warmup loop with {warmup_tries} tries...")
31+
for _ in range(warmup_tries):
32+
data = f(*args)
33+
jax.block_until_ready(data)
2934
outcomes_ms = []
30-
jax.block_until_ready(f(*args)) # warm it up!
3135
for _ in range(tries):
3236
jax.devices() # Force synchronization across devices
3337
s = datetime.datetime.now()
@@ -97,13 +101,17 @@ def is_local_directory_path(dir: str) -> bool:
97101
return dir.startswith("/") or dir.startswith("./") or dir.startswith("../")
98102

99103

100-
def timeit_from_trace(f, *args, matrix_dim=None, tries=10, task=None, trace_dir=None) -> float:
104+
def timeit_from_trace(f, *args, matrix_dim=None, warmup_tries=10, tries=10, task=None, trace_dir=None) -> float:
101105
"""
102106
Time a function with jax.profiler and get the run time from the trace.
103107
"""
104108
LOCAL_TRACE_DIR = "/tmp/microbenchmarks_tmptrace"
105109

106-
jax.block_until_ready(f(*args)) # warm it up!
110+
# warmup loop
111+
print(f"Running warmup loop with {warmup_tries} tries...")
112+
for _ in range(warmup_tries):
113+
data = f(*args)
114+
jax.block_until_ready(data)
107115

108116
if matrix_dim is not None:
109117
trace_name = f"{task}_dim_{matrix_dim}"

src/run_benchmark.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,8 @@ def run_single_benchmark(benchmark_config: Dict[str, Any]):
279279
trace_dir = benchmark_config.get("trace_dir")
280280
xlml_metrics_dir = benchmark_config.get("xlml_metrics_dir")
281281
xla_dump_dir = benchmark_config.get("xla_dump_dir")
282+
warmup_tries = benchmark_config.get("warmup_tries")
283+
warmup_tries = warmup_tries if warmup_tries is not None else 1000
282284

283285
if not benchmark_name:
284286
raise ValueError("Each benchmark must have a 'benchmark_name'.")
@@ -299,7 +301,7 @@ def run_single_benchmark(benchmark_config: Dict[str, Any]):
299301
test_start_time = (
300302
datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + "Z"
301303
) # "Z" indicates UTC
302-
benchmark_results = benchmark_func(**benchmark_param)
304+
benchmark_results = benchmark_func(**benchmark_param, warmup_tries=warmup_tries)
303305
test_end_time = (
304306
datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + "Z"
305307
)
@@ -400,6 +402,8 @@ def run_benchmark_multithreaded(benchmark_config):
400402
csv_path = benchmark_config.get("csv_path")
401403
if not benchmark_name:
402404
raise ValueError("Each benchmark must have a 'benchmark_name'.")
405+
warmup_tries = benchmark_config.get("warmup_tries")
406+
warmup_tries = warmup_tries if warmup_tries is not None else 1000
403407

404408
# Get the benchmark function
405409
benchmark_func, calculate_metrics_func = get_benchmark_functions(benchmark_name)
@@ -427,7 +431,7 @@ def run_benchmark_multithreaded(benchmark_config):
427431
with ThreadPoolExecutor(max_workers=num_hosts) as executor:
428432
# Create a mapping of futures to their corresponding parameters
429433
future_to_param = {
430-
executor.submit(benchmark_func, **benchmark_param): benchmark_param
434+
executor.submit(benchmark_func, **benchmark_param, warmup_tries=warmup_tries): benchmark_param
431435
for benchmark_param in preprocessed_benchmark_params
432436
}
433437

0 commit comments

Comments
 (0)