Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import inspect
import math
from enum import Enum
from functools import lru_cache
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import torch
Expand All @@ -26,6 +27,7 @@
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
is_kernels_available,
is_sageattention_available,
is_sageattention_version,
is_torch_npu_available,
Expand Down Expand Up @@ -131,7 +133,6 @@ def wrap(func):
_custom_op = custom_op_no_op
_register_fake = register_fake_no_op


logger = get_logger(__name__) # pylint: disable=invalid-name

# TODO(aryan): Add support for the following:
Expand All @@ -153,6 +154,8 @@ class AttentionBackendName(str, Enum):
FLASH_VARLEN = "flash_varlen"
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.

# PyTorch native
FLEX = "flex"
Expand Down Expand Up @@ -207,6 +210,22 @@ def list_backends(cls):
return list(cls._backends.keys())


@lru_cache(maxsize=None)
def _load_fa3_hub():
from ..utils.kernels_utils import _get_fa3_from_hub

fa3_hub = _get_fa3_from_hub() # won't re-download if already present
if fa3_hub is None:
raise RuntimeError(
"Failed to load FlashAttention-3 kernels from the Hub. Please ensure the wheel is available for your platform."
)
return fa3_hub


def flash_attn_3_hub_func(*args, **kwargs):
return _load_fa3_hub().flash_attn_func(*args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than loading the kernel every time we invoke flash attention, it would be better to import the function at the top of the file, similar to the other FA backends.

if _CAN_USE_FLASH_ATTN_3:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure but see
#12236 (comment)

We shouldn’t make remote calls to the Hub unless requested.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can enable user permission via env variable similar to GGUF kernels

os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

God that simplifies a lot of stuff. Thanks, Dhruv!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually wondering if we just add a DIFFUSERS_ENABLE_HUB_KERNELS constant that is used for all kernel cases.



@contextlib.contextmanager
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
"""
Expand Down Expand Up @@ -351,6 +370,13 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
)

# TODO: add support Hub variant of FA3 varlen later
elif backend in [AttentionBackendName._FLASH_3_HUB]:
if not is_kernels_available():
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)

elif backend in [
AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN,
Expand Down Expand Up @@ -514,6 +540,22 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
return torch.empty_like(query), query.new_empty(lse_shape)


@_custom_op("vllm_flash_attn3::flash_attn", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3_hub(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = flash_attn_3_hub_func(query, key, value)
lse = lse.permute(0, 2, 1)
return out, lse


@_register_fake("vllm_flash_attn3::flash_attn")
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, num_heads, head_dim = query.shape
lse_shape = (batch_size, seq_len, num_heads)
return torch.empty_like(query), query.new_empty(lse_shape)


# ===== Attention backends =====


Expand Down Expand Up @@ -657,6 +699,41 @@ def _flash_attention_3(
return (out, lse) if return_attn_probs else out


@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_attention_3_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
deterministic: bool = False,
return_attn_probs: bool = False,
) -> torch.Tensor:
out, lse, *_ = flash_attn_3_hub_func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
return (out, lse) if return_attn_probs else out


@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_VARLEN_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
Expand Down
22 changes: 22 additions & 0 deletions src/diffusers/utils/kernels_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from ..utils import get_logger
from .import_utils import is_kernels_available


logger = get_logger(__name__)


_DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3"


def _get_fa3_from_hub():
if not is_kernels_available():
return None
else:
from kernels import get_kernel

try:
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3)
return flash_attn_3_hub
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
raise
Loading