Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions tests/models/multimodal/processing/test_gemma3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest

from vllm.multimodal import MULTIMODAL_REGISTRY

from ....conftest import ImageTestAssets
from ...utils import build_model_context


@pytest.mark.parametrize("model_id", ["google/gemma-3-4b-it"])
def test_get_image_size_with_most_features(
image_assets: ImageTestAssets, model_id: str
):
ctx = build_model_context(
model_id,
mm_processor_kwargs={"do_pan_and_scan": True},
limit_mm_per_prompt={"image": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)

hf_processor_mm_kwargs: dict[str, object] = {}
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)

max_image_size = processor.info.get_image_size_with_most_features()
max_tokens = processor.info.get_num_image_tokens(
image_width=max_image_size.width,
image_height=max_image_size.height,
processor=hf_processor,
)

prompt = "<start_of_image>"
image_seq_length = hf_processor.image_seq_length

for asset in image_assets:
mm_data = {"image": [asset.pil_image]}
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
mm_kwargs_data = processed_inputs["mm_kwargs"].get_data()
num_patches_tensor = mm_kwargs_data["num_patches"]
tokens = int(num_patches_tensor.item()) * image_seq_length
assert tokens <= max_tokens
35 changes: 35 additions & 0 deletions tests/models/multimodal/processing/test_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,38 @@ def test_processor_override(
assert img_tok_count == expected_toks_per_img * num_imgs
assert pixel_shape[0] == expected_pixels_shape[0] * num_imgs
assert pixel_shape[1] == expected_pixels_shape[1]


@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
@pytest.mark.parametrize("max_pixels", [1280 * 28 * 28, 1283 * 28 * 28])
def test_get_image_size_with_most_features(
image_assets: ImageTestAssets,
model_id: str,
max_pixels: int,
):
ctx = build_model_context(
model_id,
mm_processor_kwargs={"max_pixels": max_pixels},
limit_mm_per_prompt={"image": 1},
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)

hf_processor_mm_kwargs: dict[str, object] = {}
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
merge_size = processor.info.get_hf_config().vision_config.spatial_merge_size

max_image_size = processor.info.get_image_size_with_most_features()
max_tokens = processor.info.get_num_image_tokens(
image_width=max_image_size.width,
image_height=max_image_size.height,
image_processor=hf_processor.image_processor,
)

prompt = "<|vision_start|><|image_pad|><|vision_end|>"
for asset in image_assets:
mm_data = {"image": [asset.pil_image]}
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
grid_thw = processed_inputs["mm_kwargs"].get_data()["image_grid_thw"].tolist()
t, h, w = grid_thw[0]
tokens = (t * h * w) // (merge_size**2)
assert tokens < max_tokens
5 changes: 3 additions & 2 deletions vllm/model_executor/models/gemma3_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,9 @@ def get_image_size_with_most_features(self) -> ImageSize:
)
max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]

# Result in the max possible feature size (h:w = max_num_crops:1)
return ImageSize(height=50 * max_num_crops, width=50)
vision_config = self.get_hf_config().vision_config
native_size = vision_config.image_size
return ImageSize(height=native_size * max_num_crops, width=native_size)


class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
Expand Down
44 changes: 37 additions & 7 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""

import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Literal, TypeAlias
Expand Down Expand Up @@ -959,13 +960,42 @@ def get_num_video_tokens(
return num_video_tokens

def get_image_size_with_most_features(self) -> ImageSize:
max_image_size, _ = self._get_vision_info(
image_width=9999999,
image_height=9999999,
num_frames=1,
image_processor=None,
)
return max_image_size
# NOTE: Simply processing a huge size with _get_vision_info might not give a
# size that maximizes the number of featrues, i.e., the number of (merged)
# patches. This is because the number of patches limits the allowed aspect
# ratios. For example, suppose the maximum number of patches is 1280. A square
# image cannot be broken down into 1280 patches, so feeding a giant square image
# into _get_vision_info will not yield a size that maximizes the number of
# patches. Therefore, we directly factorize the maximum number of patches into
# height and width. The tricky part is to avoid extreme aspect ratios (>200 for
# qwen2-vl). If we can't find a suitable aspect ratio, we decrease the number of
# patches and retry. This is safe because the processor does not accept extreme
# aspect ratios, so there is no valid post-resize image with the number of
# patches that yields extreme aspect ratios.

hf_config = self.get_hf_config()
vision_config = hf_config.vision_config
patch_size = vision_config.patch_size
merge_size = vision_config.spatial_merge_size
image_processor = self.get_image_processor()
max_pixels = image_processor.max_pixels or image_processor.size["longest_edge"]
unit = patch_size * merge_size
max_seq_len = max_pixels // (unit * unit)

def closest_factor_pair(n: int) -> tuple[int, int]:
# left <= right
for d in range(math.isqrt(n), 0, -1):
if n % d == 0:
return d, n // d
return 1, n

height_factor, width_factor = 1, max_seq_len
for seq_len in range(max_seq_len, 0, -1):
height_factor, width_factor = closest_factor_pair(seq_len)
if width_factor / height_factor <= 200:
break

return ImageSize(width=unit * width_factor, height=unit * height_factor)

def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
Expand Down