diff --git a/flash_sparse_attn/ops/triton/flash_bwd_postprocess.py b/flash_sparse_attn/ops/triton/flash_backward_postprocess.py similarity index 100% rename from flash_sparse_attn/ops/triton/flash_bwd_postprocess.py rename to flash_sparse_attn/ops/triton/flash_backward_postprocess.py diff --git a/flash_sparse_attn/ops/triton/flash_bwd_preprocess.py b/flash_sparse_attn/ops/triton/flash_backward_preprocess.py similarity index 100% rename from flash_sparse_attn/ops/triton/flash_bwd_preprocess.py rename to flash_sparse_attn/ops/triton/flash_backward_preprocess.py diff --git a/flash_sparse_attn/ops/triton/flash_fwd_combine.py b/flash_sparse_attn/ops/triton/flash_decode_combine.py similarity index 100% rename from flash_sparse_attn/ops/triton/flash_fwd_combine.py rename to flash_sparse_attn/ops/triton/flash_decode_combine.py diff --git a/flash_sparse_attn/ops/triton/flash_dense_bwd.py b/flash_sparse_attn/ops/triton/flash_dense_backward.py similarity index 98% rename from flash_sparse_attn/ops/triton/flash_dense_bwd.py rename to flash_sparse_attn/ops/triton/flash_dense_backward.py index 3a09b7d2..85e88002 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_dense_backward.py @@ -7,13 +7,13 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, + flash_backward_postprocess, + flash_backward_preprocess, launch_template, launch_grid, seqlen_info, block_info, mask, - flash_bwd_preprocess, - flash_bwd_postprocess, ) @@ -712,7 +712,7 @@ def _flash_dense_attn_base_backward( device=query.device, ) - flash_bwd_preprocess._flash_attn_bwd_preprocess( + flash_backward_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -793,7 +793,7 @@ def _flash_dense_attn_base_backward( num_ctas=num_ctas, ) - flash_bwd_postprocess._flash_attn_bwd_postprocess( + flash_backward_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, scale=softmax_scale, @@ -901,7 +901,7 @@ def _flash_dense_attn_varlen_base_backward( device=query.device, ) - flash_bwd_preprocess._flash_attn_bwd_preprocess( + flash_backward_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -985,7 +985,7 @@ def _flash_dense_attn_varlen_base_backward( num_ctas=num_ctas, ) - flash_bwd_postprocess._flash_attn_bwd_postprocess( + flash_backward_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, scale=softmax_scale, diff --git a/flash_sparse_attn/ops/triton/flash_dense_fwd.py b/flash_sparse_attn/ops/triton/flash_dense_forward.py similarity index 99% rename from flash_sparse_attn/ops/triton/flash_dense_fwd.py rename to flash_sparse_attn/ops/triton/flash_dense_forward.py index d60cd040..601a700c 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_dense_forward.py @@ -7,6 +7,7 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, + flash_decode_combine, utils, launch_template, launch_grid, @@ -14,7 +15,6 @@ block_info, activations, mask, - flash_fwd_combine, ) @@ -743,7 +743,7 @@ def _flash_dense_attn_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_decode_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, @@ -894,7 +894,7 @@ def _flash_dense_attn_varlen_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_decode_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/flash_gated_bwd.py b/flash_sparse_attn/ops/triton/flash_gated_backward.py similarity index 99% rename from flash_sparse_attn/ops/triton/flash_gated_bwd.py rename to flash_sparse_attn/ops/triton/flash_gated_backward.py index c82df8ea..3da2b240 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_backward.py @@ -7,14 +7,14 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, + flash_backward_postprocess, + flash_backward_preprocess, launch_template, launch_grid, seqlen_info, block_info, activations, mask, - flash_bwd_preprocess, - flash_bwd_postprocess, ) @@ -1101,7 +1101,7 @@ def _flash_gated_attn_base_backward( device=query.device, ) - flash_bwd_preprocess._flash_attn_bwd_preprocess( + flash_backward_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -1198,7 +1198,7 @@ def _flash_gated_attn_base_backward( num_ctas=num_ctas, ) - flash_bwd_postprocess._flash_attn_bwd_postprocess( + flash_backward_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, da_accum=da_accum, @@ -1324,7 +1324,7 @@ def _flash_gated_attn_varlen_base_backward( device=query.device, ) - flash_bwd_preprocess._flash_attn_bwd_preprocess( + flash_backward_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -1424,7 +1424,7 @@ def _flash_gated_attn_varlen_base_backward( num_ctas=num_ctas, ) - flash_bwd_postprocess._flash_attn_bwd_postprocess( + flash_backward_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, da_accum=da_accum, diff --git a/flash_sparse_attn/ops/triton/flash_gated_fwd.py b/flash_sparse_attn/ops/triton/flash_gated_forward.py similarity index 99% rename from flash_sparse_attn/ops/triton/flash_gated_fwd.py rename to flash_sparse_attn/ops/triton/flash_gated_forward.py index 26f7642b..3f9bc481 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_forward.py @@ -7,6 +7,7 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, + flash_decode_combine, utils, launch_template, launch_grid, @@ -14,7 +15,6 @@ block_info, activations, mask, - flash_fwd_combine, ) @@ -1095,7 +1095,7 @@ def _flash_gated_attn_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_decode_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, @@ -1266,7 +1266,7 @@ def _flash_gated_attn_varlen_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_decode_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/flash_sparse_bwd.py b/flash_sparse_attn/ops/triton/flash_sparse_backward.py similarity index 99% rename from flash_sparse_attn/ops/triton/flash_sparse_bwd.py rename to flash_sparse_attn/ops/triton/flash_sparse_backward.py index 2f58ca21..addfdbe2 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_backward.py @@ -7,13 +7,13 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, + flash_backward_postprocess, + flash_backward_preprocess, launch_template, launch_grid, seqlen_info, block_info, mask, - flash_bwd_preprocess, - flash_bwd_postprocess, ) @@ -773,7 +773,7 @@ def _flash_sparse_attn_base_backward( device=query.device, ) - flash_bwd_preprocess._flash_attn_bwd_preprocess( + flash_backward_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -855,7 +855,7 @@ def _flash_sparse_attn_base_backward( num_ctas=num_ctas, ) - flash_bwd_postprocess._flash_attn_bwd_postprocess( + flash_backward_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, scale=softmax_scale, @@ -965,7 +965,7 @@ def _flash_sparse_attn_varlen_base_backward( device=query.device, ) - flash_bwd_preprocess._flash_attn_bwd_preprocess( + flash_backward_preprocess._flash_attn_bwd_preprocess( out=out, dout=dout, dpsum=dpsum, @@ -1050,7 +1050,7 @@ def _flash_sparse_attn_varlen_base_backward( num_ctas=num_ctas, ) - flash_bwd_postprocess._flash_attn_bwd_postprocess( + flash_backward_postprocess._flash_attn_bwd_postprocess( dq_accum=dq_accum, dq=dq, scale=softmax_scale, diff --git a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py b/flash_sparse_attn/ops/triton/flash_sparse_forward.py similarity index 99% rename from flash_sparse_attn/ops/triton/flash_sparse_fwd.py rename to flash_sparse_attn/ops/triton/flash_sparse_forward.py index 1269d538..5db46941 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_forward.py @@ -7,6 +7,7 @@ from flash_sparse_attn.ops.triton import ( assert_inputs, + flash_decode_combine, utils, launch_template, launch_grid, @@ -14,7 +15,6 @@ block_info, activations, mask, - flash_fwd_combine, ) @@ -766,7 +766,7 @@ def _flash_sparse_attn_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_decode_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, @@ -920,7 +920,7 @@ def _flash_sparse_attn_varlen_base_forward( ) if is_split_kv: - flash_fwd_combine._flash_attn_fwd_combine( + flash_decode_combine._flash_attn_fwd_combine( out_partial, lse_partial, out, diff --git a/flash_sparse_attn/ops/triton/interface.py b/flash_sparse_attn/ops/triton/interface.py index aeb1f656..43638f6a 100644 --- a/flash_sparse_attn/ops/triton/interface.py +++ b/flash_sparse_attn/ops/triton/interface.py @@ -2,27 +2,27 @@ import torch -from flash_sparse_attn.ops.triton.flash_dense_fwd import ( +from flash_sparse_attn.ops.triton.flash_dense_forward import ( _flash_dense_attn_base_forward, _flash_dense_attn_varlen_base_forward, ) -from flash_sparse_attn.ops.triton.flash_dense_bwd import ( +from flash_sparse_attn.ops.triton.flash_dense_backward import ( _flash_dense_attn_base_backward, _flash_dense_attn_varlen_base_backward, ) -from flash_sparse_attn.ops.triton.flash_sparse_fwd import ( +from flash_sparse_attn.ops.triton.flash_sparse_forward import ( _flash_sparse_attn_base_forward, _flash_sparse_attn_varlen_base_forward, ) -from flash_sparse_attn.ops.triton.flash_sparse_bwd import ( +from flash_sparse_attn.ops.triton.flash_sparse_backward import ( _flash_sparse_attn_base_backward, _flash_sparse_attn_varlen_base_backward, ) -from flash_sparse_attn.ops.triton.flash_gated_fwd import ( +from flash_sparse_attn.ops.triton.flash_gated_forward import ( _flash_gated_attn_base_forward, _flash_gated_attn_varlen_base_forward, ) -from flash_sparse_attn.ops.triton.flash_gated_bwd import ( +from flash_sparse_attn.ops.triton.flash_gated_backward import ( _flash_gated_attn_base_backward, _flash_gated_attn_varlen_base_backward, ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 603f8cf5..5c1bb966 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,27 +4,27 @@ import torch -from flash_sparse_attn.ops.triton.flash_dense_bwd import ( +from flash_sparse_attn.ops.triton.flash_dense_backward import ( _flash_dense_attn_base_backward, _flash_dense_attn_varlen_base_backward, ) -from flash_sparse_attn.ops.triton.flash_dense_fwd import ( +from flash_sparse_attn.ops.triton.flash_dense_forward import ( _flash_dense_attn_base_forward, _flash_dense_attn_varlen_base_forward, ) -from flash_sparse_attn.ops.triton.flash_gated_bwd import ( +from flash_sparse_attn.ops.triton.flash_gated_backward import ( _flash_gated_attn_base_backward, _flash_gated_attn_varlen_base_backward, ) -from flash_sparse_attn.ops.triton.flash_gated_fwd import ( +from flash_sparse_attn.ops.triton.flash_gated_forward import ( _flash_gated_attn_base_forward, _flash_gated_attn_varlen_base_forward, ) -from flash_sparse_attn.ops.triton.flash_sparse_bwd import ( +from flash_sparse_attn.ops.triton.flash_sparse_backward import ( _flash_sparse_attn_base_backward, _flash_sparse_attn_varlen_base_backward, ) -from flash_sparse_attn.ops.triton.flash_sparse_fwd import ( +from flash_sparse_attn.ops.triton.flash_sparse_forward import ( _flash_sparse_attn_base_forward, _flash_sparse_attn_varlen_base_forward, )