[BUG FIX] Refactor CuTe namespace and enhance sync scripts#264
[BUG FIX] Refactor CuTe namespace and enhance sync scripts#264LoserCheems merged 5 commits intomainfrom
Conversation
…n/.ref_repo/flash-attention into sync/cute-worktree-20260407-152315
…ting based on direction
There was a problem hiding this comment.
Pull request overview
Refactors CuTe namespace rewriting to support bidirectional (local ↔ upstream) rewrites and updates the sync scripts to better handle upstream merges/submodules, while also introducing CLC-capable scheduling paths and related kernel/runtime toggles.
Changes:
- Add
--direction {local,upstream}support torewrite_cute_namespace.pyand update sync scripts to run rewrites pre/post merge and to init required submodules. - Extend CuTe tile scheduling to support CLC (dynamic persistent tile scheduler) and wire it into the SM100 forward kernel.
- Improve block-sparsity handling (compact index tensors) and adjust forward/backward tiling decisions based on sparse Q block size.
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| scripts/sync_cute_subtree.sh | Adds submodule init + bidirectional namespace rewrite and adjusts upstream merge flow. |
| scripts/sync_cute_subtree.ps1 | PowerShell parity for submodule init, rewrite direction, and more robust output/reporting. |
| scripts/rewrite_cute_namespace.py | Makes namespace/package rewrites configurable via --direction. |
| flash_sparse_attn/ops/cute/utils.py | Adds env-driven defaults and a new smid() DSL op. |
| flash_sparse_attn/ops/cute/tile_scheduler.py | Introduces scheduling modes, CLC state, and protocol updates for schedulers. |
| flash_sparse_attn/ops/cute/interface.py | Threads new sparse-Q-aware tiling knobs and CLC/2CTA toggles into fwd/bwd compilation/dispatch. |
| flash_sparse_attn/ops/cute/flash_fwd_sm100.py | Wires in CLC scheduler plumbing and passes a shared tile scheduler instance through kernel stages. |
| flash_sparse_attn/ops/cute/block_sparsity.py | Adds get_sparse_q_block_size and allows compact block-sparse index tensors. |
| flash_sparse_attn/ops/cute/cache_utils.py | Switches cache logging to fa_log. |
| flash_sparse_attn/ops/cute/flash_bwd_sm90.py | Renames argument to match dQ warp-group naming. |
| flash_sparse_attn/ops/cute/block_sparse_utils.py | Propagates renamed dQ warp-group argument. |
| flash_sparse_attn/ops/cute/README.md | Documents CUDA 12.x vs CUDA 13.x dev install extras. |
| flash_sparse_attn/ops/cute/pyproject.toml | Adds pytest-xdist to dev dependencies. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def create( | ||
| params: Params, clc: ClcState | None = None, *, loc=None, ip=None | ||
| ) -> "StaticPersistentTileScheduler": | ||
| if const_expr(cute.size(params.cluster_shape_m) == 1): | ||
| tile_idx = cute.arch.block_idx()[0] | ||
| else: | ||
| tile_idx = cute.arch.cluster_idx()[0] | ||
| return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip) |
There was a problem hiding this comment.
StaticPersistentTileScheduler.create() checks cute.size(params.cluster_shape_m) == 1, but cluster_shape_m is a scalar int constexpr. cute.size() is intended for shapes/tuples and is likely to fail or behave unexpectedly here; compare directly against 1 instead.
| assert arch // 10 in [9, 10, 11, 12], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x, 12.x" | ||
| sparse_q = None | ||
| if block_sparse_tensors is not None and arch // 10 == 9: | ||
| sparse_q = block_sparse_tensors.block_size[0] if block_sparse_tensors.block_size is not None else 128 | ||
|
|
There was a problem hiding this comment.
When block_sparse_tensors.block_size is absent, backward currently hardcodes sparse_q = 128. This can diverge from the actual sparse Q block size (which can often be inferred from mask_block_idx / seqlen_q, similar to forward via get_sparse_q_block_size). Consider inferring it (or erroring if ambiguous) to avoid silently wrong tiling/subtiling.
|
|
||
| # NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2 | ||
| subtile_factor = 2 | ||
| subtile_factor = sparse_q // m_block_size if sparse_q is not None else 2 |
There was a problem hiding this comment.
subtile_factor = sparse_q // m_block_size will silently floor if sparse_q is not an exact multiple of m_block_size, which can break block-sparse indexing assumptions. It would be safer to assert divisibility (and/or choose m_block_size accordingly) before computing subtile_factor.
| subtile_factor = sparse_q // m_block_size if sparse_q is not None else 2 | |
| if sparse_q is None: | |
| subtile_factor = 2 | |
| else: | |
| if sparse_q % m_block_size != 0: | |
| raise ValueError( | |
| f"sparse_q ({sparse_q}) must be an exact multiple of " | |
| f"m_block_size ({m_block_size})" | |
| ) | |
| subtile_factor = sparse_q // m_block_size |
| ... | ||
|
|
||
| def advance_to_next_work(self, *, loc=None, ip=None): | ||
| """Consumer-side advance: move to next tile and return it. | ||
|
|
||
| For static schedulers: grid-stride increment + get_current_work. | ||
| For CLC schedulers: consumer wait + get_current_work + consumer release + state advance. | ||
| """ | ||
| ... |
There was a problem hiding this comment.
TileSchedulerProtocol.advance_to_next_work is documented as returning the next WorkTileInfo, and implementations in this file return it, but the protocol method signature has no return annotation. Adding -> WorkTileInfo would make the interface contract explicit and help type-checkers catch mismatches.
| # === TUNING KNOBS (agent-editable) === | ||
| # Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool) |
There was a problem hiding this comment.
The tuning-knobs section appears duplicated: _TUNING_CONFIG is defined once above and then the same tuning header starts again immediately after. This looks like an accidental copy/paste that results in _TUNING_CONFIG being redefined (the later definition wins).
Summary
Root Cause
Changes
Reproduction
Tests
Compatibility
Checklist