Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 29 additions & 32 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,23 @@

import base64
import os
from collections import deque
from collections.abc import Iterable
from dataclasses import dataclass
from io import BytesIO
from typing import Optional, Union

import httpx
import numpy as np
import PIL.Image

from .utils import (
ExplicitEnum,
is_numpy_array,
is_torch_available,
is_torch_tensor,
is_torchvision_available,
is_vision_available,
logging,
requires_backends,
to_numpy,
)
from .utils.constants import ( # noqa: F401
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
)

from .utils import (ExplicitEnum, is_numpy_array, is_torch_available,
is_torch_tensor, is_torchvision_available,
is_vision_available, logging, requires_backends, to_numpy)
from .utils.constants import (IMAGENET_DEFAULT_MEAN, # noqa: F401
IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD, OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD)

if is_vision_available():
import PIL.Image
Expand Down Expand Up @@ -132,14 +121,14 @@ def concatenate_list(input_list):


def valid_images(imgs):
# If we have an list of images, make sure every image is valid
if isinstance(imgs, (list, tuple)):
for img in imgs:
if not valid_images(img):
return False
# If not a list of tuple, we have been given a single image or batched tensor of images
elif not is_valid_image(imgs):
return False
# Iteratively validate images/batches/lists for improved performance (no recursion)
queue = deque([imgs])
while queue:
img = queue.pop()
if isinstance(img, (list, tuple)):
queue.extend(img)
elif not is_valid_image(img):
return False
return True


Expand Down Expand Up @@ -213,6 +202,11 @@ def make_flat_list_of_images(
Returns:
list: A list of images or a 4d array of images.
"""
# If the input is a nested list of images, we flatten it
# Fast path for None or empty input
if not images or (isinstance(images, (list, tuple)) and len(images) == 0):
raise ValueError(f"Could not make a flat list of images from {images}")

# If the input is a nested list of images, we flatten it
if (
isinstance(images, (list, tuple))
Expand All @@ -222,15 +216,16 @@ def make_flat_list_of_images(
return [img for img_list in images for img in img_list]

if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
first_img = images[0]
if is_pil_image(first_img) or getattr(first_img, "ndim", None) == expected_ndims:
return images
if images[0].ndim == expected_ndims + 1:
if getattr(first_img, "ndim", None) == expected_ndims + 1:
return [img for img_list in images for img in img_list]

if is_valid_image(images):
if is_pil_image(images) or images.ndim == expected_ndims:
if is_pil_image(images) or getattr(images, "ndim", None) == expected_ndims:
return [images]
if images.ndim == expected_ndims + 1:
if getattr(images, "ndim", None) == expected_ndims + 1:
return list(images)

raise ValueError(f"Could not make a flat list of images from {images}")
Expand Down Expand Up @@ -276,6 +271,8 @@ def make_nested_list_of_images(


def to_numpy_array(img) -> np.ndarray:
if isinstance(img, np.ndarray) and is_valid_image(img):
return img
if not is_valid_image(img):
raise ValueError(f"Invalid image type: {type(img)}")

Expand Down
Loading