Skip to content

Commit 2108a57

Browse files
[DCP] Support dcp kv_cache interleave size > 1 (vllm-project#26696)
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com> Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com> Signed-off-by: Qiu <qiuchunshuo@huawei.com> Co-authored-by: QiuChunshuo <qiuchunshuo@huawei.com>
1 parent 4760413 commit 2108a57

File tree

12 files changed

+202
-79
lines changed

12 files changed

+202
-79
lines changed

tests/distributed/test_context_parallel.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ParallelSetup(NamedTuple):
3030
tp_size: int
3131
pp_size: int
3232
dcp_size: int
33+
dcp_kv_cache_interleave_size: int
3334
eager_mode: bool
3435
chunked_prefill: bool
3536

@@ -52,6 +53,7 @@ def detailed(
5253
tp_base: int = 4,
5354
pp_base: int = 1,
5455
dcp_base: int = 1,
56+
dcp_kv_cache_interleave_size: int = 1,
5557
multi_node_only: bool = False,
5658
runner: RunnerOption = "auto",
5759
load_format: str | None = None,
@@ -66,6 +68,7 @@ def detailed(
6668
tp_size=tp_base,
6769
pp_size=pp_multiplier * pp_base,
6870
dcp_size=int(dcp_multiplier * tp_base),
71+
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
6972
eager_mode=eager_mode_val,
7073
chunked_prefill=chunked_prefill_val,
7174
)
@@ -108,6 +111,7 @@ def _compare_cp_with_tp(
108111
tp_size,
109112
pp_size,
110113
dcp_size,
114+
dcp_kv_cache_interleave_size,
111115
eager_mode,
112116
chunked_prefill,
113117
) = parallel_setup
@@ -180,6 +184,8 @@ def _compare_cp_with_tp(
180184
str(pp_size),
181185
"--decode-context-parallel-size",
182186
str(dcp_size),
187+
"--dcp-kv-cache-interleave-size",
188+
str(dcp_kv_cache_interleave_size),
183189
"--distributed-executor-backend",
184190
distributed_backend,
185191
]
@@ -207,6 +213,7 @@ def _compare_cp_with_tp(
207213
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
208214
CPTestSettings.detailed(),
209215
CPTestSettings.detailed(tp_base=2),
216+
CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
210217
],
211218
"bigcode/gpt_bigcode-santacoder": [
212219
CPTestSettings.detailed(),

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,7 @@ def test_hybrid_block_table_initialization():
951951
max_num_reqs = 10
952952
max_num_blocks_per_req = 20
953953
max_num_batched_tokens = 512
954+
dcp_kv_cache_interleave_size = 8
954955

955956
block_table = BlockTable(
956957
block_size=block_size,
@@ -960,6 +961,7 @@ def test_hybrid_block_table_initialization():
960961
pin_memory=False,
961962
device=torch.device(DEVICE),
962963
kernel_block_size=kernel_block_sizes[0],
964+
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
963965
)
964966

965967
# Verify hybrid block configuration

vllm/attention/ops/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def _correct_attn_cp_out_kernel(
5353
lse = tl.load(lses_ptr + lse_offsets)
5454
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
5555
lse_max = tl.max(lse, axis=0)
56+
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
5657
lse -= lse_max
5758
lse_exp = tl.exp(lse)
5859
lse_acc = tl.sum(lse_exp, axis=0)

vllm/config/parallel.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,17 @@ class is dynamically inherited by the worker class. This is used to inject
227227
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
228228
needs to be divisible by dcp_size."""
229229

230+
dcp_kv_cache_interleave_size: int = 1
231+
"""Interleave size of kv_cache storage while using dcp or cp > 1,
232+
store interleave_size tokens on (d)cp i,
233+
then store next interleave_size tokens on (d)cp i+1.
234+
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
235+
Interleave_size=block_size: block-level align, first fill the block on first rank,
236+
token is stored on rank i+1 block j after rank i block j is full.
237+
Block_size should be greater than or equal to dcp_kv_cache_interleave_size.
238+
Block_size should be divisible by dcp_kv_cache_interleave_size.
239+
"""
240+
230241
_api_process_count: int = Field(default=1, gt=0)
231242
"""
232243
The number of API processes initialized.

vllm/config/vllm.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,23 @@ def __post_init__(self):
608608
)
609609
current_platform.check_and_update_config(self)
610610

611+
assert (
612+
self.parallel_config.dcp_kv_cache_interleave_size
613+
<= self.cache_config.block_size
614+
and self.cache_config.block_size
615+
% self.parallel_config.dcp_kv_cache_interleave_size
616+
== 0
617+
), (
618+
f"Block_size({self.cache_config.block_size}) should be "
619+
"greater than or equal to and divisible by dcp_kv_cache_interleave_size "
620+
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
621+
)
622+
623+
assert (
624+
self.parallel_config.dcp_kv_cache_interleave_size == 1
625+
or self.speculative_config is None
626+
), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now."
627+
611628
# Do this after all the updates to compilation_config.mode
612629
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
613630
self.compilation_config.set_splitting_ops_for_v1()

vllm/engine/arg_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ class EngineArgs:
385385
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
386386
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
387387
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
388+
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
388389
data_parallel_size: int = ParallelConfig.data_parallel_size
389390
data_parallel_rank: int | None = None
390391
data_parallel_start_rank: int | None = None
@@ -750,6 +751,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
750751
"-dcp",
751752
**parallel_kwargs["decode_context_parallel_size"],
752753
)
754+
parallel_group.add_argument(
755+
"--dcp-kv-cache-interleave-size",
756+
**parallel_kwargs["dcp_kv_cache_interleave_size"],
757+
)
753758
parallel_group.add_argument(
754759
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
755760
)
@@ -1518,6 +1523,7 @@ def create_engine_config(
15181523
worker_cls=self.worker_cls,
15191524
worker_extension_cls=self.worker_extension_cls,
15201525
decode_context_parallel_size=self.decode_context_parallel_size,
1526+
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
15211527
_api_process_count=self._api_process_count,
15221528
_api_process_rank=self._api_process_rank,
15231529
)

vllm/v1/attention/backends/flash_attn.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
AttentionCGSupport,
4444
AttentionMetadataBuilder,
4545
CommonAttentionMetadata,
46+
get_dcp_local_seq_lens,
4647
get_kv_cache_layout,
4748
)
4849
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -238,6 +239,10 @@ def __init__(
238239
self.dcp_world_size = 1
239240
self.dcp_rank = 0
240241

242+
self.dcp_kv_cache_interleave_size = (
243+
self.parallel_config.dcp_kv_cache_interleave_size
244+
)
245+
241246
self.use_full_cuda_graph = (
242247
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
243248
)
@@ -352,8 +357,12 @@ def schedule(
352357
- common_attn_metadata.query_start_loc_cpu[:-1]
353358
)
354359
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
355-
dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + (
356-
self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size
360+
361+
dcp_context_kv_lens_cpu = get_dcp_local_seq_lens(
362+
dcp_context_kv_lens_cpu,
363+
self.dcp_world_size,
364+
self.dcp_rank,
365+
self.dcp_kv_cache_interleave_size,
357366
)
358367
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
359368
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()

0 commit comments

Comments
 (0)