Skip to content

Commit edeea13

Browse files
committed
Support dumping HLOs
1 parent 43d4af6 commit edeea13

File tree

2 files changed

+177
-0
lines changed

2 files changed

+177
-0
lines changed

src/hlo_helper.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import yaml
2+
import os
3+
import glob
4+
import re
5+
import shutil # For potential safer renaming with move
6+
7+
def rename_hlo_dumps(config_path, hlo_dump_dir):
8+
"""
9+
Renames various HLO dump files with their corresponding configuration strings,
10+
handling arbitrary parameters from benchmark_sweep_params.
11+
12+
Args:
13+
config_path (str): Path to the YAML configuration file (e.g., simple_matmul.yaml).
14+
hlo_dump_dir (str): Directory where HLO dumps are stored (e.g., /tmp/hlo_graphs/).
15+
"""
16+
with open(config_path, 'r') as f:
17+
config = yaml.safe_load(f)
18+
19+
benchmark_params = []
20+
for benchmark in config.get('benchmarks', []):
21+
for params in benchmark.get('benchmark_sweep_params', []):
22+
benchmark_params.append(params)
23+
24+
if not benchmark_params:
25+
print("No benchmark sweep parameters found in the config file. Exiting.")
26+
return
27+
28+
print(f"Found {len(benchmark_params)} benchmark configurations in '{config_path}'.")
29+
30+
# Define the patterns for the files you want to rename
31+
file_patterns = [
32+
'*jit_f*before_optimizations*.txt',
33+
'*jit_f*.tpu_comp_env.txt',
34+
'*jit_f*.execution_options.txt'
35+
]
36+
37+
# Process each type of file
38+
for pattern in file_patterns:
39+
full_pattern = os.path.join(hlo_dump_dir, pattern)
40+
hlo_files = sorted(
41+
glob.glob(full_pattern),
42+
key=os.path.getmtime
43+
)
44+
45+
if not hlo_files:
46+
print(f"\nNo files found matching pattern: '{full_pattern}'. Skipping this type.")
47+
continue
48+
49+
print(f"\nProcessing {len(hlo_files)} files matching pattern: '{full_pattern}'")
50+
51+
if len(hlo_files) != len(benchmark_params):
52+
print(f" Warning: Number of files ({len(hlo_files)}) does not match number of configurations ({len(benchmark_params)}).")
53+
print(" This script assumes a one-to-one, ordered correspondence. Please verify the output carefully.")
54+
55+
for i, hlo_file_path in enumerate(hlo_files):
56+
if i < len(benchmark_params):
57+
params = benchmark_params[i]
58+
59+
# Dynamically create the config string from all parameters
60+
# Example: {'a': 10, 'b': 20} -> "a10_b20"
61+
config_parts = []
62+
for key, value in params.items():
63+
config_parts.append(f"{key}{value}")
64+
config_string = "_".join(config_parts)
65+
66+
if not config_string: # Handle cases where a param set might be empty
67+
print(f" Warning: Empty parameter set at index {i}. Skipping file '{os.path.basename(hlo_file_path)}'.")
68+
continue
69+
70+
directory, base_filename = os.path.split(hlo_file_path)
71+
72+
# Prepend the config string to the original filename
73+
new_filename = f"{config_string}_{base_filename}"
74+
new_file_path = os.path.join(directory, new_filename)
75+
76+
try:
77+
shutil.move(hlo_file_path, new_file_path)
78+
print(f" Renamed: '{base_filename}' -> '{new_filename}'")
79+
except Exception as e:
80+
print(f" Error renaming '{hlo_file_path}': {e}")
81+
else:
82+
print(f" Warning: No matching configuration for HLO file: '{os.path.basename(hlo_file_path)}'. Skipping.")
83+
84+
85+
86+
def rename_xla_dump(xla_dump_dir, benchmark_name, benchmark_param):
87+
"""
88+
Finds the latest XLA dump file matching '*jit_f*before_optimizations*.txt',
89+
then identifies all other files that share the same 'jit_f.[unique_id]' identifier
90+
and renames them to 'benchmark_name_serialized_params.original_suffix_with_extension'.
91+
"""
92+
93+
serialized_benchmark_param = str(benchmark_param)
94+
anchor_pattern = os.path.join(xla_dump_dir, '*jit_f*before_optimizations*.txt')
95+
matching_anchor_files = glob.glob(anchor_pattern)
96+
97+
if not matching_anchor_files:
98+
print(f"No files found for anchor pattern: '{anchor_pattern}'. No files will be renamed.")
99+
return
100+
101+
# Sort anchor files by modification time (latest first)
102+
matching_anchor_files.sort(key=os.path.getmtime, reverse=True)
103+
latest_anchor_file = matching_anchor_files[0]
104+
print(f"Latest anchor file found: '{latest_anchor_file}' (Modified: {datetime.fromtimestamp(os.path.getmtime(latest_anchor_file))})")
105+
106+
# Extract the common 'jit_f.[unique_id]' part from the anchor file.
107+
# This regex captures from 'jit_f.' up to the next '.' (before the specific suffix like '.before_optimizations')
108+
# Example: 'module_0080.jit_f.cl_747713181.before_optimizations.txt'
109+
# This will extract 'jit_f.cl_747713181'
110+
filename_base = os.path.basename(latest_anchor_file)
111+
jit_id_match = re.search(r'(jit_f\.[^.]+)', filename_base)
112+
113+
if not jit_id_match:
114+
print(f"Could not extract 'jit_f.[unique_id]' from '{filename_base}'. Cannot proceed with renaming.")
115+
return
116+
117+
common_jit_id_prefix = jit_id_match.group(1) # e.g., 'jit_f.cl_747713181'
118+
print(f"Extracted common JIT ID prefix for family: '{common_jit_id_prefix}'")
119+
120+
# Find all files in the directory that contain this specific common_jit_id_prefix
121+
# We are looking for files like 'module_XXX.jit_f.ID.suffix.txt'
122+
all_related_files_pattern = os.path.join(xla_dump_dir, f'*{common_jit_id_prefix}*')
123+
all_related_files = glob.glob(all_related_files_pattern)
124+
125+
if not all_related_files:
126+
print(f"No files found containing '{common_jit_id_prefix}'. This is unexpected if an anchor was found.")
127+
return
128+
129+
new_base_name = f"{benchmark_name}_{serialized_benchmark_param}"
130+
131+
print(f"\n--- Renaming files belonging to the '{common_jit_id_prefix}' family ---")
132+
for original_filepath in all_related_files:
133+
original_filename = os.path.basename(original_filepath)
134+
135+
# Find the specific suffix part *after* the common_jit_id_prefix.
136+
# This regex looks for the common_jit_id_prefix, then captures everything after it,
137+
# ensuring it starts with a dot if there's more.
138+
# Example: if original_filename is 'module_0080.jit_f.cl_747713181.after_codegen.txt'
139+
# and common_jit_id_prefix is 'jit_f.cl_747713181'
140+
# we want to capture '.after_codegen.txt'
141+
suffix_match = re.search(re.escape(common_jit_id_prefix) + r'(\..*)', original_filename)
142+
143+
if suffix_match:
144+
original_suffix_with_extension = suffix_match.group(1) # e.g., '.after_codegen.txt'
145+
else:
146+
print("shouldn't get here")
147+
148+
new_filename = f"{new_base_name}{original_suffix_with_extension}"
149+
new_filepath = os.path.join(xla_dump_dir, new_filename)
150+
151+
if original_filepath == new_filepath:
152+
print(f"Skipping: '{original_filename}' already has the desired name or path.")
153+
continue
154+
155+
try:
156+
os.rename(original_filepath, new_filepath)
157+
print(f"Renamed '{original_filename}' to '{new_filename}'")
158+
except OSError as e:
159+
print(f"Error renaming file '{original_filepath}' to '{new_filepath}': {e}")
160+
except Exception as e:
161+
print(f"An unexpected error occurred while renaming '{original_filepath}': {e}")

