Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions flash_sparse_attn/ops/triton/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,27 +95,31 @@ def finalize(
row_sum,
scale_log2,
final_scale,
IS_LOG2: tl.constexpr,
):
"""
Finalize online softmax by computing output scale and logsumexp.

:param row_max: Final maximum values per row of shape [BLOCK_M].
:param row_sum: Final sum values per row of shape [BLOCK_M].
:param final_scale: Scaling factor to be applied to the output.
:param IS_LOG2: Boolean flag indicating if the returned logsumexp should be in log2-space.

:return row_scale: Final scaling factors per row of shape [BLOCK_M].
:return lse: Logsumexp values per row of shape [BLOCK_M].
"""
# if row_sum is zero or nan, set it to 1 to avoid division by zero
acc_o_is_zero_or_nan = (row_sum == 0.0) | (row_sum != row_sum)
row_scale = tl.where(acc_o_is_zero_or_nan, 1.0, 1.0 / row_sum) * final_scale
# ln2 = math.log(2.0)
ln2 = 0.6931471805599453
lse = tl.where(
acc_o_is_zero_or_nan,
float("-inf"),
(row_max * scale_log2 + tl.log2(row_sum)) * ln2,
row_max * scale_log2 + tl.log2(row_sum),
)
if not IS_LOG2:
# ln2 = math.log(2.0)
ln2 = 0.6931471805599453
lse *= ln2
return row_scale, lse


Expand Down
1 change: 1 addition & 0 deletions flash_sparse_attn/ops/triton/flash_dense_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ def _fwd_dense_base_kernel(
row_sum=row_sum,
scale_log2=softmax_scale_log2,
final_scale=1.0,
IS_LOG2=IS_SPLIT_KV,
)
acc_o = activations.rescale_o(acc_o, row_scale, LAZY_RESCALE=False)

Expand Down
140 changes: 56 additions & 84 deletions flash_sparse_attn/ops/triton/flash_fwd_combine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import torch
import triton
import triton.language as tl
Expand Down Expand Up @@ -34,14 +33,11 @@ def _fwd_combine_kernel(
TILE_K: tl.constexpr,
HAS_CU_SEQLENS_Q: tl.constexpr,
HAS_SEQUSED_Q: tl.constexpr,
MAX_SPLITS: tl.constexpr,
):
m_block = tl.program_id(0)
k_block = tl.program_id(1)
bh_idx = tl.program_id(2)
bh_idx = tl.program_id(1)
batch_idx = bh_idx // num_heads_q
head_idx = bh_idx - batch_idx * num_heads_q
offs_m = m_block * TILE_M + tl.arange(0, TILE_M)

