Skip to content

Add return_max_attn_logit for QK-Clip support#1

Merged
WyldeCat merged 1 commit intomainfrom
jeesoo/return-max-attn-logit
Mar 26, 2026
Merged

Add return_max_attn_logit for QK-Clip support#1
WyldeCat merged 1 commit intomainfrom
jeesoo/return-max-attn-logit

Conversation

@WyldeCat
Copy link
Copy Markdown
Member

@WyldeCat WyldeCat commented Mar 20, 2026

Summary

FA4 forward pass에 return_max_attn_logit 파라미터 추가.
커널이 attention forward 중에 per-head max(Q@K^T / sqrt(d))를 추출하여 반환.
QK-Clip (Muon optimizer) 용도로, 기존 별도 full matmul 대비 ~10x 이상 효율적.

현재 SM100 (Blackwell) 만 지원. SM90 (Hopper)에서는 NotImplementedError.

Bug fix: rescale_threshold로 인한 max_attn_logit 오류

Online Softmax과 rescale_threshold

FA4는 online softmax 알고리즘을 사용하여 attention을 n_block 단위로 처리합니다.
각 n_block마다 row_max(현재까지 본 최대 attention score)를 갱신하고,
이전까지 누적된 output O를 rescale합니다:

n_block 0 (is_first=True):
  row_max = max(scores_0)
  O = softmax(scores_0) @ V_0

n_block 1:
  row_max_new = max(row_max, max(scores_1))
  acc_scale = exp2((row_max_old - row_max_new) * scale_log2)
  O = O * acc_scale + softmax(scores_1) @ V_1
  row_max = row_max_new

rescale_threshold=8.0은 O rescaling 비용을 줄이기 위한 최적화입니다:

acc_scale_ = (row_max_old - row_max_new) * scale_log2
if acc_scale_ >= -rescale_threshold:   # max 변화가 threshold 이내이면
    row_max_new = row_max_old          # row_max 업데이트 건너뜀
    acc_scale = 1.0                    # O rescaling도 건너뜀

왜 attention output과 LSE는 정확하고 max_attn_logit만 틀리는가

Softmax는 shift invariant합니다:

softmax(s)_i = exp(s_i - c) / Σ_j exp(s_j - c)   (임의의 상수 c에 대해 동일)

threshold가 rescale을 skip하면 모든 n_block이 동일한 shift 상수(row_max_old)를 사용하므로
attention output O = Σ(P_i @ V_i) / Σ(P_i)는 수학적으로 정확합니다.
LSE = log(Σ exp(s))도 동일한 이유로 정확합니다.

threshold=8.0의 실제 목적은 shift를 안 하면 커질 수 있는 exp2 값이 bf16 범위를
넘지 않도록 상한을 두는 것입니다 (exp2(8) = 256, bf16 safe).

그러나 max_attn_logit = row_max / √drow_max 자체에 직접 의존합니다.
threshold에 의해 row_max가 고정되면 이후 n_block의 더 큰 score가 반영되지 않습니다.

Fix

return_max_attn_logit=True일 때 rescale_threshold=0.0으로 설정하여
row_max가 항상 실제 최대값을 추적하도록 합니다.
return_max_attn_logit=False (기본값)일 때는 기존 최적화 그대로 유지.

TODO: true_row_max를 별도 register로 추적하면 threshold 유지 + 정확한 max_attn_logit이 가능 (~0% overhead). sScale shared memory 확장이 필요하며 별도 PR로 진행 예정.

Code Changes

flash_attn/cute/interface.py

  • return_max_attn_logit: bool = False 파라미터 추가
  • return_lse=True 필수 assert (커널이 LSE의 sScale 경로로 row_max 전달)
  • SM90에서 NotImplementedError
  • max_attn_logit 텐서 할당, CuTe tensor 변환, 커널 전달
  • host-side amax reduction → (num_heads,) 반환
  • compile_key에 포함 (별도 컴파일 캐시)
  • FlashAttnFunc / FlashAttnVarlenFunc: mark_non_differentiable 처리
  • flash_attn_func / flash_attn_varlen_func: 조건부 반환
  • bugfix: head_dim_v_paddedhead_dim 대신 head_dim_v 사용

flash_attn/cute/flash_fwd_sm100.py

  • has_max_attn_logit 파라미터 → softmax_loop, correction_loop에 전달
  • softmax_loop: rescale_threshold 조건부 설정 + mMaxAttnLogit 전달
  • correction_loop: LSE write loop 안에서 max_attn_logit 동시 write (동일 stats[] unpack, 중복 제거)

Test Results

B200 (SM100), 15 tests, all passed (rel diff = 0.000000):

Test Config
shape & return (SMALL) n_heads=10, V=128
no qk_clip returns None
output unchanged by qk_clip flag
raw MHA symmetric (seq=128) d=V=192, 1 tile
raw GQA asymmetric (seq=256) d=192 V=128, multi-tile
debug write pattern (seq=128) d=192 V=128
LSE asymmetric vs symmetric
DiffAttentionV2 SMALL (seq=256) n_heads=10
DiffAttentionV2 motif3 (seq=4096) n_heads=80, d=192, V=128
gradient flow
SWA causal (window=128, seq=512) d=192, V=128
SWA small window (window=32, seq=256) d=192, V=128
SWA DiffAttentionV2 (window=128) n_heads=10
SWA motif3 config (seq=4096) n_heads=80, window=128
4-node training (32 GPU, 30 steps) motif3_seq, qk_clip_interval=2

Benchmark

B200, BF16, bs=2, nh=80, nhkv=16:

Config Baseline w/ max_logit Overhead
motif3 asymmetric (d=192 V=128) non-causal sq=4096 1.286ms 1.320ms +2.6%
motif3 asymmetric causal sq=4096 0.689ms 0.726ms +5.4%
symmetric (d=V=192) non-causal sq=4096 2.018ms 2.091ms +3.6%
symmetric causal sq=4096 1.087ms 1.136ms +4.5%
motif3 causal sq=1024 0.073ms 0.101ms +37.3%
motif3 causal sq=8192 2.642ms 2.824ms +6.9%
  • return_max_attn_logit=False (default): overhead 0% (별도 compile key)
  • vs _compute_qk_clip() (별도 full Q@K^T matmul): ~10x 이상 효율적

Limitations

  • SM100 (Blackwell) 전용. SM90 (Hopper)에서는 NotImplementedError
  • score_mod, softcap, num_splits > 1과 비호환
  • pack_gqa=False 강제
  • return_lse=True 필수
  • Non-differentiable (forward only)

🤖 Generated with Claude Code

@WyldeCat
Copy link
Copy Markdown
Member Author

WyldeCat commented Mar 24, 2026

CuTe DSL Primer (for reviewers)

이 PR의 커널 코드는 NVIDIA CuTe DSL로 작성됨. 핵심 패턴을 이해하면 변경사항을 따라갈 수 있음.

텐서 slicing & tiling

# Global memory 텐서에서 특정 head/batch를 선택 (NumPy의 indexing과 유사)
mMAL_cur = mMaxAttnLogit[None, head_idx, batch_idx]
# → mMaxAttnLogit이 (seqlen, heads, batch) 일 때, head_idx번째 head의 1D slice 반환

# Sequence offset 적용 (varlen에서 각 sequence의 시작 위치)
mMAL_cur = cute.domain_offset((offset,), mMaxAttnLogit[None, head_idx])
# → 1D 텐서의 시작점을 offset만큼 이동 (pointer arithmetic)

# 1D 텐서를 고정 크기 블록으로 나누고, m_tile_idx번째 블록 선택
gMAL = cute.local_tile(mMAL_cur, (m_block_size,), (m_tile_idx,))
# → mMAL_cur[m_tile_idx * m_block_size : (m_tile_idx+1) * m_block_size]
# 예: m_block_size=128, m_tile_idx=1 → rows 128~255

SM100 warp specialization

SM100 커널은 3종류의 warp group이 동시에 실행됨:

MMA warps     : Q@K, P@V 행렬곱 (tensor core)
Softmax warps : online softmax (row_max, row_sum 계산) → sScale 공유메모리에 저장
Correction warps : O rescale, LSE/max_attn_logit gmem write

max_attn_logit write는 correction warp에서 수행:

  • Softmax warp가 row_max를 sScale (shared memory)에 저장
  • Correction warp가 sScale에서 읽어 row_max * softmax_scale = row_max / sqrt(d) 로 변환 후 gmem에 write
  • 각 thread (tidx 0~127)가 한 row씩 담당 — 단순 1:1 매핑
# Correction warp의 max_attn_logit write (flash_fwd_sm100.py)
for stage in range(q_stage):
    m_tile_idx = (m_block * q_stage + stage) * cta_group_size + mma_tile_coord_v
    gMAL = cute.local_tile(mMAL_cur, (m_block_size,), (m_tile_idx,))
    row_sum, row_max, _ = stats[stage]    # sScale에서 읽은 값
    row_logit = row_max * softmax_scale_log2 * LN2   # = row_max / sqrt(d)
    if tidx < seqlen_q - m_tile_idx * m_block_size:  # bounds check
        gMAL[tidx] = row_logit

Comment thread flash_attn/cute/flash_fwd_sm100.py
Comment thread flash_attn/cute/flash_fwd_sm100.py
Comment thread flash_attn/cute/interface.py Outdated
Comment thread flash_attn/cute/interface.py
@WyldeCat WyldeCat force-pushed the jeesoo/return-max-attn-logit branch 7 times, most recently from e34bd90 to 3fdd2fc Compare March 26, 2026 06:33
Add `return_max_attn_logit` parameter to FA4's forward API that returns
per-head `max(Q@K^T / sqrt(d))` during the attention forward pass.
Used by QK-Clip (Muon optimizer) to efficiently obtain attention logit
statistics without a separate full matmul.

SM100 kernel changes:
- softmax_loop: disable rescale_threshold when max_attn_logit is
  requested, so row_max tracks the true per-row maximum
- correction_loop: write max_attn_logit alongside LSE in a single loop,
  using the same row_max from the same stats[] unpack

Interface changes:
- _flash_attn_fwd: allocate max_attn_logit tensor, pass to kernel,
  host-side amax reduction to (num_heads,)
- flash_attn_func / flash_attn_varlen_func: return (out, lse, max_logit)
  when return_max_attn_logit=True
- Assert return_lse=True when return_max_attn_logit=True (kernel uses
  LSE's sScale path to communicate row_max)
- SM90: raise NotImplementedError (untested)
- Bugfix: head_dim_v_padded used head_dim instead of head_dim_v

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@WyldeCat WyldeCat force-pushed the jeesoo/return-max-attn-logit branch from 3fdd2fc to 4872e41 Compare March 26, 2026 06:36
@WyldeCat WyldeCat merged commit c4fa1d1 into main Mar 26, 2026
@WyldeCat WyldeCat deleted the jeesoo/return-max-attn-logit branch March 26, 2026 07:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants