From 6de62bf363af85e7d045ba9fcb2a2ae9bc06b33e Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 21 Apr 2026 17:25:16 +0800 Subject: [PATCH 1/3] Implement Flash Sparse Attention Forward Kernel and Update Interfaces - Added the `flash_sparse_forward.py` file containing the implementation of the Flash Sparse Attention forward kernel, including the `_fwd_inner_sparse_base_kernel` and `_fwd_base_sparse_kernel` functions. - Updated the `interface.py` file to import the new Flash Sparse Attention forward functions from the newly created `flash_sparse_forward` module. - Renamed imports for consistency, changing `flash_dense_fwd` to `flash_dense_forward`, `flash_dense_bwd` to `flash_dense_backward`, `flash_sparse_fwd` to `flash_sparse_forward`, `flash_sparse_bwd` to `flash_sparse_backward`, `flash_gated_fwd` to `flash_gated_forward`, and `flash_gated_bwd` to `flash_gated_backward`. --- ..._postprocess.py => flash_backward_postprocess.py} | 0 ...wd_preprocess.py => flash_backward_preprocess.py} | 0 .../{flash_fwd_combine.py => flash_combine.py} | 0 .../{flash_dense_bwd.py => flash_dense_backward.py} | 12 ++++++------ .../{flash_dense_fwd.py => flash_dense_forward.py} | 6 +++--- .../{flash_gated_bwd.py => flash_gated_backward.py} | 12 ++++++------ .../{flash_gated_fwd.py => flash_gated_forward.py} | 6 +++--- ...{flash_sparse_bwd.py => flash_sparse_backward.py} | 12 ++++++------ .../{flash_sparse_fwd.py => flash_sparse_forward.py} | 6 +++--- flash_sparse_attn/ops/triton/interface.py | 12 ++++++------ 10 files changed, 33 insertions(+), 33 deletions(-) rename flash_sparse_attn/ops/triton/{flash_bwd_postprocess.py => flash_backward_postprocess.py} (100%) rename flash_sparse_attn/ops/triton/{flash_bwd_preprocess.py => flash_backward_preprocess.py} (100%) rename flash_sparse_attn/ops/triton/{flash_fwd_combine.py => flash_combine.py} (100%) rename flash_sparse_attn/ops/triton/{flash_dense_bwd.py => flash_dense_backward.py} (98%) rename flash_sparse_attn/ops/triton/{flash_dense_fwd.py => flash_dense_forward.py} (99%) rename flash_sparse_attn/ops/triton/{flash_gated_bwd.py => flash_gated_backward.py} (99%) rename flash_sparse_attn/ops/triton/{flash_gated_fwd.py => flash_gated_forward.py} (99%) rename flash_sparse_attn/ops/triton/{flash_sparse_bwd.py => flash_sparse_backward.py} (99%) rename flash_sparse_attn/ops/triton/{flash_sparse_fwd.py => flash_sparse_forward.py} (99%) diff --git a/flash_sparse_attn/ops/triton/flash_bwd_postprocess.py b/flash_sparse_attn/ops/triton/flash_backward_postprocess.py similarity index 100% rename from flash_sparse_attn/ops/triton/flash_bwd_postprocess.py rename to flash_sparse_attn/ops/triton/flash_backward_postprocess.py diff --git a/flash_sparse_attn/ops/triton/flash_bwd_preprocess.py b/flash_sparse_attn/ops/triton/flash_backward_preprocess.py similarity index 100% rename from flash_sparse_attn/ops/triton/flash_bwd_preprocess.py rename to flash_sparse_attn/ops/triton/flash_backward_preprocess.py diff --git a/flash_sparse_attn/ops/triton/flash_fwd_combine.py b/flash_sparse_attn/ops/triton/flash_combine.py similarity index 100% rename from flash_sparse_attn/ops/triton/flash_fwd_combine.py rename to flash_sparse_attn/ops/triton/flash_combine.py diff --git a/flash_sparse_attn/ops/triton/flash_dense_bwd.py b/flash_sparse_attn/ops/triton/flash_dense_backward.py similarity index 98% rename from flash_sparse_attn/ops/triton/flash_dense_bwd.py rename to flash_sparse_attn/ops/triton/flash_dense_backward.py index 3a09b7d2..85e88002 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_dense_backward.py @@ -7,13 +7,13 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, + flash_backward_postprocess, + flash_backward_preprocess, launch_template, launch_grid, seqlen_info, block_info, mask, - flash_bwd_preprocess, - flash_bwd_postprocess, ) @@ -712,7 +712,7 @@ def _flash_dense_attn_base_backward( device=query.device, ) - flash_bwd_preprocess._flash_attn_bwd_preprocess( + flash_backward_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -793,7 +793,7 @@ def _flash_dense_attn_base_backward( num_ctas=num_ctas, ) - flash_bwd_postprocess._flash_attn_bwd_postprocess( + flash_backward_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, scale=softmax_scale, @@ -901,7 +901,7 @@ def _flash_dense_attn_varlen_base_backward( device=query.device, ) - flash_bwd_preprocess._flash_attn_bwd_preprocess( + flash_backward_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -985,7 +985,7 @@ def _flash_dense_attn_varlen_base_backward( num_ctas=num_ctas, ) - flash_bwd_postprocess._flash_attn_bwd_postprocess( + flash_backward_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, scale=softmax_scale, diff --git a/flash_sparse_attn/ops/triton/flash_dense_fwd.py b/flash_sparse_attn/ops/triton/flash_dense_forward.py similarity index 99% rename from flash_sparse_attn/ops/triton/flash_dense_fwd.py rename to flash_sparse_attn/ops/triton/flash_dense_forward.py index d60cd040..2c9a289d 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_dense_forward.py @@ -7,6 +7,7 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, + flash_combine, utils, launch_template, launch_grid, @@ -14,7 +15,6 @@ block_info, activations, mask, - flash_fwd_combine, ) @@ -743,7 +743,7 @@ def _flash_dense_attn_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, @@ -894,7 +894,7 @@ def _flash_dense_attn_varlen_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/flash_gated_bwd.py b/flash_sparse_attn/ops/triton/flash_gated_backward.py similarity index 99% rename from flash_sparse_attn/ops/triton/flash_gated_bwd.py rename to flash_sparse_attn/ops/triton/flash_gated_backward.py index c82df8ea..3da2b240 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_backward.py @@ -7,14 +7,14 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, + flash_backward_postprocess, + flash_backward_preprocess, launch_template, launch_grid, seqlen_info, block_info, activations, mask, - flash_bwd_preprocess, - flash_bwd_postprocess, ) @@ -1101,7 +1101,7 @@ def _flash_gated_attn_base_backward( device=query.device, ) - flash_bwd_preprocess._flash_attn_bwd_preprocess( + flash_backward_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -1198,7 +1198,7 @@ def _flash_gated_attn_base_backward( num_ctas=num_ctas, ) - flash_bwd_postprocess._flash_attn_bwd_postprocess( + flash_backward_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, da_accum=da_accum, @@ -1324,7 +1324,7 @@ def _flash_gated_attn_varlen_base_backward( device=query.device, ) - flash_bwd_preprocess._flash_attn_bwd_preprocess( + flash_backward_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -1424,7 +1424,7 @@ def _flash_gated_attn_varlen_base_backward( num_ctas=num_ctas, ) - flash_bwd_postprocess._flash_attn_bwd_postprocess( + flash_backward_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, da_accum=da_accum, diff --git a/flash_sparse_attn/ops/triton/flash_gated_fwd.py b/flash_sparse_attn/ops/triton/flash_gated_forward.py similarity index 99% rename from flash_sparse_attn/ops/triton/flash_gated_fwd.py rename to flash_sparse_attn/ops/triton/flash_gated_forward.py index 26f7642b..a4e5ff22 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_forward.py @@ -7,6 +7,7 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, + flash_combine, utils, launch_template, launch_grid, @@ -14,7 +15,6 @@ block_info, activations, mask, - flash_fwd_combine, ) @@ -1095,7 +1095,7 @@ def _flash_gated_attn_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, @@ -1266,7 +1266,7 @@ def _flash_gated_attn_varlen_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/flash_sparse_bwd.py b/flash_sparse_attn/ops/triton/flash_sparse_backward.py similarity index 99% rename from flash_sparse_attn/ops/triton/flash_sparse_bwd.py rename to flash_sparse_attn/ops/triton/flash_sparse_backward.py index 2f58ca21..addfdbe2 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_backward.py @@ -7,13 +7,13 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, + flash_backward_postprocess, + flash_backward_preprocess, launch_template, launch_grid, seqlen_info, block_info, mask, - flash_bwd_preprocess, - flash_bwd_postprocess, ) @@ -773,7 +773,7 @@ def _flash_sparse_attn_base_backward( device=query.device, ) - flash_bwd_preprocess._flash_attn_bwd_preprocess( + flash_backward_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -855,7 +855,7 @@ def _flash_sparse_attn_base_backward( num_ctas=num_ctas, ) - flash_bwd_postprocess._flash_attn_bwd_postprocess( + flash_backward_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, scale=softmax_scale, @@ -965,7 +965,7 @@ def _flash_sparse_attn_varlen_base_backward( device=query.device, ) - flash_bwd_preprocess._flash_attn_bwd_preprocess( + flash_backward_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -1050,7 +1050,7 @@ def _flash_sparse_attn_varlen_base_backward( num_ctas=num_ctas, ) - flash_bwd_postprocess._flash_attn_bwd_postprocess( + flash_backward_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, scale=softmax_scale, diff --git a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py b/flash_sparse_attn/ops/triton/flash_sparse_forward.py similarity index 99% rename from flash_sparse_attn/ops/triton/flash_sparse_fwd.py rename to flash_sparse_attn/ops/triton/flash_sparse_forward.py index 1269d538..fa653807 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_forward.py @@ -7,6 +7,7 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, + flash_combine, utils, launch_template, launch_grid, @@ -14,7 +15,6 @@ block_info, activations, mask, - flash_fwd_combine, ) @@ -766,7 +766,7 @@ def _flash_sparse_attn_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, @@ -920,7 +920,7 @@ def _flash_sparse_attn_varlen_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/interface.py b/flash_sparse_attn/ops/triton/interface.py index aeb1f656..43638f6a 100644 --- a/flash_sparse_attn/ops/triton/interface.py +++ b/flash_sparse_attn/ops/triton/interface.py @@ -2,27 +2,27 @@ import torch -from flash_sparse_attn.ops.triton.flash_dense_fwd import ( +from flash_sparse_attn.ops.triton.flash_dense_forward import ( _flash_dense_attn_base_forward, _flash_dense_attn_varlen_base_forward, ) -from flash_sparse_attn.ops.triton.flash_dense_bwd import ( +from flash_sparse_attn.ops.triton.flash_dense_backward import ( _flash_dense_attn_base_backward, _flash_dense_attn_varlen_base_backward, ) -from flash_sparse_attn.ops.triton.flash_sparse_fwd import ( +from flash_sparse_attn.ops.triton.flash_sparse_forward import ( _flash_sparse_attn_base_forward, _flash_sparse_attn_varlen_base_forward, ) -from flash_sparse_attn.ops.triton.flash_sparse_bwd import ( +from flash_sparse_attn.ops.triton.flash_sparse_backward import ( _flash_sparse_attn_base_backward, _flash_sparse_attn_varlen_base_backward, ) -from flash_sparse_attn.ops.triton.flash_gated_fwd import ( +from flash_sparse_attn.ops.triton.flash_gated_forward import ( _flash_gated_attn_base_forward, _flash_gated_attn_varlen_base_forward, ) -from flash_sparse_attn.ops.triton.flash_gated_bwd import ( +from flash_sparse_attn.ops.triton.flash_gated_backward import ( _flash_gated_attn_base_backward, _flash_gated_attn_varlen_base_backward, ) From f39993f4014824220fef82ec9459dbc41518c5fa Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 21 Apr 2026 17:25:32 +0800 Subject: [PATCH 2/3] Fix import paths in test_utils.py for backward and forward functions --- tests/test_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 603f8cf5..5c1bb966 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,27 +4,27 @@ import torch -from flash_sparse_attn.ops.triton.flash_dense_bwd import ( +from flash_sparse_attn.ops.triton.flash_dense_backward import ( _flash_dense_attn_base_backward, _flash_dense_attn_varlen_base_backward, ) -from flash_sparse_attn.ops.triton.flash_dense_fwd import ( +from flash_sparse_attn.ops.triton.flash_dense_forward import ( _flash_dense_attn_base_forward, _flash_dense_attn_varlen_base_forward, ) -from flash_sparse_attn.ops.triton.flash_gated_bwd import ( +from flash_sparse_attn.ops.triton.flash_gated_backward import ( _flash_gated_attn_base_backward, _flash_gated_attn_varlen_base_backward, ) -from flash_sparse_attn.ops.triton.flash_gated_fwd import ( +from flash_sparse_attn.ops.triton.flash_gated_forward import ( _flash_gated_attn_base_forward, _flash_gated_attn_varlen_base_forward, ) -from flash_sparse_attn.ops.triton.flash_sparse_bwd import ( +from flash_sparse_attn.ops.triton.flash_sparse_backward import ( _flash_sparse_attn_base_backward, _flash_sparse_attn_varlen_base_backward, ) -from flash_sparse_attn.ops.triton.flash_sparse_fwd import ( +from flash_sparse_attn.ops.triton.flash_sparse_forward import ( _flash_sparse_attn_base_forward, _flash_sparse_attn_varlen_base_forward, ) From 6ce30f9398305f6323625eb7164b33e174710bde Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 21 Apr 2026 17:29:23 +0800 Subject: [PATCH 3/3] Refactor flash_combine references to flash_decode_combine in forward functions --- .../triton/{flash_combine.py => flash_decode_combine.py} | 0 flash_sparse_attn/ops/triton/flash_dense_forward.py | 6 +++--- flash_sparse_attn/ops/triton/flash_gated_forward.py | 6 +++--- flash_sparse_attn/ops/triton/flash_sparse_forward.py | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) rename flash_sparse_attn/ops/triton/{flash_combine.py => flash_decode_combine.py} (100%) diff --git a/flash_sparse_attn/ops/triton/flash_combine.py b/flash_sparse_attn/ops/triton/flash_decode_combine.py similarity index 100% rename from flash_sparse_attn/ops/triton/flash_combine.py rename to flash_sparse_attn/ops/triton/flash_decode_combine.py diff --git a/flash_sparse_attn/ops/triton/flash_dense_forward.py b/flash_sparse_attn/ops/triton/flash_dense_forward.py index 2c9a289d..601a700c 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_forward.py +++ b/flash_sparse_attn/ops/triton/flash_dense_forward.py @@ -7,7 +7,7 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, - flash_combine, + flash_decode_combine, utils, launch_template, launch_grid, @@ -743,7 +743,7 @@ def _flash_dense_attn_base_forward( ) if is_split_kv: - flash_combine._flash_attn_fwd_combine( + flash_decode_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, @@ -894,7 +894,7 @@ def _flash_dense_attn_varlen_base_forward( ) if is_split_kv: - flash_combine._flash_attn_fwd_combine( + flash_decode_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/flash_gated_forward.py b/flash_sparse_attn/ops/triton/flash_gated_forward.py index a4e5ff22..3f9bc481 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_forward.py +++ b/flash_sparse_attn/ops/triton/flash_gated_forward.py @@ -7,7 +7,7 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, - flash_combine, + flash_decode_combine, utils, launch_template, launch_grid, @@ -1095,7 +1095,7 @@ def _flash_gated_attn_base_forward( ) if is_split_kv: - flash_combine._flash_attn_fwd_combine( + flash_decode_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, @@ -1266,7 +1266,7 @@ def _flash_gated_attn_varlen_base_forward( ) if is_split_kv: - flash_combine._flash_attn_fwd_combine( + flash_decode_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/flash_sparse_forward.py b/flash_sparse_attn/ops/triton/flash_sparse_forward.py index fa653807..5db46941 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_forward.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_forward.py @@ -7,7 +7,7 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, - flash_combine, + flash_decode_combine, utils, launch_template, launch_grid, @@ -766,7 +766,7 @@ def _flash_sparse_attn_base_forward( ) if is_split_kv: - flash_combine._flash_attn_fwd_combine( + flash_decode_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, @@ -920,7 +920,7 @@ def _flash_sparse_attn_varlen_base_forward( ) if is_split_kv: - flash_combine._flash_attn_fwd_combine( + flash_decode_combine._flash_attn_fwd_combine( out_partial, lse_partial, out,