Skip to content
10 changes: 10 additions & 0 deletions python/paddle/distributed/auto_parallel/pipelining/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,9 @@ def _step_microbatches(
for work in bwd_sends_to_wait:
work.wait()

# Synchronize the gradients of shared parameters.
self._stage._sync_shared_param_grads()


class Schedule1F1B(PipelineScheduleSingle):
"""
Expand Down Expand Up @@ -681,6 +684,9 @@ def _step_microbatches(
# Return losses if there is a container passed in
self._update_losses(self._stage, losses)

# Synchronize the gradients of shared parameters.
self._stage._sync_shared_param_grads()


class PipelineScheduleMulti(_PipelineSchedule):
"""
Expand Down Expand Up @@ -979,6 +985,10 @@ def _step_microbatches(
# Return losses if there is a container passed in
self._update_losses(self._stages, losses)

# Synchronize the gradients of shared parameters.
for stage in self._stages:
stage._sync_shared_param_grads()


def _get_1f1b_rank_ops(
n_local_stages,
Expand Down
174 changes: 169 additions & 5 deletions python/paddle/distributed/auto_parallel/pipelining/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@
import logging
from abc import ABC, abstractmethod
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Union
from typing import Any, Callable, Union

import paddle
import paddle.distributed as dist
from paddle import nn
from paddle.base.framework import EagerParamBase
from paddle.distributed.auto_parallel.api import (
dtensor_from_local,
dtensor_to_local,
)
from paddle.distributed.communication.group import Group

from ._backward import stage_backward
from .utils import (
Expand All @@ -40,9 +42,6 @@
map_structure,
)

if TYPE_CHECKING:
from paddle.distributed.communication.group import Group

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -845,7 +844,7 @@ class PipelineStage(_PipelineStageBase):
A class representing a pipeline stage in a pipeline parallelism setup.

PipelineStage assumes sequential partitioning of the model, i.e. the model is split into chunks where outputs from
one chunk feed into inputs of the next chunk, with no skip connections.
one chunk feed into inputs of the next chunk. Additionally, optimization of shared parameters is also supported here.

PipelineStage performs runtime shape/dtype inference automatically by propagating the outputs from stage0 to
stage1 and so forth, in linear order. To bypass shape inference, pass the `input_args` and `output_args` to each
Expand All @@ -858,6 +857,9 @@ class PipelineStage(_PipelineStageBase):
input_args (TensorMeta|tuple[TensorMeta, ...]|None): The input arguments for the layer.
output_args (TensorMeta|tuple[TensorMeta, ...]|None): The output arguments for the layer.
group (Group, None): The process group for distributed training. If None, default group.
shared_parameters (list[dict[str, list[EagerParamBase]]]|None): A list of dictionaries defining shared parameter
pairs between pipeline stages. Each dictionary represents a unique parameter pair with:
- "params" (list[EagerParamBase], required): Exactly 2 parameters to share across stages.
"""

def __init__(
Expand All @@ -868,13 +870,18 @@ def __init__(
input_args: TensorMeta | tuple[TensorMeta, ...] | None = None,
output_args: TensorMeta | tuple[TensorMeta, ...] | None = None,
group: Group | None = None,
shared_parameters: list[dict[str, list[EagerParamBase]]] | None = None,
):
super().__init__(layer, stage_index, num_stages, group)
self.inputs: list[paddle.Tensor] | None = None
self.inputs_meta: tuple[TensorMeta, ...] | None = None
# output's grad meta-info
self.grads_meta: tuple[TensorMeta, ...] | None = None

# Synchronize shared parameters on the current rank.
self.shared_parameters = shared_parameters
self._sync_shared_param()

if input_args is None:
assert output_args is None, (
"If specifying output_args, input_args must also be specified. "
Expand Down Expand Up @@ -927,6 +934,163 @@ def stage_global_rank(peer_rank):

logger.debug(dbg_str)

def _sync_shared_param(self):
if self.shared_parameters is None:
# 1. Default no shared parameters to process.
self.shared_parameters = {}
return

# 2. Validate parameters.
# TODO(xuexixi): Currently, shared parameter information relies on user input, so strict validation is required here.
# A more robust interface implementation is desired in the future.
self._validate_shared_parameter_pair()

# 3. Build shared parameter information for the current rank.
self._init_shared_group()

# 4. Synchronize the initialized shared parameters.
# When initializing the stage, perform broadcast synchronization on the shared parameters.
for idx, a_map in enumerate(self.shared_parameters):
shared_param = a_map["shared_param"]
if shared_param is None or not shared_param._is_initialized():
# Skip processing shared parameters that are not assigned to the current rank.
continue
group = a_map.get("group")
assert group is not None and dist.get_rank() in group.ranks
logger.debug(
f"Call `broadcast` for synchronization of Shared parameter pair at index {idx}",
)
with paddle.no_grad():
paddle.distributed.broadcast(
shared_param._local_value(),
src=group.ranks[0],
group=group,
)

def _validate_shared_parameter_pair(self):
# Validate shared_parameters structure.
assert isinstance(
self.shared_parameters, list
), f"Expected `shared_parameters` to return a list, but got {type(self.shared_parameters).__name__}. "

# Validate every pair shard parameter.
for idx, a_shared_map in enumerate(self.shared_parameters):
# Validate map structure.
assert isinstance(
a_shared_map, dict
), f"Invalid shared parameter pair: expected dict, but got {type(a_shared_map).__name__}."
assert len(a_shared_map) <= 3, (
f"shared_parameters['{idx}'] exceeds size limit (max 3 keys). "
f"Allowed: ['params', 'group', 'shared_param'], got: {list(a_shared_map.keys())}"
)
# Validate required 'params' entry.
params = a_shared_map.get("params")
assert (
params is not None
), f"Missing shared parameter 'params' not found in shared_parameters['{idx}']. Available keys: {list(a_shared_map)}."
assert (isinstance(params, list) or tuple(params, list)) and len(
params
) == 2, f"Shared parameter only support 2 shared parameters in list or tuple, but got {len(params)}."
# Validate parameter types and placements.
param_1, param_2 = params
assert isinstance(param_1, EagerParamBase) and isinstance(
param_2, EagerParamBase
), (
f"Shared parameter expects parameters are 'EagerParamBase' type, but got "
f"'{type(param_1).__name__}' and '{type(param_2).__name__}' respectively."
)
assert param_1.placements == param_2.placements, (
f"Shared parameters must have identical placements for optimal performance."
f"But placements mismatch: {param_1.placements} vs {param_2.placements}"
)
# Validate process meshes.
ranks_1 = param_1.process_mesh.process_ids
ranks_2 = param_2.process_mesh.process_ids
assert len(ranks_1) == len(ranks_2)
assert (
ranks_1 != ranks_2
), f"Shared parameters must be on different stage meshes, but both are on {ranks_1}."

# In VPP mode, a same shared_parameters is reused across stage builds. To avoid redundant group creation, the 'shared_param'
# and 'group' attributes may already exist, as they are created during the `_init_shared_group` call.
# Validate optional 'group' entry.
if "group" in a_shared_map:
group = a_shared_map["group"]
assert group is None or isinstance(
group, Group
), f"Expected 'shared_parameters[{idx}][\"group\"]' is 'Group' or None, but got '{type(a_shared_map['group']).__name__}'."
# Validate optional 'sync_param' entry.
if "sync_param" in a_shared_map:
sync_param = a_shared_map["sync_param"]
assert sync_param is None or sync_param in list(
param_1, param_2
), f"Expected 'shared_parameters[{idx}][\"sync_param\"]' is one of the two params or None."

def _init_shared_group(self):
# Retrieve the parameters to be shared and the required communication group information for the current rank, and store them in
# the 'shared_param' and 'group' attributes of the shared_parameters respectively:
# - group (Group, optional): Communication group for sharing the current parameter pair on the current rank (auto-created if missing)
# - shared_param (EagerParamBase, optional): Parameter to be shared on the current rank, should be one of 'params'; if None, it means
# no sharing is required on this rank. (auto-populated if missing)
get_group_from_ranks = {}
for idx, a_map in enumerate(self.shared_parameters):
params = a_map["params"]
ranks_1 = params[0].process_mesh.process_ids
ranks_2 = params[1].process_mesh.process_ids
cur_rank = dist.get_rank()

# Build communication groups for every shared parameters pair.
for rank_1, rank_2 in zip(ranks_1, ranks_2):
group_ranks = tuple(sorted([rank_1, rank_2]))
if "group" in a_map:
# In VPP mode, since `shared_parameters`` is reused across stage creations,
# the 'group' may already exist, avoiding redundant group creation.
if cur_rank in group_ranks:
assert group_ranks == tuple(
a_map["group"].ranks
), f"Shared Parameter group ranks mismatch: expected {group_ranks}, but got {a_map['group'].ranks}. "
else:
if group_ranks not in get_group_from_ranks:
get_group_from_ranks[group_ranks] = dist.new_group(
ranks=list(group_ranks)
)
if cur_rank in group_ranks:
# Record `group` is communication group associated with the current rank.
a_map["group"] = get_group_from_ranks[group_ranks]
logger.debug(
f"Build communication group {a_map['group'].name} for Shared parameter pair at index {idx} in rank {cur_rank}"
)

# Find the shared parameter on the current rank.
# Record `shared_param` is None default no shared parameter exists on current rank.
cur_param = None
if cur_rank in ranks_1:
cur_param = params[0]
elif cur_rank in ranks_2:
cur_param = params[1]
# Record shared parameter associated with the current rank.
a_map["shared_param"] = cur_param

def _sync_shared_param_grads(self):
# After the stage scheduling ends, perform allreduce synchronization
# on the gradients of shared parameters.
for idx, a_map in enumerate(self.shared_parameters):
shared_param = a_map["shared_param"]
if shared_param is None or not shared_param._is_initialized():
# Skip processing shared parameters that are not assigned to the current rank.
continue
group = a_map.get("group")
assert group is not None and dist.get_rank() in group.ranks
logger.debug(
f"Call `all_reduce` for gradient synchronization of Shared parameter pair at index {idx}",
)
with paddle.no_grad():
paddle.distributed.all_reduce(
shared_param.grad._local_value(),
op=paddle.distributed.ReduceOp.SUM,
group=group,
)

def _shape_inference(
self,
args: tuple[Any, ...],
Expand Down
2 changes: 2 additions & 0 deletions test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
FLAGS_embedding_deterministic=1
NVIDIA_TF32_OVERRIDE=0)
py_test_modules(test_PP_Schedules MODULES test_PP_Schedules)
py_test_modules(test_pipeline_sync_shared_parameters MODULES
test_pipeline_sync_shared_parameters)
py_test_modules(
test_context_parallel
MODULES
Expand Down
Loading
Loading