File tree Expand file tree Collapse file tree 1 file changed +13
-4
lines changed Expand file tree Collapse file tree 1 file changed +13
-4
lines changed Original file line number Diff line number Diff line change @@ -1189,11 +1189,20 @@ def _report_kv_cache_config(
11891189 // len (kv_cache_config .kv_cache_groups )
11901190 * min_block_size
11911191 )
1192- if vllm_config .parallel_config .decode_context_parallel_size > 1 :
1193- num_tokens *= vllm_config .parallel_config .decode_context_parallel_size
1192+ if (
1193+ vllm_config .parallel_config .prefill_context_parallel_size *
1194+ vllm_config .parallel_config .decode_context_parallel_size > 1
1195+ ):
1196+ num_tokens *= (vllm_config .parallel_config .prefill_context_parallel_size *
1197+ vllm_config .parallel_config .decode_context_parallel_size )
1198+ cp_size = (vllm_config .parallel_config .prefill_context_parallel_size *
1199+ vllm_config .parallel_config .decode_context_parallel_size )
11941200 logger .info (
1195- "Multiplying the GPU KV cache size by the dcp_world_size %d." ,
1196- vllm_config .parallel_config .decode_context_parallel_size ,
1201+ "Multiplying the GPU KV cache size by the cp_world_size %d "
1202+ "(pcp_world_size %d * dcp_world_size %d)." ,
1203+ cp_size ,
1204+ vllm_config .parallel_config .prefill_context_parallel_size ,
1205+ vllm_config .parallel_config .decode_context_parallel_size
11971206 )
11981207 num_tokens_str = f"{ num_tokens :,} "
11991208 logger .info ("GPU KV cache size: %s tokens" , num_tokens_str )
You can’t perform that action at this time.
0 commit comments