Skip to content
Open
29 changes: 25 additions & 4 deletions tests/distributed/test_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
dcp_size: int
pcp_size: int
eager_mode: bool
chunked_prefill: bool


class CPTestOptions(NamedTuple):
multi_node_only: bool
load_format: str | None = None
attn_backend: str = "FLASH_ATTN"


@dataclass
Expand All @@ -52,20 +54,25 @@ def detailed(
tp_base: int = 4,
pp_base: int = 1,
dcp_base: int = 1,
pcp_base: int = 1,
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: str | None = None,
attn_backend: str = "FLASH_ATTN",
):
parallel_setups = []
for eager_mode_val in [False]:
for pp_multiplier in [1]:
for dcp_multiplier in [0.5, 1]:
# TODO(qcs): Test the effect of mixed activation
# when PCP and DCP are compatible.
for pcp_multiplier, dcp_multiplier in zip([1, 2, 1], [0.5, 1, 1]):
for chunked_prefill_val in [True]:
parallel_setups.append(
ParallelSetup(
tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
dcp_size=int(dcp_multiplier * tp_base),
pcp_size=int(pcp_multiplier * pcp_base),
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val,
)
Expand All @@ -75,7 +82,9 @@ def detailed(
distributed_backends=["mp"],
runner=runner,
test_options=CPTestOptions(
multi_node_only=multi_node_only, load_format=load_format
multi_node_only=multi_node_only,
load_format=load_format,
attn_backend=attn_backend,
),
)

Expand Down Expand Up @@ -108,11 +117,12 @@ def _compare_cp_with_tp(
tp_size,
pp_size,
dcp_size,
pcp_size,
eager_mode,
chunked_prefill,
) = parallel_setup

multi_node_only, load_format = test_options
multi_node_only, load_format, attn_backend = test_options

model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")
Expand Down Expand Up @@ -155,7 +165,7 @@ def _compare_cp_with_tp(
"--max-model-len",
"2048",
"--max-num-seqs",
"8",
"16",
]
if chunked_prefill:
common_args.append("--enable-chunked-prefill")
Expand All @@ -172,6 +182,10 @@ def _compare_cp_with_tp(
if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])

cp_env = tp_env = {
"VLLM_ATTENTION_BACKEND": attn_backend,
}

cp_args = [
*common_args,
"--tensor-parallel-size",
Expand All @@ -180,6 +194,8 @@ def _compare_cp_with_tp(
str(pp_size),
"--decode-context-parallel-size",
str(dcp_size),
"--prefill-context-parallel-size",
str(pcp_size),
"--distributed-executor-backend",
distributed_backend,
]
Expand All @@ -198,19 +214,24 @@ def _compare_cp_with_tp(
model_id,
cp_args,
tp_args,
cp_env,
tp_env,
method=method,
max_wait_seconds=720,
)


CP_TEXT_GENERATION_MODELS = {
# [MLA attention only]
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
],
"bigcode/gpt_bigcode-santacoder": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
CPTestSettings.detailed(attn_backend="FLASHINFER"),
CPTestSettings.detailed(tp_base=2, attn_backend="FLASHINFER"),
],
}

Expand Down
13 changes: 13 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ class AttentionImpl(ABC, Generic[T]):
dcp_world_size: int
dcp_rank: int

pcp_world_size: int
pcp_rank: int

def __new__(cls, *args, **kwargs):
# use __new__ so that all subclasses will call this
self = super().__new__(cls)
Expand All @@ -139,6 +142,16 @@ def __new__(cls, *args, **kwargs):
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
try:
from vllm.distributed.parallel_state import get_pcp_group

self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group
except AssertionError:
# PCP might not be initialized in testing
self.pcp_world_size = 1
self.pcp_rank = 0

self.need_to_return_lse_for_decode = (
self.dcp_world_size > 1 and self.can_return_lse_for_decode
)
Expand Down
33 changes: 31 additions & 2 deletions vllm/attention/ops/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,11 @@ def correct_attn_out(
return out, lse


def cp_lse_ag_out_rs(
def _cp_lse_common(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None,
return_lse=False,
):
"""
cp_attn_out: [ B, H, D ]
Expand All @@ -195,6 +194,21 @@ def cp_lse_ag_out_rs(
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
assert out.is_contiguous()
return out, lse


def cp_lse_ag_out_rs(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None,
return_lse: bool = False,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
out = cp_group.reduce_scatter(out, dim=1)

if return_lse:
Expand All @@ -205,6 +219,21 @@ def cp_lse_ag_out_rs(
return out


def cp_lse_ag_out_ar(
cp_attn_out: torch.Tensor,
cp_attn_lse: torch.Tensor,
cp_group: GroupCoordinator,
ctx: CPTritonContext = None,
):
"""
cp_attn_out: [ B, H, D ]
cp_attn_lse: [ B, H ]
"""
out, lse = _cp_lse_common(cp_attn_out, cp_attn_lse, cp_group, ctx=ctx)
out = cp_group.all_reduce(out)
return out


@triton.jit
def _pack_seq_kernel(
x_ptr, # [N, D]
Expand Down
8 changes: 7 additions & 1 deletion vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class ParallelConfig:
"""Number of pipeline parallel groups."""
tensor_parallel_size: int = 1
"""Number of tensor parallel groups."""
prefill_context_parallel_size: int = 1
"""Number of prefill context parallel groups."""
data_parallel_size: int = 1
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
Expand Down Expand Up @@ -467,7 +469,11 @@ def __post_init__(self) -> None:
)

# Continue with the rest of the initialization
self.world_size = self.pipeline_parallel_size * self.tensor_parallel_size
self.world_size = (
self.pipeline_parallel_size
* self.tensor_parallel_size
* self.prefill_context_parallel_size
)

if self.distributed_executor_backend == "external_launcher":
logger.info("Using external launcher for distributed inference.")
Expand Down
9 changes: 9 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,15 @@ def __post_init__(self):
):
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

# prefill context parallel do not support full cudagraphs now.
if self.parallel_config.prefill_context_parallel_size > 1:
logger.warning(
"Prefill context parallel (PCP) is enabled, which is "
"incompatible with full CUDA graphs. Set "
"cudagraph_mode to PIECEWISE."
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE

# decode context parallel do not support full cudagraphs now.
if self.parallel_config.decode_context_parallel_size > 1:
logger.warning(
Expand Down
Loading