# Get seqlen info for this batch
offset_q, actual_seqlen_q = seqlen_info.get_seqlen_info(
Expand All @@ -52,7 +48,6 @@ def _fwd_combine_kernel(
HAS_CU_SEQLENS=HAS_CU_SEQLENS_Q,
HAS_SEQUSED=HAS_SEQUSED_Q,
)
mask_m = offs_m < actual_seqlen_q

# Initialize base pointers
out_part_base = seqlen_info.offset_batch_Q(
Expand Down Expand Up @@ -85,13 +80,23 @@ def _fwd_combine_kernel(
HAS_CU_SEQLENS_Q,
USE_PADDED=False,
)
lse_base = seqlen_info.offset_batch_Q(
Lse + head_idx * stride_lh,
batch_idx,
offset_q,
0,
stride_lb,
1,
HAS_CU_SEQLENS_Q,
USE_PADDED=False,
)

# Create pointers
out_part_ptrs = tl.make_block_ptr(
base=out_part_base,
shape=(num_splits, actual_seqlen_q, head_dim),
strides=(stride_ops, stride_opm, 1),
offsets=(0, m_block * TILE_M, k_block * TILE_K),
offsets=(0, m_block * TILE_M, 0),
block_shape=(1, TILE_M, TILE_K),
order=(2, 1, 0),
)
Expand All @@ -107,99 +112,69 @@ def _fwd_combine_kernel(
base=out_base,
shape=(actual_seqlen_q, head_dim),
strides=(stride_om, 1),
offsets=(m_block * TILE_M, k_block * TILE_K),
offsets=(m_block * TILE_M, 0),
block_shape=(TILE_M, TILE_K),
order=(1, 0),
)
lse_ptrs = tl.make_block_ptr(
base=lse_base,
shape=(actual_seqlen_q,),
strides=(1,),
offsets=(m_block * TILE_M,),
block_shape=(TILE_M,),
order=(0,),
)

# Initialize accumulators
lse_vals = tl.full((MAX_SPLITS, TILE_M), float("-inf"), dtype=tl.float32)
max_lse = tl.full((TILE_M,), float("-inf"), dtype=tl.float32)
e_sum = tl.zeros((TILE_M,), dtype=tl.float32)
e_max = tl.full((TILE_M,), float("-inf"), dtype=tl.float32)
acc_o = tl.zeros((TILE_M, TILE_K), dtype=tl.float32)

# Compute max across splits
for s in tl.static_range(MAX_SPLITS):
if s < num_splits:
lse_s = tl.sum(tl.load(lse_part_ptrs, boundary_check=(0, 1)), axis=0)
# boundary_check pads OOB with 0, fixup to -inf for correct softmax
lse_s = tl.where(mask_m, lse_s, float("-inf"))
lse_part_ptrs = tl.advance(lse_part_ptrs, (1, 0))
else:
lse_s = tl.full((TILE_M,), float("-inf"), dtype=tl.float32)
lse_vals = tl.where(
(tl.arange(0, MAX_SPLITS) == s)[:, None],
lse_s[None, :],
lse_vals,
)
max_lse = tl.maximum(max_lse, lse_s)
# Combine split outputs
for _ in tl.range(0, num_splits):
# Load partial LSE
lse_s = tl.sum(tl.load(lse_part_ptrs, boundary_check=(0, 1)), axis=0)

# if all -inf, use 0 to avoid nan in exp
max_lse = tl.where(max_lse == float("-inf"), 0.0, max_lse)
# Advance LSE pointers
lse_part_ptrs = tl.advance(lse_part_ptrs, (1, 0))

# Compute normalized scales
sum_exp = tl.zeros((TILE_M,), dtype=tl.float32)
for s in tl.static_range(MAX_SPLITS):
lse_s = tl.sum(
tl.where(
(tl.arange(0, MAX_SPLITS) == s)[:, None],
lse_vals,
tl.full((MAX_SPLITS, TILE_M), 0.0, dtype=tl.float32),
),
axis=0,
)
exp_s = tl.where(s < num_splits, tl.exp(lse_s - max_lse), 0.0)
sum_exp += exp_s
# Compute normalized exponentials
new_e_max = tl.maximum(lse_s, e_max)
old_scale = tl.exp2(e_max - new_e_max)
exp_logic = tl.exp2(lse_s - new_e_max)

# Load partial outputs
o_s = tl.sum(tl.load(out_part_ptrs, boundary_check=(0, 1, 2)), axis=0)

# Advance output pointers
out_part_ptrs = tl.advance(out_part_ptrs, (1, 0, 0))

inv_sum = tl.where((sum_exp == 0.0) | (sum_exp != sum_exp), 0.0, 1.0 / sum_exp)
# Compute scaled outputs
acc_o *= old_scale[:, None]
acc_o += exp_logic[:, None] * o_s

for s in tl.static_range(MAX_SPLITS):
if s < num_splits:
lse_s = tl.sum(
tl.where(
(tl.arange(0, MAX_SPLITS) == s)[:, None],
lse_vals,
tl.full((MAX_SPLITS, TILE_M), 0.0, dtype=tl.float32),
),
axis=0,
)
scale = tl.exp(lse_s - max_lse) * inv_sum
o_s = tl.sum(tl.load(out_part_ptrs, boundary_check=(0, 1, 2)), axis=0)
out_part_ptrs = tl.advance(out_part_ptrs, (1, 0, 0))
acc_o += scale[:, None] * o_s
# Update e_sum and e_max
e_sum = e_sum * old_scale + exp_logic
e_max = new_e_max
Comment on lines +141 to +158
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new online log-sum-exp update can produce NaNs when a split has no valid blocks and writes LSE=-inf. In that case e_max starts at -inf and new_e_max can also be -inf, so tl.exp(e_max - new_e_max) and tl.exp(lse_s - new_e_max) evaluate exp(-inf - -inf) -> NaN, which then contaminates e_sum/acc_o (and inv_sum=0 won’t reliably clear NaNs). Please add an explicit guard for the new_e_max == -inf (or e_max == -inf & lse_s == -inf) case to keep old_scale finite (typically 1.0) and exp_logic at 0.0, matching the prior behavior that avoided NaNs when all contributions were -inf.

Copilot uses AI. Check for mistakes.

# Normalize output
inv_sum = tl.where((e_sum == 0.0) | (e_sum != e_sum), 0.0, 1.0 / e_sum)
acc_o *= inv_sum[:, None]

# Store output
tl.store(
out_ptrs,
acc_o.to(Out.dtype.element_ty),
boundary_check=(0, 1),
)

# Compute LSE
# ln2 = math.log(2.0)
ln2 = 0.6931471805599453
lse = tl.where(e_sum > 0.0, (e_max + tl.log2(e_sum)) * ln2, float("-inf"))

# Store LSE
# Only from the first k_block to avoid duplicates
if k_block == 0:
lse_base = seqlen_info.offset_batch_Q(
Lse + head_idx * stride_lh,
batch_idx,
offset_q,
0,
stride_lb,
1,
HAS_CU_SEQLENS_Q,
USE_PADDED=False,
)
lse_ptrs = tl.make_block_ptr(
base=lse_base,
shape=(actual_seqlen_q,),
strides=(1,),
offsets=(m_block * TILE_M,),
block_shape=(TILE_M,),
order=(0,),
)
lse = tl.where(
sum_exp > 0.0,
max_lse + tl.log(sum_exp),
float("-inf"),
)
tl.store(lse_ptrs, lse, boundary_check=(0,))
tl.store(lse_ptrs, lse, boundary_check=(0,))


def _flash_attn_fwd_combine(
Expand All @@ -218,7 +193,6 @@ def _flash_attn_fwd_combine(
total_q, num_heads_q, head_dim = out_partial.shape[1:]
batch_size = cu_seqlens_q.shape[0] - 1
seqlen_q = total_q
MAX_SPLITS = 1 << max(int(math.ceil(math.log2(max(num_splits, 1)))), 1)

TILE_K = max(triton.next_power_of_2(head_dim), 16)

Expand All @@ -232,7 +206,6 @@ def _flash_attn_fwd_combine(
batch_size=batch_size,
seqlen_q=seqlen_q,
num_heads_q=num_heads_q,
head_dim=head_dim,
)

_fwd_combine_kernel[grid](
Expand Down Expand Up @@ -262,7 +235,6 @@ def _flash_attn_fwd_combine(
TILE_K=TILE_K,
HAS_CU_SEQLENS_Q=cu_seqlens_q is not None,
HAS_SEQUSED_Q=seqused_q is not None,
MAX_SPLITS=MAX_SPLITS,
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
Expand Down
1 change: 1 addition & 0 deletions flash_sparse_attn/ops/triton/flash_gated_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,7 @@ def _fwd_gated_base_kernel(
row_sum=row_sum,
scale_log2=softmax_scale_log2,
final_scale=1.0,
IS_LOG2=IS_SPLIT_KV,
)
acc_o = activations.rescale_o(acc_o, row_scale, LAZY_RESCALE=False)

Expand Down
1 change: 1 addition & 0 deletions flash_sparse_attn/ops/triton/flash_sparse_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ def _fwd_base_sparse_kernel(
row_sum=row_sum,
scale_log2=softmax_scale_log2,
final_scale=1.0,
IS_LOG2=IS_SPLIT_KV,
)
acc_o = activations.rescale_o(acc_o, row_scale, LAZY_RESCALE=False)

Expand Down
3 changes: 0 additions & 3 deletions flash_sparse_attn/ops/triton/launch_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,20 @@ def get_fwd_combine_grid(
batch_size: int,
seqlen_q: int,
num_heads_q: int,
head_dim: int,
):
"""
Get the grid function for the forward combine kernel.

:param batch_size: Batch size
:param seqlen_q: Sequence length of queries
:param num_heads_q: Number of query heads
:param head_dim: Head dimension

:return grid: Grid function
"""

def grid(META):
return (
triton.cdiv(seqlen_q, META["TILE_M"]),
triton.cdiv(head_dim, META["TILE_K"]),
batch_size * num_heads_q,
)

Expand Down