Add Flash Attention 4 (CuTe DSL) Support #42404
Open
+844
−24
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Problem Statement
Flash Attention 4 represents a significant architectural shift in the flash-attention package:
flash_attn.cutesubmodule instead of the mainflash_attnpackageflash_attn_varlen_funchas a different signature - it does NOT acceptmax_seqlen_qandmax_seqlen_kparameters (calculates them internally fromcu_seqlens)learnable_sink,num_splits, andpack_gqadropout_pandalibi_slopessupportWithout explicit FA4 support, users cannot leverage these improvements even when they have compatible hardware and the flash-attn package with CuTe DSL installed.
Solution Design
The solution maintains full backward compatibility while adding FA4 support through:
1. Detection Layer
Added
is_flash_attn_4_available()function that checks:flash_attn.cutesubmodule2. Priority-Based Auto-Selection
When
attn_implementation=None, the selection order is:FA4 gets highest priority on compatible hardware for optimal performance.
3. Runtime Introspection
Created
_is_using_fa4()helper that uses function signature inspection to detect FA4 vs FA2/FA3 at runtime. This enables conditional code paths without hardcoded version checks.4. Conditional Varlen Calls
Modified two critical call sites in
_flash_attention_forward()to conditionally pass parameters:max_seqlen_qandmax_seqlen_k(calculates internally)5. Parameter Support
Extended
_process_flash_attention_kwargs()to handle FA4-specific parameters, with automatic filtering based on introspection to maintain compatibility across versions.6. Registration
Registered
flash_attention_4inAttentionInterface._global_mappingto enable explicit selection viaattn_implementation="flash_attention_4".Implementation Details
Core Changes
Detection and Import
flash_attn.cutesubmoduleIntegration Layer
Interface Registration
Testing Infrastructure
New Files
Test Suite
Comprehensive test coverage including:
Validation Script
Quick validation script for SSH GPU access that checks:
Usage Examples
Demonstrates:
Testing Status
Automated Checks
Created and ran comprehensive verification script checking:
All 14 core integration checks passed, plus 7 additional file checks passed.
Pending Testing (Requires GPU)
GPU Validation Required
Due to lack of CUDA GPU access during development, the following tests are pending:
Basic Functionality
Integration Tests
Real-World Usage
Hardware Requirements
Known Limitations
dropout_pparameter - training with dropout will automatically fall back to FA2/eagersoftcap != 0.0may be restricted during backward passAll limitations are handled gracefully via automatic fallback.
Usage
Explicit FA4 Selection
Users can explicitly request FA4 when loading models.
Auto-Selection (Recommended)
When no attention implementation is specified, transformers will automatically select the best available implementation, with FA4 receiving highest priority on compatible hardware.
Check Availability
Users can check if FA4 is available using the
is_flash_attn_4_available()function.fixes : #42405