Skip to content

Commit 4760413

Browse files
authored
[Bugfix] Spec decode + structured output + spec model max len edge case (vllm-project#28298)
Signed-off-by: Andy Lo <andy@mistral.ai>
1 parent 26990d2 commit 4760413

File tree

3 files changed

+36
-8
lines changed

3 files changed

+36
-8
lines changed

tests/v1/spec_decode/test_max_len.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from tests.utils import get_attn_backend_list_based_on_platform
88
from vllm import LLM, SamplingParams
99
from vllm.platforms import current_platform
10+
from vllm.sampling_params import StructuredOutputsParams
1011

1112
_PROMPTS = [
1213
"1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1",
@@ -56,8 +57,34 @@ def test_eagle_max_len(
5657
"method": "eagle",
5758
"model": "yuhuili/EAGLE-LLaMA3-Instruct-8B",
5859
"num_speculative_tokens": num_speculative_tokens,
60+
"max_model_len": 80,
5961
},
60-
max_model_len=100,
62+
max_model_len=200,
6163
)
62-
sampling_params = SamplingParams(max_tokens=100, ignore_eos=True)
63-
llm.generate(_PROMPTS, sampling_params)
64+
sampling_params = SamplingParams(max_tokens=200, ignore_eos=True)
65+
outputs = llm.generate(_PROMPTS, sampling_params)
66+
for o in outputs:
67+
assert o.outputs[0].finish_reason == "length", (
68+
"This test is only meaningful if the output "
69+
"is truncated due to max length"
70+
)
71+
72+
sampling_params = SamplingParams(
73+
max_tokens=200,
74+
structured_outputs=StructuredOutputsParams(
75+
regex="^" + "a b c d e " * 15 + "$"
76+
),
77+
)
78+
output = llm.generate(_PROMPTS, sampling_params)
79+
for o in output:
80+
assert o.prompt_token_ids is not None
81+
assert (
82+
len(o.prompt_token_ids)
83+
< 80
84+
< len(o.prompt_token_ids) + len(o.outputs[0].token_ids)
85+
< 200
86+
), (
87+
"This test is only meaningful if the output "
88+
"is longer than the eagle max length"
89+
)
90+
assert o.outputs[0].text == "a b c d e " * 15

vllm/v1/core/sched/scheduler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,9 @@ def schedule(self) -> SchedulerOutput:
325325
scheduled_spec_decode_tokens[request.request_id] = (
326326
request.spec_token_ids
327327
)
328+
# New spec tokens will be set in `update_draft_token_ids` before the
329+
# next step when applicable.
330+
request.spec_token_ids = []
328331

329332
# Encoder-related.
330333
if encoder_inputs_to_schedule:
@@ -1149,10 +1152,7 @@ def update_draft_token_ids(
11491152
continue
11501153

11511154
# Add newly generated spec token ids to the request.
1152-
if not spec_token_ids:
1153-
# NOTE(woosuk): request.spec_token_ids should be updated.
1154-
request.spec_token_ids.clear()
1155-
elif self.structured_output_manager.should_advance(request):
1155+
if self.structured_output_manager.should_advance(request):
11561156
metadata = request.structured_output_request
11571157
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
11581158
spec_token_ids

vllm/v1/structured_output/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,10 @@ def grammar_bitmask(
269269
and token is not None
270270
and not structured_output_request.grammar.is_terminated()
271271
):
272-
assert structured_output_request.grammar.accept_tokens(
272+
accepted = structured_output_request.grammar.accept_tokens(
273273
req_id, [token]
274274
)
275+
assert accepted, (token, req_id, scheduled_spec_decode_tokens)
275276
state_advancements += 1
276277
cumulative_index += 1
277278
if state_advancements > 0:

0 commit comments

Comments
 (0)