Skip to content

Update CuTe namespace and functionality#266

Merged
LoserCheems merged 3 commits intomainfrom
optim-combine-func
Apr 21, 2026
Merged

Update CuTe namespace and functionality#266
LoserCheems merged 3 commits intomainfrom
optim-combine-func

Conversation

@LoserCheems
Copy link
Copy Markdown
Collaborator

No description provided.

Copilot AI review requested due to automatic review settings April 21, 2026 08:53
@LoserCheems LoserCheems merged commit eecc9bb into main Apr 21, 2026
3 checks passed
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR expands the CuTe FlashAttention implementation to support new Blackwell/SM100 capabilities (FP8 forward, MLA Qv-absorbed path, and top‑k KV gather), while updating scheduling/masking and softcap/score-mod plumbing.

Changes:

  • Add SM100 FP8 forward support via per-(batch, kv_head) descale tensors and softmax max-offset handling.
  • Introduce an SM100 MLA forward kernel (qv path) with optional top‑k KV gather using cpasync + bitmasking.
  • Extend interfaces/testing utilities for qv, gather_kv_indices, return_lse, and softcap backward support.

Reviewed changes

Copilot reviewed 20 out of 21 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
flash_sparse_attn/ops/cute/utils.py Adds CUDA 12 detection for default 2CTA disabling; updates softcap score_mod signature and adds softcap backward score_mod.
flash_sparse_attn/ops/cute/topk_gather_kv.py New cpasync KV gather manager with optional bitmask production for top‑k sparsity.
flash_sparse_attn/ops/cute/tile_scheduler.py Adds use_cluster_idx flag to control cluster-index handling in CLC coordinate mapping.
flash_sparse_attn/ops/cute/testing.py Extends reference attention to support return_lse and gather_kv_indices; improves padding mask generation.
flash_sparse_attn/ops/cute/softmax.py Adds Boolean import and new SoftmaxSm100 helpers + max_offset support.
flash_sparse_attn/ops/cute/named_barrier.py Adds named barriers for the new SM100 MLA 2CTA forward kernel.
flash_sparse_attn/ops/cute/mma_sm100_desc.py Updates CUTLASS FP8 type names to Float8E4M3FN / Float8E5M2.
flash_sparse_attn/ops/cute/mask.py Adds support for applying a packed bitmask during masking on SM100.
flash_sparse_attn/ops/cute/interface.py Extends public API for FP8 forward (descale tensors), MLA qv, top‑k gather, and softcap backward wiring.
flash_sparse_attn/ops/cute/flash_fwd_sm100.py Adds DescaleTensors param and FP8 tuning, plus effective descale application in softmax/correction loops.
flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py New SM100 MLA forward kernel (including top‑k gather path) and local test/benchmark harness.
flash_sparse_attn/ops/cute/flash_fwd.py Plumbs seqlen into score-mod application and adds apply_score_mod wrapper.
flash_sparse_attn/ops/cute/flash_bwd_sm90.py Removes unused softcap parameter from SM90 bwd call signature.
flash_sparse_attn/ops/cute/flash_bwd_sm100.py Removes unused softcap parameter from SM100 bwd call signature.
flash_sparse_attn/ops/cute/flash_bwd.py Adds score_mod / score_mod_bwd plumbing and applies them in the SM80 backward path.
flash_sparse_attn/ops/cute/cute_dsl_utils.py Adds FP8 torch→CuTe dtype mapping and fp8 DLPack uint8 workaround; adds kernel attribute dumping helper.
flash_sparse_attn/ops/cute/block_sparse_utils.py Extends empty-tile correction to incorporate max-offset scaling.
flash_sparse_attn/ops/cute/blackwell_helpers.py Generalizes tcgen05 MMA PTX emission to use the correct MMA “kind” for FP8/etc.
flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py New FP8 benchmarking script for SM100.
flash_sparse_attn/ops/cute/bench_utils.py Updates FLOPs calculation for MLA (has_qv) and adds bandwidth estimators.
flash_sparse_attn/ops/cute/init.py Removes global monkey-patching of cute.compile.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +405 to +411
for i in cutlass.range_constexpr(ncol_packed):
col_start = 32 * i # mask is bit-packed into uint32
curr_mask_val = rBitmask[i]
for j in cutlass.range_constexpr(32):
curr_col = col_start + j
mask = (curr_mask_val >> j) & 1
acc_S[curr_col] = acc_S[curr_col] if Boolean(mask) else -Float32.inf
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rBitmask masking path is indexing acc_S by a column number (acc_S[curr_col]) and only covers 32 bits per Uint32 without accounting for which 64-column half the current warp is responsible for. This will mask the wrong elements (and only half the tile) for a typical (tile_m,tile_n) fragment. Rework this to apply the bitmask using the fragment’s (row,col) coordinates (e.g., via tScS_t2r / thr_tmem_load.partition_D) and incorporate the correct column offset for the warp/fragment portion being processed.

