44from vllm .config import CacheConfig , ModelConfig , SchedulerConfig
55from vllm .multimodal .inputs import MultiModalKwargs , PlaceholderRange
66from vllm .sampling_params import SamplingParams
7- from vllm .v1 .core .scheduler import Scheduler
7+ from vllm .v1 .core .scheduler import Scheduler , SchedulerOutput
88from vllm .v1 .outputs import ModelRunnerOutput
99from vllm .v1 .request import Request , RequestStatus
1010
11+ EOS_TOKEN_ID = 50256
12+
1113
1214def create_scheduler (
1315 model : str = "facebook/opt-125m" ,
@@ -38,6 +40,7 @@ def create_scheduler(
3840 return Scheduler (scheduler_config ,
3941 model_config ,
4042 cache_config ,
43+ speculative_config = None ,
4144 lora_config = None ,
4245 log_stats = True )
4346
@@ -46,8 +49,12 @@ def create_requests(
4649 num_requests : int ,
4750 num_tokens : int = 10 ,
4851 mm_positions : Optional [List [PlaceholderRange ]] = None ,
52+ max_tokens : int = 16 ,
53+ stop_token_ids : Optional [List [int ]] = None ,
4954):
50- sampling_params = SamplingParams ()
55+ sampling_params = SamplingParams (ignore_eos = False ,
56+ max_tokens = max_tokens ,
57+ stop_token_ids = stop_token_ids )
5158 requests = []
5259 for i in range (num_requests ):
5360 if mm_positions is not None :
@@ -64,7 +71,7 @@ def create_requests(
6471 multi_modal_inputs = mm_inputs ,
6572 multi_modal_placeholders = mm_position ,
6673 multi_modal_hashes = None ,
67- eos_token_id = None ,
74+ eos_token_id = EOS_TOKEN_ID ,
6875 arrival_time = 0 ,
6976 )
7077 requests .append (request )
@@ -195,7 +202,7 @@ def test_schedule_partial_requests():
195202 model_runner_output = ModelRunnerOutput (
196203 req_ids = [request .request_id for request in requests ],
197204 req_id_to_index = req_to_index ,
198- sampled_token_ids = [0 ] * len (requests ),
205+ sampled_token_ids = [[ 0 ] for _ in range ( len (requests ))] ,
199206 logprobs = None ,
200207 prompt_logprobs_dict = {},
201208 )
@@ -215,6 +222,189 @@ def test_schedule_partial_requests():
215222 assert requests [2 ].request_id not in output .num_scheduled_tokens
216223
217224
225+ def test_stop_via_update_from_output ():
226+ """Test stopping behavior through update_from_output"""
227+ scheduler = create_scheduler ()
228+
229+ # Test case 1: Stop on EOS token
230+ requests = create_requests (num_requests = 2 , max_tokens = 10 )
231+ for req in requests :
232+ req .num_computed_tokens = req .num_tokens
233+ scheduler .requests [req .request_id ] = req
234+ scheduler .running .append (req )
235+ scheduler .scheduled_req_ids .add (req .request_id )
236+
237+ scheduler_output = SchedulerOutput (scheduled_new_reqs = [],
238+ scheduled_cached_reqs = [],
239+ num_scheduled_tokens = {
240+ requests [0 ].request_id : 1 ,
241+ requests [1 ].request_id : 2
242+ },
243+ total_num_scheduled_tokens = 3 ,
244+ scheduled_encoder_inputs = {},
245+ scheduled_spec_decode_tokens = {
246+ requests [0 ].request_id : [],
247+ requests [1 ].request_id : [10 ]
248+ },
249+ num_common_prefix_blocks = 0 ,
250+ finished_req_ids = set (),
251+ free_encoder_input_ids = [])
252+
253+ model_output = ModelRunnerOutput (
254+ req_ids = [req .request_id for req in requests ],
255+ req_id_to_index = {
256+ req .request_id : i
257+ for i , req in enumerate (requests )
258+ },
259+ sampled_token_ids = [[EOS_TOKEN_ID ],
260+ [10 ,
261+ 11 ]], # First request hits EOS, second continues
262+ logprobs = None ,
263+ prompt_logprobs_dict = {})
264+
265+ scheduler .update_from_output (scheduler_output , model_output )
266+
267+ # Verify first request stopped, second continues
268+ assert len (scheduler .running ) == 1
269+ assert scheduler .running [0 ].request_id == requests [1 ].request_id
270+ assert requests [0 ].status == RequestStatus .FINISHED_STOPPED
271+ assert requests [0 ].request_id in scheduler .finished_req_ids
272+ assert list (requests [0 ].output_token_ids ) == [EOS_TOKEN_ID ]
273+ assert list (requests [1 ].output_token_ids ) == [10 , 11 ]
274+
275+ # Test case 2: Stop on custom stop token
276+ scheduler = create_scheduler ()
277+ requests = create_requests (num_requests = 2 ,
278+ max_tokens = 10 ,
279+ stop_token_ids = [42 , 43 ])
280+ for req in requests :
281+ req .num_computed_tokens = req .num_tokens
282+ scheduler .requests [req .request_id ] = req
283+ scheduler .running .append (req )
284+ scheduler .scheduled_req_ids .add (req .request_id )
285+
286+ scheduler_output = SchedulerOutput (scheduled_new_reqs = [],
287+ scheduled_cached_reqs = [],
288+ num_scheduled_tokens = {
289+ requests [0 ].request_id : 3 ,
290+ requests [1 ].request_id : 2
291+ },
292+ total_num_scheduled_tokens = 5 ,
293+ scheduled_encoder_inputs = {},
294+ scheduled_spec_decode_tokens = {
295+ requests [0 ].request_id : [10 , 42 ],
296+ requests [1 ].request_id : [13 ]
297+ },
298+ num_common_prefix_blocks = 0 ,
299+ finished_req_ids = set (),
300+ free_encoder_input_ids = [])
301+
302+ model_output = ModelRunnerOutput (
303+ req_ids = [req .request_id for req in requests ],
304+ req_id_to_index = {
305+ req .request_id : i
306+ for i , req in enumerate (requests )
307+ },
308+ sampled_token_ids = [[10 , 42 , 12 ],
309+ [13 , 14 ]], # First request hits stop token
310+ logprobs = None ,
311+ prompt_logprobs_dict = {})
312+
313+ scheduler .update_from_output (scheduler_output , model_output )
314+
315+ # Verify first request stopped on custom token
316+ assert len (scheduler .running ) == 1
317+ assert scheduler .running [0 ].request_id == requests [1 ].request_id
318+ assert requests [0 ].status == RequestStatus .FINISHED_STOPPED
319+ assert requests [0 ].stop_reason == 42
320+ assert requests [0 ].request_id in scheduler .finished_req_ids
321+ assert list (requests [0 ].output_token_ids ) == [10 , 42 ]
322+ assert list (requests [1 ].output_token_ids ) == [13 , 14 ]
323+
324+ # Test case 3: Stop on max tokens
325+ scheduler = create_scheduler ()
326+ requests = create_requests (num_requests = 2 , max_tokens = 2 )
327+ for req in requests :
328+ req .num_computed_tokens = req .num_tokens
329+ scheduler .requests [req .request_id ] = req
330+ scheduler .running .append (req )
331+ scheduler .scheduled_req_ids .add (req .request_id )
332+
333+ scheduler_output = SchedulerOutput (scheduled_new_reqs = [],
334+ scheduled_cached_reqs = [],
335+ num_scheduled_tokens = {
336+ requests [0 ].request_id : 3 ,
337+ requests [1 ].request_id : 1
338+ },
339+ total_num_scheduled_tokens = 4 ,
340+ scheduled_encoder_inputs = {},
341+ scheduled_spec_decode_tokens = {
342+ requests [0 ].request_id : [10 , 11 ],
343+ requests [1 ].request_id : []
344+ },
345+ num_common_prefix_blocks = 0 ,
346+ finished_req_ids = set (),
347+ free_encoder_input_ids = [])
348+
349+ model_output = ModelRunnerOutput (
350+ req_ids = [req .request_id for req in requests ],
351+ req_id_to_index = {
352+ req .request_id : i
353+ for i , req in enumerate (requests )
354+ },
355+ sampled_token_ids = [[10 , 11 , 12 ],
356+ [13 ]], # First request exceeds max_tokens
357+ logprobs = None ,
358+ prompt_logprobs_dict = {})
359+
360+ scheduler .update_from_output (scheduler_output , model_output )
361+
362+ # Verify first request stopped due to length
363+ assert len (scheduler .running ) == 1
364+ assert scheduler .running [0 ].request_id == requests [1 ].request_id
365+ assert requests [0 ].status == RequestStatus .FINISHED_LENGTH_CAPPED
366+ assert requests [0 ].request_id in scheduler .finished_req_ids
367+ assert list (requests [0 ].output_token_ids ) == [10 , 11
368+ ] # Truncated to max_tokens
369+ assert list (requests [1 ].output_token_ids ) == [13 ]
370+
371+ # Test case 4: Ignore EOS flag
372+ scheduler = create_scheduler ()
373+ requests = create_requests (num_requests = 1 , max_tokens = 10 )
374+ requests [0 ].sampling_params .ignore_eos = True
375+ requests [0 ].num_computed_tokens = requests [0 ].num_tokens
376+ scheduler .requests [requests [0 ].request_id ] = requests [0 ]
377+ scheduler .running .append (requests [0 ])
378+ scheduler .scheduled_req_ids .add (requests [0 ].request_id )
379+
380+ scheduler_output = SchedulerOutput (
381+ scheduled_new_reqs = [],
382+ scheduled_cached_reqs = [],
383+ num_scheduled_tokens = {requests [0 ].request_id : 3 },
384+ total_num_scheduled_tokens = 3 ,
385+ scheduled_encoder_inputs = {},
386+ scheduled_spec_decode_tokens = {
387+ requests [0 ].request_id : [EOS_TOKEN_ID , 10 ]
388+ },
389+ num_common_prefix_blocks = 0 ,
390+ finished_req_ids = set (),
391+ free_encoder_input_ids = [])
392+
393+ model_output = ModelRunnerOutput (
394+ req_ids = [requests [0 ].request_id ],
395+ req_id_to_index = {requests [0 ].request_id : 0 },
396+ sampled_token_ids = [[EOS_TOKEN_ID , 10 , 11 ]],
397+ logprobs = None ,
398+ prompt_logprobs_dict = {})
399+
400+ scheduler .update_from_output (scheduler_output , model_output )
401+
402+ # Verify request continues past EOS
403+ assert len (scheduler .running ) == 1
404+ assert not requests [0 ].is_finished ()
405+ assert list (requests [0 ].output_token_ids ) == [EOS_TOKEN_ID , 10 , 11 ]
406+
407+
218408def test_schedule_concurrent_batches ():
219409 scheduler = create_scheduler (
220410 max_num_batched_tokens = 1024 ,
@@ -243,7 +433,7 @@ def test_schedule_concurrent_batches():
243433 model_runner_output = ModelRunnerOutput (
244434 req_ids = [requests [0 ].request_id ],
245435 req_id_to_index = {requests [0 ].request_id : 0 },
246- sampled_token_ids = [0 ],
436+ sampled_token_ids = [[ 0 ] ],
247437 logprobs = None ,
248438 prompt_logprobs_dict = {},
249439 )
@@ -259,7 +449,7 @@ def test_schedule_concurrent_batches():
259449 model_runner_output = ModelRunnerOutput (
260450 req_ids = [requests [1 ].request_id ],
261451 req_id_to_index = {requests [1 ].request_id : 0 },
262- sampled_token_ids = [0 ],
452+ sampled_token_ids = [[ 0 ] ],
263453 logprobs = None ,
264454 prompt_logprobs_dict = {},
265455 )
0 commit comments