Skip to content

Commit 07878a3

Browse files
authored
rm _init_npu_pipeline_comm (#53150)
1 parent 43b950f commit 07878a3

File tree

1 file changed

+0
-84
lines changed

1 file changed

+0
-84
lines changed

python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py

Lines changed: 0 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -737,89 +737,6 @@ def _init_pair_comm(self, pair, ring_id):
737737
sync=False,
738738
)
739739

740-
def _init_npu_pipeline_comm(self, startup_block):
741-
assert (self.pp_degree % 2) == 0
742-
743-
max_ring_id = -1
744-
my_pair = []
745-
for pair in self.pipeline_pair:
746-
pair_key = pair[0] * 1000 + pair[1]
747-
ring_id = self.pp_ring_map[pair_key]
748-
max_ring_id = max(max_ring_id, ring_id)
749-
logger.info(f"pp pair:{pair}, ring_id: {ring_id}")
750-
751-
if self.pp_rank in pair:
752-
my_pair.append(pair)
753-
754-
# for example: self.pp_rank=2, self.pp_degree=4
755-
send_to_next_pair = (
756-
self.pp_rank,
757-
(self.pp_rank + 1) % self.pp_degree,
758-
) # 2->3
759-
recv_from_next_pair = (
760-
(self.pp_rank + 1) % self.pp_degree,
761-
self.pp_rank,
762-
) # 3->2
763-
recv_from_prev_pair = (
764-
(self.pp_rank - 1 + self.pp_degree) % self.pp_degree,
765-
self.pp_rank,
766-
) # 1->2
767-
send_to_prev_pair = (
768-
self.pp_rank,
769-
(self.pp_rank - 1 + self.pp_degree) % self.pp_degree,
770-
) # 2->1
771-
772-
even = (self.pp_rank % 2) == 0
773-
774-
# 1. even send to next, odd recv from prev, 0->1, 2->3
775-
pair = send_to_next_pair if even else recv_from_prev_pair
776-
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
777-
self._init_pair_comm(pair, ring_id)
778-
my_pair.remove(pair)
779-
logger.info(f"pair0(even->odd): pp pair:{pair}, ring_id: {ring_id}")
780-
781-
# 2. even recv from next, odd send to prev, 1->0, 3->2
782-
pair = recv_from_next_pair if even else send_to_prev_pair
783-
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
784-
self._init_pair_comm(pair, ring_id)
785-
my_pair.remove(pair)
786-
logger.info(f"pair1(even<-odd): pp pair:{pair}, ring_id: {ring_id}")
787-
788-
# if pp_degree is 2, only need pair(0->1, 1->0)
789-
if self.pp_degree > 2:
790-
# 3. odd send to next, even recv from prev, 1->2, 3->0
791-
pair = send_to_next_pair if not even else recv_from_prev_pair
792-
ring_id = self.pp_ring_map.get(
793-
pair[0] * 1000 + pair[1], max_ring_id + 1
794-
) # 3->0 not in pp_ring_map
795-
self._init_pair_comm(pair, ring_id)
796-
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
797-
my_pair.remove(pair)
798-
logger.info(
799-
"pair2(odd->even): pp pair:{}, ring_id: {}".format(
800-
pair, ring_id
801-
)
802-
)
803-
804-
# 4. odd recv from next, even send to prev, 2->1, 0->3
805-
pair = recv_from_next_pair if not even else send_to_prev_pair
806-
ring_id = self.pp_ring_map.get(
807-
pair[0] * 1000 + pair[1], max_ring_id + 2
808-
) # 0->3 not in pp_ring_map
809-
self._init_pair_comm(pair, ring_id)
810-
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
811-
my_pair.remove(pair)
812-
logger.info(
813-
"pair3(odd<-even): pp pair:{}, ring_id: {}".format(
814-
pair, ring_id
815-
)
816-
)
817-
818-
assert len(my_pair) == 0, (
819-
"Current pipeline does not support cross stage communication, "
820-
"please check unexpected pair {}".format(my_pair)
821-
)
822-
823740
def _init_pipeline_comm(self, startup_block):
824741
# TODO (JZ-LIANG) to unify pp_rank_ and pp_rank
825742
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None:
@@ -834,7 +751,6 @@ def _init_pipeline_comm(self, startup_block):
834751
)
835752

836753
if core.is_compiled_with_custom_device('npu'):
837-
self._init_npu_pipeline_comm(startup_block)
838754
return
839755

840756
# GPU

0 commit comments

Comments
 (0)