Suggested change
for i in cutlass.range_constexpr(ncol_packed):
col_start = 32 * i # mask is bit-packed into uint32
curr_mask_val = rBitmask[i]
for j in cutlass.range_constexpr(32):
curr_col = col_start + j
mask = (curr_mask_val >> j) & 1
acc_S[curr_col] = acc_S[curr_col] if Boolean(mask) else -Float32.inf
for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True):
curr_col = tScS_t2r[i][1]
mask_word_idx = curr_col // 32
mask_bit_idx = curr_col % 32
if mask_word_idx < ncol_packed:
curr_mask_val = rBitmask[mask_word_idx]
mask = (curr_mask_val >> mask_bit_idx) & 1
acc_S[i] = acc_S[i] if Boolean(mask) else -Float32.inf
else:
acc_S[i] = -Float32.inf

Copilot uses AI. Check for mistakes.
Comment on lines +408 to +413
batch = q.shape[0]
topk_len = gather_kv_indices.shape[2]
if topk_len < seqlen_k:
topk_index_mask = torch.full(
(batch, seqlen_q, seqlen_k), False, device="cuda"
).scatter_(-1, gather_kv_indices, True)
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attention_ref's gather_kv_indices handling assumes a 3D tensor (shape[2]) and hardcodes device="cuda" when building the mask. This will break for the documented varlen shape (total_q, topk_length) and for non-CUDA devices. Use scores.device/q.device and support both 2D and 3D gather_kv_indices (e.g., normalize to (batch,seqlen_q,topk) before masking).

Suggested change
batch = q.shape[0]
topk_len = gather_kv_indices.shape[2]
if topk_len < seqlen_k:
topk_index_mask = torch.full(
(batch, seqlen_q, seqlen_k), False, device="cuda"
).scatter_(-1, gather_kv_indices, True)
batch = scores.shape[0]
if gather_kv_indices.ndim == 2:
expected_total_q = batch * seqlen_q
if gather_kv_indices.shape[0] != expected_total_q:
raise ValueError(
"Expected 2D gather_kv_indices to have shape "
f"({expected_total_q}, topk), got {tuple(gather_kv_indices.shape)}"
)
gather_kv_indices = rearrange(
gather_kv_indices, "(b t) k -> b t k", b=batch, t=seqlen_q
)
elif gather_kv_indices.ndim == 3:
if gather_kv_indices.shape[0] != batch or gather_kv_indices.shape[1] != seqlen_q:
raise ValueError(
"Expected 3D gather_kv_indices to have shape "
f"({batch}, {seqlen_q}, topk), got {tuple(gather_kv_indices.shape)}"
)
else:
raise ValueError(
"gather_kv_indices must be a 2D tensor of shape (total_q, topk) "
"or a 3D tensor of shape (batch, seqlen_q, topk)"
)
topk_len = gather_kv_indices.shape[-1]
if topk_len < seqlen_k:
topk_index_mask = torch.full(
(batch, seqlen_q, seqlen_k),
False,
device=scores.device,
dtype=torch.bool,
).scatter_(-1, gather_kv_indices.to(device=scores.device), True)

Copilot uses AI. Check for mistakes.
Comment on lines +159 to +171
if const_expr(not transpose and not self.disable_bitmask):
row_non_interleaved = i * self.num_threads + self.thread_idx
row_idx_non_interleaved = n_block * self.tile_n + row_non_interleaved
self.rTopk_NonInterleaved[0] = self.mIndexTopk[row_idx_non_interleaved]

@cute.jit
def compute_bitmask(
self,
producer_state_bitmask,
):
lane_idx = cute.arch.lane_idx()
assert cute.size(self.rTopk_NonInterleaved) == 1
bitmask = Uint32(0)
Copy link

Copilot AI Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CpasyncGatherKVManager.create() allows tile_n values where topk_indices_per_thread > 1 (e.g., tile_n=256), but compute_bitmask() asserts rTopk_NonInterleaved has size 1 and load_index_topk() only writes element [0]. Either constrain tile_n to 128 (or topk_indices_per_thread==1) with an explicit assert, or generalize rTopk_NonInterleaved/bitmask computation to handle multiple indices per thread.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants