Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
709 changes: 709 additions & 0 deletions docs/source/design/flashinfer_integration_issues.md

Large diffs are not rendered by default.

1,279 changes: 1,279 additions & 0 deletions tests/kernels/generate_flashinfer_traces.py

Large diffs are not rendered by default.

563 changes: 563 additions & 0 deletions tests/kernels/run_flashinfer_test.py

Large diffs are not rendered by default.

17 changes: 17 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,23 @@ def __post_init__(self):
if self.compilation_config.pass_config.enable_async_tp:
self.compilation_config.pass_config.enable_sequence_parallelism = True

# Log when VLLM_USE_FLASHINFER master switch is enabled
if envs.VLLM_USE_FLASHINFER:
logger.info(
"VLLM_USE_FLASHINFER is enabled. FlashInfer will be used for: "
"attention, sampling, MoE, RMSNorm, activations, and all2all "
"(where applicable and supported by hardware)."
)

# NOTE: FlashInfer allreduce fusion (enable_fi_allreduce_fusion) is NOT
# auto-enabled here because it has known compatibility issues with
# FlashInfer 0.5.2/0.5.3 (the versions vLLM supports). The Python bindings
# exist but JIT compilation fails due to CUDA struct mismatches.
# Users who want to enable this feature should:
# 1. Set VLLM_USE_FLASHINFER_ALLREDUCE=1 explicitly
# 2. Use compilation_config.pass_config.enable_fi_allreduce_fusion=True
# 3. Verify they have a compatible FlashInfer build

if current_platform.support_static_graph_mode():
# if cudagraph_mode is not explicitly set by users, set default
# value
Expand Down
6 changes: 5 additions & 1 deletion vllm/env_override.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
# see https://github.com/vllm-project/vllm/issues/10480
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
# see https://github.com/vllm-project/vllm/issues/10619
torch._inductor.config.compile_threads = 1
try:
torch._inductor.config.compile_threads = 1
except AttributeError:
# torch._inductor.config may not exist in all PyTorch versions
pass

# ===================================================
# torch 2.9 Inductor PythonWrapperCodegen monkeypatch
Expand Down
68 changes: 58 additions & 10 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
VLLM_CONFIG_ROOT: str = os.path.expanduser("~/.config/vllm")
VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai"
VLLM_NO_USAGE_STATS: bool = False
VLLM_USE_FLASHINFER: bool = False
VLLM_DISABLE_FLASHINFER_PREFILL: bool = False
VLLM_DO_NOT_TRACK: bool = False
VLLM_USAGE_SOURCE: str = ""
Expand Down Expand Up @@ -162,6 +163,9 @@
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_USE_FLASHINFER_NORM: bool = False
VLLM_USE_FLASHINFER_ACTIVATION: bool = False
VLLM_USE_FLASHINFER_ALLREDUCE: bool = False
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = (
"latency"
)
Expand Down Expand Up @@ -599,6 +603,12 @@ def get_vllm_port() -> int | None:
"VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai"
),
"VLLM_NO_USAGE_STATS": lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1",
# Master switch to enable all FlashInfer backends/kernels.
# When set to 1, enables FlashInfer for: attention, sampling, MoE,
# RMSNorm, activations, allreduce, and all2all.
"VLLM_USE_FLASHINFER": lambda: bool(
int(os.getenv("VLLM_USE_FLASHINFER", "0"))
),
"VLLM_DISABLE_FLASHINFER_PREFILL": lambda: os.environ.get(
"VLLM_DISABLE_FLASHINFER_PREFILL", "0"
)
Expand Down Expand Up @@ -646,21 +656,23 @@ def get_vllm_port() -> int | None:
# - "FLASHINFER_MLA": use FlashInfer for MLA
# - "CUTLASS_MLA": use CUTLASS for MLA
# All possible options loaded dynamically from AttentionBackendEnum
# Falls back to FLASHINFER when VLLM_USE_FLASHINFER is set.
"VLLM_ATTENTION_BACKEND": env_with_choices(
"VLLM_ATTENTION_BACKEND",
None,
"FLASHINFER" if os.getenv("VLLM_USE_FLASHINFER", "0") == "1" else None,
lambda: list(
__import__(
"vllm.attention.backends.registry", fromlist=["AttentionBackendEnum"]
).AttentionBackendEnum.__members__.keys()
),
),
# If set, vllm will use flashinfer sampler
# If set, vllm will use flashinfer sampler.
# Falls back to VLLM_USE_FLASHINFER if not explicitly set.
"VLLM_USE_FLASHINFER_SAMPLER": lambda: bool(
int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"])
)
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ
else None,
else (True if os.getenv("VLLM_USE_FLASHINFER", "0") == "1" else None),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
# (CPU backend only) CPU key-value cache space.
Expand Down Expand Up @@ -1178,33 +1190,66 @@ def get_vllm_port() -> int | None:
int(os.getenv("VLLM_USE_FUSED_MOE_GROUPED_TOPK", "1"))
),
# Allow use of FlashInfer MoE kernels for fused moe ops.
# Falls back to VLLM_USE_FLASHINFER if not explicitly set.
"VLLM_USE_FLASHINFER_MOE_FP16": lambda: bool(
int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16", "0"))
int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP16",
os.getenv("VLLM_USE_FLASHINFER", "0")))
),
# Allow use of FlashInfer MoE kernels for fused moe ops.
# Falls back to VLLM_USE_FLASHINFER if not explicitly set.
"VLLM_USE_FLASHINFER_MOE_FP8": lambda: bool(
int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))
int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8",
os.getenv("VLLM_USE_FLASHINFER", "0")))
),
# Allow use of FlashInfer CUTLASS kernels for fused moe ops.
# Falls back to VLLM_USE_FLASHINFER if not explicitly set.
"VLLM_USE_FLASHINFER_MOE_FP4": lambda: bool(
int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0"))
int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4",
os.getenv("VLLM_USE_FLASHINFER", "0")))
),
# Allow use of FlashInfer RMSNorm/LayerNorm kernels.
# Falls back to VLLM_USE_FLASHINFER if not explicitly set.
"VLLM_USE_FLASHINFER_NORM": lambda: bool(
int(os.getenv("VLLM_USE_FLASHINFER_NORM",
os.getenv("VLLM_USE_FLASHINFER", "0")))
),
# Allow use of FlashInfer activation kernels (silu_and_mul, gelu_and_mul).
# Falls back to VLLM_USE_FLASHINFER if not explicitly set.
"VLLM_USE_FLASHINFER_ACTIVATION": lambda: bool(
int(os.getenv("VLLM_USE_FLASHINFER_ACTIVATION",
os.getenv("VLLM_USE_FLASHINFER", "0")))
),
# If set to 1, enable FlashInfer fused allreduce + RMSNorm for tensor
# parallel inference. Requires SM >= 90 (Hopper), TP > 1.
# NOTE: This is NOT auto-enabled by VLLM_USE_FLASHINFER because
# FlashInfer 0.5.2/0.5.3 (versions vLLM supports) have compatibility issues
# with the allreduce fusion - the Python bindings exist but JIT compilation
# fails. Only set this if you have verified your FlashInfer build works.
"VLLM_USE_FLASHINFER_ALLREDUCE": lambda: bool(
int(os.getenv("VLLM_USE_FLASHINFER_ALLREDUCE", "0"))
),
# If set to 1, use the FlashInfer
# MXFP8 (activation) x MXFP4 (weight) MoE backend.
# Falls back to VLLM_USE_FLASHINFER if not explicitly set.
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8": lambda: bool(
int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))
int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8",
os.getenv("VLLM_USE_FLASHINFER", "0")))
),
# If set to 1, use the FlashInfer CUTLASS backend for
# MXFP8 (activation) x MXFP4 (weight) MoE.
# This is separate from the TRTLLMGEN path controlled by
# VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8.
# Falls back to VLLM_USE_FLASHINFER if not explicitly set.
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS": lambda: bool(
int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS", "0"))
int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS",
os.getenv("VLLM_USE_FLASHINFER", "0")))
),
# If set to 1, use the FlashInfer
# BF16 (activation) x MXFP4 (weight) MoE backend.
# Falls back to VLLM_USE_FLASHINFER if not explicitly set.
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": lambda: bool(
int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0"))
int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16",
os.getenv("VLLM_USE_FLASHINFER", "0")))
),
# Control the cache sized used by the xgrammar compiler. The default
# of 512 MB should be enough for roughly 1000 JSON schemas.
Expand Down Expand Up @@ -1243,9 +1288,12 @@ def get_vllm_port() -> int | None:
# - "deepep_high_throughput", use deepep high-throughput kernels
# - "deepep_low_latency", use deepep low-latency kernels
# - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl
# Falls back to flashinfer_all2allv when VLLM_USE_FLASHINFER is set.
"VLLM_ALL2ALL_BACKEND": env_with_choices(
"VLLM_ALL2ALL_BACKEND",
"allgather_reducescatter",
"flashinfer_all2allv"
if os.getenv("VLLM_USE_FLASHINFER", "0") == "1"
else "allgather_reducescatter",
[
"naive",
"pplx",
Expand Down
36 changes: 34 additions & 2 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn as nn
import torch.nn.functional as F

import vllm.envs as envs
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
Expand All @@ -18,10 +19,20 @@
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.collection_utils import LazyDict
from vllm.utils.flashinfer import has_flashinfer

logger = init_logger(__name__)


def _use_flashinfer_activation() -> bool:
"""Check if FlashInfer activation should be used."""
return (
envs.VLLM_USE_FLASHINFER_ACTIVATION
and has_flashinfer()
and current_platform.is_cuda()
)


@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
"""An activation function for FATReLU.
Expand Down Expand Up @@ -71,7 +82,10 @@ class SiluAndMul(CustomOp):

def __init__(self):
super().__init__()
if current_platform.is_cuda_alike():
self._use_flashinfer = _use_flashinfer_activation()
if self._use_flashinfer:
logger.info_once("Using FlashInfer silu_and_mul activation.")
elif current_platform.is_cuda_alike():
self.op = torch.ops._C.silu_and_mul
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
Expand All @@ -87,6 +101,11 @@ def forward_native(x: torch.Tensor) -> torch.Tensor:
return F.silu(x[..., :d]) * x[..., d:]

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
if self._use_flashinfer:
from flashinfer.activation import silu_and_mul

return silu_and_mul(x)

d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
Expand Down Expand Up @@ -204,7 +223,10 @@ def __init__(self, approximate: str = "none"):
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
if current_platform.is_cuda_alike() or current_platform.is_cpu():
self._use_flashinfer = _use_flashinfer_activation()
if self._use_flashinfer:
logger.info_once("Using FlashInfer gelu_and_mul activation.")
elif current_platform.is_cuda_alike() or current_platform.is_cpu():
if approximate == "none":
self.op = torch.ops._C.gelu_and_mul
elif approximate == "tanh":
Expand All @@ -223,6 +245,16 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
if self._use_flashinfer:
if self.approximate == "tanh":
from flashinfer.activation import gelu_tanh_and_mul

return gelu_tanh_and_mul(x)
else:
from flashinfer.activation import gelu_and_mul

return gelu_and_mul(x)

d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
Expand Down
39 changes: 35 additions & 4 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,43 @@
import torch.nn as nn
import torch.nn.functional as F

import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import (
rms_norm_batch_invariant,
vllm_is_batch_invariant,
)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer

logger = init_logger(__name__)


def _use_flashinfer_norm() -> bool:
"""Check if FlashInfer normalization should be used."""
return (
envs.VLLM_USE_FLASHINFER_NORM
and has_flashinfer()
and current_platform.is_cuda()
)


def rms_norm(
x: torch.Tensor, weight: torch.Tensor, variance_epsilon: float
) -> torch.Tensor:
from vllm import _custom_ops as ops

if vllm_is_batch_invariant():
return rms_norm_batch_invariant(x, weight, variance_epsilon)

if _use_flashinfer_norm():
from flashinfer.norm import rmsnorm

logger.info_once("Using FlashInfer rmsnorm.")
return rmsnorm(x, weight, variance_epsilon)

from vllm import _custom_ops as ops

out = torch.empty_like(x)
ops.rms_norm(
out,
Expand All @@ -38,12 +59,22 @@ def fused_add_rms_norm(
weight: torch.Tensor,
variance_epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops

if vllm_is_batch_invariant():
return rms_norm_batch_invariant(
x + residual, weight, variance_epsilon
), x + residual

if _use_flashinfer_norm():
from flashinfer.norm import fused_add_rmsnorm

logger.info_once("Using FlashInfer fused_add_rmsnorm.")
# FlashInfer's fused_add_rmsnorm is in-place and returns None
# It modifies x and residual in-place: x = rmsnorm(x + residual), residual = x + residual
fused_add_rmsnorm(x, residual, weight, variance_epsilon)
return x, residual

from vllm import _custom_ops as ops

ops.fused_add_rms_norm(
x,
residual,
Expand Down
10 changes: 10 additions & 0 deletions vllm/v1/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,16 @@ def init_worker(self, all_kwargs: list[dict[str, Any]]) -> None:
)
self.vllm_config.enable_trace_function_call_for_thread()

# Initialize FlashInfer-Bench tracing/adapters if environment variables are set
# This must happen early to patch flashinfer functions before they're imported
if os.environ.get("FIB_ENABLE_TRACING") or os.environ.get("FIB_ENABLE_APPLY"):
try:
import flashinfer_bench # noqa: F401
import logging
logger.info(f"[FLASHINFER-BENCH] Initialized in worker process PID={os.getpid()}")
except ImportError:
pass # flashinfer-bench not installed

from vllm.plugins import load_general_plugins

load_general_plugins()
Expand Down