Skip to content

Conversation

@sambhavnoobcoder
Copy link
Contributor

@sambhavnoobcoder sambhavnoobcoder commented Nov 25, 2025

Problem Statement

Flash Attention 4 represents a significant architectural shift in the flash-attention package:

  1. Different import path: FA4 uses flash_attn.cute submodule instead of the main flash_attn package
  2. API incompatibility: FA4's flash_attn_varlen_func has a different signature - it does NOT accept max_seqlen_q and max_seqlen_k parameters (calculates them internally from cu_seqlens)
  3. New parameters: FA4 introduces new optimization parameters like learnable_sink, num_splits, and pack_gqa
  4. Removed parameters: FA4 removes dropout_p and alibi_slopes support
  5. Hardware requirements: FA4 requires CUDA compute capability SM 8.0+ (Ampere or newer), with optimal performance on SM 9.0+ (Hopper/Blackwell)

Without explicit FA4 support, users cannot leverage these improvements even when they have compatible hardware and the flash-attn package with CuTe DSL installed.

Solution Design

The solution maintains full backward compatibility while adding FA4 support through:

1. Detection Layer

Added is_flash_attn_4_available() function that checks:

  • CUDA availability
  • flash-attn package installation
  • Presence of flash_attn.cute submodule
  • GPU compute capability >= SM 8.0

2. Priority-Based Auto-Selection

When attn_implementation=None, the selection order is:

FA4 > FA3 > FA2 > SDPA > Eager

FA4 gets highest priority on compatible hardware for optimal performance.

3. Runtime Introspection

Created _is_using_fa4() helper that uses function signature inspection to detect FA4 vs FA2/FA3 at runtime. This enables conditional code paths without hardcoded version checks.

4. Conditional Varlen Calls

Modified two critical call sites in _flash_attention_forward() to conditionally pass parameters:

  • FA4 path: Calls without max_seqlen_q and max_seqlen_k (calculates internally)
  • FA2/FA3 path: Calls with explicit max_seqlen parameters (required)

5. Parameter Support

Extended _process_flash_attention_kwargs() to handle FA4-specific parameters, with automatic filtering based on introspection to maintain compatibility across versions.

6. Registration

Registered flash_attention_4 in AttentionInterface._global_mapping to enable explicit selection via attn_implementation="flash_attention_4".

Implementation Details

Core Changes

Detection and Import

  • Added FA4 detection function with hardware capability checks in import_utils
  • Exported detection function in utils module
  • Modified import logic to handle flash_attn.cute submodule

Integration Layer

  • Updated helper functions to recognize FA4
  • Added introspection-based FA4 detection
  • Extended parameter processing for FA4-specific options
  • Implemented conditional varlen function calls at both call sites

Interface Registration

  • Registered FA4 in attention interface mapping

Testing Infrastructure

  • Added test decorator for FA4-specific tests

New Files

Test Suite
Comprehensive test coverage including:

  • Detection function tests
  • Import tests with GPU
  • Basic forward pass tests
  • Causal attention tests
  • Varlen function API signature verification
  • HF integration tests
  • FA4-specific parameter tests (softcap, window_size)

Validation Script
Quick validation script for SSH GPU access that checks:

  • CUDA environment and compute capability
  • Package installations
  • FA4 detection
  • Import functionality
  • API signature correctness
  • Basic forward pass
  • HF integration layer

Usage Examples
Demonstrates:

  • Explicit FA4 selection
  • Automatic implementation selection
  • Performance comparison across implementations

Testing Status

Automated Checks
Created and ran comprehensive verification script checking:

  • Detection function exists and exported
  • All imports configured correctly
  • FA4 integrated into helper functions
  • Import paths from flash_attn.cute configured
  • Introspection helper created
  • FA4 parameters added
  • Both varlen call sites protected with conditionals
  • AttentionInterface registration complete
  • Test decorator added
  • All files compile without errors

All 14 core integration checks passed, plus 7 additional file checks passed.

Pending Testing (Requires GPU)

GPU Validation Required
Due to lack of CUDA GPU access during development, the following tests are pending:

  1. Basic Functionality

    • FA4 detection on real GPU
    • Import from flash_attn.cute
    • Basic forward pass execution
    • Varlen function calls
  2. Integration Tests

    • Full test suite execution
    • Model inference with FA4
    • Numerical accuracy comparison (FA4 vs FA2)
    • Performance benchmarking
  3. Real-World Usage

    • Testing with popular models (Llama, Mistral, Qwen2)
    • Testing with static and dynamic caches
    • Testing varlen sequences in production scenarios
    • Training with FA4 (if backward pass available)

Hardware Requirements

Component Requirement
GPU NVIDIA with CUDA support
Compute Capability SM 8.0+ (Ampere/Hopper/Blackwell)
Optimal Performance SM 9.0+ (Hopper H100/H200, Blackwell)
CUDA 11.8+ (12.8+ for Blackwell)
Software flash-attn with CuTe DSL support

Known Limitations

  1. Dropout Not Supported: FA4 doesn't have dropout_p parameter - training with dropout will automatically fall back to FA2/eager
  2. ALiBi Slopes Not Supported: Models using ALiBi (e.g., BLOOM) cannot use FA4 - will fall back to FA2/eager
  3. Backward Pass: May be inference-only initially depending on flash-attn release
  4. Softcap in Backward: softcap != 0.0 may be restricted during backward pass

All limitations are handled gracefully via automatic fallback.

Usage

Explicit FA4 Selection

Users can explicitly request FA4 when loading models.

Auto-Selection (Recommended)

When no attention implementation is specified, transformers will automatically select the best available implementation, with FA4 receiving highest priority on compatible hardware.

Check Availability

Users can check if FA4 is available using the is_flash_attn_4_available() function.

fixes : #42405

@sambhavnoobcoder sambhavnoobcoder changed the title # Add Flash Attention 4 (CuTe DSL) Support Add Flash Attention 4 (CuTe DSL) Support Nov 25, 2025
@vasqu vasqu mentioned this pull request Nov 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant