Skip to content
141 changes: 86 additions & 55 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
import traceback
import weakref
from collections import namedtuple
from collections import deque, namedtuple
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -57,34 +57,62 @@
class RequestQueueItem:
id: int
request: Optional[ExecutorRequest] = None
is_canceled_request: bool = False
query: Optional[list] = None # only used in `StarAttention`

@property
def is_shutdown_request(self):
return self.id == SHUTDOWN_REQUEST_ID

@property
def is_normal_request(self):
return not (self.is_shutdown_request or self.is_canceled_request)

def _get_from_request_queue(request_queue,
timeout: Optional[datetime.timedelta],
max_req_count: int) -> List[RequestQueueItem]:

def _get_from_request_queue(
request_queue,
timeout: Optional[datetime.timedelta]) -> List[RequestQueueItem]:
items = []
timeout_secs = timeout.total_seconds() if timeout is not None else None
req_count = 0
try:
if request_queue.empty() and (timeout_secs is None or timeout_secs > 0):
# if queue is empty and want to wait, wait
items.append(request_queue.get(timeout=timeout_secs))
else:
# if not empty or don't want to wait, just return all items in queue
while req_count < max_req_count:
while True:
queue_item = request_queue.get_nowait()
items.append(queue_item)
if not queue_item.is_shutdown_request():
req_count += 1
except queue.Empty:
pass
return items


def _get_from_waiting_queue(
waiting_queue: deque[RequestQueueItem],
max_req_count: int,
) -> List[RequestQueueItem]:
"""Safely extracts up to max_req_count items from a deque.

Args:
waiting_queue: The queue to pop items from.
max_req_count: Maximum items to retrieve. Returns empty list if <=0.

Returns:
List of retrieved items (may be shorter than max_req_count if queue empties first).
"""
# Edge case handling
if max_req_count <= 0: # Handles negative/zero counts
return []

items = []
req_count = 0
while req_count < max_req_count and waiting_queue:
items.append(waiting_queue.popleft())
req_count += 1
return items


