Update CuTe namespace and functionality#266
Conversation
…e/flash-sparse-attention/.ref_repo/flash-attention into sync/cute-worktree-20260421-164501
There was a problem hiding this comment.
Pull request overview
This PR expands the CuTe FlashAttention implementation to support new Blackwell/SM100 capabilities (FP8 forward, MLA Qv-absorbed path, and top‑k KV gather), while updating scheduling/masking and softcap/score-mod plumbing.
Changes:
- Add SM100 FP8 forward support via per-(batch, kv_head) descale tensors and softmax max-offset handling.
- Introduce an SM100 MLA forward kernel (
qvpath) with optional top‑k KV gather using cpasync + bitmasking. - Extend interfaces/testing utilities for
qv,gather_kv_indices,return_lse, and softcap backward support.
Reviewed changes
Copilot reviewed 20 out of 21 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| flash_sparse_attn/ops/cute/utils.py | Adds CUDA 12 detection for default 2CTA disabling; updates softcap score_mod signature and adds softcap backward score_mod. |
| flash_sparse_attn/ops/cute/topk_gather_kv.py | New cpasync KV gather manager with optional bitmask production for top‑k sparsity. |
| flash_sparse_attn/ops/cute/tile_scheduler.py | Adds use_cluster_idx flag to control cluster-index handling in CLC coordinate mapping. |
| flash_sparse_attn/ops/cute/testing.py | Extends reference attention to support return_lse and gather_kv_indices; improves padding mask generation. |
| flash_sparse_attn/ops/cute/softmax.py | Adds Boolean import and new SoftmaxSm100 helpers + max_offset support. |
| flash_sparse_attn/ops/cute/named_barrier.py | Adds named barriers for the new SM100 MLA 2CTA forward kernel. |
| flash_sparse_attn/ops/cute/mma_sm100_desc.py | Updates CUTLASS FP8 type names to Float8E4M3FN / Float8E5M2. |
| flash_sparse_attn/ops/cute/mask.py | Adds support for applying a packed bitmask during masking on SM100. |
| flash_sparse_attn/ops/cute/interface.py | Extends public API for FP8 forward (descale tensors), MLA qv, top‑k gather, and softcap backward wiring. |
| flash_sparse_attn/ops/cute/flash_fwd_sm100.py | Adds DescaleTensors param and FP8 tuning, plus effective descale application in softmax/correction loops. |
| flash_sparse_attn/ops/cute/flash_fwd_mla_sm100.py | New SM100 MLA forward kernel (including top‑k gather path) and local test/benchmark harness. |
| flash_sparse_attn/ops/cute/flash_fwd.py | Plumbs seqlen into score-mod application and adds apply_score_mod wrapper. |
| flash_sparse_attn/ops/cute/flash_bwd_sm90.py | Removes unused softcap parameter from SM90 bwd call signature. |
| flash_sparse_attn/ops/cute/flash_bwd_sm100.py | Removes unused softcap parameter from SM100 bwd call signature. |
| flash_sparse_attn/ops/cute/flash_bwd.py | Adds score_mod / score_mod_bwd plumbing and applies them in the SM80 backward path. |
| flash_sparse_attn/ops/cute/cute_dsl_utils.py | Adds FP8 torch→CuTe dtype mapping and fp8 DLPack uint8 workaround; adds kernel attribute dumping helper. |
| flash_sparse_attn/ops/cute/block_sparse_utils.py | Extends empty-tile correction to incorporate max-offset scaling. |
| flash_sparse_attn/ops/cute/blackwell_helpers.py | Generalizes tcgen05 MMA PTX emission to use the correct MMA “kind” for FP8/etc. |
| flash_sparse_attn/ops/cute/benchmark_flash_attention_fp8.py | New FP8 benchmarking script for SM100. |
| flash_sparse_attn/ops/cute/bench_utils.py | Updates FLOPs calculation for MLA (has_qv) and adds bandwidth estimators. |
| flash_sparse_attn/ops/cute/init.py | Removes global monkey-patching of cute.compile. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| for i in cutlass.range_constexpr(ncol_packed): | ||
| col_start = 32 * i # mask is bit-packed into uint32 | ||
| curr_mask_val = rBitmask[i] | ||
| for j in cutlass.range_constexpr(32): | ||
| curr_col = col_start + j | ||
| mask = (curr_mask_val >> j) & 1 | ||
| acc_S[curr_col] = acc_S[curr_col] if Boolean(mask) else -Float32.inf |
There was a problem hiding this comment.
rBitmask masking path is indexing acc_S by a column number (acc_S[curr_col]) and only covers 32 bits per Uint32 without accounting for which 64-column half the current warp is responsible for. This will mask the wrong elements (and only half the tile) for a typical (tile_m,tile_n) fragment. Rework this to apply the bitmask using the fragment’s (row,col) coordinates (e.g., via tScS_t2r / thr_tmem_load.partition_D) and incorporate the correct column offset for the warp/fragment portion being processed.
| for i in cutlass.range_constexpr(ncol_packed): | |
| col_start = 32 * i # mask is bit-packed into uint32 | |
| curr_mask_val = rBitmask[i] | |
| for j in cutlass.range_constexpr(32): | |
| curr_col = col_start + j | |
| mask = (curr_mask_val >> j) & 1 | |
| acc_S[curr_col] = acc_S[curr_col] if Boolean(mask) else -Float32.inf | |
| for i in cutlass.range(cute.size(tScS_t2r.shape), unroll_full=True): | |
| curr_col = tScS_t2r[i][1] | |
| mask_word_idx = curr_col // 32 | |
| mask_bit_idx = curr_col % 32 | |
| if mask_word_idx < ncol_packed: | |
| curr_mask_val = rBitmask[mask_word_idx] | |
| mask = (curr_mask_val >> mask_bit_idx) & 1 | |
| acc_S[i] = acc_S[i] if Boolean(mask) else -Float32.inf | |
| else: | |
| acc_S[i] = -Float32.inf |
| batch = q.shape[0] | ||
| topk_len = gather_kv_indices.shape[2] | ||
| if topk_len < seqlen_k: | ||
| topk_index_mask = torch.full( | ||
| (batch, seqlen_q, seqlen_k), False, device="cuda" | ||
| ).scatter_(-1, gather_kv_indices, True) |
There was a problem hiding this comment.
attention_ref's gather_kv_indices handling assumes a 3D tensor (shape[2]) and hardcodes device="cuda" when building the mask. This will break for the documented varlen shape (total_q, topk_length) and for non-CUDA devices. Use scores.device/q.device and support both 2D and 3D gather_kv_indices (e.g., normalize to (batch,seqlen_q,topk) before masking).
| batch = q.shape[0] | |
| topk_len = gather_kv_indices.shape[2] | |
| if topk_len < seqlen_k: | |
| topk_index_mask = torch.full( | |
| (batch, seqlen_q, seqlen_k), False, device="cuda" | |
| ).scatter_(-1, gather_kv_indices, True) | |
| batch = scores.shape[0] | |
| if gather_kv_indices.ndim == 2: | |
| expected_total_q = batch * seqlen_q | |
| if gather_kv_indices.shape[0] != expected_total_q: | |
| raise ValueError( | |
| "Expected 2D gather_kv_indices to have shape " | |
| f"({expected_total_q}, topk), got {tuple(gather_kv_indices.shape)}" | |
| ) | |
| gather_kv_indices = rearrange( | |
| gather_kv_indices, "(b t) k -> b t k", b=batch, t=seqlen_q | |
| ) | |
| elif gather_kv_indices.ndim == 3: | |
| if gather_kv_indices.shape[0] != batch or gather_kv_indices.shape[1] != seqlen_q: | |
| raise ValueError( | |
| "Expected 3D gather_kv_indices to have shape " | |
| f"({batch}, {seqlen_q}, topk), got {tuple(gather_kv_indices.shape)}" | |
| ) | |
| else: | |
| raise ValueError( | |
| "gather_kv_indices must be a 2D tensor of shape (total_q, topk) " | |
| "or a 3D tensor of shape (batch, seqlen_q, topk)" | |
| ) | |
| topk_len = gather_kv_indices.shape[-1] | |
| if topk_len < seqlen_k: | |
| topk_index_mask = torch.full( | |
| (batch, seqlen_q, seqlen_k), | |
| False, | |
| device=scores.device, | |
| dtype=torch.bool, | |
| ).scatter_(-1, gather_kv_indices.to(device=scores.device), True) |
| if const_expr(not transpose and not self.disable_bitmask): | ||
| row_non_interleaved = i * self.num_threads + self.thread_idx | ||
| row_idx_non_interleaved = n_block * self.tile_n + row_non_interleaved | ||
| self.rTopk_NonInterleaved[0] = self.mIndexTopk[row_idx_non_interleaved] | ||
|
|
||
| @cute.jit | ||
| def compute_bitmask( | ||
| self, | ||
| producer_state_bitmask, | ||
| ): | ||
| lane_idx = cute.arch.lane_idx() | ||
| assert cute.size(self.rTopk_NonInterleaved) == 1 | ||
| bitmask = Uint32(0) |
There was a problem hiding this comment.
CpasyncGatherKVManager.create() allows tile_n values where topk_indices_per_thread > 1 (e.g., tile_n=256), but compute_bitmask() asserts rTopk_NonInterleaved has size 1 and load_index_topk() only writes element [0]. Either constrain tile_n to 128 (or topk_indices_per_thread==1) with an explicit assert, or generalize rTopk_NonInterleaved/bitmask computation to handle multiple indices per thread.
No description provided.