diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 6a05aac215c6..4db70f4626e2 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -17,6 +17,7 @@ import inspect import math from enum import Enum +from functools import lru_cache from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch @@ -26,6 +27,7 @@ is_flash_attn_3_available, is_flash_attn_available, is_flash_attn_version, + is_kernels_available, is_sageattention_available, is_sageattention_version, is_torch_npu_available, @@ -131,7 +133,6 @@ def wrap(func): _custom_op = custom_op_no_op _register_fake = register_fake_no_op - logger = get_logger(__name__) # pylint: disable=invalid-name # TODO(aryan): Add support for the following: @@ -153,6 +154,8 @@ class AttentionBackendName(str, Enum): FLASH_VARLEN = "flash_varlen" _FLASH_3 = "_flash_3" _FLASH_VARLEN_3 = "_flash_varlen_3" + _FLASH_3_HUB = "_flash_3_hub" + # _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet. # PyTorch native FLEX = "flex" @@ -207,6 +210,22 @@ def list_backends(cls): return list(cls._backends.keys()) +@lru_cache(maxsize=None) +def _load_fa3_hub(): + from ..utils.kernels_utils import _get_fa3_from_hub + + fa3_hub = _get_fa3_from_hub() # won't re-download if already present + if fa3_hub is None: + raise RuntimeError( + "Failed to load FlashAttention-3 kernels from the Hub. Please ensure the wheel is available for your platform." + ) + return fa3_hub + + +def flash_attn_3_hub_func(*args, **kwargs): + return _load_fa3_hub().flash_attn_func(*args, **kwargs) + + @contextlib.contextmanager def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE): """ @@ -351,6 +370,13 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." ) + # TODO: add support Hub variant of FA3 varlen later + elif backend in [AttentionBackendName._FLASH_3_HUB]: + if not is_kernels_available(): + raise RuntimeError( + f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." + ) + elif backend in [ AttentionBackendName.SAGE, AttentionBackendName.SAGE_VARLEN, @@ -514,6 +540,22 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc return torch.empty_like(query), query.new_empty(lse_shape) +@_custom_op("vllm_flash_attn3::flash_attn", mutates_args=(), device_types="cuda") +def _wrapped_flash_attn_3_hub( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = flash_attn_3_hub_func(query, key, value) + lse = lse.permute(0, 2, 1) + return out, lse + + +@_register_fake("vllm_flash_attn3::flash_attn") +def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, seq_len, num_heads, head_dim = query.shape + lse_shape = (batch_size, seq_len, num_heads) + return torch.empty_like(query), query.new_empty(lse_shape) + + # ===== Attention backends ===== @@ -657,6 +699,41 @@ def _flash_attention_3( return (out, lse) if return_attn_probs else out +@_AttentionBackendRegistry.register( + AttentionBackendName._FLASH_3_HUB, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention_3_hub( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + deterministic: bool = False, + return_attn_probs: bool = False, +) -> torch.Tensor: + out, lse, *_ = flash_attn_3_hub_func( + q=query, + k=key, + v=value, + softmax_scale=scale, + causal=is_causal, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=window_size, + softcap=softcap, + num_splits=1, + pack_gqa=None, + deterministic=deterministic, + sm_margin=0, + ) + return (out, lse) if return_attn_probs else out + + @_AttentionBackendRegistry.register( AttentionBackendName._FLASH_VARLEN_3, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], diff --git a/src/diffusers/utils/kernels_utils.py b/src/diffusers/utils/kernels_utils.py new file mode 100644 index 000000000000..dddc9ede21e7 --- /dev/null +++ b/src/diffusers/utils/kernels_utils.py @@ -0,0 +1,22 @@ +from ..utils import get_logger +from .import_utils import is_kernels_available + + +logger = get_logger(__name__) + + +_DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3" + + +def _get_fa3_from_hub(): + if not is_kernels_available(): + return None + else: + from kernels import get_kernel + + try: + flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3) + return flash_attn_3_hub + except Exception as e: + logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}") + raise