From 53b843c84fae801252d0c94e66d313253b9bec25 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 21 Apr 2026 16:46:33 +0800 Subject: [PATCH 1/3] Rewrite vendored CuTe namespace to flash_attn.cute before subtree merge --- flash_sparse_attn/ops/cute/__init__.py | 4 +- .../ops/cute/blackwell_helpers.py | 2 +- flash_sparse_attn/ops/cute/block_info.py | 2 +- .../ops/cute/block_sparse_utils.py | 4 +- flash_sparse_attn/ops/cute/block_sparsity.py | 2 +- flash_sparse_attn/ops/cute/cache_utils.py | 2 +- .../ops/cute/compute_block_sparsity.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd.py | 14 +++---- .../ops/cute/flash_bwd_postprocess.py | 10 ++--- .../ops/cute/flash_bwd_preprocess.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd_sm100.py | 28 +++++++------- flash_sparse_attn/ops/cute/flash_bwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_bwd_sm90.py | 24 ++++++------ flash_sparse_attn/ops/cute/flash_fwd.py | 24 ++++++------ .../ops/cute/flash_fwd_combine.py | 6 +-- flash_sparse_attn/ops/cute/flash_fwd_sm100.py | 34 ++++++++--------- flash_sparse_attn/ops/cute/flash_fwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_fwd_sm90.py | 28 +++++++------- flash_sparse_attn/ops/cute/interface.py | 38 +++++++++---------- flash_sparse_attn/ops/cute/mask.py | 4 +- flash_sparse_attn/ops/cute/pack_gqa.py | 2 +- flash_sparse_attn/ops/cute/paged_kv.py | 2 +- flash_sparse_attn/ops/cute/softmax.py | 4 +- flash_sparse_attn/ops/cute/tile_scheduler.py | 4 +- 24 files changed, 127 insertions(+), 127 deletions(-) diff --git a/flash_sparse_attn/ops/cute/__init__.py b/flash_sparse_attn/ops/cute/__init__.py index 01de305..1b84363 100644 --- a/flash_sparse_attn/ops/cute/__init__.py +++ b/flash_sparse_attn/ops/cute/__init__.py @@ -3,7 +3,7 @@ from importlib.metadata import PackageNotFoundError, version try: - __version__ = version("flash-sparse-attn") + __version__ = version("fa4") except PackageNotFoundError: __version__ = "0.0.0" @@ -14,7 +14,7 @@ flash_attn_varlen_func, ) -from flash_sparse_attn.ops.cute.cute_dsl_utils import cute_compile_patched +from flash_attn.cute.cute_dsl_utils import cute_compile_patched # Patch cute.compile to optionally dump SASS cute.compile = cute_compile_patched diff --git a/flash_sparse_attn/ops/cute/blackwell_helpers.py b/flash_sparse_attn/ops/cute/blackwell_helpers.py index cdec2e5..7207780 100644 --- a/flash_sparse_attn/ops/cute/blackwell_helpers.py +++ b/flash_sparse_attn/ops/cute/blackwell_helpers.py @@ -7,7 +7,7 @@ from cutlass.cute.nvgpu import tcgen05 from cutlass._mlir.dialects import llvm -import flash_sparse_attn.ops.cute.mma_sm100_desc as sm100_desc +import flash_attn.cute.mma_sm100_desc as sm100_desc @cute.jit diff --git a/flash_sparse_attn/ops/cute/block_info.py b/flash_sparse_attn/ops/cute/block_info.py index cebd0bf..f210138 100644 --- a/flash_sparse_attn/ops/cute/block_info.py +++ b/flash_sparse_attn/ops/cute/block_info.py @@ -6,7 +6,7 @@ import cutlass.cute as cute from cutlass import Int32, const_expr -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK +from flash_attn.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK @dataclass(frozen=True) diff --git a/flash_sparse_attn/ops/cute/block_sparse_utils.py b/flash_sparse_attn/ops/cute/block_sparse_utils.py index de8c6f0..52cb7e0 100644 --- a/flash_sparse_attn/ops/cute/block_sparse_utils.py +++ b/flash_sparse_attn/ops/cute/block_sparse_utils.py @@ -15,8 +15,8 @@ from quack import copy_utils # Import data structures from block_sparsity -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.named_barrier import NamedBarrierBwd # NOTE [SM100 block-sparse empty tiles: mbarrier contract] diff --git a/flash_sparse_attn/ops/cute/block_sparsity.py b/flash_sparse_attn/ops/cute/block_sparsity.py index 9e34734..3fad8c9 100644 --- a/flash_sparse_attn/ops/cute/block_sparsity.py +++ b/flash_sparse_attn/ops/cute/block_sparsity.py @@ -7,7 +7,7 @@ import cutlass.cute as cute import torch -from flash_sparse_attn.ops.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor +from flash_attn.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor def ceildiv(a: int, b: int) -> int: diff --git a/flash_sparse_attn/ops/cute/cache_utils.py b/flash_sparse_attn/ops/cute/cache_utils.py index dc04970..f1b5970 100644 --- a/flash_sparse_attn/ops/cute/cache_utils.py +++ b/flash_sparse_attn/ops/cute/cache_utils.py @@ -17,7 +17,7 @@ import cutlass.cute as cute import tvm_ffi from cutlass.cutlass_dsl import JitCompiledFunction -from flash_sparse_attn.ops.cute.fa_logging import fa_log +from flash_attn.cute.fa_logging import fa_log # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen. diff --git a/flash_sparse_attn/ops/cute/compute_block_sparsity.py b/flash_sparse_attn/ops/cute/compute_block_sparsity.py index d986ecb..a2dd98e 100644 --- a/flash_sparse_attn/ops/cute/compute_block_sparsity.py +++ b/flash_sparse_attn/ops/cute/compute_block_sparsity.py @@ -6,13 +6,13 @@ import torch from cutlass import Boolean, Int8, Int32, const_expr -from flash_sparse_attn.ops.cute.block_sparsity import ( +from flash_attn.cute.block_sparsity import ( BlockSparseTensors, BlockSparseTensorsTorch, to_cute_block_sparse_tensors, ) -from flash_sparse_attn.ops.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar +from flash_attn.cute.seqlen_info import SeqlenInfoQK class BlockSparsityKernel: diff --git a/flash_sparse_attn/ops/cute/flash_bwd.py b/flash_sparse_attn/ops/cute/flash_bwd.py index d4f1db6..824abdd 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd.py +++ b/flash_sparse_attn/ops/cute/flash_bwd.py @@ -15,14 +15,14 @@ import cutlass.utils as utils_basic from quack import layout_utils -from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_attn.cute.block_sparsity import BlockSparseTensors class FlashAttentionBackwardSm80: diff --git a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py index 36e4e8d..76c8562 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py @@ -18,13 +18,13 @@ from quack import layout_utils from quack import sm90_utils -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute import utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK import cutlass.cute.nvgpu.tcgen05 as tcgen05 from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py index 698b19f..d93ea5c 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py @@ -25,10 +25,10 @@ from quack import copy_utils, layout_utils -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo +from flash_attn.cute import utils +from flash_attn.cute.seqlen_info import SeqlenInfo from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py index 58d3b41..e06cd81 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py @@ -16,27 +16,27 @@ import quack.activation from quack import layout_utils -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import copy_utils -from flash_sparse_attn.ops.cute import pipeline -from flash_sparse_attn.ops.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_attn.cute import utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import copy_utils +from flash_attn.cute import pipeline +from flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, # noqa SingleTileVarlenScheduler, ) -from flash_sparse_attn.ops.cute import barrier -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwdSm100 -from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute import barrier +from flash_attn.cute.named_barrier import NamedBarrierBwdSm100 +from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( get_total_q_block_count_bwd, get_block_sparse_iteration_info_bwd, get_m_block_from_iter_bwd, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py index 0941ae2..556c59e 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py index 7980c2d..f724b5a 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py @@ -17,24 +17,24 @@ from quack import sm90_utils from quack.sm90_utils import gemm_zero_init, gemm_w_idx -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute import pipeline +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute import pipeline from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, SingleTileVarlenScheduler, ) -from flash_sparse_attn.ops.cute import barrier -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd -from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute import barrier +from flash_attn.cute.named_barrier import NamedBarrierBwd +from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( get_total_q_block_count_bwd, produce_block_sparse_q_loads_bwd_sm90, consume_block_sparse_mma_bwd_sm90, diff --git a/flash_sparse_attn/ops/cute/flash_fwd.py b/flash_sparse_attn/ops/cute/flash_fwd.py index 43926f7..4d47fab 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd.py +++ b/flash_sparse_attn/ops/cute/flash_fwd.py @@ -23,17 +23,17 @@ from quack import copy_utils from quack import layout_utils -from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.softmax import Softmax -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute.pack_gqa import PackGQA -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_attn.cute import ampere_helpers as sm80_utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import Softmax +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.pack_gqa import PackGQA +from flash_attn.cute.named_barrier import NamedBarrierFwd +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments class FlashAttentionForwardBase: @@ -1190,6 +1190,6 @@ def load_K_next(): # SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility def __getattr__(name): if name == "FlashAttentionForwardSm90": - from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 + from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 return FlashAttentionForwardSm90 raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/flash_sparse_attn/ops/cute/flash_fwd_combine.py b/flash_sparse_attn/ops/cute/flash_fwd_combine.py index 7f38c2b..4936202 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_combine.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_combine.py @@ -12,9 +12,9 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Float32, Int32, Boolean, const_expr -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo +from flash_attn.cute import utils +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute.seqlen_info import SeqlenInfo from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py index de73573..6e4fdbf 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py @@ -33,29 +33,29 @@ from quack import copy_utils, layout_utils -from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -import flash_sparse_attn.ops.cute.pipeline as pipeline_custom +from flash_attn.cute.paged_kv import PagedKVManager +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +import flash_attn.cute.pipeline as pipeline_custom import cutlass.pipeline as cutlass_pipeline -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.softmax import SoftmaxSm100, apply_score_mod_inner -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( get_total_block_count, produce_block_sparse_loads_sm100, softmax_block_sparse_sm100, handle_block_sparse_empty_tile_correction_sm100, ) -from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout -from flash_sparse_attn.ops.cute import mma_sm100_desc as sm100_desc -from flash_sparse_attn.ops.cute import blackwell_helpers as sm100_utils -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwdSm100 +from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout +from flash_attn.cute import mma_sm100_desc as sm100_desc +from flash_attn.cute import blackwell_helpers as sm100_utils +from flash_attn.cute.named_barrier import NamedBarrierFwdSm100 from cutlass.cute import FastDivmodDivisor from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( ClcState, SchedulingMode, TileSchedulerArguments, @@ -65,8 +65,8 @@ SingleTileLPTScheduler, SingleTileVarlenScheduler, ) -from flash_sparse_attn.ops.cute.fa_logging import fa_log, fa_printf -from flash_sparse_attn.ops.cute.utils import smid +from flash_attn.cute.fa_logging import fa_log, fa_printf +from flash_attn.cute.utils import smid # === TUNING KNOBS (agent-editable) === # Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool) diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py index f9917d8..08d219a 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 class FlashAttentionForwardSm120(FlashAttentionForwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py index 8ae0e77..4108ce4 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py @@ -21,23 +21,23 @@ from quack import layout_utils from quack import sm90_utils -from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute.mask import AttentionMask -from flash_sparse_attn.ops.cute.softmax import Softmax, apply_score_mod_inner -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK -from flash_sparse_attn.ops.cute.block_info import BlockInfo -from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors -from flash_sparse_attn.ops.cute.block_sparse_utils import ( +from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned +from flash_attn.cute import utils +from flash_attn.cute.mask import AttentionMask +from flash_attn.cute.softmax import Softmax, apply_score_mod_inner +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_attn.cute.block_sparse_utils import ( produce_block_sparse_loads, consume_block_sparse_loads, ) -from flash_sparse_attn.ops.cute import pipeline as pipeline_custom -from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom -from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager -from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd +from flash_attn.cute import pipeline as pipeline_custom +from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_attn.cute.paged_kv import PagedKVManager +from flash_attn.cute.named_barrier import NamedBarrierFwd from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.tile_scheduler import ( +from flash_attn.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, @@ -45,7 +45,7 @@ ) from cutlass.cute import FastDivmodDivisor -from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardBase +from flash_attn.cute.flash_fwd import FlashAttentionForwardBase class FlashAttentionForwardSm90(FlashAttentionForwardBase): diff --git a/flash_sparse_attn/ops/cute/interface.py b/flash_sparse_attn/ops/cute/interface.py index b5f046a..872960a 100644 --- a/flash_sparse_attn/ops/cute/interface.py +++ b/flash_sparse_attn/ops/cute/interface.py @@ -34,35 +34,35 @@ import cutlass.cute as cute from cutlass import Int32, Float32 from quack.compile_utils import make_fake_tensor as fake_tensor -from flash_sparse_attn.ops.cute.cache_utils import get_jit_cache -from flash_sparse_attn.ops.cute.testing import is_fake_mode +from flash_attn.cute.cache_utils import get_jit_cache +from flash_attn.cute.testing import is_fake_mode if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: - from flash_sparse_attn.ops.cute import cute_dsl_ptxas # noqa: F401 + from flash_attn.cute import cute_dsl_ptxas # noqa: F401 # Patch to dump ptx and then use system ptxas to compile to cubin cute_dsl_ptxas.patch() -from flash_sparse_attn.ops.cute import utils -from flash_sparse_attn.ops.cute import fa_logging -from flash_sparse_attn.ops.cute.cute_dsl_utils import ( +from flash_attn.cute import utils +from flash_attn.cute import fa_logging +from flash_attn.cute.cute_dsl_utils import ( to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims, ) -from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 -from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 -from flash_sparse_attn.ops.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 -from flash_sparse_attn.ops.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 -from flash_sparse_attn.ops.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess -from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 -from flash_sparse_attn.ops.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 -from flash_sparse_attn.ops.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 -from flash_sparse_attn.ops.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 -from flash_sparse_attn.ops.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess -from flash_sparse_attn.ops.cute.flash_fwd_combine import FlashAttentionForwardCombine - -from flash_sparse_attn.ops.cute.block_sparsity import ( +from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 +from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 +from flash_attn.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 +from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess +from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 +from flash_attn.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 +from flash_attn.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 +from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess +from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine + +from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, get_sparse_q_block_size, to_cute_block_sparse_tensors, diff --git a/flash_sparse_attn/ops/cute/mask.py b/flash_sparse_attn/ops/cute/mask.py index ba53d13..6b5ca16 100644 --- a/flash_sparse_attn/ops/cute/mask.py +++ b/flash_sparse_attn/ops/cute/mask.py @@ -8,8 +8,8 @@ from cutlass import Float32, Int32, Uint32, const_expr from quack import layout_utils -import flash_sparse_attn.ops.cute.utils as utils -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +import flash_attn.cute.utils as utils +from flash_attn.cute.seqlen_info import SeqlenInfoQK MaskGenFn: TypeAlias = Callable[[int], Uint32] MASK_R2P_CHUNK_SIZE: int = 32 diff --git a/flash_sparse_attn/ops/cute/pack_gqa.py b/flash_sparse_attn/ops/cute/pack_gqa.py index 484834f..e87df01 100644 --- a/flash_sparse_attn/ops/cute/pack_gqa.py +++ b/flash_sparse_attn/ops/cute/pack_gqa.py @@ -9,7 +9,7 @@ from quack import layout_utils -import flash_sparse_attn.ops.cute.utils as utils +import flash_attn.cute.utils as utils def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): diff --git a/flash_sparse_attn/ops/cute/paged_kv.py b/flash_sparse_attn/ops/cute/paged_kv.py index f84b35c..bf11acb 100644 --- a/flash_sparse_attn/ops/cute/paged_kv.py +++ b/flash_sparse_attn/ops/cute/paged_kv.py @@ -6,7 +6,7 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Int32, const_expr -from flash_sparse_attn.ops.cute import utils +from flash_attn.cute import utils from quack.cute_dsl_utils import ParamsBase from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/softmax.py b/flash_sparse_attn/ops/cute/softmax.py index 0eaa479..eed55a0 100644 --- a/flash_sparse_attn/ops/cute/softmax.py +++ b/flash_sparse_attn/ops/cute/softmax.py @@ -10,9 +10,9 @@ from cutlass import Float32 from quack import layout_utils -import flash_sparse_attn.ops.cute.utils as utils +import flash_attn.cute.utils as utils from quack.cute_dsl_utils import ParamsBase -from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.seqlen_info import SeqlenInfoQK @dataclass diff --git a/flash_sparse_attn/ops/cute/tile_scheduler.py b/flash_sparse_attn/ops/cute/tile_scheduler.py index 9c15dd6..3ee4bc8 100644 --- a/flash_sparse_attn/ops/cute/tile_scheduler.py +++ b/flash_sparse_attn/ops/cute/tile_scheduler.py @@ -19,8 +19,8 @@ from quack.cute_dsl_utils import ParamsBase -import flash_sparse_attn.ops.cute.utils as utils -from flash_sparse_attn.ops.cute.fast_math import clz +import flash_attn.cute.utils as utils +from flash_attn.cute.fast_math import clz class SchedulingMode(IntEnum): From b8e264ece3257e302d4c4de8f7bba6de2fa39df5 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 21 Apr 2026 16:46:33 +0800 Subject: [PATCH 2/3] Merge branch 'sync/cute-upstream-temp' of /Users/losercheems/Workspace/flash-sparse-attention/.ref_repo/flash-attention into sync/cute-worktree-20260421-164501 --- flash_sparse_attn/ops/cute/__init__.py | 8 - flash_sparse_attn/ops/cute/bench_utils.py | 42 +- .../ops/cute/benchmark_flash_attention_fp8.py | 434 +++ .../ops/cute/blackwell_helpers.py | 68 +- .../ops/cute/block_sparse_utils.py | 7 +- flash_sparse_attn/ops/cute/cute_dsl_utils.py | 77 +- flash_sparse_attn/ops/cute/flash_bwd.py | 31 +- flash_sparse_attn/ops/cute/flash_bwd_sm100.py | 1 - flash_sparse_attn/ops/cute/flash_bwd_sm90.py | 1 - flash_sparse_attn/ops/cute/flash_fwd.py | 38 +- .../ops/cute/flash_fwd_mla_sm100.py | 3440 +++++++++++++++++ flash_sparse_attn/ops/cute/flash_fwd_sm100.py | 144 +- flash_sparse_attn/ops/cute/interface.py | 441 ++- flash_sparse_attn/ops/cute/mask.py | 18 +- flash_sparse_attn/ops/cute/mma_sm100_desc.py | 4 +- flash_sparse_attn/ops/cute/named_barrier.py | 9 + flash_sparse_attn/ops/cute/softmax.py | 40 +- flash_sparse_attn/ops/cute/testing.py | 26 +- flash_sparse_attn/ops/cute/tile_scheduler.py | 11 +- flash_sparse_attn/ops/cute/topk_gather_kv.py | 274 ++ flash_sparse_attn/ops/cute/utils.py | 45 +- 21 files changed, 4937 insertions(+), 222 deletions(-) create mode 100644 flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py create mode 100644 flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py create mode 100644 flash_sparse_attn/ops/cute/topk_gather_kv.py diff --git a/flash_sparse_attn/ops/cute/__init__.py b/flash_sparse_attn/ops/cute/__init__.py index 1b84363..be32e14 100644 --- a/flash_sparse_attn/ops/cute/__init__.py +++ b/flash_sparse_attn/ops/cute/__init__.py @@ -7,19 +7,11 @@ except PackageNotFoundError: __version__ = "0.0.0" -import cutlass.cute as cute - from .interface import ( flash_attn_func, flash_attn_varlen_func, ) -from flash_attn.cute.cute_dsl_utils import cute_compile_patched - -# Patch cute.compile to optionally dump SASS -cute.compile = cute_compile_patched - - __all__ = [ "flash_attn_func", "flash_attn_varlen_func", diff --git a/flash_sparse_attn/ops/cute/bench_utils.py b/flash_sparse_attn/ops/cute/bench_utils.py index 45cbcf1..f6ad96d 100644 --- a/flash_sparse_attn/ops/cute/bench_utils.py +++ b/flash_sparse_attn/ops/cute/bench_utils.py @@ -13,7 +13,15 @@ def flops( - batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None) + batch, + nheads, + seqlen_q, + seqlen_k, + headdim, + headdim_v, + causal=False, + window_size=(None, None), + has_qv=False, ): if causal: avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2 @@ -35,7 +43,37 @@ def flops( else torch.full_like(row_idx, seqlen_k - 1) ) avg_seqlen = (col_right - col_left + 1).float().mean().item() - return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v) + eff_headdim = headdim + headdim_v if has_qv else headdim + return batch * nheads * 2 * seqlen_q * avg_seqlen * (eff_headdim + headdim_v) + + +# ── Bandwidth calculation ──────────────────────────────────────────────────── + + +def bandwidth_fwd_bytes( + batch, nheads, nheads_kv, seqlen_q, seqlen_k, headdim, headdim_v, dtype_bytes=2, has_qv=False +): + """HBM traffic for one attention pass: read Q,K,V + write O.""" + q = batch * nheads * seqlen_q * headdim + qv = batch * nheads * seqlen_q * headdim_v if has_qv else 0 + k = batch * nheads_kv * seqlen_k * headdim + v = batch * nheads_kv * seqlen_k * headdim_v + o = batch * nheads * seqlen_q * headdim_v + return (q + qv + k + v + o) * dtype_bytes + + +def bandwidth_bwd_bytes( + batch, nheads, nheads_kv, seqlen_q, seqlen_k, headdim, headdim_v, dtype_bytes=2 +): + """HBM traffic for one attention pass: read Q,K,V,dO + write dQ,dK,dV.""" + q = batch * nheads * seqlen_q * headdim + k = batch * nheads_kv * seqlen_k * headdim + v = batch * nheads_kv * seqlen_k * headdim_v + do = batch * nheads * seqlen_q * headdim_v + dq = q + dk = k + dv = v + return (q + k + v + do + dq + dk + dv) * dtype_bytes # ── Reference attention ───────────────────────────────────────────────────── diff --git a/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py b/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py new file mode 100644 index 0000000..c79e768 --- /dev/null +++ b/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py @@ -0,0 +1,434 @@ +# Benchmark FP8 attention for FA4 (CuTe-DSL) on SM100. +# +# Run (recommended): +# python -m flash_attn.cute.benchmark_flash_attention_fp8 +# +# Notes: +# - This is intended to be used while bringing up FP8 support for SM100. +# - FP8 correctness depends on descales + max-offset scaling being implemented in the SM100 kernel. +# This script optionally checks output vs a BF16 PyTorch baseline on dequantized FP8 inputs. +# +# Adapted from: `hopper/benchmark_flash_attention_fp8.py` + +from __future__ import annotations + +import argparse +import inspect +import math +import time +from typing import Iterable + +import torch +from einops import rearrange + +from flash_attn.cute.benchmark import benchmark_forward +from flash_attn.cute.interface import _flash_attn_fwd as flash_attn_cute_fwd + +try: + import cudnn +except ImportError: + cudnn = None + + +def _torch_float8_dtype(name: str) -> torch.dtype: + if name in ("fp8", "fp8_e4m3", "fp8_e4m3fn"): + return torch.float8_e4m3fn + if name in ("fp8_e5m2", "fp8_e5m2fn"): + return torch.float8_e5m2 + raise ValueError(f"Unsupported fp8 dtype name: {name}") + + +def _parse_int_list(csv: str) -> list[int]: + out: list[int] = [] + for part in csv.split(","): + part = part.strip() + if not part: + continue + out.append(int(part)) + return out + + +def attention_pytorch(qkv: torch.Tensor, causal: bool) -> torch.Tensor: + """ + qkv: (batch, seqlen, 3, nheads, headdim) + out: (batch, seqlen, nheads, headdim) + """ + batch_size, seqlen, _, nheads, d = qkv.shape + q, k, v = qkv.unbind(dim=2) + q = rearrange(q, "b t h d -> (b h) t d") + k = rearrange(k, "b s h d -> (b h) d s") + softmax_scale = 1.0 / math.sqrt(d) + scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device) + scores = rearrange( + torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale), "(b h) t s -> b h t s", h=nheads + ) + if causal: + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1) + output = torch.einsum("bhts,bshd->bthd", attention, v) + return output.to(dtype=qkv.dtype) + + +def flops(batch: int, seqlen: int, headdim: int, nheads: int, causal: bool) -> int: + # Matches the hopper benchmark’s convention. + return 4 * batch * seqlen**2 * nheads * headdim // (2 if causal else 1) + + +def efficiency(flop: int, seconds: float) -> float: + return (flop / seconds / 1e12) if not math.isnan(seconds) else 0.0 + + +def time_fwd(fn, *args, repeats: int, **kwargs) -> float: + time.sleep(1) # reduce residual throttling effects between benchmarks + _, m = benchmark_forward(fn, *args, repeats=repeats, verbose=False, **kwargs) + return float(m.mean) + + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + if torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + if torch_type == torch.float32: + return cudnn.data_type.FLOAT + if torch_type == torch.int32: + return cudnn.data_type.INT32 + if torch_type == torch.int64: + return cudnn.data_type.INT64 + if torch_type == torch.float8_e4m3fn: + return cudnn.data_type.FP8_E4M3 + if torch_type == torch.float8_e5m2: + return cudnn.data_type.FP8_E5M2 + raise ValueError("Unsupported tensor data type.") + + +def cudnn_sdpa_fp8_setup(qkv: torch.Tensor, seqlen_q: int, seqlen_k: int, causal: bool): + """Minimal cudnn.fp8 sdpa runner (optional).""" + assert cudnn is not None, "cudnn python bindings not available" + b, _, _, nheads, headdim = qkv.shape + o_gpu = torch.zeros(b, seqlen_q, nheads, headdim, dtype=qkv.dtype, device=qkv.device) + o_gpu_transposed = torch.as_strided( + o_gpu, + [b, nheads, seqlen_q, headdim], + [nheads * seqlen_q * headdim, headdim, nheads * headdim, 1], + ) + amax_s_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device) + amax_o_gpu = torch.empty(1, 1, 1, 1, dtype=torch.float32, device=qkv.device) + + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(qkv.dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + new_q = torch.as_strided( + qkv, + [b, nheads, seqlen_q, headdim], + [seqlen_q * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], + storage_offset=0, + ) + q = graph.tensor( + name="Q", + dim=list(new_q.shape), + stride=list(new_q.stride()), + data_type=convert_to_cudnn_type(qkv.dtype), + ) + + new_k = torch.as_strided( + qkv, + [b, nheads, seqlen_k, headdim], + [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], + storage_offset=nheads * headdim, + ) + k = graph.tensor( + name="K", + dim=list(new_k.shape), + stride=list(new_k.stride()), + data_type=convert_to_cudnn_type(qkv.dtype), + ) + + new_v = torch.as_strided( + qkv, + [b, nheads, seqlen_k, headdim], + [seqlen_k * nheads * headdim * 3, headdim, headdim * nheads * 3, 1], + storage_offset=nheads * headdim * 2, + ) + v = graph.tensor( + name="V", + dim=list(new_v.shape), + stride=list(new_v.stride()), + data_type=convert_to_cudnn_type(qkv.dtype), + ) + + def _scale_tensor(): + return graph.tensor(dim=[1, 1, 1, 1], stride=[1, 1, 1, 1], data_type=cudnn.data_type.FLOAT) + + default_scale_gpu = torch.ones(1, 1, 1, 1, dtype=torch.float32, device="cuda") + descale_q = _scale_tensor() + descale_k = _scale_tensor() + descale_v = _scale_tensor() + descale_s = _scale_tensor() + scale_s = _scale_tensor() + scale_o = _scale_tensor() + + o, _, amax_s, amax_o = graph.sdpa_fp8( + q=q, + k=k, + v=v, + descale_q=descale_q, + descale_k=descale_k, + descale_v=descale_v, + descale_s=descale_s, + scale_s=scale_s, + scale_o=scale_o, + is_inference=True, + attn_scale=1.0 / math.sqrt(headdim), + use_causal_mask=causal, + name="sdpa", + ) + o.set_output(True).set_dim(o_gpu_transposed.shape).set_stride(o_gpu_transposed.stride()) + amax_s.set_output(False).set_dim(amax_s_gpu.shape).set_stride(amax_s_gpu.stride()) + amax_o.set_output(False).set_dim(amax_o_gpu.shape).set_stride(amax_o_gpu.stride()) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() + + variant_pack = { + q: new_q, + k: new_k, + v: new_v, + descale_q: default_scale_gpu, + descale_k: default_scale_gpu, + descale_v: default_scale_gpu, + descale_s: default_scale_gpu, + scale_s: default_scale_gpu, + scale_o: default_scale_gpu, + o: o_gpu_transposed, + amax_s: amax_s_gpu, + amax_o: amax_o_gpu, + } + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + + def run(): + graph.execute(variant_pack, workspace) + return o_gpu + + return run + + +def _maybe_pass_descales(callable_, **kwargs): + sig = inspect.signature(callable_) + return {k: v for k, v in kwargs.items() if k in sig.parameters} + + +def main(argv: Iterable[str] | None = None) -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--repeats", type=int, default=30) + parser.add_argument("--dim", type=int, default=2048) + parser.add_argument("--headdims", default="64,128") + parser.add_argument("--dtype", default="fp8_e4m3fn") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--check", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable correctness checks vs BF16 PyTorch baseline.", + ) + parser.add_argument( + "--check-quantization-only", + action="store_true", + help="Check FP8 kernel vs dequantized-FP8 baseline (quantization error only).", + ) + parser.add_argument("--atol-bf16", type=float, default=0.10) + parser.add_argument("--rtol-bf16", type=float, default=0.10) + parser.add_argument("--atol-fp8", type=float, default=0.50) + parser.add_argument("--rtol-fp8", type=float, default=0.50) + parser.add_argument("--run-cudnn", action="store_true") + args = parser.parse_args(list(argv) if argv is not None else None) + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + major, minor = torch.cuda.get_device_capability() + if major != 10: + raise RuntimeError( + f"This benchmark is for SM100 (compute capability 10.x). Got {major}.{minor}." + ) + + torch.manual_seed(args.seed) + device = "cuda" + fp8_dtype = _torch_float8_dtype(args.dtype) + headdim_vals = _parse_int_list(args.headdims) + bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 16384)] + + methods = ["Pytorch", "FA4-CuTe-BF16", "FA4-CuTe-FP8"] + ( + ["cuDNN-FP8"] if args.run_cudnn and cudnn is not None else [] + ) + + fp8_failures = [] + + for headdim in headdim_vals: + for causal in (False, True): + for batch, seqlen in bs_seqlen_vals: + torch.cuda.empty_cache() + nheads = args.dim // headdim + if args.dim % headdim != 0: + raise ValueError(f"--dim must be divisible by headdim ({args.dim=} {headdim=})") + + q_bf16 = torch.randn( + batch, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16 + ) + k_bf16 = torch.randn( + batch, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16 + ) + v_bf16 = torch.randn( + batch, seqlen, nheads, headdim, device=device, dtype=torch.bfloat16 + ) + qkv_bf16 = torch.stack([q_bf16, k_bf16, v_bf16], dim=2) + + times = {} + speeds = {} + + out_ref_bf16 = None + try: + out_ref_bf16 = attention_pytorch(qkv_bf16, causal=causal) # warmup / reference + t = time_fwd(attention_pytorch, qkv_bf16, causal=causal, repeats=args.repeats) + times["Pytorch"] = t + except RuntimeError as e: + if "out of memory" in str(e).lower(): + times["Pytorch"] = float("nan") + out_ref_bf16 = None + else: + raise + + # FA4 / CuTe BF16 baseline + try: + softmax_scale = headdim**-0.5 + out_fa4_bf16, _ = flash_attn_cute_fwd( + q_bf16, k_bf16, v_bf16, softmax_scale=softmax_scale, causal=causal + ) # warmup / compile + t = time_fwd( + flash_attn_cute_fwd, + q_bf16, + k_bf16, + v_bf16, + softmax_scale=softmax_scale, + causal=causal, + repeats=args.repeats, + ) + times["FA4-CuTe-BF16"] = t + if args.check and out_ref_bf16 is not None: + torch.testing.assert_close( + out_fa4_bf16, + out_ref_bf16, + atol=args.atol_bf16, + rtol=args.rtol_bf16, + ) + except Exception as e: + # Treat as fatal: BF16 kernel should be usable for basic sanity checking. + raise RuntimeError("FA4-CuTe BF16 baseline failed") from e + + # FA4 / CuTe FP8 + q_fp8 = q_bf16.to(fp8_dtype) + k_fp8 = k_bf16.to(fp8_dtype) + v_fp8 = v_bf16.to(fp8_dtype) + + # Placeholder descales (FA3-style: per-(batch, kv_head)). + q_descale = torch.ones(batch, nheads, device=device, dtype=torch.float32) + k_descale = torch.ones(batch, nheads, device=device, dtype=torch.float32) + v_descale = torch.ones(batch, nheads, device=device, dtype=torch.float32) + + # Optional: FP8 reference baseline (dequantized FP8 -> PyTorch) for quantization-error-only checks + out_ref_fp8 = None + if args.check and args.check_quantization_only: + try: + # Dequantize FP8 inputs back to BF16 (applying descales) + q_ref_fp8 = (q_fp8.to(torch.bfloat16) * q_descale[:, None, :, None]).to( + torch.bfloat16 + ) + k_ref_fp8 = (k_fp8.to(torch.bfloat16) * k_descale[:, None, :, None]).to( + torch.bfloat16 + ) + v_ref_fp8 = (v_fp8.to(torch.bfloat16) * v_descale[:, None, :, None]).to( + torch.bfloat16 + ) + qkv_ref_fp8 = torch.stack([q_ref_fp8, k_ref_fp8, v_ref_fp8], dim=2) + out_ref_fp8 = attention_pytorch(qkv_ref_fp8, causal=causal) + except RuntimeError as e: + if "out of memory" in str(e).lower(): + out_ref_fp8 = None + else: + raise + + fa4_kwargs = dict(softmax_scale=softmax_scale, causal=causal) + fa4_kwargs.update( + _maybe_pass_descales( + flash_attn_cute_fwd, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + ) + + try: + # Warmup/compile (will raise until FP8 is implemented) + out_fa4_fp8, _ = flash_attn_cute_fwd(q_fp8, k_fp8, v_fp8, **fa4_kwargs) + t = time_fwd( + flash_attn_cute_fwd, + q_fp8, + k_fp8, + v_fp8, + repeats=args.repeats, + **fa4_kwargs, + ) + times["FA4-CuTe-FP8"] = t + if args.check: + # Choose baseline: quantization-only (dequantized FP8) or full (BF16) + if args.check_quantization_only: + ref_baseline = out_ref_fp8 + else: + ref_baseline = out_ref_bf16 + + if ref_baseline is not None: + torch.testing.assert_close( + out_fa4_fp8, + ref_baseline, + atol=args.atol_fp8, + rtol=args.rtol_fp8, + ) + except Exception as e: + fp8_failures.append((causal, headdim, batch, seqlen, repr(e))) + times["FA4-CuTe-FP8"] = float("nan") + + if args.run_cudnn and cudnn is not None: + qkv_fp8 = qkv_bf16.to(fp8_dtype) + runner = cudnn_sdpa_fp8_setup(qkv_fp8, seqlen, seqlen, causal=causal) + _ = runner() # warmup + t = time_fwd(lambda: runner(), repeats=args.repeats) + times["cuDNN-FP8"] = t + + print(f"### causal={causal}, headdim={headdim}, batch={batch}, seqlen={seqlen} ###") + for method in methods: + t = times.get(method, float("nan")) + speeds[method] = efficiency(flops(batch, seqlen, headdim, nheads, causal), t) + if math.isnan(t): + print(f"{method} fwd: (skipped)") + else: + print(f"{method} fwd: {speeds[method]:.2f} TFLOPs/s, {t * 1e3:.3f} ms") + if math.isnan(times.get("FA4-CuTe-FP8", float("nan"))): + print("FA4-CuTe-FP8 status: FAILED") + + if fp8_failures: + print(f"\nFP8 failures: {len(fp8_failures)} (showing first 5)") + for causal, headdim, batch, seqlen, err in fp8_failures[:5]: + print(f"- causal={causal} headdim={headdim} batch={batch} seqlen={seqlen}: {err}") + return 1 + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/flash_sparse_attn/ops/cute/blackwell_helpers.py b/flash_sparse_attn/ops/cute/blackwell_helpers.py index 7207780..4caadce 100644 --- a/flash_sparse_attn/ops/cute/blackwell_helpers.py +++ b/flash_sparse_attn/ops/cute/blackwell_helpers.py @@ -10,6 +10,24 @@ import flash_attn.cute.mma_sm100_desc as sm100_desc +def _tcgen05_mma_kind(op: cute.nvgpu.tcgen05.mma.MmaOp) -> str: + if isinstance(op, tcgen05.mma.MmaF16BF16Op): + return "f16" + if isinstance(op, tcgen05.mma.MmaTF32Op): + return "tf32" + if isinstance(op, tcgen05.mma.MmaI8Op): + return "i8" + if isinstance(op, tcgen05.mma.MmaFP8Op): + return "f8f6f4" + if isinstance(op, tcgen05.mma.MmaMXF8Op): + return "mxf8f6f4" + if isinstance(op, tcgen05.mma.MmaMXF4Op): + return "mxf4" + if isinstance(op, tcgen05.mma.MmaMXF4NVF4Op): + return "mxf4nvf4" + raise TypeError(f"Unsupported tcgen05 MMA op kind: {type(op).__name__}") + + @cute.jit def gemm_w_idx( tiled_mma: cute.TiledMma, @@ -108,6 +126,7 @@ def gemm_ptx( sA_layout = sA.layout if sA is not None else None sB_layout = sB.layout idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + kind = _tcgen05_mma_kind(op) if const_expr(not is_ts): sA_swizzle = sA.iterator.type.swizzle_type smem_desc_base_a: int = const_expr( @@ -177,7 +196,7 @@ def gemm_ptx( f"mov.b64 smem_desc_a, {{$1, {hex(smem_desc_a_hi)}}};\n\t" f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" - f"tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" + f"tcgen05.mma.cta_group::1.kind::{kind} [$0], smem_desc_a, smem_desc_b, idesc, p;\n\t" "}\n", "r,r,r,r", has_side_effects=True, @@ -198,7 +217,7 @@ def gemm_ptx( ".reg .b64 smem_desc_b;\n\t" f"mov.b64 smem_desc_b, {{$2, {hex(smem_desc_b_hi)}}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" - f"tcgen05.mma.cta_group::1.kind::f16 [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" + f"tcgen05.mma.cta_group::1.kind::{kind} [$0], [$1], smem_desc_b, {hex(idesc)}, p;\n\t" "}\n", "r,r,r,r", has_side_effects=True, @@ -223,6 +242,7 @@ def gemm_ptx_loop( sA_layout = sA.layout if sA is not None else tCrA.layout sB_layout = sB.layout idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + kind = _tcgen05_mma_kind(op) if const_expr(not is_ts): sA_swizzle = sA.iterator.type.swizzle_type smem_desc_base_a: int = const_expr( @@ -310,14 +330,14 @@ def gemm_ptx_loop( f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], smem_desc_a, smem_desc_b, idesc, 1;\n\t" ) for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) ) @@ -351,14 +371,14 @@ def gemm_ptx_loop( f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, 1;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in cutlass.range_constexpr(1, cute.size(tCrA.shape[2])) ) @@ -394,6 +414,7 @@ def gemm_ptx_partial( sA_layout = sA.layout if sA is not None else tCrA.layout sB_layout = sB.layout idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + kind = _tcgen05_mma_kind(op) if const_expr(not is_ts): sA_swizzle = sA.iterator.type.swizzle_type smem_desc_base_a: int = const_expr( @@ -477,7 +498,7 @@ def gemm_ptx_partial( f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" @@ -486,7 +507,7 @@ def gemm_ptx_partial( f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" ) for k in range(1, cute.size(tCrA.shape[2])) ) @@ -554,7 +575,7 @@ def gemm_ptx_partial( f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" @@ -562,7 +583,7 @@ def gemm_ptx_partial( f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range( 1, @@ -575,7 +596,7 @@ def gemm_ptx_partial( ( f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range(split_arrive_idx, cute.size(tCrA.shape[2])) ) @@ -613,6 +634,7 @@ def gemm_ptx_partial1( assert sA_layout is not None, "sA_layout must be provided when a_src is not TMEM" assert sA_swizzle is not None, "sA_swizzle must be provided when a_src is not TMEM" idesc: int = const_expr(sm100_desc.mma_op_to_idesc(op)) + kind = _tcgen05_mma_kind(op) if const_expr(not is_ts): smem_desc_base_a: int = const_expr( sm100_desc.make_smem_desc_base( @@ -706,14 +728,14 @@ def gemm_ptx_partial1( f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $4, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, {pred_str};\n\t" + "".join( ( f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {{$5, $6, $7, $8}}, 1;\n\t" ) for k in range(1, cute.size(tCrA.shape[2])) ) @@ -751,13 +773,13 @@ def gemm_ptx_partial1( f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $3, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, {pred_str};\n\t" + "".join( ( f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::1.kind::{kind} [$0], [tmem_a], smem_desc_b, idesc, {{$4, $5, $6, $7}}, 1;\n\t" ) for k in range(1, cute.size(tCrA.shape[2])) ) @@ -783,6 +805,7 @@ def gemm_ptx_precomputed( mbar_phase: Optional[Int32] = None, zero_init: bool | Boolean = False, cta_group: int = 1, + kind: str = "f16", ) -> None: # acc_tmem_addr += acc_offset is_ts = const_expr(smem_desc_base_a is None) @@ -842,7 +865,7 @@ def gemm_ptx_precomputed( f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t" @@ -851,7 +874,7 @@ def gemm_ptx_precomputed( f"add.s32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t" ) for k in range(1, num_k_tile) ) @@ -911,7 +934,7 @@ def gemm_ptx_precomputed( f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t" "setp.ne.b32 p, $2, 0;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t" + "".join( ( # f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t" @@ -919,7 +942,7 @@ def gemm_ptx_precomputed( f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range( 1, @@ -933,7 +956,7 @@ def gemm_ptx_precomputed( # f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t" f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t" f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t" ) for k in range(num_k_tile // 4 * 3, num_k_tile) ) @@ -1019,6 +1042,7 @@ def gemm_ptx_precomputed_varname( smem_offset: int, zero_init: bool | Boolean = False, cta_group: int = 1, + kind: str = "f16", ) -> None: is_ts = False num_k_tile = cute.size(tCrB_layout.shape[2]) @@ -1067,7 +1091,7 @@ def gemm_ptx_precomputed_varname( ) + "setp.ne.b32 p, $1, 0;\n\t" # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b, idesc, {pred_str};\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], {smem_var_name_prefix}_0, smem_desc_b_0, {idesc_var_name}, {pred_str};\n\t" + "".join( ( # f"mov.b64 {{smem_desc_a_lo, smem_desc_a_hi}}, {smem_var_name_prefix}_{k};\n\t" @@ -1077,7 +1101,7 @@ def gemm_ptx_precomputed_varname( # f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t" # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, idesc, 1;\n\t" # f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b, {idesc_var_name}, 1;\n\t" - f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t" + f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::{kind} [tmem_acc], {smem_var_name_prefix}_{k}, smem_desc_b_{k}, {idesc_var_name}, 1;\n\t" ) for k in range(1, num_k_tile) ) diff --git a/flash_sparse_attn/ops/cute/block_sparse_utils.py b/flash_sparse_attn/ops/cute/block_sparse_utils.py index 52cb7e0..b19dcd3 100644 --- a/flash_sparse_attn/ops/cute/block_sparse_utils.py +++ b/flash_sparse_attn/ops/cute/block_sparse_utils.py @@ -667,6 +667,8 @@ def handle_block_sparse_empty_tile_correction_sm100( o_corr_consumer_phase: Int32, corr_epi_producer_phase: Int32, softmax_scale_log2: Float32, + max_offset: Float32, + max_offset_scale: Float32, mO_cur: Optional[cute.Tensor] = None, gO: Optional[cute.Tensor] = None, gmem_tiled_copy_O: Optional[cute.TiledCopy] = None, @@ -706,10 +708,11 @@ def handle_block_sparse_empty_tile_correction_sm100( if sink_val != -Float32.inf and (const_expr(not is_split_kv) or split_idx == 0): if row_max_value == -Float32.inf: row_max_value = sink_val * (LOG2_E / softmax_scale_log2) - row_sum_value = Float32(1.0) + row_sum_value = max_offset_scale else: row_sum_value = row_sum_value + cute.math.exp2( - sink_val * LOG2_E - row_max_value * softmax_scale_log2, fastmath=True + sink_val * LOG2_E - row_max_value * softmax_scale_log2 + max_offset, + fastmath=True, ) if tidx < m_block_size: scale_row_idx = tidx + stage * m_block_size diff --git a/flash_sparse_attn/ops/cute/cute_dsl_utils.py b/flash_sparse_attn/ops/cute/cute_dsl_utils.py index 79ebd9d..6dfad66 100644 --- a/flash_sparse_attn/ops/cute/cute_dsl_utils.py +++ b/flash_sparse_attn/ops/cute/cute_dsl_utils.py @@ -1,9 +1,7 @@ # Copyright (c) 2025, Tri Dao. -import os -import pathlib from typing import Tuple -from functools import partial, lru_cache +from functools import lru_cache import torch @@ -28,6 +26,8 @@ torch.float16: cutlass.Float16, torch.bfloat16: cutlass.BFloat16, torch.float32: cutlass.Float32, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, + torch.float8_e5m2: cutlass.Float8E5M2, } @@ -41,27 +41,6 @@ def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: return torch.cuda.get_device_capability(device) -def load_cubin_module_data_patched(cubin_data, filepath): - pathlib.Path(filepath).write_bytes(cubin_data) - return load_cubin_module_data_og(cubin_data) - - -def cute_compile_patched(*args, **kwargs): - """A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set.""" - cubin_path = os.getenv("CUTE_CUBIN_PATH", None) - if cubin_path is not None: - cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial( - load_cubin_module_data_patched, filepath=cubin_path - ) - output = cute_compile_og(*args, **kwargs) - if cubin_path is not None: - cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og - if extract is not None: - sass = extract(cubin_path, None) - pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass) - return output - - def assume_strides_aligned(t): """Assume all strides except the last are divisible by 128 bits. @@ -82,7 +61,20 @@ def assume_tensor_aligned(t): def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True): """Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1.""" - tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi) + # NOTE: torch 2.9.1 doesn't support fp8 via DLPack but 2.11.0 nightly does + # currently export raw bytes as uint8 and tell cutlass correct type + # can directly export as fp8 when torch supports it + if t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + tensor = from_dlpack( + t.view(torch.uint8).detach(), + assumed_align=assumed_align, + enable_tvm_ffi=enable_tvm_ffi, + ) + tensor.element_type = ( + cutlass.Float8E4M3FN if t.dtype == torch.float8_e4m3fn else cutlass.Float8E5M2 + ) + else: + tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi) if fully_dynamic: return tensor.mark_layout_dynamic() if leading_dim == -1: @@ -127,3 +119,38 @@ def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]: patterns are not interchangeable. """ return tuple(s == 0 for s in tensor.stride()) + + +# credit: monellz (https://github.com/NVIDIA/cutlass/issues/2658#issuecomment-3630564264) +def dump_kernel_attributes(compiled_kernel): + from cuda.bindings import driver + from cutlass.utils import HardwareInfo + import torch + + device_id = torch.cuda.current_device() + hardware_info = HardwareInfo(device_id=device_id) + cubin_data = compiled_kernel.artifacts.CUBIN + assert cubin_data is not None, "cubin_data is None, need '--keep-cubin' option when compiling" + cuda_library = hardware_info._checkCudaErrors( + driver.cuLibraryLoadData(cubin_data, None, None, 0, None, None, 0) + ) + kernels = hardware_info._checkCudaErrors(driver.cuLibraryEnumerateKernels(1, cuda_library)) + kernel = hardware_info._checkCudaErrors(driver.cuKernelGetFunction(kernels[0])) + # more metrics: https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1g5e92a1b0d8d1b82cb00dcfb2de15961b + local_size_bytes = hardware_info._checkCudaErrors( + driver.cuFuncGetAttribute( + driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, + kernel, + ) + ) + num_regs = hardware_info._checkCudaErrors( + driver.cuFuncGetAttribute( + driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NUM_REGS, + kernel, + ) + ) + + print("--- Kernel Info ---") + print(f"local_size_bytes: {local_size_bytes}") + print(f"num_regs: {num_regs}") + print("--- End Kernel Info ---") diff --git a/flash_sparse_attn/ops/cute/flash_bwd.py b/flash_sparse_attn/ops/cute/flash_bwd.py index 824abdd..eeb7615 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd.py +++ b/flash_sparse_attn/ops/cute/flash_bwd.py @@ -46,6 +46,8 @@ def __init__( AtomLayoutNdKV: int = 8, AtomLayoutMdQ: int = 1, V_in_regs: bool = False, + score_mod: cutlass.Constexpr | None = None, + score_mod_bwd: cutlass.Constexpr | None = None, ): """Initializes the configuration for a flash attention v2 kernel. @@ -90,6 +92,8 @@ def __init__( self.Mma_dKV_is_RS = AtomLayoutMSdP == 1 and AtomLayoutNdKV == num_mma_warps and SdP_swapAB and not dKV_swapAB self.V_in_regs = V_in_regs self.share_QV_smem = V_in_regs + self.score_mod = score_mod + self.score_mod_bwd = score_mod_bwd @staticmethod def can_implement( @@ -377,7 +381,6 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, @@ -430,7 +433,7 @@ def __call__( tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) grid_dim = TileScheduler.get_grid_shape(tile_sched_params) - softmax_scale_log2 = softmax_scale * math.log2(math.e) + softmax_scale_log2, softmax_scale = utils.compute_softmax_scale_log2(softmax_scale, self.score_mod) self.kernel( mQ, mK, @@ -773,6 +776,7 @@ def kernel( smem_copy_params=smem_copy_params, gmem_copy_params=gmem_copy_params, load_Q_LSE=load_Q_LSE, load_dO_dPsum=load_dO_dPsum, m_block_max=m_block_max, + softmax_scale=softmax_scale, softmax_scale_log2=softmax_scale_log2, ) @@ -861,6 +865,7 @@ def compute_one_m_block( load_Q_LSE: Callable, load_dO_dPsum: Callable, m_block_max: cutlass.Int32, + softmax_scale: cutlass.Float32, softmax_scale_log2: cutlass.Float32, mask_fn: Optional[Callable] = None, ): @@ -890,13 +895,24 @@ def load_dO_next(): smem_copy_params.smem_thr_copy_QdO, smem_copy_params.smem_thr_copy_KV, swap_AB=self.SdP_swapAB, ) + acc_S_pre = cute.make_fragment_like(acc_S) + acc_S_pre.store(acc_S.load()) tLSErLSE = cute.make_fragment_like(smem_copy_params.tSsLSEMma[None, 0]) cute.autovec_copy( smem_copy_params.tSsLSEMma[None, smem_pipe_read_q if cutlass.const_expr(self.num_stages_Q > 1) else 0], tLSErLSE ) + acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S) + acc_S_pre_mn = layout_utils.reshape_acc_to_mn(acc_S_pre) + if cutlass.const_expr(self.score_mod is not None): + for r in cutlass.range(cute.size(acc_S_mn, mode=[0]), unroll_full=True): + acc_S_mn[r, None].store( + self.score_mod( + acc_S_mn[r, None].load() * softmax_scale, + 0, 0, 0, 0, None, [], + ) + ) if cutlass.const_expr(mask_fn is not None): mask_fn(acc_S, m_block=m_block) - acc_S_mn = layout_utils.reshape_acc_to_mn(acc_S) bidx = 0 # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_S_mn) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 1: cute.print_tensor(tLSErLSE) @@ -926,7 +942,14 @@ def load_dO_next(): # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) assert cute.size(acc_dP_mn, mode=[0]) == cute.size(tLSErdPsum) for r in cutlass.range(cute.size(acc_dP_mn, mode=[0]), unroll_full=True): - acc_dP_mn[r, None].store(acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r])) + grad_val = acc_S_mn[r, None].load() * (acc_dP_mn[r, None].load() - tLSErdPsum[r]) + if cutlass.const_expr(self.score_mod_bwd is not None): + grad_val = self.score_mod_bwd( + grad_val, + acc_S_pre_mn[r, None].load() * softmax_scale, + 0, 0, 0, 0, None, [], + ) + acc_dP_mn[r, None].store(grad_val) # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == bidx: cute.print_tensor(acc_dP_mn) rP = cute.make_fragment_like(acc_S, self.dtype) rP.store(acc_S.load().to(self.dtype)) diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py index e06cd81..4b4083e 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py @@ -456,7 +456,6 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py index f724b5a..c9a690d 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py @@ -350,7 +350,6 @@ def __call__( mCuSeqlensK: Optional[cute.Tensor] = None, mSeqUsedQ: Optional[cute.Tensor] = None, mSeqUsedK: Optional[cute.Tensor] = None, - softcap: Float32 | float | None = None, window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, mdQ_semaphore: Optional[cute.Tensor] = None, diff --git a/flash_sparse_attn/ops/cute/flash_fwd.py b/flash_sparse_attn/ops/cute/flash_fwd.py index 4d47fab..d1a43cf 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd.py +++ b/flash_sparse_attn/ops/cute/flash_fwd.py @@ -27,7 +27,7 @@ from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned from flash_attn.cute import utils from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import Softmax +from flash_attn.cute.softmax import Softmax, apply_score_mod_inner from flash_attn.cute.seqlen_info import SeqlenInfoQK from flash_attn.cute.block_info import BlockInfo from flash_attn.cute.pack_gqa import PackGQA @@ -1145,8 +1145,8 @@ def load_V_next(): m_block, acc_S, n_block, - seqlen, softmax_scale=softmax.softmax_scale, + seqlen=seqlen, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, ) @@ -1185,6 +1185,40 @@ def load_K_next(): ) # if const_expr(self.num_stages > 1): # load_K_next() + @cute.jit + def apply_score_mod( + self, + thr_mma_qk, + batch_idx, + head_idx, + m_block, + acc_S, + n_block, + softmax_scale, + seqlen, + aux_tensors: Optional[list] = None, + fastdiv_mods=None, + ): + # Prepare index tensor + cS = cute.make_identity_tensor((self.tile_m, self.tile_n)) + cS = cute.domain_offset((m_block * self.tile_m, n_block * self.tile_n), cS) + tScS = thr_mma_qk.partition_C(cS) + + apply_score_mod_inner( + acc_S, + tScS, + self.score_mod, + batch_idx, + head_idx, + softmax_scale, + self.vec_size, + self.qk_acc_dtype, + aux_tensors, + fastdiv_mods, + seqlen_info=seqlen, + constant_q_idx=None, + qhead_per_kvhead=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) # SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility diff --git a/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py new file mode 100644 index 0000000..07cd99f --- /dev/null +++ b/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py @@ -0,0 +1,3440 @@ +import math +import time +from functools import partial +from typing import Callable, Optional + +import torch +import torch.utils.benchmark as benchmark + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int64, Int32, Uint32, Boolean, const_expr +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.runtime import from_dlpack +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.utils import ClcDynamicPersistentTileScheduler + +from quack import copy_utils + +from flash_attn.cute.pack_gqa import pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_attn.cute.block_info import BlockInfo +from flash_attn.cute.mask import AttentionMask +import flash_attn.cute.blackwell_helpers as fa_sm100_utils +from flash_attn.cute.softmax import SoftmaxSm100 +from flash_attn.cute.tile_scheduler import ( + ClcState, + SchedulingMode, + TileSchedulerArguments, + TileSchedulerProtocol, + SingleTileScheduler, + SingleTileLPTScheduler, + SingleTileVarlenScheduler, + ParamsBase, +) +from flash_attn.cute.fa_logging import fa_log, fa_printf +from flash_attn.cute.utils import smid + +from flash_attn.cute.topk_gather_kv import CpasyncGatherKVManager + +from flash_attn.cute.testing import attention_ref + +from flash_attn.cute.named_barrier import NamedBarrierFwdSm100_MLA2CTA + +from flash_attn.cute.cute_dsl_utils import dump_kernel_attributes + + +class FlashAttentionMLAForwardSm100: + def __init__( + self, + is_causal: bool = False, + use_cpasync_load_KV: bool = False, + topk_length: int = 2048, + is_topk_gather: bool = True, + pack_gqa: bool = False, + qhead_per_kvhead: int = 1, + nheads_kv: int = 1, + hdim: int = 64, + hdimv: int = 512, + is_varlen_q: bool = False, + disable_bitmask: bool = False, + use_clc_scheduler: bool = True, + ): + self.is_causal = is_causal + self.is_local = False + self.pack_gqa = pack_gqa + self.qhead_per_kvhead = qhead_per_kvhead + self.nheads_kv = nheads_kv + self.is_varlen_q = is_varlen_q + self.use_tma_O = True + self.use_cpasync_load_KV = use_cpasync_load_KV + self.use_tma_KV = not use_cpasync_load_KV + self.topk_length = topk_length + self.is_topk_gather = is_topk_gather + if is_topk_gather: + assert pack_gqa + assert qhead_per_kvhead == 128, "require MQA 128 for DSA path" + assert use_cpasync_load_KV + # user-provided option if topk indices guaranteed in bounds + self.disable_bitmask = disable_bitmask + + # ==== tile scheduler ==== + self.is_persistent = False + self.use_clc_scheduler = use_clc_scheduler and not is_varlen_q + self.sched_stages = 1 + self.scheduling_mode = ( + SchedulingMode.CLC if self.use_clc_scheduler else SchedulingMode.STATIC + ) + + if const_expr(is_varlen_q): + self.TileScheduler = SingleTileVarlenScheduler + elif self.use_clc_scheduler: + self.TileScheduler = SingleTileLPTScheduler + else: + self.TileScheduler = SingleTileScheduler + + fa_log( + 1, + f"TileScheduler={self.TileScheduler.__name__}, scheduling_mode={self.scheduling_mode.name}", + ) + + # ==== thread info ==== + self.num_softmax_threads = 128 + self.num_epilogue_threads = 128 + self.num_load_threads = 32 + self.num_mma_threads = 32 + self.num_empty_threads = 32 if use_cpasync_load_KV else 64 + self.num_relay_threads = 32 if use_cpasync_load_KV else 0 + self.num_cpasync_load_threads = 128 if use_cpasync_load_KV else 0 + self.num_threads = ( + self.num_softmax_threads + + self.num_epilogue_threads + + self.num_load_threads + + self.num_mma_threads + + self.num_empty_threads + + self.num_relay_threads + + self.num_cpasync_load_threads + ) + self.num_warps = self.num_threads // 32 + assert self.num_warps == 12 or self.num_warps == 16 + self.softmax_warp_indices = (0, 1, 2, 3) + self.epilogue_warp_indices = (4, 5, 6, 7) + self.load_warp_id = 8 + self.mma_warp_id = 9 + self.clc_scheduler_warp_id = 10 + self.relay_warp_id = 11 + self.empty_warp_ids = tuple( + w + for w, active in [ + (self.relay_warp_id, not use_cpasync_load_KV), + (self.clc_scheduler_warp_id, not self.use_clc_scheduler), + ] + if active + ) + self.cpasync_load_warp_indices = (12, 13, 14, 15) + + # ==== register usage ==== + if self.num_warps == 16: + self.num_regs_load = 80 + self.num_regs_mma = 80 + self.num_regs_softmax = 208 + self.num_regs_epilogue = 128 + self.num_regs_cpasync = 96 if self.use_cpasync_load_KV else 0 + self.num_regs_other = 48 + else: + self.num_regs_load = 168 - 40 + self.num_regs_mma = 168 - 40 + self.num_regs_softmax = 168 + 80 + self.num_regs_epilogue = 168 - 40 + self.num_regs_cpasync = 0 + self.num_regs_other = 48 + + self.num_regs_per_thread = 168 if self.num_warps == 12 else 128 + self.num_regs_total = 504 if self.num_warps == 12 else 512 + + assert ( + self.num_regs_mma + + self.num_regs_softmax + + self.num_regs_epilogue + + self.num_regs_cpasync + <= self.num_regs_total + ) + + # ==== 2cta info ==== + self.use_2cta_instrs = True + self.cta_group = tcgen05.CtaGroup.TWO + self.cta_group_size = 2 + self.cluster_shape_mn = (2, 1) + self.cluster_shape_mnk = (2, 1, 1) + + # ==== problem shape info ==== + self.hdim = hdim + self.hdimv = hdimv + self.cta_tile_m = 64 + self.cluster_tile_m = self.cta_group_size * self.cta_tile_m + self.tile_n = 128 + assert ( + pack_gqa is False + or self.cluster_tile_m % qhead_per_kvhead == 0 + or qhead_per_kvhead % self.cluster_tile_m == 0 + ) + self.num_hdimv_splits = 2 # split hdimv in half for our Qv @ V^T and P @ V mmas. + assert hdimv % 32 == 0 + assert self.topk_length % (self.tile_n * 2) == 0 or not self.is_topk_gather + self.epi_tile = (self.cta_tile_m, self.hdimv // self.num_hdimv_splits) + + # ==== MMA info ==== + self.mma_tiler_QK = ( + self.cluster_tile_m, + self.tile_n, + self.hdim, + ) + self.mma_tiler_QviVi = ( + self.cluster_tile_m, + self.tile_n, + self.hdimv // self.num_hdimv_splits, + ) + self.mma_tiler_PVti = ( + self.cluster_tile_m, + self.hdimv // self.num_hdimv_splits, + self.tile_n, + ) + self.major_mode_Q = tcgen05.OperandMajorMode.K + self.major_mode_Qvi = tcgen05.OperandMajorMode.K + self.major_mode_K = tcgen05.OperandMajorMode.K + self.major_mode_Vi = tcgen05.OperandMajorMode.K + self.major_mode_Vti = tcgen05.OperandMajorMode.MN + self.major_mode_P = tcgen05.OperandMajorMode.K + self.operand_source_Q = tcgen05.OperandSource.SMEM + self.operand_source_Qvi = tcgen05.OperandSource.SMEM + self.operand_source_P = tcgen05.OperandSource.SMEM + + # ==== pipeline info ==== + self.num_stages_Q = 1 + self.num_stages_K = 1 + self.num_stages_Qvi = 1 + self.num_stages_Vi = 2 + self.num_stages_S = 2 + self.num_stages_P = 1 + self.num_stages_Oi = 1 + self.num_stages_sm_stats = 2 + self.num_stages_bitmask = 4 + assert self.num_stages_S == 2, "mainloops expect 2 stages for S" + + # ==== dtype info ==== + self.dtype_acc = Float32 + + # ==== TMEM info ==== + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + self.tmem_cols_S = self.tile_n // self.cta_group_size + self.tmem_cols_Oi = (self.hdimv // self.num_hdimv_splits) // self.cta_group_size + self.tmem_offset_S = [ + self.tmem_cols_S * stage for stage in range(self.num_stages_S) + ] # allocate 64 TMEM columns for each stage of S + self.tmem_offset_O0 = self.tmem_cols_S * self.num_stages_S + self.tmem_offset_O1 = self.tmem_offset_O0 + self.tmem_cols_Oi + self.tmem_offsets_O = [self.tmem_offset_O0, self.tmem_offset_O1] + self.total_tmem = self.tmem_offset_O1 + self.tmem_cols_Oi + assert self.total_tmem <= self.tmem_alloc_cols, ( + f"Total TMEM columns allocated {self.total_tmem} exceeds capacity {self.tmem_alloc_cols}" + ) + + def _get_shared_storage_cls(self): + self.buffer_align_bytes = 1024 + + def smem_struct_align(dtype, staged_layout): + return cute.struct.Align[ + cute.struct.MemRange[dtype, cute.cosize(staged_layout)], + self.buffer_align_bytes, + ] + + def mbar_struct(num_stages): + return cute.struct.MemRange[Int64, 2 * num_stages] + + (sQ_struct, sK_struct, sQv0_struct, sQv1_struct, sV0_struct, sV1_struct, sP_struct) = ( + smem_struct_align(dtype, layout) + for dtype, layout in [ + (self.dtype_Q, self.sQ_layout_staged), + (self.dtype_K, self.sK_layout_staged), + (self.dtype_Qv, self.sQvi_layout_staged), + (self.dtype_Qv, self.sQvi_layout_staged), + (self.dtype_V, self.sVi_layout_staged), + (self.dtype_V, self.sVi_layout_staged), + (self.dtype_P, self.sP_layout_staged), + ] + ) + sStats_struct = cute.struct.MemRange[Float32, cute.cosize(self.sStats_layout)] + sScale_struct = cute.struct.MemRange[Float32, cute.cosize(self.sScale_layout)] + sBitmask_struct = cute.struct.MemRange[Uint32, cute.cosize(self.sBitmask_layout)] + + ( + mbar_ptr_Q_struct, + mbar_ptr_K_struct, + mbar_ptr_Qv0_struct, + mbar_ptr_Qv1_struct, + mbar_ptr_V0_struct, + mbar_ptr_V1_struct, + mbar_ptr_S_struct, + mbar_ptr_P_struct, + mbar_ptr_O0_struct, + mbar_ptr_O1_struct, + mbar_sm_stats_struct, + mbar_bitmask_struct, + ) = ( + mbar_struct(n) + for n in [ + self.num_stages_Q, + self.num_stages_K, + self.num_stages_Qvi, + self.num_stages_Qvi, + self.num_stages_Vi, + self.num_stages_Vi, + self.num_stages_S, + self.num_stages_P, + self.num_stages_Oi, + self.num_stages_Oi, + self.num_stages_sm_stats, + self.num_stages_bitmask, + ] + ) + mbar_ptr_tmem_dealloc_struct = Int64 + tmem_holding_buf_struct = Int32 + + self.sched_stages = 1 + clc_response_size = self.sched_stages * 4 if self.use_clc_scheduler else 0 + clc_mbar_size = self.sched_stages * 2 if self.use_clc_scheduler else 0 + + @cute.struct + class SharedStorage: + mbar_ptr_Q: mbar_ptr_Q_struct + mbar_ptr_K: mbar_ptr_K_struct + mbar_ptr_Qv0: mbar_ptr_Qv0_struct + mbar_ptr_Qv1: mbar_ptr_Qv1_struct + mbar_ptr_V0: mbar_ptr_V0_struct + mbar_ptr_V1: mbar_ptr_V1_struct + mbar_ptr_S: mbar_ptr_S_struct + mbar_ptr_P: mbar_ptr_P_struct + mbar_ptr_O0: mbar_ptr_O0_struct + mbar_ptr_O1: mbar_ptr_O1_struct + mbar_ptr_K_cpasync: mbar_ptr_K_struct + mbar_ptr_V0_cpasync: mbar_ptr_V0_struct + mbar_ptr_V1_cpasync: mbar_ptr_V1_struct + mbar_ptr_sm_stats: mbar_sm_stats_struct + mbar_ptr_bitmask: mbar_bitmask_struct + mbar_ptr_tmem_dealloc: mbar_ptr_tmem_dealloc_struct + tmem_holding_buf: tmem_holding_buf_struct + clc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, clc_mbar_size] + clc_response: cute.struct.MemRange[Int32, clc_response_size] + sO_empty_mbar_ptr: cutlass.Int64 + + sRowMax: sStats_struct + sRowSum: sStats_struct + sScale: sScale_struct + sBitmask: sBitmask_struct + sQv0: sQv0_struct + sQv1: sQv1_struct + sQ: sQ_struct + sK: sK_struct + sV0: sV0_struct + sV1: sV1_struct + sP: sP_struct + + # print("smem bytes = ", SharedStorage.size_in_bytes()) + + return SharedStorage + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, # (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + mQv: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, d) if there is cu_seqlens_q + mK: cute.Tensor, # (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table + mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table + mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + mLSE: Optional[cute.Tensor], # (b, h, s_q) or (h, total_q) if there is cu_seqlens_q + softmax_scale: Float32, + mCuSeqlensQ: Optional[cute.Tensor] = None, # (b + 1) + mCuSeqlensK: Optional[cute.Tensor] = None, # (b + 1) + mSeqUsedQ: Optional[cute.Tensor] = None, # (b) + mSeqUsedK: Optional[cute.Tensor] = None, # (b) + mIndexTopk: Optional[ + cute.Tensor + ] = None, # (b, s_q, topk) or (total_q, topk) if there is cu_seqlens_q + mPageTable: Optional[cute.Tensor] = None, + window_size_left: Int32 | int | None = None, + window_size_right: Int32 | int | None = None, + # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). + stream: cuda.CUstream = None, + ): + # ==== asserts for unimplemented features ==== + assert mPageTable is None, "page table tbd for MLA" + + # ==== dtype info ==== + self.dtype_Q = mQ.element_type + self.dtype_K = mK.element_type + self.dtype_Qv = mQv.element_type + self.dtype_V = mV.element_type + self.dtype_P = mV.element_type + self.dtype_O = mO.element_type + + # ==== Prepare Tensors ==== + new_stride = lambda mX: ( + *(cute.assume(s, divby=128 // mX.element_type.width) for s in mX.stride[:-1]), + mX.stride[-1], + ) + mQ, mQv, mK, mV, mO = [ + cute.make_tensor(mX.iterator, cute.make_layout(mX.shape, stride=new_stride(mX))) + for mX in (mQ, mQv, mK, mV, mO) + ] + + # (b, s, h, d) -> (s, d, h, b) or + # (total, h, d) -> (total, d, h) or + # (num_pages, page_size, h_k, d) -> (page_size, d, h_k, num_pages) + QO_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensQ is None) else [0, 2, 1] + KV_layout_transpose = [1, 3, 2, 0] if const_expr(mCuSeqlensK is None) else [0, 2, 1] + mQ, mQv, mO = [ + cute.make_tensor(mX.iterator, cute.select(mX.layout, mode=QO_layout_transpose)) + for mX in (mQ, mQv, mO) + ] + mK, mV = [ + cute.make_tensor(mX.iterator, cute.select(mX.layout, mode=KV_layout_transpose)) + for mX in (mK, mV) + ] + # (s_k, dv, h_k, b) -> (dv, s_k, h_k, b) or + # (total_k, dv, h_k) -> (dv, total_k, h_k) + V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] + mVt = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose)) + # (b, h, s_q) -> (s_q, h, b) or (h, total_q) -> (total_q, h) + # (b, s_q, topk) -> (topk, s_q, b) or (total_q, topk) -> (topk, total_q) + LSE_layout_transpose = [2, 1, 0] if const_expr(mCuSeqlensQ is None) else [1, 0] + mLSE, mIndexTopk = ( + cute.make_tensor(t.iterator, cute.select(t.layout, mode=LSE_layout_transpose)) + if t is not None + else None + for t in (mLSE, mIndexTopk) + ) + topk_length_dynamic = mIndexTopk.shape[0] if mIndexTopk is not None else None + + self.o_layout = cutlass.utils.LayoutEnum.from_tensor(mO) + + mO_og = mO + if const_expr(self.pack_gqa): + mQ, mQv, mO = [ + pack_gqa_layout(mX, self.qhead_per_kvhead, self.nheads_kv, head_idx=2) + for mX in (mQ, mQv, mO) + ] + if const_expr(mLSE is not None): + mLSE = pack_gqa_layout(mLSE, self.qhead_per_kvhead, self.nheads_kv, head_idx=1) + + def split_hdimv(m, dim: int): + """Re-tile mode `dim` of tensor `m` from hdimv into (hdimv//S, S), + and return (slice0, slice1) where slice_i selects chunk i.""" + S = self.num_hdimv_splits + chunk = self.hdimv // S + split_shape = (*m.shape[:dim], (chunk, S), *m.shape[dim + 1 :]) + split_stride = (*m.layout.stride[:dim], (1, chunk), *m.layout.stride[dim + 1 :]) + split = cute.make_tensor(m.iterator, cute.make_layout(split_shape, stride=split_stride)) + ndim = len(split.shape) + slices = [ + split[(*([None] * dim), (None, i), *([None] * (ndim - dim - 1)))] for i in range(S) + ] + return slices + + # (seqlen_q, hdimv//2, nheads, batch) or (total_q, hdimv//2, nheads) + mQv0, mQv1 = split_hdimv(mQv, dim=1) + mV0, mV1 = split_hdimv(mV, dim=1) + # (hdimv//2, seqlen_k, nheads_k, batch) or (hdimv//2, total_k, nheads_k) + mVt0, mVt1 = split_hdimv(mVt, dim=0) + + # ==== Prepare MMAs ==== + # (local_var, dtype_a, major_a, major_b, mma_tiler, operand_source_a) + # fmt: off + _mma_specs = [ + ("tiled_mma_QK", self.dtype_Q, self.major_mode_Q, self.major_mode_K, self.mma_tiler_QK, self.operand_source_Q), + ("tiled_mma_QviVi", self.dtype_Qv, self.major_mode_Qvi, self.major_mode_Vi, self.mma_tiler_QviVi, self.operand_source_Qvi), + ("tiled_mma_PVti", self.dtype_P, self.major_mode_P, self.major_mode_Vti, self.mma_tiler_PVti, self.operand_source_P), + ] + tiled_mma_QK, tiled_mma_QviVi, tiled_mma_PVti = ( + sm100_utils.make_trivial_tiled_mma( + dtype_a, major_a, major_b, self.dtype_acc, self.cta_group, mma_tiler[:2], operand_source_a, + ) + for _, dtype_a, major_a, major_b, mma_tiler, operand_source_a in _mma_specs + ) + # fmt: on + + # ==== Prepare SMEM layouts and TMAs ==== + # (attr, make_fn, tiled_mma, mma_tiler, dtype, num_stages) + # fmt: off + _smem_layout_specs = [ + ("sQ_layout", sm100_utils.make_smem_layout_a, tiled_mma_QK, self.mma_tiler_QK, self.dtype_Q, self.num_stages_Q), + ("sK_layout", sm100_utils.make_smem_layout_b, tiled_mma_QK, self.mma_tiler_QK, self.dtype_K, self.num_stages_K), + ("sQvi_layout", sm100_utils.make_smem_layout_a, tiled_mma_QviVi, self.mma_tiler_QviVi, self.dtype_Qv, self.num_stages_Qvi), + ("sVi_layout", sm100_utils.make_smem_layout_b, tiled_mma_QviVi, self.mma_tiler_QviVi, self.dtype_V, self.num_stages_Vi), + ("sVti_layout", sm100_utils.make_smem_layout_b, tiled_mma_PVti, self.mma_tiler_PVti, self.dtype_V, self.num_stages_Vi), + ("sP_layout", sm100_utils.make_smem_layout_a, tiled_mma_PVti, self.mma_tiler_PVti, self.dtype_P, self.num_stages_P), + ] + for attr, make_fn, tiled_mma, mma_tiler, dtype, num_stages in _smem_layout_specs: + ab_kwarg = "a_dtype" if make_fn is sm100_utils.make_smem_layout_a else "b_dtype" + staged = make_fn( + tiled_mma=tiled_mma, + mma_tiler_mnk=mma_tiler, + num_stages=num_stages, + **{ab_kwarg: dtype}, + ) + setattr(self, f"{attr}_staged", staged) + setattr(self, attr, cute.select(staged, mode=[0, 1, 2])) + # fmt: on + + self.sStats_layout = cute.make_layout((self.cta_tile_m, self.cta_group_size)) + self.sScale_layout = cute.make_layout((self.cta_tile_m, self.num_stages_sm_stats)) + self.sBitmask_layout = cute.make_layout((self.tile_n // 32, self.num_stages_bitmask)) + + # fmt: off + for attr, dtype, layout in [ + ("tma_copy_bytes_Q", self.dtype_Q, self.sQ_layout), + ("tma_copy_bytes_K", self.dtype_K, self.sK_layout), + ("tma_copy_bytes_Qvi", self.dtype_Qv, self.sQvi_layout), + ("tma_copy_bytes_Vi", self.dtype_V, self.sVi_layout), + ]: + setattr(self, attr, cute.size_in_bytes(dtype, layout) * self.cta_group_size) + # fmt: on + + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(self.cta_group) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_QK.thr_id.shape,) + ) + cta_shape = cta_layout_vmnk.shape + + def make_tma(make_fn, mX, smem_layout, mma_tiler, tiled_mma): + return make_fn(tma_load_op, mX, smem_layout, mma_tiler, tiled_mma, cta_shape) + + A, B = cute.nvgpu.make_tiled_tma_atom_A, cute.nvgpu.make_tiled_tma_atom_B + + # (atom_name, tensor_name, make_fn, m, smem_layout, mma_tiler, tiled_mma, kv_only) + # fmt: off + _tma_specs = [ + ("tma_atom_Q", "tma_tensor_Q", A, mQ, self.sQ_layout, self.mma_tiler_QK, tiled_mma_QK, False), + ("tma_atom_Qv0", "tma_tensor_Qv0", A, mQv0, self.sQvi_layout, self.mma_tiler_QviVi, tiled_mma_QviVi, False), + ("tma_atom_Qv1", "tma_tensor_Qv1", A, mQv1, self.sQvi_layout, self.mma_tiler_QviVi, tiled_mma_QviVi, False), + ("tma_atom_K", "tma_tensor_K", B, mK, self.sK_layout, self.mma_tiler_QK, tiled_mma_QK, True), + ("tma_atom_V0", "tma_tensor_V0", B, mV0, self.sVi_layout, self.mma_tiler_QviVi, tiled_mma_QviVi, True), + ("tma_atom_V1", "tma_tensor_V1", B, mV1, self.sVi_layout, self.mma_tiler_QviVi, tiled_mma_QviVi, True), + ("tma_atom_Vt0", "tma_tensor_Vt0", B, mVt0, self.sVti_layout, self.mma_tiler_PVti, tiled_mma_PVti, True), + ("tma_atom_Vt1", "tma_tensor_Vt1", B, mVt1, self.sVti_layout, self.mma_tiler_PVti, tiled_mma_PVti, True), + ] + _tmas = {} + for atom_name, tensor_name, make_fn, m, smem_layout, mma_tiler, tiled_mma, kv_only in _tma_specs: + _tmas[atom_name], _tmas[tensor_name] = ( + make_tma(make_fn, m, smem_layout, mma_tiler, tiled_mma) + if const_expr(not kv_only or self.use_tma_KV) + else (None, None) + ) + + (tma_atom_Q, tma_tensor_Q, + tma_atom_Qv0, tma_tensor_Qv0, + tma_atom_Qv1, tma_tensor_Qv1, + tma_atom_K, tma_tensor_K, + tma_atom_V0, tma_tensor_V0, + tma_atom_V1, tma_tensor_V1, + tma_atom_Vt0, tma_tensor_Vt0, + tma_atom_Vt1, tma_tensor_Vt1) = _tmas.values() + # fmt: on + + # ==== Set up Oi smem -> gmem tma store ==== + + self.overlap_sO_sV = True + if const_expr(self.overlap_sO_sV): + num_stages_sO = self.num_hdimv_splits * self.num_stages_Vi + else: + num_stages_sO = self.num_hdimv_splits + sO_layout = sm100_utils.make_smem_layout_epi( + self.dtype_O, self.o_layout, self.epi_tile, num_stages_sO + ) + self.ragged_tma_O = ( + self.use_tma_O + and self.is_varlen_q + and self.pack_gqa + and self.cta_tile_m % self.qhead_per_kvhead == 0 + ) + make_tiled_tma_atom_fn = ( + partial(make_packgqa_tiled_tma_atom, qhead_per_kvhead=self.qhead_per_kvhead, head_idx=2) + if const_expr(self.ragged_tma_O) + else cpasync.make_tiled_tma_atom + ) + if const_expr(self.use_tma_O): + mO_tma = mO_og if const_expr(self.ragged_tma_O) else mO + if const_expr(self.ragged_tma_O): + mO_tma = copy_utils.create_ragged_tensor_for_tma( + mO_tma, ragged_dim=0, ptr_shift=True + ) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + tma_atom_O, tma_tensor_O = make_tiled_tma_atom_fn( + tma_store_op, mO_tma, cute.select(sO_layout, mode=[0, 1]), self.epi_tile + ) + else: + tma_atom_O = None + tma_tensor_O = None + + # ==== Set up Oi rmem -> gmem copy ==== + universal_copy_bits = 128 + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype_O, + num_bits_per_copy=universal_copy_bits, + ) + thread_layout_O_r2g = cute.make_layout((64, 2), stride=(1, 64)) + value_layout_O_r2g = cute.make_layout( + (1, self.hdimv // self.num_hdimv_splits // self.cta_group_size) + ) + tiled_copy_O_r2g = cute.make_tiled_copy_tv( + atom=atom_universal_copy, + thr_layout=thread_layout_O_r2g, + val_layout=value_layout_O_r2g, + ) + + # ==== Allocate shared memory ==== + SharedStorage = self._get_shared_storage_cls() + + # ==== Tile scheduler ==== + + TileScheduler = self.TileScheduler + + tile_sched_args = TileSchedulerArguments( + num_block=cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tile_m), + num_head=cute.size(mQ.shape[2]), + num_batch=cute.size(mQ.shape[3]) + if const_expr(mCuSeqlensQ is None) + else cute.size(mCuSeqlensQ.shape[0] - 1), + num_splits=1, # todo: split_kv + seqlen_k=cute.size(mK.shape[0]), # todo: page table + headdim=self.hdim, + headdim_v=self.hdimv, + total_q=cute.size(mQ.shape[0]) + if const_expr(mCuSeqlensQ is not None) + else cute.size(mQ.shape[0]) * cute.size(mQ.shape[3]), + tile_shape_mn=( + self.cta_tile_m, + self.tile_n, + ), + mCuSeqlensQ=mCuSeqlensQ, + mSeqUsedQ=mSeqUsedQ, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + element_size=self.dtype_K.width // 8, + is_persistent=self.is_persistent, + # lpt=self.is_causal or self.is_local, + lpt=False, + is_split_kv=False, + cluster_shape_mn=self.cluster_shape_mn, + use_cluster_idx=False, + ) + tile_sched_params = TileScheduler.to_underlying_arguments( + tile_sched_args, scheduling_mode=self.scheduling_mode + ) + self.tile_scheduler_cls = TileScheduler + grid_dim = TileScheduler.get_grid_shape(tile_sched_params) + fa_printf(1, "grid = {}", grid_dim) + + # ==== Named Barrier ==== + self.cpasync_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.Cpasync), + num_threads=self.num_cpasync_load_threads, + ) + self.softmax_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.Softmax), + num_threads=self.num_softmax_threads, + ) + self.epi_barrier = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.Epilogue), + num_threads=self.num_epilogue_threads, + ) + # softmax -> correction + self.sm_stats_barrier_full = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.SoftmaxStatsFull), + num_threads=self.num_softmax_threads + self.num_epilogue_threads, + ) + self.sm_stats_barrier_empty = cutlass.pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.SoftmaxStatsEmpty), + num_threads=self.num_softmax_threads + self.num_epilogue_threads, + ) + + LOG2_E = math.log2(math.e) + softmax_scale_log2 = softmax_scale * LOG2_E + + # ==== Launch kernel ==== + self.kernel( + tma_tensor_Q, + tma_tensor_Qv0, + tma_tensor_Qv1, + tma_tensor_K if self.use_tma_KV else mK, + tma_tensor_V0 if self.use_tma_KV else mV0, + tma_tensor_V1 if self.use_tma_KV else mV1, + tma_tensor_Vt0 if self.use_tma_KV else mVt0, + tma_tensor_Vt1 if self.use_tma_KV else mVt1, + tma_tensor_O if self.use_tma_O else mO, + mLSE, + mCuSeqlensQ, + mCuSeqlensK, + mSeqUsedQ, + mSeqUsedK, + mIndexTopk, + tma_atom_Q, + tma_atom_Qv0, + tma_atom_Qv1, + tma_atom_K, + tma_atom_V0, + tma_atom_V1, + tma_atom_Vt0, + tma_atom_Vt1, + tma_atom_O, + tiled_copy_O_r2g, + self.sQ_layout_staged, + self.sK_layout_staged, + self.sQvi_layout_staged, + self.sVi_layout_staged, + self.sVti_layout_staged, + self.sP_layout_staged, + self.sStats_layout, + self.sScale_layout, + self.sBitmask_layout, + sO_layout, + tiled_mma_QK, + tiled_mma_QviVi, + tiled_mma_PVti, + softmax_scale, + softmax_scale_log2, + topk_length_dynamic, + tile_sched_params, + SharedStorage, + ).launch( + grid=grid_dim, + block=( + self.num_threads, + 1, + 1, + ), + cluster=self.cluster_shape_mnk, + smem=SharedStorage.size_in_bytes(), + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mQv0: cute.Tensor, + mQv1: cute.Tensor, + mK: cute.Tensor, + mV0: cute.Tensor, + mV1: cute.Tensor, + mVt0: cute.Tensor, + mVt1: cute.Tensor, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + mCuSeqlensQ: Optional[cute.Tensor], + mCuSeqlensK: Optional[cute.Tensor], + mSeqUsedQ: Optional[cute.Tensor], + mSeqUsedK: Optional[cute.Tensor], + mIndexTopk: Optional[cute.Tensor], + tma_atom_Q: cute.CopyAtom, + tma_atom_Qv0: cute.CopyAtom, + tma_atom_Qv1: cute.CopyAtom, + tma_atom_K: Optional[cute.CopyAtom], + tma_atom_V0: Optional[cute.CopyAtom], + tma_atom_V1: Optional[cute.CopyAtom], + tma_atom_Vt0: Optional[cute.CopyAtom], + tma_atom_Vt1: Optional[cute.CopyAtom], + tma_atom_O: Optional[cute.CopyAtom], + tiled_copy_O_r2g: cute.TiledCopy, + sQ_layout_staged: cute.ComposedLayout, + sK_layout_staged: cute.ComposedLayout, + sQvi_layout_staged: cute.ComposedLayout, + sVi_layout_staged: cute.ComposedLayout, + sVti_layout_staged: cute.ComposedLayout, + sP_layout_staged: cute.ComposedLayout, + sStats_layout: cute.Layout, + sScale_layout: cute.Layout, + sBitmask_layout: cute.Layout, + sO_layout: cute.ComposedLayout, + tiled_mma_QK: cute.TiledMma, + tiled_mma_QviVi: cute.TiledMma, + tiled_mma_PVti: cute.TiledMma, + softmax_scale: Float32, + softmax_scale_log2: Float32, + topk_length_dynamic: Optional[Int32], + tile_sched_params: ParamsBase, + SharedStorage: cutlass.Constexpr[Callable], + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + cta_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma_QK.thr_id.shape,) + ) + + cta_m_block, head_idx, batch_idx = cute.arch.block_idx() + cluster_m_block = cta_m_block // self.cta_group_size + mma_tile_coord_v = cta_m_block % cute.size(tiled_mma_QK.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + + # ==== Allocate SMEM ==== + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # ==== TMEM stuff ==== + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierFwdSm100_MLA2CTA.TmemPtr), + num_threads=self.num_mma_threads + self.num_softmax_threads + self.num_epilogue_threads, + ) + tmem = cutlass.utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=tmem_alloc_barrier, + allocator_warp_id=self.mma_warp_id, + is_two_cta=self.use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.mbar_ptr_tmem_dealloc, + ) + + # ==== Prefetch TMA descriptors ==== + if warp_idx == self.load_warp_id: + cpasync.prefetch_descriptor(tma_atom_Q) + cpasync.prefetch_descriptor(tma_atom_Qv0) + cpasync.prefetch_descriptor(tma_atom_Qv1) + if const_expr(self.use_tma_KV): + cpasync.prefetch_descriptor(tma_atom_K) + cpasync.prefetch_descriptor(tma_atom_V0) + cpasync.prefetch_descriptor(tma_atom_V1) + cpasync.prefetch_descriptor(tma_atom_Vt0) + cpasync.prefetch_descriptor(tma_atom_Vt1) + if const_expr(self.use_tma_O): + cpasync.prefetch_descriptor(tma_atom_O) + + # ==== Construct pipelines ==== + tma_warp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + mma_warp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + sm_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_softmax_threads) + epi_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_epilogue_threads) + sm_threads_cluster = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.num_softmax_threads * self.cta_group_size + ) + epi_threads_cluster = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.num_epilogue_threads * self.cta_group_size + ) + + TmaUmma = pipeline.PipelineTmaUmma + AsyncUmma = pipeline.PipelineAsyncUmma + UmmaAsync = pipeline.PipelineUmmaAsync + Async = pipeline.PipelineAsync + + def make_pipeline(cls, mbar_ptr, num_stages, producer, consumer, tx_count=None): + return cls.create( + barrier_storage=mbar_ptr.data_ptr(), + num_stages=num_stages, + producer_group=producer, + consumer_group=consumer, + defer_sync=True, + **({"cta_layout_vmnk": cta_layout_vmnk} if cls is not Async else {}), + **({"tx_count": tx_count} if tx_count is not None else {}), + ) + + # Unconditional pipelines + # fmt: off + pipeline_Q = make_pipeline(TmaUmma, storage.mbar_ptr_Q, self.num_stages_Q, tma_warp, mma_warp, self.tma_copy_bytes_Q) + pipeline_Qv0 = make_pipeline(TmaUmma, storage.mbar_ptr_Qv0, self.num_stages_Qvi, tma_warp, mma_warp, self.tma_copy_bytes_Qvi) + pipeline_Qv1 = make_pipeline(TmaUmma, storage.mbar_ptr_Qv1, self.num_stages_Qvi, tma_warp, mma_warp, self.tma_copy_bytes_Qvi) + pipeline_S = make_pipeline(UmmaAsync, storage.mbar_ptr_S, self.num_stages_S, mma_warp, sm_threads_cluster) + pipeline_P = make_pipeline(AsyncUmma, storage.mbar_ptr_P, self.num_stages_P, sm_threads_cluster, mma_warp) + pipeline_O0 = make_pipeline(UmmaAsync, storage.mbar_ptr_O0, self.num_stages_Oi, mma_warp, epi_threads_cluster) + pipeline_O1 = make_pipeline(UmmaAsync, storage.mbar_ptr_O1, self.num_stages_Oi, mma_warp, epi_threads_cluster) + pipeline_sm_stats = make_pipeline(Async, storage.mbar_ptr_sm_stats, self.num_stages_sm_stats, sm_threads, epi_threads) + + # K/V pipelines: type and producer depend on use_tma_KV + if const_expr(self.use_tma_KV): + pipeline_K = make_pipeline(TmaUmma, storage.mbar_ptr_K, self.num_stages_K, tma_warp, mma_warp, self.tma_copy_bytes_K) + pipeline_V0 = make_pipeline(TmaUmma, storage.mbar_ptr_V0, self.num_stages_Vi, tma_warp, mma_warp, self.tma_copy_bytes_Vi) + pipeline_V1 = make_pipeline(TmaUmma, storage.mbar_ptr_V1, self.num_stages_Vi, tma_warp, mma_warp, self.tma_copy_bytes_Vi) + pipeline_K_cpasync = pipeline_V0_cpasync = pipeline_V1_cpasync = pipeline_bitmask = None + else: + cpasync_load_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_cpasync_load_threads) + relay_warps_cluster = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.cta_group_size) + relay_threads = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_relay_threads) + + pipeline_K = make_pipeline(AsyncUmma, storage.mbar_ptr_K, self.num_stages_K, relay_warps_cluster, mma_warp) + pipeline_V0 = make_pipeline(AsyncUmma, storage.mbar_ptr_V0, self.num_stages_Vi, relay_warps_cluster, mma_warp) + pipeline_V1 = make_pipeline(AsyncUmma, storage.mbar_ptr_V1, self.num_stages_Vi, relay_warps_cluster, mma_warp) + pipeline_K_cpasync = make_pipeline(Async, storage.mbar_ptr_K_cpasync, self.num_stages_K, cpasync_load_threads, relay_threads) + pipeline_V0_cpasync = make_pipeline(Async, storage.mbar_ptr_V0_cpasync, self.num_stages_Vi, cpasync_load_threads, relay_threads) + pipeline_V1_cpasync = make_pipeline(Async, storage.mbar_ptr_V1_cpasync, self.num_stages_Vi, cpasync_load_threads, relay_threads) + pipeline_bitmask = ( + make_pipeline(Async, storage.mbar_ptr_bitmask, self.num_stages_bitmask, cpasync_load_threads, sm_threads) + if const_expr(self.is_topk_gather and not self.disable_bitmask) else None + ) + # fmt: on + + sO_empty_mbar_ptr = None + if const_expr(self.use_tma_O and self.overlap_sO_sV): + sO_empty_mbar_ptr = storage.sO_empty_mbar_ptr + if warp_idx == 0: + cute.arch.mbarrier_init(sO_empty_mbar_ptr, 1) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cta_layout_vmnk, is_relaxed=True) + + # ==== Get SMEM tensors ==== + # fmt: off + sQ, sK, sQv0, sQv1, sV0, sV1, sVt0, sVt1, sP = ( + store.get_tensor(layout.outer, swizzle=layout.inner) + for store, layout in [ + (storage.sQ, sQ_layout_staged), + (storage.sK, sK_layout_staged), + (storage.sQv0, sQvi_layout_staged), + (storage.sQv1, sQvi_layout_staged), + (storage.sV0, sVi_layout_staged), + (storage.sV1, sVi_layout_staged), + (storage.sV0, sVti_layout_staged), # sVt0 reuses sV0 storage + (storage.sV1, sVti_layout_staged), # sVt1 reuses sV1 storage + (storage.sP, sP_layout_staged), + ] + ) + # fmt: on + sRowMax = storage.sRowMax.get_tensor(sStats_layout) + sRowSum = storage.sRowSum.get_tensor(sStats_layout) + sScale = storage.sScale.get_tensor(sScale_layout) + sBitmask = None + if const_expr(self.is_topk_gather): + sBitmask = storage.sBitmask.get_tensor(sBitmask_layout) + + if const_expr(self.overlap_sO_sV): + sO_iterator = sV0.iterator + assert cute.cosize(sO_layout) <= cute.cosize(sVi_layout_staged) * self.num_hdimv_splits + else: + sO_iterator = sQv0.iterator + assert cute.cosize(sO_layout) <= cute.cosize(sQvi_layout_staged) * self.num_hdimv_splits + sO = cute.make_tensor( + cute.recast_ptr(sO_iterator, sO_layout.inner, self.dtype_O), sO_layout.outer + ) + + # ==== Get thread MMAs and accumulator fragments ==== + thr_mma_QK = tiled_mma_QK.get_slice(mma_tile_coord_v) + thr_mma_QviVi = tiled_mma_QviVi.get_slice(mma_tile_coord_v) + thr_mma_PVti = tiled_mma_PVti.get_slice(mma_tile_coord_v) + + acc_shape_QK = thr_mma_QK.partition_shape_C(self.mma_tiler_QK[:2]) + tStS = thr_mma_QK.make_fragment_C(cute.append(acc_shape_QK, self.num_stages_S)) + + acc_shape_PVi = thr_mma_PVti.partition_shape_C(self.mma_tiler_PVti[:2]) + tO0tO0 = thr_mma_PVti.make_fragment_C(acc_shape_PVi) + tO1tO1 = thr_mma_PVti.make_fragment_C(acc_shape_PVi) + tO0tO0 = cute.make_tensor(tO0tO0.iterator + self.tmem_offset_O0, tO0tO0.layout) + tO1tO1 = cute.make_tensor(tO1tO1.iterator + self.tmem_offset_O1, tO1tO1.layout) + + block_info = BlockInfo( + self.cta_tile_m * self.cta_group_size, + self.tile_n, + is_causal=self.is_causal, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + SeqlenInfoCls = partial( + SeqlenInfoQK.create, + seqlen_q_static=mQ.shape[0] if const_expr(not self.pack_gqa) else mQ.shape[0][1], + seqlen_k_static=mK.shape[0], + tile_m=self.cta_tile_m, + tile_n=self.tile_n, + mCuSeqlensQ=mCuSeqlensQ, + mCuSeqlensK=mCuSeqlensK, + mSeqUsedQ=mSeqUsedQ, + mSeqUsedK=mSeqUsedK, + ) + AttentionMaskCls = partial( + AttentionMask, + self.cta_tile_m * self.cta_group_size, + self.tile_n, + qhead_per_kvhead_packgqa=self.qhead_per_kvhead if const_expr(self.pack_gqa) else 1, + ) + + if const_expr(self.use_clc_scheduler): + clc_response_ptr = storage.clc_response.data_ptr() + clc_mbar_ptr = storage.clc_mbar_ptr.data_ptr() + + clc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_clc_consumer_warps_per_cta = self.num_threads // cute.arch.WARP_SIZE + num_clc_consumer_warps = num_clc_consumer_warps_per_cta * self.cta_group_size + clc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, cute.arch.WARP_SIZE * num_clc_consumer_warps + ) + clc = ClcState.create( + hw_scheduler=ClcDynamicPersistentTileScheduler.create( + self.tile_scheduler_cls.clc_problem_shape(tile_sched_params), + cute.arch.block_idx(), + cute.arch.grid_dim(), + clc_response_ptr, + ), + pipeline=pipeline.PipelineClcFetchAsync.create( + barrier_storage=clc_mbar_ptr, + num_stages=self.sched_stages, + producer_group=clc_pipeline_producer_group, + consumer_group=clc_pipeline_consumer_group, + tx_count=16, + cta_layout_vmnk=cta_layout_vmnk, + ), + consumer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.sched_stages + ), + producer_state=pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.sched_stages + ), + ) + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params, clc=clc) + else: + tile_scheduler = self.tile_scheduler_cls.create(tile_sched_params) + assert isinstance(tile_scheduler, TileSchedulerProtocol), ( + f"tile_scheduler is not a TileSchedulerProtocol: {type(tile_scheduler)}" + ) + + pipeline.pipeline_init_wait(cluster_shape_mn=cta_layout_vmnk) + + if const_expr(self.use_clc_scheduler): + if warp_idx == self.clc_scheduler_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_other) + if is_leader_cta: + self.clc_scheduler_warp(tile_scheduler) + else: + self.empty_warp(tile_scheduler) + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i] and warp_idx != self.clc_scheduler_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_other) + self.empty_warp(tile_scheduler) + else: + for i in cutlass.range_constexpr(len(self.empty_warp_ids)): + if warp_idx == self.empty_warp_ids[i]: + cute.arch.setmaxregister_decrease(self.num_regs_other) + + if const_expr(self.use_cpasync_load_KV): + if warp_idx == self.relay_warp_id: + if const_expr(self.num_regs_load < self.num_regs_per_thread): + cute.arch.setmaxregister_decrease(self.num_regs_load) + self.relay( + pipeline_K, + pipeline_V0, + pipeline_V1, + pipeline_K_cpasync, + pipeline_V0_cpasync, + pipeline_V1_cpasync, + sO_empty_mbar_ptr, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + + if warp_idx in self.cpasync_load_warp_indices: + if const_expr(self.num_regs_cpasync < self.num_regs_per_thread): + cute.arch.setmaxregister_decrease(self.num_regs_cpasync) + self.load_cpasync( + mIndexTopk, + mK, + mV0, + mV1, + mVt0, + mVt1, + sK, + sV0, + sV1, + sVt0, + sVt1, + sBitmask, + pipeline_K, + pipeline_V0, + pipeline_V1, + pipeline_K_cpasync, + pipeline_V0_cpasync, + pipeline_V1_cpasync, + pipeline_bitmask, + sO_empty_mbar_ptr, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + + if warp_idx == self.load_warp_id: + if const_expr(self.num_regs_load < self.num_regs_per_thread): + cute.arch.setmaxregister_decrease(self.num_regs_load) + self.load( + mQ, + mK, + mQv0, + mQv1, + mV0, + mV1, + mVt0, + mVt1, + sQ, + sK, + sQv0, + sQv1, + sV0, + sV1, + sVt0, + sVt1, + tma_atom_Q, + tma_atom_K, + tma_atom_Qv0, + tma_atom_Qv1, + tma_atom_V0, + tma_atom_V1, + tma_atom_Vt0, + tma_atom_Vt1, + pipeline_Q, + pipeline_K, + pipeline_Qv0, + pipeline_Qv1, + pipeline_V0, + pipeline_V1, + sO_empty_mbar_ptr, + thr_mma_QK, + thr_mma_QviVi, + thr_mma_PVti, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + + if warp_idx == self.mma_warp_id: + if const_expr(self.num_regs_mma < self.num_regs_per_thread): + cute.arch.setmaxregister_decrease(self.num_regs_mma) + # ==== Allocate TMEM ==== + tmem.allocate(self.tmem_alloc_cols) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.dtype_acc) + self.mma( + sQ, + sK, + sQv0, + sQv1, + sV0, + sV1, + sVt0, + sVt1, + sP, + tiled_mma_QK, + tiled_mma_QviVi, + tiled_mma_PVti, + pipeline_Q, + pipeline_K, + pipeline_Qv0, + pipeline_Qv1, + pipeline_V0, + pipeline_V1, + pipeline_S, + pipeline_P, + pipeline_O0, + pipeline_O1, + sO_empty_mbar_ptr, + is_leader_cta, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + tmem.relinquish_alloc_permit() + tmem_alloc_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + + if warp_idx in self.softmax_warp_indices: + cute.arch.setmaxregister_increase(self.num_regs_softmax) + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.dtype_acc) + self.softmax_loop( + softmax_scale, + softmax_scale_log2, + mLSE, + sRowMax, + sRowSum, + sScale, + sBitmask, + sP, + tStS, + thr_mma_QK, + pipeline_S, + pipeline_P, + pipeline_sm_stats, + pipeline_bitmask, + sO_empty_mbar_ptr, + AttentionMaskCls, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + tmem_alloc_barrier.arrive() + + if warp_idx in self.epilogue_warp_indices: + if const_expr(self.num_regs_epilogue < self.num_regs_per_thread): + cute.arch.setmaxregister_decrease(self.num_regs_epilogue) + elif const_expr(self.num_regs_epilogue > self.num_regs_per_thread): + cute.arch.setmaxregister_increase(self.num_regs_epilogue) + + tmem.wait_for_alloc() + tmem_ptr = tmem.retrieve_ptr(self.dtype_acc) + self.correction_loop( + softmax_scale_log2, + mO, + mLSE, + tma_atom_O, + sRowMax, + sRowSum, + sScale, + sO, + tO0tO0, + tO1tO1, + pipeline_O0, + pipeline_O1, + pipeline_sm_stats, + sO_empty_mbar_ptr, + tiled_copy_O_r2g, + topk_length_dynamic, + block_info, + SeqlenInfoCls, + tile_scheduler=tile_scheduler, + ) + tmem_alloc_barrier.arrive() + + @cute.jit + def clc_scheduler_warp( + self, + tile_scheduler: TileSchedulerProtocol, + ): + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + tile_scheduler.prefetch_next_work() + work_tile = tile_scheduler.advance_to_next_work() + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + if cute.arch.thread_idx()[0] == self.clc_scheduler_warp_id * cute.arch.WARP_SIZE: + fa_printf( + 3, + "[CLC] query sm={} cta={} (m_blk={},h={},b={},s={}) valid={}\n", + smid(), + cute.arch.block_idx()[0], + work_tile.tile_idx[0], + work_tile.tile_idx[1], + work_tile.tile_idx[2], + work_tile.tile_idx[3], + work_tile.is_valid_tile, + ) + tile_scheduler.producer_tail() + + @cute.jit + def empty_warp( + self, + tile_scheduler: TileSchedulerProtocol, + ): + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + work_tile = tile_scheduler.advance_to_next_work() + + @cute.jit + def relay( + self, + pipeline_K: pipeline.PipelineAsyncUmma, + pipeline_V0: pipeline.PipelineAsyncUmma, + pipeline_V1: pipeline.PipelineAsyncUmma, + pipeline_K_cpasync: pipeline.PipelineAsync, + pipeline_V0_cpasync: pipeline.PipelineAsync, + pipeline_V1_cpasync: pipeline.PipelineAsync, + sO_empty_mbar_ptr: Optional[cute.Pointer], + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + # ==== Make pipeline states ==== + # pipeline_{K,V0,V1} producer + # pipeline_{K,V0,V1}_cpasync consumer + producer_state_K = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_K + ) + producer_state_V0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + producer_state_V1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + consumer_state_K = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_K + ) + consumer_state_V0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Vi + ) + consumer_state_V1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Vi + ) + relay_K_fn = partial(self.relay_inner, pipeline_K_cpasync, pipeline_K) + relay_V0_fn = partial(self.relay_inner, pipeline_V0_cpasync, pipeline_V0) + relay_V1_fn = partial(self.relay_inner, pipeline_V1_cpasync, pipeline_V1) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + n_block_max = self.topk_length // self.tile_n + # n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, + cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + + # ==== Prologue ==== + # relay K, V0, V1 + consumer_state_K, producer_state_K = relay_K_fn(consumer_state_K, producer_state_K) + consumer_state_V0, producer_state_V0 = relay_V0_fn(consumer_state_V0, producer_state_V0) + consumer_state_V1, producer_state_V1 = relay_V1_fn(consumer_state_V1, producer_state_V1) + + # ==== Mainloop ==== + for _ in cutlass.range(num_n_blocks - 1, unroll=2): + # relay K, V0, V1, Vt0, Vt1 + consumer_state_K, producer_state_K = relay_K_fn(consumer_state_K, producer_state_K) + for _ in cutlass.range_constexpr(2): + consumer_state_V0, producer_state_V0 = relay_V0_fn( + consumer_state_V0, producer_state_V0 + ) + consumer_state_V1, producer_state_V1 = relay_V1_fn( + consumer_state_V1, producer_state_V1 + ) + + # ==== Epilogue === + # relay Vt0, Vt1 + consumer_state_V0, producer_state_V0 = relay_V0_fn(consumer_state_V0, producer_state_V0) + consumer_state_V1, producer_state_V1 = relay_V1_fn(consumer_state_V1, producer_state_V1) + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + + pipeline_K.producer_tail(producer_state_K) + pipeline_V0.producer_tail(producer_state_V0) + pipeline_V1.producer_tail(producer_state_V1) + + @cute.jit + def relay_inner( + self, + pipeline_cpasync: pipeline.PipelineAsync, + pipeline_mma: pipeline.PipelineAsyncUmma, + consumer_state: pipeline.PipelineState, + producer_state: pipeline.PipelineState, + ): + pipeline_cpasync.consumer_wait(consumer_state) + with cute.arch.elect_one(): + pipeline_mma.producer_commit(producer_state) + consumer_state.advance() + producer_state.advance() + return consumer_state, producer_state + + @cute.jit + def load_cpasync( + self, + mIndexTopk: cute.Tensor, + mK: cute.Tensor, + mV0: cute.Tensor, + mV1: cute.Tensor, + mVt0: cute.Tensor, + mVt1: cute.Tensor, + sK: cute.Tensor, + sV0: cute.Tensor, + sV1: cute.Tensor, + sVt0: cute.Tensor, + sVt1: cute.Tensor, + sBitmask: Optional[cute.Tensor], + pipeline_K: pipeline.PipelineAsyncUmma, + pipeline_V0: pipeline.PipelineAsyncUmma, + pipeline_V1: pipeline.PipelineAsyncUmma, + pipeline_K_cpasync: pipeline.PipelineAsync, + pipeline_V0_cpasync: pipeline.PipelineAsync, + pipeline_V1_cpasync: pipeline.PipelineAsync, + pipeline_bitmask: pipeline.PipelineAsync, + sO_empty_mbar_ptr: Optional[cute.Pointer], + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + # ==== cpasync load warpgroup ==== + # Description: loads tiles of K, V, V0, V1 from gmem to smem using cpasync + # produces: K, V, V0, V1, bitmask + # consumes: - + + # TODO: use cpasync for non-topk paged attn + assert sBitmask is not None, "cpasync load meant to be used with topk gather" + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + tidx = cute.arch.thread_idx()[0] % self.num_cpasync_load_threads + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % ( + self.num_cpasync_load_threads // 32 + ) + + # ==== Make pipeline states ==== + # producer: acquire PipelineAsyncUmma <- mma + # producer: commit PipelineAsync -> relay + producer_state_K = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_K + ) + producer_state_V0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + producer_state_V1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + if const_expr(not self.disable_bitmask): + producer_state_bitmask = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, + stages=self.num_stages_bitmask, + ) + if const_expr(self.use_tma_O): + producer_phase_O = Int32(1) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) + + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + n_block_max = self.topk_length // self.tile_n + # n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, + cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + num_n_block_groups = cute.ceil_div(num_n_blocks, self.num_stages_S) + + # cluster_m_block == m_idx under MQA 128 assumption + m_idx = cluster_m_block + if const_expr(not seqlen.has_cu_seqlens_q): + mIndexTopk_cur = mIndexTopk[None, m_idx, batch_idx] + else: + offset_q = seqlen.offset_q + mIndexTopk_cur = mIndexTopk[None, m_idx + offset_q] + + if const_expr(self.is_causal): + seqlen_k_limit = m_idx + 1 + seqlen.seqlen_k - seqlen.seqlen_q + else: + seqlen_k_limit = seqlen.seqlen_k + cpasync_gather_kv_manager = CpasyncGatherKVManager.create( + mIndexTopk_cur, + sBitmask, + cta_rank_in_cluster, + tidx, + warp_idx, + self.topk_length, + seqlen_k_limit, + self.tile_n, + self.hdim, + self.hdimv, + self.num_hdimv_splits, + self.num_cpasync_load_threads, + mK.element_type, + self.cta_group_size, + pipeline_bitmask, + self.num_stages_bitmask, + self.cpasync_barrier, + self.disable_bitmask, + ) + + # (seqlen_k, hdim) or (seqlen_k, hdimv//2) + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] + mV0_cur = seqlen.offset_batch_K(mV0, batch_idx, dim=3)[None, None, head_idx_kv] + mV1_cur = seqlen.offset_batch_K(mV1, batch_idx, dim=3)[None, None, head_idx_kv] + # (hdimv//2, seqlen_k) + if const_expr(not seqlen.has_cu_seqlens_k): + mVt0_cur = mVt0[None, None, head_idx_kv, batch_idx] + mVt1_cur = mVt1[None, None, head_idx_kv, batch_idx] + else: + mVt0_cur = cute.domain_offset((0, seqlen.offset_k), mVt0[None, None, head_idx_kv]) + mVt1_cur = cute.domain_offset((0, seqlen.offset_k), mVt1[None, None, head_idx_kv]) + # (hdimv//4, seqlen_k) + hdimv_split_per_cta = self.hdimv // self.num_hdimv_splits // self.cta_group_size + mVt0_cur = cute.tiled_divide(mVt0_cur, (hdimv_split_per_cta,))[ + None, cta_rank_in_cluster, None + ] + mVt1_cur = cute.tiled_divide(mVt1_cur, (hdimv_split_per_cta,))[ + None, cta_rank_in_cluster, None + ] + + load_K = partial( + self.cpasync_gather_load_KV, + cpasync_gather_kv_manager, + pipeline_K, + pipeline_K_cpasync, + sK, + False, + "K", + mK_cur, + ) + load_V0 = partial( + self.cpasync_gather_load_KV, + cpasync_gather_kv_manager, + pipeline_V0, + pipeline_V0_cpasync, + sV0, + False, + "V", + mV0_cur, + ) + load_V1 = partial( + self.cpasync_gather_load_KV, + cpasync_gather_kv_manager, + pipeline_V1, + pipeline_V1_cpasync, + sV1, + False, + "V", + mV1_cur, + ) + load_Vt0 = partial( + self.cpasync_gather_load_KV, + cpasync_gather_kv_manager, + pipeline_V0, + pipeline_V0_cpasync, + sVt0, + True, + "V", + mVt0_cur, + ) + load_Vt1 = partial( + self.cpasync_gather_load_KV, + cpasync_gather_kv_manager, + pipeline_V1, + pipeline_V1_cpasync, + sVt1, + True, + "V", + mVt1_cur, + ) + + # gather KV path processes n_blocks in increasing order + n_block = 0 + + # ==== Prologue ==== + # K, V0, V1 + cpasync_gather_kv_manager.load_index_topk(n_block, transpose=False) + producer_state_K = load_K(producer_state_K) + producer_state_V0 = load_V0(producer_state_V0) + producer_state_V1 = load_V1(producer_state_V1) + if const_expr(not self.disable_bitmask): + producer_state_bitmask = cpasync_gather_kv_manager.compute_bitmask( + producer_state_bitmask + ) + + if const_expr(self.use_tma_O and self.overlap_sO_sV): + cute.arch.mbarrier_wait(sO_empty_mbar_ptr, phase=producer_phase_O) + producer_phase_O ^= 1 + + # ==== Mainloop ==== + for n_block_group in cutlass.range(num_n_block_groups - 1, unroll=1): + for stage in cutlass.range_constexpr(self.num_stages_S): + n_block = n_block_group * self.num_stages_S + stage + # K, V0, V1 + cpasync_gather_kv_manager.load_index_topk(n_block + 1, transpose=False) + producer_state_K = load_K(producer_state_K) + producer_state_V0 = load_V0(producer_state_V0) + producer_state_V1 = load_V1(producer_state_V1) + if const_expr(not self.disable_bitmask): + producer_state_bitmask = cpasync_gather_kv_manager.compute_bitmask( + producer_state_bitmask + ) + # Vt0, Vt1 + cpasync_gather_kv_manager.load_index_topk(n_block, transpose=True) + producer_state_V0 = load_Vt0(producer_state_V0) + producer_state_V1 = load_Vt1(producer_state_V1) + + # ==== Epilogue ==== + for stage in cutlass.range_constexpr(self.num_stages_S): + n_block = (num_n_block_groups - 1) * self.num_stages_S + stage + if const_expr(stage == 0): + # K, V0, V1 + cpasync_gather_kv_manager.load_index_topk(n_block + 1, transpose=False) + producer_state_K = load_K(producer_state_K) + producer_state_V0 = load_V0(producer_state_V0) + producer_state_V1 = load_V1(producer_state_V1) + if const_expr(not self.disable_bitmask): + producer_state_bitmask = cpasync_gather_kv_manager.compute_bitmask( + producer_state_bitmask + ) + + # Vt0, Vt1 + cpasync_gather_kv_manager.load_index_topk(n_block, transpose=True) + producer_state_V0 = load_Vt0(producer_state_V0) + producer_state_V1 = load_Vt1(producer_state_V1) + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + + pipeline_K_cpasync.producer_tail(producer_state_K) + pipeline_V0_cpasync.producer_tail(producer_state_V0) + pipeline_V1_cpasync.producer_tail(producer_state_V1) + if const_expr(not self.disable_bitmask): + pipeline_bitmask.producer_tail(producer_state_bitmask) + + @cute.jit + def cpasync_gather_load_KV( + self, + cpasync_gather_kv_manager: CpasyncGatherKVManager, + pipeline_mma: pipeline.PipelineAsyncUmma, + pipeline_cpasync: pipeline.PipelineAsync, + sX: cute.Tensor, + transpose: bool, + K_or_V: str, + mX: cute.Tensor, + producer_state: pipeline.PipelineState, + ): + stage, phase = producer_state.index, producer_state.phase + pipeline_mma.producer_acquire(producer_state) + cpasync_gather_kv_manager.load_X(mX, sX[None, None, None, stage], transpose, K_or_V) + cute.arch.cp_async_commit_group() + pipeline_cpasync.sync_object_full.arrive_cp_async_mbarrier(stage) + producer_state.advance() + return producer_state + + @cute.jit + def load( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mQv0: cute.Tensor, + mQv1: cute.Tensor, + mV0: cute.Tensor, + mV1: cute.Tensor, + mVt0: cute.Tensor, + mVt1: cute.Tensor, + sQ: cute.Tensor, + sK: cute.Tensor, + sQv0: cute.Tensor, + sQv1: cute.Tensor, + sV0: cute.Tensor, + sV1: cute.Tensor, + sVt0: cute.Tensor, + sVt1: cute.Tensor, + tma_atom_Q: cute.CopyAtom, + tma_atom_K: cute.CopyAtom, + tma_atom_Qv0: cute.CopyAtom, + tma_atom_Qv1: cute.CopyAtom, + tma_atom_V0: cute.CopyAtom, + tma_atom_V1: cute.CopyAtom, + tma_atom_Vt0: cute.CopyAtom, + tma_atom_Vt1: cute.CopyAtom, + pipeline_Q: pipeline.PipelineAsync, + pipeline_K: pipeline.PipelineAsync, + pipeline_Qv0: pipeline.PipelineAsync, + pipeline_Qv1: pipeline.PipelineAsync, + pipeline_V0: pipeline.PipelineAsync, + pipeline_V1: pipeline.PipelineAsync, + sO_empty_mbar_ptr: Optional[cute.Pointer], + thr_mma_QK: cute.ThrMma, + thr_mma_QviVi: cute.ThrMma, + thr_mma_PVti: cute.ThrMma, + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + # ==== Load warp ==== + # Description: loads tiles of Q, Qv, K, V, V0, V1 from gmem to smem using TMA + # produces: Q, Qv, K, V, V0, V1 + # consumes: - + + mQvs = [mQv0, mQv1] + mVs = [mV0, mV1] + mVts = [mVt0, mVt1] + + sQvs = [sQv0, sQv1] + sVs = [sV0, sV1] + sVts = [sVt0, sVt1] + + tma_atom_Qvs = [tma_atom_Qv0, tma_atom_Qv1] + tma_atom_Vs = [tma_atom_V0, tma_atom_V1] + tma_atom_Vts = [tma_atom_Vt0, tma_atom_Vt1] + + pipeline_Qvs = [pipeline_Qv0, pipeline_Qv1] + pipeline_Vs = [pipeline_V0, pipeline_V1] + + # ==== Make pipeline states ==== + producer_state_Q = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Q + ) + producer_state_Qv0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Qvi + ) + producer_state_Qv1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Qvi + ) + if const_expr(self.use_tma_KV): + producer_state_K = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_K + ) + producer_state_V0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + producer_state_V1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Vi + ) + if const_expr(self.use_tma_O): + producer_phase_O = Int32(1) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + head_idx_kv = ( + head_idx // self.qhead_per_kvhead if const_expr(not self.pack_gqa) else head_idx + ) + + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + n_block_max = self.topk_length // self.tile_n + # n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, + cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + even_n_blocks = num_n_blocks % 2 == 0 and num_n_blocks > 0 + num_n_block_groups = cute.ceil_div(num_n_blocks, self.num_stages_S) + + # ==== Partition GMEM tensors ==== + # (seqlen_q, hdim or hdimv//2) + mQ_cur = seqlen.offset_batch_Q(mQ, batch_idx, dim=3)[None, None, head_idx] + mQvs_cur = [ + seqlen.offset_batch_Q(mQvs[split], batch_idx, dim=3)[None, None, head_idx] + for split in range(self.num_hdimv_splits) + ] + # (mma_tile_m, hdim or hdimv//2) + gQ = cute.local_tile( + mQ_cur, + (self.mma_tiler_QK[0], self.mma_tiler_QK[2]), + (cluster_m_block, 0), + ) + gQvs = [ + cute.local_tile( + mQvs_cur[split], + (self.mma_tiler_QviVi[0], self.mma_tiler_QviVi[2]), + (cluster_m_block, 0), + ) + for split in range(self.num_hdimv_splits) + ] + tSgQ = thr_mma_QK.partition_A(gQ) + tSgQvs = [ + thr_mma_QviVi.partition_A(gQvs[split]) for split in range(self.num_hdimv_splits) + ] + tQsQ, tQgQ = cpasync.tma_partition( + atom=tma_atom_Q, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sQ, 0, 3), + gmem_tensor=cute.group_modes(tSgQ, 0, 3), + ) + tQvsQvs, tQvgQvs = zip( + *[ + cpasync.tma_partition( + atom=tma_atom, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sQv, 0, 3), + gmem_tensor=cute.group_modes(tSgQv, 0, 3), + ) + for tma_atom, sQv, tSgQv in zip(tma_atom_Qvs, sQvs, tSgQvs) + ] + ) + + if const_expr(self.use_tma_KV): + # (seqlen_k, hdim) or (seqlen_k, hdimv//2) + mK_cur = seqlen.offset_batch_K(mK, batch_idx, dim=3)[None, None, head_idx_kv] + mVs_cur = [ + seqlen.offset_batch_K(mVs[split], batch_idx, dim=3)[None, None, head_idx_kv] + for split in range(self.num_hdimv_splits) + ] + # (hdimv//2, seqlen_k) + if const_expr(not seqlen.has_cu_seqlens_k): + mVts_cur = [ + mVts[split][None, None, head_idx_kv, batch_idx] + for split in range(self.num_hdimv_splits) + ] + else: + mVts_cur = [ + cute.domain_offset( + (0, seqlen.offset_k), mVts[split][None, None, head_idx_kv] + ) + for split in range(self.num_hdimv_splits) + ] + # (tile_n, hdim or hdimv//2, num_n_blocks) + gK = cute.local_tile( + mK_cur, + (self.mma_tiler_QK[1], self.mma_tiler_QK[2]), + (None, 0), + ) + gVs = [ + cute.local_tile( + mVs_cur[split], + (self.mma_tiler_QviVi[1], self.mma_tiler_QviVi[2]), + (None, 0), + ) + for split in range(self.num_hdimv_splits) + ] + # (hdim or hdimv//2, tile_n, num_n_blocks) + gVts = [ + cute.local_tile( + mVts_cur[split], + (self.mma_tiler_PVti[1], self.mma_tiler_PVti[2]), + (0, None), + ) + for split in range(self.num_hdimv_splits) + ] + tSgK = thr_mma_QK.partition_B(gK) + tSgVs = [ + thr_mma_QviVi.partition_B(gVs[split]) for split in range(self.num_hdimv_splits) + ] + tOgVts = [ + thr_mma_PVti.partition_B(gVts[split]) for split in range(self.num_hdimv_splits) + ] + tKsK, tKgK = cpasync.tma_partition( + atom=tma_atom_K, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sK, 0, 3), + gmem_tensor=cute.group_modes(tSgK, 0, 3), + ) + tVsVs, tVgVs = zip( + *[ + cpasync.tma_partition( + atom=tma_atom, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sV, 0, 3), + gmem_tensor=cute.group_modes(tSgV, 0, 3), + ) + for tma_atom, sV, tSgV in zip(tma_atom_Vs, sVs, tSgVs) + ] + ) + tVtsVts, tVtgVts = zip( + *[ + cpasync.tma_partition( + atom=tma_atom, + cta_coord=0, + cta_layout=cute.make_layout(1), + smem_tensor=cute.group_modes(sVt, 0, 3), + gmem_tensor=cute.group_modes(tOgV, 0, 3), + ) + for tma_atom, sVt, tOgV in zip(tma_atom_Vts, sVts, tOgVts) + ] + ) + + load_Q = partial(self.load_inner, tma_atom_Q, tQgQ, tQsQ, pipeline_Q) + load_Qv = partial(self.load_inner, tma_atom_Qvs, tQvgQvs, tQvsQvs, pipeline_Qvs) + if const_expr(self.use_tma_KV): + load_K = partial(self.load_inner, tma_atom_K, tKgK, tKsK, pipeline_K) + load_V = partial(self.load_inner, tma_atom_Vs, tVgVs, tVsVs, pipeline_Vs) + load_Vt = partial(self.load_inner, tma_atom_Vts, tVtgVts, tVtsVts, pipeline_Vs) + + # ==== Load stationary operands ==== + + # copy Q, Qvi gmem -> smem + producer_state_Q = load_Q(producer_state_Q) + producer_state_Qv0 = load_Qv(producer_state_Qv0, split=0) + producer_state_Qv1 = load_Qv(producer_state_Qv1, split=1) + + if const_expr(self.use_tma_KV): + # ==== Prologue ==== + n_block_first = n_block_max - 1 + # copy K gmem -> smem + producer_state_K = load_K(producer_state_K, n_block=n_block_first) + # copy Vi gmem -> smem + producer_state_V0 = load_V(producer_state_V0, n_block=n_block_first, split=0) + producer_state_V1 = load_V(producer_state_V1, n_block=n_block_first, split=1) + + if const_expr(self.use_tma_O and self.overlap_sO_sV): + cute.arch.mbarrier_wait(sO_empty_mbar_ptr, phase=producer_phase_O) + producer_phase_O ^= 1 + + # ==== Main loop ==== + for n_block_group in cutlass.range(num_n_block_groups - 1, unroll=1): + for stage in cutlass.range_constexpr(self.num_stages_S): + n_block = n_block_max - 1 - n_block_group * self.num_stages_S - stage + # copy K gmem -> smem + producer_state_K = load_K(producer_state_K, n_block=n_block - 1) + # copy Vi gmem -> smem + producer_state_V0 = load_V(producer_state_V0, n_block=n_block - 1, split=0) + producer_state_V1 = load_V(producer_state_V1, n_block=n_block - 1, split=1) + # copy Vti gmem -> smem + producer_state_V0 = load_Vt(producer_state_V0, n_block=n_block, split=0) + producer_state_V1 = load_Vt(producer_state_V1, n_block=n_block, split=1) + + # ==== Epilogue ==== + num_final_n_blocks = self.num_stages_S if even_n_blocks else self.num_stages_S - 1 + for stage in cutlass.range(num_final_n_blocks, unroll_full=True): + n_block = num_final_n_blocks - 1 - stage + if n_block > 0: + # copy K gmem -> smem + producer_state_K = load_K(producer_state_K, n_block=n_block - 1) + # copy Vi gmem -> smem + producer_state_V0 = load_V(producer_state_V0, n_block=n_block - 1, split=0) + producer_state_V1 = load_V(producer_state_V1, n_block=n_block - 1, split=1) + # copy Vti gmem -> smem + producer_state_V0 = load_Vt(producer_state_V0, n_block=n_block, split=0) + producer_state_V1 = load_Vt(producer_state_V1, n_block=n_block, split=1) + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + + pipeline_Q.producer_tail(producer_state_Q) + pipeline_Qv0.producer_tail(producer_state_Qv0) + pipeline_Qv1.producer_tail(producer_state_Qv1) + if const_expr(self.use_tma_KV): + pipeline_K.producer_tail(producer_state_K) + pipeline_V0.producer_tail(producer_state_V0) + pipeline_V1.producer_tail(producer_state_V1) + + @cute.jit + def load_inner( + self, + tma_atom: cute.CopyAtom, + tXgX: cute.Tensor, + tXsX: cute.Tensor, + load_pipeline: pipeline.PipelineAsync, + producer_state: pipeline.PipelineState, + n_block: Optional[Int32] = None, + split: Optional[Int32] = None, + ): + stage = producer_state.index + if const_expr(split is not None): + tma_atom = tma_atom[split] + tXgX = tXgX[split] + tXsX = tXsX[split] + load_pipeline = load_pipeline[split] + if const_expr(n_block is not None): + tXgX = tXgX[(None, n_block)] + tXsX = tXsX[(None, stage)] + + load_pipeline.producer_acquire(producer_state) + tma_bar_ptr = load_pipeline.producer_get_barrier(producer_state) + cute.copy(tma_atom, tXgX, tXsX, tma_bar_ptr=tma_bar_ptr) + producer_state.advance() + return producer_state + + @cute.jit + def mma( + self, + sQ: cute.Tensor, + sK: cute.Tensor, + sQv0: cute.Tensor, + sQv1: cute.Tensor, + sV0: cute.Tensor, + sV1: cute.Tensor, + sVt0: cute.Tensor, + sVt1: cute.Tensor, + sP: cute.Tensor, + tiled_mma_QK: cute.TiledMma, + tiled_mma_QviVi: cute.TiledMma, + tiled_mma_PVti: cute.TiledMma, + pipeline_Q: pipeline.PipelineAsync, + pipeline_K: pipeline.PipelineAsync, + pipeline_Qv0: pipeline.PipelineAsync, + pipeline_Qv1: pipeline.PipelineAsync, + pipeline_V0: pipeline.PipelineAsync, + pipeline_V1: pipeline.PipelineAsync, + pipeline_S: pipeline.PipelineAsync, + pipeline_P: pipeline.PipelineAsync, + pipeline_O0: pipeline.PipelineAsync, + pipeline_O1: pipeline.PipelineAsync, + sO_empty_mbar_ptr: Optional[cute.Pointer], + is_leader_cta: Boolean, + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + # ==== mma warp ==== + # Description: Computes Q @ K^T, Qv @ V^T, and P @ V + # Produces: S, O + # Consumes: Q, K, Qv, V, P + + pipelines_V = [pipeline_V0, pipeline_V1] + pipelines_Qv = [pipeline_Qv0, pipeline_Qv1] + pipelines_O = [pipeline_O0, pipeline_O1] + + sQvs = [sQv0, sQv1] + sVs = [sV0, sV1] + sVts = [sVt0, sVt1] + + # Set accumulate = True for Qv @ V^T since we are accumulating on the Q @ K^T result + tiled_mma_QviVi.set(tcgen05.Field.ACCUMULATE, True) + + # Operands for S = Q @ K^T + tSrQ = tiled_mma_QK.make_fragment_A(sQ) + tSrK = tiled_mma_QK.make_fragment_B(sK) + + # Operands for S += Qv @ V^T + tSrQvs = [ + tiled_mma_QviVi.make_fragment_A(sQvs[split]) for split in range(self.num_hdimv_splits) + ] + tSrVs = [ + tiled_mma_QviVi.make_fragment_B(sVs[split]) for split in range(self.num_hdimv_splits) + ] + + # Operands for Oi = P @ Vi + tOrP = tiled_mma_PVti.make_fragment_A(sP) + tOrVts = [ + tiled_mma_PVti.make_fragment_B(sVts[split]) for split in range(self.num_hdimv_splits) + ] + + # GEMM functions + gemm_QK = [ + partial( + fa_sm100_utils.gemm_ptx_partial, + tiled_mma_QK.op, + self.tmem_offset_S[stage], + tCrA=tSrQ[None, None, None, 0], + sA=sQ[None, None, None, 0], + zero_init=True, + cta_group=self.cta_group_size, + ) + for stage in range(self.num_stages_S) + ] + gemms_QvV = [ + [ + partial( + fa_sm100_utils.gemm_ptx_partial, + tiled_mma_QviVi.op, + self.tmem_offset_S[stage], + tCrA=tSrQvs[split][None, None, None, 0], + sA=sQvs[split][None, None, None, 0], + zero_init=False, + cta_group=self.cta_group_size, + ) + for stage in range(self.num_stages_S) + ] + for split in range(self.num_hdimv_splits) + ] + gemms_PVt = [ + partial( + fa_sm100_utils.gemm_ptx_partial, + tiled_mma_PVti.op, + self.tmem_offsets_O[split], + tOrP[None, None, None, 0], + sA=sP[None, None, None, 0], + cta_group=self.cta_group_size, + ) + for split in range(self.num_hdimv_splits) + ] + + consumer_state_Q = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Q + ) + consumer_state_K = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_K + ) + consumer_state_Qv0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Qvi + ) + consumer_state_Qv1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Qvi + ) + consumer_state_V0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Vi + ) + consumer_state_V1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Vi + ) + producer_state_S = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_S + ) + consumer_state_P = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_P + ) + producer_state_O0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Oi + ) + producer_state_O1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_Oi + ) + + mma_QK = partial(self.mma_inner, gemm_QK, pipeline_K, tSrK, sK) + mma_QvV = partial(self.mma_inner, gemms_QvV, pipelines_V, tSrVs, sVs) + mma_PVt = partial(self.mma_inner, gemms_PVt, pipelines_V, tOrVts, sVts) + + work_tile = tile_scheduler.initial_work_tile_info() + O_should_accumulate = False + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + # n_block_max = self.topk_length // self.tile_n + n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, + cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + even_n_blocks = num_n_blocks % 2 == 0 and num_n_blocks > 0 + num_n_block_groups = cute.ceil_div(num_n_blocks, self.num_stages_S) + + if is_leader_cta: + pipeline_Q.consumer_wait(consumer_state_Q) + pipeline_Qv0.consumer_wait(consumer_state_Qv0) + pipeline_Qv1.consumer_wait(consumer_state_Qv1) + + consumer_states_V = [consumer_state_V0, consumer_state_V1] + producer_states_O = [producer_state_O0, producer_state_O1] + + # ==== Prologue ==== + pipeline_S.producer_acquire(producer_state_S) + # S = Q @ K^T + consumer_state_K = mma_QK(consumer_state_K, stage=0) + # S += Qvi @ Vi^T + for split in cutlass.range_constexpr(self.num_hdimv_splits): + consumer_states_V[split] = mma_QvV( + consumer_states_V[split], stage=0, split=split + ) + pipeline_S.producer_commit(producer_state_S) + producer_state_S.advance() + + # ==== Mainloop ==== + for _ in cutlass.range(num_n_block_groups - 1, unroll=1): + for stage in cutlass.range_constexpr(self.num_stages_S): + next_stage = const_expr((stage + 1) % self.num_stages_S) + pipeline_S.producer_acquire(producer_state_S) + # S = Q @ K^T + consumer_state_K = mma_QK(consumer_state_K, stage=next_stage) + # S += Qvi @ Vi^T + for split in cutlass.range_constexpr(self.num_hdimv_splits): + consumer_states_V[split] = mma_QvV( + consumer_states_V[split], stage=next_stage, split=split + ) + pipeline_S.producer_commit(producer_state_S) + producer_state_S.advance() + # Oi += P @ Vi + pipeline_P.consumer_wait(consumer_state_P) + for split in cutlass.range_constexpr(self.num_hdimv_splits): + producer_state_Oi = producer_states_O[split] + pipelines_O[split].producer_acquire(producer_state_Oi) + consumer_states_V[split] = mma_PVt( + consumer_states_V[split], + split=split, + zero_init=not O_should_accumulate, + ) + pipelines_O[split].producer_commit(producer_state_Oi) + producer_state_Oi.advance() + producer_states_O[split] = producer_state_Oi + pipeline_P.consumer_release(consumer_state_P) + consumer_state_P.advance() + O_should_accumulate = True + + # ==== Epilogue ==== + num_final_n_blocks = self.num_stages_S if even_n_blocks else self.num_stages_S - 1 + for stage in cutlass.range_constexpr(self.num_stages_S): + n_block = num_final_n_blocks - 1 - stage + if const_expr(stage == 0): + if n_block > 0: + pipeline_S.producer_acquire(producer_state_S) + # S = Q @ K^T + consumer_state_K = mma_QK(consumer_state_K, stage=stage + 1) + # S += Qvi @ Vi^T + for split in cutlass.range_constexpr(self.num_hdimv_splits): + consumer_states_V[split] = mma_QvV( + consumer_states_V[split], stage=stage + 1, split=split + ) + pipeline_S.producer_commit(producer_state_S) + producer_state_S.advance() + if n_block >= 0: + # Oi += P @ Vi + pipeline_P.consumer_wait(consumer_state_P) + for split in cutlass.range_constexpr(self.num_hdimv_splits): + producer_state_Oi = producer_states_O[split] + pipelines_O[split].producer_acquire(producer_state_Oi) + consumer_states_V[split] = mma_PVt( + consumer_states_V[split], + split=split, + zero_init=not O_should_accumulate, + ) + pipelines_O[split].producer_commit(producer_state_Oi) + producer_state_Oi.advance() + producer_states_O[split] = producer_state_Oi + pipeline_P.consumer_release(consumer_state_P) + consumer_state_P.advance() + O_should_accumulate = True + + consumer_state_V0, consumer_state_V1 = consumer_states_V + producer_state_O0, producer_state_O1 = producer_states_O + + pipeline_Q.consumer_release(consumer_state_Q) + + # if we overlap sOi with sQvi for tma store, need to acquire signal + if const_expr(self.use_tma_O and not self.overlap_sO_sV): + pipeline_O0.producer_tail(producer_state_O0.clone()) + pipeline_O1.producer_tail(producer_state_O1.clone()) + + pipeline_Qv0.consumer_release(consumer_state_Qv0) + pipeline_Qv1.consumer_release(consumer_state_Qv1) + consumer_state_Q.advance() + consumer_state_Qv0.advance() + consumer_state_Qv1.advance() + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + O_should_accumulate = False + + pipeline_S.producer_tail(producer_state_S) + pipeline_O0.producer_tail(producer_state_O0) + pipeline_O1.producer_tail(producer_state_O1) + + @cute.jit + def mma_inner( + self, + gemm, + load_pipeline, + tCrB, + sB, + consumer_state: pipeline.PipelineState, + stage: Optional[Int32] = None, + split: Optional[Int32] = None, + zero_init: Optional[bool] = None, + ): + if const_expr(split is not None): + gemm = gemm[split] + load_pipeline = load_pipeline[split] + tCrB = tCrB[split] + sB = sB[split] + if const_expr(stage is not None): + gemm = gemm[stage] + + smem_stage = consumer_state.index + tCrB_cur = tCrB[None, None, None, smem_stage] + sB_cur = sB[None, None, None, smem_stage] + + load_pipeline.consumer_wait(consumer_state) + if const_expr(zero_init is not None): + gemm(tCrB=tCrB_cur, sB=sB_cur, zero_init=zero_init) + else: + gemm(tCrB=tCrB_cur, sB=sB_cur) + load_pipeline.consumer_release(consumer_state) + consumer_state.advance() + return consumer_state + + @cute.jit + def softmax_loop( + self, + softmax_scale: Float32, + softmax_scale_log2: Float32, + mLSE: Optional[cute.Tensor], + sRowMax: cute.Tensor, + sRowSum: cute.Tensor, + sScale: cute.Tensor, + sBitmask: Optional[cute.Tensor], + sP: cute.Tensor, + tStS: cute.Tensor, + thr_mma_QK: cute.ThrMma, + pipeline_S: pipeline.PipelineAsync, + pipeline_P: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + pipeline_bitmask: Optional[pipeline.PipelineAsync], + sO_empty_mbar_ptr: Optional[cute.Pointer], + AttentionMaskCls: Callable, + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + # ==== softmax warpgroup ==== + # Description: computes softmax on S and writes the result to P + # Produces: P, softmax stats + # Consumes: S, bitmask (for topk sparsity) + + tidx = cute.arch.thread_idx()[0] % self.num_softmax_threads + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % ( + self.num_softmax_threads // 32 + ) + + tSAcc = tStS[(None, None), 0, 0, 0] + tSAcc_staged = [tStS[(None, None), 0, 0, stage] for stage in range(self.num_stages_S)] + + cS = cute.make_identity_tensor(self.mma_tiler_QK[:2]) # (128, 128) + tScS = thr_mma_QK.partition_C(cS)[(None, None), 0, 0] # (64, 128) + + # S tmem -> rmem copy objects + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.dtype_acc, + ) + tmem_load_tiled = tcgen05.make_tmem_copy(tmem_load_atom, tSAcc) + tmem_load_thr = tmem_load_tiled.get_slice(tidx) + # S tmem -> rmem copy operands + tStS_t2r = tmem_load_thr.partition_S(tSAcc) # (((32, 32), 1), 1, 2) + tStS_t2r_staged = [ + tmem_load_thr.partition_S(tSAcc_staged[stage]) for stage in range(self.num_stages_S) + ] + tScS_t2r = tmem_load_thr.partition_D(tScS) + tSrS_t2r = cute.make_rmem_tensor(tScS_t2r.shape, self.dtype_acc) + + # P rmem -> smem copy objects + universal_copy_bits = 128 + smem_store_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.dtype_P, + num_bits_per_copy=universal_copy_bits, + ) + smem_store_tiled = cute.make_tiled_copy_D(smem_store_atom, tmem_load_tiled) + smem_store_thr = smem_store_tiled.get_slice(tidx) + # P rmem -> smem copy operands + sP_slice = sP[None, None, None, 0] + sP_mn = cute.make_tensor( + sP_slice.iterator, + cute.make_layout( + ( + (sP_slice.shape[0][0], sP_slice.shape[1]), + (sP_slice.shape[0][1], sP_slice.shape[2]), + ), + stride=( + (sP_slice.stride[0][0], sP_slice.stride[1]), + (sP_slice.stride[0][1], sP_slice.stride[2]), + ), + ), + ) + sP_smem_view = smem_store_thr.partition_D(sP_mn) + + consumer_state_S = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_S + ) + producer_state_P = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_P + ) + producer_state_sm_stats = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, stages=self.num_stages_sm_stats + ) + consumer_state_bitmask = None + if const_expr(self.is_topk_gather and not self.disable_bitmask): + consumer_state_bitmask = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_bitmask + ) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + n_block_max = self.topk_length // self.tile_n + # n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, + cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + even_n_blocks = num_n_blocks % 2 == 0 and num_n_blocks > 0 + num_n_block_groups = cute.ceil_div(num_n_blocks, self.num_stages_S) + + mask = AttentionMaskCls(seqlen) + mask_fn = partial( + mask.apply_mask_sm100, + m_block=cluster_m_block, + thr_mma=thr_mma_QK, + thr_tmem_load=tmem_load_thr, + mask_causal=self.is_causal, + mask_local=self.is_local, + batch_idx=batch_idx, + head_idx=head_idx, + r2p=False, # TODO: fix r2p for 2cta + ) + disable_mask = self.disable_bitmask and self.is_topk_gather + + softmax = SoftmaxSm100.create( + softmax_scale_log2, + rescale_threshold=8.0 if const_expr(self.dtype_Q.width == 16) else 0.0, + softmax_scale=softmax_scale, + ) + softmax.reset() + + softmax_step_fn = partial( + self.softmax_step, + softmax, + sRowMax, + sScale, + sBitmask, + tStS_t2r_staged, + tSrS_t2r, + sP_smem_view, + tmem_load_thr, + smem_store_thr, + pipeline_S, + pipeline_P, + pipeline_sm_stats, + pipeline_bitmask, + tidx, + warp_idx, + ) + + ### first iteration ### + n_block = n_block_max - 1 + ( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + ) = softmax_step_fn( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + 0, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=True) + if not const_expr(disable_mask) + else None, + is_first=True, + ) + n_block -= 1 + + ### Separate iterations with causal masking + # note: For square mma tile, can mask at most 1 n_block_group + if const_expr((self.is_causal or self.is_local) and not self.is_topk_gather): + n_block_min_causal_local_mask = block_info.get_n_block_min_causal_local_mask( + seqlen, cluster_m_block, n_block_min + ) + num_masked_n_blocks = n_block_max - 1 - n_block_min_causal_local_mask + num_masked_n_block_groups = min( + num_n_block_groups - 1, cute.ceil_div(num_masked_n_blocks, self.num_stages_S) + ) + num_n_block_groups -= num_masked_n_block_groups + for _ in cutlass.range(num_masked_n_block_groups, unroll=1): + for stage in cutlass.range_constexpr(self.num_stages_S): + ( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + ) = softmax_step_fn( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + 1 - stage, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False), + ) + n_block -= 1 + + ### Mainloop ### + for n_block_group in cutlass.range(num_n_block_groups - 1, unroll=1): + for stage in cutlass.range_constexpr(self.num_stages_S): + ( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + ) = softmax_step_fn( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + 1 - stage, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False) + if const_expr(self.is_topk_gather and not self.disable_bitmask) + else None, + ) + n_block -= 1 + + ### last iteration if even ### + # always mask to simplify logic + if even_n_blocks: + ( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + ) = softmax_step_fn( + consumer_state_S, + producer_state_P, + producer_state_sm_stats, + consumer_state_bitmask, + 1, + n_block, + mask_fn=partial(mask_fn, mask_seqlen=False) + if not const_expr(disable_mask) + else None, + ) + n_block -= 1 + + # write row max and sum to smem + sRowSum[tidx % self.cta_tile_m, warp_idx // self.cta_group_size] = softmax.row_sum[0] + if const_expr(mLSE is not None): + if tidx < self.cta_tile_m: + sRowMax[tidx, 0] = softmax.row_max[0] + self.sm_stats_barrier_full.arrive() + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + self.sm_stats_barrier_empty.arrive_and_wait() + + pipeline_P.producer_tail(producer_state_P) + pipeline_sm_stats.producer_tail(producer_state_sm_stats) + + @cute.jit + def softmax_step( + self, + softmax: SoftmaxSm100, + sRowMax: cute.Tensor, + sScale: cute.Tensor, + sBitmask: Optional[cute.Tensor], + tStS_t2r_staged: cute.Tensor, + tSrS_t2r: cute.Tensor, + sP_smem_view: cute.Tensor, + tmem_load_thr: cute.CopyAtom, + smem_store_thr: cute.CopyAtom, + pipeline_S: pipeline.PipelineAsync, + pipeline_P: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + pipeline_bitmask: Optional[pipeline.PipelineAsync], + tidx: Int32, + warp_idx: Int32, + consumer_state_S: pipeline.PipelineState, + producer_state_P: pipeline.PipelineState, + producer_state_sm_stats: pipeline.PipelineState, + consumer_state_bitmask: Optional[pipeline.PipelineState], + stage: cutlass.Constexpr[Int32], + n_block: Int32, + mask_fn: Optional[Callable] = None, + is_first: Boolean = False, + ): + tSrP = cute.make_rmem_tensor(tSrS_t2r.shape, self.dtype_P) + rP_smem_view = smem_store_thr.retile(tSrP) + + pipeline_S.consumer_wait(consumer_state_S) + cute.copy(tmem_load_thr, tStS_t2r_staged[stage], tSrS_t2r) + cute.arch.fence_view_async_tmem_load() + pipeline_S.consumer_release(consumer_state_S) + + rBitmask = None + if const_expr(self.is_topk_gather and not self.disable_bitmask): + assert pipeline_bitmask is not None + assert consumer_state_bitmask is not None + pipeline_bitmask.consumer_wait(consumer_state_bitmask) + rBitmask = cute.make_rmem_tensor((self.tile_n // 64,), dtype=Uint32) + bitmask_col_offset = self.tile_n // 64 if warp_idx >= 2 else 0 + for i in cutlass.range_constexpr(cute.size(rBitmask)): + rBitmask[i] = sBitmask[bitmask_col_offset + i, consumer_state_bitmask.index] + + if const_expr(mask_fn is not None): + mask_fn(tSrS_t2r, n_block=n_block, rBitmask=rBitmask) + + # compute threadwise row_max + row_max = softmax.compute_row_max_local(tSrS_t2r.load(), is_first) + self.softmax_barrier.arrive_and_wait() + + # 2-thread reduce row_max through smem + assert self.cta_tile_m * self.cta_group_size == 128 + sRowMax[tidx % self.cta_tile_m, warp_idx // self.cta_group_size] = row_max + self.softmax_barrier.arrive_and_wait() + # must release after barrier sync + if const_expr(self.is_topk_gather and not self.disable_bitmask): + pipeline_bitmask.consumer_release(consumer_state_bitmask) + row_max0 = sRowMax[tidx % self.cta_tile_m, 0] + row_max1 = sRowMax[tidx % self.cta_tile_m, 1] + row_max = max(row_max0, row_max1) + + row_max, acc_scale = softmax.update_row_max_from_local(row_max, is_first) + + # note: acc_scales agree for paired threads + pipeline_sm_stats.producer_acquire(producer_state_sm_stats) + if warp_idx < self.cta_group_size: + sScale[tidx % self.cta_tile_m, producer_state_sm_stats.index] = acc_scale + pipeline_sm_stats.producer_commit(producer_state_sm_stats) + + # x -> scale_log2*x-rowmax + softmax.scale_subtract_rowmax(tSrS_t2r, row_max) + + # x -> exp2(x) + softmax.apply_exp2_convert(tSrS_t2r, tSrP) + + pipeline_P.producer_acquire(producer_state_P) + cute.copy(smem_store_thr, rP_smem_view, sP_smem_view) + cute.arch.fence_view_async_shared() + pipeline_P.producer_commit(producer_state_P) + + consumer_state_S.advance() + producer_state_P.advance() + producer_state_sm_stats.advance() + if const_expr(self.is_topk_gather and not self.disable_bitmask): + consumer_state_bitmask.advance() + + softmax.update_row_sum(tSrS_t2r.load(), acc_scale, is_first) + + return consumer_state_S, producer_state_P, producer_state_sm_stats, consumer_state_bitmask + + @cute.jit + def correction_loop( + self, + softmax_scale_log2: Float32, + mO: cute.Tensor, + mLSE: Optional[cute.Tensor], + tma_atom_O: Optional[cute.CopyAtom], + sRowMax: cute.Tensor, + sRowSum: cute.Tensor, + sScale: cute.Tensor, + sO: cute.Tensor, + tO0tO0: cute.Tensor, + tO1tO1: cute.Tensor, + pipeline_O0: pipeline.PipelineAsync, + pipeline_O1: pipeline.PipelineAsync, + pipeline_sm_stats: pipeline.PipelineAsync, + sO_empty_mbar_ptr: Optional[cute.Pointer], + tiled_copy_O_r2g: cute.TiledCopy, + topk_length_dynamic: Optional[Int32], + block_info: BlockInfo, + SeqlenInfoCls: Callable, + tile_scheduler: TileSchedulerProtocol, + ): + ### ==== correction/epilogue warpgroup ==== + # Correction: copy scale smem -> rmem, copy O tmem -> rmem, rescale O, store O rmem -> tmem + # Epilogue: copy O tmem -> rmem, do final scaling of O, store O rmem -> gmem, + # optionally store LSE + # Produces: - + # Consumes: O, softmax stats + + tidx = cute.arch.thread_idx()[0] % self.num_epilogue_threads + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) % ( + self.num_epilogue_threads // 32 + ) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + leader_warp = warp_idx == 0 + + tO0tO0 = tO0tO0[(None, None), 0, 0] # (64, (128, 2)) + tO1tO1 = tO1tO1[(None, None), 0, 0] # (64, (128, 2)) + tOtOs = [tO0tO0, tO1tO1] + + # tuneable parameter + corr_tile_size = math.gcd(32, self.tmem_cols_Oi) + + tmem_load_atom_O = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.dtype_acc, + ) + tmem_store_atom_O = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.dtype_acc, + ) + thr_tmem_load_O = tcgen05.make_tmem_copy(tmem_load_atom_O, tO0tO0).get_slice(tidx) + thr_tmem_store_O = tcgen05.make_tmem_copy(tmem_store_atom_O, tO0tO0).get_slice(tidx) + + # ((32,1),1,4) + tOtOs_t2r = [ + thr_tmem_load_O.partition_S(tOtOs[split]) for split in range(self.num_hdimv_splits) + ] + tOtOs_r2t = [ + thr_tmem_store_O.partition_D(tOtOs[split]) for split in range(self.num_hdimv_splits) + ] + + cOi = cute.make_identity_tensor((self.cta_tile_m, self.hdimv // self.num_hdimv_splits)) + thr_tiled_copy_O_r2g = tiled_copy_O_r2g.get_slice(tidx) + tOicOi = thr_tiled_copy_O_r2g.partition_S(cOi) + + tOicOi_t2r = thr_tmem_load_O.partition_D(tOicOi[(None, None), 0, 0]) + + pipelines_O = [pipeline_O0, pipeline_O1] + + consumer_state_O0 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Oi + ) + consumer_state_O1 = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_Oi + ) + consumer_state_sm_stats = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, stages=self.num_stages_sm_stats + ) + + do_correction_rescale = partial( + self.correction_rescale, + thr_tmem_load_O, + thr_tmem_store_O, + tOicOi_t2r, + ) + + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + cta_m_block, head_idx, batch_idx, _ = work_tile.tile_idx + cluster_m_block = cta_m_block // self.cta_group_size + + seqlen = SeqlenInfoCls(batch_idx) + if const_expr(self.is_topk_gather): + n_block_min = 0 + n_block_max = self.topk_length // self.tile_n + # n_block_max = topk_length_dynamic // self.tile_n + else: + n_block_min, n_block_max = block_info.get_n_block_min_max( + seqlen, + cluster_m_block, + ) + num_n_blocks = n_block_max - n_block_min + + consumer_states_O = [consumer_state_O0, consumer_state_O1] + + # acquire first signal and release immediately + pipeline_sm_stats.consumer_wait(consumer_state_sm_stats) + pipeline_sm_stats.consumer_release(consumer_state_sm_stats) + consumer_state_sm_stats.advance() + + for _ in cutlass.range(num_n_blocks - 1, unroll=1): + pipeline_sm_stats.consumer_wait(consumer_state_sm_stats) + scale = sScale[tidx % self.cta_tile_m, consumer_state_sm_stats.index] + should_rescale = cute.arch.vote_ballot_sync(scale < 1.0) != 0 + pipeline_sm_stats.consumer_release(consumer_state_sm_stats) + consumer_state_sm_stats.advance() + + for split in cutlass.range_constexpr(self.num_hdimv_splits): + consumer_state_Oi = consumer_states_O[split] + pipelines_O[split].consumer_wait(consumer_state_Oi) + if should_rescale: + do_correction_rescale( + tOtOs_t2r[split], + tOtOs_r2t[split], + scale, + ) + pipelines_O[split].consumer_release(consumer_state_Oi) + consumer_state_Oi.advance() + consumer_states_O[split] = consumer_state_Oi + + # (seqlen_q, hdimv) + mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3, ragged=self.ragged_tma_O)[ + None, None, head_idx + ] + # (cta_tile_m, hdimv//2, 2) + gO = cute.local_tile( + mO_cur, + (self.cta_tile_m, self.hdimv // self.num_hdimv_splits), + (cta_m_block, None), + ) + tOgO = thr_tiled_copy_O_r2g.partition_D(gO) + # ((32, 1), 1, 4) + tOrOs_t2r = [ + cute.make_rmem_tensor(tOicOi_t2r.shape, self.dtype_acc) + for split in range(self.num_hdimv_splits) + ] + tOrOs_r2g_f32 = [ + thr_tiled_copy_O_r2g.retile(tOrOs_t2r[split]) + for split in range(self.num_hdimv_splits) + ] + tOrOs_r2g = [ + cute.make_rmem_tensor_like(tOrOs_r2g_f32[split], self.dtype_O) + for split in range(self.num_hdimv_splits) + ] + if const_expr(self.use_tma_O): + tOsO = thr_tiled_copy_O_r2g.partition_D(sO) + store_O, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_O, + 0, + cute.make_layout(1), + sO, + gO, + ) + + self.sm_stats_barrier_full.arrive_and_wait() + + row_sum0 = sRowSum[tidx % self.cta_tile_m, 0] + row_sum1 = sRowSum[tidx % self.cta_tile_m, 1] + row_sum = row_sum0 + row_sum1 + acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum + scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) + + self.sm_stats_barrier_empty.arrive() + + seqlen_q = ( + seqlen.seqlen_q + if const_expr(not self.pack_gqa) + else seqlen.seqlen_q * self.qhead_per_kvhead + ) + + # compute and store lse to gmem + if const_expr(mLSE is not None): + if const_expr(not seqlen.has_cu_seqlens_q): + mLSE_cur = mLSE[None, head_idx, batch_idx] + else: + lse_offset = ( + seqlen.offset_q if const_expr(not self.pack_gqa) else (0, seqlen.offset_q) + ) + mLSE_cur = cute.domain_offset((lse_offset,), mLSE[None, head_idx]) + gLSE = cute.local_tile(mLSE_cur, (self.cta_tile_m,), (cta_m_block,)) + if tidx < self.cta_tile_m: + row_max = sRowMax[tidx, 0] + LN2 = math.log(2.0) + lse = ( + (row_max * softmax_scale_log2 + cute.math.log2(row_sum, fastmath=True)) + * LN2 + if not acc_O_mn_row_is_zero_or_nan + else -Float32.inf + ) + if tidx < seqlen_q - cta_m_block * self.cta_tile_m: + gLSE[tidx] = lse + + row_idx = cta_m_block * self.cta_tile_m + tOicOi[0][0] + + for split in cutlass.range_constexpr(self.num_hdimv_splits): + consumer_state_Oi = consumer_states_O[split] + pipelines_O[split].consumer_wait(consumer_state_Oi) + # copy Oi tmem -> rmem + cute.copy( + thr_tmem_load_O, + tOtOs_t2r[split], + tOrOs_t2r[split], + ) + + # scale and downcast Oi + tOrOs_r2g[split].store((tOrOs_r2g_f32[split].load() * scale).to(self.dtype_O)) + + if const_expr(not self.use_tma_O): + # copy Oi rmem -> gmem + if row_idx < seqlen_q: + cute.copy( + thr_tiled_copy_O_r2g, + tOrOs_r2g[split], + tOgO[None, None, None, split], + ) + else: + # copy Oi rmem -> smem + if const_expr(self.overlap_sO_sV): + # last slot for Vti is always 1, 3 + sO_idx = 1 + 2 * split + else: + sO_idx = split + cute.copy( + thr_tiled_copy_O_r2g, + tOrOs_r2g[split], + tOsO[None, None, None, sO_idx], + ) + cute.arch.fence_view_async_shared() + self.epi_barrier.arrive_and_wait() + # tma store Oi smem -> gmem + if leader_warp: + store_O(src_idx=sO_idx, dst_idx=split) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(1 - split, read=True) + if const_expr(split == 1 and self.overlap_sO_sV): + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(sO_empty_mbar_ptr) + + consumer_state_O0, consumer_state_O1 = consumer_states_O + + cute.arch.fence_view_async_tmem_load() + pipeline_O0.consumer_release(consumer_state_O0) + pipeline_O1.consumer_release(consumer_state_O1) + consumer_state_O0.advance() + consumer_state_O1.advance() + + # Advance to next tile + work_tile = tile_scheduler.advance_to_next_work() + + @cute.jit + def correction_rescale( + self, + thr_tmem_load: cute.CopyAtom, + thr_tmem_store: cute.CopyAtom, + tOcO_t2r: cute.Tensor, + tOtO_t2r: cute.Tensor, + tOtO_r2t: cute.Tensor, + scale: Float32, + ): + tOrO_t2r_frg = cute.make_rmem_tensor_like(tOcO_t2r[None, None, 0], self.dtype_acc) + + for i in cutlass.range_constexpr(cute.size(tOtO_t2r, mode=[2])): + tOtO_t2r_cur = tOtO_t2r[None, None, i] + tOtO_r2t_cur = tOtO_r2t[None, None, i] + + cute.copy(thr_tmem_load, tOtO_t2r_cur, tOrO_t2r_frg) + for j in cutlass.range(0, cute.size(tOrO_t2r_frg), 2, unroll_full=True): + tOrO_t2r_frg[j], tOrO_t2r_frg[j + 1] = cute.arch.mul_packed_f32x2( + (tOrO_t2r_frg[j], tOrO_t2r_frg[j + 1]), (scale, scale) + ) + cute.copy(thr_tmem_store, tOrO_t2r_frg, tOtO_r2t_cur) + cute.arch.fence_view_async_tmem_store() + + +def test_mla_kernel( + seqlen_q=2048, + seqlen_k=2048, + topk_length=2048, + nheads=1, + batch=1, + iter=0, + compile_cache=dict(), + validate=True, + seed=0, + gather_kv=True, + pack_gqa=False, + is_causal=False, + varlen_q=False, + varlen_k=False, + disable_bitmask=False, +): + torch.manual_seed(seed) + hdim = 64 + hdimv = 512 + softmax_scale = 1.0 / math.sqrt(hdim + hdimv) + + nheads_kv = 1 + qhead_per_kvhead = nheads + + compile_key = ( + is_causal, + gather_kv, + topk_length if gather_kv else None, + pack_gqa, + qhead_per_kvhead, + nheads_kv, + varlen_q, + varlen_k, + disable_bitmask, + ) + if compile_key not in compile_cache: + total_q_dummy = batch * seqlen_q + total_k_dummy = batch * seqlen_k + + if varlen_q: + Q = torch.randn(total_q_dummy, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(total_q_dummy, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(total_q_dummy, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(nheads, total_q_dummy, dtype=torch.float32, device="cuda") + index_topk = ( + torch.rand(total_q_dummy, topk_length, device="cuda") + .argsort(dim=-1) + .to(torch.int32) + ) + cu_seqlens_q_dummy = torch.arange( + 0, (batch + 1) * seqlen_q, seqlen_q, dtype=torch.int32, device="cuda" + ) + else: + Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(batch, nheads, seqlen_q, dtype=torch.float32, device="cuda") + index_topk = ( + torch.rand(batch, seqlen_q, topk_length, device="cuda") + .argsort(dim=-1) + .to(torch.int32) + ) + + if varlen_k: + K = torch.randn(total_k_dummy, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(total_k_dummy, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + cu_seqlens_k_dummy = torch.arange( + 0, (batch + 1) * seqlen_k, seqlen_k, dtype=torch.int32, device="cuda" + ) + else: + K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + + mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) + mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) + mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) + mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) + mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) + mLSE = from_dlpack(lse, assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) + if gather_kv: + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic( + leading_dim=index_topk.ndim - 1 + ) + else: + mIndexTopk = None + + compile_kwargs = dict(mIndexTopk=mIndexTopk) + if varlen_q: + compile_kwargs["mCuSeqlensQ"] = from_dlpack(cu_seqlens_q_dummy, assumed_align=4) + if varlen_k: + compile_kwargs["mCuSeqlensK"] = from_dlpack(cu_seqlens_k_dummy, assumed_align=4) + + kernel = cute.compile( + FlashAttentionMLAForwardSm100( + is_causal=is_causal, + use_cpasync_load_KV=gather_kv, + topk_length=topk_length if gather_kv else 2048, + is_topk_gather=gather_kv, + pack_gqa=pack_gqa, + qhead_per_kvhead=qhead_per_kvhead, + nheads_kv=nheads_kv, + is_varlen_q=varlen_q, + disable_bitmask=disable_bitmask, + ), + mQ, + mQv, + mK, + mV, + mO, + mLSE, + softmax_scale, + **compile_kwargs, + options="--keep-ptx --keep-cubin --generate-line-info", + ) + dump_kernel_attributes(kernel) + compile_cache[compile_key] = kernel + + # ================================================================ + # ---- Generate variable seqlens for this run ---- + if varlen_q: + torch.manual_seed(seed + 1000) + # When causal without varlen_k, every per-batch seqlen_q must not exceed seqlen_k. + max_seqlen_q = seqlen_k if (is_causal and not varlen_k) else seqlen_q + seqlens_q = torch.randint(1, max_seqlen_q + 1, (batch,), dtype=torch.int32) + cu_seqlens_q = torch.zeros(batch + 1, dtype=torch.int32, device="cuda") + cu_seqlens_q[1:] = seqlens_q.cumsum(0).to(torch.int32).cuda() + total_q = cu_seqlens_q[-1].item() + else: + seqlens_q = torch.full((batch,), seqlen_q, dtype=torch.int32) + total_q = None # unused + + if varlen_k: + torch.manual_seed(seed + 2000) + # Each batch item must have at least topk_length keys so topk gather is valid. + min_seqlen_k = topk_length if gather_kv else 1 + seqlens_k = torch.randint(min_seqlen_k, seqlen_k + 1, (batch,), dtype=torch.int32) + # When causal, every batch item needs seqlens_k[b] >= seqlens_q[b]. + if is_causal: + seqlens_k = torch.maximum(seqlens_k, seqlens_q) + cu_seqlens_k = torch.zeros(batch + 1, dtype=torch.int32, device="cuda") + cu_seqlens_k[1:] = seqlens_k.cumsum(0).to(torch.int32).cuda() + total_k = cu_seqlens_k[-1].item() + else: + seqlens_k = torch.full((batch,), seqlen_k, dtype=torch.int32) + total_k = None # unused + + torch.manual_seed(seed) # restore main seed before drawing actual tensors + + # ---- Allocate Q / Qv / O / lse ---- + if varlen_q: + Q = torch.randn(total_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(total_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(total_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(nheads, total_q, dtype=torch.float32, device="cuda") + else: + Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + lse = torch.empty(batch, nheads, seqlen_q, dtype=torch.float32, device="cuda") + + # ---- Allocate K / V ---- + if varlen_k: + K = torch.randn(total_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(total_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + else: + K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + + # ---- Generate index_topk with per-batch valid ranges when varlen_k ---- + # index_topk shape: (total_q, topk_length) if varlen_q else (batch, seqlen_q, topk_length) + if gather_kv: + topk_parts = [] + for b in range(batch): + sl_q_b = seqlens_q[b].item() + sl_k_b = seqlens_k[b].item() + # Draw topk_length unique indices from [0, sl_k_b) for each query in this batch item. + topk_b = ( + torch.rand(sl_q_b, sl_k_b, device="cuda") + .argsort(dim=-1)[..., :topk_length] + .to(torch.int32) + ) # (sl_q_b, topk_length), all < sl_k_b + topk_parts.append(topk_b) + + if varlen_q: + index_topk = torch.cat(topk_parts, dim=0) # (total_q, topk_length) + else: + index_topk = torch.stack(topk_parts, dim=0) # (batch, seqlen_q, topk_length) + else: + index_topk = None + + # ---- Reference computation (per-batch loop covers all four varlen combos) ---- + O_ref_list, O_pt_list, lse_ref_list, lse_pt_list = [], [], [], [] + for b in range(batch): + qs = cu_seqlens_q[b].item() if varlen_q else b * seqlen_q + qe = cu_seqlens_q[b + 1].item() if varlen_q else (b + 1) * seqlen_q + ks = cu_seqlens_k[b].item() if varlen_k else b * seqlen_k + ke = cu_seqlens_k[b + 1].item() if varlen_k else (b + 1) * seqlen_k + + Q_b = Q[qs:qe].unsqueeze(0) if varlen_q else Q[b : b + 1] # (1, sl_q, nheads, hdim) + Qv_b = Qv[qs:qe].unsqueeze(0) if varlen_q else Qv[b : b + 1] # (1, sl_q, nheads, hdimv) + K_b = K[ks:ke].unsqueeze(0) if varlen_k else K[b : b + 1] # (1, sl_k, nheads_kv, hdim) + V_b = V[ks:ke].unsqueeze(0) if varlen_k else V[b : b + 1] # (1, sl_k, nheads_kv, hdimv) + if gather_kv: + topk_b = index_topk[qs:qe].unsqueeze(0) if varlen_q else index_topk[b : b + 1] + else: + topk_b = None + + O_b, _, lse_b = attention_ref( + Q_b, K_b, V_b, qv=Qv_b, causal=is_causal, return_lse=True, gather_kv_indices=topk_b + ) + O_pt_b, _, lse_pt_b = attention_ref( + Q_b, + K_b, + V_b, + qv=Qv_b, + causal=is_causal, + upcast=False, + reorder_ops=True, + return_lse=True, + gather_kv_indices=topk_b, + ) + O_ref_list.append(O_b.squeeze(0)) + O_pt_list.append(O_pt_b.squeeze(0)) + lse_ref_list.append(lse_b.squeeze(0)) + lse_pt_list.append(lse_pt_b.squeeze(0)) + + cat_dim_o = 0 if (varlen_q) else 0 # always 0: leading token/batch dim + cat_dim_lse = -1 if (varlen_q) else -1 # always last: token dim + + if varlen_q: + O_ref = torch.cat(O_ref_list, dim=0) # (total_q, nheads, hdimv) + O_pt = torch.cat(O_pt_list, dim=0) + lse_ref = torch.cat(lse_ref_list, dim=-1) # (nheads, total_q) + lse_pt = torch.cat(lse_pt_list, dim=-1) + else: + O_ref = torch.stack(O_ref_list, dim=0) # (batch, seqlen_q, nheads, hdimv) + O_pt = torch.stack(O_pt_list, dim=0) + lse_ref = torch.stack(lse_ref_list, dim=0) # (batch, nheads, seqlen_q) + lse_pt = torch.stack(lse_pt_list, dim=0) + + rtol = 2 + atol = 2 * (O_ref + 0.3 - 0.3 - O_ref).abs().max().item() + + # ---- CuTe tensor wrappers ---- + mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) + mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) + mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) + mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) + mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) + mLSE = from_dlpack(lse, assumed_align=4).mark_layout_dynamic(leading_dim=lse.ndim - 1) + if index_topk is not None: + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic( + leading_dim=index_topk.ndim - 1 + ) + else: + mIndexTopk = None + + run_kwargs = dict(mIndexTopk=mIndexTopk) + if varlen_q: + run_kwargs["mCuSeqlensQ"] = from_dlpack(cu_seqlens_q, assumed_align=4) + if varlen_k: + run_kwargs["mCuSeqlensK"] = from_dlpack(cu_seqlens_k, assumed_align=4) + + # ---- Run kernel ---- + compile_cache[compile_key]( + mQ, + mQv, + mK, + mV, + mO, + mLSE, + softmax_scale, + **run_kwargs, + ) + + print(f"Pytorch max O diff: {(O_pt - O_ref).abs().max().item()}") + print(f"Pytorch mean O diff: {(O_pt - O_ref).abs().mean().item()}") + print(f"Max abs diff O, O_ref: {(O - O_ref).abs().max().item()}") + print(f"Mean abs diff O, O_ref: {(O - O_ref).abs().mean().item()}") + + # print(f"Pytorch LSE max diff: {(lse_pt - lse_ref).abs().max().item()}") + # print(f"Pytorch LSE mean diff: {(lse_pt - lse_ref).abs().mean().item()}") + # print(f"Max abs diff LSE: {(lse - lse_ref).abs().max().item()}") + # print(f"Mean abs diff LSE: {(lse - lse_ref).abs().mean().item()}") + + if validate: + assert (O - O_ref).abs().max().item() <= rtol * (O_pt - O_ref).abs().max().item() + atol + varlen_tag = "" + if varlen_q: + varlen_tag += f", total_q:{total_q}" + if varlen_k: + varlen_tag += f", total_k:{total_k}" + print( + f"batch:{batch:3d}, nheads:{nheads:3d}, seqlen_q:{seqlen_q:5d}, seqlen_k:{seqlen_k:5d}" + f"{varlen_tag}, iter:{iter:2d} PASSED" + ) + else: + print(mO) + print( + f"batch:{batch:3d}, nheads:{nheads:3d}, seqlen_q:{seqlen_q:5d}, seqlen_k:{seqlen_k:5d}" + f", iter:{iter:2d} RUN (NOT TESTING CORRECTNESS)" + ) + + return None + + +def timeit(fn, *args, **kwargs): + # Synchronize before timing + torch.cuda.synchronize() + + # Warmup + for _ in range(10): + fn(*args, **kwargs) + + # Benchmark using PyTorch's Timer + t = benchmark.Timer( + stmt="fn(*args, **kwargs)", globals={"fn": fn, "args": args, "kwargs": kwargs} + ) + + # Time it multiple runs + measurement = t.timeit(20) # 20 repeats + avg_time = measurement.mean # Average time in seconds + + time.sleep(1) + + return avg_time + + +def benchmark_mla_kernel( + batch=1, + seqlen_q=2048, + seqlen_k=2048, + topk_length=2048, + nheads=128, + hdim=64, + hdimv=512, + compile_cache=dict(), + gather_kv=True, + is_causal=False, + disable_bitmask=False, +): + assert hdim == 64, "hdim must be 64" + assert hdimv == 512, "hdimv must be 512" + + qhead_per_kvhead = nheads + nheads_kv = 1 + pack_gqa = True + softmax_scale = 1.0 / math.sqrt(hdim + hdimv) + + compile_key = ( + is_causal, + gather_kv, + topk_length if gather_kv else None, + pack_gqa, + qhead_per_kvhead, + nheads_kv, + disable_bitmask, + ) + if compile_key not in compile_cache: + Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + index_topk = ( + torch.rand(batch, seqlen_q, topk_length, device="cuda").argsort(dim=-1).to(torch.int32) + ) + + mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) + mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) + mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) + mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) + mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) + if gather_kv: + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic( + leading_dim=index_topk.ndim - 1 + ) + else: + mIndexTopk = None + + mLSE = None + + kernel = cute.compile( + FlashAttentionMLAForwardSm100( + is_causal=is_causal, + use_cpasync_load_KV=gather_kv, + topk_length=topk_length if gather_kv else 2048, + is_topk_gather=gather_kv, + pack_gqa=pack_gqa, + qhead_per_kvhead=qhead_per_kvhead, + nheads_kv=nheads_kv, + disable_bitmask=disable_bitmask, + ), + mQ, + mQv, + mK, + mV, + mO, + mLSE, + softmax_scale, + mIndexTopk=mIndexTopk, + ) + compile_cache[compile_key] = kernel + + Q = torch.randn(batch, seqlen_q, nheads, hdim, dtype=torch.bfloat16, device="cuda") + Qv = torch.randn(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + K = torch.randn(batch, seqlen_k, nheads_kv, hdim, dtype=torch.bfloat16, device="cuda") + V = torch.randn(batch, seqlen_k, nheads_kv, hdimv, dtype=torch.bfloat16, device="cuda") + O = torch.empty(batch, seqlen_q, nheads, hdimv, dtype=torch.bfloat16, device="cuda") + + index_topk = ( + torch.rand(batch, seqlen_q, topk_length, device="cuda").argsort(dim=-1).to(torch.int32) + ) + + mQ = from_dlpack(Q, assumed_align=16).mark_layout_dynamic(leading_dim=Q.ndim - 1) + mQv = from_dlpack(Qv, assumed_align=16).mark_layout_dynamic(leading_dim=Qv.ndim - 1) + mK = from_dlpack(K, assumed_align=16).mark_layout_dynamic(leading_dim=K.ndim - 1) + mV = from_dlpack(V, assumed_align=16).mark_layout_dynamic(leading_dim=V.ndim - 1) + mO = from_dlpack(O, assumed_align=16).mark_layout_dynamic(leading_dim=O.ndim - 1) + if gather_kv: + mIndexTopk = from_dlpack(index_topk, assumed_align=16).mark_layout_dynamic( + leading_dim=index_topk.ndim - 1 + ) + else: + mIndexTopk = None + mLSE = None + + exec_time_in_s = timeit( + compile_cache[compile_key], + mQ, + mQv, + mK, + mV, + mO, + mLSE, + softmax_scale, + mIndexTopk=mIndexTopk, + ) + + seqlen_k_eff = topk_length if gather_kv else seqlen_k + + FLOPs = 2 * batch * nheads * seqlen_q * seqlen_k_eff * (hdim + 2 * hdimv) + if is_causal and not gather_kv: + FLOPs /= 2 + + TFLOPS = FLOPs / exec_time_in_s / 1e12 + + q_bytes = 2 * batch * nheads * seqlen_q * hdim + qv_bytes = 2 * batch * nheads * seqlen_q * hdimv + k_bytes = 2 * batch * nheads_kv * seqlen_k_eff * hdim + v_bytes = 2 * batch * nheads_kv * seqlen_k_eff * hdimv + o_bytes = 2 * batch * nheads * seqlen_q * hdimv + total_bytes = q_bytes + qv_bytes + k_bytes + v_bytes + o_bytes + TBs = total_bytes / exec_time_in_s / 1e12 + + print( + f"batch: {batch}, seqlen_q: {seqlen_q}, seqlen_k: {seqlen_k}, nheads: {nheads}, -> {exec_time_in_s * 1e3:.2f} ms, {TFLOPS:.2f} TFLOPS, {TBs:.2f} TBs" + ) + + +if __name__ == "__main__": + run_test = True + run_benchmark = True + gather_kv = False + is_causal = True + pack_gqa = True + topk_length = 2048 + varlen_q = False + varlen_k = False + disable_bitmask = True + validate = True + + if run_test: + if not gather_kv: + seqlen_q_test_values = range(1, 4002, 400) + seqlen_k_test_values = range(1, 4002, 400) + else: + seqlen_q_test_values = range(1, 1001, 200) + seqlen_k_test_values = range(topk_length, 9001, 2000) + seqlen_q_test_values = [1] + seqlen_k_test_values = [4096] + nheads_test_values = [128] + batch_test_values = [4] + test_configs = [ + ( + batch, + nheads, + seqlen_q, + seqlen_k, + ) + for batch in batch_test_values + for nheads in nheads_test_values + for seqlen_q in seqlen_q_test_values + for seqlen_k in seqlen_k_test_values + ] + iters_per_config = 1 + compile_cache = dict() + print("=" * 40) + print("Testing MLA Kernel") + print("=" * 40) + for config in test_configs: + batch, nheads, seqlen_q, seqlen_k = config + # if is_causal and seqlen_k < seqlen_q: + # continue + for iter in range(iters_per_config): + test_mla_kernel( + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + topk_length=topk_length, + nheads=nheads, + batch=batch, + iter=iter, + compile_cache=compile_cache, + validate=validate, + seed=0, + gather_kv=gather_kv, + pack_gqa=pack_gqa, + is_causal=is_causal, + varlen_q=varlen_q, + varlen_k=varlen_k, + disable_bitmask=disable_bitmask, + ) + if run_benchmark: + if gather_kv: + seqlen_q_benchmark_values = [1] + seqlen_k_benchmark_values = [8192 * 2] + nheads_benchmark_values = [128] + batch_benchmark_values = [512] + else: + seqlen_q_benchmark_values = [1] + seqlen_k_benchmark_values = [8192 * 2] + nheads_benchmark_values = [128] + batch_benchmark_values = [512] + seqlen_q_benchmark_values = [4096] + seqlen_k_benchmark_values = [4096] + nheads_benchmark_values = [16] + batch_benchmark_values = [8] + benchmark_configs = [ + ( + batch, + nheads, + seqlen_q, + seqlen_k, + ) + for batch in batch_benchmark_values + for nheads in nheads_benchmark_values + for seqlen_q in seqlen_q_benchmark_values + for seqlen_k in seqlen_k_benchmark_values + ] + compile_cache = dict() + print("=" * 40) + print("Benchmarking MLA Kernel") + print("=" * 40) + for config in benchmark_configs: + batch, nheads, seqlen_q, seqlen_k = config + benchmark_mla_kernel( + batch=batch, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + topk_length=topk_length, + nheads=nheads, + gather_kv=gather_kv, + is_causal=is_causal, + disable_bitmask=disable_bitmask, + compile_cache=compile_cache, + ) diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py index 6e4fdbf..75e767c 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py @@ -14,7 +14,7 @@ # https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py import math -from typing import Tuple, Callable, Optional, Literal +from typing import Tuple, Callable, Optional, Literal, NamedTuple from functools import partial import cuda.bindings.driver as cuda @@ -87,29 +87,23 @@ (True, False, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, (False, True, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 72}, } +_FP8_TUNING_CONFIG = { + (True, False, 128, False): {'ex2_emu_freq': 10, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 160, 'num_regs_correction': 72}, +} +_FP8_SMALL_HDIM_REGS = { + False: {"num_regs_softmax": 168, "num_regs_correction": 96, "num_regs_other": 80}, + True: {"num_regs_softmax": 152, "num_regs_correction": 96, "num_regs_other": 112}, +} # === END TUNING KNOBS === -# === TUNING KNOBS (agent-editable) === -# Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool) -# Values: -# ex2_emu_freq: int — how often to use emulated exp2 (0=all hardware exp2, higher=more emulation). -# SM103 has fast native exp2, so set freq=0 there. -# ex2_emu_start_frg: int — fragment index to start emulation from -# num_regs_softmax: int — register count for softmax warps (multiple of 8) -# num_regs_correction: int — register count for correction warps (multiple of 8) -# num_regs_other is derived: 512 - num_regs_softmax * 2 - num_regs_correction -_TUNING_CONFIG = { - (True, False, 128, False): {'ex2_emu_freq': 10, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 176, 'num_regs_correction': 88}, - (False, True, 128, False): {'ex2_emu_freq': 16, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 192, 'num_regs_correction': 72}, - (True, False, 192, False): {"ex2_emu_freq": 16, "ex2_emu_start_frg": 0, "num_regs_softmax": 184, "num_regs_correction": 80}, - (False, True, 192, False): {"ex2_emu_freq": 32, "ex2_emu_start_frg": 1, "num_regs_softmax": 192, "num_regs_correction": 72}, - (True, False, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 80}, - (False, True, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, - (True, False, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, - (False, True, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 72}, -} -# === END TUNING KNOBS === +class DescaleTensors(NamedTuple): + q_descale: Optional[cute.Tensor] = None + k_descale: Optional[cute.Tensor] = None + v_descale: Optional[cute.Tensor] = None + + def __new_from_mlir_values__(self, values): + return DescaleTensors(*((*values, None, None, None)[:3])) class FlashAttentionForwardSm100: @@ -209,6 +203,7 @@ def __init__( "Paged KV does not support irregular head dim" ) + # ClC does not compose with these other features, so disable even if requested self.use_clc_scheduler = ( use_clc_scheduler and self.use_tma_KV @@ -372,6 +367,7 @@ def __call__( window_size_left: Int32 | int | None = None, window_size_right: Int32 | int | None = None, learnable_sink: Optional[cute.Tensor] = None, + descale_tensors: Optional[DescaleTensors] = None, blocksparse_tensors: Optional[BlockSparseTensors] = None, aux_tensors: Optional[list] = None, # Always keep stream as the last parameter (EnvStream: obtained implicitly via TVM FFI). @@ -427,6 +423,24 @@ def __call__( raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") if const_expr(self.q_dtype != self.v_dtype): raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + if const_expr(self.q_dtype.width == 8): + paged_kv_non_tma = not self.use_tma_KV + if const_expr(self.head_dim_padded < 96): + fp8_regs = _FP8_SMALL_HDIM_REGS[paged_kv_non_tma] + self.num_regs_softmax = fp8_regs["num_regs_softmax"] + self.num_regs_correction = fp8_regs["num_regs_correction"] + self.num_regs_other = fp8_regs["num_regs_other"] + else: + fp8_tune = _FP8_TUNING_CONFIG.get( + (self.use_2cta_instrs, self.is_causal, self.head_dim_padded, self.is_sm103), {} + ) + if const_expr("ex2_emu_freq" in fp8_tune): + self._tune = {**self._tune, **fp8_tune} + self.enable_ex2_emu = self._tune["ex2_emu_freq"] > 0 + if const_expr(not paged_kv_non_tma and "num_regs_softmax" in fp8_tune): + self.num_regs_softmax = fp8_tune["num_regs_softmax"] + self.num_regs_correction = fp8_tune["num_regs_correction"] + self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_correction self._setup_attributes() self.use_tma_O = ( self.arch >= Arch.sm_90 @@ -726,6 +740,7 @@ class SharedStorage: window_size_left, window_size_right, learnable_sink, + descale_tensors, blocksparse_tensors, sQ_layout, sK_layout, @@ -772,6 +787,7 @@ def kernel( window_size_left: Optional[Int32], window_size_right: Optional[Int32], learnable_sink: Optional[cute.Tensor], + descale_tensors: Optional[DescaleTensors], blocksparse_tensors: Optional[BlockSparseTensors], sQ_layout: cute.ComposedLayout, sK_layout: cute.ComposedLayout, @@ -1200,6 +1216,7 @@ def kernel( self.softmax_loop, softmax_scale_log2=softmax_scale_log2, softmax_scale=softmax_scale, + descale_tensors=descale_tensors, thr_mma_qk=thr_mma_qk, sScale=sScale, mLSE=mLSE, @@ -1255,6 +1272,7 @@ def kernel( sm_stats_barrier, pipeline_o_epi, learnable_sink, + descale_tensors, gmem_tiled_copy_O, tma_atom_O, softmax_scale_log2, @@ -1520,6 +1538,7 @@ def mma( qk_mma_op, pv_mma_op = tiled_mma_qk.op, tiled_mma_pv.op qk_mma_idesc, pv_mma_idesc = sm100_desc.mma_op_to_idesc(qk_mma_op), sm100_desc.mma_op_to_idesc(pv_mma_op) + qk_mma_kind = sm100_utils._tcgen05_mma_kind(qk_mma_op) q_smem_base = sm100_desc.smem_desc_base_from_tensor(sQ, sm100_desc.Major.K) k_smem_base = sm100_desc.smem_desc_base_from_tensor(sK, sm100_desc.Major.K) v_smem_base = sm100_desc.smem_desc_base_from_tensor(sV, sm100_desc.Major.MN) @@ -1548,6 +1567,7 @@ def mma( tCrB_layout=tSrK[None, None, None, 0].layout, smem_var_name_prefix=f"fa_fwd_q_smem_desc", idesc_var_name=f"fa_fwd_qk_mma_idesc", + kind=qk_mma_kind, smem_offset=-sQ_stage_stride if stage == 0 else sQ_stage_stride, zero_init=True, cta_group=self.cta_group_size, @@ -1776,12 +1796,39 @@ def mma( # pipeline_o_acc.producer_acquire() inside the loop. # for both softmax0 and softmax1 warp group + @cute.jit + def _kv_head_idx(self, head_idx: Int32) -> Int32: + """Map query-head tile index -> KV-head index (FA3 descale semantics).""" + if cutlass.const_expr(self.pack_gqa): + return head_idx + return head_idx // self.qhead_per_kvhead + + @cute.jit + def _load_effective_descales( + self, + descale_tensors: Optional[DescaleTensors], + batch_idx: Int32, + kv_head_idx: Int32, + ) -> Tuple[Float32, Float32]: + """Load effective QK and V descales, defaulting unspecified tensors to identity.""" + qk_descale = Float32(1.0) + v_descale = Float32(1.0) + if cutlass.const_expr(descale_tensors is not None): + if cutlass.const_expr(descale_tensors.q_descale is not None): + qk_descale = qk_descale * Float32(descale_tensors.q_descale[batch_idx, kv_head_idx]) + if cutlass.const_expr(descale_tensors.k_descale is not None): + qk_descale = qk_descale * Float32(descale_tensors.k_descale[batch_idx, kv_head_idx]) + if cutlass.const_expr(descale_tensors.v_descale is not None): + v_descale = Float32(descale_tensors.v_descale[batch_idx, kv_head_idx]) + return qk_descale, v_descale + @cute.jit def softmax_loop( self, stage: int | Int32, softmax_scale_log2: Float32, - softmax_scale: Float32, + softmax_scale: Float32 | None, + descale_tensors: Optional[DescaleTensors], thr_mma_qk: cute.core.ThrMma, tStS: cute.Tensor, # ((TILE_M, TILE_N), 1, 1, q_stage) sScale: cute.Tensor, @@ -1847,7 +1894,10 @@ def softmax_loop( ) tStScale_r2t = thr_tmem_store_scale.partition_D(tStScale) tmem_store_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(16)), Float32 + tcgen05.copy.St32x32bOp( + tcgen05.copy.Repetition(8 if const_expr(self.q_dtype.width == 8) else 16) + ), + Float32, ) thr_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP).get_slice(tidx) tStP_r2t = thr_tmem_store.partition_D(tStP) # (((16,32),1),1,4) @@ -1863,6 +1913,7 @@ def softmax_loop( work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + kv_head_idx = self._kv_head_idx(head_idx) seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) @@ -1917,10 +1968,26 @@ def softmax_loop( else: mask_fn_none = None + qk_descale, _ = self._load_effective_descales(descale_tensors, batch_idx, kv_head_idx) + + max_offset = 8 if cutlass.const_expr(self.q_dtype.width == 8) else 0 + if const_expr(self.score_mod is None): + softmax_scale_log2_eff = softmax_scale_log2 * qk_descale + softmax_scale_eff = None + else: + softmax_scale_log2_eff = softmax_scale_log2 + softmax_scale_eff = softmax_scale * qk_descale + + rescale_threshold = ( + 8.0 if const_expr(self.q_dtype.width == 16) else + 4.0 if const_expr(self.q_dtype.width == 8) else + 0.0 + ) softmax = SoftmaxSm100.create( - softmax_scale_log2, - rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, - softmax_scale=softmax_scale, + softmax_scale_log2_eff, + rescale_threshold=rescale_threshold, + softmax_scale=softmax_scale_eff, + max_offset=max_offset, ) softmax.reset() @@ -2266,6 +2333,7 @@ def correction_loop( sm_stats_barrier: pipeline.NamedBarrier, pipeline_o_epi: pipeline.PipelineAsync, learnable_sink: Optional[cute.Tensor], + descale_tensors: Optional[DescaleTensors], gmem_tiled_copy_O: cute.TiledCopy, tma_atom_O: cute.CopyAtom, softmax_scale_log2: Float32, @@ -2306,6 +2374,17 @@ def correction_loop( work_tile = tile_scheduler.initial_work_tile_info() while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx + kv_head_idx = self._kv_head_idx(head_idx) + qk_descale, v_descale = self._load_effective_descales(descale_tensors, batch_idx, kv_head_idx) + if const_expr(self.score_mod is None): + softmax_scale_log2_eff = softmax_scale_log2 * qk_descale + else: + softmax_scale_log2_eff = softmax_scale_log2 + + max_offset = Float32(8.0) if cutlass.const_expr(self.q_dtype.width == 8) else Float32(0.0) + max_offset_scale = ( + Float32(256.0) if cutlass.const_expr(self.q_dtype.width == 8) else Float32(1.0) + ) seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) @@ -2408,15 +2487,16 @@ def correction_loop( if const_expr(not self.is_split_kv) or split_idx == 0: if row_max == -Float32.inf: # It's possible to have an empty row with splitKV. - row_max = sink_val * (LOG2_E / softmax_scale_log2) - row_sum = Float32(1.0) + row_max = sink_val * (LOG2_E / softmax_scale_log2_eff) + row_sum = max_offset_scale else: row_sum += cute.math.exp2( - sink_val * LOG2_E - row_max * softmax_scale_log2, fastmath=True + sink_val * LOG2_E - row_max * softmax_scale_log2_eff + max_offset, fastmath=True ) acc_O_mn_row_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum stats[stage] = (row_sum, row_max, acc_O_mn_row_is_zero_or_nan) scale = cute.arch.rcp_approx(row_sum if not acc_O_mn_row_is_zero_or_nan else 1.0) + scale = scale * v_descale # Wait for the last O to be ready from the MMA warp pipeline_o_acc.consumer_wait_w_index_phase(stage, o_corr_consumer_phase) if const_expr(not self.use_correction_warps_for_epi): @@ -2480,7 +2560,9 @@ def correction_loop( sm_stats_consumer_phase, o_corr_consumer_phase, corr_epi_producer_phase, - softmax_scale_log2, + softmax_scale_log2_eff, + max_offset, + max_offset_scale, mO_cur, gO, gmem_tiled_copy_O_for_empty_tile, @@ -2507,7 +2589,7 @@ def correction_loop( # cute.printf("row_sum = {}, row_max = {}, acc_O_mn_row_is_zero_or_nan = {}\n", row_sum, row_max, acc_O_mn_row_is_zero_or_nan) LN2 = math.log(2.0) lse = ( - (row_max * softmax_scale_log2 + cute.math.log2(row_sum, fastmath=True)) * LN2 + (row_max * softmax_scale_log2_eff + (cute.math.log2(row_sum, fastmath=True) - max_offset)) * LN2 if not acc_O_mn_row_is_zero_or_nan else -Float32.inf ) diff --git a/flash_sparse_attn/ops/cute/interface.py b/flash_sparse_attn/ops/cute/interface.py index 872960a..17b703b 100644 --- a/flash_sparse_attn/ops/cute/interface.py +++ b/flash_sparse_attn/ops/cute/interface.py @@ -1,24 +1,6 @@ # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'll need install nvidia-cutlass-dsl==4.2.0. -# Supported features: -# - BF16 & FP16 dtype -# - noncausal & causal attention -# - MHA, GQA, MQA -# - hdim 64, 96, 128. -# - (hdim_qk, hdim_v) = (192, 128) for Blackwell (i.e. DeepSeek shape) -# - varlen -# - sliding window -# - bwd pass for Ampere (will also run on Hopper/Blackwell, but will be slow) - -# Features not supported yet: -# - split (i.e. FlashDecoding) -# - tuned block sizes -# - paged KV -# - append KV to existing KV cache -# - FP8 -# - bwd pass optimized for Hopper/Blackwell - import os import math from dataclasses import dataclass @@ -52,7 +34,7 @@ ) from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 -from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 +from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100, DescaleTensors from flash_attn.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 @@ -61,6 +43,7 @@ from flash_attn.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine +from flash_attn.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 from flash_attn.cute.block_sparsity import ( BlockSparseTensorsTorch, @@ -102,6 +85,7 @@ def _get_device_arch(): def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, alignment: int) -> None: """Validate head dimension constraints based on compute capability.""" is_deepseek_shape = head_dim == 192 and head_dim_v == 128 + is_deepseek_mla_absorbed_shape = head_dim == 64 and head_dim_v == 512 is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128 is_sm90_range = 8 <= head_dim <= 256 and 8 <= head_dim_v <= 256 @@ -111,7 +95,7 @@ def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int, f"head_dim and head_dim_v must be between 8 and 256 and divisible by {alignment}." ) elif compute_capability in [10, 11]: - assert (is_standard_range or is_deepseek_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( + assert (is_standard_range or is_deepseek_shape or is_deepseek_mla_absorbed_shape) and head_dim % alignment == 0 and head_dim_v % alignment == 0, ( f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM100/SM110. " f"head_dim and head_dim_v must be between 8 and 128 and divisible by {alignment}, or (192, 128) for DeepSeek." ) @@ -253,11 +237,12 @@ def _validate_tensor(t, name, expected_shape, expected_dtype, expected_device): if not is_fake_mode(): assert t.is_cuda, f"{name} must be on CUDA" - torch2cute_dtype_map = { torch.float16: cutlass.Float16, torch.bfloat16: cutlass.BFloat16, torch.float32: cutlass.Float32, + torch.float8_e4m3fn: cutlass.Float8E4M3FN, + torch.float8_e5m2: cutlass.Float8E5M2, } @@ -298,12 +283,14 @@ def _flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + qv: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, + min_seqlen_k: Optional[int] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, @@ -325,6 +312,10 @@ def _flash_attn_fwd( out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, aux_tensors: Optional[list[torch.Tensor]] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + gather_kv_indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for FlashAttention. @@ -340,6 +331,7 @@ def _flash_attn_fwd( aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. """ q, k, v = [maybe_contiguous(t) for t in (q, k, v)] + q_descale, k_descale, v_descale = [maybe_contiguous(t) for t in (q_descale, k_descale, v_descale)] num_head, head_dim = q.shape[-2:] if cu_seqlens_q is None: batch_size, seqlen_q = q.shape[:2] @@ -385,7 +377,9 @@ def _flash_attn_fwd( assert seqused_k is None or seqused_k.shape == (batch_size,), ( "seqused_k must have shape (batch_size,)" ) - assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" + assert q.dtype in [torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2], ( + "inputs must be float16, bfloat16, fp8 e4m3fn, or fp8 e5m2" + ) assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: if t is not None: @@ -406,6 +400,9 @@ def _flash_attn_fwd( q, k, v, + q_descale, + k_descale, + v_descale, cu_seqlens_q, cu_seqlens_k, seqused_q, @@ -421,14 +418,17 @@ def _flash_attn_fwd( if arch // 10 not in [8, 12]: _validate_head_dims(head_dim, head_dim_v, arch // 10, alignment) if softmax_scale is None: - softmax_scale = 1.0 / math.sqrt(head_dim) + softmax_scale = 1.0 / math.sqrt(head_dim) if qv is None else 1.0 / math.sqrt(head_dim + head_dim_v) if softcap == 0.0: softcap = None qhead_per_kvhead = num_head // num_head_kv if pack_gqa is None: pack_gqa = qhead_per_kvhead > 1 - out_torch_dtype = q.dtype + is_fp8 = q.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + if is_fp8 and (q.requires_grad or k.requires_grad or v.requires_grad): + raise NotImplementedError("FA4 CuTe FP8 backward is not supported yet (forward-only).") + out_torch_dtype = torch.bfloat16 if is_fp8 else q.dtype device = q.device q_batch_seqlen_shape = (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q) @@ -450,7 +450,24 @@ def _flash_attn_fwd( elif lse is not None: _validate_tensor(lse, "lse", lse_shape, torch.float32, device) + if seqlen_k == 0: + out.zero_() + if lse is not None: + lse.fill_(float("-inf")) + return out, lse + + if is_fp8: + for t, name in ((q_descale, "q_descale"), (k_descale, "k_descale"), (v_descale, "v_descale")): + if t is not None: + _validate_tensor(t, name, (batch_size, num_head_kv), torch.float32, device) + else: + assert q_descale is None and k_descale is None and v_descale is None, ( + "q_descale/k_descale/v_descale are only supported for FP8 inputs" + ) + dtype = torch2cute_dtype_map[q.dtype] + if is_fp8: + assert arch // 10 == 10, "FP8 is only supported on SM100 (compute capability 10.x) for FA4 CuTe." use_block_sparsity = block_sparse_tensors is not None causal, local, window_size_left, window_size_right = _resolve_causal_local_window( @@ -492,11 +509,16 @@ def _flash_attn_fwd( # TODO: fix GQA + SplitKV + non-varlen if pack_gqa and num_splits != 1 and cu_seqlens_q is None: pack_gqa = False + + if pack_gqa and qv is not None and 128 % qhead_per_kvhead != 0: + pack_gqa = False if max_seqlen_q is None: max_seqlen_q = seqlen_q if cu_seqlens_q is None else total_q if max_seqlen_k is None: max_seqlen_k = seqlen_k + if cu_seqlens_k is None and seqused_k is None: + min_seqlen_k = seqlen_k seqlen_q_packgqa = max_seqlen_q * qhead_per_kvhead if arch // 10 == 10: q_stage = 2 if seqlen_q_packgqa > tile_m else 1 @@ -515,7 +537,7 @@ def _flash_attn_fwd( # SplitKV uses float32 partial output, which doubles the O buffer size # in shared memory, causing OOM for diff-headdim (192, 128) if arch // 10 in [10, 11] and head_dim != head_dim_v and num_splits > 1: - if num_n_blocks >= 64: + if num_n_blocks >= 64 and head_dim_v != 512: tile_n = 64 num_n_blocks = (seqlen_k_loaded + tile_n - 1) // tile_n num_splits = num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, 128) @@ -543,13 +565,16 @@ def _flash_attn_fwd( and (tile_m % qhead_per_kvhead == 0 or not pack_gqa) ) - # hash score and mask mods for compile cache - score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False - mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False - if softcap is not None: assert score_mod is None, "softcap and score_mod cannot be used together" score_mod = utils.create_softcap_scoremod(softcap) + elif score_mod is not None: + if arch // 10 == 8: + raise NotImplementedError("Custom user-provided score_mod is not supported on SM8x architectures.") + + # hash score and mask mods for compile cache + score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False + mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False is_varlen = ( cu_seqlens_q is not None @@ -558,6 +583,13 @@ def _flash_attn_fwd( or seqused_k is not None ) + # CLC regressed for varlen MHA and dense noncausal. Imbalanced varlen shapes + # keep more K/V blocks in flight and hurt L2; dense noncausal mostly just + # pays work-stealing overhead. + is_varlen_mha = is_varlen and qhead_per_kvhead == 1 + is_dense_noncausal = not is_varlen and not causal and not local + use_clc_scheduler = requested_use_clc_scheduler and not is_varlen_mha and not is_dense_noncausal + if mask_mod is not None: if is_varlen: raise NotImplementedError( @@ -602,6 +634,44 @@ def _flash_attn_fwd( else: aux_tensor_metadata = None + if qv is not None: + assert arch // 10 in [10, 11], "only support Blackwell arch with qv" + assert qv.shape[:-1] == q.shape[:-1] + assert qv.shape[-1] == head_dim_v + assert head_dim == 64 and head_dim_v == 512, "only support MLA weight absorbed shape with qv" + assert not local, "local not yet supported with qv" + assert page_table is None, "page table not yet supported with qv" + assert q_descale is None and k_descale is None and v_descale is None, ( + "q_descale/k_descale/v_descale are not yet supported with qv" + ) + + assert not is_split_kv, "split kv not supported with qv" + assert learnable_sink is None + assert softcap is None + assert score_mod is None + assert mask_mod is None + + qv = maybe_contiguous(qv) + + gather_kv_length = 2048 + sparse_kv = gather_kv_indices is not None + disable_sparse_kv_bitmask = False + if sparse_kv: + assert gather_kv_indices.shape[:-1] == q.shape[:-2] + gather_kv_length = gather_kv_indices.shape[-1] + assert gather_kv_length % 256 == 0 + if min_seqlen_k is None or causal: + disable_sparse_kv_bitmask = False + else: + # seqlen_k_boundary = min_seqlen_k - max_seqlen_q + 1 if causal else min_seqlen_k + seqlen_k_boundary = min_seqlen_k + disable_sparse_kv_bitmask = seqlen_k_boundary >= gather_kv_length + else: + assert gather_kv_indices is None, "gather_kv_indices is only supported with qv" + gather_kv_length = None + sparse_kv = None + disable_sparse_kv_bitmask = None + compile_key = ( dtype, head_dim, @@ -622,6 +692,9 @@ def _flash_attn_fwd( window_size_left is not None, window_size_right is not None, learnable_sink is not None, + q_descale is not None, + k_descale is not None, + v_descale is not None, tile_m, tile_n, q_stage, @@ -635,6 +708,10 @@ def _flash_attn_fwd( mma_pv_is_rs, intra_wg_overlap, requested_use_clc_scheduler, + qv is not None, + gather_kv_length, + sparse_kv, + disable_sparse_kv_bitmask, fa_logging.get_fa_log_level(), ) if compile_key not in _flash_attn_fwd.compile_cache: @@ -665,6 +742,33 @@ def _flash_attn_fwd( else: lse_tensor = None + q_descale_tensor = ( + to_cute_tensor(q_descale, assumed_align=4, leading_dim=1) + if q_descale is not None + else None + ) + k_descale_tensor = ( + to_cute_tensor(k_descale, assumed_align=4, leading_dim=1) + if k_descale is not None + else None + ) + v_descale_tensor = ( + to_cute_tensor(v_descale, assumed_align=4, leading_dim=1) + if v_descale is not None + else None + ) + descale_tensors_tensor = ( + DescaleTensors( + q_descale=q_descale_tensor, + k_descale=k_descale_tensor, + v_descale=v_descale_tensor, + ) + if q_descale_tensor is not None + or k_descale_tensor is not None + or v_descale_tensor is not None + else None + ) + sparse_tensors = None if normalized_block_sparse_tensors is not None: sparse_tensors = to_cute_block_sparse_tensors(normalized_block_sparse_tensors) @@ -674,6 +778,9 @@ def _flash_attn_fwd( if aux_tensors is not None: cute_aux_tensors = [to_cute_aux_tensor(buf) for buf in aux_tensors] + qv_tensor = to_cute_tensor(qv) if qv is not None else None + gather_kv_indices_tensor = to_cute_tensor(gather_kv_indices) if gather_kv_indices is not None else None + if arch // 10 == 8: assert page_table is None, "paged KV not supported on SM 8.0" assert not is_split_kv, "SplitKV not supported on SM 8.0" @@ -719,31 +826,44 @@ def _flash_attn_fwd( paged_kv_non_tma=page_size not in [None, tile_n], ) elif arch // 10 in [10, 11]: - fa_fwd = FlashAttentionForwardSm100( - head_dim, - head_dim_v, - qhead_per_kvhead=qhead_per_kvhead, - is_causal=causal, - is_local=local, - is_split_kv=is_split_kv, - pack_gqa=pack_gqa, - m_block_size=tile_m, - n_block_size=tile_n, - q_stage=q_stage, - is_persistent=not causal - and not local - and cu_seqlens_q is None - and seqused_q is None - and not is_split_kv, - score_mod=score_mod, - mask_mod=mask_mod, - has_aux_tensors=aux_tensors is not None, - paged_kv_non_tma=page_size not in [None, tile_n], - is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, - q_subtile_factor=q_subtile_factor, - use_2cta_instrs=use_2cta_instrs, - use_clc_scheduler=requested_use_clc_scheduler, - ) + if qv is not None: + fa_fwd = FlashAttentionMLAForwardSm100( + is_causal=causal, + use_cpasync_load_KV=sparse_kv, + topk_length=gather_kv_length, + is_topk_gather=sparse_kv, + pack_gqa=pack_gqa, + qhead_per_kvhead=qhead_per_kvhead, + nheads_kv=num_head_kv, + is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, + disable_bitmask=disable_sparse_kv_bitmask, + ) + else: + fa_fwd = FlashAttentionForwardSm100( + head_dim, + head_dim_v, + qhead_per_kvhead=qhead_per_kvhead, + is_causal=causal, + is_local=local, + is_split_kv=is_split_kv, + pack_gqa=pack_gqa, + m_block_size=tile_m, + n_block_size=tile_n, + q_stage=q_stage, + is_persistent=not causal + and not local + and cu_seqlens_q is None + and seqused_q is None + and not is_split_kv, + score_mod=score_mod, + mask_mod=mask_mod, + has_aux_tensors=aux_tensors is not None, + paged_kv_non_tma=page_size not in [None, tile_n], + is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, + q_subtile_factor=q_subtile_factor, + use_2cta_instrs=use_2cta_instrs, + use_clc_scheduler=use_clc_scheduler, + ) elif arch // 10 == 12: # SM120 (Blackwell GeForce / DGX Spark): uses SM80 MMA with SM120 SMEM capacity assert not use_block_sparsity, "Block sparsity not supported on SM 12.0" @@ -771,51 +891,113 @@ def _flash_attn_fwd( f"Unsupported compute capability: {arch}. Supported: 8.x, 9.x, 10.x, 11.x, 12.x" ) # TODO: check @can_implement - _flash_attn_fwd.compile_cache[compile_key] = cute.compile( - fa_fwd, - q_tensor, - k_tensor, - v_tensor, - o_tensor, - lse_tensor, - softmax_scale, - cu_seqlens_q_tensor, - cu_seqlens_k_tensor, - seqused_q_tensor, - seqused_k_tensor, - page_table_tensor, - window_size_left, - window_size_right, - learnable_sink_tensor, - sparse_tensors, - cute_aux_tensors, - current_stream, - options="--enable-tvm-ffi", - ) + if qv is not None: + _flash_attn_fwd.compile_cache[compile_key] = cute.compile( + fa_fwd, + q_tensor, + qv_tensor, + k_tensor, + v_tensor, + o_tensor, + lse_tensor, + softmax_scale, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + gather_kv_indices_tensor, + page_table_tensor, + window_size_left, + window_size_right, + current_stream, + options="--enable-tvm-ffi", + ) + else: + compile_args = [ + fa_fwd, + q_tensor, + k_tensor, + v_tensor, + o_tensor, + lse_tensor, + softmax_scale, + cu_seqlens_q_tensor, + cu_seqlens_k_tensor, + seqused_q_tensor, + seqused_k_tensor, + page_table_tensor, + window_size_left, + window_size_right, + learnable_sink_tensor, + sparse_tensors, + cute_aux_tensors, + current_stream, + ] + if arch // 10 in [10, 11]: + compile_args.insert(-3, descale_tensors_tensor) + _flash_attn_fwd.compile_cache[compile_key] = cute.compile(*compile_args, options="--enable-tvm-ffi") # In "fake mode", we will take torch fake tensors as input and the expected behaviors are: # - Use those fake metadata to populate compilation cache # - Return "fake" output tensors, which could be needed in follow-up fake operations # Thus, we skip the actual kernel invocation here. if not is_fake_mode(): - _flash_attn_fwd.compile_cache[compile_key]( - q.detach(), - k.detach(), - v.detach(), - out.detach() if not is_split_kv else out_partial, - lse_partial if is_split_kv else lse, - softmax_scale, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - page_table, - window_size_left, - window_size_right, - learnable_sink, - normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, - aux_tensors, + q_call, k_call, v_call = q.detach(), k.detach(), v.detach() + qv_call = qv.detach() if qv is not None else None + if is_fp8: + # need uint8 workaround until we pin torch >= 2.11.0 where fp8 export is supported + q_call = q_call.view(torch.uint8) + k_call = k_call.view(torch.uint8) + v_call = v_call.view(torch.uint8) + if qv_call is not None: + qv_call = qv_call.view(torch.uint8) + descale_tensors = ( + DescaleTensors(q_descale=q_descale, k_descale=k_descale, v_descale=v_descale) + if q_descale is not None or k_descale is not None or v_descale is not None + else None ) + if qv is not None: + _flash_attn_fwd.compile_cache[compile_key]( + q_call, + qv_call, + k_call, + v_call, + out.detach(), + lse, + softmax_scale, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + gather_kv_indices, + page_table, + window_size_left, + window_size_right, + ) + else: + call_args = [ + q_call, + k_call, + v_call, + out.detach() if not is_split_kv else out_partial, + lse_partial if is_split_kv else lse, + softmax_scale, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + page_table, + window_size_left, + window_size_right, + learnable_sink, + ] + if arch // 10 in [10, 11]: + call_args.append(descale_tensors) + call_args.extend([ + normalized_block_sparse_tensors[:4] if normalized_block_sparse_tensors is not None else None, + aux_tensors, + ]) + _flash_attn_fwd.compile_cache[compile_key](*call_args) if is_split_kv: _flash_attn_fwd_combine( out_partial, @@ -1171,12 +1353,20 @@ def _flash_attn_bwd( pack_gqa = qhead_per_kvhead > 1 # pack_gqa backward not yet supported in bwd pack_gqa = False - if score_mod is not None: + + if softcap != 0.0: + assert score_mod is None and score_mod_bwd is None, ( + "softcap and score_mod/score_mod_bwd cannot be used together" + ) + score_mod = utils.create_softcap_scoremod(softcap) + score_mod_bwd = utils.create_softcap_scoremod_bwd(softcap) + elif score_mod is not None: assert score_mod_bwd is not None, "score_mod_bwd is required when score_mod is provided" - assert softcap == 0.0, "softcap and score_mod are mutually exclusive (different log2 scaling)" assert cu_seqlens_q is None and cu_seqlens_k is None, ( "varlen + score_mod not supported in bwd yet" ) + if arch // 10 == 8: + raise NotImplementedError("Custom user-provided score_mod is not supported on SM8x architectures.") device = q.device out_torch_dtype = q.dtype @@ -1322,7 +1512,6 @@ def _flash_attn_bwd( causal, window_size_left is not None, window_size_right is not None, - softcap != 0.0, m_block_size, n_block_size, num_threads, @@ -1352,6 +1541,9 @@ def _flash_attn_bwd( get_broadcast_dims(k), get_broadcast_dims(v), get_broadcast_dims(dout), + # Prevent TVM stride poisoning when only one block is present. + (seqlen_q_rounded // m_block_size == 1), + (seqlen_k_rounded // n_block_size == 1), ) else: compile_key = ( @@ -1363,7 +1555,6 @@ def _flash_attn_bwd( causal, window_size_left is not None, window_size_right is not None, - softcap != 0.0, m_block_size, n_block_size, num_threads, @@ -1385,6 +1576,9 @@ def _flash_attn_bwd( get_broadcast_dims(k), get_broadcast_dims(v), get_broadcast_dims(dout), + # Prevent TVM stride poisoning when only one block is present. + (seqlen_q_rounded // m_block_size == 1), + (seqlen_k_rounded // n_block_size == 1), ) if compile_key not in _flash_attn_bwd.compile_cache: q_tensor, k_tensor, v_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ @@ -1427,6 +1621,8 @@ def _flash_attn_bwd( AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs=V_in_regs, + score_mod=score_mod, + score_mod_bwd=score_mod_bwd, ) elif arch // 10 == 9: fa_bwd_obj = FlashAttentionBackwardSm90( @@ -1498,7 +1694,6 @@ def _flash_attn_bwd( cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, - None, # softcap - not yet supported in backward window_size_left, window_size_right, dQ_semaphore_tensor, @@ -1525,7 +1720,6 @@ def _flash_attn_bwd( cu_seqlens_k, seqused_q, seqused_k, - None, # softcap - not yet supported in backward window_size_left, window_size_right, dQ_semaphore, @@ -1583,6 +1777,8 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + qv: Optional[torch.Tensor] = None, + gather_kv_indices: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), @@ -1613,6 +1809,7 @@ def forward( q, k, v, + qv=qv, softmax_scale=softmax_scale, causal=causal, window_size_left=window_size[0], @@ -1624,6 +1821,7 @@ def forward( mask_mod=mask_mod, block_sparse_tensors=block_sparse_tensors, return_lse=return_lse, + gather_kv_indices=gather_kv_indices, ) ctx.save_for_backward(q, k, v, out, lse) ctx.softmax_scale = softmax_scale @@ -1657,7 +1855,7 @@ def backward(ctx, dout, dlse): deterministic=ctx.deterministic, dlse=dlse, ) - return dq, dk, dv, *((None,) * 20) # Extra Nones is fine + return dq, dk, dv, *((None,) * 30) # Extra Nones is fine class FlashAttnVarlenFunc(torch.autograd.Function): @@ -1667,12 +1865,15 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens_q: Optional[torch.Tensor], - cu_seqlens_k: Optional[torch.Tensor], + qv: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, + min_seqlen_k: Optional[int] = None, + gather_kv_indices: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, @@ -1690,12 +1891,14 @@ def forward( q, k, v, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, + qv=qv, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + seqused_q=seqused_q, + seqused_k=seqused_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, + min_seqlen_k=min_seqlen_k, page_table=page_table, softmax_scale=softmax_scale, causal=causal, @@ -1708,6 +1911,7 @@ def forward( score_mod=score_mod, aux_tensors=aux_tensors, return_lse=return_lse, + gather_kv_indices=gather_kv_indices, ) ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ctx.softmax_scale = softmax_scale @@ -1724,7 +1928,6 @@ def forward( @staticmethod def backward(ctx, dout, dlse): q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors - assert ctx.softcap == 0.0 if not ctx.return_lse: dlse = None if dout is None: @@ -1751,13 +1954,15 @@ def backward(ctx, dout, dlse): dlse=dlse, ) - return dq, dk, dv, *((None,) * 20) + return dq, dk, dv, *((None,) * 30) def flash_attn_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + qv: Optional[torch.Tensor] = None, + gather_kv_indices: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), @@ -1778,6 +1983,8 @@ def flash_attn_func( q, k, v, + qv, + gather_kv_indices, softmax_scale, causal, window_size, @@ -1800,12 +2007,15 @@ def flash_attn_varlen_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + qv: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, max_seqlen_k: Optional[int] = None, + min_seqlen_k: Optional[int] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + gather_kv_indices: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, @@ -1819,16 +2029,33 @@ def flash_attn_varlen_func( aux_tensors: Optional[list] = None, return_lse: bool = False, ): + """ + Explanation of some optional arguments: + + qv: we write the MLA weight absorbed formula as + O = softmax(scale * (Q @ K.T + Qv @ V.T)) @ V + where Q = q_pe, Qv = q_nope, K = pe_cache, V = kv_cache. + + gather_kv_indices: a tensor of shape (batch, seqlen_q, gather_kv_length) or + (total_q, gather_kv_length) if there is cu_seqlens_q. + Currently, only used for topk sparsity with MLA absorption kernel. + + min_seqlen_k: for varlen, specifies the minimum kv sequence length for any batch. + Used with gather_kv_indices to determine if we need oob masking. + """ return FlashAttnVarlenFunc.apply( q, k, v, + qv, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, max_seqlen_q, max_seqlen_k, + min_seqlen_k, + gather_kv_indices, page_table, softmax_scale, causal, diff --git a/flash_sparse_attn/ops/cute/mask.py b/flash_sparse_attn/ops/cute/mask.py index 6b5ca16..99e7008 100644 --- a/flash_sparse_attn/ops/cute/mask.py +++ b/flash_sparse_attn/ops/cute/mask.py @@ -5,7 +5,7 @@ import cutlass import cutlass.cute as cute -from cutlass import Float32, Int32, Uint32, const_expr +from cutlass import Float32, Int32, Uint32, const_expr, Boolean from quack import layout_utils import flash_attn.cute.utils as utils @@ -384,6 +384,8 @@ def apply_mask_sm100( fastdiv_mods=(None, None), head_divmod=None, check_q_boundary: bool = False, + r2p: bool = True, + rBitmask: Optional[cute.Tensor] = None, ) -> None: assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" acc_shape = (self.tile_m, self.tile_n) @@ -397,8 +399,18 @@ def apply_mask_sm100( if n_block < 0: n_block = 0 seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - r2p = True - if const_expr(not mask_causal and not mask_local and mask_mod is None): + + if const_expr(rBitmask is not None): + ncol_packed = const_expr(cute.size(rBitmask.shape[0])) + for i in cutlass.range_constexpr(ncol_packed): + col_start = 32 * i # mask is bit-packed into uint32 + curr_mask_val = rBitmask[i] + for j in cutlass.range_constexpr(32): + curr_col = col_start + j + mask = (curr_mask_val >> j) & 1 + acc_S[curr_col] = acc_S[curr_col] if Boolean(mask) else -Float32.inf + + elif const_expr(not mask_causal and not mask_local and mask_mod is None): if const_expr(mask_seqlen): if const_expr(not r2p): for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): diff --git a/flash_sparse_attn/ops/cute/mma_sm100_desc.py b/flash_sparse_attn/ops/cute/mma_sm100_desc.py index ab8dd09..9a2adfc 100644 --- a/flash_sparse_attn/ops/cute/mma_sm100_desc.py +++ b/flash_sparse_attn/ops/cute/mma_sm100_desc.py @@ -83,9 +83,9 @@ def to_UMMA_format(cutlass_type) -> int: if cutlass_type is cutlass.TFloat32: return F16F32Format.TF32 # Float-8 / Float-6 / Float-4 – add whenever CUTLASS exposes them - if cutlass_type is cutlass.FloatE4M3FN: + if cutlass_type is cutlass.Float8E4M3FN: return MXF8F6F4Format.E4M3 - if cutlass_type is cutlass.FloatE5M2: + if cutlass_type is cutlass.Float8E5M2: return MXF8F6F4Format.E5M2 raise TypeError(f"Unsupported CUTLASS scalar type for A/B: {cutlass_type!r}") diff --git a/flash_sparse_attn/ops/cute/named_barrier.py b/flash_sparse_attn/ops/cute/named_barrier.py index dd0d198..c4536da 100644 --- a/flash_sparse_attn/ops/cute/named_barrier.py +++ b/flash_sparse_attn/ops/cute/named_barrier.py @@ -45,3 +45,12 @@ class NamedBarrierBwdSm100(enum.IntEnum): Compute = enum.auto() dQaccReduce = enum.auto() TmemPtr = enum.auto() + + +class NamedBarrierFwdSm100_MLA2CTA(enum.IntEnum): + Epilogue = enum.auto() + TmemPtr = enum.auto() + Cpasync = enum.auto() + Softmax = enum.auto() + SoftmaxStatsFull = enum.auto() + SoftmaxStatsEmpty = enum.auto() diff --git a/flash_sparse_attn/ops/cute/softmax.py b/flash_sparse_attn/ops/cute/softmax.py index eed55a0..9369e0d 100644 --- a/flash_sparse_attn/ops/cute/softmax.py +++ b/flash_sparse_attn/ops/cute/softmax.py @@ -7,7 +7,7 @@ import cutlass import cutlass.cute as cute -from cutlass import Float32 +from cutlass import Float32, Boolean from quack import layout_utils import flash_attn.cute.utils as utils @@ -169,12 +169,14 @@ def rescale_O(self, acc_O: cute.Tensor, row_scale: cute.Tensor) -> None: @dataclass class SoftmaxSm100(Softmax): rescale_threshold: cutlass.Constexpr[float] = 0.0 + max_offset: cutlass.Constexpr[int] = 0 @staticmethod def create( scale_log2: Float32, rescale_threshold: cutlass.Constexpr[float] = 0.0, softmax_scale: Float32 | None = None, + max_offset: cutlass.Constexpr[int] = 0, ): num_rows = 1 arch = 100 @@ -188,8 +190,40 @@ def create( arch, softmax_scale, rescale_threshold=rescale_threshold, + max_offset=max_offset, ) + @cute.jit + def compute_row_max_local(self, acc_S_row: cute.TensorSSA, is_first: Boolean) -> Float32: + if cutlass.const_expr(is_first): + row_max_new = self._compute_row_max(acc_S_row) + else: + row_max_old = self.row_max[0] + row_max_new = self._compute_row_max(acc_S_row, init_val=row_max_old) + return row_max_new + + @cute.jit + def update_row_max_from_local( + self, + row_max_new: Float32, + is_first: Boolean, + ) -> Tuple[Float32, Float32]: + if cutlass.const_expr(is_first): + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale = 0.0 + else: + row_max_old = self.row_max[0] + row_max_safe = row_max_new if row_max_new != -cutlass.Float32.inf else 0.0 + acc_scale_ = (row_max_old - row_max_safe) * self.scale_log2 + acc_scale = cute.math.exp2(acc_scale_) + if cutlass.const_expr(self.rescale_threshold > 0.0): + if acc_scale_ >= -self.rescale_threshold: + row_max_new = row_max_old + row_max_safe = row_max_old + acc_scale = 1.0 + self.row_max[0] = row_max_new + return row_max_safe, acc_scale + @cute.jit def update_row_max(self, acc_S_row: cute.TensorSSA, is_first: int) -> Tuple[Float32, Float32]: if cutlass.const_expr(is_first): @@ -227,11 +261,13 @@ def scale_subtract_rowmax( ): assert cute.size(acc_S_row.shape) % 2 == 0, "acc_S_row must have an even number of elements" row_max_scaled = row_max * self.scale_log2 + max_offset = Float32(self.max_offset) + bias = max_offset - row_max_scaled for i in cutlass.range(0, cute.size(acc_S_row.shape), 2, unroll_full=True): acc_S_row[i], acc_S_row[i + 1] = cute.arch.fma_packed_f32x2( (acc_S_row[i], acc_S_row[i + 1]), (self.scale_log2, self.scale_log2), - (-row_max_scaled, -row_max_scaled), + (bias, bias), ) @cute.jit diff --git a/flash_sparse_attn/ops/cute/testing.py b/flash_sparse_attn/ops/cute/testing.py index 6e3c40e..6e4bfed 100644 --- a/flash_sparse_attn/ops/cute/testing.py +++ b/flash_sparse_attn/ops/cute/testing.py @@ -91,20 +91,23 @@ def pad_input(hidden_states, indices, batch, seqlen): return rearrange(output, "(b s) ... -> b s ...", b=batch) -def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False): +def generate_random_padding_mask( + max_seqlen, batch_size, device, mode="random", zero_lengths=False, min_seqlen=None +): assert mode in ["full", "random", "third"] + min_seqlen = min_seqlen if min_seqlen is not None else 0 if zero_lengths else 1 if mode == "full": lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) elif mode == "random": lengths = torch.randint( - max(0 if zero_lengths else 1, max_seqlen - 20), + max(min_seqlen, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device, ) else: lengths = torch.randint( - max(0 if zero_lengths else 1, max_seqlen // 3), + max(min_seqlen, max_seqlen // 3), max_seqlen + 1, (batch_size, 1), device=device, @@ -343,6 +346,8 @@ def attention_ref( upcast=True, reorder_ops=False, intermediate_dtype=None, + return_lse=False, + gather_kv_indices=None, ): if causal: window_size = (window_size[0], 0) @@ -399,10 +404,21 @@ def attention_ref( local_mask = ( torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask ) + if gather_kv_indices is not None: + batch = q.shape[0] + topk_len = gather_kv_indices.shape[2] + if topk_len < seqlen_k: + topk_index_mask = torch.full( + (batch, seqlen_q, seqlen_k), False, device="cuda" + ).scatter_(-1, gather_kv_indices, True) + scores.masked_fill_(rearrange(~topk_index_mask, "b t s -> b 1 t s"), float("-inf")) if local_mask is not None: scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias + # After all masks are applied, before softmax: + # scores shape: [b, h, t, s] + lse = torch.logsumexp(scores, dim=-1) # [b, h, t] if learnable_sink is None: attention = torch.softmax(scores, dim=-1).to(v.dtype) else: @@ -414,6 +430,8 @@ def attention_ref( normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp( learnable_sink - logits_or_sinks_max ) + # LSE with sink: log(Z) = log(normalizer) + max + lse = (torch.log(normalizer.squeeze(-1)) + logits_or_sinks_max.squeeze(-1)).to(dtype_og) attention = (unnormalized_scores / normalizer).to(v.dtype) if query_padding_mask is not None: attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) @@ -431,6 +449,8 @@ def attention_ref( output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + if return_lse: + return output.to(dtype_og), attention.to(dtype_og), lse.to(dtype_og) return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) diff --git a/flash_sparse_attn/ops/cute/tile_scheduler.py b/flash_sparse_attn/ops/cute/tile_scheduler.py index 3ee4bc8..ae57858 100644 --- a/flash_sparse_attn/ops/cute/tile_scheduler.py +++ b/flash_sparse_attn/ops/cute/tile_scheduler.py @@ -402,6 +402,7 @@ class Params(ParamsBase): cluster_shape_m: cutlass.Constexpr[int] = 1 scheduling_mode: cutlass.Constexpr[SchedulingMode] = SchedulingMode.STATIC lpt: cutlass.Constexpr[bool] = True + use_cluster_idx: cutlass.Constexpr[bool] = True @staticmethod @cute.jit @@ -445,6 +446,7 @@ def create( cluster_shape_m=args.cluster_shape_mn[0], scheduling_mode=scheduling_mode, lpt=args.lpt, + use_cluster_idx=args.use_cluster_idx, ) def __init__( @@ -532,12 +534,19 @@ def clc_work_to_coords(self, work) -> WorkTileInfo: block_idx = block_idx // self.params.cluster_shape_m if const_expr(self.params.lpt): # Longest-processing-time-first: reverse block order - block_idx = self.params.num_block - 1 - block_idx + if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): + num_block = self.params.num_block // self.params.cluster_shape_m + else: + num_block = self.params.num_block + block_idx = num_block - 1 - block_idx split_idx = Int32(0) if const_expr(self.params.is_split_kv): batch_idx, split_idx = divmod(work.tile_idx[2], self.params.num_splits_divmod) else: batch_idx = work.tile_idx[2] + if const_expr(self.params.cluster_shape_m > 1 and not self.params.use_cluster_idx): + bidx_in_cluster = cute.arch.block_in_cluster_idx() + block_idx = block_idx * self.params.cluster_shape_m + bidx_in_cluster[0] return WorkTileInfo( (Int32(block_idx), Int32(work.tile_idx[1]), Int32(batch_idx), Int32(split_idx)), work.is_valid_tile, diff --git a/flash_sparse_attn/ops/cute/topk_gather_kv.py b/flash_sparse_attn/ops/cute/topk_gather_kv.py new file mode 100644 index 0000000..67169fb --- /dev/null +++ b/flash_sparse_attn/ops/cute/topk_gather_kv.py @@ -0,0 +1,274 @@ +from typing import Type, Optional +from dataclasses import dataclass +import operator + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync +from cutlass import Int32, Uint32, const_expr, Boolean + +from flash_attn.cute import utils +from flash_attn.cute.utils import warp_reduce +from quack.cute_dsl_utils import ParamsBase + +import math + + +@dataclass +class CpasyncGatherKVManager(ParamsBase): + mIndexTopk: cute.Tensor + sBitmask: cute.Tensor + + cta_rank_in_cluster: Int32 + thread_idx: Int32 + warp_idx: Int32 + + topk_length: Int32 + seqlen_k_limit: Int32 + tile_n: Int32 + num_threads: cutlass.Constexpr[Int32] + hdim: cutlass.Constexpr[Int32] + hdim_v: cutlass.Constexpr[Int32] + num_hdimv_splits: cutlass.Constexpr[Int32] + cta_group_size: cutlass.Constexpr[Int32] + + gmem_threads_per_row: cutlass.Constexpr[Int32] + topk_indices_per_thread: Int32 + async_copy_elems: Int32 + + gmem_tiled_copy_KV: cute.TiledCopy + gmem_thr_copy_KV: cute.TiledCopy + + rTopk: cute.Tensor + rTopkHalf: cute.Tensor + # for bitmask + rTopk_NonInterleaved: cute.Tensor + + pipeline_bitmask: Optional[pipeline.PipelineAsync] + cpasync_barrier: pipeline.NamedBarrier + + disable_bitmask: cutlass.Constexpr[Boolean] + + @staticmethod + def create( + mIndexTopk: cute.Tensor, + sBitmask: cute.Tensor, + cta_rank_in_cluster: Int32, + thread_idx: Int32, + warp_idx: Int32, + topk_length: Int32, + seqlen_k_limit: Int32, + tile_n: cutlass.Constexpr[Int32], + hdim: cutlass.Constexpr[Int32], + hdim_v: cutlass.Constexpr[Int32], + num_hdimv_splits: cutlass.Constexpr[Int32], + num_threads: cutlass.Constexpr[Int32], + dtype: Type[cutlass.Numeric], + cta_group_size: cutlass.Constexpr[Int32], + pipeline_bitmask: Optional[pipeline.PipelineAsync], + num_stages_bitmask: cutlass.Constexpr[Int32], + cpasync_barrier: pipeline.NamedBarrier, + disable_bitmask: cutlass.Constexpr[Boolean], + ): + assert tile_n % num_threads == 0 + assert num_threads == 128 + assert hdim % 64 == 0 + assert (hdim_v // num_hdimv_splits // cta_group_size) % 64 == 0 + assert num_threads % cute.arch.WARP_SIZE == 0 + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // dtype.width + dtype_bytes = dtype.width // 8 + # assumes hdim is never part of transposed operand + gmem_k_block_size = math.gcd( + hdim, + hdim_v // num_hdimv_splits // cta_group_size, + 128 // dtype_bytes, + ) + assert gmem_k_block_size % async_copy_elems == 0 + gmem_threads_per_row = gmem_k_block_size // async_copy_elems + assert cute.arch.WARP_SIZE % gmem_threads_per_row == 0 + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + dtype, + num_bits_per_copy=universal_copy_bits, + ) + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, async_copy_elems)) + gmem_tiled_copy_KV = cute.make_tiled_copy_tv(atom_async_copy, thr_layout, val_layout) + gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(thread_idx) + topk_indices_per_thread = tile_n // num_threads + + rTopk = cute.make_rmem_tensor((topk_indices_per_thread,), Int32) + rTopkHalf = cute.make_rmem_tensor((topk_indices_per_thread,), Int32) + rTopk_NonInterleaved = cute.make_rmem_tensor((topk_indices_per_thread,), Int32) + + return CpasyncGatherKVManager( + mIndexTopk, + sBitmask, + cta_rank_in_cluster, + thread_idx, + warp_idx, + topk_length, + seqlen_k_limit, + tile_n, + num_threads, + hdim, + hdim_v, + num_hdimv_splits, + cta_group_size, + gmem_threads_per_row, + topk_indices_per_thread, + async_copy_elems, + gmem_tiled_copy_KV, + gmem_thr_copy_KV, + rTopk, + rTopkHalf, + rTopk_NonInterleaved, + pipeline_bitmask, + cpasync_barrier, + disable_bitmask, + ) + + @cute.jit + def load_index_topk( + self, + n_block: Int32, + transpose: bool, + ): + entries_per_thread = self.topk_indices_per_thread + rTopk = self.rTopk if const_expr(transpose) else self.rTopkHalf + + for i in cutlass.range_constexpr(entries_per_thread): + row = ( + i * self.num_threads + + (self.thread_idx % self.gmem_threads_per_row) + * (self.num_threads // self.gmem_threads_per_row) + + (self.thread_idx // self.gmem_threads_per_row) + ) + # need this if not offset in load_X + # if const_expr(not transpose): + # row += self.cta_rank_in_cluster * (self.tile_n//self.cta_group_size) + # row = row % self.tile_n + row_idx = n_block * self.tile_n + row + rTopk[i] = self.mIndexTopk[row_idx] + + if const_expr(not transpose and not self.disable_bitmask): + row_non_interleaved = i * self.num_threads + self.thread_idx + row_idx_non_interleaved = n_block * self.tile_n + row_non_interleaved + self.rTopk_NonInterleaved[0] = self.mIndexTopk[row_idx_non_interleaved] + + @cute.jit + def compute_bitmask( + self, + producer_state_bitmask, + ): + lane_idx = cute.arch.lane_idx() + assert cute.size(self.rTopk_NonInterleaved) == 1 + bitmask = Uint32(0) + + # Step 1. Construct per-thread bitmask + topk_idx = self.rTopk_NonInterleaved[0] + is_valid = topk_idx >= 0 and topk_idx < self.seqlen_k_limit + if is_valid: + bitmask = Uint32(1 << lane_idx) + + # Step 2. Warp shuffle bitwise OR = add since indices are exclusive. + bitmask = warp_reduce(bitmask, operator.add) + + self.pipeline_bitmask.producer_acquire(producer_state_bitmask) + # store to smem and sync threads + if lane_idx == 0: + self.sBitmask[self.warp_idx, producer_state_bitmask.index] = bitmask + self.cpasync_barrier.arrive_and_wait() + + self.pipeline_bitmask.producer_commit(producer_state_bitmask) + producer_state_bitmask.advance() + return producer_state_bitmask + + @cute.jit + def compute_X_ptr( + self, + mX: cute.Tensor, + transpose: bool, + ): + entries_per_thread = self.topk_indices_per_thread + tPrXPtr = cute.make_rmem_tensor((entries_per_thread,), cutlass.Int64) + tPrRowValid = cute.make_rmem_tensor((entries_per_thread,), cutlass.Int32) + rTopk = self.rTopk if const_expr(transpose) else self.rTopkHalf + + for i in cutlass.range_constexpr(entries_per_thread): + topk_idx = rTopk[i] + if const_expr(not self.disable_bitmask): + row_valid = topk_idx >= 0 and topk_idx < self.seqlen_k_limit + tPrRowValid[i] = row_valid + if const_expr(not transpose): + tPrXPtr[i] = utils.elem_pointer(mX, (topk_idx, 0)).toint() + else: + tPrXPtr[i] = utils.elem_pointer(mX, (0, topk_idx)).toint() + + return tPrXPtr, tPrRowValid + + @cute.jit + def load_X( + self, + mX: cute.Tensor, + sX: cute.Tensor, + transpose: bool, + K_or_V: str, + ): + assert K_or_V in ("K", "V") + cta_tile_n = self.tile_n if const_expr(transpose) else self.tile_n // self.cta_group_size + head_dim = self.hdim if const_expr(K_or_V == "K") else self.hdim_v // self.num_hdimv_splits + if const_expr(transpose): + head_dim = head_dim // self.cta_group_size + order = (1, 0) if const_expr(transpose) else (0, 1) + + sX_nd_layout = cute.make_ordered_layout((cta_tile_n, head_dim), order=order) + sX_nd = cute.composition(sX, sX_nd_layout) + + cX = cute.make_identity_tensor((cta_tile_n, head_dim)) + tXsX = self.gmem_thr_copy_KV.partition_D(sX_nd) + tXcX = self.gmem_thr_copy_KV.partition_S(cX) + + tPrXPtr, tPrRowValid = self.compute_X_ptr(mX, transpose) + + if const_expr(not transpose): + offset = self.cta_rank_in_cluster * (self.gmem_threads_per_row // self.cta_group_size) + else: + offset = 0 + + for m in cutlass.range_constexpr(cute.size(tXsX, mode=[1])): + if const_expr(not self.disable_bitmask): + row_valid = utils.shuffle_sync( + tPrRowValid[m // self.gmem_threads_per_row], + (m + offset) % self.gmem_threads_per_row, + width=self.gmem_threads_per_row, + ) + should_load = cute.make_fragment_like(tXsX[(0, None), m, 0], Boolean) + should_load.fill(Boolean(row_valid)) + x_ptr_i64 = utils.shuffle_sync( + tPrXPtr[m // self.gmem_threads_per_row], + (m + offset) % self.gmem_threads_per_row, + width=self.gmem_threads_per_row, + ) + x_gmem_ptr = cute.make_ptr( + mX.element_type, x_ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + mX_cur = cute.make_tensor(x_gmem_ptr, cute.make_layout((head_dim,))) + mX_cur_copy = cute.tiled_divide(mX_cur, (self.async_copy_elems,)) + + for k in cutlass.range_constexpr(cute.size(tXsX, mode=[2])): + ki = tXcX[0, 0, k][1] // self.async_copy_elems + mX_cur_copy_ki = mX_cur_copy[None, ki] + tXsX_k = tXsX[None, m, k] + mX_cur_copy_ki = cute.make_tensor(mX_cur_copy_ki.iterator, tXsX_k.layout) + cute.copy( + self.gmem_tiled_copy_KV, + mX_cur_copy_ki, + tXsX_k, + pred=should_load if const_expr(not self.disable_bitmask) else None, + ) diff --git a/flash_sparse_attn/ops/cute/utils.py b/flash_sparse_attn/ops/cute/utils.py index 3118661..a392305 100644 --- a/flash_sparse_attn/ops/cute/utils.py +++ b/flash_sparse_attn/ops/cute/utils.py @@ -60,12 +60,33 @@ _fa_disable_2cta_enabled: bool = os.environ.get("FA_DISABLE_2CTA", "0") == "1" +def _is_cuda_12() -> bool: + """Check if the CUDA toolkit version is 12.x. + + 2CTA forward non-causal has a codegen regression on CUDA 12 that causes + ~18% slowdown compared to 1CTA. This is fixed in CUDA 13.x. + """ + try: + import torch + + cuda_version = torch.version.cuda + if cuda_version is not None: + major = cuda_version.split(".")[0] + return int(major) == 12 + except Exception: + pass + return False + + +_fa_disable_2cta_cuda12: bool = _is_cuda_12() + + def _get_use_clc_scheduler_default() -> bool: return _fa_clc_enabled def _get_disable_2cta_default() -> bool: - return _fa_disable_2cta_enabled + return _fa_disable_2cta_enabled or _fa_disable_2cta_cuda12 def _compute_base_hash(func: Callable) -> str: @@ -126,16 +147,28 @@ def hash_callable( def create_softcap_scoremod(softcap_val): - inv_softcap = 1.0 / softcap_val - @cute.jit - def scoremod_premask_fn(acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, aux_tensors): - scores = acc_S_SSA * inv_softcap - return scores * cute.math.tanh(scores, fastmath=True) + def scoremod_premask_fn( + acc_S_SSA, batch_idx, head_idx, q_idx, kv_idx, seqlen_info, aux_tensors + ): + scores = acc_S_SSA / softcap_val + return softcap_val * cute.math.tanh(scores, fastmath=True) return scoremod_premask_fn +def create_softcap_scoremod_bwd(softcap_val): + @cute.jit + def scoremod_bwd_fn( + grad_out_SSA, score_SSA, batch_idx, head_idx, q_idx, kv_idx, seqlen_info, aux_tensors + ): + scores = score_SSA / softcap_val + tanh_scores = cute.math.tanh(scores, fastmath=True) + return grad_out_SSA * (1.0 - tanh_scores * tanh_scores) + + return scoremod_bwd_fn + + LOG2_E = math.log2(math.e) From 97082f6677f1d9848f4b09026943c14dac4d1e25 Mon Sep 17 00:00:00 2001 From: LoserCheems Date: Tue, 21 Apr 2026 16:46:33 +0800 Subject: [PATCH 3/3] Rewrite vendored CuTe namespace to flash_sparse_attn.ops.cute --- flash_sparse_attn/ops/cute/__init__.py | 2 +- .../ops/cute/benchmark_flash_attention_fp8.py | 6 +-- .../ops/cute/blackwell_helpers.py | 2 +- flash_sparse_attn/ops/cute/block_info.py | 2 +- .../ops/cute/block_sparse_utils.py | 4 +- flash_sparse_attn/ops/cute/block_sparsity.py | 2 +- flash_sparse_attn/ops/cute/cache_utils.py | 2 +- .../ops/cute/compute_block_sparsity.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd.py | 14 +++---- .../ops/cute/flash_bwd_postprocess.py | 10 ++--- .../ops/cute/flash_bwd_preprocess.py | 6 +-- flash_sparse_attn/ops/cute/flash_bwd_sm100.py | 28 ++++++------- flash_sparse_attn/ops/cute/flash_bwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_bwd_sm90.py | 24 +++++------ flash_sparse_attn/ops/cute/flash_fwd.py | 24 +++++------ .../ops/cute/flash_fwd_combine.py | 6 +-- .../ops/cute/flash_fwd_mla_sm100.py | 26 ++++++------ flash_sparse_attn/ops/cute/flash_fwd_sm100.py | 34 ++++++++-------- flash_sparse_attn/ops/cute/flash_fwd_sm120.py | 2 +- flash_sparse_attn/ops/cute/flash_fwd_sm90.py | 28 ++++++------- flash_sparse_attn/ops/cute/interface.py | 40 +++++++++---------- flash_sparse_attn/ops/cute/mask.py | 4 +- flash_sparse_attn/ops/cute/pack_gqa.py | 2 +- flash_sparse_attn/ops/cute/paged_kv.py | 2 +- flash_sparse_attn/ops/cute/softmax.py | 4 +- flash_sparse_attn/ops/cute/tile_scheduler.py | 4 +- flash_sparse_attn/ops/cute/topk_gather_kv.py | 4 +- 27 files changed, 145 insertions(+), 145 deletions(-) diff --git a/flash_sparse_attn/ops/cute/__init__.py b/flash_sparse_attn/ops/cute/__init__.py index be32e14..3dbaebd 100644 --- a/flash_sparse_attn/ops/cute/__init__.py +++ b/flash_sparse_attn/ops/cute/__init__.py @@ -3,7 +3,7 @@ from importlib.metadata import PackageNotFoundError, version try: - __version__ = version("fa4") + __version__ = version("flash-sparse-attn") except PackageNotFoundError: __version__ = "0.0.0" diff --git a/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py b/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py index c79e768..09d5143 100644 --- a/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py +++ b/flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py @@ -1,7 +1,7 @@ # Benchmark FP8 attention for FA4 (CuTe-DSL) on SM100. # # Run (recommended): -# python -m flash_attn.cute.benchmark_flash_attention_fp8 +# python -m flash_sparse_attn.ops.cute.benchmark_flash_attention_fp8 # # Notes: # - This is intended to be used while bringing up FP8 support for SM100. @@ -21,8 +21,8 @@ import torch from einops import rearrange -from flash_attn.cute.benchmark import benchmark_forward -from flash_attn.cute.interface import _flash_attn_fwd as flash_attn_cute_fwd +from flash_sparse_attn.ops.cute.benchmark import benchmark_forward +from flash_sparse_attn.ops.cute.interface import _flash_attn_fwd as flash_attn_cute_fwd try: import cudnn diff --git a/flash_sparse_attn/ops/cute/blackwell_helpers.py b/flash_sparse_attn/ops/cute/blackwell_helpers.py index 4caadce..4bc9fe5 100644 --- a/flash_sparse_attn/ops/cute/blackwell_helpers.py +++ b/flash_sparse_attn/ops/cute/blackwell_helpers.py @@ -7,7 +7,7 @@ from cutlass.cute.nvgpu import tcgen05 from cutlass._mlir.dialects import llvm -import flash_attn.cute.mma_sm100_desc as sm100_desc +import flash_sparse_attn.ops.cute.mma_sm100_desc as sm100_desc def _tcgen05_mma_kind(op: cute.nvgpu.tcgen05.mma.MmaOp) -> str: diff --git a/flash_sparse_attn/ops/cute/block_info.py b/flash_sparse_attn/ops/cute/block_info.py index f210138..cebd0bf 100644 --- a/flash_sparse_attn/ops/cute/block_info.py +++ b/flash_sparse_attn/ops/cute/block_info.py @@ -6,7 +6,7 @@ import cutlass.cute as cute from cutlass import Int32, const_expr -from flash_attn.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK, SeqlenInfoQKNewK @dataclass(frozen=True) diff --git a/flash_sparse_attn/ops/cute/block_sparse_utils.py b/flash_sparse_attn/ops/cute/block_sparse_utils.py index b19dcd3..66edb6a 100644 --- a/flash_sparse_attn/ops/cute/block_sparse_utils.py +++ b/flash_sparse_attn/ops/cute/block_sparse_utils.py @@ -15,8 +15,8 @@ from quack import copy_utils # Import data structures from block_sparsity -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.named_barrier import NamedBarrierBwd +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd # NOTE [SM100 block-sparse empty tiles: mbarrier contract] diff --git a/flash_sparse_attn/ops/cute/block_sparsity.py b/flash_sparse_attn/ops/cute/block_sparsity.py index 3fad8c9..9e34734 100644 --- a/flash_sparse_attn/ops/cute/block_sparsity.py +++ b/flash_sparse_attn/ops/cute/block_sparsity.py @@ -7,7 +7,7 @@ import cutlass.cute as cute import torch -from flash_attn.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor +from flash_sparse_attn.ops.cute.cute_dsl_utils import get_broadcast_dims, to_cute_tensor def ceildiv(a: int, b: int) -> int: diff --git a/flash_sparse_attn/ops/cute/cache_utils.py b/flash_sparse_attn/ops/cute/cache_utils.py index f1b5970..dc04970 100644 --- a/flash_sparse_attn/ops/cute/cache_utils.py +++ b/flash_sparse_attn/ops/cute/cache_utils.py @@ -17,7 +17,7 @@ import cutlass.cute as cute import tvm_ffi from cutlass.cutlass_dsl import JitCompiledFunction -from flash_attn.cute.fa_logging import fa_log +from flash_sparse_attn.ops.cute.fa_logging import fa_log # Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols # (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen. diff --git a/flash_sparse_attn/ops/cute/compute_block_sparsity.py b/flash_sparse_attn/ops/cute/compute_block_sparsity.py index a2dd98e..d986ecb 100644 --- a/flash_sparse_attn/ops/cute/compute_block_sparsity.py +++ b/flash_sparse_attn/ops/cute/compute_block_sparsity.py @@ -6,13 +6,13 @@ import torch from cutlass import Boolean, Int8, Int32, const_expr -from flash_attn.cute.block_sparsity import ( +from flash_sparse_attn.ops.cute.block_sparsity import ( BlockSparseTensors, BlockSparseTensorsTorch, to_cute_block_sparse_tensors, ) -from flash_attn.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK class BlockSparsityKernel: diff --git a/flash_sparse_attn/ops/cute/flash_bwd.py b/flash_sparse_attn/ops/cute/flash_bwd.py index eeb7615..081caa1 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd.py +++ b/flash_sparse_attn/ops/cute/flash_bwd.py @@ -15,14 +15,14 @@ import cutlass.utils as utils_basic from quack import layout_utils -from flash_attn.cute import ampere_helpers as sm80_utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments -from flash_attn.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors class FlashAttentionBackwardSm80: diff --git a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py index 76c8562..36e4e8d 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_postprocess.py @@ -18,13 +18,13 @@ from quack import layout_utils from quack import sm90_utils -from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import ampere_helpers as sm80_utils -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK import cutlass.cute.nvgpu.tcgen05 as tcgen05 from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py index d93ea5c..698b19f 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_preprocess.py @@ -25,10 +25,10 @@ from quack import copy_utils, layout_utils -from flash_attn.cute import utils -from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py index 4b4083e..104fdb5 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm100.py @@ -16,27 +16,27 @@ import quack.activation from quack import layout_utils -from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import copy_utils -from flash_attn.cute import pipeline -from flash_attn.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import copy_utils +from flash_sparse_attn.ops.cute import pipeline +from flash_sparse_attn.ops.cute.blackwell_helpers import gemm_w_idx, gemm_ptx_w_idx # noqa +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, # noqa SingleTileVarlenScheduler, ) -from flash_attn.cute import barrier -from flash_attn.cute.named_barrier import NamedBarrierBwdSm100 -from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute import barrier +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwdSm100 +from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( get_total_q_block_count_bwd, get_block_sparse_iteration_info_bwd, get_m_block_from_iter_bwd, diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py index 556c59e..0941ae2 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 class FlashAttentionBackwardSm120(FlashAttentionBackwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py index c9a690d..755e580 100644 --- a/flash_sparse_attn/ops/cute/flash_bwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_bwd_sm90.py @@ -17,24 +17,24 @@ from quack import sm90_utils from quack.sm90_utils import gemm_zero_init, gemm_w_idx -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute import pipeline +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute import pipeline from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTBwdScheduler, SingleTileVarlenScheduler, ) -from flash_attn.cute import barrier -from flash_attn.cute.named_barrier import NamedBarrierBwd -from flash_attn.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute import barrier +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierBwd +from flash_sparse_attn.ops.cute.softmax import apply_score_mod_inner, apply_score_mod_bwd_inner +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( get_total_q_block_count_bwd, produce_block_sparse_q_loads_bwd_sm90, consume_block_sparse_mma_bwd_sm90, diff --git a/flash_sparse_attn/ops/cute/flash_fwd.py b/flash_sparse_attn/ops/cute/flash_fwd.py index d1a43cf..399efd3 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd.py +++ b/flash_sparse_attn/ops/cute/flash_fwd.py @@ -23,17 +23,17 @@ from quack import copy_utils from quack import layout_utils -from flash_attn.cute import ampere_helpers as sm80_utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import Softmax, apply_score_mod_inner -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.pack_gqa import PackGQA -from flash_attn.cute.named_barrier import NamedBarrierFwd -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments +from flash_sparse_attn.ops.cute import ampere_helpers as sm80_utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.softmax import Softmax, apply_score_mod_inner +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute.pack_gqa import PackGQA +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.tile_scheduler import SingleTileScheduler, SingleTileVarlenScheduler, TileSchedulerArguments class FlashAttentionForwardBase: @@ -1224,6 +1224,6 @@ def apply_score_mod( # SM90 forward pass moved to flash_fwd_sm90.py; re-export for backward compatibility def __getattr__(name): if name == "FlashAttentionForwardSm90": - from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 + from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 return FlashAttentionForwardSm90 raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/flash_sparse_attn/ops/cute/flash_fwd_combine.py b/flash_sparse_attn/ops/cute/flash_fwd_combine.py index 4936202..7f38c2b 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_combine.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_combine.py @@ -12,9 +12,9 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Float32, Int32, Boolean, const_expr -from flash_attn.cute import utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute.seqlen_info import SeqlenInfo +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfo from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py index 07cd99f..046457b 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py @@ -19,13 +19,13 @@ from quack import copy_utils -from flash_attn.cute.pack_gqa import pack_gqa_layout, make_packgqa_tiled_tma_atom -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.mask import AttentionMask -import flash_attn.cute.blackwell_helpers as fa_sm100_utils -from flash_attn.cute.softmax import SoftmaxSm100 -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.pack_gqa import pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute.mask import AttentionMask +import flash_sparse_attn.ops.cute.blackwell_helpers as fa_sm100_utils +from flash_sparse_attn.ops.cute.softmax import SoftmaxSm100 +from flash_sparse_attn.ops.cute.tile_scheduler import ( ClcState, SchedulingMode, TileSchedulerArguments, @@ -35,16 +35,16 @@ SingleTileVarlenScheduler, ParamsBase, ) -from flash_attn.cute.fa_logging import fa_log, fa_printf -from flash_attn.cute.utils import smid +from flash_sparse_attn.ops.cute.fa_logging import fa_log, fa_printf +from flash_sparse_attn.ops.cute.utils import smid -from flash_attn.cute.topk_gather_kv import CpasyncGatherKVManager +from flash_sparse_attn.ops.cute.topk_gather_kv import CpasyncGatherKVManager -from flash_attn.cute.testing import attention_ref +from flash_sparse_attn.ops.cute.testing import attention_ref -from flash_attn.cute.named_barrier import NamedBarrierFwdSm100_MLA2CTA +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwdSm100_MLA2CTA -from flash_attn.cute.cute_dsl_utils import dump_kernel_attributes +from flash_sparse_attn.ops.cute.cute_dsl_utils import dump_kernel_attributes class FlashAttentionMLAForwardSm100: diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py index 75e767c..040bb43 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py @@ -33,29 +33,29 @@ from quack import copy_utils, layout_utils -from flash_attn.cute.paged_kv import PagedKVManager -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -import flash_attn.cute.pipeline as pipeline_custom +from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +import flash_sparse_attn.ops.cute.pipeline as pipeline_custom import cutlass.pipeline as cutlass_pipeline -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import SoftmaxSm100, apply_score_mod_inner -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.softmax import SoftmaxSm100, apply_score_mod_inner +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( get_total_block_count, produce_block_sparse_loads_sm100, softmax_block_sparse_sm100, handle_block_sparse_empty_tile_correction_sm100, ) -from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout -from flash_attn.cute import mma_sm100_desc as sm100_desc -from flash_attn.cute import blackwell_helpers as sm100_utils -from flash_attn.cute.named_barrier import NamedBarrierFwdSm100 +from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout +from flash_sparse_attn.ops.cute import mma_sm100_desc as sm100_desc +from flash_sparse_attn.ops.cute import blackwell_helpers as sm100_utils +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwdSm100 from cutlass.cute import FastDivmodDivisor from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( ClcState, SchedulingMode, TileSchedulerArguments, @@ -65,8 +65,8 @@ SingleTileLPTScheduler, SingleTileVarlenScheduler, ) -from flash_attn.cute.fa_logging import fa_log, fa_printf -from flash_attn.cute.utils import smid +from flash_sparse_attn.ops.cute.fa_logging import fa_log, fa_printf +from flash_sparse_attn.ops.cute.utils import smid # === TUNING KNOBS (agent-editable) === # Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool) diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py index 08d219a..f9917d8 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm120.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm120.py @@ -8,7 +8,7 @@ import cutlass import cutlass.utils as utils_basic -from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 class FlashAttentionForwardSm120(FlashAttentionForwardSm80): diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py index 4108ce4..8ae0e77 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm90.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm90.py @@ -21,23 +21,23 @@ from quack import layout_utils from quack import sm90_utils -from flash_attn.cute.cute_dsl_utils import assume_tensor_aligned -from flash_attn.cute import utils -from flash_attn.cute.mask import AttentionMask -from flash_attn.cute.softmax import Softmax, apply_score_mod_inner -from flash_attn.cute.seqlen_info import SeqlenInfoQK -from flash_attn.cute.block_info import BlockInfo -from flash_attn.cute.block_sparsity import BlockSparseTensors -from flash_attn.cute.block_sparse_utils import ( +from flash_sparse_attn.ops.cute.cute_dsl_utils import assume_tensor_aligned +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.mask import AttentionMask +from flash_sparse_attn.ops.cute.softmax import Softmax, apply_score_mod_inner +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.block_info import BlockInfo +from flash_sparse_attn.ops.cute.block_sparsity import BlockSparseTensors +from flash_sparse_attn.ops.cute.block_sparse_utils import ( produce_block_sparse_loads, consume_block_sparse_loads, ) -from flash_attn.cute import pipeline as pipeline_custom -from flash_attn.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom -from flash_attn.cute.paged_kv import PagedKVManager -from flash_attn.cute.named_barrier import NamedBarrierFwd +from flash_sparse_attn.ops.cute import pipeline as pipeline_custom +from flash_sparse_attn.ops.cute.pack_gqa import PackGQA, pack_gqa_layout, make_packgqa_tiled_tma_atom +from flash_sparse_attn.ops.cute.paged_kv import PagedKVManager +from flash_sparse_attn.ops.cute.named_barrier import NamedBarrierFwd from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.tile_scheduler import ( +from flash_sparse_attn.ops.cute.tile_scheduler import ( TileSchedulerArguments, SingleTileScheduler, SingleTileLPTScheduler, @@ -45,7 +45,7 @@ ) from cutlass.cute import FastDivmodDivisor -from flash_attn.cute.flash_fwd import FlashAttentionForwardBase +from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardBase class FlashAttentionForwardSm90(FlashAttentionForwardBase): diff --git a/flash_sparse_attn/ops/cute/interface.py b/flash_sparse_attn/ops/cute/interface.py index 17b703b..cd2e005 100644 --- a/flash_sparse_attn/ops/cute/interface.py +++ b/flash_sparse_attn/ops/cute/interface.py @@ -16,36 +16,36 @@ import cutlass.cute as cute from cutlass import Int32, Float32 from quack.compile_utils import make_fake_tensor as fake_tensor -from flash_attn.cute.cache_utils import get_jit_cache -from flash_attn.cute.testing import is_fake_mode +from flash_sparse_attn.ops.cute.cache_utils import get_jit_cache +from flash_sparse_attn.ops.cute.testing import is_fake_mode if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: - from flash_attn.cute import cute_dsl_ptxas # noqa: F401 + from flash_sparse_attn.ops.cute import cute_dsl_ptxas # noqa: F401 # Patch to dump ptx and then use system ptxas to compile to cubin cute_dsl_ptxas.patch() -from flash_attn.cute import utils -from flash_attn.cute import fa_logging -from flash_attn.cute.cute_dsl_utils import ( +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute import fa_logging +from flash_sparse_attn.ops.cute.cute_dsl_utils import ( to_cute_tensor, to_cute_aux_tensor, get_aux_tensor_metadata, get_broadcast_dims, ) -from flash_attn.cute.flash_fwd import FlashAttentionForwardSm80 -from flash_attn.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 -from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100, DescaleTensors -from flash_attn.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 -from flash_attn.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess -from flash_attn.cute.flash_bwd import FlashAttentionBackwardSm80 -from flash_attn.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 -from flash_attn.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 -from flash_attn.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 -from flash_attn.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess -from flash_attn.cute.flash_fwd_combine import FlashAttentionForwardCombine -from flash_attn.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 - -from flash_attn.cute.block_sparsity import ( +from flash_sparse_attn.ops.cute.flash_fwd import FlashAttentionForwardSm80 +from flash_sparse_attn.ops.cute.flash_fwd_sm90 import FlashAttentionForwardSm90 +from flash_sparse_attn.ops.cute.flash_fwd_sm100 import FlashAttentionForwardSm100, DescaleTensors +from flash_sparse_attn.ops.cute.flash_fwd_sm120 import FlashAttentionForwardSm120 +from flash_sparse_attn.ops.cute.flash_bwd_preprocess import FlashAttentionBackwardPreprocess +from flash_sparse_attn.ops.cute.flash_bwd import FlashAttentionBackwardSm80 +from flash_sparse_attn.ops.cute.flash_bwd_sm90 import FlashAttentionBackwardSm90 +from flash_sparse_attn.ops.cute.flash_bwd_sm100 import FlashAttentionBackwardSm100 +from flash_sparse_attn.ops.cute.flash_bwd_sm120 import FlashAttentionBackwardSm120 +from flash_sparse_attn.ops.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess +from flash_sparse_attn.ops.cute.flash_fwd_combine import FlashAttentionForwardCombine +from flash_sparse_attn.ops.cute.flash_fwd_mla_sm100 import FlashAttentionMLAForwardSm100 + +from flash_sparse_attn.ops.cute.block_sparsity import ( BlockSparseTensorsTorch, get_sparse_q_block_size, to_cute_block_sparse_tensors, diff --git a/flash_sparse_attn/ops/cute/mask.py b/flash_sparse_attn/ops/cute/mask.py index 99e7008..c7832f9 100644 --- a/flash_sparse_attn/ops/cute/mask.py +++ b/flash_sparse_attn/ops/cute/mask.py @@ -8,8 +8,8 @@ from cutlass import Float32, Int32, Uint32, const_expr, Boolean from quack import layout_utils -import flash_attn.cute.utils as utils -from flash_attn.cute.seqlen_info import SeqlenInfoQK +import flash_sparse_attn.ops.cute.utils as utils +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK MaskGenFn: TypeAlias = Callable[[int], Uint32] MASK_R2P_CHUNK_SIZE: int = 32 diff --git a/flash_sparse_attn/ops/cute/pack_gqa.py b/flash_sparse_attn/ops/cute/pack_gqa.py index e87df01..484834f 100644 --- a/flash_sparse_attn/ops/cute/pack_gqa.py +++ b/flash_sparse_attn/ops/cute/pack_gqa.py @@ -9,7 +9,7 @@ from quack import layout_utils -import flash_attn.cute.utils as utils +import flash_sparse_attn.ops.cute.utils as utils def pack_gqa_layout(T, qhead_per_kvhead, nheads_kv, head_idx): diff --git a/flash_sparse_attn/ops/cute/paged_kv.py b/flash_sparse_attn/ops/cute/paged_kv.py index bf11acb..f84b35c 100644 --- a/flash_sparse_attn/ops/cute/paged_kv.py +++ b/flash_sparse_attn/ops/cute/paged_kv.py @@ -6,7 +6,7 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Int32, const_expr -from flash_attn.cute import utils +from flash_sparse_attn.ops.cute import utils from quack.cute_dsl_utils import ParamsBase from cutlass.cute import FastDivmodDivisor diff --git a/flash_sparse_attn/ops/cute/softmax.py b/flash_sparse_attn/ops/cute/softmax.py index 9369e0d..79eecd4 100644 --- a/flash_sparse_attn/ops/cute/softmax.py +++ b/flash_sparse_attn/ops/cute/softmax.py @@ -10,9 +10,9 @@ from cutlass import Float32, Boolean from quack import layout_utils -import flash_attn.cute.utils as utils +import flash_sparse_attn.ops.cute.utils as utils from quack.cute_dsl_utils import ParamsBase -from flash_attn.cute.seqlen_info import SeqlenInfoQK +from flash_sparse_attn.ops.cute.seqlen_info import SeqlenInfoQK @dataclass diff --git a/flash_sparse_attn/ops/cute/tile_scheduler.py b/flash_sparse_attn/ops/cute/tile_scheduler.py index ae57858..002e7cd 100644 --- a/flash_sparse_attn/ops/cute/tile_scheduler.py +++ b/flash_sparse_attn/ops/cute/tile_scheduler.py @@ -19,8 +19,8 @@ from quack.cute_dsl_utils import ParamsBase -import flash_attn.cute.utils as utils -from flash_attn.cute.fast_math import clz +import flash_sparse_attn.ops.cute.utils as utils +from flash_sparse_attn.ops.cute.fast_math import clz class SchedulingMode(IntEnum): diff --git a/flash_sparse_attn/ops/cute/topk_gather_kv.py b/flash_sparse_attn/ops/cute/topk_gather_kv.py index 67169fb..5795ae0 100644 --- a/flash_sparse_attn/ops/cute/topk_gather_kv.py +++ b/flash_sparse_attn/ops/cute/topk_gather_kv.py @@ -8,8 +8,8 @@ from cutlass.cute.nvgpu import cpasync from cutlass import Int32, Uint32, const_expr, Boolean -from flash_attn.cute import utils -from flash_attn.cute.utils import warp_reduce +from flash_sparse_attn.ops.cute import utils +from flash_sparse_attn.ops.cute.utils import warp_reduce from quack.cute_dsl_utils import ParamsBase import math