Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/self-scheduled-flash-attn-caller.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ jobs:
runner_type: "a10"
report_repo_id: hf-internal-testing/transformers_flash_attn_ci
commit_sha: ${{ github.sha }}
pytest_marker: "flash_attn_test or flash_attn_3_test"
pytest_marker: "flash_attn_test or flash_attn_3_test or flash_attn_4_test"
secrets: inherit
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def pytest_configure(config):
config.addinivalue_line("markers", "torch_export_test: mark test which tests torch export functionality")
config.addinivalue_line("markers", "flash_attn_test: mark test which tests flash attention functionality")
config.addinivalue_line("markers", "flash_attn_3_test: mark test which tests flash attention 3 functionality")
config.addinivalue_line("markers", "flash_attn_4_test: mark test which tests flash attention 4 functionality")

os.environ["DISABLE_SAFETENSORS_CONVERSION"] = "true"

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ line-ending = "auto"
addopts = "--doctest-glob='**/*.md'"
doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
markers = [
"flash_attn_4_test: marks tests related to flash attention 4 (deselect with '-m \"not flash_attn_4_test\"')",
"flash_attn_3_test: marks tests related to flash attention 3 (deselect with '-m \"not flash_attn_3_test\"')",
"flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')",
"bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
Expand Down
1 change: 1 addition & 0 deletions src/transformers/masking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ class AttentionMaskInterface(GeneralInterface):
"eager": eager_mask,
"flash_attention_2": flash_attention_mask,
"flash_attention_3": flash_attention_mask,
"flash_attention_4": flash_attention_mask,
"flex_attention": flex_attention_mask,
}

Expand Down
69 changes: 47 additions & 22 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,21 @@
from .utils import (
is_flash_attn_2_available,
is_flash_attn_3_available,
is_flash_attn_4_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_npu_available,
is_torch_xpu_available,
logging,
)
from .utils.import_utils import is_tracing


logger = logging.get_logger(__name__)


