|
20 | 20 |
|
21 | 21 | import argparse |
22 | 22 | import logging |
23 | | -import random |
24 | 23 | import sys |
25 | 24 | import uuid |
26 | 25 |
|
27 | | -from framework.benchmark_config import BenchmarkConfig, generate_all_configs |
| 26 | +from framework.benchmark_config import BenchmarkConfig, generate_all_configs, generate_main_configs |
28 | 27 | from framework.benchmark_runner import BenchmarkRunner |
29 | 28 |
|
30 | 29 |
|
31 | 30 | if __name__ == "__main__": |
32 | 31 | # Parse arguments |
33 | 32 | parser = argparse.ArgumentParser() |
34 | | - parser.add_argument("--output-dir", type=str, default="benchmark_results", help="Output dir for benchmark results") |
| 33 | + parser.add_argument("--output-dir", type=str, default=None, help="Output dir for benchmark results") |
35 | 34 | parser.add_argument("--log-level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR"], default="INFO") |
36 | 35 | parser.add_argument("--model-id", type=str, help="Specific model ID to benchmark (if supported by benchmarks)") |
37 | 36 |
|
38 | | - parser.add_argument("--warmup", type=int, default=5, help="Number of warmup iterations") |
39 | | - parser.add_argument("--iterations", type=int, default=20, help="Number of measurement iterations") |
| 37 | + parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations") |
| 38 | + parser.add_argument("--iterations", type=int, default=10, help="Number of measurement iterations") |
40 | 39 |
|
41 | 40 | parser.add_argument("--batch-size", "-b", type=int, nargs="+", help="Batch size") |
42 | 41 | parser.add_argument("--sequence-length", "-s", type=int, nargs="+", help="Sequence length") |
43 | 42 | parser.add_argument("--num-tokens-to-generate", "-n", type=int, nargs="+", help="Number of tokens to generate") |
44 | 43 |
|
| 44 | + parser.add_argument("--cross-generate", action="store_true", help="Cross-generate all combinations of configs") |
45 | 45 | parser.add_argument("--num-tokens-to-profile", "-p", type=int, default=0, help="Number of tokens to profile") |
46 | 46 |
|
47 | 47 | parser.add_argument("--commit-id", type=str, help="Git commit ID (if not provided, will auto-detect from git)") |
|
69 | 69 |
|
70 | 70 | # If there is only one (batch_size, sequence_length, num_tokens_to_generate), we benchmark across configs |
71 | 71 | elif len(args.batch_size) * len(args.sequence_length) * len(args.num_tokens_to_generate) == 1: |
72 | | - benchmark_configs = generate_all_configs( |
| 72 | + if args.cross_generate: |
| 73 | + benchmark_configs = generate_all_configs( |
| 74 | + warmup_iterations=args.warmup, |
| 75 | + measurement_iterations=args.iterations, |
| 76 | + batch_size=args.batch_size[0], |
| 77 | + sequence_length=args.sequence_length[0], |
| 78 | + num_tokens_to_generate=args.num_tokens_to_generate[0], |
| 79 | + ) |
| 80 | + else: |
| 81 | + benchmark_configs = generate_main_configs( |
| 82 | + warmup_iterations=args.warmup, |
| 83 | + measurement_iterations=args.iterations, |
| 84 | + batch_size=args.batch_size[0], |
| 85 | + sequence_length=args.sequence_length[0], |
| 86 | + num_tokens_to_generate=args.num_tokens_to_generate[0], |
| 87 | + ) |
| 88 | + |
| 89 | + # Otherwise, we benchmark across all combinations of dimensions |
| 90 | + else: |
| 91 | + main_config = generate_main_configs( |
73 | 92 | warmup_iterations=args.warmup, |
74 | 93 | measurement_iterations=args.iterations, |
75 | 94 | batch_size=args.batch_size[0], |
76 | 95 | sequence_length=args.sequence_length[0], |
77 | 96 | num_tokens_to_generate=args.num_tokens_to_generate[0], |
78 | | - ) |
79 | | - random.shuffle(benchmark_configs) |
80 | | - |
81 | | - # Otherwise, we benchmark across all combinations of dimensions |
82 | | - else: |
83 | | - kwargs = { |
84 | | - "warmup_iterations": args.warmup, |
85 | | - "measurement_iterations": args.iterations, |
86 | | - "gpu_monitoring": False, |
87 | | - "batch_size": args.batch_size[0], |
88 | | - "sequence_length": args.sequence_length[0], |
89 | | - "num_tokens_to_generate": args.num_tokens_to_generate[0], |
90 | | - "attn_implementation": "flex_attention", |
91 | | - "sdpa_backend": None, |
92 | | - "compile_mode": "default", |
93 | | - "kernelize": False, |
94 | | - } |
| 97 | + )[0] |
95 | 98 | benchmark_configs = [] |
96 | 99 | for num_tokens_to_generate in args.num_tokens_to_generate: |
97 | 100 | for sequence_length in args.sequence_length: |
98 | 101 | for batch_size in args.batch_size: |
99 | | - kwargs["batch_size"] = batch_size |
100 | | - kwargs["sequence_length"] = sequence_length |
101 | | - kwargs["num_tokens_to_generate"] = num_tokens_to_generate |
102 | | - benchmark_configs.append(BenchmarkConfig(**kwargs)) |
| 102 | + cfg_dict = main_config.to_dict() |
| 103 | + cfg_dict["batch_size"] = batch_size |
| 104 | + cfg_dict["sequence_length"] = sequence_length |
| 105 | + cfg_dict["num_tokens_to_generate"] = num_tokens_to_generate |
| 106 | + cfg_dict.pop("name") |
| 107 | + benchmark_configs.append(BenchmarkConfig.from_dict(cfg_dict)) |
103 | 108 |
|
104 | 109 | runner = BenchmarkRunner(logger, args.output_dir, args.commit_id) |
105 | 110 | results = runner.run_benchmarks( |
106 | 111 | args.model_id, |
107 | | - benchmark_configs[:3], |
| 112 | + benchmark_configs, |
108 | 113 | args.num_tokens_to_profile, |
109 | 114 | pretty_print_summary=True, |
110 | 115 | ) |
|
0 commit comments