Skip to content

Output from generate is a tuple, rather than a string #334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jrp2014 opened this issue May 2, 2025 · 7 comments
Open

Output from generate is a tuple, rather than a string #334

jrp2014 opened this issue May 2, 2025 · 7 comments
Labels
bug Something isn't working

Comments

@jrp2014
Copy link

jrp2014 commented May 2, 2025

I must be doing something wrong, but the output string that I get back from generate is a tuple with both the model output and some stats. I always set the verbose parameter. When it is set to false, I get, eg:

2025-05-02 22:32:09,653 - INFO - Processing '20250405-140223_DSC03171.jpg' with model: HuggingFaceTB/SmolVLM-Instruct
Fetching 12 files: 100%|█████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 53544.31it/s]
Fetching 12 files: 100%|████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 238538.62it/s]
Model output: (' A stone bridge over a river with a building in the background.', {'input_tokens': 1271, 'output_tokens': 14, 'total_tokens': 1285, 'prompt_tps': 1105.750074387695, 'generation_tps': 123.6502897460356, 'peak_memory': 5.442698462})

and with a bit more debugging output:

output is of type <class 'tuple'>
output is (' A stone bridge over a river with a building in the background.', {'input_tokens': 1271, 'output_tokens': 14, 'total_tokens': 1285, 'prompt_tps': 1102.7072144956928, 'generation_tps': 123.4124421865043, 'peak_memory': 5.442698464})

which checks.

Truing the verbose parameter:

2025-05-02 22:34:00,701 - INFO - Processing '20250405-140223_DSC03171.jpg' with model: HuggingFaceTB/SmolVLM-Instruct
Fetching 12 files: 100%|█████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 28695.35it/s]
Fetching 12 files: 100%|█████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 27206.30it/s]
==========
Files: /Users/jrp/Pictures/Processed/20250405-140223_DSC03171.jpg 

Prompt: <|im_start|>User:<image>Provide a factual caption, description, and keywords suitable for cataloguing, or searching for, the image. taken around 2025:04:05 14:02:23 near GPS 51.320833, -2.211517. Focus on visual content. Avoid repeating the context unless it is visible.<end_of_utterance>
Assistant:
 A stone bridge over a river with a building in the background.
==========
Prompt: 1271 tokens, 1102.124 tokens-per-sec
Generation: 14 tokens, 123.637 tokens-per-sec
Peak memory: 5.443 GB
Model output: (' A stone bridge over a river with a building in the background.', {'input_tokens': 1271, 'output_tokens': 14, 'total_tokens': 1285, 'prompt_tps': 1102.1237777966808, 'generation_tps': 123.63677533913753, 'peak_memory': 5.442698464})

(' A stone bridge over a river with a building in the background.', {'input_tokens': 1271, 'output_tokens': 14, 'total_tokens': 1285, 'prompt_tps': 1102.1237777966808, 'generation_tps': 123.63677533913753, 'peak_memory': 5.442698464})
[SUCCESS] SmolVLM-Instruct

Output:
(' A stone bridge over a river with a building in the background.', {'input_tokens': 1271, 'output_tokens': 14, 'total_tokens': 1285, 'prompt_tps': 1102.1237777966808, 'generation_tps': 123.63677533913753, 'peak_memory': 5.442698464})
Time taken: 2.85s

I am obviously making a schoolboy error of some sort, but I can't see what it is. Suggestions ... ?

#!/usr/bin/env python3
"""Image analysis and caption generation using MLX Vision Language Models."""

# Standard library imports
import argparse
import contextlib
import html
import logging
import platform
import re  # For ANSI code stripping
import signal
import subprocess
import sys
import time
import traceback
import types # Added for TracebackType
from dataclasses import dataclass, field
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import (
    Any, Dict, Final, List, NamedTuple, NoReturn,
    Optional, TextIO, Tuple, Union, Callable,
    TypeVar, Pattern # Added Pattern
)

import os # Needed for os.path.getmtime

from huggingface_hub import HFCacheInfo, scan_cache_dir
from huggingface_hub import __version__ as hf_version
from huggingface_hub.errors import HFValidationError

# Third-party imports
try:
    import mlx.core as mx
except ImportError:
    print("Core dependency missing: mlx. Please install it.", file=sys.stderr)
    sys.exit(1)

try:
    from PIL import Image, UnidentifiedImageError, ExifTags
    from PIL.ExifTags import GPSTAGS, TAGS
    pillow_version = Image.__version__ if hasattr(Image, '__version__') else 'N/A'
except ImportError:
    print("Error: Pillow not found. Please install it (`pip install Pillow`).", file=sys.stderr)
    pillow_version = "N/A"
    sys.exit(1)

# Local application/library specific imports
try:
    from mlx_vlm import (__version__ as vlm_version, generate, load)
    from mlx_vlm.prompt_utils import apply_chat_template
    from mlx_vlm.utils import load_config
except ImportError:
    print("Error: mlx-vlm not found. Please install it (`pip install mlx-vlm`).", file=sys.stderr)
    sys.exit(1)

# Optional imports for version reporting
try:
    # Import the module first
    from mlx_lm import _version as mlx_lm_version_module
    # Then try to get its __version__ attribute, ensuring it's a string
    mlx_lm_version = str(getattr(mlx_lm_version_module, '__version__', 'N/A'))
except ImportError:
    # If the module itself cannot be imported
    mlx_lm_version = "N/A"
except AttributeError:
    # If the module is imported but lacks a __version__ attribute
    mlx_lm_version = "N/A (module found, no version attr)"
try:
    import transformers
    transformers_version = transformers.__version__
    # Import specific tokenizer types
    from transformers.tokenization_utils import PreTrainedTokenizer
    from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
except ImportError:
    transformers_version = "N/A"
    # Define dummy types if transformers is not installed to avoid NameErrors later
    # Although the script exits earlier if transformers is missing for mlx_vlm
    PreTrainedTokenizer = type("PreTrainedTokenizer", (), {})
    PreTrainedTokenizerFast = type("PreTrainedTokenizerFast", (), {})


# Custom timeout context manager (for Python < 3.11)
# Note: This implementation relies on signal.SIGALRM and will not work on Windows.
class timeout_manager(contextlib.ContextDecorator):
    def __init__(self, seconds: float) -> None:
        self.seconds: float = seconds
        # Accommodate signal.SIG_DFL, signal.SIG_IGN (integers)
        self.timer: Union[Callable[[int, Optional[types.FrameType]], Any], int, None] = None

    def _timeout_handler(self, signum: int, frame: Optional[types.FrameType]) -> NoReturn: # Use FrameType
        raise TimeoutError(f"Operation timed out after {self.seconds} seconds")

    def __enter__(self) -> 'timeout_manager':
        # Check if SIGALRM is available (won't be on Windows)
        if hasattr(signal, 'SIGALRM'):
            if self.seconds > 0:
                try:
                    self.timer = signal.signal(signal.SIGALRM, self._timeout_handler)
                    signal.alarm(int(self.seconds))
                except ValueError as e:
                    # Running in a thread or environment where signals are restricted
                    logger.warning(f"Could not set SIGALRM for timeout: {e}. Timeout disabled.")
                    self.seconds = 0 # Disable timeout functionality
        else:
            if self.seconds > 0:
                logger.warning("Timeout functionality requires signal.SIGALRM, not available on this platform (e.g., Windows). Timeout disabled.")
                self.seconds = 0 # Disable timeout functionality
        return self

    def __exit__(self, exc_type: Optional[type[BaseException]],
                 exc_val: Optional[BaseException],
                 exc_tb: Optional[types.TracebackType]) -> None: # Use TracebackType
        # Only try to reset the alarm if it was successfully set
        if hasattr(signal, 'SIGALRM') and self.seconds > 0 and self.timer is not None:
            signal.alarm(0)
            signal.signal(signal.SIGALRM, self.timer)


# Configure logging
logger = logging.getLogger(__name__)
# BasicConfig called in main()

# Constants
MB_CONVERSION: Final[float] = 1024 * 1024

# --- Utility Functions ---
def _pad_text(text: str, width: int, left: bool = True) -> str:
    """Pads text to a specific visual width, accounting for ANSI codes."""
    pad_len = max(0, width - Colors.visual_len(text))
    return f"{text}{' ' * pad_len}" if left else f"{' ' * pad_len}{text}"

