@@ -87,6 +87,7 @@ def benchmark_model(
8787 max_batch_size : int = 32 ,
8888 max_tokens : int = 128 ,
8989 max_model_len : int = 8192 ,
90+ track_latency : bool = False ,
9091) -> dict :
9192 """Benchmark a model with given requests."""
9293 print (f"\n { '=' * 70 } " )
@@ -130,12 +131,19 @@ def benchmark_model(
130131 print (f"\n Running benchmark ({ len (requests )} requests)..." )
131132 start_time = time .time ()
132133
133- # Track per-request timing by generating in small batches
134- batch_size = min (max_batch_size , 8 ) # Use smaller batches for timing granularity
134+ # Determine batch size based on latency tracking
135+ if track_latency :
136+ # Use smaller batches for timing granularity (reduces throughput)
137+ batch_size = min (max_batch_size , 8 )
138+ print (f" (Latency tracking enabled: batches of { batch_size } )" )
139+ else :
140+ # Use full batch size for maximum throughput
141+ batch_size = max_batch_size
142+ print (f" (Using batches of { batch_size } for max throughput)" )
143+
135144 outputs = []
136145 request_latencies = []
137146
138- print (f" (Generating in batches of { batch_size } for latency tracking...)" )
139147 for i in range (0 , len (requests ), batch_size ):
140148 batch = requests [i : i + batch_size ]
141149 batch_start = time .time ()
@@ -145,10 +153,13 @@ def benchmark_model(
145153
146154 outputs .extend (batch_outputs )
147155
148- # Estimate per-request latency as batch_time / batch_size
149- # This is an approximation but better than nothing for batched generation
150- for _ in batch :
151- request_latencies .append (batch_time / len (batch ))
156+ # Track per-request latency only if enabled
157+ if track_latency :
158+ # Estimate per-request latency as batch_time / batch_size
159+ # This is an approximation but better than nothing for batched
160+ # generation
161+ for _ in batch :
162+ request_latencies .append (batch_time / len (batch ))
152163
153164 end_time = time .time ()
154165 total_time = end_time - start_time
@@ -204,7 +215,7 @@ def benchmark_model(
204215 print (f"P90 Latency: { results ['p90_latency' ] * 1000 :.2f} ms" )
205216 print (f"P99 Latency: { results ['p99_latency' ] * 1000 :.2f} ms" )
206217 else :
207- print ("P50/P90/P99: N/A (use --num-requests ≤20 for per-request latency )" )
218+ print ("P50/P90/P99: N/A (use --track-latency for percentiles )" )
208219 print (f"{ '=' * 70 } " )
209220
210221 return results
@@ -259,6 +270,14 @@ def main():
259270 action = "store_true" ,
260271 help = "Skip built-in model benchmark (only run custom)" ,
261272 )
273+ parser .add_argument (
274+ "--track-latency" ,
275+ action = "store_true" ,
276+ help = (
277+ "Enable fine-grained latency tracking (P50/P90/P99). "
278+ "Uses smaller batches which reduces throughput."
279+ ),
280+ )
262281
263282 args = parser .parse_args ()
264283
@@ -272,11 +291,11 @@ def main():
272291 print ("DeepSeek V3 Benchmark: TorchTitan vs Built-in" )
273292 print (f"{ '#' * 70 } " )
274293 if run_custom and run_builtin :
275- print ("Mode: Comparing Custom vs Built-in" )
294+ print ("Mode: Comparing TorchTitan vs Built-in" )
276295 elif run_custom :
277- print ("Mode: Custom model only" )
296+ print ("Mode: TorchTitan only" )
278297 elif run_builtin :
279- print ("Mode: Built-in model only" )
298+ print ("Mode: Built-in only" )
280299
281300 # Import custom model if needed
282301 if run_custom :
@@ -294,7 +313,7 @@ def main():
294313 if run_custom :
295314 try :
296315 print ("\n " + "=" * 70 )
297- print ("CUSTOM MODEL (TorchTitan + vLLM MLA )" )
316+ print ("TorchTitan DeepSeek (Custom Implementation )" )
298317 print ("=" * 70 )
299318 results ["custom" ] = benchmark_model (
300319 model_name = args .model ,
@@ -303,6 +322,7 @@ def main():
303322 max_batch_size = args .max_batch_size ,
304323 max_tokens = args .max_tokens ,
305324 max_model_len = args .max_model_len ,
325+ track_latency = args .track_latency ,
306326 )
307327 except Exception as e :
308328 print (f"\n ❌ Custom model benchmark failed: { e } " )
@@ -314,7 +334,7 @@ def main():
314334 if run_builtin :
315335 try :
316336 print ("\n " + "=" * 70 )
317- print ("BUILT-IN vLLM MODEL " )
337+ print ("Built-in DeepSeek ( vLLM Native Implementation) " )
318338 print ("=" * 70 )
319339 results ["builtin" ] = benchmark_model (
320340 model_name = args .model ,
@@ -323,6 +343,7 @@ def main():
323343 max_batch_size = args .max_batch_size ,
324344 max_tokens = args .max_tokens ,
325345 max_model_len = args .max_model_len ,
346+ track_latency = args .track_latency ,
326347 )
327348 except Exception as e :
328349 print (f"\n ❌ Built-in model benchmark failed: { e } " )
@@ -340,7 +361,9 @@ def main():
340361 builtin = results .get ("builtin" , {})
341362
342363 if custom and builtin :
343- print (f"\n { 'Metric' :<25} { 'Custom' :<20} { 'Built-in' :<20} { 'Speedup' :<15} " )
364+ print (
365+ f"\n { 'Metric' :<25} { 'TorchTitan' :<20} { 'Built-in' :<20} { 'Speedup' :<15} "
366+ )
344367 print ("-" * 80 )
345368
346369 metrics = [
0 commit comments