Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 47 additions & 30 deletions src/transformers/image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,12 @@

import numpy as np

from .image_utils import (
ChannelDimension,
ImageInput,
get_channel_dimension_axis,
get_image_size,
infer_channel_dimension_format,
)
from .image_utils import (ChannelDimension, ImageInput,
get_channel_dimension_axis, get_image_size,
infer_channel_dimension_format)
from .utils import ExplicitEnum, TensorType, is_torch_tensor
from .utils.import_utils import (
is_torch_available,
is_vision_available,
requires_backends,
)

from .utils.import_utils import (is_torch_available, is_vision_available,
requires_backends)

if is_vision_available():
import PIL
Expand Down Expand Up @@ -698,39 +690,64 @@ def pad(
if input_data_format is None:
input_data_format = infer_channel_dimension_format(image)

# Fast path: If padding is 0, return image unchanged for performance
def is_zero_pad(padval):
if isinstance(padval, int) or isinstance(padval, float):
return padval == 0
elif isinstance(padval, tuple) or isinstance(padval, list):
return all(is_zero_pad(p) for p in padval)
elif isinstance(padval, Iterable):
return all(is_zero_pad(p) for p in padval)
return False

if is_zero_pad(padding):
if (mode == PaddingMode.CONSTANT and is_zero_pad(constant_values)) or mode != PaddingMode.CONSTANT:
# No actual padding will be applied
return (
image
if (data_format is None or input_data_format == data_format)
else to_channel_dimension_format(image, data_format, input_data_format)
)

def _expand_for_data_format(values):
"""
Convert values to be in the format expected by np.pad based on the data format.
"""
if isinstance(values, (int, float)):
values = ((values, values), (values, values))
elif isinstance(values, tuple) and len(values) == 1:
values = ((values[0], values[0]), (values[0], values[0]))
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
values = (values, values)
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
pass
pad_tuple = ((values, values), (values, values))
elif isinstance(values, tuple):
if len(values) == 1:
pad_tuple = ((values[0], values[0]), (values[0], values[0]))
elif len(values) == 2 and isinstance(values[0], int):
pad_tuple = (values, values)
elif len(values) == 2 and isinstance(values[0], tuple):
pad_tuple = values
else:
raise ValueError(f"Unsupported format: {values}")
else:
raise ValueError(f"Unsupported format: {values}")

# add 0 for channel dimension
values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0))
if input_data_format == ChannelDimension.FIRST:
pad_tuple = ((0, 0),) + pad_tuple
else:
pad_tuple = pad_tuple + ((0, 0),)

# Add additional padding if there's a batch dimension
values = ((0, 0), *values) if image.ndim == 4 else values
return values
if image.ndim == 4:
pad_tuple = ((0, 0),) + pad_tuple
return pad_tuple

padding = _expand_for_data_format(padding)
padding_tuple = _expand_for_data_format(padding)

if mode == PaddingMode.CONSTANT:
constant_values = _expand_for_data_format(constant_values)
image = np.pad(image, padding, mode="constant", constant_values=constant_values)
constant_values_tuple = _expand_for_data_format(constant_values)
image = np.pad(image, padding_tuple, mode="constant", constant_values=constant_values_tuple)
elif mode == PaddingMode.REFLECT:
image = np.pad(image, padding, mode="reflect")
image = np.pad(image, padding_tuple, mode="reflect")
elif mode == PaddingMode.REPLICATE:
image = np.pad(image, padding, mode="edge")
image = np.pad(image, padding_tuple, mode="edge")
elif mode == PaddingMode.SYMMETRIC:
image = np.pad(image, padding, mode="symmetric")
image = np.pad(image, padding_tuple, mode="symmetric")
else:
raise ValueError(f"Invalid padding mode: {mode}")

Expand Down
84 changes: 42 additions & 42 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import base64
import os
from collections import deque
from collections.abc import Iterable
from dataclasses import dataclass
from io import BytesIO
Expand All @@ -22,26 +23,13 @@
import httpx
import numpy as np

from .utils import (
ExplicitEnum,
is_numpy_array,
is_torch_available,
is_torch_tensor,
is_torchvision_available,
is_vision_available,
logging,
requires_backends,
to_numpy,
)
from .utils.constants import ( # noqa: F401
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD,
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
)

from .utils import (ExplicitEnum, is_numpy_array, is_torch_available,
is_torch_tensor, is_torchvision_available,
is_vision_available, logging, requires_backends, to_numpy)
from .utils.constants import (IMAGENET_DEFAULT_MEAN, # noqa: F401
IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN,
IMAGENET_STANDARD_STD, OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD)

if is_vision_available():
import PIL.Image
Expand Down Expand Up @@ -132,14 +120,14 @@ def concatenate_list(input_list):


def valid_images(imgs):
# If we have an list of images, make sure every image is valid
if isinstance(imgs, (list, tuple)):
for img in imgs:
if not valid_images(img):
return False
# If not a list of tuple, we have been given a single image or batched tensor of images
elif not is_valid_image(imgs):
return False
# Iteratively validate images/batches/lists for improved performance (no recursion)
queue = deque([imgs])
while queue:
img = queue.pop()
if isinstance(img, (list, tuple)):
queue.extend(img)
elif not is_valid_image(img):
return False
return True


Expand Down Expand Up @@ -213,6 +201,11 @@ def make_flat_list_of_images(
Returns:
list: A list of images or a 4d array of images.
"""
# If the input is a nested list of images, we flatten it
# Fast path for None or empty input
if not images or (isinstance(images, (list, tuple)) and len(images) == 0):
raise ValueError(f"Could not make a flat list of images from {images}")

# If the input is a nested list of images, we flatten it
if (
isinstance(images, (list, tuple))
Expand All @@ -222,15 +215,16 @@ def make_flat_list_of_images(
return [img for img_list in images for img in img_list]

if isinstance(images, (list, tuple)) and is_valid_list_of_images(images):
if is_pil_image(images[0]) or images[0].ndim == expected_ndims:
first_img = images[0]
if is_pil_image(first_img) or getattr(first_img, "ndim", None) == expected_ndims:
return images
if images[0].ndim == expected_ndims + 1:
if getattr(first_img, "ndim", None) == expected_ndims + 1:
return [img for img_list in images for img in img_list]

if is_valid_image(images):
if is_pil_image(images) or images.ndim == expected_ndims:
if is_pil_image(images) or getattr(images, "ndim", None) == expected_ndims:
return [images]
if images.ndim == expected_ndims + 1:
if getattr(images, "ndim", None) == expected_ndims + 1:
return list(images)

raise ValueError(f"Could not make a flat list of images from {images}")
Expand Down Expand Up @@ -299,26 +293,32 @@ def infer_channel_dimension_format(
Returns:
The channel dimension of the image.
"""
num_channels = num_channels if num_channels is not None else (1, 3)
num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels
num_channels = (
(num_channels,) if isinstance(num_channels, int) else (1, 3) if num_channels is None else num_channels
)

if image.ndim == 3:
ndim = image.ndim
if ndim == 3:
first_dim, last_dim = 0, 2
elif image.ndim == 4:
elif ndim == 4:
first_dim, last_dim = 1, 3
elif image.ndim == 5:
elif ndim == 5:
first_dim, last_dim = 2, 4
else:
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}")
raise ValueError(f"Unsupported number of image dimensions: {ndim}")

shape = image.shape
first_in_channels = shape[first_dim] in num_channels
last_in_channels = shape[last_dim] in num_channels

if image.shape[first_dim] in num_channels and image.shape[last_dim] in num_channels:
if first_in_channels and last_in_channels:
logger.warning(
f"The channel dimension is ambiguous. Got image shape {image.shape}. Assuming channels are the first dimension. Use the [input_data_format](https://huggingface.co/docs/transformers/main/internal/image_processing_utils#transformers.image_transforms.rescale.input_data_format) parameter to assign the channel dimension."
f"The channel dimension is ambiguous. Got image shape {shape}. Assuming channels are the first dimension. Use the [input_data_format](https://huggingface.co/docs/transformers/main/internal/image_processing_utils#transformers.image_transforms.rescale.input_data_format) parameter to assign the channel dimension."
)
return ChannelDimension.FIRST
elif image.shape[first_dim] in num_channels:
elif first_in_channels:
return ChannelDimension.FIRST
elif image.shape[last_dim] in num_channels:
elif last_in_channels:
return ChannelDimension.LAST
raise ValueError("Unable to infer channel dimension format")

Expand Down
Loading