# --- Version Info ---
def get_library_versions() -> Dict[str, str]:
    """Collect versions of key libraries."""
    versions = {
        'mlx': getattr(mx, '__version__', 'N/A'),
        'mlx-vlm': vlm_version if 'vlm_version' in globals() else 'N/A',
        'mlx-lm': mlx_lm_version,
        'huggingface-hub': hf_version,
        'transformers': transformers_version,
        'Pillow': pillow_version
    }
    return versions

def print_version_info(versions: Dict[str, str]) -> None:
    """Print collected library versions to the console with date."""
    print("\n--- Library Versions ---")
    max_len = max(len(k) for k in versions) + 1 if versions else 10
    for name, ver in sorted(versions.items()):
        status_color = Colors.GREEN if ver != "N/A" else Colors.YELLOW
        name_padded = name.ljust(max_len)
        print(f"{name_padded}: {Colors.colored(ver, status_color)}")
    print(f"\nGenerated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")


# --- ANSI Color Codes for Console Output ---
class Colors:
    """ANSI color codes for terminal output"""
    RESET: Final[str] = "\033[0m"
    BOLD: Final[str] = "\033[1m"
    RED: Final[str] = "\033[91m"
    GREEN: Final[str] = "\033[92m"
    YELLOW: Final[str] = "\033[93m"
    BLUE: Final[str] = "\033[94m"
    MAGENTA: Final[str] = "\033[95m"
    CYAN: Final[str] = "\033[96m"
    WHITE: Final[str] = "\033[97m"
    GRAY: Final[str] = "\033[90m"
    _enabled: bool = sys.stderr.isatty()
    _ansi_escape_re: Pattern[str] = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')

    @staticmethod
    def colored(text: Any, color: str) -> str: # Use Any for text input
        if not Colors._enabled:
            return str(text) # Ensure return is string
        return f"{color}{str(text)}{Colors.RESET}" # Ensure text is string

    @staticmethod
    def visual_len(text: Any) -> int: # Use Any for text input
        if not isinstance(text, str):
             text = str(text)
        return len(Colors._ansi_escape_re.sub('', text))


# Type aliases and definitions
T = TypeVar('T')
ExifValue = Any
ExifDict = Dict[Union[str, int], ExifValue]
MetadataDict = Dict[str, str]
PathLike = Union[str, Path]
GPSTupleElement = Union[int, float]
GPSTuple = Tuple[GPSTupleElement, GPSTupleElement, GPSTupleElement]

# Constants - Defaults
DEFAULT_MAX_TOKENS: Final[int] = 500
DEFAULT_FOLDER: Final[Path] = Path.home() / "Pictures" / "Processed"
DEFAULT_HTML_OUTPUT: Final[Path] = Path("results.html")
DEFAULT_TEMPERATURE: Final[float] = 0.1
DEFAULT_TIMEOUT: Final[float] = 300.0  # Default timeout in seconds

# Constants - EXIF
EXIF_IMAGE_DESCRIPTION_TAG: Final[int] = 270  # Standard EXIF tag ID for ImageDescription
IMPORTANT_EXIF_TAGS: Final[frozenset[str]] = frozenset({
    "DateTimeOriginal", "ImageDescription", "CreateDate", "Make", "Model",
    "LensModel", "ExposureTime", "FNumber", "ISOSpeedRatings",
    "FocalLength", "ExposureProgram",
})
DATE_FORMATS: Final[Tuple[str, ...]] = ("%Y:%m:%d %H:%M:%S", "%Y-%m-%d %H:%M:%S", "%Y%m%d")
EXIF_DATE_TAGS: Final[Tuple[str, ...]] = ("DateTimeOriginal", "CreateDate",  "DateTime")
GPS_LAT_REF_TAG: Final[int] = 1
GPS_LAT_TAG: Final[int] = 2
GPS_LON_REF_TAG: Final[int] = 3
GPS_LON_TAG: Final[int] = 4
GPS_INFO_TAG_ID: Final[int] = 34853  # Standard EXIF tag ID for GPS IFD

# Type definitions
class MemoryStats(NamedTuple):
    """Memory statistics container (values represent deltas or peak)."""
    active: float
    cached: float
    peak: float
    time: float
    @staticmethod
    def zero() -> 'MemoryStats': return MemoryStats(0.0, 0.0, 0.0, 0.0)

@dataclass(frozen=True)
class ModelResult:
    """Container for model processing results, including failures."""
    model_name: str
    success: bool
    output: Optional[str] = None
    stats: MemoryStats = field(default_factory=MemoryStats.zero)
    error_stage: Optional[str] = None
    error_message: Optional[str] = None
    captured_output_on_fail: Optional[str] = None


# --- File Handling ---
def find_most_recent_file(folder: Path) -> Optional[Path]:
    """Return the Path of the most recently modified file in the folder."""
    if not folder.is_dir():
        logger.error(Colors.colored(f"Provided path is not a directory: {folder}", Colors.RED))
        return None
    try:
        files: List[Path] = [
            f for f in folder.iterdir() if f.is_file() and not f.name.startswith(".")
        ]
        if not files:
            logger.warning(Colors.colored(f"No non-hidden files found in: {folder}", Colors.YELLOW))
            return None
        files.sort(key=lambda f: f.stat().st_mtime, reverse=True)
        most_recent: Path = files[0]
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f"Most recent file found: {most_recent}")
        return most_recent
    except PermissionError:
        logger.error(Colors.colored(f"Permission denied accessing folder: {folder}", Colors.RED))
        return None
    except OSError as e:
        logger.error(Colors.colored(f"OS error scanning folder {folder}: {e}", Colors.RED))
        return None

def print_image_dimensions(image_path: Path) -> None:
    """Print the dimensions and megapixel count of the image."""
    try:
        with Image.open(image_path) as img:
            width, height = img.size
            mpx: float = (width * height) / 1_000_000
            print(f"Image dimensions: {Colors.colored(f'{width}x{height}', Colors.CYAN)} ({Colors.colored(f'{mpx:.1f}', Colors.CYAN)} MPixels)")
    except FileNotFoundError:
        logger.error(Colors.colored(f"Image file not found: {image_path}", Colors.RED))
    except UnidentifiedImageError:
        logger.error(Colors.colored(f"Cannot identify image file (may be corrupt or wrong format): {image_path}", Colors.RED))
    except Exception as e:
        logger.error(Colors.colored(f"Error reading image dimensions for {image_path}: {e}", Colors.RED))


# --- EXIF & Metadata Handling ---
@lru_cache(maxsize=128)
def get_exif_data(image_path: Path) -> Optional[ExifDict]:
    """Extract EXIF data from an image file, including decoding GPS IFD."""
    # GPS_INFO_TAG_ID moved to global constants

    try:
        with Image.open(image_path) as img:
            exif_raw: Any = img.getexif()
            if not exif_raw:
                if logger.isEnabledFor(logging.DEBUG):
                    logger.debug(f"No EXIF data found in {image_path}")
                return None

            exif = exif_raw.get_ifd(ExifTags.IFD.Exif)

            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f"Raw EXIF data for {image_path}: {exif_raw}")
                logger.debug(f"Exif IFD: {exif}")

            # First pass: decode main EXIF tags
            exif_decoded: ExifDict = {}
            
            # Decode remaining EXIF tags
            for tag_id, value in exif.items():
                if tag_id == ExifTags.Base.GPSInfo:
                    continue
                tag_name = TAGS.get(tag_id, str(tag_id))
                exif_decoded[tag_name] = value

            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f"EXIF data decoded for {image_path}: {exif_decoded}")

            # Second pass: handle GPS IFD specifically
            gps_ifd = exif_raw.get_ifd(ExifTags.IFD.GPSInfo)

            if isinstance(gps_ifd, dict) and gps_ifd:
                gps_decoded = {}
                for gps_tag_id, gps_value in gps_ifd.items():
                    gps_tag_name = GPSTAGS.get(gps_tag_id, str(gps_tag_id))
                    gps_decoded[gps_tag_name] = gps_value
                exif_decoded["GPSInfo"] = gps_decoded

                if logger.isEnabledFor(logging.DEBUG):
                    logger.debug(f"EXIF GPS decoded for {image_path}: {exif_decoded}")


            return exif_decoded

    except (FileNotFoundError, UnidentifiedImageError) as e:
        logger.error(f"Error reading image file: {e}")
    except Exception as e:
        logger.error(f"Unexpected error reading EXIF: {e}")
    return None

