Skip to content

Commit 272c8f1

Browse files
committed
code cleanup and fix scheduler_block_size
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
1 parent f0ab17c commit 272c8f1

File tree

6 files changed

+25
-45
lines changed

6 files changed

+25
-45
lines changed

vllm/attention/ops/common.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,11 @@ def correct_attn_out(
168168
return out, lse
169169

170170

171-
def cp_lse_ag_out_rs(
171+
def _cp_lse_common(
172172
cp_attn_out: torch.Tensor,
173173
cp_attn_lse: torch.Tensor,
174174
cp_group: GroupCoordinator,
175175
ctx: CPTritonContext = None,
176-
return_lse=False,
177176
):
178177
"""
179178
cp_attn_out: [ B, H, D ]
@@ -195,6 +194,21 @@ def cp_lse_ag_out_rs(
195194
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
196195
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
197196
assert out.is_contiguous()
197+
return out, lse
198+
199+
200+
def cp_lse_ag_out_rs(
201+
cp_attn_out: torch.Tensor,
202+
cp_attn_lse: torch.Tensor,
203+
cp_group: GroupCoordinator,
204+
ctx: CPTritonContext = None,
205+
return_lse: bool = False,
206+
):
207+
"""
208+
cp_attn_out: [ B, H, D ]
209+
cp_attn_lse: [ B, H ]
210+
"""
211+
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
198212
out = cp_group.reduce_scatter(out, dim=1)
199213

200214
if return_lse:
@@ -215,22 +229,7 @@ def cp_lse_ag_out_ar(
215229
cp_attn_out: [ B, H, D ]
216230
cp_attn_lse: [ B, H ]
217231
"""
218-
if cp_group.world_size == 1:
219-
return cp_attn_out
220-
221-
if ctx is None:
222-
ctx = CPTritonContext()
223-
224-
lses = torch.empty(
225-
(cp_group.world_size,) + cp_attn_lse.shape,
226-
dtype=cp_attn_lse.dtype,
227-
device=cp_attn_lse.device,
228-
)
229-
230-
cp_attn_lse = cp_attn_lse.contiguous()
231-
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
232-
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
233-
assert out.is_contiguous()
232+
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
234233
out = cp_group.all_reduce(out)
235234
return out
236235

vllm/distributed/parallel_state.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,16 +1093,6 @@ def get_pcp_group() -> GroupCoordinator:
10931093
return _PCP
10941094

10951095

1096-
def get_prefill_context_model_parallel_world_size():
1097-
"""Return world size for the tensor model parallel group."""
1098-
return get_pcp_group().world_size
1099-
1100-
1101-
def get_prefill_context_model_parallel_rank():
1102-
"""Return my rank for the tensor model parallel group."""
1103-
return get_pcp_group().rank_in_group
1104-
1105-
11061096
@deprecated(
11071097
"`get_pipeline_model_parallel_group` has been replaced with "
11081098
"`get_pp_group` and may be removed in v0.12. Please use "
@@ -1476,16 +1466,6 @@ def get_tensor_model_parallel_rank():
14761466
return get_tp_group().rank_in_group
14771467

14781468

1479-
def get_decode_context_model_parallel_world_size():
1480-
"""Return world size for the decode context model parallel group."""
1481-
return get_dcp_group().world_size
1482-
1483-
1484-
def get_decode_context_model_parallel_rank():
1485-
"""Return my rank for the decode context model parallel group."""
1486-
return get_dcp_group().rank_in_group
1487-
1488-
14891469
def get_node_count() -> int:
14901470
"""Return the total number of nodes in the distributed environment."""
14911471
assert _NODE_COUNT is not None, "distributed environment is not initialized"

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from vllm.config import ParallelConfig
1010
from vllm.distributed import (
1111
get_dp_group,
12-
get_prefill_context_model_parallel_rank,
12+
get_pcp_group,
1313
get_tensor_model_parallel_rank,
1414
)
1515
from vllm.logger import init_logger
@@ -763,7 +763,7 @@ def flatten_tp_across_dp(dp_rank: int):
763763
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
764764
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
765765
pcp_size = pcp_size_
766-
pcp_rank = get_prefill_context_model_parallel_rank() if pcp_size_ > 1 else 0
766+
pcp_rank = get_pcp_group().rank_in_group if pcp_size_ > 1 else 0
767767

768768
if not use_ep:
769769
return FusedMoEParallelConfig(

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from vllm.distributed import (
1919
get_dp_group,
2020
get_ep_group,
21-
get_prefill_context_model_parallel_world_size,
21+
get_pcp_group,
2222
get_tensor_model_parallel_world_size,
2323
tensor_model_parallel_all_reduce,
2424
)
@@ -1103,7 +1103,7 @@ def __init__(
11031103
pcp_size_ = (
11041104
pcp_size
11051105
if pcp_size is not None
1106-
else get_prefill_context_model_parallel_world_size()
1106+
else get_pcp_group().world_size
11071107
)
11081108

11091109
self.is_sequence_parallel = is_sequence_parallel

vllm/v1/core/single_type_kv_cache_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def find_longest_cache_hit(
341341
"SlidingWindowManager can only be used for sliding window groups"
342342
)
343343
assert dcp_world_size == 1, "DCP not support sliding window attn now."
344-
assert pcp_world_size == 1, "CP not support sliding window attn now."
344+
assert pcp_world_size == 1, "PCP not support sliding window attn now."
345345

346346
# The number of contiguous blocks needed for prefix cache hit.
347347
# -1 since the input token itself is also included in the window
@@ -481,7 +481,7 @@ def find_longest_cache_hit(
481481
"Hybrid KV cache is not supported for " + "eagle + chunked local attention."
482482
)
483483
assert dcp_world_size == 1, "DCP not support chunked local attn now."
484-
assert pcp_world_size == 1, "CP not support chunked local attn now."
484+
assert pcp_world_size == 1, "PCP not support chunked local attn now."
485485
max_num_blocks = max_length // kv_cache_spec.block_size
486486
if max_length > 0:
487487
local_attention_start_idx = (
@@ -572,7 +572,7 @@ def find_longest_cache_hit(
572572
"MambaManager can only be used for mamba groups"
573573
)
574574
assert dcp_world_size == 1, "DCP not support mamba now."
575-
assert pcp_world_size == 1, "CP not support mamba now."
575+
assert pcp_world_size == 1, "PCP not support mamba now."
576576
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
577577
[] for _ in range(len(kv_cache_group_ids))
578578
)

vllm/v1/engine/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(
148148
scheduler_block_size = (
149149
vllm_config.cache_config.block_size
150150
* vllm_config.parallel_config.decode_context_parallel_size
151+
* vllm_config.parallel_config.prefill_context_parallel_size
151152
)
152153

153154
self.scheduler: SchedulerInterface = Scheduler(

0 commit comments

Comments
 (0)