Skip to content

Commit 59b01b2

Browse files
committed
Refactor the write_to_csv function in src/run_benchmark.py to enhance the readability and analytical usability of the CSV output. Also, update requirements.txt with any new dependencies.
.
1 parent dcf1b63 commit 59b01b2

File tree

2 files changed

+63
-13
lines changed

2 files changed

+63
-13
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ absl-py
33
flax
44
jaxlib
55
numpy
6+
pandas
67
jsonlines
78
ray[default]
89
# For profiling

src/run_benchmark.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
import random
1414
import string
1515
from typing import Any, Callable, Dict, List, Tuple
16-
from benchmark_utils import maybe_write_metrics_file, rename_xla_dump
16+
from benchmark_utils import maybe_write_metrics_file, rename_xla_dump, MetricsStatistics
1717
import jax
1818
import yaml
1919
import ray
2020
from concurrent.futures import ThreadPoolExecutor
2121
import os
2222
import copy
23-
23+
import pandas as pd
24+
import ast
25+
import json
2426

2527
COLLECTIVE_BENCHMARK_MAP = {
2628
"all_gather": "benchmark_collectives.all_gather_benchmark",
@@ -197,24 +199,71 @@ def generate_benchmark_params_sweeping(
197199

198200

199201
def write_to_csv(csv_path: str, calculate_metrics_results: List[Dict[str, Any]]):
200-
"""Write the metrics results to a CSV file."""
202+
"""Writes benchmark metrics to a CSV file.
203+
204+
This function takes a list of dictionaries, where each dictionary contains
205+
the 'metadata' and 'metrics' from a benchmark run. It processes each
206+
dictionary by flattening it, calculating additional statistics for specific
207+
fields (like 'ici_average_time_ms_list'), and then converting it into a
208+
pandas DataFrame. All resulting DataFrames are concatenated and written to
209+
the specified CSV file.
210+
211+
Args:
212+
csv_path: The path to the output CSV file.
213+
calculate_metrics_results: A list of dictionaries with benchmark results.
214+
"""
201215
if not calculate_metrics_results:
202216
raise ValueError("0 metrics results are collected.")
203217
if not isinstance(calculate_metrics_results[0], dict):
204218
raise ValueError("metrics result is not a dict.")
205-
# Open the CSV file for writing
206-
with open(csv_path, mode="w", newline="") as csv_file:
207-
# Use the keys from the first item as the headers
208219

209-
headers = calculate_metrics_results[0].keys()
220+
def flatten_dict(current_dict: Dict, output_dict: Dict) -> Dict:
221+
"""Recursively flattens a nested dictionary."""
222+
for key, val in current_dict.items():
223+
if isinstance(val, Dict):
224+
output_dict = flatten_dict(val, output_dict)
225+
else:
226+
# Try to evaluate string-formatted literals (e.g., "[1, 2, 3]")
227+
try:
228+
output_dict[key] = ast.literal_eval(val)
229+
except (ValueError, SyntaxError, TypeError):
230+
# If it's not a valid literal, keep it as a string.
231+
output_dict[key] = val
232+
return output_dict
233+
234+
def convert_dict_to_df(target_dict: Dict) -> pd.DataFrame:
235+
"""Converts a single benchmark result dictionary to a pandas DataFrame."""
236+
flattened_dict = flatten_dict(target_dict, {})
237+
238+
# TODO(user): Generalize this hard-coded value if needed.
239+
flattened_dict["dtype"] = "bfloat16"
240+
241+
# This section is specific to collective benchmarks that produce
242+
# 'ici_average_time_ms_list'.
243+
if "ici_average_time_ms_list" in flattened_dict:
244+
# Calculate statistics for the timing list.
245+
ici_average_time_ms_statistics = MetricsStatistics(
246+
metrics_list=flattened_dict["ici_average_time_ms_list"],
247+
metrics_name="ici_average_time_ms",
248+
).statistics
249+
for key, val in ici_average_time_ms_statistics.items():
250+
flattened_dict["ici_average_time_ms_" + key] = val
251+
252+
253+
# Convert list to JSON string for CSV storage.
254+
flattened_dict["ici_average_time_ms_list"] = json.dumps(
255+
flattened_dict["ici_average_time_ms_list"]
256+
)
257+
258+
df = pd.DataFrame(flattened_dict, index=[0])
259+
return df
260+
261+
# Create a list of DataFrames and concatenate them once for efficiency.
262+
df_list = [convert_dict_to_df(each) for each in calculate_metrics_results]
263+
df = pd.concat(df_list, ignore_index=True)
210264

211-
# Initialize a DictWriter with the headers
212-
writer = csv.DictWriter(csv_file, fieldnames=headers)
213-
writer.writeheader() # Write the header row
265+
df.to_csv(csv_path, index=False)
214266

215-
# Iterate through each result and write to the CSV
216-
for each in calculate_metrics_results:
217-
writer.writerow(each) # Write each row
218267
print(f"Metrics written to CSV at {csv_path}.")
219268

220269

0 commit comments

Comments
 (0)