def _format_exif_date(date_str_input: Any) -> Optional[str]:
    """Return the EXIF date value as a string, without parsing."""
    if isinstance(date_str_input, str):
        return date_str_input.strip()
    try:
        # Attempt to convert non-string types to string
        return str(date_str_input).strip()
    except Exception:
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f"Could not convert potential date value '{date_str_input}' to string.")
        return None

def _convert_gps_coordinate(ref: Optional[Union[str, bytes]], coord: Any) -> Optional[float]:
    """Convert various GPS coordinate formats to decimal degrees, robustly handling Ratio types and malformed data."""
    if not ref or not coord:
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug("Missing GPS reference or coordinate.")
        return None

    try:
        # --- Reference Handling ---
        ref_str: str
        if isinstance(ref, bytes):
            try:
                ref_str = ref.decode('ascii')
            except UnicodeDecodeError:
                logger.warning(f"Invalid GPS reference encoding: {ref!r}")
                return None
        elif isinstance(ref, str):
            ref_str = ref
        else:
            logger.warning(f"Unexpected GPS reference type: {type(ref).__name__}")
            return None

        ref_upper = ref_str.strip().upper()
        if ref_upper not in ['N', 'S', 'E', 'W']:
            logger.warning(f"Unexpected GPS reference value: {ref_str}")
            return None

        # --- Coordinate Value Handling ---
        degrees: Optional[float] = None
        minutes: Optional[float] = None
        seconds: Optional[float] = 0.0 # Default seconds to 0

        def to_float(val: Any) -> Optional[float]:
            """Safely convert EXIF value (potentially Ratio) to float."""
            if hasattr(val, 'numerator') and hasattr(val, 'denominator'): # Check for Ratio-like object
                try:
                    # Ensure numerator/denominator are numeric before division
                    num = float(val.numerator)
                    den = float(val.denominator)
                    if den == 0:
                        logger.warning(f"Invalid Ratio in GPS (denominator is zero): {val}")
                        return None
                    return num / den
                except (ValueError, TypeError, AttributeError, ZeroDivisionError) as e:
                    logger.warning(f"Malformed Ratio in GPS: {val} ({e})")
                    return None
            try:
                # Handle direct numeric types or strings representing numbers
                return float(val)
            except (ValueError, TypeError):
                logger.warning(f"Could not convert GPS value to float: {val!r} (type: {type(val).__name__})")
                return None

        # Check if coord is a sequence (tuple or list)
        if isinstance(coord, (tuple, list)):
            if len(coord) == 3: # Assume Degrees, Minutes, Seconds (DMS)
                degrees = to_float(coord[0])
                minutes = to_float(coord[1])
                seconds = to_float(coord[2])
            elif len(coord) == 2: # Assume Degrees, Decimal Minutes (DM)
                degrees = to_float(coord[0])
                minutes = to_float(coord[1])
            elif len(coord) == 1: # Assume Decimal Degrees in a sequence
                degrees = to_float(coord[0])
                minutes = 0.0 # Set minutes explicitly
            else:
                logger.warning(f"Unexpected GPS coordinate sequence length: {len(coord)} for {coord}")
                return None
        else:
            # Assume direct Decimal Degrees if not a sequence
            degrees = to_float(coord)
            minutes = 0.0 # Set minutes explicitly

        # --- Validation and Calculation ---
        if degrees is None or minutes is None or seconds is None:
            logger.warning(f"Failed to extract valid numeric values from GPS coordinate: {coord}")
            return None

        # Validate ranges
        # Allow slightly wider range for degrees initially, sign applied later
        if not (0 <= abs(degrees) <= 180 and 0 <= minutes < 60 and 0 <= seconds < 60):
            logger.warning(f"GPS values out of range: Deg={degrees}, Min={minutes}, Sec={seconds}")
            return None

        # Calculate decimal degrees
        decimal = abs(degrees) + (minutes / 60.0) + (seconds / 3600.0)

        # Apply sign based on reference
        return -decimal if ref_upper in ['S', 'W'] else decimal

    except Exception as e:
        # Catch-all for unexpected errors during conversion
        logger.error(f"Unexpected GPS conversion error for coord={coord!r}, ref={ref!r}: {type(e).__name__}: {e}", exc_info=logger.level <= logging.DEBUG)
        return None

def _extract_exif_date(exif_data: Optional[ExifDict]) -> Tuple[Optional[str], Optional[str], Optional[str]]:
    """Helper to extract the best date string from EXIF data."""
    if not exif_data:
        return None, None, None

    # Prioritize DateTimeOriginal
    dt_original = exif_data.get("DateTimeOriginal")
    if dt_original:
        raw_date_str = _format_exif_date(dt_original) # Get raw string
        if raw_date_str:
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f"Using raw DateTimeOriginal from EXIF: '{raw_date_str}'")
            return raw_date_str, "EXIF (DateTimeOriginal)", "DateTimeOriginal"

    # Fallback to other date tags
    for tag in EXIF_DATE_TAGS:
        if tag == "DateTimeOriginal":
             continue # Already checked
        if tag in exif_data:
            raw_date_str = _format_exif_date(exif_data[tag]) # Get raw string
            if raw_date_str:
                if logger.isEnabledFor(logging.DEBUG):
                    logger.debug(f"Using raw date from {tag}: '{raw_date_str}'")
                return raw_date_str, f"EXIF ({tag})", tag

    return None, None, None # No valid EXIF date found

def _extract_exif_description(exif_data: Optional[ExifDict]) -> Optional[str]:
    """Helper to extract the description from EXIF data."""
    if not exif_data:
        return None

    description = exif_data.get('ImageDescription', None)
    if logger.isEnabledFor(logging.DEBUG):
        logger.debug(f"Raw EXIF ImageDescription: '{description}'")

    if description is not None:
        try:
            desc_str = str(description).strip()
            return desc_str if desc_str else None # Return None if empty after stripping
        except Exception as desc_err:
            logger.warning(f"Could not convert ImageDescription value '{description}' to string: {desc_err}")
            return None
    return None

def _extract_gps_coordinates(exif_data: Optional[ExifDict], image_path_name: str) -> Optional[str]:
    """Helper to extract and format GPS coordinates from EXIF data."""
    if not exif_data or "GPSInfo" not in exif_data or not isinstance(exif_data["GPSInfo"], dict):
        return None

    gps_info = exif_data["GPSInfo"]
    lat = gps_info.get("GPSLatitude")
    lat_ref = gps_info.get("GPSLatitudeRef")
    lon = gps_info.get("GPSLongitude")
    lon_ref = gps_info.get("GPSLongitudeRef")

    if not (lat and lat_ref and lon and lon_ref):
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f"Incomplete GPS tags found for {image_path_name}")
        return None

    try:
        latitude = _convert_gps_coordinate(lat_ref, lat)
        longitude = _convert_gps_coordinate(lon_ref, lon)

        if latitude is not None and longitude is not None:
            gps_str = f"{latitude:.6f}, {longitude:.6f}"
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f"Extracted GPS {gps_str} for {image_path_name}")
            return gps_str
        else:
            logger.warning(f"Failed to convert GPS coordinates for {image_path_name}. Lat: {lat}, Lon: {lon}")
            return None
    except Exception as e:
        logger.warning(f"Error processing GPS coordinates for {image_path_name}: {e}")
        return None

