Add return_max_attn_logit for QK-Clip support#1
Merged
Conversation
Member
Author
CuTe DSL Primer (for reviewers)이 PR의 커널 코드는 NVIDIA CuTe DSL로 작성됨. 핵심 패턴을 이해하면 변경사항을 따라갈 수 있음. 텐서 slicing & tiling# Global memory 텐서에서 특정 head/batch를 선택 (NumPy의 indexing과 유사)
mMAL_cur = mMaxAttnLogit[None, head_idx, batch_idx]
# → mMaxAttnLogit이 (seqlen, heads, batch) 일 때, head_idx번째 head의 1D slice 반환
# Sequence offset 적용 (varlen에서 각 sequence의 시작 위치)
mMAL_cur = cute.domain_offset((offset,), mMaxAttnLogit[None, head_idx])
# → 1D 텐서의 시작점을 offset만큼 이동 (pointer arithmetic)
# 1D 텐서를 고정 크기 블록으로 나누고, m_tile_idx번째 블록 선택
gMAL = cute.local_tile(mMAL_cur, (m_block_size,), (m_tile_idx,))
# → mMAL_cur[m_tile_idx * m_block_size : (m_tile_idx+1) * m_block_size]
# 예: m_block_size=128, m_tile_idx=1 → rows 128~255SM100 warp specializationSM100 커널은 3종류의 warp group이 동시에 실행됨:
# Correction warp의 max_attn_logit write (flash_fwd_sm100.py)
for stage in range(q_stage):
m_tile_idx = (m_block * q_stage + stage) * cta_group_size + mma_tile_coord_v
gMAL = cute.local_tile(mMAL_cur, (m_block_size,), (m_tile_idx,))
row_sum, row_max, _ = stats[stage] # sScale에서 읽은 값
row_logit = row_max * softmax_scale_log2 * LN2 # = row_max / sqrt(d)
if tidx < seqlen_q - m_tile_idx * m_block_size: # bounds check
gMAL[tidx] = row_logit |
wanyaworld
reviewed
Mar 24, 2026
ca1207
reviewed
Mar 26, 2026
ca1207
reviewed
Mar 26, 2026
wanyaworld
approved these changes
Mar 26, 2026
wanyaworld
approved these changes
Mar 26, 2026
e34bd90 to
3fdd2fc
Compare
Add `return_max_attn_logit` parameter to FA4's forward API that returns per-head `max(Q@K^T / sqrt(d))` during the attention forward pass. Used by QK-Clip (Muon optimizer) to efficiently obtain attention logit statistics without a separate full matmul. SM100 kernel changes: - softmax_loop: disable rescale_threshold when max_attn_logit is requested, so row_max tracks the true per-row maximum - correction_loop: write max_attn_logit alongside LSE in a single loop, using the same row_max from the same stats[] unpack Interface changes: - _flash_attn_fwd: allocate max_attn_logit tensor, pass to kernel, host-side amax reduction to (num_heads,) - flash_attn_func / flash_attn_varlen_func: return (out, lse, max_logit) when return_max_attn_logit=True - Assert return_lse=True when return_max_attn_logit=True (kernel uses LSE's sScale path to communicate row_max) - SM90: raise NotImplementedError (untested) - Bugfix: head_dim_v_padded used head_dim instead of head_dim_v Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
3fdd2fc to
4872e41
Compare
ca1207
approved these changes
Mar 26, 2026
wanyaworld
approved these changes
Mar 26, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
FA4 forward pass에
return_max_attn_logit파라미터 추가.커널이 attention forward 중에 per-head
max(Q@K^T / sqrt(d))를 추출하여 반환.QK-Clip (Muon optimizer) 용도로, 기존 별도 full matmul 대비 ~10x 이상 효율적.
현재 SM100 (Blackwell) 만 지원. SM90 (Hopper)에서는
NotImplementedError.Bug fix:
rescale_threshold로 인한max_attn_logit오류Online Softmax과
rescale_thresholdFA4는 online softmax 알고리즘을 사용하여 attention을 n_block 단위로 처리합니다.
각 n_block마다
row_max(현재까지 본 최대 attention score)를 갱신하고,이전까지 누적된 output O를 rescale합니다:
rescale_threshold=8.0은 O rescaling 비용을 줄이기 위한 최적화입니다:왜 attention output과 LSE는 정확하고 max_attn_logit만 틀리는가
Softmax는 shift invariant합니다:
threshold가 rescale을 skip하면 모든 n_block이 동일한 shift 상수(row_max_old)를 사용하므로
attention output O = Σ(P_i @ V_i) / Σ(P_i)는 수학적으로 정확합니다.
LSE = log(Σ exp(s))도 동일한 이유로 정확합니다.
threshold=8.0의 실제 목적은 shift를 안 하면 커질 수 있는
exp2값이 bf16 범위를넘지 않도록 상한을 두는 것입니다 (
exp2(8) = 256, bf16 safe).그러나
max_attn_logit = row_max / √d는row_max자체에 직접 의존합니다.threshold에 의해
row_max가 고정되면 이후 n_block의 더 큰 score가 반영되지 않습니다.Fix
return_max_attn_logit=True일 때rescale_threshold=0.0으로 설정하여row_max가 항상 실제 최대값을 추적하도록 합니다.return_max_attn_logit=False(기본값)일 때는 기존 최적화 그대로 유지.Code Changes
flash_attn/cute/interface.pyreturn_max_attn_logit: bool = False파라미터 추가return_lse=True필수 assert (커널이 LSE의 sScale 경로로 row_max 전달)NotImplementedErrormax_attn_logit텐서 할당, CuTe tensor 변환, 커널 전달amaxreduction →(num_heads,)반환compile_key에 포함 (별도 컴파일 캐시)FlashAttnFunc/FlashAttnVarlenFunc:mark_non_differentiable처리flash_attn_func/flash_attn_varlen_func: 조건부 반환head_dim_v_padded가head_dim대신head_dim_v사용flash_attn/cute/flash_fwd_sm100.pyhas_max_attn_logit파라미터 →softmax_loop,correction_loop에 전달softmax_loop:rescale_threshold조건부 설정 +mMaxAttnLogit전달correction_loop: LSE write loop 안에서 max_attn_logit 동시 write (동일stats[]unpack, 중복 제거)Test Results
B200 (SM100), 15 tests, all passed (rel diff = 0.000000):
Benchmark
B200, BF16, bs=2, nh=80, nhkv=16:
return_max_attn_logit=False(default): overhead 0% (별도 compile key)_compute_qk_clip()(별도 full Q@K^T matmul): ~10x 이상 효율적Limitations
NotImplementedErrorscore_mod,softcap,num_splits > 1과 비호환pack_gqa=False강제return_lse=True필수🤖 Generated with Claude Code