Skip to content
7 changes: 7 additions & 0 deletions tests/distributed/test_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
dcp_size: int
dcp_kv_cache_interleave_size: int
eager_mode: bool
chunked_prefill: bool

Expand All @@ -52,6 +53,7 @@ def detailed(
tp_base: int = 4,
pp_base: int = 1,
dcp_base: int = 1,
dcp_kv_cache_interleave_size: int = 1,
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: str | None = None,
Expand All @@ -66,6 +68,7 @@ def detailed(
tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
dcp_size=int(dcp_multiplier * tp_base),
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val,
)
Expand Down Expand Up @@ -108,6 +111,7 @@ def _compare_cp_with_tp(
tp_size,
pp_size,
dcp_size,
dcp_kv_cache_interleave_size,
eager_mode,
chunked_prefill,
) = parallel_setup
Expand Down Expand Up @@ -180,6 +184,8 @@ def _compare_cp_with_tp(
str(pp_size),
"--decode-context-parallel-size",
str(dcp_size),
"--dcp-kv-cache-interleave-size",
str(dcp_kv_cache_interleave_size),
"--distributed-executor-backend",
distributed_backend,
]
Expand Down Expand Up @@ -207,6 +213,7 @@ def _compare_cp_with_tp(
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
],
"bigcode/gpt_bigcode-santacoder": [
CPTestSettings.detailed(),
Expand Down
2 changes: 2 additions & 0 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,7 @@ def test_hybrid_block_table_initialization():
max_num_reqs = 10
max_num_blocks_per_req = 20
max_num_batched_tokens = 512
dcp_kv_cache_interleave_size = 8

block_table = BlockTable(
block_size=block_size,
Expand All @@ -960,6 +961,7 @@ def test_hybrid_block_table_initialization():
pin_memory=False,
device=torch.device(DEVICE),
kernel_block_size=kernel_block_sizes[0],
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
)

# Verify hybrid block configuration
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _correct_attn_cp_out_kernel(
lse = tl.load(lses_ptr + lse_offsets)
lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse)
lse_max = tl.max(lse, axis=0)
lse_max = tl.where(lse_max == -float("inf"), 0, lse_max)
lse -= lse_max
lse_exp = tl.exp(lse)
lse_acc = tl.sum(lse_exp, axis=0)
Expand Down
11 changes: 11 additions & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,17 @@ class is dynamically inherited by the worker class. This is used to inject
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
needs to be divisible by dcp_size."""

dcp_kv_cache_interleave_size: int = 1
"""Interleave size of kv_cache storage while using dcp or cp > 1,
store interleave_size tokens on (d)cp i,
then store next interleave_size tokens on (d)cp i+1.
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
Interleave_size=block_size: block-level align, first fill the block on first rank,
token is stored on rank i+1 block j after rank i block j is full.
Block_size should be greater than or equal to dcp_kv_cache_interleave_size.
Block_size should be divisible by dcp_kv_cache_interleave_size.
"""

_api_process_count: int = Field(default=1, gt=0)
"""
The number of API processes initialized.
Expand Down
17 changes: 17 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,23 @@ def __post_init__(self):
)
current_platform.check_and_update_config(self)

assert (
self.parallel_config.dcp_kv_cache_interleave_size
<= self.cache_config.block_size
and self.cache_config.block_size
% self.parallel_config.dcp_kv_cache_interleave_size
== 0
), (
f"Block_size({self.cache_config.block_size}) should be "
"greater than or equal to and divisible by dcp_kv_cache_interleave_size "
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
)

assert (
self.parallel_config.dcp_kv_cache_interleave_size == 1
or self.speculative_config is None
), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now."

# Do this after all the updates to compilation_config.mode
if self.compilation_config.mode == CompilationMode.VLLM_COMPILE:
self.compilation_config.set_splitting_ops_for_v1()
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ class EngineArgs:
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: int | None = None
data_parallel_start_rank: int | None = None
Expand Down Expand Up @@ -750,6 +751,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"-dcp",
**parallel_kwargs["decode_context_parallel_size"],
)
parallel_group.add_argument(
"--dcp-kv-cache-interleave-size",
**parallel_kwargs["dcp_kv_cache_interleave_size"],
)
parallel_group.add_argument(
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
)
Expand Down Expand Up @@ -1518,6 +1523,7 @@ def create_engine_config(
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size,
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
_api_process_count=self._api_process_count,
_api_process_rank=self._api_process_rank,
)
Expand Down
13 changes: 11 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_dcp_local_seq_lens,
get_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -238,6 +239,10 @@ def __init__(
self.dcp_world_size = 1
self.dcp_rank = 0

self.dcp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
)

self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
Expand Down Expand Up @@ -352,8 +357,12 @@ def schedule(
- common_attn_metadata.query_start_loc_cpu[:-1]
)
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + (
self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size

dcp_context_kv_lens_cpu = get_dcp_local_seq_lens(
dcp_context_kv_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
self.dcp_kv_cache_interleave_size,
)
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
Expand Down
Loading