Skip to content

[BUG FIX] Refactor CuTe namespace and enhance sync scripts#264

Merged
LoserCheems merged 5 commits intomainfrom
optim_triton_version
Apr 7, 2026
Merged

[BUG FIX] Refactor CuTe namespace and enhance sync scripts#264
LoserCheems merged 5 commits intomainfrom
optim_triton_version

Conversation

@LoserCheems
Copy link
Copy Markdown
Collaborator

Summary

  • This update refactors the CuTe namespace for better integration and enhances sync scripts for improved functionality.

Root Cause

  • The previous namespace handling was static and did not support dynamic rewriting, leading to integration issues.

Changes

  • Refactored the namespace handling to allow dynamic rewriting based on direction.
  • Enhanced sync scripts to include submodule initialization and namespace rewrite functions.

Reproduction

  • Clone the repository and run the sync scripts with different namespace directions to observe the changes.

Tests

  • Updated tests to validate the new namespace handling and sync script functionality.

Compatibility

  • Changes ensure backward compatibility with existing namespace usage.

Checklist

  • Linked issue provided
  • Adds or updates tests
  • Updates docs if needed
  • No perf regressions

Copilot AI review requested due to automatic review settings April 7, 2026 15:30
@LoserCheems LoserCheems merged commit 69fa203 into main Apr 7, 2026
3 checks passed
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Refactors CuTe namespace rewriting to support bidirectional (local ↔ upstream) rewrites and updates the sync scripts to better handle upstream merges/submodules, while also introducing CLC-capable scheduling paths and related kernel/runtime toggles.

Changes:

  • Add --direction {local,upstream} support to rewrite_cute_namespace.py and update sync scripts to run rewrites pre/post merge and to init required submodules.
  • Extend CuTe tile scheduling to support CLC (dynamic persistent tile scheduler) and wire it into the SM100 forward kernel.
  • Improve block-sparsity handling (compact index tensors) and adjust forward/backward tiling decisions based on sparse Q block size.

Reviewed changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
scripts/sync_cute_subtree.sh Adds submodule init + bidirectional namespace rewrite and adjusts upstream merge flow.
scripts/sync_cute_subtree.ps1 PowerShell parity for submodule init, rewrite direction, and more robust output/reporting.
scripts/rewrite_cute_namespace.py Makes namespace/package rewrites configurable via --direction.
flash_sparse_attn/ops/cute/utils.py Adds env-driven defaults and a new smid() DSL op.
flash_sparse_attn/ops/cute/tile_scheduler.py Introduces scheduling modes, CLC state, and protocol updates for schedulers.
flash_sparse_attn/ops/cute/interface.py Threads new sparse-Q-aware tiling knobs and CLC/2CTA toggles into fwd/bwd compilation/dispatch.
flash_sparse_attn/ops/cute/flash_fwd_sm100.py Wires in CLC scheduler plumbing and passes a shared tile scheduler instance through kernel stages.
flash_sparse_attn/ops/cute/block_sparsity.py Adds get_sparse_q_block_size and allows compact block-sparse index tensors.
flash_sparse_attn/ops/cute/cache_utils.py Switches cache logging to fa_log.
flash_sparse_attn/ops/cute/flash_bwd_sm90.py Renames argument to match dQ warp-group naming.
flash_sparse_attn/ops/cute/block_sparse_utils.py Propagates renamed dQ warp-group argument.
flash_sparse_attn/ops/cute/README.md Documents CUDA 12.x vs CUDA 13.x dev install extras.
flash_sparse_attn/ops/cute/pyproject.toml Adds pytest-xdist to dev dependencies.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +321 to 328
def create(
params: Params, clc: ClcState | None = None, *, loc=None, ip=None
) -> "StaticPersistentTileScheduler":
if const_expr(cute.size(params.cluster_shape_m) == 1):
tile_idx = cute.arch.block_idx()[0]
else:
tile_idx = cute.arch.cluster_idx()[0]
return StaticPersistentTileScheduler(params, tile_idx, loc=loc, ip=ip)
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StaticPersistentTileScheduler.create() checks cute.size(params.cluster_shape_m) == 1, but cluster_shape_m is a scalar int constexpr. cute.size() is intended for shapes/tuples and is likely to fail or behave unexpectedly here; compare directly against 1 instead.

Copilot uses AI. Check for mistakes.
Comment on lines 1011 to 1015
assert arch // 10 in [9, 10, 11, 12], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x, 12.x"
sparse_q = None
if block_sparse_tensors is not None and arch // 10 == 9:
sparse_q = block_sparse_tensors.block_size[0] if block_sparse_tensors.block_size is not None else 128

Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When block_sparse_tensors.block_size is absent, backward currently hardcodes sparse_q = 128. This can diverge from the actual sparse Q block size (which can often be inferred from mask_block_idx / seqlen_q, similar to forward via get_sparse_q_block_size). Consider inferring it (or erroring if ambiguous) to avoid silently wrong tiling/subtiling.

Copilot uses AI. Check for mistakes.

# NB: this could be derived from the block_sparse_tensors but for now we hardcode it to 2
subtile_factor = 2
subtile_factor = sparse_q // m_block_size if sparse_q is not None else 2
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subtile_factor = sparse_q // m_block_size will silently floor if sparse_q is not an exact multiple of m_block_size, which can break block-sparse indexing assumptions. It would be safer to assert divisibility (and/or choose m_block_size accordingly) before computing subtile_factor.

Suggested change
subtile_factor = sparse_q // m_block_size if sparse_q is not None else 2
if sparse_q is None:
subtile_factor = 2
else:
if sparse_q % m_block_size != 0:
raise ValueError(
f"sparse_q ({sparse_q}) must be an exact multiple of "
f"m_block_size ({m_block_size})"
)
subtile_factor = sparse_q // m_block_size

Copilot uses AI. Check for mistakes.
Comment on lines +113 to +121
...

def advance_to_next_work(self, *, loc=None, ip=None):
"""Consumer-side advance: move to next tile and return it.

For static schedulers: grid-stride increment + get_current_work.
For CLC schedulers: consumer wait + get_current_work + consumer release + state advance.
"""
...
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TileSchedulerProtocol.advance_to_next_work is documented as returning the next WorkTileInfo, and implementations in this file return it, but the protocol method signature has no return annotation. Adding -> WorkTileInfo would make the interface contract explicit and help type-checkers catch mismatches.

Copilot uses AI. Check for mistakes.
Comment on lines 93 to 94
# === TUNING KNOBS (agent-editable) ===
# Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool)
Copy link

Copilot AI Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tuning-knobs section appears duplicated: _TUNING_CONFIG is defined once above and then the same tuning header starts again immediately after. This looks like an accidental copy/paste that results in _TUNING_CONFIG being redefined (the later definition wins).

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants