11# SPDX-License-Identifier: Apache-2.0
22"""KV-Cache Utilities."""
3+ from collections import deque
34from collections .abc import Sequence
45from dataclasses import dataclass
5- from typing import Any , List , NamedTuple , Optional , Tuple , TypedDict
6+ from typing import Any , List , NamedTuple , Optional , Tuple
67
78from vllm .config import VllmConfig
89from vllm .logger import init_logger
@@ -28,14 +29,52 @@ class BlockHashType(NamedTuple):
2829 extra_keys : Optional [Any ] = None
2930
3031
31- class PrefixCachingMetrics ( TypedDict ) :
32- """Metrics for prefix caching."""
32+ class PrefixCachingMetrics :
33+ """Metrics for prefix caching with a hit rate of the most recent N requests.
3334
34- query_total : int
35- """The total number of queries."""
35+ Args:
36+ interval: The number of the most recent requests to aggregate.
37+ Defaults to 1000.
38+ """
39+
40+ def __init__ (self , interval : int = 1000 ):
41+ self .interval = interval
42+ self .aggregated_query_total = 0
43+ self .aggregated_query_hit = 0
44+ self .request_queries : deque [Tuple [int , int ]] = deque ()
3645
37- query_hit : int
38- """The number of queries that hit the prefix cache."""
46+ def add_request_query (self , num_queries : int , num_hits : int ):
47+ """Add a request to the metrics. This function is called when
48+ a new request is being scheduled and is looking for computed blocks.
49+ When there are more than `interval` requests, the oldest request
50+ is removed from the metrics.
51+
52+ Args:
53+ num_queries: The number of queries in the request.
54+ num_hits: The number of hits in the request.
55+ """
56+
57+ self .request_queries .append ((num_queries , num_hits ))
58+ if len (self .request_queries ) > self .interval :
59+ old_num_queries , old_num_hits = self .request_queries .popleft ()
60+ self .aggregated_query_total -= old_num_queries
61+ self .aggregated_query_hit -= old_num_hits
62+
63+ self .aggregated_query_total += num_queries
64+ self .aggregated_query_hit += num_hits
65+
66+ def reset (self ):
67+ """Reset the metrics."""
68+ self .aggregated_query_total = 0
69+ self .aggregated_query_hit = 0
70+ self .request_queries .clear ()
71+
72+ @property
73+ def hit_rate (self ) -> float :
74+ """Calculate the hit rate for the past N requests."""
75+ if self .aggregated_query_total == 0 :
76+ return 0.0
77+ return self .aggregated_query_hit / self .aggregated_query_total
3978
4079
4180@dataclass
0 commit comments