diff --git a/flash_sparse_attn/ops/triton/activations.py b/flash_sparse_attn/ops/triton/activations.py index 83bc33e..e5d5d9e 100644 --- a/flash_sparse_attn/ops/triton/activations.py +++ b/flash_sparse_attn/ops/triton/activations.py @@ -9,6 +9,64 @@ def check_inf(x): @triton.jit def online_softmax( + acc_s, + row_max, + row_sum, + scale_log2, + CHECK_INF: tl.constexpr, + RESCALE_THRESHOLD: tl.constexpr, +): + """ + Apply online softmax to acc_s, and update row_max and row_sum. + + :param acc_s: Attention scores tensor of shape [BLOCK_M, BLOCK_N]. + :param row_max: Current maximum values per row of shape [BLOCK_M], init to -inf. + :param row_sum: Current sum values per row of shape [BLOCK_M], init to 0. + :param scale_log2: Log2 of the scaling factor to be applied to acc_s. + :param CHECK_INF: Boolean flag indicating if -inf row_max should be clamped to 0. + :param RESCALE_THRESHOLD: Threshold for rescaling to avoid underflow. If <= 0, rescaling is disabled. + + :return p: Softmax probabilities tensor of shape [BLOCK_M, BLOCK_N]. + :return row_max_new: Updated maximum values per row of shape [BLOCK_M]. + :return row_sum_new: Updated sum values per row of shape [BLOCK_M]. + :return row_scale: Scaling factors per row of shape [BLOCK_M]. + """ + # Compute current row max + row_max_curr = tl.max(acc_s, axis=1) + + # Update row max + row_max_new = tl.maximum(row_max_curr, row_max) + + # Avoid exp(-inf - (-inf)) = nan by clamping -inf to 0 + if CHECK_INF: + row_max_new = check_inf(row_max_new) + + # Compute scaled differences to new row max + acc_scale_log2 = (row_max - row_max_new) * scale_log2 + + # Compute row scale + if RESCALE_THRESHOLD > 0.0: + # Triton can only skip computation at block granularity + if tl.min(acc_scale_log2) < -RESCALE_THRESHOLD: + row_scale = tl.exp2(acc_scale_log2) + else: + row_max_new = row_max + row_scale = acc_scale_log2 * 0.0 + 1.0 + else: + row_scale = tl.exp2(acc_scale_log2) + + # Compute attention weights + p = tl.exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2) + + # Update row sum + row_sum_cur = tl.sum(p, axis=1) + row_sum_new = row_sum * row_scale + row_sum_cur + + return p, row_max_new, row_sum_new, row_scale + + +@triton.jit +def online_sparse_softmax( acc_s, block_max, row_max, @@ -16,10 +74,9 @@ def online_softmax( scale_log2, softmax_threshold_log2, CHECK_INF: tl.constexpr, - RESCALE_THRESHOLD: tl.constexpr, ): """ - Apply online softmax to acc_s, and update block_max, row_max and row_sum. + Apply online sparse softmax to acc_s, and update block_max, row_max and row_sum. :param acc_s: Attention scores tensor of shape [BLOCK_M, BLOCK_N]. :param block_max: Running block-wise maximum scalar, init to -inf. @@ -28,7 +85,6 @@ def online_softmax( :param scale_log2: Log2 of the scaling factor to be applied to acc_s. :param softmax_threshold_log2: Threshold in log2-domain for block-level skip. If > -inf and block max is below threshold relative to running max, skip softmax update. :param CHECK_INF: Boolean flag indicating if -inf row_max should be clamped to 0. - :param RESCALE_THRESHOLD: Threshold for rescaling to avoid underflow. If <= 0, rescaling is disabled. :return p: Softmax probabilities tensor of shape [BLOCK_M, BLOCK_N]. :return block_max_new: Updated block-wise maximum scalar. @@ -69,15 +125,7 @@ def online_softmax( acc_scale_log2 = (row_max - row_max_new) * scale_log2 # Compute row scale - if RESCALE_THRESHOLD > 0.0: - # Triton can only skip computation at block granularity - if tl.min(acc_scale_log2) < -RESCALE_THRESHOLD: - row_scale = tl.exp2(acc_scale_log2) - else: - row_max_new = row_max - row_scale = acc_scale_log2 * 0.0 + 1.0 - else: - row_scale = tl.exp2(acc_scale_log2) + row_scale = tl.exp2(acc_scale_log2) # Compute attention weights p = tl.exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2) diff --git a/flash_sparse_attn/ops/triton/flash_bwd_postprocess.py b/flash_sparse_attn/ops/triton/flash_bwd_postprocess.py index e47ca38..1b4b591 100644 --- a/flash_sparse_attn/ops/triton/flash_bwd_postprocess.py +++ b/flash_sparse_attn/ops/triton/flash_bwd_postprocess.py @@ -90,7 +90,7 @@ def _bwd_postprocess_kernel( order=(1, 0), ) - # Advance dq_accum pointer + # Advance dq_accum pointers dq_accum_ptrs = tl.advance(dq_accum_ptrs, (m_block * TILE_M, 0)) # Load accumulators @@ -99,7 +99,7 @@ def _bwd_postprocess_kernel( # Scale dq dq = (acc_dq * scale).to(dQ.dtype.element_ty) - # Advance dq pointer + # Advance dq pointers dq_ptrs = tl.advance(dq_ptrs, (m_block * TILE_M, 0)) # Store dq @@ -144,7 +144,7 @@ def _bwd_postprocess_kernel( order=(0,), ) - # Advance da_accum pointer + # Advance da_accum pointers da_accum_ptrs = tl.advance(da_accum_ptrs, (m_block * TILE_M,)) # Load da accumulators @@ -153,7 +153,7 @@ def _bwd_postprocess_kernel( # Scale da da = (acc_da * scale).to(dA.dtype.element_ty) - # Advance da pointer + # Advance da pointers da_ptrs = tl.advance(da_ptrs, (m_block * TILE_M,)) # Store da diff --git a/flash_sparse_attn/ops/triton/flash_bwd_preprocess.py b/flash_sparse_attn/ops/triton/flash_bwd_preprocess.py index 14df94b..da14061 100644 --- a/flash_sparse_attn/ops/triton/flash_bwd_preprocess.py +++ b/flash_sparse_attn/ops/triton/flash_bwd_preprocess.py @@ -174,31 +174,31 @@ def _bwd_preprocess_kernel( # Initialize accumulators acc_dq = tl.zeros((TILE_M, TILE_K), dtype=tl.float32) - # Advance output pointer + # Advance output pointers o_ptrs = tl.advance(o_ptrs, (m_block * TILE_M, 0)) # Load o tile o_tile = tl.load(o_ptrs, boundary_check=(0, 1)).to(tl.float32) - # Advance do pointer + # Advance do pointers do_ptrs = tl.advance(do_ptrs, (m_block * TILE_M, 0)) # Load do tile do_tile = tl.load(do_ptrs, boundary_check=(0, 1)).to(tl.float32) - # Advance dpsum pointer + # Advance dpsum pointers dpsum_ptrs = tl.advance(dpsum_ptrs, (m_block * TILE_M,)) # Compute dpsum dpsum = tl.sum(o_tile * do_tile, axis=1) - # Advance acc_dq pointer + # Advance acc_dq pointers dq_accum_ptrs = tl.advance(dq_accum_ptrs, (m_block * TILE_M, 0)) # Store dpsum tl.store(dpsum_ptrs, dpsum, boundary_check=(0,)) - # Advance lse pointer + # Advance lse pointers lse_ptrs = tl.advance(lse_ptrs, (m_block * TILE_M,)) # Load lse tile @@ -211,7 +211,7 @@ def _bwd_preprocess_kernel( # Store dq_accum tl.store(dq_accum_ptrs, acc_dq, boundary_check=(0, 1)) - # Advance lse_log2 pointer + # Advance lse_log2 pointers lse_log2_ptrs = tl.advance(lse_log2_ptrs, (m_block * TILE_M,)) # Store lse_log2 diff --git a/flash_sparse_attn/ops/triton/flash_fwd_combine.py b/flash_sparse_attn/ops/triton/flash_dec_combine.py similarity index 97% rename from flash_sparse_attn/ops/triton/flash_fwd_combine.py rename to flash_sparse_attn/ops/triton/flash_dec_combine.py index 385bdb7..8fe332d 100644 --- a/flash_sparse_attn/ops/triton/flash_fwd_combine.py +++ b/flash_sparse_attn/ops/triton/flash_dec_combine.py @@ -6,7 +6,7 @@ @triton.jit -def _fwd_combine_kernel( +def _dec_combine_kernel( Out_partial, Lse_partial, Out, @@ -177,7 +177,7 @@ def _fwd_combine_kernel( tl.store(lse_ptrs, lse, boundary_check=(0,)) -def _flash_attn_fwd_combine( +def _flash_attn_dec_combine( out_partial: torch.Tensor, lse_partial: torch.Tensor, out: torch.Tensor, @@ -197,18 +197,18 @@ def _flash_attn_fwd_combine( TILE_K = max(triton.next_power_of_2(head_dim), 16) TILE_M, num_warps, num_stages, num_ctas = ( - launch_template.get_fwd_combine_launch_config( + launch_template.get_dec_combine_launch_config( tile_k=TILE_K, ) ) - grid = launch_grid.get_fwd_combine_grid( + grid = launch_grid.get_dec_combine_grid( batch_size=batch_size, seqlen_q=seqlen_q, num_heads_q=num_heads_q, ) - _fwd_combine_kernel[grid]( + _dec_combine_kernel[grid]( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/flash_dense_fwd.py b/flash_sparse_attn/ops/triton/flash_dense_fwd.py index d60cd04..c25d572 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_dense_fwd.py @@ -14,7 +14,7 @@ block_info, activations, mask, - flash_fwd_combine, + flash_dec_combine, ) @@ -25,7 +25,6 @@ def _fwd_inner_dense_base_kernel( k_ptrs, v_ptrs, acc_o, - block_max, row_max, row_sum, softmax_scale_log2, @@ -47,7 +46,7 @@ def _fwd_inner_dense_base_kernel( # Compute attention scores acc_s = tl.dot(q_tile, k_tile) - # Advance key pointer + # Advance key pointers k_ptrs = tl.advance(k_ptrs, (0, -TILE_N)) if n_block > n_block_min: # Load next key tile @@ -73,13 +72,11 @@ def _fwd_inner_dense_base_kernel( ) # Apply online softmax - p, block_max, row_max, row_sum, row_scale, _ = activations.online_softmax( + p, row_max, row_sum, row_scale = activations.online_softmax( acc_s=acc_s, - block_max=block_max, row_max=row_max, row_sum=row_sum, scale_log2=softmax_scale_log2, - softmax_threshold_log2=float("-inf"), CHECK_INF=CHECK_INF, RESCALE_THRESHOLD=0.0, ) @@ -93,10 +90,10 @@ def _fwd_inner_dense_base_kernel( # Update output accumulator acc_o += tl.dot(p.to(v_tile.dtype), v_tile) - # Advance value pointer + # Advance value pointers v_ptrs = tl.advance(v_ptrs, (-TILE_N, 0)) - return k_tile, k_ptrs, v_ptrs, acc_o, block_max, row_max, row_sum + return k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum @triton.jit @@ -404,7 +401,6 @@ def _fwd_dense_base_kernel( q_tile = tl.load(q_ptrs, boundary_check=(0, 1)) # Initialize accumulators - block_max = tl.full((), float("-inf"), dtype=tl.float32) row_max = tl.full((TILE_M,), float("-inf"), dtype=tl.float32) row_sum = tl.zeros((TILE_M,), dtype=tl.float32) acc_o = tl.zeros((TILE_M, TILE_K), dtype=tl.float32) @@ -415,14 +411,13 @@ def _fwd_dense_base_kernel( # Process n_blocks with masking if IS_CAUSAL or IS_LOCAL: for n_block in tl.range(n_block_max - 1, n_block_max_no_mask - 1, -1): - k_tile, k_ptrs, v_ptrs, acc_o, block_max, row_max, row_sum = ( + k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = ( _fwd_inner_dense_base_kernel( q_tile=q_tile, k_tile=k_tile, k_ptrs=k_ptrs, v_ptrs=v_ptrs, acc_o=acc_o, - block_max=block_max, row_max=row_max, row_sum=row_sum, softmax_scale_log2=softmax_scale_log2, @@ -446,32 +441,29 @@ def _fwd_dense_base_kernel( # First iteration with seqlen masking n_block = n_block_max - 1 - k_tile, k_ptrs, v_ptrs, acc_o, block_max, row_max, row_sum = ( - _fwd_inner_dense_base_kernel( - q_tile=q_tile, - k_tile=k_tile, - k_ptrs=k_ptrs, - v_ptrs=v_ptrs, - acc_o=acc_o, - block_max=block_max, - row_max=row_max, - row_sum=row_sum, - softmax_scale_log2=softmax_scale_log2, - m_block=m_block, - n_block=n_block, - n_block_min=n_block, - actual_seqlen_q=actual_seqlen_q, - actual_seqlen_k=actual_seqlen_k, - TILE_M=TILE_M, - TILE_N=TILE_N, - WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, - WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT, - QHEADS_PER_KVHEAD_PACKGQA=QHEADS_PER_KVHEAD_PACKGQA, - IS_MASK=True, - MASK_CAUSAL=False, - MASK_LOCAL=False, - CHECK_INF=True, - ) + k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = _fwd_inner_dense_base_kernel( + q_tile=q_tile, + k_tile=k_tile, + k_ptrs=k_ptrs, + v_ptrs=v_ptrs, + acc_o=acc_o, + row_max=row_max, + row_sum=row_sum, + softmax_scale_log2=softmax_scale_log2, + m_block=m_block, + n_block=n_block, + n_block_min=n_block, + actual_seqlen_q=actual_seqlen_q, + actual_seqlen_k=actual_seqlen_k, + TILE_M=TILE_M, + TILE_N=TILE_N, + WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT, + QHEADS_PER_KVHEAD_PACKGQA=QHEADS_PER_KVHEAD_PACKGQA, + IS_MASK=True, + MASK_CAUSAL=False, + MASK_LOCAL=False, + CHECK_INF=True, ) n_block_max_no_mask = n_block_max - 1 @@ -497,14 +489,13 @@ def _fwd_dense_base_kernel( ) k_tile = tl.load(k_ptrs, boundary_check=(0, 1)) for n_block in tl.range(n_block_max_no_mask - 1, n_block_min_no_mask - 1, -1): - k_tile, k_ptrs, v_ptrs, acc_o, block_max, row_max, row_sum = ( + k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = ( _fwd_inner_dense_base_kernel( q_tile=q_tile, k_tile=k_tile, k_ptrs=k_ptrs, v_ptrs=v_ptrs, acc_o=acc_o, - block_max=block_max, row_max=row_max, row_sum=row_sum, softmax_scale_log2=softmax_scale_log2, @@ -545,14 +536,13 @@ def _fwd_dense_base_kernel( ) k_tile = tl.load(k_ptrs, boundary_check=(0, 1)) for n_block in tl.range(n_block_min_no_mask - 1, n_block_min - 1, -1): - k_tile, k_ptrs, v_ptrs, acc_o, block_max, row_max, row_sum = ( + k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = ( _fwd_inner_dense_base_kernel( q_tile=q_tile, k_tile=k_tile, k_ptrs=k_ptrs, v_ptrs=v_ptrs, acc_o=acc_o, - block_max=block_max, row_max=row_max, row_sum=row_sum, softmax_scale_log2=softmax_scale_log2, @@ -743,7 +733,7 @@ def _flash_dense_attn_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_dec_combine( out_partial, lse_partial, out, @@ -894,7 +884,7 @@ def _flash_dense_attn_varlen_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_dec_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/flash_gated_bwd.py b/flash_sparse_attn/ops/triton/flash_gated_bwd.py index c82df8e..3c15946 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_bwd.py @@ -130,7 +130,7 @@ def _bwd_inner_gated_base_kernel( # Load LSE lse_log2 = tl.load(lse_ptrs, boundary_check=(0,)) - # Advance LSE pointer + # Advance LSE pointers lse_ptrs = tl.advance(lse_ptrs, (TILE_M,)) # Compute attention weights diff --git a/flash_sparse_attn/ops/triton/flash_gated_fwd.py b/flash_sparse_attn/ops/triton/flash_gated_fwd.py index 26f7642..0172d81 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_fwd.py @@ -14,7 +14,7 @@ block_info, activations, mask, - flash_fwd_combine, + flash_dec_combine, ) @@ -99,7 +99,7 @@ def _fwd_inner_gated_base_kernel( # Apply online softmax p, block_max, row_max, row_sum, row_scale, skip_softmax = ( - activations.online_softmax( + activations.online_sparse_softmax( acc_s=acc_s, block_max=block_max, row_max=row_max, @@ -107,7 +107,6 @@ def _fwd_inner_gated_base_kernel( scale_log2=softmax_scale_log2, softmax_threshold_log2=softmax_threshold_log2, CHECK_INF=CHECK_INF, - RESCALE_THRESHOLD=0.0, ) ) @@ -121,7 +120,7 @@ def _fwd_inner_gated_base_kernel( # Update output accumulator acc_o += tl.dot(p.to(v_tile.dtype), v_tile) - # Advance value pointer + # Advance value pointers v_ptrs = tl.advance(v_ptrs, (-TILE_N, 0)) else: # Advance key and value pointers @@ -1095,7 +1094,7 @@ def _flash_gated_attn_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_dec_combine( out_partial, lse_partial, out, @@ -1266,7 +1265,7 @@ def _flash_gated_attn_varlen_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_dec_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py index 1269d53..e183b09 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py @@ -14,7 +14,7 @@ block_info, activations, mask, - flash_fwd_combine, + flash_dec_combine, ) @@ -48,7 +48,7 @@ def _fwd_inner_sparse_base_kernel( # Compute attention scores acc_s = tl.dot(q_tile, k_tile) - # Advance key pointer + # Advance key pointers k_ptrs = tl.advance(k_ptrs, (0, -TILE_N)) if n_block > n_block_min: # Load next key tile @@ -75,7 +75,7 @@ def _fwd_inner_sparse_base_kernel( # Apply online softmax p, block_max, row_max, row_sum, row_scale, skip_softmax = ( - activations.online_softmax( + activations.online_sparse_softmax( acc_s=acc_s, block_max=block_max, row_max=row_max, @@ -83,7 +83,6 @@ def _fwd_inner_sparse_base_kernel( scale_log2=softmax_scale_log2, softmax_threshold_log2=softmax_threshold_log2, CHECK_INF=CHECK_INF, - RESCALE_THRESHOLD=0.0, ) ) @@ -97,7 +96,7 @@ def _fwd_inner_sparse_base_kernel( # Update output accumulator acc_o += tl.dot(p.to(v_tile.dtype), v_tile) - # Advance value pointer + # Advance value pointers v_ptrs = tl.advance(v_ptrs, (-TILE_N, 0)) return k_tile, k_ptrs, v_ptrs, acc_o, block_max, row_max, row_sum @@ -766,7 +765,7 @@ def _flash_sparse_attn_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_dec_combine( out_partial, lse_partial, out, @@ -920,7 +919,7 @@ def _flash_sparse_attn_varlen_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_dec_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/launch_grid.py b/flash_sparse_attn/ops/triton/launch_grid.py index 3982c03..1ae6d94 100644 --- a/flash_sparse_attn/ops/triton/launch_grid.py +++ b/flash_sparse_attn/ops/triton/launch_grid.py @@ -60,13 +60,13 @@ def grid(META): return grid -def get_fwd_combine_grid( +def get_dec_combine_grid( batch_size: int, seqlen_q: int, num_heads_q: int, ): """ - Get the grid function for the forward combine kernel. + Get the grid function for the decode combine kernel. :param batch_size: Batch size :param seqlen_q: Sequence length of queries diff --git a/flash_sparse_attn/ops/triton/launch_template.py b/flash_sparse_attn/ops/triton/launch_template.py index 8a6dd6b..d79b5d3 100644 --- a/flash_sparse_attn/ops/triton/launch_template.py +++ b/flash_sparse_attn/ops/triton/launch_template.py @@ -522,11 +522,11 @@ def get_bwd_gated_launch_config( raise NotImplementedError(f"Unsupported device type: {device.type}") -def get_fwd_combine_launch_config( +def get_dec_combine_launch_config( tile_k, ) -> tuple[int, int, int, int]: """ - Get launch configuration for forward combine kernel based on input parameters and device architecture. + Get launch configuration for decode combine kernel based on input parameters and device architecture. :param tile_k: Tile size in the K dimension @@ -538,7 +538,7 @@ def get_fwd_combine_launch_config( if arch == -1: raise NotImplementedError(f"Unsupported device: {device} with arch {arch}") - # NOTE: Setting num_ctas=2 for the forward kernel can trigger Triton's PlanCTA assertion + # NOTE: Setting num_ctas=2 for the decode combine kernel can trigger Triton's PlanCTA assertion # Setting num_ctas=1 for now to avoid this issue, but we may want to revisit this in the future if device.type == "cuda": # For A100