Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions flash_sparse_attn/ops/cute/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ FlashAttention-4 is a CuTeDSL-based implementation of FlashAttention for Hopper
pip install flash-attn-4
```

If you're on CUDA 13, install with the `cu13` extra for best performance:

```sh
pip install "flash-attn-4[cu13]"
```

## Usage

```python
Expand Down
82 changes: 49 additions & 33 deletions flash_sparse_attn/ops/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,28 @@
SingleTileVarlenScheduler,
)

# === TUNING KNOBS (agent-editable) ===
# Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool)
# Values:
# ex2_emu_freq: int — how often to use emulated exp2 (0=all hardware exp2, higher=more emulation).
# SM103 has fast native exp2, so set freq=0 there.
# ex2_emu_start_frg: int — fragment index to start emulation from
# num_regs_softmax: int — register count for softmax warps (multiple of 8)
# num_regs_correction: int — register count for correction warps (multiple of 8)
# num_regs_other is derived: 512 - num_regs_softmax * 2 - num_regs_correction
_TUNING_CONFIG = {
(True, False, 128, False): {'ex2_emu_freq': 10, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 176, 'num_regs_correction': 88},
(False, True, 128, False): {'ex2_emu_freq': 16, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 192, 'num_regs_correction': 72},
(True, False, 192, False): {"ex2_emu_freq": 16, "ex2_emu_start_frg": 0, "num_regs_softmax": 184, "num_regs_correction": 80},
(False, True, 192, False): {"ex2_emu_freq": 32, "ex2_emu_start_frg": 1, "num_regs_softmax": 192, "num_regs_correction": 72},
(True, False, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 80},
(False, True, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64},
(True, False, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64},
(False, True, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 72},
}
# === END TUNING KNOBS ===


class FlashAttentionForwardSm100:

