Skip to content

Commit 5791cf6

Browse files
committed
Update the benchmarks to enable XLA dumping.
1 parent 4ca21ec commit 5791cf6

File tree

4 files changed

+17
-9
lines changed

4 files changed

+17
-9
lines changed

src/benchmark_attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,20 @@ def pallas_flash_attention_benchmark(
156156
"""Benchmarks the Pallas flash attention kernel."""
157157

158158
@partial(jax.jit, static_argnames=["causal"])
159-
def pallas_attention(q, k, v, causal):
159+
def f(q, k, v, causal):
160160
return pallas_flash_attention.mha_reference(
161161
q, k, v, ab=None, segment_ids=None, causal=causal
162162
)
163163

164164
# Generate QKV.
165165
q, k, v = generate_qkv(batch, seq_len, d_model, num_heads)
166166
# Run once
167-
output = pallas_attention(q, k, v, causal)
167+
output = f(q, k, v, causal)
168168
jax.block_until_ready(output)
169169

170170
# Run benchmark
171171
time_ms_list = simple_timeit(
172-
pallas_attention,
172+
f,
173173
q,
174174
k,
175175
v,

src/benchmark_convolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def f(x, kernel, mode):
6060

6161
# Time the operation
6262
time_ms_list = simple_timeit(
63-
jitted_f,
63+
f,
6464
x,
6565
kernel,
6666
padding_mode,

src/benchmark_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,12 +246,10 @@ def rename_xla_dump(
246246
matching_anchor_files.sort(key=os.path.getmtime, reverse=True)
247247
latest_anchor_file = matching_anchor_files[0]
248248

249-
# Extract the common 'jit_f.[unique_id]' part from the anchor file.
250-
# This regex captures from 'jit_f.' up to the next '.' (before the specific suffix like '.before_optimizations')
251249
# Example: 'module_0080.jit_f.cl_747713181.before_optimizations.txt'
252-
# This will extract 'jit_f.cl_747713181'
250+
# This will extract 'module_0080.jit_f.cl_747713181'
253251
filename_base = os.path.basename(latest_anchor_file)
254-
jit_id_match = re.search(r"(jit_f\.[^.]+)", filename_base)
252+
jit_id_match = re.search(r"(module.*jit_f\.[^.]+)", filename_base)
255253

256254
if not jit_id_match:
257255
print(

src/run_benchmark.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import ray
2020
from concurrent.futures import ThreadPoolExecutor
2121
import os
22+
import copy
23+
2224

2325
COLLECTIVE_BENCHMARK_MAP = {
2426
"all_gather": "benchmark_collectives.all_gather_benchmark",
@@ -240,6 +242,7 @@ def run_single_benchmark(benchmark_config: Dict[str, Any]):
240242
# Run the benchmark
241243
calculate_metrics_results = []
242244
for benchmark_param in benchmark_params:
245+
original_benchmark_param = copy.deepcopy(benchmark_param)
243246
benchmark_param = preprocess_benchmark_param(
244247
benchmark_param, trace_dir=trace_dir
245248
)
@@ -286,7 +289,7 @@ def run_single_benchmark(benchmark_config: Dict[str, Any]):
286289
tmp_xla_dump_dir=TMP_XLA_DUMP_DIR,
287290
dest_xla_dump_dir=xla_dump_dir,
288291
benchmark_name=benchmark_name,
289-
benchmark_param=filtered_benchmark_param,
292+
benchmark_param=original_benchmark_param,
290293
)
291294

292295
# Dump metrics to file.
@@ -305,6 +308,13 @@ def main(config_path: str, multithreaded: bool):
305308
if not benchmarks or not isinstance(benchmarks, list):
306309
raise ValueError("Configuration must contain a 'benchmarks' list.")
307310

311+
# Clear the tmp dirs.
312+
if os.path.exists(TMP_XLA_DUMP_DIR):
313+
for filename in os.listdir(TMP_XLA_DUMP_DIR):
314+
file_path = os.path.join(TMP_XLA_DUMP_DIR, filename)
315+
if os.path.isfile(file_path):
316+
os.remove(file_path)
317+
308318
if multithreaded:
309319
ray.init(
310320
runtime_env=ray.runtime_env.RuntimeEnv(

0 commit comments

Comments
 (0)