diff --git a/flash_sparse_attn/ops/triton/activations.py b/flash_sparse_attn/ops/triton/activations.py index a46d0f1..83bc33e 100644 --- a/flash_sparse_attn/ops/triton/activations.py +++ b/flash_sparse_attn/ops/triton/activations.py @@ -95,6 +95,7 @@ def finalize( row_sum, scale_log2, final_scale, + IS_LOG2: tl.constexpr, ): """ Finalize online softmax by computing output scale and logsumexp. @@ -102,6 +103,7 @@ def finalize( :param row_max: Final maximum values per row of shape [BLOCK_M]. :param row_sum: Final sum values per row of shape [BLOCK_M]. :param final_scale: Scaling factor to be applied to the output. + :param IS_LOG2: Boolean flag indicating if the returned logsumexp should be in log2-space. :return row_scale: Final scaling factors per row of shape [BLOCK_M]. :return lse: Logsumexp values per row of shape [BLOCK_M]. @@ -109,13 +111,15 @@ def finalize( # if row_sum is zero or nan, set it to 1 to avoid division by zero acc_o_is_zero_or_nan = (row_sum == 0.0) | (row_sum != row_sum) row_scale = tl.where(acc_o_is_zero_or_nan, 1.0, 1.0 / row_sum) * final_scale - # ln2 = math.log(2.0) - ln2 = 0.6931471805599453 lse = tl.where( acc_o_is_zero_or_nan, float("-inf"), - (row_max * scale_log2 + tl.log2(row_sum)) * ln2, + row_max * scale_log2 + tl.log2(row_sum), ) + if not IS_LOG2: + # ln2 = math.log(2.0) + ln2 = 0.6931471805599453 + lse *= ln2 return row_scale, lse diff --git a/flash_sparse_attn/ops/triton/flash_dense_fwd.py b/flash_sparse_attn/ops/triton/flash_dense_fwd.py index 7f2eacd..d60cd04 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_dense_fwd.py @@ -579,6 +579,7 @@ def _fwd_dense_base_kernel( row_sum=row_sum, scale_log2=softmax_scale_log2, final_scale=1.0, + IS_LOG2=IS_SPLIT_KV, ) acc_o = activations.rescale_o(acc_o, row_scale, LAZY_RESCALE=False) diff --git a/flash_sparse_attn/ops/triton/flash_fwd_combine.py b/flash_sparse_attn/ops/triton/flash_fwd_combine.py index 2b24956..385bdb7 100644 --- a/flash_sparse_attn/ops/triton/flash_fwd_combine.py +++ b/flash_sparse_attn/ops/triton/flash_fwd_combine.py @@ -1,4 +1,3 @@ -import math import torch import triton import triton.language as tl @@ -34,14 +33,11 @@ def _fwd_combine_kernel( TILE_K: tl.constexpr, HAS_CU_SEQLENS_Q: tl.constexpr, HAS_SEQUSED_Q: tl.constexpr, - MAX_SPLITS: tl.constexpr, ): m_block = tl.program_id(0) - k_block = tl.program_id(1) - bh_idx = tl.program_id(2) + bh_idx = tl.program_id(1) batch_idx = bh_idx // num_heads_q head_idx = bh_idx - batch_idx * num_heads_q - offs_m = m_block * TILE_M + tl.arange(0, TILE_M) # Get seqlen info for this batch offset_q, actual_seqlen_q = seqlen_info.get_seqlen_info( @@ -52,7 +48,6 @@ def _fwd_combine_kernel( HAS_CU_SEQLENS=HAS_CU_SEQLENS_Q, HAS_SEQUSED=HAS_SEQUSED_Q, ) - mask_m = offs_m < actual_seqlen_q # Initialize base pointers out_part_base = seqlen_info.offset_batch_Q( @@ -85,13 +80,23 @@ def _fwd_combine_kernel( HAS_CU_SEQLENS_Q, USE_PADDED=False, ) + lse_base = seqlen_info.offset_batch_Q( + Lse + head_idx * stride_lh, + batch_idx, + offset_q, + 0, + stride_lb, + 1, + HAS_CU_SEQLENS_Q, + USE_PADDED=False, + ) # Create pointers out_part_ptrs = tl.make_block_ptr( base=out_part_base, shape=(num_splits, actual_seqlen_q, head_dim), strides=(stride_ops, stride_opm, 1), - offsets=(0, m_block * TILE_M, k_block * TILE_K), + offsets=(0, m_block * TILE_M, 0), block_shape=(1, TILE_M, TILE_K), order=(2, 1, 0), ) @@ -107,65 +112,54 @@ def _fwd_combine_kernel( base=out_base, shape=(actual_seqlen_q, head_dim), strides=(stride_om, 1), - offsets=(m_block * TILE_M, k_block * TILE_K), + offsets=(m_block * TILE_M, 0), block_shape=(TILE_M, TILE_K), order=(1, 0), ) + lse_ptrs = tl.make_block_ptr( + base=lse_base, + shape=(actual_seqlen_q,), + strides=(1,), + offsets=(m_block * TILE_M,), + block_shape=(TILE_M,), + order=(0,), + ) # Initialize accumulators - lse_vals = tl.full((MAX_SPLITS, TILE_M), float("-inf"), dtype=tl.float32) - max_lse = tl.full((TILE_M,), float("-inf"), dtype=tl.float32) + e_sum = tl.zeros((TILE_M,), dtype=tl.float32) + e_max = tl.full((TILE_M,), float("-inf"), dtype=tl.float32) acc_o = tl.zeros((TILE_M, TILE_K), dtype=tl.float32) - # Compute max across splits - for s in tl.static_range(MAX_SPLITS): - if s < num_splits: - lse_s = tl.sum(tl.load(lse_part_ptrs, boundary_check=(0, 1)), axis=0) - # boundary_check pads OOB with 0, fixup to -inf for correct softmax - lse_s = tl.where(mask_m, lse_s, float("-inf")) - lse_part_ptrs = tl.advance(lse_part_ptrs, (1, 0)) - else: - lse_s = tl.full((TILE_M,), float("-inf"), dtype=tl.float32) - lse_vals = tl.where( - (tl.arange(0, MAX_SPLITS) == s)[:, None], - lse_s[None, :], - lse_vals, - ) - max_lse = tl.maximum(max_lse, lse_s) + # Combine split outputs + for _ in tl.range(0, num_splits): + # Load partial LSE + lse_s = tl.sum(tl.load(lse_part_ptrs, boundary_check=(0, 1)), axis=0) - # if all -inf, use 0 to avoid nan in exp - max_lse = tl.where(max_lse == float("-inf"), 0.0, max_lse) + # Advance LSE pointers + lse_part_ptrs = tl.advance(lse_part_ptrs, (1, 0)) - # Compute normalized scales - sum_exp = tl.zeros((TILE_M,), dtype=tl.float32) - for s in tl.static_range(MAX_SPLITS): - lse_s = tl.sum( - tl.where( - (tl.arange(0, MAX_SPLITS) == s)[:, None], - lse_vals, - tl.full((MAX_SPLITS, TILE_M), 0.0, dtype=tl.float32), - ), - axis=0, - ) - exp_s = tl.where(s < num_splits, tl.exp(lse_s - max_lse), 0.0) - sum_exp += exp_s + # Compute normalized exponentials + new_e_max = tl.maximum(lse_s, e_max) + old_scale = tl.exp2(e_max - new_e_max) + exp_logic = tl.exp2(lse_s - new_e_max) + + # Load partial outputs + o_s = tl.sum(tl.load(out_part_ptrs, boundary_check=(0, 1, 2)), axis=0) + + # Advance output pointers + out_part_ptrs = tl.advance(out_part_ptrs, (1, 0, 0)) - inv_sum = tl.where((sum_exp == 0.0) | (sum_exp != sum_exp), 0.0, 1.0 / sum_exp) + # Compute scaled outputs + acc_o *= old_scale[:, None] + acc_o += exp_logic[:, None] * o_s - for s in tl.static_range(MAX_SPLITS): - if s < num_splits: - lse_s = tl.sum( - tl.where( - (tl.arange(0, MAX_SPLITS) == s)[:, None], - lse_vals, - tl.full((MAX_SPLITS, TILE_M), 0.0, dtype=tl.float32), - ), - axis=0, - ) - scale = tl.exp(lse_s - max_lse) * inv_sum - o_s = tl.sum(tl.load(out_part_ptrs, boundary_check=(0, 1, 2)), axis=0) - out_part_ptrs = tl.advance(out_part_ptrs, (1, 0, 0)) - acc_o += scale[:, None] * o_s + # Update e_sum and e_max + e_sum = e_sum * old_scale + exp_logic + e_max = new_e_max + + # Normalize output + inv_sum = tl.where((e_sum == 0.0) | (e_sum != e_sum), 0.0, 1.0 / e_sum) + acc_o *= inv_sum[:, None] # Store output tl.store( @@ -173,33 +167,14 @@ def _fwd_combine_kernel( acc_o.to(Out.dtype.element_ty), boundary_check=(0, 1), ) + + # Compute LSE + # ln2 = math.log(2.0) + ln2 = 0.6931471805599453 + lse = tl.where(e_sum > 0.0, (e_max + tl.log2(e_sum)) * ln2, float("-inf")) + # Store LSE - # Only from the first k_block to avoid duplicates - if k_block == 0: - lse_base = seqlen_info.offset_batch_Q( - Lse + head_idx * stride_lh, - batch_idx, - offset_q, - 0, - stride_lb, - 1, - HAS_CU_SEQLENS_Q, - USE_PADDED=False, - ) - lse_ptrs = tl.make_block_ptr( - base=lse_base, - shape=(actual_seqlen_q,), - strides=(1,), - offsets=(m_block * TILE_M,), - block_shape=(TILE_M,), - order=(0,), - ) - lse = tl.where( - sum_exp > 0.0, - max_lse + tl.log(sum_exp), - float("-inf"), - ) - tl.store(lse_ptrs, lse, boundary_check=(0,)) + tl.store(lse_ptrs, lse, boundary_check=(0,)) def _flash_attn_fwd_combine( @@ -218,7 +193,6 @@ def _flash_attn_fwd_combine( total_q, num_heads_q, head_dim = out_partial.shape[1:] batch_size = cu_seqlens_q.shape[0] - 1 seqlen_q = total_q - MAX_SPLITS = 1 << max(int(math.ceil(math.log2(max(num_splits, 1)))), 1) TILE_K = max(triton.next_power_of_2(head_dim), 16) @@ -232,7 +206,6 @@ def _flash_attn_fwd_combine( batch_size=batch_size, seqlen_q=seqlen_q, num_heads_q=num_heads_q, - head_dim=head_dim, ) _fwd_combine_kernel[grid]( @@ -262,7 +235,6 @@ def _flash_attn_fwd_combine( TILE_K=TILE_K, HAS_CU_SEQLENS_Q=cu_seqlens_q is not None, HAS_SEQUSED_Q=seqused_q is not None, - MAX_SPLITS=MAX_SPLITS, num_warps=num_warps, num_stages=num_stages, num_ctas=num_ctas, diff --git a/flash_sparse_attn/ops/triton/flash_gated_fwd.py b/flash_sparse_attn/ops/triton/flash_gated_fwd.py index 5b398f5..26f7642 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_fwd.py @@ -911,6 +911,7 @@ def _fwd_gated_base_kernel( row_sum=row_sum, scale_log2=softmax_scale_log2, final_scale=1.0, + IS_LOG2=IS_SPLIT_KV, ) acc_o = activations.rescale_o(acc_o, row_scale, LAZY_RESCALE=False) diff --git a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py index 4c840cd..1269d53 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py @@ -599,6 +599,7 @@ def _fwd_base_sparse_kernel( row_sum=row_sum, scale_log2=softmax_scale_log2, final_scale=1.0, + IS_LOG2=IS_SPLIT_KV, ) acc_o = activations.rescale_o(acc_o, row_scale, LAZY_RESCALE=False) diff --git a/flash_sparse_attn/ops/triton/launch_grid.py b/flash_sparse_attn/ops/triton/launch_grid.py index e1a6f0c..3982c03 100644 --- a/flash_sparse_attn/ops/triton/launch_grid.py +++ b/flash_sparse_attn/ops/triton/launch_grid.py @@ -64,7 +64,6 @@ def get_fwd_combine_grid( batch_size: int, seqlen_q: int, num_heads_q: int, - head_dim: int, ): """ Get the grid function for the forward combine kernel. @@ -72,7 +71,6 @@ def get_fwd_combine_grid( :param batch_size: Batch size :param seqlen_q: Sequence length of queries :param num_heads_q: Number of query heads - :param head_dim: Head dimension :return grid: Grid function """ @@ -80,7 +78,6 @@ def get_fwd_combine_grid( def grid(META): return ( triton.cdiv(seqlen_q, META["TILE_M"]), - triton.cdiv(head_dim, META["TILE_K"]), batch_size * num_heads_q, )