From 1ca21547ed01064293670302f31eb095b7bb85b9 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Tue, 7 Apr 2026 15:27:06 +0000 Subject: [PATCH 1/5] Rewrite vendored CuTe namespace to flash_attn.cute before subtree merge --- flash_sparse_attn/ops/cute/__init__.py | 4 +- .../ops/cute/blackwell_helpers.py | 2 +- flash_sparse_attn/ops/cute/block_info.py | 2 +- .../ops/cute/block_sparse_utils.py | 4 +- flash_sparse_attn/ops/cute/block_sparsity.py | 2 +- .../ops/cute/compute_block_sparsity.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd.py | 14 +++---- .../ops/cute/flash_bwd_postprocess.py | 10 ++--- .../ops/cute/flash_bwd_preprocess.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd_sm100.py | 28 +++++++------- flash_sparse_attn/ops/cute/flash_bwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_bwd_sm90.py | 24 ++++++------ flash_sparse_attn/ops/cute/flash_fwd.py | 24 ++++++------ .../ops/cute/flash_fwd_combine.py | 6 +-- flash_sparse_attn/ops/cute/flash_fwd_sm100.py | 30 +++++++-------- flash_sparse_attn/ops/cute/flash_fwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_fwd_sm90.py | 28 +++++++------- flash_sparse_attn/ops/cute/interface.py | 38 +++++++++---------- flash_sparse_attn/ops/cute/mask.py | 4 +- flash_sparse_attn/ops/cute/pack_gqa.py | 2 +- flash_sparse_attn/ops/cute/paged_kv.py | 2 +- flash_sparse_attn/ops/cute/softmax.py | 4 +- flash_sparse_attn/ops/cute/tile_scheduler.py | 4 +- 23 files changed, 124 insertions(+), 124 deletions(-) diff --git a/flash_sparse_attn/ops/cute/__init__.py b/flash_sparse_attn/ops/cute/__init__.py index 01de305..1b84363 100644 --- a/flash_sparse_attn/ops/cute/__init__.py +++ b/flash_sparse_attn/ops/cute/__init__.py @@ -3,7 +3,7 @@ from importlib.metadata import PackageNotFoundError, version try: - __version__ = version("flash-sparse-attn") + __version__ = version("fa4") except PackageNotFoundError: __version__ = "0.0.0" @@ -14,7 +14,7 @@ flash_attn_varlen_func, ) -from flash_sparse_attn.ops.cute.cute_dsl_utils import cute_compile_patched +from flash_attn.cute.cute_dsl_utils import cute_compile_patched # Patch cute.compile to optionally dump SASS cute.compile = cute_compile_patched diff --git a/flash_sparse_attn/ops/cute/blackwell_helpers.py b/flash_sparse_attn/ops/cute/blackwell_helpers.py index cdec2e5..7207780 100644 --- a/flash_sparse_attn/ops/cute/blackwell_helpers.py +++ b/flash_sparse_attn/ops/cute/blackwell_helpers.py @@ -7,7 +7,7 @@ from cutlass.cute.nvgpu import tcgen05 from cutlass._mlir.dialects import llvm -import flash_sparse_attn.ops.cute.mma_sm100_desc as sm100_desc +import flash_attn.cute.mma_sm100_desc as sm100_desc @cute.jit diff --git a/flash_sparse_attn/ops/cute/block_info.py b/flash_sparse_attn/ops/cute/block_info.py index cebd0bf..f210138 100644 --- a/flash_sparse_attn/ops/cute/block_info.py +++ b/flash_sparse_attn/ops/cute/block_info.py @@ -6,7 +6,7 @@ import cutlass.cute as cute from cutlass import Int32, const_expr -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK +from flash_attn.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK @dataclass(frozen=True) diff --git a/flash_sparse_attn/ops/cute/block_sparse_utils.py b/flash_sparse_attn/ops/cute/block_sparse_utils.py index 7e8cc82..63e91bc 100644 --- a/flash_sparse_attn/ops/cute/block_sparse_utils.py +++ b/flash_sparse_attn/ops/cute/block_sparse_utils.py @@ -15,8 +15,8 @@ from quack import copy_utils # Import data structures from block_sparsity -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.named_barrier import NamedBarrierBwd # NOTE [SM100 block-sparse empty tiles: mbarrier contract] diff --git a/flash_sparse_attn/ops/cute/block_sparsity.py b/flash_sparse_attn/ops/cute/block_sparsity.py index ea910f4..f19c8fb 100644 --- a/flash_sparse_attn/ops/cute/block_sparsity.py +++ b/flash_sparse_attn/ops/cute/block_sparsity.py @@ -7,7 +7,7 @@ import cutlass.cute as cute import torch -from flash_sparse_attn.ops.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor +from flash_attn.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor def ceildiv(a: int, b: int) -> int: diff --git a/flash_sparse_attn/ops/cute/compute_block_sparsity.py b/flash_sparse_attn/ops/cute/compute_block_sparsity.py index d986ecb..a2dd98e 100644 --- a/flash_sparse_attn/ops/cute/compute_block_sparsity.py +++ b/flash_sparse_attn/ops/cute/compute_block_sparsity.py @@ -6,13 +6,13 @@ import torch from cutlass import Boolean, Int8, Int32, const_expr -from flash_sparse_attn.ops.cute.block_sparsity import ( +from flash_attn.cute.block_sparsity import ( BlockSparseTensors, BlockSparseTensorsTorch, to_cute_block_sparse_tensors, ) -from flash_sparse_attn.ops.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar +from flash_attn.cute.seqlen_info import SeqlenInfoQK class BlockSparsityKernel: diff --git a/flash_sparse_attn/ops/cute/flash_bwd.py b/flash_sparse_attn/ops/cute/flash_bwd.py index d4f1db6..824abdd 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd.py +++ b/flash_sparse_attn/ops/cute/flash_bwd.py @@ -15,14 +15,14 @@ import cutlass.utils as utils_basic from quack import layout_utils -from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_attn.cute.block_sparsity import BlockSparseTensors class FlashAttentionBackwardSm80: diff --git a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py index 36e4e8d..76c8562 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py @@ -18,13 +18,13 @@ from quack import layout_utils from quack import sm90_utils -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute import utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK import cutlass.cute.nvgpu.tcgen05 as tcgen05 from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py index 698b19f..d93ea5c 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py @@ -25,10 +25,10 @@ from quack import copy_utils, layout_utils -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo +from flash_attn.cute import utils +from flash_attn.cute.seqlen_info import SeqlenInfo from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py index 58d3b41..e06cd81 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py @@ -16,27 +16,27 @@ import quack.activation from quack import layout_utils -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import copy_utils -from flash_sparse_attn.ops.cute import pipeline -from flash_sparse_attn.ops.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_attn.cute import utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import copy_utils +from flash_attn.cute import pipeline +from flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, # noqa SingleTileVarlenScheduler, ) -from flash_sparse_attn.ops.cute import barrier -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwdSm100 -from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute import barrier +from flash_attn.cute.named_barrier import NamedBarrierBwdSm100 +from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( get_total_q_block_count_bwd, get_block_sparse_iteration_info_bwd, get_m_block_from_iter_bwd, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py index 0941ae2..556c59e 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py index 054b044..e163350 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py @@ -17,24 +17,24 @@ from quack import sm90_utils from quack.sm90_utils import gemm_zero_init, gemm_w_idx -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute import pipeline +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute import pipeline from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, SingleTileVarlenScheduler, ) -from flash_sparse_attn.ops.cute import barrier -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd -from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute import barrier +from flash_attn.cute.named_barrier import NamedBarrierBwd +from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( get_total_q_block_count_bwd, produce_block_sparse_q_loads_bwd_sm90, consume_block_sparse_mma_bwd_sm90, diff --git a/flash_sparse_attn/ops/cute/flash_fwd.py b/flash_sparse_attn/ops/cute/flash_fwd.py index 43926f7..4d47fab 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd.py +++ b/flash_sparse_attn/ops/cute/flash_fwd.py @@ -23,17 +23,17 @@ from quack import copy_utils from quack import layout_utils -from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.softmax import Softmax -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute.pack_gqa import PackGQA -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import Softmax +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.pack_gqa import PackGQA +from flash_attn.cute.named_barrier import NamedBarrierFwd +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments class FlashAttentionForwardBase: @@ -1190,6 +1190,6 @@ def load_K_next(): # SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility def __getattr__(name): if name == "FlashAttentionForwardSm90": - from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 + from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 return FlashAttentionForwardSm90 raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/flash_sparse_attn/ops/cute/flash_fwd_combine.py b/flash_sparse_attn/ops/cute/flash_fwd_combine.py index 7f38c2b..4936202 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_combine.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_combine.py @@ -12,9 +12,9 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Float32, Int32, Boolean, const_expr -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo +from flash_attn.cute import utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute.seqlen_info import SeqlenInfo from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py index fd42f62..76cf739 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py @@ -32,28 +32,28 @@ from quack import copy_utils, layout_utils -from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -import flash_sparse_attn.ops.cute.pipeline as pipeline_custom -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.softmax import SoftmaxSm100, apply_score_mod_inner -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute.paged_kv import PagedKVManager +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +import flash_attn.cute.pipeline as pipeline_custom +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( get_total_block_count, produce_block_sparse_loads_sm100, softmax_block_sparse_sm100, handle_block_sparse_empty_tile_correction_sm100, ) -from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout -from flash_sparse_attn.ops.cute import mma_sm100_desc as sm100_desc -from flash_sparse_attn.ops.cute import blackwell_helpers as sm100_utils -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwdSm100 +from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout +from flash_attn.cute import mma_sm100_desc as sm100_desc +from flash_attn.cute import blackwell_helpers as sm100_utils +from flash_attn.cute.named_barrier import NamedBarrierFwdSm100 from cutlass.cute import FastDivmodDivisor from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, StaticPersistentTileScheduler, diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py index f9917d8..08d219a 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 class FlashAttentionForwardSm120(FlashAttentionForwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py index 8ae0e77..4108ce4 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py @@ -21,23 +21,23 @@ from quack import layout_utils from quack import sm90_utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.softmax import Softmax, apply_score_mod_inner -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import Softmax, apply_score_mod_inner +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( produce_block_sparse_loads, consume_block_sparse_loads, ) -from flash_sparse_attn.ops.cute import pipeline as pipeline_custom -from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom -from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd +from flash_attn.cute import pipeline as pipeline_custom +from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_attn.cute.paged_kv import PagedKVManager +from flash_attn.cute.named_barrier import NamedBarrierFwd from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, @@ -45,7 +45,7 @@ ) from cutlass.cute import FastDivmodDivisor -from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardBase +from flash_attn.cute.flash_fwd import FlashAttentionForwardBase class FlashAttentionForwardSm90(FlashAttentionForwardBase): diff --git a/flash_sparse_attn/ops/cute/interface.py b/flash_sparse_attn/ops/cute/interface.py index 77ab487..f922cb2 100644 --- a/flash_sparse_attn/ops/cute/interface.py +++ b/flash_sparse_attn/ops/cute/interface.py @@ -34,35 +34,35 @@ import cutlass.cute as cute from cutlass import Int32, Float32 from quack.compile_utils import make_fake_tensor as fake_tensor -from flash_sparse_attn.ops.cute.cache_utils import get_jit_cache -from flash_sparse_attn.ops.cute.testing import is_fake_mode +from flash_attn.cute.cache_utils import get_jit_cache +from flash_attn.cute.testing import is_fake_mode if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: - from flash_sparse_attn.ops.cute import cute_dsl_ptxas # noqa: F401 + from flash_attn.cute import cute_dsl_ptxas # noqa: F401 # Patch to dump ptx and then use system ptxas to compile to cubin cute_dsl_ptxas.patch() -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute import fa_logging -from flash_sparse_attn.ops.cute.cute_dsl_utils import ( +from flash_attn.cute import utils +from flash_attn.cute import fa_logging +from flash_attn.cute.cute_dsl_utils import ( to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims, ) -from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 -from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 -from flash_sparse_attn.ops.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 -from flash_sparse_attn.ops.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 -from flash_sparse_attn.ops.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess -from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 -from flash_sparse_attn.ops.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 -from flash_sparse_attn.ops.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 -from flash_sparse_attn.ops.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 -from flash_sparse_attn.ops.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess -from flash_sparse_attn.ops.cute.flash_fwd_combine import FlashAttentionForwardCombine - -from flash_sparse_attn.ops.cute.block_sparsity import ( +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 +from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 +from flash_attn.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 +from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess +from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 +from flash_attn.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 +from flash_attn.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 +from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess +from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine + +from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, to_cute_block_sparse_tensors, normalize_block_sparse_config, diff --git a/flash_sparse_attn/ops/cute/mask.py b/flash_sparse_attn/ops/cute/mask.py index ba53d13..6b5ca16 100644 --- a/flash_sparse_attn/ops/cute/mask.py +++ b/flash_sparse_attn/ops/cute/mask.py @@ -8,8 +8,8 @@ from cutlass import Float32, Int32, Uint32, const_expr from quack import layout_utils -import flash_sparse_attn.ops.cute.utils as utils -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +import flash_attn.cute.utils as utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK MaskGenFn: TypeAlias = Callable[[int], Uint32] MASK_R2P_CHUNK_SIZE: int = 32 diff --git a/flash_sparse_attn/ops/cute/pack_gqa.py b/flash_sparse_attn/ops/cute/pack_gqa.py index 484834f..e87df01 100644 --- a/flash_sparse_attn/ops/cute/pack_gqa.py +++ b/flash_sparse_attn/ops/cute/pack_gqa.py @@ -9,7 +9,7 @@ from quack import layout_utils -import flash_sparse_attn.ops.cute.utils as utils +import flash_attn.cute.utils as utils def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): diff --git a/flash_sparse_attn/ops/cute/paged_kv.py b/flash_sparse_attn/ops/cute/paged_kv.py index f84b35c..bf11acb 100644 --- a/flash_sparse_attn/ops/cute/paged_kv.py +++ b/flash_sparse_attn/ops/cute/paged_kv.py @@ -6,7 +6,7 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Int32, const_expr -from flash_sparse_attn.ops.cute import utils +from flash_attn.cute import utils from quack.cute_dsl_utils import ParamsBase from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/softmax.py b/flash_sparse_attn/ops/cute/softmax.py index 0eaa479..eed55a0 100644 --- a/flash_sparse_attn/ops/cute/softmax.py +++ b/flash_sparse_attn/ops/cute/softmax.py @@ -10,9 +10,9 @@ from cutlass import Float32 from quack import layout_utils -import flash_sparse_attn.ops.cute.utils as utils +import flash_attn.cute.utils as utils from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.seqlen_info import SeqlenInfoQK @dataclass diff --git a/flash_sparse_attn/ops/cute/tile_scheduler.py b/flash_sparse_attn/ops/cute/tile_scheduler.py index 73add39..c7067ae 100644 --- a/flash_sparse_attn/ops/cute/tile_scheduler.py +++ b/flash_sparse_attn/ops/cute/tile_scheduler.py @@ -16,8 +16,8 @@ from quack.cute_dsl_utils import ParamsBase -import flash_sparse_attn.ops.cute.utils as utils -from flash_sparse_attn.ops.cute.fast_math import clz +import flash_attn.cute.utils as utils +from flash_attn.cute.fast_math import clz class WorkTileInfo(cutlass.utils.WorkTileInfo): From 458bae5334690b3d872515552525c112ce52541d Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Tue, 7 Apr 2026 15:27:11 +0000 Subject: [PATCH 2/5] Merge branch 'sync/cute-upstream-temp' of /workspace/flash-sparse-attn/.ref_repo/flash-attention into sync/cute-worktree-20260407-152315 --- flash_sparse_attn/ops/cute/README.md | 3 +- .../ops/cute/block_sparse_utils.py | 14 +- flash_sparse_attn/ops/cute/block_sparsity.py | 43 +- flash_sparse_attn/ops/cute/cache_utils.py | 24 +- flash_sparse_attn/ops/cute/flash_bwd_sm90.py | 2 +- flash_sparse_attn/ops/cute/flash_fwd_sm100.py | 236 ++++++--- flash_sparse_attn/ops/cute/interface.py | 52 +- flash_sparse_attn/ops/cute/pyproject.toml | 1 + flash_sparse_attn/ops/cute/tile_scheduler.py | 448 ++++++++++++++++-- flash_sparse_attn/ops/cute/utils.py | 29 +- 10 files changed, 696 insertions(+), 156 deletions(-) diff --git a/flash_sparse_attn/ops/cute/README.md b/flash_sparse_attn/ops/cute/README.md index 653f7b1..c7f1b32 100644 --- a/flash_sparse_attn/ops/cute/README.md +++ b/flash_sparse_attn/ops/cute/README.md @@ -27,6 +27,7 @@ out = flash_attn_func(q, k, v, causal=True) ```sh git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention -pip install -e "flash_attn/cute[dev]" +pip install -e "flash_attn/cute[dev]" # CUDA 12.x +pip install -e "flash_attn/cute[dev,cu13]" # CUDA 13.x (e.g. B200) pytest tests/cute/ ``` diff --git a/flash_sparse_attn/ops/cute/block_sparse_utils.py b/flash_sparse_attn/ops/cute/block_sparse_utils.py index 63e91bc..52cb7e0 100644 --- a/flash_sparse_attn/ops/cute/block_sparse_utils.py +++ b/flash_sparse_attn/ops/cute/block_sparse_utils.py @@ -1348,18 +1348,18 @@ def _store_one_dQaccum_sm90( m_block, sdQaccum: cute.Tensor, gdQaccum: cute.Tensor, - num_mma_warp_groups: cutlass.Constexpr, + num_dQ_warp_groups: cutlass.Constexpr, num_threads_per_warp_group: cutlass.Constexpr, tma_copy_bytes_dQ, ): """Store dQaccum for a single m_block.""" - for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): - cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True) + for warp_group_idx in cutlass.range_constexpr(num_dQ_warp_groups): + cute.arch.cp_async_bulk_wait_group(num_dQ_warp_groups - 1 - warp_group_idx, read=True) cute.arch.barrier_arrive( barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx, number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, ) - for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups): + for warp_group_idx in cutlass.range_constexpr(num_dQ_warp_groups): cute.arch.barrier( barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx, number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE, @@ -1383,7 +1383,7 @@ def dQaccum_store_block_sparse_bwd_sm90( gdQaccum: cute.Tensor, subtile_factor: cutlass.Constexpr, m_block_max: int, - num_mma_warp_groups: cutlass.Constexpr, + num_dQ_warp_groups: cutlass.Constexpr, num_threads_per_warp_group: cutlass.Constexpr, tma_copy_bytes_dQ, ): @@ -1412,7 +1412,7 @@ def dQaccum_store_block_sparse_bwd_sm90( m_block, sdQaccum, gdQaccum, - num_mma_warp_groups, + num_dQ_warp_groups, num_threads_per_warp_group, tma_copy_bytes_dQ, ) @@ -1428,7 +1428,7 @@ def dQaccum_store_block_sparse_bwd_sm90( m_block, sdQaccum, gdQaccum, - num_mma_warp_groups, + num_dQ_warp_groups, num_threads_per_warp_group, tma_copy_bytes_dQ, ) diff --git a/flash_sparse_attn/ops/cute/block_sparsity.py b/flash_sparse_attn/ops/cute/block_sparsity.py index f19c8fb..3fad8c9 100644 --- a/flash_sparse_attn/ops/cute/block_sparsity.py +++ b/flash_sparse_attn/ops/cute/block_sparsity.py @@ -34,6 +34,23 @@ class BlockSparseTensorsTorch(NamedTuple): block_size: tuple[int, int] | None = None +def get_sparse_q_block_size( + tensors: BlockSparseTensorsTorch | None, + seqlen_q: int, +) -> int | None: + """Return the Q sparse block size, or None when sparsity is unset or ambiguous.""" + if tensors is None: + return None + if tensors.block_size is not None: + return tensors.block_size[0] + num_m_blocks = tensors.mask_block_idx.shape[2] + min_block_size = ceildiv(seqlen_q, num_m_blocks) + max_block_size = seqlen_q if num_m_blocks == 1 else (seqlen_q - 1) // (num_m_blocks - 1) + if min_block_size != max_block_size: + return None + return min_block_size + + def _expand_sparsity_tensor( tensor: torch.Tensor, expected_shape: Tuple[int, ...], @@ -81,6 +98,12 @@ def _check_and_expand_block( expanded_cnt = _expand_sparsity_tensor( cnt, expected_count_shape, f"{name}_block_cnt", context, hint ) + # [Note] Allow Compact block sparse indices + # Allow the last dimension (n_blocks) of idx to be <= expected, since + # FA4 only accesses indices 0..cnt-1 per query tile. This enables compact + # index tensors that avoid O(N^2) memory at long sequence lengths. + if idx.ndim == 4 and idx.shape[3] <= expected_index_shape[3]: + expected_index_shape = (*expected_index_shape[:3], idx.shape[3]) expanded_idx = _expand_sparsity_tensor( idx, expected_index_shape, f"{name}_block_idx", context, hint ) @@ -140,17 +163,14 @@ def infer_block_sparse_expected_shapes( num_m_blocks = tensors.mask_block_idx.shape[2] if sparse_block_size_q is None: - min_block_size = ceildiv(seqlen_q, num_m_blocks) - if num_m_blocks == 1: - max_block_size = seqlen_q - else: - max_block_size = (seqlen_q - 1) // (num_m_blocks - 1) - if max_block_size != min_block_size and base_m_block != 1: + sparse_block_size_q = get_sparse_q_block_size(tensors, seqlen_q) + if sparse_block_size_q is None and base_m_block != 1: raise ValueError( f"Block sparse tensors{context} require explicit sparse_block_size[0] " f"to disambiguate block size for seqlen_q={seqlen_q} and num_m_blocks={num_m_blocks}." ) - sparse_block_size_q = min_block_size + if sparse_block_size_q is None: + sparse_block_size_q = ceildiv(seqlen_q, num_m_blocks) if sparse_block_size_q % base_m_block != 0: raise ValueError( @@ -186,9 +206,11 @@ def infer_block_sparse_expected_shapes( raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.") if mask_block_cnt.shape[2] != mask_block_idx.shape[2]: raise ValueError(f"Block sparse tensors{context} must share the same m-block dimension.") - if mask_block_idx.shape[3] != expected_n_blocks: + # [Note] Allow Compact block sparse indices: FA4 only accesses indices 0..cnt-1 + # per query tile, so idx.shape[3] can be <= expected_n_blocks. + if mask_block_idx.shape[3] > expected_n_blocks: raise ValueError( - f"Block sparse tensors{context} n-block dimension must be {expected_n_blocks}." + f"Block sparse tensors{context} n-block dimension must be <= {expected_n_blocks}." ) if expected_m_blocks != num_m_blocks: raise ValueError( @@ -314,7 +336,7 @@ def normalize_block_sparse_config( ) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None, int]: m_block_size, n_block_size = block_size if tensors.block_size is None: - sparse_block_size_q, sparse_block_size_kv = q_stage * m_block_size, n_block_size + sparse_block_size_q, sparse_block_size_kv = None, n_block_size else: sparse_block_size_q, sparse_block_size_kv = tensors.block_size if sparse_block_size_kv != n_block_size: @@ -401,6 +423,7 @@ def to_cute_block_sparse_tensors( """Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi""" if not is_block_sparsity_enabled(tensors): return None + ( mask_block_cnt, mask_block_idx, diff --git a/flash_sparse_attn/ops/cute/cache_utils.py b/flash_sparse_attn/ops/cute/cache_utils.py index 8606f04..f1b5970 100644 --- a/flash_sparse_attn/ops/cute/cache_utils.py +++ b/flash_sparse_attn/ops/cute/cache_utils.py @@ -1,7 +1,6 @@ # Manage Ahead-of-Time (AOT) compiled kernels import fcntl import hashlib -import logging import os import pickle import sys @@ -18,6 +17,7 @@ import cutlass.cute as cute import tvm_ffi from cutlass.cutlass_dsl import JitCompiledFunction +from flash_attn.cute.fa_logging import fa_log # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen. @@ -30,12 +30,6 @@ CompileKeyType: TypeAlias = tuple[Hashable, ...] CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function -logger = logging.getLogger(__name__) -_handler = logging.StreamHandler() -_handler.setFormatter(logging.Formatter("%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S")) -logger.addHandler(_handler) -logger.setLevel(logging.DEBUG) - # Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1` CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1" @@ -222,13 +216,13 @@ def _try_load_from_storage(self, key: CompileKeyType) -> bool: label=sha256_hex, ): if obj_path.exists(): - logger.debug("Loading compiled function from disk: %s", obj_path) + fa_log(1, f"Loading compiled function from disk: {obj_path}") m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True) fn = getattr(m, self.EXPORT_FUNCTION_PREFIX) JITCache.__setitem__(self, key, fn) return True else: - logger.debug("Cache miss on disk for key hash %s", sha256_hex) + fa_log(1, f"Cache miss on disk for key hash {sha256_hex}") return False def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None: @@ -243,14 +237,14 @@ def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) - obj_path = self.cache_path / f"{sha256_hex}.o" if obj_path.exists(): # Another process already exported. - logger.debug("Skipping export, already on disk: %s", obj_path) + fa_log(1, f"Skipping export, already on disk: {obj_path}") return - logger.debug("Exporting compiled function to disk: %s", obj_path) + fa_log(1, f"Exporting compiled function to disk: {obj_path}") fn.export_to_c( object_file_path=str(obj_path), function_name=self.EXPORT_FUNCTION_PREFIX, ) - logger.debug("Successfully exported compiled function to disk: %s", obj_path) + fa_log(1, f"Successfully exported compiled function to disk: {obj_path}") def _key_to_hash(self, key: CompileKeyType) -> str: return hashlib.sha256(pickle.dumps(key)).hexdigest() @@ -262,7 +256,7 @@ def clear(self) -> None: """ Not only clear the in-memory cache. Also purge persistent compilation cache. """ - logger.debug("Clearing persistent cache at %s", self.cache_path) + fa_log(1, f"Clearing persistent cache at {self.cache_path}") super().clear() for child in self.cache_path.iterdir(): child.unlink() @@ -281,8 +275,8 @@ def get_jit_cache(name: str | None = None) -> JITCache: path = get_cache_path() / _compute_source_fingerprint() if name: path = path / name - logger.debug("Creating persistent JIT cache at %s", path) + fa_log(1, f"Creating persistent JIT cache at {path}") return JITPersistentCache(path) else: - logger.debug("Persistent cache disabled, using in-memory JIT cache") + fa_log(1, "Persistent cache disabled, using in-memory JIT cache") return JITCache() diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py index e163350..f724b5a 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py @@ -1865,7 +1865,7 @@ def dQaccum_store( gdQaccum, subtile_factor=self.subtile_factor, m_block_max=m_block_max, - num_mma_warp_groups=self.num_wg_mma, + num_dQ_warp_groups=self.num_wg_dQ, num_threads_per_warp_group=self.num_threads_per_warp_group, tma_copy_bytes_dQ=self.tma_copy_bytes["dQ"], ) diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py index 76cf739..6e4fdbf 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py @@ -14,7 +14,7 @@ # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py import math -from typing import Type, Tuple, Callable, Optional, Literal +from typing import Tuple, Callable, Optional, Literal from functools import partial import cuda.bindings.driver as cuda @@ -27,6 +27,7 @@ import cutlass.utils.blackwell_helpers as sm100_utils_basic from cutlass import pipeline from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait +from cutlass.utils import ClcDynamicPersistentTileScheduler from cutlass.base_dsl.arch import Arch from cutlass.cutlass_dsl import BaseDSL @@ -36,6 +37,7 @@ from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute import utils import flash_attn.cute.pipeline as pipeline_custom +import cutlass.pipeline as cutlass_pipeline from flash_attn.cute.mask import AttentionMask from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK @@ -54,12 +56,39 @@ from cutlass.cute import FastDivmodDivisor from quack.cute_dsl_utils import ParamsBase from flash_attn.cute.tile_scheduler import ( + ClcState, + SchedulingMode, TileSchedulerArguments, + TileSchedulerProtocol, SingleTileScheduler, StaticPersistentTileScheduler, SingleTileLPTScheduler, SingleTileVarlenScheduler, ) +from flash_attn.cute.fa_logging import fa_log, fa_printf +from flash_attn.cute.utils import smid + +# === TUNING KNOBS (agent-editable) === +# Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool) +# Values: +# ex2_emu_freq: int — how often to use emulated exp2 (0=all hardware exp2, higher=more emulation). +# SM103 has fast native exp2, so set freq=0 there. +# ex2_emu_start_frg: int — fragment index to start emulation from +# num_regs_softmax: int — register count for softmax warps (multiple of 8) +# num_regs_correction: int — register count for correction warps (multiple of 8) +# num_regs_other is derived: 512 - num_regs_softmax * 2 - num_regs_correction +_TUNING_CONFIG = { + (True, False, 128, False): {'ex2_emu_freq': 10, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 176, 'num_regs_correction': 88}, + (False, True, 128, False): {'ex2_emu_freq': 16, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 192, 'num_regs_correction': 72}, + (True, False, 192, False): {"ex2_emu_freq": 16, "ex2_emu_start_frg": 0, "num_regs_softmax": 184, "num_regs_correction": 80}, + (False, True, 192, False): {"ex2_emu_freq": 32, "ex2_emu_start_frg": 1, "num_regs_softmax": 192, "num_regs_correction": 72}, + (True, False, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 80}, + (False, True, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, + (True, False, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, + (False, True, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 72}, +} +# === END TUNING KNOBS === + # === TUNING KNOBS (agent-editable) === # Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool) @@ -106,6 +135,7 @@ def __init__( paged_kv_non_tma: bool = False, is_varlen_q: bool = False, use_2cta_instrs: bool = False, + use_clc_scheduler: bool = False, ): self.use_tma_KV = not paged_kv_non_tma # self.dtype = dtype @@ -179,6 +209,32 @@ def __init__( "Paged KV does not support irregular head dim" ) + self.use_clc_scheduler = ( + use_clc_scheduler + and self.use_tma_KV + and not self.overlap_sO_sQ + ) + self.sched_stages = 1 + if self.use_clc_scheduler: + assert self.cluster_shape_mn[1] == 1, f"CLC requires cluster N == 1: {self.cluster_shape_mn}" + assert self.cluster_shape_mn[0] in (1, 2), f"bad CLC cluster M: {self.cluster_shape_mn}" + assert self.cluster_shape_mn[0] == self.cta_group_size, ( + f"CLC cluster M != cta_group_size: {self.cluster_shape_mn}, {self.cta_group_size}" + ) + + self.scheduling_mode = SchedulingMode.CLC if self.use_clc_scheduler else SchedulingMode.STATIC + + if is_varlen_q: + self.TileScheduler = SingleTileVarlenScheduler + elif self.is_causal or self.is_local or self.use_clc_scheduler: + self.TileScheduler = SingleTileLPTScheduler + elif self.is_persistent: + self.TileScheduler = StaticPersistentTileScheduler + else: + self.TileScheduler = SingleTileScheduler + + fa_log(1, f"TileScheduler={self.TileScheduler.__name__}, scheduling_mode={self.scheduling_mode.name}, USE_2CTA={self.use_2cta_instrs}") + self.softmax0_warp_ids = (0, 1, 2, 3) self.softmax1_warp_ids = (4, 5, 6, 7) self.correction_warp_ids = (8, 9, 10, 11) @@ -219,6 +275,8 @@ def __init__( elif self.is_varlen_q: # fallback self.epilogue_warp_ids = (13, 14) + self.clc_scheduler_warp_id = self.empty_warp_ids[0] if self.use_clc_scheduler else None + self.tmem_s_offset = [0, self.n_block_size] # e.g., 0, 128 self.tmem_o_offset = [ self.tmem_s_offset[-1] + self.n_block_size + i * self.head_dim_v_padded @@ -551,19 +609,7 @@ def __call__( vO_layout = cute.make_layout((1, async_copy_elems)) gmem_tiled_copy_O = cute.make_tiled_copy_tv(atom_universal_copy, tO_layout, vO_layout) - if const_expr(mCuSeqlensQ is not None or mSeqUsedQ is not None): - TileScheduler = SingleTileVarlenScheduler - else: - if const_expr(self.is_causal or self.is_local): - TileScheduler = SingleTileLPTScheduler - else: - TileScheduler = ( - SingleTileScheduler - if const_expr(not self.is_persistent) - else StaticPersistentTileScheduler - ) - # For non-persistent 2CTA (use_cluster_idx), each cluster covers - # cta_tiler[0] * cta_group_size rows, so num_block must be divided accordingly + TileScheduler = self.TileScheduler _num_block_divisor = self.cta_tiler[0] * (self.cta_group_size if not self.is_persistent and self.cta_group_size > 1 else 1) tile_sched_args = TileSchedulerArguments( cute.ceil_div(cute.size(mQ.shape[0]), _num_block_divisor), @@ -591,7 +637,9 @@ def __call__( cluster_shape_mn=self.cluster_shape_mn, use_cluster_idx=not self.is_persistent and self.cta_group_size > 1, ) - tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + tile_sched_params = TileScheduler.to_underlying_arguments( + tile_sched_args, scheduling_mode=self.scheduling_mode + ) self.tile_scheduler_cls = TileScheduler grid_dim = TileScheduler.get_grid_shape(tile_sched_params) @@ -601,6 +649,9 @@ def __call__( cutlass.max(cute.cosize(sQ_layout), cute.cosize(sO_layout) * self.o_dtype.width // self.q_dtype.width) ) + clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0 + clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0 + @cute.struct class SharedStorage: # m_barriers for pipelines @@ -620,6 +671,13 @@ class SharedStorage: # Smem tensors # store row max and row sum sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2] + # CLC buffers placed here to utilize padding before sO's 1024-byte alignment. + # This avoids adding bytes at the end when we're at the smem limit. + # PipelineClcFetchAsync expects 2 * sched_stages mbarriers (full + empty). + clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size] + # CLC response storage (16 bytes per stage, stored as 4 Int32s). + clc_response: cute.struct.MemRange[Int32, clc_response_size] + # Large TMA buffers with 1024-byte alignment sO: cute.struct.Align[ cute.struct.MemRange[self.o_dtype, sO_size], self.buffer_align_bytes ] @@ -787,8 +845,8 @@ def kernel( ThreadCooperativeGroup = partial(pipeline.CooperativeGroup, pipeline.Agent.Thread) mma_warp = ThreadCooperativeGroup(len([self.mma_warp_id])) - load_warps = ThreadCooperativeGroup(len(self.load_warp_ids)) tma_warp = ThreadCooperativeGroup(1) + load_threads = ThreadCooperativeGroup(len(self.load_warp_ids) * cute.arch.WARP_SIZE) softmax_warps = ThreadCooperativeGroup(len(self.softmax0_warp_ids)) softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE * len(self.softmax0_warp_ids)) # softmax_threads = ThreadCooperativeGroup(cute.arch.WARP_SIZE) @@ -822,13 +880,10 @@ def kernel( defer_sync=True, ) else: - cpasync_producer_group_q = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len(self.load_warp_ids) * cute.arch.WARP_SIZE - ) pipeline_q = pipeline_custom.PipelineAsyncUmma.create( barrier_storage=storage.mbar_load_Q.data_ptr(), num_stages=self.q_stage, - producer_group=cpasync_producer_group_q, + producer_group=load_threads, consumer_group=mma_warp, cta_layout_vmnk=cta_layout_vmnk, defer_sync=True, @@ -844,13 +899,10 @@ def kernel( defer_sync=True, ) else: - cpasync_producer_group = pipeline.CooperativeGroup( - pipeline.Agent.Thread, len(self.load_warp_ids) * cute.arch.WARP_SIZE - ) pipeline_kv = pipeline.PipelineAsyncUmma.create( barrier_storage=storage.mbar_load_KV.data_ptr(), num_stages=self.kv_stage, - producer_group=cpasync_producer_group, + producer_group=load_threads, consumer_group=mma_warp, cta_layout_vmnk=cta_layout_vmnk, defer_sync=True, @@ -986,17 +1038,69 @@ def kernel( window_size_right=window_size_right, qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, ) - TileSchedulerCls = partial(self.tile_scheduler_cls.create, tile_sched_params) - # Cluster wait before tensor memory alloc pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + if const_expr(self.use_clc_scheduler): + clc_response_ptr = storage.clc_response.data_ptr() + clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr() + + clc_pipeline_producer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread + ) + num_clc_consumer_warps_per_cta = self.threads_per_cta // cute.arch.WARP_SIZE + # NB on CTA0 warp15 == scheduler on CTA1 == empty but still both consume + num_clc_consumer_warps = num_clc_consumer_warps_per_cta * self.cta_group_size + clc_pipeline_consumer_group = cutlass_pipeline.CooperativeGroup( + cutlass_pipeline.Agent.Thread, cute.arch.WARP_SIZE * num_clc_consumer_warps + ) + + block_idx = cute.arch.block_idx() + clc = ClcState.create( + hw_scheduler=ClcDynamicPersistentTileScheduler.create( + self.tile_scheduler_cls.clc_problem_shape(tile_sched_params), + block_idx, + cute.arch.grid_dim(), + clc_response_ptr, + ), + pipeline=cutlass_pipeline.PipelineClcFetchAsync.create( + barrier_storage=clc_mbar_ptr, + num_stages=self.sched_stages, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=16, + cta_layout_vmnk=cta_layout_vmnk, + ), + consumer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Consumer, self.sched_stages + ), + producer_state=cutlass_pipeline.make_pipeline_state( + cutlass_pipeline.PipelineUserType.Producer, self.sched_stages + ), + ) + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params, clc=clc) + else: + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) + assert isinstance(tile_scheduler, TileSchedulerProtocol), f"tile_scheduler is not a TileSchedulerProtocol: {type(tile_scheduler)}" + # /////////////////////////////////////////////////////////////////////////////// - # EMPTY + # EMPTY / CLC SCHEDULER WARP # /////////////////////////////////////////////////////////////////////////////// - for i in cutlass.range_constexpr(len(self.empty_warp_ids)): - if warp_idx == self.empty_warp_ids[i]: + if const_expr(self.use_clc_scheduler): + if warp_idx == self.clc_scheduler_warp_id: cute.arch.setmaxregister_decrease(self.num_regs_other) + if is_leader_cta: + self.clc_scheduler_warp(tile_scheduler) + else: + self.empty_warp(tile_scheduler) + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i] and warp_idx != self.clc_scheduler_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_other) + self.empty_warp(tile_scheduler) + else: + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i]: + cute.arch.setmaxregister_decrease(self.num_regs_other) # /////////////////////////////////////////////////////////////////////////////// # LOAD @@ -1022,8 +1126,8 @@ def kernel( block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, blocksparse_tensors, + tile_scheduler=tile_scheduler, ) # /////////////////////////////////////////////////////////////////////////////// @@ -1053,8 +1157,8 @@ def kernel( block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, blocksparse_tensors, + tile_scheduler=tile_scheduler, ) # Dealloc the tensor memory buffer tmem.relinquish_alloc_permit() @@ -1076,8 +1180,8 @@ def kernel( block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, mma_tile_coord_v, + tile_scheduler=tile_scheduler, ) # /////////////////////////////////////////////////////////////////////////////// @@ -1109,11 +1213,11 @@ def kernel( num_splits=num_splits, SeqlenInfoCls=SeqlenInfoCls, AttentionMaskCls=AttentionMaskCls, - TileSchedulerCls=TileSchedulerCls, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, head_divmod=head_divmod, blocksparse_tensors=blocksparse_tensors, + tile_scheduler=tile_scheduler, ) if const_expr(not self.s0_s1_barrier): @@ -1157,8 +1261,8 @@ def kernel( block_info, num_splits, SeqlenInfoCls, - TileSchedulerCls, blocksparse_tensors, + tile_scheduler=tile_scheduler, ) tmem_alloc_barrier.arrive() @@ -1185,8 +1289,8 @@ def load( block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors], + tile_scheduler: TileSchedulerProtocol, ): num_load_threads = len(self.load_warp_ids) * cute.arch.WARP_SIZE tidx = cute.arch.thread_idx()[0] % num_load_threads @@ -1203,7 +1307,6 @@ def load( kv_producer_state = pipeline.make_pipeline_state( pipeline.PipelineUserType.Producer, self.kv_stage ) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -1374,9 +1477,8 @@ def load( self.q_subtile_factor if self.q_subtile_factor is not None else 1, ) - tile_scheduler.prefetch_next_work() - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop if issue_kv_for_this_warp: @@ -1405,8 +1507,8 @@ def mma( block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors], + tile_scheduler=None, ): tSrQ = tiled_mma_qk.make_fragment_A(sQ) tSrK = tiled_mma_qk.make_fragment_B(sK) @@ -1495,7 +1597,6 @@ def mma( ) P_full_O_rescaled_phase = Int32(0) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -1666,8 +1767,7 @@ def mma( # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop # We don't need pipeline_s_p_o.producer_tail() since there's no dangling mbarrier at the end @@ -1696,11 +1796,11 @@ def softmax_loop( num_splits: Int32, SeqlenInfoCls: Callable, AttentionMaskCls: Callable, - TileSchedulerCls: Callable, aux_tensors: Optional[list] = None, fastdiv_mods=(None, None), head_divmod=None, blocksparse_tensors: Optional[BlockSparseTensors] = None, + tile_scheduler=None, ): """Compute softmax on attention scores from QK matrix multiplication. @@ -1760,7 +1860,6 @@ def softmax_loop( warp_idx_in_wg = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -2003,8 +2102,7 @@ def softmax_loop( # gLSE[tidx] = lse # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop # This is equivalent to pipeline_sm_stats.producer_tail @@ -2174,8 +2272,8 @@ def correction_loop( block_info: BlockInfo, num_splits: Int32, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, blocksparse_tensors: Optional[BlockSparseTensors] = None, + tile_scheduler=None, ): tidx = cute.arch.thread_idx()[0] % (cute.arch.WARP_SIZE * len(self.correction_warp_ids)) warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % 4 @@ -2205,7 +2303,6 @@ def correction_loop( o_corr_consumer_phase = Int32(0) corr_epi_producer_phase = Int32(1) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -2436,8 +2533,7 @@ def correction_loop( cute.make_tensor(lse_gmem_ptr, (1,))[0] = lse # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() # End of persistent scheduler loop # This is equivalent to pipeline_o_epi.consumer_tail() for the correction warps @@ -2644,11 +2740,10 @@ def epilogue_s2g( block_info: BlockInfo, num_splits: int, SeqlenInfoCls: Callable, - TileSchedulerCls: Callable, mma_tile_coord_v: Int32 = 0, + tile_scheduler=None, ): epi_consumer_phase = Int32(0) - tile_scheduler = TileSchedulerCls() work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx @@ -2704,8 +2799,39 @@ def epilogue_s2g( epi_consumer_phase ^= 1 # Advance to next tile - tile_scheduler.advance_to_next_work() - work_tile = tile_scheduler.get_current_work() + work_tile = tile_scheduler.advance_to_next_work() + + @cute.jit + def clc_scheduler_warp( + self, + tile_scheduler: TileSchedulerProtocol, + ): + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + tile_scheduler.prefetch_next_work() + work_tile = tile_scheduler.advance_to_next_work() + if cute.arch.thread_idx()[0] == self.clc_scheduler_warp_id * cute.arch.WARP_SIZE: + fa_printf( + 3, + "[CLC] query sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n", + smid(), + cute.arch.block_idx()[0], + work_tile.tile_idx[0], + work_tile.tile_idx[1], + work_tile.tile_idx[2], + work_tile.tile_idx[3], + work_tile.is_valid_tile, + ) + tile_scheduler.producer_tail() + + @cute.jit + def empty_warp( + self, + tile_scheduler: TileSchedulerProtocol, + ): + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_tile = tile_scheduler.advance_to_next_work() def load_Q( self, diff --git a/flash_sparse_attn/ops/cute/interface.py b/flash_sparse_attn/ops/cute/interface.py index f922cb2..872960a 100644 --- a/flash_sparse_attn/ops/cute/interface.py +++ b/flash_sparse_attn/ops/cute/interface.py @@ -64,6 +64,7 @@ from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, + get_sparse_q_block_size, to_cute_block_sparse_tensors, normalize_block_sparse_config, normalize_block_sparse_config_bwd, @@ -124,20 +125,27 @@ class FwdConfig: intra_wg_overlap: bool -def _tile_size_fwd_sm90(head_dim, head_dim_v, is_causal, is_local, use_block_sparsity): +def _tile_size_fwd_sm90(head_dim, head_dim_v, is_causal, is_local, sparse_block_size_q=None): """Return FwdConfig for SM90 forward. Tile sizes and flags based on tile_size_fwd_sm90 in hopper/tile_size.h, adjusted for the Python kernel's different register/smem tradeoffs (benchmarked on H100 SXM). + + When sparse_block_size_q is set, tile_m must divide it. For head_dim <= 96 the + optimal tile_m=192 is used when compatible, otherwise we fall back to 128. """ if head_dim <= 64: # C++: 192×192 non-causal, 192×128 causal/local. # Python: 192×128 RS+OL is consistently best across seqlens. + if sparse_block_size_q is not None and sparse_block_size_q % 192 != 0: + return FwdConfig(128, 128, True, True) return FwdConfig(192, 128, True, True) elif head_dim <= 96: # C++: 192×144 noRS+OL for all cases. # Python: RS is catastrophic with 192× tiles (~300 vs ~600 TFLOPS). # noRS+OL is always required. Causal: 192×128 slightly better short seqlen. + if sparse_block_size_q is not None and sparse_block_size_q % 192 != 0: + return FwdConfig(128, 128, False, True) if is_causal or is_local: return FwdConfig(192, 128, False, True) else: @@ -151,7 +159,6 @@ def _tile_size_fwd_sm90(head_dim, head_dim_v, is_causal, is_local, use_block_spa tile_n = 64 if is_local else 80 return FwdConfig(128, tile_n, True, True) - @dataclass(frozen=True) class BwdConfig: m_block_size: int @@ -169,7 +176,7 @@ class BwdConfig: dQ_single_wg: bool = False -def _tile_size_bwd_sm90(head_dim, head_dim_v, causal, local): +def _tile_size_bwd_sm90(head_dim, head_dim_v, causal, local, sparse_block_size_q=None): """Return BwdConfig for SM90. Configs based on C++ FA3 hopper/flash_bwd_launch_template.h, @@ -196,6 +203,8 @@ def _tile_size_bwd_sm90(head_dim, head_dim_v, causal, local): # C++ FA3: causal/local: 64, 128; non-causal: 80, 128 with dQ_swapAB is_causal_or_local = causal or local m_block_size = 64 if is_causal_or_local else 80 + if sparse_block_size_q is not None and sparse_block_size_q % m_block_size != 0: + m_block_size = 64 return BwdConfig( m_block_size=m_block_size, n_block_size=128, @@ -448,7 +457,9 @@ def _flash_attn_fwd( causal, window_size_left, window_size_right, mask_mod ) - # In fake mode (CPU-only compilation), use a fake stream placeholder. + requested_use_clc_scheduler = utils._get_use_clc_scheduler_default() + requested_disable_2cta = utils._get_disable_2cta_default() + current_stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) # SM80/SM120: uses SM80 MMA, 128 threads (4 warps) @@ -468,7 +479,8 @@ def _flash_attn_fwd( elif arch // 10 == 8: fwd_cfg = FwdConfig(128, 64, True, True) # SM80, should tune elif arch // 10 == 9: - fwd_cfg = _tile_size_fwd_sm90(head_dim, head_dim_v, causal, local, use_block_sparsity) + sparse_q = get_sparse_q_block_size(block_sparse_tensors, seqlen_q) + fwd_cfg = _tile_size_fwd_sm90(head_dim, head_dim_v, causal, local, sparse_block_size_q=sparse_q) else: fwd_cfg = FwdConfig(tile_mn[0], tile_mn[1], fwd_cfg.mma_pv_is_rs, fwd_cfg.intra_wg_overlap) tile_m, tile_n = fwd_cfg.m_block_size, fwd_cfg.n_block_size @@ -517,6 +529,7 @@ def _flash_attn_fwd( use_2cta_instrs = ( arch // 10 in [10, 11] + and not requested_disable_2cta and not causal and not local and not is_split_kv @@ -621,6 +634,7 @@ def _flash_attn_fwd( q_subtile_factor, mma_pv_is_rs, intra_wg_overlap, + requested_use_clc_scheduler, fa_logging.get_fa_log_level(), ) if compile_key not in _flash_attn_fwd.compile_cache: @@ -728,6 +742,7 @@ def _flash_attn_fwd( is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, q_subtile_factor=q_subtile_factor, use_2cta_instrs=use_2cta_instrs, + use_clc_scheduler=requested_use_clc_scheduler, ) elif arch // 10 == 12: # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity @@ -994,6 +1009,9 @@ def _flash_attn_bwd( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: arch = _get_device_arch() assert arch // 10 in [9, 10, 11, 12], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x, 12.x" + sparse_q = None + if block_sparse_tensors is not None and arch // 10 == 9: + sparse_q = block_sparse_tensors.block_size[0] if block_sparse_tensors.block_size is not None else 128 num_head, head_dim = q.shape[-2:] head_dim_v = v.shape[-1] @@ -1027,7 +1045,13 @@ def _flash_attn_bwd( assert mask_mod is None, "mask_mod backward not supported on SM 12.0" assert deterministic is False, "deterministic backward not supported on SM 12.0" elif arch // 10 == 9: - cfg = _tile_size_bwd_sm90(head_dim, head_dim_v, causal, local) + cfg = _tile_size_bwd_sm90( + head_dim, + head_dim_v, + causal, + local, + sparse_block_size_q=sparse_q, + ) m_block_size = cfg.m_block_size n_block_size = cfg.n_block_size num_stages_Q = cfg.num_stages_Q @@ -1056,10 +1080,13 @@ def _flash_attn_bwd( dKV_swapAB = False AtomLayoutMdQ = 1 AtomLayoutNdKV = 1 + requested_disable_2cta = utils._get_disable_2cta_default() disable_2cta = ( - score_mod is not None + requested_disable_2cta + or score_mod is not None or score_mod_bwd is not None or mask_mod is not None + or block_sparse_tensors is not None ) cluster_size = 2 if head_dim >= 128 and not disable_2cta else 1 use_2cta_instrs = cluster_size==2 @@ -1087,16 +1114,7 @@ def _flash_attn_bwd( num_head_kv = k.shape[-2] use_block_sparsity = block_sparse_tensors is not None - - # SM90 block-sparse backward: tile_m=64 is the GCD between a m_block_size that fits, - # the base block_m of 128 from forward, and block-sparse size for subtiling. - if arch // 10 == 9 and use_block_sparsity: - m_block_size = 64 - # dQ_swapAB tuning: use False when m_block_size=64 (same as causal case) - dQ_swapAB = False - - # NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2 - subtile_factor = 2 + subtile_factor = sparse_q // m_block_size if sparse_q is not None else 2 seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size num_n_blocks = seqlen_k_rounded // n_block_size diff --git a/flash_sparse_attn/ops/cute/pyproject.toml b/flash_sparse_attn/ops/cute/pyproject.toml index 2b0b60b..6ecf64d 100644 --- a/flash_sparse_attn/ops/cute/pyproject.toml +++ b/flash_sparse_attn/ops/cute/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ cu13 = ["nvidia-cutlass-dsl[cu13]>=4.4.2"] dev = [ "pytest", + "pytest-xdist", "ruff", ] diff --git a/flash_sparse_attn/ops/cute/tile_scheduler.py b/flash_sparse_attn/ops/cute/tile_scheduler.py index c7067ae..3ee4bc8 100644 --- a/flash_sparse_attn/ops/cute/tile_scheduler.py +++ b/flash_sparse_attn/ops/cute/tile_scheduler.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, Tri Dao. -from typing import Optional, Tuple +from enum import IntEnum, auto +from typing import Optional, Tuple, Protocol, runtime_checkable from dataclasses import dataclass try: @@ -9,10 +10,12 @@ from typing_extensions import override import cutlass +from cutlass.pipeline import PipelineClcFetchAsync, PipelineState from cutlass._mlir import ir import cutlass.cute as cute from cutlass import Int32, const_expr from cutlass.cute import FastDivmodDivisor +from cutlass.utils import ClcDynamicPersistentTileScheduler, ClcDynamicPersistentTileSchedulerParams from quack.cute_dsl_utils import ParamsBase @@ -20,6 +23,67 @@ from flash_attn.cute.fast_math import clz +class SchedulingMode(IntEnum): + NONE = auto() + STATIC = auto() + DYNAMIC = auto() + CLC = auto() + + +@dataclass +class ClcState(ParamsBase): + """Owns the runtime state shared by CLC-capable tile schedulers. + + `FlashAttentionForwardSm100` constructs this state because it owns the CLC + response buffer, mbarrier storage, and launch geometry needed to initialize + the hardware scheduler and async pipeline. Individual tile schedulers then + consume this state and map the returned hardware work tiles into their own + logical `WorkTileInfo` coordinates. + + To add CLC support to a scheduler: + - implement `clc_problem_shape(params)` so the kernel can create the hardware scheduler + - accept `clc: ClcState | None` in `create(...)` / `__init__` + - map `clc.initial_work_tile_info()` and `clc.get_current_work()` into scheduler coordinates + """ + + _hw_scheduler: ClcDynamicPersistentTileScheduler + _pipeline: PipelineClcFetchAsync + _consumer_state: PipelineState + _producer_state: PipelineState + + @staticmethod + def create( + *, + hw_scheduler: ClcDynamicPersistentTileScheduler, + pipeline: PipelineClcFetchAsync, + consumer_state: PipelineState, + producer_state: PipelineState, + ) -> "ClcState": + return ClcState(hw_scheduler, pipeline, consumer_state, producer_state) + + def initial_work_tile_info(self): + return self._hw_scheduler.initial_work_tile_info() + + def get_current_work(self): + return self._hw_scheduler.get_current_work() + + def prefetch_next_work(self, *, loc=None, ip=None): + self._pipeline.producer_acquire(self._producer_state, loc=loc, ip=ip) + mbarrier_addr = self._pipeline.producer_get_barrier(self._producer_state, loc=loc, ip=ip) + self._hw_scheduler.advance_to_next_work(mbarrier_addr, loc=loc, ip=ip) + self._producer_state.advance(loc=loc, ip=ip) + + def consumer_wait(self, *, loc=None, ip=None): + self._pipeline.consumer_wait(self._consumer_state, loc=loc, ip=ip) + + def consumer_release(self, *, loc=None, ip=None): + self._pipeline.consumer_release(self._consumer_state, loc=loc, ip=ip) + self._consumer_state.advance(loc=loc, ip=ip) + + def producer_tail(self, *, loc=None, ip=None): + self._pipeline.producer_tail(self._producer_state, loc=loc, ip=ip) + + class WorkTileInfo(cutlass.utils.WorkTileInfo): """Altered WorkTileInfo which includes four axes: (block, head, batch, split)""" @@ -31,6 +95,47 @@ def __new_from_mlir_values__(self, values: list[ir.Value]) -> "WorkTileInfo": return WorkTileInfo(new_tile_idx, new_is_valid_tile) +@runtime_checkable +class TileSchedulerProtocol(Protocol): + """Protocol defining the interface all tile schedulers must implement. + + Schedulers are responsible for: + 1. Coordinate mapping: linear tile index -> (m_block, head, batch, split) + 2. Work distribution: how to get the next tile (static grid-stride vs CLC dynamic) + """ + + def get_current_work(self) -> WorkTileInfo: + """Get the current work tile coordinates.""" + ... + + def initial_work_tile_info(self) -> WorkTileInfo: + """Get the initial work tile for this CTA.""" + ... + + def advance_to_next_work(self, *, loc=None, ip=None): + """Consumer-side advance: move to next tile and return it. + + For static schedulers: grid-stride increment + get_current_work. + For CLC schedulers: consumer wait + get_current_work + consumer release + state advance. + """ + ... + + def prefetch_next_work(self, *, loc=None, ip=None) -> None: + """Producer-side prefetch of next work tile (no-op for static schedulers). + + For CLC schedulers: producer acquire + issue CLC query + producer state advance. + Only called by the scheduler warp. + """ + ... + + def producer_tail(self, *, loc=None, ip=None) -> None: + """Producer-side cleanup after the last tile. + + No-op for static schedulers. For CLC schedulers: pipeline producer_tail. + """ + ... + + @dataclass class TileSchedulerArguments(ParamsBase): num_block: Int32 @@ -89,15 +194,25 @@ def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"SingleTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) return SingleTileScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod - def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler": + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileScheduler": if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx): blk_coord = cute.arch.block_idx() else: - # All CTAs in a cluster must get the same block coordinate blk_coord = cute.arch.cluster_idx() return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip) @@ -141,6 +256,10 @@ def prefetch_next_work(self, *, loc=None, ip=None): def advance_to_next_work(self, *, loc=None, ip=None): self._is_first_block = False + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass def __extract_mlir_values__(self): values, self._values_pos = [], [] @@ -186,18 +305,28 @@ def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"StaticPersistentTileScheduler only supports STATIC, got {scheduling_mode!r}" + ) return StaticPersistentTileScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod - def create(params: Params, *, loc=None, ip=None) -> "StaticPersistentTileScheduler": + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "StaticPersistentTileScheduler": if const_expr(cute.size(params.cluster_shape_m) == 1): tile_idx = cute.arch.block_idx()[0] else: tile_idx = cute.arch.cluster_idx()[0] return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip) - # called by host @staticmethod def get_grid_shape( params: Params, @@ -207,18 +336,14 @@ def get_grid_shape( ) -> Tuple[Int32, Int32, Int32]: hardware_info = cutlass.utils.HardwareInfo() sm_count = hardware_info.get_device_multiprocessor_count() - # Grid must be a multiple of cluster_shape_m for CUDA cluster launch. max_ctas = (sm_count // params.cluster_shape_m) * params.cluster_shape_m grid_x = cutlass.min(max_ctas, params.total_blocks_cluster * params.cluster_shape_m) return (grid_x, Int32(1), Int32(1)) - # @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: hn_idx, block_idx = divmod(self._tile_idx, self.params.num_block_cluster_divmod) batch_idx, head_idx = divmod(hn_idx, self.params.num_head_divmod) is_valid = self._tile_idx < self.params.total_blocks_cluster - # if cute.arch.thread_idx()[0] == 0: - # cute.printf("TileScheduler: tile_idx=%d, hn_idx=%d, block_idx=%d, batch_idx=%d, head_idx=%d, is_valid=%d", self._tile_idx, hn_idx, block_idx, batch_idx, head_idx, is_valid) return WorkTileInfo( (Int32(block_idx), Int32(head_idx), Int32(batch_idx), Int32(0)), is_valid ) @@ -234,6 +359,10 @@ def advance_to_next_work(self, *, loc=None, ip=None): self._tile_idx += cute.arch.grid_dim()[0] else: self._tile_idx += cute.arch.cluster_dim()[0] + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + pass def __extract_mlir_values__(self): values, self._values_pos = [], [] @@ -260,32 +389,41 @@ class Params(ParamsBase): total_blocks: Int32 num_splits: Int32 num_block: Int32 + num_head: Int32 + num_batch: Int32 l2_minor: Int32 - num_block_divmod: FastDivmodDivisor num_head_divmod: FastDivmodDivisor l2_minor_divmod: FastDivmodDivisor l2_major_divmod: FastDivmodDivisor l2_minor_residual_divmod: FastDivmodDivisor num_hb_quotient: Int32 + num_splits_divmod: FastDivmodDivisor is_split_kv: cutlass.Constexpr[bool] = False + cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC + lpt: cutlass.Constexpr[bool] = True @staticmethod @cute.jit def create( - args: TileSchedulerArguments, *, loc=None, ip=None + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, ) -> "SingleTileLPTScheduler.Params": - # cute.printf(args.num_block, args.num_head, args.num_batch, args.seqlen_k, args.headdim, args.headdim_v, args.total_q, args.tile_shape_mn, args.qhead_per_kvhead_packgqa, args.element_size) + assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( + f"Only STATIC and CLC are supported, got {scheduling_mode!r}" + ) size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size size_one_head = size_one_kv_head size_l2 = 50 * 1024 * 1024 # 40 MB for K & V # Swizzle is the size of each "section". Round swizzle to a power of 2 # Need to be careful about the case where only one head will fit # swizzle is how many heads can fit in L2 - # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) - # Seems faster if swizzle if a power of 2 + # Seems faster if swizzle is a power of 2 log2_floor = lambda n: 31 - clz(n) swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) - # swizzle = 1 if size_l2 < size_one_head else (size_l2 // size_one_head) # If we're in the last section (called residual), we don't want to divide by # swizzle. Instead we want to divide by the remainder. num_hb_quotient = (args.num_head * args.num_batch) // swizzle @@ -293,37 +431,84 @@ def create( return SingleTileLPTScheduler.Params( total_blocks=args.num_block * args.num_head * args.num_batch, num_block=args.num_block, + num_head=args.num_head, + num_batch=args.num_batch, l2_minor=Int32(swizzle), - num_block_divmod=FastDivmodDivisor(args.num_block), num_head_divmod=FastDivmodDivisor(args.num_head), l2_minor_divmod=FastDivmodDivisor(swizzle), l2_major_divmod=FastDivmodDivisor(swizzle * args.num_block), - l2_minor_residual_divmod=FastDivmodDivisor( - max(num_hb_remainder, 1) - ), # don't divide by 0 + l2_minor_residual_divmod=FastDivmodDivisor(max(num_hb_remainder, 1)), num_hb_quotient=Int32(num_hb_quotient), num_splits=args.num_splits, + num_splits_divmod=FastDivmodDivisor(args.num_splits), is_split_kv=args.is_split_kv, + cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, + lpt=args.lpt, ) - def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): + def __init__( + self, + params: Params, + tile_idx: Int32, + split_idx: Int32, + clc: ClcState | None = None, + *, + loc=None, + ip=None, + ): self.params = params self._tile_idx = tile_idx self._split_idx = split_idx + self.clc = clc self._loc = loc self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: - return SingleTileLPTScheduler.Params.create(args, loc=loc, ip=ip) + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + return SingleTileLPTScheduler.Params.create( + args, scheduling_mode=scheduling_mode, loc=loc, ip=ip + ) + + @staticmethod + def _clc_grid_shape(params: Params): + num_batch_splits = ( + params.num_batch * params.num_splits + if const_expr(params.is_split_kv) + else params.num_batch + ) + return ( + cute.round_up(params.num_block, params.cluster_shape_m), + params.num_head, + num_batch_splits, + ) @staticmethod @cute.jit - def create(params: Params, *, loc=None, ip=None) -> "SingleTileLPTScheduler": + def clc_problem_shape(params: Params): + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=SingleTileLPTScheduler._clc_grid_shape(params), + cluster_shape_mnk=(params.cluster_shape_m, 1, 1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileLPTScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler( + params, cute.arch.block_idx()[0], Int32(0), clc, loc=loc, ip=ip + ) tile_idx, split_idx, _ = cute.arch.block_idx() return SingleTileLPTScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) - # called by host @staticmethod def get_grid_shape( params: Params, @@ -331,10 +516,40 @@ def get_grid_shape( loc=None, ip=None, ) -> Tuple[Int32, Int32, Int32]: + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + return SingleTileLPTScheduler._clc_grid_shape(params) return (params.total_blocks, params.num_splits, Int32(1)) + @cute.jit + def clc_work_to_coords(self, work) -> WorkTileInfo: + """Convert CLC response (block, head, batch_split) to WorkTileInfo. + + CLC returns raw grid coordinates — no L2 swizzle (hardware decides order). + We only apply cluster division, optional LPT block reversal, and split_kv unpacking. + """ + block_idx = work.tile_idx[0] + if const_expr(self.params.cluster_shape_m > 1): + block_idx = block_idx // self.params.cluster_shape_m + if const_expr(self.params.lpt): + # Longest-processing-time-first: reverse block order + block_idx = self.params.num_block - 1 - block_idx + split_idx = Int32(0) + if const_expr(self.params.is_split_kv): + batch_idx, split_idx = divmod(work.tile_idx[2], self.params.num_splits_divmod) + else: + batch_idx = work.tile_idx[2] + return WorkTileInfo( + (Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)), + work.is_valid_tile, + ) + @cute.jit def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.get_current_work() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) + # Static path: L2-swizzled coordinate mapping params = self.params # Implement LPT scheduling coordinate calculation bidhb, l2_mod = divmod(self._tile_idx, params.l2_major_divmod) @@ -348,25 +563,45 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: bidhb_actual = bidhb * params.l2_minor + bidhb_residual batch_idx, head_idx = divmod(bidhb_actual, params.num_head_divmod) # Longest-processing-time-first - block = params.num_block - 1 - block + if const_expr(params.lpt): + block = params.num_block - 1 - block is_valid = self._tile_idx < params.total_blocks return WorkTileInfo( (Int32(block), Int32(head_idx), Int32(batch_idx), Int32(self._split_idx)), is_valid ) + @cute.jit def initial_work_tile_info(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + work = self.clc.initial_work_tile_info() + self._tile_idx = work.tile_idx[0] + return self.clc_work_to_coords(work) return self.get_current_work(loc=loc, ip=ip) def prefetch_next_work(self, *, loc=None, ip=None): - pass + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.prefetch_next_work(loc=loc, ip=ip) def advance_to_next_work(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work() + self.clc.consumer_release(loc=loc, ip=ip) + return work # Single tile scheduler - set to invalid tile_idx to indicate no more work self._tile_idx = self.params.total_blocks + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.producer_tail(loc=loc, ip=ip) def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [self.params, self._tile_idx, self._split_idx]: + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj in objs: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -374,10 +609,13 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip([self.params, self._tile_idx, self._split_idx], self._values_pos): + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj, n_items in zip(objs, self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] - return self.__class__(*(tuple(obj_list)), loc=self._loc) + return self.__class__(*obj_list, loc=self._loc) class SingleTileLPTBwdScheduler: @@ -436,7 +674,16 @@ def __init__(self, params: Params, tile_idx: Int32, *, loc=None, ip=None): self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + assert scheduling_mode == SchedulingMode.STATIC, ( + f"SingleTileLPTBwdScheduler only supports STATIC, got {scheduling_mode!r}" + ) return SingleTileLPTBwdScheduler.Params.create(args, loc=loc, ip=ip) @staticmethod @@ -487,6 +734,7 @@ def prefetch_next_work(self, *, loc=None, ip=None): def advance_to_next_work(self, *, loc=None, ip=None): # Single tile scheduler - set to invalid tile_idx to indicate no more work self._tile_idx = self.params.total_blocks + return self.get_current_work() def __extract_mlir_values__(self): values, self._values_pos = [], [] @@ -520,15 +768,25 @@ class Params(ParamsBase): is_split_kv: cutlass.Constexpr[bool] = False head_swizzle: cutlass.Constexpr[bool] = False cluster_shape_m: cutlass.Constexpr[int] = 1 + scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC @staticmethod @cute.jit def create( - args: TileSchedulerArguments, *, loc=None, ip=None + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, ) -> "SingleTileVarlenScheduler.Params": + assert scheduling_mode in (SchedulingMode.STATIC, SchedulingMode.CLC), ( + f"Only STATIC and CLC are supported, got {scheduling_mode!r}" + ) size_l2 = 50 * 1024 * 1024 # 50 MB for K & V # if backward, this is qdo block size - kv_block_size = (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + kv_block_size = ( + (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + ) # if backward, add dqaccum block size to calculate swizzle if args.head_swizzle: kv_block_size += args.headdim * 4 * args.tile_shape_mn[1] @@ -537,6 +795,11 @@ def create( "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" ) assert args.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + # TODO: Support varlen CLC with cluster_shape_m > 1 by refactoring the + # flattened-tile decode so cluster unpacking semantics are explicit. + assert scheduling_mode != SchedulingMode.CLC or args.cluster_shape_mn[0] == 1, ( + "Varlen CLC currently requires cluster_shape_mn[0] == 1" + ) return SingleTileVarlenScheduler.Params( num_head=args.num_head, num_batch=args.num_batch, @@ -551,22 +814,65 @@ def create( is_split_kv=args.is_split_kv, head_swizzle=args.head_swizzle, cluster_shape_m=args.cluster_shape_mn[0], + scheduling_mode=scheduling_mode, ) - def __init__(self, params: Params, tile_idx: Int32, split_idx: Int32, *, loc=None, ip=None): + def __init__( + self, + params: Params, + tile_idx: Int32, + split_idx: Int32, + clc: ClcState | None = None, + *, + loc=None, + ip=None, + ): self.params = params self._tile_idx = tile_idx self._split_idx = split_idx self._is_first_block = True + self.clc = clc self._loc = loc self._ip = ip @staticmethod - def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: - return SingleTileVarlenScheduler.Params.create(args, loc=loc, ip=ip) + def to_underlying_arguments( + args: TileSchedulerArguments, + *, + scheduling_mode: SchedulingMode = SchedulingMode.STATIC, + loc=None, + ip=None, + ) -> Params: + return SingleTileVarlenScheduler.Params.create( + args, scheduling_mode=scheduling_mode, loc=loc, ip=ip + ) @staticmethod - def create(params: Params, *, loc=None, ip=None) -> "SingleTileVarlenScheduler": + @cute.jit + def clc_problem_shape(params: Params): + return ClcDynamicPersistentTileSchedulerParams( + problem_shape_ntile_mnl=SingleTileVarlenScheduler.get_grid_shape(params), + cluster_shape_mnk=(1, 1, 1), + ) + + @staticmethod + @cute.jit + def create( + params: Params, clc: ClcState | None = None, *, loc=None, ip=None + ) -> "SingleTileVarlenScheduler": + if const_expr(params.scheduling_mode == SchedulingMode.CLC): + block_idx = cute.arch.block_idx() + split_idx = Int32(0) + if const_expr(params.is_split_kv): + split_idx = block_idx[1] + return SingleTileVarlenScheduler( + params, + block_idx[0], + split_idx, + clc, + loc=loc, + ip=ip, + ) tile_idx, split_idx, _ = cute.arch.block_idx() return SingleTileVarlenScheduler(params, tile_idx, split_idx, loc=loc, ip=ip) @@ -582,7 +888,7 @@ def get_grid_shape( params.total_q + params.num_batch * (params.cluster_shape_m * params.tile_shape_mn[0] - 1) ) // params.tile_shape_mn[0] - # round down to nearest multiple of cluster since odd excess is always padding + # Round down to nearest multiple of cluster since odd excess is always padding. total_blocks_max = total_blocks_max // params.cluster_shape_m * params.cluster_shape_m return (total_blocks_max * params.num_head, params.num_splits, Int32(1)) @@ -610,7 +916,8 @@ def _get_num_m_blocks(self, lane: Int32, bidb_start: Int32) -> Int32: ) @cute.jit - def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + def _varlen_coord_map(self) -> WorkTileInfo: + """Map self._tile_idx to (block, head, batch) via warp-level prefix sums.""" params = self.params lane_idx = cute.arch.lane_idx() num_m_blocks = self._get_num_m_blocks(lane_idx, bidb_start=0) @@ -708,19 +1015,62 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: split_idx = self._split_idx if const_expr(params.is_split_kv) else Int32(0) return WorkTileInfo((Int32(block), Int32(head_idx), Int32(batch_idx), split_idx), is_valid) + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + clc_work = self.clc.get_current_work() + # Default to grid_dim (one past last valid flat index) so _varlen_coord_map + # returns is_valid=False when CLC is exhausted. CLC tile_idx is garbage when + # invalid, so we can't trust it. Local-then-assign avoids CuTe DSL structural + # mismatch on self inside the runtime if. + new_tile_idx = cute.arch.grid_dim()[0] + new_split_idx = Int32(0) + if clc_work.is_valid_tile: + new_tile_idx = clc_work.tile_idx[0] + if const_expr(self.params.is_split_kv): + new_split_idx = clc_work.tile_idx[1] + self._tile_idx = new_tile_idx + self._split_idx = new_split_idx + return self._varlen_coord_map() + + @cute.jit def initial_work_tile_info(self, *, loc=None, ip=None): - return self.get_current_work(loc=loc, ip=ip) + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + clc_work = self.clc.initial_work_tile_info() + # See get_current_work for why grid_dim and local-then-assign. + new_tile_idx = cute.arch.grid_dim()[0] + new_split_idx = Int32(0) + if clc_work.is_valid_tile: + new_tile_idx = clc_work.tile_idx[0] + if const_expr(self.params.is_split_kv): + new_split_idx = clc_work.tile_idx[1] + self._tile_idx = new_tile_idx + self._split_idx = new_split_idx + return self._varlen_coord_map() def prefetch_next_work(self, *, loc=None, ip=None): - pass + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.prefetch_next_work(loc=loc, ip=ip) def advance_to_next_work(self, *, loc=None, ip=None): - # Single tile scheduler - set to invalid tile_idx to indicate no more work + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.consumer_wait(loc=loc, ip=ip) + work = self.get_current_work() + self.clc.consumer_release(loc=loc, ip=ip) + return work self._is_first_block = False + return self.get_current_work() + + def producer_tail(self, *, loc=None, ip=None): + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + self.clc.producer_tail(loc=loc, ip=ip) def __extract_mlir_values__(self): values, self._values_pos = [], [] - for obj in [self.params, self._tile_idx, self._split_idx]: + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj in objs: obj_values = cutlass.extract_mlir_values(obj) values += obj_values self._values_pos.append(len(obj_values)) @@ -728,10 +1078,10 @@ def __extract_mlir_values__(self): def __new_from_mlir_values__(self, values): obj_list = [] - for obj, n_items in zip( - [self.params, self._tile_idx, self._split_idx], - self._values_pos, - ): + objs = [self.params, self._tile_idx, self._split_idx] + if const_expr(self.params.scheduling_mode == SchedulingMode.CLC): + objs += [self.clc] + for obj, n_items in zip(objs, self._values_pos): obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) values = values[n_items:] - return SingleTileVarlenScheduler(*(tuple(obj_list)), loc=self._loc) + return self.__class__(*obj_list, loc=self._loc) diff --git a/flash_sparse_attn/ops/cute/utils.py b/flash_sparse_attn/ops/cute/utils.py index 2d8767c..3118661 100644 --- a/flash_sparse_attn/ops/cute/utils.py +++ b/flash_sparse_attn/ops/cute/utils.py @@ -3,12 +3,13 @@ import math import hashlib import inspect +import os from typing import Type, Callable, Optional, Tuple, overload import cutlass import cutlass.cute as cute -from cutlass import Float32, const_expr +from cutlass import Float32, Int32, const_expr from cutlass.cute import FastDivmodDivisor from cutlass.cutlass_dsl import T, dsl_user_op from cutlass._mlir.dialects import nvvm, llvm @@ -55,6 +56,17 @@ ), } +_fa_clc_enabled: bool = os.environ.get("FA_CLC", "0") == "1" +_fa_disable_2cta_enabled: bool = os.environ.get("FA_DISABLE_2CTA", "0") == "1" + + +def _get_use_clc_scheduler_default() -> bool: + return _fa_clc_enabled + + +def _get_disable_2cta_default() -> bool: + return _fa_disable_2cta_enabled + def _compute_base_hash(func: Callable) -> str: """Compute hash from source code or bytecode and closure values.""" @@ -250,6 +262,21 @@ def warp_reduce( return val +@dsl_user_op +def smid(*, loc=None, ip=None) -> Int32: + return Int32( + llvm.inline_asm( + T.i32(), + [], + "mov.u32 $0, %smid;", + "=r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op def fmax( a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None From 253f7e654441cc3e6be4d7bccbd318c14c944628 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Tue, 7 Apr 2026 15:27:15 +0000 Subject: [PATCH 3/5] Rewrite vendored CuTe namespace to flash_sparse_attn.ops.cute --- flash_sparse_attn/ops/cute/__init__.py | 4 +- .../ops/cute/blackwell_helpers.py | 2 +- flash_sparse_attn/ops/cute/block_info.py | 2 +- .../ops/cute/block_sparse_utils.py | 4 +- flash_sparse_attn/ops/cute/block_sparsity.py | 2 +- flash_sparse_attn/ops/cute/cache_utils.py | 2 +- .../ops/cute/compute_block_sparsity.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd.py | 14 +++---- .../ops/cute/flash_bwd_postprocess.py | 10 ++--- .../ops/cute/flash_bwd_preprocess.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd_sm100.py | 28 +++++++------- flash_sparse_attn/ops/cute/flash_bwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_bwd_sm90.py | 24 ++++++------ flash_sparse_attn/ops/cute/flash_fwd.py | 24 ++++++------ .../ops/cute/flash_fwd_combine.py | 6 +-- flash_sparse_attn/ops/cute/flash_fwd_sm100.py | 34 ++++++++--------- flash_sparse_attn/ops/cute/flash_fwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_fwd_sm90.py | 28 +++++++------- flash_sparse_attn/ops/cute/interface.py | 38 +++++++++---------- flash_sparse_attn/ops/cute/mask.py | 4 +- flash_sparse_attn/ops/cute/pack_gqa.py | 2 +- flash_sparse_attn/ops/cute/paged_kv.py | 2 +- flash_sparse_attn/ops/cute/softmax.py | 4 +- flash_sparse_attn/ops/cute/tile_scheduler.py | 4 +- 24 files changed, 127 insertions(+), 127 deletions(-) diff --git a/flash_sparse_attn/ops/cute/__init__.py b/flash_sparse_attn/ops/cute/__init__.py index 1b84363..01de305 100644 --- a/flash_sparse_attn/ops/cute/__init__.py +++ b/flash_sparse_attn/ops/cute/__init__.py @@ -3,7 +3,7 @@ from importlib.metadata import PackageNotFoundError, version try: - __version__ = version("fa4") + __version__ = version("flash-sparse-attn") except PackageNotFoundError: __version__ = "0.0.0" @@ -14,7 +14,7 @@ flash_attn_varlen_func, ) -from flash_attn.cute.cute_dsl_utils import cute_compile_patched +from flash_sparse_attn.ops.cute.cute_dsl_utils import cute_compile_patched # Patch cute.compile to optionally dump SASS cute.compile = cute_compile_patched diff --git a/flash_sparse_attn/ops/cute/blackwell_helpers.py b/flash_sparse_attn/ops/cute/blackwell_helpers.py index 7207780..cdec2e5 100644 --- a/flash_sparse_attn/ops/cute/blackwell_helpers.py +++ b/flash_sparse_attn/ops/cute/blackwell_helpers.py @@ -7,7 +7,7 @@ from cutlass.cute.nvgpu import tcgen05 from cutlass._mlir.dialects import llvm -import flash_attn.cute.mma_sm100_desc as sm100_desc +import flash_sparse_attn.ops.cute.mma_sm100_desc as sm100_desc @cute.jit diff --git a/flash_sparse_attn/ops/cute/block_info.py b/flash_sparse_attn/ops/cute/block_info.py index f210138..cebd0bf 100644 --- a/flash_sparse_attn/ops/cute/block_info.py +++ b/flash_sparse_attn/ops/cute/block_info.py @@ -6,7 +6,7 @@ import cutlass.cute as cute from cutlass import Int32, const_expr -from flash_attn.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK @dataclass(frozen=True) diff --git a/flash_sparse_attn/ops/cute/block_sparse_utils.py b/flash_sparse_attn/ops/cute/block_sparse_utils.py index 52cb7e0..de8c6f0 100644 --- a/flash_sparse_attn/ops/cute/block_sparse_utils.py +++ b/flash_sparse_attn/ops/cute/block_sparse_utils.py @@ -15,8 +15,8 @@ from quack import copy_utils # Import data structures from block_sparsity -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.named_barrier import NamedBarrierBwd +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd # NOTE [SM100 block-sparse empty tiles: mbarrier contract] diff --git a/flash_sparse_attn/ops/cute/block_sparsity.py b/flash_sparse_attn/ops/cute/block_sparsity.py index 3fad8c9..9e34734 100644 --- a/flash_sparse_attn/ops/cute/block_sparsity.py +++ b/flash_sparse_attn/ops/cute/block_sparsity.py @@ -7,7 +7,7 @@ import cutlass.cute as cute import torch -from flash_attn.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor +from flash_sparse_attn.ops.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor def ceildiv(a: int, b: int) -> int: diff --git a/flash_sparse_attn/ops/cute/cache_utils.py b/flash_sparse_attn/ops/cute/cache_utils.py index f1b5970..dc04970 100644 --- a/flash_sparse_attn/ops/cute/cache_utils.py +++ b/flash_sparse_attn/ops/cute/cache_utils.py @@ -17,7 +17,7 @@ import cutlass.cute as cute import tvm_ffi from cutlass.cutlass_dsl import JitCompiledFunction -from flash_attn.cute.fa_logging import fa_log +from flash_sparse_attn.ops.cute.fa_logging import fa_log # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen. diff --git a/flash_sparse_attn/ops/cute/compute_block_sparsity.py b/flash_sparse_attn/ops/cute/compute_block_sparsity.py index a2dd98e..d986ecb 100644 --- a/flash_sparse_attn/ops/cute/compute_block_sparsity.py +++ b/flash_sparse_attn/ops/cute/compute_block_sparsity.py @@ -6,13 +6,13 @@ import torch from cutlass import Boolean, Int8, Int32, const_expr -from flash_attn.cute.block_sparsity import ( +from flash_sparse_attn.ops.cute.block_sparsity import ( BlockSparseTensors, BlockSparseTensorsTorch, to_cute_block_sparse_tensors, ) -from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK class BlockSparsityKernel: diff --git a/flash_sparse_attn/ops/cute/flash_bwd.py b/flash_sparse_attn/ops/cute/flash_bwd.py index 824abdd..d4f1db6 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd.py +++ b/flash_sparse_attn/ops/cute/flash_bwd.py @@ -15,14 +15,14 @@ import cutlass.utils as utils_basic from quack import layout_utils -from flash_attn.cute import ampere_helpers as sm80_utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments -from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors class FlashAttentionBackwardSm80: diff --git a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py index 76c8562..36e4e8d 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py @@ -18,13 +18,13 @@ from quack import layout_utils from quack import sm90_utils -from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import ampere_helpers as sm80_utils -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK import cutlass.cute.nvgpu.tcgen05 as tcgen05 from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py index d93ea5c..698b19f 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py @@ -25,10 +25,10 @@ from quack import copy_utils, layout_utils -from flash_attn.cute import utils -from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py index e06cd81..58d3b41 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py @@ -16,27 +16,27 @@ import quack.activation from quack import layout_utils -from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import copy_utils -from flash_attn.cute import pipeline -from flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import copy_utils +from flash_sparse_attn.ops.cute import pipeline +from flash_sparse_attn.ops.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, # noqa SingleTileVarlenScheduler, ) -from flash_attn.cute import barrier -from flash_attn.cute.named_barrier import NamedBarrierBwdSm100 -from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute import barrier +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwdSm100 +from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( get_total_q_block_count_bwd, get_block_sparse_iteration_info_bwd, get_m_block_from_iter_bwd, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py index 556c59e..0941ae2 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py index f724b5a..7980c2d 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py @@ -17,24 +17,24 @@ from quack import sm90_utils from quack.sm90_utils import gemm_zero_init, gemm_w_idx -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute import pipeline +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute import pipeline from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, SingleTileVarlenScheduler, ) -from flash_attn.cute import barrier -from flash_attn.cute.named_barrier import NamedBarrierBwd -from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute import barrier +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd +from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( get_total_q_block_count_bwd, produce_block_sparse_q_loads_bwd_sm90, consume_block_sparse_mma_bwd_sm90, diff --git a/flash_sparse_attn/ops/cute/flash_fwd.py b/flash_sparse_attn/ops/cute/flash_fwd.py index 4d47fab..43926f7 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd.py +++ b/flash_sparse_attn/ops/cute/flash_fwd.py @@ -23,17 +23,17 @@ from quack import copy_utils from quack import layout_utils -from flash_attn.cute import ampere_helpers as sm80_utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import Softmax -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.pack_gqa import PackGQA -from flash_attn.cute.named_barrier import NamedBarrierFwd -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.softmax import Softmax +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute.pack_gqa import PackGQA +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments class FlashAttentionForwardBase: @@ -1190,6 +1190,6 @@ def load_K_next(): # SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility def __getattr__(name): if name == "FlashAttentionForwardSm90": - from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 + from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 return FlashAttentionForwardSm90 raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/flash_sparse_attn/ops/cute/flash_fwd_combine.py b/flash_sparse_attn/ops/cute/flash_fwd_combine.py index 4936202..7f38c2b 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_combine.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_combine.py @@ -12,9 +12,9 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Float32, Int32, Boolean, const_expr -from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py index 6e4fdbf..de73573 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py @@ -33,29 +33,29 @@ from quack import copy_utils, layout_utils -from flash_attn.cute.paged_kv import PagedKVManager -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -import flash_attn.cute.pipeline as pipeline_custom +from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +import flash_sparse_attn.ops.cute.pipeline as pipeline_custom import cutlass.pipeline as cutlass_pipeline -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.softmax import SoftmaxSm100, apply_score_mod_inner +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( get_total_block_count, produce_block_sparse_loads_sm100, softmax_block_sparse_sm100, handle_block_sparse_empty_tile_correction_sm100, ) -from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout -from flash_attn.cute import mma_sm100_desc as sm100_desc -from flash_attn.cute import blackwell_helpers as sm100_utils -from flash_attn.cute.named_barrier import NamedBarrierFwdSm100 +from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout +from flash_sparse_attn.ops.cute import mma_sm100_desc as sm100_desc +from flash_sparse_attn.ops.cute import blackwell_helpers as sm100_utils +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwdSm100 from cutlass.cute import FastDivmodDivisor from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( ClcState, SchedulingMode, TileSchedulerArguments, @@ -65,8 +65,8 @@ SingleTileLPTScheduler, SingleTileVarlenScheduler, ) -from flash_attn.cute.fa_logging import fa_log, fa_printf -from flash_attn.cute.utils import smid +from flash_sparse_attn.ops.cute.fa_logging import fa_log, fa_printf +from flash_sparse_attn.ops.cute.utils import smid # === TUNING KNOBS (agent-editable) === # Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool) diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py index 08d219a..f9917d8 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 class FlashAttentionForwardSm120(FlashAttentionForwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py index 4108ce4..8ae0e77 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py @@ -21,23 +21,23 @@ from quack import layout_utils from quack import sm90_utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import Softmax, apply_score_mod_inner -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.softmax import Softmax, apply_score_mod_inner +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( produce_block_sparse_loads, consume_block_sparse_loads, ) -from flash_attn.cute import pipeline as pipeline_custom -from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom -from flash_attn.cute.paged_kv import PagedKVManager -from flash_attn.cute.named_barrier import NamedBarrierFwd +from flash_sparse_attn.ops.cute import pipeline as pipeline_custom +from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, @@ -45,7 +45,7 @@ ) from cutlass.cute import FastDivmodDivisor -from flash_attn.cute.flash_fwd import FlashAttentionForwardBase +from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardBase class FlashAttentionForwardSm90(FlashAttentionForwardBase): diff --git a/flash_sparse_attn/ops/cute/interface.py b/flash_sparse_attn/ops/cute/interface.py index 872960a..b5f046a 100644 --- a/flash_sparse_attn/ops/cute/interface.py +++ b/flash_sparse_attn/ops/cute/interface.py @@ -34,35 +34,35 @@ import cutlass.cute as cute from cutlass import Int32, Float32 from quack.compile_utils import make_fake_tensor as fake_tensor -from flash_attn.cute.cache_utils import get_jit_cache -from flash_attn.cute.testing import is_fake_mode +from flash_sparse_attn.ops.cute.cache_utils import get_jit_cache +from flash_sparse_attn.ops.cute.testing import is_fake_mode if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: - from flash_attn.cute import cute_dsl_ptxas # noqa: F401 + from flash_sparse_attn.ops.cute import cute_dsl_ptxas # noqa: F401 # Patch to dump ptx and then use system ptxas to compile to cubin cute_dsl_ptxas.patch() -from flash_attn.cute import utils -from flash_attn.cute import fa_logging -from flash_attn.cute.cute_dsl_utils import ( +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute import fa_logging +from flash_sparse_attn.ops.cute.cute_dsl_utils import ( to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims, ) -from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 -from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 -from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 -from flash_attn.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 -from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess -from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 -from flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 -from flash_attn.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 -from flash_attn.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 -from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess -from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine - -from flash_attn.cute.block_sparsity import ( +from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 +from flash_sparse_attn.ops.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 +from flash_sparse_attn.ops.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 +from flash_sparse_attn.ops.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess +from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_sparse_attn.ops.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 +from flash_sparse_attn.ops.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 +from flash_sparse_attn.ops.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 +from flash_sparse_attn.ops.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess +from flash_sparse_attn.ops.cute.flash_fwd_combine import FlashAttentionForwardCombine + +from flash_sparse_attn.ops.cute.block_sparsity import ( BlockSparseTensorsTorch, get_sparse_q_block_size, to_cute_block_sparse_tensors, diff --git a/flash_sparse_attn/ops/cute/mask.py b/flash_sparse_attn/ops/cute/mask.py index 6b5ca16..ba53d13 100644 --- a/flash_sparse_attn/ops/cute/mask.py +++ b/flash_sparse_attn/ops/cute/mask.py @@ -8,8 +8,8 @@ from cutlass import Float32, Int32, Uint32, const_expr from quack import layout_utils -import flash_attn.cute.utils as utils -from flash_attn.cute.seqlen_info import SeqlenInfoQK +import flash_sparse_attn.ops.cute.utils as utils +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK MaskGenFn: TypeAlias = Callable[[int], Uint32] MASK_R2P_CHUNK_SIZE: int = 32 diff --git a/flash_sparse_attn/ops/cute/pack_gqa.py b/flash_sparse_attn/ops/cute/pack_gqa.py index e87df01..484834f 100644 --- a/flash_sparse_attn/ops/cute/pack_gqa.py +++ b/flash_sparse_attn/ops/cute/pack_gqa.py @@ -9,7 +9,7 @@ from quack import layout_utils -import flash_attn.cute.utils as utils +import flash_sparse_attn.ops.cute.utils as utils def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): diff --git a/flash_sparse_attn/ops/cute/paged_kv.py b/flash_sparse_attn/ops/cute/paged_kv.py index bf11acb..f84b35c 100644 --- a/flash_sparse_attn/ops/cute/paged_kv.py +++ b/flash_sparse_attn/ops/cute/paged_kv.py @@ -6,7 +6,7 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Int32, const_expr -from flash_attn.cute import utils +from flash_sparse_attn.ops.cute import utils from quack.cute_dsl_utils import ParamsBase from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/softmax.py b/flash_sparse_attn/ops/cute/softmax.py index eed55a0..0eaa479 100644 --- a/flash_sparse_attn/ops/cute/softmax.py +++ b/flash_sparse_attn/ops/cute/softmax.py @@ -10,9 +10,9 @@ from cutlass import Float32 from quack import layout_utils -import flash_attn.cute.utils as utils +import flash_sparse_attn.ops.cute.utils as utils from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK @dataclass diff --git a/flash_sparse_attn/ops/cute/tile_scheduler.py b/flash_sparse_attn/ops/cute/tile_scheduler.py index 3ee4bc8..9c15dd6 100644 --- a/flash_sparse_attn/ops/cute/tile_scheduler.py +++ b/flash_sparse_attn/ops/cute/tile_scheduler.py @@ -19,8 +19,8 @@ from quack.cute_dsl_utils import ParamsBase -import flash_attn.cute.utils as utils -from flash_attn.cute.fast_math import clz +import flash_sparse_attn.ops.cute.utils as utils +from flash_sparse_attn.ops.cute.fast_math import clz class SchedulingMode(IntEnum): From b59d7d30388b771a639b1c9185c1e95f6e1aedfd Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Tue, 7 Apr 2026 15:29:36 +0000 Subject: [PATCH 4/5] Refactor rewrite_cute_namespace.py to support dynamic namespace rewriting based on direction --- scripts/rewrite_cute_namespace.py | 52 ++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 5 deletions(-) diff --git a/scripts/rewrite_cute_namespace.py b/scripts/rewrite_cute_namespace.py index 5bbef92..5fe13bd 100644 --- a/scripts/rewrite_cute_namespace.py +++ b/scripts/rewrite_cute_namespace.py @@ -4,13 +4,29 @@ from pathlib import Path -def rewrite_python_files(target_dir: Path) -> int: +LOCAL_NAMESPACE = "flash_sparse_attn.ops.cute" +UPSTREAM_NAMESPACE = "flash_attn.cute" +LOCAL_PACKAGE_NAME = "flash-sparse-attn" +UPSTREAM_PACKAGE_NAME = "fa4" + + +def rewrite_python_files( + target_dir: Path, + *, + source_namespace: str, + destination_namespace: str, + source_package_name: str, + destination_package_name: str, +) -> int: changed = 0 for path in sorted(target_dir.rglob("*.py")): original = path.read_text(encoding="utf-8") - updated = original.replace("flash_attn.cute", "flash_sparse_attn.ops.cute") + updated = original.replace(source_namespace, destination_namespace) if path.name == "__init__.py": - updated = updated.replace('version("fa4")', 'version("flash-sparse-attn")') + updated = updated.replace( + f'version("{source_package_name}")', + f'version("{destination_package_name}")', + ) if updated != original: path.write_text(updated, encoding="utf-8") changed += 1 @@ -19,6 +35,12 @@ def rewrite_python_files(target_dir: Path) -> int: def main() -> int: parser = argparse.ArgumentParser(description="Rewrite vendored CuTe imports to the local package namespace.") + parser.add_argument( + "--direction", + choices=("local", "upstream"), + default="local", + help="Rewrite imports to the local namespace or back to the upstream namespace.", + ) parser.add_argument( "target_dir", nargs="?", @@ -31,8 +53,28 @@ def main() -> int: if not target_dir.exists(): raise SystemExit(f"Target directory does not exist: {target_dir}") - changed = rewrite_python_files(target_dir) - print(f"Rewrote CuTe namespace in {changed} file(s) under {target_dir}") + if args.direction == "local": + source_namespace = UPSTREAM_NAMESPACE + destination_namespace = LOCAL_NAMESPACE + source_package_name = UPSTREAM_PACKAGE_NAME + destination_package_name = LOCAL_PACKAGE_NAME + else: + source_namespace = LOCAL_NAMESPACE + destination_namespace = UPSTREAM_NAMESPACE + source_package_name = LOCAL_PACKAGE_NAME + destination_package_name = UPSTREAM_PACKAGE_NAME + + changed = rewrite_python_files( + target_dir, + source_namespace=source_namespace, + destination_namespace=destination_namespace, + source_package_name=source_package_name, + destination_package_name=destination_package_name, + ) + print( + f"Rewrote CuTe namespace from {source_namespace} to {destination_namespace} " + f"in {changed} file(s) under {target_dir}" + ) return 0 From 06e243ff9610affb9216bb384419d188ea8d56c5 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Tue, 7 Apr 2026 15:29:47 +0000 Subject: [PATCH 5/5] Enhance sync scripts with submodule initialization and namespace rewrite functions --- scripts/sync_cute_subtree.ps1 | 212 ++++++++++++++++++++++++++++------ scripts/sync_cute_subtree.sh | 110 +++++++++++++----- 2 files changed, 258 insertions(+), 64 deletions(-) diff --git a/scripts/sync_cute_subtree.ps1 b/scripts/sync_cute_subtree.ps1 index 1d5ac0b..8d23d47 100644 --- a/scripts/sync_cute_subtree.ps1 +++ b/scripts/sync_cute_subtree.ps1 @@ -13,6 +13,7 @@ param( Set-StrictMode -Version Latest $ErrorActionPreference = "Stop" +$PrepareMergeCommitMessage = "Rewrite vendored CuTe namespace to flash_attn.cute before subtree merge" $RewriteCommitMessage = "Rewrite vendored CuTe namespace to flash_sparse_attn.ops.cute" function Test-GitRemoteSpec { @@ -33,10 +34,14 @@ function Invoke-Git { ) if ($Repo) { - & git -C $Repo @Arguments + $output = & git -C $Repo @Arguments } else { - & git @Arguments + $output = & git @Arguments + } + + if ($output) { + $output | ForEach-Object { Write-Host $_ } } if ($LASTEXITCODE -ne 0) { @@ -110,6 +115,81 @@ function Get-DirtyStatus { return Get-GitOutput -Repo $Repo -Arguments @("status", "--porcelain") } +function Get-GitCommonDir { + param( + [Parameter(Mandatory = $true)] + [string]$Repo + ) + + $commonDir = (& git -C $Repo rev-parse --path-format=absolute --git-common-dir) 2>$null + if ($LASTEXITCODE -eq 0 -and $commonDir) { + return (($commonDir | Out-String).Trim()) + } + + $commonDir = Get-GitOutput -Repo $Repo -Arguments @("rev-parse", "--git-common-dir") + if (-not $commonDir) { + return $null + } + + if (-not [System.IO.Path]::IsPathRooted($commonDir)) { + return [System.IO.Path]::GetFullPath((Join-Path $Repo $commonDir)) + } + + return $commonDir +} + +function Get-SubmoduleStatusLine { + param( + [Parameter(Mandatory = $true)] + [string]$Repo, + [Parameter(Mandatory = $true)] + [string]$SubmodulePath + ) + + $output = (& git -C $Repo submodule status -- $SubmodulePath) 2>$null + if ($LASTEXITCODE -ne 0 -or -not $output) { + return $null + } + + return (($output | Select-Object -First 1 | Out-String).Trim()) +} + +function Ensure-SubmoduleInitialized { + param( + [Parameter(Mandatory = $true)] + [string]$Repo, + [Parameter(Mandatory = $true)] + [string]$SubmodulePath, + [Parameter(Mandatory = $true)] + [string]$Label + ) + + $statusLine = Get-SubmoduleStatusLine -Repo $Repo -SubmodulePath $SubmodulePath + if (-not $statusLine) { + return + } + + if ($statusLine[0] -eq '-') { + $commonDir = Get-GitCommonDir -Repo $Repo + $referenceRepo = $null + if ($commonDir) { + $referenceRepo = Join-Path $commonDir "modules" + foreach ($segment in $SubmodulePath -split '/') { + $referenceRepo = Join-Path $referenceRepo $segment + } + } + + if ($referenceRepo -and (Test-Path $referenceRepo)) { + Write-Host "Initializing $Label in $Repo from local git cache ..." + Invoke-Git -Repo $Repo -Arguments @("submodule", "update", "--init", "--reference", $referenceRepo, "--", $SubmodulePath) + } + else { + Write-Host "Initializing $Label in $Repo ..." + Invoke-Git -Repo $Repo -Arguments @("submodule", "update", "--init", "--", $SubmodulePath) + } + } +} + function Test-IsGitRepo { param( [Parameter(Mandatory = $true)] @@ -249,6 +329,76 @@ function Assert-SyncContainsUpstream { throw ($message -join [Environment]::NewLine) } +function Get-SyncResultProperty { + param( + [Parameter(Mandatory = $false)] + [object]$SyncResult, + [Parameter(Mandatory = $true)] + [string]$PropertyName + ) + + if ($null -eq $SyncResult) { + return $null + } + + if ($SyncResult -is [System.Array]) { + foreach ($item in $SyncResult) { + if ($null -ne $item -and $item.PSObject.Properties.Match($PropertyName).Count -gt 0) { + return $item.$PropertyName + } + } + + return $null + } + + if ($SyncResult.PSObject.Properties.Match($PropertyName).Count -gt 0) { + return $SyncResult.$PropertyName + } + + return $null +} + +function Invoke-NamespaceRewrite { + param( + [Parameter(Mandatory = $true)] + [string]$RewriteScript, + [Parameter(Mandatory = $true)] + [string]$TargetPath, + [Parameter(Mandatory = $true)] + [ValidateSet("local", "upstream")] + [string]$Direction + ) + + $rewriteOutput = & python $RewriteScript --direction $Direction $TargetPath + if ($rewriteOutput) { + $rewriteOutput | ForEach-Object { Write-Host $_ } + } + if ($LASTEXITCODE -ne 0) { + throw "python $RewriteScript --direction $Direction $TargetPath failed with exit code $LASTEXITCODE" + } +} + +function Commit-PrefixIfChanged { + param( + [Parameter(Mandatory = $true)] + [string]$Repo, + [Parameter(Mandatory = $true)] + [string]$Prefix, + [Parameter(Mandatory = $true)] + [string]$Message + ) + + $prefixStatus = Get-GitOutput -Repo $Repo -Arguments @("status", "--porcelain", "--", $Prefix) + if (-not $prefixStatus) { + return $false + } + + Ensure-GitIdentity -Repo $Repo + Invoke-Git -Repo $Repo -Arguments @("add", "--", $Prefix) + Invoke-Git -Repo $Repo -Arguments @("commit", "-m", $Message) + return $true +} + function Invoke-CoreSync { param( [Parameter(Mandatory = $true)] @@ -271,13 +421,15 @@ function Invoke-CoreSync { ) $cutlassRepo = Join-Path $WorkRepoRoot "csrc/cutlass" + $cutlassSubmodulePath = "csrc/cutlass" $targetPath = Join-Path $WorkRepoRoot $Prefix $startHead = Get-GitOutput -Repo $WorkRepoRoot -Arguments @("rev-parse", "HEAD") $localSplitBefore = $null - Invoke-Git -Repo $WorkRepoRoot -Arguments @("rev-parse", "--show-toplevel") | Out-Null - Invoke-Git -Repo $UpstreamRepoForSplit -Arguments @("rev-parse", "--show-toplevel") | Out-Null + Get-GitOutput -Repo $WorkRepoRoot -Arguments @("rev-parse", "--show-toplevel") | Out-Null + Get-GitOutput -Repo $UpstreamRepoForSplit -Arguments @("rev-parse", "--show-toplevel") | Out-Null + Ensure-SubmoduleInitialized -Repo $WorkRepoRoot -SubmodulePath $cutlassSubmodulePath -Label "csrc/cutlass submodule" Test-WorktreeClean -Repo $WorkRepoRoot -Label "Superproject" if ((Test-Path $cutlassRepo) -and (Test-IsGitRepo -Repo $cutlassRepo)) { Test-WorktreeClean -Repo $cutlassRepo -Label "csrc/cutlass submodule" @@ -310,8 +462,13 @@ function Invoke-CoreSync { throw "$Prefix does not exist yet. Run this script once with -Init first." } + Write-Host "Rewriting vendored CuTe imports to flash_attn.cute before subtree merge ..." + Invoke-NamespaceRewrite -RewriteScript $RewriteScript -TargetPath $targetPath -Direction upstream + [void](Commit-PrefixIfChanged -Repo $WorkRepoRoot -Prefix $Prefix -Message $PrepareMergeCommitMessage) + Write-Host "Pulling upstream updates into $Prefix ..." - Invoke-GitNoMergeEdit -Repo $WorkRepoRoot -Arguments @("subtree", "pull", "--prefix=$Prefix", $UpstreamRepoForSplit, $TempBranch) + Invoke-Git -Repo $WorkRepoRoot -Arguments @("fetch", $UpstreamRepoForSplit, $TempBranch) + Invoke-GitNoMergeEdit -Repo $WorkRepoRoot -Arguments @("merge", "-X", "theirs", "-Xsubtree=$Prefix", "FETCH_HEAD") } } finally { @@ -321,17 +478,8 @@ function Invoke-CoreSync { } Write-Host "Rewriting vendored CuTe imports to flash_sparse_attn.ops.cute ..." - & python $RewriteScript $targetPath - if ($LASTEXITCODE -ne 0) { - throw "python $RewriteScript $targetPath failed with exit code $LASTEXITCODE" - } - - $prefixStatus = Get-GitOutput -Repo $WorkRepoRoot -Arguments @("status", "--porcelain", "--", $Prefix) - if ($prefixStatus) { - Ensure-GitIdentity -Repo $WorkRepoRoot - Invoke-Git -Repo $WorkRepoRoot -Arguments @("add", "--", $Prefix) - Invoke-Git -Repo $WorkRepoRoot -Arguments @("commit", "-m", "Rewrite vendored CuTe namespace to flash_sparse_attn.ops.cute") - } + Invoke-NamespaceRewrite -RewriteScript $RewriteScript -TargetPath $targetPath -Direction local + [void](Commit-PrefixIfChanged -Repo $WorkRepoRoot -Prefix $Prefix -Message $RewriteCommitMessage) $localSplitAfter = Get-PrefixSplitCommit -Repo $WorkRepoRoot -Prefix $Prefix Assert-SyncContainsUpstream -Repo $WorkRepoRoot -UpstreamSplitCommit $splitCommit -LocalSplitCommit $localSplitAfter -PreviousLocalSplitCommit $localSplitBefore @@ -390,25 +538,14 @@ function Invoke-TemporaryWorktreeSync { } $currentStatus = Get-DirtyStatus -Repo $RepoRoot - $currentPrefixStatus = Get-GitOutput -Repo $RepoRoot -Arguments @("status", "--porcelain", "--", $Prefix) if ($currentStatus) { Write-Host "Stashing current worktree before cherry-picking synced commits back ..." Invoke-Git -Repo $RepoRoot -Arguments @("stash", "push", "-u", "-m", $stashName) $stashCreated = $true } - $commitsToCherryPick = @() - $applyRewriteAfterRestore = $false - foreach ($commit in $commits) { - if ($currentPrefixStatus -and (Get-CommitSubject -Repo $tempWorktree -Commit $commit) -eq $RewriteCommitMessage) { - $applyRewriteAfterRestore = $true - continue - } - $commitsToCherryPick += $commit - } - try { - foreach ($commit in $commitsToCherryPick) { + foreach ($commit in $commits) { Write-Host "Cherry-picking $commit back into current worktree ..." Ensure-GitIdentity -Repo $RepoRoot if ((Get-CommitParentCount -Repo $tempWorktree -Commit $commit) -gt 1) { @@ -433,15 +570,6 @@ function Invoke-TemporaryWorktreeSync { } } - if ($applyRewriteAfterRestore) { - Write-Host "Applying CuTe namespace rewrite in current worktree after restoring local changes ..." - & python $RewriteScript (Join-Path $RepoRoot $Prefix) - if ($LASTEXITCODE -ne 0) { - throw "python $RewriteScript $(Join-Path $RepoRoot $Prefix) failed with exit code $LASTEXITCODE" - } - Write-Host "Namespace rewrite was applied in the current worktree without creating an extra commit because local changes already exist under $Prefix." - } - return $result } finally { @@ -493,5 +621,13 @@ else { Write-Host "Done." Write-Host "Upstream source: $UpstreamRepo" Write-Host "Upstream cache used for subtree split: $upstreamRepoForSplit" -Write-Host "Synced commit range: $($syncResult.StartHead) -> $($syncResult.EndHead)" +$syncStartHead = Get-SyncResultProperty -SyncResult $syncResult -PropertyName "StartHead" +$syncEndHead = Get-SyncResultProperty -SyncResult $syncResult -PropertyName "EndHead" +if (-not $syncStartHead) { + $syncStartHead = Get-GitOutput -Repo $repoRoot -Arguments @("rev-parse", "HEAD") +} +if (-not $syncEndHead) { + $syncEndHead = Get-GitOutput -Repo $repoRoot -Arguments @("rev-parse", "HEAD") +} +Write-Host "Synced commit range: $syncStartHead -> $syncEndHead" Write-Host "Local edits inside $Prefix stay in this repo and future upstream changes can be merged by rerunning this script without -Init." diff --git a/scripts/sync_cute_subtree.sh b/scripts/sync_cute_subtree.sh index cc35e24..2188d9c 100644 --- a/scripts/sync_cute_subtree.sh +++ b/scripts/sync_cute_subtree.sh @@ -13,6 +13,7 @@ KEEP_TEMP_BRANCH=0 NO_TEMPORARY_WORKTREE=0 TEMP_WORKTREE_PATH="" TEMP_WORKTREE_BRANCH="" +PREPARE_MERGE_COMMIT_MESSAGE="Rewrite vendored CuTe namespace to flash_attn.cute before subtree merge" REWRITE_COMMIT_MESSAGE="Rewrite vendored CuTe namespace to flash_sparse_attn.ops.cute" UPSTREAM_SPLIT_REF="HEAD" @@ -80,6 +81,56 @@ dirty_status() { git -C "$repo" status --porcelain } +get_git_common_dir() { + local repo="$1" + local common_dir + + common_dir="$(git -C "$repo" rev-parse --path-format=absolute --git-common-dir 2>/dev/null | tail -n 1 | tr -d '\r')" + if [[ -z "$common_dir" ]]; then + common_dir="$(git -C "$repo" rev-parse --git-common-dir | tail -n 1 | tr -d '\r')" + if [[ -n "$common_dir" && ! "$common_dir" =~ ^([A-Za-z]:[\\/]|/) ]]; then + common_dir="$(cd "$repo" && cd "$common_dir" && pwd)" + fi + fi + + printf '%s\n' "$common_dir" +} + +get_submodule_status_line() { + local repo="$1" + local submodule_path="$2" + git -C "$repo" submodule status -- "$submodule_path" 2>/dev/null | head -n 1 | tr -d '\r' +} + +ensure_submodule_initialized() { + local repo="$1" + local submodule_path="$2" + local label="$3" + local status_line status_prefix common_dir reference_repo + + status_line="$(get_submodule_status_line "$repo" "$submodule_path")" + if [[ -z "$status_line" ]]; then + return + fi + + status_prefix="${status_line:0:1}" + if [[ "$status_prefix" == "-" ]]; then + common_dir="$(get_git_common_dir "$repo")" + reference_repo="" + if [[ -n "$common_dir" ]]; then + reference_repo="$common_dir/modules/$submodule_path" + fi + + if [[ -n "$reference_repo" && -d "$reference_repo" ]]; then + echo "Initializing $label in $repo from local git cache ..." + invoke_git -C "$repo" submodule update --init --reference "$reference_repo" -- "$submodule_path" + else + echo "Initializing $label in $repo ..." + invoke_git -C "$repo" submodule update --init -- "$submodule_path" + fi + fi +} + is_git_repo() { local repo="$1" git -C "$repo" rev-parse --show-toplevel >/dev/null 2>&1 @@ -192,9 +243,30 @@ cleanup_worktree() { fi } +run_namespace_rewrite() { + local target_path="$1" + local direction="$2" + python "$REWRITE_SCRIPT" --direction "$direction" "$target_path" +} + +commit_prefix_if_changed() { + local repo="$1" + local prefix="$2" + local message="$3" + + if [[ -z "$(git -C "$repo" status --porcelain -- "$prefix")" ]]; then + return 1 + fi + + ensure_git_identity "$repo" + invoke_git -C "$repo" add -- "$prefix" + invoke_git -C "$repo" commit -m "$message" +} + invoke_core_sync() { local work_repo_root="$1" local cutlass_repo="$work_repo_root/csrc/cutlass" + local cutlass_submodule_path="csrc/cutlass" local target_path="$work_repo_root/$PREFIX" local start_head local_split_before local_split_after start_head="$(git_output -C "$work_repo_root" rev-parse HEAD)" @@ -202,6 +274,7 @@ invoke_core_sync() { invoke_git -C "$work_repo_root" rev-parse --show-toplevel >/dev/null invoke_git -C "$UPSTREAM_REPO_FOR_SPLIT" rev-parse --show-toplevel >/dev/null + ensure_submodule_initialized "$work_repo_root" "$cutlass_submodule_path" "csrc/cutlass submodule" test_worktree_clean "$work_repo_root" "Superproject" if [[ -e "$cutlass_repo" ]] && is_git_repo "$cutlass_repo"; then test_worktree_clean "$cutlass_repo" "csrc/cutlass submodule" @@ -242,18 +315,20 @@ invoke_core_sync() { echo "$PREFIX does not exist yet. Run this script once with --init first." >&2 exit 1 fi + + echo "Rewriting vendored CuTe imports to flash_attn.cute before subtree merge ..." + run_namespace_rewrite "$target_path" upstream + commit_prefix_if_changed "$work_repo_root" "$PREFIX" "$PREPARE_MERGE_COMMIT_MESSAGE" || true + echo "Pulling upstream updates into $PREFIX ..." - invoke_git_no_merge_edit -C "$work_repo_root" subtree pull --prefix="$PREFIX" "$UPSTREAM_REPO_FOR_SPLIT" "$TEMP_BRANCH" + invoke_git -C "$work_repo_root" fetch "$UPSTREAM_REPO_FOR_SPLIT" "$TEMP_BRANCH" + invoke_git_no_merge_edit -C "$work_repo_root" merge -X theirs "-Xsubtree=$PREFIX" FETCH_HEAD fi echo "Rewriting vendored CuTe imports to flash_sparse_attn.ops.cute ..." - python "$REWRITE_SCRIPT" "$target_path" + run_namespace_rewrite "$target_path" local - if [[ -n "$(git -C "$work_repo_root" status --porcelain -- "$PREFIX")" ]]; then - ensure_git_identity "$work_repo_root" - invoke_git -C "$work_repo_root" add -- "$PREFIX" - invoke_git -C "$work_repo_root" commit -m "Rewrite vendored CuTe namespace to flash_sparse_attn.ops.cute" - fi + commit_prefix_if_changed "$work_repo_root" "$PREFIX" "$REWRITE_COMMIT_MESSAGE" || true local_split_after="$(get_prefix_split_commit "$work_repo_root" "$PREFIX")" assert_sync_contains_upstream "$work_repo_root" "$SPLIT_COMMIT" "$local_split_after" "$local_split_before" @@ -264,7 +339,7 @@ invoke_core_sync() { } invoke_temporary_worktree_sync() { - local timestamp temp_branch_name temp_worktree original_head stash_name current_status current_prefix_status commits commit cherry_pick_commits apply_rewrite_after_restore + local timestamp temp_branch_name temp_worktree original_head stash_name current_status commits commit timestamp="$(date +%Y%m%d-%H%M%S)" temp_branch_name="sync/cute-worktree-$timestamp" temp_worktree="$(cd "$REPO_ROOT/.." && pwd)/.cute-sync-worktree-$timestamp" @@ -284,25 +359,14 @@ invoke_temporary_worktree_sync() { fi current_status="$(dirty_status "$REPO_ROOT")" - current_prefix_status="$(git -C "$REPO_ROOT" status --porcelain -- "$PREFIX")" stash_name="sync-cute-autostash-$timestamp" if [[ -n "$current_status" ]]; then echo "Stashing current worktree before cherry-picking synced commits back ..." invoke_git -C "$REPO_ROOT" stash push -u -m "$stash_name" fi - cherry_pick_commits=() - apply_rewrite_after_restore=0 while IFS= read -r commit; do [[ -z "$commit" ]] && continue - if [[ -n "$current_prefix_status" ]] && [[ "$(get_commit_subject "$temp_worktree" "$commit")" == "$REWRITE_COMMIT_MESSAGE" ]]; then - apply_rewrite_after_restore=1 - continue - fi - cherry_pick_commits+=("$commit") - done <<< "$commits" - - for commit in "${cherry_pick_commits[@]}"; do echo "Cherry-picking $commit back into current worktree ..." ensure_git_identity "$REPO_ROOT" if [[ "$(get_commit_parent_count "$temp_worktree" "$commit")" -gt 1 ]]; then @@ -310,7 +374,7 @@ invoke_temporary_worktree_sync() { else invoke_git -C "$REPO_ROOT" cherry-pick "$commit" fi - done + done <<< "$commits" if [[ -n "$current_status" ]]; then echo "Restoring stashed local changes ..." @@ -319,12 +383,6 @@ invoke_temporary_worktree_sync() { exit 1 fi fi - - if [[ "$apply_rewrite_after_restore" -eq 1 ]]; then - echo "Applying CuTe namespace rewrite in current worktree after restoring local changes ..." - python "$REWRITE_SCRIPT" "$REPO_ROOT/$PREFIX" - echo "Namespace rewrite was applied in the current worktree without creating an extra commit because local changes already exist under $PREFIX." - fi } while [[ $# -gt 0 ]]; do