Skip to content

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Aug 25, 2025

What does this PR do?

Building FA3 from the source is a time-consuming process, and we could use the kernels lib for this.

This PR attempts a dirty implementation of using kernels to set the FA3 backend when requested by the user. It's an extremely early implementation, so apologies in advance.

We could let users specify something like set_attention_backend("kernels-community/vllm-flash-attn3", interface="flash_attn_func") when they don't have the FA3 build locally available. But that's a matter for our discussion.

Additionally, this also helps keep diffusers Hub-first as kernels provides a great way to leverage the platform, IMO.

Minimal code to test:

import torch
from diffusers import FluxPipeline

model_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(
    model_id, torch_dtype=torch.bfloat16
).to("cuda")

pipe.transformer.set_attention_backend("_flash_3_hub")

prompt = "A cat holding a sign that says 'hello world'"
image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]

image.save("output.png")

Some comments are inline.

@sayakpaul sayakpaul requested review from DN6 and a-r-r-o-w August 25, 2025 16:59
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not do it this way i.e. changing the behavior when FA3 is not installed. Instead, let's add new backend _flash_3_hf (or something similar) so that user has to explicitly set it to download the kernel. As an end user, running any remote code should require some form of explicit consent imo, and this approach is better than defaulting to downloading from Hub.

@@ -514,6 +540,22 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
return torch.empty_like(query), query.new_empty(lse_shape)


@_custom_op("vllm_flash_attn3::_flash_attn_forward", mutates_args=(), device_types="cuda")
Copy link
Member Author

@sayakpaul sayakpaul Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, fullgraph tracing for compile is failing with:

torch._dynamo.exc.Unsupported: Operator does not support running with fake tensors
  Explanation: 
  Hint: see https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0 for how to fix

  Developer debug context: unsupported operator: _vllm_flash_attn3_28fbd26_dirty.fwd.default

I am referred to https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html but couldn't figure how to craft this input.

Any advice? @anijain2305

Code:

import torch
from diffusers import FluxPipeline

model_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(
    model_id, torch_dtype=torch.bfloat16
).to("cuda")

pipe.transformer.set_attention_backend("_flash_3_hub")
pipe.transformer.compile(fullgraph=True)

prompt = "A cat holding a sign that says 'hello world'"

with torch._dynamo.config.patch(error_on_recompile=True):
    image = pipe(
        prompt, num_inference_steps=28, guidance_scale=4.0, generator=torch.manual_seed(0)
    ).images[0]
    image.save("output.png")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that the custom op is not compatible with torch.compile. There are 2 ways to go

  1. Short term - Allow graph breaks (fullgraph=False) - This will compile around the custom op.
  2. Long term - Make the custom op torch.compile compatible - https://docs.pytorch.org/tutorials/advanced/custom_ops_landing_page.html - This requires some serious work.

@sayakpaul sayakpaul changed the title [wip][core] use kernels for FA3 when the build is not locally available [core] use kernels for FA3 when the build is not locally available Aug 26, 2025
@sayakpaul sayakpaul changed the title [core] use kernels for FA3 when the build is not locally available [core] use kernels to support _flash_3_hub attention backend Aug 26, 2025
@sayakpaul sayakpaul marked this pull request as ready for review August 26, 2025 13:01
@sayakpaul sayakpaul requested a review from a-r-r-o-w August 26, 2025 13:01
Copy link
Member

@a-r-r-o-w a-r-r-o-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, looking much better!

@sayakpaul
Copy link
Member Author

Getting the following error when trying to compile:

Expand
File "/fsx/sayak/diffusers/check_fa3_backend.py", line 18, in <module>
    image = pipe(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 919, in __call__
    noise_pred = self.transformer(
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1771, in _wrapped_call_impl
    return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in compile_wrapper
    raise e.with_traceback(None) from e.__cause__  # User compiler error
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor
  Explanation: torch.* ops that return a non-Tensor cannot be traced into the Dynamo FX graph output


  Developer debug context: example_value type: bool; op: call_function; target: <function compiled_with_cxx11_abi at 0x7f67768c0160>


from user code:
   File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 720, in forward
    encoder_hidden_states, hidden_states = block(
  File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 443, in forward
    attention_outputs = self.attn(
  File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 342, in forward
    return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 116, in __call__
    hidden_states = dispatch_attention_fn(
  File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 293, in dispatch_attention_fn
    return backend_fn(**kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 717, in _flash_attention_3_hub
    out, lse, *_ = flash_attn_3_hub_func(
  File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 226, in flash_attn_3_hub_func
    return _load_fa3_hub().flash_attn_func(*args, **kwargs)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/polyfills/__init__.py", line 193, in getattr_and_trace
    return fn(*args[2:], **kwargs)
  File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 217, in _load_fa3_hub
    fa3_hub = _get_fa3_from_hub()  # won't re-download if already present
  File "/fsx/sayak/diffusers/src/diffusers/utils/kernels_utils.py", line 18, in _get_fa3_from_hub
    flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/kernels/utils.py", line 234, in get_kernel
    package_name, package_path = install_kernel(repo_id, revision=revision)
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/kernels/utils.py", line 117, in install_kernel
    variant = build_variant()
  File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/kernels/utils.py", line 64, in build_variant
    cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Maybe we can investigate this in a future PR.

Code:

import torch
from diffusers import FluxPipeline

model_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(
    model_id, torch_dtype=torch.bfloat16
).to("cuda")

pipe.transformer.set_attention_backend("_flash_3_hub")
pipe.transformer.compile(fullgraph=True)

prompt = "A cat holding a sign that says 'hello world'"

with torch._dynamo.config.patch(error_on_recompile=True):
    image = pipe(
        prompt, num_inference_steps=28, guidance_scale=4.0, generator=torch.manual_seed(0)
    ).images[0]
    image.save("output.png")

Cc: @anijain2305

@sayakpaul sayakpaul requested a review from a-r-r-o-w August 26, 2025 15:07
@sayakpaul
Copy link
Member Author

The recompilation issues are gone thanks to a recent fix from @danieldk. I will button up this PR and let you know once it’s ready for another review. @a-r-r-o-w



def flash_attn_3_hub_func(*args, **kwargs):
return _load_fa3_hub().flash_attn_func(*args, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than loading the kernel every time we invoke flash attention, it would be better to import the function at the top of the file, similar to the other FA backends.

if _CAN_USE_FLASH_ATTN_3:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure but see
#12236 (comment)

We shouldn’t make remote calls to the Hub unless requested.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can enable user permission via env variable similar to GGUF kernels

os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

God that simplifies a lot of stuff. Thanks, Dhruv!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually wondering if we just add a DIFFUSERS_ENABLE_HUB_KERNELS constant that is used for all kernel cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants