11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4+ import random
5+
46import torch
57
68from vllm .v1 .core .block_pool import BlockPool
79from vllm .v1 .core .kv_cache_utils import (BlockHash , BlockHashWithGroupId ,
810 KVCacheBlock )
9- from vllm .v1 .core .single_type_kv_cache_manager import SlidingWindowManager
10- from vllm .v1 .kv_cache_interface import SlidingWindowSpec
11+ from vllm .v1 .core .single_type_kv_cache_manager import (
12+ ChunkedLocalAttentionManager , SlidingWindowManager )
13+ from vllm .v1 .kv_cache_interface import (ChunkedLocalAttentionSpec ,
14+ SlidingWindowSpec )
1115
1216
1317def get_sliding_window_manager (sliding_window_spec , block_pool ):
@@ -17,6 +21,80 @@ def get_sliding_window_manager(sliding_window_spec, block_pool):
1721 kv_cache_group_id = 0 )
1822
1923
24+ def get_chunked_local_attention_manager (chunked_local_attention_spec ,
25+ block_pool ):
26+ return ChunkedLocalAttentionManager (chunked_local_attention_spec ,
27+ block_pool ,
28+ caching_hash_fn = lambda x : x ,
29+ kv_cache_group_id = 0 )
30+
31+
32+ def test_chunked_local_attention_possible_cached_prefix ():
33+ block_size = 2
34+ chunked_local_attention_spec = ChunkedLocalAttentionSpec (
35+ block_size = block_size ,
36+ num_kv_heads = 1 ,
37+ head_size = 1 ,
38+ dtype = torch .float32 ,
39+ attention_chunk_size = 4 ,
40+ use_mla = False ,
41+ )
42+
43+ block_pool = BlockPool (num_gpu_blocks = 100 , enable_caching = True )
44+ manager = get_chunked_local_attention_manager (chunked_local_attention_spec ,
45+ block_pool )
46+
47+ def run_one_case (block_is_cached , tail_token , expect_length ):
48+ block_hash_list = [
49+ BlockHash (i , ()) for i in range (len (block_is_cached ))
50+ ]
51+
52+ block_pool .cached_block_hash_to_block .clear ()
53+
54+ # Mock the block pool with the cached blocks
55+ for i , (block_hash ,
56+ is_cached ) in enumerate (zip (block_hash_list , block_is_cached )):
57+ if is_cached :
58+ block_pool .cached_block_hash_to_block [BlockHashWithGroupId (
59+ block_hash , 0 )] = {
60+ i : block_pool .blocks [i + 10 ],
61+ }
62+
63+ computed_blocks = manager .find_longest_cache_hit (
64+ block_hashes = block_hash_list ,
65+ max_length = len (block_hash_list ) * block_size + tail_token ,
66+ kv_cache_group_ids = [0 ],
67+ block_pool = block_pool ,
68+ kv_cache_spec = chunked_local_attention_spec ,
69+ use_eagle = False )[0 ]
70+ assert len (computed_blocks ) == expect_length
71+
72+ assert all (block == block_pool .null_block
73+ for block in computed_blocks [:(expect_length - 1 ) // 2 ])
74+
75+ run_one_case ([True ], 0 , 1 )
76+ run_one_case ([True ], 1 , 1 )
77+ run_one_case ([True , False ], 0 , 2 )
78+ run_one_case ([True , False ], 1 , 2 )
79+ run_one_case ([True , True ], 0 , 2 )
80+ run_one_case ([True , True ], 1 , 2 )
81+ run_one_case ([True , True , False ], 0 , 2 )
82+ run_one_case ([True , True , False ], 1 , 2 )
83+ run_one_case ([True , True , True ], 0 , 3 )
84+ run_one_case ([True , True , True ], 1 , 3 )
85+ run_one_case ([True , True , True , False ], 0 , 4 )
86+ run_one_case ([True , True , True , False ], 1 , 4 )
87+ run_one_case ([random .choice ([True , False ])] * 8 + [True ], 1 , 9 )
88+ run_one_case ([random .choice ([True , False ])] * 8 + [False ], 1 , 8 )
89+ run_one_case ([random .choice ([True , False ])] * 8 + [True , True ], 1 , 10 )
90+ run_one_case ([random .choice ([True , False ])] * 8 + [True , False ], 0 , 10 )
91+ run_one_case ([random .choice ([True , False ])] * 8 + [True , False ], 1 , 10 )
92+ run_one_case ([random .choice ([True , False ])] * 8 + [False , True ], 0 , 10 )
93+ run_one_case ([random .choice ([True , False ])] * 8 + [False , True ], 1 , 10 )
94+ run_one_case ([random .choice ([True , False ])] * 8 + [False , False ], 0 , 10 )
95+ run_one_case ([random .choice ([True , False ])] * 8 + [False , False ], 1 , 10 )
96+
97+
2098def test_sliding_window_possible_cached_prefix ():
2199 block_size = 2
22100 sliding_window_spec = SlidingWindowSpec (
@@ -84,6 +162,58 @@ def run_one_case(block_is_cached, expect_length):
84162 ], 8 )
85163
86164
165+ def test_chunked_local_attention_remove_skipped_blocks ():
166+ attention_spec = ChunkedLocalAttentionSpec (
167+ block_size = 2 ,
168+ num_kv_heads = 1 ,
169+ head_size = 1 ,
170+ dtype = torch .float32 ,
171+ attention_chunk_size = 4 ,
172+ use_mla = False ,
173+ )
174+
175+ block_pool = BlockPool (num_gpu_blocks = 2000 , enable_caching = True )
176+
177+ manager = get_chunked_local_attention_manager (attention_spec , block_pool )
178+
179+ null_block_id = block_pool .null_block .block_id
180+
181+ def id_to_block_table (ids ) -> list [KVCacheBlock ]:
182+ return [
183+ KVCacheBlock (id_ )
184+ if id_ != null_block_id else block_pool .null_block for id_ in ids
185+ ]
186+
187+ def assert_block_id (block_table : list [KVCacheBlock ], ids : list [int ]):
188+ for block , id_ in zip (block_table , ids ):
189+ if id_ == null_block_id :
190+ assert block == block_pool .null_block
191+ else :
192+ assert block .block_id == id_
193+
194+ original_block_ids = [
195+ 1000 , 1001 , 1002 , 1003 , 1004 , 1005 , 1006 , 1007 , 1008 , 1009 , 1010
196+ ]
197+ block_table = id_to_block_table (original_block_ids )
198+ manager .req_to_blocks ["test" ] = block_table
199+
200+ manager .remove_skipped_blocks ("test" , 0 )
201+ assert_block_id (block_table , original_block_ids )
202+
203+ # For 4th token (0-indexed), token 0-3 is out of the local attention window.
204+ manager .remove_skipped_blocks ("test" , 4 )
205+ assert_block_id (block_table , [null_block_id ] * 2 )
206+
207+ # For 6th token (0-indexed), token 4 - 6 are in local attention window,
208+ # token 0 - 3 are out, 2 blocks can be removed.
209+ manager .remove_skipped_blocks ("test" , 6 )
210+ assert_block_id (block_table , [null_block_id ] * 2 + original_block_ids [2 :])
211+ # For 12th token (0-indexed),
212+ # token 0-11 are out, 6 block can be removed.
213+ manager .remove_skipped_blocks ("test" , 12 )
214+ assert_block_id (block_table , [null_block_id ] * 6 )
215+
216+
87217def test_sliding_window_remove_skipped_blocks ():
88218 sliding_window_spec = SlidingWindowSpec (
89219 block_size = 2 ,
@@ -172,3 +302,26 @@ def test_get_num_blocks_to_allocate():
172302 cached_blocks_1 ) == 20
173303 assert manager .get_num_blocks_to_allocate ("2" , 20 * block_size ,
174304 cached_blocks_2 ) == 15
305+
306+
307+ def test_chunked_local_attention_get_num_blocks_to_allocate ():
308+ block_size = 2
309+ attention_spec = ChunkedLocalAttentionSpec (
310+ block_size = block_size ,
311+ num_kv_heads = 1 ,
312+ head_size = 1 ,
313+ dtype = torch .float32 ,
314+ attention_chunk_size = 4 , # Placeholder value, not related to test result
315+ use_mla = False ,
316+ )
317+
318+ block_pool = BlockPool (num_gpu_blocks = 100 , enable_caching = True )
319+ manager = get_chunked_local_attention_manager (attention_spec , block_pool )
320+ cached_blocks_1 = [KVCacheBlock (i + 1 ) for i in range (10 )]
321+ cached_blocks_2 = [block_pool .null_block for _ in range (5 )
322+ ] + [KVCacheBlock (i + 1 ) for i in range (5 )]
323+
324+ assert manager .get_num_blocks_to_allocate ("1" , 20 * block_size ,
325+ cached_blocks_1 ) == 20
326+ assert manager .get_num_blocks_to_allocate ("2" , 20 * block_size ,
327+ cached_blocks_2 ) == 15
0 commit comments