@@ -435,105 +435,158 @@ async def _assign_func_results(
435435 for future in futures :
436436 _set_exception_if_not_done (future , e )
437437
438+ def _split_batch_by_model_id (
439+ self , batch : List [_SingleRequest ]
440+ ) -> List [List [_SingleRequest ]]:
441+ """Split a batch into sub-batches based on multiplexed_model_id.
442+
443+ When using model multiplexing with batching, requests for different models
444+ may end up in the same batch. This method ensures that each sub-batch only
445+ contains requests for the same model, preventing issues where a single batch
446+ contains requests for different models.
447+
448+ If no requests have a multiplexed_model_id set, returns the original batch
449+ as a single sub-batch.
450+
451+ Args:
452+ batch: The batch of requests to split.
453+
454+ Returns:
455+ A list of sub-batches, where each sub-batch contains requests for the
456+ same multiplexed_model_id.
457+ """
458+ # Group requests by their multiplexed_model_id
459+ model_id_to_requests : Dict [str , List [_SingleRequest ]] = {}
460+ for request in batch :
461+ model_id = request .request_context .multiplexed_model_id
462+ if model_id not in model_id_to_requests :
463+ model_id_to_requests [model_id ] = []
464+ model_id_to_requests [model_id ].append (request )
465+
466+ # Return sub-batches for each model_id
467+ return list (model_id_to_requests .values ())
468+
438469 async def _process_batches (self , func : Callable ) -> None :
439470 """Loops infinitely and processes queued request batches."""
440471 # When asyncio task is created, the task will inherit the request context from the current context.
441472 # So we unset the request context so the current context is not inherited by the task, _process_batch.
442473 serve .context ._unset_request_context ()
443474 while not self ._loop .is_closed ():
444- batch , computed_batch_size = await self .wait_for_batch ()
445- promise = self ._process_batch (func , batch , computed_batch_size )
475+ batch , _ = await self .wait_for_batch ()
476+
477+ # Split batch by multiplexed_model_id to ensure requests for different
478+ # models are processed in separate batches. This is necessary when using
479+ # model multiplexing with batching, as a single batch containing requests
480+ # for different models would not work correctly.
481+ sub_batches = self ._split_batch_by_model_id (batch )
482+
483+ # Process all sub-batches together under a single semaphore permit.
484+ # This ensures sub-batches from the same original batch run concurrently
485+ # rather than being serialized by the semaphore.
486+ promise = self ._process_sub_batches (func , sub_batches )
446487 task = asyncio .create_task (promise )
447488 self .tasks .add (task )
448489 self .curr_iteration_start_times [task ] = time .time ()
449490 task .add_done_callback (self ._handle_completed_task )
450491
451- async def _process_batch (
452- self , func : Callable , batch : List [_SingleRequest ], computed_batch_size : int
492+ async def _process_sub_batches (
493+ self , func : Callable , sub_batches : List [List [ _SingleRequest ]]
453494 ) -> None :
454- """Processes queued request batch."""
495+ """Processes multiple sub-batches concurrently under a single semaphore permit.
496+
497+ This method acquires the semaphore once and then processes all sub-batches
498+ in parallel, ensuring that sub-batches from the same original batch don't
499+ compete for semaphore permits.
500+ """
455501 # NOTE: this semaphore caps the number of concurrent batches specified by `max_concurrent_batches`
456502 async with self .semaphore :
457- # Remove requests that have been cancelled from the batch. If
458- # all requests have been cancelled, simply return and wait for
459- # the next batch.
460- original_batch_len = len (batch )
461- batch = [req for req in batch if not req .future .cancelled ()]
462- if len (batch ) == 0 :
463- return
464-
465- # Record batch utilization metric.
466- # Use computed_batch_size from wait_for_batch for efficiency.
467- # If requests were cancelled, we need to recompute since the batch changed.
468- if len (batch ) != original_batch_len :
469- computed_batch_size = self ._compute_batch_size (batch )
470-
471- # Calculate and record batch utilization percentage.
472- batch_utilization_percent = (
473- computed_batch_size / self .max_batch_size
474- ) * 100
475- self ._batch_utilization_histogram .observe (
476- batch_utilization_percent , tags = {"function_name" : self ._function_name }
477- )
503+ # Create tasks for each sub-batch. We use asyncio.create_task() instead
504+ # of passing coroutines directly to asyncio.gather() because create_task
505+ # copies the current context, giving each sub-batch its own isolated
506+ # contextvars. This prevents concurrent sub-batches from overwriting
507+ # each other's _serve_batch_request_context, which would cause
508+ # get_multiplexed_model_id() to return wrong values.
509+ tasks = [
510+ asyncio .create_task (self ._process_batch_inner (func , sub_batch ))
511+ for sub_batch in sub_batches
512+ ]
513+ await asyncio .gather (* tasks )
514+
515+ async def _process_batch_inner (
516+ self , func : Callable , batch : List [_SingleRequest ]
517+ ) -> None :
518+ """Processes a single batch without acquiring the semaphore.
478519
479- # Record actual batch size (number of requests in the batch computed by the batch_size_fn).
480- self ._batch_size_histogram .observe (
481- computed_batch_size , tags = {"function_name" : self ._function_name }
482- )
520+ This is the inner implementation called by _process_sub_batches after
521+ the semaphore has already been acquired.
522+ """
523+ # Remove requests that have been cancelled from the batch. If
524+ # all requests have been cancelled, simply return and wait for
525+ # the next batch.
526+ batch = [req for req in batch if not req .future .cancelled ()]
527+ if len (batch ) == 0 :
528+ return
483529
484- # Increment batches processed counter.
485- self ._batches_processed_counter .inc (
486- tags = {"function_name" : self ._function_name }
487- )
530+ # Compute batch size for this sub-batch. Each sub-batch may have a different
531+ # size, especially when splitting by model_id, so we compute it here.
532+ computed_batch_size = self ._compute_batch_size (batch )
488533
489- futures = [item .future for item in batch ]
534+ # Calculate and record batch utilization percentage.
535+ batch_utilization_percent = (computed_batch_size / self .max_batch_size ) * 100
536+ self ._batch_utilization_histogram .observe (
537+ batch_utilization_percent , tags = {"function_name" : self ._function_name }
538+ )
490539
491- # Most of the logic in the function should be wrapped in this try-
492- # except block, so the futures' exceptions can be set if an exception
493- # occurs. Otherwise, the futures' requests may hang indefinitely.
494- batch_execution_start_time = time .time ()
495- try :
496- self_arg = batch [0 ].self_arg
497- args , kwargs = _batch_args_kwargs (
498- [item .flattened_args for item in batch ]
499- )
540+ # Record actual batch size (number of requests in the batch computed by the batch_size_fn).
541+ self ._batch_size_histogram .observe (
542+ computed_batch_size , tags = {"function_name" : self ._function_name }
543+ )
500544
501- # Method call.
502- if self_arg is not None :
503- func_future_or_generator = func (self_arg , * args , ** kwargs )
504- # Normal function call.
505- else :
506- func_future_or_generator = func (* args , ** kwargs )
545+ # Increment batches processed counter.
546+ self ._batches_processed_counter .inc (tags = {"function_name" : self ._function_name })
507547
508- # Add individual request context to the batch request context
509- serve .context ._set_batch_request_context (
510- [req .request_context for req in batch ]
511- )
548+ futures = [item .future for item in batch ]
512549
513- if isasyncgenfunction (func ):
514- func_generator = func_future_or_generator
515- await self ._consume_func_generator (
516- func_generator , futures , len (batch )
517- )
518- else :
519- func_future = func_future_or_generator
520- await self ._assign_func_results (func_future , futures , len (batch ))
521-
522- # Reset the batch request context after the batch is processed
523- serve .context ._set_batch_request_context ([])
524- except Exception as e :
525- logger .exception ("_process_batch ran into an unexpected exception." )
526-
527- for future in futures :
528- _set_exception_if_not_done (future , e )
529- finally :
530- # Record batch execution time.
531- batch_execution_time_ms = (
532- time .time () - batch_execution_start_time
533- ) * 1000
534- self ._batch_execution_time_histogram .observe (
535- batch_execution_time_ms , tags = {"function_name" : self ._function_name }
536- )
550+ # Most of the logic in the function should be wrapped in this try-
551+ # except block, so the futures' exceptions can be set if an exception
552+ # occurs. Otherwise, the futures' requests may hang indefinitely.
553+ batch_execution_start_time = time .time ()
554+ try :
555+ self_arg = batch [0 ].self_arg
556+ args , kwargs = _batch_args_kwargs ([item .flattened_args for item in batch ])
557+
558+ # Method call.
559+ if self_arg is not None :
560+ func_future_or_generator = func (self_arg , * args , ** kwargs )
561+ # Normal function call.
562+ else :
563+ func_future_or_generator = func (* args , ** kwargs )
564+
565+ # Add individual request context to the batch request context
566+ serve .context ._set_batch_request_context (
567+ [req .request_context for req in batch ]
568+ )
569+
570+ if isasyncgenfunction (func ):
571+ func_generator = func_future_or_generator
572+ await self ._consume_func_generator (func_generator , futures , len (batch ))
573+ else :
574+ func_future = func_future_or_generator
575+ await self ._assign_func_results (func_future , futures , len (batch ))
576+
577+ # Reset the batch request context after the batch is processed
578+ serve .context ._set_batch_request_context ([])
579+ except Exception as e :
580+ logger .exception ("_process_batch ran into an unexpected exception." )
581+
582+ for future in futures :
583+ _set_exception_if_not_done (future , e )
584+ finally :
585+ # Record batch execution time.
586+ batch_execution_time_ms = (time .time () - batch_execution_start_time ) * 1000
587+ self ._batch_execution_time_histogram .observe (
588+ batch_execution_time_ms , tags = {"function_name" : self ._function_name }
589+ )
537590
538591 def _handle_completed_task (self , task : asyncio .Task ) -> None :
539592 self .tasks .remove (task )
0 commit comments