diff --git a/flash_sparse_attn/ops/cute/README.md b/flash_sparse_attn/ops/cute/README.md index 61aa412..653f7b1 100644 --- a/flash_sparse_attn/ops/cute/README.md +++ b/flash_sparse_attn/ops/cute/README.md @@ -8,6 +8,12 @@ FlashAttention-4 is a CuTeDSL-based implementation of FlashAttention for Hopper pip install flash-attn-4 ``` +If you're on CUDA 13, install with the `cu13` extra for best performance: + +```sh +pip install "flash-attn-4[cu13]" +``` + ## Usage ```python diff --git a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py index ac26fcb..fd42f62 100644 --- a/flash_sparse_attn/ops/cute/flash_fwd_sm100.py +++ b/flash_sparse_attn/ops/cute/flash_fwd_sm100.py @@ -61,6 +61,28 @@ SingleTileVarlenScheduler, ) +# === TUNING KNOBS (agent-editable) === +# Keys: (use_2cta_instrs: bool, is_causal: bool, head_dim_padded: int, is_sm103: bool) +# Values: +# ex2_emu_freq: int — how often to use emulated exp2 (0=all hardware exp2, higher=more emulation). +# SM103 has fast native exp2, so set freq=0 there. +# ex2_emu_start_frg: int — fragment index to start emulation from +# num_regs_softmax: int — register count for softmax warps (multiple of 8) +# num_regs_correction: int — register count for correction warps (multiple of 8) +# num_regs_other is derived: 512 - num_regs_softmax * 2 - num_regs_correction +_TUNING_CONFIG = { + (True, False, 128, False): {'ex2_emu_freq': 10, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 176, 'num_regs_correction': 88}, + (False, True, 128, False): {'ex2_emu_freq': 16, 'ex2_emu_start_frg': 1, 'num_regs_softmax': 192, 'num_regs_correction': 72}, + (True, False, 192, False): {"ex2_emu_freq": 16, "ex2_emu_start_frg": 0, "num_regs_softmax": 184, "num_regs_correction": 80}, + (False, True, 192, False): {"ex2_emu_freq": 32, "ex2_emu_start_frg": 1, "num_regs_softmax": 192, "num_regs_correction": 72}, + (True, False, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 80}, + (False, True, 128, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, + (True, False, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 64}, + (False, True, 192, True): {"ex2_emu_freq": 0, "ex2_emu_start_frg": 0, "num_regs_softmax": 176, "num_regs_correction": 72}, +} +# === END TUNING KNOBS === + + class FlashAttentionForwardSm100: def __init__( @@ -141,8 +163,10 @@ def __init__( # Does S1 need to wait for S0 to finish # self.s0_s1_barrier = self.head_dim_padded in [64, 96] and (not self.is_causal and not self.is_local) is_sm103 = self.arch >= Arch.sm_103 and self.arch <= Arch.sm_103f - # self.enable_ex2_emu = self.head_dim_padded <= 128 and not is_sm103 - self.enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103 + self.is_sm103 = is_sm103 + # enable_ex2_emu is derived: True if tuning config has freq > 0, else fallback to default logic + _default_enable_ex2_emu = (self.head_dim_padded <= 128 or (self.head_dim_padded == 192 and self.use_2cta_instrs and not self.is_causal and not self.is_local)) and not is_sm103 + self.enable_ex2_emu = _default_enable_ex2_emu self.s0_s1_barrier = False self.overlap_sO_sQ = ( (self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or @@ -210,31 +234,26 @@ def __init__( # vec buffer for row_max & row_sum self.tmem_vec_offset = self.tmem_s_offset + # Look up tuning config for register counts and ex2_emu params + _tune_key = (self.use_2cta_instrs, self.is_causal, self.head_dim_padded, self.is_sm103) + self._tune = _TUNING_CONFIG.get(_tune_key, {}) + if "ex2_emu_freq" in self._tune: + self.enable_ex2_emu = self._tune["ex2_emu_freq"] > 0 if self.head_dim_padded < 96: self.num_regs_softmax = 200 if not paged_kv_non_tma else 184 self.num_regs_correction = 64 self.num_regs_other = 48 if not paged_kv_non_tma else 80 else: - # self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184 - if not self.enable_ex2_emu: - self.num_regs_softmax = 192 if not paged_kv_non_tma else 184 + if not paged_kv_non_tma and "num_regs_softmax" in self._tune: + self.num_regs_softmax = self._tune["num_regs_softmax"] + self.num_regs_correction = self._tune["num_regs_correction"] + elif not paged_kv_non_tma: + self.num_regs_softmax = 192 + self.num_regs_correction = 80 else: - # self.num_regs_softmax = 200 if not paged_kv_non_tma else 184 - self.num_regs_softmax = 192 if not paged_kv_non_tma else 184 - # self.num_regs_softmax = 176 - # self.num_regs_correction = 96 - # self.num_regs_correction = 64 if self.is_causal or self.is_local else 80 - if not self.enable_ex2_emu: - self.num_regs_correction = 80 if not paged_kv_non_tma else 64 - else: - # self.num_regs_correction = 64 - self.num_regs_correction = 80 if not paged_kv_non_tma else 64 - # self.num_regs_other = 32 - # self.num_regs_other = 64 - # self.num_regs_other = 80 - self.num_regs_other = 48 if not paged_kv_non_tma else 80 - # self.num_regs_other = 96 if self.is_causal or self.is_local else 80 - # self.num_regs_other = 64 if self.is_causal or self.is_local else 80 + self.num_regs_softmax = 184 + self.num_regs_correction = 64 + self.num_regs_other = 512 - self.num_regs_softmax * 2 - self.num_regs_correction self.buffer_align_bytes = 1024 @@ -358,21 +377,14 @@ def __call__( and not (self.pack_gqa and self.m_block_size % self.qhead_per_kvhead != 0) and not (self.pack_gqa and self.is_split_kv) ) - # This can be tuned - # This is currently very ad-hoc, we should tune it systematically self.ex2_emu_freq = 0 - # self.ex2_emu_start_frg = 1 if self.is_causal else 0 - self.ex2_emu_start_frg = 1 + self.ex2_emu_start_frg = self._tune.get("ex2_emu_start_frg", 1) if const_expr(self.enable_ex2_emu): - self.ex2_emu_freq = 16 - if const_expr(self.head_dim_padded == 128 and self.use_2cta_instrs): - self.ex2_emu_freq = 12 + self.ex2_emu_freq = self._tune.get("ex2_emu_freq", 16) if const_expr( self.pack_gqa and self.head_dim_padded > 64 and not self.is_causal and not self.is_local ): - self.ex2_emu_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else 10 - if const_expr(self.head_dim_padded > 64 and self.is_causal): - self.ex2_emu_freq = 10 + self.ex2_emu_freq = 32 if mCuSeqlensQ is not None or mSeqUsedQ is not None else self._tune.get("ex2_emu_freq", 10) cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE q_major_mode = tcgen05.OperandMajorMode.K @@ -487,7 +499,7 @@ def __call__( tma_atom_Q = None async_copy_elems = 128 // self.q_dtype.width num_load_threads = cute.arch.WARP_SIZE * len(self.load_warp_ids) - threads_per_row = self.head_dim_padded // async_copy_elems + threads_per_row = math.gcd(self.head_dim_padded // async_copy_elems, num_load_threads) gmem_tiled_copy_Q = copy_utils.tiled_copy_2d( self.q_dtype, threads_per_row, num_load_threads, async_copy_elems, is_async=True ) @@ -550,8 +562,11 @@ def __call__( if const_expr(not self.is_persistent) else StaticPersistentTileScheduler ) + # For non-persistent 2CTA (use_cluster_idx), each cluster covers + # cta_tiler[0] * cta_group_size rows, so num_block must be divided accordingly + _num_block_divisor = self.cta_tiler[0] * (self.cta_group_size if not self.is_persistent and self.cta_group_size > 1 else 1) tile_sched_args = TileSchedulerArguments( - cute.ceil_div(cute.size(mQ.shape[0]), self.cta_tiler[0]), + cute.ceil_div(cute.size(mQ.shape[0]), _num_block_divisor), cute.size(mQ.shape[2]), cute.size(mQ.shape[3]) if const_expr(mCuSeqlensQ is None) @@ -574,6 +589,7 @@ def __call__( lpt=self.is_causal or self.is_local, is_split_kv=self.is_split_kv, cluster_shape_mn=self.cluster_shape_mn, + use_cluster_idx=not self.is_persistent and self.cta_group_size > 1, ) tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) self.tile_scheduler_cls = TileScheduler diff --git a/flash_sparse_attn/ops/cute/interface.py b/flash_sparse_attn/ops/cute/interface.py index a4558e1..77ab487 100644 --- a/flash_sparse_attn/ops/cute/interface.py +++ b/flash_sparse_attn/ops/cute/interface.py @@ -524,7 +524,7 @@ def _flash_attn_fwd( and seqused_q is None and not use_block_sparsity and page_size in [None, 128] - and int(math.ceil(head_dim / 16) * 16) == 128 + and int(math.ceil(head_dim / 16) * 16) in [128, 192] and int(math.ceil(head_dim_v / 16) * 16) == 128 and seqlen_q_packgqa > 2 * tile_m and (tile_m % qhead_per_kvhead == 0 or not pack_gqa) diff --git a/flash_sparse_attn/ops/cute/pyproject.toml b/flash_sparse_attn/ops/cute/pyproject.toml index 40bd734..2b0b60b 100644 --- a/flash_sparse_attn/ops/cute/pyproject.toml +++ b/flash_sparse_attn/ops/cute/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ ] [project.optional-dependencies] +cu13 = ["nvidia-cutlass-dsl[cu13]>=4.4.2"] dev = [ "pytest", "ruff", @@ -51,6 +52,16 @@ tag_regex = "^fa4-v(?P.+)$" git_describe_command = "git describe --dirty --tags --long --match 'fa4-v*'" fallback_version = "0.0.0" +[[tool.uv.index]] +name = "pytorch-cu130" +url = "https://download.pytorch.org/whl/cu130" +explicit = true + +[tool.uv.sources] +torch = [ + { index = "pytorch-cu130", marker = "extra == 'cu13'" }, +] + [tool.ruff] line-length = 100 diff --git a/flash_sparse_attn/ops/cute/tile_scheduler.py b/flash_sparse_attn/ops/cute/tile_scheduler.py index 8f60f50..73add39 100644 --- a/flash_sparse_attn/ops/cute/tile_scheduler.py +++ b/flash_sparse_attn/ops/cute/tile_scheduler.py @@ -51,6 +51,7 @@ class TileSchedulerArguments(ParamsBase): lpt: cutlass.Constexpr[bool] = False is_split_kv: cutlass.Constexpr[bool] = False head_swizzle: cutlass.Constexpr[bool] = False + use_cluster_idx: cutlass.Constexpr[bool] = False class SingleTileScheduler: @@ -63,6 +64,7 @@ class Params(ParamsBase): num_splits_divmod: FastDivmodDivisor is_split_kv: cutlass.Constexpr[bool] = False cluster_shape_mn: cutlass.Constexpr[Tuple[int, int]] = (1, 1) + use_cluster_idx: cutlass.Constexpr[bool] = False @staticmethod def create( @@ -76,6 +78,7 @@ def create( FastDivmodDivisor(args.num_splits), args.is_split_kv, args.cluster_shape_mn, + args.use_cluster_idx, ) def __init__(self, params: Params, blk_coord: cute.Coord, *, loc=None, ip=None): @@ -91,13 +94,11 @@ def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) @staticmethod def create(params: Params, *, loc=None, ip=None) -> "SingleTileScheduler": - # if const_expr(cute.size(params.cluster_shape_mn) == 1): - # blk_coord = cute.arch.block_idx() - # else: - # # All CTAs in a cluster must get the same block coordinate - # blk_coord = cute.arch.cluster_idx() - # Temporary set to block_idx until we sort out the best way to handle cluster - blk_coord = cute.arch.block_idx() + if const_expr(cute.size(params.cluster_shape_mn) == 1 or not params.use_cluster_idx): + blk_coord = cute.arch.block_idx() + else: + # All CTAs in a cluster must get the same block coordinate + blk_coord = cute.arch.cluster_idx() return SingleTileScheduler(params, blk_coord, loc=loc, ip=ip) # called by host @@ -110,8 +111,13 @@ def get_grid_shape( ) -> Tuple[Int32, Int32, Int32]: # TODO: this hard-codes the fact that we only use cluster = (1, 1) or (2, 1) assert params.cluster_shape_mn[1] == 1, "Only cluster_shape_mn[1] == 1 is supported" + if const_expr(params.use_cluster_idx): + # Grid must have num_block * cluster_m physical blocks so that there are num_block clusters + grid_x = params.num_block * params.cluster_shape_mn[0] + else: + grid_x = cute.round_up(params.num_block, params.cluster_shape_mn[0]) return ( - cute.round_up(params.num_block, params.cluster_shape_mn[0]), + grid_x, params.num_head * params.num_splits, params.num_batch, ) @@ -395,8 +401,8 @@ def create( ) -> "SingleTileLPTBwdScheduler.Params": size_l2 = 50 * 1024 * 1024 size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size - # size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4 - size_one_dqaccum_head = 0 + size_one_dqaccum_head = args.seqlen_k * (args.headdim) * 4 + # size_one_dqaccum_head = 0 size_one_head = size_one_qdo_head + size_one_dqaccum_head log2_floor = lambda n: 31 - clz(n) swizzle = 1 if size_l2 < size_one_head else (1 << log2_floor(size_l2 // size_one_head)) @@ -521,9 +527,12 @@ def create( args: TileSchedulerArguments, *, loc=None, ip=None ) -> "SingleTileVarlenScheduler.Params": size_l2 = 50 * 1024 * 1024 # 50 MB for K & V - max_kvblock_in_l2 = size_l2 // ( - (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] - ) + # if backward, this is qdo block size + kv_block_size = (args.headdim + args.headdim_v) * args.element_size * args.tile_shape_mn[1] + # if backward, add dqaccum block size to calculate swizzle + if args.head_swizzle: + kv_block_size += args.headdim * 4 * args.tile_shape_mn[1] + max_kvblock_in_l2 = size_l2 // kv_block_size assert args.mCuSeqlensQ is not None or args.mSeqUsedQ is not None, ( "At least one of mCuSeqlensQ or mSeqUsedQ must be provided" ) @@ -654,6 +663,7 @@ def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: num_n_blocks = ( num_m_blocks * params.tile_shape_mn[0] + * params.cluster_shape_m // params.qhead_per_kvhead_packgqa // params.tile_shape_mn[1] ) diff --git a/scripts/sync_cute_subtree.ps1 b/scripts/sync_cute_subtree.ps1 index 4685294..1d5ac0b 100644 --- a/scripts/sync_cute_subtree.ps1 +++ b/scripts/sync_cute_subtree.ps1 @@ -44,6 +44,29 @@ function Invoke-Git { } } +function Invoke-GitNoMergeEdit { + param( + [Parameter(Mandatory = $true)] + [string[]]$Arguments, + [string]$Repo + ) + + $previousValue = $env:GIT_MERGE_AUTOEDIT + $env:GIT_MERGE_AUTOEDIT = "no" + + try { + Invoke-Git -Arguments $Arguments -Repo $Repo + } + finally { + if ($null -eq $previousValue) { + Remove-Item Env:GIT_MERGE_AUTOEDIT -ErrorAction SilentlyContinue + } + else { + $env:GIT_MERGE_AUTOEDIT = $previousValue + } + } +} + function Get-GitOutput { param( [Parameter(Mandatory = $true)] @@ -121,6 +144,26 @@ function Ensure-GitIdentity { } } +function Resolve-RemoteDefaultRef { + param( + [Parameter(Mandatory = $true)] + [string]$Repo + ) + + $remoteRef = (& git -C $Repo symbolic-ref --quiet --short refs/remotes/origin/HEAD) 2>$null + if ($LASTEXITCODE -eq 0 -and $remoteRef) { + return ($remoteRef | Out-String).Trim() + } + + & git -C $Repo remote set-head origin --auto *> $null + $remoteRef = (& git -C $Repo symbolic-ref --quiet --short refs/remotes/origin/HEAD) 2>$null + if ($LASTEXITCODE -eq 0 -and $remoteRef) { + return ($remoteRef | Out-String).Trim() + } + + return "origin/main" +} + function Get-CommitSubject { param( [Parameter(Mandatory = $true)] @@ -132,6 +175,80 @@ function Get-CommitSubject { return Get-GitOutput -Repo $Repo -Arguments @("log", "-1", "--format=%s", $Commit) } +function Get-CommitParentCount { + param( + [Parameter(Mandatory = $true)] + [string]$Repo, + [Parameter(Mandatory = $true)] + [string]$Commit + ) + + $revLine = Get-GitOutput -Repo $Repo -Arguments @("rev-list", "--parents", "-n", "1", $Commit) + if (-not $revLine) { + return 0 + } + + $fields = $revLine -split '\s+' + return [Math]::Max(0, $fields.Count - 1) +} + +function Get-PrefixSplitCommit { + param( + [Parameter(Mandatory = $true)] + [string]$Repo, + [Parameter(Mandatory = $true)] + [string]$Prefix + ) + + return Get-GitOutput -Repo $Repo -Arguments @("subtree", "split", "--prefix=$Prefix", "HEAD") +} + +function Test-CommitIsAncestor { + param( + [Parameter(Mandatory = $true)] + [string]$Repo, + [Parameter(Mandatory = $true)] + [string]$OlderCommit, + [Parameter(Mandatory = $true)] + [string]$NewerCommit + ) + + & git -C $Repo merge-base --is-ancestor $OlderCommit $NewerCommit *> $null + return $LASTEXITCODE -eq 0 +} + +function Assert-SyncContainsUpstream { + param( + [Parameter(Mandatory = $true)] + [string]$Repo, + [Parameter(Mandatory = $true)] + [string]$UpstreamSplitCommit, + [Parameter(Mandatory = $true)] + [string]$LocalSplitCommit, + [string]$PreviousLocalSplitCommit + ) + + if (Test-CommitIsAncestor -Repo $Repo -OlderCommit $UpstreamSplitCommit -NewerCommit $LocalSplitCommit) { + return + } + + $message = @( + "Subtree sync did not incorporate the upstream split commit.", + "Upstream split commit: $UpstreamSplitCommit", + "Local prefix split after sync: $LocalSplitCommit" + ) + + if ($PreviousLocalSplitCommit) { + $message += "Local prefix split before sync: $PreviousLocalSplitCommit" + if ($PreviousLocalSplitCommit -eq $LocalSplitCommit) { + $message += "The local prefix split did not change even though upstream has newer CuTe commits." + } + } + + $message += "git subtree pull reported success, but the vendored prefix still does not contain the upstream split lineage." + throw ($message -join [Environment]::NewLine) +} + function Invoke-CoreSync { param( [Parameter(Mandatory = $true)] @@ -146,6 +263,8 @@ function Invoke-CoreSync { [string]$TempBranch, [Parameter(Mandatory = $true)] [string]$RewriteScript, + [Parameter(Mandatory = $true)] + [string]$UpstreamSplitRef, [switch]$Init, [switch]$SkipFetch, [switch]$KeepTempBranch @@ -154,6 +273,7 @@ function Invoke-CoreSync { $cutlassRepo = Join-Path $WorkRepoRoot "csrc/cutlass" $targetPath = Join-Path $WorkRepoRoot $Prefix $startHead = Get-GitOutput -Repo $WorkRepoRoot -Arguments @("rev-parse", "HEAD") + $localSplitBefore = $null Invoke-Git -Repo $WorkRepoRoot -Arguments @("rev-parse", "--show-toplevel") | Out-Null Invoke-Git -Repo $UpstreamRepoForSplit -Arguments @("rev-parse", "--show-toplevel") | Out-Null @@ -168,10 +288,14 @@ function Invoke-CoreSync { Invoke-Git -Repo $UpstreamRepoForSplit -Arguments @("fetch", "origin") } - Write-Host "Splitting upstream history for $UpstreamPrefix ..." - $splitCommit = Get-GitOutput -Repo $UpstreamRepoForSplit -Arguments @("subtree", "split", "--prefix=$UpstreamPrefix", "HEAD") + Write-Host "Splitting upstream history for $UpstreamPrefix from $UpstreamSplitRef ..." + $splitCommit = Get-GitOutput -Repo $UpstreamRepoForSplit -Arguments @("subtree", "split", "--prefix=$UpstreamPrefix", $UpstreamSplitRef) Invoke-Git -Repo $UpstreamRepoForSplit -Arguments @("branch", "-f", $TempBranch, $splitCommit) + if ((-not $Init) -and (Test-Path $targetPath)) { + $localSplitBefore = Get-PrefixSplitCommit -Repo $WorkRepoRoot -Prefix $Prefix + } + try { if ($Init) { if (Test-Path $targetPath) { @@ -179,7 +303,7 @@ function Invoke-CoreSync { } Write-Host "Adding subtree into $Prefix ..." - Invoke-Git -Repo $WorkRepoRoot -Arguments @("subtree", "add", "--prefix=$Prefix", $UpstreamRepoForSplit, $TempBranch) + Invoke-GitNoMergeEdit -Repo $WorkRepoRoot -Arguments @("subtree", "add", "--prefix=$Prefix", $UpstreamRepoForSplit, $TempBranch) } else { if (-not (Test-Path $targetPath)) { @@ -187,7 +311,7 @@ function Invoke-CoreSync { } Write-Host "Pulling upstream updates into $Prefix ..." - Invoke-Git -Repo $WorkRepoRoot -Arguments @("subtree", "pull", "--prefix=$Prefix", $UpstreamRepoForSplit, $TempBranch) + Invoke-GitNoMergeEdit -Repo $WorkRepoRoot -Arguments @("subtree", "pull", "--prefix=$Prefix", $UpstreamRepoForSplit, $TempBranch) } } finally { @@ -209,6 +333,9 @@ function Invoke-CoreSync { Invoke-Git -Repo $WorkRepoRoot -Arguments @("commit", "-m", "Rewrite vendored CuTe namespace to flash_sparse_attn.ops.cute") } + $localSplitAfter = Get-PrefixSplitCommit -Repo $WorkRepoRoot -Prefix $Prefix + Assert-SyncContainsUpstream -Repo $WorkRepoRoot -UpstreamSplitCommit $splitCommit -LocalSplitCommit $localSplitAfter -PreviousLocalSplitCommit $localSplitBefore + $endHead = Get-GitOutput -Repo $WorkRepoRoot -Arguments @("rev-parse", "HEAD") return [PSCustomObject]@{ StartHead = $startHead @@ -230,6 +357,8 @@ function Invoke-TemporaryWorktreeSync { [string]$TempBranch, [Parameter(Mandatory = $true)] [string]$RewriteScript, + [Parameter(Mandatory = $true)] + [string]$UpstreamSplitRef, [switch]$Init, [switch]$SkipFetch, [switch]$KeepTempBranch @@ -247,9 +376,9 @@ function Invoke-TemporaryWorktreeSync { $stashCreated = $false try { - $result = Invoke-CoreSync -WorkRepoRoot $tempWorktree -UpstreamRepoForSplit $UpstreamRepoForSplit -Prefix $Prefix -UpstreamPrefix $UpstreamPrefix -TempBranch $TempBranch -RewriteScript $RewriteScript -Init:$Init -SkipFetch:$SkipFetch -KeepTempBranch:$KeepTempBranch + $result = Invoke-CoreSync -WorkRepoRoot $tempWorktree -UpstreamRepoForSplit $UpstreamRepoForSplit -Prefix $Prefix -UpstreamPrefix $UpstreamPrefix -TempBranch $TempBranch -RewriteScript $RewriteScript -UpstreamSplitRef $UpstreamSplitRef -Init:$Init -SkipFetch:$SkipFetch -KeepTempBranch:$KeepTempBranch - $commitListOutput = Get-GitOutput -Repo $tempWorktree -Arguments @("rev-list", "--reverse", "HEAD", "^$originalHead") + $commitListOutput = Get-GitOutput -Repo $tempWorktree -Arguments @("rev-list", "--reverse", "--first-parent", "HEAD", "^$originalHead") $commits = @() if ($commitListOutput) { $commits = $commitListOutput -split "`r?`n" | Where-Object { $_ } @@ -282,7 +411,12 @@ function Invoke-TemporaryWorktreeSync { foreach ($commit in $commitsToCherryPick) { Write-Host "Cherry-picking $commit back into current worktree ..." Ensure-GitIdentity -Repo $RepoRoot - Invoke-Git -Repo $RepoRoot -Arguments @("cherry-pick", $commit) + if ((Get-CommitParentCount -Repo $tempWorktree -Commit $commit) -gt 1) { + Invoke-Git -Repo $RepoRoot -Arguments @("cherry-pick", "-m", "1", $commit) + } + else { + Invoke-Git -Repo $RepoRoot -Arguments @("cherry-pick", $commit) + } } } catch { @@ -323,6 +457,7 @@ $rewriteScript = Join-Path $repoRoot "scripts/rewrite_cute_namespace.py" $cacheRepo = Join-Path $repoRoot $CacheDir $upstreamRepoForSplit = $null +$upstreamSplitRef = "HEAD" if (Test-GitRemoteSpec -Value $UpstreamRepo) { if (-not (Test-Path $cacheRepo)) { @@ -337,23 +472,22 @@ if (Test-GitRemoteSpec -Value $UpstreamRepo) { Write-Host "Updating cached upstream origin URL ..." Invoke-Git -Repo $upstreamRepoForSplit -Arguments @("remote", "set-url", "origin", $UpstreamRepo) } + + $upstreamSplitRef = Resolve-RemoteDefaultRef -Repo $upstreamRepoForSplit } else { $upstreamRepoForSplit = (Resolve-Path $UpstreamRepo).Path } -$cutlassRepo = Join-Path $repoRoot "csrc/cutlass" -$targetPath = Join-Path $repoRoot $Prefix - $dirtyStatus = Get-DirtyStatus -Repo $repoRoot if ($dirtyStatus -and -not $NoTemporaryWorktree) { - $syncResult = Invoke-TemporaryWorktreeSync -RepoRoot $repoRoot -UpstreamRepoForSplit $upstreamRepoForSplit -Prefix $Prefix -UpstreamPrefix $UpstreamPrefix -TempBranch $TempBranch -RewriteScript $rewriteScript -Init:$Init -SkipFetch:$SkipFetch -KeepTempBranch:$KeepTempBranch + $syncResult = Invoke-TemporaryWorktreeSync -RepoRoot $repoRoot -UpstreamRepoForSplit $upstreamRepoForSplit -Prefix $Prefix -UpstreamPrefix $UpstreamPrefix -TempBranch $TempBranch -RewriteScript $rewriteScript -UpstreamSplitRef $upstreamSplitRef -Init:$Init -SkipFetch:$SkipFetch -KeepTempBranch:$KeepTempBranch } else { if ($dirtyStatus) { throw "Superproject has uncommitted changes and -NoTemporaryWorktree was set.`n$dirtyStatus" } - $syncResult = Invoke-CoreSync -WorkRepoRoot $repoRoot -UpstreamRepoForSplit $upstreamRepoForSplit -Prefix $Prefix -UpstreamPrefix $UpstreamPrefix -TempBranch $TempBranch -RewriteScript $rewriteScript -Init:$Init -SkipFetch:$SkipFetch -KeepTempBranch:$KeepTempBranch + $syncResult = Invoke-CoreSync -WorkRepoRoot $repoRoot -UpstreamRepoForSplit $upstreamRepoForSplit -Prefix $Prefix -UpstreamPrefix $UpstreamPrefix -TempBranch $TempBranch -RewriteScript $rewriteScript -UpstreamSplitRef $upstreamSplitRef -Init:$Init -SkipFetch:$SkipFetch -KeepTempBranch:$KeepTempBranch } Write-Host "Done." diff --git a/scripts/sync_cute_subtree.sh b/scripts/sync_cute_subtree.sh index e164c89..cc35e24 100644 --- a/scripts/sync_cute_subtree.sh +++ b/scripts/sync_cute_subtree.sh @@ -14,6 +14,7 @@ NO_TEMPORARY_WORKTREE=0 TEMP_WORKTREE_PATH="" TEMP_WORKTREE_BRANCH="" REWRITE_COMMIT_MESSAGE="Rewrite vendored CuTe namespace to flash_sparse_attn.ops.cute" +UPSTREAM_SPLIT_REF="HEAD" usage() { cat <<'EOF' @@ -46,6 +47,14 @@ invoke_git() { git "$@" } +invoke_git_no_merge_edit() { + if [[ $# -lt 1 ]]; then + echo "invoke_git_no_merge_edit requires arguments" >&2 + exit 1 + fi + GIT_MERGE_AUTOEDIT=no git "$@" +} + git_output() { if [[ $# -lt 1 ]]; then echo "git_output requires arguments" >&2 @@ -98,12 +107,80 @@ ensure_git_identity() { fi } +resolve_remote_default_ref() { + local repo="$1" + local remote_ref + + remote_ref="$(git -C "$repo" symbolic-ref --quiet --short refs/remotes/origin/HEAD 2>/dev/null || true)" + if [[ -n "$remote_ref" ]]; then + printf '%s\n' "$remote_ref" + return + fi + + git -C "$repo" remote set-head origin --auto >/dev/null 2>&1 || true + remote_ref="$(git -C "$repo" symbolic-ref --quiet --short refs/remotes/origin/HEAD 2>/dev/null || true)" + if [[ -n "$remote_ref" ]]; then + printf '%s\n' "$remote_ref" + return + fi + + printf 'origin/main\n' +} + get_commit_subject() { local repo="$1" local commit="$2" git -C "$repo" log -1 --format=%s "$commit" } +get_prefix_split_commit() { + local repo="$1" + local prefix="$2" + git -C "$repo" subtree split --prefix="$prefix" HEAD | tail -n 1 | tr -d '\r' +} + +commit_is_ancestor() { + local repo="$1" + local older_commit="$2" + local newer_commit="$3" + + git -C "$repo" merge-base --is-ancestor "$older_commit" "$newer_commit" +} + +assert_sync_contains_upstream() { + local repo="$1" + local upstream_split_commit="$2" + local local_split_commit="$3" + local previous_local_split_commit="${4:-}" + + if commit_is_ancestor "$repo" "$upstream_split_commit" "$local_split_commit"; then + return + fi + + echo "Subtree sync did not incorporate the upstream split commit." >&2 + echo "Upstream split commit: $upstream_split_commit" >&2 + echo "Local prefix split after sync: $local_split_commit" >&2 + if [[ -n "$previous_local_split_commit" ]]; then + echo "Local prefix split before sync: $previous_local_split_commit" >&2 + if [[ "$previous_local_split_commit" == "$local_split_commit" ]]; then + echo "The local prefix split did not change even though upstream has newer CuTe commits." >&2 + fi + fi + echo "git subtree pull reported success, but the vendored prefix still does not contain the upstream split lineage." >&2 + exit 1 +} + +get_commit_parent_count() { + local repo="$1" + local commit="$2" + local rev_line + local rev_fields + + rev_line="$(git -C "$repo" rev-list --parents -n 1 "$commit")" + read -r -a rev_fields <<< "$rev_line" + printf '%s\n' "$(( ${#rev_fields[@]} - 1 ))" +} + cleanup_worktree() { if [[ -n "$TEMP_WORKTREE_PATH" ]]; then git -C "$REPO_ROOT" worktree remove --force "$TEMP_WORKTREE_PATH" >/dev/null 2>&1 || true @@ -119,7 +196,7 @@ invoke_core_sync() { local work_repo_root="$1" local cutlass_repo="$work_repo_root/csrc/cutlass" local target_path="$work_repo_root/$PREFIX" - local start_head + local start_head local_split_before local_split_after start_head="$(git_output -C "$work_repo_root" rev-parse HEAD)" invoke_git -C "$work_repo_root" rev-parse --show-toplevel >/dev/null @@ -135,10 +212,16 @@ invoke_core_sync() { invoke_git -C "$UPSTREAM_REPO_FOR_SPLIT" fetch origin fi - echo "Splitting upstream history for $UPSTREAM_PREFIX ..." - SPLIT_COMMIT="$(git_output -C "$UPSTREAM_REPO_FOR_SPLIT" subtree split --prefix="$UPSTREAM_PREFIX" HEAD | tail -n 1 | tr -d '\r')" + echo "Splitting upstream history for $UPSTREAM_PREFIX from $UPSTREAM_SPLIT_REF ..." + SPLIT_COMMIT="$(git_output -C "$UPSTREAM_REPO_FOR_SPLIT" subtree split --prefix="$UPSTREAM_PREFIX" "$UPSTREAM_SPLIT_REF" | tail -n 1 | tr -d '\r')" invoke_git -C "$UPSTREAM_REPO_FOR_SPLIT" branch -f "$TEMP_BRANCH" "$SPLIT_COMMIT" + if [[ "$INIT" -eq 0 ]] && [[ -e "$target_path" ]]; then + local_split_before="$(get_prefix_split_commit "$work_repo_root" "$PREFIX")" + else + local_split_before="" + fi + cleanup_core() { if [[ "$KEEP_TEMP_BRANCH" -eq 0 ]]; then git -C "$UPSTREAM_REPO_FOR_SPLIT" update-ref -d "refs/heads/$TEMP_BRANCH" >/dev/null 2>&1 || true @@ -153,14 +236,14 @@ invoke_core_sync() { exit 1 fi echo "Adding subtree into $PREFIX ..." - invoke_git -C "$work_repo_root" subtree add --prefix="$PREFIX" "$UPSTREAM_REPO_FOR_SPLIT" "$TEMP_BRANCH" + invoke_git_no_merge_edit -C "$work_repo_root" subtree add --prefix="$PREFIX" "$UPSTREAM_REPO_FOR_SPLIT" "$TEMP_BRANCH" else if [[ ! -e "$target_path" ]]; then echo "$PREFIX does not exist yet. Run this script once with --init first." >&2 exit 1 fi echo "Pulling upstream updates into $PREFIX ..." - invoke_git -C "$work_repo_root" subtree pull --prefix="$PREFIX" "$UPSTREAM_REPO_FOR_SPLIT" "$TEMP_BRANCH" + invoke_git_no_merge_edit -C "$work_repo_root" subtree pull --prefix="$PREFIX" "$UPSTREAM_REPO_FOR_SPLIT" "$TEMP_BRANCH" fi echo "Rewriting vendored CuTe imports to flash_sparse_attn.ops.cute ..." @@ -172,6 +255,9 @@ invoke_core_sync() { invoke_git -C "$work_repo_root" commit -m "Rewrite vendored CuTe namespace to flash_sparse_attn.ops.cute" fi + local_split_after="$(get_prefix_split_commit "$work_repo_root" "$PREFIX")" + assert_sync_contains_upstream "$work_repo_root" "$SPLIT_COMMIT" "$local_split_after" "$local_split_before" + END_HEAD="$(git_output -C "$work_repo_root" rev-parse HEAD)" SYNC_START_HEAD="$start_head" SYNC_END_HEAD="$END_HEAD" @@ -191,7 +277,7 @@ invoke_temporary_worktree_sync() { invoke_core_sync "$temp_worktree" - commits="$(git -C "$temp_worktree" rev-list --reverse HEAD "^$original_head")" + commits="$(git -C "$temp_worktree" rev-list --reverse --first-parent HEAD "^$original_head")" if [[ -z "$commits" ]]; then echo "No new subtree commits were created." return @@ -219,7 +305,11 @@ invoke_temporary_worktree_sync() { for commit in "${cherry_pick_commits[@]}"; do echo "Cherry-picking $commit back into current worktree ..." ensure_git_identity "$REPO_ROOT" - invoke_git -C "$REPO_ROOT" cherry-pick "$commit" + if [[ "$(get_commit_parent_count "$temp_worktree" "$commit")" -gt 1 ]]; then + invoke_git -C "$REPO_ROOT" cherry-pick -m 1 "$commit" + else + invoke_git -C "$REPO_ROOT" cherry-pick "$commit" + fi done if [[ -n "$current_status" ]]; then @@ -309,6 +399,10 @@ else UPSTREAM_REPO_FOR_SPLIT="$(cd "$UPSTREAM_REPO" && pwd)" fi +if is_git_remote_spec "$UPSTREAM_REPO"; then + UPSTREAM_SPLIT_REF="$(resolve_remote_default_ref "$UPSTREAM_REPO_FOR_SPLIT")" +fi + SYNC_START_HEAD="$(git_output -C "$REPO_ROOT" rev-parse HEAD)" SYNC_END_HEAD="$SYNC_START_HEAD"