@@ -203,7 +203,13 @@ def _cp_lse_common(
203203
204204 cp_attn_lse = cp_attn_lse .contiguous ()
205205 lses = cp_group .all_gather (cp_attn_lse , dim = 0 ).view_as (lses )
206- out , lse = correct_attn_out (cp_attn_out , lses , cp_group .rank_in_group , ctx , is_lse_base_on_e = is_lse_base_on_e )
206+ out , lse = correct_attn_out (
207+ cp_attn_out ,
208+ lses ,
209+ cp_group .rank_in_group ,
210+ ctx ,
211+ is_lse_base_on_e = is_lse_base_on_e ,
212+ )
207213 return out , lse
208214
209215
@@ -219,7 +225,9 @@ def cp_lse_ag_out_rs(
219225 cp_attn_out: [ B, H, D ]
220226 cp_attn_lse: [ B, H ]
221227 """
222- out , lse = _cp_lse_common (cp_attn_out , cp_attn_lse , cp_group , ctx = ctx , is_lse_base_on_e = is_lse_base_on_e )
228+ out , lse = _cp_lse_common (
229+ cp_attn_out , cp_attn_lse , cp_group , ctx = ctx , is_lse_base_on_e = is_lse_base_on_e
230+ )
223231 out = cp_group .reduce_scatter (out , dim = 1 )
224232
225233 if return_lse :
@@ -242,7 +250,9 @@ def cp_lse_ag_out_ar(
242250 cp_attn_out: [ B, H, D ]
243251 cp_attn_lse: [ B, H ]
244252 """
245- out , lse = _cp_lse_common (cp_attn_out , cp_attn_lse , cp_group , ctx = ctx , is_lse_base_on_e = is_lse_base_on_e )
253+ out , lse = _cp_lse_common (
254+ cp_attn_out , cp_attn_lse , cp_group , ctx = ctx , is_lse_base_on_e = is_lse_base_on_e
255+ )
246256 out = cp_group .all_reduce (out )
247257
248258 if return_lse :
0 commit comments