Skip to content

Commit f8afd97

Browse files
[scheduler] support cp
Signed-off-by: LookAround <lixushi@huawei.com>
1 parent 5aeb925 commit f8afd97

File tree

13 files changed

+171
-39
lines changed

13 files changed

+171
-39
lines changed

vllm/config/parallel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ class ParallelConfig:
6767
"""Number of pipeline parallel groups."""
6868
tensor_parallel_size: int = 1
6969
"""Number of tensor parallel groups."""
70+
context_parallel_size: int = 1
71+
"""Number of context parallel groups."""
7072
data_parallel_size: int = 1
7173
"""Number of data parallel groups. MoE layers will be sharded according to
7274
the product of the tensor parallel size and data parallel size."""
@@ -349,7 +351,7 @@ def __post_init__(self) -> None:
349351

350352
# Continue with the rest of the initialization
351353
self.world_size = self.pipeline_parallel_size * \
352-
self.tensor_parallel_size
354+
self.tensor_parallel_size * self.context_parallel_size
353355

354356
if self.data_parallel_size_local > self.data_parallel_size:
355357
raise ValueError(

vllm/distributed/parallel_state.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,24 @@ def get_pp_group() -> GroupCoordinator:
982982
return _PP
983983

984984

985+
_CP: Optional[GroupCoordinator] = None
986+
987+
988+
def get_cp_group() -> GroupCoordinator:
989+
assert _CP is not None, ("context parallel group is not initialized")
990+
return _CP
991+
992+
993+
def get_context_model_parallel_world_size():
994+
"""Return world size for the tensor model parallel group."""
995+
return get_cp_group().world_size
996+
997+
998+
def get_context_model_parallel_rank():
999+
"""Return my rank for the tensor model parallel group."""
1000+
return get_cp_group().rank_in_group
1001+
1002+
9851003
@deprecated("`get_pipeline_model_parallel_group` has been replaced with "
9861004
"`get_pp_group` and may be removed in v0.12. Please use "
9871005
"`get_pp_group` instead.")
@@ -1088,6 +1106,7 @@ def init_distributed_environment(world_size: int = -1,
10881106
def initialize_model_parallel(
10891107
tensor_model_parallel_size: int = 1,
10901108
pipeline_model_parallel_size: int = 1,
1109+
context_model_parallel_size: int = 1,
10911110
decode_context_model_parallel_size: Optional[int] = 1,
10921111
backend: Optional[str] = None,
10931112
) -> None:
@@ -1138,7 +1157,7 @@ def initialize_model_parallel(
11381157
# last dimension, then reshape to 2D, then unbind the last dimension
11391158
all_ranks = torch.arange(world_size).reshape(
11401159
-1, data_parallel_size, pipeline_model_parallel_size,
1141-
tensor_model_parallel_size) # noqa
1160+
context_model_parallel_size, tensor_model_parallel_size) # noqa
11421161

11431162
# Build the tensor model-parallel groups.
11441163
global _TP
@@ -1174,7 +1193,7 @@ def initialize_model_parallel(
11741193
global _PP
11751194
assert _PP is None, (
11761195
"pipeline model parallel group is already initialized")
1177-
group_ranks = all_ranks.transpose(2, 3).reshape(
1196+
group_ranks = all_ranks.transpose(2, 4).reshape(
11781197
-1, pipeline_model_parallel_size).unbind(0)
11791198
group_ranks = [x.tolist() for x in group_ranks]
11801199
_PP = init_model_parallel_group(group_ranks,
@@ -1185,7 +1204,7 @@ def initialize_model_parallel(
11851204
global _DP
11861205
assert _DP is None, ("data parallel group is already initialized")
11871206
group_ranks = all_ranks.transpose(1,
1188-
3).reshape(-1,
1207+
4).reshape(-1,
11891208
data_parallel_size).unbind(0)
11901209
group_ranks = [x.tolist() for x in group_ranks]
11911210
_DP = init_model_parallel_group(group_ranks,
@@ -1196,23 +1215,34 @@ def initialize_model_parallel(
11961215
global _EP
11971216
assert _EP is None, ("expert parallel group is already initialized")
11981217
group_ranks = all_ranks.transpose(1, 2).reshape(
1199-
-1, data_parallel_size * tensor_model_parallel_size).unbind(0)
1218+
-1, data_parallel_size * tensor_model_parallel_size * context_model_parallel_size).unbind(0)
12001219
group_ranks = [x.tolist() for x in group_ranks]
12011220
_EP = init_model_parallel_group(group_ranks,
12021221
get_world_group().local_rank,
12031222
backend,
12041223
group_name="ep")
12051224

1225+
global _CP
1226+
assert _CP is None, ("context parallel group is already initialized")
1227+
group_ranks = all_ranks.transpose(3, 4).reshape(
1228+
-1, context_model_parallel_size).unbind(0)
1229+
group_ranks = [x.tolist() for x in group_ranks]
1230+
_CP = init_model_parallel_group(group_ranks,
1231+
get_world_group().local_rank,
1232+
backend,
1233+
group_name="cp")
1234+
12061235
logger.info(
12071236
"rank %s in world size %s is assigned as "
1208-
"DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
1237+
"DP rank %s, PP rank %s, TP rank %s, EP rank %s, CP rank %s", rank, world_size,
12091238
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
1210-
_EP.rank_in_group)
1239+
_EP.rank_in_group, _CP.rank_in_group)
12111240

12121241

12131242
def ensure_model_parallel_initialized(
12141243
tensor_model_parallel_size: int,
12151244
pipeline_model_parallel_size: int,
1245+
context_model_parallel_size: int,
12161246
decode_context_model_parallel_size: Optional[int] = 1,
12171247
backend: Optional[str] = None,
12181248
) -> None:
@@ -1225,6 +1255,7 @@ def ensure_model_parallel_initialized(
12251255
if not model_parallel_is_initialized():
12261256
initialize_model_parallel(tensor_model_parallel_size,
12271257
pipeline_model_parallel_size,
1258+
context_model_parallel_size,
12281259
decode_context_model_parallel_size, backend)
12291260
return
12301261

@@ -1238,6 +1269,11 @@ def ensure_model_parallel_initialized(
12381269
"pipeline parallel group already initialized, but of unexpected size. "
12391270
f"got: {pp_world_size=} vs. "
12401271
f"wanted: {pipeline_model_parallel_size=}")
1272+
cp_world_size = get_cp_group().world_size
1273+
assert (cp_world_size == context_model_parallel_size), (
1274+
"context parallel group already initialized, but of unexpected size: "
1275+
f"{cp_world_size=} vs. "
1276+
f"{context_model_parallel_size=}")
12411277

12421278

12431279
def prepare_communication_buffer_for_model(model: torch.nn.Module):
@@ -1345,6 +1381,11 @@ def destroy_model_parallel():
13451381
_EP.destroy()
13461382
_EP = None
13471383

1384+
global _CP
1385+
if _CP:
1386+
_CP.destroy()
1387+
_CP = None
1388+
13481389

13491390
def destroy_distributed_environment():
13501391
global _WORLD, _NODE_COUNT

vllm/engine/arg_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ class EngineArgs:
318318
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
319319
decode_context_parallel_size: int = \
320320
ParallelConfig.decode_context_parallel_size
321+
context_parallel_size: int = ParallelConfig.context_parallel_size
321322
data_parallel_size: int = ParallelConfig.data_parallel_size
322323
data_parallel_rank: Optional[int] = None
323324
data_parallel_start_rank: Optional[int] = None
@@ -653,6 +654,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
653654
parallel_group.add_argument(
654655
"--decode-context-parallel-size", "-dcp",
655656
**parallel_kwargs["decode_context_parallel_size"])
657+
parallel_group.add_argument(
658+
"--context-parallel-size", "-cp",
659+
**parallel_kwargs["context_parallel_size"])
656660
parallel_group.add_argument("--data-parallel-size", "-dp",
657661
**parallel_kwargs["data_parallel_size"])
658662
parallel_group.add_argument(
@@ -1310,6 +1314,7 @@ def create_engine_config(
13101314
parallel_config = ParallelConfig(
13111315
pipeline_parallel_size=self.pipeline_parallel_size,
13121316
tensor_parallel_size=self.tensor_parallel_size,
1317+
context_parallel_size=self.context_parallel_size,
13131318
data_parallel_size=self.data_parallel_size,
13141319
data_parallel_rank=self.data_parallel_rank or 0,
13151320
data_parallel_external_lb=data_parallel_external_lb,
@@ -1369,7 +1374,7 @@ def create_engine_config(
13691374
long_prefill_token_threshold=self.long_prefill_token_threshold,
13701375
disable_hybrid_kv_cache_manager=self.
13711376
disable_hybrid_kv_cache_manager,
1372-
async_scheduling=self.async_scheduling,
1377+
async_scheduling=self.async_scheduling
13731378
)
13741379

13751380
if not model_config.is_multimodal_model and self.default_mm_loras:

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import vllm.envs as envs
99
from vllm.config import ParallelConfig
10-
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
10+
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank, get_context_model_parallel_rank
1111
from vllm.logger import init_logger
1212
from vllm.model_executor.layers.quantization.utils.quant_utils import (
1313
GroupShape)
@@ -566,9 +566,11 @@ def biased_moe_quant_config(
566566
@dataclass
567567
class FusedMoEParallelConfig:
568568
tp_size: int
569+
cp_size: int
569570
dp_size: int
570571
ep_size: int
571572
tp_rank: int
573+
cp_rank: int
572574
dp_rank: int
573575
ep_rank: int
574576

@@ -594,15 +596,15 @@ def use_deepep_ll_kernels(self):
594596
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
595597

596598
@staticmethod
597-
def make(tp_size_: int, dp_size_: int,
599+
def make(tp_size_: int, dp_size_: int, cp_size_: int,
598600
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
599601
"""
600602
Determine MoE parallel configuration. Based on the input `tp_size_`,
601603
`dp_size_` and vllm's parallel config, determine what
602604
level's of parallelism to use in the fused moe layer.
603605
604606
Args:
605-
tp_size_ (int): `tp_size` passed into the FusedMoE constructor.
607+
tp_size_ (int): `tp_size` pa use_ep = (dp_size_ * tp_size_ssed into the FusedMoE constructor.
606608
dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
607609
vllm_parallel_config (ParallelConfig): vLLM's parallel config
608610
object which contains the `enable_expert_parallel` flag.
@@ -675,16 +677,20 @@ def flatten_tp_across_dp(dp_rank: int):
675677
tp_rank = dp_rank * tp_size_ + tp_rank
676678
return tp_size, tp_rank
677679

678-
use_ep = (dp_size_ * tp_size_ > 1
680+
use_ep = (dp_size_ * tp_size_ * cp_size_ > 1
679681
and vllm_parallel_config.enable_expert_parallel)
680682

681683
dp_size = dp_size_
682684
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
683685
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
686+
cp_size = cp_size_
687+
cp_rank = get_context_model_parallel_rank() if cp_size_ > 1 else 0
684688

685689
if not use_ep:
686690
return FusedMoEParallelConfig(tp_size=tp_size,
687691
tp_rank=tp_rank,
692+
cp_size=cp_size,
693+
cp_rank=cp_rank,
688694
dp_size=dp_size,
689695
dp_rank=dp_rank,
690696
ep_size=1,
@@ -694,10 +700,12 @@ def flatten_tp_across_dp(dp_rank: int):
694700
assert use_ep
695701
# In EP, each device owns a set of experts fully. There is no tensor
696702
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
697-
ep_size = tp_size
698-
ep_rank = tp_rank
703+
ep_size = tp_size * cp_size
704+
ep_rank = tp_rank + tp_size * cp_rank
699705
return FusedMoEParallelConfig(tp_size=1,
700706
tp_rank=0,
707+
cp_size=1,
708+
cp_rank=0,
701709
dp_size=dp_size,
702710
dp_rank=dp_rank,
703711
ep_size=ep_size,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from vllm.config.parallel import ExpertPlacementStrategy
1616
from vllm.distributed import (get_dp_group, get_ep_group,
1717
get_tensor_model_parallel_world_size,
18+
get_context_model_parallel_world_size,
1819
tensor_model_parallel_all_reduce)
1920
from vllm.distributed.eplb.eplb_state import EplbState
2021
from vllm.forward_context import ForwardContext, get_forward_context
@@ -828,6 +829,7 @@ def __init__(
828829
tp_size: Optional[int] = None,
829830
ep_size: Optional[int] = None,
830831
dp_size: Optional[int] = None,
832+
cp_size: Optional[int] = None,
831833
prefix: str = "",
832834
custom_routing_function: Optional[Callable] = None,
833835
scoring_func: str = "softmax",
@@ -849,6 +851,8 @@ def __init__(
849851
get_tensor_model_parallel_world_size())
850852
dp_size_ = (dp_size
851853
if dp_size is not None else get_dp_group().world_size)
854+
cp_size_ = (cp_size
855+
if cp_size is not None else get_context_model_parallel_world_size())
852856

853857
self.is_sequence_parallel = is_sequence_parallel
854858
if self.is_sequence_parallel:
@@ -859,6 +863,7 @@ def __init__(
859863
FusedMoEParallelConfig.make(
860864
tp_size_=tp_size_,
861865
dp_size_=dp_size_,
866+
cp_size_=cp_size_,
862867
vllm_parallel_config=vllm_config.parallel_config))
863868

864869
self.global_num_experts = num_experts + num_redundant_experts

0 commit comments

Comments
 (0)