@@ -119,7 +119,10 @@ def swap_row(self, src: int, tgt: int) -> None:
119119 self .block_table .np [src_tgt ] = self .block_table .np [tgt_src ]
120120
121121 def compute_slot_mapping (
122- self , req_indices : np .ndarray , positions : np .ndarray
122+ self ,
123+ req_indices : np .ndarray ,
124+ positions : np .ndarray ,
125+ cp_kv_cache_interleave_size : int = 1 ,
123126 ) -> None :
124127 # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
125128 # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
@@ -144,9 +147,19 @@ def compute_slot_mapping(
144147 # Use virtual_block_size for mask calculation, which marks local
145148 # tokens.
146149 virtual_block_offsets = positions % virtual_block_size
147- mask = virtual_block_offsets % self .dcp_world_size == self .dcp_rank
150+ mask = (
151+ virtual_block_offsets
152+ // cp_kv_cache_interleave_size
153+ % self .dcp_world_size
154+ == self .dcp_rank
155+ )
148156 # Calculate local block_offsets
149- block_offsets = virtual_block_offsets // self .dcp_world_size
157+ block_offsets = (
158+ virtual_block_offsets
159+ // (self .dcp_world_size * cp_kv_cache_interleave_size )
160+ * cp_kv_cache_interleave_size
161+ + virtual_block_offsets % cp_kv_cache_interleave_size
162+ )
150163 # Calculate slot_mapping
151164 slot_mapping = block_numbers * self .block_size + block_offsets
152165 # Write final slots, use -1 for not-local
@@ -284,10 +297,17 @@ def swap_row(self, src: int, tgt: int) -> None:
284297 block_table .swap_row (src , tgt )
285298
286299 def compute_slot_mapping (
287- self , req_indices : np .ndarray , positions : np .ndarray
300+ self ,
301+ req_indices : np .ndarray ,
302+ positions : np .ndarray ,
303+ cp_kv_cache_interleave_size : int = 1 ,
288304 ) -> None :
289305 for block_table in self .block_tables :
290- block_table .compute_slot_mapping (req_indices , positions )
306+ block_table .compute_slot_mapping (
307+ req_indices ,
308+ positions ,
309+ cp_kv_cache_interleave_size ,
310+ )
291311
292312 def commit_block_table (self , num_reqs : int ) -> None :
293313 for block_table in self .block_tables :
0 commit comments