Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 71 additions & 38 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import base64
import os
from collections import deque
from collections.abc import Iterable
from dataclasses import dataclass
from io import BytesIO
Expand All @@ -22,26 +23,13 @@
import httpx
import numpy as np

from .utils import (
ExplicitEnum,
is_numpy_array,
is_torch_available,
is_torch_tensor,
is_torchvision_available,
is_vision_available,
logging,
requires_backends,
to_numpy,
)
from .utils.constants import ( # noqa: F401
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
)

from .utils import (ExplicitEnum, is_numpy_array, is_torch_available,
is_torch_tensor, is_torchvision_available,
is_vision_available, logging, requires_backends, to_numpy)
from .utils.constants import (IMAGENET_DEFAULT_MEAN, # noqa: F401
IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD, OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD)

if is_vision_available():
import PIL.Image
Expand All @@ -67,6 +55,8 @@
if is_torch_available():
import torch

_infer_channel_dim_cache = {}


logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -132,14 +122,14 @@ def concatenate_list(input_list):


def valid_images(imgs):
# If we have an list of images, make sure every image is valid
if isinstance(imgs, (list, tuple)):
for img in imgs:
if not valid_images(img):
return False
# If not a list of tuple, we have been given a single image or batched tensor of images
elif not is_valid_image(imgs):
return False
# Iteratively validate images/batches/lists for improved performance (no recursion)
queue = deque([imgs])
while queue:
img = queue.pop()
if isinstance(img, (list, tuple)):
queue.extend(img)
elif not is_valid_image(img):
return False
return True


Expand Down Expand Up @@ -213,6 +203,11 @@ def make_flat_list_of_images(
Returns:
list: A list of images or a 4d array of images.
"""
# If the input is a nested list of images, we flatten it
# Fast path for None or empty input
if not images or (isinstance(images, (list, tuple)) and len(images) == 0):
raise ValueError(f"Could not make a flat list of images from {images}")

# If the input is a nested list of images, we flatten it
if (
isinstance(images, (list, tuple))
Expand All @@ -222,15 +217,16 @@ def make_flat_list_of_images(
return [img for img_list in images for img in img_list]

if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
first_img = images[0]
if is_pil_image(first_img) or getattr(first_img, "ndim", None) == expected_ndims:
return images
if images[0].ndim == expected_ndims + 1:
if getattr(first_img, "ndim", None) == expected_ndims + 1:
return [img for img_list in images for img in img_list]

if is_valid_image(images):
if is_pil_image(images) or images.ndim == expected_ndims:
if is_pil_image(images) or getattr(images, "ndim", None) == expected_ndims:
return [images]
if images.ndim == expected_ndims + 1:
if getattr(images, "ndim", None) == expected_ndims + 1:
return list(images)

raise ValueError(f"Could not make a flat list of images from {images}")
Expand Down Expand Up @@ -360,15 +356,52 @@ def get_image_size(image: np.ndarray, channel_dim: Optional[ChannelDimension] =
Returns:
A tuple of the image's height and width.
"""
# Fast-path: infer channel without calling infer_channel_dimension_format if shape gives exact match
shape = image.shape
ndim = image.ndim

if channel_dim is None:
channel_dim = infer_channel_dimension_format(image)
# Heuristic: try most common cases without full inference
if ndim == 3:
# image.shape e.g., (3, H, W) or (H, W, 3)
if shape[0] in (1, 3): # Channel FIRST
channel_dim_val = "FIRST"
elif shape[2] in (1, 3): # Channel LAST
channel_dim_val = "LAST"
else:
channel_dim = infer_channel_dimension_format(image)
channel_dim_val = channel_dim
elif ndim == 4:
# image.shape e.g., (N, 3, H, W) or (N, H, W, 3)
if shape[1] in (1, 3):
channel_dim_val = "FIRST"
elif shape[3] in (1, 3):
channel_dim_val = "LAST"
else:
channel_dim = infer_channel_dimension_format(image)
channel_dim_val = channel_dim
elif ndim == 5:
# image.shape e.g., (N, T, 3, H, W) or (N, T, H, W, 3)
if shape[2] in (1, 3):
channel_dim_val = "FIRST"
elif shape[4] in (1, 3):
channel_dim_val = "LAST"
else:
channel_dim = infer_channel_dimension_format(image)
channel_dim_val = channel_dim
else:
# Fallback for non-standard dims
channel_dim = infer_channel_dimension_format(image)
channel_dim_val = channel_dim
else:
channel_dim_val = channel_dim

if channel_dim == ChannelDimension.FIRST:
return image.shape[-2], image.shape[-1]
elif channel_dim == ChannelDimension.LAST:
return image.shape[-3], image.shape[-2]
if channel_dim_val == "FIRST" or channel_dim_val == ChannelDimension.FIRST:
return shape[-2], shape[-1]
elif channel_dim_val == "LAST" or channel_dim_val == ChannelDimension.LAST:
return shape[-3], shape[-2]
else:
raise ValueError(f"Unsupported data format: {channel_dim}")
raise ValueError(f"Unsupported data format: {channel_dim_val}")


def get_image_size_for_max_height_width(
Expand Down
Loading