Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 89 additions & 15 deletions python/paddle/distributed/fleet/meta_parallel/dualpipev.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def _forward_backward_compute(
backward_phase: int,
micro_datasets=None,
combine_backward_event_to_wait=None,
pass_pp_stream=False,
) -> None:
if self.forward_only:
self._forward_compute(forward_phase, micro_datasets)
Expand Down Expand Up @@ -319,8 +320,12 @@ def _forward_backward_compute(
backward_grads,
self.scaler,
combine_bw_event_to_wait=combine_backward_event_to_wait,
pp_stream=self.pp_group.process_group.get_stream(
paddle.framework._current_expected_place_()
pp_stream=(
self.pp_group.process_group.get_stream(
paddle.framework._current_expected_place_()
)
if pass_pp_stream
else None
),
)
)
Expand All @@ -339,7 +344,9 @@ def _forward_backward_compute(
backward_phase, backward_acc_id, input_grads=backward_input_grads
)

def _commit_and_wait_comm(self) -> None:
def _commit_and_wait_comm(
self, p2p_overlap=False, use_outer_event_wait=False
) -> None:
common_forward_ops_num = (
len(self.comm_forward_ops)
if self.comm_forward_ops is not None
Expand All @@ -355,18 +362,26 @@ def _commit_and_wait_comm(self) -> None:
paddle.device.current_stream().stream_base
)

use_stream_wait_event = self._overlap_p2p_comm and deep_ep is not None
use_stream_wait_event = (
p2p_overlap and self._overlap_p2p_comm and deep_ep is not None
)

pp_raw_stream = self.pp_group.process_group.get_stream(
paddle.framework._current_expected_place_()
)
if use_outer_event_wait:
self.pp_group.process_group.set_outer_wait(True)

if common_forward_ops_num > 0:
fwd_reqs = batch_isend_irecv(self.comm_forward_ops)

if not use_stream_wait_event:
for req in fwd_reqs:
req.wait()

if use_outer_event_wait:
self.pp_group.process_group.set_outer_wait(False)

if use_stream_wait_event:
forward_event_to_wait = deep_ep.get_event_from_custom_stream(
pp_raw_stream
Expand Down Expand Up @@ -524,29 +539,49 @@ def _forward_backward_pass(
backward_phase: int,
micro_datasets=None,
recv0: bool = True,
first_chunk=False,
last_chunk=False,
main_stage=False,
last_stage_and_first_chunk=False,
) -> None:
if recv0:
self._recv_forward(forward_phase)
self._recv_backward(backward_phase)

use_outer_wait = (
self._overlap_p2p_comm
need_send_forward = not (
self.is_pipeline_first_stage() and forward_phase == 1
) or (self.is_pipeline_last_stage() and forward_phase == 0)
need_send_backward = not (
self.is_pipeline_first_stage() and backward_phase == 0
) or (self.is_pipeline_last_stage() and backward_phase == 1)

use_outer_event_wait = (
main_stage
and not first_chunk
and self._overlap_p2p_comm
and deep_ep is not None
and (len(self.comm_forward_ops) > 0)
and (need_send_forward and need_send_backward)
)

if use_outer_wait:
self.pp_group.process_group.set_outer_wait(True)
pass_pp_stream = (
main_stage
and not last_chunk
and self._overlap_p2p_comm
and deep_ep is not None
and (need_send_forward and need_send_backward)
and (not last_stage_and_first_chunk)
)

combine_bw_wait_event = self._commit_and_wait_comm()
combine_bw_wait_event = self._commit_and_wait_comm(
not last_chunk, use_outer_event_wait
)

if use_outer_wait:
self.pp_group.process_group.set_outer_wait(False)
self._forward_backward_compute(
forward_phase,
backward_phase,
micro_datasets,
combine_backward_event_to_wait=combine_bw_wait_event,
pass_pp_stream=pass_pp_stream,
)

self._send_forward(forward_phase)
Expand Down Expand Up @@ -663,7 +698,11 @@ def forward_backward_pipeline(

# Step 4 (Main step): nF0B1F1B0
step_4 = self.accumulate_steps - num_ranks * 2 + rank + 1
have_step5 = num_ranks - rank - 1 > 0
# Update code to support send/recv overlap
# Only support send/recv overlap in MainStep
for i in range(step_4):
is_last_chunk = i + 1 == step_4
if i == 0:
if self.is_pipeline_last_stage():
# NOTE: We don't overlap these two passes to further reduce bubble size.
Expand All @@ -674,13 +713,48 @@ def forward_backward_pipeline(
self._backward_pass(1, send=False)
self._send_forward(0)
self._send_backward(1)

self._forward_backward_pass(
1,
0,
micro_datasets,
first_chunk=True,
last_chunk=is_last_chunk,
main_stage=True,
)
else:
self._forward_backward_pass(
0, 1, micro_datasets, recv0=False
0,
1,
micro_datasets,
recv0=False,
first_chunk=True,
main_stage=True,
)

self._forward_backward_pass(
1,
0,
micro_datasets,
last_chunk=is_last_chunk,
main_stage=True,
)
else:
self._forward_backward_pass(0, 1, micro_datasets)
self._forward_backward_pass(1, 0, micro_datasets)

self._forward_backward_pass(
0,
1,
micro_datasets,
main_stage=True,
last_stage_and_first_chunk=self.is_pipeline_last_stage(),
)
self._forward_backward_pass(
1,
0,
micro_datasets,
last_chunk=is_last_chunk,
main_stage=True,
)

# Step 5: nB1F1B0
step_5 = num_ranks - rank - 1
Expand Down
Loading