From 619f11d13f67218970135afec6987ee5e838675e Mon Sep 17 00:00:00 2001 From: yongming-qin Date: Sat, 29 Nov 2025 17:12:52 -0800 Subject: [PATCH 1/2] Add files for openvla referring deepseek-vl2. They both use timm for vision. Signed-off-by: yongming-qin --- vllm/model_executor/models/openvla.py | 723 ++++++++++++++++++ vllm/transformers_utils/configs/openvla.py | 143 ++++ vllm/transformers_utils/processors/openvla.py | 157 ++++ 3 files changed, 1023 insertions(+) create mode 100644 vllm/model_executor/models/openvla.py create mode 100644 vllm/transformers_utils/configs/openvla.py create mode 100644 vllm/transformers_utils/processors/openvla.py diff --git a/vllm/model_executor/models/openvla.py b/vllm/model_executor/models/openvla.py new file mode 100644 index 000000000000..73861e0ed503 --- /dev/null +++ b/vllm/model_executor/models/openvla.py @@ -0,0 +1,723 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Inference-only OpenVLA model compatible with HuggingFace weights.""" + +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import partial +from typing import Annotated, Literal, TypeAlias + +import torch +import torch.nn as nn +from transformers import BatchFeature + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.models.transformers.utils import replace_linear_class +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, + MultiModalUUIDDict, +) +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + MultiModalProcessingInfo, + PromptIndexTargets, + PromptInsertion, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.openvla import OpenVLAConfig +from vllm.transformers_utils.processors.openvla import OpenVLAProcessor +from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config +from vllm.utils.collection_utils import is_list_of +from vllm.utils.tensor_schema import TensorSchema, TensorShape +from vllm.utils.torch_utils import set_default_torch_dtype + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) + +# The image token +_IMAGE_TOKEN = "" + + +class OpenVLAImagePixelInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + """ + + type: Literal["pixel_values"] + data: Annotated[torch.Tensor, TensorShape("b", "c", "h", "w")] + + +class OpenVLAImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - n: Number of image tokens + - h: Hidden size (must match language model backbone) + """ + + type: Literal["image_embeds"] + data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("b", "n", "h")] + + +OpenVLAImageInputs: TypeAlias = ( + OpenVLAImagePixelInputs | OpenVLAImageEmbeddingInputs +) + + +# === Utility Functions for Monkey-Patching === +def unpack_tuple(fn): + """Unpack tuple return value to single value.""" + def wrapper(*args, **kwargs): + result = fn(*args, **kwargs) + return result[0] if isinstance(result, tuple) else result + return wrapper + + +# HF Transformers overwrites parameters with names containing `gamma`; we're going to patch VisionBackbone.LayerScale. +# =>> TIMM :: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L109 +# =>> Transformers :: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L3960 +def _ls_new_forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.scale_factor) if self.inplace else x * self.scale_factor + + +def ls_apply_patch(ls_module): + """Apply patch to LayerScale module to use scale_factor instead of gamma.""" + if hasattr(ls_module, 'gamma'): + ls_module.scale_factor = nn.Parameter(ls_module.gamma.clone()) + ls_module.forward = _ls_new_forward.__get__(ls_module, type(ls_module)) + del ls_module.gamma + + +# === Prismatic Vision Backbone === +class PrismaticVisionBackbone(nn.Module): + """Vision backbone using timm ViT.""" + + def __init__( + self, + use_fused_vision_backbone: bool, + image_sizes: list[int], + timm_model_ids: list[str], + timm_override_act_layers: list[str | None], + ) -> None: + super().__init__() + self.use_fused_vision_backbone = use_fused_vision_backbone + + try: + import timm + from timm.models.vision_transformer import LayerScale + except ImportError as e: + raise ImportError("Please install timm") from e + + assert len(timm_model_ids) <= 2, "Prismatic models only support up to 2 (fused) vision backbones!" + + # Create main featurizer + self.featurizer = timm.create_model( + timm_model_ids[0], + pretrained=False, + num_classes=0, + img_size=image_sizes[0], + act_layer=timm_override_act_layers[0] if timm_override_act_layers else None, + ) + # Monkey-patch forward to return second-to-last layer patches + self.featurizer.forward = unpack_tuple( + partial(self.featurizer.get_intermediate_layers, n={len(self.featurizer.blocks) - 2}) + ) + self.embed_dim = self.featurizer.embed_dim + + # Create fused featurizer if needed + if self.use_fused_vision_backbone: + self.fused_featurizer = timm.create_model( + timm_model_ids[1], + pretrained=False, + num_classes=0, + img_size=image_sizes[1], + act_layer=timm_override_act_layers[1] if len(timm_override_act_layers) > 1 else None, + ) + self.fused_featurizer.forward = unpack_tuple( + partial(self.fused_featurizer.get_intermediate_layers, n={len(self.fused_featurizer.blocks) - 2}) + ) + self.embed_dim += self.fused_featurizer.embed_dim + + # Patch LayerScale for HF compatibility + for module in self.featurizer.modules(): + if isinstance(module, LayerScale): + ls_apply_patch(module) + + if self.use_fused_vision_backbone: + for module in self.fused_featurizer.modules(): + if isinstance(module, LayerScale): + ls_apply_patch(module) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + """Run image through featurizer.""" + if not self.use_fused_vision_backbone: + return self.featurizer(pixel_values) + + # Split channel-stacked input + img, img_fused = torch.split(pixel_values, [3, 3], dim=1) + patches, patches_fused = self.featurizer(img), self.fused_featurizer(img_fused) + return torch.cat([patches, patches_fused], dim=2) + + +# === Prismatic Projector === +class PrismaticProjector(nn.Module): + """MLP projector from vision features to LLM hidden size.""" + + def __init__(self, use_fused_vision_backbone: bool, vision_dim: int, llm_dim: int) -> None: + super().__init__() + self.use_fused_vision_backbone = use_fused_vision_backbone + self.vision_dim, self.llm_dim = vision_dim, llm_dim + + if not self.use_fused_vision_backbone: + self.fc1 = nn.Linear(self.vision_dim, self.llm_dim, bias=True) + self.fc2 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) + self.act_fn1 = nn.GELU() + else: + initial_projection_dim = 4 * vision_dim + self.fc1 = nn.Linear(self.vision_dim, initial_projection_dim, bias=True) + self.fc2 = nn.Linear(initial_projection_dim, self.llm_dim, bias=True) + self.fc3 = nn.Linear(self.llm_dim, self.llm_dim, bias=True) + self.act_fn1 = nn.GELU() + self.act_fn2 = nn.GELU() + + def forward(self, img_patches: torch.Tensor) -> torch.Tensor: + """Project vision features to LLM hidden size.""" + if not self.use_fused_vision_backbone: + projected_features = self.fc1(img_patches) + projected_features = self.act_fn1(projected_features) + projected_features = self.fc2(projected_features) + else: + projected_features = self.fc1(img_patches) + projected_features = self.act_fn1(projected_features) + projected_features = self.fc2(projected_features) + projected_features = self.act_fn2(projected_features) + projected_features = self.fc3(projected_features) + return projected_features + + +# === Processing Info === +class OpenVLAProcessingInfo(BaseProcessingInfo): + """Processing info for OpenVLA model.""" + + def get_hf_config(self): + # Try to get vLLM's OpenVLAConfig, but fall back to the model's own config + # if it was loaded with trust_remote_code=True + try: + return self.ctx.get_hf_config(OpenVLAConfig) + except TypeError: + # If type check fails, the config is from the model's own configuration_prismatic.py + # We can still use it as it has the same structure + hf_config = self.ctx.model_config.hf_config + # Verify it has the necessary attributes + if not hasattr(hf_config, 'image_sizes') or not hasattr(hf_config, 'timm_model_ids'): + raise ValueError( + "Config does not have required attributes. " + "Expected config with 'image_sizes' and 'timm_model_ids' attributes." + ) + return hf_config + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(OpenVLAProcessor, **kwargs) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None} + + def get_num_image_tokens( + self, *, image_width: int, image_height: int, cropping: bool = True + ) -> int: + """Calculate number of image tokens based on image size.""" + hf_config = self.get_hf_config() + image_size = hf_config.image_sizes[0] # Use first image size + + # Calculate patches: (image_size / patch_size) ^ 2 + # For siglip-vit-so400m with patch14, image_size=224: (224/14)^2 = 16^2 = 256 + # But we need to check the actual patch size from the vision model + # For now, assume patch_size=14 for siglip models + patch_size = 14 # Default for siglip models + num_patches_per_side = image_size // patch_size + num_image_tokens = num_patches_per_side * num_patches_per_side + + return num_image_tokens + + def get_image_size_with_most_features(self) -> ImageSize: + """Get image size that produces most features.""" + hf_config = self.get_hf_config() + image_size = hf_config.image_sizes[0] + return ImageSize(width=image_size, height=image_size) + + +# === Dummy Inputs Builder === +class OpenVLADummyInputsBuilder(BaseDummyInputsBuilder[OpenVLAProcessingInfo]): + """Dummy inputs builder for OpenVLA.""" + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + # For OpenVLA, images are inserted as embeddings after BOS token + # We need to return dummy text that will be tokenized to include at least BOS + # so that PromptInsertion can work properly + # Return a space character which will be tokenized to at least one token + # (the tokenizer will add BOS if configured) + num_images = mm_counts.get("image", 0) + if num_images > 0: + # Return a space to ensure at least one token after tokenization + # This allows the insertion mechanism to work correctly + return " " + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + max_image_size = self.info.get_image_size_with_most_features() + image_overrides = mm_options.get("image") if mm_options else None + + return { + "image": self._get_dummy_images( + width=max_image_size.width, + height=max_image_size.height, + num_images=num_images, + overrides=image_overrides, + ) + } + + +# === MultiModal Processor === +class OpenVLAMultiModalProcessor( + BaseMultiModalProcessor[OpenVLAProcessingInfo] +): + """MultiModal processor for OpenVLA.""" + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if not mm_data: + tokenizer = self.info.get_tokenizer() + return tokenizer(prompt, add_special_tokens=True, return_tensors="pt") + + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + # OpenVLA uses PromptInsertion, so the HF processor doesn't apply updates + # vLLM will apply them via _apply_prompt_updates + return False + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + """Insert image tokens after BOS token (position 1).""" + tokenizer = self.info.get_tokenizer() + bos_token_id = tokenizer.bos_token_id + assert isinstance(bos_token_id, int) + + # Get image token ID - try multiple methods + image_token_id = None + + # Method 1: Check if tokenizer already has the token + if hasattr(tokenizer, 'vocab'): + image_token_id = tokenizer.vocab.get(_IMAGE_TOKEN) + elif hasattr(tokenizer, 'get_vocab'): + vocab = tokenizer.get_vocab() + image_token_id = vocab.get(_IMAGE_TOKEN) + elif hasattr(tokenizer, 'convert_tokens_to_ids'): + try: + image_token_id = tokenizer.convert_tokens_to_ids(_IMAGE_TOKEN) + if image_token_id == tokenizer.unk_token_id: + image_token_id = None + except Exception: + pass + + # Method 2: Try to get from processor + if image_token_id is None: + try: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_token_id = getattr(hf_processor, 'image_token_id', None) + except Exception: + pass + + # Method 3: Add the token if it's missing + if image_token_id is None: + try: + # Add the image token as a special token + num_added = tokenizer.add_special_tokens({"additional_special_tokens": [_IMAGE_TOKEN]}) + if num_added > 0: + # Get the token ID after adding + if hasattr(tokenizer, 'vocab'): + image_token_id = tokenizer.vocab.get(_IMAGE_TOKEN) + elif hasattr(tokenizer, 'get_vocab'): + vocab = tokenizer.get_vocab() + image_token_id = vocab.get(_IMAGE_TOKEN) + elif hasattr(tokenizer, 'convert_tokens_to_ids'): + image_token_id = tokenizer.convert_tokens_to_ids(_IMAGE_TOKEN) + except Exception as e: + # If adding fails, we'll raise an error below + pass + + # If image token is not found, use pad token as placeholder + # OpenVLA inserts image embeddings directly, so we don't strictly need + # the token in the vocabulary for the replacement mechanism + if image_token_id is None: + # Use pad token as placeholder - this is just for the replacement mechanism + # The actual image embeddings will be inserted by embed_multimodal + pad_token_id = tokenizer.pad_token_id + if pad_token_id is not None: + image_token_id = pad_token_id + else: + # Fallback to a high token ID that's unlikely to be used + # This is just for the prompt replacement mechanism + vocab_size = getattr(tokenizer, 'vocab_size', 32000) + image_token_id = vocab_size - 1 # Use last token as placeholder + + def get_insertion_openvla(item_idx: int): + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) + + if isinstance(images, ImageEmbeddingItems): + num_image_tokens = images.get_feature_size(item_idx) + else: + image_size = images.get_image_size(item_idx) + num_image_tokens = self.info.get_num_image_tokens( + image_width=image_size.width, + image_height=image_size.height, + ) + + image_tokens = [image_token_id] * num_image_tokens + + # Return PromptUpdateDetails to specify which tokens should be replaced with embeddings + return PromptUpdateDetails.select_token_id( + image_tokens, + embed_token_id=image_token_id, + ) + + # OpenVLA inserts image tokens after BOS token (position 1) + # Use PromptInsertion to insert at the start + # start() works even with empty prompts and will insert before any tokens + return [ + PromptInsertion( + modality="image", + target=PromptIndexTargets.start(), + insertion=get_insertion_openvla, + ) + ] + + +# === Main Model Class === +@MULTIMODAL_REGISTRY.register_processor( + OpenVLAMultiModalProcessor, + info=OpenVLAProcessingInfo, + dummy_inputs=OpenVLADummyInputsBuilder, +) +class OpenVLAForActionPrediction(nn.Module, SupportsMultiModal, SupportsPP): + """OpenVLA model for action prediction.""" + + merge_by_field_config = True + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "language.": "language_model.", + "vision_backbone.": "vision.", + } + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return _IMAGE_TOKEN + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: OpenVLAConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + model_config = vllm_config.model_config + tokenizer = cached_tokenizer_from_config(model_config) + + # Get image token ID - try multiple methods + image_token_id = None + if hasattr(tokenizer, 'vocab'): + image_token_id = tokenizer.vocab.get(_IMAGE_TOKEN) + elif hasattr(tokenizer, 'get_vocab'): + vocab = tokenizer.get_vocab() + image_token_id = vocab.get(_IMAGE_TOKEN) + elif hasattr(tokenizer, 'convert_tokens_to_ids'): + try: + image_token_id = tokenizer.convert_tokens_to_ids(_IMAGE_TOKEN) + if image_token_id == tokenizer.unk_token_id: + image_token_id = None + except Exception: + pass + + # If not found, use pad token as placeholder + # The actual image embeddings will be inserted by embed_multimodal + if image_token_id is None: + pad_token_id = tokenizer.pad_token_id + if pad_token_id is not None: + image_token_id = pad_token_id + else: + # Fallback to a high token ID + vocab_size = getattr(tokenizer, 'vocab_size', 32000) + image_token_id = vocab_size - 1 + + self.image_token_id: int = image_token_id + + # Initialize vision backbone + self.vision = self._init_vision_module( + config, quant_config, maybe_prefix(prefix, "vision_backbone") + ) + + # Initialize projector + vision_dim = self.vision.embed_dim + llm_dim = config.text_config.hidden_size + self.projector = PrismaticProjector( + config.use_fused_vision_backbone, + vision_dim, + llm_dim, + ) + + # Initialize language model + # Determine the architecture name based on LLM backbone + # For Llama-2, use "LlamaForCausalLM" + llm_arch = "LlamaForCausalLM" # Default for Llama-2 + if "mistral" in config.llm_backbone_id.lower(): + llm_arch = "MistralForCausalLM" + elif "phi" in config.llm_backbone_id.lower(): + llm_arch = "Phi3ForCausalLM" + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=[llm_arch], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def _get_parent_and_attr(self, root: torch.nn.Module, dotted_name: str): + """Return (parent_module, final_attr_name) for a dotted module path.""" + names = dotted_name.split(".") + parent = root + for n in names[:-1]: + parent = getattr(parent, n) + return parent, names[-1] + + def patch_vit_for_tp(self, vit: torch.nn.Module, quant_config: QuantizationConfig): + """Patch ViT for tensor parallelism.""" + try: + import timm + except ImportError as e: + raise ImportError("Please install timm") from e + + for name, module in vit.named_modules(): + if isinstance(module, nn.Linear): + parent, attr_name = self._get_parent_and_attr(vit, name) + if isinstance(parent, timm.layers.Mlp) and attr_name == "fc1": + new_linear = replace_linear_class( + module, "colwise", quant_config, prefix=name + ) + setattr(parent, attr_name, new_linear) + elif isinstance(parent, timm.layers.Mlp) and attr_name == "fc2": + new_linear = replace_linear_class( + module, "rowwise", quant_config, prefix=name + ) + setattr(parent, attr_name, new_linear) + + return vit + + def _init_vision_module( + self, + config: OpenVLAConfig, + quant_config: QuantizationConfig | None, + prefix: str = "", + ) -> nn.Module: + """Initialize vision backbone.""" + try: + import timm + except ImportError as e: + raise ImportError("Please install timm") from e + + with set_default_torch_dtype(torch.float16): + vision_backbone = PrismaticVisionBackbone( + use_fused_vision_backbone=config.use_fused_vision_backbone, + image_sizes=config.image_sizes, + timm_model_ids=config.timm_model_ids, + timm_override_act_layers=config.timm_override_act_layers, + ) + + if get_tensor_model_parallel_world_size() > 1: + vision_backbone.featurizer = self.patch_vit_for_tp( + vision_backbone.featurizer, quant_config + ) + if vision_backbone.use_fused_vision_backbone: + vision_backbone.fused_featurizer = self.patch_vit_for_tp( + vision_backbone.fused_featurizer, quant_config + ) + + vision_backbone = vision_backbone.to(dtype=torch.get_default_dtype()) + return vision_backbone + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> OpenVLAImageInputs | None: + """Parse and validate image input.""" + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + return OpenVLAImagePixelInputs( + type="pixel_values", + data=pixel_values, + ) + + if image_embeds is not None: + return OpenVLAImageEmbeddingInputs( + type="image_embeds", + data=image_embeds, + ) + + raise AssertionError("This line should be unreachable.") + + def _pixel_values_to_embedding( + self, + pixel_values: torch.Tensor, + ) -> list[torch.Tensor]: + """Convert pixel values to embeddings.""" + # Process through vision backbone + # pixel_values: [batch, channels, height, width] + # The vision backbone handles fused backbone internally + patches = self.vision(pixel_values) + # patches: [batch, num_patches, vision_dim] + + # Project to LLM hidden size + projected = self.projector(patches) + # projected: [batch, num_patches, llm_dim] + + # Return as list of tensors (one per image in batch) + return list(torch.unbind(projected, dim=0)) + + def _process_image_input( + self, image_input: OpenVLAImageInputs + ) -> list[torch.Tensor]: + """Process image input to embeddings.""" + if image_input["type"] == "image_embeds": + image_data = image_input["data"] + if is_list_of(image_data, torch.Tensor): + return image_data + if len(image_data.shape) == 3: + return list(torch.unbind(image_data, dim=0)) + raise ValueError( + "We expect batched 2D tensors; " + "this can be either a list of 2D tensors or a single 3D tensor." + ) + + pixel_values = image_input["data"] + return self._pixel_values_to_embedding(pixel_values=pixel_values) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + """Embed multimodal inputs (images).""" + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return [] + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ): + """Forward pass.""" + if intermediate_tensors is not None: + inputs_embeds = None + + hidden_states = self.language_model( + input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + """Compute logits from hidden states.""" + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights from checkpoint.""" + loader = AutoWeightsLoader(self) + autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + return autoloaded_weights + diff --git a/vllm/transformers_utils/configs/openvla.py b/vllm/transformers_utils/configs/openvla.py new file mode 100644 index 000000000000..6628f780bd03 --- /dev/null +++ b/vllm/transformers_utils/configs/openvla.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from OpenVLA/Prismatic configuration structure + +from typing import Any, Dict, List, Optional + +from transformers import PretrainedConfig +from transformers.models.auto import CONFIG_MAPPING + + +# === Utilities for Mapping Prismatic names to HF names === +# fmt: off +VISION_BACKBONE_TO_RESOLUTION: Dict[str, List[int]] = { + "clip-vit-l": [224], "siglip-vit-so400m": [224], "dinov2-vit-l": [224], "in1k-vit-l": [224], + "clip-vit-l-336px": [336], + "siglip-vit-so400m-384px": [384], + "dinoclip-vit-l-336px": [336, 336], + "dinosiglip-vit-so-224px": [224, 224], + "dinosiglip-vit-so-384px": [384, 384], +} +VISION_BACKBONE_TO_TIMM_ID: Dict[str, List[str]] = { + "clip-vit-l": ["vit_large_patch14_clip_224.openai"], + "clip-vit-l-336px": ["vit_large_patch14_clip_336.openai"], + "dinov2-vit-l": ["vit_large_patch14_reg4_dinov2.lvd142m"], + "in1k-vit-l": ["vit_large_patch16_224.augreg_in21k_ft_in1k"], + "siglip-vit-so400m": ["vit_so400m_patch14_siglip_224"], + "siglip-vit-so400m-384px": ["vit_so400m_patch14_siglip_384"], + "dinoclip-vit-l-336px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_large_patch14_clip_336.openai"], + "dinosiglip-vit-so-224px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_224"], + "dinosiglip-vit-so-384px": ["vit_large_patch14_reg4_dinov2.lvd142m", "vit_so400m_patch14_siglip_384"], +} +TIMM_OVERRIDE_ACT_LAYER: Dict[str, List[Optional[str]]] = { + "clip-vit-l": ["quick_gelu"], "clip-vit-l-336px": ["quick_gelu"], + "dinov2-vit-l": [None], "in1k-vit-l": [None], + "siglip-vit-so400m": [None], "siglip-vit-so400m-384px": [None], + "dinoclip-vit-l-336px": [None, "quick_gelu"], + "dinosiglip-vit-so-224px": [None, None], "dinosiglip-vit-so-384px": [None, None] +} + +LLM_BACKBONE_TO_HF_PATH = { + "llama2-7b-pure": "meta-llama/Llama-2-7b-hf", "llama2-13b-pure": "meta-llama/Llama-2-13b-hf", + "llama2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", "llama2-13b-chat": "meta-llama/Llama-2-13b-chat-hf", + "vicuna-v15-7b": "lmsys/vicuna-7b-v1.5", "vicuna-v15-13b": "lmsys/vicuna-13b-v1.5", + "mistral-v0.1-7b-pure": "mistralai/Mistral-7B-v0.1", + "mistral-v0.1-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1", + "phi-2-3b": "microsoft/phi-2", +} +LLM_BACKBONE_TO_HF_METACLASS = { + "llama2-7b-pure": "llama", "llama2-13b-pure": "llama", "llama2-7b-chat": "llama", "llama2-13b-chat": "llama", + "vicuna-v15-7b": "llama", "vicuna-v15-13b": "llama", + "mistral-v0.1-7b-pure": "mistral", "mistral-v0.1-7b-instruct": "mistral", + "phi-2-3b": "phi", +} +# fmt: on + + +class OpenVLAConfig(PretrainedConfig): + """Configuration for OpenVLA model compatible with vLLM.""" + + model_type: str = "openvla" + is_composition: bool = False + + def __init__( + self, + vision_backbone_id: str = "siglip-vit-so400m", + llm_backbone_id: str = "llama2-7b-pure", + arch_specifier: str = "no-align+fused-gelu-mlp", + use_fused_vision_backbone: Optional[bool] = None, + image_resize_strategy: str = "resize-naive", + image_sizes: Optional[List[int]] = None, + text_config: Optional[Dict[str, Any]] = None, + llm_max_length: int = 2048, + pad_token_id: int = 32000, + pad_to_multiple_of: int = 64, + norm_stats: Optional[Dict[str, Dict[str, Dict[str, Dict[str, List[float]]]]]] = None, + n_action_bins: int = 256, + **kwargs: Any, + ) -> None: + # Set OpenVLA-specific fields + self.norm_stats = norm_stats + self.n_action_bins = n_action_bins + self.arch_specifier = arch_specifier + + # Determine vision backbone ID from image_sizes if not provided + if vision_backbone_id is None and image_sizes is not None: + # Try to infer from image_sizes + if image_sizes == [224, 224]: + vision_backbone_id = "siglip-vit-so400m" + elif image_sizes == [384, 384]: + vision_backbone_id = "siglip-vit-so400m-384px" + else: + vision_backbone_id = "siglip-vit-so400m" # default + + # Validate vision backbone + if vision_backbone_id not in VISION_BACKBONE_TO_RESOLUTION: + # Use default if not found + vision_backbone_id = "siglip-vit-so400m" + + # Validate LLM backbone + if llm_backbone_id not in LLM_BACKBONE_TO_HF_PATH: + llm_backbone_id = "llama2-7b-pure" # default + + self.vision_backbone_id = vision_backbone_id + self.llm_backbone_id = llm_backbone_id + + # Determine if using fused vision backbone + self.use_fused_vision_backbone = ( + use_fused_vision_backbone + if use_fused_vision_backbone is not None + else any(self.vision_backbone_id.startswith(v) for v in ["dinoclip", "dinosiglip"]) + ) + + # Set vision config fields + self.timm_model_ids = VISION_BACKBONE_TO_TIMM_ID.get(self.vision_backbone_id, ["vit_so400m_patch14_siglip_224"]) + self.timm_override_act_layers = TIMM_OVERRIDE_ACT_LAYER.get(self.vision_backbone_id, [None]) + if image_sizes is not None: + self.image_sizes = image_sizes + else: + self.image_sizes = VISION_BACKBONE_TO_RESOLUTION.get(self.vision_backbone_id, [224]) + self.image_resize_strategy = image_resize_strategy + + # Set LLM config fields + self.hf_llm_id = LLM_BACKBONE_TO_HF_PATH.get(self.llm_backbone_id, "meta-llama/Llama-2-7b-hf") + self.llm_max_length = llm_max_length + self.pad_token_id = pad_token_id + self.pad_to_multiple_of = pad_to_multiple_of + + # Create text_config (LLM backbone config) + # [IMPORTANT] HF Utilities actually look for a `text_config` field + llm_metaclass = LLM_BACKBONE_TO_HF_METACLASS.get(self.llm_backbone_id, "llama") + self.text_config = ( + CONFIG_MAPPING[llm_metaclass](**text_config) + if text_config is not None + else CONFIG_MAPPING[llm_metaclass]() + ) + + # Dispatch **kwargs to super() + super().__init__(pad_token_id=pad_token_id, **kwargs) + + + + diff --git a/vllm/transformers_utils/processors/openvla.py b/vllm/transformers_utils/processors/openvla.py new file mode 100644 index 000000000000..8699ea8701bd --- /dev/null +++ b/vllm/transformers_utils/processors/openvla.py @@ -0,0 +1,157 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from OpenVLA/Prismatic processor structure + +from typing import Any, ClassVar, List, Optional, Union + +import torch +from PIL import Image +from transformers import ( + BatchFeature, + PreTrainedTokenizerBase, + ProcessorMixin, + TensorType, +) +from transformers.processing_utils import PaddingStrategy, TruncationStrategy + +# Try to import PrismaticProcessor from the model folder +# This will be used when loading the model +try: + import sys + import os + # Add the model directory to path if needed + # The processor will be loaded dynamically from the model folder + pass +except ImportError: + pass + + +class OpenVLAProcessor(ProcessorMixin): + """ + Processor for OpenVLA model that wraps PrismaticProcessor. + This is a minimal wrapper that vLLM can use to interface with the model's processor. + """ + + attributes: ClassVar[List[str]] = ["image_processor", "tokenizer"] + image_processor_class: str = "AutoImageProcessor" + tokenizer_class: str = "AutoTokenizer" + + def __init__( + self, + image_processor: Optional[Any] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + **kwargs: Any, + ) -> None: + super().__init__(image_processor, tokenizer) + self.image_token = "" + # Get image_token_id from tokenizer if available + if tokenizer is not None: + self.image_token_id = tokenizer.vocab.get(self.image_token) + if self.image_token_id is None: + # Try to add it + try: + tokenizer.add_special_tokens({"additional_special_tokens": [self.image_token]}) + self.image_token_id = tokenizer.vocab.get(self.image_token) + except Exception: + pass + + def __call__( + self, + text: Union[str, List[str]], + images: Union[Image.Image, List[Image.Image]], + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Optional[Union[bool, str, TruncationStrategy]] = None, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + **kwargs: Any, + ) -> BatchFeature: + """ + Preprocess text and images for OpenVLA model. + + Args: + text: Text input(s) to encode + images: Image(s) to preprocess + padding: Padding strategy + truncation: Truncation strategy + max_length: Maximum sequence length + return_tensors: Return tensor type + + Returns: + BatchFeature with input_ids, attention_mask, and pixel_values + """ + # Process images + if self.image_processor is not None: + pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"] + else: + raise ValueError("image_processor is required") + + # Process text + if self.tokenizer is not None: + text_inputs = self.tokenizer( + text, + return_tensors=return_tensors, + padding=padding, + truncation=truncation, + max_length=max_length, + **kwargs, + ) + else: + raise ValueError("tokenizer is required") + + # Validate batch sizes match + if pixel_values.shape[0] != text_inputs.input_ids.shape[0]: + raise ValueError( + "Batch is malformed; expected same number of images and text inputs!" + ) + + return BatchFeature(data={**text_inputs, "pixel_values": pixel_values}) + + def batch_decode( + self, + sequences: Union[List[int], List[List[int]], torch.Tensor, Any], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + **kwargs: Any, + ) -> List[str]: + """Decode token sequences to text.""" + if self.tokenizer is None: + raise ValueError("tokenizer is required") + return self.tokenizer.batch_decode( + sequences=sequences, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + def decode( + self, + token_ids: Union[int, List[int], torch.Tensor, Any], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: Optional[bool] = None, + **kwargs: Any, + ) -> str: + """Decode token sequence to text.""" + if self.tokenizer is None: + raise ValueError("tokenizer is required") + return self.tokenizer.decode( + token_ids=token_ids, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def model_input_names(self) -> List[str]: + """Return model input names.""" + tokenizer_input_names = ( + self.tokenizer.model_input_names if self.tokenizer else [] + ) + image_processor_input_names = ( + self.image_processor.model_input_names if self.image_processor else [] + ) + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + + + From 1b19006883a8cbef356bf30d68dd0775ed56752d Mon Sep 17 00:00:00 2001 From: Luke Date: Tue, 2 Dec 2025 01:13:43 -0500 Subject: [PATCH 2/2] Add the registry and config files so that vllm can parse HF openvla/openvla-7b Signed-off-by: Luke --- vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 2 ++ 3 files changed, 4 insertions(+) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 73a61f1148b5..ca9c928c4ccc 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -361,6 +361,7 @@ ), "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), "NVLM_D": ("nvlm_d", "NVLM_D_Model"), + "OpenVLAForActionPrediction": ("openvla", "OpenVLAForActionPrediction"), "Ovis": ("ovis", "Ovis"), "Ovis2_5": ("ovis2_5", "Ovis2_5"), "PaddleOCRVLForConditionalGeneration": ( diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 8f2cd3315ab9..220cf139330f 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -83,6 +83,7 @@ def __getitem__(self, key): speculators="SpeculatorsConfig", nemotron="NemotronConfig", olmo3="Olmo3Config", + openvla="OpenVLAConfig", ovis="OvisConfig", ultravox="UltravoxConfig", step3_vl="Step3VLConfig", diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 109f2b698651..ea48b1bdd37b 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -39,6 +39,7 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig from vllm.transformers_utils.configs.olmo3 import Olmo3Config +from vllm.transformers_utils.configs.openvla import OpenVLAConfig from vllm.transformers_utils.configs.ovis import OvisConfig from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig from vllm.transformers_utils.configs.radio import RadioConfig @@ -73,6 +74,7 @@ "NemotronConfig", "NemotronHConfig", "Olmo3Config", + "OpenVLAConfig", "OvisConfig", "RadioConfig", "SpeculatorsConfig",