-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[core] use kernels
to support _flash_3_hub
attention backend
#12236
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not do it this way i.e. changing the behavior when FA3 is not installed. Instead, let's add new backend _flash_3_hf
(or something similar) so that user has to explicitly set it to download the kernel. As an end user, running any remote code should require some form of explicit consent imo, and this approach is better than defaulting to downloading from Hub.
@@ -514,6 +540,22 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc | |||
return torch.empty_like(query), query.new_empty(lse_shape) | |||
|
|||
|
|||
@_custom_op("vllm_flash_attn3::_flash_attn_forward", mutates_args=(), device_types="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, fullgraph tracing for compile is failing with:
torch._dynamo.exc.Unsupported: Operator does not support running with fake tensors
Explanation:
Hint: see https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0 for how to fix
Developer debug context: unsupported operator: _vllm_flash_attn3_28fbd26_dirty.fwd.default
I am referred to https://docs.pytorch.org/tutorials/advanced/python_custom_ops.html but couldn't figure how to craft this input.
Any advice? @anijain2305
Code:
import torch
from diffusers import FluxPipeline
model_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(
model_id, torch_dtype=torch.bfloat16
).to("cuda")
pipe.transformer.set_attention_backend("_flash_3_hub")
pipe.transformer.compile(fullgraph=True)
prompt = "A cat holding a sign that says 'hello world'"
with torch._dynamo.config.patch(error_on_recompile=True):
image = pipe(
prompt, num_inference_steps=28, guidance_scale=4.0, generator=torch.manual_seed(0)
).images[0]
image.save("output.png")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means that the custom op is not compatible with torch.compile. There are 2 ways to go
- Short term - Allow graph breaks (fullgraph=False) - This will compile around the custom op.
- Long term - Make the custom op torch.compile compatible - https://docs.pytorch.org/tutorials/advanced/custom_ops_landing_page.html - This requires some serious work.
kernels
for FA3 when the build is not locally availablekernels
for FA3 when the build is not locally available
kernels
for FA3 when the build is not locally availablekernels
to support _flash_3_hub
attention backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, looking much better!
Getting the following error when trying to compile: ExpandFile "/fsx/sayak/diffusers/check_fa3_backend.py", line 18, in <module>
image = pipe(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/fsx/sayak/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 919, in __call__
noise_pred = self.transformer(
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1771, in _wrapped_call_impl
return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 745, in compile_wrapper
raise e.with_traceback(None) from e.__cause__ # User compiler error
torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor
Explanation: torch.* ops that return a non-Tensor cannot be traced into the Dynamo FX graph output
Developer debug context: example_value type: bool; op: call_function; target: <function compiled_with_cxx11_abi at 0x7f67768c0160>
from user code:
File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 720, in forward
encoder_hidden_states, hidden_states = block(
File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 443, in forward
attention_outputs = self.attn(
File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 342, in forward
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
File "/fsx/sayak/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 116, in __call__
hidden_states = dispatch_attention_fn(
File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 293, in dispatch_attention_fn
return backend_fn(**kwargs)
File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 717, in _flash_attention_3_hub
out, lse, *_ = flash_attn_3_hub_func(
File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 226, in flash_attn_3_hub_func
return _load_fa3_hub().flash_attn_func(*args, **kwargs)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/_dynamo/polyfills/__init__.py", line 193, in getattr_and_trace
return fn(*args[2:], **kwargs)
File "/fsx/sayak/diffusers/src/diffusers/models/attention_dispatch.py", line 217, in _load_fa3_hub
fa3_hub = _get_fa3_from_hub() # won't re-download if already present
File "/fsx/sayak/diffusers/src/diffusers/utils/kernels_utils.py", line 18, in _get_fa3_from_hub
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/kernels/utils.py", line 234, in get_kernel
package_name, package_path = install_kernel(repo_id, revision=revision)
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/kernels/utils.py", line 117, in install_kernel
variant = build_variant()
File "/fsx/sayak/miniconda3/envs/diffusers/lib/python3.10/site-packages/kernels/utils.py", line 64, in build_variant
cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98"
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo" Maybe we can investigate this in a future PR. Code:
Cc: @anijain2305 |
The recompilation issues are gone thanks to a recent fix from @danieldk. I will button up this PR and let you know once it’s ready for another review. @a-r-r-o-w |
|
||
|
||
def flash_attn_3_hub_func(*args, **kwargs): | ||
return _load_fa3_hub().flash_attn_func(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than loading the kernel every time we invoke flash attention, it would be better to import the function at the top of the file, similar to the other FA backends.
diffusers/src/diffusers/models/attention_dispatch.py
Lines 63 to 68 in e58711e
if _CAN_USE_FLASH_ATTN_3: | |
from flash_attn_interface import flash_attn_func as flash_attn_3_func | |
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func | |
else: | |
flash_attn_3_func = None | |
flash_attn_3_varlen_func = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure but see
#12236 (comment)
We shouldn’t make remote calls to the Hub unless requested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can enable user permission via env variable similar to GGUF kernels
os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
God that simplifies a lot of stuff. Thanks, Dhruv!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually wondering if we just add a DIFFUSERS_ENABLE_HUB_KERNELS
constant that is used for all kernel cases.
What does this PR do?
Building FA3 from the source is a time-consuming process, and we could use the
kernels
lib for this.This PR attempts a dirty implementation of using
kernels
to set the FA3 backend when requested by the user. It's an extremely early implementation, so apologies in advance.We could let users specify something like
set_attention_backend("kernels-community/vllm-flash-attn3", interface="flash_attn_func")
when they don't have the FA3 build locally available. But that's a matter for our discussion.Additionally, this also helps keep
diffusers
Hub-first askernels
provides a great way to leverage the platform, IMO.Minimal code to test:
Some comments are inline.