Skip to content

Commit 289eb6c

Browse files
authored
[Core] Simplify async KV output aggregation (vllm-project#28327)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 19d91ec commit 289eb6c

File tree

4 files changed

+45
-153
lines changed

4 files changed

+45
-153
lines changed

tests/v1/executor/test_executor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import pytest
1111

12+
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
1213
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
1314
from vllm.sampling_params import SamplingParams
1415
from vllm.v1.engine.async_llm import AsyncLLM
@@ -28,12 +29,19 @@ def collective_rpc(
2829
kwargs: dict | None = None,
2930
non_block: bool = False,
3031
unique_reply_rank: int | None = None,
32+
kv_output_aggregator: KVOutputAggregator = None,
3133
) -> Any | list[Any] | Future[Any | list[Any]]:
3234
# Drop marker to show that this was run
3335
with open(".marker", "w"):
3436
...
3537
return super().collective_rpc(
36-
method, timeout, args, kwargs, non_block, unique_reply_rank
38+
method,
39+
timeout,
40+
args,
41+
kwargs,
42+
non_block,
43+
unique_reply_rank,
44+
kv_output_aggregator,
3745
)
3846

3947

tests/v1/kv_connector/unit/test_output_aggregator.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from concurrent.futures import Future
43

54
import pytest
65

@@ -86,74 +85,6 @@ def test_aggregate_workers_output():
8685
assert aggregated.invalid_block_ids == {3, 4, 5}
8786

8887

89-
def test_async_aggregate_workers_output():
90-
aggregator = KVOutputAggregator(expected_finished_count=2)
91-
92-
future: Future[list[DummyModelRunnerOutput]] = Future()
93-
result_future = aggregator.async_aggregate(future)
94-
95-
output1 = DummyModelRunnerOutput()
96-
output2 = DummyModelRunnerOutput()
97-
future.set_result([output1, output2])
98-
99-
assert result_future.done()
100-
aggregated = result_future.result()
101-
assert aggregated is output1
102-
aggregated = aggregated.kv_connector_output
103-
assert aggregated.finished_sending is None
104-
assert aggregated.finished_recving is None
105-
assert not aggregated.invalid_block_ids
106-
107-
future = Future()
108-
result_future = aggregator.async_aggregate(future)
109-
110-
output1 = DummyModelRunnerOutput(
111-
finished_sending={"req1"}, finished_recving={"req2"}
112-
)
113-
output2 = DummyModelRunnerOutput(invalid_block_ids={1})
114-
future.set_result([output1, output2])
115-
116-
assert result_future.done()
117-
aggregated = result_future.result()
118-
assert aggregated is output1
119-
aggregated = aggregated.kv_connector_output
120-
assert aggregated.finished_sending is None
121-
assert aggregated.finished_recving is None
122-
assert aggregated.invalid_block_ids == {1}
123-
124-
future = Future()
125-
result_future = aggregator.async_aggregate(future)
126-
127-
output1 = DummyModelRunnerOutput(invalid_block_ids={2})
128-
output2 = DummyModelRunnerOutput(finished_sending={"req1"})
129-
future.set_result([output1, output2])
130-
131-
assert result_future.done()
132-
aggregated = result_future.result()
133-
assert aggregated is output1
134-
aggregated = aggregated.kv_connector_output
135-
assert aggregated.finished_sending == {"req1"}
136-
assert aggregated.finished_recving is None
137-
assert aggregated.invalid_block_ids == {2}
138-
139-
future = Future()
140-
result_future = aggregator.async_aggregate(future)
141-
142-
output1 = DummyModelRunnerOutput(invalid_block_ids={3, 4})
143-
output2 = DummyModelRunnerOutput(
144-
finished_recving={"req2"}, invalid_block_ids={4, 5}
145-
)
146-
future.set_result([output1, output2])
147-
148-
assert result_future.done()
149-
aggregated = result_future.result()
150-
assert aggregated is output1
151-
aggregated = aggregated.kv_connector_output
152-
assert aggregated.finished_sending is None
153-
assert aggregated.finished_recving == {"req2"}
154-
assert aggregated.invalid_block_ids == {3, 4, 5}
155-
156-
15788
def test_aggregate_workers_output_with_expected_finished_count():
15889
# We create the aggregator expecting to collect from 4 workers
15990
aggregator = KVOutputAggregator(expected_finished_count=4)

vllm/distributed/kv_transfer/kv_connector/utils.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
KV cache helper for store.
55
"""
66

7-
import contextlib
8-
from collections.abc import Sequence
9-
from concurrent.futures import CancelledError, Future
107
from typing import TYPE_CHECKING, Literal
118

129
import torch
@@ -220,43 +217,6 @@ def update_finished_set(
220217

221218
return output
222219

223-
def async_aggregate(
224-
self,
225-
output_future: Future[Sequence[ModelRunnerOutput | None]],
226-
output_rank: int = 0,
227-
) -> Future[ModelRunnerOutput | None]:
228-
"""Takes a future that resolves to a list of outputs and returns a future
229-
which resolves to a single aggregated output."""
230-
result_future: Future[ModelRunnerOutput | None] = Future()
231-
232-
def callback(fut):
233-
if result_future.done():
234-
return
235-
try:
236-
result_future.set_result(self.aggregate(fut.result(), output_rank))
237-
except CancelledError:
238-
result_future.cancel()
239-
except Exception as e:
240-
result_future.set_exception(e)
241-
242-
output_future.add_done_callback(callback)
243-
244-
from vllm.v1.executor.multiproc_executor import FutureWrapper
245-
246-
if isinstance(output_future, FutureWrapper):
247-
# Due to the threadless implementation of multiproc FutureWrapper,
248-
# we must block on the delegate future's result() method.
249-
delegate_result = result_future.result
250-
251-
def result(timeout=None):
252-
with contextlib.suppress(Exception):
253-
output_future.result(timeout=timeout)
254-
return delegate_result()
255-
256-
result_future.result = result # type: ignore[method-assign]
257-
258-
return result_future
259-
260220

261221
def _make_src_and_dst_indices(
262222
src_block_ids: list[int],

vllm/v1/executor/multiproc_executor.py

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from vllm.config import VllmConfig
3030
from vllm.distributed import destroy_distributed_environment, destroy_model_parallel
3131
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
32+
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
3233
from vllm.distributed.parallel_state import (
3334
get_dp_group,
3435
get_ep_group,
@@ -57,8 +58,13 @@
5758

5859

5960
class FutureWrapper(Future):
60-
def __init__(self, futures_queue: deque[tuple["FutureWrapper", Callable]]):
61+
def __init__(
62+
self,
63+
futures_queue: deque[tuple["FutureWrapper", Callable]],
64+
aggregate: Callable = lambda x: x,
65+
):
6166
self.futures_queue = futures_queue
67+
self.aggregate = aggregate
6268
super().__init__()
6369

6470
def result(self, timeout=None):
@@ -72,7 +78,7 @@ def result(self, timeout=None):
7278

7379
def wait_for_response(self, get_response: Callable):
7480
try:
75-
response = get_response()
81+
response = self.aggregate(get_response())
7682
with suppress(InvalidStateError):
7783
self.set_result(response)
7884
except Exception as e:
@@ -160,7 +166,6 @@ def _init_executor(self) -> None:
160166
self.futures_queue = deque[tuple[FutureWrapper, Callable]]()
161167

162168
self.output_rank = self._get_output_rank()
163-
self.has_connector = self.vllm_config.kv_transfer_config is not None
164169

165170
def start_worker_monitor(self):
166171
workers = self.workers
@@ -199,44 +204,27 @@ def register_failure_callback(self, callback: FailureCallback):
199204
def execute_model( # type: ignore[override]
200205
self, scheduler_output: SchedulerOutput, non_block: bool = False
201206
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
202-
return self._execute_with_aggregation(
203-
"execute_model", scheduler_output, non_block=non_block
207+
return self.collective_rpc(
208+
"execute_model",
209+
args=(scheduler_output,),
210+
unique_reply_rank=self.output_rank,
211+
non_block=non_block,
212+
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
213+
kv_output_aggregator=self.kv_output_aggregator,
204214
)
205215

206216
def sample_tokens( # type: ignore[override]
207217
self, grammar_output: GrammarOutput | None, non_block: bool = False
208218
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
209-
return self._execute_with_aggregation( # type: ignore[return-value]
210-
"sample_tokens", grammar_output, non_block=non_block
211-
)
212-
213-
def _execute_with_aggregation(
214-
self, method: str, *args, non_block: bool = False
215-
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
216-
if not self.has_connector:
217-
# get output only from a single worker (output_rank)
218-
return self.collective_rpc(
219-
method,
220-
args=args,
221-
unique_reply_rank=self.output_rank,
222-
non_block=non_block,
223-
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
224-
)
225-
226-
# get output from all workers
227-
outputs = self.collective_rpc(
228-
method,
229-
args=args,
219+
return self.collective_rpc(
220+
"sample_tokens",
221+
args=(grammar_output,),
222+
unique_reply_rank=self.output_rank,
230223
non_block=non_block,
231224
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
225+
kv_output_aggregator=self.kv_output_aggregator,
232226
)
233227

234-
# aggregate all workers output to a single output
235-
assert self.kv_output_aggregator is not None
236-
if non_block:
237-
return self.kv_output_aggregator.async_aggregate(outputs, self.output_rank)
238-
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
239-
240228
def execute_dummy_batch(self) -> None:
241229
self.collective_rpc("execute_dummy_batch", unique_reply_rank=self.output_rank)
242230

@@ -254,29 +242,34 @@ def collective_rpc( # type: ignore[override]
254242
kwargs: dict | None = None,
255243
non_block: bool = False,
256244
unique_reply_rank: int | None = None,
245+
kv_output_aggregator: KVOutputAggregator = None,
257246
) -> Any | list[Any] | Future[Any | list[Any]]:
258-
"""Returns single result if unique_reply_rank is provided, otherwise list."""
247+
"""Returns single result if unique_reply_rank and/or kv_output_aggregator
248+
is provided, otherwise list."""
259249

260250
if self.is_failed:
261251
raise RuntimeError("Executor failed.")
262252

263253
deadline = None if timeout is None else time.monotonic() + timeout
264254
kwargs = kwargs or {}
265255

266-
# NOTE: If the args are heterogeneous, then we pack them into a list,
267-
# and unpack them in the method of every worker, because every worker
268-
# knows their own rank.
256+
if kv_output_aggregator is not None:
257+
output_rank = None
258+
aggregate: Callable[[Any], Any] = partial(
259+
kv_output_aggregator.aggregate, output_rank=unique_reply_rank or 0
260+
)
261+
else:
262+
output_rank = unique_reply_rank
263+
aggregate = lambda x: x
269264

270265
if isinstance(method, str):
271266
send_method = method
272267
else:
273268
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
274-
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, unique_reply_rank))
269+
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank))
275270

276271
workers = (
277-
(self.workers[unique_reply_rank],)
278-
if unique_reply_rank is not None
279-
else self.workers
272+
(self.workers[output_rank],) if output_rank is not None else self.workers
280273
)
281274

282275
shutdown_event = self.shutdown_event
@@ -299,10 +292,10 @@ def get_response():
299292
" stack trace above for the root cause"
300293
)
301294
responses.append(result)
302-
return responses[0] if unique_reply_rank is not None else responses
295+
return responses[0] if output_rank is not None else responses
303296

304297
if non_block:
305-
future = FutureWrapper(self.futures_queue)
298+
future = FutureWrapper(self.futures_queue, aggregate=aggregate)
306299
self.futures_queue.appendleft((future, get_response))
307300
return future
308301

@@ -311,7 +304,7 @@ def get_response():
311304
future, get_fut_response = self.futures_queue.pop()
312305
future.wait_for_response(get_fut_response)
313306

314-
return get_response()
307+
return aggregate(get_response())
315308

316309
@staticmethod
317310
def _ensure_worker_termination(worker_procs: list[BaseProcess]):

0 commit comments

Comments
 (0)