2929from vllm .config import VllmConfig
3030from vllm .distributed import destroy_distributed_environment , destroy_model_parallel
3131from vllm .distributed .device_communicators .shm_broadcast import Handle , MessageQueue
32+ from vllm .distributed .kv_transfer .kv_connector .utils import KVOutputAggregator
3233from vllm .distributed .parallel_state import (
3334 get_dp_group ,
3435 get_ep_group ,
5758
5859
5960class FutureWrapper (Future ):
60- def __init__ (self , futures_queue : deque [tuple ["FutureWrapper" , Callable ]]):
61+ def __init__ (
62+ self ,
63+ futures_queue : deque [tuple ["FutureWrapper" , Callable ]],
64+ aggregate : Callable = lambda x : x ,
65+ ):
6166 self .futures_queue = futures_queue
67+ self .aggregate = aggregate
6268 super ().__init__ ()
6369
6470 def result (self , timeout = None ):
@@ -72,7 +78,7 @@ def result(self, timeout=None):
7278
7379 def wait_for_response (self , get_response : Callable ):
7480 try :
75- response = get_response ()
81+ response = self . aggregate ( get_response () )
7682 with suppress (InvalidStateError ):
7783 self .set_result (response )
7884 except Exception as e :
@@ -160,7 +166,6 @@ def _init_executor(self) -> None:
160166 self .futures_queue = deque [tuple [FutureWrapper , Callable ]]()
161167
162168 self .output_rank = self ._get_output_rank ()
163- self .has_connector = self .vllm_config .kv_transfer_config is not None
164169
165170 def start_worker_monitor (self ):
166171 workers = self .workers
@@ -199,44 +204,27 @@ def register_failure_callback(self, callback: FailureCallback):
199204 def execute_model ( # type: ignore[override]
200205 self , scheduler_output : SchedulerOutput , non_block : bool = False
201206 ) -> ModelRunnerOutput | None | Future [ModelRunnerOutput | None ]:
202- return self ._execute_with_aggregation (
203- "execute_model" , scheduler_output , non_block = non_block
207+ return self .collective_rpc (
208+ "execute_model" ,
209+ args = (scheduler_output ,),
210+ unique_reply_rank = self .output_rank ,
211+ non_block = non_block ,
212+ timeout = envs .VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS ,
213+ kv_output_aggregator = self .kv_output_aggregator ,
204214 )
205215
206216 def sample_tokens ( # type: ignore[override]
207217 self , grammar_output : GrammarOutput | None , non_block : bool = False
208218 ) -> ModelRunnerOutput | Future [ModelRunnerOutput ]:
209- return self ._execute_with_aggregation ( # type: ignore[return-value]
210- "sample_tokens" , grammar_output , non_block = non_block
211- )
212-
213- def _execute_with_aggregation (
214- self , method : str , * args , non_block : bool = False
215- ) -> ModelRunnerOutput | None | Future [ModelRunnerOutput | None ]:
216- if not self .has_connector :
217- # get output only from a single worker (output_rank)
218- return self .collective_rpc (
219- method ,
220- args = args ,
221- unique_reply_rank = self .output_rank ,
222- non_block = non_block ,
223- timeout = envs .VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS ,
224- )
225-
226- # get output from all workers
227- outputs = self .collective_rpc (
228- method ,
229- args = args ,
219+ return self .collective_rpc (
220+ "sample_tokens" ,
221+ args = (grammar_output ,),
222+ unique_reply_rank = self .output_rank ,
230223 non_block = non_block ,
231224 timeout = envs .VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS ,
225+ kv_output_aggregator = self .kv_output_aggregator ,
232226 )
233227
234- # aggregate all workers output to a single output
235- assert self .kv_output_aggregator is not None
236- if non_block :
237- return self .kv_output_aggregator .async_aggregate (outputs , self .output_rank )
238- return self .kv_output_aggregator .aggregate (outputs , self .output_rank )
239-
240228 def execute_dummy_batch (self ) -> None :
241229 self .collective_rpc ("execute_dummy_batch" , unique_reply_rank = self .output_rank )
242230
@@ -254,29 +242,34 @@ def collective_rpc( # type: ignore[override]
254242 kwargs : dict | None = None ,
255243 non_block : bool = False ,
256244 unique_reply_rank : int | None = None ,
245+ kv_output_aggregator : KVOutputAggregator = None ,
257246 ) -> Any | list [Any ] | Future [Any | list [Any ]]:
258- """Returns single result if unique_reply_rank is provided, otherwise list."""
247+ """Returns single result if unique_reply_rank and/or kv_output_aggregator
248+ is provided, otherwise list."""
259249
260250 if self .is_failed :
261251 raise RuntimeError ("Executor failed." )
262252
263253 deadline = None if timeout is None else time .monotonic () + timeout
264254 kwargs = kwargs or {}
265255
266- # NOTE: If the args are heterogeneous, then we pack them into a list,
267- # and unpack them in the method of every worker, because every worker
268- # knows their own rank.
256+ if kv_output_aggregator is not None :
257+ output_rank = None
258+ aggregate : Callable [[Any ], Any ] = partial (
259+ kv_output_aggregator .aggregate , output_rank = unique_reply_rank or 0
260+ )
261+ else :
262+ output_rank = unique_reply_rank
263+ aggregate = lambda x : x
269264
270265 if isinstance (method , str ):
271266 send_method = method
272267 else :
273268 send_method = cloudpickle .dumps (method , protocol = pickle .HIGHEST_PROTOCOL )
274- self .rpc_broadcast_mq .enqueue ((send_method , args , kwargs , unique_reply_rank ))
269+ self .rpc_broadcast_mq .enqueue ((send_method , args , kwargs , output_rank ))
275270
276271 workers = (
277- (self .workers [unique_reply_rank ],)
278- if unique_reply_rank is not None
279- else self .workers
272+ (self .workers [output_rank ],) if output_rank is not None else self .workers
280273 )
281274
282275 shutdown_event = self .shutdown_event
@@ -299,10 +292,10 @@ def get_response():
299292 " stack trace above for the root cause"
300293 )
301294 responses .append (result )
302- return responses [0 ] if unique_reply_rank is not None else responses
295+ return responses [0 ] if output_rank is not None else responses
303296
304297 if non_block :
305- future = FutureWrapper (self .futures_queue )
298+ future = FutureWrapper (self .futures_queue , aggregate = aggregate )
306299 self .futures_queue .appendleft ((future , get_response ))
307300 return future
308301
@@ -311,7 +304,7 @@ def get_response():
311304 future , get_fut_response = self .futures_queue .pop ()
312305 future .wait_for_response (get_fut_response )
313306
314- return get_response ()
307+ return aggregate ( get_response () )
315308
316309 @staticmethod
317310 def _ensure_worker_termination (worker_procs : list [BaseProcess ]):
0 commit comments