Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions flash_sparse_attn/ops/cute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,11 @@
except PackageNotFoundError:
__version__ = "0.0.0"

import cutlass.cute as cute

from .interface import (
flash_attn_func,
flash_attn_varlen_func,
)

from flash_sparse_attn.ops.cute.cute_dsl_utils import cute_compile_patched

# Patch cute.compile to optionally dump SASS
cute.compile = cute_compile_patched


__all__ = [
"flash_attn_func",
"flash_attn_varlen_func",
Expand Down
42 changes: 40 additions & 2 deletions flash_sparse_attn/ops/cute/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@


def flops(
batch, nheads, seqlen_q, seqlen_k, headdim, headdim_v, causal=False, window_size=(None, None)
batch,
nheads,
seqlen_q,
seqlen_k,
headdim,
headdim_v,
causal=False,
window_size=(None, None),
has_qv=False,
):
if causal:
avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2
Expand All @@ -35,7 +43,37 @@ def flops(
else torch.full_like(row_idx, seqlen_k - 1)
)
avg_seqlen = (col_right - col_left + 1).float().mean().item()
return batch * nheads * 2 * seqlen_q * avg_seqlen * (headdim + headdim_v)
eff_headdim = headdim + headdim_v if has_qv else headdim
return batch * nheads * 2 * seqlen_q * avg_seqlen * (eff_headdim + headdim_v)


# ── Bandwidth calculation ────────────────────────────────────────────────────


def bandwidth_fwd_bytes(
batch, nheads, nheads_kv, seqlen_q, seqlen_k, headdim, headdim_v, dtype_bytes=2, has_qv=False
):
"""HBM traffic for one attention pass: read Q,K,V + write O."""
q = batch * nheads * seqlen_q * headdim
qv = batch * nheads * seqlen_q * headdim_v if has_qv else 0
k = batch * nheads_kv * seqlen_k * headdim
v = batch * nheads_kv * seqlen_k * headdim_v
o = batch * nheads * seqlen_q * headdim_v
return (q + qv + k + v + o) * dtype_bytes


def bandwidth_bwd_bytes(
batch, nheads, nheads_kv, seqlen_q, seqlen_k, headdim, headdim_v, dtype_bytes=2
):
"""HBM traffic for one attention pass: read Q,K,V,dO + write dQ,dK,dV."""
q = batch * nheads * seqlen_q * headdim
k = batch * nheads_kv * seqlen_k * headdim
v = batch * nheads_kv * seqlen_k * headdim_v
do = batch * nheads * seqlen_q * headdim_v
dq = q
dk = k
dv = v
return (q + k + v + do + dq + dk + dv) * dtype_bytes


# ── Reference attention ─────────────────────────────────────────────────────
Expand Down
Loading
Loading