@functools.cache
def _load_iteration_indexes(env_var: str):
spans = os.environ.get(env_var, None)
Expand Down Expand Up @@ -182,6 +210,7 @@ def __init__(self,
self.device_id = torch.cuda.current_device()
self.global_rank = global_mpi_rank()
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
self.waiting_queue: deque[RequestQueueItem] = deque()

# profile config
self.profile_start_iters, self.profile_stop_iters = _load_iteration_indexes(
Expand Down Expand Up @@ -251,7 +280,7 @@ def __init__(self,
self.send_handles = [None] * self.num_micro_batches

self.inflight_req_ids = ReqIdsSet()
self.canceled_req_ids = ReqIdsSet()
self.canceled_req_ids = []

self.model_engine.warmup(self.resource_manager)
if self.draft_model_engine is not None:
Expand Down Expand Up @@ -368,7 +397,12 @@ def cancel_request(self, id: int):
Args:
id (int): The request id for which to cancel the response
"""
self.canceled_req_ids.insert(id)
try:
self.enqueue_lock.acquire()
self.request_queue.put(
RequestQueueItem(id, is_canceled_request=True))
finally:
self.enqueue_lock.release()

def shutdown(self):
"""
Expand Down Expand Up @@ -454,6 +488,11 @@ def enqueue_request(self,
def set_gather_responses(self, gather_all_responses):
self.gather_all_responses = gather_all_responses

@property
def should_stop_processing(self):
return self.is_shutdown and len(self.active_requests) == 0 and len(
self.waiting_queue) == 0

@contextmanager
def _profiler(self):
it = -1
Expand Down Expand Up @@ -710,12 +749,12 @@ def _executor_loop_pp(self):
with self._profiler() as profile_step:
iter_start_time = time.time()
iter_stats = None
while not self.is_shutdown or len(self.active_requests) > 0:
while not self.should_stop_processing:
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()
new_requests = self._fetch_new_requests()
if self.is_shutdown and len(self.active_requests) == 0:
if self.should_stop_processing:
break

if self.enable_iter_perf_stats:
Expand Down Expand Up @@ -839,7 +878,7 @@ def _executor_loop_pp(self):
if previous_batch is not None:
with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
self._update_requests(previous_batch.sample_state)
self._handle_cancelled_requests()
self._handle_canceled_requests()
finished_requests = self._handle_responses()
previous_scheduled_batch = previous_batch.sample_state.scheduled_requests
self.resource_manager.update_resources(
Expand All @@ -861,12 +900,12 @@ def _executor_loop(self):
sample_state = None
iter_start_time = time.time()
iter_stats = None
while not self.is_shutdown or len(self.active_requests) > 0:
while not self.should_stop_processing:
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()
new_requests = self._fetch_new_requests()
if self.is_shutdown and len(self.active_requests) == 0:
if self.should_stop_processing:
break

if self.kv_cache_transceiver:
Expand Down Expand Up @@ -950,7 +989,7 @@ def _executor_loop(self):
for req in ctx_transmission_reqs:
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS

self._handle_cancelled_requests()
self._handle_canceled_requests()
finished_requests = self._handle_responses()
self.resource_manager.update_resources(scheduled_batch)
if self.enable_kv_cache_events:
Expand Down Expand Up @@ -1006,12 +1045,12 @@ def _executor_loop_overlap(self):
with self._profiler() as profile_step:
iter_start_time = time.time()
iter_stats = None
while not self.is_shutdown or len(self.active_requests) > 0:
while not self.should_stop_processing:
profile_step()
if self.enable_iter_perf_stats:
iter_start_time = time.time()
new_requests = self._fetch_new_requests()
if self.is_shutdown and len(self.active_requests) == 0:
if self.should_stop_processing:
break

if self.kv_cache_transceiver:
Expand Down Expand Up @@ -1125,7 +1164,7 @@ def _process_previous_batch(self):
for req in self.previous_batch.ctx_transmission_reqs:
req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS

self._handle_cancelled_requests()
self._handle_canceled_requests()
finished_requests = self._handle_responses()
scheduled_requests = self.previous_batch.sample_state.scheduled_requests
self.resource_manager.update_resources(scheduled_requests)
Expand Down Expand Up @@ -1200,13 +1239,11 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]:
total_num_active_requests = len(self.active_requests)
total_max_num_active_requests = self.max_num_active_requests

timeout = None if total_num_active_requests == 0 else datetime.timedelta(
0)
timeout = None if (total_num_active_requests == 0) and len(
self.waiting_queue) == 0 else datetime.timedelta(0)
new_requests = []
if self.dist.rank == 0:
new_requests = _get_from_request_queue(
self.request_queue, timeout,
total_max_num_active_requests - total_num_active_requests)
new_requests = _get_from_request_queue(self.request_queue, timeout)

if self.dist.rank == 0:
py_logits_post_processors = self._collect_py_objects_from_requests(
Expand All @@ -1229,21 +1266,28 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]:
# drop requests arriving after shutdown
valid_new_requests = []
for req_item in new_requests:
if req_item.is_shutdown_request():
if req_item.is_shutdown_request:
self.is_shutdown = True
break
elif req_item.is_canceled_request:
self.canceled_req_ids.append(req_item.id)
else:
valid_new_requests.append(req_item)
# Check if the beam width of the requests is equal to the max_beam_width
for req_item in valid_new_requests:
assert req_item.request.sampling_config.beam_width == self.max_beam_width, f"Request beam width {req_item.request.sampling_config.beam_width} is not equal to max_beam_width {self.max_beam_width}. This is not supported!"
new_requests = valid_new_requests

if py_request_objects and (self.dist.tp_size > 1
or self.dist.has_pp) and self.dist.rank > 0:
for attr_name, req_obj_dict in py_request_objects:
self._attach_py_objects_to_requests(new_requests, attr_name,
req_obj_dict)
self._attach_py_objects_to_requests(valid_new_requests,
attr_name, req_obj_dict)

self.waiting_queue.extend(valid_new_requests)

new_requests = _get_from_waiting_queue(
self.waiting_queue,
total_max_num_active_requests - total_num_active_requests)

if not self.enable_attention_dp:
self._update_new_active_requests_queue_latency(new_requests)
Expand Down Expand Up @@ -1339,7 +1383,7 @@ def _collect_py_objects_from_requests(
"""
req_id_to_obj = {}
for item in requests:
if item.is_shutdown_request():
if not item.is_normal_request:
continue
obj = getattr(item.request, attribute_name, None)
if obj is not None:
Expand Down Expand Up @@ -1926,41 +1970,28 @@ def _handle_errors(self, error_msg: Optional[str] = None):
def _terminate_request(self, request: LlmRequest):
self.resource_manager.free_resources(request)

@nvtx_range("_handle_cancelled_requests")
def _handle_cancelled_requests(self):
#TODO: properly handle canceled ids in pp case
if self.dist.has_tp:
self.canceled_req_ids = self.dist.broadcast(self.canceled_req_ids,
root=0)

@nvtx_range("_handle_canceled_requests")
def _handle_canceled_requests(self):
if len(self.canceled_req_ids) == 0:
return

cancelled_responses = {}
left_requests = []
# Tracks canceled requests for proper handling in overlap mode during `sampler.update_requests`.
self.canceled_requests = []
# cancel request in the waiting queue
self.waiting_queue = deque(req for req in self.waiting_queue
if req.id not in self.canceled_req_ids)

for request in self.active_requests:
req_id = request.py_request_id
if req_id in self.canceled_req_ids:
self._terminate_request(request)
# Mark requests as finished, then, we reuse all existing code
# to clean up the KV cache resources.
request.finish_by_reason(FinishReason.CANCELLED)
request.decoding_iter = request.py_decoding_iter
cancelled_responses[req_id] = request.create_response(
False, self.dist.rank)
self.canceled_requests.append(request)
self.canceled_req_ids.erase(req_id)
else:
left_requests.append(request)
self.active_requests = left_requests

# When enable attention dp, each rank does not have full copy of requests
# so we need to remove the cancel requests not in the local rank
self.canceled_req_ids.clear()

# enqueue the cancelled requests' responses as they are not
# active_requests and be discarded in the sampler loop.
self._enqueue_responses(cancelled_responses)
if self.enable_attention_dp:
# TODO: revisit the cancel logic of attention dp
# When enable attention dp, each rank does not have full copy of requests
# so we need to remove the cancel requests not in the local rank
self.canceled_req_ids.clear()

@nvtx_range("_enqueue_responses")
def _enqueue_responses(self, responses: Dict[int, LlmResponse]):
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9):
logits_dim = logits.dim()
if logits_dim == 1:
logits = logits.unsqueeze(0)
assert logits_dim == 2, "logits should be 2D [batch_size, vocab_size]"
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"

# sort the logits of each sample in descending order
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
Expand Down
2 changes: 0 additions & 2 deletions tests/unittest/llmapi/apps/_test_openai_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ async def test_request_cancellation(server: RemoteOpenAIServer,
model_name: str):
# clunky test: send an ungodly amount of load in with short timeouts
# then ensure that it still responds quickly afterwards
pytest.skip("https://nvbugs/5310314")

chat_input = [{"role": "user", "content": "Write a long story"}]
client = server.get_async_client(timeout=0.5, max_retries=3)
tasks = []
Expand Down