Skip to content

Commit 352a2e0

Browse files
authored
Benchmark simplification (#42408)
* Renames * Added the timestamps to request * Better rename for prompt_ids * Merged the two timing functions * Style * Remove the first timestamp for generate timing * Fix nit in comment * Re-introduce timestamps * Now upload two versions of the results: full and summarized * Make summarized result more summarized * Fix wrong file name * Dumb fix
1 parent 7094f1e commit 352a2e0

File tree

6 files changed

+178
-157
lines changed

6 files changed

+178
-157
lines changed

benchmark_v2/framework/benchmark_runner.py

Lines changed: 77 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from queue import Queue
1111
from typing import Any
1212

13+
import numpy as np
1314
import torch
1415
from datasets import Dataset
1516
from huggingface_hub import HfApi
@@ -208,10 +209,11 @@ def run_benchmark(
208209
self.logger.info(f"Running benchmark scenario: {config.name}")
209210

210211
# Quick validation: try one measurement first to see if this scenario works
211-
generate_fn = self.time_generate_batch if config.continuous_batching else self.time_generate
212212
flush_memory()
213-
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = generate_fn(
214-
max_new_tokens=1, gpu_monitor=None
213+
e2e_latency, timestamps, shape_and_decoded_output, gpu_metrics = self.time_generate(
214+
max_new_tokens=config.num_tokens_to_generate,
215+
use_continuous_batching=config.continuous_batching,
216+
gpu_monitor=None,
215217
)
216218
if e2e_latency < 0:
217219
self.logger.warning(f"Skipping config {config.name}: {e2e_latency = } (no GPU monitoring)")
@@ -220,18 +222,23 @@ def run_benchmark(
220222
# Warmup runs
221223
self.logger.info(f"Warming up with {config.warmup_iterations} iterations...")
222224
for _ in trange(config.warmup_iterations, desc="Warmup"):
223-
_ = generate_fn(max_new_tokens=config.num_tokens_to_generate)
225+
_ = self.time_generate(
226+
max_new_tokens=config.num_tokens_to_generate,
227+
use_continuous_batching=config.continuous_batching,
228+
gpu_monitor=None,
229+
)
224230
self.logger.info("Warmup over.")
225231

226232
# Measurement runs
227233
result = BenchmarkResult()
228234
self.logger.info(f"Benchmarking with {config.measurement_iterations} iterations.")
229235
for _ in trange(config.measurement_iterations, desc="Benchmarking"):
230-
e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics = generate_fn(
236+
e2e_latency, timestamps, shape_and_decoded_output, gpu_metrics = self.time_generate(
231237
max_new_tokens=config.num_tokens_to_generate,
238+
use_continuous_batching=config.continuous_batching,
232239
gpu_monitor=(GPUMonitor(logger=self.logger) if config.gpu_monitoring else None),
233240
)
234-
result.accumulate(e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics)
241+
result.accumulate(e2e_latency, timestamps, shape_and_decoded_output, gpu_metrics)
235242
self.logger.info("Benchmarking done. Cleaning up.")
236243

237244
# Profile if needed
@@ -249,75 +256,50 @@ def run_benchmark(
249256
"config": config,
250257
}
251258

252-
# TODO: refactor `generate_batch` to handle streaming so we can use it here
253-
def time_generate_batch(
254-
self,
255-
max_new_tokens: int,
256-
gpu_monitor: GPUMonitor | None = None,
257-
) -> tuple[float, list[float], str, GPURawMetrics | None]:
258-
if gpu_monitor is not None:
259-
gpu_monitor.start()
260-
# Prepare inputs
261-
inputs = self.inputs["input_ids"].tolist()
262-
timestamps = []
263-
last_result_generated_tokens = None
264-
wall_time_0 = time.perf_counter()
265-
# We disable prefix sharing because all prompts are the same
266-
with self.model.continuous_batching_context_manager(allow_prefix_sharing=False) as manager:
267-
manager.add_requests(inputs, max_new_tokens=max_new_tokens, streaming=True)
268-
unfinished_requests = len(inputs)
269-
while unfinished_requests > 0:
270-
# NOTE: I don't like having the extra if stmt here, but hopefully won't degrade perf too much
271-
result = manager.get_result()
272-
if result is not None:
273-
timestamps.append(time.perf_counter() - wall_time_0) # FIXME: the timestamps are wrong
274-
if result.is_finished():
275-
last_result_generated_tokens = result.generated_tokens
276-
unfinished_requests -= 1
277-
elif not manager.is_running():
278-
raise RuntimeError("Generation thread exited unexpectedly")
279-
# Post-processing
280-
wall_time_1 = time.perf_counter()
281-
e2e_latency = wall_time_1 - wall_time_0
282-
gpu_metrics = gpu_monitor.stop_and_collect() if gpu_monitor is not None else None
283-
decoded_output = self.tokenizer.decode(last_result_generated_tokens, skip_special_tokens=True)
284-
shape_and_decoded_output = f"{(1, len(last_result_generated_tokens))} | {decoded_output}"
285-
return e2e_latency, timestamps, shape_and_decoded_output, gpu_metrics
286-
287259
def time_generate(
288260
self,
289261
max_new_tokens: int,
262+
use_continuous_batching: bool = False,
290263
gpu_monitor: GPUMonitor | None = None,
291264
) -> tuple[float, list[float], str, GPURawMetrics | None]:
292-
"""Time the latency of a call to model.generate() with the given (inputs) and (max_new_tokens)."""
293265
# Prepare gpu monitoring if needed
294266
if gpu_monitor is not None:
295267
gpu_monitor.start()
296-
# Prepare streamer
297-
streamer = BenchmarkStreamer()
268+
298269
# Generate and time
299-
wall_time_0 = time.perf_counter()
300-
outputs = self.model.generate(
301-
**self.inputs,
302-
max_new_tokens=max_new_tokens,
303-
streamer=streamer,
304-
)
270+
if use_continuous_batching:
271+
inputs = self.inputs["input_ids"].tolist()
272+
wall_time_0 = time.perf_counter()
273+
results = self.model.generate_batch(inputs, allow_prefix_sharing=False, record_timestamps=True)
274+
else:
275+
streamer = BenchmarkStreamer()
276+
wall_time_0 = time.perf_counter()
277+
results = self.model.generate(**self.inputs, streamer=streamer)
278+
305279
wall_time_1 = time.perf_counter()
306-
# Stop gpu monitoring if needed
307280
gpu_metrics = gpu_monitor.stop_and_collect() if gpu_monitor is not None else None
308-
# Check if generation had the right number of tokens
281+
282+
# Retrieve timestamps and results in a way that allows similar post-processing
309283
input_tokens = self.inputs["input_ids"].size(-1)
310-
batch_size, output_tokens = outputs.shape
311-
new_tokens = output_tokens - input_tokens
312-
if new_tokens != max_new_tokens:
313-
raise RuntimeError(f"Generated {new_tokens} tokens, expected {max_new_tokens}")
284+
if use_continuous_batching:
285+
timestamps = [result.timestamps for result in results.values()]
286+
results = torch.tensor([result.generated_tokens for result in results.values()])
287+
else:
288+
timestamps = [streamer.timestamps[1:]] # skip the first timestamp because it's the input tokens
289+
results = results[:, input_tokens:]
290+
291+
# Check if generation had the right number of tokens
292+
if results.size(-1) != max_new_tokens:
293+
raise RuntimeError(f"Generated {results.size(-1)} tokens, expected {max_new_tokens}")
294+
314295
# Decode outputs
315-
decoded_output = self.tokenizer.decode(outputs[0, input_tokens:], skip_special_tokens=True)
316-
shape_and_decoded_output = f"{tuple(outputs.shape)} | {decoded_output}"
317-
# Compute intermediate quantities
296+
decoded_output = self.tokenizer.decode(results[0], skip_special_tokens=True)
297+
shape_and_decoded_output = f"{tuple(results.shape)} | {decoded_output}"
298+
299+
# Compute metrics
318300
e2e_latency = wall_time_1 - wall_time_0
319-
token_generation_times = [t - wall_time_0 for t in streamer.timestamps[1:]]
320-
return e2e_latency, token_generation_times, shape_and_decoded_output, gpu_metrics
301+
timestamps = torch.tensor(timestamps).sub(wall_time_0).tolist()
302+
return e2e_latency, timestamps, shape_and_decoded_output, gpu_metrics
321303

322304
def profile_generate(self, num_tokens_to_profile: int, config_name: str) -> None:
323305
"""Profile the latency of a call to model.generate() with the given (inputs) and (max_new_tokens)."""
@@ -431,36 +413,38 @@ def push_results_to_hub(self, dataset_id: str, results: dict[Any, Any], timestam
431413
"PUSH_TO_HUB_TOKEN is not set, cannot push results to the Hub. When setting dataset_id, please also set the PUSH_TO_HUB_TOKEN environment variable."
432414
)
433415

416+
api = HfApi()
434417
n_results = len(results)
435-
self.logger.info(f"Pushing {n_results} results to: {dataset_id}")
436-
rows = []
437-
for cfg_hash, entry in results.items():
438-
row = {
439-
"benchmark_config_hash": cfg_hash,
440-
"config": entry["config"].to_dict(),
441-
"measurements": entry["measurements"].to_dict(),
442-
"metadata": entry["metadata"].to_dict(),
443-
}
444-
rows.append(row)
445-
446-
ds = Dataset.from_list(rows)
447-
with tempfile.TemporaryDirectory() as tmp:
448-
jsonl_path = os.path.join(tmp, "data.jsonl")
449-
with open(jsonl_path, "w") as f:
450-
json_lines = []
451-
for ex in ds:
452-
json_lines.append(json.dumps(ex, ensure_ascii=False))
453-
f.write("\n".join(json_lines))
454-
455-
api = HfApi()
456-
# NOTE: we expect the repository to already exist
457-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if not timestamp else timestamp
458-
file_name = f"benchmark_run_{timestamp}.jsonl"
459-
api.upload_file(
460-
path_or_fileobj=jsonl_path,
461-
path_in_repo=file_name,
462-
repo_id=dataset_id,
463-
repo_type="dataset",
464-
token=PUSH_TO_HUB_TOKEN,
465-
)
466-
self.logger.info(f"Successfully uploaded results to: {dataset_id}")
418+
for summarized in [False, True]:
419+
self.logger.info(f"Pushing {n_results} results to: {dataset_id} with {summarized = }")
420+
rows = []
421+
for cfg_hash, entry in results.items():
422+
row = {
423+
"benchmark_config_hash": cfg_hash,
424+
"config": entry["config"].to_dict(),
425+
"measurements": entry["measurements"].to_dict(summarized=summarized),
426+
"metadata": entry["metadata"].to_dict(),
427+
}
428+
rows.append(row)
429+
430+
ds = Dataset.from_list(rows)
431+
with tempfile.TemporaryDirectory() as tmp:
432+
file_name = "summarized_results" if summarized else "full_results"
433+
jsonl_path = os.path.join(tmp, f"{file_name}.jsonl")
434+
with open(jsonl_path, "w") as f:
435+
json_lines = []
436+
for ex in ds:
437+
json_lines.append(json.dumps(ex, ensure_ascii=False))
438+
f.write("\n".join(json_lines))
439+
440+
# NOTE: we expect the repository to already exist
441+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") if not timestamp else timestamp
442+
file_name = file_name + "/" + f"benchmark_run_{timestamp}.jsonl"
443+
api.upload_file(
444+
path_or_fileobj=jsonl_path,
445+
path_in_repo=file_name,
446+
repo_id=dataset_id,
447+
repo_type="dataset",
448+
token=PUSH_TO_HUB_TOKEN,
449+
)
450+
self.logger.info(f"Successfully uploaded results to: {dataset_id} with {summarized = }")

benchmark_v2/framework/data_classes.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -89,31 +89,35 @@ class BenchmarkResult:
8989

9090
def __init__(self) -> None:
9191
self.e2e_latency = []
92+
self._timestamps = []
9293
self.time_to_first_token = []
9394
self.inter_token_latency = []
9495
self.shape_and_decoded_outputs = []
9596
self.gpu_metrics = []
9697

97-
def compute_itl(self, token_generation_times: list[float]) -> list[float]:
98-
return (token_generation_times[-1] - token_generation_times[0]) / len(token_generation_times)
99-
10098
def accumulate(
10199
self,
102100
e2e_latency: float,
103-
token_generation_times: list[float],
101+
timestamps: list[float],
104102
shape_and_decoded_output: str,
105103
gpu_metrics: GPURawMetrics | None,
106104
) -> None:
107105
self.e2e_latency.append(e2e_latency)
108-
self.time_to_first_token.append(token_generation_times[0])
109-
# inter-token latency is already an average in itself
110-
self.inter_token_latency.append(self.compute_itl(token_generation_times))
106+
self._timestamps.append(timestamps)
107+
self._accumulate_ttft_and_itl(timestamps)
111108
self.shape_and_decoded_outputs.append(shape_and_decoded_output)
112109
self.gpu_metrics.append(gpu_metrics)
113110

114-
def to_dict(self) -> dict[str, None | int | float]:
115-
# Save GPU metrics as None if it contains only None values
116-
if all(gm is None for gm in self.gpu_metrics):
111+
def _accumulate_ttft_and_itl(self, timestamps: list[float]) -> None:
112+
timestamps = np.array(timestamps)
113+
tftt = np.min(timestamps[:, 0])
114+
itl = np.mean(timestamps[:, -1] - timestamps[:, 0]) / (timestamps.shape[1] - 1)
115+
self.time_to_first_token.append(tftt)
116+
self.inter_token_latency.append(itl)
117+
118+
def to_dict(self, summarized: bool = False) -> dict[str, Any]:
119+
# Save GPU metrics as None if it contains only None values or if we are summarizing
120+
if summarized or all(gm is None for gm in self.gpu_metrics):
117121
gpu_metrics = None
118122
else:
119123
gpu_metrics = [gm.to_dict() for gm in self.gpu_metrics]
@@ -123,6 +127,7 @@ def to_dict(self) -> dict[str, None | int | float]:
123127
"inter_token_latency": self.inter_token_latency,
124128
"shape_and_decoded_outputs": self.shape_and_decoded_outputs,
125129
"gpu_metrics": gpu_metrics,
130+
"timestamps": None if summarized else self._timestamps,
126131
}
127132

128133
@classmethod
@@ -132,16 +137,19 @@ def from_dict(cls, data: dict[str, None | int | float]) -> "BenchmarkResult":
132137
gpu_metrics = [None for _ in range(len(data["e2e_latency"]))]
133138
else:
134139
gpu_metrics = [GPURawMetrics.from_dict(gm) for gm in data["gpu_metrics"]]
140+
# Handle timestamps, which can be saved as None to reduce file size
141+
if data["timestamps"] is None:
142+
timestamps = [None for _ in range(len(data["e2e_latency"]))]
143+
else:
144+
timestamps = data["timestamps"]
135145
# Create a new instance and accumulate the data
136146
new_instance = cls()
137-
for i in range(len(data["e2e_latency"])):
138-
new_instance.accumulate(
139-
e2e_latency=data["e2e_latency"][i],
140-
time_to_first_token=data["time_to_first_token"][i],
141-
inter_token_latency=data["inter_token_latency"][i],
142-
shape_and_decoded_output=data["shape_and_decoded_outputs"][i],
143-
gpu_metrics=gpu_metrics[i],
144-
)
147+
new_instance.e2e_latency = data["e2e_latency"]
148+
new_instance._timestamps = timestamps
149+
new_instance.time_to_first_token = data["time_to_first_token"]
150+
new_instance.inter_token_latency = data["inter_token_latency"]
151+
new_instance.shape_and_decoded_outputs = data["shape_and_decoded_outputs"]
152+
new_instance.gpu_metrics = gpu_metrics
145153
return new_instance
146154

147155
def get_throughput(self, total_generated_tokens: int) -> list[float]:

src/transformers/generation/continuous_batching/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def mark_blocks_as_complete(self, state: RequestState) -> None:
379379
self._block_manager.mark_blocks_as_complete(
380380
num_complete_blocks=num_complete_blocks,
381381
allocated_blocks=cm.block_table[state.request_id],
382-
prompt_ids=(state.full_prompt_ids + state.static_outputs),
382+
prompt_ids=(state.initial_tokens + state.generated_tokens),
383383
)
384384

385385

0 commit comments

Comments
 (0)