From 79a767009a5d69323e0e7698535c3d57fb1058ad Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 21 Apr 2026 17:37:58 +0800 Subject: [PATCH 01/14] Rename flash_fwd_combine to flash_dec_combine Co-authored-by: Copilot --- .../{flash_fwd_combine.py => flash_dec_combine.py} | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) rename flash_sparse_attn/ops/triton/{flash_fwd_combine.py => flash_dec_combine.py} (97%) diff --git a/flash_sparse_attn/ops/triton/flash_fwd_combine.py b/flash_sparse_attn/ops/triton/flash_dec_combine.py similarity index 97% rename from flash_sparse_attn/ops/triton/flash_fwd_combine.py rename to flash_sparse_attn/ops/triton/flash_dec_combine.py index 385bdb7..8fe332d 100644 --- a/flash_sparse_attn/ops/triton/flash_fwd_combine.py +++ b/flash_sparse_attn/ops/triton/flash_dec_combine.py @@ -6,7 +6,7 @@ @triton.jit -def _fwd_combine_kernel( +def _dec_combine_kernel( Out_partial, Lse_partial, Out, @@ -177,7 +177,7 @@ def _fwd_combine_kernel( tl.store(lse_ptrs, lse, boundary_check=(0,)) -def _flash_attn_fwd_combine( +def _flash_attn_dec_combine( out_partial: torch.Tensor, lse_partial: torch.Tensor, out: torch.Tensor, @@ -197,18 +197,18 @@ def _flash_attn_fwd_combine( TILE_K = max(triton.next_power_of_2(head_dim), 16) TILE_M, num_warps, num_stages, num_ctas = ( - launch_template.get_fwd_combine_launch_config( + launch_template.get_dec_combine_launch_config( tile_k=TILE_K, ) ) - grid = launch_grid.get_fwd_combine_grid( + grid = launch_grid.get_dec_combine_grid( batch_size=batch_size, seqlen_q=seqlen_q, num_heads_q=num_heads_q, ) - _fwd_combine_kernel[grid]( + _dec_combine_kernel[grid]( out_partial, lse_partial, out, From 7650a66d7edf19e05290979f2023aa3bb2d09256 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 21 Apr 2026 17:38:13 +0800 Subject: [PATCH 02/14] Rename get_fwd_combine_launch_config to get_dec_combine_launch_config and update related comments Co-authored-by: Copilot --- flash_sparse_attn/ops/triton/launch_template.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_sparse_attn/ops/triton/launch_template.py b/flash_sparse_attn/ops/triton/launch_template.py index 8a6dd6b..d79b5d3 100644 --- a/flash_sparse_attn/ops/triton/launch_template.py +++ b/flash_sparse_attn/ops/triton/launch_template.py @@ -522,11 +522,11 @@ def get_bwd_gated_launch_config( raise NotImplementedError(f"Unsupported device type: {device.type}") -def get_fwd_combine_launch_config( +def get_dec_combine_launch_config( tile_k, ) -> tuple[int, int, int, int]: """ - Get launch configuration for forward combine kernel based on input parameters and device architecture. + Get launch configuration for decode combine kernel based on input parameters and device architecture. :param tile_k: Tile size in the K dimension @@ -538,7 +538,7 @@ def get_fwd_combine_launch_config( if arch == -1: raise NotImplementedError(f"Unsupported device: {device} with arch {arch}") - # NOTE: Setting num_ctas=2 for the forward kernel can trigger Triton's PlanCTA assertion + # NOTE: Setting num_ctas=2 for the decode combine kernel can trigger Triton's PlanCTA assertion # Setting num_ctas=1 for now to avoid this issue, but we may want to revisit this in the future if device.type == "cuda": # For A100 From 4dc2796f0b2f643c455d63b50d4c4485b054545b Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 21 Apr 2026 17:38:19 +0800 Subject: [PATCH 03/14] Rename get_fwd_combine_grid to get_dec_combine_grid and update related docstring Co-authored-by: Copilot --- flash_sparse_attn/ops/triton/launch_grid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_sparse_attn/ops/triton/launch_grid.py b/flash_sparse_attn/ops/triton/launch_grid.py index 3982c03..1ae6d94 100644 --- a/flash_sparse_attn/ops/triton/launch_grid.py +++ b/flash_sparse_attn/ops/triton/launch_grid.py @@ -60,13 +60,13 @@ def grid(META): return grid -def get_fwd_combine_grid( +def get_dec_combine_grid( batch_size: int, seqlen_q: int, num_heads_q: int, ): """ - Get the grid function for the forward combine kernel. + Get the grid function for the decode combine kernel. :param batch_size: Batch size :param seqlen_q: Sequence length of queries From b511322978cdb55fe550524bcd4a138c3366e9ad Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 22 Apr 2026 15:23:02 +0800 Subject: [PATCH 04/14] Fix comments to clarify pointer advancements in _bwd_postprocess_kernel --- flash_sparse_attn/ops/triton/flash_bwd_postprocess.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flash_sparse_attn/ops/triton/flash_bwd_postprocess.py b/flash_sparse_attn/ops/triton/flash_bwd_postprocess.py index e47ca38..1b4b591 100644 --- a/flash_sparse_attn/ops/triton/flash_bwd_postprocess.py +++ b/flash_sparse_attn/ops/triton/flash_bwd_postprocess.py @@ -90,7 +90,7 @@ def _bwd_postprocess_kernel( order=(1, 0), ) - # Advance dq_accum pointer + # Advance dq_accum pointers dq_accum_ptrs = tl.advance(dq_accum_ptrs, (m_block * TILE_M, 0)) # Load accumulators @@ -99,7 +99,7 @@ def _bwd_postprocess_kernel( # Scale dq dq = (acc_dq * scale).to(dQ.dtype.element_ty) - # Advance dq pointer + # Advance dq pointers dq_ptrs = tl.advance(dq_ptrs, (m_block * TILE_M, 0)) # Store dq @@ -144,7 +144,7 @@ def _bwd_postprocess_kernel( order=(0,), ) - # Advance da_accum pointer + # Advance da_accum pointers da_accum_ptrs = tl.advance(da_accum_ptrs, (m_block * TILE_M,)) # Load da accumulators @@ -153,7 +153,7 @@ def _bwd_postprocess_kernel( # Scale da da = (acc_da * scale).to(dA.dtype.element_ty) - # Advance da pointer + # Advance da pointers da_ptrs = tl.advance(da_ptrs, (m_block * TILE_M,)) # Store da From c6507bbbf478cb0463df78095df1bb8f7e2648df Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 22 Apr 2026 15:23:16 +0800 Subject: [PATCH 05/14] Clarify comments for pointer advancements in _bwd_preprocess_kernel --- flash_sparse_attn/ops/triton/flash_bwd_preprocess.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/flash_sparse_attn/ops/triton/flash_bwd_preprocess.py b/flash_sparse_attn/ops/triton/flash_bwd_preprocess.py index 14df94b..da14061 100644 --- a/flash_sparse_attn/ops/triton/flash_bwd_preprocess.py +++ b/flash_sparse_attn/ops/triton/flash_bwd_preprocess.py @@ -174,31 +174,31 @@ def _bwd_preprocess_kernel( # Initialize accumulators acc_dq = tl.zeros((TILE_M, TILE_K), dtype=tl.float32) - # Advance output pointer + # Advance output pointers o_ptrs = tl.advance(o_ptrs, (m_block * TILE_M, 0)) # Load o tile o_tile = tl.load(o_ptrs, boundary_check=(0, 1)).to(tl.float32) - # Advance do pointer + # Advance do pointers do_ptrs = tl.advance(do_ptrs, (m_block * TILE_M, 0)) # Load do tile do_tile = tl.load(do_ptrs, boundary_check=(0, 1)).to(tl.float32) - # Advance dpsum pointer + # Advance dpsum pointers dpsum_ptrs = tl.advance(dpsum_ptrs, (m_block * TILE_M,)) # Compute dpsum dpsum = tl.sum(o_tile * do_tile, axis=1) - # Advance acc_dq pointer + # Advance acc_dq pointers dq_accum_ptrs = tl.advance(dq_accum_ptrs, (m_block * TILE_M, 0)) # Store dpsum tl.store(dpsum_ptrs, dpsum, boundary_check=(0,)) - # Advance lse pointer + # Advance lse pointers lse_ptrs = tl.advance(lse_ptrs, (m_block * TILE_M,)) # Load lse tile @@ -211,7 +211,7 @@ def _bwd_preprocess_kernel( # Store dq_accum tl.store(dq_accum_ptrs, acc_dq, boundary_check=(0, 1)) - # Advance lse_log2 pointer + # Advance lse_log2 pointers lse_log2_ptrs = tl.advance(lse_log2_ptrs, (m_block * TILE_M,)) # Store lse_log2 From af18e0c807403949ba42c4dc5507087e5962d160 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 22 Apr 2026 15:24:51 +0800 Subject: [PATCH 06/14] Rename flash_fwd_combine to flash_dec_combine and update related references Co-authored-by: Copilot --- .../ops/triton/flash_dense_fwd.py | 78 ++++++++----------- 1 file changed, 34 insertions(+), 44 deletions(-) diff --git a/flash_sparse_attn/ops/triton/flash_dense_fwd.py b/flash_sparse_attn/ops/triton/flash_dense_fwd.py index d60cd04..7209cac 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_dense_fwd.py @@ -14,7 +14,7 @@ block_info, activations, mask, - flash_fwd_combine, + flash_dec_combine, ) @@ -25,7 +25,6 @@ def _fwd_inner_dense_base_kernel( k_ptrs, v_ptrs, acc_o, - block_max, row_max, row_sum, softmax_scale_log2, @@ -47,7 +46,7 @@ def _fwd_inner_dense_base_kernel( # Compute attention scores acc_s = tl.dot(q_tile, k_tile) - # Advance key pointer + # Advance key pointers k_ptrs = tl.advance(k_ptrs, (0, -TILE_N)) if n_block > n_block_min: # Load next key tile @@ -73,13 +72,11 @@ def _fwd_inner_dense_base_kernel( ) # Apply online softmax - p, block_max, row_max, row_sum, row_scale, _ = activations.online_softmax( + p, row_max, row_sum, row_scale = activations.online_softmax( acc_s=acc_s, - block_max=block_max, row_max=row_max, row_sum=row_sum, - scale_log2=softmax_scale_log2, - softmax_threshold_log2=float("-inf"), + SCALE_LOG2=softmax_scale_log2, CHECK_INF=CHECK_INF, RESCALE_THRESHOLD=0.0, ) @@ -93,10 +90,10 @@ def _fwd_inner_dense_base_kernel( # Update output accumulator acc_o += tl.dot(p.to(v_tile.dtype), v_tile) - # Advance value pointer + # Advance value pointers v_ptrs = tl.advance(v_ptrs, (-TILE_N, 0)) - return k_tile, k_ptrs, v_ptrs, acc_o, block_max, row_max, row_sum + return k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum @triton.jit @@ -404,7 +401,6 @@ def _fwd_dense_base_kernel( q_tile = tl.load(q_ptrs, boundary_check=(0, 1)) # Initialize accumulators - block_max = tl.full((), float("-inf"), dtype=tl.float32) row_max = tl.full((TILE_M,), float("-inf"), dtype=tl.float32) row_sum = tl.zeros((TILE_M,), dtype=tl.float32) acc_o = tl.zeros((TILE_M, TILE_K), dtype=tl.float32) @@ -415,14 +411,13 @@ def _fwd_dense_base_kernel( # Process n_blocks with masking if IS_CAUSAL or IS_LOCAL: for n_block in tl.range(n_block_max - 1, n_block_max_no_mask - 1, -1): - k_tile, k_ptrs, v_ptrs, acc_o, block_max, row_max, row_sum = ( + k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = ( _fwd_inner_dense_base_kernel( q_tile=q_tile, k_tile=k_tile, k_ptrs=k_ptrs, v_ptrs=v_ptrs, acc_o=acc_o, - block_max=block_max, row_max=row_max, row_sum=row_sum, softmax_scale_log2=softmax_scale_log2, @@ -446,32 +441,29 @@ def _fwd_dense_base_kernel( # First iteration with seqlen masking n_block = n_block_max - 1 - k_tile, k_ptrs, v_ptrs, acc_o, block_max, row_max, row_sum = ( - _fwd_inner_dense_base_kernel( - q_tile=q_tile, - k_tile=k_tile, - k_ptrs=k_ptrs, - v_ptrs=v_ptrs, - acc_o=acc_o, - block_max=block_max, - row_max=row_max, - row_sum=row_sum, - softmax_scale_log2=softmax_scale_log2, - m_block=m_block, - n_block=n_block, - n_block_min=n_block, - actual_seqlen_q=actual_seqlen_q, - actual_seqlen_k=actual_seqlen_k, - TILE_M=TILE_M, - TILE_N=TILE_N, - WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, - WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT, - QHEADS_PER_KVHEAD_PACKGQA=QHEADS_PER_KVHEAD_PACKGQA, - IS_MASK=True, - MASK_CAUSAL=False, - MASK_LOCAL=False, - CHECK_INF=True, - ) + k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = _fwd_inner_dense_base_kernel( + q_tile=q_tile, + k_tile=k_tile, + k_ptrs=k_ptrs, + v_ptrs=v_ptrs, + acc_o=acc_o, + row_max=row_max, + row_sum=row_sum, + softmax_scale_log2=softmax_scale_log2, + m_block=m_block, + n_block=n_block, + n_block_min=n_block, + actual_seqlen_q=actual_seqlen_q, + actual_seqlen_k=actual_seqlen_k, + TILE_M=TILE_M, + TILE_N=TILE_N, + WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT, + WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT, + QHEADS_PER_KVHEAD_PACKGQA=QHEADS_PER_KVHEAD_PACKGQA, + IS_MASK=True, + MASK_CAUSAL=False, + MASK_LOCAL=False, + CHECK_INF=True, ) n_block_max_no_mask = n_block_max - 1 @@ -497,14 +489,13 @@ def _fwd_dense_base_kernel( ) k_tile = tl.load(k_ptrs, boundary_check=(0, 1)) for n_block in tl.range(n_block_max_no_mask - 1, n_block_min_no_mask - 1, -1): - k_tile, k_ptrs, v_ptrs, acc_o, block_max, row_max, row_sum = ( + k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = ( _fwd_inner_dense_base_kernel( q_tile=q_tile, k_tile=k_tile, k_ptrs=k_ptrs, v_ptrs=v_ptrs, acc_o=acc_o, - block_max=block_max, row_max=row_max, row_sum=row_sum, softmax_scale_log2=softmax_scale_log2, @@ -545,14 +536,13 @@ def _fwd_dense_base_kernel( ) k_tile = tl.load(k_ptrs, boundary_check=(0, 1)) for n_block in tl.range(n_block_min_no_mask - 1, n_block_min - 1, -1): - k_tile, k_ptrs, v_ptrs, acc_o, block_max, row_max, row_sum = ( + k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = ( _fwd_inner_dense_base_kernel( q_tile=q_tile, k_tile=k_tile, k_ptrs=k_ptrs, v_ptrs=v_ptrs, acc_o=acc_o, - block_max=block_max, row_max=row_max, row_sum=row_sum, softmax_scale_log2=softmax_scale_log2, @@ -743,7 +733,7 @@ def _flash_dense_attn_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, @@ -894,7 +884,7 @@ def _flash_dense_attn_varlen_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, From 31ee17155a78bdd1af648603ae13621d53951eb1 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 22 Apr 2026 15:25:40 +0800 Subject: [PATCH 07/14] Rename flash_fwd_combine to flash_dec_combine and update references in forward kernels --- .../ops/triton/flash_sparse_fwd.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py index 1269d53..94c2105 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py @@ -14,7 +14,7 @@ block_info, activations, mask, - flash_fwd_combine, + flash_dec_combine, ) @@ -48,7 +48,7 @@ def _fwd_inner_sparse_base_kernel( # Compute attention scores acc_s = tl.dot(q_tile, k_tile) - # Advance key pointer + # Advance key pointers k_ptrs = tl.advance(k_ptrs, (0, -TILE_N)) if n_block > n_block_min: # Load next key tile @@ -75,15 +75,14 @@ def _fwd_inner_sparse_base_kernel( # Apply online softmax p, block_max, row_max, row_sum, row_scale, skip_softmax = ( - activations.online_softmax( + activations.online_sparse_softmax( acc_s=acc_s, block_max=block_max, row_max=row_max, row_sum=row_sum, - scale_log2=softmax_scale_log2, - softmax_threshold_log2=softmax_threshold_log2, + SCALE_LOG2=softmax_scale_log2, + SOFTMAX_THRESHOLD_LOG2=softmax_threshold_log2, CHECK_INF=CHECK_INF, - RESCALE_THRESHOLD=0.0, ) ) @@ -97,7 +96,7 @@ def _fwd_inner_sparse_base_kernel( # Update output accumulator acc_o += tl.dot(p.to(v_tile.dtype), v_tile) - # Advance value pointer + # Advance value pointers v_ptrs = tl.advance(v_ptrs, (-TILE_N, 0)) return k_tile, k_ptrs, v_ptrs, acc_o, block_max, row_max, row_sum @@ -766,7 +765,7 @@ def _flash_sparse_attn_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, @@ -920,7 +919,7 @@ def _flash_sparse_attn_varlen_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, From 318de2902bd29f4f6aec335e37e93888938309d6 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 22 Apr 2026 15:26:20 +0800 Subject: [PATCH 08/14] Rename flash_fwd_combine to flash_dec_combine and update references in gated attention kernels --- flash_sparse_attn/ops/triton/flash_gated_fwd.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/flash_sparse_attn/ops/triton/flash_gated_fwd.py b/flash_sparse_attn/ops/triton/flash_gated_fwd.py index 26f7642..660f9d7 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_fwd.py @@ -14,7 +14,7 @@ block_info, activations, mask, - flash_fwd_combine, + flash_dec_combine, ) @@ -99,15 +99,14 @@ def _fwd_inner_gated_base_kernel( # Apply online softmax p, block_max, row_max, row_sum, row_scale, skip_softmax = ( - activations.online_softmax( + activations.online_sparse_softmax( acc_s=acc_s, block_max=block_max, row_max=row_max, row_sum=row_sum, - scale_log2=softmax_scale_log2, - softmax_threshold_log2=softmax_threshold_log2, + SCALE_LOG2=softmax_scale_log2, + SOFTMAX_THRESHOLD_LOG2=softmax_threshold_log2, CHECK_INF=CHECK_INF, - RESCALE_THRESHOLD=0.0, ) ) @@ -121,7 +120,7 @@ def _fwd_inner_gated_base_kernel( # Update output accumulator acc_o += tl.dot(p.to(v_tile.dtype), v_tile) - # Advance value pointer + # Advance value pointers v_ptrs = tl.advance(v_ptrs, (-TILE_N, 0)) else: # Advance key and value pointers @@ -1095,7 +1094,7 @@ def _flash_gated_attn_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, @@ -1266,7 +1265,7 @@ def _flash_gated_attn_varlen_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, From f772c1f57cd4481fbda9b3decafd847b11a7d0db Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 22 Apr 2026 15:26:50 +0800 Subject: [PATCH 09/14] Fix comment to clarify LSE pointer advancements in _bwd_inner_gated_base_kernel --- flash_sparse_attn/ops/triton/flash_gated_bwd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_sparse_attn/ops/triton/flash_gated_bwd.py b/flash_sparse_attn/ops/triton/flash_gated_bwd.py index c82df8e..3c15946 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_bwd.py @@ -130,7 +130,7 @@ def _bwd_inner_gated_base_kernel( # Load LSE lse_log2 = tl.load(lse_ptrs, boundary_check=(0,)) - # Advance LSE pointer + # Advance LSE pointers lse_ptrs = tl.advance(lse_ptrs, (TILE_M,)) # Compute attention weights From 40e1baf2dc63b2d5de049ca378545605ee200ad2 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 22 Apr 2026 15:39:05 +0800 Subject: [PATCH 10/14] Standardize parameter naming from SCALE_LOG2 to scale_log2 in kernel functions --- flash_sparse_attn/ops/triton/flash_dense_fwd.py | 2 +- flash_sparse_attn/ops/triton/flash_gated_fwd.py | 2 +- flash_sparse_attn/ops/triton/flash_sparse_fwd.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_sparse_attn/ops/triton/flash_dense_fwd.py b/flash_sparse_attn/ops/triton/flash_dense_fwd.py index 7209cac..05fb342 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_dense_fwd.py @@ -76,7 +76,7 @@ def _fwd_inner_dense_base_kernel( acc_s=acc_s, row_max=row_max, row_sum=row_sum, - SCALE_LOG2=softmax_scale_log2, + scale_log2=softmax_scale_log2, CHECK_INF=CHECK_INF, RESCALE_THRESHOLD=0.0, ) diff --git a/flash_sparse_attn/ops/triton/flash_gated_fwd.py b/flash_sparse_attn/ops/triton/flash_gated_fwd.py index 660f9d7..04cb2d6 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_fwd.py @@ -104,7 +104,7 @@ def _fwd_inner_gated_base_kernel( block_max=block_max, row_max=row_max, row_sum=row_sum, - SCALE_LOG2=softmax_scale_log2, + scale_log2=softmax_scale_log2, SOFTMAX_THRESHOLD_LOG2=softmax_threshold_log2, CHECK_INF=CHECK_INF, ) diff --git a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py index 94c2105..ea52a8a 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py @@ -80,7 +80,7 @@ def _fwd_inner_sparse_base_kernel( block_max=block_max, row_max=row_max, row_sum=row_sum, - SCALE_LOG2=softmax_scale_log2, + scale_log2=softmax_scale_log2, SOFTMAX_THRESHOLD_LOG2=softmax_threshold_log2, CHECK_INF=CHECK_INF, ) From a795d01130343081fd2d5df921681957a6b26ceb Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 22 Apr 2026 15:40:15 +0800 Subject: [PATCH 11/14] Refactor online_softmax and online_sparse_softmax to improve clarity and functionality Co-authored-by: Copilot --- flash_sparse_attn/ops/triton/activations.py | 78 +++++++++++++++++---- 1 file changed, 63 insertions(+), 15 deletions(-) diff --git a/flash_sparse_attn/ops/triton/activations.py b/flash_sparse_attn/ops/triton/activations.py index 83bc33e..5dbcd8d 100644 --- a/flash_sparse_attn/ops/triton/activations.py +++ b/flash_sparse_attn/ops/triton/activations.py @@ -10,26 +10,82 @@ def check_inf(x): @triton.jit def online_softmax( acc_s, - block_max, row_max, row_sum, scale_log2, - softmax_threshold_log2, CHECK_INF: tl.constexpr, RESCALE_THRESHOLD: tl.constexpr, ): """ - Apply online softmax to acc_s, and update block_max, row_max and row_sum. + Apply online softmax to acc_s, and update row_max and row_sum. :param acc_s: Attention scores tensor of shape [BLOCK_M, BLOCK_N]. - :param block_max: Running block-wise maximum scalar, init to -inf. :param row_max: Current maximum values per row of shape [BLOCK_M], init to -inf. :param row_sum: Current sum values per row of shape [BLOCK_M], init to 0. :param scale_log2: Log2 of the scaling factor to be applied to acc_s. - :param softmax_threshold_log2: Threshold in log2-domain for block-level skip. If > -inf and block max is below threshold relative to running max, skip softmax update. :param CHECK_INF: Boolean flag indicating if -inf row_max should be clamped to 0. :param RESCALE_THRESHOLD: Threshold for rescaling to avoid underflow. If <= 0, rescaling is disabled. + :return p: Softmax probabilities tensor of shape [BLOCK_M, BLOCK_N]. + :return row_max_new: Updated maximum values per row of shape [BLOCK_M]. + :return row_sum_new: Updated sum values per row of shape [BLOCK_M]. + :return row_scale: Scaling factors per row of shape [BLOCK_M]. + """ + # Compute current row max + row_max_curr = tl.max(acc_s, axis=1) + + # Update row max + row_max_new = tl.maximum(row_max_curr, row_max) + + # Avoid exp(-inf - (-inf)) = nan by clamping -inf to 0 + if CHECK_INF: + row_max_new = check_inf(row_max_new) + + # Compute scaled differences to new row max + acc_scale_log2 = (row_max - row_max_new) * scale_log2 + + # Compute row scale + if RESCALE_THRESHOLD > 0.0: + # Triton can only skip computation at block granularity + if tl.min(acc_scale_log2) < -RESCALE_THRESHOLD: + row_scale = tl.exp2(acc_scale_log2) + else: + row_max_new = row_max + row_scale = acc_scale_log2 * 0.0 + 1.0 + else: + row_scale = tl.exp2(acc_scale_log2) + + # Compute attention weights + p = tl.exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2) + + # Update row sum + row_sum_cur = tl.sum(p, axis=1) + row_sum_new = row_sum * row_scale + row_sum_cur + + return p, row_max_new, row_sum_new, row_scale + + +@triton.jit +def online_sparse_softmax( + acc_s, + block_max, + row_max, + row_sum, + scale_log2, + SOFTMAX_THRESHOLD_LOG2: tl.constexpr, + CHECK_INF: tl.constexpr, +): + """ + Apply online sparse softmax to acc_s, and update block_max, row_max and row_sum. + + :param acc_s: Attention scores tensor of shape [BLOCK_M, BLOCK_N]. + :param block_max: Running block-wise maximum scalar, init to -inf. + :param row_max: Current maximum values per row of shape [BLOCK_M], init to -inf. + :param row_sum: Current sum values per row of shape [BLOCK_M], init to 0. + :param scale_log2: Log2 of the scaling factor to be applied to acc_s. + :param SOFTMAX_THRESHOLD_LOG2: Threshold in log2-domain for block-level skip. If > -inf and block max is below threshold relative to running max, skip softmax update. + :param CHECK_INF: Boolean flag indicating if -inf row_max should be clamped to 0. + :return p: Softmax probabilities tensor of shape [BLOCK_M, BLOCK_N]. :return block_max_new: Updated block-wise maximum scalar. :return row_max_new: Updated maximum values per row of shape [BLOCK_M]. @@ -42,7 +98,7 @@ def online_softmax( # Update skip condition based on threshold block_max_diff_log2 = (block_max_curr - block_max) * scale_log2 - skip_softmax = block_max_diff_log2 < softmax_threshold_log2 + skip_softmax = block_max_diff_log2 < SOFTMAX_THRESHOLD_LOG2 # Return zero attention weights if skip_softmax: @@ -69,15 +125,7 @@ def online_softmax( acc_scale_log2 = (row_max - row_max_new) * scale_log2 # Compute row scale - if RESCALE_THRESHOLD > 0.0: - # Triton can only skip computation at block granularity - if tl.min(acc_scale_log2) < -RESCALE_THRESHOLD: - row_scale = tl.exp2(acc_scale_log2) - else: - row_max_new = row_max - row_scale = acc_scale_log2 * 0.0 + 1.0 - else: - row_scale = tl.exp2(acc_scale_log2) + row_scale = tl.exp2(acc_scale_log2) # Compute attention weights p = tl.exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2) From e7f411d4514321a0afb37c6322efd5b07acfa435 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 22 Apr 2026 15:49:40 +0800 Subject: [PATCH 12/14] Rename _flash_attn_fwd_combine to _flash_attn_dec_combine in forward functions for consistency --- flash_sparse_attn/ops/triton/flash_dense_fwd.py | 4 ++-- flash_sparse_attn/ops/triton/flash_gated_fwd.py | 4 ++-- flash_sparse_attn/ops/triton/flash_sparse_fwd.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/flash_sparse_attn/ops/triton/flash_dense_fwd.py b/flash_sparse_attn/ops/triton/flash_dense_fwd.py index 05fb342..c25d572 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_dense_fwd.py @@ -733,7 +733,7 @@ def _flash_dense_attn_base_forward( ) if is_split_kv: - flash_dec_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_dec_combine( out_partial, lse_partial, out, @@ -884,7 +884,7 @@ def _flash_dense_attn_varlen_base_forward( ) if is_split_kv: - flash_dec_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_dec_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/flash_gated_fwd.py b/flash_sparse_attn/ops/triton/flash_gated_fwd.py index 04cb2d6..b0b294e 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_fwd.py @@ -1094,7 +1094,7 @@ def _flash_gated_attn_base_forward( ) if is_split_kv: - flash_dec_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_dec_combine( out_partial, lse_partial, out, @@ -1265,7 +1265,7 @@ def _flash_gated_attn_varlen_base_forward( ) if is_split_kv: - flash_dec_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_dec_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py index ea52a8a..cf862f0 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py @@ -765,7 +765,7 @@ def _flash_sparse_attn_base_forward( ) if is_split_kv: - flash_dec_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_dec_combine( out_partial, lse_partial, out, @@ -919,7 +919,7 @@ def _flash_sparse_attn_varlen_base_forward( ) if is_split_kv: - flash_dec_combine._flash_attn_fwd_combine( + flash_dec_combine._flash_attn_dec_combine( out_partial, lse_partial, out, From 7ef17e61f72cb240fe413404d2306786801bb952 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 22 Apr 2026 15:52:45 +0800 Subject: [PATCH 13/14] Standardize parameter naming for softmax threshold in gated and sparse attention kernels --- flash_sparse_attn/ops/triton/flash_gated_fwd.py | 2 +- flash_sparse_attn/ops/triton/flash_sparse_fwd.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_sparse_attn/ops/triton/flash_gated_fwd.py b/flash_sparse_attn/ops/triton/flash_gated_fwd.py index b0b294e..0172d81 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_fwd.py @@ -105,7 +105,7 @@ def _fwd_inner_gated_base_kernel( row_max=row_max, row_sum=row_sum, scale_log2=softmax_scale_log2, - SOFTMAX_THRESHOLD_LOG2=softmax_threshold_log2, + softmax_threshold_log2=softmax_threshold_log2, CHECK_INF=CHECK_INF, ) ) diff --git a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py index cf862f0..e183b09 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py @@ -81,7 +81,7 @@ def _fwd_inner_sparse_base_kernel( row_max=row_max, row_sum=row_sum, scale_log2=softmax_scale_log2, - SOFTMAX_THRESHOLD_LOG2=softmax_threshold_log2, + softmax_threshold_log2=softmax_threshold_log2, CHECK_INF=CHECK_INF, ) ) From 41c6f0992a5e0bde39ff6e056d15e89a38ab62c5 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Wed, 22 Apr 2026 15:53:39 +0800 Subject: [PATCH 14/14] Standardize parameter naming for softmax threshold in online_sparse_softmax function Co-authored-by: Copilot --- flash_sparse_attn/ops/triton/activations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_sparse_attn/ops/triton/activations.py b/flash_sparse_attn/ops/triton/activations.py index 5dbcd8d..e5d5d9e 100644 --- a/flash_sparse_attn/ops/triton/activations.py +++ b/flash_sparse_attn/ops/triton/activations.py @@ -72,7 +72,7 @@ def online_sparse_softmax( row_max, row_sum, scale_log2, - SOFTMAX_THRESHOLD_LOG2: tl.constexpr, + softmax_threshold_log2, CHECK_INF: tl.constexpr, ): """ @@ -83,7 +83,7 @@ def online_sparse_softmax( :param row_max: Current maximum values per row of shape [BLOCK_M], init to -inf. :param row_sum: Current sum values per row of shape [BLOCK_M], init to 0. :param scale_log2: Log2 of the scaling factor to be applied to acc_s. - :param SOFTMAX_THRESHOLD_LOG2: Threshold in log2-domain for block-level skip. If > -inf and block max is below threshold relative to running max, skip softmax update. + :param softmax_threshold_log2: Threshold in log2-domain for block-level skip. If > -inf and block max is below threshold relative to running max, skip softmax update. :param CHECK_INF: Boolean flag indicating if -inf row_max should be clamped to 0. :return p: Softmax probabilities tensor of shape [BLOCK_M, BLOCK_N]. @@ -98,7 +98,7 @@ def online_sparse_softmax( # Update skip condition based on threshold block_max_diff_log2 = (block_max_curr - block_max) * scale_log2 - skip_softmax = block_max_diff_log2 < SOFTMAX_THRESHOLD_LOG2 + skip_softmax = block_max_diff_log2 < softmax_threshold_log2 # Return zero attention weights if skip_softmax: