Add support for upstream split reference in sync scripts#263
Add support for upstream split reference in sync scripts#263LoserCheems merged 2 commits intomainfrom
Conversation
…e-worktree-20260326-102028
There was a problem hiding this comment.
Pull request overview
Adds support for syncing the vendored CuTe subtree from the upstream repo’s default branch/ref (including “split ref” handling), improves subtree merge/cherry-pick behavior, and updates the CuTe-based FlashAttention-4 implementation to support additional SM100/2-CTA configurations and CUDA 13 packaging.
Changes:
- Update sync scripts (bash + PowerShell) to resolve upstream default ref, disable merge auto-edit, and handle merge commits when replaying sync commits.
- Extend CuTe tile scheduling / SM100 forward kernel tuning (cluster indexing support, register/ex2 emulation tuning, and enable 2-CTA path for head_dim padded 192).
- Add CUDA 13 optional dependency + uv index/source wiring and document the
cu13install extra.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| scripts/sync_cute_subtree.sh | Resolve upstream default ref for remote caches; avoid merge editor; verify sync includes upstream split; handle merge commits during cherry-pick. |
| scripts/sync_cute_subtree.ps1 | PowerShell parity for upstream default ref resolution, merge auto-edit disabling, and merge-aware cherry-picking. |
| flash_sparse_attn/ops/cute/tile_scheduler.py | Add use_cluster_idx to support cluster-based block coordinate selection and adjusted grid sizing. |
| flash_sparse_attn/ops/cute/interface.py | Broaden 2-CTA eligibility to include head_dim padded 192. |
| flash_sparse_attn/ops/cute/flash_fwd_sm100.py | Add tuning table + cluster-index scheduling hooks; adjust Q copy thread layout logic. |
| flash_sparse_attn/ops/cute/pyproject.toml | Add cu13 extra and uv index/source configuration for CUDA 13 torch wheels. |
| flash_sparse_attn/ops/cute/README.md | Document installing with flash-attn-4[cu13]. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| async_copy_elems = 128 // self.q_dtype.width | ||
| num_load_threads = cute.arch.WARP_SIZE * len(self.load_warp_ids) | ||
| threads_per_row = self.head_dim_padded // async_copy_elems | ||
| threads_per_row = math.gcd(self.head_dim_padded // async_copy_elems, num_load_threads) | ||
| gmem_tiled_copy_Q = copy_utils.tiled_copy_2d( | ||
| self.q_dtype, threads_per_row, num_load_threads, async_copy_elems, is_async=True | ||
| ) |
There was a problem hiding this comment.
copy_utils.tiled_copy_2d is defined as tiled_copy_2d(dtype, major_mode_size, num_threads, is_async=False) in flash_sparse_attn/ops/cute/copy_utils.py, but this call passes an extra positional argument (async_copy_elems) and uses threads_per_row as the second argument. As written, this will raise a TypeError when use_tma_Q is false (and even if it didn’t, it would be using the helper with the wrong semantics). Update the call to match the helper’s signature (i.e., pass the actual major-mode size in elements and use the keyword is_async=), or update tiled_copy_2d/call sites consistently if the intended API includes num_copy_elems.
Summary
Root Cause
Changes
Reproduction
Tests
Compatibility
Checklist