-
- Notifications
You must be signed in to change notification settings - Fork 11.2k
[DCP] Support dcp kv_cache interleave size > 1 #26696
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| 👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for configurable interleave size for KV cache in Decode Context Parallelism (DCP), which is a nice enhancement for flexibility. The changes also include refactoring the dcp_local_seq_lens computation into a utility function. The implementation is mostly solid, but I've identified a couple of areas for improvement. One is a misleading error message in an assertion, and the other is an opportunity to refactor a new utility function for better readability and efficiency. Addressing these points will improve the code quality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
aa23faa to 397fd51 Compare 656e08c to 46ec829 Compare There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution! Requesting changes to prevent merging until the test results can be obtained 👍
| tp_base: int = 4, | ||
| pp_base: int = 1, | ||
| dcp_base: int = 1, | ||
| cp_kv_cache_interleave_size: int = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please call this dcp_kv_cache_interleave_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After prefill cp (#25852) is supported, this kv_cache_interleave_size will be used for both dcp and pcp, shall we keep this name for future usage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure; by that logic we should update dcp_local_seq_lens to cp_local_seq_lens too but we can do that in the pcp PR
vllm/v1/worker/block_table.py Outdated
| self, | ||
| req_indices: np.ndarray, | ||
| positions: np.ndarray, | ||
| cp_kv_cache_interleave_size: int = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since this is a constant pass it via the init
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for review, we have already passed it via init
vllm/utils/__init__.py Outdated
| i += 1 | ||
| | ||
| | ||
| def get_dcp_local_seq_lens( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should find a better spot for this; this is too broad of a utils file for a feature specific utility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we put this function in vllm/v1/attention/backends/utils.py, same place as CommonAttentionMetadata.dcp_local_seq_lens definition, this should be a more appropriate spot
vllm/v1/worker/gpu_model_runner.py Outdated
| | ||
| # update seq_lens of decode reqs under DCP. | ||
| if self.dcp_world_size > 1: | ||
| self.dcp_local_seq_lens.gpu[:num_reqs] = get_dcp_local_seq_lens( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might actually be better to compute get_dcp_local_seq_lens using host buffers and then do a non-blocking copy to self.dcp_local_seq_lens.gpu (see: CpuGpuBuffer.copy_to_gpu)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(then when async scheduling is enabled it will be overlapped)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified as suggested, thanks for review
youzhedian left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
vllm/v1/worker/gpu_model_runner.py Outdated
| self.max_model_len = model_config.max_model_len | ||
| self.dcp_world_size = self.parallel_config.decode_context_parallel_size | ||
| try: | ||
| self.dcp_rank = get_dcp_group().rank_in_group |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delay to get_dcp_local_seq_lens calling is better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some cases we might need to know how seq_len is split globally, instead of only local seq_len on current dcp_rank, for example in our current npu mla impl, we need the global seq_len split message to calculate a mask for following update_lse (if no kv_cache is stored on some (d)cp_ranks, then there's no need to do corresponding update_lse), so we think it's better to return the full seq_len split result from get_dcp_local_seq_lens, and each dcp_rank can select their corresponding part as needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think we can simplify this to:
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group that way we'll still get the benefit of the assert in get_dcp_group() an if a test sets self.dcp_world_size > 1 it should be initializing the dcp group anyways
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better way to get dcp_rank 👍 Modified as suggested
vllm/v1/worker/block_table.py Outdated
| self, | ||
| req_indices: np.ndarray, | ||
| positions: np.ndarray, | ||
| cp_kv_cache_interleave_size: int = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe no default val is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for review, now we pass this arg via init, since it's a constant
LucasWilkinson left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looking really good! Thanks for all the cleanup; couple final comments
| ) | ||
| # Note(qcs): The context lengths in the perspective of dcp rank0 | ||
| # padded to `dcp_kv_cache_interleave_size`. | ||
| local_context_lens_cpu = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this correct? im surprised this is dcp_rank agnostic, am i missing something?
right now its
local_context_lens_cpu = cdiv(context_lens_cpu, self.dcp_world_size * self.dcp_kv_cache_interleave_size) * self.dcp_kv_cache_interleave_size but i would have thought it should be something like:
virtual_block_size = self.dcp_world_size * self.dcp_kv_cache_interleave_size partial_virtual_block = context_lens_cpu % virtual_block_size base = context_lens_cpu // virtual_block_size extra = min(max(0, partial_virtual_block - self.dcp_rank * self.dcp_kv_cache_interleave_size), self.dcp_kv_cache_interleave_size) local_context_lens_cpu = base + extra (also can we use cdiv here to make more readable?, i.e. from vllm.utils import cdiv)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've changed this comment to: "The max local context lengths padded to dcp_local_block_size."
is this correct? im surprised this is dcp_rank agnostic, am i missing something?
Yes, this value is dcp_rank agnostic. The reason I initially mentioned rank 0 is because in the current implementation, the sequence on rank 0 is typically the longest by default. Replacing that description with "maximum length" is indeed more accurate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
won't this over gather on everything other than rank 0? previously there was this correction code in reorg_kvcache to correct to the current rank
if rank > cp_target_rank and cp_chunk_seq_len: real_cp_chunk_seq_len = cp_chunk_seq_len - 1 else: real_cp_chunk_seq_len = cp_chunk_seq_len also this not the rank 0 context len any more since we could partial block on rank 0, previously:
# Note(hc): The context lengths in the perspective of dcp rank0. cp_context_lens_cpu = torch.ceil( context_lens_cpu.float() / self.dcp_world_size ).int() was correct because with an interleave size 1 there could never be a partial block on rank 1
am i missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry for delay, we've identified another bug and are working on a fix.
won't this over gather on everything other than rank 0?
Yes, the current approach would cause over-gathering on other ranks. We've changed local_context_lens_cpu to accurately reflect the KVCache length on each device to avoid unnecessary gathering (1a94cb7).
However, one question remains: are the performance of all_gather and all_gatherv consistent? To my knowledge, all_gatherv suffers from performance degradation on NPUs. I'm uncertain whether similar issues exist on GPUs or other devices.
Test results
DeepSeek-V2-Lite-Chat with FlashMLA
TP8DCP8 concurrency=256
interleave 1
| dataset | version | metric | mode | vllm-api-stream-chat |
|---|---|---|---|---|
| gsm8kdataset | - | accuracy | gen | 69.14 |
interleave 8
| dataset | version | metric | mode | vllm-api-stream-chat |
|---|---|---|---|---|
| gsm8kdataset | - | accuracy | gen | 69.37 |
interleave 64
| dataset | version | metric | mode | vllm-api-stream-chat |
|---|---|---|---|---|
| gsm8kdataset | - | accuracy | gen | 68.69 |
new bug
Regarding the newly discovered bug, it occurs in the tritonMLA backend and has existed prior to this PR (tested on 650b51f). Specifically, when testing with the gsm8k dataset, we observe some outputs consisting entirely of "!" as concurrency increases (appearing in 2-3 cases at concurrency=5, and dozens of cases at concurrency=256, leading to a 6% drop in accuracy). This indicates NaN values occurring during the prefill phase, though the root cause has not yet been identified. I think resolving this issue may require significant time for fixes, and plan to address it in a subsequent PR.
CC @youzhedian
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it; makes sense! You are correct we should go with the allgather; that was my bad 🤦 I didn't realize reorg_kvcache was unpadding. I think we need to improve the readability of this padded all-gather then un-padding logic (apologies this is basically tech debt from the orignal DCP PR not you; i was never-very-happy / totally-understood the reorg_kvcache code and was hoping to refactor it but haven't had a chance)
got it; makes sense, I understand now, good call lets go back to the all_gather approach but try to make it more readable to avoid confusions. Some suggestions:
local_max_context_chunk -> padded_local_max_context_chunk_across_ranks
local_chunk_seq_lens -> padded_local_chunk_seq_lens
local_cu_chunk_seq_lens_cpu -> padded_local_cu_chunk_seq_lens_cpu
local_cu_seq_lens -> padded_local_cu_seq_lens
local_context_lens_allrank -> local_context_lens_allranks
reorg_kvcache <- Add a comment indicating this removes padding added to make the all-gather a fixed size across ranks or maybe rename reorg_kvcache -> unpad_kvcache
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, it took us considerable effort to understand this function initially. I've updated the comments and added examples for clarity.
Sure, I've renamed those variables based on your feedback. Thanks for the review
6fce752 to 6ddf209 Compare Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
| This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: Qiu <qiuchunshuo@huawei.com>
LucasWilkinson left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for all the hardwork and sorry about the back and forth
we should refactor reorg_kvcache at somepoint (maybe use a triton kernel) but that can be done in the future; i appreciate the clarify comments/renaming!
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Head branch was pushed to by a user without write access
Signed-off-by: Hollow Man <hollowman@opensuse.org>
Signed-off-by: Hollow Man <hollowman@opensuse.org>
Purpose
1. cp_kv_cache_interleave_size support
In dcp scenario, kv_cache is split across dcp ranks, current implementation (#23734) split kv_cache with a token-level interleave style: token_idx i is stored on GPU whose dcp_rank == i % dcp_world_size.
For the convenience of pd disaggregate support, we add the cp_kv_cache_interleave_size argument to control the interleave size of kv_cache split size: store interleave_size tokens on dcp i, then store next interleave_size tokens on dcp i+1. The default value of cp_kv_cache_interleave_size is 1, which is same as original token-level interleave implementation. By setting cp_kv_cache_interleave_size to block_size, we can split kv_cache with a block-level interleave style, and makes it easy to support pd disaggregate with dcp > 1: D nodes only need to pull the corresponding kv_cache blocks, without need to rearange tokens in blocks.
Only dcp with cp_kv_cache_interleave_size is supported now, but the case of (p)cp is also considered and is easy to extend in the future.
2. Move dcp_local_seq_lens computation to utils
Move dcp_local_seq_lens computation to utils and pass it by metadata, so other attn backends can reuse it.
Test Plan
Model: DeepSeek-V2-Lite-Chat
Dataset: gsm8k
Test Result
tp2 dcp2, original code
tp2 dcp2, interleave_size = 1
tp2 dcp2, interleave_size = 64
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.