Skip to content

Commit 1599fb7

Browse files
authored
[Serve] Make batching work with multiplexing (#59334)
fixes #56633 - [x] Add documentation - [x] update `get_multiplexed_model_id` to see if we are batch context first - [x] update logic - [x] add tests - [x] does not introduce any backwards incompatibility, previously the system did not provide any guarantee about contents of a batch and now we are add a constraint that guarantees each batch contains requests for same model. - [x] execute sub batches concurrently The thing I dislike about this implementation is that it does not fill the batch in the case where the replica is responsible for > 2 models and incoming traffic is equally distributed between those models. Becasue the current implementation fills the batch first, then divides them. Metric | Baseline (42905 reqs) | Master (27526 reqs) | Δ Change (Master − Baseline) -- | -- | -- | -- Requests | 42,905 | 27,526 | −15,379 Fails | 0 | 0 | 0 Median (ms) | 290 | 300 | +10 ms 95%ile (ms) | 560 | 570 | +10 ms 99%ile (ms) | 620 | 640 | +20 ms Average (ms) | 327.41 | 332.96 | +5.55 ms Min (ms) | 61 | 80 | +19 ms Max (ms) | 764 | 802 | +38 ms Avg Size (bytes) | 13 | 13 | 0 Current RPS | 299 | 293 | −6 Current Failures/s | 0 | 0 | 0 --------- Signed-off-by: abrar <abrar@anyscale.com>
1 parent 5334493 commit 1599fb7

File tree

5 files changed

+426
-78
lines changed

5 files changed

+426
-78
lines changed

doc/source/serve/doc_code/multiplexed.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,34 @@ async def __call__(self, request: starlette.requests.Request):
6666
serve.run(Upstream.bind(Downstream.bind()))
6767
resp = requests.get("http://localhost:8000")
6868
# __serve_model_composition_example_end__
69+
70+
71+
# __serve_multiplexed_batching_example_begin__
72+
from typing import List # noqa: E402
73+
from starlette.requests import Request
74+
75+
76+
@serve.deployment(max_ongoing_requests=15)
77+
class BatchedMultiplexModel:
78+
@serve.multiplexed(max_num_models_per_replica=3)
79+
async def get_model(self, model_id: str):
80+
# Load and return your model here
81+
return model_id
82+
83+
@serve.batch(max_batch_size=10, batch_wait_timeout_s=0.1)
84+
async def batched_predict(self, inputs: List[str]) -> List[str]:
85+
# Get the model ID - this works correctly inside batched functions
86+
# because all requests in the batch target the same model
87+
model_id = serve.get_multiplexed_model_id()
88+
model = await self.get_model(model_id)
89+
90+
# Process the batch with the loaded model
91+
return [f"{model}:{inp}" for inp in inputs]
92+
93+
async def __call__(self, request: Request):
94+
# Extract input from the request body
95+
input_text = await request.body()
96+
return await self.batched_predict(input_text.decode())
97+
98+
99+
# __serve_multiplexed_batching_example_end__

doc/source/serve/model-multiplexing.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,19 @@ When using model composition, you can send requests from an upstream deployment
8484
:start-after: __serve_model_composition_example_begin__
8585
:end-before: __serve_model_composition_example_end__
8686
```
87+
88+
## Using model multiplexing with batching
89+
90+
You can combine model multiplexing with the `@serve.batch` decorator for efficient batched inference. When you use both features together, Ray Serve automatically splits batches by model ID to ensure each batch contains only requests for the same model. This prevents issues where a single batch would contain requests targeting different models.
91+
92+
The following example shows how to combine multiplexing with batching:
93+
94+
```{literalinclude} doc_code/multiplexed.py
95+
:language: python
96+
:start-after: __serve_multiplexed_batching_example_begin__
97+
:end-before: __serve_multiplexed_batching_example_end__
98+
```
99+
100+
:::{note}
101+
`serve.get_multiplexed_model_id()` works correctly inside functions decorated with `@serve.batch`. Ray Serve guarantees that all requests in a batch have the same `multiplexed_model_id`, so you can safely use this value to load and apply the appropriate model for the entire batch.
102+
:::

python/ray/serve/api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,11 @@ def get_multiplexed_model_id() -> str:
890890
This is used with a function decorated with `@serve.multiplexed`
891891
to retrieve the model ID for the current request.
892892
893+
When called from within a batched function (decorated with `@serve.batch`),
894+
this returns the multiplexed model ID that is common to all requests in
895+
the current batch. This works because batches are automatically split
896+
by model ID to ensure all requests in a batch target the same model.
897+
893898
.. code-block:: python
894899
895900
import ray
@@ -911,6 +916,14 @@ def get_multiplexed_model_id() -> str:
911916
def my_deployment_function(request):
912917
assert serve.get_multiplexed_model_id() == "model_1"
913918
"""
919+
# First check if we're inside a batch context. If so, get the model ID
920+
# from the batch request context. All requests in a batch are guaranteed
921+
# to have the same multiplexed_model_id (batches are split by model ID).
922+
batch_request_context = ray.serve.context._get_serve_batch_request_context()
923+
if batch_request_context:
924+
return batch_request_context[0].multiplexed_model_id
925+
926+
# Fall back to the regular request context
914927
_request_context = ray.serve.context._get_serve_request_context()
915928
return _request_context.multiplexed_model_id
916929

python/ray/serve/batching.py

Lines changed: 131 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -435,105 +435,158 @@ async def _assign_func_results(
435435
for future in futures:
436436
_set_exception_if_not_done(future, e)
437437

438+
def _split_batch_by_model_id(
439+
self, batch: List[_SingleRequest]
440+
) -> List[List[_SingleRequest]]:
441+
"""Split a batch into sub-batches based on multiplexed_model_id.
442+
443+
When using model multiplexing with batching, requests for different models
444+
may end up in the same batch. This method ensures that each sub-batch only
445+
contains requests for the same model, preventing issues where a single batch
446+
contains requests for different models.
447+
448+
If no requests have a multiplexed_model_id set, returns the original batch
449+
as a single sub-batch.
450+
451+
Args:
452+
batch: The batch of requests to split.
453+
454+
Returns:
455+
A list of sub-batches, where each sub-batch contains requests for the
456+
same multiplexed_model_id.
457+
"""
458+
# Group requests by their multiplexed_model_id
459+
model_id_to_requests: Dict[str, List[_SingleRequest]] = {}
460+
for request in batch:
461+
model_id = request.request_context.multiplexed_model_id
462+
if model_id not in model_id_to_requests:
463+
model_id_to_requests[model_id] = []
464+
model_id_to_requests[model_id].append(request)
465+
466+
# Return sub-batches for each model_id
467+
return list(model_id_to_requests.values())
468+
438469
async def _process_batches(self, func: Callable) -> None:
439470
"""Loops infinitely and processes queued request batches."""
440471
# When asyncio task is created, the task will inherit the request context from the current context.
441472
# So we unset the request context so the current context is not inherited by the task, _process_batch.
442473
serve.context._unset_request_context()
443474
while not self._loop.is_closed():
444-
batch, computed_batch_size = await self.wait_for_batch()
445-
promise = self._process_batch(func, batch, computed_batch_size)
475+
batch, _ = await self.wait_for_batch()
476+
477+
# Split batch by multiplexed_model_id to ensure requests for different
478+
# models are processed in separate batches. This is necessary when using
479+
# model multiplexing with batching, as a single batch containing requests
480+
# for different models would not work correctly.
481+
sub_batches = self._split_batch_by_model_id(batch)
482+
483+
# Process all sub-batches together under a single semaphore permit.
484+
# This ensures sub-batches from the same original batch run concurrently
485+
# rather than being serialized by the semaphore.
486+
promise = self._process_sub_batches(func, sub_batches)
446487
task = asyncio.create_task(promise)
447488
self.tasks.add(task)
448489
self.curr_iteration_start_times[task] = time.time()
449490
task.add_done_callback(self._handle_completed_task)
450491

451-
async def _process_batch(
452-
self, func: Callable, batch: List[_SingleRequest], computed_batch_size: int
492+
async def _process_sub_batches(
493+
self, func: Callable, sub_batches: List[List[_SingleRequest]]
453494
) -> None:
454-
"""Processes queued request batch."""
495+
"""Processes multiple sub-batches concurrently under a single semaphore permit.
496+
497+
This method acquires the semaphore once and then processes all sub-batches
498+
in parallel, ensuring that sub-batches from the same original batch don't
499+
compete for semaphore permits.
500+
"""
455501
# NOTE: this semaphore caps the number of concurrent batches specified by `max_concurrent_batches`
456502
async with self.semaphore:
457-
# Remove requests that have been cancelled from the batch. If
458-
# all requests have been cancelled, simply return and wait for
459-
# the next batch.
460-
original_batch_len = len(batch)
461-
batch = [req for req in batch if not req.future.cancelled()]
462-
if len(batch) == 0:
463-
return
464-
465-
# Record batch utilization metric.
466-
# Use computed_batch_size from wait_for_batch for efficiency.
467-
# If requests were cancelled, we need to recompute since the batch changed.
468-
if len(batch) != original_batch_len:
469-
computed_batch_size = self._compute_batch_size(batch)
470-
471-
# Calculate and record batch utilization percentage.
472-
batch_utilization_percent = (
473-
computed_batch_size / self.max_batch_size
474-
) * 100
475-
self._batch_utilization_histogram.observe(
476-
batch_utilization_percent, tags={"function_name": self._function_name}
477-
)
503+
# Create tasks for each sub-batch. We use asyncio.create_task() instead
504+
# of passing coroutines directly to asyncio.gather() because create_task
505+
# copies the current context, giving each sub-batch its own isolated
506+
# contextvars. This prevents concurrent sub-batches from overwriting
507+
# each other's _serve_batch_request_context, which would cause
508+
# get_multiplexed_model_id() to return wrong values.
509+
tasks = [
510+
asyncio.create_task(self._process_batch_inner(func, sub_batch))
511+
for sub_batch in sub_batches
512+
]
513+
await asyncio.gather(*tasks)
514+
515+
async def _process_batch_inner(
516+
self, func: Callable, batch: List[_SingleRequest]
517+
) -> None:
518+
"""Processes a single batch without acquiring the semaphore.
478519
479-
# Record actual batch size (number of requests in the batch computed by the batch_size_fn).
480-
self._batch_size_histogram.observe(
481-
computed_batch_size, tags={"function_name": self._function_name}
482-
)
520+
This is the inner implementation called by _process_sub_batches after
521+
the semaphore has already been acquired.
522+
"""
523+
# Remove requests that have been cancelled from the batch. If
524+
# all requests have been cancelled, simply return and wait for
525+
# the next batch.
526+
batch = [req for req in batch if not req.future.cancelled()]
527+
if len(batch) == 0:
528+
return
483529

484-
# Increment batches processed counter.
485-
self._batches_processed_counter.inc(
486-
tags={"function_name": self._function_name}
487-
)
530+
# Compute batch size for this sub-batch. Each sub-batch may have a different
531+
# size, especially when splitting by model_id, so we compute it here.
532+
computed_batch_size = self._compute_batch_size(batch)
488533

489-
futures = [item.future for item in batch]
534+
# Calculate and record batch utilization percentage.
535+
batch_utilization_percent = (computed_batch_size / self.max_batch_size) * 100
536+
self._batch_utilization_histogram.observe(
537+
batch_utilization_percent, tags={"function_name": self._function_name}
538+
)
490539

491-
# Most of the logic in the function should be wrapped in this try-
492-
# except block, so the futures' exceptions can be set if an exception
493-
# occurs. Otherwise, the futures' requests may hang indefinitely.
494-
batch_execution_start_time = time.time()
495-
try:
496-
self_arg = batch[0].self_arg
497-
args, kwargs = _batch_args_kwargs(
498-
[item.flattened_args for item in batch]
499-
)
540+
# Record actual batch size (number of requests in the batch computed by the batch_size_fn).
541+
self._batch_size_histogram.observe(
542+
computed_batch_size, tags={"function_name": self._function_name}
543+
)
500544

501-
# Method call.
502-
if self_arg is not None:
503-
func_future_or_generator = func(self_arg, *args, **kwargs)
504-
# Normal function call.
505-
else:
506-
func_future_or_generator = func(*args, **kwargs)
545+
# Increment batches processed counter.
546+
self._batches_processed_counter.inc(tags={"function_name": self._function_name})
507547

508-
# Add individual request context to the batch request context
509-
serve.context._set_batch_request_context(
510-
[req.request_context for req in batch]
511-
)
548+
futures = [item.future for item in batch]
512549

513-
if isasyncgenfunction(func):
514-
func_generator = func_future_or_generator
515-
await self._consume_func_generator(
516-
func_generator, futures, len(batch)
517-
)
518-
else:
519-
func_future = func_future_or_generator
520-
await self._assign_func_results(func_future, futures, len(batch))
521-
522-
# Reset the batch request context after the batch is processed
523-
serve.context._set_batch_request_context([])
524-
except Exception as e:
525-
logger.exception("_process_batch ran into an unexpected exception.")
526-
527-
for future in futures:
528-
_set_exception_if_not_done(future, e)
529-
finally:
530-
# Record batch execution time.
531-
batch_execution_time_ms = (
532-
time.time() - batch_execution_start_time
533-
) * 1000
534-
self._batch_execution_time_histogram.observe(
535-
batch_execution_time_ms, tags={"function_name": self._function_name}
536-
)
550+
# Most of the logic in the function should be wrapped in this try-
551+
# except block, so the futures' exceptions can be set if an exception
552+
# occurs. Otherwise, the futures' requests may hang indefinitely.
553+
batch_execution_start_time = time.time()
554+
try:
555+
self_arg = batch[0].self_arg
556+
args, kwargs = _batch_args_kwargs([item.flattened_args for item in batch])
557+
558+
# Method call.
559+
if self_arg is not None:
560+
func_future_or_generator = func(self_arg, *args, **kwargs)
561+
# Normal function call.
562+
else:
563+
func_future_or_generator = func(*args, **kwargs)
564+
565+
# Add individual request context to the batch request context
566+
serve.context._set_batch_request_context(
567+
[req.request_context for req in batch]
568+
)
569+
570+
if isasyncgenfunction(func):
571+
func_generator = func_future_or_generator
572+
await self._consume_func_generator(func_generator, futures, len(batch))
573+
else:
574+
func_future = func_future_or_generator
575+
await self._assign_func_results(func_future, futures, len(batch))
576+
577+
# Reset the batch request context after the batch is processed
578+
serve.context._set_batch_request_context([])
579+
except Exception as e:
580+
logger.exception("_process_batch ran into an unexpected exception.")
581+
582+
for future in futures:
583+
_set_exception_if_not_done(future, e)
584+
finally:
585+
# Record batch execution time.
586+
batch_execution_time_ms = (time.time() - batch_execution_start_time) * 1000
587+
self._batch_execution_time_histogram.observe(
588+
batch_execution_time_ms, tags={"function_name": self._function_name}
589+
)
537590

538591
def _handle_completed_task(self, task: asyncio.Task) -> None:
539592
self.tasks.remove(task)

0 commit comments

Comments
 (0)