@@ -88,6 +88,88 @@ def _correct_attn_cp_out_kernel(
8888 tl .store (new_output_ptr + output_offsets , output )
8989
9090
91+ @triton .jit
92+ def _correct_attn_cp_out_kernel_for_flashinfer (
93+ outputs_ptr ,
94+ new_output_ptr ,
95+ lses_ptr ,
96+ vlse_ptr ,
97+ outputs_stride_B ,
98+ outputs_stride_H ,
99+ outputs_stride_D ,
100+ lses_stride_N ,
101+ lses_stride_B ,
102+ lses_stride_H ,
103+ lse_idx ,
104+ HEAD_DIM : tl .constexpr ,
105+ N_ROUNDED : tl .constexpr ,
106+ ):
107+ """
108+ Apply the all-gathered lses to correct each local rank's attention
109+ output. we still need perform a cross-rank reduction to obtain the
110+ final attention output.
111+
112+ Args:
113+ outputs_ptr (triton.PointerType):
114+ Pointer to input tensor of shape [ B, H, D ]
115+ lses_ptr (triton.PointerType):
116+ Pointer to input tensor of shape [ N, B, H ]
117+ new_output_ptr (triton.PointerType):
118+ Pointer to output tensor of shape [ B, H, D ]
119+ vlse_ptr (triton.PointerType):
120+ Pointer to output tensor of shape [ B, H ]
121+ """
122+ batch_idx = tl .program_id (axis = 0 ).to (tl .int64 )
123+ head_idx = tl .program_id (axis = 1 ).to (tl .int64 )
124+ d_offsets = tl .arange (0 , HEAD_DIM )
125+ num_n_offsets = tl .arange (0 , N_ROUNDED )
126+
127+ # shape = [N]
128+ lse_offsets = (
129+ num_n_offsets * lses_stride_N
130+ + batch_idx * lses_stride_B
131+ + head_idx * lses_stride_H
132+ )
133+
134+ # calc final lse
135+ lse = tl .load (lses_ptr + lse_offsets )
136+ lse = tl .where ((lse != lse ) | (lse == float ("inf" )), - float ("inf" ), lse )
137+ lse_max = tl .max (lse , axis = 0 )
138+ lse_max = tl .where (lse_max == - float ("inf" ), 0 , lse_max )
139+ lse -= lse_max
140+ lse_exp = tl .exp2 (lse )
141+ lse_acc = tl .sum (lse_exp , axis = 0 )
142+ lse = tl .log2 (lse_acc )
143+ lse += lse_max
144+
145+ lse_offsets = batch_idx * lses_stride_B + head_idx * lses_stride_H
146+ tl .store (vlse_ptr + lse_offsets , lse )
147+
148+ # shape = [D]
149+ output_offsets = (
150+ batch_idx * outputs_stride_B
151+ + head_idx * outputs_stride_H
152+ + d_offsets * outputs_stride_D
153+ )
154+
155+ # correct output
156+ lse_offset = (
157+ lse_idx * lses_stride_N + batch_idx * lses_stride_B + head_idx * lses_stride_H
158+ )
159+ lse_tmp = tl .load (lses_ptr + lse_offset )
160+ lse_finally = lse_tmp - lse
161+ lse_finally = tl .where (
162+ (lse_finally != lse_finally ) | (lse_finally == float ("inf" )),
163+ - float ("inf" ),
164+ lse_finally ,
165+ )
166+ factor = tl .exp2 (lse_finally )
167+ output = tl .load (outputs_ptr + output_offsets )
168+ output = output * factor
169+
170+ tl .store (new_output_ptr + output_offsets , output )
171+
172+
91173class CPTritonContext :
92174 """The CPTritonContext is used to avoid recompilation of the Triton JIT."""
93175
@@ -102,7 +184,7 @@ def call_kernel(self, kernel, grid, *regular_args, **const_args):
102184
103185
104186def correct_attn_out (
105- out : torch .Tensor , lses : torch .Tensor , cp_rank : int , ctx : CPTritonContext
187+ out : torch .Tensor , lses : torch .Tensor , cp_rank : int , ctx : CPTritonContext , is_lse_base_on_e : bool = True ,
106188) -> tuple [torch .Tensor , torch .Tensor ]:
107189 """Correct the attention output using the all-gathered lses.
108190
@@ -164,8 +246,10 @@ def correct_attn_out(
164246 cp_rank ,
165247 )
166248 const_args = {"HEAD_DIM" : D , "N_ROUNDED" : N }
167-
168- ctx .call_kernel (_correct_attn_cp_out_kernel , grid , * regular_args , ** const_args )
249+ correct_attn_kernel = _correct_attn_cp_out_kernel
250+ if not is_lse_base_on_e :
251+ correct_attn_kernel = _correct_attn_cp_out_kernel_for_flashinfer
252+ ctx .call_kernel (correct_attn_kernel , grid , * regular_args , ** const_args )
169253 return out , lse
170254
171255
@@ -175,6 +259,7 @@ def cp_lse_ag_out_rs(
175259 cp_group : GroupCoordinator ,
176260 ctx : CPTritonContext = None ,
177261 return_lse = False ,
262+ is_lse_base_on_e = True ,
178263):
179264 """
180265 cp_attn_out: [ B, H, D ]
@@ -194,7 +279,7 @@ def cp_lse_ag_out_rs(
194279
195280 cp_attn_lse = cp_attn_lse .contiguous ()
196281 lses = cp_group .all_gather (cp_attn_lse , dim = 0 ).view_as (lses )
197- out , lse = correct_attn_out (cp_attn_out , lses , cp_group .rank_in_group , ctx )
282+ out , lse = correct_attn_out (cp_attn_out , lses , cp_group .rank_in_group , ctx , is_lse_base_on_e = is_lse_base_on_e )
198283 out = cp_group .reduce_scatter (out , dim = 1 )
199284
200285 if return_lse :
0 commit comments