From a0bd45a7247f45c97cfac9ec50b100919a5bdb26 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 21 Apr 2026 17:06:21 +0800 Subject: [PATCH 1/5] Rewrite vendored CuTe namespace to flash_attn.cute before subtree merge --- flash_sparse_attn/ops/cute/__init__.py | 2 +- .../ops/cute/benchmark_flash_attention_fp8.py | 6 +-- .../ops/cute/blackwell_helpers.py | 2 +- flash_sparse_attn/ops/cute/block_info.py | 2 +- .../ops/cute/block_sparse_utils.py | 4 +- flash_sparse_attn/ops/cute/block_sparsity.py | 2 +- flash_sparse_attn/ops/cute/cache_utils.py | 2 +- .../ops/cute/compute_block_sparsity.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd.py | 14 +++---- .../ops/cute/flash_bwd_postprocess.py | 10 ++--- .../ops/cute/flash_bwd_preprocess.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd_sm100.py | 28 ++++++------- flash_sparse_attn/ops/cute/flash_bwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_bwd_sm90.py | 24 +++++------ flash_sparse_attn/ops/cute/flash_fwd.py | 24 +++++------ .../ops/cute/flash_fwd_combine.py | 6 +-- .../ops/cute/flash_fwd_mla_sm100.py | 26 ++++++------ flash_sparse_attn/ops/cute/flash_fwd_sm100.py | 34 ++++++++-------- flash_sparse_attn/ops/cute/flash_fwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_fwd_sm90.py | 28 ++++++------- flash_sparse_attn/ops/cute/interface.py | 40 +++++++++---------- flash_sparse_attn/ops/cute/mask.py | 4 +- flash_sparse_attn/ops/cute/pack_gqa.py | 2 +- flash_sparse_attn/ops/cute/paged_kv.py | 2 +- flash_sparse_attn/ops/cute/softmax.py | 4 +- flash_sparse_attn/ops/cute/tile_scheduler.py | 4 +- flash_sparse_attn/ops/cute/topk_gather_kv.py | 4 +- 27 files changed, 145 insertions(+), 145 deletions(-) diff --git a/flash_sparse_attn/ops/cute/__init__.py b/flash_sparse_attn/ops/cute/__init__.py index 3dbaebd..be32e14 100644 --- a/flash_sparse_attn/ops/cute/__init__.py +++ b/flash_sparse_attn/ops/cute/__init__.py @@ -3,7 +3,7 @@ from importlib.metadata import PackageNotFoundError, version try: - __version__ = version("flash-sparse-attn") + __version__ = version("fa4") except PackageNotFoundError: __version__ = "0.0.0" diff --git a/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py b/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py index 09d5143..c79e768 100644 --- a/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py +++ b/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py @@ -1,7 +1,7 @@ # Benchmark FP8 attention for FA4 (CuTe-DSL) on SM100. # # Run (recommended): -# python -m flash_sparse_attn.ops.cute.benchmark_flash_attention_fp8 +# python -m flash_attn.cute.benchmark_flash_attention_fp8 # # Notes: # - This is intended to be used while bringing up FP8 support for SM100. @@ -21,8 +21,8 @@ import torch from einops import rearrange -from flash_sparse_attn.ops.cute.benchmark import benchmark_forward -from flash_sparse_attn.ops.cute.interface import _flash_attn_fwd as flash_attn_cute_fwd +from flash_attn.cute.benchmark import benchmark_forward +from flash_attn.cute.interface import _flash_attn_fwd as flash_attn_cute_fwd try: import cudnn diff --git a/flash_sparse_attn/ops/cute/blackwell_helpers.py b/flash_sparse_attn/ops/cute/blackwell_helpers.py index 4bc9fe5..4caadce 100644 --- a/flash_sparse_attn/ops/cute/blackwell_helpers.py +++ b/flash_sparse_attn/ops/cute/blackwell_helpers.py @@ -7,7 +7,7 @@ from cutlass.cute.nvgpu import tcgen05 from cutlass._mlir.dialects import llvm -import flash_sparse_attn.ops.cute.mma_sm100_desc as sm100_desc +import flash_attn.cute.mma_sm100_desc as sm100_desc def _tcgen05_mma_kind(op: cute.nvgpu.tcgen05.mma.MmaOp) -> str: diff --git a/flash_sparse_attn/ops/cute/block_info.py b/flash_sparse_attn/ops/cute/block_info.py index cebd0bf..f210138 100644 --- a/flash_sparse_attn/ops/cute/block_info.py +++ b/flash_sparse_attn/ops/cute/block_info.py @@ -6,7 +6,7 @@ import cutlass.cute as cute from cutlass import Int32, const_expr -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK +from flash_attn.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK @dataclass(frozen=True) diff --git a/flash_sparse_attn/ops/cute/block_sparse_utils.py b/flash_sparse_attn/ops/cute/block_sparse_utils.py index 66edb6a..b19dcd3 100644 --- a/flash_sparse_attn/ops/cute/block_sparse_utils.py +++ b/flash_sparse_attn/ops/cute/block_sparse_utils.py @@ -15,8 +15,8 @@ from quack import copy_utils # Import data structures from block_sparsity -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.named_barrier import NamedBarrierBwd # NOTE [SM100 block-sparse empty tiles: mbarrier contract] diff --git a/flash_sparse_attn/ops/cute/block_sparsity.py b/flash_sparse_attn/ops/cute/block_sparsity.py index 9e34734..3fad8c9 100644 --- a/flash_sparse_attn/ops/cute/block_sparsity.py +++ b/flash_sparse_attn/ops/cute/block_sparsity.py @@ -7,7 +7,7 @@ import cutlass.cute as cute import torch -from flash_sparse_attn.ops.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor +from flash_attn.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor def ceildiv(a: int, b: int) -> int: diff --git a/flash_sparse_attn/ops/cute/cache_utils.py b/flash_sparse_attn/ops/cute/cache_utils.py index dc04970..f1b5970 100644 --- a/flash_sparse_attn/ops/cute/cache_utils.py +++ b/flash_sparse_attn/ops/cute/cache_utils.py @@ -17,7 +17,7 @@ import cutlass.cute as cute import tvm_ffi from cutlass.cutlass_dsl import JitCompiledFunction -from flash_sparse_attn.ops.cute.fa_logging import fa_log +from flash_attn.cute.fa_logging import fa_log # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen. diff --git a/flash_sparse_attn/ops/cute/compute_block_sparsity.py b/flash_sparse_attn/ops/cute/compute_block_sparsity.py index d986ecb..a2dd98e 100644 --- a/flash_sparse_attn/ops/cute/compute_block_sparsity.py +++ b/flash_sparse_attn/ops/cute/compute_block_sparsity.py @@ -6,13 +6,13 @@ import torch from cutlass import Boolean, Int8, Int32, const_expr -from flash_sparse_attn.ops.cute.block_sparsity import ( +from flash_attn.cute.block_sparsity import ( BlockSparseTensors, BlockSparseTensorsTorch, to_cute_block_sparse_tensors, ) -from flash_sparse_attn.ops.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar +from flash_attn.cute.seqlen_info import SeqlenInfoQK class BlockSparsityKernel: diff --git a/flash_sparse_attn/ops/cute/flash_bwd.py b/flash_sparse_attn/ops/cute/flash_bwd.py index 081caa1..eeb7615 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd.py +++ b/flash_sparse_attn/ops/cute/flash_bwd.py @@ -15,14 +15,14 @@ import cutlass.utils as utils_basic from quack import layout_utils -from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_attn.cute.block_sparsity import BlockSparseTensors class FlashAttentionBackwardSm80: diff --git a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py index 36e4e8d..76c8562 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py @@ -18,13 +18,13 @@ from quack import layout_utils from quack import sm90_utils -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute import utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK import cutlass.cute.nvgpu.tcgen05 as tcgen05 from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py index 698b19f..d93ea5c 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py @@ -25,10 +25,10 @@ from quack import copy_utils, layout_utils -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo +from flash_attn.cute import utils +from flash_attn.cute.seqlen_info import SeqlenInfo from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py index 104fdb5..4b4083e 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py @@ -16,27 +16,27 @@ import quack.activation from quack import layout_utils -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import copy_utils -from flash_sparse_attn.ops.cute import pipeline -from flash_sparse_attn.ops.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_attn.cute import utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import copy_utils +from flash_attn.cute import pipeline +from flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, # noqa SingleTileVarlenScheduler, ) -from flash_sparse_attn.ops.cute import barrier -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwdSm100 -from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute import barrier +from flash_attn.cute.named_barrier import NamedBarrierBwdSm100 +from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( get_total_q_block_count_bwd, get_block_sparse_iteration_info_bwd, get_m_block_from_iter_bwd, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py index 0941ae2..556c59e 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py index 755e580..c9a690d 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py @@ -17,24 +17,24 @@ from quack import sm90_utils from quack.sm90_utils import gemm_zero_init, gemm_w_idx -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute import pipeline +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute import pipeline from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, SingleTileVarlenScheduler, ) -from flash_sparse_attn.ops.cute import barrier -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd -from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute import barrier +from flash_attn.cute.named_barrier import NamedBarrierBwd +from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( get_total_q_block_count_bwd, produce_block_sparse_q_loads_bwd_sm90, consume_block_sparse_mma_bwd_sm90, diff --git a/flash_sparse_attn/ops/cute/flash_fwd.py b/flash_sparse_attn/ops/cute/flash_fwd.py index 399efd3..d1a43cf 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd.py +++ b/flash_sparse_attn/ops/cute/flash_fwd.py @@ -23,17 +23,17 @@ from quack import copy_utils from quack import layout_utils -from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.softmax import Softmax, apply_score_mod_inner -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute.pack_gqa import PackGQA -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import Softmax, apply_score_mod_inner +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.pack_gqa import PackGQA +from flash_attn.cute.named_barrier import NamedBarrierFwd +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments class FlashAttentionForwardBase: @@ -1224,6 +1224,6 @@ def apply_score_mod( # SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility def __getattr__(name): if name == "FlashAttentionForwardSm90": - from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 + from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 return FlashAttentionForwardSm90 raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/flash_sparse_attn/ops/cute/flash_fwd_combine.py b/flash_sparse_attn/ops/cute/flash_fwd_combine.py index 7f38c2b..4936202 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_combine.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_combine.py @@ -12,9 +12,9 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Float32, Int32, Boolean, const_expr -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo +from flash_attn.cute import utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute.seqlen_info import SeqlenInfo from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py index 046457b..07cd99f 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py @@ -19,13 +19,13 @@ from quack import copy_utils -from flash_sparse_attn.ops.cute.pack_gqa import pack_gqa_layout, make_packgqa_tiled_tma_atom -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute.mask import AttentionMask -import flash_sparse_attn.ops.cute.blackwell_helpers as fa_sm100_utils -from flash_sparse_attn.ops.cute.softmax import SoftmaxSm100 -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.pack_gqa import pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.mask import AttentionMask +import flash_attn.cute.blackwell_helpers as fa_sm100_utils +from flash_attn.cute.softmax import SoftmaxSm100 +from flash_attn.cute.tile_scheduler import ( ClcState, SchedulingMode, TileSchedulerArguments, @@ -35,16 +35,16 @@ SingleTileVarlenScheduler, ParamsBase, ) -from flash_sparse_attn.ops.cute.fa_logging import fa_log, fa_printf -from flash_sparse_attn.ops.cute.utils import smid +from flash_attn.cute.fa_logging import fa_log, fa_printf +from flash_attn.cute.utils import smid -from flash_sparse_attn.ops.cute.topk_gather_kv import CpasyncGatherKVManager +from flash_attn.cute.topk_gather_kv import CpasyncGatherKVManager -from flash_sparse_attn.ops.cute.testing import attention_ref +from flash_attn.cute.testing import attention_ref -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwdSm100_MLA2CTA +from flash_attn.cute.named_barrier import NamedBarrierFwdSm100_MLA2CTA -from flash_sparse_attn.ops.cute.cute_dsl_utils import dump_kernel_attributes +from flash_attn.cute.cute_dsl_utils import dump_kernel_attributes class FlashAttentionMLAForwardSm100: diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py index 040bb43..75e767c 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py @@ -33,29 +33,29 @@ from quack import copy_utils, layout_utils -from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -import flash_sparse_attn.ops.cute.pipeline as pipeline_custom +from flash_attn.cute.paged_kv import PagedKVManager +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +import flash_attn.cute.pipeline as pipeline_custom import cutlass.pipeline as cutlass_pipeline -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.softmax import SoftmaxSm100, apply_score_mod_inner -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( get_total_block_count, produce_block_sparse_loads_sm100, softmax_block_sparse_sm100, handle_block_sparse_empty_tile_correction_sm100, ) -from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout -from flash_sparse_attn.ops.cute import mma_sm100_desc as sm100_desc -from flash_sparse_attn.ops.cute import blackwell_helpers as sm100_utils -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwdSm100 +from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout +from flash_attn.cute import mma_sm100_desc as sm100_desc +from flash_attn.cute import blackwell_helpers as sm100_utils +from flash_attn.cute.named_barrier import NamedBarrierFwdSm100 from cutlass.cute import FastDivmodDivisor from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( ClcState, SchedulingMode, TileSchedulerArguments, @@ -65,8 +65,8 @@ SingleTileLPTScheduler, SingleTileVarlenScheduler, ) -from flash_sparse_attn.ops.cute.fa_logging import fa_log, fa_printf -from flash_sparse_attn.ops.cute.utils import smid +from flash_attn.cute.fa_logging import fa_log, fa_printf +from flash_attn.cute.utils import smid # === TUNING KNOBS (agent-editable) === # Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool) diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py index f9917d8..08d219a 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 class FlashAttentionForwardSm120(FlashAttentionForwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py index 8ae0e77..4108ce4 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py @@ -21,23 +21,23 @@ from quack import layout_utils from quack import sm90_utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.softmax import Softmax, apply_score_mod_inner -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import Softmax, apply_score_mod_inner +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( produce_block_sparse_loads, consume_block_sparse_loads, ) -from flash_sparse_attn.ops.cute import pipeline as pipeline_custom -from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom -from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd +from flash_attn.cute import pipeline as pipeline_custom +from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_attn.cute.paged_kv import PagedKVManager +from flash_attn.cute.named_barrier import NamedBarrierFwd from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, @@ -45,7 +45,7 @@ ) from cutlass.cute import FastDivmodDivisor -from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardBase +from flash_attn.cute.flash_fwd import FlashAttentionForwardBase class FlashAttentionForwardSm90(FlashAttentionForwardBase): diff --git a/flash_sparse_attn/ops/cute/interface.py b/flash_sparse_attn/ops/cute/interface.py index cd2e005..17b703b 100644 --- a/flash_sparse_attn/ops/cute/interface.py +++ b/flash_sparse_attn/ops/cute/interface.py @@ -16,36 +16,36 @@ import cutlass.cute as cute from cutlass import Int32, Float32 from quack.compile_utils import make_fake_tensor as fake_tensor -from flash_sparse_attn.ops.cute.cache_utils import get_jit_cache -from flash_sparse_attn.ops.cute.testing import is_fake_mode +from flash_attn.cute.cache_utils import get_jit_cache +from flash_attn.cute.testing import is_fake_mode if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: - from flash_sparse_attn.ops.cute import cute_dsl_ptxas # noqa: F401 + from flash_attn.cute import cute_dsl_ptxas # noqa: F401 # Patch to dump ptx and then use system ptxas to compile to cubin cute_dsl_ptxas.patch() -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute import fa_logging -from flash_sparse_attn.ops.cute.cute_dsl_utils import ( +from flash_attn.cute import utils +from flash_attn.cute import fa_logging +from flash_attn.cute.cute_dsl_utils import ( to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims, ) -from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 -from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 -from flash_sparse_attn.ops.cute.flash_fwd_sm100 import FlashAttentionForwardSm100, DescaleTensors -from flash_sparse_attn.ops.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 -from flash_sparse_attn.ops.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess -from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 -from flash_sparse_attn.ops.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 -from flash_sparse_attn.ops.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 -from flash_sparse_attn.ops.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 -from flash_sparse_attn.ops.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess -from flash_sparse_attn.ops.cute.flash_fwd_combine import FlashAttentionForwardCombine -from flash_sparse_attn.ops.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 - -from flash_sparse_attn.ops.cute.block_sparsity import ( +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 +from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100, DescaleTensors +from flash_attn.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 +from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess +from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 +from flash_attn.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 +from flash_attn.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 +from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess +from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine +from flash_attn.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 + +from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, get_sparse_q_block_size, to_cute_block_sparse_tensors, diff --git a/flash_sparse_attn/ops/cute/mask.py b/flash_sparse_attn/ops/cute/mask.py index c7832f9..99e7008 100644 --- a/flash_sparse_attn/ops/cute/mask.py +++ b/flash_sparse_attn/ops/cute/mask.py @@ -8,8 +8,8 @@ from cutlass import Float32, Int32, Uint32, const_expr, Boolean from quack import layout_utils -import flash_sparse_attn.ops.cute.utils as utils -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +import flash_attn.cute.utils as utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK MaskGenFn: TypeAlias = Callable[[int], Uint32] MASK_R2P_CHUNK_SIZE: int = 32 diff --git a/flash_sparse_attn/ops/cute/pack_gqa.py b/flash_sparse_attn/ops/cute/pack_gqa.py index 484834f..e87df01 100644 --- a/flash_sparse_attn/ops/cute/pack_gqa.py +++ b/flash_sparse_attn/ops/cute/pack_gqa.py @@ -9,7 +9,7 @@ from quack import layout_utils -import flash_sparse_attn.ops.cute.utils as utils +import flash_attn.cute.utils as utils def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): diff --git a/flash_sparse_attn/ops/cute/paged_kv.py b/flash_sparse_attn/ops/cute/paged_kv.py index f84b35c..bf11acb 100644 --- a/flash_sparse_attn/ops/cute/paged_kv.py +++ b/flash_sparse_attn/ops/cute/paged_kv.py @@ -6,7 +6,7 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Int32, const_expr -from flash_sparse_attn.ops.cute import utils +from flash_attn.cute import utils from quack.cute_dsl_utils import ParamsBase from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/softmax.py b/flash_sparse_attn/ops/cute/softmax.py index 79eecd4..9369e0d 100644 --- a/flash_sparse_attn/ops/cute/softmax.py +++ b/flash_sparse_attn/ops/cute/softmax.py @@ -10,9 +10,9 @@ from cutlass import Float32, Boolean from quack import layout_utils -import flash_sparse_attn.ops.cute.utils as utils +import flash_attn.cute.utils as utils from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.seqlen_info import SeqlenInfoQK @dataclass diff --git a/flash_sparse_attn/ops/cute/tile_scheduler.py b/flash_sparse_attn/ops/cute/tile_scheduler.py index 002e7cd..ae57858 100644 --- a/flash_sparse_attn/ops/cute/tile_scheduler.py +++ b/flash_sparse_attn/ops/cute/tile_scheduler.py @@ -19,8 +19,8 @@ from quack.cute_dsl_utils import ParamsBase -import flash_sparse_attn.ops.cute.utils as utils -from flash_sparse_attn.ops.cute.fast_math import clz +import flash_attn.cute.utils as utils +from flash_attn.cute.fast_math import clz class SchedulingMode(IntEnum): diff --git a/flash_sparse_attn/ops/cute/topk_gather_kv.py b/flash_sparse_attn/ops/cute/topk_gather_kv.py index 5795ae0..67169fb 100644 --- a/flash_sparse_attn/ops/cute/topk_gather_kv.py +++ b/flash_sparse_attn/ops/cute/topk_gather_kv.py @@ -8,8 +8,8 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Int32, Uint32, const_expr, Boolean -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.utils import warp_reduce +from flash_attn.cute import utils +from flash_attn.cute.utils import warp_reduce from quack.cute_dsl_utils import ParamsBase import math From 2f7ea0cbcdd09db7523d135e3cc849a139d09566 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 21 Apr 2026 17:16:47 +0800 Subject: [PATCH 2/5] Rewrite vendored CuTe namespace to flash_sparse_attn.ops.cute --- flash_sparse_attn/ops/cute/__init__.py | 2 +- .../ops/cute/benchmark_flash_attention_fp8.py | 6 +-- .../ops/cute/blackwell_helpers.py | 2 +- flash_sparse_attn/ops/cute/block_info.py | 2 +- .../ops/cute/block_sparse_utils.py | 4 +- flash_sparse_attn/ops/cute/block_sparsity.py | 2 +- flash_sparse_attn/ops/cute/cache_utils.py | 2 +- .../ops/cute/compute_block_sparsity.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd.py | 14 +++---- .../ops/cute/flash_bwd_postprocess.py | 10 ++--- .../ops/cute/flash_bwd_preprocess.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd_sm100.py | 28 ++++++------- flash_sparse_attn/ops/cute/flash_bwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_bwd_sm90.py | 24 +++++------ flash_sparse_attn/ops/cute/flash_fwd.py | 24 +++++------ .../ops/cute/flash_fwd_combine.py | 6 +-- .../ops/cute/flash_fwd_mla_sm100.py | 26 ++++++------ flash_sparse_attn/ops/cute/flash_fwd_sm100.py | 34 ++++++++-------- flash_sparse_attn/ops/cute/flash_fwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_fwd_sm90.py | 28 ++++++------- flash_sparse_attn/ops/cute/interface.py | 40 +++++++++---------- flash_sparse_attn/ops/cute/mask.py | 4 +- flash_sparse_attn/ops/cute/pack_gqa.py | 2 +- flash_sparse_attn/ops/cute/paged_kv.py | 2 +- flash_sparse_attn/ops/cute/softmax.py | 4 +- flash_sparse_attn/ops/cute/tile_scheduler.py | 4 +- flash_sparse_attn/ops/cute/topk_gather_kv.py | 4 +- 27 files changed, 145 insertions(+), 145 deletions(-) diff --git a/flash_sparse_attn/ops/cute/__init__.py b/flash_sparse_attn/ops/cute/__init__.py index be32e14..3dbaebd 100644 --- a/flash_sparse_attn/ops/cute/__init__.py +++ b/flash_sparse_attn/ops/cute/__init__.py @@ -3,7 +3,7 @@ from importlib.metadata import PackageNotFoundError, version try: - __version__ = version("fa4") + __version__ = version("flash-sparse-attn") except PackageNotFoundError: __version__ = "0.0.0" diff --git a/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py b/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py index c79e768..09d5143 100644 --- a/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py +++ b/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py @@ -1,7 +1,7 @@ # Benchmark FP8 attention for FA4 (CuTe-DSL) on SM100. # # Run (recommended): -# python -m flash_attn.cute.benchmark_flash_attention_fp8 +# python -m flash_sparse_attn.ops.cute.benchmark_flash_attention_fp8 # # Notes: # - This is intended to be used while bringing up FP8 support for SM100. @@ -21,8 +21,8 @@ import torch from einops import rearrange -from flash_attn.cute.benchmark import benchmark_forward -from flash_attn.cute.interface import _flash_attn_fwd as flash_attn_cute_fwd +from flash_sparse_attn.ops.cute.benchmark import benchmark_forward +from flash_sparse_attn.ops.cute.interface import _flash_attn_fwd as flash_attn_cute_fwd try: import cudnn diff --git a/flash_sparse_attn/ops/cute/blackwell_helpers.py b/flash_sparse_attn/ops/cute/blackwell_helpers.py index 4caadce..4bc9fe5 100644 --- a/flash_sparse_attn/ops/cute/blackwell_helpers.py +++ b/flash_sparse_attn/ops/cute/blackwell_helpers.py @@ -7,7 +7,7 @@ from cutlass.cute.nvgpu import tcgen05 from cutlass._mlir.dialects import llvm -import flash_attn.cute.mma_sm100_desc as sm100_desc +import flash_sparse_attn.ops.cute.mma_sm100_desc as sm100_desc def _tcgen05_mma_kind(op: cute.nvgpu.tcgen05.mma.MmaOp) -> str: diff --git a/flash_sparse_attn/ops/cute/block_info.py b/flash_sparse_attn/ops/cute/block_info.py index f210138..cebd0bf 100644 --- a/flash_sparse_attn/ops/cute/block_info.py +++ b/flash_sparse_attn/ops/cute/block_info.py @@ -6,7 +6,7 @@ import cutlass.cute as cute from cutlass import Int32, const_expr -from flash_attn.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK @dataclass(frozen=True) diff --git a/flash_sparse_attn/ops/cute/block_sparse_utils.py b/flash_sparse_attn/ops/cute/block_sparse_utils.py index b19dcd3..66edb6a 100644 --- a/flash_sparse_attn/ops/cute/block_sparse_utils.py +++ b/flash_sparse_attn/ops/cute/block_sparse_utils.py @@ -15,8 +15,8 @@ from quack import copy_utils # Import data structures from block_sparsity -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.named_barrier import NamedBarrierBwd +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd # NOTE [SM100 block-sparse empty tiles: mbarrier contract] diff --git a/flash_sparse_attn/ops/cute/block_sparsity.py b/flash_sparse_attn/ops/cute/block_sparsity.py index 3fad8c9..9e34734 100644 --- a/flash_sparse_attn/ops/cute/block_sparsity.py +++ b/flash_sparse_attn/ops/cute/block_sparsity.py @@ -7,7 +7,7 @@ import cutlass.cute as cute import torch -from flash_attn.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor +from flash_sparse_attn.ops.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor def ceildiv(a: int, b: int) -> int: diff --git a/flash_sparse_attn/ops/cute/cache_utils.py b/flash_sparse_attn/ops/cute/cache_utils.py index f1b5970..dc04970 100644 --- a/flash_sparse_attn/ops/cute/cache_utils.py +++ b/flash_sparse_attn/ops/cute/cache_utils.py @@ -17,7 +17,7 @@ import cutlass.cute as cute import tvm_ffi from cutlass.cutlass_dsl import JitCompiledFunction -from flash_attn.cute.fa_logging import fa_log +from flash_sparse_attn.ops.cute.fa_logging import fa_log # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen. diff --git a/flash_sparse_attn/ops/cute/compute_block_sparsity.py b/flash_sparse_attn/ops/cute/compute_block_sparsity.py index a2dd98e..d986ecb 100644 --- a/flash_sparse_attn/ops/cute/compute_block_sparsity.py +++ b/flash_sparse_attn/ops/cute/compute_block_sparsity.py @@ -6,13 +6,13 @@ import torch from cutlass import Boolean, Int8, Int32, const_expr -from flash_attn.cute.block_sparsity import ( +from flash_sparse_attn.ops.cute.block_sparsity import ( BlockSparseTensors, BlockSparseTensorsTorch, to_cute_block_sparse_tensors, ) -from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK class BlockSparsityKernel: diff --git a/flash_sparse_attn/ops/cute/flash_bwd.py b/flash_sparse_attn/ops/cute/flash_bwd.py index eeb7615..081caa1 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd.py +++ b/flash_sparse_attn/ops/cute/flash_bwd.py @@ -15,14 +15,14 @@ import cutlass.utils as utils_basic from quack import layout_utils -from flash_attn.cute import ampere_helpers as sm80_utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments -from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors class FlashAttentionBackwardSm80: diff --git a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py index 76c8562..36e4e8d 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py @@ -18,13 +18,13 @@ from quack import layout_utils from quack import sm90_utils -from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import ampere_helpers as sm80_utils -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK import cutlass.cute.nvgpu.tcgen05 as tcgen05 from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py index d93ea5c..698b19f 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py @@ -25,10 +25,10 @@ from quack import copy_utils, layout_utils -from flash_attn.cute import utils -from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py index 4b4083e..104fdb5 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py @@ -16,27 +16,27 @@ import quack.activation from quack import layout_utils -from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import copy_utils -from flash_attn.cute import pipeline -from flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import copy_utils +from flash_sparse_attn.ops.cute import pipeline +from flash_sparse_attn.ops.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, # noqa SingleTileVarlenScheduler, ) -from flash_attn.cute import barrier -from flash_attn.cute.named_barrier import NamedBarrierBwdSm100 -from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute import barrier +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwdSm100 +from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( get_total_q_block_count_bwd, get_block_sparse_iteration_info_bwd, get_m_block_from_iter_bwd, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py index 556c59e..0941ae2 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py index c9a690d..755e580 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py @@ -17,24 +17,24 @@ from quack import sm90_utils from quack.sm90_utils import gemm_zero_init, gemm_w_idx -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute import pipeline +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute import pipeline from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, SingleTileVarlenScheduler, ) -from flash_attn.cute import barrier -from flash_attn.cute.named_barrier import NamedBarrierBwd -from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute import barrier +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd +from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( get_total_q_block_count_bwd, produce_block_sparse_q_loads_bwd_sm90, consume_block_sparse_mma_bwd_sm90, diff --git a/flash_sparse_attn/ops/cute/flash_fwd.py b/flash_sparse_attn/ops/cute/flash_fwd.py index d1a43cf..399efd3 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd.py +++ b/flash_sparse_attn/ops/cute/flash_fwd.py @@ -23,17 +23,17 @@ from quack import copy_utils from quack import layout_utils -from flash_attn.cute import ampere_helpers as sm80_utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import Softmax, apply_score_mod_inner -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.pack_gqa import PackGQA -from flash_attn.cute.named_barrier import NamedBarrierFwd -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.softmax import Softmax, apply_score_mod_inner +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute.pack_gqa import PackGQA +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments class FlashAttentionForwardBase: @@ -1224,6 +1224,6 @@ def apply_score_mod( # SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility def __getattr__(name): if name == "FlashAttentionForwardSm90": - from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 + from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 return FlashAttentionForwardSm90 raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/flash_sparse_attn/ops/cute/flash_fwd_combine.py b/flash_sparse_attn/ops/cute/flash_fwd_combine.py index 4936202..7f38c2b 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_combine.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_combine.py @@ -12,9 +12,9 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Float32, Int32, Boolean, const_expr -from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py index 07cd99f..046457b 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py @@ -19,13 +19,13 @@ from quack import copy_utils -from flash_attn.cute.pack_gqa import pack_gqa_layout, make_packgqa_tiled_tma_atom -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.mask import AttentionMask -import flash_attn.cute.blackwell_helpers as fa_sm100_utils -from flash_attn.cute.softmax import SoftmaxSm100 -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.pack_gqa import pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute.mask import AttentionMask +import flash_sparse_attn.ops.cute.blackwell_helpers as fa_sm100_utils +from flash_sparse_attn.ops.cute.softmax import SoftmaxSm100 +from flash_sparse_attn.ops.cute.tile_scheduler import ( ClcState, SchedulingMode, TileSchedulerArguments, @@ -35,16 +35,16 @@ SingleTileVarlenScheduler, ParamsBase, ) -from flash_attn.cute.fa_logging import fa_log, fa_printf -from flash_attn.cute.utils import smid +from flash_sparse_attn.ops.cute.fa_logging import fa_log, fa_printf +from flash_sparse_attn.ops.cute.utils import smid -from flash_attn.cute.topk_gather_kv import CpasyncGatherKVManager +from flash_sparse_attn.ops.cute.topk_gather_kv import CpasyncGatherKVManager -from flash_attn.cute.testing import attention_ref +from flash_sparse_attn.ops.cute.testing import attention_ref -from flash_attn.cute.named_barrier import NamedBarrierFwdSm100_MLA2CTA +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwdSm100_MLA2CTA -from flash_attn.cute.cute_dsl_utils import dump_kernel_attributes +from flash_sparse_attn.ops.cute.cute_dsl_utils import dump_kernel_attributes class FlashAttentionMLAForwardSm100: diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py index 75e767c..040bb43 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py @@ -33,29 +33,29 @@ from quack import copy_utils, layout_utils -from flash_attn.cute.paged_kv import PagedKVManager -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -import flash_attn.cute.pipeline as pipeline_custom +from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +import flash_sparse_attn.ops.cute.pipeline as pipeline_custom import cutlass.pipeline as cutlass_pipeline -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.softmax import SoftmaxSm100, apply_score_mod_inner +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( get_total_block_count, produce_block_sparse_loads_sm100, softmax_block_sparse_sm100, handle_block_sparse_empty_tile_correction_sm100, ) -from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout -from flash_attn.cute import mma_sm100_desc as sm100_desc -from flash_attn.cute import blackwell_helpers as sm100_utils -from flash_attn.cute.named_barrier import NamedBarrierFwdSm100 +from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout +from flash_sparse_attn.ops.cute import mma_sm100_desc as sm100_desc +from flash_sparse_attn.ops.cute import blackwell_helpers as sm100_utils +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwdSm100 from cutlass.cute import FastDivmodDivisor from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( ClcState, SchedulingMode, TileSchedulerArguments, @@ -65,8 +65,8 @@ SingleTileLPTScheduler, SingleTileVarlenScheduler, ) -from flash_attn.cute.fa_logging import fa_log, fa_printf -from flash_attn.cute.utils import smid +from flash_sparse_attn.ops.cute.fa_logging import fa_log, fa_printf +from flash_sparse_attn.ops.cute.utils import smid # === TUNING KNOBS (agent-editable) === # Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool) diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py index 08d219a..f9917d8 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 class FlashAttentionForwardSm120(FlashAttentionForwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py index 4108ce4..8ae0e77 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py @@ -21,23 +21,23 @@ from quack import layout_utils from quack import sm90_utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import Softmax, apply_score_mod_inner -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.softmax import Softmax, apply_score_mod_inner +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( produce_block_sparse_loads, consume_block_sparse_loads, ) -from flash_attn.cute import pipeline as pipeline_custom -from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom -from flash_attn.cute.paged_kv import PagedKVManager -from flash_attn.cute.named_barrier import NamedBarrierFwd +from flash_sparse_attn.ops.cute import pipeline as pipeline_custom +from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, @@ -45,7 +45,7 @@ ) from cutlass.cute import FastDivmodDivisor -from flash_attn.cute.flash_fwd import FlashAttentionForwardBase +from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardBase class FlashAttentionForwardSm90(FlashAttentionForwardBase): diff --git a/flash_sparse_attn/ops/cute/interface.py b/flash_sparse_attn/ops/cute/interface.py index 17b703b..cd2e005 100644 --- a/flash_sparse_attn/ops/cute/interface.py +++ b/flash_sparse_attn/ops/cute/interface.py @@ -16,36 +16,36 @@ import cutlass.cute as cute from cutlass import Int32, Float32 from quack.compile_utils import make_fake_tensor as fake_tensor -from flash_attn.cute.cache_utils import get_jit_cache -from flash_attn.cute.testing import is_fake_mode +from flash_sparse_attn.ops.cute.cache_utils import get_jit_cache +from flash_sparse_attn.ops.cute.testing import is_fake_mode if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: - from flash_attn.cute import cute_dsl_ptxas # noqa: F401 + from flash_sparse_attn.ops.cute import cute_dsl_ptxas # noqa: F401 # Patch to dump ptx and then use system ptxas to compile to cubin cute_dsl_ptxas.patch() -from flash_attn.cute import utils -from flash_attn.cute import fa_logging -from flash_attn.cute.cute_dsl_utils import ( +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute import fa_logging +from flash_sparse_attn.ops.cute.cute_dsl_utils import ( to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims, ) -from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 -from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 -from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100, DescaleTensors -from flash_attn.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 -from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess -from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 -from flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 -from flash_attn.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 -from flash_attn.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 -from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess -from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine -from flash_attn.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 - -from flash_attn.cute.block_sparsity import ( +from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 +from flash_sparse_attn.ops.cute.flash_fwd_sm100 import FlashAttentionForwardSm100, DescaleTensors +from flash_sparse_attn.ops.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 +from flash_sparse_attn.ops.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess +from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_sparse_attn.ops.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 +from flash_sparse_attn.ops.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 +from flash_sparse_attn.ops.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 +from flash_sparse_attn.ops.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess +from flash_sparse_attn.ops.cute.flash_fwd_combine import FlashAttentionForwardCombine +from flash_sparse_attn.ops.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 + +from flash_sparse_attn.ops.cute.block_sparsity import ( BlockSparseTensorsTorch, get_sparse_q_block_size, to_cute_block_sparse_tensors, diff --git a/flash_sparse_attn/ops/cute/mask.py b/flash_sparse_attn/ops/cute/mask.py index 99e7008..c7832f9 100644 --- a/flash_sparse_attn/ops/cute/mask.py +++ b/flash_sparse_attn/ops/cute/mask.py @@ -8,8 +8,8 @@ from cutlass import Float32, Int32, Uint32, const_expr, Boolean from quack import layout_utils -import flash_attn.cute.utils as utils -from flash_attn.cute.seqlen_info import SeqlenInfoQK +import flash_sparse_attn.ops.cute.utils as utils +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK MaskGenFn: TypeAlias = Callable[[int], Uint32] MASK_R2P_CHUNK_SIZE: int = 32 diff --git a/flash_sparse_attn/ops/cute/pack_gqa.py b/flash_sparse_attn/ops/cute/pack_gqa.py index e87df01..484834f 100644 --- a/flash_sparse_attn/ops/cute/pack_gqa.py +++ b/flash_sparse_attn/ops/cute/pack_gqa.py @@ -9,7 +9,7 @@ from quack import layout_utils -import flash_attn.cute.utils as utils +import flash_sparse_attn.ops.cute.utils as utils def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): diff --git a/flash_sparse_attn/ops/cute/paged_kv.py b/flash_sparse_attn/ops/cute/paged_kv.py index bf11acb..f84b35c 100644 --- a/flash_sparse_attn/ops/cute/paged_kv.py +++ b/flash_sparse_attn/ops/cute/paged_kv.py @@ -6,7 +6,7 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Int32, const_expr -from flash_attn.cute import utils +from flash_sparse_attn.ops.cute import utils from quack.cute_dsl_utils import ParamsBase from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/softmax.py b/flash_sparse_attn/ops/cute/softmax.py index 9369e0d..79eecd4 100644 --- a/flash_sparse_attn/ops/cute/softmax.py +++ b/flash_sparse_attn/ops/cute/softmax.py @@ -10,9 +10,9 @@ from cutlass import Float32, Boolean from quack import layout_utils -import flash_attn.cute.utils as utils +import flash_sparse_attn.ops.cute.utils as utils from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK @dataclass diff --git a/flash_sparse_attn/ops/cute/tile_scheduler.py b/flash_sparse_attn/ops/cute/tile_scheduler.py index ae57858..002e7cd 100644 --- a/flash_sparse_attn/ops/cute/tile_scheduler.py +++ b/flash_sparse_attn/ops/cute/tile_scheduler.py @@ -19,8 +19,8 @@ from quack.cute_dsl_utils import ParamsBase -import flash_attn.cute.utils as utils -from flash_attn.cute.fast_math import clz +import flash_sparse_attn.ops.cute.utils as utils +from flash_sparse_attn.ops.cute.fast_math import clz class SchedulingMode(IntEnum): diff --git a/flash_sparse_attn/ops/cute/topk_gather_kv.py b/flash_sparse_attn/ops/cute/topk_gather_kv.py index 67169fb..5795ae0 100644 --- a/flash_sparse_attn/ops/cute/topk_gather_kv.py +++ b/flash_sparse_attn/ops/cute/topk_gather_kv.py @@ -8,8 +8,8 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Int32, Uint32, const_expr, Boolean -from flash_attn.cute import utils -from flash_attn.cute.utils import warp_reduce +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.utils import warp_reduce from quack.cute_dsl_utils import ParamsBase import math From 1811618cd43aacc413e843407d1e1b3e0bf8c525 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 21 Apr 2026 17:19:10 +0800 Subject: [PATCH 3/5] Rewrite vendored CuTe namespace to flash_attn.cute before subtree merge --- flash_sparse_attn/ops/cute/__init__.py | 2 +- .../ops/cute/benchmark_flash_attention_fp8.py | 6 +-- .../ops/cute/blackwell_helpers.py | 2 +- flash_sparse_attn/ops/cute/block_info.py | 2 +- .../ops/cute/block_sparse_utils.py | 4 +- flash_sparse_attn/ops/cute/block_sparsity.py | 2 +- flash_sparse_attn/ops/cute/cache_utils.py | 2 +- .../ops/cute/compute_block_sparsity.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd.py | 14 +++---- .../ops/cute/flash_bwd_postprocess.py | 10 ++--- .../ops/cute/flash_bwd_preprocess.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd_sm100.py | 28 ++++++------- flash_sparse_attn/ops/cute/flash_bwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_bwd_sm90.py | 24 +++++------ flash_sparse_attn/ops/cute/flash_fwd.py | 24 +++++------ .../ops/cute/flash_fwd_combine.py | 6 +-- .../ops/cute/flash_fwd_mla_sm100.py | 26 ++++++------ flash_sparse_attn/ops/cute/flash_fwd_sm100.py | 34 ++++++++-------- flash_sparse_attn/ops/cute/flash_fwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_fwd_sm90.py | 28 ++++++------- flash_sparse_attn/ops/cute/interface.py | 40 +++++++++---------- flash_sparse_attn/ops/cute/mask.py | 4 +- flash_sparse_attn/ops/cute/pack_gqa.py | 2 +- flash_sparse_attn/ops/cute/paged_kv.py | 2 +- flash_sparse_attn/ops/cute/softmax.py | 4 +- flash_sparse_attn/ops/cute/tile_scheduler.py | 4 +- flash_sparse_attn/ops/cute/topk_gather_kv.py | 4 +- 27 files changed, 145 insertions(+), 145 deletions(-) diff --git a/flash_sparse_attn/ops/cute/__init__.py b/flash_sparse_attn/ops/cute/__init__.py index 3dbaebd..be32e14 100644 --- a/flash_sparse_attn/ops/cute/__init__.py +++ b/flash_sparse_attn/ops/cute/__init__.py @@ -3,7 +3,7 @@ from importlib.metadata import PackageNotFoundError, version try: - __version__ = version("flash-sparse-attn") + __version__ = version("fa4") except PackageNotFoundError: __version__ = "0.0.0" diff --git a/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py b/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py index 09d5143..c79e768 100644 --- a/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py +++ b/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py @@ -1,7 +1,7 @@ # Benchmark FP8 attention for FA4 (CuTe-DSL) on SM100. # # Run (recommended): -# python -m flash_sparse_attn.ops.cute.benchmark_flash_attention_fp8 +# python -m flash_attn.cute.benchmark_flash_attention_fp8 # # Notes: # - This is intended to be used while bringing up FP8 support for SM100. @@ -21,8 +21,8 @@ import torch from einops import rearrange -from flash_sparse_attn.ops.cute.benchmark import benchmark_forward -from flash_sparse_attn.ops.cute.interface import _flash_attn_fwd as flash_attn_cute_fwd +from flash_attn.cute.benchmark import benchmark_forward +from flash_attn.cute.interface import _flash_attn_fwd as flash_attn_cute_fwd try: import cudnn diff --git a/flash_sparse_attn/ops/cute/blackwell_helpers.py b/flash_sparse_attn/ops/cute/blackwell_helpers.py index 4bc9fe5..4caadce 100644 --- a/flash_sparse_attn/ops/cute/blackwell_helpers.py +++ b/flash_sparse_attn/ops/cute/blackwell_helpers.py @@ -7,7 +7,7 @@ from cutlass.cute.nvgpu import tcgen05 from cutlass._mlir.dialects import llvm -import flash_sparse_attn.ops.cute.mma_sm100_desc as sm100_desc +import flash_attn.cute.mma_sm100_desc as sm100_desc def _tcgen05_mma_kind(op: cute.nvgpu.tcgen05.mma.MmaOp) -> str: diff --git a/flash_sparse_attn/ops/cute/block_info.py b/flash_sparse_attn/ops/cute/block_info.py index cebd0bf..f210138 100644 --- a/flash_sparse_attn/ops/cute/block_info.py +++ b/flash_sparse_attn/ops/cute/block_info.py @@ -6,7 +6,7 @@ import cutlass.cute as cute from cutlass import Int32, const_expr -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK +from flash_attn.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK @dataclass(frozen=True) diff --git a/flash_sparse_attn/ops/cute/block_sparse_utils.py b/flash_sparse_attn/ops/cute/block_sparse_utils.py index 66edb6a..b19dcd3 100644 --- a/flash_sparse_attn/ops/cute/block_sparse_utils.py +++ b/flash_sparse_attn/ops/cute/block_sparse_utils.py @@ -15,8 +15,8 @@ from quack import copy_utils # Import data structures from block_sparsity -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.named_barrier import NamedBarrierBwd # NOTE [SM100 block-sparse empty tiles: mbarrier contract] diff --git a/flash_sparse_attn/ops/cute/block_sparsity.py b/flash_sparse_attn/ops/cute/block_sparsity.py index 9e34734..3fad8c9 100644 --- a/flash_sparse_attn/ops/cute/block_sparsity.py +++ b/flash_sparse_attn/ops/cute/block_sparsity.py @@ -7,7 +7,7 @@ import cutlass.cute as cute import torch -from flash_sparse_attn.ops.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor +from flash_attn.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor def ceildiv(a: int, b: int) -> int: diff --git a/flash_sparse_attn/ops/cute/cache_utils.py b/flash_sparse_attn/ops/cute/cache_utils.py index dc04970..f1b5970 100644 --- a/flash_sparse_attn/ops/cute/cache_utils.py +++ b/flash_sparse_attn/ops/cute/cache_utils.py @@ -17,7 +17,7 @@ import cutlass.cute as cute import tvm_ffi from cutlass.cutlass_dsl import JitCompiledFunction -from flash_sparse_attn.ops.cute.fa_logging import fa_log +from flash_attn.cute.fa_logging import fa_log # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen. diff --git a/flash_sparse_attn/ops/cute/compute_block_sparsity.py b/flash_sparse_attn/ops/cute/compute_block_sparsity.py index d986ecb..a2dd98e 100644 --- a/flash_sparse_attn/ops/cute/compute_block_sparsity.py +++ b/flash_sparse_attn/ops/cute/compute_block_sparsity.py @@ -6,13 +6,13 @@ import torch from cutlass import Boolean, Int8, Int32, const_expr -from flash_sparse_attn.ops.cute.block_sparsity import ( +from flash_attn.cute.block_sparsity import ( BlockSparseTensors, BlockSparseTensorsTorch, to_cute_block_sparse_tensors, ) -from flash_sparse_attn.ops.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar +from flash_attn.cute.seqlen_info import SeqlenInfoQK class BlockSparsityKernel: diff --git a/flash_sparse_attn/ops/cute/flash_bwd.py b/flash_sparse_attn/ops/cute/flash_bwd.py index 081caa1..eeb7615 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd.py +++ b/flash_sparse_attn/ops/cute/flash_bwd.py @@ -15,14 +15,14 @@ import cutlass.utils as utils_basic from quack import layout_utils -from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_attn.cute.block_sparsity import BlockSparseTensors class FlashAttentionBackwardSm80: diff --git a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py index 36e4e8d..76c8562 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py @@ -18,13 +18,13 @@ from quack import layout_utils from quack import sm90_utils -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute import utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK import cutlass.cute.nvgpu.tcgen05 as tcgen05 from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py index 698b19f..d93ea5c 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py @@ -25,10 +25,10 @@ from quack import copy_utils, layout_utils -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo +from flash_attn.cute import utils +from flash_attn.cute.seqlen_info import SeqlenInfo from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py index 104fdb5..4b4083e 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py @@ -16,27 +16,27 @@ import quack.activation from quack import layout_utils -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import copy_utils -from flash_sparse_attn.ops.cute import pipeline -from flash_sparse_attn.ops.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_attn.cute import utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import copy_utils +from flash_attn.cute import pipeline +from flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, # noqa SingleTileVarlenScheduler, ) -from flash_sparse_attn.ops.cute import barrier -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwdSm100 -from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute import barrier +from flash_attn.cute.named_barrier import NamedBarrierBwdSm100 +from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( get_total_q_block_count_bwd, get_block_sparse_iteration_info_bwd, get_m_block_from_iter_bwd, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py index 0941ae2..556c59e 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py index 755e580..c9a690d 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py @@ -17,24 +17,24 @@ from quack import sm90_utils from quack.sm90_utils import gemm_zero_init, gemm_w_idx -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute import pipeline +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute import pipeline from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, SingleTileVarlenScheduler, ) -from flash_sparse_attn.ops.cute import barrier -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd -from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute import barrier +from flash_attn.cute.named_barrier import NamedBarrierBwd +from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( get_total_q_block_count_bwd, produce_block_sparse_q_loads_bwd_sm90, consume_block_sparse_mma_bwd_sm90, diff --git a/flash_sparse_attn/ops/cute/flash_fwd.py b/flash_sparse_attn/ops/cute/flash_fwd.py index 399efd3..d1a43cf 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd.py +++ b/flash_sparse_attn/ops/cute/flash_fwd.py @@ -23,17 +23,17 @@ from quack import copy_utils from quack import layout_utils -from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.softmax import Softmax, apply_score_mod_inner -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute.pack_gqa import PackGQA -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import Softmax, apply_score_mod_inner +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.pack_gqa import PackGQA +from flash_attn.cute.named_barrier import NamedBarrierFwd +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments class FlashAttentionForwardBase: @@ -1224,6 +1224,6 @@ def apply_score_mod( # SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility def __getattr__(name): if name == "FlashAttentionForwardSm90": - from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 + from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 return FlashAttentionForwardSm90 raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/flash_sparse_attn/ops/cute/flash_fwd_combine.py b/flash_sparse_attn/ops/cute/flash_fwd_combine.py index 7f38c2b..4936202 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_combine.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_combine.py @@ -12,9 +12,9 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Float32, Int32, Boolean, const_expr -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo +from flash_attn.cute import utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute.seqlen_info import SeqlenInfo from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py index 046457b..07cd99f 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py @@ -19,13 +19,13 @@ from quack import copy_utils -from flash_sparse_attn.ops.cute.pack_gqa import pack_gqa_layout, make_packgqa_tiled_tma_atom -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute.mask import AttentionMask -import flash_sparse_attn.ops.cute.blackwell_helpers as fa_sm100_utils -from flash_sparse_attn.ops.cute.softmax import SoftmaxSm100 -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.pack_gqa import pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.mask import AttentionMask +import flash_attn.cute.blackwell_helpers as fa_sm100_utils +from flash_attn.cute.softmax import SoftmaxSm100 +from flash_attn.cute.tile_scheduler import ( ClcState, SchedulingMode, TileSchedulerArguments, @@ -35,16 +35,16 @@ SingleTileVarlenScheduler, ParamsBase, ) -from flash_sparse_attn.ops.cute.fa_logging import fa_log, fa_printf -from flash_sparse_attn.ops.cute.utils import smid +from flash_attn.cute.fa_logging import fa_log, fa_printf +from flash_attn.cute.utils import smid -from flash_sparse_attn.ops.cute.topk_gather_kv import CpasyncGatherKVManager +from flash_attn.cute.topk_gather_kv import CpasyncGatherKVManager -from flash_sparse_attn.ops.cute.testing import attention_ref +from flash_attn.cute.testing import attention_ref -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwdSm100_MLA2CTA +from flash_attn.cute.named_barrier import NamedBarrierFwdSm100_MLA2CTA -from flash_sparse_attn.ops.cute.cute_dsl_utils import dump_kernel_attributes +from flash_attn.cute.cute_dsl_utils import dump_kernel_attributes class FlashAttentionMLAForwardSm100: diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py index 040bb43..75e767c 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py @@ -33,29 +33,29 @@ from quack import copy_utils, layout_utils -from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -import flash_sparse_attn.ops.cute.pipeline as pipeline_custom +from flash_attn.cute.paged_kv import PagedKVManager +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +import flash_attn.cute.pipeline as pipeline_custom import cutlass.pipeline as cutlass_pipeline -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.softmax import SoftmaxSm100, apply_score_mod_inner -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( get_total_block_count, produce_block_sparse_loads_sm100, softmax_block_sparse_sm100, handle_block_sparse_empty_tile_correction_sm100, ) -from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout -from flash_sparse_attn.ops.cute import mma_sm100_desc as sm100_desc -from flash_sparse_attn.ops.cute import blackwell_helpers as sm100_utils -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwdSm100 +from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout +from flash_attn.cute import mma_sm100_desc as sm100_desc +from flash_attn.cute import blackwell_helpers as sm100_utils +from flash_attn.cute.named_barrier import NamedBarrierFwdSm100 from cutlass.cute import FastDivmodDivisor from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( ClcState, SchedulingMode, TileSchedulerArguments, @@ -65,8 +65,8 @@ SingleTileLPTScheduler, SingleTileVarlenScheduler, ) -from flash_sparse_attn.ops.cute.fa_logging import fa_log, fa_printf -from flash_sparse_attn.ops.cute.utils import smid +from flash_attn.cute.fa_logging import fa_log, fa_printf +from flash_attn.cute.utils import smid # === TUNING KNOBS (agent-editable) === # Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool) diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py index f9917d8..08d219a 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 class FlashAttentionForwardSm120(FlashAttentionForwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py index 8ae0e77..4108ce4 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py @@ -21,23 +21,23 @@ from quack import layout_utils from quack import sm90_utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.softmax import Softmax, apply_score_mod_inner -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import Softmax, apply_score_mod_inner +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( produce_block_sparse_loads, consume_block_sparse_loads, ) -from flash_sparse_attn.ops.cute import pipeline as pipeline_custom -from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom -from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd +from flash_attn.cute import pipeline as pipeline_custom +from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_attn.cute.paged_kv import PagedKVManager +from flash_attn.cute.named_barrier import NamedBarrierFwd from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, @@ -45,7 +45,7 @@ ) from cutlass.cute import FastDivmodDivisor -from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardBase +from flash_attn.cute.flash_fwd import FlashAttentionForwardBase class FlashAttentionForwardSm90(FlashAttentionForwardBase): diff --git a/flash_sparse_attn/ops/cute/interface.py b/flash_sparse_attn/ops/cute/interface.py index cd2e005..17b703b 100644 --- a/flash_sparse_attn/ops/cute/interface.py +++ b/flash_sparse_attn/ops/cute/interface.py @@ -16,36 +16,36 @@ import cutlass.cute as cute from cutlass import Int32, Float32 from quack.compile_utils import make_fake_tensor as fake_tensor -from flash_sparse_attn.ops.cute.cache_utils import get_jit_cache -from flash_sparse_attn.ops.cute.testing import is_fake_mode +from flash_attn.cute.cache_utils import get_jit_cache +from flash_attn.cute.testing import is_fake_mode if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: - from flash_sparse_attn.ops.cute import cute_dsl_ptxas # noqa: F401 + from flash_attn.cute import cute_dsl_ptxas # noqa: F401 # Patch to dump ptx and then use system ptxas to compile to cubin cute_dsl_ptxas.patch() -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute import fa_logging -from flash_sparse_attn.ops.cute.cute_dsl_utils import ( +from flash_attn.cute import utils +from flash_attn.cute import fa_logging +from flash_attn.cute.cute_dsl_utils import ( to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims, ) -from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 -from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 -from flash_sparse_attn.ops.cute.flash_fwd_sm100 import FlashAttentionForwardSm100, DescaleTensors -from flash_sparse_attn.ops.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 -from flash_sparse_attn.ops.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess -from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 -from flash_sparse_attn.ops.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 -from flash_sparse_attn.ops.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 -from flash_sparse_attn.ops.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 -from flash_sparse_attn.ops.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess -from flash_sparse_attn.ops.cute.flash_fwd_combine import FlashAttentionForwardCombine -from flash_sparse_attn.ops.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 - -from flash_sparse_attn.ops.cute.block_sparsity import ( +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 +from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100, DescaleTensors +from flash_attn.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 +from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess +from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 +from flash_attn.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 +from flash_attn.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 +from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess +from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine +from flash_attn.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 + +from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, get_sparse_q_block_size, to_cute_block_sparse_tensors, diff --git a/flash_sparse_attn/ops/cute/mask.py b/flash_sparse_attn/ops/cute/mask.py index c7832f9..99e7008 100644 --- a/flash_sparse_attn/ops/cute/mask.py +++ b/flash_sparse_attn/ops/cute/mask.py @@ -8,8 +8,8 @@ from cutlass import Float32, Int32, Uint32, const_expr, Boolean from quack import layout_utils -import flash_sparse_attn.ops.cute.utils as utils -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +import flash_attn.cute.utils as utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK MaskGenFn: TypeAlias = Callable[[int], Uint32] MASK_R2P_CHUNK_SIZE: int = 32 diff --git a/flash_sparse_attn/ops/cute/pack_gqa.py b/flash_sparse_attn/ops/cute/pack_gqa.py index 484834f..e87df01 100644 --- a/flash_sparse_attn/ops/cute/pack_gqa.py +++ b/flash_sparse_attn/ops/cute/pack_gqa.py @@ -9,7 +9,7 @@ from quack import layout_utils -import flash_sparse_attn.ops.cute.utils as utils +import flash_attn.cute.utils as utils def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): diff --git a/flash_sparse_attn/ops/cute/paged_kv.py b/flash_sparse_attn/ops/cute/paged_kv.py index f84b35c..bf11acb 100644 --- a/flash_sparse_attn/ops/cute/paged_kv.py +++ b/flash_sparse_attn/ops/cute/paged_kv.py @@ -6,7 +6,7 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Int32, const_expr -from flash_sparse_attn.ops.cute import utils +from flash_attn.cute import utils from quack.cute_dsl_utils import ParamsBase from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/softmax.py b/flash_sparse_attn/ops/cute/softmax.py index 79eecd4..9369e0d 100644 --- a/flash_sparse_attn/ops/cute/softmax.py +++ b/flash_sparse_attn/ops/cute/softmax.py @@ -10,9 +10,9 @@ from cutlass import Float32, Boolean from quack import layout_utils -import flash_sparse_attn.ops.cute.utils as utils +import flash_attn.cute.utils as utils from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.seqlen_info import SeqlenInfoQK @dataclass diff --git a/flash_sparse_attn/ops/cute/tile_scheduler.py b/flash_sparse_attn/ops/cute/tile_scheduler.py index 002e7cd..ae57858 100644 --- a/flash_sparse_attn/ops/cute/tile_scheduler.py +++ b/flash_sparse_attn/ops/cute/tile_scheduler.py @@ -19,8 +19,8 @@ from quack.cute_dsl_utils import ParamsBase -import flash_sparse_attn.ops.cute.utils as utils -from flash_sparse_attn.ops.cute.fast_math import clz +import flash_attn.cute.utils as utils +from flash_attn.cute.fast_math import clz class SchedulingMode(IntEnum): diff --git a/flash_sparse_attn/ops/cute/topk_gather_kv.py b/flash_sparse_attn/ops/cute/topk_gather_kv.py index 5795ae0..67169fb 100644 --- a/flash_sparse_attn/ops/cute/topk_gather_kv.py +++ b/flash_sparse_attn/ops/cute/topk_gather_kv.py @@ -8,8 +8,8 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Int32, Uint32, const_expr, Boolean -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.utils import warp_reduce +from flash_attn.cute import utils +from flash_attn.cute.utils import warp_reduce from quack.cute_dsl_utils import ParamsBase import math From abc1eea5c3e5d83eca07fff10f92cb1019fbcae9 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 21 Apr 2026 17:19:11 +0800 Subject: [PATCH 4/5] Rewrite vendored CuTe namespace to flash_sparse_attn.ops.cute --- flash_sparse_attn/ops/cute/__init__.py | 2 +- .../ops/cute/benchmark_flash_attention_fp8.py | 6 +-- .../ops/cute/blackwell_helpers.py | 2 +- flash_sparse_attn/ops/cute/block_info.py | 2 +- .../ops/cute/block_sparse_utils.py | 4 +- flash_sparse_attn/ops/cute/block_sparsity.py | 2 +- flash_sparse_attn/ops/cute/cache_utils.py | 2 +- .../ops/cute/compute_block_sparsity.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd.py | 14 +++---- .../ops/cute/flash_bwd_postprocess.py | 10 ++--- .../ops/cute/flash_bwd_preprocess.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd_sm100.py | 28 ++++++------- flash_sparse_attn/ops/cute/flash_bwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_bwd_sm90.py | 24 +++++------ flash_sparse_attn/ops/cute/flash_fwd.py | 24 +++++------ .../ops/cute/flash_fwd_combine.py | 6 +-- .../ops/cute/flash_fwd_mla_sm100.py | 26 ++++++------ flash_sparse_attn/ops/cute/flash_fwd_sm100.py | 34 ++++++++-------- flash_sparse_attn/ops/cute/flash_fwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_fwd_sm90.py | 28 ++++++------- flash_sparse_attn/ops/cute/interface.py | 40 +++++++++---------- flash_sparse_attn/ops/cute/mask.py | 4 +- flash_sparse_attn/ops/cute/pack_gqa.py | 2 +- flash_sparse_attn/ops/cute/paged_kv.py | 2 +- flash_sparse_attn/ops/cute/softmax.py | 4 +- flash_sparse_attn/ops/cute/tile_scheduler.py | 4 +- flash_sparse_attn/ops/cute/topk_gather_kv.py | 4 +- 27 files changed, 145 insertions(+), 145 deletions(-) diff --git a/flash_sparse_attn/ops/cute/__init__.py b/flash_sparse_attn/ops/cute/__init__.py index be32e14..3dbaebd 100644 --- a/flash_sparse_attn/ops/cute/__init__.py +++ b/flash_sparse_attn/ops/cute/__init__.py @@ -3,7 +3,7 @@ from importlib.metadata import PackageNotFoundError, version try: - __version__ = version("fa4") + __version__ = version("flash-sparse-attn") except PackageNotFoundError: __version__ = "0.0.0" diff --git a/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py b/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py index c79e768..09d5143 100644 --- a/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py +++ b/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py @@ -1,7 +1,7 @@ # Benchmark FP8 attention for FA4 (CuTe-DSL) on SM100. # # Run (recommended): -# python -m flash_attn.cute.benchmark_flash_attention_fp8 +# python -m flash_sparse_attn.ops.cute.benchmark_flash_attention_fp8 # # Notes: # - This is intended to be used while bringing up FP8 support for SM100. @@ -21,8 +21,8 @@ import torch from einops import rearrange -from flash_attn.cute.benchmark import benchmark_forward -from flash_attn.cute.interface import _flash_attn_fwd as flash_attn_cute_fwd +from flash_sparse_attn.ops.cute.benchmark import benchmark_forward +from flash_sparse_attn.ops.cute.interface import _flash_attn_fwd as flash_attn_cute_fwd try: import cudnn diff --git a/flash_sparse_attn/ops/cute/blackwell_helpers.py b/flash_sparse_attn/ops/cute/blackwell_helpers.py index 4caadce..4bc9fe5 100644 --- a/flash_sparse_attn/ops/cute/blackwell_helpers.py +++ b/flash_sparse_attn/ops/cute/blackwell_helpers.py @@ -7,7 +7,7 @@ from cutlass.cute.nvgpu import tcgen05 from cutlass._mlir.dialects import llvm -import flash_attn.cute.mma_sm100_desc as sm100_desc +import flash_sparse_attn.ops.cute.mma_sm100_desc as sm100_desc def _tcgen05_mma_kind(op: cute.nvgpu.tcgen05.mma.MmaOp) -> str: diff --git a/flash_sparse_attn/ops/cute/block_info.py b/flash_sparse_attn/ops/cute/block_info.py index f210138..cebd0bf 100644 --- a/flash_sparse_attn/ops/cute/block_info.py +++ b/flash_sparse_attn/ops/cute/block_info.py @@ -6,7 +6,7 @@ import cutlass.cute as cute from cutlass import Int32, const_expr -from flash_attn.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK @dataclass(frozen=True) diff --git a/flash_sparse_attn/ops/cute/block_sparse_utils.py b/flash_sparse_attn/ops/cute/block_sparse_utils.py index b19dcd3..66edb6a 100644 --- a/flash_sparse_attn/ops/cute/block_sparse_utils.py +++ b/flash_sparse_attn/ops/cute/block_sparse_utils.py @@ -15,8 +15,8 @@ from quack import copy_utils # Import data structures from block_sparsity -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.named_barrier import NamedBarrierBwd +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd # NOTE [SM100 block-sparse empty tiles: mbarrier contract] diff --git a/flash_sparse_attn/ops/cute/block_sparsity.py b/flash_sparse_attn/ops/cute/block_sparsity.py index 3fad8c9..9e34734 100644 --- a/flash_sparse_attn/ops/cute/block_sparsity.py +++ b/flash_sparse_attn/ops/cute/block_sparsity.py @@ -7,7 +7,7 @@ import cutlass.cute as cute import torch -from flash_attn.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor +from flash_sparse_attn.ops.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor def ceildiv(a: int, b: int) -> int: diff --git a/flash_sparse_attn/ops/cute/cache_utils.py b/flash_sparse_attn/ops/cute/cache_utils.py index f1b5970..dc04970 100644 --- a/flash_sparse_attn/ops/cute/cache_utils.py +++ b/flash_sparse_attn/ops/cute/cache_utils.py @@ -17,7 +17,7 @@ import cutlass.cute as cute import tvm_ffi from cutlass.cutlass_dsl import JitCompiledFunction -from flash_attn.cute.fa_logging import fa_log +from flash_sparse_attn.ops.cute.fa_logging import fa_log # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen. diff --git a/flash_sparse_attn/ops/cute/compute_block_sparsity.py b/flash_sparse_attn/ops/cute/compute_block_sparsity.py index a2dd98e..d986ecb 100644 --- a/flash_sparse_attn/ops/cute/compute_block_sparsity.py +++ b/flash_sparse_attn/ops/cute/compute_block_sparsity.py @@ -6,13 +6,13 @@ import torch from cutlass import Boolean, Int8, Int32, const_expr -from flash_attn.cute.block_sparsity import ( +from flash_sparse_attn.ops.cute.block_sparsity import ( BlockSparseTensors, BlockSparseTensorsTorch, to_cute_block_sparse_tensors, ) -from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK class BlockSparsityKernel: diff --git a/flash_sparse_attn/ops/cute/flash_bwd.py b/flash_sparse_attn/ops/cute/flash_bwd.py index eeb7615..081caa1 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd.py +++ b/flash_sparse_attn/ops/cute/flash_bwd.py @@ -15,14 +15,14 @@ import cutlass.utils as utils_basic from quack import layout_utils -from flash_attn.cute import ampere_helpers as sm80_utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments -from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors class FlashAttentionBackwardSm80: diff --git a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py index 76c8562..36e4e8d 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py @@ -18,13 +18,13 @@ from quack import layout_utils from quack import sm90_utils -from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import ampere_helpers as sm80_utils -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK import cutlass.cute.nvgpu.tcgen05 as tcgen05 from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py index d93ea5c..698b19f 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py @@ -25,10 +25,10 @@ from quack import copy_utils, layout_utils -from flash_attn.cute import utils -from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py index 4b4083e..104fdb5 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py @@ -16,27 +16,27 @@ import quack.activation from quack import layout_utils -from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import copy_utils -from flash_attn.cute import pipeline -from flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import copy_utils +from flash_sparse_attn.ops.cute import pipeline +from flash_sparse_attn.ops.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, # noqa SingleTileVarlenScheduler, ) -from flash_attn.cute import barrier -from flash_attn.cute.named_barrier import NamedBarrierBwdSm100 -from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute import barrier +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwdSm100 +from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( get_total_q_block_count_bwd, get_block_sparse_iteration_info_bwd, get_m_block_from_iter_bwd, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py index 556c59e..0941ae2 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py index c9a690d..755e580 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py @@ -17,24 +17,24 @@ from quack import sm90_utils from quack.sm90_utils import gemm_zero_init, gemm_w_idx -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute import pipeline +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute import pipeline from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, SingleTileVarlenScheduler, ) -from flash_attn.cute import barrier -from flash_attn.cute.named_barrier import NamedBarrierBwd -from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute import barrier +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd +from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( get_total_q_block_count_bwd, produce_block_sparse_q_loads_bwd_sm90, consume_block_sparse_mma_bwd_sm90, diff --git a/flash_sparse_attn/ops/cute/flash_fwd.py b/flash_sparse_attn/ops/cute/flash_fwd.py index d1a43cf..399efd3 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd.py +++ b/flash_sparse_attn/ops/cute/flash_fwd.py @@ -23,17 +23,17 @@ from quack import copy_utils from quack import layout_utils -from flash_attn.cute import ampere_helpers as sm80_utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import Softmax, apply_score_mod_inner -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.pack_gqa import PackGQA -from flash_attn.cute.named_barrier import NamedBarrierFwd -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.softmax import Softmax, apply_score_mod_inner +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute.pack_gqa import PackGQA +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments class FlashAttentionForwardBase: @@ -1224,6 +1224,6 @@ def apply_score_mod( # SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility def __getattr__(name): if name == "FlashAttentionForwardSm90": - from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 + from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 return FlashAttentionForwardSm90 raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/flash_sparse_attn/ops/cute/flash_fwd_combine.py b/flash_sparse_attn/ops/cute/flash_fwd_combine.py index 4936202..7f38c2b 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_combine.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_combine.py @@ -12,9 +12,9 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Float32, Int32, Boolean, const_expr -from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py index 07cd99f..046457b 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py @@ -19,13 +19,13 @@ from quack import copy_utils -from flash_attn.cute.pack_gqa import pack_gqa_layout, make_packgqa_tiled_tma_atom -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.mask import AttentionMask -import flash_attn.cute.blackwell_helpers as fa_sm100_utils -from flash_attn.cute.softmax import SoftmaxSm100 -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.pack_gqa import pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute.mask import AttentionMask +import flash_sparse_attn.ops.cute.blackwell_helpers as fa_sm100_utils +from flash_sparse_attn.ops.cute.softmax import SoftmaxSm100 +from flash_sparse_attn.ops.cute.tile_scheduler import ( ClcState, SchedulingMode, TileSchedulerArguments, @@ -35,16 +35,16 @@ SingleTileVarlenScheduler, ParamsBase, ) -from flash_attn.cute.fa_logging import fa_log, fa_printf -from flash_attn.cute.utils import smid +from flash_sparse_attn.ops.cute.fa_logging import fa_log, fa_printf +from flash_sparse_attn.ops.cute.utils import smid -from flash_attn.cute.topk_gather_kv import CpasyncGatherKVManager +from flash_sparse_attn.ops.cute.topk_gather_kv import CpasyncGatherKVManager -from flash_attn.cute.testing import attention_ref +from flash_sparse_attn.ops.cute.testing import attention_ref -from flash_attn.cute.named_barrier import NamedBarrierFwdSm100_MLA2CTA +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwdSm100_MLA2CTA -from flash_attn.cute.cute_dsl_utils import dump_kernel_attributes +from flash_sparse_attn.ops.cute.cute_dsl_utils import dump_kernel_attributes class FlashAttentionMLAForwardSm100: diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py index 75e767c..040bb43 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py @@ -33,29 +33,29 @@ from quack import copy_utils, layout_utils -from flash_attn.cute.paged_kv import PagedKVManager -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -import flash_attn.cute.pipeline as pipeline_custom +from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +import flash_sparse_attn.ops.cute.pipeline as pipeline_custom import cutlass.pipeline as cutlass_pipeline -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.softmax import SoftmaxSm100, apply_score_mod_inner +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( get_total_block_count, produce_block_sparse_loads_sm100, softmax_block_sparse_sm100, handle_block_sparse_empty_tile_correction_sm100, ) -from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout -from flash_attn.cute import mma_sm100_desc as sm100_desc -from flash_attn.cute import blackwell_helpers as sm100_utils -from flash_attn.cute.named_barrier import NamedBarrierFwdSm100 +from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout +from flash_sparse_attn.ops.cute import mma_sm100_desc as sm100_desc +from flash_sparse_attn.ops.cute import blackwell_helpers as sm100_utils +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwdSm100 from cutlass.cute import FastDivmodDivisor from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( ClcState, SchedulingMode, TileSchedulerArguments, @@ -65,8 +65,8 @@ SingleTileLPTScheduler, SingleTileVarlenScheduler, ) -from flash_attn.cute.fa_logging import fa_log, fa_printf -from flash_attn.cute.utils import smid +from flash_sparse_attn.ops.cute.fa_logging import fa_log, fa_printf +from flash_sparse_attn.ops.cute.utils import smid # === TUNING KNOBS (agent-editable) === # Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool) diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py index 08d219a..f9917d8 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 class FlashAttentionForwardSm120(FlashAttentionForwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py index 4108ce4..8ae0e77 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py @@ -21,23 +21,23 @@ from quack import layout_utils from quack import sm90_utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import Softmax, apply_score_mod_inner -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.softmax import Softmax, apply_score_mod_inner +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( produce_block_sparse_loads, consume_block_sparse_loads, ) -from flash_attn.cute import pipeline as pipeline_custom -from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom -from flash_attn.cute.paged_kv import PagedKVManager -from flash_attn.cute.named_barrier import NamedBarrierFwd +from flash_sparse_attn.ops.cute import pipeline as pipeline_custom +from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, @@ -45,7 +45,7 @@ ) from cutlass.cute import FastDivmodDivisor -from flash_attn.cute.flash_fwd import FlashAttentionForwardBase +from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardBase class FlashAttentionForwardSm90(FlashAttentionForwardBase): diff --git a/flash_sparse_attn/ops/cute/interface.py b/flash_sparse_attn/ops/cute/interface.py index 17b703b..cd2e005 100644 --- a/flash_sparse_attn/ops/cute/interface.py +++ b/flash_sparse_attn/ops/cute/interface.py @@ -16,36 +16,36 @@ import cutlass.cute as cute from cutlass import Int32, Float32 from quack.compile_utils import make_fake_tensor as fake_tensor -from flash_attn.cute.cache_utils import get_jit_cache -from flash_attn.cute.testing import is_fake_mode +from flash_sparse_attn.ops.cute.cache_utils import get_jit_cache +from flash_sparse_attn.ops.cute.testing import is_fake_mode if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: - from flash_attn.cute import cute_dsl_ptxas # noqa: F401 + from flash_sparse_attn.ops.cute import cute_dsl_ptxas # noqa: F401 # Patch to dump ptx and then use system ptxas to compile to cubin cute_dsl_ptxas.patch() -from flash_attn.cute import utils -from flash_attn.cute import fa_logging -from flash_attn.cute.cute_dsl_utils import ( +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute import fa_logging +from flash_sparse_attn.ops.cute.cute_dsl_utils import ( to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims, ) -from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 -from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 -from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100, DescaleTensors -from flash_attn.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 -from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess -from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 -from flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 -from flash_attn.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 -from flash_attn.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 -from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess -from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine -from flash_attn.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 - -from flash_attn.cute.block_sparsity import ( +from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 +from flash_sparse_attn.ops.cute.flash_fwd_sm100 import FlashAttentionForwardSm100, DescaleTensors +from flash_sparse_attn.ops.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 +from flash_sparse_attn.ops.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess +from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_sparse_attn.ops.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 +from flash_sparse_attn.ops.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 +from flash_sparse_attn.ops.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 +from flash_sparse_attn.ops.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess +from flash_sparse_attn.ops.cute.flash_fwd_combine import FlashAttentionForwardCombine +from flash_sparse_attn.ops.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 + +from flash_sparse_attn.ops.cute.block_sparsity import ( BlockSparseTensorsTorch, get_sparse_q_block_size, to_cute_block_sparse_tensors, diff --git a/flash_sparse_attn/ops/cute/mask.py b/flash_sparse_attn/ops/cute/mask.py index 99e7008..c7832f9 100644 --- a/flash_sparse_attn/ops/cute/mask.py +++ b/flash_sparse_attn/ops/cute/mask.py @@ -8,8 +8,8 @@ from cutlass import Float32, Int32, Uint32, const_expr, Boolean from quack import layout_utils -import flash_attn.cute.utils as utils -from flash_attn.cute.seqlen_info import SeqlenInfoQK +import flash_sparse_attn.ops.cute.utils as utils +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK MaskGenFn: TypeAlias = Callable[[int], Uint32] MASK_R2P_CHUNK_SIZE: int = 32 diff --git a/flash_sparse_attn/ops/cute/pack_gqa.py b/flash_sparse_attn/ops/cute/pack_gqa.py index e87df01..484834f 100644 --- a/flash_sparse_attn/ops/cute/pack_gqa.py +++ b/flash_sparse_attn/ops/cute/pack_gqa.py @@ -9,7 +9,7 @@ from quack import layout_utils -import flash_attn.cute.utils as utils +import flash_sparse_attn.ops.cute.utils as utils def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): diff --git a/flash_sparse_attn/ops/cute/paged_kv.py b/flash_sparse_attn/ops/cute/paged_kv.py index bf11acb..f84b35c 100644 --- a/flash_sparse_attn/ops/cute/paged_kv.py +++ b/flash_sparse_attn/ops/cute/paged_kv.py @@ -6,7 +6,7 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Int32, const_expr -from flash_attn.cute import utils +from flash_sparse_attn.ops.cute import utils from quack.cute_dsl_utils import ParamsBase from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/softmax.py b/flash_sparse_attn/ops/cute/softmax.py index 9369e0d..79eecd4 100644 --- a/flash_sparse_attn/ops/cute/softmax.py +++ b/flash_sparse_attn/ops/cute/softmax.py @@ -10,9 +10,9 @@ from cutlass import Float32, Boolean from quack import layout_utils -import flash_attn.cute.utils as utils +import flash_sparse_attn.ops.cute.utils as utils from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK @dataclass diff --git a/flash_sparse_attn/ops/cute/tile_scheduler.py b/flash_sparse_attn/ops/cute/tile_scheduler.py index ae57858..002e7cd 100644 --- a/flash_sparse_attn/ops/cute/tile_scheduler.py +++ b/flash_sparse_attn/ops/cute/tile_scheduler.py @@ -19,8 +19,8 @@ from quack.cute_dsl_utils import ParamsBase -import flash_attn.cute.utils as utils -from flash_attn.cute.fast_math import clz +import flash_sparse_attn.ops.cute.utils as utils +from flash_sparse_attn.ops.cute.fast_math import clz class SchedulingMode(IntEnum): diff --git a/flash_sparse_attn/ops/cute/topk_gather_kv.py b/flash_sparse_attn/ops/cute/topk_gather_kv.py index 67169fb..5795ae0 100644 --- a/flash_sparse_attn/ops/cute/topk_gather_kv.py +++ b/flash_sparse_attn/ops/cute/topk_gather_kv.py @@ -8,8 +8,8 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Int32, Uint32, const_expr, Boolean -from flash_attn.cute import utils -from flash_attn.cute.utils import warp_reduce +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.utils import warp_reduce from quack.cute_dsl_utils import ParamsBase import math From aa03d81f02ce39e3a831f72d1efa353ac36a95e9 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 21 Apr 2026 17:20:11 +0800 Subject: [PATCH 5/5] Enhance sync script with cherry-pick functionality and improve merge conflict handling Co-authored-by: Copilot --- scripts/sync_cute_subtree.sh | 38 ++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/scripts/sync_cute_subtree.sh b/scripts/sync_cute_subtree.sh index 2188d9c..24eaa53 100644 --- a/scripts/sync_cute_subtree.sh +++ b/scripts/sync_cute_subtree.sh @@ -14,7 +14,9 @@ NO_TEMPORARY_WORKTREE=0 TEMP_WORKTREE_PATH="" TEMP_WORKTREE_BRANCH="" PREPARE_MERGE_COMMIT_MESSAGE="Rewrite vendored CuTe namespace to flash_attn.cute before subtree merge" +MERGE_COMMIT_MESSAGE="Merge upstream CuTe subtree updates" REWRITE_COMMIT_MESSAGE="Rewrite vendored CuTe namespace to flash_sparse_attn.ops.cute" +MERGE_CONFLICT_OPTION="ours" UPSTREAM_SPLIT_REF="HEAD" usage() { @@ -263,6 +265,34 @@ commit_prefix_if_changed() { invoke_git -C "$repo" commit -m "$message" } +cherry_pick_commit_back() { + local repo="$1" + local commit="$2" + local parent_count + + parent_count="$(get_commit_parent_count "$repo" "$commit")" + if [[ "$parent_count" -gt 1 ]]; then + if invoke_git -C "$REPO_ROOT" cherry-pick -m 1 "$commit"; then + return 0 + fi + else + if invoke_git -C "$REPO_ROOT" cherry-pick "$commit"; then + return 0 + fi + fi + + if git -C "$REPO_ROOT" rev-parse -q --verify CHERRY_PICK_HEAD >/dev/null 2>&1 \ + && git -C "$REPO_ROOT" diff --cached --quiet \ + && git -C "$REPO_ROOT" diff --quiet; then + echo "Cherry-pick $commit became empty after conflict resolution. Skipping ..." + invoke_git -C "$REPO_ROOT" cherry-pick --skip + return 0 + fi + + echo "Cherry-pick $commit failed with conflicts. Resolve manually and continue." >&2 + exit 1 +} + invoke_core_sync() { local work_repo_root="$1" local cutlass_repo="$work_repo_root/csrc/cutlass" @@ -322,7 +352,7 @@ invoke_core_sync() { echo "Pulling upstream updates into $PREFIX ..." invoke_git -C "$work_repo_root" fetch "$UPSTREAM_REPO_FOR_SPLIT" "$TEMP_BRANCH" - invoke_git_no_merge_edit -C "$work_repo_root" merge -X theirs "-Xsubtree=$PREFIX" FETCH_HEAD + invoke_git_no_merge_edit -C "$work_repo_root" merge -m "$MERGE_COMMIT_MESSAGE" "-X$MERGE_CONFLICT_OPTION" "-Xsubtree=$PREFIX" FETCH_HEAD fi echo "Rewriting vendored CuTe imports to flash_sparse_attn.ops.cute ..." @@ -369,11 +399,7 @@ invoke_temporary_worktree_sync() { [[ -z "$commit" ]] && continue echo "Cherry-picking $commit back into current worktree ..." ensure_git_identity "$REPO_ROOT" - if [[ "$(get_commit_parent_count "$temp_worktree" "$commit")" -gt 1 ]]; then - invoke_git -C "$REPO_ROOT" cherry-pick -m 1 "$commit" - else - invoke_git -C "$REPO_ROOT" cherry-pick "$commit" - fi + cherry_pick_commit_back "$temp_worktree" "$commit" done <<< "$commits" if [[ -n "$current_status" ]]; then