def __init__(
Expand Down Expand Up @@ -141,8 +163,10 @@ def __init__(
# Does S1 need to wait for S0 to finish
# self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local)
is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f
# self.enable_ex2_emu = self.head_dim_padded <= 128 and not is_sm103
self.enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103
self.is_sm103 = is_sm103
# enable_ex2_emu is derived: True if tuning config has freq > 0, else fallback to default logic
_default_enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103
self.enable_ex2_emu = _default_enable_ex2_emu
self.s0_s1_barrier = False
self.overlap_sO_sQ = (
(self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or
Expand Down Expand Up @@ -210,31 +234,26 @@ def __init__(
# vec buffer for row_max & row_sum
self.tmem_vec_offset = self.tmem_s_offset

# Look up tuning config for register counts and ex2_emu params
_tune_key = (self.use_2cta_instrs, self.is_causal, self.head_dim_padded, self.is_sm103)
self._tune = _TUNING_CONFIG.get(_tune_key, {})
if "ex2_emu_freq" in self._tune:
self.enable_ex2_emu = self._tune["ex2_emu_freq"] > 0
if self.head_dim_padded < 96:
self.num_regs_softmax = 200 if not paged_kv_non_tma else 184
self.num_regs_correction = 64
self.num_regs_other = 48 if not paged_kv_non_tma else 80
else:
# self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184
if not self.enable_ex2_emu:
self.num_regs_softmax = 192 if not paged_kv_non_tma else 184
if not paged_kv_non_tma and "num_regs_softmax" in self._tune:
self.num_regs_softmax = self._tune["num_regs_softmax"]
self.num_regs_correction = self._tune["num_regs_correction"]
elif not paged_kv_non_tma:
self.num_regs_softmax = 192
self.num_regs_correction = 80
else:
# self.num_regs_softmax = 200 if not paged_kv_non_tma else 184
self.num_regs_softmax = 192 if not paged_kv_non_tma else 184
# self.num_regs_softmax = 176
# self.num_regs_correction = 96
# self.num_regs_correction = 64 if self.is_causal or self.is_local else 80
if not self.enable_ex2_emu:
self.num_regs_correction = 80 if not paged_kv_non_tma else 64
else:
# self.num_regs_correction = 64
self.num_regs_correction = 80 if not paged_kv_non_tma else 64
# self.num_regs_other = 32
# self.num_regs_other = 64
# self.num_regs_other = 80
self.num_regs_other = 48 if not paged_kv_non_tma else 80
# self.num_regs_other = 96 if self.is_causal or self.is_local else 80
# self.num_regs_other = 64 if self.is_causal or self.is_local else 80
self.num_regs_softmax = 184
self.num_regs_correction = 64
self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_correction

self.buffer_align_bytes = 1024

Expand Down Expand Up @@ -358,21 +377,14 @@ def __call__(
and not (self.pack_gqa and self.m_block_size % self.qhead_per_kvhead != 0)
and not (self.pack_gqa and self.is_split_kv)
)
# This can be tuned
# This is currently very ad-hoc, we should tune it systematically
self.ex2_emu_freq = 0
# self.ex2_emu_start_frg = 1 if self.is_causal else 0
self.ex2_emu_start_frg = 1
self.ex2_emu_start_frg = self._tune.get("ex2_emu_start_frg", 1)
if const_expr(self.enable_ex2_emu):
self.ex2_emu_freq = 16
if const_expr(self.head_dim_padded == 128 and self.use_2cta_instrs):
self.ex2_emu_freq = 12
self.ex2_emu_freq = self._tune.get("ex2_emu_freq", 16)
if const_expr(
self.pack_gqa and self.head_dim_padded > 64 and not self.is_causal and not self.is_local
):
self.ex2_emu_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10
if const_expr(self.head_dim_padded > 64 and self.is_causal):
self.ex2_emu_freq = 10
self.ex2_emu_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else self._tune.get("ex2_emu_freq", 10)

cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
q_major_mode = tcgen05.OperandMajorMode.K
Expand Down Expand Up @@ -487,7 +499,7 @@ def __call__(
tma_atom_Q = None
async_copy_elems = 128 // self.q_dtype.width
num_load_threads = cute.arch.WARP_SIZE * len(self.load_warp_ids)
threads_per_row = self.head_dim_padded // async_copy_elems
threads_per_row = math.gcd(self.head_dim_padded // async_copy_elems, num_load_threads)
gmem_tiled_copy_Q = copy_utils.tiled_copy_2d(
self.q_dtype, threads_per_row, num_load_threads, async_copy_elems, is_async=True
)
Comment on lines 500 to 505
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy_utils.tiled_copy_2d is defined as tiled_copy_2d(dtype, major_mode_size, num_threads, is_async=False) in flash_sparse_attn/ops/cute/copy_utils.py, but this call passes an extra positional argument (async_copy_elems) and uses threads_per_row as the second argument. As written, this will raise a TypeError when use_tma_Q is false (and even if it didn’t, it would be using the helper with the wrong semantics). Update the call to match the helper’s signature (i.e., pass the actual major-mode size in elements and use the keyword is_async=), or update tiled_copy_2d/call sites consistently if the intended API includes num_copy_elems.

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -550,8 +562,11 @@ def __call__(
if const_expr(not self.is_persistent)
else StaticPersistentTileScheduler
)
# For non-persistent 2CTA (use_cluster_idx), each cluster covers
# cta_tiler[0] * cta_group_size rows, so num_block must be divided accordingly
_num_block_divisor = self.cta_tiler[0] * (self.cta_group_size if not self.is_persistent and self.cta_group_size > 1 else 1)
tile_sched_args = TileSchedulerArguments(
cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]),
cute.ceil_div(cute.size(mQ.shape[0]), _num_block_divisor),
cute.size(mQ.shape[2]),
cute.size(mQ.shape[3])
if const_expr(mCuSeqlensQ is None)
Expand All @@ -574,6 +589,7 @@ def __call__(
lpt=self.is_causal or self.is_local,
is_split_kv=self.is_split_kv,
cluster_shape_mn=self.cluster_shape_mn,
use_cluster_idx=not self.is_persistent and self.cta_group_size > 1,
)
tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args)
self.tile_scheduler_cls = TileScheduler
Expand Down
2 changes: 1 addition & 1 deletion flash_sparse_attn/ops/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def _flash_attn_fwd(
and seqused_q is None
and not use_block_sparsity
and page_size in [None, 128]
and int(math.ceil(head_dim / 16) * 16) == 128
and int(math.ceil(head_dim / 16) * 16) in [128, 192]
and int(math.ceil(head_dim_v / 16) * 16) == 128
and seqlen_q_packgqa > 2 * tile_m
and (tile_m % qhead_per_kvhead == 0 or not pack_gqa)
Expand Down
11 changes: 11 additions & 0 deletions flash_sparse_attn/ops/cute/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
]

[project.optional-dependencies]
cu13 = ["nvidia-cutlass-dsl[cu13]>=4.4.2"]
dev = [
"pytest",
"ruff",
Expand All @@ -51,6 +52,16 @@ tag_regex = "^fa4-v(?P<version>.+)$"
git_describe_command = "git describe --dirty --tags --long --match 'fa4-v*'"
fallback_version = "0.0.0"

[[tool.uv.index]]
name = "pytorch-cu130"
url = "https://download.pytorch.org/whl/cu130"
explicit = true

[tool.uv.sources]
torch = [
{ index = "pytorch-cu130", marker = "extra == 'cu13'" },
]

[tool.ruff]
line-length = 100

Expand Down
36 changes: 23 additions & 13 deletions flash_sparse_attn/ops/cute/tile_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class TileSchedulerArguments(ParamsBase):
lpt: cutlass.Constexpr[bool] = False
is_split_kv: cutlass.Constexpr[bool] = False
head_swizzle: cutlass.Constexpr[bool] = False
use_cluster_idx: cutlass.Constexpr[bool] = False


class SingleTileScheduler:
Expand All @@ -63,6 +64,7 @@ class Params(ParamsBase):
num_splits_divmod: FastDivmodDivisor
is_split_kv: cutlass.Constexpr[bool] = False
cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1)
use_cluster_idx: cutlass.Constexpr[bool] = False

@staticmethod
def create(
Expand All @@ -76,6 +78,7 @@ def create(
FastDivmodDivisor(args.num_splits),
args.is_split_kv,
args.cluster_shape_mn,
args.use_cluster_idx,
)

def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None):
Expand All @@ -91,13 +94,11 @@ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None)

@staticmethod
def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler":
# if const_expr(cute.size(params.cluster_shape_mn) == 1):
# blk_coord = cute.arch.block_idx()
# else:
# # All CTAs in a cluster must get the same block coordinate
# blk_coord = cute.arch.cluster_idx()
# Temporary set to block_idx until we sort out the best way to handle cluster
blk_coord = cute.arch.block_idx()
if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx):
blk_coord = cute.arch.block_idx()
else:
# All CTAs in a cluster must get the same block coordinate
blk_coord = cute.arch.cluster_idx()
return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip)

# called by host
Expand All @@ -110,8 +111,13 @@ def get_grid_shape(
) -> Tuple[Int32, Int32, Int32]:
# TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1)
assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported"
if const_expr(params.use_cluster_idx):
# Grid must have num_block * cluster_m physical blocks so that there are num_block clusters
grid_x = params.num_block * params.cluster_shape_mn[0]
else:
grid_x = cute.round_up(params.num_block, params.cluster_shape_mn[0])
return (
cute.round_up(params.num_block, params.cluster_shape_mn[0]),
grid_x,
params.num_head * params.num_splits,
params.num_batch,
)
Expand Down Expand Up @@ -395,8 +401,8 @@ def create(
) -> "SingleTileLPTBwdScheduler.Params":
size_l2 = 50 * 1024 * 1024
size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size
# size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4
size_one_dqaccum_head = 0
size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4
# size_one_dqaccum_head = 0
size_one_head = size_one_qdo_head + size_one_dqaccum_head
log2_floor = lambda n: 31 - clz(n)
swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head))
Expand Down Expand Up @@ -521,9 +527,12 @@ def create(
args: TileSchedulerArguments, *, loc=None, ip=None
) -> "SingleTileVarlenScheduler.Params":
size_l2 = 50 * 1024 * 1024 # 50 MB for K & V
max_kvblock_in_l2 = size_l2 // (
(args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]
)
# if backward, this is qdo block size
kv_block_size = (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1]
# if backward, add dqaccum block size to calculate swizzle
if args.head_swizzle:
kv_block_size += args.headdim * 4 * args.tile_shape_mn[1]
max_kvblock_in_l2 = size_l2 // kv_block_size
assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, (
"At least one of mCuSeqlensQ or mSeqUsedQ must be provided"
)
Expand Down Expand Up @@ -654,6 +663,7 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo:
num_n_blocks = (
num_m_blocks
* params.tile_shape_mn[0]
* params.cluster_shape_m
// params.qhead_per_kvhead_packgqa
// params.tile_shape_mn[1]
)
Expand Down
Loading
Loading