src/run_benchmark.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
import string
1515
from typing import Any, Callable, Dict, List, Tuple
1616
from benchmark_utils import maybe_write_metrics_file
17+
from hlo_helper import rename_xla_dump
1718
import jax
1819
import yaml
1920
import ray
2021
from concurrent.futures import ThreadPoolExecutor
22+
import os
2123

2224
COLLECTIVE_BENCHMARK_MAP = {
2325
"all_gather": "benchmark_collectives.all_gather_benchmark",
@@ -222,6 +224,10 @@ def run_single_benchmark(benchmark_config: Dict[str, Any]):
222224
csv_path = benchmark_config.get("csv_path")
223225
trace_dir = benchmark_config.get("trace_dir")
224226
xlml_metrics_dir = benchmark_config.get("xlml_metrics_dir")
227+
xla_dump_dir = benchmark_config.get("xla_dump_dir")
228+
229+
if xla_dump_dir:
230+
os.environ["XLA_FLAGS"] = f"--xla_dump_to={xla_dump_dir}"
225231

226232
if not benchmark_name:
227233
raise ValueError("Each benchmark must have a 'benchmark_name'.")
@@ -274,6 +280,10 @@ def run_single_benchmark(benchmark_config: Dict[str, Any]):
274280
test_start_time,
275281
test_end_time,
276282
)
283+
# Post process the xla dump
284+
if xla_dump_dir:
285+
rename_xla_dump(xla_dump_dir=xla_dump_dir, benchmark_name=benchmark_name, benchmark_param=filtered_benchmark_param)
286+
277287

278288
# Dump metrics to file.
279289
if csv_path:
@@ -282,6 +292,12 @@ def run_single_benchmark(benchmark_config: Dict[str, Any]):
282292
)
283293
write_to_csv(f"{csv_path}/{test_name}.csv", calculate_metrics_results)
284294

295+
if not xla_dump_dir:
296+
os.environ["XLA_FLAGS"] = ""
297+
298+
299+
300+
285301

286302
def main(config_path: str, multithreaded: bool):
287303
"""Main function."""

0 commit comments

Comments
 (0)