Skip to content

Commit 9639b0d

Browse files
committed
Refactor EngineCoreOutputs
1 parent bc1bdec commit 9639b0d

File tree

10 files changed

+112
-62
lines changed

10 files changed

+112
-62
lines changed

benchmarks/backend_request_func.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,10 +272,15 @@ async def async_request_openai_completions(
272272
try:
273273
async with session.post(url=api_url, json=payload,
274274
headers=headers) as response:
275+
#print(f"RES = {response.status}")
275276
if response.status == 200:
276277
first_chunk_received = False
278+
279+
#print(response)
280+
277281
async for chunk_bytes in response.content:
278282
chunk_bytes = chunk_bytes.strip()
283+
#print(f"CB = {chunk_bytes}")
279284
if not chunk_bytes:
280285
continue
281286

@@ -313,7 +318,7 @@ async def async_request_openai_completions(
313318
else:
314319
output.success = False
315320
output.error = (
316-
"Never received a valid chunk to calculate TTFT."
321+
"Never received a valid chunk to calculate TTFT. "
317322
"This response will be marked as failed!")
318323
output.generated_text = generated_text
319324
output.latency = most_recent_timestamp - st

benchmarks/benchmark_serving.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,10 @@ async def benchmark(
564564
)
565565
test_output = await request_func(request_func_input=test_input)
566566
if not test_output.success:
567-
raise ValueError(
567+
#raise ValueError(
568+
# "Initial test run failed - Please make sure benchmark arguments "
569+
# f"are correctly specified. Error: {test_output.error}")
570+
print(
568571
"Initial test run failed - Please make sure benchmark arguments "
569572
f"are correctly specified. Error: {test_output.error}")
570573
else:

vllm/v1/core/scheduler.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
1212
compute_encoder_budget)
1313
from vllm.v1.core.kv_cache_manager import KVCacheManager
14-
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
14+
from vllm.v1.engine import EngineCoreOutputs
1515
from vllm.v1.metrics.stats import SchedulerStats
1616
from vllm.v1.outputs import ModelRunnerOutput
17-
from vllm.v1.request import Request, RequestStatus
17+
from vllm.v1.request import FinishReason, Request, RequestStatus
1818

1919
if TYPE_CHECKING:
2020
from vllm.multimodal import MultiModalKwargs
@@ -413,11 +413,19 @@ def update_from_output(
413413
sampled_token_ids = model_runner_output.sampled_token_ids
414414
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
415415
new_running: List[Request] = []
416-
outputs: List[EngineCoreOutput] = []
416+
output = EngineCoreOutputs(request_ids=[],
417+
new_token_id_offsets=[],
418+
new_token_ids=[],
419+
finished=[],
420+
finish_reason={},
421+
stop_reason=[],
422+
scheduler_stats=None
423+
)
417424

418425
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
419426
# loop can be a performance bottleneck. We should do our best to avoid
420427
# expensive operations inside the loop.
428+
offset = 0
421429
for request in self.running:
422430
req_id = request.request_id
423431
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
@@ -455,30 +463,32 @@ def update_from_output(
455463
# TODO: Update the KV cache manager for prefix caching.
456464

457465
# Check for stop and update request state.
458-
# This must be called before we make the EngineCoreOutput.
466+
# This must be called before we make the EngineCoreOutputs.
459467
stopped = self._check_stop(request)
460468
if stopped:
461469
self._free_request(request)
462470

463-
# Add EngineCoreOutput for this Request.
464-
output = EngineCoreOutput(
465-
request_id=req_id,
466-
new_token_ids=request.output_token_ids[-num_new_tokens:],
467-
finished=request.is_finished(),
468-
finish_reason=request.get_finished_reason(),
469-
stop_reason=request.stop_reason)
470-
outputs.append(output)
471+
# not a list of outputs here
472+
473+
# Add EngineCoreOutputs for this Request.
474+
output.request_ids.append(req_id)
475+
output.new_token_id_offsets.append(offset)
476+
output.new_token_ids += request.output_token_ids[-num_new_tokens:]
477+
output.finished.append(request.is_finished())
478+
if request.get_finished_reason() is not None:
479+
output.finish_reason[req_id] = request.get_finished_reason()
480+
#print(f"req stop = {request.stop_reason}, {request.status}")
481+
output.stop_reason.append(request.stop_reason)
482+
offset = offset + 1 # move out of if?
471483

472484
# Breakout of the loop.
473485
if stopped:
474486
continue
475487

476488
new_running.append(request)
477489
self.running = new_running
478-
return EngineCoreOutputs(
479-
outputs=outputs,
480-
scheduler_stats=self.make_stats(),
481-
)
490+
output.scheduler_stats = self.make_stats()
491+
return output
482492

483493
def _check_stop(self, request: Request) -> bool:
484494
if (request.num_tokens >= self.max_model_len

vllm/v1/engine/__init__.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import enum
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, List, Optional, Union
5+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
66

77
import msgspec
88

@@ -59,17 +59,17 @@ class EngineCoreRequest:
5959
lora_request: Optional["LoRARequest"]
6060

6161

62-
class EngineCoreOutput(
63-
msgspec.Struct,
64-
array_like=True, # type: ignore[call-arg]
65-
omit_defaults=True, # type: ignore[call-arg]
66-
gc=False): # type: ignore[call-arg]
67-
68-
request_id: str
69-
new_token_ids: List[int]
70-
finished: bool
71-
finish_reason: Optional[FinishReason] = None
72-
stop_reason: Union[int, str, None] = None
62+
#class EngineCoreOutput(
63+
# msgspec.Struct,
64+
# array_like=True, # type: ignore[call-arg]
65+
# omit_defaults=True, # type: ignore[call-arg]
66+
# gc=False): # type: ignore[call-arg]
67+
#
68+
# request_id: str
69+
# new_token_ids: List[int]
70+
# finished: bool
71+
# finish_reason: Optional[FinishReason] = None
72+
# stop_reason: Union[int, str, None] = None
7373

7474

7575
class EngineCoreOutputs(
@@ -82,7 +82,12 @@ class EngineCoreOutputs(
8282
# e.g. columnwise layout
8383

8484
# [num_reqs]
85-
outputs: List[EngineCoreOutput]
85+
request_ids: List[str]
86+
new_token_id_offsets: List[int]
87+
new_token_ids: List[int]
88+
finished: List[bool]
89+
finish_reason: Dict[str, FinishReason] # Union[List, Dict]?
90+
stop_reason: List[Union[int, str, None]]
8691
scheduler_stats: SchedulerStats
8792

8893

vllm/v1/engine/async_llm.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,19 +249,27 @@ async def _run_output_handler(self):
249249
# Split outputs into chunks of at most
250250
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
251251
# event loop for too long.
252-
num_outputs = len(outputs.outputs)
252+
num_outputs = len(outputs.new_token_id_offsets)
253+
253254
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
254-
slices = (outputs.outputs, )
255+
slices = ((0, num_outputs), )
255256
else:
256-
slices = np.array_split(
257-
outputs.outputs,
257+
slices = []
258+
parts = np.linspace(
259+
num_outputs,
258260
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
261+
last = 0
262+
for i in parts:
263+
slices.append((last, i))
264+
last = i
265+
print(f"slices = {slices}")
259266

260267
iteration_stats = None
261-
for i, outputs_slice in enumerate(slices):
268+
for i, slice in enumerate(slices):
269+
slice_start, slice_end = slice
262270
# 2) Process EngineCoreOutputs.
263271
processed_outputs = self.output_processor.process_outputs(
264-
outputs_slice, iteration_stats)
272+
outputs, slice_start, slice_end, iteration_stats)
265273
# NOTE: RequestOutputs are pushed to their queues.
266274
assert not processed_outputs.request_outputs
267275
iteration_stats = processed_outputs.iteration_stats

vllm/v1/engine/core.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,14 @@ def step(self) -> EngineCoreOutputs:
123123

124124
if not self.scheduler.has_unfinished_requests():
125125
return EngineCoreOutputs(
126-
outputs=[], scheduler_stats=self.scheduler.make_stats())
126+
request_ids=[],
127+
new_token_id_offsets=[],
128+
new_token_ids=[],
129+
finished=[],
130+
finish_reason={},
131+
stop_reason=[],
132+
scheduler_stats=self.scheduler.make_stats()
133+
)
127134

128135
scheduler_output = self.scheduler.schedule()
129136
output = self.model_executor.execute_model(scheduler_output)
@@ -299,5 +306,6 @@ def process_output_socket(self, output_path: str):
299306
with zmq_socket_ctx(output_path, zmq.constants.PUSH) as socket:
300307
while True:
301308
outputs = self.output_queue.get()
309+
#print(outputs)
302310
encoder.encode_into(outputs, buffer)
303311
socket.send_multipart((buffer, ), copy=False)

vllm/v1/engine/detokenizer.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from vllm.sampling_params import RequestOutputKind
99
from vllm.transformers_utils.detokenizer_utils import (
1010
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
11-
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
11+
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequest, FinishReason
1212

1313
logger = init_logger(__name__)
1414

@@ -98,18 +98,16 @@ def from_new_request(
9898

9999
def update_from_output(
100100
self,
101-
output: EngineCoreOutput,
101+
new_token_ids: List[int],
102+
finish_reason: Optional[FinishReason],
103+
stop_reason: Union[int, str, None],
102104
) -> Optional[DetokenizerOutput]:
103105
"""
104106
Update RequestState for the request_id by:
105107
1) Detokenize the new token ids incrementally.
106108
2) Update the RequestOutput with the new text.
107109
"""
108110

109-
new_token_ids = output.new_token_ids
110-
finish_reason = output.finish_reason
111-
stop_reason = output.stop_reason
112-
113111
# 1) Detokenize the new token ids incrementally.
114112
# TODO(woosuk): This method becomes very inefficient when the number of
115113
# new_token_ids is more than 1. We need to optimize this.

vllm/v1/engine/llm_engine.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,14 @@ def add_request(
143143

144144
def step(self) -> List[RequestOutput]:
145145

146-
# 1) Get EngineCoreOutput from the EngineCore.
146+
# 1) Get EngineCoreOutputs from the EngineCore.
147147
outputs = self.engine_core.get_output()
148148

149149
# 2) Process EngineCoreOutputs.
150150
processed_outputs = self.output_processor.process_outputs(
151-
outputs.outputs)
151+
outputs,
152+
0,
153+
len(outputs.request_ids))
152154

153155
# 3) Abort any reqs that finished due to stop strings.
154156
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)

vllm/v1/engine/output_processor.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from vllm.outputs import RequestOutput
88
from vllm.transformers_utils.detokenizer_utils import AnyTokenizer
99
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
10-
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
10+
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequest
1111
from vllm.v1.engine.detokenizer import (DetokenizerOutput,
1212
IncrementalDetokenizer)
1313
from vllm.v1.metrics.stats import IterationStats, RequestStateStats
@@ -106,59 +106,71 @@ def add_request(
106106

107107
def process_outputs(
108108
self,
109-
engine_core_outputs: List[EngineCoreOutput],
109+
engine_core_outputs: EngineCoreOutputs,
110+
first: int,
111+
last: int,
110112
iteration_stats: Optional[IterationStats] = None,
111113
) -> OutputProcessorOutput:
112114
"""
113115
Process the EngineCoreOutputs:
114116
1) Compute stats for logging
115117
2) Detokenize
116118
3) Create and handle RequestOutput objects:
117-
* If there is a queue (for usage with AsyncLLM),
119+
* If there is a queue (for usage with AsyncLLM),
118120
put the RequestOutput objects into the queue for
119121
handling by the per-request generate() tasks.
120122
121-
* If there is no queue (for usage with LLMEngine),
123+
* If there is no queue (for usage with LLMEngine),
122124
return a list of RequestOutput objects.
123125
124126
****************** NOTE FOR DEVELOPERS ******************
125127
126128
VLLM V1 minimizes the number of python loops over the full
127-
batch to ensure system overheads are minimized. This is the
129+
batch to ensure system overheads are minimized. This is the
128130
only function that should loop over EngineCoreOutputs.
129131
130132
If you need to touch every element of the batch, implement a
131133
method called XXXClass.update_from_output() to be called
132134
within the loop below. For examples, see:
133135
* IterationStats.update_from_output()
134136
* Detokenizer.update_from_output()
135-
137+
136138
TODO(rob): add Protocol makes update_from_output explicit.
137-
139+
138140
**********************************************************
139141
"""
140142

141143
request_outputs: List[RequestOutput] = []
142144
reqs_to_abort: List[str] = []
143145
if not iteration_stats:
144146
iteration_stats = IterationStats(self.log_stats)
145-
for engine_core_output in engine_core_outputs:
146-
req_id = engine_core_output.request_id
147+
for i, req_id in enumerate(engine_core_outputs.request_ids[first:last]):
147148
req_state = self.request_states.get(req_id)
148149
if req_state is None:
149150
# Ignore output for already-aborted request.
150151
continue
151152

153+
num_tokens = last - first # might not be robust
154+
start = engine_core_outputs.new_token_id_offsets[i]
155+
end = engine_core_outputs.new_token_id_offsets[i + 1] if i < num_tokens - 1 else -1
156+
# better way to do this?
157+
new_token_ids = engine_core_outputs.new_token_ids[start:end]
158+
152159
# 1) Compute stats for this iteration.
153-
iteration_stats.update_from_output(engine_core_output,
160+
iteration_stats.update_from_output(num_tokens,
154161
req_state.is_prefilling,
155162
req_state.prompt_len,
156163
req_state.stats)
157164
req_state.is_prefilling = False
158165

159166
# 2) Detokenize the token ids into text.
167+
#print(f"finish = {engine_core_outputs.finish_reason.get(req_id)}")
168+
#print(f"stop = {engine_core_outputs.stop_reason[i + first]}")
160169
detokenizer_output = req_state.detokenizer.update_from_output(
161-
engine_core_output)
170+
new_token_ids,
171+
engine_core_outputs.finish_reason.get(req_id),
172+
engine_core_outputs.stop_reason[i + first],
173+
)
162174

163175
# 3) Create and handle RequestOutput objects.
164176
if detokenizer_output is not None:
@@ -177,7 +189,7 @@ def process_outputs(
177189
assert detokenizer_output.finish_reason is not None
178190

179191
self.request_states.pop(req_id)
180-
if not engine_core_output.finished:
192+
if not engine_core_outputs.finished[i]:
181193
# If req not finished in EngineCore, but Detokenizer
182194
# detected stop string, abort needed in EngineCore.
183195
reqs_to_abort.append(req_id)

0 commit comments

Comments
 (0)