Skip to content

Commit f8e5ae6

Browse files
authored
Better continuous batching tests (#42699)
* No more size 0 cuda graph * Better tests for CB * compile fix for CB test * style * More cleanup and cuda exclusive * Returned to slow tests * Change decorator order * Restore XPU change * Rebase fixes
1 parent 86644be commit f8e5ae6

File tree

2 files changed

+206
-298
lines changed

2 files changed

+206
-298
lines changed

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def pad_by_intervals(size: int, max_value: int, nb_intervals: int) -> int:
6666
interval_size = max_value // nb_intervals
6767
if interval_size == 0:
6868
return max_value
69-
padded = ceil(size / interval_size) * interval_size
69+
padded = ceil(size / interval_size) * interval_size if size > 0 else interval_size
7070
return min(padded, max_value)
7171

7272

@@ -713,6 +713,7 @@ def _process_logit(self, batch_data: dict, logits: torch.Tensor, logit_processor
713713
# Handle shape compatibility: logit processors expect 2D tensors [batch_size, vocab_size]
714714
# but continuous batching always produces 3D tensors [batch_size, seq_len, vocab_size]
715715
batch_size, seq_len, vocab_size = logits.shape
716+
# NOTE: to be an exact match with generate, we should also convert logits2d to float32 here, but it's not needed in practice
716717
logits_2d = logits.view(batch_size * seq_len, vocab_size)
717718
input_ids_2d = batch_data["input_ids"].view(batch_size * seq_len)
718719
# Process with 2D tensors
@@ -869,7 +870,7 @@ def stop(self, block: bool = True, timeout: float | None = None) -> None:
869870
logger.warning("\nBatch processor was not initialized.")
870871
else:
871872
if self.batch_processor.cache.use_prefix_sharing:
872-
logger.warning(
873+
logger.info(
873874
f"\nPrefix sharing was on. Total prefix length: {self.batch_processor.cache._total_prefix_length}"
874875
)
875876

0 commit comments

Comments
 (0)