Skip to content

Commit d7b5ae7

Browse files
committed
[refactor] rename and clean code
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
1 parent 72e3a0f commit d7b5ae7

File tree

3 files changed

+70
-79
lines changed

3 files changed

+70
-79
lines changed

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,9 @@ def __init__(
233233
self.dcp_world_size = 1
234234
self.dcp_rank = 0
235235

236-
self.dcp_kv_cache_interleave_size = \
236+
self.dcp_kv_cache_interleave_size = (
237237
self.parallel_config.dcp_kv_cache_interleave_size
238+
)
238239

239240
self.use_full_cuda_graph = (
240241
self.compilation_config.cudagraph_mode.has_full_cudagraphs()

vllm/v1/attention/backends/mla/common.py

Lines changed: 64 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,9 @@ class ChunkedContextMetadata:
359359
workspace: torch.Tensor
360360

361361
# for mla DCP
362-
cp_chunk_seq_lens: list[list[int]] | None = None
363-
origin_context_lens: list[int] | None = None
364-
cp_cu_seq_lens: torch.Tensor | None = None
365-
chunk_size: int | None = None
362+
local_chunk_seq_lens: list[list[int]] | None = None
363+
local_context_lens_allrank: list[list[int]] | None = None
364+
local_cu_seq_lens: torch.Tensor | None = None
366365
cu_seq_lens_lst: list[list[int]] | None = None
367366

368367
block_table: torch.Tensor
@@ -555,7 +554,8 @@ def __init__(
555554
# DCP might not be initialized in testing
556555
self.dcp_world_size = 1
557556
self.dcp_rank = 0
558-
self.dcp_kv_cache_interleave_size = parallel_config.dcp_kv_cache_interleave_size
557+
self.dcp_local_block_size = parallel_config.dcp_kv_cache_interleave_size
558+
self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
559559

560560
# Don't try to access the runner on AMD
561561
if self.aot_schedule:
@@ -784,15 +784,6 @@ def build(
784784
reqs_start = num_decodes # prefill_start
785785

786786
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
787-
# Note(hc): The context lengths in the perspective of dcp rank0.
788-
cp_context_lens_cpu = (
789-
torch.ceil(
790-
context_lens_cpu.float()
791-
/ (self.dcp_world_size * self.dcp_kv_cache_interleave_size)
792-
).int()
793-
* self.dcp_kv_cache_interleave_size
794-
)
795-
origin_context_lens = context_lens_cpu.tolist()
796787
max_context_len_cpu = context_lens_cpu.max().item()
797788
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
798789
prefill_query_start_loc = (
@@ -848,32 +839,55 @@ def build(
848839
)
849840

850841
if self.dcp_world_size > 1:
842+
local_context_lens_allrank = get_dcp_local_seq_lens(
843+
context_lens_cpu,
844+
self.dcp_world_size,
845+
None,
846+
self.dcp_local_block_size,
847+
)
848+
# Note(qcs): The max local context lengths
849+
# padded to `dcp_local_block_size`.
850+
local_context_lens_cpu = (
851+
cdiv(
852+
context_lens_cpu,
853+
self.dcp_virtual_block_size,
854+
)
855+
* self.dcp_local_block_size
856+
)
851857
# Note(hc): The above max_context_chunk already enforces
852858
# block_size alignment, DCP just need the block_size can
853859
# be divisible by dcp_world_size, because DCP use
854860
# cp_gather_cache which not require `cp_chunk_starts`
855861
# aligned to page_size.
856862
assert max_context_chunk % self.dcp_world_size == 0
857-
cp_max_context_chunk = max_context_chunk // self.dcp_world_size
858-
cp_chunk_starts = (
863+
local_max_context_chunk = (
864+
cdiv(
865+
max_context_chunk,
866+
self.dcp_virtual_block_size,
867+
)
868+
* self.dcp_local_block_size
869+
)
870+
local_chunk_starts = (
859871
torch.arange(num_chunks, dtype=torch.int32)
860872
.unsqueeze(1)
861873
.expand(-1, num_prefills)
862-
* cp_max_context_chunk
874+
* local_max_context_chunk
863875
)
864-
cp_chunk_ends = torch.min(
865-
cp_context_lens_cpu.unsqueeze(0),
866-
cp_chunk_starts + cp_max_context_chunk,
876+
local_chunk_ends = torch.min(
877+
local_context_lens_cpu.unsqueeze(0),
878+
local_chunk_starts + local_max_context_chunk,
867879
)
868-
cp_chunk_seq_lens = (cp_chunk_ends - cp_chunk_starts).clamp(min=0)
880+
local_chunk_seq_lens = (
881+
local_chunk_ends - local_chunk_starts
882+
).clamp(min=0)
869883

870-
cp_cu_seq_lens_cpu = torch.zeros(
884+
local_cu_chunk_seq_lens_cpu = torch.zeros(
871885
num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True
872886
)
873887
torch.cumsum(
874-
cp_chunk_seq_lens,
888+
local_chunk_seq_lens,
875889
dim=1,
876-
out=cp_cu_seq_lens_cpu[:, 1:],
890+
out=local_cu_chunk_seq_lens_cpu[:, 1:],
877891
dtype=torch.int32,
878892
)
879893

@@ -885,15 +899,16 @@ def build(
885899
if self.dcp_world_size > 1:
886900
chunked_context_metadata = chunked_context_metadata_cls(
887901
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
888-
starts=cp_chunk_starts.to(device, non_blocking=True),
889-
seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(),
902+
starts=local_chunk_starts.to(device, non_blocking=True),
903+
seq_tot=local_chunk_seq_lens.sum(dim=1).tolist(),
890904
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
891905
seq_lens=chunk_seq_lens,
892906
workspace=self.chunked_prefill_workspace,
893-
cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(),
894-
origin_context_lens=origin_context_lens,
895-
cp_cu_seq_lens=cp_cu_seq_lens_cpu.to(device, non_blocking=True),
896-
chunk_size=max_context_chunk,
907+
local_chunk_seq_lens=local_chunk_seq_lens.tolist(),
908+
local_context_lens_allrank=local_context_lens_allrank.tolist(),
909+
local_cu_seq_lens=local_cu_chunk_seq_lens_cpu.to(
910+
device, non_blocking=True
911+
),
897912
cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
898913
)
899914
else:
@@ -970,70 +985,48 @@ def build(
970985
def reorg_kvcache(
971986
allgatered_kv_c_normed: torch.Tensor,
972987
allgatered_k_pe: torch.Tensor,
973-
cp_chunk_seq_lens_lst: list[int],
974-
origin_context_lens: list[int],
975-
cp_world_size: int,
988+
local_chunk_seq_lens_lst: list[int],
989+
local_context_lens_allrank: list[list[int]],
976990
sum_seq_len: int,
977991
max_seq_len: int,
978-
chunk_size: int,
979-
chunk_idx: int,
980992
toks: int,
981-
interleave_size: int,
982993
) -> tuple[torch.Tensor, torch.Tensor]:
983994
"""
984995
reorg kvcache after cp local gather to tp layout for attn kernel.
985996
986997
Args:
987-
cp_chunk_seq_lens_lst: chunk context lengths under CP.
988-
origin_context_lens: origin full context lengths under CP.
989-
cp_world_size: CP size.
998+
local_chunk_seq_lens_lst: local chunk context lengths
999+
under current CP rank.
1000+
local_context_lens_allrank: local context lengths on each CP rank.
9901001
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
9911002
max_seq_len: the max value of cp_chunk_seq_lens_lst.
992-
chunk_size: equals to max_context_chunk from
993-
chunked_context_metadata building.
994-
chunk_idx: chunk idx of chunked_prefill.
9951003
toks: the number of tokens for local gather cache.
996-
interleave_size: Interleave size of kv_cache storage.
9971004
"""
9981005
kv_c_segments = []
9991006
k_pe_segments = []
10001007
src_token_idx = 0
10011008
max_seq_len_check = 0
1002-
local_context_lens_allrank = get_dcp_local_seq_lens(
1003-
torch.Tensor(origin_context_lens),
1004-
cp_world_size,
1005-
None,
1006-
interleave_size,
1007-
)
1008-
# print(origin_context_lens, local_context_lens_allrank)
1009-
for cp_chunk_seq_len, origin_context_len, local_context_lens in zip(
1010-
cp_chunk_seq_lens_lst, origin_context_lens, local_context_lens_allrank
1009+
for local_chunk_seq_len, local_context_lens in zip(
1010+
local_chunk_seq_lens_lst, local_context_lens_allrank
10111011
):
1012-
chunk_context_len = chunk_size
1013-
if cp_chunk_seq_len != 0:
1014-
chunk_context_len = min(
1015-
chunk_context_len, origin_context_len - chunk_size * chunk_idx
1016-
)
1017-
10181012
cur_seq_len = 0
1019-
for rank in range(cp_world_size):
1020-
real_cp_chunk_seq_len = local_context_lens[rank]
1021-
if real_cp_chunk_seq_len != 0:
1013+
for rank, local_context_len in enumerate(local_context_lens):
1014+
if local_context_len != 0:
10221015
kv_c_segment = allgatered_kv_c_normed[
10231016
rank * toks + src_token_idx : rank * toks
10241017
+ src_token_idx
1025-
+ real_cp_chunk_seq_len
1018+
+ local_context_len
10261019
]
10271020
k_pe_segment = allgatered_k_pe[
10281021
rank * toks + src_token_idx : rank * toks
10291022
+ src_token_idx
1030-
+ real_cp_chunk_seq_len
1023+
+ local_context_len
10311024
]
10321025
kv_c_segments.append(kv_c_segment)
10331026
k_pe_segments.append(k_pe_segment)
1034-
cur_seq_len += real_cp_chunk_seq_len
1027+
cur_seq_len += local_context_len
10351028
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
1036-
src_token_idx += cp_chunk_seq_len
1029+
src_token_idx += local_chunk_seq_len
10371030
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
10381031
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
10391032
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
@@ -1591,10 +1584,9 @@ def _context_parallel_compute_prefill_context(
15911584
assert attn_metadata.prefill is not None
15921585
prefill_metadata = attn_metadata.prefill
15931586
assert prefill_metadata.chunked_context is not None
1594-
assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None
1595-
assert prefill_metadata.chunked_context.origin_context_lens is not None
1596-
assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None
1597-
assert prefill_metadata.chunked_context.chunk_size is not None
1587+
assert prefill_metadata.chunked_context.local_chunk_seq_lens is not None
1588+
assert prefill_metadata.chunked_context.local_context_lens_allrank is not None
1589+
assert prefill_metadata.chunked_context.local_cu_seq_lens is not None
15981590
assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
15991591

16001592
output = None
@@ -1607,7 +1599,7 @@ def _context_parallel_compute_prefill_context(
16071599
src_cache=kv_c_and_k_pe_cache,
16081600
dst=workspace,
16091601
block_table=prefill_metadata.block_table,
1610-
cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i],
1602+
cu_seq_lens=prefill_metadata.chunked_context.local_cu_seq_lens[i],
16111603
batch_size=attn_metadata.num_prefills,
16121604
seq_starts=prefill_metadata.chunked_context.starts[i],
16131605
)
@@ -1637,17 +1629,13 @@ def _context_parallel_compute_prefill_context(
16371629
kv_c_normed, k_pe = reorg_kvcache(
16381630
allgatered_kv_c_normed,
16391631
allgatered_k_pe,
1640-
cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.cp_chunk_seq_lens[
1632+
local_chunk_seq_lens_lst=prefill_metadata.chunked_context.local_chunk_seq_lens[
16411633
i
16421634
],
1643-
origin_context_lens=prefill_metadata.chunked_context.origin_context_lens,
1644-
cp_world_size=dcp_world_size,
1635+
local_context_lens_allrank=prefill_metadata.chunked_context.local_context_lens_allrank,
16451636
sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
16461637
max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
1647-
chunk_size=prefill_metadata.chunked_context.chunk_size,
1648-
chunk_idx=i,
16491638
toks=toks,
1650-
interleave_size=self.dcp_kv_cache_interleave_size,
16511639
)
16521640

16531641
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(

vllm/v1/attention/backends/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,7 +1013,9 @@ def get_dcp_local_seq_lens(
10131013
)
10141014
else:
10151015
rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32)
1016-
seq_lens_tiled = seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
1016+
seq_lens_tiled = (
1017+
seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
1018+
)
10171019
base = (
10181020
seq_lens_tiled
10191021
// dcp_kv_cache_interleave_size
@@ -1027,4 +1029,4 @@ def get_dcp_local_seq_lens(
10271029
dcp_kv_cache_interleave_size,
10281030
)
10291031
dcp_local_seq_lens = base + remainder
1030-
return dcp_local_seq_lens.squeeze(1)
1032+
return dcp_local_seq_lens.squeeze(1)

0 commit comments

Comments
 (0)