Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
14 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 60 additions & 12 deletions flash_sparse_attn/ops/triton/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,74 @@ def check_inf(x):

@triton.jit
def online_softmax(
acc_s,
row_max,
row_sum,
scale_log2,
CHECK_INF: tl.constexpr,
RESCALE_THRESHOLD: tl.constexpr,
):
Comment on lines 11 to +18
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is described as a pure rename/comment clarity change, but this file introduces a new online_softmax implementation and renames the previous implementation to online_sparse_softmax (plus signature/return-value changes used by dense kernels). Please update the PR description to reflect these functional refactors, or split the refactor into a separate PR for easier review/risk assessment.

Copilot uses AI. Check for mistakes.
"""
Apply online softmax to acc_s, and update row_max and row_sum.

Comment on lines 11 to 21
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description says this is a pure rename/comment-clarification refactor, but this file also changes the softmax API (signature changes to online_softmax and introduces online_sparse_softmax). If this is intentional, please update the PR description to reflect the behavioral/API surface changes; otherwise, consider limiting the PR to naming/comment changes only.

Copilot uses AI. Check for mistakes.
:param acc_s: Attention scores tensor of shape [BLOCK_M, BLOCK_N].
:param row_max: Current maximum values per row of shape [BLOCK_M], init to -inf.
:param row_sum: Current sum values per row of shape [BLOCK_M], init to 0.
:param scale_log2: Log2 of the scaling factor to be applied to acc_s.
:param CHECK_INF: Boolean flag indicating if -inf row_max should be clamped to 0.
:param RESCALE_THRESHOLD: Threshold for rescaling to avoid underflow. If <= 0, rescaling is disabled.

:return p: Softmax probabilities tensor of shape [BLOCK_M, BLOCK_N].
:return row_max_new: Updated maximum values per row of shape [BLOCK_M].
:return row_sum_new: Updated sum values per row of shape [BLOCK_M].
:return row_scale: Scaling factors per row of shape [BLOCK_M].
"""
# Compute current row max
row_max_curr = tl.max(acc_s, axis=1)

# Update row max
row_max_new = tl.maximum(row_max_curr, row_max)

# Avoid exp(-inf - (-inf)) = nan by clamping -inf to 0
if CHECK_INF:
row_max_new = check_inf(row_max_new)

# Compute scaled differences to new row max
acc_scale_log2 = (row_max - row_max_new) * scale_log2

# Compute row scale
if RESCALE_THRESHOLD > 0.0:
# Triton can only skip computation at block granularity
if tl.min(acc_scale_log2) < -RESCALE_THRESHOLD:
row_scale = tl.exp2(acc_scale_log2)
else:
row_max_new = row_max
row_scale = acc_scale_log2 * 0.0 + 1.0
else:
row_scale = tl.exp2(acc_scale_log2)

# Compute attention weights
p = tl.exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2)

# Update row sum
row_sum_cur = tl.sum(p, axis=1)
row_sum_new = row_sum * row_scale + row_sum_cur

return p, row_max_new, row_sum_new, row_scale


@triton.jit
def online_sparse_softmax(
acc_s,
block_max,
row_max,
row_sum,
scale_log2,
softmax_threshold_log2,
CHECK_INF: tl.constexpr,
RESCALE_THRESHOLD: tl.constexpr,
):
Comment on lines +74 to +77
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SOFTMAX_THRESHOLD_LOG2 is declared as tl.constexpr, but call sites pass a runtime value computed inside the kernel (e.g., via seqlen_info.get_softmax_threshold). Triton requires tl.constexpr arguments to be compile-time constants, so this is likely to fail compilation; make the threshold a regular runtime argument (remove tl.constexpr) or restructure so the threshold is known at specialization time.

