@@ -737,89 +737,6 @@ def _init_pair_comm(self, pair, ring_id):
737737 sync = False ,
738738 )
739739
740- def _init_npu_pipeline_comm (self , startup_block ):
741- assert (self .pp_degree % 2 ) == 0
742-
743- max_ring_id = - 1
744- my_pair = []
745- for pair in self .pipeline_pair :
746- pair_key = pair [0 ] * 1000 + pair [1 ]
747- ring_id = self .pp_ring_map [pair_key ]
748- max_ring_id = max (max_ring_id , ring_id )
749- logger .info (f"pp pair:{ pair } , ring_id: { ring_id } " )
750-
751- if self .pp_rank in pair :
752- my_pair .append (pair )
753-
754- # for example: self.pp_rank=2, self.pp_degree=4
755- send_to_next_pair = (
756- self .pp_rank ,
757- (self .pp_rank + 1 ) % self .pp_degree ,
758- ) # 2->3
759- recv_from_next_pair = (
760- (self .pp_rank + 1 ) % self .pp_degree ,
761- self .pp_rank ,
762- ) # 3->2
763- recv_from_prev_pair = (
764- (self .pp_rank - 1 + self .pp_degree ) % self .pp_degree ,
765- self .pp_rank ,
766- ) # 1->2
767- send_to_prev_pair = (
768- self .pp_rank ,
769- (self .pp_rank - 1 + self .pp_degree ) % self .pp_degree ,
770- ) # 2->1
771-
772- even = (self .pp_rank % 2 ) == 0
773-
774- # 1. even send to next, odd recv from prev, 0->1, 2->3
775- pair = send_to_next_pair if even else recv_from_prev_pair
776- ring_id = self .pp_ring_map [pair [0 ] * 1000 + pair [1 ]]
777- self ._init_pair_comm (pair , ring_id )
778- my_pair .remove (pair )
779- logger .info (f"pair0(even->odd): pp pair:{ pair } , ring_id: { ring_id } " )
780-
781- # 2. even recv from next, odd send to prev, 1->0, 3->2
782- pair = recv_from_next_pair if even else send_to_prev_pair
783- ring_id = self .pp_ring_map [pair [0 ] * 1000 + pair [1 ]]
784- self ._init_pair_comm (pair , ring_id )
785- my_pair .remove (pair )
786- logger .info (f"pair1(even<-odd): pp pair:{ pair } , ring_id: { ring_id } " )
787-
788- # if pp_degree is 2, only need pair(0->1, 1->0)
789- if self .pp_degree > 2 :
790- # 3. odd send to next, even recv from prev, 1->2, 3->0
791- pair = send_to_next_pair if not even else recv_from_prev_pair
792- ring_id = self .pp_ring_map .get (
793- pair [0 ] * 1000 + pair [1 ], max_ring_id + 1
794- ) # 3->0 not in pp_ring_map
795- self ._init_pair_comm (pair , ring_id )
796- if self .pp_rank != 0 and self .pp_rank != self .pp_degree - 1 :
797- my_pair .remove (pair )
798- logger .info (
799- "pair2(odd->even): pp pair:{}, ring_id: {}" .format (
800- pair , ring_id
801- )
802- )
803-
804- # 4. odd recv from next, even send to prev, 2->1, 0->3
805- pair = recv_from_next_pair if not even else send_to_prev_pair
806- ring_id = self .pp_ring_map .get (
807- pair [0 ] * 1000 + pair [1 ], max_ring_id + 2
808- ) # 0->3 not in pp_ring_map
809- self ._init_pair_comm (pair , ring_id )
810- if self .pp_rank != 0 and self .pp_rank != self .pp_degree - 1 :
811- my_pair .remove (pair )
812- logger .info (
813- "pair3(odd<-even): pp pair:{}, ring_id: {}" .format (
814- pair , ring_id
815- )
816- )
817-
818- assert len (my_pair ) == 0 , (
819- "Current pipeline does not support cross stage communication, "
820- "please check unexpected pair {}" .format (my_pair )
821- )
822-
823740 def _init_pipeline_comm (self , startup_block ):
824741 # TODO (JZ-LIANG) to unify pp_rank_ and pp_rank
825742 if os .getenv ("PADDLE_MANUAL_PIPELINE_STAGE" , None ) is None :
@@ -834,7 +751,6 @@ def _init_pipeline_comm(self, startup_block):
834751 )
835752
836753 if core .is_compiled_with_custom_device ('npu' ):
837- self ._init_npu_pipeline_comm (startup_block )
838754 return
839755
840756 # GPU
0 commit comments