diff --git a/flash_sparse_attn/ops/triton/flash_backward_postprocess.py b/flash_sparse_attn/ops/triton/flash_bwd_postprocess.py similarity index 100% rename from flash_sparse_attn/ops/triton/flash_backward_postprocess.py rename to flash_sparse_attn/ops/triton/flash_bwd_postprocess.py diff --git a/flash_sparse_attn/ops/triton/flash_backward_preprocess.py b/flash_sparse_attn/ops/triton/flash_bwd_preprocess.py similarity index 100% rename from flash_sparse_attn/ops/triton/flash_backward_preprocess.py rename to flash_sparse_attn/ops/triton/flash_bwd_preprocess.py diff --git a/flash_sparse_attn/ops/triton/flash_dense_backward.py b/flash_sparse_attn/ops/triton/flash_dense_bwd.py similarity index 98% rename from flash_sparse_attn/ops/triton/flash_dense_backward.py rename to flash_sparse_attn/ops/triton/flash_dense_bwd.py index 85e88002..3a09b7d2 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_backward.py +++ b/flash_sparse_attn/ops/triton/flash_dense_bwd.py @@ -7,13 +7,13 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, - flash_backward_postprocess, - flash_backward_preprocess, launch_template, launch_grid, seqlen_info, block_info, mask, + flash_bwd_preprocess, + flash_bwd_postprocess, ) @@ -712,7 +712,7 @@ def _flash_dense_attn_base_backward( device=query.device, ) - flash_backward_preprocess._flash_attn_bwd_preprocess( + flash_bwd_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -793,7 +793,7 @@ def _flash_dense_attn_base_backward( num_ctas=num_ctas, ) - flash_backward_postprocess._flash_attn_bwd_postprocess( + flash_bwd_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, scale=softmax_scale, @@ -901,7 +901,7 @@ def _flash_dense_attn_varlen_base_backward( device=query.device, ) - flash_backward_preprocess._flash_attn_bwd_preprocess( + flash_bwd_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -985,7 +985,7 @@ def _flash_dense_attn_varlen_base_backward( num_ctas=num_ctas, ) - flash_backward_postprocess._flash_attn_bwd_postprocess( + flash_bwd_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, scale=softmax_scale, diff --git a/flash_sparse_attn/ops/triton/flash_dense_forward.py b/flash_sparse_attn/ops/triton/flash_dense_fwd.py similarity index 99% rename from flash_sparse_attn/ops/triton/flash_dense_forward.py rename to flash_sparse_attn/ops/triton/flash_dense_fwd.py index 601a700c..d60cd040 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_forward.py +++ b/flash_sparse_attn/ops/triton/flash_dense_fwd.py @@ -7,7 +7,6 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, - flash_decode_combine, utils, launch_template, launch_grid, @@ -15,6 +14,7 @@ block_info, activations, mask, + flash_fwd_combine, ) @@ -743,7 +743,7 @@ def _flash_dense_attn_base_forward( ) if is_split_kv: - flash_decode_combine._flash_attn_fwd_combine( + flash_fwd_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, @@ -894,7 +894,7 @@ def _flash_dense_attn_varlen_base_forward( ) if is_split_kv: - flash_decode_combine._flash_attn_fwd_combine( + flash_fwd_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/flash_decode_combine.py b/flash_sparse_attn/ops/triton/flash_fwd_combine.py similarity index 100% rename from flash_sparse_attn/ops/triton/flash_decode_combine.py rename to flash_sparse_attn/ops/triton/flash_fwd_combine.py diff --git a/flash_sparse_attn/ops/triton/flash_gated_backward.py b/flash_sparse_attn/ops/triton/flash_gated_bwd.py similarity index 99% rename from flash_sparse_attn/ops/triton/flash_gated_backward.py rename to flash_sparse_attn/ops/triton/flash_gated_bwd.py index 3da2b240..c82df8ea 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_backward.py +++ b/flash_sparse_attn/ops/triton/flash_gated_bwd.py @@ -7,14 +7,14 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, - flash_backward_postprocess, - flash_backward_preprocess, launch_template, launch_grid, seqlen_info, block_info, activations, mask, + flash_bwd_preprocess, + flash_bwd_postprocess, ) @@ -1101,7 +1101,7 @@ def _flash_gated_attn_base_backward( device=query.device, ) - flash_backward_preprocess._flash_attn_bwd_preprocess( + flash_bwd_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -1198,7 +1198,7 @@ def _flash_gated_attn_base_backward( num_ctas=num_ctas, ) - flash_backward_postprocess._flash_attn_bwd_postprocess( + flash_bwd_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, da_accum=da_accum, @@ -1324,7 +1324,7 @@ def _flash_gated_attn_varlen_base_backward( device=query.device, ) - flash_backward_preprocess._flash_attn_bwd_preprocess( + flash_bwd_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -1424,7 +1424,7 @@ def _flash_gated_attn_varlen_base_backward( num_ctas=num_ctas, ) - flash_backward_postprocess._flash_attn_bwd_postprocess( + flash_bwd_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, da_accum=da_accum, diff --git a/flash_sparse_attn/ops/triton/flash_gated_forward.py b/flash_sparse_attn/ops/triton/flash_gated_fwd.py similarity index 99% rename from flash_sparse_attn/ops/triton/flash_gated_forward.py rename to flash_sparse_attn/ops/triton/flash_gated_fwd.py index 3f9bc481..26f7642b 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_forward.py +++ b/flash_sparse_attn/ops/triton/flash_gated_fwd.py @@ -7,7 +7,6 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, - flash_decode_combine, utils, launch_template, launch_grid, @@ -15,6 +14,7 @@ block_info, activations, mask, + flash_fwd_combine, ) @@ -1095,7 +1095,7 @@ def _flash_gated_attn_base_forward( ) if is_split_kv: - flash_decode_combine._flash_attn_fwd_combine( + flash_fwd_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, @@ -1266,7 +1266,7 @@ def _flash_gated_attn_varlen_base_forward( ) if is_split_kv: - flash_decode_combine._flash_attn_fwd_combine( + flash_fwd_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/flash_sparse_backward.py b/flash_sparse_attn/ops/triton/flash_sparse_bwd.py similarity index 99% rename from flash_sparse_attn/ops/triton/flash_sparse_backward.py rename to flash_sparse_attn/ops/triton/flash_sparse_bwd.py index addfdbe2..2f58ca21 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_backward.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_bwd.py @@ -7,13 +7,13 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, - flash_backward_postprocess, - flash_backward_preprocess, launch_template, launch_grid, seqlen_info, block_info, mask, + flash_bwd_preprocess, + flash_bwd_postprocess, ) @@ -773,7 +773,7 @@ def _flash_sparse_attn_base_backward( device=query.device, ) - flash_backward_preprocess._flash_attn_bwd_preprocess( + flash_bwd_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -855,7 +855,7 @@ def _flash_sparse_attn_base_backward( num_ctas=num_ctas, ) - flash_backward_postprocess._flash_attn_bwd_postprocess( + flash_bwd_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, scale=softmax_scale, @@ -965,7 +965,7 @@ def _flash_sparse_attn_varlen_base_backward( device=query.device, ) - flash_backward_preprocess._flash_attn_bwd_preprocess( + flash_bwd_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -1050,7 +1050,7 @@ def _flash_sparse_attn_varlen_base_backward( num_ctas=num_ctas, ) - flash_backward_postprocess._flash_attn_bwd_postprocess( + flash_bwd_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, scale=softmax_scale, diff --git a/flash_sparse_attn/ops/triton/flash_sparse_forward.py b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py similarity index 99% rename from flash_sparse_attn/ops/triton/flash_sparse_forward.py rename to flash_sparse_attn/ops/triton/flash_sparse_fwd.py index 5db46941..1269d538 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_forward.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py @@ -7,7 +7,6 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, - flash_decode_combine, utils, launch_template, launch_grid, @@ -15,6 +14,7 @@ block_info, activations, mask, + flash_fwd_combine, ) @@ -766,7 +766,7 @@ def _flash_sparse_attn_base_forward( ) if is_split_kv: - flash_decode_combine._flash_attn_fwd_combine( + flash_fwd_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, @@ -920,7 +920,7 @@ def _flash_sparse_attn_varlen_base_forward( ) if is_split_kv: - flash_decode_combine._flash_attn_fwd_combine( + flash_fwd_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/interface.py b/flash_sparse_attn/ops/triton/interface.py index 43638f6a..aeb1f656 100644 --- a/flash_sparse_attn/ops/triton/interface.py +++ b/flash_sparse_attn/ops/triton/interface.py @@ -2,27 +2,27 @@ import torch -from flash_sparse_attn.ops.triton.flash_dense_forward import ( +from flash_sparse_attn.ops.triton.flash_dense_fwd import ( _flash_dense_attn_base_forward, _flash_dense_attn_varlen_base_forward, ) -from flash_sparse_attn.ops.triton.flash_dense_backward import ( +from flash_sparse_attn.ops.triton.flash_dense_bwd import ( _flash_dense_attn_base_backward, _flash_dense_attn_varlen_base_backward, ) -from flash_sparse_attn.ops.triton.flash_sparse_forward import ( +from flash_sparse_attn.ops.triton.flash_sparse_fwd import ( _flash_sparse_attn_base_forward, _flash_sparse_attn_varlen_base_forward, ) -from flash_sparse_attn.ops.triton.flash_sparse_backward import ( +from flash_sparse_attn.ops.triton.flash_sparse_bwd import ( _flash_sparse_attn_base_backward, _flash_sparse_attn_varlen_base_backward, ) -from flash_sparse_attn.ops.triton.flash_gated_forward import ( +from flash_sparse_attn.ops.triton.flash_gated_fwd import ( _flash_gated_attn_base_forward, _flash_gated_attn_varlen_base_forward, ) -from flash_sparse_attn.ops.triton.flash_gated_backward import ( +from flash_sparse_attn.ops.triton.flash_gated_bwd import ( _flash_gated_attn_base_backward, _flash_gated_attn_varlen_base_backward, ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 5c1bb966..603f8cf5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,27 +4,27 @@ import torch -from flash_sparse_attn.ops.triton.flash_dense_backward import ( +from flash_sparse_attn.ops.triton.flash_dense_bwd import ( _flash_dense_attn_base_backward, _flash_dense_attn_varlen_base_backward, ) -from flash_sparse_attn.ops.triton.flash_dense_forward import ( +from flash_sparse_attn.ops.triton.flash_dense_fwd import ( _flash_dense_attn_base_forward, _flash_dense_attn_varlen_base_forward, ) -from flash_sparse_attn.ops.triton.flash_gated_backward import ( +from flash_sparse_attn.ops.triton.flash_gated_bwd import ( _flash_gated_attn_base_backward, _flash_gated_attn_varlen_base_backward, ) -from flash_sparse_attn.ops.triton.flash_gated_forward import ( +from flash_sparse_attn.ops.triton.flash_gated_fwd import ( _flash_gated_attn_base_forward, _flash_gated_attn_varlen_base_forward, ) -from flash_sparse_attn.ops.triton.flash_sparse_backward import ( +from flash_sparse_attn.ops.triton.flash_sparse_bwd import ( _flash_sparse_attn_base_backward, _flash_sparse_attn_varlen_base_backward, ) -from flash_sparse_attn.ops.triton.flash_sparse_forward import ( +from flash_sparse_attn.ops.triton.flash_sparse_fwd import ( _flash_sparse_attn_base_forward, _flash_sparse_attn_varlen_base_forward, )