# TODO Deprecate when all models have the attention interface
def flash_attn_supports_top_left_mask():
if is_flash_attn_3_available():
if is_flash_attn_3_available() or is_flash_attn_4_available():
return False
if is_flash_attn_2_available():
return not is_flash_attn_greater_or_equal_2_10()
Expand All @@ -48,7 +50,8 @@ def flash_attn_supports_top_left_mask():
# TODO Deprecate when all models have the attention interface
def is_flash_attn_available():
return (
is_flash_attn_3_available()
is_flash_attn_4_available()
or is_flash_attn_3_available()
or is_flash_attn_2_available()
or is_torch_npu_available()
or is_torch_xpu_available()
Expand Down Expand Up @@ -84,13 +87,16 @@ def _lazy_imports(implementation: Optional[str], attention_wrapper: Optional[Cal
"""
is_fa2 = is_flash_attn_2_available()
is_fa3 = is_flash_attn_3_available()
is_fa4 = is_flash_attn_4_available()

pad_input, unpad_input = _pad_input, _unpad_input

is_paged = implementation.startswith("paged|")
implementation = implementation.split("|")[1] if is_paged else implementation

if (implementation == "flash_attention_2" and is_fa2) or (implementation is None and is_fa2 and not is_fa3):
if (implementation == "flash_attention_2" and is_fa2) or (
implementation is None and is_fa2 and not is_fa3 and not is_fa4
):
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import pad_input, unpad_input
elif is_torch_npu_available():
Expand All @@ -99,8 +105,10 @@ def _lazy_imports(implementation: Optional[str], attention_wrapper: Optional[Cal
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
else:
if implementation == "flash_attention_3" or (implementation is None and is_fa3):
if implementation == "flash_attention_3" or (implementation is None and is_fa3 and not is_fa4):
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
elif implementation == "flash_attention_4" or (implementation is None and is_fa4):
from flash_attn.cute import flash_attn_func, flash_attn_varlen_func
# Kernels fallback
else:
from .integrations.hub_kernels import load_and_register_attn_kernel
Expand Down Expand Up @@ -212,7 +220,7 @@ def _unpad_input(hidden_states, attention_mask, unused_mask=None):
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
max_seqlen_in_batch = seqlens_in_batch.max()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))

return (
Expand Down Expand Up @@ -261,9 +269,7 @@ def _get_unpad_data(attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.T
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
# NOTE: Similar to the `.item()` in prepare_fa2_from_position_ids, with torch compile,
# this might cause a graph break
max_seqlen_in_batch = seqlens_in_batch.max().item()
max_seqlen_in_batch = seqlens_in_batch.max()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
Expand Down Expand Up @@ -384,12 +390,6 @@ def prepare_fa_kwargs_from_position_ids(position_ids):
# We should use cu_seq_lens instead of position_ids to get the max length since position_ids is not always increasing
# for some models (e.g. qwen2-vl).
max_length_q = cu_seq_lens_q.diff().max()
# NOTE: With torch compile, this will cause a graph break if you don't set
# `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` in the environment or call
# `torch._dynamo.config.capture_scalar_outputs = True` before doing the forward pass.
# This is a limitation of flash attention API, as the function `flash_attn_varlen_func`
# requires `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
max_length_q = max_length_q.item()
max_length_k = max_length_q

return (cu_seq_lens_q, cu_seq_lens_k), (max_length_q, max_length_k)
Expand Down Expand Up @@ -499,6 +499,8 @@ def _process_flash_attention_kwargs(
softcap: Optional[float] = None,
deterministic: Optional[bool] = None,
s_aux: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int | torch.IntTensor] = None,
max_seqlen_k: Optional[int | torch.IntTensor] = None,
supports_mapping: Optional[dict[str, bool]] = None,
**kwargs,
):
Expand Down Expand Up @@ -529,6 +531,10 @@ def _process_flash_attention_kwargs(
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
s_aux (`torch.Tensor`, *optional*):
Attention sink auxiliary that adds a `bias` to the attention calculation via an additional head.
max_seqlen_q (`Union[int, torch.IntTensor]`, *optional*):
The maximum sequence length in the query tensor during a varlen forward.
max_seqlen_k (`Union[int, torch.IntTensor]`, *optional*):
The maximum sequence length in the key/value tensor during a varlen forward.
Return:
flash_kwargs (`dict`):
A dict of kwargs that are requested and supported.
Expand Down Expand Up @@ -561,6 +567,23 @@ def _process_flash_attention_kwargs(
if supports_mapping["s_aux"] and s_aux is not None:
flash_kwargs["s_aux"] = s_aux

# There is a limitation of the flash attention API, as the function `flash_attn_varlen_func`
# may require `max_length_q`, `max_length_k` to be passed as `int` and not `torch.Tensor`.
#
# You can either set
# - Env: `TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1`
# - Before compiling: `torch._dynamo.config.capture_scalar_outputs = True`
# to allow torch compile to handle scalar outputs in those cases.
if supports_mapping["max_seqlen_q"] and max_seqlen_q is not None:
if not isinstance(max_seqlen_q, int) and is_tracing(max_seqlen_q):
max_seqlen_q = max_seqlen_q.item()
flash_kwargs["max_seqlen_q"] = max_seqlen_q

if supports_mapping["max_seqlen_k"] and max_seqlen_k is not None:
if not isinstance(max_seqlen_k, int) and is_tracing(max_seqlen_k):
max_seqlen_k = max_seqlen_k.item()
flash_kwargs["max_seqlen_k"] = max_seqlen_k

return flash_kwargs


Expand Down Expand Up @@ -615,7 +638,8 @@ def _flash_attention_forward(
)

# Extract the flash attention kwargs that have been requested (and are supported by the implementation)
flash_kwargs = process_flash_kwargs_fn(
flash_kwargs = partial(
process_flash_kwargs_fn,
query_length=query_length,
key_length=key_states.size(1),
is_causal=is_causal,
Expand Down Expand Up @@ -651,15 +675,15 @@ def _flash_attention_forward(
if "mps" in str(q.device):
cu_seq_lens_k = cu_seq_lens_k.clone()

# Newer fa versions no longer accept `max_seqlen_(q|k)`
final_flash_kwargs = flash_kwargs(max_seqlen_q=max_length_q, max_seqlen_k=max_length_k)
out_unpad = flash_varlen_fn(
q,
k,
v,
cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
**flash_kwargs,
**final_flash_kwargs,
)
if isinstance(out_unpad, tuple):
out_unpad = out_unpad[0]
Expand All @@ -682,15 +706,15 @@ def _flash_attention_forward(
if "mps" in str(q.device):
cu_seq_lens_k = cu_seq_lens_k.clone()

# Newer fa versions no longer accept `max_seqlen_(q|k)`
final_flash_kwargs = flash_kwargs(max_seqlen_q=max_length_q, max_seqlen_k=max_length_k)
out = flash_varlen_fn(
q,
k,
v,
cu_seqlens_q=cu_seq_lens_q,
cu_seqlens_k=cu_seq_lens_k,
max_seqlen_q=max_length_q,
max_seqlen_k=max_length_k,
**flash_kwargs,
**final_flash_kwargs,
)
if isinstance(out, tuple):
out = out[0]
Expand All @@ -699,7 +723,8 @@ def _flash_attention_forward(

# No padding
else:
out = flash_fn(query_states, key_states, value_states, **flash_kwargs)
final_flash_kwargs = flash_kwargs()
out = flash_fn(query_states, key_states, value_states, **final_flash_kwargs)
if isinstance(out, tuple):
out = out[0]

Expand Down
Loading
Loading