def extract_image_metadata(image_path: Path, debug: bool = False) -> MetadataDict:
    """Extract key metadata: date string, GPS, and selected EXIF tags."""
    metadata: MetadataDict = {}
    exif_data: Optional[ExifDict] = get_exif_data(image_path)

    if debug:
        logger.debug(f"Debug mode enabled for metadata extraction of {image_path.name}")
    if logger.isEnabledFor(logging.DEBUG):
        logger.debug(f"Raw EXIF data for {image_path.name}: {exif_data}")

    # 1. Extract Date String (using helper)
    raw_exif_date, date_source, date_tag = _extract_exif_date(exif_data)
    if raw_exif_date:
        metadata['date'] = raw_exif_date # Store the raw string
        metadata['date_source'] = date_source
        metadata['date_tag'] = date_tag
    else:
        # Fallback to file modification time (formatted as ISO 8601 string)
        try:
            mtime = os.path.getmtime(image_path)
            # Format as YYYY-MM-DD HH:MM:SS for consistency
            metadata['date'] = datetime.fromtimestamp(mtime).strftime("%Y-%m-%d %H:%M:%S")
            metadata['date_source'] = "File Modification Time"
            metadata['date_tag'] = "mtime"
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f"Using file modification time: {metadata['date']}")
        except OSError as e:
            logger.warning(f"Could not get modification time for {image_path.name}: {e}")
            metadata['date'] = "Unknown"
            metadata['date_source'] = "Unavailable"
            metadata['date_tag'] = "None"

    # 2. Extract Description (using helper)
    description = _extract_exif_description(exif_data)
    metadata['description'] = description if description is not None else 'N/A'

    # 3. Extract GPS (using helper)
    gps_coords = _extract_gps_coordinates(exif_data, image_path.name)
    if gps_coords:
        metadata['gps'] = gps_coords
    # No need for an else clause, if gps_coords is None, 'gps' key won't be in metadata

    # 4. Add other important EXIF tags
    if exif_data:
        for tag in IMPORTANT_EXIF_TAGS:
            # Avoid overwriting the date/desc we just determined
            if tag not in EXIF_DATE_TAGS and tag != 'ImageDescription' and tag in exif_data:
                value = exif_data[tag]
                # Simple conversion for common types, avoid complex objects
                if isinstance(value, (str, int, float)):
                     metadata[tag] = str(value).strip()
                elif isinstance(value, bytes):
                     # Attempt to decode bytes, fallback to repr
                     try:
                         metadata[tag] = value.decode('utf-8', errors='replace').strip()
                     except UnicodeDecodeError:
                         metadata[tag] = repr(value)
                elif isinstance(value, tuple):
                     metadata[tag] = ", ".join(map(str, value))
                # Add other simple types if needed, but avoid deep structures

    return metadata

def pretty_print_exif(exif: ExifDict, verbose: bool = False) -> None:
    """Pretty print key EXIF data in a formatted table, using colors."""
    # (Implementation remains the same as previous correct version)
    if not exif:
        print("No EXIF data available.")
        return

    print(f"\n--- {Colors.colored('Key EXIF Data', Colors.CYAN)} ---")
    tags_to_print: List[Tuple[str, str, bool]] = []
    for tag, value in exif.items():
        tag_str = str(tag)
        if tag_str == "GPSInfo" and isinstance(value, dict):
            continue
        if isinstance(value, dict):
            if logger.isEnabledFor(logging.DEBUG):
                logger.debug(f"Skipping dictionary value for EXIF tag '{tag_str}' in pretty print.")
            continue
        value_str: str
        if isinstance(value, bytes):
             try:
                 decoded_str = value.decode('utf-8', errors='replace')
                 value_str = decoded_str[:57] + "..." if len(decoded_str) > 60 else decoded_str
             except Exception:
                 value_str = f"<bytes len={len(value)}>"
        elif isinstance(value, tuple) and len(value) > 10:
            value_str = f"<tuple len={len(value)}>"
        elif isinstance(value, bytearray):
            value_str = f"<bytearray len={len(value)}>"
        else:
             try:
                  value_str = str(value)
                  if len(value_str) > 60:
                      value_str = value_str[:57] + "..."
             except Exception as str_err:
                 # Log the specific error during string conversion
                 if logger.isEnabledFor(logging.DEBUG):
                     logger.debug(f"Could not convert EXIF value for tag '{tag_str}' to string: {str_err}")
                 value_str = f"<unrepresentable type: {type(value).__name__}>"

        is_important = tag_str in IMPORTANT_EXIF_TAGS
        if verbose or is_important:
            tags_to_print.append((tag_str, value_str, is_important))

    if not tags_to_print:
        print("No relevant EXIF tags found to display.")
        return
    tags_to_print.sort(key=lambda x: x[0])
    max_tag_len = max(Colors.visual_len(t[0]) for t in tags_to_print) if tags_to_print else 20
    max_val_len = max(Colors.visual_len(t[1]) for t in tags_to_print) if tags_to_print else 40
    min_width = 10
    max_tag_len = max(max_tag_len, min_width)
    max_val_len = max(max_val_len, min_width + 5)
    header_color = Colors.BLUE
    border_color = Colors.BLUE
    important_color = Colors.YELLOW
    # Use the extracted helper function
    pad = _pad_text
    print(Colors.colored(f"╔{'═' * (max_tag_len + 2)}{'═' * (max_val_len + 2)}╗", border_color))
    print(f"{Colors.colored('║', border_color)} {pad(Colors.colored('Tag', header_color), max_tag_len)} {Colors.colored('│', border_color)} {pad(Colors.colored('Value', header_color), max_val_len)} {Colors.colored('║', border_color)}")
    print(Colors.colored(f"╠{'═' * (max_tag_len + 2)}{'═' * (max_val_len + 2)}╣", border_color))
    for tag_name, value_display, is_important_tag in tags_to_print:
        tag_display = Colors.colored(tag_name, Colors.BOLD + important_color) if is_important_tag else tag_name
        print(f"{Colors.colored('║', border_color)} {pad(tag_display, max_tag_len)} {Colors.colored('│', border_color)} {pad(value_display, max_val_len)} {Colors.colored('║', border_color)}")
    print(Colors.colored(f"╚{'═' * (max_tag_len + 2)}{'═' * (max_val_len + 2)}╝", border_color))


# --- Model Handling ---
def get_cached_model_ids() -> List[str]:
    """Get list of model repo IDs from the huggingface cache."""
    if scan_cache_dir is None:
        logger.error(Colors.colored("huggingface_hub library not found. Cannot scan Hugging Face cache.", Colors.RED))
        return []
    try:
        # Use CacheInfo (public class)
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug("Scanning Hugging Face cache directory...")
        cache_info: HFCacheInfo = scan_cache_dir()
        model_ids = sorted([repo.repo_id for repo in cache_info.repos])
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f"Found {len(model_ids)} potential models in Hugging Face cache: {model_ids}")
        return model_ids
    except HFValidationError:
        logger.error(Colors.colored("Hugging Face cache directory invalid.", Colors.RED))
        return []
    except FileNotFoundError:
        logger.error(Colors.colored("Hugging Face cache directory not found.", Colors.RED))
        return []
    except Exception as e:
        logger.error(Colors.colored(f"Unexpected error scanning Hugging Face cache: {type(e).__name__}: {e}", Colors.RED), exc_info=logger.level <= logging.DEBUG)
        return []


