@@ -359,10 +359,9 @@ class ChunkedContextMetadata:
359359 workspace : torch .Tensor
360360
361361 # for mla DCP
362- cp_chunk_seq_lens : list [list [int ]] | None = None
363- origin_context_lens : list [int ] | None = None
364- cp_cu_seq_lens : torch .Tensor | None = None
365- chunk_size : int | None = None
362+ local_chunk_seq_lens : list [list [int ]] | None = None
363+ local_context_lens_allrank : list [list [int ]] | None = None
364+ local_cu_seq_lens : torch .Tensor | None = None
366365 cu_seq_lens_lst : list [list [int ]] | None = None
367366
368367 block_table : torch .Tensor
@@ -555,7 +554,8 @@ def __init__(
555554 # DCP might not be initialized in testing
556555 self .dcp_world_size = 1
557556 self .dcp_rank = 0
558- self .dcp_kv_cache_interleave_size = parallel_config .dcp_kv_cache_interleave_size
557+ self .dcp_local_block_size = parallel_config .dcp_kv_cache_interleave_size
558+ self .dcp_virtual_block_size = self .dcp_local_block_size * self .dcp_world_size
559559
560560 # Don't try to access the runner on AMD
561561 if self .aot_schedule :
@@ -784,15 +784,6 @@ def build(
784784 reqs_start = num_decodes # prefill_start
785785
786786 context_lens_cpu = num_computed_tokens_cpu [reqs_start :num_reqs ]
787- # Note(hc): The context lengths in the perspective of dcp rank0.
788- cp_context_lens_cpu = (
789- torch .ceil (
790- context_lens_cpu .float ()
791- / (self .dcp_world_size * self .dcp_kv_cache_interleave_size )
792- ).int ()
793- * self .dcp_kv_cache_interleave_size
794- )
795- origin_context_lens = context_lens_cpu .tolist ()
796787 max_context_len_cpu = context_lens_cpu .max ().item ()
797788 num_prefills_with_context_cpu = (context_lens_cpu > 0 ).sum ().item ()
798789 prefill_query_start_loc = (
@@ -848,32 +839,55 @@ def build(
848839 )
849840
850841 if self .dcp_world_size > 1 :
842+ local_context_lens_allrank = get_dcp_local_seq_lens (
843+ context_lens_cpu ,
844+ self .dcp_world_size ,
845+ None ,
846+ self .dcp_local_block_size ,
847+ )
848+ # Note(qcs): The max local context lengths
849+ # padded to `dcp_local_block_size`.
850+ local_context_lens_cpu = (
851+ cdiv (
852+ context_lens_cpu ,
853+ self .dcp_virtual_block_size ,
854+ )
855+ * self .dcp_local_block_size
856+ )
851857 # Note(hc): The above max_context_chunk already enforces
852858 # block_size alignment, DCP just need the block_size can
853859 # be divisible by dcp_world_size, because DCP use
854860 # cp_gather_cache which not require `cp_chunk_starts`
855861 # aligned to page_size.
856862 assert max_context_chunk % self .dcp_world_size == 0
857- cp_max_context_chunk = max_context_chunk // self .dcp_world_size
858- cp_chunk_starts = (
863+ local_max_context_chunk = (
864+ cdiv (
865+ max_context_chunk ,
866+ self .dcp_virtual_block_size ,
867+ )
868+ * self .dcp_local_block_size
869+ )
870+ local_chunk_starts = (
859871 torch .arange (num_chunks , dtype = torch .int32 )
860872 .unsqueeze (1 )
861873 .expand (- 1 , num_prefills )
862- * cp_max_context_chunk
874+ * local_max_context_chunk
863875 )
864- cp_chunk_ends = torch .min (
865- cp_context_lens_cpu .unsqueeze (0 ),
866- cp_chunk_starts + cp_max_context_chunk ,
876+ local_chunk_ends = torch .min (
877+ local_context_lens_cpu .unsqueeze (0 ),
878+ local_chunk_starts + local_max_context_chunk ,
867879 )
868- cp_chunk_seq_lens = (cp_chunk_ends - cp_chunk_starts ).clamp (min = 0 )
880+ local_chunk_seq_lens = (
881+ local_chunk_ends - local_chunk_starts
882+ ).clamp (min = 0 )
869883
870- cp_cu_seq_lens_cpu = torch .zeros (
884+ local_cu_chunk_seq_lens_cpu = torch .zeros (
871885 num_chunks , num_prefills + 1 , dtype = torch .int32 , pin_memory = True
872886 )
873887 torch .cumsum (
874- cp_chunk_seq_lens ,
888+ local_chunk_seq_lens ,
875889 dim = 1 ,
876- out = cp_cu_seq_lens_cpu [:, 1 :],
890+ out = local_cu_chunk_seq_lens_cpu [:, 1 :],
877891 dtype = torch .int32 ,
878892 )
879893
@@ -885,15 +899,16 @@ def build(
885899 if self .dcp_world_size > 1 :
886900 chunked_context_metadata = chunked_context_metadata_cls (
887901 cu_seq_lens = cu_seq_lens_cpu .to (device , non_blocking = True ),
888- starts = cp_chunk_starts .to (device , non_blocking = True ),
889- seq_tot = cp_chunk_seq_lens .sum (dim = 1 ).tolist (),
902+ starts = local_chunk_starts .to (device , non_blocking = True ),
903+ seq_tot = local_chunk_seq_lens .sum (dim = 1 ).tolist (),
890904 max_seq_lens = chunk_seq_lens .max (dim = 1 ).values .tolist (),
891905 seq_lens = chunk_seq_lens ,
892906 workspace = self .chunked_prefill_workspace ,
893- cp_chunk_seq_lens = cp_chunk_seq_lens .tolist (),
894- origin_context_lens = origin_context_lens ,
895- cp_cu_seq_lens = cp_cu_seq_lens_cpu .to (device , non_blocking = True ),
896- chunk_size = max_context_chunk ,
907+ local_chunk_seq_lens = local_chunk_seq_lens .tolist (),
908+ local_context_lens_allrank = local_context_lens_allrank .tolist (),
909+ local_cu_seq_lens = local_cu_chunk_seq_lens_cpu .to (
910+ device , non_blocking = True
911+ ),
897912 cu_seq_lens_lst = cu_seq_lens_cpu .tolist (),
898913 )
899914 else :
@@ -970,70 +985,48 @@ def build(
970985def reorg_kvcache (
971986 allgatered_kv_c_normed : torch .Tensor ,
972987 allgatered_k_pe : torch .Tensor ,
973- cp_chunk_seq_lens_lst : list [int ],
974- origin_context_lens : list [int ],
975- cp_world_size : int ,
988+ local_chunk_seq_lens_lst : list [int ],
989+ local_context_lens_allrank : list [list [int ]],
976990 sum_seq_len : int ,
977991 max_seq_len : int ,
978- chunk_size : int ,
979- chunk_idx : int ,
980992 toks : int ,
981- interleave_size : int ,
982993) -> tuple [torch .Tensor , torch .Tensor ]:
983994 """
984995 reorg kvcache after cp local gather to tp layout for attn kernel.
985996
986997 Args:
987- cp_chunk_seq_lens_lst: chunk context lengths under CP.
988- origin_context_lens: origin full context lengths under CP .
989- cp_world_size: CP size .
998+ local_chunk_seq_lens_lst: local chunk context lengths
999+ under current CP rank .
1000+ local_context_lens_allrank: local context lengths on each CP rank .
9901001 sum_seq_len: the sum of cp_chunk_seq_lens_lst.
9911002 max_seq_len: the max value of cp_chunk_seq_lens_lst.
992- chunk_size: equals to max_context_chunk from
993- chunked_context_metadata building.
994- chunk_idx: chunk idx of chunked_prefill.
9951003 toks: the number of tokens for local gather cache.
996- interleave_size: Interleave size of kv_cache storage.
9971004 """
9981005 kv_c_segments = []
9991006 k_pe_segments = []
10001007 src_token_idx = 0
10011008 max_seq_len_check = 0
1002- local_context_lens_allrank = get_dcp_local_seq_lens (
1003- torch .Tensor (origin_context_lens ),
1004- cp_world_size ,
1005- None ,
1006- interleave_size ,
1007- )
1008- # print(origin_context_lens, local_context_lens_allrank)
1009- for cp_chunk_seq_len , origin_context_len , local_context_lens in zip (
1010- cp_chunk_seq_lens_lst , origin_context_lens , local_context_lens_allrank
1009+ for local_chunk_seq_len , local_context_lens in zip (
1010+ local_chunk_seq_lens_lst , local_context_lens_allrank
10111011 ):
1012- chunk_context_len = chunk_size
1013- if cp_chunk_seq_len != 0 :
1014- chunk_context_len = min (
1015- chunk_context_len , origin_context_len - chunk_size * chunk_idx
1016- )
1017-
10181012 cur_seq_len = 0
1019- for rank in range (cp_world_size ):
1020- real_cp_chunk_seq_len = local_context_lens [rank ]
1021- if real_cp_chunk_seq_len != 0 :
1013+ for rank , local_context_len in enumerate (local_context_lens ):
1014+ if local_context_len != 0 :
10221015 kv_c_segment = allgatered_kv_c_normed [
10231016 rank * toks + src_token_idx : rank * toks
10241017 + src_token_idx
1025- + real_cp_chunk_seq_len
1018+ + local_context_len
10261019 ]
10271020 k_pe_segment = allgatered_k_pe [
10281021 rank * toks + src_token_idx : rank * toks
10291022 + src_token_idx
1030- + real_cp_chunk_seq_len
1023+ + local_context_len
10311024 ]
10321025 kv_c_segments .append (kv_c_segment )
10331026 k_pe_segments .append (k_pe_segment )
1034- cur_seq_len += real_cp_chunk_seq_len
1027+ cur_seq_len += local_context_len
10351028 max_seq_len_check = max (max_seq_len_check , cur_seq_len )
1036- src_token_idx += cp_chunk_seq_len
1029+ src_token_idx += local_chunk_seq_len
10371030 reorganized_kv_c_normed = torch .cat (kv_c_segments , dim = 0 )
10381031 reorganized_k_pe = torch .cat (k_pe_segments , dim = 0 )
10391032 assert reorganized_kv_c_normed .shape [0 ] == sum_seq_len
@@ -1591,10 +1584,9 @@ def _context_parallel_compute_prefill_context(
15911584 assert attn_metadata .prefill is not None
15921585 prefill_metadata = attn_metadata .prefill
15931586 assert prefill_metadata .chunked_context is not None
1594- assert prefill_metadata .chunked_context .cp_chunk_seq_lens is not None
1595- assert prefill_metadata .chunked_context .origin_context_lens is not None
1596- assert prefill_metadata .chunked_context .cp_cu_seq_lens is not None
1597- assert prefill_metadata .chunked_context .chunk_size is not None
1587+ assert prefill_metadata .chunked_context .local_chunk_seq_lens is not None
1588+ assert prefill_metadata .chunked_context .local_context_lens_allrank is not None
1589+ assert prefill_metadata .chunked_context .local_cu_seq_lens is not None
15981590 assert prefill_metadata .chunked_context .cu_seq_lens_lst is not None
15991591
16001592 output = None
@@ -1607,7 +1599,7 @@ def _context_parallel_compute_prefill_context(
16071599 src_cache = kv_c_and_k_pe_cache ,
16081600 dst = workspace ,
16091601 block_table = prefill_metadata .block_table ,
1610- cu_seq_lens = prefill_metadata .chunked_context .cp_cu_seq_lens [i ],
1602+ cu_seq_lens = prefill_metadata .chunked_context .local_cu_seq_lens [i ],
16111603 batch_size = attn_metadata .num_prefills ,
16121604 seq_starts = prefill_metadata .chunked_context .starts [i ],
16131605 )
@@ -1637,17 +1629,13 @@ def _context_parallel_compute_prefill_context(
16371629 kv_c_normed , k_pe = reorg_kvcache (
16381630 allgatered_kv_c_normed ,
16391631 allgatered_k_pe ,
1640- cp_chunk_seq_lens_lst = prefill_metadata .chunked_context .cp_chunk_seq_lens [
1632+ local_chunk_seq_lens_lst = prefill_metadata .chunked_context .local_chunk_seq_lens [
16411633 i
16421634 ],
1643- origin_context_lens = prefill_metadata .chunked_context .origin_context_lens ,
1644- cp_world_size = dcp_world_size ,
1635+ local_context_lens_allrank = prefill_metadata .chunked_context .local_context_lens_allrank ,
16451636 sum_seq_len = prefill_metadata .chunked_context .cu_seq_lens_lst [i ][- 1 ],
16461637 max_seq_len = prefill_metadata .chunked_context .max_seq_lens [i ],
1647- chunk_size = prefill_metadata .chunked_context .chunk_size ,
1648- chunk_idx = i ,
16491638 toks = toks ,
1650- interleave_size = self .dcp_kv_cache_interleave_size ,
16511639 )
16521640
16531641 kv_nope = self .kv_b_proj (kv_c_normed )[0 ].view (
0 commit comments