|
13 | 13 | import random |
14 | 14 | import string |
15 | 15 | from typing import Any, Callable, Dict, List, Tuple |
16 | | -from benchmark_utils import maybe_write_metrics_file, rename_xla_dump |
| 16 | +from benchmark_utils import maybe_write_metrics_file, rename_xla_dump, MetricsStatistics |
17 | 17 | import jax |
18 | 18 | import yaml |
19 | 19 | import ray |
20 | 20 | from concurrent.futures import ThreadPoolExecutor |
21 | 21 | import os |
22 | 22 | import copy |
23 | | - |
| 23 | +import pandas as pd |
| 24 | +import ast |
| 25 | +import json |
24 | 26 |
|
25 | 27 | COLLECTIVE_BENCHMARK_MAP = { |
26 | 28 | "all_gather": "benchmark_collectives.all_gather_benchmark", |
@@ -197,24 +199,71 @@ def generate_benchmark_params_sweeping( |
197 | 199 |
|
198 | 200 |
|
199 | 201 | def write_to_csv(csv_path: str, calculate_metrics_results: List[Dict[str, Any]]): |
200 | | - """Write the metrics results to a CSV file.""" |
| 202 | + """Writes benchmark metrics to a CSV file. |
| 203 | +
|
| 204 | + This function takes a list of dictionaries, where each dictionary contains |
| 205 | + the 'metadata' and 'metrics' from a benchmark run. It processes each |
| 206 | + dictionary by flattening it, calculating additional statistics for specific |
| 207 | + fields (like 'ici_average_time_ms_list'), and then converting it into a |
| 208 | + pandas DataFrame. All resulting DataFrames are concatenated and written to |
| 209 | + the specified CSV file. |
| 210 | +
|
| 211 | + Args: |
| 212 | + csv_path: The path to the output CSV file. |
| 213 | + calculate_metrics_results: A list of dictionaries with benchmark results. |
| 214 | + """ |
201 | 215 | if not calculate_metrics_results: |
202 | 216 | raise ValueError("0 metrics results are collected.") |
203 | 217 | if not isinstance(calculate_metrics_results[0], dict): |
204 | 218 | raise ValueError("metrics result is not a dict.") |
205 | | - # Open the CSV file for writing |
206 | | - with open(csv_path, mode="w", newline="") as csv_file: |
207 | | - # Use the keys from the first item as the headers |
208 | 219 |
|
209 | | - headers = calculate_metrics_results[0].keys() |
| 220 | + def flatten_dict(current_dict: Dict, output_dict: Dict) -> Dict: |
| 221 | + """Recursively flattens a nested dictionary.""" |
| 222 | + for key, val in current_dict.items(): |
| 223 | + if isinstance(val, Dict): |
| 224 | + output_dict = flatten_dict(val, output_dict) |
| 225 | + else: |
| 226 | + # Try to evaluate string-formatted literals (e.g., "[1, 2, 3]") |
| 227 | + try: |
| 228 | + output_dict[key] = ast.literal_eval(val) |
| 229 | + except (ValueError, SyntaxError, TypeError): |
| 230 | + # If it's not a valid literal, keep it as a string. |
| 231 | + output_dict[key] = val |
| 232 | + return output_dict |
| 233 | + |
| 234 | + def convert_dict_to_df(target_dict: Dict) -> pd.DataFrame: |
| 235 | + """Converts a single benchmark result dictionary to a pandas DataFrame.""" |
| 236 | + flattened_dict = flatten_dict(target_dict, {}) |
| 237 | + |
| 238 | + # TODO(user): Generalize this hard-coded value if needed. |
| 239 | + flattened_dict["dtype"] = "bfloat16" |
| 240 | + |
| 241 | + # This section is specific to collective benchmarks that produce |
| 242 | + # 'ici_average_time_ms_list'. |
| 243 | + if "ici_average_time_ms_list" in flattened_dict: |
| 244 | + # Calculate statistics for the timing list. |
| 245 | + ici_average_time_ms_statistics = MetricsStatistics( |
| 246 | + metrics_list=flattened_dict["ici_average_time_ms_list"], |
| 247 | + metrics_name="ici_average_time_ms", |
| 248 | + ).statistics |
| 249 | + for key, val in ici_average_time_ms_statistics.items(): |
| 250 | + flattened_dict["ici_average_time_ms_" + key] = val |
| 251 | + |
| 252 | + |
| 253 | + # Convert list to JSON string for CSV storage. |
| 254 | + flattened_dict["ici_average_time_ms_list"] = json.dumps( |
| 255 | + flattened_dict["ici_average_time_ms_list"] |
| 256 | + ) |
| 257 | + |
| 258 | + df = pd.DataFrame(flattened_dict, index=[0]) |
| 259 | + return df |
| 260 | + |
| 261 | + # Create a list of DataFrames and concatenate them once for efficiency. |
| 262 | + df_list = [convert_dict_to_df(each) for each in calculate_metrics_results] |
| 263 | + df = pd.concat(df_list, ignore_index=True) |
210 | 264 |
|
211 | | - # Initialize a DictWriter with the headers |
212 | | - writer = csv.DictWriter(csv_file, fieldnames=headers) |
213 | | - writer.writeheader() # Write the header row |
| 265 | + df.to_csv(csv_path, index=False) |
214 | 266 |
|
215 | | - # Iterate through each result and write to the CSV |
216 | | - for each in calculate_metrics_results: |
217 | | - writer.writerow(each) # Write each row |
218 | 267 | print(f"Metrics written to CSV at {csv_path}.") |
219 | 268 |
|
220 | 269 |
|
|
0 commit comments