Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flash_sparse_attn/ops/cute/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ out = flash_attn_func(q, k, v, causal=True)
```sh
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
pip install -e "flash_attn/cute[dev]"
pip install -e "flash_attn/cute[dev]" # CUDA 12.x
pip install -e "flash_attn/cute[dev,cu13]" # CUDA 13.x (e.g. B200)
pytest tests/cute/
```
14 changes: 7 additions & 7 deletions flash_sparse_attn/ops/cute/block_sparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,18 +1348,18 @@ def _store_one_dQaccum_sm90(
m_block,
sdQaccum: cute.Tensor,
gdQaccum: cute.Tensor,
num_mma_warp_groups: cutlass.Constexpr,
num_dQ_warp_groups: cutlass.Constexpr,
num_threads_per_warp_group: cutlass.Constexpr,
tma_copy_bytes_dQ,
):
"""Store dQaccum for a single m_block."""
for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):
cute.arch.cp_async_bulk_wait_group(num_mma_warp_groups - 1 - warp_group_idx, read=True)
for warp_group_idx in cutlass.range_constexpr(num_dQ_warp_groups):
cute.arch.cp_async_bulk_wait_group(num_dQ_warp_groups - 1 - warp_group_idx, read=True)
cute.arch.barrier_arrive(
barrier_id=int(NamedBarrierBwd.dQEmptyWG0) + warp_group_idx,
number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
)
for warp_group_idx in cutlass.range_constexpr(num_mma_warp_groups):
for warp_group_idx in cutlass.range_constexpr(num_dQ_warp_groups):
cute.arch.barrier(
barrier_id=int(NamedBarrierBwd.dQFullWG0) + warp_group_idx,
number_of_threads=num_threads_per_warp_group + cute.arch.WARP_SIZE,
Expand All @@ -1383,7 +1383,7 @@ def dQaccum_store_block_sparse_bwd_sm90(
gdQaccum: cute.Tensor,
subtile_factor: cutlass.Constexpr,
m_block_max: int,
num_mma_warp_groups: cutlass.Constexpr,
num_dQ_warp_groups: cutlass.Constexpr,
num_threads_per_warp_group: cutlass.Constexpr,
tma_copy_bytes_dQ,
):
Expand Down Expand Up @@ -1412,7 +1412,7 @@ def dQaccum_store_block_sparse_bwd_sm90(
m_block,
sdQaccum,
gdQaccum,
num_mma_warp_groups,
num_dQ_warp_groups,
num_threads_per_warp_group,
tma_copy_bytes_dQ,
)
Expand All @@ -1428,7 +1428,7 @@ def dQaccum_store_block_sparse_bwd_sm90(
m_block,
sdQaccum,
gdQaccum,
num_mma_warp_groups,
num_dQ_warp_groups,
num_threads_per_warp_group,
tma_copy_bytes_dQ,
)
43 changes: 33 additions & 10 deletions flash_sparse_attn/ops/cute/block_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,23 @@ class BlockSparseTensorsTorch(NamedTuple):
block_size: tuple[int, int] | None = None


def get_sparse_q_block_size(
tensors: BlockSparseTensorsTorch | None,
seqlen_q: int,
) -> int | None:
"""Return the Q sparse block size, or None when sparsity is unset or ambiguous."""
if tensors is None:
return None
if tensors.block_size is not None:
return tensors.block_size[0]
num_m_blocks = tensors.mask_block_idx.shape[2]
min_block_size = ceildiv(seqlen_q, num_m_blocks)
max_block_size = seqlen_q if num_m_blocks == 1 else (seqlen_q - 1) // (num_m_blocks - 1)
if min_block_size != max_block_size:
return None
return min_block_size


def _expand_sparsity_tensor(
tensor: torch.Tensor,
expected_shape: Tuple[int, ...],
Expand Down Expand Up @@ -81,6 +98,12 @@ def _check_and_expand_block(
expanded_cnt = _expand_sparsity_tensor(
cnt, expected_count_shape, f"{name}_block_cnt", context, hint
)
# [Note] Allow Compact block sparse indices
# Allow the last dimension (n_blocks) of idx to be <= expected, since
# FA4 only accesses indices 0..cnt-1 per query tile. This enables compact
# index tensors that avoid O(N^2) memory at long sequence lengths.
if idx.ndim == 4 and idx.shape[3] <= expected_index_shape[3]:
expected_index_shape = (*expected_index_shape[:3], idx.shape[3])
expanded_idx = _expand_sparsity_tensor(
idx, expected_index_shape, f"{name}_block_idx", context, hint
)
Expand Down Expand Up @@ -140,17 +163,14 @@ def infer_block_sparse_expected_shapes(
num_m_blocks = tensors.mask_block_idx.shape[2]

if sparse_block_size_q is None:
min_block_size = ceildiv(seqlen_q, num_m_blocks)
if num_m_blocks == 1:
max_block_size = seqlen_q
else:
max_block_size = (seqlen_q - 1) // (num_m_blocks - 1)
if max_block_size != min_block_size and base_m_block != 1:
sparse_block_size_q = get_sparse_q_block_size(tensors, seqlen_q)
if sparse_block_size_q is None and base_m_block != 1:
raise ValueError(
f"Block sparse tensors{context} require explicit sparse_block_size[0] "
f"to disambiguate block size for seqlen_q={seqlen_q} and num_m_blocks={num_m_blocks}."
)
sparse_block_size_q = min_block_size
if sparse_block_size_q is None:
sparse_block_size_q = ceildiv(seqlen_q, num_m_blocks)

if sparse_block_size_q % base_m_block != 0:
raise ValueError(
Expand Down Expand Up @@ -186,9 +206,11 @@ def infer_block_sparse_expected_shapes(
raise ValueError(f"Block sparse tensors{context} {dim_name} dim must be {tgt} or 1.")
if mask_block_cnt.shape[2] != mask_block_idx.shape[2]:
raise ValueError(f"Block sparse tensors{context} must share the same m-block dimension.")
if mask_block_idx.shape[3] != expected_n_blocks:
# [Note] Allow Compact block sparse indices: FA4 only accesses indices 0..cnt-1
# per query tile, so idx.shape[3] can be <= expected_n_blocks.
if mask_block_idx.shape[3] > expected_n_blocks:
raise ValueError(
f"Block sparse tensors{context} n-block dimension must be {expected_n_blocks}."
f"Block sparse tensors{context} n-block dimension must be <= {expected_n_blocks}."
)
if expected_m_blocks != num_m_blocks:
raise ValueError(
Expand Down Expand Up @@ -314,7 +336,7 @@ def normalize_block_sparse_config(
) -> tuple[BlockSparseTensorsTorch, Tuple[Tuple[bool, ...], ...] | None, int]:
m_block_size, n_block_size = block_size
if tensors.block_size is None:
sparse_block_size_q, sparse_block_size_kv = q_stage * m_block_size, n_block_size
sparse_block_size_q, sparse_block_size_kv = None, n_block_size
else:
sparse_block_size_q, sparse_block_size_kv = tensors.block_size
if sparse_block_size_kv != n_block_size:
Expand Down Expand Up @@ -401,6 +423,7 @@ def to_cute_block_sparse_tensors(
"""Convert torch block sparsity tensors to CuTe tensors, optionally for tvm ffi"""
if not is_block_sparsity_enabled(tensors):
return None

(
mask_block_cnt,
mask_block_idx,
Expand Down
24 changes: 9 additions & 15 deletions flash_sparse_attn/ops/cute/cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Manage Ahead-of-Time (AOT) compiled kernels
import fcntl
import hashlib
import logging
import os
import pickle
import sys
Expand All @@ -18,6 +17,7 @@
import cutlass.cute as cute
import tvm_ffi
from cutlass.cutlass_dsl import JitCompiledFunction
from flash_sparse_attn.ops.cute.fa_logging import fa_log

# Pre-load cute DSL runtime libraries with RTLD_GLOBAL so that their symbols
# (e.g. _cudaLibraryLoadData) are visible to .so modules loaded later via dlopen.
Expand All @@ -30,12 +30,6 @@
CompileKeyType: TypeAlias = tuple[Hashable, ...]
CallableFunction: TypeAlias = JitCompiledFunction | tvm_ffi.Function

logger = logging.getLogger(__name__)
_handler = logging.StreamHandler()
_handler.setFormatter(logging.Formatter("%(asctime)s.%(msecs)03d %(levelname)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
logger.addHandler(_handler)
logger.setLevel(logging.DEBUG)


# Enable cache via `FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED=1`
CUTE_DSL_CACHE_ENABLED: bool = os.getenv("FLASH_ATTENTION_CUTE_DSL_CACHE_ENABLED", "0") == "1"
Expand Down Expand Up @@ -222,13 +216,13 @@ def _try_load_from_storage(self, key: CompileKeyType) -> bool:
label=sha256_hex,
):
if obj_path.exists():
logger.debug("Loading compiled function from disk: %s", obj_path)
fa_log(1, f"Loading compiled function from disk: {obj_path}")
m = cute.runtime.load_module(str(obj_path), enable_tvm_ffi=True)
fn = getattr(m, self.EXPORT_FUNCTION_PREFIX)
JITCache.__setitem__(self, key, fn)
return True
else:
logger.debug("Cache miss on disk for key hash %s", sha256_hex)
fa_log(1, f"Cache miss on disk for key hash {sha256_hex}")
return False

def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -> None:
Expand All @@ -243,14 +237,14 @@ def _try_export_to_storage(self, key: CompileKeyType, fn: JitCompiledFunction) -
obj_path = self.cache_path / f"{sha256_hex}.o"
if obj_path.exists():
# Another process already exported.
logger.debug("Skipping export, already on disk: %s", obj_path)
fa_log(1, f"Skipping export, already on disk: {obj_path}")
return
logger.debug("Exporting compiled function to disk: %s", obj_path)
fa_log(1, f"Exporting compiled function to disk: {obj_path}")
fn.export_to_c(
object_file_path=str(obj_path),
function_name=self.EXPORT_FUNCTION_PREFIX,
)
logger.debug("Successfully exported compiled function to disk: %s", obj_path)
fa_log(1, f"Successfully exported compiled function to disk: {obj_path}")

def _key_to_hash(self, key: CompileKeyType) -> str:
return hashlib.sha256(pickle.dumps(key)).hexdigest()
Expand All @@ -262,7 +256,7 @@ def clear(self) -> None:
"""
Not only clear the in-memory cache. Also purge persistent compilation cache.
"""
logger.debug("Clearing persistent cache at %s", self.cache_path)
fa_log(1, f"Clearing persistent cache at {self.cache_path}")
super().clear()
for child in self.cache_path.iterdir():
child.unlink()
Expand All @@ -281,8 +275,8 @@ def get_jit_cache(name: str | None = None) -> JITCache:
path = get_cache_path() / _compute_source_fingerprint()
if name:
path = path / name
logger.debug("Creating persistent JIT cache at %s", path)
fa_log(1, f"Creating persistent JIT cache at {path}")
return JITPersistentCache(path)
else:
logger.debug("Persistent cache disabled, using in-memory JIT cache")
fa_log(1, "Persistent cache disabled, using in-memory JIT cache")
return JITCache()
2 changes: 1 addition & 1 deletion flash_sparse_attn/ops/cute/flash_bwd_sm90.py
Original file line number Diff line number Diff line change
Expand Up @@ -1865,7 +1865,7 @@ def dQaccum_store(
gdQaccum,
subtile_factor=self.subtile_factor,
m_block_max=m_block_max,
num_mma_warp_groups=self.num_wg_mma,
num_dQ_warp_groups=self.num_wg_dQ,
num_threads_per_warp_group=self.num_threads_per_warp_group,
tma_copy_bytes_dQ=self.tma_copy_bytes["dQ"],
)
Expand Down
Loading
Loading