|
14 | 14 |
|
15 | 15 |
|
16 | 16 | import base64 |
17 | | -from typing import Any, Callable, Union |
| 17 | +import os |
| 18 | +from io import BytesIO |
| 19 | +from typing import Callable, Union |
18 | 20 |
|
19 | | -import cv2 |
20 | 21 | import numpy as np |
21 | 22 | import requests |
| 23 | +from PIL import Image as PILImage |
22 | 24 | from PIL.Image import Image |
23 | 25 |
|
24 | 26 |
|
25 | 27 | def preprocess_image( |
26 | | - image: Union[Image, str, bytes, np.ndarray[Any, np.dtype[np.uint8 | np.float_]]], |
27 | | - encoding_function: Callable[[Any], str] = lambda x: base64.b64encode(x).decode( |
| 28 | + image: Union[Image, str, bytes, np.ndarray], |
| 29 | + encoding_function: Callable[[bytes], str] = lambda b: base64.b64encode(b).decode( |
28 | 30 | "utf-8" |
29 | 31 | ), |
30 | 32 | ) -> str: |
31 | | - if isinstance(image, Image): |
32 | | - image = np.array(image) |
33 | | - _, image_data = cv2.imencode(".png", image) |
34 | | - encoding_function = lambda x: base64.b64encode(x).decode("utf-8") |
35 | | - elif isinstance(image, str) and image.startswith(("http://", "https://")): |
36 | | - response = requests.get(image) |
37 | | - response.raise_for_status() |
38 | | - image_data = response.content |
39 | | - elif isinstance(image, str): |
40 | | - with open(image, "rb") as image_file: |
41 | | - image_data = image_file.read() |
42 | | - elif isinstance(image, bytes): |
43 | | - image_data = image |
44 | | - encoding_function = lambda x: x.decode("utf-8") |
45 | | - elif isinstance(image, np.ndarray): # type: ignore |
46 | | - if image.dtype == np.float32 or image.dtype == np.float64: |
47 | | - image = (image * 255).astype(np.uint8) |
48 | | - _, image_data = cv2.imencode(".png", image) |
49 | | - encoding_function = lambda x: base64.b64encode(x).decode("utf-8") |
50 | | - else: |
51 | | - image_data = image |
52 | | - |
53 | | - return encoding_function(image_data) |
| 33 | + """Convert various image inputs into a base64-encoded PNG string. |
| 34 | +
|
| 35 | + Parameters |
| 36 | + ---------- |
| 37 | + image : PIL.Image.Image or str or bytes or numpy.ndarray |
| 38 | + Supported inputs: |
| 39 | + - PIL Image |
| 40 | + - Path to a file, ``file://`` URL, or HTTP(S) URL |
| 41 | + - Raw bytes containing image data |
| 42 | + - ``numpy.ndarray`` with dtype ``uint8`` or ``float32``/``float64``; |
| 43 | + grayscale or 3/4-channel arrays are supported. |
| 44 | + encoding_function : callable, optional |
| 45 | + Function that converts PNG bytes to the final string representation. |
| 46 | + By default, returns base64-encoded UTF-8 string. |
| 47 | +
|
| 48 | + Returns |
| 49 | + ------- |
| 50 | + str |
| 51 | + Base64-encoded PNG string. |
| 52 | +
|
| 53 | + Raises |
| 54 | + ------ |
| 55 | + FileNotFoundError |
| 56 | + If a file path (or ``file://`` URL) does not exist. |
| 57 | + TypeError |
| 58 | + If the input type is not supported. |
| 59 | + requests.HTTPError |
| 60 | + If fetching an HTTP(S) URL fails with a non-2xx response. |
| 61 | + requests.RequestException |
| 62 | + If a network error occurs while fetching an HTTP(S) URL. |
| 63 | + OSError |
| 64 | + If the input cannot be decoded as an image by Pillow. |
| 65 | +
|
| 66 | + Notes |
| 67 | + ----- |
| 68 | + - All inputs are decoded and re-encoded to PNG to guarantee consistent output. |
| 69 | + - Float arrays are assumed to be in [0, 1] and are scaled to ``uint8``. |
| 70 | + - Network requests use a timeout of ``(5, 15)`` seconds (connect, read). |
| 71 | +
|
| 72 | + Examples |
| 73 | + -------- |
| 74 | + >>> b64_png = preprocess_image("path/to/image.jpg") |
| 75 | + >>> import numpy as np |
| 76 | + >>> arr = np.random.rand(64, 64, 3).astype(np.float32) |
| 77 | + >>> b64_png = preprocess_image(arr) |
| 78 | + """ |
| 79 | + |
| 80 | + def _to_pil_from_ndarray(arr: np.ndarray) -> Image: |
| 81 | + a = arr |
| 82 | + if a.dtype in (np.float32, np.float64): |
| 83 | + a = np.clip(a, 0.0, 1.0) |
| 84 | + a = (a * 255.0).round().astype(np.uint8) |
| 85 | + a = np.ascontiguousarray(a) |
| 86 | + return PILImage.fromarray(a) |
| 87 | + |
| 88 | + def _ensure_pil(img: Union[Image, str, bytes, np.ndarray]) -> Image: |
| 89 | + if isinstance(img, Image): |
| 90 | + return img |
| 91 | + if isinstance(img, np.ndarray): # type: ignore |
| 92 | + return _to_pil_from_ndarray(img) |
| 93 | + if isinstance(img, str): |
| 94 | + if img.startswith(("http://", "https://")): |
| 95 | + response = requests.get(img, timeout=(5, 15)) |
| 96 | + response.raise_for_status() |
| 97 | + return PILImage.open(BytesIO(response.content)) |
| 98 | + if img.startswith("file://"): |
| 99 | + file_path = img[len("file://") :] |
| 100 | + else: |
| 101 | + # fallback to file path if not marked with file:// |
| 102 | + file_path = img |
| 103 | + if not os.path.exists(file_path): |
| 104 | + raise FileNotFoundError(f"File not found: {file_path}") |
| 105 | + return PILImage.open(file_path) |
| 106 | + if isinstance(img, bytes): |
| 107 | + return PILImage.open(BytesIO(img)) |
| 108 | + raise TypeError(f"Unsupported image type: {type(img).__name__}") |
| 109 | + |
| 110 | + pil_image = _ensure_pil(image) |
| 111 | + |
| 112 | + # Normalize to PNG bytes |
| 113 | + with BytesIO() as buffer: |
| 114 | + pil_image.save(buffer, format="PNG") |
| 115 | + png_bytes = buffer.getvalue() |
| 116 | + |
| 117 | + return encoding_function(png_bytes) |
0 commit comments