diff --git a/docs/source/design/flashinfer_integration_issues.md b/docs/source/design/flashinfer_integration_issues.md new file mode 100644 index 000000000000..2ec408500f34 --- /dev/null +++ b/docs/source/design/flashinfer_integration_issues.md @@ -0,0 +1,709 @@ +# FlashInfer Integration Issues in vLLM + +This document details the current state of FlashInfer integration in vLLM, including known issues, broken functionality, and recommendations for the FlashInfer team. + +**Environment tested:** +- vLLM: v0.8.3rc2.dev6045 +- FlashInfer: 0.5.2, 0.5.3 (both tested) +- flashinfer-cubin: 0.5.2, 0.5.3 (both tested) +- PyTorch: 2.9.0 +- CUDA: 12.8 (runtime), FlashInfer packages compiled against CUDA ≤12.6 +- Hardware: NVIDIA H200 (SM 9.0 / Hopper) + +**Summary of issues by FlashInfer version:** + +| Issue | FlashInfer 0.5.2 | FlashInfer 0.5.3 | Fixed? | +|-------|------------------|------------------|--------| +| AllReduce Fusion JIT (`std::optional` bug) | ❌ Broken | ❌ Broken | **✅ Yes (Nov 2025)** | +| FP8 MoE CUDA Version (needs 12.7+) | ❌ Broken | ❌ Broken | **No** | +| MXFP4 MoE on SM90 | ❌ Broken | ❌ Broken | **No** | +| Attention Sinks on Hopper | ❌ Not supported | ❌ Not supported | **No** | + +--- + +## Table of Contents + +1. [Overview of FlashInfer Integration](#overview-of-flashinfer-integration) +2. [Working Features](#working-features) +3. [Known Issues](#known-issues) + - [Issue 1: AllReduce Fusion JIT Compilation Failure](#issue-1-allreduce-fusion-jit-compilation-failure) + - [Issue 2: FP8 MoE CUDA Version Mismatch](#issue-2-fp8-moe-cuda-version-mismatch) + - [Issue 3: MXFP4 MoE SM90 Backend Broken](#issue-3-mxfp4-moe-sm90-backend-broken) + - [Issue 4: Attention Sinks Not Supported on Hopper](#issue-4-attention-sinks-not-supported-on-hopper) +4. [Skipped Tests](#skipped-tests) +5. [Environment Variables Reference](#environment-variables-reference) +6. [Missing Operators for Complete FlashInfer Backend](#missing-operators-for-complete-flashinfer-backend) +7. [Recommendations for FlashInfer Team](#recommendations-for-flashinfer-team) + +--- + +## Overview of FlashInfer Integration + +vLLM integrates FlashInfer for multiple operators: + +| Operator | Environment Variable | Status on Hopper (SM90) | Status on Blackwell (SM100) | +|----------|---------------------|------------------------|----------------------------| +| Attention | `VLLM_ATTENTION_BACKEND=FLASHINFER` | ✅ Works (no sinks) | ✅ Works | +| Attention with Sinks | - | ❌ Needs TRTLLM | ✅ Via TRTLLM | +| Top-k/Top-p Sampling | `VLLM_USE_FLASHINFER_SAMPLER` | ✅ Works | ✅ Works | +| RMSNorm | `VLLM_USE_FLASHINFER_NORM` | ✅ Works | ✅ Works | +| Activations (SiLU, GELU) | `VLLM_USE_FLASHINFER_ACTIVATION` | ✅ Works | ✅ Works | +| MoE FP16/BF16 | `VLLM_USE_FLASHINFER_MOE_FP16` | ✅ Works | ✅ Works | +| MoE FP8 | `VLLM_USE_FLASHINFER_MOE_FP8` | ❌ CUDA version issue | ✅ Works | +| MoE MXFP4 | `VLLM_USE_FLASHINFER_MOE_MXFP4_BF16` | ❌ Broken | ✅ Works | +| AllReduce Fusion | `VLLM_USE_FLASHINFER_ALLREDUCE` | ✅ Works (with fix) | ✅ Works (with fix) | +| All2All | `VLLM_ALL2ALL_BACKEND=flashinfer_all2allv` | ✅ Works | ✅ Works | + +--- + +## Working Features + +These FlashInfer features work correctly on Hopper (H100/H200): + +### 1. Attention (without sinks) +- File: `vllm/v1/attention/backends/flashinfer.py` +- Works for standard models like Llama, Qwen without attention sinks + +### 2. Sampling +- File: `vllm/v1/sample/ops/topk_topp_sampler.py` +- Uses `flashinfer.sampling.top_k_top_p_sampling_from_probs` + +### 3. RMSNorm +- File: `vllm/model_executor/layers/layernorm.py` +- Uses `flashinfer.norm.rmsnorm` and `flashinfer.norm.fused_add_rmsnorm` +- **Note:** `fused_add_rmsnorm` is in-place and returns `None`. The vLLM integration must return `(x, residual)` after the call. + +### 4. Activations +- File: `vllm/model_executor/layers/activation.py` +- Uses `flashinfer.activation.silu_and_mul`, `gelu_and_mul`, `gelu_tanh_and_mul` + +### 5. MoE FP16/BF16 +- File: `vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py` +- Uses FlashInfer CUTLASS MoE for unquantized models + +--- + +## Known Issues + +### Issue 1: AllReduce Fusion JIT Compilation Failure + +**Severity:** High +**Affected versions:** FlashInfer 0.5.2, 0.5.3 (including with cubins installed) +**Environment variable:** `VLLM_USE_FLASHINFER_ALLREDUCE` +**Compilation pass:** `enable_fi_allreduce_fusion` + +#### Symptom + +When enabling `enable_fi_allreduce_fusion=True` in vLLM's compilation config, JIT compilation fails with: + +``` +/data-fast/.venv/lib/python3.12/site-packages/flashinfer/data/include/flashinfer/comm/trtllm_allreduce_fusion.cuh(487): error: namespace "std" has no member "optional" +/data-fast/.venv/lib/python3.12/site-packages/flashinfer/data/include/flashinfer/comm/trtllm_allreduce_fusion.cuh(487): error: identifier "batchIdx" is undefined +/data-fast/.venv/lib/python3.12/site-packages/flashinfer/data/include/flashinfer/comm/trtllm_allreduce_fusion.cuh(707): error: identifier "AllReduceFusionPattern" is undefined +``` + +#### Root Cause: C++ Namespace Mismatch in CUDA Header + +**File:** `flashinfer/data/include/flashinfer/comm/trtllm_allreduce_fusion.cuh` + +The header file has a namespace mismatch: + +```cpp +// Line 10 - includes CUDA's libcudacxx optional +#include + +// Line 487-489 - uses std::optional (wrong namespace!) +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset( + std::optional batchIdx, // ERROR: should be cuda::std::optional + int rowIdx, + int colIdx, + std::optional numRows, // ERROR: should be cuda::std::optional + ... +) +``` + +**The problem:** CUDA's libcudacxx uses `cuda::std::optional`, not `std::optional`. When NVCC compiles this header, it cannot find `std::optional` because `` (the C++17 standard library header) is not included. + +This causes a cascade of errors: +1. `std::optional` is undefined → compilation fails at line 487 +2. `batchIdx` parameter becomes undefined +3. All subsequent code using `AllReduceFusionPattern` (defined at line 690) fails because compilation never reaches that point + +#### Installing cubins Does NOT Fix This + +Even with `flashinfer-cubin==0.5.2` installed, the issue persists because: +- Cubins provide pre-compiled kernels for specific operations +- The JIT compilation wrapper still needs to compile code that includes the broken header +- The header file bug affects all code paths that include it + +#### Code Path + +``` +vllm/config/compilation.py + -> PassConfig(enable_fi_allreduce_fusion=True) + -> vllm/compilation/pass_manager.py:103 + -> AllReduceFusionPass(config) + -> vllm/compilation/collective_fusion.py + -> call_trtllm_fused_allreduce_norm() + -> flashinfer.comm.trtllm_allreduce_fusion() + -> JIT compilation of wrapper code + -> #include "flashinfer/comm/trtllm_allreduce_fusion.cuh" + -> COMPILATION FAILS +``` + +#### How to Reproduce + +```python +from vllm import LLM +from vllm.config import CompilationConfig, PassConfig + +# Enable the allreduce fusion pass +pass_config = PassConfig( + enable_fi_allreduce_fusion=True, + enable_noop=True, +) +compilation_config = CompilationConfig(pass_config=pass_config) + +# This will fail during model loading +llm = LLM( + model='Qwen/Qwen3-30B-A3B-Instruct-2507', + tensor_parallel_size=2, + compilation_config=compilation_config, +) +``` + +#### Workaround + +vLLM does NOT auto-enable this feature even when `VLLM_USE_FLASHINFER=1` is set: + +```python +# vllm/envs.py - VLLM_USE_FLASHINFER_ALLREDUCE does NOT fallback to VLLM_USE_FLASHINFER +"VLLM_USE_FLASHINFER_ALLREDUCE": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_ALLREDUCE", "0")) +), +``` + +**Note:** Running vLLM in eager mode (`enforce_eager=True`) does NOT help avoid this issue. The `enable_fi_allreduce_fusion` pass modifies the model's forward pass in a way that expects torch.compile to perform the actual fusion. Without compilation, the modified code path breaks. + +#### Fix Applied (November 2025) + +**Status:** ✅ **FIXED** - Patch available + +The bug has been fixed by changing all occurrences of `std::optional` and `std::nullopt` to their CUDA namespace equivalents. + +**Complete Patch:** + +```diff +--- a/data/include/flashinfer/comm/trtllm_allreduce_fusion.cuh ++++ b/data/include/flashinfer/comm/trtllm_allreduce_fusion.cuh +@@ -443,8 +443,8 @@ inline int getSMRegisters() { + return regs_per_block; + } + +-inline __device__ int64_t get_sf_out_offset_128x4(std::optional batchIdx, int mIdx, int kIdx, +- std::optional numRows, int numCols) { ++inline __device__ int64_t get_sf_out_offset_128x4(cuda::std::optional batchIdx, int mIdx, int kIdx, ++ cuda::std::optional numRows, int numCols) { + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + +@@ -484,8 +484,8 @@ inline __device__ int64_t get_sf_out_offset_128x4(std::optional batchIdx, i + } + + template +-__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional batchIdx, int rowIdx, +- int colIdx, std::optional numRows, ++__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(cuda::std::optional batchIdx, int rowIdx, ++ int colIdx, cuda::std::optional numRows, + int numCols, SFType* SFout, + QuantizationSFLayout layout) { + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +@@ -949,7 +949,7 @@ + if constexpr (GetQuantType == QuantType::kFP4) { + // NOTE(Yingyi): might update later + auto sf_out = utils::cvt_quant_to_fp4_get_sf_out_offset( +- std::nullopt /* batchIdx */, token_id, m_access_id_in_token, std::nullopt /* numRows */, ++ cuda::std::nullopt /* batchIdx */, token_id, m_access_id_in_token, cuda::std::nullopt /* numRows */, + m_params.hidden_dim, reinterpret_cast(m_params.scale_out), m_params.layout); + reinterpret_cast(m_params.quant_out)[m_access_id] = + utils::cvt_warp_fp16_to_fp4(val, m_scale_factor, sf_out); +``` + +**Installation Script:** + +An automated installation script is available that installs FlashInfer 0.5.2 from PyPI and applies the fix: + +```bash +#!/bin/bash +# install_flashinfer_with_fix.sh +# Installs FlashInfer 0.5.2 with AllReduce Fusion fix applied + +pip install flashinfer-python==0.5.2 + +# Find installation path +FLASHINFER_PATH=$(python -c "import flashinfer, os; print(os.path.dirname(flashinfer.__file__))") +TARGET_FILE="$FLASHINFER_PATH/data/include/flashinfer/comm/trtllm_allreduce_fusion.cuh" + +# Apply fixes +sed -i 's/\bstd::optional batchIdx/cuda::std::optional batchIdx/g' "$TARGET_FILE" +sed -i 's/\bstd::optional numRows/cuda::std::optional numRows/g' "$TARGET_FILE" +sed -i 's/std::nullopt /cuda::std::nullopt /g' "$TARGET_FILE" + +echo "✓ Fix applied successfully" +``` + +**Locations Fixed:** +- Line 446-447: Function `get_sf_out_offset_128x4` signature +- Line 487-488: Function `cvt_quant_to_fp4_get_sf_out_offset` signature +- Line 952: Function call with `nullopt` arguments + +**Verification:** + +```python +# Test that the fix works +from flashinfer.comm import trtllm_allreduce_fusion +print("✓ AllReduce fusion module imports successfully") + +# Run with vLLM +from vllm.config import CompilationConfig, PassConfig +pass_config = PassConfig(enable_fi_allreduce_fusion=True) +compilation_config = CompilationConfig(pass_config=pass_config) +# JIT compilation will now succeed +``` + +**Testing:** + +```bash +# Test with AllReduce fusion enabled (automatically enabled for TP >= 2) +VLLM_USE_FLASHINFER=1 python tests/kernels/run_flashinfer_test.py \ + --model qwen --tp 2 + +# Verify the fix was applied +python tests/kernels/test_allreduce_fusion_fix.py +``` + +#### Fix Required in FlashInfer (Upstream) + +**Option A (Preferred):** Change `std::optional` to `cuda::std::optional` in the header: + +```cpp +// Line 487-489 - fix namespace +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset( + cuda::std::optional batchIdx, // FIXED + int rowIdx, + int colIdx, + cuda::std::optional numRows, // FIXED + ... +) +``` + +**Option B:** Add a using declaration at the top of the file: + +```cpp +#include +// Add this line: +namespace std { using cuda::std::optional; } +``` + +**Option C:** Include standard C++17 `` and ensure compilation with `-std=c++17`: + +```cpp +#include // C++17 standard library +``` + +--- + +### Issue 2: FP8 MoE CUDA Version Mismatch + +**Severity:** Medium +**Affected versions:** FlashInfer 0.5.2, 0.5.3 (both compiled against CUDA ≤12.6) +**Affected models:** Qwen3-30B-A3B-Instruct-2507-FP8, other FP8 MoE models +**Environment variable:** `VLLM_USE_FLASHINFER_MOE_FP8` + +#### Symptom + +```python +NotImplementedError: FP8 block scaling not implemented for CUDA 12.6 or lower. +``` + +#### Root Cause + +FlashInfer's FP8 MoE kernel requires CUDA 12.7+, but both FlashInfer 0.5.2 and 0.5.3 packages were compiled against CUDA ≤12.6. Even though the host system has CUDA 12.8, FlashInfer checks its own compile-time CUDA version. + +**Tested:** Issue persists in both FlashInfer 0.5.2 and 0.5.3. + +#### Code Path + +``` +vllm/model_executor/layers/fused_moe/fused_moe.py + -> FusedMoE.forward_impl() + -> FlashInfer FP8 MoE kernel + -> Runtime check fails +``` + +#### Detection Methods + +The runtime CUDA version detection is inconsistent: +- `torch.cuda.runtime_version()` returns 12080 (12.8.0) +- `torch.version.cuda` returns "12.8" +- FlashInfer internal check returns CUDA 12.6 (compile-time) + +#### Workaround + +Skip FP8 MoE tests on systems where FlashInfer was compiled against CUDA < 12.7. + +#### Recommendation for FlashInfer Team + +1. Document the minimum CUDA version requirement for FP8 MoE clearly +2. Consider providing wheels compiled against different CUDA versions +3. Improve the error message to indicate it's a compile-time vs runtime version mismatch + +--- + +### Issue 3: MXFP4 MoE SM90 Backend Broken + +**Severity:** High +**Affected versions:** FlashInfer 0.5.2, 0.5.3 (SM90 kernels not implemented) +**Affected models:** GPT-OSS-120B and other MXFP4 models on Hopper +**Environment variable:** `VLLM_USE_FLASHINFER_MOE_MXFP4_BF16` + +#### Symptom + +``` +File "flashinfer/data/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu", line 214 +RuntimeError: Check failed: (false) is false: Could not construct fused moe op with the requested +input combination Activation: bfloat16, Weight: uint8, Output: bfloat16 +``` + +#### Root Cause + +vLLM's `mxfp4.py` has a backend called `SM90_FI_MXFP4_BF16` that claims to support FlashInfer MXFP4 MoE on Hopper (SM90): + +```python +# vllm/model_executor/layers/quantization/mxfp4.py:96-102 +if ( + current_platform.is_device_capability(90) # Hopper + and has_flashinfer() + and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 +): + logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90") + return Mxfp4Backend.SM90_FI_MXFP4_BF16 +``` + +However, when this backend is selected, the code path leads to `flashinfer_cutlass_fused_moe` which only has SM100 implementations: +- The CUDA file is literally named `flashinfer_cutlass_fused_moe_sm100_binding.cu` +- There is no corresponding `sm90_binding.cu` file + +#### Code Path + +``` +vllm/model_executor/layers/quantization/mxfp4.py + -> get_mxfp4_backend() returns SM90_FI_MXFP4_BF16 + -> Mxfp4MoEMethod.apply() + -> flashinfer_cutlass_fused_moe() + -> flashinfer/fused_moe/core.py:cutlass_fused_moe() + -> get_cutlass_fused_moe_module(device_arch) + -> Loads sm100_binding.cu on SM90 hardware + -> Kernel initialization fails +``` + +#### Workaround + +Disable FlashInfer MXFP4 MoE on Hopper to fall back to Marlin/Triton: + +```bash +export VLLM_USE_FLASHINFER_MOE_MXFP4_BF16=0 +export VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=0 +export VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=0 +``` + +#### Recommendation for FlashInfer Team + +1. **Option A:** Implement SM90 CUTLASS kernels for MXFP4 MoE +2. **Option B:** Remove/gate the SM90 code path in vLLM if FlashInfer doesn't plan to support it +3. Add architecture detection in `get_cutlass_fused_moe_module()` to fail gracefully on unsupported architectures + +--- + +### Issue 4: Attention Sinks Not Supported on Hopper + +**Severity:** Medium +**Affected versions:** FlashInfer 0.5.2, 0.5.3 (requires TRTLLM on SM100) +**Affected models:** GPT-OSS-120B and other models using attention sinks +**File:** `vllm/v1/attention/backends/flashinfer.py` + +#### Symptom + +``` +ValueError: Selected backend AttentionBackendEnum.FLASHINFER is not valid for this configuration. +Reason: ['sink setting not supported'] +``` + +#### Root Cause + +FlashInfer's attention sink support requires TRTLLM attention, which is only available on Blackwell (SM100): + +```python +# vllm/v1/attention/backends/flashinfer.py:353-366 +@classmethod +def supports_sink(cls) -> bool: + """FlashInfer supports sinks when TRTLLM attention is available (SM100).""" + from vllm.utils.flashinfer import ( + force_use_trtllm_attention, + supports_trtllm_attention, + ) + if force_use_trtllm_attention() is False: + return False + return supports_trtllm_attention() + +# vllm/utils/flashinfer.py:257-267 +def supports_trtllm_attention() -> bool: + if vllm_is_batch_invariant(): + return False + # Requires SM100 and NVIDIA artifactory to be accessible + return current_platform.is_device_capability(100) and has_nvidia_artifactory() +``` + +#### Workaround + +Use FlashAttention3 (`FLASH_ATTN`) for attention on Hopper when using models with attention sinks: + +```bash +export VLLM_ATTENTION_BACKEND=FLASH_ATTN +``` + +#### Recommendation for FlashInfer Team + +1. Consider implementing attention sink support for SM90 without requiring TRTLLM +2. Document that sink support requires SM100 + TRTLLM + +--- + +## Skipped Tests + +### Test: Qwen3-30B-A3B-Instruct-2507-FP8 + +**Reason:** FlashInfer FP8 MoE requires CUDA 12.7+, but FlashInfer 0.5.2 and 0.5.3 were both compiled against CUDA ≤12.6 + +```python +# tests/kernels/run_flashinfer_test.py +if model_name == "qwen": + print(f"\n⚠ Skipping {result_name}: FlashInfer FP8 MoE kernel requires " + "CUDA 12.7+ (FlashInfer package compiled against older CUDA)") + results[result_name] = RESULT_SKIPPED + continue +``` + +### Test: GPT-OSS-120B on Hopper (partial skip) + +**Reason:** FlashInfer MXFP4 MoE and FlashInfer attention (with sinks) not supported on SM90 + +The test runs but uses fallback backends: +- Attention: FlashAttention3 instead of FlashInfer +- MoE: Triton/Marlin instead of FlashInfer CUTLASS + +--- + +## Environment Variables Reference + +### Master Switch + +| Variable | Description | Default | +|----------|-------------|---------| +| `VLLM_USE_FLASHINFER` | Enable FlashInfer for all supported operators | `0` | + +### Individual Feature Flags + +| Variable | Description | Auto-enabled by master? | Notes | +|----------|-------------|------------------------|-------| +| `VLLM_ATTENTION_BACKEND` | Attention backend selection | Yes → `FLASHINFER` | | +| `VLLM_USE_FLASHINFER_SAMPLER` | Top-k/Top-p sampling | Yes | | +| `VLLM_USE_FLASHINFER_NORM` | RMSNorm | Yes | | +| `VLLM_USE_FLASHINFER_ACTIVATION` | SiLU/GELU activations | Yes | | +| `VLLM_USE_FLASHINFER_MOE_FP16` | FP16/BF16 MoE | Yes | | +| `VLLM_USE_FLASHINFER_MOE_FP8` | FP8 MoE | Yes | Requires CUDA 12.7+ | +| `VLLM_USE_FLASHINFER_MOE_FP4` | FP4 MoE (NVFP4) | Yes | | +| `VLLM_USE_FLASHINFER_MOE_MXFP4_BF16` | MXFP4 MoE with BF16 activation | Yes | **Broken on SM90** | +| `VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8` | MXFP4 MoE with MXFP8 activation | Yes | SM100 only | +| `VLLM_USE_FLASHINFER_ALLREDUCE` | AllReduce fusion | **Yes (auto-enabled for TP≥2)** | ✅ **Fixed (Nov 2025)** | +| `VLLM_ALL2ALL_BACKEND` | All2All communication backend | Yes → `flashinfer_all2allv` | | + +--- + +## Missing Operators for Complete FlashInfer Backend + +For vLLM to use FlashInfer as a **complete backend** without relying on vLLM's custom CUDA kernels, FlashInfer would need to implement the following operators: + +### Critical (Required for basic inference) + +| Operator | Current Status | vLLM Implementation | +|----------|---------------|---------------------| +| **Embedding** | ❌ Not available | `vllm/model_executor/layers/vocab_parallel_embedding.py` uses PyTorch's `F.embedding` | +| **Dense Linear (FP16/BF16)** | ⚠️ Partial | FlashInfer has FP8/FP4 GEMM but standard linear uses cuBLAS/Triton | +| **LayerNorm** | ✅ Available | `flashinfer.norm.layernorm` (in addition to RMSNorm) | +| **Softmax** | ✅ Available | `flashinfer.sampling.softmax` | +| **RoPE** | ✅ Available | `flashinfer.rope.*` | + +### Quantization Kernels (Required for quantized models) + +| Quantization | Current Status | Notes | +|--------------|---------------|-------| +| **AWQ** | ❌ Not available | vLLM uses Marlin/custom kernels | +| **GPTQ** | ❌ Not available | vLLM uses Marlin/ExLlama/custom kernels | +| **GGUF** | ❌ Not available | vLLM uses custom kernels | +| **Marlin** | ❌ Not available | vLLM's optimized W4A16 kernels | +| **FP8 (compressed-tensors)** | ⚠️ Partial | FlashInfer has FP8 GEMM but not all compression schemes | +| **W8A8** | ⚠️ Partial | FlashInfer has `mm_fp8` but not CUTLASS W8A8 | + +### Architecture-Specific Kernels + +| Kernel | Current Status | Used By | +|--------|---------------|---------| +| **MLA Attention** | ✅ Available | DeepSeek-V2, DeepSeek-V3 (`flashinfer.mla`, `flashinfer.xqa_mla`) | +| **Mamba/SSM** | ❌ Not available | Mamba, Jamba, Zamba (`vllm/model_executor/layers/mamba/ops/`) | +| **GDN Attention** | ❌ Not available | Specialized attention variants | +| **KDA Attention** | ❌ Not available | Key-value decomposed attention | +| **Tree Attention** | ❌ Not available | Speculative decoding | + +### Communication Kernels + +| Kernel | Current Status | Notes | +|--------|---------------|-------| +| **AllReduce** | ✅ **Fixed (Nov 2025)** | `flashinfer.comm.trtllm_allreduce_fusion` - patch applied | +| **All2All** | ✅ Available | `flashinfer_all2allv` backend works | +| **Custom AllReduce (one-shot)** | ⚠️ Partial | FlashInfer has `trtllm_custom_all_reduce` | +| **Quick AllReduce** | ❌ Not available | vLLM's `quick_all_reduce` is separate | + +### LoRA Kernels + +| Kernel | Current Status | Notes | +|--------|---------------|-------| +| **Punica BGMV** | ❌ Not available | Batched GEMV for LoRA (`vllm/lora/punica_wrapper/`) | +| **Punica SGMV** | ❌ Not available | Segmented GEMV for LoRA | +| **LoRA expand/shrink** | ❌ Not available | Used in multi-LoRA inference | + +### KV Cache Operations + +| Operation | Current Status | Notes | +|-----------|---------------|-------| +| **Append paged KV cache** | ✅ Available | `flashinfer.append_paged_kv_cache` | +| **Append MLA KV cache** | ✅ Available | `flashinfer.append_paged_mla_kv_cache` | +| **Copy blocks** | ❌ Not available | vLLM uses custom kernel for block copying | +| **Reshape/swap blocks** | ❌ Not available | vLLM uses custom kernels | + +### Miscellaneous + +| Kernel | Current Status | Notes | +|--------|---------------|-------| +| **Fused cross-entropy loss** | ❌ Not available | Used for training/fine-tuning | +| **Rotary embedding (batched)** | ✅ Available | `flashinfer.rope.*` | +| **Speculative sampling** | ✅ Available | `flashinfer.chain_speculative_sampling` | +| **Top-k/Top-p sampling** | ✅ Available | `flashinfer.sampling.*` | +| **Min-p sampling** | ✅ Available | `flashinfer.min_p_sampling_from_probs` | + +### Summary: What's Needed for Full FlashInfer Backend + +To run vLLM entirely on FlashInfer kernels (no vLLM custom ops), FlashInfer would need: + +1. **Embedding kernels** (vocab-parallel embedding lookup) +2. **Dense linear (non-quantized)** - Or defer to cuBLAS, which is current behavior +3. **AWQ/GPTQ/Marlin quantization kernels** for W4A16 models +4. **Mamba/SSM kernels** for state-space models +5. **LoRA kernels** (Punica BGMV/SGMV) for multi-adapter inference +6. **KV cache block copy/swap** kernels +7. ~~**Fix AllReduce fusion JIT bugs**~~ ✅ **FIXED (Nov 2025)** + +**Note:** Many of these are specialized kernels that may not make sense for FlashInfer's scope. The goal is not necessarily to replace everything, but to identify gaps for users who want maximum FlashInfer utilization. + +--- + +## Recommendations for FlashInfer Team + +### Priority 1: ~~Fix AllReduce Fusion JIT (C++ Namespace Bug)~~ ✅ FIXED + +**Status:** ✅ **FIXED (November 2025)** - Patch available above in Issue 1 section. + +The `trtllm_allreduce_fusion` kernel JIT compilation failure has been fixed. See the complete patch and installation instructions in the [Issue 1](#issue-1-allreduce-fusion-jit-compilation-failure) section above. + +### Priority 2: MXFP4 MoE SM90 Support + +Either implement SM90 CUTLASS kernels or clearly document that MXFP4 MoE is Blackwell-only. Currently vLLM claims SM90 support but it doesn't work. + +**Suggested actions:** +- Add `flashinfer_cutlass_fused_moe_sm90_binding.cu` implementation, OR +- Return an error from `get_cutlass_fused_moe_module()` for SM90, OR +- Coordinate with vLLM team to remove the `SM90_FI_MXFP4_BF16` backend + +### Priority 3: FP8 MoE CUDA Version + +Provide clearer error messages distinguishing between: +- Compile-time CUDA version of the FlashInfer package +- Runtime CUDA version of the system + +Consider shipping multiple wheel variants for different CUDA versions. + +### Priority 4: Attention Sinks on Hopper + +Consider implementing attention sink support that doesn't require TRTLLM, enabling FlashInfer attention for models like GPT-OSS on Hopper hardware. + +### Priority 5: FlashInfer-Bench Tracing with torch.compile + +**Status:** ❌ Not working - requires `enforce_eager=True` + +FlashInfer-Bench tracing currently only works when vLLM runs in eager mode (`enforce_eager=True`). When torch.compile is enabled: + +1. vLLM uses `custom_ops: ['none']` by default, which disables custom ops and uses `forward_native` (pure PyTorch) instead of `forward_cuda` (which calls FlashInfer) +2. Even if custom ops are enabled, the FlashInfer-Bench adapter wrappers are traced by torch.compile +3. During tracing, lazy imports trigger `find_spec` which torch.dynamo marks as "skipped" +4. This causes `torch._dynamo.exc.Unsupported` errors + +**Workaround:** Use `enforce_eager=True` when generating traces: + +```python +llm = LLM( + model="...", + enforce_eager=True, # Required for FlashInfer-Bench tracing +) +``` + +**Suggested actions for FlashInfer-Bench team:** +- Investigate using `torch.compiler.allow_in_graph` or custom graph break handling +- Consider tracing at the CUDA kernel level rather than Python wrapper level +- Add support for torch.compile by pre-importing all dependencies at module load time + +--- + +## Test Script + +A test script is available at `tests/kernels/run_flashinfer_test.py` to verify FlashInfer integration. + +**Note:** The test script automatically sets `VLLM_USE_FLASHINFER=1` and enables AllReduce fusion for TP >= 2. + +```bash +# Test all FlashInfer features (FlashInfer automatically enabled) +python tests/kernels/run_flashinfer_test.py --model all + +# Test specific model +python tests/kernels/run_flashinfer_test.py --model qwen +python tests/kernels/run_flashinfer_test.py --model llama +python tests/kernels/run_flashinfer_test.py --model gpt-oss + +# Test with FP8 quantization +python tests/kernels/run_flashinfer_test.py --model llama --fp8 + +# Test with specific TP size (AllReduce auto-enabled for TP >= 2) +python tests/kernels/run_flashinfer_test.py --model qwen --tp 2 +``` + +--- + +*Document last updated: November 29, 2025* +*vLLM version: 0.8.3rc2.dev6048* +*FlashInfer versions tested: 0.5.2, 0.5.3* + +**Updates:** +- **November 29, 2025**: AllReduce Fusion bug fixed - patch and installation script added +- **November 28, 2025**: Initial document created + diff --git a/tests/kernels/generate_flashinfer_traces.py b/tests/kernels/generate_flashinfer_traces.py new file mode 100644 index 000000000000..94ec2aff708d --- /dev/null +++ b/tests/kernels/generate_flashinfer_traces.py @@ -0,0 +1,1279 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Generate FlashInfer Bench traces from vLLM inference. + +This script runs inference on various models with FlashInfer enabled and +captures workload traces using FlashInfer-Bench. These traces can then be +used to optimize FlashInfer operations with custom kernels. + +The script automatically creates Definition files for vLLM's FlashInfer operations +based on the model's configuration (hidden sizes, head dims, etc.). + +Traced operators: +- Attention: + - GQA Paged Decode (attention decode with paged KV cache) + - GQA Paged Prefill (attention prefill with paged KV cache) + - GQA Ragged Prefill (attention prefill with ragged tensors) + - MLA Paged (Multi-head Latent Attention for DeepSeek models) +- Normalization: + - RMSNorm (fused_add_rmsnorm) +- Sampling: + - Top-k sampling + - Top-p sampling + - Top-k + Top-p sampling +- Activations: + - SiLU and mul + - GELU and mul + - GELU tanh and mul +- Positional Encoding: + - RoPE (apply_rope_with_cos_sin_cache_inplace) +- MoE: + - CUTLASS fused MoE + - TRT-LLM FP8 MoE + - TRT-LLM FP4 MoE +- GEMM (for quantized models): + - mm_fp4 (FP4 matrix multiplication) + - bmm_fp8 (FP8 batched matrix multiplication) + - grouped_gemm_nt_masked (MoE CUTEDSL grouped GEMM) +- Communication: + - AllReduce fusion + - MnnvlMoe dispatch/combine (All2All) + +NOTE: Tracing requires `enforce_eager=True` because FlashInfer-Bench adapters +are not compatible with torch.compile. See flashinfer_integration_issues.md +for details. + +Reference: https://bench.flashinfer.ai/docs/start/quickstart + +Installation: + pip install flashinfer-bench --no-deps + +Usage: + python tests/kernels/generate_flashinfer_traces.py --model qwen + python tests/kernels/generate_flashinfer_traces.py --model llama + python tests/kernels/generate_flashinfer_traces.py --model all --output-dir /path/to/traces +""" + +import argparse +import json +import os +import sys +from pathlib import Path + + +def check_compatibility(): + """Check if the environment is compatible for trace generation. + + Returns: + tuple: (is_compatible, error_message) + """ + # Check torch version + try: + import torch + torch_version = torch.__version__ + except ImportError: + return False, "torch is not installed" + + # Check if flashinfer-bench is installed + try: + import importlib.util + spec = importlib.util.find_spec("flashinfer_bench") + if spec is None: + return False, ( + "flashinfer-bench is not installed.\n" + "Install with: pip install flashinfer-bench --no-deps" + ) + except Exception as e: + return False, f"Error checking flashinfer-bench: {e}" + + # Check if vLLM can be imported + try: + import vllm # noqa: F401 + except Exception as e: + return False, f"vLLM import failed: {e}" + + return True, f"Compatible (torch {torch_version})" + + +def create_rmsnorm_definitions(hidden_sizes: list[int], output_dir: str): + """Create RMSNorm definitions for given hidden sizes.""" + from flashinfer_bench.data import ( + AxisConst, AxisVar, Definition, TensorSpec + ) + + definitions_dir = Path(output_dir) / "definitions" + definitions_dir.mkdir(parents=True, exist_ok=True) + + definitions = {} + for hidden_size in hidden_sizes: + def_name = f"fused_add_rmsnorm_h{hidden_size}" + + definition = Definition( + name=def_name, + op_type="rmsnorm", + axes={"M": AxisVar(), "H": AxisConst(value=hidden_size)}, + inputs={ + "hidden_states": TensorSpec(shape=["M", "H"], dtype="bfloat16"), + "residual": TensorSpec(shape=["M", "H"], dtype="bfloat16"), + "weight": TensorSpec(shape=["H"], dtype="bfloat16"), + }, + outputs={"output": TensorSpec(shape=["M", "H"], dtype="bfloat16")}, + reference="def run(hidden_states, residual, weight):\n return hidden_states\n", + ) + + definitions[def_name] = definition + + # Save to file + def_file = definitions_dir / f"{def_name}.json" + with open(def_file, 'w') as f: + json.dump(definition.model_dump(), f, indent=2) + + return definitions + + +def create_attention_definitions( + num_heads: int, + num_kv_heads: int, + head_dim: int, + output_dir: str, + page_size: int = 16, +): + """Create attention definitions for GQA paged decode/prefill. + + Note: For paged attention, the k_cache/v_cache have shape [num_pages, num_kv_heads, head_dim] + after normalization (NHD layout). The batch dimension B is separate from the cache dimension. + """ + from flashinfer_bench.data import ( + AxisConst, AxisVar, Definition, TensorSpec + ) + + definitions_dir = Path(output_dir) / "definitions" + definitions_dir.mkdir(parents=True, exist_ok=True) + + definitions = {} + + # GQA Paged Decode definition + # q: [B, H, D] - batch of queries + # k_cache/v_cache: [N, KV, D] - paged KV cache (N = total pages, independent of B) + # kv_indptr: [B1] - indptr for batch (length = batch_size + 1) + # kv_indices: [I] - indices into cache + decode_def_name = f"gqa_paged_decode_h{num_heads}_kv{num_kv_heads}_d{head_dim}_ps{page_size}" + decode_definition = Definition( + name=decode_def_name, + op_type="gqa_paged_decode", + axes={ + "B": AxisVar(), # Batch size (queries) + "B1": AxisVar(), # Batch size + 1 (for indptr) + "N": AxisVar(), # Number of cache pages (independent of B) + "I": AxisVar(), # Number of indices + "H": AxisConst(value=num_heads), + "KV": AxisConst(value=num_kv_heads), + "D": AxisConst(value=head_dim), + }, + inputs={ + "q": TensorSpec(shape=["B", "H", "D"], dtype="bfloat16"), + "k_cache": TensorSpec(shape=["N", "KV", "D"], dtype="bfloat16"), + "v_cache": TensorSpec(shape=["N", "KV", "D"], dtype="bfloat16"), + "kv_indptr": TensorSpec(shape=["B1"], dtype="int32"), + "kv_indices": TensorSpec(shape=["I"], dtype="int32"), + "sm_scale": TensorSpec(shape=[], dtype="float32"), + "page_size": TensorSpec(shape=[], dtype="int32"), + }, + outputs={"output": TensorSpec(shape=["B", "H", "D"], dtype="bfloat16")}, + reference="def run(q, k_cache, v_cache, kv_indptr, kv_indices, sm_scale, page_size):\n return q\n", + ) + definitions[decode_def_name] = decode_definition + + # Save decode definition + def_file = definitions_dir / f"{decode_def_name}.json" + with open(def_file, 'w') as f: + json.dump(decode_definition.model_dump(), f, indent=2) + + # GQA Paged Prefill (causal) definition + # q: [T, H, D] - total tokens across batch + # k_cache/v_cache: [N, KV, D] - paged KV cache + # qo_indptr: [B1] - indptr for query batch (length = batch_size + 1) + # kv_indptr: [B1] - indptr for KV batch (length = batch_size + 1) + # kv_indices: [I] - indices into cache + prefill_def_name = f"gqa_paged_prefill_causal_h{num_heads}_kv{num_kv_heads}_d{head_dim}_ps{page_size}" + prefill_definition = Definition( + name=prefill_def_name, + op_type="gqa_paged_prefill", + axes={ + "T": AxisVar(), # Total tokens + "B1": AxisVar(), # Batch size + 1 (for indptr) + "N": AxisVar(), # Number of cache pages + "I": AxisVar(), # Number of indices + "H": AxisConst(value=num_heads), + "KV": AxisConst(value=num_kv_heads), + "D": AxisConst(value=head_dim), + }, + inputs={ + "q": TensorSpec(shape=["T", "H", "D"], dtype="bfloat16"), + "k_cache": TensorSpec(shape=["N", "KV", "D"], dtype="bfloat16"), + "v_cache": TensorSpec(shape=["N", "KV", "D"], dtype="bfloat16"), + "qo_indptr": TensorSpec(shape=["B1"], dtype="int32"), + "kv_indptr": TensorSpec(shape=["B1"], dtype="int32"), + "kv_indices": TensorSpec(shape=["I"], dtype="int32"), + "sm_scale": TensorSpec(shape=[], dtype="float32"), + "page_size": TensorSpec(shape=[], dtype="int32"), + "causal": TensorSpec(shape=[], dtype="bool"), + }, + outputs={"output": TensorSpec(shape=["T", "H", "D"], dtype="bfloat16")}, + reference="def run(q, k_cache, v_cache, qo_indptr, kv_indptr, kv_indices, sm_scale, page_size, causal):\n return q\n", + ) + definitions[prefill_def_name] = prefill_definition + + # Save prefill definition + def_file = definitions_dir / f"{prefill_def_name}.json" + with open(def_file, 'w') as f: + json.dump(prefill_definition.model_dump(), f, indent=2) + + # GQA Ragged Prefill (causal) definition + # q: [T, H, D] - total tokens + # k: [T, KV, D] - keys (same token count as q) + # v: [T, KV, D] - values (same token count as q) + # qo_indptr: [B1] - indptr for query batch (length = batch_size + 1) + # kv_indptr: [B1] - indptr for KV batch (length = batch_size + 1) + ragged_def_name = f"gqa_ragged_prefill_causal_h{num_heads}_kv{num_kv_heads}_d{head_dim}" + ragged_definition = Definition( + name=ragged_def_name, + op_type="gqa_ragged_prefill", + axes={ + "T": AxisVar(), # Total tokens + "B1": AxisVar(), # Batch size + 1 (for indptr) + "H": AxisConst(value=num_heads), + "KV": AxisConst(value=num_kv_heads), + "D": AxisConst(value=head_dim), + }, + inputs={ + "q": TensorSpec(shape=["T", "H", "D"], dtype="bfloat16"), + "k": TensorSpec(shape=["T", "KV", "D"], dtype="bfloat16"), + "v": TensorSpec(shape=["T", "KV", "D"], dtype="bfloat16"), + "qo_indptr": TensorSpec(shape=["B1"], dtype="int32"), + "kv_indptr": TensorSpec(shape=["B1"], dtype="int32"), + "sm_scale": TensorSpec(shape=[], dtype="float32"), + "causal": TensorSpec(shape=[], dtype="bool"), + }, + outputs={"output": TensorSpec(shape=["T", "H", "D"], dtype="bfloat16")}, + reference="def run(q, k, v, qo_indptr, kv_indptr, sm_scale, causal):\n return q\n", + ) + definitions[ragged_def_name] = ragged_definition + + # Save ragged definition + def_file = definitions_dir / f"{ragged_def_name}.json" + with open(def_file, 'w') as f: + json.dump(ragged_definition.model_dump(), f, indent=2) + + return definitions + + +def create_mla_definitions( + num_heads: int, + kv_lora_rank: int, + qk_rope_head_dim: int, + output_dir: str, + page_size: int = 16, +): + """Create MLA attention definitions for DeepSeek models. + + Note: For MLA paged attention, the cache tensors have shape [N, CKV] or [N, KPE] + where N is the total number of cache entries (independent of batch size). + """ + from flashinfer_bench.data import ( + AxisConst, AxisVar, Definition, TensorSpec + ) + + definitions_dir = Path(output_dir) / "definitions" + definitions_dir.mkdir(parents=True, exist_ok=True) + + definitions = {} + + # MLA Paged Decode definition + # q_nope: [B, H, CKV] - batch of queries (non-positional) + # q_pe: [B, H, KPE] - batch of queries (positional) + # ckv_cache: [N, CKV] - paged compressed KV cache + # kpe_cache: [N, KPE] - paged key positional encoding cache + # kv_indptr: [B1] - indptr for batch (length = batch_size + 1) + decode_def_name = f"mla_paged_decode_h{num_heads}_ckv{kv_lora_rank}_kpe{qk_rope_head_dim}_ps{page_size}" + decode_definition = Definition( + name=decode_def_name, + op_type="mla_paged_decode", + axes={ + "B": AxisVar(), # Batch size + "B1": AxisVar(), # Batch size + 1 (for indptr) + "N": AxisVar(), # Number of cache entries + "I": AxisVar(), # Number of indices + "H": AxisConst(value=num_heads), + "CKV": AxisConst(value=kv_lora_rank), + "KPE": AxisConst(value=qk_rope_head_dim), + }, + inputs={ + "q_nope": TensorSpec(shape=["B", "H", "CKV"], dtype="bfloat16"), + "q_pe": TensorSpec(shape=["B", "H", "KPE"], dtype="bfloat16"), + "ckv_cache": TensorSpec(shape=["N", "CKV"], dtype="bfloat16"), + "kpe_cache": TensorSpec(shape=["N", "KPE"], dtype="bfloat16"), + "kv_indptr": TensorSpec(shape=["B1"], dtype="int32"), + "kv_indices": TensorSpec(shape=["I"], dtype="int32"), + "sm_scale": TensorSpec(shape=[], dtype="float32"), + "page_size": TensorSpec(shape=[], dtype="int32"), + }, + outputs={"output": TensorSpec(shape=["B", "H", "CKV"], dtype="bfloat16")}, + reference="def run(q_nope, q_pe, ckv_cache, kpe_cache, kv_indptr, kv_indices, sm_scale, page_size):\n return q_nope\n", + ) + definitions[decode_def_name] = decode_definition + + # Save decode definition + def_file = definitions_dir / f"{decode_def_name}.json" + with open(def_file, 'w') as f: + json.dump(decode_definition.model_dump(), f, indent=2) + + # MLA Paged Prefill (causal) definition + # q_nope: [T, H, CKV] - total tokens queries (non-positional) + # q_pe: [T, H, KPE] - total tokens queries (positional) + # ckv_cache: [N, CKV] - paged compressed KV cache + # kpe_cache: [N, KPE] - paged key positional encoding cache + # qo_indptr: [B1] - indptr for query batch (length = batch_size + 1) + # kv_indptr: [B1] - indptr for KV batch (length = batch_size + 1) + prefill_def_name = f"mla_paged_prefill_causal_h{num_heads}_ckv{kv_lora_rank}_kpe{qk_rope_head_dim}_ps{page_size}" + prefill_definition = Definition( + name=prefill_def_name, + op_type="mla_paged_prefill", + axes={ + "T": AxisVar(), # Total tokens + "B1": AxisVar(), # Batch size + 1 (for indptr) + "N": AxisVar(), # Number of cache entries + "I": AxisVar(), # Number of indices + "H": AxisConst(value=num_heads), + "CKV": AxisConst(value=kv_lora_rank), + "KPE": AxisConst(value=qk_rope_head_dim), + }, + inputs={ + "q_nope": TensorSpec(shape=["T", "H", "CKV"], dtype="bfloat16"), + "q_pe": TensorSpec(shape=["T", "H", "KPE"], dtype="bfloat16"), + "ckv_cache": TensorSpec(shape=["N", "CKV"], dtype="bfloat16"), + "kpe_cache": TensorSpec(shape=["N", "KPE"], dtype="bfloat16"), + "qo_indptr": TensorSpec(shape=["B1"], dtype="int32"), + "kv_indptr": TensorSpec(shape=["B1"], dtype="int32"), + "kv_indices": TensorSpec(shape=["I"], dtype="int32"), + "sm_scale": TensorSpec(shape=[], dtype="float32"), + "page_size": TensorSpec(shape=[], dtype="int32"), + "causal": TensorSpec(shape=[], dtype="bool"), + }, + outputs={"output": TensorSpec(shape=["T", "H", "CKV"], dtype="bfloat16")}, + reference="def run(q_nope, q_pe, ckv_cache, kpe_cache, qo_indptr, kv_indptr, kv_indices, sm_scale, page_size, causal):\n return q_nope\n", + ) + definitions[prefill_def_name] = prefill_definition + + # Save prefill definition + def_file = definitions_dir / f"{prefill_def_name}.json" + with open(def_file, 'w') as f: + json.dump(prefill_definition.model_dump(), f, indent=2) + + return definitions + + +def create_sampling_definitions(vocab_size: int, output_dir: str): + """Create sampling definitions for top-k/top-p sampling. + + Note: k and p parameters are not included in definitions as they can be + tensors with per-batch values, not just scalars. + """ + from flashinfer_bench.data import ( + AxisConst, AxisVar, Definition, TensorSpec + ) + + definitions_dir = Path(output_dir) / "definitions" + definitions_dir.mkdir(parents=True, exist_ok=True) + + definitions = {} + + # Top-k sampling + topk_def_name = f"top_k_sampling_v{vocab_size}" + topk_definition = Definition( + name=topk_def_name, + op_type="sampling_top_k", + axes={ + "B": AxisVar(), # Batch size + "V": AxisConst(value=vocab_size), + }, + inputs={ + "probs": TensorSpec(shape=["B", "V"], dtype="float32"), + }, + outputs={"output": TensorSpec(shape=["B"], dtype="int64")}, + reference="def run(probs):\n return torch.argmax(probs, dim=-1)\n", + ) + definitions[topk_def_name] = topk_definition + + def_file = definitions_dir / f"{topk_def_name}.json" + with open(def_file, 'w') as f: + json.dump(topk_definition.model_dump(), f, indent=2) + + # Top-p sampling + topp_def_name = f"top_p_sampling_v{vocab_size}" + topp_definition = Definition( + name=topp_def_name, + op_type="sampling_top_p", + axes={ + "B": AxisVar(), + "V": AxisConst(value=vocab_size), + }, + inputs={ + "probs": TensorSpec(shape=["B", "V"], dtype="float32"), + }, + outputs={"output": TensorSpec(shape=["B"], dtype="int64")}, + reference="def run(probs):\n return torch.argmax(probs, dim=-1)\n", + ) + definitions[topp_def_name] = topp_definition + + def_file = definitions_dir / f"{topp_def_name}.json" + with open(def_file, 'w') as f: + json.dump(topp_definition.model_dump(), f, indent=2) + + # Top-k + Top-p sampling + topkp_def_name = f"top_k_top_p_sampling_v{vocab_size}" + topkp_definition = Definition( + name=topkp_def_name, + op_type="sampling_top_k_top_p", + axes={ + "B": AxisVar(), + "V": AxisConst(value=vocab_size), + }, + inputs={ + "logits": TensorSpec(shape=["B", "V"], dtype="float32"), + }, + outputs={"output": TensorSpec(shape=["B"], dtype="int64")}, + reference="def run(logits):\n return torch.argmax(logits, dim=-1)\n", + ) + definitions[topkp_def_name] = topkp_definition + + def_file = definitions_dir / f"{topkp_def_name}.json" + with open(def_file, 'w') as f: + json.dump(topkp_definition.model_dump(), f, indent=2) + + return definitions + + +def create_activation_definitions(hidden_sizes: list[int], output_dir: str): + """Create activation definitions for SiLU, GELU.""" + from flashinfer_bench.data import ( + AxisConst, AxisVar, Definition, TensorSpec + ) + + definitions_dir = Path(output_dir) / "definitions" + definitions_dir.mkdir(parents=True, exist_ok=True) + + definitions = {} + + for hidden_size in hidden_sizes: + # SiLU and mul + silu_def_name = f"silu_and_mul_d{hidden_size}" + silu_definition = Definition( + name=silu_def_name, + op_type="activation_silu", + axes={ + "T": AxisVar(), # Total tokens + "D": AxisConst(value=hidden_size), + }, + inputs={ + "input": TensorSpec(shape=["T", "D"], dtype="bfloat16"), + }, + outputs={"output": TensorSpec(shape=["T", "D"], dtype="bfloat16")}, + reference="def run(input):\n return input\n", + ) + definitions[silu_def_name] = silu_definition + + def_file = definitions_dir / f"{silu_def_name}.json" + with open(def_file, 'w') as f: + json.dump(silu_definition.model_dump(), f, indent=2) + + # GELU and mul + gelu_def_name = f"gelu_and_mul_d{hidden_size}" + gelu_definition = Definition( + name=gelu_def_name, + op_type="activation_gelu", + axes={ + "T": AxisVar(), + "D": AxisConst(value=hidden_size), + }, + inputs={ + "input": TensorSpec(shape=["T", "D"], dtype="bfloat16"), + }, + outputs={"output": TensorSpec(shape=["T", "D"], dtype="bfloat16")}, + reference="def run(input):\n return input\n", + ) + definitions[gelu_def_name] = gelu_definition + + def_file = definitions_dir / f"{gelu_def_name}.json" + with open(def_file, 'w') as f: + json.dump(gelu_definition.model_dump(), f, indent=2) + + # GELU tanh and mul + gelu_tanh_def_name = f"gelu_tanh_and_mul_d{hidden_size}" + gelu_tanh_definition = Definition( + name=gelu_tanh_def_name, + op_type="activation_gelu_tanh", + axes={ + "T": AxisVar(), + "D": AxisConst(value=hidden_size), + }, + inputs={ + "input": TensorSpec(shape=["T", "D"], dtype="bfloat16"), + }, + outputs={"output": TensorSpec(shape=["T", "D"], dtype="bfloat16")}, + reference="def run(input):\n return input\n", + ) + definitions[gelu_tanh_def_name] = gelu_tanh_definition + + def_file = definitions_dir / f"{gelu_tanh_def_name}.json" + with open(def_file, 'w') as f: + json.dump(gelu_tanh_definition.model_dump(), f, indent=2) + + return definitions + + +def create_rope_definitions(head_dims: list[int], output_dir: str): + """Create RoPE definitions.""" + from flashinfer_bench.data import ( + AxisConst, AxisVar, Definition, TensorSpec + ) + + definitions_dir = Path(output_dir) / "definitions" + definitions_dir.mkdir(parents=True, exist_ok=True) + + definitions = {} + + for head_dim in head_dims: + for interleave in [False, True]: + interleave_str = "_interleave" if interleave else "" + def_name = f"rope_inplace_h{head_dim}{interleave_str}" + definition = Definition( + name=def_name, + op_type="rope_inplace", + axes={ + "T": AxisVar(), # Total tokens + "H": AxisVar(), # Number of heads + "D": AxisConst(value=head_dim), + "S": AxisVar(), # Max sequence length in cache + }, + inputs={ + "q": TensorSpec(shape=["T", "H", "D"], dtype="bfloat16"), + "k": TensorSpec(shape=["T", "H", "D"], dtype="bfloat16"), + "cos_cache": TensorSpec(shape=["S", "D"], dtype="bfloat16"), + "sin_cache": TensorSpec(shape=["S", "D"], dtype="bfloat16"), + "pos_ids": TensorSpec(shape=["T"], dtype="int64"), + "interleave": TensorSpec(shape=[], dtype="bool"), + }, + outputs={ + "q_out": TensorSpec(shape=["T", "H", "D"], dtype="bfloat16"), + "k_out": TensorSpec(shape=["T", "H", "D"], dtype="bfloat16"), + }, + reference="def run(q, k, cos_cache, sin_cache, pos_ids, interleave):\n return q, k\n", + ) + definitions[def_name] = definition + + def_file = definitions_dir / f"{def_name}.json" + with open(def_file, 'w') as f: + json.dump(definition.model_dump(), f, indent=2) + + return definitions + + +def create_moe_definitions( + num_experts: int, + hidden_size: int, + intermediate_size: int, + topk: int, + output_dir: str, +): + """Create MoE definitions for CUTLASS and TRT-LLM MoE kernels.""" + from flashinfer_bench.data import ( + AxisConst, AxisVar, Definition, TensorSpec + ) + + definitions_dir = Path(output_dir) / "definitions" + definitions_dir.mkdir(parents=True, exist_ok=True) + + definitions = {} + + # CUTLASS MoE + cutlass_def_name = f"cutlass_moe_e{num_experts}_h{hidden_size}_i{intermediate_size}_k{topk}" + cutlass_definition = Definition( + name=cutlass_def_name, + op_type="fused_moe_cutlass", + axes={ + "T": AxisVar(), # Total tokens + "E": AxisConst(value=num_experts), + "H": AxisConst(value=hidden_size), + "I": AxisConst(value=intermediate_size), + "K": AxisConst(value=topk), + }, + inputs={ + "hidden_states": TensorSpec(shape=["T", "H"], dtype="bfloat16"), + "w1": TensorSpec(shape=["E", "I", "H"], dtype="bfloat16"), + "w2": TensorSpec(shape=["E", "H", "I"], dtype="bfloat16"), + "topk_weights": TensorSpec(shape=["T", "K"], dtype="float32"), + "topk_ids": TensorSpec(shape=["T", "K"], dtype="int32"), + }, + outputs={"output": TensorSpec(shape=["T", "H"], dtype="bfloat16")}, + reference="def run(hidden_states, w1, w2, topk_weights, topk_ids):\n return hidden_states\n", + ) + definitions[cutlass_def_name] = cutlass_definition + + def_file = definitions_dir / f"{cutlass_def_name}.json" + with open(def_file, 'w') as f: + json.dump(cutlass_definition.model_dump(), f, indent=2) + + # TRT-LLM FP8 MoE + fp8_def_name = f"trtllm_fp8_moe_e{num_experts}_h{hidden_size}_i{intermediate_size}_k{topk}" + fp8_definition = Definition( + name=fp8_def_name, + op_type="fused_moe_fp8", + axes={ + "T": AxisVar(), + "E": AxisConst(value=num_experts), + "H": AxisConst(value=hidden_size), + "I": AxisConst(value=intermediate_size), + "K": AxisConst(value=topk), + }, + inputs={ + "hidden_states": TensorSpec(shape=["T", "H"], dtype="bfloat16"), + }, + outputs={"output": TensorSpec(shape=["T", "H"], dtype="bfloat16")}, + reference="def run(hidden_states):\n return hidden_states\n", + ) + definitions[fp8_def_name] = fp8_definition + + def_file = definitions_dir / f"{fp8_def_name}.json" + with open(def_file, 'w') as f: + json.dump(fp8_definition.model_dump(), f, indent=2) + + # TRT-LLM FP4 MoE + fp4_def_name = f"trtllm_fp4_moe_e{num_experts}_h{hidden_size}" + fp4_definition = Definition( + name=fp4_def_name, + op_type="fused_moe_fp4", + axes={ + "T": AxisVar(), + "E": AxisConst(value=num_experts), + "H": AxisConst(value=hidden_size), + }, + inputs={ + "hidden_states": TensorSpec(shape=["T", "H"], dtype="bfloat16"), + }, + outputs={"output": TensorSpec(shape=["T", "H"], dtype="bfloat16")}, + reference="def run(hidden_states):\n return hidden_states\n", + ) + definitions[fp4_def_name] = fp4_definition + + def_file = definitions_dir / f"{fp4_def_name}.json" + with open(def_file, 'w') as f: + json.dump(fp4_definition.model_dump(), f, indent=2) + + return definitions + + +def create_gemm_definitions(hidden_size: int, intermediate_size: int, output_dir: str): + """Create GEMM definitions for MLP layers. + + MLP typically uses: + - Gate projection: [hidden_size] -> [intermediate_size] + - Up projection: [hidden_size] -> [intermediate_size] + - Down projection: [intermediate_size] -> [hidden_size] + """ + from flashinfer_bench.data import ( + AxisConst, AxisVar, Definition, TensorSpec + ) + + definitions_dir = Path(output_dir) / "definitions" + definitions_dir.mkdir(parents=True, exist_ok=True) + + definitions = {} + + # FP4 GEMM for gate/up projection (hidden -> intermediate) + # Note: FP4 data is packed as int8 (2 FP4 values per byte) + fp4_gate_def_name = f"mm_fp4_k{hidden_size}_n{intermediate_size}" + fp4_gate_definition = Definition( + name=fp4_gate_def_name, + op_type="gemm_fp4", + axes={ + "M": AxisVar(), # Batch/sequence dimension + "K": AxisConst(value=hidden_size // 2), # FP4 packed (2 values per byte) + "N": AxisConst(value=intermediate_size), + }, + inputs={ + "a": TensorSpec(shape=["M", "K"], dtype="int8"), # FP4 packed as int8 + "b": TensorSpec(shape=["N", "K"], dtype="int8"), # FP4 packed, transposed + "a_descale": TensorSpec(shape=["M"], dtype="float32"), + "b_descale": TensorSpec(shape=["N"], dtype="float32"), + }, + outputs={"output": TensorSpec(shape=["M", "N"], dtype="bfloat16")}, + reference="def run(a, b, a_descale, b_descale):\n return torch.zeros(a.shape[0], b.shape[0])\n", + ) + definitions[fp4_gate_def_name] = fp4_gate_definition + + def_file = definitions_dir / f"{fp4_gate_def_name}.json" + with open(def_file, 'w') as f: + json.dump(fp4_gate_definition.model_dump(), f, indent=2) + + # FP4 GEMM for down projection (intermediate -> hidden) + fp4_down_def_name = f"mm_fp4_k{intermediate_size}_n{hidden_size}" + fp4_down_definition = Definition( + name=fp4_down_def_name, + op_type="gemm_fp4", + axes={ + "M": AxisVar(), + "K": AxisConst(value=intermediate_size // 2), # FP4 packed + "N": AxisConst(value=hidden_size), + }, + inputs={ + "a": TensorSpec(shape=["M", "K"], dtype="int8"), # FP4 packed as int8 + "b": TensorSpec(shape=["N", "K"], dtype="int8"), + "a_descale": TensorSpec(shape=["M"], dtype="float32"), + "b_descale": TensorSpec(shape=["N"], dtype="float32"), + }, + outputs={"output": TensorSpec(shape=["M", "N"], dtype="bfloat16")}, + reference="def run(a, b, a_descale, b_descale):\n return torch.zeros(a.shape[0], b.shape[0])\n", + ) + definitions[fp4_down_def_name] = fp4_down_definition + + def_file = definitions_dir / f"{fp4_down_def_name}.json" + with open(def_file, 'w') as f: + json.dump(fp4_down_definition.model_dump(), f, indent=2) + + # FP8 GEMM for gate/up projection + fp8_gate_def_name = f"mm_fp8_k{hidden_size}_n{intermediate_size}" + fp8_gate_definition = Definition( + name=fp8_gate_def_name, + op_type="gemm_fp8", + axes={ + "M": AxisVar(), + "K": AxisConst(value=hidden_size), + "N": AxisConst(value=intermediate_size), + }, + inputs={ + "A": TensorSpec(shape=["M", "K"], dtype="float8_e4m3fn"), + "B": TensorSpec(shape=["K", "N"], dtype="float8_e4m3fn"), + "A_scale": TensorSpec(shape=["M"], dtype="float32"), + "B_scale": TensorSpec(shape=["N"], dtype="float32"), + }, + outputs={"output": TensorSpec(shape=["M", "N"], dtype="bfloat16")}, + reference="def run(A, B, A_scale, B_scale):\n return torch.zeros(A.shape[0], B.shape[1])\n", + ) + definitions[fp8_gate_def_name] = fp8_gate_definition + + def_file = definitions_dir / f"{fp8_gate_def_name}.json" + with open(def_file, 'w') as f: + json.dump(fp8_gate_definition.model_dump(), f, indent=2) + + # FP8 GEMM for down projection + fp8_down_def_name = f"mm_fp8_k{intermediate_size}_n{hidden_size}" + fp8_down_definition = Definition( + name=fp8_down_def_name, + op_type="gemm_fp8", + axes={ + "M": AxisVar(), + "K": AxisConst(value=intermediate_size), + "N": AxisConst(value=hidden_size), + }, + inputs={ + "A": TensorSpec(shape=["M", "K"], dtype="float8_e4m3fn"), + "B": TensorSpec(shape=["K", "N"], dtype="float8_e4m3fn"), + "A_scale": TensorSpec(shape=["M"], dtype="float32"), + "B_scale": TensorSpec(shape=["N"], dtype="float32"), + }, + outputs={"output": TensorSpec(shape=["M", "N"], dtype="bfloat16")}, + reference="def run(A, B, A_scale, B_scale):\n return torch.zeros(A.shape[0], B.shape[1])\n", + ) + definitions[fp8_down_def_name] = fp8_down_definition + + def_file = definitions_dir / f"{fp8_down_def_name}.json" + with open(def_file, 'w') as f: + json.dump(fp8_down_definition.model_dump(), f, indent=2) + + return definitions + + +def create_comm_definitions(hidden_sizes: list[int], output_dir: str): + """Create communication definitions for AllReduce.""" + from flashinfer_bench.data import ( + AxisConst, AxisVar, Definition, TensorSpec + ) + + definitions_dir = Path(output_dir) / "definitions" + definitions_dir.mkdir(parents=True, exist_ok=True) + + definitions = {} + + for hidden_size in hidden_sizes: + # AllReduce fusion + allreduce_def_name = f"allreduce_fusion_h{hidden_size}" + allreduce_definition = Definition( + name=allreduce_def_name, + op_type="comm_allreduce", + axes={ + "T": AxisVar(), # Total tokens + "H": AxisConst(value=hidden_size), + }, + inputs={ + "input": TensorSpec(shape=["T", "H"], dtype="bfloat16"), + }, + outputs={"output": TensorSpec(shape=["T", "H"], dtype="bfloat16")}, + reference="def run(input):\n return input\n", + ) + definitions[allreduce_def_name] = allreduce_definition + + def_file = definitions_dir / f"{allreduce_def_name}.json" + with open(def_file, 'w') as f: + json.dump(allreduce_definition.model_dump(), f, indent=2) + + # MnnvlMoe dispatch + dispatch_def_name = f"mnnvl_moe_dispatch_h{hidden_size}" + dispatch_definition = Definition( + name=dispatch_def_name, + op_type="moe_dispatch", + axes={ + "T": AxisVar(), + "H": AxisConst(value=hidden_size), + }, + inputs={ + "hidden_states": TensorSpec(shape=["T", "H"], dtype="bfloat16"), + }, + outputs={"output": TensorSpec(shape=["T", "H"], dtype="bfloat16")}, + reference="def run(hidden_states):\n return hidden_states\n", + ) + definitions[dispatch_def_name] = dispatch_definition + + def_file = definitions_dir / f"{dispatch_def_name}.json" + with open(def_file, 'w') as f: + json.dump(dispatch_definition.model_dump(), f, indent=2) + + # MnnvlMoe combine + combine_def_name = f"mnnvl_moe_combine_h{hidden_size}" + combine_definition = Definition( + name=combine_def_name, + op_type="moe_combine", + axes={ + "T": AxisVar(), + "H": AxisConst(value=hidden_size), + }, + inputs={ + "hidden_states": TensorSpec(shape=["T", "H"], dtype="bfloat16"), + }, + outputs={"output": TensorSpec(shape=["T", "H"], dtype="bfloat16")}, + reference="def run(hidden_states):\n return hidden_states\n", + ) + definitions[combine_def_name] = combine_definition + + def_file = definitions_dir / f"{combine_def_name}.json" + with open(def_file, 'w') as f: + json.dump(combine_definition.model_dump(), f, indent=2) + + return definitions + + +def setup_tracing_for_model(model_config, output_dir: str, page_size: int = 16): + """Set up tracing definitions and configs for a specific model. + + Args: + model_config: The model configuration from transformers + output_dir: Directory to store traces + page_size: KV cache page size (default 16, matching vLLM default) + """ + from flashinfer_bench.data import TraceSet + from flashinfer_bench.tracing import TracingConfig, enable_tracing + + all_definitions = {} + + # Determine hidden sizes to trace based on model + hidden_sizes = [] + + # Add model's hidden size + if hasattr(model_config, 'hidden_size'): + hidden_sizes.append(model_config.hidden_size) + + # For MoE models, add intermediate size + if hasattr(model_config, 'intermediate_size'): + hidden_sizes.append(model_config.intermediate_size) + + # Add some common sizes + for size in [2048, 4096, 7168, 8192]: + if size not in hidden_sizes: + hidden_sizes.append(size) + + # Get model dimensions + num_heads = getattr(model_config, 'num_attention_heads', 32) + num_kv_heads = getattr(model_config, 'num_key_value_heads', num_heads) + head_dim = getattr(model_config, 'head_dim', None) + if head_dim is None and hasattr(model_config, 'hidden_size'): + head_dim = model_config.hidden_size // num_heads + vocab_size = getattr(model_config, 'vocab_size', 32000) + intermediate_size = getattr(model_config, 'intermediate_size', 4096) + + # MoE parameters + num_experts = getattr(model_config, 'num_local_experts', None) + if num_experts is None: + num_experts = getattr(model_config, 'num_experts', 8) + topk = getattr(model_config, 'num_experts_per_tok', 2) + + print(f"\n✓ Creating definitions for model with:") + print(f" hidden_sizes: {hidden_sizes}") + print(f" vocab_size: {vocab_size}") + print(f" num_heads: {num_heads}, kv_heads: {num_kv_heads}, head_dim: {head_dim}") + + # 1. RMSNorm definitions + print(f"✓ Creating RMSNorm definitions") + rmsnorm_defs = create_rmsnorm_definitions(hidden_sizes, output_dir) + all_definitions.update(rmsnorm_defs) + + # 2. Attention definitions + if head_dim: + print(f"✓ Creating attention definitions (page_size={page_size})") + attn_defs = create_attention_definitions(num_heads, num_kv_heads, head_dim, output_dir, page_size=page_size) + all_definitions.update(attn_defs) + + # 3. MLA definitions (DeepSeek models) + kv_lora_rank = getattr(model_config, 'kv_lora_rank', None) + qk_rope_head_dim = getattr(model_config, 'qk_rope_head_dim', None) + if kv_lora_rank and qk_rope_head_dim: + print(f"✓ Creating MLA definitions") + mla_defs = create_mla_definitions(num_heads, kv_lora_rank, qk_rope_head_dim, output_dir, page_size=page_size) + all_definitions.update(mla_defs) + + # 4. Sampling definitions + print(f"✓ Creating sampling definitions (vocab_size={vocab_size})") + sampling_defs = create_sampling_definitions(vocab_size, output_dir) + all_definitions.update(sampling_defs) + + # 5. Activation definitions + print(f"✓ Creating activation definitions (SiLU, GELU)") + activation_defs = create_activation_definitions(hidden_sizes, output_dir) + all_definitions.update(activation_defs) + + # 6. RoPE definitions + if head_dim: + head_dims = [head_dim] + # Add common head dims + for hd in [64, 128, 256]: + if hd not in head_dims: + head_dims.append(hd) + print(f"✓ Creating RoPE definitions (head_dims={head_dims})") + rope_defs = create_rope_definitions(head_dims, output_dir) + all_definitions.update(rope_defs) + + # 7. MoE definitions (if model has MoE) + if hasattr(model_config, 'num_local_experts') or hasattr(model_config, 'num_experts'): + print(f"✓ Creating MoE definitions (experts={num_experts}, topk={topk})") + moe_defs = create_moe_definitions( + num_experts=num_experts, + hidden_size=model_config.hidden_size, + intermediate_size=intermediate_size, + topk=topk, + output_dir=output_dir, + ) + all_definitions.update(moe_defs) + + # 8. Communication definitions (AllReduce, All2All) + print(f"✓ Creating communication definitions") + comm_defs = create_comm_definitions(hidden_sizes, output_dir) + all_definitions.update(comm_defs) + + # 9. GEMM definitions (for MLP layers) + print(f"✓ Creating GEMM definitions (FP4/FP8 for MLP)") + gemm_defs = create_gemm_definitions( + hidden_size=model_config.hidden_size, + intermediate_size=intermediate_size, + output_dir=output_dir, + ) + all_definitions.update(gemm_defs) + + # Create TraceSet + trace_set = TraceSet( + root=output_dir, + definitions=all_definitions, + solutions={}, + traces={} + ) + + # Create tracing configs for all definitions + tracing_configs = {} + for def_name in all_definitions.keys(): + tracing_configs[def_name] = TracingConfig( + input_dump_policy="dump_all", + filter_policy="keep_first_by_axes" + ) + + # Enable tracing + runtime = enable_tracing(dataset_path=output_dir, tracing_configs=tracing_configs) + + print(f"\n✓ Tracing enabled for {len(tracing_configs)} operation types") + + return runtime + + +def generate_traces( + model_id: str, + model_name: str, + tensor_parallel_size: int, + num_prompts: int, + max_tokens: int, + output_dir: str, + quantization: str | None = None, + trust_remote_code: bool = True, + extra_env_vars: dict | None = None, +): + """Generate FlashInfer traces for a specific model.""" + + # Set any extra environment variables + if extra_env_vars: + for key, value in extra_env_vars.items(): + os.environ[key] = value + + # Set FlashInfer environment variables + os.environ.setdefault("VLLM_USE_FLASHINFER", "1") + os.environ.setdefault("VLLM_USE_FLASHINFER_NORM", "1") # Enable FlashInfer RMSNorm + + # Set FlashInfer-Bench tracing environment variables + # CRITICAL: These must be set BEFORE vLLM spawns worker processes + os.environ["FIB_ENABLE_TRACING"] = "1" + os.environ["FIB_DATASET_PATH"] = output_dir + os.environ["FIB_ENABLE_APPLY"] = "1" + + print(f"\n✓ Environment variables set for tracing:") + print(f" FIB_ENABLE_TRACING=1") + print(f" FIB_DATASET_PATH={output_dir}") + print(f" FIB_ENABLE_APPLY=1") + + print(f"\n{'='*70}") + print(f"Generating traces for: {model_name}") + print(f"Model ID: {model_id}") + print(f"TP: {tensor_parallel_size}, Quantization: {quantization or 'none'}") + print(f"Output: {output_dir}") + print(f"{'='*70}\n") + + # Import vLLM + from vllm import LLM, SamplingParams + from transformers import AutoConfig + + # Load model config to get dimensions + print("Loading model config...") + model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code) + print(f"✓ Model config loaded") + print(f" Hidden size: {model_config.hidden_size}") + if hasattr(model_config, 'intermediate_size'): + print(f" Intermediate size: {model_config.intermediate_size}") + if hasattr(model_config, 'num_attention_heads'): + print(f" Attention heads: {model_config.num_attention_heads}") + if hasattr(model_config, 'num_key_value_heads'): + print(f" KV heads: {model_config.num_key_value_heads}") + + # Set up tracing with model-specific definitions + runtime = setup_tracing_for_model(model_config, output_dir) + + # Initialize the model + # NOTE: enforce_eager=True is REQUIRED for FlashInfer-Bench tracing to work. + # When using torch.compile, the adapter wrappers are traced and compiled away, + # preventing the tracing logic from executing at runtime. + # See docs/source/design/flashinfer_integration_issues.md for details. + llm_kwargs = { + "model": model_id, + "tensor_parallel_size": tensor_parallel_size, + "max_model_len": 2048, + "trust_remote_code": trust_remote_code, + "gpu_memory_utilization": 0.7, + "enforce_eager": True, # REQUIRED for FlashInfer-Bench tracing + } + + if quantization: + llm_kwargs["quantization"] = quantization + + print(f"\nInitializing LLM (enforce_eager=True required for tracing)...") + llm = LLM(**llm_kwargs) + + # Generate diverse prompts + prompts = [ + "What is 2+2?", + "Hello!", + "What is the capital of France? Please answer briefly.", + "Write a haiku about programming.", + "Explain machine learning in simple terms.", + "Write a short story about a robot.", + "Compare the Renaissance and Enlightenment periods.", + ] + + test_prompts = [prompts[i % len(prompts)] for i in range(num_prompts)] + + sampling_params = SamplingParams( + temperature=0.7, + top_p=0.9, + max_tokens=max_tokens, + ) + + print(f"\nRunning inference with {len(test_prompts)} prompts...") + outputs = llm.generate(test_prompts, sampling_params) + + print(f"\n✓ Generated {len(outputs)} outputs") + if outputs: + print(f"Sample: {outputs[0].outputs[0].text[:100]}...") + + del llm + + # Flush traces + if runtime: + runtime.flush() + + print(f"\n✓ Traces flushed to {output_dir}") + + # Check if any traces were actually captured + import glob + workload_files = glob.glob(os.path.join(output_dir, "workloads/**/*.jsonl"), recursive=True) + trace_count = 0 + for f in workload_files: + with open(f) as fh: + trace_count += len(fh.readlines()) + + if trace_count > 0: + print(f"✓ Captured {trace_count} workload trace(s)") + for f in workload_files: + with open(f) as fh: + count = len(fh.readlines()) + print(f" - {os.path.basename(f)}: {count} traces") + else: + print(f"\n⚠ No traces captured.") + print(f"This may happen if:") + print(f" 1. FlashInfer operators are not being used (check VLLM_USE_FLASHINFER=1)") + print(f" 2. The adapters are not matching the operator signatures") + print(f" 3. The model architecture doesn't use the traced operators") + + return trace_count > 0 + + +def main(): + parser = argparse.ArgumentParser( + description="Generate FlashInfer Bench traces from vLLM inference" + ) + parser.add_argument( + "--model", type=str, default="qwen", + choices=["qwen", "llama", "gpt-oss", "all"], + help="Model to generate traces for", + ) + parser.add_argument( + "--output-dir", type=str, default=None, + help="Directory to store traces", + ) + parser.add_argument( + "--num-prompts", type=int, default=5, + help="Number of prompts to run", + ) + parser.add_argument( + "--max-tokens", type=int, default=128, + help="Maximum tokens per prompt", + ) + parser.add_argument( + "--tp", type=int, default=2, + help="Tensor parallel size", + ) + + args = parser.parse_args() + + print("=" * 70) + print("FLASHINFER BENCH TRACE GENERATION") + print("=" * 70) + print("\nNOTE: This script uses enforce_eager=True because FlashInfer-Bench") + print(" tracing is not compatible with torch.compile.") + print(" See docs/source/design/flashinfer_integration_issues.md") + + # Check compatibility first + is_compatible, message = check_compatibility() + if not is_compatible: + print(f"\n✗ ERROR: {message}") + return 1 + + print(f"\n✓ {message}") + + output_dir = args.output_dir or str(Path.home() / ".cache" / "flashinfer_bench" / "vllm_traces") + Path(output_dir).mkdir(parents=True, exist_ok=True) + print(f"Output directory: {output_dir}") + + models = { + "qwen": { + "model_id": "Qwen/Qwen3-30B-A3B-Instruct-2507", + "model_name": "Qwen3-30B-A3B MoE", + "tensor_parallel_size": args.tp, + }, + "llama": { + "model_id": "meta-llama/Llama-3.1-70B-Instruct", + "model_name": "Llama-3.1-70B", + "tensor_parallel_size": max(args.tp, 4), + }, + "gpt-oss": { + "model_id": "openai/gpt-oss-120b", + "model_name": "GPT-OSS-120B", + "tensor_parallel_size": max(args.tp, 4), + "extra_env_vars": {"VLLM_ATTENTION_BACKEND": "FLASH_ATTN"}, + }, + } + + models_to_run = list(models.keys()) if args.model == "all" else [args.model] + results = {} + + for model_key in models_to_run: + config = models[model_key] + model_output_dir = os.path.join(output_dir, model_key) + Path(model_output_dir).mkdir(parents=True, exist_ok=True) + + try: + success = generate_traces( + model_id=config["model_id"], + model_name=config["model_name"], + tensor_parallel_size=config["tensor_parallel_size"], + num_prompts=args.num_prompts, + max_tokens=args.max_tokens, + output_dir=model_output_dir, + extra_env_vars=config.get("extra_env_vars"), + ) + results[model_key] = "✓ SUCCESS" if success else "⚠ NO TRACES" + except Exception as e: + print(f"\n✗ Error: {e}") + import traceback + traceback.print_exc() + results[model_key] = f"✗ ERROR: {e}" + + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + for model_key, result in results.items(): + print(f" {model_key}: {result}") + + print(f"\nTraces saved to: {output_dir}") + print("\nTraced operators:") + print(" Attention:") + print(" ✓ GQA Paged Decode/Prefill") + print(" ✓ GQA Ragged Prefill") + print(" ✓ MLA Paged (DeepSeek)") + print(" Normalization:") + print(" ✓ RMSNorm (fused_add_rmsnorm)") + print(" Sampling:") + print(" ✓ Top-k/Top-p sampling") + print(" Activations:") + print(" ✓ SiLU, GELU, GELU-tanh") + print(" Positional Encoding:") + print(" ✓ RoPE") + print(" MoE:") + print(" ✓ CUTLASS MoE, TRT-LLM FP8/FP4 MoE") + print(" GEMM (quantized models only):") + print(" ✓ mm_fp4 (FP4 matmul)") + print(" ✓ bmm_fp8 (FP8 batched matmul)") + print(" ✓ grouped_gemm_nt_masked (MoE CUTEDSL)") + print(" Communication:") + print(" ✓ AllReduce, MnnvlMoe All2All") + print("\nTo analyze traces:") + print(f" flashinfer-bench run --local {output_dir}") + + return 0 if all("SUCCESS" in r for r in results.values()) else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/kernels/run_flashinfer_test.py b/tests/kernels/run_flashinfer_test.py new file mode 100644 index 000000000000..16ecc0cda9f1 --- /dev/null +++ b/tests/kernels/run_flashinfer_test.py @@ -0,0 +1,563 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Simple standalone script to test FlashInfer integration with vLLM. + +This script verifies that FlashInfer is properly integrated and used for +all supported operators when running inference. + +Note: VLLM_USE_FLASHINFER is automatically set to 1 by this script. + AllReduce fusion is automatically enabled for TP >= 2. + +Usage: + # Test with a smaller model (Qwen3 MoE - fits on fewer GPUs) + python tests/kernels/run_flashinfer_test.py --model qwen + + # Test with Llama-70B (requires 4+ GPUs) + python tests/kernels/run_flashinfer_test.py --model llama + + # Test with GPT-OSS-120B (OpenAI's open-source model with MXFP4 quantization) + python tests/kernels/run_flashinfer_test.py --model gpt-oss + + # Test all models + python tests/kernels/run_flashinfer_test.py --model all + + # Test with FP8 quantization + python tests/kernels/run_flashinfer_test.py --model qwen --fp8 + + # Specify tensor parallel size + python tests/kernels/run_flashinfer_test.py --model qwen --tp 2 + + # Disable AllReduce fusion (if needed) + VLLM_USE_FLASHINFER_ALLREDUCE=0 python tests/kernels/run_flashinfer_test.py --model qwen --tp 2 + +Related scripts: + # Generate FlashInfer-Bench traces for kernel optimization + python tests/kernels/generate_flashinfer_traces.py --model qwen + + # NOTE: Tracing requires enforce_eager=True (not compatible with torch.compile) + # See docs/source/design/flashinfer_integration_issues.md for details +""" + +import argparse +import os +import sys +import time + +# This script is for testing FlashInfer - always enable it +os.environ["VLLM_USE_FLASHINFER"] = "1" + +# Set logging level to see FlashInfer activation messages +os.environ.setdefault("VLLM_LOGGING_LEVEL", "INFO") + + +def check_prerequisites(): + """Check that all prerequisites are met.""" + print("=" * 70) + print("CHECKING PREREQUISITES") + print("=" * 70) + + # Check CUDA + import torch + if not torch.cuda.is_available(): + print("❌ CUDA is not available!") + return False + + gpu_count = torch.cuda.device_count() + print(f"✓ CUDA available with {gpu_count} GPU(s)") + + for i in range(gpu_count): + props = torch.cuda.get_device_properties(i) + print(f" GPU {i}: {props.name} (SM {props.major}.{props.minor}, " + f"{props.total_memory / 1024**3:.1f} GB)") + + # Check FlashInfer + try: + import flashinfer + print(f"✓ FlashInfer installed (version: {flashinfer.__version__})") + except ImportError: + print("❌ FlashInfer is not installed!") + print(" Install with: pip install flashinfer-python") + return False + + # Check FlashInfer modules (core features) + modules_to_check = [ + ("flashinfer.norm", "rmsnorm", True), # (module, func, required) + ("flashinfer.activation", "silu_and_mul", True), + ("flashinfer.sampling", "top_k_top_p_sampling_from_probs", True), + # Note: allreduce_fusion exists in 0.5.2/0.5.3 but has JIT compilation issues + ("flashinfer.comm", "trtllm_allreduce_fusion", False), + ] + + for module_name, func_name, required in modules_to_check: + try: + module = __import__(module_name, fromlist=[func_name]) + if hasattr(module, func_name): + note = " (has JIT issues in 0.5.2/0.5.3)" if "allreduce" in func_name else "" + print(f" ✓ {module_name}.{func_name}{note}") + else: + status = "⚠" if not required else "❌" + print(f" {status} {module_name}.{func_name} not found") + except ImportError as e: + status = "⚠" if not required else "❌" + print(f" {status} {module_name} import failed: {e}") + + # Check vLLM environment variables + import vllm.envs as envs + + print("\n" + "=" * 70) + print("VLLM FLASHINFER ENVIRONMENT VARIABLES") + print("=" * 70) + + env_checks = [ + ("VLLM_USE_FLASHINFER", envs.VLLM_USE_FLASHINFER, "master switch"), + ("VLLM_ATTENTION_BACKEND", envs.VLLM_ATTENTION_BACKEND, "attention"), + ("VLLM_USE_FLASHINFER_SAMPLER", envs.VLLM_USE_FLASHINFER_SAMPLER, "sampling"), + ("VLLM_USE_FLASHINFER_NORM", envs.VLLM_USE_FLASHINFER_NORM, "RMSNorm"), + ("VLLM_USE_FLASHINFER_ACTIVATION", envs.VLLM_USE_FLASHINFER_ACTIVATION, "activations"), + ("VLLM_USE_FLASHINFER_ALLREDUCE", envs.VLLM_USE_FLASHINFER_ALLREDUCE, "allreduce"), + ("VLLM_USE_FLASHINFER_MOE_FP16", envs.VLLM_USE_FLASHINFER_MOE_FP16, "MoE FP16"), + ("VLLM_USE_FLASHINFER_MOE_FP8", envs.VLLM_USE_FLASHINFER_MOE_FP8, "MoE FP8"), + ("VLLM_USE_FLASHINFER_MOE_FP4", envs.VLLM_USE_FLASHINFER_MOE_FP4, "MoE FP4"), + ("VLLM_ALL2ALL_BACKEND", envs.VLLM_ALL2ALL_BACKEND, "all2all"), + ] + + all_set = True + for name, value, description in env_checks: + if name == "VLLM_USE_FLASHINFER" and not value: + print(f"❌ {name} = {value} (MUST be True!)") + all_set = False + elif name == "VLLM_ATTENTION_BACKEND": + expected = "FLASHINFER" + status = "✓" if value == expected else "⚠" + print(f"{status} {name} = {value}") + elif name == "VLLM_USE_FLASHINFER_ALLREDUCE": + # Allreduce is NOT auto-enabled (has JIT issues in FlashInfer 0.5.2/0.5.3) + status = "✓" if value else "○" + note = " (not auto-enabled, has JIT issues in FI 0.5.2/0.5.3)" if not value else "" + print(f"{status} {name} = {value}{note}") + else: + status = "✓" if value else "○" + print(f"{status} {name} = {value}") + + if not envs.VLLM_USE_FLASHINFER: + print("\n❌ VLLM_USE_FLASHINFER is not set!") + print(" Run with: VLLM_USE_FLASHINFER=1 python run_flashinfer_test.py ...") + return False + + return True + + +# Return values for run_inference_test +RESULT_PASS = "PASS" +RESULT_FAIL = "FAIL" +RESULT_SKIPPED = "SKIPPED" + + +def run_inference_test( + model_id: str, + tp_size: int, + quantization: str | None = None, + enable_allreduce: bool = False, +): + """Run inference test with the specified model. + + Returns: + Tuple of (result_status, skip_reason) where result_status is one of + RESULT_PASS, RESULT_FAIL, or RESULT_SKIPPED. skip_reason is set only + when RESULT_SKIPPED is returned. + """ + from vllm import LLM, SamplingParams + + print("\n" + "=" * 70) + print(f"RUNNING INFERENCE TEST") + print("=" * 70) + print(f"Model: {model_id}") + print(f"Tensor Parallel Size: {tp_size}") + print(f"Quantization: {quantization or 'None'}") + print(f"AllReduce Fusion: {'ENABLED' if enable_allreduce else 'DISABLED'}") + print("=" * 70) + + # Test prompts + prompts = [ + "What is the capital of France? Answer in one sentence.", + "Explain what artificial intelligence is in simple terms.", + "Write a short poem about coding.", + ] + + # Initialize LLM + print("\nInitializing LLM...") + start_time = time.time() + + llm_kwargs = { + "model": model_id, + "tensor_parallel_size": tp_size, + "max_model_len": 4096, + "trust_remote_code": True, + "enforce_eager": False, # Enable compilation for fusion passes + } + + if quantization: + llm_kwargs["quantization"] = quantization + + # Add AllReduce fusion if requested + if enable_allreduce and tp_size >= 2: + from vllm.config import CompilationConfig, PassConfig + pass_config = PassConfig( + enable_fi_allreduce_fusion=True, + enable_noop=True, + ) + llm_kwargs["compilation_config"] = CompilationConfig(pass_config=pass_config) + print("✓ AllReduce fusion compilation config enabled") + + try: + llm = LLM(**llm_kwargs) + except NotImplementedError as e: + # Handle FlashInfer FP8 CUDA version incompatibility + error_msg = str(e) + if "FP8 block scaling not implemented" in error_msg: + print(f"⚠ FlashInfer FP8 kernel not available: {error_msg}") + return RESULT_SKIPPED, "FlashInfer FP8 MoE requires newer CUDA (see FlashInfer package)" + raise + except RuntimeError as e: + # The error may be wrapped in a RuntimeError from worker process + error_msg = str(e) + if "FP8 block scaling not implemented" in error_msg: + print(f"⚠ FlashInfer FP8 kernel not available: {error_msg}") + return RESULT_SKIPPED, "FlashInfer FP8 MoE requires newer CUDA (see FlashInfer package)" + print(f"❌ Failed to initialize LLM: {e}") + import traceback + traceback.print_exc() + return RESULT_FAIL, None + except Exception as e: + print(f"❌ Failed to initialize LLM: {e}") + import traceback + traceback.print_exc() + return RESULT_FAIL, None + + init_time = time.time() - start_time + print(f"✓ LLM initialized in {init_time:.1f}s") + + # Set up sampling parameters + sampling_params = SamplingParams( + temperature=0.7, + top_p=0.9, + top_k=50, + max_tokens=150, + ) + + # Run inference + print("\nRunning inference...") + start_time = time.time() + + try: + outputs = llm.generate(prompts, sampling_params) + except Exception as e: + print(f"❌ Inference failed: {e}") + import traceback + traceback.print_exc() + return RESULT_FAIL, None + + inference_time = time.time() - start_time + print(f"✓ Inference completed in {inference_time:.1f}s") + + # Display outputs + print("\n" + "-" * 70) + print("GENERATED OUTPUTS") + print("-" * 70) + + all_valid = True + for i, output in enumerate(outputs): + print(f"\n[Prompt {i+1}]: {prompts[i]}") + print(f"[Output {i+1}]: {output.outputs[0].text}") + + # Basic validation + generated = output.outputs[0].text.strip() + if len(generated) < 10: + print(f"⚠ Warning: Output seems too short") + all_valid = False + elif "error" in generated.lower() or "exception" in generated.lower(): + print(f"⚠ Warning: Output may contain error message") + + # Cleanup + print("\nCleaning up...") + del llm + import torch + torch.cuda.empty_cache() + + return RESULT_PASS if all_valid else RESULT_FAIL, None + + +def main(): + parser = argparse.ArgumentParser( + description="Test FlashInfer integration with vLLM", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__ + ) + parser.add_argument( + "--model", + choices=["qwen", "llama", "gpt-oss", "all"], + default="qwen", + help="Model to test (default: qwen). gpt-oss uses MXFP4 quantization." + ) + parser.add_argument( + "--tp", + type=int, + default=None, + help="Tensor parallel size (default: auto-detect based on GPU count)" + ) + parser.add_argument( + "--fp8", + action="store_true", + help="Use FP8 quantization (uses pre-quantized FP8 model variants)" + ) + parser.add_argument( + "--skip-prereq", + action="store_true", + help="Skip prerequisite checks" + ) + parser.add_argument( + "--enable-allreduce", + action="store_true", + help="Force enable AllReduce fusion (enabled by default for TP >= 2)" + ) + + args = parser.parse_args() + + # Check prerequisites + if not args.skip_prereq: + if not check_prerequisites(): + sys.exit(1) + + import torch + gpu_count = torch.cuda.device_count() + + # Get CUDA version for compatibility checks + # Use runtime_version() which is more reliable than torch.version.cuda + cuda_version = None + cuda_version_str = "unknown" + if torch.cuda.is_available(): + try: + # Get CUDA runtime version (e.g., 12060 for CUDA 12.6.0) + runtime_version = torch.cuda.runtime_version() + if runtime_version: + major = runtime_version // 1000 + minor = (runtime_version % 1000) // 10 + cuda_version = (major, minor) + cuda_version_str = f"{major}.{minor}" + except Exception: + # Fallback to compiled version string + cuda_str = torch.version.cuda + if cuda_str: + parts = cuda_str.split('.') + cuda_version = (int(parts[0]), int(parts[1])) + cuda_version_str = cuda_str + + print(f"\nDetected CUDA version: {cuda_version_str}") + + # Model configurations + # When --fp8 is passed, use the FP8 variant of the model + # Note: For FP8 models, the quantization is auto-detected from model config + # (e.g., RedHatAI models use compressed-tensors, Qwen FP8 uses fp8) + # We should NOT pass quantization="fp8" for models that have it in their config. + models = { + "qwen": { + "model_id": "Qwen/Qwen3-30B-A3B-Instruct-2507", + "model_id_fp8": "Qwen/Qwen3-30B-A3B-Instruct-2507-FP8", + "min_tp": 1, + "recommended_tp": 2, + # Qwen FP8 needs explicit quantization arg + "fp8_needs_quant_arg": True, + }, + "llama": { + "model_id": "meta-llama/Llama-3.1-70B-Instruct", + "model_id_fp8": "RedHatAI/Meta-Llama-3.1-70B-Instruct-FP8", + "min_tp": 4, + "recommended_tp": 8, + # RedHatAI uses compressed-tensors, auto-detected from config + "fp8_needs_quant_arg": False, + }, + # GPT-OSS-120B: OpenAI's open-source reasoning model + # See: https://docs.vllm.ai/projects/recipes/en/latest/OpenAI/GPT-OSS.html + # Uses MXFP4 quantization with different backends per GPU: + # + # ATTENTION SINKS: + # - GPT-OSS uses "attention sinks" for memory efficiency + # - FlashInfer supports sinks ONLY on Blackwell (SM100) via TRTLLM attention + # - On Hopper (SM90), must use FlashAttention3 for attention + # + # MoE: + # - Blackwell (SM100): FlashInfer CUTLASS MXFP4 MoE + # - Hopper (SM90): Marlin/Triton MXFP4 MoE (FI SM90 backend is broken) + "gpt-oss": { + "model_id": "openai/gpt-oss-120b", + "model_id_fp8": None, # No FP8 variant, uses MXFP4 + "min_tp": 1, + "recommended_tp": 2, # TP=2 recommended for best performance + "fp8_needs_quant_arg": False, + # Special handling: model uses built-in MXFP4 quantization + "is_mxfp4": True, + # Model uses attention sinks: + # - Hopper: FlashInfer doesn't support sinks, use FlashAttention3 + # - Blackwell: FlashInfer supports sinks via TRTLLM attention + "skip_flashinfer_attention_hopper_only": True, + # FlashInfer MXFP4 MoE is ONLY for Blackwell (SM 100). + # On Hopper (SM 90), vLLM's SM90_FI_MXFP4_BF16 backend is broken - + # it tries to use SM100 kernels. Use Marlin/Triton instead. + "env_vars_blackwell_only": { + "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": "1", + }, + # Empty env_vars for Hopper - uses Marlin/Triton MXFP4 by default + "env_vars": {}, + }, + } + + # Determine which models to test + if args.model == "all": + models_to_test = list(models.keys()) + else: + models_to_test = [args.model] + + # Run tests + results = {} + + for model_name in models_to_test: + config = models[model_name] + + # Determine TP size + if args.tp: + tp_size = args.tp + else: + tp_size = min(gpu_count, config["recommended_tp"]) + + # Enable AllReduce by default if TP >= 2, unless explicitly disabled + if args.enable_allreduce or (tp_size >= 2 and os.getenv("VLLM_USE_FLASHINFER_ALLREDUCE") != "0"): + enable_allreduce = True + os.environ["VLLM_USE_FLASHINFER_ALLREDUCE"] = "1" + if tp_size >= 2: + print(f"\n✓ AllReduce Fusion ENABLED for {model_name} (TP={tp_size})") + if not args.enable_allreduce: + print(" (automatically enabled for TP >= 2, set VLLM_USE_FLASHINFER_ALLREDUCE=0 to disable)") + else: + enable_allreduce = False + if tp_size >= 2: + print(f"\n○ AllReduce Fusion DISABLED for {model_name} (set VLLM_USE_FLASHINFER_ALLREDUCE=1 to enable)") + + # Check if we have enough GPUs + if gpu_count < config["min_tp"]: + print(f"\n⚠ Skipping {model_name}: requires {config['min_tp']} GPUs, " + f"but only {gpu_count} available") + results[model_name] = RESULT_SKIPPED + continue + + # Select model ID and quantization based on --fp8 flag + result_name = f"{model_name}-fp8" if args.fp8 else model_name + + if args.fp8: + # Check if model has FP8 variant + if config.get("model_id_fp8") is None: + print(f"\n⚠ Skipping {result_name}: Model does not have an FP8 variant") + if config.get("is_mxfp4"): + print(f" Note: {model_name} uses MXFP4 quantization instead of FP8") + results[result_name] = RESULT_SKIPPED + continue + + # Skip Qwen FP8 - FlashInfer FP8 MoE requires CUDA 12.7+ but the + # FlashInfer package is compiled against CUDA 12.6 + if model_name == "qwen": + print(f"\n⚠ Skipping {result_name}: FlashInfer FP8 MoE kernel requires " + "CUDA 12.7+ (FlashInfer package compiled against older CUDA)") + results[result_name] = RESULT_SKIPPED + continue + + model_id = config["model_id_fp8"] + + # Only pass quantization arg if the model needs it + # (some models auto-detect from config, e.g., compressed-tensors) + quantization = "fp8" if config.get("fp8_needs_quant_arg", True) else None + else: + model_id = config["model_id"] + quantization = None + + # Check GPU architecture for architecture-specific env vars + device_capability = torch.cuda.get_device_capability(0) + is_blackwell = device_capability[0] >= 10 # SM 100+ + + # Set model-specific environment variables + model_env_vars = config.get("env_vars", {}) + for env_name, env_value in model_env_vars.items(): + os.environ[env_name] = env_value + print(f" Setting {env_name}={env_value}") + + # Set Blackwell-only environment variables (e.g., FlashInfer MXFP4 MoE) + if is_blackwell: + blackwell_env_vars = config.get("env_vars_blackwell_only", {}) + for env_name, env_value in blackwell_env_vars.items(): + os.environ[env_name] = env_value + print(f" Setting {env_name}={env_value} (Blackwell only)") + elif config.get("env_vars_blackwell_only"): + print(f" Note: Skipping FlashInfer MXFP4 MoE (requires Blackwell/SM100+)") + print(f" Using Marlin/Triton MXFP4 MoE instead (Hopper/SM90)") + # Explicitly disable FlashInfer MXFP4 MoE on non-Blackwell to use Marlin/Triton + # Note: VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 is the correct env var name + os.environ["VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"] = "0" + os.environ["VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8"] = "0" + os.environ["VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS"] = "0" + + # Check if FlashInfer attention should be skipped for attention sinks + # GPT-OSS uses attention sinks: + # - On Hopper (SM90): FlashInfer doesn't support sinks, use FlashAttention3 + # - On Blackwell (SM100): FlashInfer supports sinks via TRTLLM attention + original_attn_backend = os.environ.get("VLLM_ATTENTION_BACKEND") + if config.get("skip_flashinfer_attention_hopper_only", False) and not is_blackwell: + # Override attention backend on Hopper - FlashInfer doesn't support + # attention sinks without TRTLLM (which requires SM100). + os.environ["VLLM_ATTENTION_BACKEND"] = "FLASH_ATTN" + print(f" Note: Using FLASH_ATTN for attention (model uses attention sinks)") + print(f" FlashInfer supports sinks only on Blackwell via TRTLLM") + elif config.get("skip_flashinfer_attention_hopper_only", False) and is_blackwell: + print(f" Note: Using FlashInfer attention with TRTLLM (supports sinks on Blackwell)") + + # Run test + result, skip_reason = run_inference_test(model_id, tp_size, quantization, enable_allreduce) + + # Restore attention backend for next test + if original_attn_backend is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = original_attn_backend + elif "VLLM_ATTENTION_BACKEND" in os.environ: + del os.environ["VLLM_ATTENTION_BACKEND"] + + if result == RESULT_SKIPPED and skip_reason: + print(f"\n⚠ {result_name}: SKIPPED - {skip_reason}") + + results[result_name] = result + + # Print summary + print("\n" + "=" * 70) + print("TEST SUMMARY") + print("=" * 70) + + all_passed = True + for model_name, result in results.items(): + if result == RESULT_PASS: + print(f"✓ {model_name}: {result}") + elif result == RESULT_SKIPPED: + print(f"○ {model_name}: {result}") + else: + print(f"❌ {model_name}: {result}") + all_passed = False + + print("=" * 70) + + if all_passed: + print("All tests PASSED! FlashInfer is working correctly.") + sys.exit(0) + else: + print("Some tests FAILED. Check the output above for details.") + sys.exit(1) + + +if __name__ == "__main__": + main() + diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 9342564aa3d3..1b4fc905ef98 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -459,6 +459,23 @@ def __post_init__(self): if self.compilation_config.pass_config.enable_async_tp: self.compilation_config.pass_config.enable_sequence_parallelism = True + # Log when VLLM_USE_FLASHINFER master switch is enabled + if envs.VLLM_USE_FLASHINFER: + logger.info( + "VLLM_USE_FLASHINFER is enabled. FlashInfer will be used for: " + "attention, sampling, MoE, RMSNorm, activations, and all2all " + "(where applicable and supported by hardware)." + ) + + # NOTE: FlashInfer allreduce fusion (enable_fi_allreduce_fusion) is NOT + # auto-enabled here because it has known compatibility issues with + # FlashInfer 0.5.2/0.5.3 (the versions vLLM supports). The Python bindings + # exist but JIT compilation fails due to CUDA struct mismatches. + # Users who want to enable this feature should: + # 1. Set VLLM_USE_FLASHINFER_ALLREDUCE=1 explicitly + # 2. Use compilation_config.pass_config.enable_fi_allreduce_fusion=True + # 3. Verify they have a compatible FlashInfer build + if current_platform.support_static_graph_mode(): # if cudagraph_mode is not explicitly set by users, set default # value diff --git a/vllm/env_override.py b/vllm/env_override.py index 9ae1af3af46c..1b7987d542aa 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -21,7 +21,11 @@ # see https://github.com/vllm-project/vllm/issues/10480 os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" # see https://github.com/vllm-project/vllm/issues/10619 -torch._inductor.config.compile_threads = 1 +try: + torch._inductor.config.compile_threads = 1 +except AttributeError: + # torch._inductor.config may not exist in all PyTorch versions + pass # =================================================== # torch 2.9 Inductor PythonWrapperCodegen monkeypatch diff --git a/vllm/envs.py b/vllm/envs.py index 56558548d398..b9a6aae9e87c 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -34,6 +34,7 @@ VLLM_CONFIG_ROOT: str = os.path.expanduser("~/.config/vllm") VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai" VLLM_NO_USAGE_STATS: bool = False + VLLM_USE_FLASHINFER: bool = False VLLM_DISABLE_FLASHINFER_PREFILL: bool = False VLLM_DO_NOT_TRACK: bool = False VLLM_USAGE_SOURCE: str = "" @@ -162,6 +163,9 @@ VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False + VLLM_USE_FLASHINFER_NORM: bool = False + VLLM_USE_FLASHINFER_ACTIVATION: bool = False + VLLM_USE_FLASHINFER_ALLREDUCE: bool = False VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = ( "latency" ) @@ -599,6 +603,12 @@ def get_vllm_port() -> int | None: "VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai" ), "VLLM_NO_USAGE_STATS": lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1", + # Master switch to enable all FlashInfer backends/kernels. + # When set to 1, enables FlashInfer for: attention, sampling, MoE, + # RMSNorm, activations, allreduce, and all2all. + "VLLM_USE_FLASHINFER": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER", "0")) + ), "VLLM_DISABLE_FLASHINFER_PREFILL": lambda: os.environ.get( "VLLM_DISABLE_FLASHINFER_PREFILL", "0" ) @@ -646,21 +656,23 @@ def get_vllm_port() -> int | None: # - "FLASHINFER_MLA": use FlashInfer for MLA # - "CUTLASS_MLA": use CUTLASS for MLA # All possible options loaded dynamically from AttentionBackendEnum + # Falls back to FLASHINFER when VLLM_USE_FLASHINFER is set. "VLLM_ATTENTION_BACKEND": env_with_choices( "VLLM_ATTENTION_BACKEND", - None, + "FLASHINFER" if os.getenv("VLLM_USE_FLASHINFER", "0") == "1" else None, lambda: list( __import__( "vllm.attention.backends.registry", fromlist=["AttentionBackendEnum"] ).AttentionBackendEnum.__members__.keys() ), ), - # If set, vllm will use flashinfer sampler + # If set, vllm will use flashinfer sampler. + # Falls back to VLLM_USE_FLASHINFER if not explicitly set. "VLLM_USE_FLASHINFER_SAMPLER": lambda: bool( int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]) ) if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ - else None, + else (True if os.getenv("VLLM_USE_FLASHINFER", "0") == "1" else None), # Pipeline stage partition strategy "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), # (CPU backend only) CPU key-value cache space. @@ -1178,33 +1190,66 @@ def get_vllm_port() -> int | None: int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1")) ), # Allow use of FlashInfer MoE kernels for fused moe ops. + # Falls back to VLLM_USE_FLASHINFER if not explicitly set. "VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool( - int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0")) + int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", + os.getenv("VLLM_USE_FLASHINFER", "0"))) ), # Allow use of FlashInfer MoE kernels for fused moe ops. + # Falls back to VLLM_USE_FLASHINFER if not explicitly set. "VLLM_USE_FLASHINFER_MOE_FP8": lambda: bool( - int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0")) + int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", + os.getenv("VLLM_USE_FLASHINFER", "0"))) ), # Allow use of FlashInfer CUTLASS kernels for fused moe ops. + # Falls back to VLLM_USE_FLASHINFER if not explicitly set. "VLLM_USE_FLASHINFER_MOE_FP4": lambda: bool( - int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0")) + int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", + os.getenv("VLLM_USE_FLASHINFER", "0"))) + ), + # Allow use of FlashInfer RMSNorm/LayerNorm kernels. + # Falls back to VLLM_USE_FLASHINFER if not explicitly set. + "VLLM_USE_FLASHINFER_NORM": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_NORM", + os.getenv("VLLM_USE_FLASHINFER", "0"))) + ), + # Allow use of FlashInfer activation kernels (silu_and_mul, gelu_and_mul). + # Falls back to VLLM_USE_FLASHINFER if not explicitly set. + "VLLM_USE_FLASHINFER_ACTIVATION": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_ACTIVATION", + os.getenv("VLLM_USE_FLASHINFER", "0"))) + ), + # If set to 1, enable FlashInfer fused allreduce + RMSNorm for tensor + # parallel inference. Requires SM >= 90 (Hopper), TP > 1. + # NOTE: This is NOT auto-enabled by VLLM_USE_FLASHINFER because + # FlashInfer 0.5.2/0.5.3 (versions vLLM supports) have compatibility issues + # with the allreduce fusion - the Python bindings exist but JIT compilation + # fails. Only set this if you have verified your FlashInfer build works. + "VLLM_USE_FLASHINFER_ALLREDUCE": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_ALLREDUCE", "0")) ), # If set to 1, use the FlashInfer # MXFP8 (activation) x MXFP4 (weight) MoE backend. + # Falls back to VLLM_USE_FLASHINFER if not explicitly set. "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": lambda: bool( - int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0")) + int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", + os.getenv("VLLM_USE_FLASHINFER", "0"))) ), # If set to 1, use the FlashInfer CUTLASS backend for # MXFP8 (activation) x MXFP4 (weight) MoE. # This is separate from the TRTLLMGEN path controlled by # VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8. + # Falls back to VLLM_USE_FLASHINFER if not explicitly set. "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS": lambda: bool( - int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "0")) + int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", + os.getenv("VLLM_USE_FLASHINFER", "0"))) ), # If set to 1, use the FlashInfer # BF16 (activation) x MXFP4 (weight) MoE backend. + # Falls back to VLLM_USE_FLASHINFER if not explicitly set. "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": lambda: bool( - int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0")) + int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", + os.getenv("VLLM_USE_FLASHINFER", "0"))) ), # Control the cache sized used by the xgrammar compiler. The default # of 512 MB should be enough for roughly 1000 JSON schemas. @@ -1243,9 +1288,12 @@ def get_vllm_port() -> int | None: # - "deepep_high_throughput", use deepep high-throughput kernels # - "deepep_low_latency", use deepep low-latency kernels # - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl + # Falls back to flashinfer_all2allv when VLLM_USE_FLASHINFER is set. "VLLM_ALL2ALL_BACKEND": env_with_choices( "VLLM_ALL2ALL_BACKEND", - "allgather_reducescatter", + "flashinfer_all2allv" + if os.getenv("VLLM_USE_FLASHINFER", "0") == "1" + else "allgather_reducescatter", [ "naive", "pplx", diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 3471ee327cf8..f8437fbaa1a2 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -8,6 +8,7 @@ import torch.nn as nn import torch.nn.functional as F +import vllm.envs as envs from vllm.distributed import ( divide, get_tensor_model_parallel_rank, @@ -18,10 +19,20 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils.collection_utils import LazyDict +from vllm.utils.flashinfer import has_flashinfer logger = init_logger(__name__) +def _use_flashinfer_activation() -> bool: + """Check if FlashInfer activation should be used.""" + return ( + envs.VLLM_USE_FLASHINFER_ACTIVATION + and has_flashinfer() + and current_platform.is_cuda() + ) + + @CustomOp.register("fatrelu_and_mul") class FatreluAndMul(CustomOp): """An activation function for FATReLU. @@ -71,7 +82,10 @@ class SiluAndMul(CustomOp): def __init__(self): super().__init__() - if current_platform.is_cuda_alike(): + self._use_flashinfer = _use_flashinfer_activation() + if self._use_flashinfer: + logger.info_once("Using FlashInfer silu_and_mul activation.") + elif current_platform.is_cuda_alike(): self.op = torch.ops._C.silu_and_mul elif current_platform.is_xpu(): from vllm._ipex_ops import ipex_ops @@ -87,6 +101,11 @@ def forward_native(x: torch.Tensor) -> torch.Tensor: return F.silu(x[..., :d]) * x[..., d:] def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + if self._use_flashinfer: + from flashinfer.activation import silu_and_mul + + return silu_and_mul(x) + d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) @@ -204,7 +223,10 @@ def __init__(self, approximate: str = "none"): self.approximate = approximate if approximate not in ("none", "tanh"): raise ValueError(f"Unknown approximate mode: {approximate}") - if current_platform.is_cuda_alike() or current_platform.is_cpu(): + self._use_flashinfer = _use_flashinfer_activation() + if self._use_flashinfer: + logger.info_once("Using FlashInfer gelu_and_mul activation.") + elif current_platform.is_cuda_alike() or current_platform.is_cpu(): if approximate == "none": self.op = torch.ops._C.gelu_and_mul elif approximate == "tanh": @@ -223,6 +245,16 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + if self._use_flashinfer: + if self.approximate == "tanh": + from flashinfer.activation import gelu_tanh_and_mul + + return gelu_tanh_and_mul(x) + else: + from flashinfer.activation import gelu_and_mul + + return gelu_and_mul(x) + d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 8cc374ac9155..4213db43f644 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -6,22 +6,43 @@ import torch.nn as nn import torch.nn.functional as F +import vllm.envs as envs from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.batch_invariant import ( rms_norm_batch_invariant, vllm_is_batch_invariant, ) +from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer + +logger = init_logger(__name__) + + +def _use_flashinfer_norm() -> bool: + """Check if FlashInfer normalization should be used.""" + return ( + envs.VLLM_USE_FLASHINFER_NORM + and has_flashinfer() + and current_platform.is_cuda() + ) def rms_norm( x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float ) -> torch.Tensor: - from vllm import _custom_ops as ops - if vllm_is_batch_invariant(): return rms_norm_batch_invariant(x, weight, variance_epsilon) + + if _use_flashinfer_norm(): + from flashinfer.norm import rmsnorm + + logger.info_once("Using FlashInfer rmsnorm.") + return rmsnorm(x, weight, variance_epsilon) + + from vllm import _custom_ops as ops + out = torch.empty_like(x) ops.rms_norm( out, @@ -38,12 +59,22 @@ def fused_add_rms_norm( weight: torch.Tensor, variance_epsilon: float, ) -> tuple[torch.Tensor, torch.Tensor]: - from vllm import _custom_ops as ops - if vllm_is_batch_invariant(): return rms_norm_batch_invariant( x + residual, weight, variance_epsilon ), x + residual + + if _use_flashinfer_norm(): + from flashinfer.norm import fused_add_rmsnorm + + logger.info_once("Using FlashInfer fused_add_rmsnorm.") + # FlashInfer's fused_add_rmsnorm is in-place and returns None + # It modifies x and residual in-place: x = rmsnorm(x + residual), residual = x + residual + fused_add_rmsnorm(x, residual, weight, variance_epsilon) + return x, residual + + from vllm import _custom_ops as ops + ops.fused_add_rms_norm( x, residual, diff --git a/vllm/v1/worker/worker_base.py b/vllm/v1/worker/worker_base.py index 57e7037e946e..14cf470dbf52 100644 --- a/vllm/v1/worker/worker_base.py +++ b/vllm/v1/worker/worker_base.py @@ -247,6 +247,16 @@ def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None: ) self.vllm_config.enable_trace_function_call_for_thread() + # Initialize FlashInfer-Bench tracing/adapters if environment variables are set + # This must happen early to patch flashinfer functions before they're imported + if os.environ.get("FIB_ENABLE_TRACING") or os.environ.get("FIB_ENABLE_APPLY"): + try: + import flashinfer_bench # noqa: F401 + import logging + logger.info(f"[FLASHINFER-BENCH] Initialized in worker process PID={os.getpid()}") + except ImportError: + pass # flashinfer-bench not installed + from vllm.plugins import load_general_plugins load_general_plugins()