From 644d57d53191b94d9e50a4765891c498790d924b Mon Sep 17 00:00:00 2001 From: CSWYF3634076 <58356743+CSWYF3634076@users.noreply.github.com> Date: Wed, 27 Aug 2025 12:02:55 +0800 Subject: [PATCH 001/125] [Model] Add Ernie4.5 VL Model Support (#22514) Signed-off-by: wangyafeng --- docs/models/supported_models.md | 1 + examples/offline_inference/vision_language.py | 32 + requirements/test.in | 1 + requirements/test.txt | 3 + .../multimodal/processing/test_common.py | 1 + tests/models/registry.py | 2 + .../rotary_embedding/ernie45_vl_rope.py | 72 + .../layers/rotary_embedding/mrope.py | 123 ++ vllm/model_executor/models/ernie45_vl.py | 1504 +++++++++++++++++ vllm/model_executor/models/ernie45_vl_moe.py | 723 ++++++++ vllm/model_executor/models/registry.py | 1 + 11 files changed, 2463 insertions(+) create mode 100644 vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py create mode 100644 vllm/model_executor/models/ernie45_vl.py create mode 100644 vllm/model_executor/models/ernie45_vl_moe.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 74f3a9d1cdb5..19ce8c06724f 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -616,6 +616,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I+ | `CohereLabs/command-a-vision-07-2025`, etc. | | ✅︎ | ✅︎ | | `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | | `DonutForConditionalGeneration`^ | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | | +| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ | | `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 8d97ba266826..4e879666f61d 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -173,6 +173,37 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: ) +# Ernie4.5-VL +def run_ernie45_vl(questions: list[str], modality: str) -> ModelRequestData: + model_name = "baidu/ERNIE-4.5-VL-28B-A3B-PT" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={modality: 1}, + trust_remote_code=True, + ) + + if modality == "image": + placeholder = "Picture 1:<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + elif modality == "video": + placeholder = "Video 1:<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + + prompts = [ + ( + f"<|begin_of_sentence|>User: {question}{placeholder}\n" + "Assistant: " + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Florence2 def run_florence2(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1602,6 +1633,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData: "chameleon": run_chameleon, "command_a_vision": run_command_a_vision, "deepseek_vl_v2": run_deepseek_vl2, + "ernie45_vl": run_ernie45_vl, "florence2": run_florence2, "fuyu": run_fuyu, "gemma3": run_gemma3, diff --git a/requirements/test.in b/requirements/test.in index 098a9242bc3a..92c577c50163 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -54,3 +54,4 @@ runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 terratorch==1.1rc2 # required for PrithviMAE test +decord==0.6.0 diff --git a/requirements/test.txt b/requirements/test.txt index 8b872752d875..0c27c9bb67e8 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -156,6 +156,8 @@ datasets==3.0.2 # mteb decorator==5.1.1 # via librosa +decord==0.6.0 + # via -r requirements/test.in dill==0.3.8 # via # datasets @@ -493,6 +495,7 @@ numpy==1.26.4 # contourpy # cupy-cuda12x # datasets + # decord # einx # encodec # evaluate diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 74ca10d32609..6361cb9b5586 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -272,6 +272,7 @@ def _test_processing_correctness_one( "CohereLabs/command-a-vision-07-2025", "deepseek-ai/deepseek-vl2-tiny", "naver-clova-ix/donut-base-finetuned-docvqa", + "baidu/ERNIE-4.5-VL-28B-A3B-PT", "microsoft/Florence-2-base", "adept/fuyu-8b", "google/gemma-3-4b-it", diff --git a/tests/models/registry.py b/tests/models/registry.py index 20c7c3af6776..f2c09d3e8452 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -396,6 +396,8 @@ def check_available_online( transformers_version_reason="HF model is not compatible.", # noqa: E501 hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501 "Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"), + "Ernie4_5_VLMoeForConditionalGeneration": _HfExamplesInfo("baidu/ERNIE-4.5-VL-28B-A3B-PT", # noqa: E501 + trust_remote_code=True), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501 diff --git a/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py new file mode 100644 index 000000000000..05322e56f262 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from .common import apply_rotary_emb_dispatch +from .mrope import MRotaryEmbedding + + +class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding): + """3D rotary positional embedding. 3D is t:time h:height w:width""" + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert positions.ndim == 1 or positions.ndim == 2 + assert key is not None + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + section_h = self.mrope_section[0] # 22 + section_w = self.mrope_section[1] # 22 + section_t = self.mrope_section[2] # 20 + assert section_h == section_w + # Split according to [h w h w h w h w... t t t...] + section_cos_t = cos[..., -section_t:] + section_cos_h = cos[..., :section_h + section_w:2] + section_cos_w = cos[..., 1:section_h + section_w:2] + + cos_t, cos_h, cos_w = section_cos_t[0], section_cos_h[ + 1], section_cos_w[2] + cos_hw = torch.stack([cos_h, cos_w], + dim=-1).reshape(cos_h.shape[:-1] + + (cos_h.shape[-1] * 2, )) + cos = torch.cat([cos_hw, cos_t], dim=-1) + + section_sin_t = sin[..., -section_t:] + section_sin_h = sin[..., :section_h + section_w:2] + section_sin_w = sin[..., 1:section_h + section_w:2] + + sin_t, sin_h, sin_w = section_sin_t[0], section_sin_h[ + 1], section_sin_w[2] + sin_hw = torch.stack([sin_h, sin_w], + dim=-1).reshape(sin_h.shape[:-1] + + (sin_h.shape[-1] * 2, )) + sin = torch.cat([sin_hw, sin_t], dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, + self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, + self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index a091cfb74329..e374aa9bebf9 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -393,6 +393,15 @@ def get_input_positions_tensor( context_len=context_len, seq_len=seq_len, ) + elif hf_config.model_type in ["ernie4_5_moe_vl", "ernie4_5_vl"]: + return cls._ernie_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + context_len=context_len, + seq_len=seq_len, + ) else: return cls._vl_get_input_positions_tensor( input_tokens=input_tokens, @@ -513,6 +522,120 @@ def _glm4v_get_input_positions_tensor( len(input_tokens)).item() return llm_positions, mrope_position_delta + @classmethod + def _ernie_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, int]: + """Get mrope input positions and delta value for Ernie VL.""" + + image_token_id = hf_config.im_patch_id + video_start_token_id = hf_config.video_start_token_id + video_end_token_id = hf_config.video_end_token_id + spatial_conv_size = hf_config.spatial_conv_size + temporal_conv_size = hf_config.temporal_conv_size + llm_pos_ids_list: list = [] + + if not (image_grid_thw is None and video_grid_thw is None): + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + + input_token_type: list[str] = [] + video_check_flg = False + for token in input_tokens: + if token == video_start_token_id: + video_check_flg = True + elif token == video_end_token_id: + video_check_flg = False + + if (token == image_token_id) and (video_check_flg is False): + input_token_type.append("image") + elif (token == image_token_id) and (video_check_flg is True): + input_token_type.append("video") + else: + input_token_type.append("text") + + input_type_group: list[tuple[str, int, int]] = [] + for key, group_iter in itertools.groupby( + enumerate(input_token_type), lambda x: x[1]): + group_list = list(group_iter) + start_index = group_list[0][0] + end_index = group_list[-1][0] + 1 + input_type_group.append((key, start_index, end_index)) + + video_frame_num = 1 + mm_data_idx = 0 + for modality_type, start_idx, end_idx in input_type_group: + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + if modality_type == "image": + t, h, w = ( + image_grid_thw[mm_data_idx][0], + image_grid_thw[mm_data_idx][1], + image_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_conv_size, w // spatial_conv_size + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + mm_data_idx += 1 + + elif modality_type == "video": + t, h, w = ( + video_grid_thw[mm_data_idx][0], + video_grid_thw[mm_data_idx][1], + video_grid_thw[mm_data_idx][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = (t // + temporal_conv_size, + h // + spatial_conv_size, + w // + spatial_conv_size) + + for t_idx in range(llm_grid_t): + t_index = torch.tensor(t_idx).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view( + 1, -1, 1).expand(1, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view( + 1, 1, -1).expand(1, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx) + + mm_data_idx += 1 + video_frame_num += 1 + + else: + text_len = end_idx - start_idx + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + + st_idx) + video_frame_num = 1 + + else: + text_len = len(input_tokens) + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1)) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:seq_len] + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + return llm_positions, mrope_position_delta + @classmethod def _vl_get_input_positions_tensor( cls, diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py new file mode 100644 index 000000000000..d880fc434e20 --- /dev/null +++ b/vllm/model_executor/models/ernie45_vl.py @@ -0,0 +1,1504 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Erine VL model compatible with HuggingFace weights.""" +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Any, Callable, Literal, Optional, TypedDict, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import ImageSize, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, PromptReplacement, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.platforms import _Backend, current_platform +from vllm.sequence import IntermediateTensors + +from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import (AutoWeightsLoader, WeightsMapper, maybe_prefix, + merge_multimodal_embeddings) +from .vision import get_vit_attn_backend + +logger = init_logger(__name__) + +_MAX_FRAMES_PER_VIDEO = 16 + +# === Vision Transformer === # + + +def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), + "... d two -> ... (d two)", + two=2) + + +def apply_rotary_emb_torch(x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False) -> torch.Tensor: + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [ + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + ], + dim=-1, + ) + + +def apply_rotary_pos_emb_vision(t: torch.Tensor, + freqs: torch.Tensor) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + apply_rotary_emb = apply_rotary_emb_torch + if current_platform.is_cuda(): + from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb + output = apply_rotary_emb(t_, cos, sin).type_as(t) + return output + + +def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): + """All-gather the input tensor interleavely across model parallel group.""" + import torch.distributed as dist + gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)] + dist.all_gather(gathered_tensors, + local_tensor, + group=parallel_state.get_tp_group().device_group) + + gathered_tensors_split = [ + torch.split(tensor, hidden_size // tp_size, -1) + for tensor in gathered_tensors + ] + ordered_tensors = [ + tensor for pair in zip(*gathered_tensors_split) for tensor in pair + ] + result_tensor = torch.cat(ordered_tensors, dim=-1) + return result_tensor + + +class Ernie4_5_VisionAttention(nn.Module): + """VisionAttention using VLLM framework APIs""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Per attention head and per partition values. + self.tp_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, self.tp_size) + + self.qkv = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.hidden_size_per_attention_head, + total_num_heads=num_heads, + total_num_kv_heads=num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj") + + # Detect attention implementation. + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS, + _Backend.ROCM_AITER_FA + }: + raise RuntimeError( + f"Ernie45-VL does not support {self.attn_backend} backend now." + ) + self.is_flash_attn_backend = self.attn_backend in { + _Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA + } + + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # [s, b, 3 * head * head_dim] + seq_len, bs, _ = qkv.shape + if self.tp_size > 1: + qkv = all_gather_interleave(qkv, self.qkv.hidden_size, + self.tp_size) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] + q, k, v = qkv.chunk(3, dim=2) + + # 3 * [s, b, head * head_dim] + if self.tp_size > 1: + splitter = partial(dist_utils.split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] + new_shape = (seq_len, bs, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) + batch_size = q.shape[1] + + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() + for x in (q, k, v)) + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + + if self.is_flash_attn_backend: + # from vllm_flash_attn.flash_attn_interface import ( + # flash_attn_varlen_func) + if self.attn_backend == _Backend.ROCM_AITER_FA: + from aiter import flash_attn_varlen_func + else: + from flash_attn import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0.0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif self.attn_backend == _Backend.TORCH_SDPA: + # Execute attention entry by entry for speed & less VRAM. + outputs = [] + for i in range(1, len(cu_seqlens)): + start_idx = cu_seqlens[i - 1] + end_idx = cu_seqlens[i] + q_i = q[:, start_idx:end_idx] + k_i = k[:, start_idx:end_idx] + v_i = v[:, start_idx:end_idx] + q_i, k_i, v_i = (rearrange(x, "b s h d -> b h s d") + for x in [q_i, k_i, v_i]) + output_i = F.scaled_dot_product_attention(q_i, + k_i, + v_i, + dropout_p=0.0) + output_i = rearrange(output_i, "b h s d -> b s h d ") + outputs.append(output_i) + context_layer = torch.cat(outputs, dim=1) + elif self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, + kv_seqlen=None, + device=q.device) + + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() + + output, _ = self.proj(context_layer) + return output + + +class Ernie4_5_VisionMLP(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: int, + act_layer: type[nn.Module] = QuickGELU, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.fc1 = ColumnParallelLinear(in_features, + hidden_features, + quant_config=quant_config, + prefix=f"{prefix}.fc1") + self.act = act_layer() + self.fc2 = RowParallelLinear(hidden_features, + in_features, + quant_config=quant_config, + prefix=f"{prefix}.fc2") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_parallel, _ = self.fc1(x) + x_parallel = self.act(x_parallel) + x, _ = self.fc2(x_parallel) + return x + + +class Ernie4_5_VisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float, + act_layer: type[nn.Module] = QuickGELU, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + self.attn = Ernie4_5_VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.mlp = Ernie4_5_VisionMLP(dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + max_seqlen: Optional[int] = None, # Only used for Flash Attention + seqlens: Optional[list[int]] = None, # Only used for xFormers + ) -> torch.Tensor: + + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class Ernie4_5_VisionPatchEmbed(nn.Module): + + def __init__( + self, + patch_size: int = 14, + in_channels: int = 3, + embed_dim: int = 1280, + prefix="", + ) -> None: + + super().__init__() + self.patch_size = patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + self.proj = nn.Linear(in_channels * patch_size * patch_size, + embed_dim, + bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.to(target_dtype) + hidden_states = self.proj(hidden_states) + + return hidden_states + + +class Ernie4_5_VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.inv_freq = 1.0 / theta**( + torch.arange(start=0, end=dim, step=2, dtype=torch.float32) / dim) + + def forward(self, seqlen: int) -> torch.Tensor: + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(input=seq, vec2=self.inv_freq) + return freqs + + +class Ernie4_5_VisionTransformer(nn.Module): + + def __init__( + self, + vision_config, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + + super().__init__() + patch_size = vision_config.patch_size + spatial_merge_size = vision_config.spatial_merge_size + in_channels = vision_config.in_channels + hidden_size = vision_config.hidden_size + embed_dim = vision_config.embed_dim + depth = vision_config.depth + num_heads = vision_config.num_heads + mlp_ratio = vision_config.mlp_ratio + + self.spatial_merge_size = spatial_merge_size + self.num_heads = num_heads + self.embed_dim = embed_dim + + self.patch_embed = Ernie4_5_VisionPatchEmbed( + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + prefix=f"{prefix}.patch_embed", + ) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = embed_dim // num_heads + self.rotary_pos_emb = Ernie4_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + Ernie4_5_VisionBlock(dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + + assert (hidden_size == embed_dim + ), "vit's config.hidden must be equal to config.embed_dim" + self.ln = nn.LayerNorm(hidden_size, eps=1e-6) + + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def compute_attn_mask_seqlen( + self, cu_seqlens: torch.Tensor + ) -> tuple[Optional[int], Optional[list[int]]]: + max_seqlen, seqlens = None, None + if self.attn_backend == _Backend.FLASH_ATTN: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + elif self.attn_backend == _Backend.XFORMERS: + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + return max_seqlen, seqlens + + def forward(self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + num_pad=0) -> torch.Tensor: + + hidden_states = self.patch_embed(hidden_states) + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + rotary_pos_emb = rotary_pos_emb.to(hidden_states.device) + + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) + + if num_pad > 0: + cu_seqlens = F.pad(cu_seqlens, (1, 1), value=0) + cu_seqlens[-1] = cu_seqlens[-2] + num_pad + else: + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # add batch size + if hidden_states.ndim == 2: + hidden_states = hidden_states.unsqueeze(dim=1) + + # pre-compute seqlens for attn mask to reduce cuMemcpy operations + max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens) + + for i, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + max_seqlen=max_seqlen, + seqlens=seqlens, + ) + + final_output = self.ln(hidden_states) + + if final_output.ndim == 3: + final_output = final_output.squeeze(dim=1) + + return final_output + + def load_weights(self, weights) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +# === Vision Inputs === # + + +class Ernie4_5_VLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """Shape: + `(num_patches, num_channels * patch_size * patch_size)` + """ + + grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs + + +class Ernie4_5_VLVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] + pixel_values_videos: torch.Tensor + """Shape: + `(num_patches, + num_channels * temporal_patch_size * patch_size * patch_size)` + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 3)` + + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Ernie4_5_VLVideoInputs = Ernie4_5_VLImagePixelInputs + +# === Vision Processor === # + + +def round_by_factor(number: Union[int, float], factor: int) -> int: + return round(number / factor) * factor + + +def ceil_by_factor(number: Union[int, float], factor: int) -> int: + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: Union[int, float], factor: int) -> int: + return math.floor(number / factor) * factor + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 4 * 28 * 28, + max_pixels: int = 16384 * 28 * 28, +): + MAX_RATIO = 200 + if max(height, width) / min(height, width) > MAX_RATIO: + if height > width: + new_width = max(factor, round_by_factor(width, factor)) + new_height = floor_by_factor(new_width * MAX_RATIO, factor) + else: + new_height = max(factor, round_by_factor(height, factor)) + new_width = floor_by_factor(new_height * MAX_RATIO, factor) + + height = new_height + width = new_width + + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + + if min_pixels > h_bar * w_bar or h_bar * w_bar > max_pixels: + raise ValueError(f"encounter invalid h_bar: {h_bar}, w_bar: {w_bar}") + + return h_bar, w_bar + + +class VariableResolutionResamplerModel(nn.Module): + + def __init__(self, + in_dim, + out_dim, + spatial_conv_size, + temporal_conv_size, + config, + prefix: str = "") -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.config = config + self.spatial_conv_size = spatial_conv_size + self.temporal_conv_size = temporal_conv_size + self.use_temporal_conv = config.use_temporal_conv + + # compress 2d conv(picture) to 1d + self.spatial_dim = (self.in_dim * self.spatial_conv_size * + self.spatial_conv_size) + # compress 3d conv(video) to 1d + self.temporal_dim = (self.in_dim * self.spatial_conv_size * + self.spatial_conv_size * self.temporal_conv_size) + + self.spatial_linear1 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.spatial_linear1", + ) + + self.spatial_gelu = nn.GELU() + + self.spatial_linear2 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.spatial_linear2", + ) + + self.spatial_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6) + + if self.use_temporal_conv: + self.temporal_linear1 = ColumnParallelLinear( + self.temporal_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.temporal_linear1", + ) + + self.temporal_gelu = nn.GELU() + + self.temporal_linear2 = ColumnParallelLinear( + self.spatial_dim, + self.spatial_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.temporal_linear2", + ) + + self.temporal_norm = nn.LayerNorm(self.spatial_dim, eps=1e-6) + + self.mlp = ColumnParallelLinear( + self.spatial_dim, + self.out_dim, + bias=True, + gather_output=True, + quant_config=getattr(config, 'quant_config', None), + prefix=f"{prefix}.mlp", + ) + + self.after_norm = RMSNorm(hidden_size=out_dim, + eps=getattr(config, 'rms_norm_eps', 1e-6)) + + def spatial_conv_reshape(self, x, spatial_conv_size): + S, C = x.shape + x = x.reshape([-1, C * (spatial_conv_size**2)]) + return x + + def forward(self, x, grid_thw): + + def fwd_spatial(x): + x = self.spatial_conv_reshape(x, self.spatial_conv_size) + + x, _ = self.spatial_linear1(x) + x = self.spatial_gelu(x) + x, _ = self.spatial_linear2(x) + x = self.spatial_norm(x) + + return x + + def fwd_placeholder(x, grid_thw, to_tensor=False): + + grid_thw_cpu = grid_thw.cpu().numpy() + grid_t, grid_hw = grid_thw_cpu[:, 0], grid_thw_cpu[:, 1:] + grid_hw_after_conv = grid_hw.prod(-1) // (self.spatial_conv_size** + 2) + + tokens_per_img_or_vid = grid_thw_cpu.prod(-1) // ( + self.spatial_conv_size**2) + batch_offset = np.empty(tokens_per_img_or_vid.size, + dtype=tokens_per_img_or_vid.dtype) + batch_offset[0] = 0 + batch_offset[1:] = tokens_per_img_or_vid.cumsum()[:-1] + + slice_offsets = [] + for temporoal_size, spatial_size, b_offset in zip( + grid_t, grid_hw_after_conv, batch_offset): + for temp_offset in range(0, temporoal_size, 2): + slice_offsets.append( + np.arange( + b_offset + (temp_offset) * spatial_size, + b_offset + (temp_offset + 1) * spatial_size, + )) + slice_offsets = torch.tensor(np.concatenate(slice_offsets, + axis=-1)).to(x.device) + + slice_offsets2 = [] + for temporoal_size, spatial_size, b_offset in zip( + grid_t, grid_hw_after_conv, batch_offset): + for temp_offset in range(1 if temporoal_size > 1 else 0, + temporoal_size, 2): + slice_offsets2.append( + np.arange( + b_offset + (temp_offset) * spatial_size, + b_offset + (temp_offset + 1) * spatial_size, + )) + slice_offsets2 = torch.tensor( + np.concatenate(slice_offsets2, axis=-1)).to(x.device) + + x_timestep_1 = torch.index_select(x, dim=0, index=slice_offsets) + x_timestep_2 = torch.index_select(x, dim=0, index=slice_offsets2) + x = torch.concat([x_timestep_1, x_timestep_2], dim=-1) + return x + + def fwd_temporal(x): + x, _ = self.temporal_linear1(x) + x = self.temporal_gelu(x) + x, _ = self.temporal_linear2(x) + x = self.temporal_norm(x) + return x + + def fwd_mlp(x): + x, _ = self.mlp(x) + x = self.after_norm(x) + return x + + x = fwd_spatial(x) + if self.use_temporal_conv: + x = fwd_placeholder(x, grid_thw) + x = fwd_temporal(x) + x = fwd_mlp(x) + return x + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.model_config.hf_config + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(use_fast=True, **kwargs) + + def get_image_processor(self, **kwargs: object): + return self.get_hf_processor(**kwargs).image_processor + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def _get_vision_info( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 1, + do_resize: bool = True, + image_processor: Optional[Any], + ) -> tuple[ImageSize, int]: + if image_processor is None: + image_processor = self.get_image_processor() + hf_config = self.get_hf_config() + vision_config = hf_config.vision_config + + patch_size = vision_config.patch_size + spatial_conv_size = hf_config.spatial_conv_size + temporal_conv_size = hf_config.temporal_conv_size + + if do_resize: + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * spatial_conv_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, + height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, + height=image_height) + + grid_t = max(num_frames // temporal_conv_size, 1) + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches // (spatial_conv_size**2) + + return preprocessed_size, num_vision_tokens + + def get_num_image_tokens( + self, + *, + image_width: int, + image_height: int, + image_processor: Optional[Any], + ) -> int: + _, num_image_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + image_processor=image_processor, + ) + return num_image_tokens + + def get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + image_processor: Optional[Any], + ) -> int: + _, num_video_tokens = self._get_vision_info( + image_width=image_width, + image_height=image_height, + num_frames=num_frames, + image_processor=image_processor, + ) + return num_video_tokens + + def get_image_size_with_most_features(self) -> ImageSize: + max_image_size, _ = self._get_vision_info( + image_width=9999999, + image_height=9999999, + image_processor=None, + ) + return max_image_size + + def get_max_image_tokens(self) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_image_tokens = self.get_num_image_tokens( + image_width=target_width, + image_height=target_height, + image_processor=None, + ) + return num_image_tokens + + def _get_max_video_frames(self, max_tokens: int) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + num_frames = 0 + + while True: + next_num_frames = num_frames + 1 + next_max_tokens = self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=next_num_frames, + image_processor=None, + ) + + if next_max_tokens > max_tokens: + break + + num_frames = next_num_frames + + # If the number of frames is odd, discard one frame. + if num_frames % 2 != 0: + num_frames -= 1 + + return num_frames + + def get_num_frames_with_most_features( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + max_images = mm_counts.get("image", 0) + max_videos = mm_counts.get("video", 0) + + max_image_tokens = self.get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) + max_frames_per_video = min(max_total_frames // max(max_videos, 1), + _MAX_FRAMES_PER_VIDEO) + + return max(max_frames_per_video, 2) + + def get_max_video_tokens( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> int: + target_width, target_height = self.get_image_size_with_most_features() + + return self.get_num_video_tokens( + image_width=target_width, + image_height=target_height, + num_frames=self.get_num_frames_with_most_features( + seq_len, mm_counts), + image_processor=None, + ) + + +class Ernie4_5VLMultiModalProcessor( + BaseMultiModalProcessor[Ernie4_5_VLProcessingInfo]): + + def _pixel_values_norm( + self, + pixel_values: torch.Tensor, + mm_kwargs: object, + ) -> torch.Tensor: + hf_config = self.info.get_hf_config() + vision_config = hf_config.vision_config + image_processor = self.info.get_image_processor(**mm_kwargs) + image_mean_tensor = torch.tensor(image_processor.image_mean, + dtype=torch.float32).reshape( + [1, 3, 1, 1]) + image_std_tensor = torch.tensor(image_processor.image_std, + dtype=torch.float32).reshape( + [1, 3, 1, 1]) + rescale_factor = torch.tensor(image_processor.rescale_factor, + dtype=torch.float32) + patch_size_squared = vision_config.patch_size**2 + + image_mean_tensor = (image_mean_tensor.squeeze( + [-2, -1]).repeat_interleave(patch_size_squared, -1)) + image_std_tensor = (image_std_tensor.squeeze( + [-2, -1]).repeat_interleave(patch_size_squared, -1)) + + if not image_mean_tensor.is_contiguous(): + image_mean_tensor = image_mean_tensor.contiguous() + if not image_std_tensor.is_contiguous(): + image_std_tensor = image_std_tensor.contiguous() + + pixel_values = (rescale_factor * pixel_values.to(torch.float32) - + image_mean_tensor) / image_std_tensor + pixel_values = pixel_values.to(hf_config.torch_dtype) + return pixel_values + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + # when the prompt is not empty but the multimodal data is empty, + # directly invoke the tokenizer. + if "images" not in mm_data and "videos" not in mm_data and prompt != "": + tokenizer = self.info.get_tokenizer() + prompt_ids = tokenizer.encode(prompt) + tokenizer_output = BatchFeature(dict(input_ids=[prompt_ids]), + tensor_type="pt") + return tokenizer_output + + if "images" not in mm_data: + mm_data["images"] = [] + if "videos" not in mm_data: + mm_data["videos"] = [] + processor_output = self.info.ctx.call_hf_processor( + self.info.get_hf_processor(**mm_kwargs), + dict(text=[prompt], + images=mm_data["images"], + videos=mm_data["videos"]), + dict(**mm_kwargs, **tok_kwargs), + ) + + # Divide the processor_output into two modalities: image and video. + if processor_output is not None: + pixel_values = processor_output['images'] + if pixel_values is not None: + processor_output['images'] = self._pixel_values_norm( + pixel_values, mm_kwargs) + for key in list(processor_output.keys()): + if processor_output[key] is None: + del processor_output[key] + continue + if key == "grid_thw": + grid_thw = processor_output['grid_thw'] + pixel_values_all = processor_output['images'] + # Identify elements where the first + # dimension is greater than 1 and + # treat them as the video modality + mask = grid_thw[:, 0] > 1 + processor_output["video_grid_thw"] = grid_thw[mask] + processor_output["image_grid_thw"] = grid_thw[~mask] + image_patch_num = processor_output["image_grid_thw"].prod( + dim=1).sum() + processor_output[ + 'pixel_values'] = pixel_values_all[:image_patch_num] + processor_output['pixel_values_videos'] = pixel_values_all[ + image_patch_num:] + del processor_output['images'] + + return processor_output + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + before_placeholder = { + "image": "<|image@placeholder|>", + "video": "<|video@placeholder|>" + } + + after_placeholder = { + # image and video have same placeholder + "image": "<|IMAGE_PLACEHOLDER|>", + "video": "<|IMAGE_PLACEHOLDER|>" + } + + merge_length = hf_processor.spatial_conv_size**2 + + def get_replacement_ernie45vl(item_idx: int, modality: str): + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + if modality == "video": + num_tokens = int(grid_thw.prod( + )) // hf_processor.temporal_conv_size // merge_length + else: + num_tokens = int(grid_thw.prod()) // merge_length + return after_placeholder[modality] * num_tokens + + return [ + PromptReplacement( + modality=modality, + target=before_placeholder[modality], + replacement=partial(get_replacement_ernie45vl, + modality=modality), + ) for modality in ("image", "video") + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + + image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3))) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3))) + video_grid_sizes = video_grid_thw.prod(-1) + + return dict( + pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_grid_sizes), + video_grid_thw=MultiModalFieldConfig.batched("video"), + ) + + +class Ernie4_5_VLDummyInputsBuilder( + BaseDummyInputsBuilder[Ernie4_5_VLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + prompt = "" + for i in range(num_images): + prompt += (f"Picture {i+1}:" + "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>") + + for i in range(num_videos): + prompt += (f"Video {i+1}:" + "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>") + return prompt + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + target_width, target_height = \ + self.info.get_image_size_with_most_features() + target_num_frames = \ + self.info.get_num_frames_with_most_features(seq_len, mm_counts) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos(width=target_width, + height=target_height, + num_frames=target_num_frames, + num_videos=num_videos) + } + + +@MULTIMODAL_REGISTRY.register_processor( + Ernie4_5VLMultiModalProcessor, + info=Ernie4_5_VLProcessingInfo, + dummy_inputs=Ernie4_5_VLDummyInputsBuilder) +class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA, SupportsPP): + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + # model.resampler_model.-> language_model.model.resampler_model. + # language_model.model.resampler_model. -> resampler_model. + "language_model.model.resampler_model.": "resampler_model.", + }, + # resampler_weight_mappings + orig_to_new_substr={ + "spatial_linear.0.": "spatial_linear1.", + "spatial_linear.2.": "spatial_linear2.", + "spatial_linear.3.": "spatial_norm.", + "temporal_linear.0.": "temporal_linear1.", + "temporal_linear.2.": "temporal_linear2.", + "temporal_linear.3.": "temporal_norm.", + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|IMAGE_START|><|image@placeholder|><|IMAGE_END|>" + if modality.startswith("video"): + return "<|VIDEO_START|><|video@placeholder|><|VIDEO_END|>" + + raise ValueError("Only image or video modality is supported") + + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.vision_model = Ernie4_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + self.language_model = Ernie4_5_VLMoeForCausalLM( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.resampler_model = VariableResolutionResamplerModel( + self.config.pixel_hidden_size, + self.config.hidden_size, + self.config.spatial_conv_size, + self.config.temporal_conv_size, + config=self.config, + prefix=maybe_prefix(prefix, "resampler_model")) + + self.visual_token_mask = None + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + """compute logits""" + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def _vision_forward( + self, + pixel_values: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + if grid_thw is not None: + grid_thw = grid_thw[grid_thw > 0] + if grid_thw.numel() % 3 != 0: + raise ValueError( + f"grid_thw has {grid_thw.numel()} elements after filtering," + "which is not divisible by 3.") + grid_thw = grid_thw.reshape(-1, 3) + # example: [[1,64,64],[2,80,80]] -> [[1,64,64],[1,80,80],[1,80,80]] + grid_thw = F.pad( + torch.repeat_interleave(grid_thw[:, 1:], grid_thw[:, 0], 0), + [1, 0, 0, 0], + value=1, + ) + image_features = self.vision_model(pixel_values, grid_thw) + return image_features + + def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: + if getattr(self.config, "im_patch_id", None) is not None: + self.visual_token_mask = ( + input_ids == self.config.im_patch_id).reshape(-1, 1) + else: + self.visual_token_mask = None + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Ernie4_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Ernie4_5_VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Ernie4_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + + if pixel_values_videos is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Ernie4_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + def _process_image_input( + self, + image_input: Ernie4_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values = image_input["pixel_values"].type( + self.vision_model.dtype) + image_features = self._vision_forward(pixel_values=pixel_values, + grid_thw=grid_thw) + image_embeds = self.resampler_model(image_features, grid_thw) + + merge_size = self.vision_model.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, + video_input: Ernie4_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + pixel_values_videos = video_input["pixel_values_videos"].type( + self.vision_model.dtype) + video_features = self._vision_forward(pixel_values=pixel_values_videos, + grid_thw=grid_thw) + video_embeds = self.resampler_model(video_features, grid_thw) + + merge_size = self.vision_model.spatial_merge_size + sizes = (grid_thw.prod(-1) // + self.config.temporal_conv_size) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + + if multimodal_embeddings is None: + return inputs_embeds + + self._set_visual_token_mask(input_ids) + inputs_embeds = merge_multimodal_embeddings(input_ids, inputs_embeds, + multimodal_embeddings, + [self.config.im_patch_id]) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + + forward_kwargs = { + "input_ids": input_ids, + "positions": positions, + "intermediate_tensors": intermediate_tensors, + "inputs_embeds": inputs_embeds, + } + + if self.visual_token_mask is not None: + + if self.visual_token_mask.shape[0] != inputs_embeds.shape[0]: + padding_len = inputs_embeds.shape[ + 0] - self.visual_token_mask.shape[0] + # right pad False + pad = torch.zeros( + (padding_len, self.visual_token_mask.shape[1]), + dtype=self.visual_token_mask.dtype, + device=self.visual_token_mask.device) + self.visual_token_mask = torch.cat( + [self.visual_token_mask, pad], dim=0) + + forward_kwargs.update( + {"visual_token_mask": self.visual_token_mask}) + self.visual_token_mask = None + + hidden_states = self.language_model.model( + **forward_kwargs, + **kwargs, + ) + + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/ernie45_vl_moe.py b/vllm/model_executor/models/ernie45_vl_moe.py new file mode 100644 index 000000000000..f56c09843515 --- /dev/null +++ b/vllm/model_executor/models/ernie45_vl_moe.py @@ -0,0 +1,723 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The Baidu team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Erine VL model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +# from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import ( + Ernie4_5_VLRotaryEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .ernie45_moe import Ernie4_5_MoeMLP +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Ernie4_5_VLMoeMLP(Ernie4_5_MoeMLP): + pass + + +class Ernie4_5_VLMoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: Optional[int] = None, + rope_theta: float = 500000, + rope_scaling: Optional[dict[str, Any]] = None, + freq_allocation: int = 20, + max_position_embeddings: int = 131072, + rms_norm_eps: float = 1e-05, + qkv_bias: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + layer_idx = extract_layer_index(prefix) if len(prefix) > 0 else 0 + self.layer_idx = layer_idx + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + t_rope = freq_allocation + h_rope = (self.head_dim // 2 - freq_allocation) // 2 + w_rope = (self.head_dim // 2 - freq_allocation) // 2 + + self.rotary_emb = Ernie4_5_VLRotaryEmbedding( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position_embeddings=max_position_embeddings, + base=rope_theta, + is_neox_style=False, + dtype=torch.get_default_dtype(), + mrope_section=[h_rope, w_rope, t_rope]) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + + qkv, _ = self.qkv_proj(hidden_states) + + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + + # Attention + attn_output = self.attn(q, k, v) + # Output projection + output, _ = self.o_proj(attn_output) + return output + + +class Ernie4_5_VLMoeMoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + self.tp_size = get_tensor_model_parallel_world_size() + self.has_shared_experts = (getattr(config, "moe_num_shared_experts", 0) + > 0) + self.hidden_size = config.hidden_size + + moe_num_experts = config.moe_num_experts + max_moe_num_experts = max(moe_num_experts) + + if self.tp_size > max_moe_num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {moe_num_experts}.") + + moe_layer_start_index = config.moe_layer_start_index + text_moe_layer_start_index = moe_layer_start_index[0] + vision_moe_layer_start_index = moe_layer_start_index[1] + moe_layer_end_index = config.moe_layer_end_index + moe_layer_end_index = getattr( + config, "moe_layer_end_index", + [config.num_hidden_layers - 1, config.num_hidden_layers - 1]) + text_moe_layer_end_index = moe_layer_end_index[0] + vision_moe_layer_end_index = moe_layer_end_index[1] + + assert config.moe_num_experts[0] == config.moe_num_experts[1] + self.e_score_correction_bias = nn.Parameter( + torch.empty(2, config.moe_num_experts[0])) + + assert text_moe_layer_start_index <= text_moe_layer_end_index + + if layer_idx >= text_moe_layer_start_index and \ + layer_idx <= text_moe_layer_end_index: + self.text_experts_gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts[0], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.text_experts_gate") + + self.text_experts = FusedMoE( + num_experts=config.moe_num_experts[0], + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size[0], + reduce_results=False, + renormalize=True, + quant_config=quant_config, + e_score_correction_bias=self.e_score_correction_bias[0], + prefix=f"{prefix}.text_experts") + else: + self.text_experts = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + assert vision_moe_layer_start_index <= vision_moe_layer_end_index + if layer_idx >= vision_moe_layer_start_index and \ + layer_idx <= vision_moe_layer_end_index: + self.vision_experts_gate = ReplicatedLinear( + config.hidden_size, + config.moe_num_experts[1], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.vision_experts_gate") + + self.vision_experts = FusedMoE( + num_experts=config.moe_num_experts[1], + top_k=config.moe_k, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size[1], + reduce_results=False, + renormalize=True, + quant_config=quant_config, + e_score_correction_bias=self.e_score_correction_bias[1], + prefix=f"{prefix}.vision_experts") + else: + self.vision_experts = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + if self.has_shared_experts: + intermediate_size = (config.moe_intermediate_size[0] * + config.moe_num_shared_experts) + self.shared_experts = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.shared_experts", + reduce_results=self.text_experts. + must_reduce_shared_expert_outputs()) + + def forward( + self, + hidden_states: torch.Tensor, + visual_token_mask: torch.Tensor, + **kwargs: object, + ) -> torch.Tensor: + + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.has_shared_experts: + shared_output = self.shared_experts(hidden_states) + + if visual_token_mask is not None and visual_token_mask.any(): + # assert visual_token_mask.shape[0] != hidden_states.shape[0] + visual_token_mask = visual_token_mask.repeat( + 1, self.hidden_size).bool() + text_token_mask = ~visual_token_mask + final_hidden_states = torch.zeros_like(hidden_states) + + text_hidden_states = hidden_states[text_token_mask].reshape( + -1, self.hidden_size) + vision_hidden_states = hidden_states[visual_token_mask].reshape( + -1, self.hidden_size) + + text_router_logits, _ = self.text_experts_gate(text_hidden_states) + final_hidden_states[text_token_mask] = self.text_experts( + hidden_states=text_hidden_states, + router_logits=text_router_logits).flatten() + + vision_router_logits, _ = self.vision_experts_gate( + vision_hidden_states) + final_hidden_states[visual_token_mask] = self.vision_experts( + hidden_states=vision_hidden_states, + router_logits=vision_router_logits).flatten() + else: + # text modal input processing directly + text_router_logits, _ = self.text_experts_gate(hidden_states) + + final_hidden_states = self.text_experts( + hidden_states=hidden_states, router_logits=text_router_logits) + + if self.has_shared_experts and \ + shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = ( + self.text_experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states)) + + return final_hidden_states.view(orig_shape) + + +class Ernie4_5_VLMoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 500000) + rope_scaling = getattr(config, "rope_scaling", None) + freq_allocation = getattr(config, "freq_allocation", 20) + max_position_embeddings = getattr(config, "max_position_embeddings", + 131072) + + self.self_attn = Ernie4_5_VLMoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=getattr(config, 'head_dim', None), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + freq_allocation=freq_allocation, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'use_bias', False), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + + # MoE + moe_layer_start_index = config.moe_layer_start_index + min_moe_layer_start_index = min(moe_layer_start_index) + moe_layer_end_index = getattr( + config, "moe_layer_end_index", + [config.num_hidden_layers - 1, config.num_hidden_layers - 1]) + max_moe_layer_end_index = max(moe_layer_end_index) + assert min_moe_layer_start_index <= max_moe_layer_end_index + moe_num_experts = config.moe_num_experts + max_moe_num_experts = max(moe_num_experts) + moe_layer_interval = getattr(config, "moe_layer_interval", 1) + use_moe = getattr(config, "use_moe", max_moe_num_experts > 0) + + if (use_moe and ((layer_idx + 1) % moe_layer_interval == 0) + and layer_idx >= min_moe_layer_start_index + and layer_idx <= max_moe_layer_end_index): + self.mlp = Ernie4_5_VLMoeMoE(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Ernie4_5_VLMoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + use_bias=getattr(config, 'use_bias', False), + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + visual_token_mask: Optional[torch.Tensor], + **kwargs: object, + ) -> torch.Tensor: + + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if isinstance(self.mlp, Ernie4_5_VLMoeMoE): + hidden_states = self.mlp(hidden_states, visual_token_mask, + **kwargs) + else: + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +# Since Ernie VL distinguishes between text experts and vision experts, +# enabling torch.compile will cause errors. +# @support_torch_compile( +# dynamic_arg_dims={ +# "input_ids": 0, +# "positions": -1, +# "intermediate_tensors": 0, +# "inputs_embeds": 0, +# "visual_token_mask": 0, +# }) +class Ernie4_5_VLMoeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + + self.im_patch_id = config.im_patch_id + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Ernie4_5_VLMoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + visual_token_mask: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual, + visual_token_mask, **kwargs) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +# only used as text backbone for ernie4.5-vl +class Ernie4_5_VLMoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Ernie4_5_VLMoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds, **kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=max(self.config.moe_num_experts)) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if self.config.tie_word_embeddings and name.endswith( + "lm_head.weight"): + loaded_params.add("lm_head.weight") + continue + # MTP will be supported soon. + if "mtp" in name or \ + "vision_model" in name or \ + "resampler_model" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Distinguish between vision experts and text experts + if "mlp.experts" in name: + moe_offset = int(name.split(".")[-3]) + vision_expert_start_idx = self.config.moe_num_experts[0] + is_text_expert = \ + moe_offset <= vision_expert_start_idx - 1 + if is_text_expert: + name = name.replace(".experts.", ".text_experts.") + else: + name = name.replace( + f".experts.{moe_offset}", + f".vision_experts.{moe_offset-vision_expert_start_idx}" + ) + + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + + if weight_name not in name: + continue + + # Distinguish between vision experts and text experts + moe_offset = int(name.split(".")[-3]) + is_text_expert = \ + moe_offset <= self.config.moe_num_experts[0] - 1 + + name = name.replace(weight_name, param_name) + if is_text_expert: + name = name.replace(".experts.", ".text_experts.") + else: + name = name.replace(".experts.", ".vision_experts.") + + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Distinguish between vision expert gate + # and text expert gate + if name.endswith("mlp.gate.weight"): + name = name.replace("gate.weight", + "text_experts_gate.weight") + loaded_weight = loaded_weight.T + elif name.endswith("mlp.gate.weight_1"): + name = name.replace("gate.weight_1", + "vision_experts_gate.weight") + loaded_weight = loaded_weight.T + + if "e_score_correction_bias" in name: + name = name.replace(".moe_statics.", ".") + + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index ebf78771e40a..c65c58d4a047 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -206,6 +206,7 @@ "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "Cohere2VisionForConditionalGeneration": ("cohere2_vision", "Cohere2VisionForConditionalGeneration"), # noqa: E501 "DeepseekVLV2ForCausalLM": ("deepseek_vl2", "DeepseekVLV2ForCausalLM"), + "Ernie4_5_VLMoeForConditionalGeneration": ("ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration"), # noqa: E501 "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "Gemma3nForConditionalGeneration": ("gemma3n_mm", "Gemma3nForConditionalGeneration"), # noqa: E501 From 32102644213a6367d10ec3a92ae76fb0004f3a52 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 26 Aug 2025 21:58:59 -0700 Subject: [PATCH 002/125] [Frontend] Add --log-error-stack to print stack trace for error response (#22960) Signed-off-by: Chen Zhang --- vllm/entrypoints/openai/api_server.py | 10 ++++++++++ vllm/entrypoints/openai/cli_args.py | 2 ++ vllm/entrypoints/openai/serving_chat.py | 4 +++- vllm/entrypoints/openai/serving_classification.py | 2 ++ vllm/entrypoints/openai/serving_completion.py | 2 ++ vllm/entrypoints/openai/serving_embedding.py | 4 +++- vllm/entrypoints/openai/serving_engine.py | 9 +++++++++ vllm/entrypoints/openai/serving_pooling.py | 4 +++- vllm/entrypoints/openai/serving_responses.py | 2 ++ vllm/entrypoints/openai/serving_score.py | 4 +++- vllm/entrypoints/openai/serving_tokenization.py | 4 +++- vllm/entrypoints/openai/serving_transcription.py | 8 ++++++-- vllm/entrypoints/openai/speech_to_text.py | 4 +++- 13 files changed, 51 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index db02767fdfd7..9a2470649c8d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -1749,6 +1749,7 @@ async def init_app_state( enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, enable_log_outputs=args.enable_log_outputs, + log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None state.openai_serving_chat = OpenAIServingChat( engine_client, @@ -1767,6 +1768,7 @@ async def init_app_state( enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, enable_log_outputs=args.enable_log_outputs, + log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, @@ -1776,6 +1778,7 @@ async def init_app_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, + log_error_stack=args.log_error_stack, ) if "generate" in supported_tasks else None state.openai_serving_pooling = OpenAIServingPooling( engine_client, @@ -1784,6 +1787,7 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + log_error_stack=args.log_error_stack, ) if "encode" in supported_tasks else None state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, @@ -1792,12 +1796,14 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + log_error_stack=args.log_error_stack, ) if "embed" in supported_tasks else None state.openai_serving_classification = ServingClassification( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, + log_error_stack=args.log_error_stack, ) if "classify" in supported_tasks else None enable_serving_reranking = ("classify" in supported_tasks and getattr( @@ -1807,6 +1813,7 @@ async def init_app_state( model_config, state.openai_serving_models, request_logger=request_logger, + log_error_stack=args.log_error_stack, ) if ("embed" in supported_tasks or enable_serving_reranking) else None state.openai_serving_tokenization = OpenAIServingTokenization( @@ -1816,18 +1823,21 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, + log_error_stack=args.log_error_stack, ) state.openai_serving_transcription = OpenAIServingTranscription( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, + log_error_stack=args.log_error_stack, ) if "transcription" in supported_tasks else None state.openai_serving_translation = OpenAIServingTranslation( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, + log_error_stack=args.log_error_stack, ) if "transcription" in supported_tasks else None state.enable_server_load_tracking = args.enable_server_load_tracking diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 6e4eff5c8024..d0b5d013eb9e 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -180,6 +180,8 @@ class FrontendArgs: h11_max_header_count: int = H11_MAX_HEADER_COUNT_DEFAULT """Maximum number of HTTP headers allowed in a request for h11 parser. Helps mitigate header abuse. Default: 256.""" + log_error_stack: bool = envs.VLLM_SERVER_DEV_MODE + """If set to True, log the stack trace of error responses""" @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 7e0e62778097..1c0ffdfb9189 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -76,13 +76,15 @@ def __init__( enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, enable_log_outputs: bool = False, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage) + enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack) self.response_role = response_role self.chat_template = chat_template diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index 377f7f684717..1d510d0b60a2 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -129,12 +129,14 @@ def __init__( models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], + log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, + log_error_stack=log_error_stack, ) async def create_classify( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index a0ce65409403..b81fd63ece7a 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -59,6 +59,7 @@ def __init__( return_tokens_as_token_ids: bool = False, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, + log_error_stack: bool = False, ): super().__init__( engine_client=engine_client, @@ -67,6 +68,7 @@ def __init__( request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack, ) self.enable_prompt_tokens_details = enable_prompt_tokens_details self.default_sampling_params = ( diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 9dcad8e391c6..45c1932f1873 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -593,11 +593,13 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 0f4a7c0186b6..a97935e109ef 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -5,6 +5,7 @@ import json import sys import time +import traceback from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor from http import HTTPStatus @@ -205,6 +206,7 @@ def __init__( request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, enable_force_include_usage: bool = False, + log_error_stack: bool = False, ): super().__init__() @@ -222,6 +224,7 @@ def __init__( self._async_tokenizer_pool: dict[AnyTokenizer, AsyncMicrobatchTokenizer] = {} + self.log_error_stack = log_error_stack def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer: """ @@ -412,6 +415,12 @@ def create_error_response( message: str, err_type: str = "BadRequestError", status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + if self.log_error_stack: + exc_type, _, _ = sys.exc_info() + if exc_type is not None: + traceback.print_exc() + else: + traceback.print_stack() return ErrorResponse(error=ErrorInfo( message=message, type=err_type, code=status_code.value)) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 38745d001ade..e8cb1aed8459 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -58,11 +58,13 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 67eec2d523e3..899cb07b2b37 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -88,6 +88,7 @@ def __init__( enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, enable_log_outputs: bool = False, + log_error_stack: bool = False, ) -> None: super().__init__( engine_client=engine_client, @@ -96,6 +97,7 @@ def __init__( request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, enable_force_include_usage=enable_force_include_usage, + log_error_stack=log_error_stack, ) self.chat_template = chat_template diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index c246274514db..37838e22a400 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -47,11 +47,13 @@ def __init__( models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) async def _embedding_score( self, diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 58d720474768..2f258255d5f1 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -39,11 +39,13 @@ def __init__( request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, + log_error_stack: bool = False, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, models=models, - request_logger=request_logger) + request_logger=request_logger, + log_error_stack=log_error_stack) self.chat_template = chat_template self.chat_template_content_format: Final = chat_template_content_format diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 0d6989fe91bf..9ba58d442522 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -32,13 +32,15 @@ def __init__( *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - task_type="transcribe") + task_type="transcribe", + log_error_stack=log_error_stack) async def create_transcription( self, audio_data: bytes, request: TranscriptionRequest, @@ -88,13 +90,15 @@ def __init__( *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids, - task_type="translate") + task_type="translate", + log_error_stack=log_error_stack) async def create_translation( self, audio_data: bytes, request: TranslationRequest, diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index de2619a78f8e..1cbd7dba393f 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -53,12 +53,14 @@ def __init__( request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, task_type: Literal["transcribe", "translate"] = "transcribe", + log_error_stack: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, models=models, request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids) + return_tokens_as_token_ids=return_tokens_as_token_ids, + log_error_stack=log_error_stack) self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) From 142ac0803045b3a3edcd7aa58fe079872903a30c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 26 Aug 2025 21:59:14 -0700 Subject: [PATCH 003/125] [Frontend] Optimize beam search performance by limiting concurrency (#23599) Signed-off-by: Chen Zhang --- benchmarks/benchmark_throughput.py | 1 - tests/conftest.py | 8 +- tests/samplers/test_beam_search.py | 53 ++++++++++ vllm/entrypoints/llm.py | 152 ++++++++++++++++------------- 4 files changed, 143 insertions(+), 71 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index c7f290e1eb88..6b24b8c8f3c6 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -96,7 +96,6 @@ def run_vllm( end = time.perf_counter() else: assert lora_requests is None, "BeamSearch API does not support LoRA" - prompts = [request.prompt for request in requests] # output_len should be the same for all requests. output_len = requests[0].expected_output_len for request in requests: diff --git a/tests/conftest.py b/tests/conftest.py index 2bf88abb0f6c..f8bfdfc8e625 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1022,15 +1022,17 @@ def generate_beam_search( images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, + concurrency_limit: Optional[int] = None, ) -> list[tuple[list[list[int]], list[str]]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) - outputs = self.llm.beam_search( - inputs, - BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) + outputs = self.llm.beam_search(inputs, + BeamSearchParams(beam_width=beam_width, + max_tokens=max_tokens), + concurrency_limit=concurrency_limit) returned_outputs = [] for output in outputs: token_ids = [x.tokens for x in output.sequences] diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index bdf48c7687b2..cc9a88a255f9 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -67,6 +67,59 @@ def test_beam_search_single_input( f"vLLM: {vllm_output_ids}") +@pytest.mark.skip_v1 # FIXME: This fails on V1 right now. +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", MAX_TOKENS) +@pytest.mark.parametrize("beam_width", BEAM_WIDTHS) +def test_beam_search_with_concurrency_limit( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + beam_width: int, +) -> None: + # example_prompts[1]&[3]&[7] fails due to unknown reason even without + # concurency limit. skip them for now. + example_prompts = (example_prompts[:8]) + concurrency_limit = 2 + assert len(example_prompts) > concurrency_limit + with vllm_runner(model, dtype=dtype) as vllm_model: + outputs_with_limit = vllm_model.generate_beam_search( + example_prompts, + beam_width, + max_tokens, + concurrency_limit=concurrency_limit) + outputs_without_limit = [] + + for i in range(0, len(example_prompts), concurrency_limit): + outputs_without_limit.extend( + vllm_model.generate_beam_search( + example_prompts[i:i + concurrency_limit], beam_width, + max_tokens)) + + correct = True + for i in range(len(example_prompts)): + output_ids_with_limit, output_texts_with_limit = outputs_with_limit[i] + output_ids_without_limit, output_texts_without_limit = ( + outputs_without_limit[i]) + for j, (text_with_limit, text_without_limit) in enumerate( + zip(output_texts_with_limit, output_texts_without_limit)): + print(f">>>{j}-th with limit output:") + print(text_with_limit) + print(f">>>{j}-th without limit output:") + print(text_without_limit) + assert len(output_ids_with_limit) == len(output_ids_without_limit) + for j in range(len(output_ids_with_limit)): + if output_ids_with_limit[j] != output_ids_without_limit[j]: + print(f"Test{i} output{j}:\n+limit: {output_ids_with_limit}\n" + f"-limit: {output_ids_without_limit}") + correct = False + assert correct + + @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", MAX_TOKENS) @pytest.mark.parametrize("beam_width", MM_BEAM_WIDTHS) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 8816ff56d684..72b6123670b7 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -523,6 +523,7 @@ def beam_search( params: BeamSearchParams, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, use_tqdm: bool = False, + concurrency_limit: Optional[int] = None, ) -> list[BeamSearchOutput]: """ Generate sequences using beam search. @@ -533,6 +534,8 @@ def beam_search( params: The beam search parameters. lora_request: LoRA request to use for generation, if any. use_tqdm: Whether to use tqdm to display the progress bar. + concurrency_limit: The maximum number of concurrent requests. + If None, the number of concurrent requests is unlimited. """ # TODO: how does beam search work together with length penalty, # frequency, penalty, and stopping criteria, etc.? @@ -551,6 +554,15 @@ def beam_search( length_penalty, ) + if use_tqdm and concurrency_limit is not None: + logger.warning( + "Progress bar is not supported when using concurrency_limit. " + "Disabling progress bar.") + use_tqdm = False + + if concurrency_limit is None: + concurrency_limit = len(prompts) + def create_tokens_prompt_from_beam( beam: BeamSearchSequence) -> TokensPrompt: token_prompt_kwargs: TokensPrompt = { @@ -595,73 +607,79 @@ def create_tokens_prompt_from_beam( **mm_kwargs, ), ) - token_iter = range(max_tokens) - if use_tqdm: - token_iter = tqdm(token_iter, - desc="Beam search", - unit="token", - unit_scale=False) - logger.warning( - "The progress bar shows the upper bound on token steps and " - "may finish early due to stopping conditions. It does not " - "reflect instance-level progress.") - - for _ in token_iter: - all_beams: list[BeamSearchSequence] = list( - sum((instance.beams for instance in instances), [])) - pos = [0] + list( - itertools.accumulate( - len(instance.beams) for instance in instances)) - instance_start_and_end: list[tuple[int, int]] = list( - zip(pos[:-1], pos[1:])) - - if len(all_beams) == 0: - break - - # create the corresponding batch entries for prompt & optional lora - prompts_batch, lora_req_batch = zip( - *[(create_tokens_prompt_from_beam(beam), beam.lora_request) - for beam in all_beams]) - - # only runs for one step - # we don't need to use tqdm here - output = self.generate(prompts_batch, - sampling_params=beam_search_params, - use_tqdm=False, - lora_request=lora_req_batch) - - for (start, end), instance in zip(instance_start_and_end, - instances): - instance_new_beams = [] - for i in range(start, end): - current_beam = all_beams[i] - result = output[i] - - if result.outputs[0].logprobs is not None: - # if `result.outputs[0].logprobs` is None, it means - # the sequence is completed because of the max-model-len - # or abortion. we don't need to add it to the new beams. - logprobs = result.outputs[0].logprobs[0] - for token_id, logprob_obj in logprobs.items(): - new_beam = BeamSearchSequence( - tokens=current_beam.tokens + [token_id], - logprobs=current_beam.logprobs + [logprobs], - lora_request=current_beam.lora_request, - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob, - multi_modal_data=current_beam.multi_modal_data, - mm_processor_kwargs=current_beam. - mm_processor_kwargs) - - if token_id == tokenizer.eos_token_id and \ - not ignore_eos: - instance.completed.append(new_beam) - else: - instance_new_beams.append(new_beam) - sorted_beams = sorted(instance_new_beams, - key=sort_beams_key, - reverse=True) - instance.beams = sorted_beams[:beam_width] + for prompt_start in range(0, len(prompts), concurrency_limit): + instances_batch = instances[prompt_start:prompt_start + + concurrency_limit] + + token_iter = range(max_tokens) + if use_tqdm: + token_iter = tqdm(token_iter, + desc="Beam search", + unit="token", + unit_scale=False) + logger.warning( + "The progress bar shows the upper bound on token steps and " + "may finish early due to stopping conditions. It does not " + "reflect instance-level progress.") + for _ in token_iter: + all_beams: list[BeamSearchSequence] = list( + sum((instance.beams for instance in instances_batch), [])) + pos = [0] + list( + itertools.accumulate( + len(instance.beams) for instance in instances_batch)) + instance_start_and_end: list[tuple[int, int]] = list( + zip(pos[:-1], pos[1:])) + + if len(all_beams) == 0: + break + + # create corresponding batch entries for prompt & optional lora + prompts_batch, lora_req_batch = zip( + *[(create_tokens_prompt_from_beam(beam), beam.lora_request) + for beam in all_beams]) + + # only runs for one step + # we don't need to use tqdm here + output = self.generate(prompts_batch, + sampling_params=beam_search_params, + use_tqdm=False, + lora_request=lora_req_batch) + + for (start, end), instance in zip(instance_start_and_end, + instances_batch): + instance_new_beams = [] + for i in range(start, end): + current_beam = all_beams[i] + result = output[i] + + if result.outputs[0].logprobs is not None: + # if `result.outputs[0].logprobs` is None, it means + # the sequence is completed because of the + # max-model-len or abortion. we don't need to add + # it to the new beams. + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + + [logprobs], + lora_request=current_beam.lora_request, + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob, + multi_modal_data=current_beam. + multi_modal_data, + mm_processor_kwargs=current_beam. + mm_processor_kwargs) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + instance.completed.append(new_beam) + else: + instance_new_beams.append(new_beam) + sorted_beams = sorted(instance_new_beams, + key=sort_beams_key, + reverse=True) + instance.beams = sorted_beams[:beam_width] outputs = [] for instance in instances: From d272415e57c95da63c798c22c7d87cc5c0cda21f Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Wed, 27 Aug 2025 01:00:21 -0400 Subject: [PATCH 004/125] [Quantization] Expand compressed-tensors MoE matching logic to support NFP4 + FP8 MoEs (#22674) Signed-off-by: Dipika Sikka Signed-off-by: Dipika --- .../compressed_tensors/compressed_tensors.py | 13 +++---- .../compressed_tensors_moe.py | 36 +++++++++++++++++-- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index ce74375aab42..245cf122ebab 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -425,6 +425,10 @@ def _get_scheme_from_parts( weight_quant: BaseModel, input_quant: BaseModel, format: Optional[str] = None) -> "CompressedTensorsScheme": + + # use the per-layer format if defined, otherwise, use global format + format = format if format is not None else self.quant_format + # Detect If Mixed Precision if self._is_fp4a16_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A16Fp4() @@ -437,14 +441,14 @@ def _get_scheme_from_parts( actorder=weight_quant.actorder) if self._is_wNa16_group_channel(weight_quant, input_quant): - if (self.quant_format == CompressionFormat.marlin_24.value + if (format == CompressionFormat.marlin_24.value and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): assert weight_quant.symmetric return CompressedTensorsW4A16Sparse24( strategy=weight_quant.strategy, num_bits=weight_quant.num_bits, group_size=weight_quant.group_size) - if (self.quant_format == CompressionFormat.pack_quantized.value + if (format == CompressionFormat.pack_quantized.value and weight_quant.num_bits in WNA16_SUPPORTED_BITS): return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, @@ -453,10 +457,7 @@ def _get_scheme_from_parts( group_size=weight_quant.group_size, actorder=weight_quant.actorder) - act_quant_format = is_activation_quantization_format( - format - ) if format is not None else is_activation_quantization_format( - self.quant_format) + act_quant_format = is_activation_quantization_format(format) if act_quant_format: if self._is_fp4a4_nvfp4(weight_quant, input_quant): if cutlass_fp4_supported( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 1ee3478aa4f4..6279bb8b6057 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -22,6 +22,8 @@ is_valid_flashinfer_cutlass_fused_moe) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + find_matched_target) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1, @@ -65,12 +67,40 @@ def __init_(self, moe: FusedMoEConfig): @staticmethod def get_moe_method( quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 - layer: torch.nn.Module, + layer: torch.nn.Module ) -> "CompressedTensorsMoEMethod": # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. - weight_quant = quant_config.target_scheme_map["Linear"].get("weights") - input_quant = quant_config.target_scheme_map["Linear"].get( + # Check if a using "Linear" to select scheems + if "Linear" in quant_config.target_scheme_map: + matched_target = "Linear" + else: + # May have instead defined the linear layers in the fused model + + fused_layers = [ + "re:.*down_proj.*", "re:.*gate_proj.*", "re:.*up_proj.*" + ] + current_scheme = None + for fused_layer in fused_layers: + # Check if one of the fused layers are defined in quant_config + matched_target = find_matched_target( + layer_name=fused_layer, + module=layer, + targets=quant_config.target_scheme_map.keys(), + fused_mapping=quant_config.packed_modules_mapping) + + # Only valid if down_proj, gate_proj, and up_proj + # are mapped to the same quant scheme in the quant_config + if current_scheme is None: + current_scheme = quant_config.target_scheme_map.get( + matched_target) + else: + assert current_scheme == quant_config.target_scheme_map.get( + matched_target) + + weight_quant = quant_config.target_scheme_map[matched_target].get( + "weights") + input_quant = quant_config.target_scheme_map[matched_target].get( "input_activations") if quant_config._is_wNa16_group_channel(weight_quant, input_quant): From fce10dbed5441b4f918b23a2b63aae72bc00a2f6 Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Wed, 27 Aug 2025 13:33:27 +0800 Subject: [PATCH 005/125] [XPU] Add xpu torch.compile support (#22609) Signed-off-by: Kunshang Ji --- .buildkite/scripts/hardware_ci/run-xpu-test.sh | 1 + vllm/attention/layer.py | 3 +-- vllm/compilation/fix_functionalization.py | 8 ++++++++ vllm/platforms/cpu.py | 4 ++++ vllm/platforms/cuda.py | 4 ++++ vllm/platforms/interface.py | 8 ++++++++ vllm/platforms/rocm.py | 4 ++++ vllm/platforms/xpu.py | 15 ++++++--------- 8 files changed, 36 insertions(+), 11 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index 445cd2735c19..73f3e63fbf5f 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -31,6 +31,7 @@ docker run \ set -e echo $ZE_AFFINITY_MASK VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager + VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 -O3 -O.cudagraph_mode=NONE VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp cd tests diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 2d288bcbe0c9..237802afccde 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -190,8 +190,7 @@ def __init__( # torch.compile works by registering the attention as one giant # opaque custom op. For other platforms, we directly call them # and let torch.compile handle them. - self.use_direct_call = not current_platform.is_cuda_alike( - ) and not current_platform.is_cpu() + self.use_direct_call = not current_platform.opaque_attention_op() self.use_output = self.attn_backend.accept_output_buffer compilation_config = get_current_vllm_config().compilation_config diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 286221d32c1e..60ae14331879 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -9,6 +9,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from vllm.logger import init_logger +from vllm.platforms import current_platform from .fx_utils import is_func from .vllm_inductor_pass import VllmInductorPass @@ -26,6 +27,13 @@ class FixFunctionalizationPass(VllmInductorPass): """ def __call__(self, graph: torch.fx.Graph): + # XPU does not support auto-functionalization yet. + # Will enable this when switch to vllm-xpu-kernels. + if current_platform.is_xpu(): + logger.debug("XPU platform does not support fix functionalization" + "pass currently.") + return + self.begin() self.dump_graph(graph, "before_fix_functionalization") diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index c748595a7153..5686fae5cd7d 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -335,3 +335,7 @@ def default_v1(cls, model_config) -> bool: return (cls.supports_v1(model_config) and arch in (CpuArchEnum.X86, CpuArchEnum.POWERPC, CpuArchEnum.ARM, CpuArchEnum.S390X)) + + @classmethod + def opaque_attention_op(cls) -> bool: + return True diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c0e0fe35e402..5cbb7346436e 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -442,6 +442,10 @@ def supports_v1(cls, model_config: "ModelConfig") -> bool: def use_custom_allreduce(cls) -> bool: return True + @classmethod + def opaque_attention_op(cls) -> bool: + return True + @classmethod def get_static_graph_wrapper_cls(cls) -> str: return "vllm.compilation.cuda_graph.CUDAGraphWrapper" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f6c17de86d05..01f3e2d977bc 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -509,6 +509,14 @@ def use_custom_allreduce(cls) -> bool: """ return False + @classmethod + def opaque_attention_op(cls) -> bool: + """ + Returns True if we register attention as one giant opaque custom op + on the current platform + """ + return False + @classmethod def validate_request( cls, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 85b2fe2e480c..c6d14aa87c7f 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -411,6 +411,10 @@ def use_custom_allreduce(cls) -> bool: supported_archs = ['gfx94', 'gfx95'] return any(gfx in gcn_arch for gfx in supported_archs) + @classmethod + def opaque_attention_op(cls) -> bool: + return True + @classmethod def get_cu_count(cls, device_id: int = 0) -> int: return torch.cuda.get_device_properties( diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 235e5d8294e5..84f4cd725646 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -90,21 +90,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if cache_config and cache_config.block_size is None: cache_config.block_size = 64 - # FIXME: Temporarily forcing eager mode - # remove after t.compile support stabilizes. - if (envs.VLLM_USE_V1 and model_config is not None - and not vllm_config.model_config.enforce_eager): - from vllm.config import CompilationLevel - vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501 - # lazy import to avoid circular import from vllm.config import CUDAGraphMode compilation_config = vllm_config.compilation_config if compilation_config.cudagraph_mode is None or \ compilation_config.cudagraph_mode.max_cudagraph_mode() \ != CUDAGraphMode.NONE: - logger.info("[XPU] CUDA graph is not supported on XPU, " - "disabling cudagraphs.") + logger.info("[XPU] CUDA graph is not supported on XPU, disabling " + "cudagraphs. Fallback to cudagraph_mode=NONE") compilation_config.cudagraph_mode = CUDAGraphMode.NONE # check and update parallel config @@ -182,3 +175,7 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype): "Intel Arc A770 have bfloat16 accuracy known issue. " "You can use float16 instead by explicitly setting the " "`dtype` flag in CLI, for example: --dtype=half.") + + @classmethod + def opaque_attention_op(cls) -> bool: + return True From 9de25c294b92e42a12d1fbbb3ab3f633fa80291c Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 27 Aug 2025 13:51:50 +0800 Subject: [PATCH 006/125] [CI/Build] Remove redundant LoRA model tests (#23706) Signed-off-by: Jee Jee Li --- tests/lora/conftest.py | 5 -- tests/lora/test_baichuan.py | 112 ------------------------------------ tests/lora/test_phi.py | 71 ----------------------- 3 files changed, 188 deletions(-) delete mode 100644 tests/lora/test_baichuan.py delete mode 100644 tests/lora/test_phi.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index cba573b63c04..3475993ff8f0 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -216,11 +216,6 @@ def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") -@pytest.fixture(scope="session") -def phi2_lora_files(): - return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") - - @pytest.fixture def reset_default_device(): """ diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py deleted file mode 100644 index 774ebb9db210..000000000000 --- a/tests/lora/test_baichuan.py +++ /dev/null @@ -1,112 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -import vllm -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.lora.request import LoRARequest - -MODEL_PATH = "baichuan-inc/Baichuan-7B" - -PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: - prompts = [ - PROMPT_TEMPLATE.format(query="How many singers do we have?"), - PROMPT_TEMPLATE.format( - query= - "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 - ), - PROMPT_TEMPLATE.format( - query= - "Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 - ), - ] - print(prompts) - sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256) - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) - # Print the outputs. - generated_texts: list[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -def test_baichuan_lora(baichuan_lora_files): - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True) - - expected_lora_output = [ - "SELECT count(*) FROM singer", - "SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501 - "SELECT name , country , age FROM singer ORDER BY age ASC", - ] - - output1 = do_sample(llm, baichuan_lora_files, lora_id=1) - for i in range(len(expected_lora_output)): - assert output1[i] == expected_lora_output[i] - output2 = do_sample(llm, baichuan_lora_files, lora_id=2) - for i in range(len(expected_lora_output)): - assert output2[i] == expected_lora_output[i] - - -@pytest.mark.parametrize("fully_sharded", [True, False]) -def test_baichuan_tensor_parallel_equality(baichuan_lora_files, - num_gpus_available, fully_sharded): - if num_gpus_available < 4: - pytest.skip(f"Not enough GPUs for tensor parallelism {4}") - - llm_tp1 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) - output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1) - - del llm_tp1 - cleanup_dist_env_and_memory() - - llm_tp2 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=2, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) - output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2) - - del llm_tp2 - cleanup_dist_env_and_memory() - - assert output_tp1 == output_tp2 - - llm_tp4 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - max_lora_rank=64, - tensor_parallel_size=4, - trust_remote_code=True, - fully_sharded_loras=fully_sharded) - output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2) - - del llm_tp4 - cleanup_dist_env_and_memory() - - assert output_tp1 == output_tp4 diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py deleted file mode 100644 index 3090941e6367..000000000000 --- a/tests/lora/test_phi.py +++ /dev/null @@ -1,71 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import vllm -from vllm.lora.request import LoRARequest - -MODEL_PATH = "microsoft/phi-2" - -PROMPT_TEMPLATE = "### Instruct: {sql_prompt}\n\n### Context: {context}\n\n### Output:" # noqa: E501 - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: - prompts = [ - PROMPT_TEMPLATE.format( - sql_prompt= - "Which catalog publisher has published the most catalogs?", - context="CREATE TABLE catalogs (catalog_publisher VARCHAR);"), - PROMPT_TEMPLATE.format( - sql_prompt= - "Which trip started from the station with the largest dock count? Give me the trip id.", # noqa: E501 - context= - "CREATE TABLE trip (id VARCHAR, start_station_id VARCHAR); CREATE TABLE station (id VARCHAR, dock_count VARCHAR);" # noqa: E501 - ), - PROMPT_TEMPLATE.format( - sql_prompt= - "How many marine species are found in the Southern Ocean?", # noqa: E501 - context= - "CREATE TABLE marine_species (name VARCHAR(50), common_name VARCHAR(50), location VARCHAR(50));" # noqa: E501 - ), - ] - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=64, - stop="### End") - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None, - ) - # Print the outputs. - generated_texts: list[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -def test_phi2_lora(phi2_lora_files): - # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, - # Otherwise, the lora-test will fail due to CUDA OOM. - llm = vllm.LLM(MODEL_PATH, - max_model_len=1024, - enable_lora=True, - max_loras=2, - enforce_eager=True, - enable_chunked_prefill=True) - - expected_lora_output = [ - "SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501 - "SELECT trip.id FROM trip JOIN station ON trip.start_station_id = station.id WHERE station.dock_count = (SELECT MAX(dock_count) FROM station);", # noqa: E501 - "SELECT COUNT(*) FROM marine_species WHERE location = 'Southern Ocean';", # noqa: E501 - ] - - output1 = do_sample(llm, phi2_lora_files, lora_id=1) - for i in range(len(expected_lora_output)): - assert output1[i].startswith(expected_lora_output[i]) - output2 = do_sample(llm, phi2_lora_files, lora_id=2) - for i in range(len(expected_lora_output)): - assert output2[i].startswith(expected_lora_output[i]) From 8dbf6ed7be3f8602257ce1879825d4b5e3554d67 Mon Sep 17 00:00:00 2001 From: "rongfu.leng" Date: Wed, 27 Aug 2025 13:54:39 +0800 Subject: [PATCH 007/125] [Bugfix] fix when config.yaml config value is list parse error (#23528) Signed-off-by: rongfu.leng --- tests/utils_/test_utils.py | 41 ++++++++++++++++++++++++++++++++++++++ vllm/utils/__init__.py | 9 +++++++-- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/tests/utils_/test_utils.py b/tests/utils_/test_utils.py index 084d82dee11b..04195ea0cf92 100644 --- a/tests/utils_/test_utils.py +++ b/tests/utils_/test_utils.py @@ -5,13 +5,17 @@ import asyncio import hashlib import json +import os import pickle import socket +import tempfile from collections.abc import AsyncIterator +from pathlib import Path from unittest.mock import patch import pytest import torch +import yaml import zmq from transformers import AutoTokenizer from vllm_test_utils.monitor import monitor @@ -991,3 +995,40 @@ def child_thread_func(): child_thread.join(timeout=5) if child_thread.is_alive(): pytest.fail("Child thread failed to exit properly") + + +def test_load_config_file(tmp_path): + # Define the configuration data + config_data = { + "enable-logging": True, + "list-arg": ["item1", "item2"], + "port": 12323, + "tensor-parallel-size": 4 + } + + # Write the configuration data to a temporary YAML file + config_file_path = tmp_path / "config.yaml" + with open(config_file_path, "w") as config_file: + yaml.dump(config_data, config_file) + + # Initialize the parser + parser = FlexibleArgumentParser() + + # Call the function with the temporary file path + processed_args = parser.load_config_file(str(config_file_path)) + + # Expected output + expected_args = [ + "--enable-logging", + "--list-arg", + "item1", + "item2", + "--port", + "12323", + "--tensor-parallel-size", + "4", + ] + + # Assert that the processed arguments match the expected output + assert processed_args == expected_args + os.remove(str(config_file_path)) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 7c34a858c0a2..60bddc5b500b 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1974,7 +1974,7 @@ def _pull_args_from_config(self, args: list[str]) -> list[str]: file_path = args[index + 1] - config_args = self._load_config_file(file_path) + config_args = self.load_config_file(file_path) # 0th index is for {serve,chat,complete} # optionally followed by model_tag (only for serve) @@ -2005,7 +2005,7 @@ def _pull_args_from_config(self, args: list[str]) -> list[str]: return args - def _load_config_file(self, file_path: str) -> list[str]: + def load_config_file(self, file_path: str) -> list[str]: """Loads a yaml file and returns the key value pairs as a flattened list with argparse like pattern ```yaml @@ -2046,6 +2046,11 @@ def _load_config_file(self, file_path: str) -> list[str]: if isinstance(value, bool) and key not in store_boolean_arguments: if value: processed_args.append('--' + key) + elif isinstance(value, list): + if value: + processed_args.append('--' + key) + for item in value: + processed_args.append(str(item)) else: processed_args.append('--' + key) processed_args.append(str(value)) From 69244e67e6822f1c15816f887659e1ccc18c2632 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 27 Aug 2025 14:19:13 +0800 Subject: [PATCH 008/125] [Core] Use key-only cache for `BaseMultiModalProcessor` (#23018) Signed-off-by: DarkLight1337 --- docs/configuration/conserving_memory.md | 2 +- docs/configuration/optimization.md | 44 +- .../multimodal/processing/test_common.py | 8 +- tests/multimodal/test_cache.py | 182 +++++++- vllm/config/__init__.py | 26 +- vllm/engine/arg_utils.py | 14 +- vllm/engine/llm_engine.py | 15 +- vllm/inputs/preprocess.py | 22 +- vllm/inputs/registry.py | 12 +- .../models/hyperclovax_vision.py | 7 +- vllm/model_executor/models/llava.py | 8 +- vllm/model_executor/models/minicpmv.py | 40 +- vllm/model_executor/models/mistral3.py | 8 +- vllm/model_executor/models/phi3v.py | 20 +- vllm/model_executor/models/phi4mm.py | 21 +- vllm/model_executor/models/tarsier.py | 7 +- vllm/multimodal/cache.py | 405 +++++++++++++++++- vllm/multimodal/inputs.py | 38 +- vllm/multimodal/processing.py | 187 ++++---- vllm/multimodal/profiling.py | 4 +- vllm/multimodal/registry.py | 90 ++-- vllm/v1/engine/async_llm.py | 3 +- vllm/v1/engine/core.py | 17 +- vllm/v1/engine/llm_engine.py | 3 +- vllm/v1/engine/mm_input_cache.py | 121 ------ vllm/v1/engine/processor.py | 29 +- vllm/v1/worker/gpu_model_runner.py | 3 + vllm/v1/worker/tpu_model_runner.py | 3 + vllm/v1/worker/utils.py | 9 +- 29 files changed, 954 insertions(+), 394 deletions(-) delete mode 100644 vllm/v1/engine/mm_input_cache.py diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md index 058eba5fe0b1..efda9c8e019e 100644 --- a/docs/configuration/conserving_memory.md +++ b/docs/configuration/conserving_memory.md @@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", If you run out of CPU RAM, try the following options: -- (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process) +- (Multi-modal models only) you can set the size of multi-modal cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). ## Multi-modal input limits diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index bb47e1b90f08..3eaf2185a559 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -204,20 +204,33 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 to avoid CPU resource exhaustion. !!! note - [Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled - because it requires a one-to-one correspondence between API and engine core processes. + API server scale-out disables [multi-modal IPC caching](#ipc-caching) + because it requires a one-to-one correspondance between API and engine core processes. -## Multi-Modal Caching + This does not impact [multi-modal processor caching](#processor-caching). -### Processor Cache +## Multi-Modal Caching -By default, the multi-modal processor cache is enabled to avoid repeatedly processing -the same multi-modal inputs via Hugging Face `AutoProcessor`, +Multi-modal caching avoids repeated transfer or processing of the same multi-modal data, which commonly occurs in multi-turn conversations. -You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` -(default 4 GiB per API process + 4 GiB per engine core process). -If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`. +### Processor Caching + +Multi-modal processor caching is automatically enabled +to avoid repeatedly processing the same multi-modal inputs in `BaseMultiModalProcessor`. + +### IPC Caching + +Multi-modal IPC caching is automatically enabled when +there is a one-to-one correspondance between API (`P0`) and engine core (`P1`) processes, +to avoid repeatedly transferring the same multi-modal inputs between them. + +### Configuration + +You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB). + +If you do not benefit much from the cache, you can disable both IPC +and processor caching completely via `mm_processor_cache_gb=0`. Examples: @@ -230,3 +243,16 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", mm_processor_cache_gb=0) ``` + +### Cache Placement + +Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows: + +| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory | +|-------------------|-------------|------------|------------|-------------| +| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` | +| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` | +| ❌ | ❌ | N/A | N/A | `0` | + +K: Stores the hashes of multi-modal items +V: Stores the processed tensor data of multi-modal items diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 6361cb9b5586..3ff4360b8334 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -14,8 +14,9 @@ from vllm.config import ModelConfig from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.inputs import MultiModalInputs -from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache +from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, cached_tokenizer_from_config, encode_tokens) @@ -63,6 +64,8 @@ def _test_processing_correctness( revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, + # Ensure that the cache can fit all of the data + mm_processor_cache_gb=2048, ) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) @@ -71,8 +74,7 @@ def _test_processing_correctness( model_config, tokenizer=cached_tokenizer_from_config(model_config), ) - # Ensure that it can fit all of the data - cache = ProcessingCache(capacity_gb=2048) + cache = MultiModalProcessorOnlyCache(model_config) processing_info = factories.info(ctx) supported_mm_limits = processing_info.get_supported_mm_limits() diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index 088cd00db2e0..44c05db2278f 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -1,32 +1,64 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import numpy as np import pytest import torch -from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata +from vllm.config import ModelConfig, ParallelConfig, VllmConfig +from vllm.multimodal.cache import (MultiModalCache, + MultiModalProcessorCacheItem, + MultiModalProcessorCacheItemMetadata, + processor_cache_from_config, + receiver_cache_from_config) +from vllm.multimodal.hasher import MultiModalHasher from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, MultiModalKwargsItems, MultiModalSharedField) +from vllm.multimodal.processing import PromptInsertion +from vllm.multimodal.registry import MultiModalRegistry + +def _dummy_elem( + modality: str, + key: str, + size: int, + *, + rng: Optional[np.random.RandomState] = None, +): + if rng is None: + data = torch.empty((size, ), dtype=torch.int8) + else: + data = torch.from_numpy(rng.randint(4, size=(size, ), dtype=np.int8)) -def _dummy_elem(modality: str, key: str, size: int): return MultiModalFieldElem( modality=modality, key=key, - data=torch.empty((size, ), dtype=torch.int8), + data=data, field=MultiModalSharedField(1), ) -def _dummy_item(modality: str, size_by_key: dict[str, int]): +def _dummy_item( + modality: str, + size_by_key: dict[str, int], + *, + rng: Optional[np.random.RandomState] = None, +): return MultiModalKwargsItem.from_elems([ - _dummy_elem(modality, key, size) for key, size in size_by_key.items() + _dummy_elem(modality, key, size, rng=rng) + for key, size in size_by_key.items() ]) -def _dummy_items(size_by_key_modality: dict[str, dict[str, int]]): +def _dummy_items( + size_by_key_modality: dict[str, dict[str, int]], + *, + rng: Optional[np.random.RandomState] = None, +): return MultiModalKwargsItems.from_seq([ - _dummy_item(modality, size_by_key) + _dummy_item(modality, size_by_key, rng=rng) for modality, size_by_key in size_by_key_modality.items() ]) @@ -48,5 +80,139 @@ def test_cache_item_size(item, expected_size): cache[""] = item assert cache.currsize == expected_size - cache[""] = MultiModalCacheItemMetadata.wraps(item) + prompt_update = PromptInsertion("dummy", "target", "insertion") \ + .resolve(0) + + cache[""] = MultiModalProcessorCacheItem(item, [prompt_update]) + assert cache.currsize == expected_size + + cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update]) assert cache.currsize == expected_size + + +def _create_vllm_config( + *, + mm_processor_cache_gb: float, + enable_ipc: bool, +): + return VllmConfig( + model_config=ModelConfig(mm_processor_cache_gb=mm_processor_cache_gb), + parallel_config=ParallelConfig( + data_parallel_size=1 if enable_ipc else 2), + ) + + +def _compare_caches( + config_0: VllmConfig, + config_1: VllmConfig, + *, + item_capacity: int = 8, + hit_rate: float = 0.5, + max_items_per_iter: int = 3, + is_cached_calls_per_iter: int, + n_iter: int = 100, + seed: int = 0, +): + mm_registry = MultiModalRegistry() + cache_0_p0 = processor_cache_from_config(config_0, mm_registry) + cache_0_p1 = receiver_cache_from_config(config_0, mm_registry) + cache_1_p0 = processor_cache_from_config(config_1, mm_registry) + cache_1_p1 = receiver_cache_from_config(config_1, mm_registry) + + cache_size_gb = max( + config_0.model_config.mm_processor_cache_gb, + config_1.model_config.mm_processor_cache_gb, + ) + item_size_gb = int(cache_size_gb / item_capacity) + + rng = np.random.RandomState(seed) + all_items = [ + _dummy_item("item", {"key": item_size_gb}, rng=rng) + for _ in range(int(item_capacity / hit_rate)) + ] + all_hashes = [ + MultiModalHasher.hash_kwargs(item=item.get_data()) + for item in all_items + ] + + # Should not be used since there is nothing to convert to text + prompt_update = PromptInsertion("dummy", "target", "insertion") + + for it in range(n_iter): + num_items_to_select = rng.randint(0, max_items_per_iter) + item_idxs_to_select = rng.choice(len(all_items), num_items_to_select) + + selected_items = [all_items[idx] for idx in item_idxs_to_select] + selected_hashes = [all_hashes[idx] for idx in item_idxs_to_select] + + if cache_0_p0 is None: + cache_0_p0_out = selected_items + else: + for _ in range(is_cached_calls_per_iter): + cache_0_p0.is_cached(selected_hashes) + cache_0_p0_out = [ + item for item, _ in cache_0_p0.get_and_update( + [(item, prompt_update.content) for item in selected_items], + selected_hashes, + ) + ] + + if cache_1_p0 is None: + cache_1_p0_out = selected_items + else: + for _ in range(is_cached_calls_per_iter): + cache_1_p0.is_cached(selected_hashes) + cache_1_p0_out = [ + item for item, _ in cache_1_p0.get_and_update( + [(item, prompt_update.content) for item in selected_items], + selected_hashes, + ) + ] + + if cache_0_p1 is None: + cache_0_p1_out = cache_0_p0_out + else: + cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out, + selected_hashes) + + if cache_1_p1 is None: + cache_1_p1_out = cache_1_p0_out + else: + cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out, + selected_hashes) + + assert cache_0_p1_out == cache_1_p1_out, f"Failed at {it=}" + + +@pytest.mark.parametrize("is_cached_calls_per_iter", [1, 2, 3]) +def test_ipc_enable_disable_consistency(is_cached_calls_per_iter): + cache_size_gb = 1 / (1 << 20) + + vllm_config_ipc_enabled = _create_vllm_config( + mm_processor_cache_gb=cache_size_gb, + enable_ipc=True, + ) + vllm_config_ipc_disabled = _create_vllm_config( + mm_processor_cache_gb=0, + enable_ipc=False, + ) + vllm_config_cache_disabled = _create_vllm_config( + mm_processor_cache_gb=cache_size_gb, + enable_ipc=True, + ) + + _compare_caches( + vllm_config_ipc_enabled, + vllm_config_ipc_disabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) + _compare_caches( + vllm_config_ipc_disabled, + vllm_config_cache_disabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) + _compare_caches( + vllm_config_cache_disabled, + vllm_config_ipc_enabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index cd0e17977ede..ac6f51df9549 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -437,7 +437,7 @@ class ModelConfig: from `AutoProcessor.from_pretrained`. The available overrides depend on the model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. """ - mm_processor_cache_gb: int = 4 + mm_processor_cache_gb: float = 4 """The size (in GiB) of the multi-modal processor cache, which is used to avoid re-processing past multi-modal inputs. @@ -884,12 +884,6 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: return None - def set_mm_processor_cache_gb(self, value: int) -> None: - mm_config = self.get_multimodal_config() - - self.mm_processor_cache_gb = value - mm_config.mm_processor_cache_gb = value - def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( self.model, self.revision) @@ -1697,22 +1691,6 @@ def uses_mrope(self) -> bool: def is_multimodal_model(self) -> bool: return self.multimodal_config is not None - @property - def enable_mm_processor_cache(self) -> bool: - """Whether the multi-modal processor cache should be enabled.""" - mm_config = self.multimodal_config - if mm_config is None: - return False - - return mm_config.mm_processor_cache_gb > 0 - - def get_mm_input_cache_gb(self) -> int: - mm_config = self.multimodal_config - if mm_config is None: - return 0 - - return envs.VLLM_MM_INPUT_CACHE_GIB - @property def is_cross_encoder(self) -> bool: return (self._model_info.supports_cross_encoding @@ -2561,7 +2539,7 @@ class MultiModalConfig: `{"num_crops": 4}`. """ - mm_processor_cache_gb: int = 4 + mm_processor_cache_gb: float = 4 """ The size (in GiB) of the multi-modal processor cache, which is used to diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f24c50ad7326..9e7c95ea5205 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -351,7 +351,7 @@ class EngineArgs: mm_processor_kwargs: Optional[Dict[str, Any]] = \ MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED - mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb + mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling # LoRA fields @@ -1293,18 +1293,6 @@ def create_engine_config( worker_extension_cls=self.worker_extension_cls, ) - if model_config.is_multimodal_model: - dp_supports_mm_processor_cache = (self.data_parallel_size == 1 - or data_parallel_external_lb) - if (not dp_supports_mm_processor_cache - and model_config.mm_processor_cache_gb > 0): - logger.warning( - "Multi-modal processor cache is disabled because " - "it is not compatible with data parallelism when " - "there does not exist a one-to-one correspondance " - "between API and engine core processes.") - model_config.set_mm_processor_cache_gb(0) - speculative_config = self.create_speculative_config( target_model_config=model_config, target_parallel_config=parallel_config, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cbd714c159eb..03c2f0375da4 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -36,6 +36,7 @@ from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.outputs import (PoolingRequestOutput, RequestOutput, RequestOutputFactory) @@ -250,9 +251,13 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.generation_config_fields = ( self.model_config.try_get_generation_config()) - self.input_preprocessor = InputPreprocessor(self.model_config, - self.tokenizer, - mm_registry) + self.input_preprocessor = InputPreprocessor( + self.model_config, + self.tokenizer, + mm_registry, + mm_processor_cache=processor_only_cache_from_config( + self.model_config, mm_registry), + ) self.model_executor = executor_class(vllm_config=vllm_config) @@ -840,8 +845,8 @@ def has_unfinished_requests_for_virtual_engine( def reset_mm_cache(self) -> bool: """Reset the multi-modal cache.""" - return self.input_preprocessor.mm_registry.reset_processor_cache( - self.model_config) + self.input_preprocessor.clear_cache() + return True def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: """Reset prefix cache for all devices.""" diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 3f521012e82a..f0d0cab3df3d 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -11,6 +11,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalInputs) from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -32,12 +33,14 @@ def __init__( model_config: ModelConfig, tokenizer: Optional[TokenizerGroup], mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None, ) -> None: super().__init__() self.model_config = model_config self.tokenizer = tokenizer self.mm_registry = mm_registry + self.mm_processor_cache = mm_processor_cache def get_tokenizer_group(self) -> TokenizerGroup: if self.tokenizer is None: @@ -261,8 +264,11 @@ def _process_multimodal( """ tokenizer = self._get_mm_tokenizer(lora_request) - mm_processor = self.mm_registry.create_processor(self.model_config, - tokenizer=tokenizer) + mm_processor = self.mm_registry.create_processor( + self.model_config, + tokenizer=tokenizer, + cache=self.mm_processor_cache, + ) if mm_processor_kwargs is None: mm_processor_kwargs = {} @@ -286,8 +292,12 @@ async def _process_multimodal_async( """ tokenizer = await self._get_mm_tokenizer_async(lora_request) - mm_processor = self.mm_registry.create_processor(self.model_config, - tokenizer=tokenizer) + mm_processor = self.mm_registry.create_processor( + self.model_config, + tokenizer=tokenizer, + cache=self.mm_processor_cache, + ) + if mm_processor_kwargs is None: mm_processor_kwargs = {} @@ -860,3 +870,7 @@ async def preprocess_async( tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, ) + + def clear_cache(self) -> None: + if self.mm_processor_cache is not None: + self.mm_processor_cache.clear_cache() diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index ef146fdfbf97..f0b392e9767a 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -223,20 +223,26 @@ def dummy_data_for_profiling( The model is identified by ``model_config``. """ # Avoid circular import + from vllm.multimodal.cache import processor_only_cache_from_config from vllm.sequence import SequenceData if not model_config.is_multimodal_model: seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) return DummyData(seq_data=seq_data) + cache = processor_only_cache_from_config(model_config, mm_registry) + # Encoder dummy data does not contain multi-modal data if is_encoder_data: - enc_data = mm_registry.get_encoder_dummy_data( - model_config, seq_len) + enc_data = mm_registry.get_encoder_dummy_data(model_config, + seq_len, + cache=cache) seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids) return DummyData(seq_data=seq_data) - dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len) + dec_data = mm_registry.get_decoder_dummy_data(model_config, + seq_len, + cache=cache) return DummyData( seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index eeb8291c7784..53f0585541b1 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -33,12 +33,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -367,7 +368,7 @@ def _build_hcxvision_hf_processor( info: HCXVisionProcessingInfo, dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, HCXVisionProcessingInfo): return HCXVisionMultiModalProcessor( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index bc53982c938c..0ee26b68345c 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -22,14 +22,14 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves @@ -394,7 +394,7 @@ def _build_llava_or_pixtral_hf_processor( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, PixtralHFProcessingInfo): return PixtralHFMultiModalProcessor( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index a2a71bdd12b3..c22d871ab20d 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -58,7 +58,8 @@ VideoItem, VideoProcessorItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) + PromptUpdate, PromptUpdateDetails, + ResolvedPromptUpdate, _seq2text) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -744,6 +745,43 @@ def get_video_replacement(item_idx: int): for modality, pattern in placeholders ] + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + tokenizer = self.info.get_tokenizer() + image_processor = self.info.get_image_processor() + version = self.info.get_model_version() + + text = _seq2text(tokenizer, cached_update.content.full) + prev_item_idx = cached_update.item_idx + + if version == (2, 0) or version == (2, 5): + im_start = image_processor.im_start_token + im_end = image_processor.im_end_token + else: + im_start = image_processor.im_id_start + im_end = image_processor.im_id_end + + new_update = new_update.with_content( + PromptUpdateDetails.select_text( + text.replace( + f"{im_start}{prev_item_idx}{im_end}", + f"{im_start}{new_item_idx}{im_end}", + 1, + ), + "", + )) + + return new_update + def _get_mm_fields_config( self, hf_inputs: BatchFeature, diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 438513433d3b..08948960b275 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -22,14 +22,14 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -322,7 +322,7 @@ def _build_mistral3_processor( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: assert isinstance(info, Mistral3ProcessingInfo) return Mistral3MultiModalProcessor( diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 61e09d56046c..4522c7043d01 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -41,7 +41,8 @@ BaseProcessingInfo, MultiModalPromptUpdates, PlaceholderFeaturesInfo, - PromptReplacement, PromptUpdate) + PromptReplacement, PromptUpdate, + ResolvedPromptUpdate) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -440,6 +441,23 @@ def get_replacement_phi3v(item_idx: int): ) ] + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + hf_processor = self.info.get_hf_processor() + image_tokens: list[str] = hf_processor.img_tokens # type: ignore + new_update = new_update.with_target(image_tokens[new_item_idx]) + + return new_update + def _apply_prompt_updates( self, token_ids: list[int], diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 5129770e8d49..211cbd9c819c 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -27,7 +27,7 @@ MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptUpdate) + PromptUpdate, ResolvedPromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -850,6 +850,25 @@ def get_audio_replacement_phi4mm(item_idx: int): ), ] + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + image_tokens: list[str] = self.info.image_tokens # type: ignore + new_update = new_update.with_target(image_tokens[new_item_idx]) + elif cached_update.modality == "audio": + audio_tokens: list[str] = self.info.audio_tokens # type: ignore + new_update = new_update.with_target(audio_tokens[new_item_idx]) + + return new_update + @MULTIMODAL_REGISTRY.register_processor( Phi4MMMultiModalProcessor, diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 9b9cca8c6bd3..c66867315e55 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -25,12 +25,13 @@ from vllm.model_executor.models.llava import LlavaDummyInputsBuilder from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves @@ -332,7 +333,7 @@ def _build_tarsier_hf_processor( info: _I_Tarsier, dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, TarsierProcessingInfo): return TarsierMultiModalProcessor( diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 5cec8e71fb26..0e81cb6d4d19 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import sys -from collections.abc import Mapping -from dataclasses import dataclass -from typing import TypeVar, Union +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union import torch +from typing_extensions import TypeAlias, override from vllm.logger import init_logger from vllm.utils import GiB_bytes, LRUCache @@ -15,24 +16,67 @@ MultiModalKwargsItem, MultiModalKwargsItems, NestedTensors) +if TYPE_CHECKING: + from vllm.config import ModelConfig, VllmConfig + + from .processing import ResolvedPromptUpdate + from .registry import MultiModalRegistry + logger = init_logger(__name__) -@dataclass -class MultiModalCacheItemMetadata: - size: int +class MultiModalProcessorCacheItem: + """ + The data to store inside `MultiModalProcessorOnlyCache`. - @classmethod - def wraps(cls, value: "MultiModalCacheValue"): - return cls(size=MultiModalCache.get_item_size(value)) + Args: + item: The processed tensor data corresponding to a multi-modal item. + prompt_updates: The prompt updates corresponding to `item`. + """ + + def __init__( + self, + item: MultiModalKwargsItem, + prompt_updates: Sequence["ResolvedPromptUpdate"], + ) -> None: + super().__init__() + + self.item = item + self.prompt_updates = prompt_updates + + +class MultiModalProcessorCacheItemMetadata: + """ + The metadata to store inside `MultiModalProcessorSenderCache`. + + Args: + item: The processed tensor data corresponding to a multi-modal item. + Since P1 already stores the tensor data, we only store its size + metadata in P0 to reduce memory usage. The size metadata is still + needed to keep the same cache eviction policy as P0. + prompt_updates: The prompt updates corresponding to `item`. + This needs to stay on P0 because for some models, they are + dependent on the processed tensor data (cached on P1). + """ + + def __init__( + self, + item: MultiModalKwargsItem, + prompt_updates: Sequence["ResolvedPromptUpdate"], + ) -> None: + super().__init__() + + self.item_size = MultiModalCache.get_item_size(item) + self.prompt_updates = prompt_updates MultiModalCacheValue = Union[ + MultiModalProcessorCacheItem, + MultiModalProcessorCacheItemMetadata, MultiModalKwargsItems, MultiModalKwargsItem, MultiModalKwargs, Mapping[str, NestedTensors], - MultiModalCacheItemMetadata, ] _V = TypeVar("_V", bound=MultiModalCacheValue) @@ -47,8 +91,10 @@ def get_leaf_size( *, debug: bool = False, ) -> int: - if isinstance(leaf, MultiModalFieldElem): - return cls.get_item_size(leaf.data) # type: ignore + if isinstance(leaf, MultiModalProcessorCacheItem): + return cls.get_leaf_size(leaf.item) + if isinstance(leaf, MultiModalProcessorCacheItemMetadata): + return leaf.item_size # These are not subclasses of dict if isinstance(leaf, MultiModalKwargsItems): @@ -58,13 +104,13 @@ def get_leaf_size( if isinstance(leaf, MultiModalKwargs): return cls.get_item_size(leaf.data) # type: ignore + if isinstance(leaf, MultiModalFieldElem): + return cls.get_item_size(leaf.data) # type: ignore + # sys.getsizeof doesn't work for tensors if isinstance(leaf, torch.Tensor): return leaf.nbytes - if isinstance(leaf, MultiModalCacheItemMetadata): - return leaf.size - return sys.getsizeof(leaf) @classmethod @@ -98,3 +144,332 @@ def get_lru_cache( GiB_bytes * capacity_gb, getsizeof=lambda x: cls.get_item_size(x, debug=debug), ) + + +_I = TypeVar("_I", contravariant=True) +_O = TypeVar("_O", covariant=True) + + +class BaseMultiModalCache(ABC, Generic[_I, _O]): + """ + Abstract base class to read/write multi-modal items from cache. + + The idea of multi-modal caching is based on having a client and server + where the client executes in the frontend process (=P0) and + the server in the core process (=P1). The data flow is as follows: + + ``` + is_cached() x N get_and_update() + P0: From API -----------------> -----------------> To P1 + + get_and_update() + P1: From P0 -----------------> To model + ``` + + `is_cached()` can be called any number of times in P0. However, + `get_and_update()` must be called in P0 and P1 one after another + so that their cache eviction order remains the same. + + This ensures that the keys in P0 and P1 caches are mirrored, + allowing us to determine whether a key is cached in P1 by looking + up the P0 cache, without having to communicate with P1. + """ + + @abstractmethod + def get_and_update_item( + self, + mm_item: _I, + mm_hash: str, + ) -> _O: + """ + Possibly update a multi-modal item based on whether it is + in the underlying cache. + + This update is done out-of-place and updates the cache eviction order. + + Args: + mm_item: The multi-modal item to update. + mm_hash: The hash of `mm_item`. + + Returns: + The update multi-modal item. + """ + raise NotImplementedError + + def get_and_update( + self, + mm_items: Sequence[_I], + mm_hashes: list[str], + ) -> list[_O]: + """ + Possibly update a sequence of multi-modal items based on whether they + are in the underlying cache. + + This update is done out-of-place and updates the cache eviction order. + + Args: + mm_items: The multi-modal items to update. + mm_hashes: The hash of each item in `mm_items`. + + Returns: + A new list of updated multi-modal items. + """ + assert len(mm_items) == len(mm_hashes) + + return [ + self.get_and_update_item(mm_item, mm_hash) + for mm_item, mm_hash in zip(mm_items, mm_hashes) + ] + + @abstractmethod + def clear_cache(self) -> None: + """Clear the underlying cache.""" + raise NotImplementedError + + +MultiModalProcessorCacheInItem: TypeAlias = \ + Optional[tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]]] + + +MultiModalProcessorCacheOutItem: TypeAlias = \ + tuple[Optional[MultiModalKwargsItem], Sequence["ResolvedPromptUpdate"]] + + +class BaseMultiModalProcessorCache( + BaseMultiModalCache[MultiModalProcessorCacheInItem, + MultiModalProcessorCacheOutItem]): + """The required interface for caches on P0.""" + + @abstractmethod + def is_cached_item(self, mm_hash: str) -> bool: + """ + Check whether a multi-modal item is + in the underlying cache. + + This **DOES NOT** update the cache eviction order. + + Args: + mm_hash: The hash of the item to check. + + Returns: + `True` if the item is cached, otherwise `False`. + """ + raise NotImplementedError + + def is_cached(self, mm_hashes: list[str]) -> list[bool]: + """ + Check whether a sequence of multi-modal items are + in the underlying cache. + + This **DOES NOT** update the cache eviction order. + + Args: + mm_hashes: The hash of each item to check. + + Returns: + For each item, `True` if the item is cached, otherwise `False`. + """ + return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes] + + +class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is disabled. + + How to update each item: + + - If the item is in the cache, replace the input with the cached item. + - If the item is not in the cache, store that item (which includes + tensor data and metadata) into the cache, and return the input. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalProcessorCacheItem, + ) + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return mm_hash in self._cache + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return cached_item.item, cached_item.prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item) + + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is enabled. + + How to update each item: + + - If the item is already in the cache, clear the input to avoid + unnecessary IPC. + + - If the item is not in the cache, store the metadata of that item so + that the eviction policy remains the same as the cache on P1, + and return the input. + By only storing the metadata, we avoid keeping the data itself in + memory inside P0. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalProcessorCacheItemMetadata, + ) + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return mm_hash in self._cache + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return None, cached_item.prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item) + + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +def _enable_processor_cache( + model_config: "ModelConfig", + mm_registry: "MultiModalRegistry", +) -> bool: + if not mm_registry.supports_multimodal_inputs(model_config): + return False + + mm_config = model_config.get_multimodal_config() + return mm_config.mm_processor_cache_gb > 0 + + +def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool: + parallel_config = vllm_config.parallel_config + supports_ipc_cache = (parallel_config.data_parallel_size == 1 + or parallel_config.data_parallel_external_lb) + + return supports_ipc_cache + + +def processor_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", +) -> Optional[BaseMultiModalProcessorCache]: + """Return a `BaseMultiModalProcessorCache`, if enabled.""" + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return MultiModalProcessorOnlyCache(model_config) + + return MultiModalProcessorSenderCache(model_config) + + +def processor_only_cache_from_config( + model_config: "ModelConfig", + mm_registry: "MultiModalRegistry", +): + """Return a `MultiModalProcessorOnlyCache`, if enabled.""" + if not _enable_processor_cache(model_config, mm_registry): + return None + + return MultiModalProcessorOnlyCache(model_config) + + +class BaseMultiModalReceiverCache( + BaseMultiModalCache[Optional[MultiModalKwargsItem], + MultiModalKwargsItem]): + """The required interface for caches on P1.""" + + +class MultiModalReceiverCache(BaseMultiModalReceiverCache): + """ + The cache which is used on P1 when IPC caching is enabled. + + How to update each item: + + - If the item is in the cache, replace the input with the cached item. + - If the item is not in the cache, store that item (which includes tensor + data) into the cache, and return the input. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalKwargsItem, + ) + + @override + def get_and_update_item( + self, + mm_item: Optional[MultiModalKwargsItem], + mm_hash: str, + ) -> MultiModalKwargsItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return cached_item + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = mm_item + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +def receiver_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", +) -> Optional[BaseMultiModalReceiverCache]: + """Return a `BaseMultiModalReceiverCache`, if enabled.""" + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return None + + return MultiModalReceiverCache(model_config) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 581f9a109cce..2c0ebaced67e 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -7,11 +7,11 @@ from dataclasses import dataclass from functools import partial from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, - Union, cast, final) +from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, + cast, final) import numpy as np -from typing_extensions import NotRequired, TypeAlias, deprecated +from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated from vllm.utils import LazyLoader, full_groupby, is_list_of from vllm.utils.jsontree import JSONTree, json_map_leaves @@ -668,7 +668,15 @@ def get_data(self) -> dict[str, NestedTensors]: return {key: elem.data for key, elem in self.items()} -class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]): +_I = TypeVar( + "_I", + MultiModalKwargsItem, + Optional[MultiModalKwargsItem], + default=MultiModalKwargsItem, +) + + +class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): """ A dictionary of [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s @@ -714,27 +722,37 @@ def from_seq(items: Sequence[MultiModalKwargsItem]): items_by_modality = full_groupby(items, key=lambda x: x.modality) return MultiModalKwargsItems(items_by_modality) - def __getitem__(self, modality: str): + def __getitem__(self, modality: str) -> Sequence[_I]: if modality not in self: raise KeyError(f"Modality {modality!r} not found. " f"Available modalities: {set(self.keys())}") - return super().__getitem__(modality) + return super().__getitem__(modality) # type: ignore[return-value] def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs": elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) - for items in self.values(): - for item in items: + for modality, items in self.items(): + for i, item in enumerate(items): + if item is None: + raise RuntimeError("Cannot build data from empty " + f"mm_items[{modality}][{i}]") + for key, elem in item.items(): elems_by_key[key].append(elem) return MultiModalKwargs({ key: elems[0].field.reduce_data(elems, pin_memory=pin_memory) - for key, elems in elems_by_key.items() if len(elems) > 0 + for key, elems in elems_by_key.items() }) +MultiModalKwargsOptionalItems: TypeAlias = Union[ + MultiModalKwargsItems[MultiModalKwargsItem], + MultiModalKwargsItems[Optional[MultiModalKwargsItem]], +] + + class MultiModalKwargs(UserDict[str, NestedTensors]): """ A dictionary that represents the keyword arguments to @@ -898,7 +916,7 @@ class MultiModalInputs(TypedDict): token_type_ids: NotRequired[list[int]] """The token type IDs of the prompt.""" - mm_kwargs: MultiModalKwargsItems + mm_kwargs: MultiModalKwargsOptionalItems """Keyword arguments to be directly passed to the model after batching.""" mm_hashes: "MultiModalHashDict" diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 8c225e2a3c08..6ecdf80d4aa6 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -4,7 +4,7 @@ from collections import defaultdict from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, Sequence) -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from enum import Enum from functools import lru_cache from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, @@ -20,12 +20,11 @@ encode_tokens) from vllm.utils import flatten_2d_lists, full_groupby -from .cache import MultiModalCache from .hasher import MultiModalHasher from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargsItem, MultiModalKwargsItems, - PlaceholderRange) + MultiModalKwargsOptionalItems, PlaceholderRange) from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, MultiModalDataParser) @@ -34,6 +33,7 @@ from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessorMixin + from .cache import BaseMultiModalProcessorCache from .profiling import BaseDummyInputsBuilder logger = init_logger(__name__) @@ -557,6 +557,15 @@ def iter_matches( return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx) + def with_target(self, target: UpdateTarget): + return replace(self, target=target) + + def with_content(self, content: PromptUpdateInfo): + if not isinstance(content, PromptUpdateDetails): + content = PromptUpdateDetails.from_seq(content) + + return replace(self, content=content) + class _TokenMatch(NamedTuple): start_idx: int @@ -865,21 +874,6 @@ def find_mm_placeholders( return dict(full_groupby_modality(it)) -class ProcessingCache(MultiModalCache): - - def __init__(self, capacity_gb: float) -> None: - super().__init__() - - self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem) - - self.get = self._cache.get - self.put = self._cache.put - self.reset = self._cache.clear - - -_CacheItemOrHash = Union[MultiModalKwargsItem, str] - - class BaseProcessingInfo: """Base class to provide the information necessary for data processing.""" @@ -982,7 +976,7 @@ def get_mm_max_tokens_per_item( class MultiModalProcessingInfo(NamedTuple): - kwargs: MultiModalKwargsItems + kwargs: MultiModalKwargsOptionalItems hashes: MultiModalHashes prompt_updates: MultiModalPromptUpdates @@ -994,11 +988,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): Not to be confused with `transformers.ProcessorMixin`. """ - def __init__(self, - info: _I, - dummy_inputs: "BaseDummyInputsBuilder[_I]", - *, - cache: Optional[ProcessingCache] = None) -> None: + def __init__( + self, + info: _I, + dummy_inputs: "BaseDummyInputsBuilder[_I]", + *, + cache: Optional["BaseMultiModalProcessorCache"] = None, + ) -> None: super().__init__() self.info = info @@ -1355,32 +1351,6 @@ def _apply_hf_processor_main( return prompt_ids, mm_processed_data, False - def _get_cache_missing_items( - self, - cache: ProcessingCache, - mm_data_items: MultiModalDataItems, - mm_hashes: MultiModalHashes, - ) -> tuple[dict[str, list[_CacheItemOrHash]], MultiModalDataItems]: - mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]] = { - modality: [(h if (v := cache.get(h)) is None else v) - for h in hashes] - for modality, hashes in mm_hashes.items() - } - - mm_missing_idxs = { - modality: [ - idx for idx, item_or_hash in enumerate(items_or_hashes) - if isinstance(item_or_hash, str) - ] - for modality, items_or_hashes in mm_cache_items_or_hashes.items() - } - mm_missing_data = { - modality: [mm_data_items[modality][idx] for idx in idxs] - for modality, idxs in mm_missing_idxs.items() - } - - return mm_cache_items_or_hashes, self._to_mm_items(mm_missing_data) - def _hash_mm_items( self, mm_items: MultiModalDataItems, @@ -1401,28 +1371,92 @@ def _hash_mm_items( for modality, items in mm_items.items() } + def _get_cache_missing_items( + self, + cache: "BaseMultiModalProcessorCache", + mm_data_items: MultiModalDataItems, + mm_hashes: MultiModalHashes, + ) -> MultiModalDataItems: + mm_is_cached = { + modality: cache.is_cached(hashes) + for modality, hashes in mm_hashes.items() + } + + mm_missing_idxs = { + modality: [ + idx for idx, item_is_cached in enumerate(items_is_cached) + if not item_is_cached + ] + for modality, items_is_cached in mm_is_cached.items() + } + mm_missing_data = { + modality: [mm_data_items[modality][idx] for idx in idxs] + for modality, idxs in mm_missing_idxs.items() + } + + return self._to_mm_items(mm_missing_data) + + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + """ + Override this if other attributes of `ResolvedPromptUpdate` + also need to be recomputed after retrieving from the cache. + """ + return replace(cached_update, item_idx=new_item_idx) + def _merge_mm_kwargs( self, - cache: ProcessingCache, - mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]], + cache: "BaseMultiModalProcessorCache", + mm_hashes: MultiModalHashes, mm_missing_kwargs: MultiModalKwargsItems, - ) -> MultiModalKwargsItems: + mm_missing_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]: + # Need to calculate this at the beginning to avoid skipping cache logic + # for subsequently repeated items in the same modality + mm_is_cached = { + modality: cache.is_cached(hashes) + for modality, hashes in mm_hashes.items() + } + mm_missing_next_idx = defaultdict[str, int](lambda: 0) - merged_items = defaultdict[str, list[MultiModalKwargsItem]](list) - for modality, items_or_hashes in mm_cache_items_or_hashes.items(): - for item_or_hash in items_or_hashes: - if isinstance(item_or_hash, str): - kw_item = mm_missing_kwargs[modality][ - mm_missing_next_idx[modality]] - cache.put(item_or_hash, kw_item) + merged_kwargs = defaultdict[str, + list[Optional[MultiModalKwargsItem]]](list) + merged_prompt_updates = defaultdict[ + str, list[Sequence[ResolvedPromptUpdate]]](list) + for modality, hashes in mm_hashes.items(): + missing_kwargs = mm_missing_kwargs.get(modality, []) + missing_prompt_updates = mm_missing_prompt_updates.get( + modality, []) + + for item_idx, item_hash in enumerate(hashes): + kwargs: Optional[MultiModalKwargsItem] + if not mm_is_cached[modality][item_idx]: + missing_next_idx = mm_missing_next_idx[modality] + kwargs = missing_kwargs[missing_next_idx] + updates = missing_prompt_updates[missing_next_idx] + mm_missing_next_idx[modality] += 1 + + item = kwargs, updates else: - kw_item = item_or_hash + item = None + + kwargs, updates = cache.get_and_update_item(item, item_hash) + + merged_kwargs[modality].append(kwargs) + merged_prompt_updates[modality].append([ + self._recompute_cached_prompt_update(update, item_idx) + for update in updates + ]) - merged_items[modality].append(kw_item) + mm_kwargs = MultiModalKwargsItems(merged_kwargs) + mm_prompt_updates = dict(merged_prompt_updates) - return MultiModalKwargsItems(merged_items) + return mm_kwargs, mm_prompt_updates def _apply_hf_processor( self, @@ -1490,10 +1524,8 @@ def _cached_apply_hf_processor( mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, tokenization_kwargs) - ( - mm_cache_items_or_hashes, - mm_missing_data_items, - ) = self._get_cache_missing_items( + + mm_missing_data_items = self._get_cache_missing_items( cache=cache, mm_data_items=mm_data_items, mm_hashes=mm_hashes, @@ -1520,16 +1552,17 @@ def _cached_apply_hf_processor( hf_processor_mm_kwargs), ) - mm_kwargs = self._merge_mm_kwargs( - cache, - mm_cache_items_or_hashes=mm_cache_items_or_hashes, - mm_missing_kwargs=mm_missing_kwargs, + mm_missing_prompt_updates = self._get_mm_prompt_updates( + mm_missing_data_items, + hf_processor_mm_kwargs, + mm_missing_kwargs, ) - mm_prompt_updates = self._get_mm_prompt_updates( - mm_data_items, - hf_processor_mm_kwargs, - mm_kwargs, + mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs( + cache, + mm_hashes=mm_hashes, + mm_missing_kwargs=mm_missing_kwargs, + mm_missing_prompt_updates=mm_missing_prompt_updates, ) mm_info = MultiModalProcessingInfo( @@ -1614,7 +1647,7 @@ def _apply_prompt_updates( def _validate_mm_kwargs( self, - mm_kwargs: MultiModalKwargsItems, + mm_kwargs: MultiModalKwargsOptionalItems, mm_item_counts: Mapping[str, int], ) -> None: for modality, item_count in mm_item_counts.items(): @@ -1655,7 +1688,7 @@ def _maybe_apply_prompt_updates( self, mm_items: MultiModalDataItems, prompt_ids: list[int], - mm_kwargs: MultiModalKwargsItems, + mm_kwargs: MultiModalKwargsOptionalItems, mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index ea2efbdd8b52..ffc69a2db60a 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -13,7 +13,7 @@ from vllm.logger import init_logger from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalInputs, MultiModalKwargsItems, + MultiModalInputs, MultiModalKwargsOptionalItems, MultiModalPlaceholderDict) from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, EncDecMultiModalProcessor) @@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple): """Dummy data used for profiling.""" prompt_token_ids: list[int] - multi_modal_data: MultiModalKwargsItems + multi_modal_data: MultiModalKwargsOptionalItems multi_modal_placeholders: MultiModalPlaceholderDict diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 8cd9e5604872..38adbf8f3536 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping from dataclasses import dataclass -from functools import lru_cache from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar import torch.nn as nn @@ -13,8 +12,9 @@ cached_tokenizer_from_config) from vllm.utils import ClassRegistry -from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, - ProcessingCache) +from .cache import (BaseMultiModalProcessorCache, + processor_only_cache_from_config) +from .processing import BaseMultiModalProcessor, BaseProcessingInfo from .profiling import (BaseDummyInputsBuilder, DummyDecoderData, DummyEncoderData, MultiModalProfiler) @@ -65,7 +65,7 @@ def __call__( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor[_I]: ... @@ -80,20 +80,13 @@ def build_processor( self, ctx: InputProcessingContext, *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ): info = self.info(ctx) dummy_inputs_builder = self.dummy_inputs(info) return self.processor(info, dummy_inputs_builder, cache=cache) -# Make sure a different cache is used for each model config -# NOTE: ModelConfig is not hashable so it cannot be passed directly -@lru_cache(maxsize=1) -def _get_processor_cache(model_id: str, capacity_gb: int): - return ProcessingCache(capacity_gb) if capacity_gb > 0 else None - - class MultiModalRegistry: """ A registry that dispatches data processing according to the model. @@ -103,31 +96,6 @@ def __init__(self) -> None: self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]() - def _get_processor_cache(self, model_config: "ModelConfig"): - model_id = model_config.model - capacity_gb = model_config.mm_processor_cache_gb - return _get_processor_cache(model_id, capacity_gb) - - def reset_processor_cache(self, model_config: "ModelConfig") -> bool: - """Reset the multi-modal processing cache.""" - if processor_cache := self._get_processor_cache(model_config): - processor_cache.reset() - - return True # Success - - def enable_mm_input_cache(self, model_config: "ModelConfig") -> bool: - """Whether the multi-modal input cache should be enabled. - NOTE: This is put under MultiModalRegistry on purpose to respect - text-only mode for multimodal models. - """ - - if not self.supports_multimodal_inputs(model_config): - return False - - mm_config = model_config.get_multimodal_config() - - return mm_config.mm_processor_cache_gb > 0 - def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool: """ Checks if the model supports multimodal inputs. @@ -157,6 +125,8 @@ def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool: def get_max_tokens_per_item_by_modality( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of tokens per data item from each modality based @@ -165,11 +135,11 @@ def get_max_tokens_per_item_by_modality( if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) seq_len = model_config.max_model_len - mm_limits = self.get_mm_limits_per_prompt(model_config) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) return profiler.get_mm_max_contiguous_tokens( seq_len, @@ -182,6 +152,8 @@ def get_max_tokens_per_item_by_modality( def get_max_tokens_per_item_by_nonzero_modality( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of tokens per data item from each modality based @@ -192,15 +164,19 @@ def get_max_tokens_per_item_by_nonzero_modality( This is currently directly used only in V1 for profiling the memory usage of a model. """ - mm_limits = self.get_mm_limits_per_prompt(model_config) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) + max_tokens_per_item = self.get_max_tokens_per_item_by_modality( + model_config, + cache=cache, + ) return { key: max_tokens_per_mm_item - for key, max_tokens_per_mm_item in - self.get_max_tokens_per_item_by_modality(model_config).items() + for key, max_tokens_per_mm_item in max_tokens_per_item.items() if mm_limits[key] > 0 } + # TODO: Remove once V0 is gone def get_max_tokens_by_modality( self, model_config: "ModelConfig", @@ -209,14 +185,19 @@ def get_max_tokens_by_modality( Get the maximum number of tokens from each modality for profiling the memory usage of a model. """ - mm_limits = self.get_mm_limits_per_prompt(model_config) + cache = processor_only_cache_from_config(model_config, self) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) + max_tokens_per_item = self.get_max_tokens_per_item_by_modality( + model_config, + cache=cache, + ) return { key: mm_limits[key] * max_tokens_per_mm_item - for key, max_tokens_per_mm_item in - self.get_max_tokens_per_item_by_modality(model_config).items() + for key, max_tokens_per_mm_item in max_tokens_per_item.items() } + # TODO: Remove once V0 is gone def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ Get the maximum number of multi-modal tokens @@ -227,6 +208,8 @@ def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: def get_mm_limits_per_prompt( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of multi-modal input instances for each modality @@ -235,7 +218,7 @@ def get_mm_limits_per_prompt( if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) return profiler.get_mm_limits() @@ -303,7 +286,7 @@ def create_processor( model_config: "ModelConfig", *, tokenizer: Optional[AnyTokenizer] = None, - disable_cache: Optional[bool] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor[BaseProcessingInfo]: """ Create a multi-modal processor for a specific model and tokenizer. @@ -311,15 +294,10 @@ def create_processor( if not model_config.is_multimodal_model: raise ValueError(f"{model_config.model} is not a multimodal model") - if disable_cache is None: - disable_cache = not model_config.enable_mm_processor_cache - model_cls = self._get_model_cls(model_config) factories = self._processor_factories[model_cls] ctx = self._create_processing_ctx(model_config, tokenizer) - cache = None if disable_cache else self._get_processor_cache( - model_config) return factories.build_processor(ctx, cache=cache) @@ -328,13 +306,15 @@ def get_decoder_dummy_data( model_config: "ModelConfig", seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> DummyDecoderData: """ Create dummy data for profiling the memory usage of a model. The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts) @@ -352,13 +332,15 @@ def get_encoder_dummy_data( model_config: "ModelConfig", seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> DummyEncoderData: """ Create dummy data for profiling the memory usage of a model. The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 342d7b24f8e9..dbea0b610b31 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -597,8 +597,7 @@ async def stop_profile(self) -> None: await asyncio.gather(*coros) async def reset_mm_cache(self) -> None: - self.processor.mm_registry.reset_processor_cache(self.model_config) - self.processor.mm_input_cache_client.reset() + self.processor.clear_cache() await self.engine_core.reset_mm_cache_async() async def reset_prefix_cache(self, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 32765cda6482..b61482806184 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -22,6 +22,7 @@ from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) @@ -38,7 +39,6 @@ EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, UtilityOutput, UtilityResult) -from vllm.v1.engine.mm_input_cache import MultiModalInputCacheServer from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig @@ -128,8 +128,9 @@ def __init__(self, ) self.use_spec_decode = vllm_config.speculative_config is not None - self.mm_input_cache_server = MultiModalInputCacheServer( - vllm_config.model_config, MULTIMODAL_REGISTRY) + self.mm_registry = mm_registry = MULTIMODAL_REGISTRY + self.mm_receiver_cache = receiver_cache_from_config( + vllm_config, mm_registry) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously @@ -370,7 +371,8 @@ def reset_mm_cache(self): logger.warning("Resetting the multi-modal cache when requests are " "in progress may lead to desynced internal caches.") - self.mm_input_cache_server.reset() + if self.mm_receiver_cache is not None: + self.mm_receiver_cache.clear_cache() def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() @@ -435,10 +437,11 @@ def preprocess_add_request( assert request.mm_kwargs is not None # Note on thread safety: no race condition. - # `mm_input_cache_server` is reset at the end of LLMEngine init, + # `mm_receiver_cache` is reset at the end of LLMEngine init, # and will only accessed in the input processing thread afterwards. - request.mm_kwargs = self.mm_input_cache_server.get_and_update( - request.mm_kwargs, request.mm_hashes) + if self.mm_receiver_cache is not None: + request.mm_kwargs = self.mm_receiver_cache.get_and_update( + request.mm_kwargs, request.mm_hashes) req = Request.from_engine_core_request(request, self.request_block_hasher) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 5a00a930951c..7130f666ef19 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -271,8 +271,7 @@ def stop_profile(self): self.engine_core.profile(False) def reset_mm_cache(self): - self.processor.mm_registry.reset_processor_cache(self.model_config) - self.processor.mm_input_cache_client.reset() + self.processor.clear_cache() self.engine_core.reset_mm_cache() def reset_prefix_cache(self, device: Optional[Device] = None): diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py deleted file mode 100644 index aa7dc62fd4ac..000000000000 --- a/vllm/v1/engine/mm_input_cache.py +++ /dev/null @@ -1,121 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Sequence -from typing import TYPE_CHECKING, Optional - -from vllm.multimodal import MultiModalRegistry -from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata -from vllm.multimodal.inputs import MultiModalKwargsItem -from vllm.utils import is_list_of - -if TYPE_CHECKING: - from vllm.config import ModelConfig - -# The idea of multimodal input caching is based on having a client and -# a server, where the client executes in the frontend process (=P0) and the -# server in the core process (=P1). -# -# -- P0: -# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of -# each input multi-modal item (e.g. image), -# - BaseMultiModalProcessor processes the input items into `mm_kwargs`, -# which are MultiModalKwargsItem instances that each correspond to an -# input multi-modal item. -# - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding -# `mm_hash` for each item. It stores the `mm_hash` as keys and the size -# of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking -# up additional memory in P0. -# - The `mm_hash` is always sent to P1. -# - The corresponding `mm_kwargs` are only sent to P1 if they are not cached -# in MultiModalInputCacheServer. -# -# -- P1: -# - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0), -# MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`. -# - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0), -# MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`. -# - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to -# the engine for model execution. -# -# Both Client and Server must perform cache update and eviction based on the -# same item size. This ensures that the keys of MultiModalInputCacheClient -# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0 -# whether a key is cached in MultiModalInputCacheServer by querying -# MultiModalInputCacheClient without having to communicate with P1. - - -class MultiModalInputCacheClient: - """Used by P0 to check whether multi-modal kwargs are cached in P1.""" - - def __init__(self, model_config: "ModelConfig", - mm_registry: MultiModalRegistry) -> None: - super().__init__() - - self.enabled = mm_registry.enable_mm_input_cache(model_config) - self.mm_cache = MultiModalCache.get_lru_cache( - model_config.get_mm_input_cache_gb(), - MultiModalCacheItemMetadata, - ) - - def get_and_update( - self, - mm_kwargs: Sequence[MultiModalKwargsItem], - mm_hashes: list[str], - ) -> list[Optional[MultiModalKwargsItem]]: - if not self.enabled: - return list(mm_kwargs) - - assert len(mm_kwargs) == len(mm_hashes) - - out_mm_items = list[Optional[MultiModalKwargsItem]]() - for mm_item, mm_hash in zip(mm_kwargs, mm_hashes): - if self.mm_cache.get(mm_hash) is not None: - out_mm_items.append(None) - else: - self.mm_cache[mm_hash] = \ - MultiModalCacheItemMetadata.wraps(mm_item) - out_mm_items.append(mm_item) - - return out_mm_items - - def reset(self) -> None: - self.mm_cache.clear() - - -class MultiModalInputCacheServer: - """Used by P1 to avoid requiring past multi-modal kwargs from P0.""" - - def __init__(self, model_config: "ModelConfig", - mm_registry: MultiModalRegistry) -> None: - super().__init__() - - self.enabled = mm_registry.enable_mm_input_cache(model_config) - self.mm_cache = MultiModalCache.get_lru_cache( - model_config.get_mm_input_cache_gb(), - MultiModalKwargsItem, - ) - - def get_and_update( - self, - mm_kwargs: Sequence[Optional[MultiModalKwargsItem]], - mm_hashes: list[str], - ) -> list[MultiModalKwargsItem]: - if not self.enabled: - mm_kwargs_lst = list(mm_kwargs) - assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem) - return mm_kwargs_lst - - assert len(mm_kwargs) == len(mm_hashes) - - out_mm_items = list[MultiModalKwargsItem]() - for mm_item, mm_hash in zip(mm_kwargs, mm_hashes): - if mm_item is None: - out_mm_items.append(self.mm_cache[mm_hash]) - else: - self.mm_cache[mm_hash] = mm_item - out_mm_items.append(mm_item) - - return out_mm_items - - def reset(self) -> None: - self.mm_cache.clear() diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 300b0713b2ff..7ed60156626b 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -11,6 +11,7 @@ from vllm.inputs.preprocess import InputPreprocessor from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.cache import processor_cache_from_config from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import argsort_mm_positions @@ -18,7 +19,6 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) from vllm.v1.structured_output.backend_lm_format_enforcer import ( @@ -47,16 +47,17 @@ def __init__( self.generation_config_fields = ( self.model_config.try_get_generation_config()) - self.input_preprocessor = InputPreprocessor(self.model_config, - self.tokenizer, - mm_registry) - self.mm_input_cache_client = MultiModalInputCacheClient( - self.model_config, mm_registry) + self.mm_registry = mm_registry + self.mm_processor_cache = processor_cache_from_config( + vllm_config, mm_registry) - @property - def mm_registry(self): - return self.input_preprocessor.mm_registry + self.input_preprocessor = InputPreprocessor( + self.model_config, + self.tokenizer, + mm_registry, + mm_processor_cache=self.mm_processor_cache, + ) def _validate_logprobs( self, @@ -310,7 +311,7 @@ def process_inputs( # in the input sequence. sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions) - orig_sorted_mm_inputs = [ + sorted_mm_inputs = [ decoder_mm_inputs[modality][idx] for modality, idx in sorted_mm_idxs ] @@ -323,11 +324,6 @@ def process_inputs( for modality, idx in sorted_mm_idxs ] - sorted_mm_inputs = self.mm_input_cache_client.get_and_update( - orig_sorted_mm_inputs, - sorted_mm_hashes, - ) - return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, prompt_token_ids=decoder_inputs["prompt_token_ids"], @@ -415,3 +411,6 @@ def _validate_model_input( # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens + + def clear_cache(self) -> None: + self.input_preprocessor.clear_cache() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f1ceaaae62a7..053aaf4f968e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2186,10 +2186,13 @@ def _get_mm_dummy_batch( max_items_per_batch: int, ) -> BatchedTensorInputs: """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, seq_len=self.max_num_tokens, mm_counts={modality: 1}, + cache=self.mm_budget.cache, ) dummy_mm_data = dummy_decoder_data.multi_modal_data diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 4a485b7e077d..d36423660427 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1813,10 +1813,13 @@ def _get_mm_dummy_batch( max_items_per_batch: int, ) -> BatchedTensorInputs: """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, seq_len=self.max_num_tokens, mm_counts={modality: 1}, + cache=self.mm_budget.cache, ) dummy_mm_data = dummy_decoder_data.multi_modal_data diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index b96473e7b164..82ede5ad8eb1 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -10,6 +10,7 @@ from vllm.config import ModelConfig, SchedulerConfig from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index +from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget @@ -33,14 +34,18 @@ def __init__( self.model_config = model_config self.scheduler_config = scheduler_config self.mm_registry = mm_registry + self.cache = cache = processor_only_cache_from_config( + model_config, mm_registry) self.max_model_len = model_config.max_model_len self.max_num_reqs = scheduler_config.max_num_seqs - self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config) + self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, + cache=cache) max_tokens_by_modality = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config) + .get_max_tokens_per_item_by_nonzero_modality(model_config, + cache=cache) encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( scheduler_config, From 64466778397482e0cb9ff9f6b320ca6d9dc567ae Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Wed, 27 Aug 2025 15:27:14 +0800 Subject: [PATCH 009/125] [XPU]fix cuda event used in XPU model runner (#23708) Signed-off-by: Kunshang Ji --- vllm/v1/worker/xpu_model_runner.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py index 59f8d0fcf5bd..fb892211f19d 100644 --- a/vllm/v1/worker/xpu_model_runner.py +++ b/vllm/v1/worker/xpu_model_runner.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from contextlib import contextmanager from typing import TYPE_CHECKING import torch @@ -22,7 +23,8 @@ def __init__( vllm_config: VllmConfig, device: torch.device, ): - super().__init__(vllm_config, device) + with _torch_cuda_wrapper(): + super().__init__(vllm_config, device) # FIXME: To be verified. self.cascade_attn_enabled = False @@ -31,3 +33,21 @@ def _init_device_properties(self) -> None: def _sync_device(self) -> None: torch.xpu.synchronize() + + +@contextmanager +def _torch_cuda_wrapper(): + + class _EventPlaceholder: + + def __init__(self, *args, **kwargs) -> None: + self.record = lambda: None + self.synchronize = lambda: None + + try: + # replace cuda Event with xpu Event, this should work by default + torch.cuda.Event = torch.xpu.Event + yield + finally: + # if anything goes wrong, just patch it with a placeholder + torch.cuda.Event = _EventPlaceholder From 91e382c935c2905c29f3ca22c658e03e8f02deaa Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 27 Aug 2025 16:11:15 +0800 Subject: [PATCH 010/125] [CI/Build] Remove redundant register in model init tests (#23715) Signed-off-by: DarkLight1337 --- tests/models/test_initialization.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index bbd3da982af8..b4d516233b4b 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -38,11 +38,6 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, model_arch=model_arch, exist_overrides=model_info.hf_overrides) - if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"): - from vllm.model_executor.models.llama4 import Llama4ForCausalLM - from vllm.model_executor.models.registry import ModelRegistry - ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM) - # Avoid calling model.forward() def _initialize_kv_caches_v0(self) -> None: self.cache_config.num_gpu_blocks = 0 From 5bd9f841581a3a9e9eecdd8764240575bb28e391 Mon Sep 17 00:00:00 2001 From: Michael Yao Date: Wed, 27 Aug 2025 17:50:09 +0800 Subject: [PATCH 011/125] [Docs] Fix an admonition important (#23726) Signed-off-by: windsonsea --- docs/configuration/optimization.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 3eaf2185a559..a8eab9985c8b 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -164,7 +164,7 @@ llm = LLM( ) ``` -!! important +!!! important Batch-level DP is not to be confused with API request-level DP (which is instead controlled by `data_parallel_size`). From 6578e873655859462758c5c51e51f876f2aa24a3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 27 Aug 2025 02:52:45 -0700 Subject: [PATCH 012/125] Optimize input preparation for FlashInfer [2/N] (#23174) Signed-off-by: Woosuk Kwon --- vllm/v1/attention/backends/flashinfer.py | 80 ++++++++++++++++-------- 1 file changed, 54 insertions(+), 26 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 941d2a4d7f1a..f948157c2b57 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import ClassVar, Optional, Union +import numpy as np import torch from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, @@ -22,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( QuantKey, kFp8StaticTensorSym, kNvfp4Quant) from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils import cdiv, is_pin_memory_available from vllm.utils.flashinfer import (supports_trtllm_attention, use_trtllm_attention) @@ -230,6 +232,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], dtype=torch.int32, device="cpu", pin_memory=pin_memory) + self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy() self.paged_kv_indices_cpu = torch.zeros(max_num_pages, dtype=torch.int32, device="cpu", @@ -238,10 +241,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], dtype=torch.int32, device="cpu", pin_memory=pin_memory) - - self.block_table_arange = torch.arange(max_num_pages_per_req, - dtype=torch.int32, - device=self.device) + self.paged_kv_last_page_len_np = ( + self.paged_kv_last_page_len_cpu.numpy()) def _get_workspace_buffer(self): if self._workspace_buffer is None: @@ -317,9 +318,10 @@ def build(self, max_seq_len = common_attn_metadata.max_seq_len seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu + seq_lens_np = seq_lens_cpu.numpy() block_table_tensor = common_attn_metadata.block_table_tensor - block_table_bounds_cpu = (seq_lens_cpu + page_size - 1) // page_size + num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size use_cascade = common_prefix_len > 0 if use_cascade: @@ -342,37 +344,41 @@ def build(self, # Remove the blocks of the shared prefix from all requests. block_table_tensor = block_table_tensor[:, num_common_kv_blocks:] - block_table_bounds_cpu -= num_common_kv_blocks + num_blocks_np -= num_common_kv_blocks else: shared_qo_indptr_cpu = None shared_kv_page_indptr_cpu = None shared_kv_page_indices_cpu = None shared_kv_last_page_len_cpu = None - max_num_blocks = block_table_bounds_cpu.max().item() - block_table_bounds = block_table_bounds_cpu.to(self.device, - non_blocking=True) - mask = (self.block_table_arange[:max_num_blocks].unsqueeze(0) - < block_table_bounds.unsqueeze(1)) + # write self.paged_kv_indptr_cpu inplace (0-index is always 0) + np.cumsum( + num_blocks_np, + dtype=np.int32, + out=self.paged_kv_indptr_np[1:num_reqs + 1], + ) + paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1] + paged_kv_indptr.copy_(self.paged_kv_indptr_cpu[:num_reqs + 1], + non_blocking=True) + # write self.paged_kv_indices inplace - num_actual_pages = torch.sum(mask) + num_actual_pages = num_blocks_np.sum().item() paged_kv_indices = self.paged_kv_indices[:num_actual_pages] - torch.masked_select(block_table_tensor[:, :max_num_blocks], - mask, - out=paged_kv_indices) - - # write self.paged_kv_indptr_cpu inplace (0-index is always 0) - torch.cumsum(block_table_bounds_cpu, - dim=0, - dtype=torch.int32, - out=self.paged_kv_indptr_cpu[1:1 + num_reqs]) + _copy_page_indices_kernel[(num_reqs, )]( + paged_kv_indices, + block_table_tensor, + block_table_tensor.stride(0), + paged_kv_indptr, + BLOCK_SIZE=1024, + ) - paged_kv_last_page_len_cpu = seq_lens_cpu % page_size # write self.paged_kv_last_page_len_cpu inplace - torch.where(paged_kv_last_page_len_cpu == 0, - torch.tensor(page_size), - paged_kv_last_page_len_cpu, - out=self.paged_kv_last_page_len_cpu[:num_reqs]) + paged_kv_last_page_len_np = seq_lens_np % page_size + self.paged_kv_last_page_len_np[:num_reqs] = np.where( + paged_kv_last_page_len_np == 0, + page_size, + paged_kv_last_page_len_np, + ) # Check if any layer uses sinks (requires TRTLLM attention) has_sinks = self.global_hyperparameters.has_sinks @@ -1002,3 +1008,25 @@ def fast_plan_decode( self._sm_scale = sm_scale self._rope_scale = rope_scale self._rope_theta = rope_theta + + +@triton.jit +def _copy_page_indices_kernel( + page_indices, + block_table, + block_table_stride, + cu_num_blocks, + BLOCK_SIZE: tl.constexpr, +): + req_idx = tl.program_id(0) + row_ptr = block_table + req_idx * block_table_stride + start_idx = tl.load(cu_num_blocks + req_idx) + end_idx = tl.load(cu_num_blocks + req_idx + 1) + num_blocks = end_idx - start_idx + + offset = tl.arange(0, BLOCK_SIZE) + for i in tl.range(0, num_blocks, BLOCK_SIZE): + block_ids = tl.load(row_ptr + i + offset, mask=i + offset < num_blocks) + tl.store(page_indices + start_idx + i + offset, + block_ids, + mask=i + offset < num_blocks) From 04ff1e43fb6e2e675170d0c90399290f8925abb7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 27 Aug 2025 03:25:00 -0700 Subject: [PATCH 013/125] [Misc] Move CpuGpuBuffer to vllm/v1/utils.py (#23728) Signed-off-by: Woosuk Kwon --- vllm/v1/utils.py | 29 +++++++++++++++++++++++++++++ vllm/v1/worker/cpu_model_runner.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 6 +++--- vllm/v1/worker/utils.py | 29 ----------------------------- 4 files changed, 33 insertions(+), 33 deletions(-) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index b5750c82db02..8f9face6fbf2 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -96,6 +96,35 @@ def __repr__(self): return f"ConstantList({self._x})" +class CpuGpuBuffer: + + def __init__( + self, + *args, + dtype: torch.dtype, + device: torch.device, + pin_memory: bool, + ): + self.cpu = torch.zeros(*args, + dtype=dtype, + device="cpu", + pin_memory=pin_memory) + self.np = self.cpu.numpy() + self.gpu = self.cpu.to(device) + + def copy_to_gpu(self, n: Optional[int] = None) -> torch.Tensor: + if n is None: + return self.gpu.copy_(self.cpu, non_blocking=True) + return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True) + + def copy_to_cpu(self, n: Optional[int] = None) -> torch.Tensor: + """NOTE: Because this method is non-blocking, explicit synchronization + is needed to ensure the data is copied to CPU.""" + if n is None: + return self.cpu.copy_(self.gpu, non_blocking=True) + return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True) + + def get_engine_client_zmq_addr(local_only: bool, host: str, port: int = 0) -> str: diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 137578f0e608..742e553b77e0 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -10,8 +10,8 @@ from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.v1.attention.backends.cpu_attn import TorchSDPAMetadataBuilderV1 +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_model_runner import GPUModelRunner -from vllm.v1.worker.utils import CpuGpuBuffer if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 053aaf4f968e..d93460d618e7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -78,14 +78,14 @@ from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer +from vllm.v1.utils import CpuGpuBuffer from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.kv_connector_model_runner_mixin import ( KVConnectorModelRunnerMixin, KVConnectorOutput) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from .utils import (AttentionGroup, CpuGpuBuffer, MultiModalBudget, - bind_kv_cache, gather_mm_placeholders, - initialize_kv_cache_for_kv_sharing, +from .utils import (AttentionGroup, MultiModalBudget, bind_kv_cache, + gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) if TYPE_CHECKING: diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 82ede5ad8eb1..f40753468766 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -303,32 +303,3 @@ def bind_kv_cache( for layer_name, kv_cache in kv_caches.items(): # NOTE: Use list because of v0 PP virtual engine. forward_context[layer_name].kv_cache = [kv_cache] - - -class CpuGpuBuffer: - - def __init__( - self, - *args, - dtype: torch.dtype, - device: torch.device, - pin_memory: bool, - ): - self.cpu = torch.zeros(*args, - dtype=dtype, - device="cpu", - pin_memory=pin_memory) - self.np = self.cpu.numpy() - self.gpu = self.cpu.to(device) - - def copy_to_gpu(self, n: Optional[int] = None) -> None: - if n is None: - return self.gpu.copy_(self.cpu, non_blocking=True) - return self.gpu[:n].copy_(self.cpu[:n], non_blocking=True) - - def copy_to_cpu(self, n: Optional[int] = None) -> None: - """NOTE: Because this method is non-blocking, explicit synchronization - is needed to ensure the data is copied to CPU.""" - if n is None: - return self.cpu.copy_(self.gpu, non_blocking=True) - return self.cpu[:n].copy_(self.gpu[:n], non_blocking=True) From 11eddf02f0234f79435d747f2d3dce117ab39aa1 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 27 Aug 2025 03:45:04 -0700 Subject: [PATCH 014/125] [FlashInfer] Cache hyper params in metadata builder (#23732) Signed-off-by: Woosuk Kwon --- vllm/v1/attention/backends/flashinfer.py | 30 ++++++++++++------------ 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index f948157c2b57..1115fc606b05 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -214,6 +214,10 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # TODO: discard this for trtllm-gen backend self.global_hyperparameters = infer_global_hyperparameters( get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl)) + self.sm_scale = self.global_hyperparameters.sm_scale + self.window_left = self.global_hyperparameters.window_left + self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap + self.has_sinks = self.global_hyperparameters.has_sinks # Preparing persistent buffers (device-side) self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, @@ -381,8 +385,6 @@ def build(self, ) # Check if any layer uses sinks (requires TRTLLM attention) - has_sinks = self.global_hyperparameters.has_sinks - prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, self.num_kv_heads, num_prefill_tokens, @@ -390,7 +392,7 @@ def build(self, self.cache_dtype, self.q_data_type, is_prefill=True, - has_sinks=has_sinks) + has_sinks=self.has_sinks) decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, self.num_kv_heads, num_decode_tokens, @@ -398,7 +400,7 @@ def build(self, self.cache_dtype, self.q_data_type, is_prefill=False, - has_sinks=has_sinks) + has_sinks=self.has_sinks) attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, @@ -433,9 +435,9 @@ def build(self, self.head_dim, self.page_size, causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters.logits_soft_cap, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, ) @@ -472,10 +474,9 @@ def build(self, self.head_dim, self.page_size, causal=True, - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, ) @@ -525,10 +526,9 @@ def build(self, self.page_size, # Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE", - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, + sm_scale=self.sm_scale, + window_left=self.window_left, + logits_soft_cap=self.logits_soft_cap, q_data_type=self.q_data_type, kv_data_type=self.kv_cache_dtype, ) From e03940762b43812fccd3c214bda60201cff9d16a Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 27 Aug 2025 18:59:35 +0800 Subject: [PATCH 015/125] [CI/Build] Reduce LoRA layer test cases (#23721) Signed-off-by: Jee Jee Li --- tests/lora/test_layers.py | 72 ++++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 39 deletions(-) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 92db023babc2..6e2dda464d8e 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -243,7 +243,7 @@ def check_punica_wrapper(punica_wrapper) -> bool: @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) @@ -347,7 +347,7 @@ def create_random_embedding_layer(): @torch.inference_mode() # @pytest.mark.skip( # reason="Fails when loras are in any slot other than the first.") -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) @pytest.mark.parametrize("stage", STAGES) @@ -486,7 +486,7 @@ def create_random_embedding_layer(): @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 256512]) @pytest.mark.parametrize("stage", STAGES) @@ -620,12 +620,15 @@ def _pretest(): @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) -def test_linear_replicated(dist_init, num_loras, device, stage, - bias_enabled) -> None: +def test_linear_replicated( + dist_init, + num_loras, + device, + stage, +) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -634,10 +637,11 @@ def test_linear_replicated(dist_init, num_loras, device, stage, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16, + ) def create_random_linear_replicated_layer(): @@ -651,10 +655,6 @@ def create_random_linear_replicated_layer(): lora_linear.create_lora_weights(max_loras, lora_config) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == 1) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): @@ -734,14 +734,13 @@ def create_random_linear_replicated_layer(): @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, - device, stage, bias_enabled) -> None: + device, stage) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -750,11 +749,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + ) def create_random_linear_parallel_layer(): if orientation == "row": @@ -777,10 +777,7 @@ def create_random_linear_parallel_layer(): lora_linear.create_lora_weights(max_loras, lora_config) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == 1) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None + return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): @@ -860,14 +857,13 @@ def create_random_linear_parallel_layer(): @torch.inference_mode() -@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("num_loras", [1, 2, 4]) @pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("stage", STAGES) -@pytest.mark.parametrize("bias_enabled", [True, False]) def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, - device, stage, bias_enabled) -> None: + device, stage) -> None: if current_platform.is_cuda_alike(): torch.cuda.set_device(device) @@ -876,11 +872,12 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, torch.set_default_device(device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras) assert check_punica_wrapper(punica_wrapper) - lora_config = LoRAConfig(max_loras=max_loras, - max_lora_rank=8, - fully_sharded_loras=fully_shard, - lora_dtype=torch.float16, - bias_enabled=bias_enabled) + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + fully_sharded_loras=fully_shard, + lora_dtype=torch.float16, + ) def create_column_parallel_packed_layer(): if repeats == 2: @@ -924,10 +921,7 @@ class FakeConfig: model_config=FakeConfig()) assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len( lora_linear.lora_b_stacked) == n_slices) - if bias_enabled: - assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices - else: - assert lora_linear.lora_bias_stacked is None + return linear, lora_linear for i in range(NUM_RANDOM_SEEDS): From 8f0d7eaea87409a54ccaed76995b59c6b0a3d4cf Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Wed, 27 Aug 2025 19:57:38 +0800 Subject: [PATCH 016/125] [XPU] Fix OOM issue for data parallel with Ray backend (#22500) Signed-off-by: Fanli Lin Signed-off-by: Fanli Lin Co-authored-by: Cyrus Leung --- vllm/v1/engine/core.py | 27 ++++++++++++++++++--------- vllm/v1/engine/utils.py | 35 +++++++++++++++++++++++++++++++---- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index b61482806184..a7038e2d2c26 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -39,7 +39,8 @@ EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, UtilityOutput, UtilityResult) -from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses +from vllm.v1.engine.utils import (EngineHandshakeMetadata, EngineZmqAddresses, + get_device_indices) from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats @@ -1169,22 +1170,30 @@ def __init__( # https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501 # and get_accelerator_ids_for_accelerator_resource() in worker.py # of ray. - self._set_cuda_visible_devices(vllm_config, local_dp_rank) + self._set_visible_devices(vllm_config, local_dp_rank) super().__init__(vllm_config, local_client, "", executor_class, log_stats) - def _set_cuda_visible_devices(self, vllm_config: VllmConfig, - local_dp_rank: int): + def _set_visible_devices(self, vllm_config: VllmConfig, + local_dp_rank: int): from vllm.platforms import current_platform - device_control_env_var = current_platform.device_control_env_var + if current_platform.is_xpu(): + pass + else: + device_control_env_var = current_platform.device_control_env_var + self._set_cuda_visible_devices(vllm_config, local_dp_rank, + device_control_env_var) + + def _set_cuda_visible_devices(self, vllm_config: VllmConfig, + local_dp_rank: int, + device_control_env_var: str): world_size = vllm_config.parallel_config.world_size # Set CUDA_VISIBLE_DEVICES or equivalent. try: - os.environ[device_control_env_var] = ",".join( - str(current_platform.device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * - world_size, (local_dp_rank + 1) * world_size)) + value = get_device_indices(device_control_env_var, local_dp_rank, + world_size) + os.environ[device_control_env_var] = value except IndexError as e: raise Exception( f"Error setting {device_control_env_var}: " diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 62f229e28693..56ef8477d267 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -164,19 +164,33 @@ def set_device_control_env_var(vllm_config: VllmConfig, """ world_size = vllm_config.parallel_config.world_size evar = current_platform.device_control_env_var + + value = get_device_indices(evar, local_dp_rank, world_size) + with patch.dict(os.environ, values=((evar, value), )): + yield + + +def get_device_indices(device_control_env_var: str, local_dp_rank: int, + world_size: int): + """ + Returns a comma-separated string of device indices for the specified + data parallel rank. + + For example, if world_size=2 and local_dp_rank=1, and there are 4 devices, + this will select devices 2 and 3 for local_dp_rank=1. + """ try: value = ",".join( str(current_platform.device_id_to_physical_device_id(i)) for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size)) except IndexError as e: - raise Exception(f"Error setting {evar}: " + raise Exception(f"Error setting {device_control_env_var}: " f"local range: [{local_dp_rank * world_size}, " f"{(local_dp_rank + 1) * world_size}) " "base value: " - f"\"{os.getenv(evar)}\"") from e - with patch.dict(os.environ, values=((evar, value), )): - yield + f"\"{os.getenv(device_control_env_var)}\"") from e + return value class CoreEngineActorManager: @@ -254,6 +268,19 @@ def __init__( dp_vllm_config = copy.deepcopy(vllm_config) dp_vllm_config.parallel_config.placement_group = pg local_client = index < local_engine_count + + # Ray XPU known issue: dpctl initializes the GPU runtime early, so + # setting device env vars in Ray actor's initialization method + # will not affect device selection. See: + # https://github.com/ray-project/ray/blob/master/python/ray/_private/accelerators/intel_gpu.py#L56 # noqa: E501 + if current_platform.is_xpu(): + device_evar = current_platform.device_control_env_var + device_indices = get_device_indices(device_evar, local_index, + world_size) + actor_env_vars = self.env_vars_dict.copy() + actor_env_vars[device_evar] = device_indices + runtime_env = RuntimeEnv(env_vars=actor_env_vars) + actor = ray.remote(DPEngineCoreActor).options( scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, From 1f7a9c95e4b2a1e02b19e94fd7371443f08b2e4b Mon Sep 17 00:00:00 2001 From: Michael Yao Date: Wed, 27 Aug 2025 20:37:52 +0800 Subject: [PATCH 017/125] [Docs] Fix a 1-2-3 list and style issues in tpu.md (#23729) Signed-off-by: windsonsea --- docs/configuration/tpu.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/configuration/tpu.md b/docs/configuration/tpu.md index ac2b6baffd14..e456077e0495 100644 --- a/docs/configuration/tpu.md +++ b/docs/configuration/tpu.md @@ -45,30 +45,30 @@ This initial compilation time ranges significantly and is impacted by many of th ### Optimize based on your data -#### max model len vs. most model len +#### max-model-len vs. most-model-len ![most_model_len](../assets/design/tpu/most_model_len.png) -If most of your requests are shorter than the maximum model length but you still need to accommodate occasional longer requests, setting a high maximum model length can negatively impact performance. In these cases, you can try introducing most model len by specifying the `VLLM_TPU_MOST_MODEL_LEN` environment variable. +If most of your requests are shorter than the maximum model length but you still need to accommodate occasional longer requests, setting a high maximum model length can negatively impact performance. In these cases, you can try introducing most-model-len by specifying the `VLLM_TPU_MOST_MODEL_LEN` environment variable. For example, 1% requests are 32k length and 99% requests are 2k length. You can pass 32k into `--max-model-len 32768` and use `VLLM_TPU_MOST_MODEL_LEN=2048`. -The requests get subdivided into max-model-len and most-model-len categories, for the latter category, we can gain better performance since the server can process more requests at a time. +The requests get subdivided into max-model-len and most-model-len categories, for the latter category, you can gain better performance since the server can process more requests at a time. #### Padding -For online serving with latency requirements, consider switching to bucket padding by setting the `VLLM_TPU_BUCKET_PADDING_GAP` environment variable. Because of the layout of the TPU, try using increments of 128: 128, 256, etc. +For online serving with latency requirements, consider switching to bucket padding by setting the `VLLM_TPU_BUCKET_PADDING_GAP` environment variable. Because of the layout of the TPU, try using increments of 128 (e.g., 128, 256, etc.) -The server pads the requests into fixed lengths before sending them to the model to avoid recompilation. To read more about tpu padding, see [here](https://cloud.google.com/tpu/docs/performance-guide#xla-efficiencies). Currently, there are 2 ways to pad the requests: +The server pads the requests into fixed lengths before sending them to the model to avoid recompilation. To read more about TPU padding, see [here](https://cloud.google.com/tpu/docs/performance-guide#xla-efficiencies). Currently, there are 2 ways to pad the requests: -1) the default exponential padding (pad to the nearest power of 2) -2) bucket padding (pad to the nearest linearly increasing bucket). +1. the default exponential padding (pad to the nearest power of 2) +2. bucket padding (pad to the nearest linearly increasing bucket). When using bucket padding, the buckets start from 16, end at max_model_len, and increment by `VLLM_TPU_BUCKET_PADDING_GAP`. For example, max_model_len=512, padding_gap=64, the buckets will be [16, 32, 64, 128, 192, 256, 320, 384, 448, 512]. -The fewer tokens we pad, the less unnecessary computation TPU does, the better performance we can get. For example, if num_tokens=300, with exponential padding, we pad to 512, with the bucket_padding above, we pad to 320. +The fewer tokens you pad, the less unnecessary computation TPU does, the better performance you can get. For example, if num_tokens=300, with exponential padding, you pad to 512, with the bucket_padding above, you pad to 320. However, you need to be careful to choose the padding gap. If the gap is too small, it means the number of buckets is large, leading to increased warmup (precompile) time and higher memory to store the compiled graph. Too many compiled graphs may lead to HBM OOM. Conversely, an overly large gap yields no performance improvement compared to the default exponential padding. From 9d30de44698e1e337e4736ff62b83ebe1bbd4d40 Mon Sep 17 00:00:00 2001 From: tc-mb <157115220+tc-mb@users.noreply.github.com> Date: Wed, 27 Aug 2025 20:38:00 +0800 Subject: [PATCH 018/125] [model] Support MiniCPM-V 4.5 (#23586) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: tc-mb Signed-off-by: Xin Yang Signed-off-by: Abatom Signed-off-by: chzhang Signed-off-by: Pate Motter Signed-off-by: Terrencezzj Signed-off-by: Woosuk Kwon Signed-off-by: simon-mo Signed-off-by: mgoin Signed-off-by: Siyuan Fu Signed-off-by: siyuanf Signed-off-by: Weiliang Liu Signed-off-by: Michael Goin Signed-off-by: yewentao256 Signed-off-by: DarkLight1337 Signed-off-by: Luka Govedič Signed-off-by: Zijing Liu Signed-off-by: Zijing Liu Signed-off-by: jiabin.00 Signed-off-by: zjy0516 Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Jee Jee Li Signed-off-by: tc-mb <157115220+tc-mb@users.noreply.github.com> Signed-off-by: Roger Wang Signed-off-by: Roger Wang Signed-off-by: Huy Do Signed-off-by: Matúš Námešný Signed-off-by: Guillaume Calmettes Signed-off-by: Chen Zhang Signed-off-by: oye93 Signed-off-by: Julien Lin Signed-off-by: Didier Durand Signed-off-by: Tianyu Li Signed-off-by: Hongxia Yang Signed-off-by: Yuekai Zhang Signed-off-by: vllmellm Signed-off-by: jiang1.li Signed-off-by: Zerohertz Signed-off-by: Hyogeun Oh (오효근) Signed-off-by: Thomas Parnell Signed-off-by: Russell Bryant Signed-off-by: Isotr0py Signed-off-by: Huzaifa Sidhpurwala Signed-off-by: Federico <65908512+coval3nte@users.noreply.github.com> Signed-off-by: Zixuan Zhang Signed-off-by: wuhang Signed-off-by: czhu-cohere Signed-off-by: Wei Wei Signed-off-by: Yiheng Xu Signed-off-by: Chenheli Hua Signed-off-by: wangyafeng Co-authored-by: Xin Yang <105740670+xyang16@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Zhonghua Deng Co-authored-by: Chaojun Zhang Co-authored-by: Pate Motter Co-authored-by: Terrence Zhao <32208165+Terrencezzj@users.noreply.github.com> Co-authored-by: Woosuk Kwon Co-authored-by: Simon Mo Co-authored-by: Michael Goin Co-authored-by: weiliang Co-authored-by: Siyuan Fu Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič Co-authored-by: Zijing Liu Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Jiangyun Zhu Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Raghavan Co-authored-by: Jee Jee Li Co-authored-by: Cyrus Leung Co-authored-by: Roger Wang Co-authored-by: Roger Wang Co-authored-by: knlnguyen1802 Co-authored-by: Huy Do Co-authored-by: Matúš Námešný Co-authored-by: Guillaume Calmettes Co-authored-by: Chen Zhang Co-authored-by: En Ouyang Co-authored-by: Li, Jiang Co-authored-by: nvjullin Co-authored-by: Didier Durand <2927957+didier-durand@users.noreply.github.com> Co-authored-by: TianyuLi0 <116711075+TianyuLi0@users.noreply.github.com> Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Co-authored-by: Yuekai Zhang Co-authored-by: vllmellm Co-authored-by: Hyogeun Oh (오효근) Co-authored-by: Thomas Parnell Co-authored-by: Russell Bryant Co-authored-by: Lukas Geiger Co-authored-by: Isotr0py Co-authored-by: Huzaifa Sidhpurwala Co-authored-by: Federico <65908512+coval3nte@users.noreply.github.com> Co-authored-by: zixuanzhang226 Co-authored-by: wuhang Co-authored-by: yzds <41983536+youzhedian@users.noreply.github.com> Co-authored-by: hongchao Co-authored-by: czhu-cohere Co-authored-by: Wei Co-authored-by: Yiheng Xu Co-authored-by: Aaron Pham Co-authored-by: Chenheli Hua Co-authored-by: CSWYF3634076 <58356743+CSWYF3634076@users.noreply.github.com> --- docs/models/supported_models.md | 2 +- tests/models/registry.py | 2 +- vllm/model_executor/models/minicpmv.py | 314 +++++++++++++++++- .../chat_templates/registry.py | 11 + .../chat_templates/template_minicpmv45.jinja | 93 ++++++ 5 files changed, 407 insertions(+), 15 deletions(-) create mode 100644 vllm/transformers_utils/chat_templates/template_minicpmv45.jinja diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 19ce8c06724f..35a5fa0c2e42 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -638,7 +638,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | | `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I+ + V+ | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | | `MiniCPMO` | MiniCPM-O | T + IE+ + VE+ + AE+ | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, etc. | ✅︎ | | ✅︎ | +| `MiniCPMV` | MiniCPM-V | T + IE+ + VE+ | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, `openbmb/MiniCPM-V-4_5`, etc. | ✅︎ | | ✅︎ | | `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + IE+ | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ | | `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | | `MllamaForConditionalGeneration` | Llama 3.2 | T + I+ | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | diff --git a/tests/models/registry.py b/tests/models/registry.py index f2c09d3e8452..ee546e7af85c 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -451,7 +451,7 @@ def check_available_online( "MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6", trust_remote_code=True), "MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5", - extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4"}, # noqa: E501 + extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4", "4.5": "openbmb/MiniCPM-V-4_5"}, # noqa: E501 trust_remote_code=True), "MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501 trust_remote_code=True, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index c22d871ab20d..2d785c30fd7d 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -27,12 +27,14 @@ from collections import defaultdict from collections.abc import Iterable, Mapping, Sequence from functools import partial +from itertools import chain from typing import Annotated, Any, Callable, Literal, Optional, Union import numpy as np import torch import torch.types from torch import nn +from torch.nn.init import trunc_normal_ from transformers import BatchFeature, PretrainedConfig from typing_extensions import TypeVar @@ -47,10 +49,11 @@ from vllm.model_executor.models.minicpm import MiniCPMForCausalLM from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM +from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - NestedTensors) + MultiModalKwargsItems, NestedTensors) from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageProcessorItems, ImageSize, ModalityData, ModalityDataItems, @@ -218,6 +221,187 @@ def forward(self, x: torch.Tensor, return x +class Resampler4_5(Resampler2_5): + + def __init__(self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: tuple[int, int] = (70, 70), + max_temporal_size: int = 36000, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__(num_queries, + embed_dim, + num_heads, + kv_dim, + norm_layer, + max_size, + quant_config=quant_config, + prefix=prefix) + + trunc_normal_(self.query, std=.02) + self.max_temporal_size = max_temporal_size + self._set_temporal_pos_cache(self.max_temporal_size) + self.apply(self._init_weights) + + def get_1d_sincos_pos_embed_from_temporal_size(self, embed_dim: int, + pos: np.ndarray): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + def _set_temporal_pos_cache(self, + max_temporal_size: int, + device: torch.types.Device = "cpu") -> None: + temporal_size = np.arange(max_temporal_size, dtype=np.float32) + pos_embed = torch.from_numpy( + self.get_1d_sincos_pos_embed_from_temporal_size( + self.embed_dim, temporal_size)).float().to(device) + self.register_buffer("temporal_pos_embed", pos_embed, persistent=False) + + def _adjust_temporal_pos_cache(self, + max_temporal_size: int, + device: torch.types.Device = "cpu"): + if max_temporal_size > self.max_temporal_size: + self.max_temporal_size = max_temporal_size + self._set_temporal_pos_cache(self.max_temporal_size, device) + + def _init_weights(self, m: Union[nn.Linear, nn.LayerNorm]): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward( + self, + x: torch.Tensor, + tgt_sizes: torch.Tensor, + # temporal_ids for high refresh rate videos + temporal_ids=None + ) -> torch.Tensor: + assert x.shape[0] == tgt_sizes.shape[0] + bs = x.shape[0] + + device = x.device + dtype = x.dtype + + patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] + + self._adjust_pos_cache(tgt_sizes, device=device) + + temporal_pos_emb = False + temporal_ids_flatten = None + if temporal_ids is not None: + # example: [[-1], [-1], [2, 6, 9]] + temporal_ids_flatten = list(chain.from_iterable(temporal_ids)) + max_temporal_size = max(temporal_ids_flatten, default=0) + if max_temporal_size > -1: + temporal_pos_emb = True + if max_temporal_size > self.max_temporal_size: + self._adjust_temporal_pos_cache(max_temporal_size, device) + + max_patch_len = patch_len.max().item() + assert isinstance(max_patch_len, int) + + key_padding_mask = torch.zeros((bs, max_patch_len), + dtype=torch.bool, + device=device) + + x, _ = self.kv_proj(x) # B * L * D + x = self.ln_kv(x).permute(1, 0, 2) # L * B * D + q = self.ln_q(self.query) # Q * D + + pos_embed_2d = [] + pos_embed_temporal = [] + for i in range(bs): + tgt_h, tgt_w = tgt_sizes[i] + if temporal_pos_emb: + if temporal_ids_flatten[i] == -1: + pos_embed_temporal.append( + torch.zeros(self.embed_dim, dtype=dtype, + device=device)) + else: + pos_embed_temporal.append(self.temporal_pos_embed[ + temporal_ids_flatten[i]].to(dtype)) # D + + pos_embed_2d.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape( + (tgt_h * tgt_w, -1)).to(dtype)) # patches * D + key_padding_mask[i, patch_len[i]:] = True + + pos_embed_2d = torch.nn.utils.rnn.pad_sequence( + pos_embed_2d, batch_first=True, + padding_value=0.0).permute(1, 0, 2) # BLD => L * B * D + + k = x + v = x + pos_embed_2d + if pos_embed_temporal: + k += torch.stack(pos_embed_temporal, dim=0) + bs = len(temporal_ids) + merge_k = [] + merge_v = [] + merge_key_padding_mask = [] + + start = 0 + for tp in temporal_ids: + end = start + len(tp) + # L * (end-start) * D -> (end-start) * L * D + # -> 1 * L*(end-start) * D + merge_k.append(k[:, start:end, :].permute(1, 0, 2).reshape( + -1, self.embed_dim)) + merge_v.append(v[:, start:end, :].permute(1, 0, 2).reshape( + -1, self.embed_dim)) + merge_key_padding_mask.append( + key_padding_mask[start:end, :].reshape(-1, 1)) + + start = end + + k = torch.nn.utils.rnn.pad_sequence(merge_k, + batch_first=True, + padding_value=0.0).permute( + 1, 0, 2) # L*(end-start) + v = torch.nn.utils.rnn.pad_sequence(merge_v, + batch_first=True, + padding_value=0.0).permute( + 1, 0, 2) # L*(end-start) + key_padding_mask = torch.nn.utils.rnn.pad_sequence( + merge_key_padding_mask, batch_first=True, + padding_value=True).squeeze(-1) + + out = self.attn( + self._repeat(q, bs), # Q * B * D + k, # L * B * D + L * B * D + v, + key_padding_mask=key_padding_mask, + )[0] + # out: Q * B * D + x = out.permute(1, 0, 2) # B * Q * D + + x = self.ln_post(x) + x = x @ self.proj + return x + + def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]: version_float = getattr(config, "version", None) @@ -354,9 +538,7 @@ def get_model_version(self): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: mm_limits = {"image": None} - if self.get_model_version() == (2, - 6) or self.get_model_version() == (4, - 0): + if self.get_model_version() in {(2, 6), (4, 0), (4, 5)}: mm_limits["video"] = None return mm_limits @@ -637,8 +819,7 @@ def _base_call_hf_processor( out_keys: set[str], ) -> dict[str, NestedTensors]: # This processor supports zipping prompt and mm_data together - if self.info.get_model_version() == ( - 2, 6) or self.info.get_model_version() == (4, 0): + if self.info.get_model_version() in {(2, 6), (4, 0), (4, 5)}: inputs = super()._call_hf_processor( prompt=prompts, # type: ignore mm_data=mm_data, @@ -816,7 +997,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # and config class self.config = config self.multimodal_config = multimodal_config - self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.version = get_version_by_config(self.config) self.llm = self.init_llm(vllm_config=vllm_config, @@ -1364,11 +1544,9 @@ def init_vision_module( prefix: str = "", ) -> nn.Module: quant_config = self._maybe_ignore_quant_config(quant_config) - model = Idefics2VisionTransformer( - config.vision_config, - quant_config=quant_config, - prefix=prefix, - use_data_parallel=self.use_data_parallel) + model = Idefics2VisionTransformer(config.vision_config, + quant_config=quant_config, + prefix=prefix) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model @@ -1436,11 +1614,121 @@ def load_weights(self, weights: Iterable[tuple[str, return loader.load_weights(weights) +class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + assert self.version == (4, 5) + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)): + return None + return quant_config + + def init_llm( + self, + vllm_config: VllmConfig, + prefix: str = "", + ) -> nn.Module: + return Qwen3ForCausalLM(vllm_config=vllm_config, prefix=prefix) + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + model = Idefics2VisionTransformer(config.vision_config, + quant_config=quant_config, + prefix=prefix) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> nn.Module: + quant_config = self._maybe_ignore_quant_config(quant_config) + with set_default_torch_dtype(torch.float16): + # The resampler in 4.0 remains consistent with the one in 2.5/2.6. + resampler = Resampler4_5(num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + prefix=prefix) + + return resampler.to(device=current_platform.device_type, + dtype=torch.get_default_dtype()) + + def get_vision_hidden_states( + self, data: MiniCPMVImagePixelInputs) -> torch.Tensor: + pixel_values = data["pixel_values"] + tgt_sizes = data["tgt_sizes"] + temporal_ids = data.get('temporal_ids', None) + + B = len(pixel_values) + P = pixel_values[0].shape[-2] + L = max(item.shape[-1] for item in pixel_values) + device = pixel_values[0].device + dtype = pixel_values[0].dtype + + all_pixel_values = torch.zeros((B, 3, P, L), + dtype=dtype, + device=device) + all_temporal_ids = None if temporal_ids is None else flatten_2d_lists( + temporal_ids) + for i, pixel_values_item in enumerate(pixel_values): + L_item = pixel_values_item.shape[-1] + all_pixel_values[i, ..., :L_item] = pixel_values_item + + num_patches = tgt_sizes.prod(-1) + max_patches = num_patches.max().item() + assert isinstance(max_patches, int) + + patch_attn_mask = torch.zeros((B, max_patches), + dtype=torch.bool, + device=device) + for i, num_patches_item in enumerate(num_patches): + patch_attn_mask[i, :num_patches_item] = True + + vision_embedding = self.vpm( + all_pixel_values, + patch_attention_mask=patch_attn_mask.unsqueeze(1), + tgt_sizes=tgt_sizes, + ) + + return self.resampler(vision_embedding, tgt_sizes, all_temporal_ids) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, + skip_prefixes=["apm.", "audio", "tts"]) + return loader.load_weights(weights) + + _SUPPORT_VERSION = { (2, 0): MiniCPMV2_0, (2, 5): MiniCPMV2_5, (2, 6): MiniCPMV2_6, (4, 0): MiniCPMV4_0, + (4, 5): MiniCPMV4_5, } diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py index e0ef7f0999d4..d09c5fa924fb 100644 --- a/vllm/transformers_utils/chat_templates/registry.py +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -20,6 +20,16 @@ def _get_qwen_chat_template_fallback( return CHAT_TEMPLATES_DIR / "template_basic.jinja" +def _get_minicpmv_chat_template_fallback( + tokenizer_name_or_path: str) -> Optional[Path]: + # MiniCPM-V-4.5 version uses a dedicated template + if "4.5" in tokenizer_name_or_path or "4_5" in tokenizer_name_or_path: + return CHAT_TEMPLATES_DIR / "template_minicpmv45.jinja" + + # Other versions use chatml template + return CHAT_TEMPLATES_DIR / "template_chatml.jinja" + + # yapf: disable _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", @@ -27,6 +37,7 @@ def _get_qwen_chat_template_fallback( "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", "florence2": CHAT_TEMPLATES_DIR / "template_basic.jinja", "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", + "minicpmv": _get_minicpmv_chat_template_fallback, "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", "qwen": _get_qwen_chat_template_fallback, } diff --git a/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja b/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja new file mode 100644 index 000000000000..661ebd1cf5c1 --- /dev/null +++ b/vllm/transformers_utils/chat_templates/template_minicpmv45.jinja @@ -0,0 +1,93 @@ +{%- set enable_thinking = enable_thinking | default(false) %} +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} + +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} + +{%- for message in messages %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set content = message.content %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is defined and message.reasoning_content is not none %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in message.content %} + {%- set content = message.content.split('')[-1].lstrip('\n') %} + {%- set reasoning_content = message.content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} + {%- if enable_thinking is defined and enable_thinking is true %} + {{- '\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file From 8c13820f0b203976eab8e821c102234a73f338cd Mon Sep 17 00:00:00 2001 From: cndoit18 Date: Wed, 27 Aug 2025 20:42:20 +0800 Subject: [PATCH 019/125] [Bugfix] Fix task field initialization when PYTHONOPTIMIZE is enabled (#23718) Signed-off-by: cndoit18 --- vllm/worker/pooling_model_runner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index 8d8d9b4d0503..3e1950798dbf 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -199,8 +199,9 @@ def _prepare_pooling( pooling_params = seq_group_metadata.pooling_params assert pooling_params is not None - assert (task := pooling_params.task) is not None, ( - "You did not set `task` in the API") + + task = pooling_params.task + assert task is not None, "You did not set `task` in the API" model = cast(VllmModelForPooling, self.model) to_update = model.pooler.get_pooling_updates(task) From a403d0fa41cc68e3b6da4e1097dc896fde2f1a6a Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 27 Aug 2025 05:50:47 -0700 Subject: [PATCH 020/125] [Misc] Remove unnecessary `_send_reconfig_message()` in `core_client.py` (#23127) Signed-off-by: Nick Hill --- vllm/v1/engine/core_client.py | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 079dd9a7d38d..65f7abc97110 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1190,21 +1190,6 @@ async def _abort_requests(self, request_ids: list[str], await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) - async def _send_reconfig_message( - self, reconfig_request: ReconfigureDistributedRequest, - engine: EngineIdentity) -> asyncio.Future: - """Send reconfiguration message and return the result future without - waiting for completion.""" - call_id = uuid.uuid1().int >> 64 - future = asyncio.get_running_loop().create_future() - self.utility_results[call_id] = future - message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( - (self.client_index, call_id, "reinitialize_distributed", - (reconfig_request, )))) - await self._send_input_message(message, engine, reconfig_request) - self._ensure_output_queue_task() - return future - async def scale_elastic_ep(self, new_data_parallel_size: int) -> None: """Scale elastic EP data parallel size""" cur_data_parallel_size = len(self.core_engines) @@ -1214,7 +1199,7 @@ async def scale_elastic_ep(self, new_data_parallel_size: int) -> None: f"different from cur_data_parallel_size {cur_data_parallel_size}") assert self.vllm_config.parallel_config.data_parallel_backend == \ - "ray", ("Only ray DP backend supports scaling elastic EP") + "ray", "Only ray DP backend supports scaling elastic EP" scale_up = new_data_parallel_size > cur_data_parallel_size @@ -1246,9 +1231,10 @@ async def _scale_up_elastic_ep(self, cur_data_parallel_size: int, data_parallel_master_ip, new_data_parallel_master_port=self.vllm_config.parallel_config. data_parallel_master_port) - future = await self._send_reconfig_message(reconfig_request, - engine) - reconfig_futures.append(future) + coro = self._call_utility_async("reinitialize_distributed", + reconfig_request, + engine=engine) + reconfig_futures.append(asyncio.create_task(coro)) logger.info("All reconfigure messages sent, starting engine creation") @@ -1318,9 +1304,10 @@ async def _scale_down_elastic_ep(self, cur_data_parallel_size: int, if cur_dp_rank >= new_data_parallel_size: reconfig_request.new_data_parallel_rank = \ ReconfigureRankType.SHUTDOWN_CURRENT_RANK - future = await self._send_reconfig_message(reconfig_request, - engine) - reconfig_futures.append(future) + coro = self._call_utility_async("reinitialize_distributed", + reconfig_request, + engine=engine) + reconfig_futures.append(asyncio.create_task(coro)) for _ in range(new_data_parallel_size, cur_data_parallel_size): self.core_engines.pop() From 704432af3c129b7a57fca9b059eefe214159f836 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Wed, 27 Aug 2025 14:51:54 +0200 Subject: [PATCH 021/125] [V1] [Hybrid] Disable prefix caching by default for hybrid or mamba-based models (#23716) Signed-off-by: Thomas Parnell --- docs/usage/v1_guide.md | 10 ++++++---- vllm/model_executor/models/config.py | 9 +++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 64bd0d9bf507..20234e761133 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -107,14 +107,16 @@ to enable simultaneous generation and embedding using the same engine instance i #### Mamba Models Models using selective state-space mechanisms instead of standard transformer attention are supported. -Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. +Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. +Please note that prefix caching is not yet supported for these models. Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, -`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that -these models currently require disabling prefix caching in V1. +`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). +Please note that prefix caching is not yet supported for these models. Hybrid models with mechanisms different to Mamba are also supported (e.g, `MiniMaxText01ForCausalLM`, `MiniMaxM1ForCausalLM`). -Please note that these models currently require disabling prefix caching and enforcing eager mode in V1. +Please note that prefix caching is not yet supported for these models. +It is also necessary to enforce eager mode for these models in V1. #### Encoder-Decoder Models diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index f62209326b98..88b3154de2cb 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -292,12 +292,13 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: return model_config = vllm_config.model_config + cache_config = vllm_config.cache_config compilation_config = vllm_config.compilation_config - model_cls, _ = ModelRegistry.resolve_model_cls( - model_config.architecture, - model_config=model_config, - ) + # TODO(tdoublep): remove once prefix caching is enabled + cache_config.enable_prefix_caching = False + logger.info("Hybrid or mamba-based model detected: disabling prefix " + "caching since it is not yet supported.") # TODO(tdoublep): remove as full cuda graph support is added FCG_NOT_SUPPORTED_MODELS = [ From 5eeef1b90852917b300ed67b98e341eb846ba2e9 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 27 Aug 2025 21:24:09 +0800 Subject: [PATCH 022/125] [Model] Explicit `default_pooling_type` interface (#23736) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/bert.py | 4 +-- vllm/model_executor/models/bert_with_rope.py | 5 ++-- vllm/model_executor/models/gritlm.py | 2 +- vllm/model_executor/models/interfaces.py | 19 +------------ vllm/model_executor/models/interfaces_base.py | 28 +++++++++++++++++++ vllm/model_executor/models/internlm2.py | 3 +- vllm/model_executor/models/modernbert.py | 3 +- .../models/prithvi_geospatial_mae.py | 7 +++-- vllm/model_executor/models/qwen2_rm.py | 3 +- vllm/model_executor/models/registry.py | 7 +++-- vllm/model_executor/models/roberta.py | 3 +- 11 files changed, 51 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 22b6c4401213..b34ca5cbe963 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -28,8 +28,8 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask -from .interfaces import (SupportsCrossEncoding, SupportsQuant, - default_pooling_type) +from .interfaces import SupportsCrossEncoding, SupportsQuant +from .interfaces_base import default_pooling_type from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 129450927e56..dcb7e75456cd 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -27,13 +27,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (SupportsQuant, - default_pooling_type) from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from .interfaces import SupportsQuant +from .interfaces_base import default_pooling_type + class BertWithRopeEmbedding(nn.Module): diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 3f6790269ae6..1b3d541c65cf 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -20,7 +20,7 @@ from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config -from .interfaces import default_pooling_type +from .interfaces_base import default_pooling_type logger = init_logger(__name__) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 9415e67924e7..22f005849e86 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -3,7 +3,7 @@ from collections.abc import Iterable, Mapping, MutableSequence from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, - TypeVar, Union, overload, runtime_checkable) + Union, overload, runtime_checkable) import numpy as np import torch @@ -641,23 +641,6 @@ def supports_cross_encoding( return is_pooling_model(model) and _supports_cross_encoding(model) -_T = TypeVar("_T", bound=type[torch.nn.Module]) - - -def default_pooling_type(pooling_type: str): - """Set default_pooling_type decorator. """ - - def func(model: _T) -> _T: - model.default_pooling_type = pooling_type # type: ignore - return model - - return func - - -def get_default_pooling_type(model: Union[type[object], object]) -> str: - return getattr(model, "default_pooling_type", "LAST") - - class SupportsQuant: """The interface required for all models that support quantization.""" diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 697fa020deb4..19a3ef1a3b80 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -144,6 +144,17 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): MRO of your model class. """ + default_pooling_type: ClassVar[str] = "LAST" + """ + Indicates the + [vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][] + to use by default. + + You can use the + [vllm.model_executor.models.interfaces_base.default_pooling_type][] + decorator to conveniently set this field. + """ + pooler: Pooler """The pooler is only called on TP rank 0.""" @@ -165,3 +176,20 @@ def is_pooling_model( return False return getattr(model, "is_pooling_model", False) + + +_T = TypeVar("_T", bound=type[nn.Module]) + + +def default_pooling_type(pooling_type: str): + """Decorator to set `VllmModelForPooling.default_pooling_type`.""" + + def func(model: _T) -> _T: + model.default_pooling_type = pooling_type # type: ignore + return model + + return func + + +def get_default_pooling_type(model: Union[type[object], object]) -> str: + return getattr(model, "default_pooling_type", "LAST") diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index d0c4bf5450d6..26bc48ffbd9b 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -31,7 +31,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type +from .interfaces import SupportsLoRA, SupportsPP +from .interfaces_base import default_pooling_type from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 72290bf2ee29..477855586128 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -26,7 +26,8 @@ from vllm.sequence import IntermediateTensors from vllm.tasks import PoolingTask -from .interfaces import SupportsCrossEncoding, default_pooling_type +from .interfaces import SupportsCrossEncoding +from .interfaces_base import default_pooling_type from .utils import WeightsMapper, maybe_prefix diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index 59e9f3e8a47b..f46d6375e1f6 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -27,9 +27,6 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import ( - IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput, - default_pooling_type) from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (ImageItem, ModalityData, @@ -43,6 +40,10 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from .interfaces import (IsAttentionFree, MultiModalEmbeddings, + SupportsMultiModalWithRawInput) +from .interfaces_base import default_pooling_type + def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]): # This model receives in input a multi-dimensional tensor representing diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index e0a30e04c602..421b43563bad 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -18,7 +18,8 @@ from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP, default_pooling_type +from .interfaces import SupportsLoRA, SupportsPP +from .interfaces_base import default_pooling_type from .qwen2 import Qwen2Model from .utils import AutoWeightsLoader, maybe_prefix diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index c65c58d4a047..196b5f35e1e4 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -25,11 +25,12 @@ from vllm.transformers_utils.dynamic_module import ( try_get_class_from_dynamic_module) -from .interfaces import (get_default_pooling_type, has_inner_state, has_noops, - is_attention_free, is_hybrid, supports_cross_encoding, +from .interfaces import (has_inner_state, has_noops, is_attention_free, + is_hybrid, supports_cross_encoding, supports_multimodal, supports_multimodal_raw_input, supports_pp, supports_transcription, supports_v0_only) -from .interfaces_base import is_pooling_model, is_text_generation_model +from .interfaces_base import (get_default_pooling_type, is_pooling_model, + is_text_generation_model) logger = init_logger(__name__) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 49a37342c67f..2bfa51162910 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -22,7 +22,8 @@ from vllm.sequence import IntermediateTensors from .bert_with_rope import BertWithRope, JinaRobertaModel -from .interfaces import SupportsCrossEncoding, default_pooling_type +from .interfaces import SupportsCrossEncoding +from .interfaces_base import default_pooling_type class RobertaEmbedding(nn.Module): From 8dd2baa5978f123974177023d6efab731153a2f4 Mon Sep 17 00:00:00 2001 From: rebel-hongseok Date: Wed, 27 Aug 2025 22:25:49 +0900 Subject: [PATCH 023/125] Add vLLM Korea Meetup in the README.md and meetups.md (#23746) Signed-off-by: rebel-hongseok Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- README.md | 1 + docs/community/meetups.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index ef5b43588953..8812aac4ea26 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ Easy, fast, and cheap LLM serving for everyone *Latest News* 🔥 - [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH). +- [2025/08] We hosted [vLLM Korea Meetup](https://luma.com/cgcgprmh) with Red Hat and Rebellions! We shared the latest advancements in vLLM along with project spotlights from the vLLM Korea community. Please find the meetup slides [here](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). - [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152). - [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/). - [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). diff --git a/docs/community/meetups.md b/docs/community/meetups.md index 61ea44220ad2..d76238cb3179 100644 --- a/docs/community/meetups.md +++ b/docs/community/meetups.md @@ -3,6 +3,7 @@ We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: - [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg), August 23rd 2025. [[Slides]](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH) +- [vLLM Korea Meetup](https://luma.com/cgcgprmh), August 19th 2025. [[Slides]](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). - [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA), August 2nd 2025. [[Slides]](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) [[Recording]](https://www.chaspark.com/#/live/1166916873711665152). - [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing) - [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). From 16dc4052b004261b547fc50fe7b20e2d2fbf915d Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 27 Aug 2025 14:39:48 +0100 Subject: [PATCH 024/125] Fix pre-commit on main (#23747) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/community/meetups.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/community/meetups.md b/docs/community/meetups.md index d76238cb3179..221a7bd96213 100644 --- a/docs/community/meetups.md +++ b/docs/community/meetups.md @@ -3,7 +3,7 @@ We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: - [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg), August 23rd 2025. [[Slides]](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH) -- [vLLM Korea Meetup](https://luma.com/cgcgprmh), August 19th 2025. [[Slides]](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). +- [vLLM Korea Meetup](https://luma.com/cgcgprmh), August 19th 2025. [[Slides]](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view). - [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA), August 2nd 2025. [[Slides]](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) [[Recording]](https://www.chaspark.com/#/live/1166916873711665152). - [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing) - [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). From fe8d7b6f03e7d8a36ffb6931397fc81ee594dd64 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 27 Aug 2025 21:41:22 +0800 Subject: [PATCH 025/125] [Model] Interface to enable batch-level DP support (#23733) Signed-off-by: DarkLight1337 Signed-off-by: Cyrus Leung Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- docs/configuration/optimization.md | 7 +++++-- vllm/config/__init__.py | 7 +++++++ vllm/model_executor/models/interfaces.py | 11 +++++++++++ vllm/model_executor/models/minicpmv.py | 2 ++ vllm/model_executor/models/mllama4.py | 2 ++ vllm/model_executor/models/qwen2_5_vl.py | 2 ++ vllm/model_executor/models/registry.py | 9 +++++++-- vllm/model_executor/models/step3_vl.py | 2 ++ 8 files changed, 38 insertions(+), 4 deletions(-) diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index a8eab9985c8b..b11ccb5c0027 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -168,8 +168,11 @@ llm = LLM( Batch-level DP is not to be confused with API request-level DP (which is instead controlled by `data_parallel_size`). -The availability of batch-level DP is based on model implementation. -Currently, the following models support `mm_encoder_tp_mode="data"`: +Batch-level DP needs to be implemented on a per-model basis, +and enabled by setting `supports_encoder_tp_data = True` in the model class. +Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to use this feature. + +Known supported models: - Llama4 () - MiniCPM-V-4 () diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index ac6f51df9549..e3fb6d796def 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -872,6 +872,13 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str, def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: if self._model_info.supports_multimodal: + if (self.mm_encoder_tp_mode == "data" and + not self._model_info.supports_multimodal_encoder_tp_data): + logger.warning_once( + "This model does not support `--mm-encoder-tp-mode data`. " + "Falling back to `--mm-encoder-tp-mode weights`.") + self.mm_encoder_tp_mode = "weights" + return MultiModalConfig( limit_per_prompt=self.limit_mm_per_prompt, media_io_kwargs=self.media_io_kwargs, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 22f005849e86..506732fed361 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -52,6 +52,12 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ + supports_encoder_tp_data: ClassVar[bool] = False + """ + A flag that indicates whether this model supports + `multimodal_config.mm_encoder_tp_mode="data"`. + """ + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: """ @@ -137,6 +143,11 @@ def supports_multimodal( return getattr(model, "supports_multimodal", False) +def supports_multimodal_encoder_tp_data( + model: Union[type[object], object]) -> bool: + return getattr(model, "supports_encoder_tp_data", False) + + @runtime_checkable class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): """The interface required for all multi-modal models.""" diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 2d785c30fd7d..0181bfeebda0 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -1521,6 +1521,8 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): ], } + supports_encoder_tp_data = True + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) assert self.version == (4, 0) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 595bdd17cf2c..ac9b968f7a0c 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -716,6 +716,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, "gate_up_proj": ["gate_proj", "up_proj"], } + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 648ba81eb387..b528083b7c9c 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -868,6 +868,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, "model.": "language_model.model.", }) + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 196b5f35e1e4..80eac78cdfad 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -27,8 +27,10 @@ from .interfaces import (has_inner_state, has_noops, is_attention_free, is_hybrid, supports_cross_encoding, - supports_multimodal, supports_multimodal_raw_input, - supports_pp, supports_transcription, supports_v0_only) + supports_multimodal, + supports_multimodal_encoder_tp_data, + supports_multimodal_raw_input, supports_pp, + supports_transcription, supports_v0_only) from .interfaces_base import (get_default_pooling_type, is_pooling_model, is_text_generation_model) @@ -324,6 +326,7 @@ class _ModelInfo: supports_cross_encoding: bool supports_multimodal: bool supports_multimodal_raw_input: bool + supports_multimodal_encoder_tp_data: bool supports_pp: bool has_inner_state: bool is_attention_free: bool @@ -343,6 +346,8 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo": supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), supports_multimodal_raw_input=supports_multimodal_raw_input(model), + supports_multimodal_encoder_tp_data= + supports_multimodal_encoder_tp_data(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index f8877b584b19..f379d2c15fb6 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -867,6 +867,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, "lm_head.": "language_model.lm_head.", }) + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): From 513c1fe255f7d4ec3e91f7f5c2dd2d97c0460765 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Wed, 27 Aug 2025 14:55:12 +0100 Subject: [PATCH 026/125] Only run `get_attr_docs` if generating help text (#23723) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/engine/arg_utils.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9e7c95ea5205..3399d505e363 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -152,9 +152,17 @@ def is_online_quantization(quantization: Any) -> bool: return quantization in ["inc"] +NEEDS_HELP = ( + "--help" in (argv := sys.argv) # vllm SUBCOMMAND --help + or (argv0 := argv[0]).endswith("mkdocs") # mkdocs SUBCOMMAND + or argv0.endswith("mkdocs/__main__.py") # python -m mkdocs SUBCOMMAND +) + + @functools.lru_cache(maxsize=30) def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: - cls_docs = get_attr_docs(cls) + # Save time only getting attr docs if we're generating help text + cls_docs = get_attr_docs(cls) if NEEDS_HELP else {} kwargs = {} for field in fields(cls): # Get the set of possible types for the field @@ -172,7 +180,7 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: # Get the help text for the field name = field.name - help = cls_docs[name].strip() + help = cls_docs.get(name, "").strip() # Escape % for argparse help = help.replace("%", "%%") @@ -254,6 +262,9 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any: def get_kwargs(cls: ConfigType) -> dict[str, Any]: """Return argparse kwargs for the given Config dataclass. + If `--help` or `mkdocs` are not present in the command line command, the + attribute documentation will not be included in the help output. + The heavy computation is cached via functools.lru_cache, and a deep copy is returned so callers can mutate the dictionary without affecting the cached version. From 3af47c3cc693f432b59658019891393385aa0e2a Mon Sep 17 00:00:00 2001 From: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Date: Wed, 27 Aug 2025 10:09:08 -0400 Subject: [PATCH 027/125] [Feature] Add Hopper DeepGEMM E8M0 for DeepSeekV3.1 scale_fmt (#23666) Signed-off-by: yewentao256 Signed-off-by: youkaichao Co-authored-by: youkaichao --- tests/kernels/moe/test_block_fp8.py | 5 +- tests/kernels/moe/test_deepep_deepgemm_moe.py | 7 ++- vllm/envs.py | 8 ++- .../layers/fused_moe/batched_deep_gemm_moe.py | 4 +- .../layers/fused_moe/fused_moe.py | 7 ++- .../layers/fused_moe/triton_deep_gemm_moe.py | 6 +-- .../model_executor/layers/quantization/fp8.py | 9 ++-- .../layers/quantization/utils/fp8_utils.py | 4 +- vllm/transformers_utils/config.py | 18 +++++++ vllm/utils/deep_gemm.py | 53 +++++++++---------- 10 files changed, 68 insertions(+), 53 deletions(-) diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index 9e4eaf221f24..ecc57acc6796 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -16,7 +16,7 @@ fused_topk, modular_triton_fused_moe) from vllm.platforms import current_platform from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used dg_available = has_deep_gemm() @@ -226,8 +226,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), - reason="Not E8M0 scale MOE") +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Not E8M0 scale MOE") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch): diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 1e922be47f2b..36a98522a658 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -20,8 +20,7 @@ FusedMoEModularKernel) from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, - is_deep_gemm_supported) +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported from ...utils import multi_gpu_test from .parallel_utils import ProcessGroupInfo, parallel_launch @@ -374,7 +373,7 @@ def _test_deepep_deepgemm_moe( @multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM") def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, topk: int, world_dp_size: tuple[int, int]): @@ -432,7 +431,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, @multi_gpu_test(num_gpus=2) @requires_deep_ep @requires_deep_gemm -@pytest.mark.skipif(is_blackwell_deep_gemm_e8m0_used(), +@pytest.mark.skipif(is_deep_gemm_e8m0_used(), reason="Skipping test for Blackwell DeepGEMM") def test_ll_deepep_deepgemm_moe( mnk: tuple[int, int, int], diff --git a/vllm/envs.py b/vllm/envs.py index 66c7c2c7f2c4..35735b552575 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -131,6 +131,7 @@ VLLM_TPU_USING_PATHWAYS: bool = False VLLM_USE_DEEP_GEMM: bool = False VLLM_USE_DEEP_GEMM_E8M0: bool = True + VLLM_USE_DEEP_GEMM_E8M0_HOPPER: bool = False VLLM_SKIP_DEEP_GEMM_WARMUP: bool = False VLLM_USE_FUSED_MOE_GROUPED_TOPK: bool = True VLLM_USE_FLASHINFER_MOE_FP8: bool = False @@ -954,9 +955,12 @@ def get_vllm_port() -> Optional[int]: lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs. - # E8M0 is faster on B200 but may reduce accuracy. "VLLM_USE_DEEP_GEMM_E8M0": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1"))), + # TODO(wentao): unify the two E8M0 flags after verifying the correctness. + # Whether to use E8M0 scaling when DeepGEMM is used on Hopper GPUs. + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER": + lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "0"))), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine @@ -1244,6 +1248,8 @@ def compute_hash() -> str: "VLLM_USE_FLASHINFER_SAMPLER", "VLLM_DISABLED_KERNELS", "VLLM_USE_DEEP_GEMM", + "VLLM_USE_DEEP_GEMM_E8M0", + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER", "VLLM_USE_TRTLLM_FP4_GEMM", "VLLM_USE_FUSED_MOE_GROUPED_TOPK", "VLLM_USE_FLASHINFER_MOE_FP8", diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index c4d680af932f..a5326dfe84f6 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.triton_utils import tl, triton from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, - is_blackwell_deep_gemm_e8m0_used) + is_deep_gemm_e8m0_used) logger = init_logger(__name__) @@ -174,7 +174,7 @@ def silu_mul_fp8_quant_deep_gemm( eps, fp8_min, fp8_max, - is_blackwell_deep_gemm_e8m0_used(), + is_deep_gemm_e8m0_used(), BLOCK=group_size, NUM_STAGES=4, num_warps=1, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 84dafcf00d82..17a5c735a57f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -40,7 +40,7 @@ from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled @@ -1431,9 +1431,8 @@ def fused_experts(hidden_states: torch.Tensor, # E8M0 scale, which means we requantize the weight and input to the specific # scale. Fallen back to cutlass or triton for some cases would cause # accuracy issue. - if (allow_deep_gemm and use_fp8_w8a8 - and (is_blackwell_deep_gemm_e8m0_used() - or _valid_deep_gemm(hidden_states, w1, w2))): + if (allow_deep_gemm and use_fp8_w8a8 and + (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))): assert apply_router_weight_on_input is False assert is_act_and_mul, ( "DeepGemm only supports is_act_and_mul=True for now.") diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 486ca881df48..6cd81d97f029 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -10,7 +10,7 @@ DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape, deep_gemm_block_shape) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts -from vllm.utils.deep_gemm import is_blackwell_deep_gemm_e8m0_used +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -107,7 +107,7 @@ def workspace_shapes( # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm and (is_blackwell_deep_gemm_e8m0_used() + if self.allow_deep_gemm and (is_deep_gemm_e8m0_used() or _valid_deep_gemm_shape(M, N, K)): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( @@ -143,7 +143,7 @@ def apply( ): use_deep_gemm = (self.allow_deep_gemm and (_valid_deep_gemm(hidden_states, w1, w2) - or is_blackwell_deep_gemm_e8m0_used())) + or is_deep_gemm_e8m0_used())) experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert assert experts is not None diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d45d368b582d..be358cfa949f 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -48,8 +48,7 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, - is_deep_gemm_supported) +from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used, is_deep_gemm_supported from vllm.utils.flashinfer import has_flashinfer_moe if TYPE_CHECKING: @@ -427,7 +426,7 @@ def process_weights_after_loading(self, layer: Module) -> None: # On B200, if E8M0 for DeepGemm is used, we need to # requantize the weight and input to the specific scale # at the same time. - if is_blackwell_deep_gemm_e8m0_used(): + if is_deep_gemm_e8m0_used(): assert layer.weight_block_size is not None block_sz = tuple(layer.weight_block_size) requant_weight_ue8m0_inplace( @@ -734,7 +733,7 @@ def process_weights_after_loading(self, layer: Module) -> None: # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. - if self.allow_deep_gemm and not is_blackwell_deep_gemm_e8m0_used(): + if self.allow_deep_gemm and not is_deep_gemm_e8m0_used(): # Lazy import to avoid CUDA initialization problems. if _is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = \ @@ -871,7 +870,7 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.w13_input_scale del layer.w2_input_scale - if is_blackwell_deep_gemm_e8m0_used(): + if is_deep_gemm_e8m0_used(): assert layer.weight_block_size is not None # Re-quantise the expert weights so their scales are UE8M0. block_sz = tuple(layer.weight_block_size) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index ab1d5383f465..7b324dce3c36 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -20,7 +20,7 @@ from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, direct_register_custom_op -from vllm.utils.deep_gemm import (is_blackwell_deep_gemm_e8m0_used, +from vllm.utils.deep_gemm import (is_deep_gemm_e8m0_used, should_use_deepgemm_for_fp8_linear) logger = init_logger(__name__) @@ -385,7 +385,7 @@ def per_token_group_quant_fp8( scaling factor. """ if use_ue8m0 is None: - use_ue8m0 = is_blackwell_deep_gemm_e8m0_used() + use_ue8m0 = is_deep_gemm_e8m0_used() dtype = current_platform.fp8_dtype() if dtype is None else dtype assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible " diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 2cd799e5eb5a..bec792465bfb 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -501,6 +501,24 @@ def get_config( if quantization_config is not None: config.quantization_config = quantization_config + # auto-enable DeepGEMM UE8M0 on Hopper if model config requests it + scale_fmt = quantization_config.get("scale_fmt", None) + if scale_fmt in ("ue8m0", ): + if not envs.is_set("VLLM_USE_DEEP_GEMM_E8M0_HOPPER"): + os.environ["VLLM_USE_DEEP_GEMM_E8M0_HOPPER"] = "1" + logger.info_once( + ("Detected quantization_config.scale_fmt=%s; " + "enabling Hopper UE8M0."), + scale_fmt, + ) + elif not envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER: + logger.warning_once( + ("Model config requests UE8M0 " + "(quantization_config.scale_fmt=%s), but " + "VLLM_USE_DEEP_GEMM_E8M0_HOPPER=0 is set; " + "Hopper UE8M0 disabled."), + scale_fmt, + ) if hf_overrides_kw: logger.debug("Overriding HF config with %s", hf_overrides_kw) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index b0bc3a79eb0a..cd1dbfb813fe 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -31,34 +31,33 @@ def is_deep_gemm_supported() -> bool: @functools.cache -def is_blackwell_deep_gemm_e8m0_used() -> bool: +def is_deep_gemm_e8m0_used() -> bool: """Return ``True`` if vLLM is configured to use DeepGEMM " - "E8M0 scale on a Blackwell-class GPU. + "E8M0 scale on a Hopper or Blackwell-class GPU. """ if not is_deep_gemm_supported(): - logger.debug_once( + logger.info_once( "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system.") return False - if not envs.VLLM_USE_DEEP_GEMM_E8M0: - logger.debug_once("DeepGEMM E8M0 disabled: VLLM_USE_DEEP_GEMM_E8M0=0.") - return False - _lazy_init() if _fp8_gemm_nt_impl is None: - logger.debug_once( - "DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") + logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found") return False - enabled = (current_platform.is_cuda() - and current_platform.has_device_capability(100)) - if enabled: - logger.debug_once("DeepGEMM E8M0 enabled on Blackwell GPU.") - else: - logger.debug_once( - "DeepGEMM E8M0 disabled: not running on Blackwell GPU.") - return enabled + if current_platform.is_device_capability(100) and \ + envs.VLLM_USE_DEEP_GEMM_E8M0: + logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.") + return True + + if current_platform.is_device_capability(90) and \ + envs.VLLM_USE_DEEP_GEMM_E8M0_HOPPER: + logger.info_once("DeepGEMM E8M0 enabled on Hopper GPU.") + return True + + logger.info_once("DeepGEMM E8M0 disabled on current configuration.") + return False def _missing(*_: Any, **__: Any) -> NoReturn: @@ -124,20 +123,18 @@ def fp8_gemm_nt(*args, **kwargs): _lazy_init() if _fp8_gemm_nt_impl is None: return _missing(*args, **kwargs) - return _fp8_gemm_nt_impl( - *args, - disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), - **kwargs) + return _fp8_gemm_nt_impl(*args, + disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), + **kwargs) def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): _lazy_init() if _grouped_impl is None: return _missing(*args, **kwargs) - return _grouped_impl( - *args, - disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), - **kwargs) + return _grouped_impl(*args, + disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), + **kwargs) def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): @@ -145,9 +142,7 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): if _grouped_masked_impl is None: return _missing(*args, **kwargs) return _grouped_masked_impl( - *args, - disable_ue8m0_cast=not is_blackwell_deep_gemm_e8m0_used(), - **kwargs) + *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs) def _ceil_to_ue8m0(x: torch.Tensor): @@ -211,7 +206,7 @@ def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, "m_grouped_fp8_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", "per_block_cast_to_fp8", - "is_blackwell_deep_gemm_e8m0_used", + "is_deep_gemm_e8m0_used", "is_deep_gemm_supported", "should_use_deepgemm_for_fp8_linear", ] From 841490434aaee4b1c8d8427112af740b6662f384 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Wed, 27 Aug 2025 22:45:17 +0800 Subject: [PATCH 028/125] [Model] Enable native HF format InternVL support (#23742) Signed-off-by: Isotr0py --- docs/models/supported_models.md | 1 + .../multimodal/generation/test_common.py | 29 +++++++++---------- tests/models/registry.py | 3 +- vllm/model_executor/models/registry.py | 1 + 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 35a5fa0c2e42..20cf75873af7 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -629,6 +629,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | | `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + IE+ + VE+ | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ | | `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ | | `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 96208f8eda62..2b60faae8ec0 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -222,21 +222,6 @@ }, marks=[large_gpu_mark(min_gb=32)], ), - # Check "auto" with fallback to transformers - "internvl-transformers": VLMTestInfo( - models=["OpenGVLab/InternVL3-1B-hf"], - test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), - prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 - img_idx_to_prompt=lambda idx: "", - max_model_len=4096, - use_tokenizer_eos=True, - image_size_factors=[(0.25, 0.5, 1.0)], - vllm_runner_kwargs={ - "model_impl": "auto", - }, - auto_cls=AutoModelForImageTextToText, - marks=[pytest.mark.core_model], - ), #### Extended model tests "aria": VLMTestInfo( models=["rhymes-ai/Aria"], @@ -461,6 +446,20 @@ use_tokenizer_eos=True, patch_hf_runner=model_utils.internvl_patch_hf_runner, ), + "intern_vl-hf": VLMTestInfo( + models=["OpenGVLab/InternVL3-1B-hf"], + test_type=( + VLMTestType.IMAGE, + VLMTestType.MULTI_IMAGE, + VLMTestType.VIDEO, + ), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "", + video_idx_to_prompt=lambda idx: "