Copilot uses AI. Check for mistakes.
"""
Apply online softmax to acc_s, and update block_max, row_max and row_sum.
Apply online sparse softmax to acc_s, and update block_max, row_max and row_sum.

:param acc_s: Attention scores tensor of shape [BLOCK_M, BLOCK_N].
:param block_max: Running block-wise maximum scalar, init to -inf.
Expand All @@ -28,7 +85,6 @@ def online_softmax(
:param scale_log2: Log2 of the scaling factor to be applied to acc_s.
:param softmax_threshold_log2: Threshold in log2-domain for block-level skip. If > -inf and block max is below threshold relative to running max, skip softmax update.
:param CHECK_INF: Boolean flag indicating if -inf row_max should be clamped to 0.
:param RESCALE_THRESHOLD: Threshold for rescaling to avoid underflow. If <= 0, rescaling is disabled.

:return p: Softmax probabilities tensor of shape [BLOCK_M, BLOCK_N].
:return block_max_new: Updated block-wise maximum scalar.
Expand Down Expand Up @@ -69,15 +125,7 @@ def online_softmax(
acc_scale_log2 = (row_max - row_max_new) * scale_log2

# Compute row scale
if RESCALE_THRESHOLD > 0.0:
# Triton can only skip computation at block granularity
if tl.min(acc_scale_log2) < -RESCALE_THRESHOLD:
row_scale = tl.exp2(acc_scale_log2)
else:
row_max_new = row_max
row_scale = acc_scale_log2 * 0.0 + 1.0
else:
row_scale = tl.exp2(acc_scale_log2)
row_scale = tl.exp2(acc_scale_log2)

# Compute attention weights
p = tl.exp2(acc_s * scale_log2 - row_max_new[:, None] * scale_log2)
Expand Down
8 changes: 4 additions & 4 deletions flash_sparse_attn/ops/triton/flash_bwd_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _bwd_postprocess_kernel(
order=(1, 0),
)

# Advance dq_accum pointer
# Advance dq_accum pointers
dq_accum_ptrs = tl.advance(dq_accum_ptrs, (m_block * TILE_M, 0))

# Load accumulators
Expand All @@ -99,7 +99,7 @@ def _bwd_postprocess_kernel(
# Scale dq
dq = (acc_dq * scale).to(dQ.dtype.element_ty)

# Advance dq pointer
# Advance dq pointers
dq_ptrs = tl.advance(dq_ptrs, (m_block * TILE_M, 0))

# Store dq
Expand Down Expand Up @@ -144,7 +144,7 @@ def _bwd_postprocess_kernel(
order=(0,),
)

# Advance da_accum pointer
# Advance da_accum pointers
da_accum_ptrs = tl.advance(da_accum_ptrs, (m_block * TILE_M,))

# Load da accumulators
Expand All @@ -153,7 +153,7 @@ def _bwd_postprocess_kernel(
# Scale da
da = (acc_da * scale).to(dA.dtype.element_ty)

# Advance da pointer
# Advance da pointers
da_ptrs = tl.advance(da_ptrs, (m_block * TILE_M,))

# Store da
Expand Down
12 changes: 6 additions & 6 deletions flash_sparse_attn/ops/triton/flash_bwd_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,31 +174,31 @@ def _bwd_preprocess_kernel(
# Initialize accumulators
acc_dq = tl.zeros((TILE_M, TILE_K), dtype=tl.float32)

# Advance output pointer
# Advance output pointers
o_ptrs = tl.advance(o_ptrs, (m_block * TILE_M, 0))

# Load o tile
o_tile = tl.load(o_ptrs, boundary_check=(0, 1)).to(tl.float32)

# Advance do pointer
# Advance do pointers
do_ptrs = tl.advance(do_ptrs, (m_block * TILE_M, 0))

# Load do tile
do_tile = tl.load(do_ptrs, boundary_check=(0, 1)).to(tl.float32)

# Advance dpsum pointer
# Advance dpsum pointers
dpsum_ptrs = tl.advance(dpsum_ptrs, (m_block * TILE_M,))

# Compute dpsum
dpsum = tl.sum(o_tile * do_tile, axis=1)

# Advance acc_dq pointer
# Advance acc_dq pointers
dq_accum_ptrs = tl.advance(dq_accum_ptrs, (m_block * TILE_M, 0))

# Store dpsum
tl.store(dpsum_ptrs, dpsum, boundary_check=(0,))

# Advance lse pointer
# Advance lse pointers
lse_ptrs = tl.advance(lse_ptrs, (m_block * TILE_M,))

# Load lse tile
Expand All @@ -211,7 +211,7 @@ def _bwd_preprocess_kernel(
# Store dq_accum
tl.store(dq_accum_ptrs, acc_dq, boundary_check=(0, 1))

# Advance lse_log2 pointer
# Advance lse_log2 pointers
lse_log2_ptrs = tl.advance(lse_log2_ptrs, (m_block * TILE_M,))

# Store lse_log2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


@triton.jit
def _fwd_combine_kernel(
def _dec_combine_kernel(
Out_partial,
Lse_partial,
Out,
Expand Down Expand Up @@ -177,7 +177,7 @@ def _fwd_combine_kernel(
tl.store(lse_ptrs, lse, boundary_check=(0,))


def _flash_attn_fwd_combine(
def _flash_attn_dec_combine(
out_partial: torch.Tensor,
lse_partial: torch.Tensor,
out: torch.Tensor,
Expand All @@ -197,18 +197,18 @@ def _flash_attn_fwd_combine(
TILE_K = max(triton.next_power_of_2(head_dim), 16)

TILE_M, num_warps, num_stages, num_ctas = (
launch_template.get_fwd_combine_launch_config(
launch_template.get_dec_combine_launch_config(
tile_k=TILE_K,
)
)

grid = launch_grid.get_fwd_combine_grid(
grid = launch_grid.get_dec_combine_grid(
batch_size=batch_size,
seqlen_q=seqlen_q,
num_heads_q=num_heads_q,
)

_fwd_combine_kernel[grid](
_dec_combine_kernel[grid](
out_partial,
lse_partial,
out,
Expand Down
Loading