def print_model_stats(results: List[ModelResult]) -> None:
    """Print a table summarizing model performance statistics to the console, including failures."""
    if not results:
        logger.info("No model results to display.")
        return

    # Sort results: failures last, then by time
    results.sort(key=lambda x: (not x.success, x.stats.time if x.success else float('inf')))

    # --- Constants and Configuration ---
    BASE_NAME_MAX_WIDTH = 45
    COL_WIDTH = 12
    MIN_NAME_COL_WIDTH = len("Model")

    # --- Colors ---
    COLORS = types.SimpleNamespace(
        HEADER=Colors.BLUE,
        BORDER=Colors.BLUE,
        SUMMARY=Colors.YELLOW,
        FAIL=Colors.RED,
        FAIL_TEXT=Colors.GRAY,
        SUCCESS=Colors.GREEN
    )

    # Calculate display lengths and format model names
    def format_model_name(result: ModelResult) -> Tuple[str, int]:
        base_name = result.model_name.split('/')[-1]
        display_name = base_name[:BASE_NAME_MAX_WIDTH] + ("..." if len(base_name) > BASE_NAME_MAX_WIDTH else "")
        
        if not result.success:
            fail_suffix = f" [FAIL: {result.error_stage or '?'}]"
            display_name = Colors.colored(display_name + fail_suffix, COLORS.FAIL)
        else:
            display_name = Colors.colored(display_name, COLORS.SUCCESS)
            
        return display_name, Colors.visual_len(display_name)

    # Calculate maximum display length
    name_displays = [format_model_name(r) for r in results]
    max_display_len = max((length for _, length in name_displays), default=MIN_NAME_COL_WIDTH)
    name_col_width = max(max_display_len, MIN_NAME_COL_WIDTH)

    if logger.isEnabledFor(logging.DEBUG):
        logger.debug(f"Column widths - Name: {name_col_width}, Data: {COL_WIDTH}")

    # Helper for horizontal lines
    def h_line(char: str) -> str:
        return Colors.colored(
            f"╔{'═' * (name_col_width + 2)}{'═' * (COL_WIDTH + 2)}╤"
            f"{'═' * (COL_WIDTH + 2)}{'═' * (COL_WIDTH + 2)}{'═' * (COL_WIDTH + 2)}╗"
            if char == '═' else
            f"╚{'═' * (name_col_width + 2)}{'═' * (COL_WIDTH + 2)}╧"
            f"{'═' * (COL_WIDTH + 2)}{'═' * (COL_WIDTH + 2)}{'═' * (COL_WIDTH + 2)}╝",
            COLORS.BORDER
        )

    # Print table header
    print("\n" + h_line('═'))
    headers = ["Model", "Active Δ", "Cache Δ", "Peak Mem", "Time"]
    header_row = Colors.colored(
        f"║ {_pad_text(Colors.colored(headers[0], COLORS.HEADER), name_col_width)} │ "
        + " │ ".join(_pad_text(Colors.colored(h, COLORS.HEADER), COL_WIDTH, False) for h in headers[1:])
        + " ║", COLORS.BORDER
    )
    print(header_row)
    print(Colors.colored(f"╠{'═' * (name_col_width + 2)}{'═' * (COL_WIDTH + 2)}╪"
                        f"{'═' * (COL_WIDTH + 2)}{'═' * (COL_WIDTH + 2)}{'═' * (COL_WIDTH + 2)}╣",
                        COLORS.BORDER))

    # Print data rows
    successful_results = []
    for result, (display_name, _) in zip(results, name_displays):
        if result.success:
            successful_results.append(result)
            stats = [
                f"{result.stats.active:,.0f} MB",
                f"{result.stats.cached:,.0f} MB",
                f"{result.stats.peak:,.0f} MB",
                f"{result.stats.time:.2f} s"
            ]
        else:
            stats = [Colors.colored("-", COLORS.FAIL_TEXT)] * 4

        row = Colors.colored(
            f"║ {_pad_text(display_name, name_col_width)} │ "
            + " │ ".join(_pad_text(stat, COL_WIDTH, False) for stat in stats)
            + " ║", COLORS.BORDER
        )
        print(row)

    # Print summary if there are successful results
    if successful_results:
        print(Colors.colored(f"╠{'═' * (name_col_width + 2)}{'═' * (COL_WIDTH + 2)}╪"
                           f"{'═' * (COL_WIDTH + 2)}{'═' * (COL_WIDTH + 2)}{'═' * (COL_WIDTH + 2)}╣",
                           COLORS.BORDER))
        
        avg_stats = [
            sum(r.stats.active for r in successful_results) / len(successful_results),
            sum(r.stats.cached for r in successful_results) / len(successful_results),
            max(r.stats.peak for r in successful_results),
            sum(r.stats.time for r in successful_results) / len(successful_results)
        ]
        
        summary_stats = [
            f"{avg_stats[0]:,.0f} MB",
            f"{avg_stats[1]:,.0f} MB",
            f"{avg_stats[2]:,.0f} MB",
            f"{avg_stats[3]:.2f} s"
        ]
        
        summary_title = Colors.colored(f"AVG/PEAK ({len(successful_results)} Success)", COLORS.SUMMARY)
        summary_row = Colors.colored(
            f"║ {_pad_text(summary_title, name_col_width)} │ "
            + " │ ".join(_pad_text(Colors.colored(stat, COLORS.SUMMARY), COL_WIDTH, False) for stat in summary_stats)
            + " ║", COLORS.BORDER
        )
        print(summary_row)

    print(h_line('╝'))
    
    if logger.isEnabledFor(logging.DEBUG):
        logger.debug(f"Displayed stats for {len(results)} models ({len(successful_results)} successful)")


# --- HTML Report Generation ---
def generate_html_report(results: List[ModelResult], filename: Path, versions: Dict[str, str]) -> None:
    """Generates an HTML file with model stats, output/errors, failures, and versions."""
    if not results:
        logger.warning("No results to generate HTML report.")
        return

    results.sort(key=lambda x: (not x.success, x.stats.time if x.success else 0))

    html_start = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Model Performance Results</title>
    <style>
        body { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; margin: 20px; background-color: #f8f9fa; color: #212529; line-height: 1.6; }
        h1 { text-align: center; color: #343a40; border-bottom: 2px solid #dee2e6; padding-bottom: 10px; margin-bottom: 30px; }
        table { border-collapse: collapse; width: 95%; margin: 30px auto; box-shadow: 0 4px 8px rgba(0,0,0,0.1); background-color: #ffffff; }
        th, td { border: 1px solid #dee2e6; padding: 12px 15px; text-align: left; vertical-align: top; }
        th { background-color: #e9ecef; font-weight: 600; color: #495057; position: sticky; top: 0; z-index: 1; }
        tr:nth-child(even):not(.failed-row) { background-color: #f8f9fa; }
        tr:not(.failed-row):hover { background-color: #e2e6ea; }
        td.numeric, th.numeric { text-align: right; font-variant-numeric: tabular-nums; }
        .summary td { font-weight: bold; background-color: #d1ecf1; color: #0c5460; border-color: #bee5eb; }
        caption { caption-side: bottom; padding: 15px; font-style: italic; color: #6c757d; text-align: center; }
        .model-name { font-family: 'Courier New', Courier, monospace; font-weight: 500; }
        .model-output, .captured-output pre { white-space: pre-wrap; word-wrap: break-word; max-width: 600px; font-size: 0.9em; }
        .captured-output { background-color: #fff3cd; border-left: 3px solid #ffeeba; padding: 8px; margin-top: 10px; color: #856404; }
        .captured-output strong { color: #856404; }
        .captured-output pre { margin: 0; padding: 5px; background-color: transparent; border: none; color: #856404; }
        tr.failed-row { background-color: #f8d7da !important; color: #721c24; }
        tr.failed-row:hover { background-color: #f5c6cb !important; }
        tr.failed-row .model-output { font-style: italic; color: #721c24; }
        tr.failed-row td.numeric { color: #721c24; font-style: italic; }
        .error-message { font-weight: bold; display: block; color: #721c24; }
        footer { margin-top: 30px; text-align: center; font-size: 0.85em; color: #6c757d; }
        footer h2 { font-size: 1.1em; color: #495057; margin-bottom: 10px;}
        footer ul { list-style: none; padding: 0; margin: 0 0 10px 0; }
        footer li { display: inline-block; margin: 0 10px; }
        footer code { background-color: #e9ecef; padding: 2px 4px; border-radius: 3px; }
        footer p { margin-top: 5px; }
    </style>
</head>
<body>
    <h1>Model Performance Summary</h1>
    <table>
        <caption>Performance metrics and output/errors for Vision Language Model processing. Generated on """ + datetime.now().strftime('%Y-%m-%d %H:%M:%S') + """. Failures shown but excluded from averages.</caption>
        <thead>
            <tr>
                <th>Model</th>
                <th class="numeric">Active Δ (MB)</th> <th class="numeric">Cache Δ (MB)</th>
                <th class="numeric">Peak Mem (MB)</th> <th class="numeric">Time (s)</th>
                <th>Result / Error / Captured Output</th>
            </tr>
        </thead>
        <tbody>
"""
    html_rows: str = ""
    successful_results: List[ModelResult] = [r for r in results if r.success]

    for result in results:
        model_disp_name = html.escape(result.model_name)
        row_class = ""
        result_content = ""
        stats_cells = ""

        if result.success:
            escaped_output = html.escape(result.output or "")
            # Highlight model output in HTML
            result_content = f'<div class="model-output"><strong>{escaped_output}</strong></div>'
            stats_cells = f"""
                <td class="numeric">{result.stats.active:,.0f}</td>
                <td class="numeric">{result.stats.cached:,.0f}</td>
                <td class="numeric">{result.stats.peak:,.0f}</td>
                <td class="numeric">{result.stats.time:.2f}</td>
            """
        else:
            row_class = ' class="failed-row"'
            error_message = html.escape(result.error_message or "Unknown error")
            result_content = f'<span class="error-message">{error_message}</span>'
            if result.captured_output_on_fail:
                captured_output = html.escape(result.captured_output_on_fail)
                result_content += f'<div class="captured-output"><strong>Captured Output:</strong><pre>{captured_output}</pre></div>'
            stats_cells = """
                <td class="numeric">-</td>
                <td class="numeric">-</td>
                <td class="numeric">-</td>
                <td class="numeric">-</td>
            """

        html_rows += f"""
            <!-- Data row for model: {model_disp_name} -->
            <tr{row_class}>
                <td class="model-name">{model_disp_name}</td>
                {stats_cells}
                <td>{result_content}</td>
            </tr>
"""

    html_summary_row: str = ""
    if successful_results:
        avg_active = sum(r.stats.active for r in successful_results) / len(successful_results)
        avg_cache = sum(r.stats.cached for r in successful_results) / len(successful_results)
        max_peak = max(r.stats.peak for r in successful_results)
        avg_time = sum(r.stats.time for r in successful_results) / len(successful_results)
        summary_title = f"AVG/PEAK ({len(successful_results)} Success)"
        # Format summary stats as comma-separated integers
        html_summary_row = f"""
            <!-- Summary Row (Based on Successful Runs) -->
            <tr class="summary">
                <td>{summary_title}</td>
                <td class="numeric">{avg_active:,.0f}</td>
                <td class="numeric">{avg_cache:,.0f}</td>
                <td class="numeric">{max_peak:,.0f}</td>
                <td class="numeric">{avg_time:.2f}</td>
                <td></td>
            </tr>
"""

    # --- Add Version Info Footer ---
    html_footer: str = "<footer>\n<h2>Library Versions</h2>\n<ul>\n"
    # Use sorted items for consistent order in HTML
    for name, ver in sorted(versions.items()):
        status_style = 'color: green;' if ver != "N/A" else 'color: orange;'
        html_footer += f'<li>{html.escape(name)}: <code style="{status_style}">{html.escape(ver)}</code></li>\n'
    html_footer += "</ul>\n"
    # Add date to footer
    html_footer += f"<p>Report generated on: {datetime.now().strftime('%Y-%m-%d')}</p>\n</footer>"
    # -----------------------------

    html_end = f"""
        </tbody>
    </table>
    <!-- End of Table -->
    {html_footer}
</body>
</html>
"""
    html_content: str = html_start + html_rows + html_summary_row + html_end

    try:
        # *** Restore TextIO type hint for file handle 'f' ***
        f: TextIO
        with open(filename, "w", encoding="utf-8") as f:
            f.write(html_content)
        logger.info(f"HTML report saved to: {Colors.colored(str(filename.resolve()), Colors.GREEN)}")
    except IOError as e:
        logger.error(Colors.colored(f"Failed to write HTML report to {filename}: {e}", Colors.RED))
    except Exception as e:
         logger.error(Colors.colored(f"An unexpected error occurred while writing HTML report: {type(e).__name__}: {e}", Colors.RED), exc_info=logger.level <= logging.DEBUG)


def get_system_info() -> Tuple[str, str]:
    """Get system architecture and GPU information."""
    arch: str = platform.machine()
    gpu_info: str = "Unknown"
    try:
        # Try to get GPU info on macOS
        if platform.system() == "Darwin":
            result: subprocess.CompletedProcess[str] = subprocess.run(['system_profiler', 'SPDisplaysDataType'],
                                 capture_output=True, text=True, timeout=2)
            if result.returncode == 0:
                # Extract GPU info from system_profiler output
                gpu_lines: List[str] = [line for line in result.stdout.split('\n')
                           if "Chipset Model:" in line]
                if gpu_lines:
                    gpu_info = gpu_lines[0].split("Chipset Model:")[-1].strip()
    except (subprocess.SubprocessError, TimeoutError):
        pass
    return arch, gpu_info

# --- Model Processing Core ---
def validate_inputs(image_path: PathLike, model_path: str, temperature: float = 0.0) -> None:
    """Validate input paths and parameters."""
    img_path: Path = Path(image_path)
    if not img_path.exists():
        raise FileNotFoundError(f"Image not found: {img_path}")
    if not img_path.is_file():
        raise ValueError(f"Not a file: {img_path}")
    if img_path.suffix.lower() not in {'.jpg', '.jpeg', '.png', '.webp'}:
        raise ValueError(f"Unsupported image format: {img_path.suffix}")

    validate_temperature(temperature)

def validate_temperature(temp: float) -> None:
    """Validate temperature parameter is within acceptable range."""
    if not isinstance(temp, (int, float)):
        raise ValueError(f"Temperature must be a number, got {type(temp)}")
    if not 0.0 <= temp <= 1.0:
        raise ValueError(f"Temperature must be between 0 and 1, got {temp}")

def validate_image_accessible(image_path: Path) -> None:
    """Validate image file is accessible and supported."""
    try:
        with timeout_manager(5):  # 5 second timeout for read test
            with Image.open(image_path) as img:
                img.verify()
    except TimeoutError:
        raise IOError(f"Timeout while reading image: {image_path}")
    except UnidentifiedImageError:
        raise ValueError(f"File is not a recognized image format: {image_path}")
    except Exception as e:
        raise IOError(f"Error accessing image {image_path}: {e}")

def _run_model_generation(
    model_identifier: str,
    image_path: Path,
    prompt: str,
    max_tokens: int,
    verbose: bool,
    trust_remote_code: bool,
    temperature: float
) -> Tuple[str, Any, Any]: # Returns (output, model, tokenizer) - model/tokenizer needed for cleanup
    """Loads model, formats prompt, and runs generation. Raises exceptions on failure."""
    model = tokenizer = None # Ensure they are defined in this scope
    try:
        # Load model and tokenizer
        model, tokenizer = load(model_identifier, trust_remote_code=trust_remote_code)
        config: Dict[str, Any] = load_config(model_identifier, trust_remote_code=trust_remote_code)

        # Prepare prompt
        formatted_prompt: str = apply_chat_template(tokenizer, config, prompt, num_images=1)

        # Generate output
        output: Optional[str] = generate(
            model=model,
            processor=tokenizer,  # Type checking handled by function signature
            prompt=formatted_prompt,
            image=image_path.as_posix(),
            max_tokens=max_tokens,
            verbose=verbose,
            temp=temperature
        )

        # Ensure all computations involving the model are done before measuring memory/time
        mx.eval(model.parameters()) # Evaluate model parameters if needed after generation

        return output if output is not None else "[No model output]", model, tokenizer

    except Exception:
        # If loading failed, model might be None. If generation failed, model is likely loaded.
        # Re-raise the exception to be caught by the outer function, which handles ModelResult creation.
        # We pass model/tokenizer back mainly for the finally block in the caller.
        raise # Re-raise the original exception

def process_image_with_model(
    model_identifier: str,
    image_path: Path,
    prompt: str,
    max_tokens: int,
    verbose: bool = False,
    trust_remote_code: bool = False,
    temperature: float = DEFAULT_TEMPERATURE,
    timeout: float = DEFAULT_TIMEOUT,
) -> ModelResult:
    """Process an image with a Vision Language Model, managing stats and errors."""
    logger.info(f"Processing '{image_path.name}' with model: {Colors.colored(model_identifier, Colors.MAGENTA)}")

    model = tokenizer = None # Initialize here for the finally block
    arch, gpu_info = get_system_info()
    start_time: float = 0.0
    initial_mem: float = 0.0
    initial_cache: float = 0.0
    output: Optional[str] = None
    error_stage: str = "initialization"

    try:
        validate_temperature(temperature)
        validate_image_accessible(image_path)
        error_stage = "validation"

        if logger.isEnabledFor(logging.DEBUG):
            logger.debug(f"System: {arch}, GPU: {gpu_info}")

        # --- Capture initial state BEFORE model operations ---
        initial_mem = mx.get_active_memory() / MB_CONVERSION
        initial_cache = mx.get_cache_memory() / MB_CONVERSION
        start_time = time.perf_counter()
        # -----------------------------------------------------

        with timeout_manager(timeout):
            error_stage = "load/generate" # Stage if timeout occurs here
            # Call the internal generation function
            output, model, tokenizer = _run_model_generation(
                model_identifier=model_identifier,
                image_path=image_path,
                prompt=prompt,
                max_tokens=max_tokens,
                verbose=verbose,
                trust_remote_code=trust_remote_code,
                temperature=temperature
            )
            error_stage = "post-generate" # Stage after successful generation

        # --- Capture final state AFTER model operations ---
        end_time = time.perf_counter()
        final_active_mem = mx.get_active_memory() / MB_CONVERSION
        final_cache_mem = mx.get_cache_memory() / MB_CONVERSION
        peak_mem = mx.get_peak_memory() / MB_CONVERSION
        # --------------------------------------------------

        final_stats = MemoryStats(
            active=final_active_mem - initial_mem,
            cached=final_cache_mem - initial_cache,
            peak=peak_mem,
            time=end_time - start_time
        )

        return ModelResult(
            model_name=model_identifier,
            success=True,
            output=output, # Use the output from _run_model_generation
            stats=final_stats
        )

    except TimeoutError:
        # Timeout specifically handled
        logger.error(f"Timeout ({timeout}s) during '{error_stage}' for model {model_identifier}")
        return ModelResult(
            model_name=model_identifier,
            success=False,
            error_stage="timeout",
            error_message=f"Operation timed out after {timeout} seconds during {error_stage}"
        )
    except Exception as e:
        # Determine stage more accurately based on where the exception occurred
        # If model is None here, it likely failed during load within _run_model_generation
        if model is None and error_stage == "load/generate":
             error_stage = "model_load"

        logger.error(f"Failed during '{error_stage}' for model {model_identifier}: {type(e).__name__}: {e}")
        if logger.isEnabledFor(logging.DEBUG):
            traceback.print_exc()

        return ModelResult(
            model_name=model_identifier,
            success=False,
            error_stage=error_stage,
            error_message=str(e)
            # captured_output_on_fail might be added here if needed, but requires more complex handling
        )
    finally:
        # Ensure cleanup happens regardless of success/failure
        if model is not None:
            del model
        if tokenizer is not None:
            del tokenizer
        # Clear cache and reset peak memory after each model run
        mx.clear_cache()
        mx.reset_peak_memory()
        if logger.isEnabledFor(logging.DEBUG):
             logger.debug(f"Cleaned up resources for model {model_identifier}")

# --- Main Execution Helper Functions ---

def print_cli_header(title: str) -> None:
    print(Colors.colored(f"\n{'=' * 80}", Colors.BLUE))
    print(Colors.colored(f"{title.center(80)}", Colors.CYAN + Colors.BOLD))
    print(Colors.colored(f"{'=' * 80}\n", Colors.BLUE))

def print_cli_section(title: str) -> None:
    print(Colors.colored(f"\n--- {title} ---", Colors.MAGENTA))

def print_cli_error(msg: str) -> None:
    print(Colors.colored(f"Error: {msg}", Colors.RED), file=sys.stderr)

def setup_environment(args: argparse.Namespace) -> Dict[str, str]:
    """Configure logging, collect versions, print warnings."""
    log_level: int = logging.DEBUG if args.debug else (logging.INFO if args.verbose else logging.INFO)
    logging.basicConfig(level=log_level,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[logging.StreamHandler(sys.stderr)], force=True)
    if args.debug:
        logger.debug("Debug mode enabled.")
    elif args.verbose:
        logger.info("Verbose mode enabled.")

    library_versions: Dict[str, str] = get_library_versions()
    if args.debug: # Only print versions in debug mode initially
        print_version_info(library_versions)

    if args.trust_remote_code:
        logger.warning(Colors.colored("--- SECURITY WARNING ---", Colors.YELLOW + Colors.BOLD))
        logger.warning(Colors.colored("`--trust-remote-code` is enabled.", Colors.YELLOW))
        logger.warning(Colors.colored("-----------------------", Colors.YELLOW + Colors.BOLD))

    return library_versions

def find_and_validate_image(args: argparse.Namespace) -> Path:
    folder_path: Path = args.folder.resolve()
    print_cli_section(f"Scanning folder: {folder_path}")
    if args.folder == DEFAULT_FOLDER and not DEFAULT_FOLDER.is_dir():
        print_cli_error(f"Default folder '{DEFAULT_FOLDER}' does not exist.")
    image_path: Optional[Path] = find_most_recent_file(folder_path)
    if not image_path:
        print_cli_error(f"Could not find a suitable image file in {folder_path}. Exiting.")
        sys.exit(1)
    resolved_image_path: Path = image_path.resolve()
    print_cli_section(f"Processing file: {resolved_image_path.name}")
    print(f"Located at: {Colors.colored(resolved_image_path, Colors.BLUE)}")
    try:
        with Image.open(resolved_image_path) as img:
            img.verify()
        print_image_dimensions(resolved_image_path)
        return resolved_image_path
    except (FileNotFoundError, UnidentifiedImageError, OSError, Exception) as img_err:
        print_cli_error(f"Cannot open or verify image {resolved_image_path}: {img_err}. Exiting.")
        sys.exit(1)

def handle_metadata(image_path: Path, args: argparse.Namespace) -> MetadataDict:
    """Extract, print, and return image metadata."""
    metadata: MetadataDict = extract_image_metadata(image_path, debug=args.debug)
    print(f"  Date: {Colors.colored(metadata.get('date', 'N/A'), Colors.CYAN)}")
    print(f"  Desc: {Colors.colored(metadata.get('description', 'N/A'), Colors.CYAN)}")
    print(f"  GPS:  {Colors.colored(metadata.get('gps', 'N/A'), Colors.CYAN)}")

    if args.verbose or args.debug:
         exif_data: Optional[ExifDict] = get_exif_data(image_path)
         if exif_data:
             pretty_print_exif(exif_data, verbose=True)
         else:
             print("\nNo detailed EXIF data could be extracted.")
    return metadata

def prepare_prompt(args: argparse.Namespace, metadata: MetadataDict) -> str:
    """Prepare the prompt for the VLM, using user input or generating from metadata."""
    prompt: str
    if args.prompt:
        prompt = args.prompt
        logger.info("Using user-provided prompt.")
    else:
        logger.info("Generating default prompt based on image metadata.")
        prompt_parts: List[str] = [
            "Provide a factual caption, description, and keywords suitable for cataloguing, or searching for, the image.",
            (f"Context: Relates to '{metadata.get('description', '')}'"
             if metadata.get('description') and metadata['description'] != "N/A" else ""),
            (f"taken around {metadata.get('date', '')}"
             if metadata.get('date') and metadata['date'] != "Unknown date" else ""),
            (f"near GPS {metadata.get('gps', '')}."
             if metadata.get('gps') and metadata['gps'] != "Unknown location" else ""),
            "Focus on visual content. Avoid repeating the context unless it is visible."
        ]
        prompt = " ".join(filter(None, prompt_parts)).strip()
        if logger.isEnabledFor(logging.DEBUG):
            logger.debug("Using generated prompt based on metadata.")
    print(f"\n{Colors.colored('--- Using Prompt ---', Colors.CYAN)}\n{prompt}\n{Colors.colored('-'*40, Colors.BLUE)}")
    return prompt

def process_models(
    args: argparse.Namespace,
    image_path: Path,
    prompt: str
) -> List[ModelResult]:
    model_identifiers: List[str]
    if args.models:
        model_identifiers = args.models
        print_cli_section(f"Processing specified models: {', '.join(model_identifiers)}")
    else:
        print_cli_section("Scanning cache for models to process...")
        model_identifiers = get_cached_model_ids()
    results: List[ModelResult] = []
    if not model_identifiers:
        print_cli_error("No models specified or found in cache.")
        if not args.models:
            print("Ensure models are downloaded and cache is accessible.", file=sys.stderr)
    else:
        print_cli_section(f"Processing {len(model_identifiers)} model(s)...")
        separator: str = Colors.colored(f"\n{'-' * 40}\n", Colors.BLUE)
        for model_id in model_identifiers:
            print(separator)
            is_vlm_verbose: bool = args.verbose or args.debug
            result: ModelResult = process_image_with_model(
                model_identifier=model_id,
                image_path=image_path,
                prompt=prompt, max_tokens=args.max_tokens,
                verbose=is_vlm_verbose,
                trust_remote_code=args.trust_remote_code,
                temperature=args.temperature,
                timeout=args.timeout
            )
            results.append(result)
            model_short_name: str = model_id.split('/')[-1]
            if result.success:
                print(Colors.colored(f"[SUCCESS] {model_short_name}", Colors.GREEN))
                if result.output:
                    # Highlight model output
                    print(f"\n{Colors.colored('Output:', Colors.CYAN)}\n{Colors.colored(result.output, Colors.CYAN + Colors.BOLD)}")
                if args.verbose or args.debug:
                    print(f"Time taken: {result.stats.time:.2f}s")
            else:
                print(Colors.colored(f"[FAIL] {model_short_name} (Stage: {result.error_stage})", Colors.RED))
                print(f"  {Colors.colored('ERROR', Colors.RED)}: Model {model_short_name} failed during '{result.error_stage}'.")
                if args.verbose or args.debug:
                    print(f"  Reason: {result.error_message}")
    return results

def finalize_execution(
    args: argparse.Namespace,
    results: List[ModelResult],
    library_versions: Dict[str, str],
    overall_start_time: float
) -> None:
    """Print summary stats, generate report, print versions, and total time."""
    # --- 5. Print Summary Statistics ---
    if results:
        print(Colors.colored(f"\n{'=' * 80}\n", Colors.BLUE)) # Separator
        print_model_stats(results) # Function call
    else:
        print(Colors.colored("\nNo models processed. No performance summary generated.", Colors.YELLOW))

    # --- 6. Generate HTML Report ---
    html_output_path: Path = args.output_html.resolve()
    if results:
        # Pass collected versions to the report generator
        generate_html_report(results, html_output_path, library_versions)
    else:
        # Use logger.info for consistency
        logger.info(f"Skipping HTML report generation to {html_output_path} as no models were processed.")

    # --- 7. Print Version Info to Console ---
    # Print versions after all processing and reporting is done
    print_version_info(library_versions)

    # --- Calculate and Print Total Time ---
    overall_time: float = time.perf_counter() - overall_start_time
    print(f"\nTotal execution time: {Colors.colored(f'{overall_time:.2f} seconds', Colors.GREEN)}.")

# --- Main Execution ---
def main(args: argparse.Namespace) -> None:
    """Main function to orchestrate image analysis."""
    overall_start_time: float = time.perf_counter()
    print_cli_header("MLX Vision Language Model Image Analysis")
    library_versions: Dict[str, str] = setup_environment(args)
    resolved_image_path: Path = find_and_validate_image(args)
    metadata: MetadataDict = handle_metadata(resolved_image_path, args)
    prompt: str = prepare_prompt(args, metadata)
    results: List[ModelResult] = process_models(args, resolved_image_path, prompt)
    finalize_execution(args, results, library_versions, overall_start_time)

if __name__ == "__main__":
    
    # Setup Argument Parser
    parser: argparse.ArgumentParser = argparse.ArgumentParser(
        description="Analyze image with MLX VLMs.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    # Add arguments (separated for clarity)
    parser.add_argument("-f", "--folder", type=Path, default=DEFAULT_FOLDER, help="Folder to scan.")
    parser.add_argument("--output-html", type=Path, default=DEFAULT_HTML_OUTPUT, help="Output HTML report file.")
    parser.add_argument("--models", nargs='+', type=str, default=None, help="Specify models by ID/path. Overrides cache scan.")
    parser.add_argument("--trust-remote-code", action=argparse.BooleanOptionalAction, default=True, help="Allow custom code from Hub models (SECURITY RISK).")
    parser.add_argument("-p", "--prompt", type=str, default=None, help="Custom prompt.")
    parser.add_argument("-m", "--max-tokens", type=int, default=DEFAULT_MAX_TOKENS, help="Max new tokens to generate.")
    parser.add_argument("-t", "--temperature", type=float, default=DEFAULT_TEMPERATURE, help=f"Sampling temperature (default: {DEFAULT_TEMPERATURE}).")
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output (INFO logging).")
    parser.add_argument("-d", "--debug", action="store_true", help="Enable debug logging (DEBUG logging).")
    parser.add_argument("--timeout", type=float, default=DEFAULT_TIMEOUT, help=f"Timeout in seconds for model operations (default: {DEFAULT_TIMEOUT}).")

    # Parse arguments
    parsed_args: argparse.Namespace = parser.parse_args()

    # --- Main Execution ---
    try:
        main(parsed_args)
    except Exception as main_err:
        # Log final unhandled exceptions with color
        # Use logger.critical for severe errors causing exit
        logger.critical(Colors.colored(f"An unexpected error occurred during main execution: {main_err}", Colors.RED), exc_info=True)
        sys.exit(1) # Exit with error status
@Blaizzy
Copy link
Owner

Blaizzy commented May 2, 2025

I noticed it today as well

I will take a look

@Blaizzy Blaizzy added the bug Something isn't working label May 2, 2025
@jrp2014
Copy link
Author

jrp2014 commented May 3, 2025

The second part of the Tuple contains useful information, so no problem with keeping it and adjusting the various type annotations accordingly. Some examples would need alignment too, I suppose.

@jrp2014 jrp2014 changed the title Is output from generate is a tuple, rather than a string Output from generate is a tuple, rather than a string May 4, 2025
@jrp2014
Copy link
Author

jrp2014 commented May 4, 2025

Further investigation shows that generate does indeed return

    return text, usage_stats

while generate itself says it's going to generate a string:

def generate(
    model: nn.Module,
    processor: PreTrainedTokenizer,
    prompt: str,
    image: Union[str, List[str]] = None,
    verbose: bool = False,
    **kwargs,
) -> str:

This seems to have been introduce through the FastAPI PR 321.

So a simple fix would be to correct this return type.

However, there are other options. usage_stats sees to be neither fish nor fowl. There is a useful class, which you might think would be a preferable type to return, rather than the tuple:

@dataclass
class GenerationResult:
    text: str
    token: Optional[int]
    logprobs: Optional[List[float]]
    prompt_tokens: int
    generation_tokens: int
    prompt_tps: float
    generation_tps: float
    peak_memory: float

After all, it is what stream_generate returns. This could take a little more thinking about. stream_generate produces a sequence of GenerationResults. generate produces the concatenation of the text files in that sequence and the last in the sequence of the other parameters. So is the peak_memory the peak memory for the whole sequence thus far, or is it just for the last/latest chunk of the stream? It would be useful to have a way of combining the cumulative generation result with the latest one for parameters beyond text.

Finally, while we get peak_memory out, it could also be useful to get the active memory and the cache memory stats (and perhaps anything else that mlx offers.

If you decide to return an enhanced type (and I think that that would be useful), the documentation and examples would need to follow.

@jrp2014
Copy link
Author

jrp2014 commented May 9, 2025

Gemma 2.5 Pro identifies the mismatch!

The `mlx_vlm.generate` function, when `verbose=True` (or by default in some configurations), can return a tuple where the first element is the generated text string and subsequent elements might be statistics or other information (like tokens per second, prompt evaluation time, etc.).

In your `_run_model_generation` function, the variable `output` was assigned this entire tuple. Then, this tuple was being returned as the first element of the 3-tuple `(output, model, tokenizer)`. Consequently, when `process_image_with_model` constructed `ModelResult`, the `output` field of `ModelResult` was assigned this tuple, not just the string you expected.

I will adjust `_run_model_generation` to specifically extract the string part from the `generate` function's result, ensuring that `ModelResult.output` correctly stores only the textual output.

Additionally, during the review, I noticed that the `get_exif_data` function primarily extracts tags from the Exif SubIFD and GPS SubIFD. It might miss some common tags (like "Make", "Model", "ImageDescription") that are typically located in the main Image File Directory (IFD0). This is a separate point for potential improvement if you need a more comprehensive set of EXIF tags. For now, I'll address the immediate question about the model output format.

I'll modify `_run_model_generation` to correctly process the output from the `generate` function. 

More generally, I don't klnow what ruff settings you use, but there are a good few bits of cruft such as unused or duplicate imports and functions, odd references to self etc, that would be fairly low effort to action. But it would probably take someone capable of validating updates to do that; I can't see the benefit of doing it independently and having a reviewer review machine-generated fixes.

@Blaizzy
Copy link
Owner

Blaizzy commented May 9, 2025

Got it, thanks this is really helpful!

Will fix it over the weekend

@Blaizzy
Copy link
Owner

Blaizzy commented May 9, 2025

Yes, it was introduced here:
http://github.com/Blaizzy/mlx-vlm/pull/321/files#diff-83ac80a02189338eeaf681e559f8111ce51a227f730a8b2d690be5911cf1febcR1330-R1339

I'm thinking probably best to return a dataclass instead of a tuple.

Where users will need to index into .text to access the output.

What do you think?

@jrp2014
Copy link
Author

jrp2014 commented May 10, 2025

Yes, that would be good. There is already a GenerationResult class that could be suitable, particularly if it was augmented with the other memory use metrics that mix offers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants