Skip to content

Commit 9a9fda1

Browse files
luccafongLu Fang
andauthored
[Core] Support Local Chunked Attention for Hybrid KV Cache (vllm-project#19351)
Signed-off-by: Lucia Fang <fanglu@fb.com> Signed-off-by: Lu Fang <fanglu@meta.com> Signed-off-by: Lu Fang <fanglu@fb.com> Co-authored-by: Lu Fang <fanglu@meta.com>
1 parent 466e878 commit 9a9fda1

File tree

9 files changed

+351
-19
lines changed

9 files changed

+351
-19
lines changed

tests/v1/core/test_specialized_manager.py

Lines changed: 155 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import random
5+
46
import torch
57

68
from vllm.v1.core.block_pool import BlockPool
79
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
810
KVCacheBlock)
9-
from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager
10-
from vllm.v1.kv_cache_interface import SlidingWindowSpec
11+
from vllm.v1.core.single_type_kv_cache_manager import (
12+
ChunkedLocalAttentionManager, SlidingWindowManager)
13+
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
14+
SlidingWindowSpec)
1115

1216

1317
def get_sliding_window_manager(sliding_window_spec, block_pool):
@@ -17,6 +21,80 @@ def get_sliding_window_manager(sliding_window_spec, block_pool):
1721
kv_cache_group_id=0)
1822

1923

24+
def get_chunked_local_attention_manager(chunked_local_attention_spec,
25+
block_pool):
26+
return ChunkedLocalAttentionManager(chunked_local_attention_spec,
27+
block_pool,
28+
caching_hash_fn=lambda x: x,
29+
kv_cache_group_id=0)
30+
31+
32+
def test_chunked_local_attention_possible_cached_prefix():
33+
block_size = 2
34+
chunked_local_attention_spec = ChunkedLocalAttentionSpec(
35+
block_size=block_size,
36+
num_kv_heads=1,
37+
head_size=1,
38+
dtype=torch.float32,
39+
attention_chunk_size=4,
40+
use_mla=False,
41+
)
42+
43+
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
44+
manager = get_chunked_local_attention_manager(chunked_local_attention_spec,
45+
block_pool)
46+
47+
def run_one_case(block_is_cached, tail_token, expect_length):
48+
block_hash_list = [
49+
BlockHash(i, ()) for i in range(len(block_is_cached))
50+
]
51+
52+
block_pool.cached_block_hash_to_block.clear()
53+
54+
# Mock the block pool with the cached blocks
55+
for i, (block_hash,
56+
is_cached) in enumerate(zip(block_hash_list, block_is_cached)):
57+
if is_cached:
58+
block_pool.cached_block_hash_to_block[BlockHashWithGroupId(
59+
block_hash, 0)] = {
60+
i: block_pool.blocks[i + 10],
61+
}
62+
63+
computed_blocks = manager.find_longest_cache_hit(
64+
block_hashes=block_hash_list,
65+
max_length=len(block_hash_list) * block_size + tail_token,
66+
kv_cache_group_ids=[0],
67+
block_pool=block_pool,
68+
kv_cache_spec=chunked_local_attention_spec,
69+
use_eagle=False)[0]
70+
assert len(computed_blocks) == expect_length
71+
72+
assert all(block == block_pool.null_block
73+
for block in computed_blocks[:(expect_length - 1) // 2])
74+
75+
run_one_case([True], 0, 1)
76+
run_one_case([True], 1, 1)
77+
run_one_case([True, False], 0, 2)
78+
run_one_case([True, False], 1, 2)
79+
run_one_case([True, True], 0, 2)
80+
run_one_case([True, True], 1, 2)
81+
run_one_case([True, True, False], 0, 2)
82+
run_one_case([True, True, False], 1, 2)
83+
run_one_case([True, True, True], 0, 3)
84+
run_one_case([True, True, True], 1, 3)
85+
run_one_case([True, True, True, False], 0, 4)
86+
run_one_case([True, True, True, False], 1, 4)
87+
run_one_case([random.choice([True, False])] * 8 + [True], 1, 9)
88+
run_one_case([random.choice([True, False])] * 8 + [False], 1, 8)
89+
run_one_case([random.choice([True, False])] * 8 + [True, True], 1, 10)
90+
run_one_case([random.choice([True, False])] * 8 + [True, False], 0, 10)
91+
run_one_case([random.choice([True, False])] * 8 + [True, False], 1, 10)
92+
run_one_case([random.choice([True, False])] * 8 + [False, True], 0, 10)
93+
run_one_case([random.choice([True, False])] * 8 + [False, True], 1, 10)
94+
run_one_case([random.choice([True, False])] * 8 + [False, False], 0, 10)
95+
run_one_case([random.choice([True, False])] * 8 + [False, False], 1, 10)
96+
97+
2098
def test_sliding_window_possible_cached_prefix():
2199
block_size = 2
22100
sliding_window_spec = SlidingWindowSpec(
@@ -84,6 +162,58 @@ def run_one_case(block_is_cached, expect_length):
84162
], 8)
85163

86164

165+
def test_chunked_local_attention_remove_skipped_blocks():
166+
attention_spec = ChunkedLocalAttentionSpec(
167+
block_size=2,
168+
num_kv_heads=1,
169+
head_size=1,
170+
dtype=torch.float32,
171+
attention_chunk_size=4,
172+
use_mla=False,
173+
)
174+
175+
block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True)
176+
177+
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
178+
179+
null_block_id = block_pool.null_block.block_id
180+
181+
def id_to_block_table(ids) -> list[KVCacheBlock]:
182+
return [
183+
KVCacheBlock(id_)
184+
if id_ != null_block_id else block_pool.null_block for id_ in ids
185+
]
186+
187+
def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]):
188+
for block, id_ in zip(block_table, ids):
189+
if id_ == null_block_id:
190+
assert block == block_pool.null_block
191+
else:
192+
assert block.block_id == id_
193+
194+
original_block_ids = [
195+
1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010
196+
]
197+
block_table = id_to_block_table(original_block_ids)
198+
manager.req_to_blocks["test"] = block_table
199+
200+
manager.remove_skipped_blocks("test", 0)
201+
assert_block_id(block_table, original_block_ids)
202+
203+
# For 4th token (0-indexed), token 0-3 is out of the local attention window.
204+
manager.remove_skipped_blocks("test", 4)
205+
assert_block_id(block_table, [null_block_id] * 2)
206+
207+
# For 6th token (0-indexed), token 4 - 6 are in local attention window,
208+
# token 0 - 3 are out, 2 blocks can be removed.
209+
manager.remove_skipped_blocks("test", 6)
210+
assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:])
211+
# For 12th token (0-indexed),
212+
# token 0-11 are out, 6 block can be removed.
213+
manager.remove_skipped_blocks("test", 12)
214+
assert_block_id(block_table, [null_block_id] * 6)
215+
216+
87217
def test_sliding_window_remove_skipped_blocks():
88218
sliding_window_spec = SlidingWindowSpec(
89219
block_size=2,
@@ -172,3 +302,26 @@ def test_get_num_blocks_to_allocate():
172302
cached_blocks_1) == 20
173303
assert manager.get_num_blocks_to_allocate("2", 20 * block_size,
174304
cached_blocks_2) == 15
305+
306+
307+
def test_chunked_local_attention_get_num_blocks_to_allocate():
308+
block_size = 2
309+
attention_spec = ChunkedLocalAttentionSpec(
310+
block_size=block_size,
311+
num_kv_heads=1,
312+
head_size=1,
313+
dtype=torch.float32,
314+
attention_chunk_size=4, # Placeholder value, not related to test result
315+
use_mla=False,
316+
)
317+
318+
block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True)
319+
manager = get_chunked_local_attention_manager(attention_spec, block_pool)
320+
cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)]
321+
cached_blocks_2 = [block_pool.null_block for _ in range(5)
322+
] + [KVCacheBlock(i + 1) for i in range(5)]
323+
324+
assert manager.get_num_blocks_to_allocate("1", 20 * block_size,
325+
cached_blocks_1) == 20
326+
assert manager.get_num_blocks_to_allocate("2", 20 * block_size,
327+
cached_blocks_2) == 15

vllm/attention/layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def __init__(
172172
kv_sharing_target_layer_name, **extra_impl_args)
173173
self.backend = backend_name_to_enum(attn_backend.get_name())
174174
self.dtype = dtype
175+
self.use_irope = extra_impl_args.get("use_irope", False)
175176

176177
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
177178
# torch.compile works by registering the attention as one giant

vllm/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4722,6 +4722,13 @@ def __post_init__(self):
47224722
if self.kv_events_config is not None:
47234723
# Hybrid KV cache manager is not compatible with KV events.
47244724
self.scheduler_config.disable_hybrid_kv_cache_manager = True
4725+
if self.model_config is not None and \
4726+
self.model_config.attention_chunk_size is not None and \
4727+
self.speculative_config is not None and \
4728+
self.speculative_config.use_eagle():
4729+
# Hybrid KV cache manager is not yet supported with chunked
4730+
# local attention + eagle.
4731+
self.scheduler_config.disable_hybrid_kv_cache_manager = True
47254732

47264733
def update_sizes_for_sequence_parallelism(self,
47274734
possible_sizes: list) -> list:

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ def use_cascade_attention(
538538
num_kv_heads: int,
539539
use_alibi: bool,
540540
use_sliding_window: bool,
541+
use_local_attention: bool,
541542
num_sms: int,
542543
) -> bool:
543544
"""Decide whether to use cascade attention.
@@ -553,7 +554,7 @@ def use_cascade_attention(
553554
if common_prefix_len < 256:
554555
return False
555556
# Cascade attention is currently not supported with these variants.
556-
if use_alibi or use_sliding_window:
557+
if use_alibi or use_sliding_window or use_local_attention:
557558
return False
558559
# Too few queries. Probably not worth using cascade attention.
559560
# We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.

vllm/v1/attention/backends/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def use_cascade_attention(
120120
num_kv_heads: int,
121121
use_alibi: bool,
122122
use_sliding_window: bool,
123+
use_local_attention: bool,
123124
num_sms: int,
124125
) -> bool:
125126
return False

vllm/v1/core/kv_cache_utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from vllm.config import VllmConfig
1212
from vllm.logger import init_logger
1313
from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit
14-
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
14+
from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
15+
FullAttentionSpec, KVCacheConfig,
1516
KVCacheGroupSpec, KVCacheSpec,
1617
KVCacheTensor, SlidingWindowSpec)
1718
from vllm.v1.metrics.stats import PrefixCacheStats
@@ -976,7 +977,11 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
976977
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
977978
has_sliding_window = any(
978979
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
979-
if has_full_attention and has_sliding_window:
980+
has_chunked_local_attention = any(
981+
isinstance(spec, ChunkedLocalAttentionSpec)
982+
for spec in kv_cache_spec.values())
983+
if has_full_attention and (has_sliding_window
984+
or has_chunked_local_attention):
980985
for layer_name, spec in kv_cache_spec.items():
981986
if isinstance(spec, SlidingWindowSpec):
982987
kv_cache_spec[layer_name] = FullAttentionSpec(
@@ -987,6 +992,15 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
987992
use_mla=spec.use_mla,
988993
sliding_window=spec.sliding_window,
989994
)
995+
elif isinstance(spec, ChunkedLocalAttentionSpec):
996+
kv_cache_spec[layer_name] = FullAttentionSpec(
997+
block_size=spec.block_size,
998+
num_kv_heads=spec.num_kv_heads,
999+
head_size=spec.head_size,
1000+
dtype=spec.dtype,
1001+
use_mla=spec.use_mla,
1002+
attention_chunk_size=spec.attention_chunk_size,
1003+
)
9901004

9911005
if is_hybrid(kv_cache_spec):
9921006
raise ValueError("Hybrid KV cache manager is disabled but failed to "
@@ -1010,7 +1024,6 @@ def get_kv_cache_config(
10101024
The generated KVCacheConfigs
10111025
"""
10121026
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
1013-
10141027
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
10151028
unify_hybrid_kv_cache_specs(kv_cache_spec)
10161029

0 commit comments

Comments
 (0)