@@ -168,12 +168,11 @@ def correct_attn_out(
168168 return out , lse
169169
170170
171- def cp_lse_ag_out_rs (
171+ def _cp_lse_common (
172172 cp_attn_out : torch .Tensor ,
173173 cp_attn_lse : torch .Tensor ,
174174 cp_group : GroupCoordinator ,
175175 ctx : CPTritonContext = None ,
176- return_lse = False ,
177176):
178177 """
179178 cp_attn_out: [ B, H, D ]
@@ -195,6 +194,21 @@ def cp_lse_ag_out_rs(
195194 lses = cp_group .all_gather (cp_attn_lse , dim = 0 ).view_as (lses )
196195 out , lse = correct_attn_out (cp_attn_out , lses , cp_group .rank_in_group , ctx )
197196 assert out .is_contiguous ()
197+ return out , lse
198+
199+
200+ def cp_lse_ag_out_rs (
201+ cp_attn_out : torch .Tensor ,
202+ cp_attn_lse : torch .Tensor ,
203+ cp_group : GroupCoordinator ,
204+ ctx : CPTritonContext = None ,
205+ return_lse : bool = False ,
206+ ):
207+ """
208+ cp_attn_out: [ B, H, D ]
209+ cp_attn_lse: [ B, H ]
210+ """
211+ out , lse = _cp_lse_common (cp_attn_out , cp_attn_lse , cp_group , ctx = ctx )
198212 out = cp_group .reduce_scatter (out , dim = 1 )
199213
200214 if return_lse :
@@ -215,22 +229,7 @@ def cp_lse_ag_out_ar(
215229 cp_attn_out: [ B, H, D ]
216230 cp_attn_lse: [ B, H ]
217231 """
218- if cp_group .world_size == 1 :
219- return cp_attn_out
220-
221- if ctx is None :
222- ctx = CPTritonContext ()
223-
224- lses = torch .empty (
225- (cp_group .world_size ,) + cp_attn_lse .shape ,
226- dtype = cp_attn_lse .dtype ,
227- device = cp_attn_lse .device ,
228- )
229-
230- cp_attn_lse = cp_attn_lse .contiguous ()
231- lses = cp_group .all_gather (cp_attn_lse , dim = 0 ).view_as (lses )
232- out , lse = correct_attn_out (cp_attn_out , lses , cp_group .rank_in_group , ctx )
233- assert out .is_contiguous ()
232+ out , lse = _cp_lse_common (cp_attn_out , cp_attn_lse , cp_group , ctx = ctx )
234233 out = cp_group .all_reduce (out )
235234 return out
236235
0 commit comments