diff --git a/nemo_retriever/pyproject.toml b/nemo_retriever/pyproject.toml index 5a94917ab..0d561e2bf 100644 --- a/nemo_retriever/pyproject.toml +++ b/nemo_retriever/pyproject.toml @@ -76,6 +76,8 @@ dependencies = [ "soundfile>=0.12.0", "scipy>=1.11.0", "nvidia-ml-py", + "fsspec>=2023.1.0", + "universal-pathlib>=0.2.0", "vllm==0.16.0", ] diff --git a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py index e65be2eda..7a2a54af5 100644 --- a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py @@ -24,6 +24,7 @@ from nemo_retriever.params import CaptionParams from nemo_retriever.params import DedupParams from nemo_retriever.params import EmbedParams +from nemo_retriever.params import StoreParams from nemo_retriever.params import ExtractParams from nemo_retriever.params import TextChunkParams from nemo_retriever.params.models import BatchTuningParams @@ -329,6 +330,11 @@ def main( table_structure_invoke_url: Optional[str] = typer.Option( None, "--table-structure-invoke-url", help="Remote endpoint URL for table-structure model inference." ), + store_images_uri: Optional[str] = typer.Option( + None, + "--store-images-uri", + help="When set, store extracted images to this URI (local path or fsspec URI like s3://bucket/prefix).", + ), extract_text: bool = typer.Option( True, "--extract-text/--no-extract-text", help="Enable text extraction from documents." ), @@ -549,6 +555,9 @@ def main( if enable_text_chunk: ingestor = ingestor.split(text_chunk_params) + if store_images_uri: + ingestor = ingestor.store(StoreParams(storage_uri=store_images_uri)) + enable_caption = caption or caption_invoke_url is not None enable_dedup = dedup if dedup is not None else enable_caption if enable_dedup: @@ -614,6 +623,8 @@ def main( handle_lancedb(ingest_local_results, lancedb_uri, LANCEDB_TABLE, hybrid=hybrid, mode="overwrite") lancedb_write_time = time.perf_counter() - lancedb_write_start + del ingest_local_results # free driver heap before recall + from nemo_retriever.model import resolve_embed_model from nemo_retriever.recall.beir import BeirConfig from nemo_retriever.recall.core import RecallConfig, retrieve_and_score diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py index 61d1a1937..a49ab98d9 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py @@ -54,6 +54,7 @@ from ..params import IngestExecuteParams from ..params import PdfSplitParams from ..params import TextChunkParams +from ..params import StoreParams from ..params import CaptionParams from ..params import VdbUploadParams @@ -1037,6 +1038,25 @@ def embed( return self + def store(self, params: StoreParams | None = None, **kwargs: Any) -> "BatchIngestor": + """Store extracted images to disk or cloud storage via fsspec.""" + if self._rd_dataset is None: + raise RuntimeError("No Ray Dataset to store from. Call .files(...) / .extract(...) first.") + + from nemo_retriever.io.image_store import store_extracted_images + + resolved = _coerce_params(params, StoreParams, kwargs) + store_kwargs = resolved.model_dump(mode="python") + self._tasks.append(("store", dict(store_kwargs))) + + _store_fn = partial(store_extracted_images, **store_kwargs) + self._rd_dataset = self._rd_dataset.map_batches( + _store_fn, + batch_format="pandas", + num_cpus=1, + ) + return self + def dedup(self, params: "DedupParams | None" = None, **kwargs: Any) -> "BatchIngestor": """Remove duplicate and overlapping images before captioning.""" if self._rd_dataset is None: diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py index ec89f292a..9295a0417 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py @@ -60,7 +60,9 @@ from ..params import HtmlChunkParams from ..params import IngestExecuteParams from ..params import TextChunkParams +from ..params import StoreParams from ..params import VdbUploadParams +from ..io.image_store import load_image_b64_from_uri from ..pdf.extract import pdf_extraction from ..pdf.split import _split_pdf_to_single_page_bytes, pdf_path_to_pages_df from ..utils.remote_auth import resolve_remote_api_key @@ -109,6 +111,41 @@ def _deep_copy_row(row_dict: Dict[str, Any]) -> Dict[str, Any]: return out +def _b64_from_dict(d: dict) -> Optional[str]: + """Resolve base64 from a dict: inline ``image_b64`` then ``stored_image_uri``.""" + b64 = d.get("image_b64") + if isinstance(b64, str) and b64: + return b64 + uri = d.get("stored_image_uri") + if isinstance(uri, str) and uri: + return load_image_b64_from_uri(uri) + return None + + +def _resolve_page_image_b64(page_image: Any) -> Optional[str]: + """Return base64 for a page image, loading from stored URI if needed.""" + if not isinstance(page_image, dict): + return None + return _b64_from_dict(page_image) + + +def _resolve_item_image_b64(item: dict, page_image_b64: Optional[str]) -> Optional[str]: + """Return base64 for a structured content item, with fallback chain. + + Priority: inline base64 > stored URI > crop from page image. + """ + resolved = _b64_from_dict(item) + if resolved: + return resolved + if page_image_b64: + bbox = item.get("bbox_xyxy_norm") + if bbox and len(bbox) == 4: + cropped_b64, _ = _crop_b64_image_by_norm_bbox(page_image_b64, bbox_xyxy_norm=bbox) + return cropped_b64 + return page_image_b64 + return None + + def explode_content_to_rows( batch_df: Any, *, @@ -167,9 +204,7 @@ def explode_content_to_rows( if not any(c in batch_df.columns for c in content_columns): batch_df = batch_df.copy() if text_mod in IMAGE_MODALITIES and "page_image" in batch_df.columns: - batch_df["_image_b64"] = batch_df["page_image"].apply( - lambda pi: pi.get("image_b64") if isinstance(pi, dict) else None - ) + batch_df["_image_b64"] = batch_df["page_image"].apply(_resolve_page_image_b64) batch_df["_embed_modality"] = text_mod return batch_df @@ -181,8 +216,8 @@ def explode_content_to_rows( # Extract page-level image b64 once per source row. page_image = row_dict.get("page_image") page_image_b64: Optional[str] = None - if any_images and isinstance(page_image, dict): - page_image_b64 = page_image.get("image_b64") + if any_images: + page_image_b64 = _resolve_page_image_b64(page_image) # Row for page text. page_text = row_dict.get(text_column) @@ -210,15 +245,8 @@ def explode_content_to_rows( content_row[text_column] = t.strip() content_row["_embed_modality"] = struct_mod content_row["_content_type"] = col - if struct_mod in IMAGE_MODALITIES and page_image_b64: - bbox = item.get("bbox_xyxy_norm") - if bbox and len(bbox) == 4: - cropped_b64, _ = _crop_b64_image_by_norm_bbox(page_image_b64, bbox_xyxy_norm=bbox) - content_row["_image_b64"] = cropped_b64 - else: - content_row["_image_b64"] = page_image_b64 - elif struct_mod in IMAGE_MODALITIES: - content_row["_image_b64"] = None + if struct_mod in IMAGE_MODALITIES: + content_row["_image_b64"] = _resolve_item_image_b64(item, page_image_b64) new_rows.append(content_row) exploded_any = True @@ -271,9 +299,7 @@ def collapse_content_to_page_rows( # Full page image (no cropping) for image modalities. if modality in IMAGE_MODALITIES: if "page_image" in batch_df.columns: - batch_df["_image_b64"] = batch_df["page_image"].apply( - lambda pi: pi.get("image_b64") if isinstance(pi, dict) else None - ) + batch_df["_image_b64"] = batch_df["page_image"].apply(_resolve_page_image_b64) else: batch_df["_image_b64"] = None @@ -1337,6 +1363,15 @@ def extract_audio( self._tasks.append((apply_asr_to_df, {"asr_params": self._extract_audio_asr_kwargs})) return self + def store(self, params: "StoreParams | None" = None, **kwargs: Any) -> "InProcessIngestor": + """Store extracted images to disk or cloud storage via fsspec.""" + from nemo_retriever.io.image_store import store_extracted_images + + resolved = _coerce_params(params, StoreParams, kwargs) + store_kwargs = resolved.model_dump(mode="python") + self._tasks.append((store_extracted_images, store_kwargs)) + return self + def dedup(self, params: "DedupParams | None" = None, **kwargs: Any) -> "InProcessIngestor": """Remove duplicate and overlapping images before captioning.""" @@ -1356,7 +1391,6 @@ def caption(self, params: "CaptionParams | None" = None, **kwargs: Any) -> "InPr Otherwise a local ``NemotronVLMCaptioner`` is loaded from HF. """ from nemo_retriever.caption.caption import caption_images - from nemo_retriever.params import CaptionParams resolved = _coerce_params(params, CaptionParams, kwargs) caption_kwargs = resolved.model_dump(mode="python") diff --git a/nemo_retriever/src/nemo_retriever/ingestor.py b/nemo_retriever/src/nemo_retriever/ingestor.py index 0d0acd66b..f17683d13 100644 --- a/nemo_retriever/src/nemo_retriever/ingestor.py +++ b/nemo_retriever/src/nemo_retriever/ingestor.py @@ -28,6 +28,7 @@ from nemo_retriever.params import IngestExecuteParams from nemo_retriever.params import IngestorCreateParams from nemo_retriever.params import RunMode +from nemo_retriever.params import StoreParams from nemo_retriever.params import VdbUploadParams @@ -141,8 +142,9 @@ def split(self, params: TextChunkParams | None = None, **kwargs: Any) -> "ingest _ = _merge_params(params, kwargs) self._not_implemented("split") - def store(self) -> "ingestor": + def store(self, params: StoreParams | None = None, **kwargs: Any) -> "ingestor": """Record a store task configuration.""" + _ = _merge_params(params, kwargs) self._not_implemented("store") def store_embed(self) -> "ingestor": diff --git a/nemo_retriever/src/nemo_retriever/io/__init__.py b/nemo_retriever/src/nemo_retriever/io/__init__.py index 8f80f45d6..8ebe1198c 100644 --- a/nemo_retriever/src/nemo_retriever/io/__init__.py +++ b/nemo_retriever/src/nemo_retriever/io/__init__.py @@ -3,13 +3,16 @@ # SPDX-License-Identifier: Apache-2.0 from .dataframe import read_dataframe, validate_primitives_dataframe, write_dataframe +from .image_store import load_image_b64_from_uri, store_extracted_images from .markdown import to_markdown, to_markdown_by_page from .stage_files import build_stage_output_path, find_stage_inputs __all__ = [ "build_stage_output_path", "find_stage_inputs", + "load_image_b64_from_uri", "read_dataframe", + "store_extracted_images", "to_markdown", "to_markdown_by_page", "validate_primitives_dataframe", diff --git a/nemo_retriever/src/nemo_retriever/io/image_store.py b/nemo_retriever/src/nemo_retriever/io/image_store.py new file mode 100644 index 000000000..1109b7894 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/io/image_store.py @@ -0,0 +1,349 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Store extracted images to disk or cloud storage via fsspec.""" + +from __future__ import annotations + +import base64 +import io +import logging +import os +import re +from typing import Any, Dict, Optional, Sequence + +import pandas as pd +from PIL import Image +from upath import UPath + +logger = logging.getLogger(__name__) + +_DIRECT_IMAGE_FORMATS = {"png", "jpeg"} +_FORMAT_ALIASES = {"jpg": "jpeg"} + +# Known limitation: _safe_stem derives the output subdirectory from the +# filename alone (e.g. "report.pdf" → "report/"). Two source files with +# the same basename but different parent directories will write to the same +# subdirectory and may overwrite each other. This matches the legacy +# nv-ingest store behaviour. A future PR should incorporate a short hash +# of the full source path to eliminate collisions. + + +def _safe_stem(name: str) -> str: + """Derive a filesystem-safe stem from a source path.""" + s = str(name or "").strip() or "document" + s = os.path.splitext(os.path.basename(s))[0] or "document" + s = re.sub(r"[^A-Za-z0-9._-]+", "_", s) + return s[:160] if len(s) > 160 else s + + +def _normalize_image_format(image_format: str) -> str: + fmt = str(image_format or "png").strip().lower() + fmt = _FORMAT_ALIASES.get(fmt, fmt) + if fmt not in _DIRECT_IMAGE_FORMATS: + raise ValueError(f"Unsupported image_format: {image_format!r}. Supported formats: png, jpeg") + return fmt + + +def _decode_image_bytes(image_b64: str) -> bytes | None: + try: + return base64.b64decode(image_b64) + except Exception as exc: + logger.warning("Failed to decode image payload: %s", exc) + return None + + +def _write_bytes(dest: UPath, raw: bytes) -> None: + dest.parent.mkdir(parents=True, exist_ok=True) + with dest.open("wb") as f: + f.write(raw) + + +def _normalized_encoding(value: Any) -> str | None: + if not isinstance(value, str): + return None + enc = value.strip().lower() + enc = _FORMAT_ALIASES.get(enc, enc) + if enc in _DIRECT_IMAGE_FORMATS: + return enc + return None + + +def _sniff_image_encoding(raw: bytes) -> str | None: + if raw.startswith(b"\x89PNG\r\n\x1a\n"): + return "png" + if raw.startswith(b"\xff\xd8\xff"): + return "jpeg" + return None + + +def _resolve_direct_write_encoding( + payload: dict[str, Any] | None, + raw: bytes, + fallback: str, +) -> str: + sniffed = _sniff_image_encoding(raw) + if sniffed: + return sniffed + + declared = _normalized_encoding(payload.get("encoding")) if isinstance(payload, dict) else None + if declared: + return declared + + return fallback + + +def _crop_and_write( + dest: UPath, + page_image: Image.Image, + bbox_xyxy_norm: Sequence[float], + image_format: str = "png", +) -> bool: + """Crop a region from an already-decoded page image and write to *dest*. + + Returns ``True`` on success, ``False`` on skip/failure. + """ + try: + w, h = page_image.size + if w <= 1 or h <= 1: + return False + + x1n, y1n, x2n, y2n = (float(v) for v in bbox_xyxy_norm) + x1 = int(min(max(x1n * w, 0), w)) + y1 = int(min(max(y1n * h, 0), h)) + x2 = int(min(max(x2n * w, 0), w)) + y2 = int(min(max(y2n * h, 0), h)) + if x2 <= x1 or y2 <= y1: + return False + + crop = page_image.crop((x1, y1, x2, y2)) + buf = io.BytesIO() + crop.save(buf, format=image_format.upper()) + + dest.parent.mkdir(parents=True, exist_ok=True) + with dest.open("wb") as f: + f.write(buf.getvalue()) + return True + except Exception as exc: + logger.warning("Failed to crop and write %s: %s", dest, exc) + return False + + +def _decode_page_image(page_image_b64: str) -> Image.Image | None: + """Decode a base64-encoded page image into a PIL Image (once per row).""" + try: + raw = base64.b64decode(page_image_b64) + return Image.open(io.BytesIO(raw)).convert("RGB") + except Exception as exc: + logger.warning("Failed to decode page image: %s", exc) + return None + + +def _build_uri_info( + dest: UPath, + storage_root: UPath, + public_base_url: Optional[str], +) -> Dict[str, Optional[str]]: + """Build a dict with ``stored_image_uri`` and optionally ``stored_image_url``.""" + relative_key = dest.relative_to(storage_root).as_posix() + info: Dict[str, Optional[str]] = {"stored_image_uri": dest.as_uri()} + if public_base_url: + info["stored_image_url"] = f"{public_base_url.rstrip('/')}/{relative_key}" + return info + + +def load_image_b64_from_uri(uri: str) -> Optional[str]: + """Read an image from a stored URI and return its base64 encoding. + + Accepts any fsspec-compatible URI (``file://``, ``s3://``, etc.). + Returns ``None`` on failure so callers can fall back gracefully. + """ + try: + raw = UPath(uri).read_bytes() + return base64.b64encode(raw).decode("ascii") + except Exception as exc: + logger.warning("Failed to load image from %s: %s", uri, exc) + return None + + +def store_extracted_images( + df: pd.DataFrame, + *, + storage_uri: str = "stored_images", + storage_options: dict[str, Any] | None = None, + public_base_url: str | None = None, + store_page_images: bool = True, + store_tables: bool = True, + store_charts: bool = True, + store_infographics: bool = True, + store_images: bool = True, + image_format: str = "png", + strip_base64: bool = True, +) -> pd.DataFrame: + """Pipeline task: store extracted images to disk or cloud storage. + + For each row in the DataFrame: + + * Writes the full page image (from ``page_image["image_b64"]``) when + *store_page_images* is ``True``. + * Crops and writes sub-page images for tables / charts / infographics + using ``bbox_xyxy_norm`` from the page image. + * Writes natural sub-page images from the ``images`` column. + * Updates the DataFrame in-place with ``stored_image_uri`` (and an + optional ``stored_image_url`` when *public_base_url* is set). + + Parameters + ---------- + df : pd.DataFrame + Primitives DataFrame produced by the extraction pipeline. + storage_uri : str + Base URI for storage. Local path (``"./output"``) or + fsspec-compatible URI (``"s3://bucket/prefix"``). + storage_options : dict | None + Extra options forwarded to fsspec / UPath (auth keys, endpoint, etc.). + public_base_url : str | None + When set, each stored item also receives a ``stored_image_url`` + built as ``{public_base_url}/{relative_key}``. + store_page_images : bool + Save full page images. + store_tables : bool + Save table crops. + store_charts : bool + Save chart crops. + store_infographics : bool + Save infographic crops. + store_images : bool + Save natural sub-page images from the ``images`` column. + image_format : str + Output image format for generated crops (default ``"png"``). + Direct-write payloads preserve their source encoding and file extension. + strip_base64 : bool + When ``True`` (the default), clear ``image_b64`` after successful writes + to reduce memory pressure. The embed stage loads images from the stored + URIs when base64 is absent. Set to ``False`` only if downstream code + requires inline base64 for a reason other than embedding. + + Returns + ------- + pd.DataFrame + The (mutated) input DataFrame with storage URIs added. + """ + if not isinstance(df, pd.DataFrame) or df.empty: + return df + + logger.info("Storing extracted images to %s", storage_uri) + storage_root = UPath(storage_uri, **(storage_options or {})).resolve() + ext = _normalize_image_format(image_format) + if strip_base64: + logger.debug("strip_base64=True: image payloads will be cleared after writing.") + + col_flags: dict[str, str] = {} + if store_tables: + col_flags["table"] = "table" + if store_charts: + col_flags["chart"] = "chart" + if store_infographics: + col_flags["infographic"] = "infographic" + + for idx, row in df.iterrows(): + try: + source_path = row.get("path") or "" + stem = _safe_stem(source_path) + page_num = row.get("page_number", 1) + + page_image = row.get("page_image") + page_image_b64: str | None = None + if isinstance(page_image, dict): + page_image_b64 = page_image.get("image_b64") + + # Decode the page image lazily and reuse for all crops in this row. + page_pil: Image.Image | None = None + page_pil_decode_attempted = False + + def _get_page_pil() -> Image.Image | None: + nonlocal page_pil, page_pil_decode_attempted + if page_pil_decode_attempted: + return page_pil + page_pil_decode_attempted = True + if isinstance(page_image_b64, str) and page_image_b64: + page_pil = _decode_page_image(page_image_b64) + return page_pil + + # -- Full page image -- + if store_page_images and isinstance(page_image_b64, str) and page_image_b64: + raw = _decode_image_bytes(page_image_b64) + if raw is not None: + direct_ext = _resolve_direct_write_encoding(page_image, raw, ext) + dest = storage_root / stem / f"page_{page_num}.{direct_ext}" + _write_bytes(dest, raw) + uri_info = _build_uri_info(dest, storage_root, public_base_url) + page_image.update(uri_info) + page_image["encoding"] = direct_ext + if strip_base64: + page_image["image_b64"] = None + df.at[idx, "page_image"] = page_image + + # -- Structured content (tables / charts / infographics) -- + for col_name, type_label in col_flags.items(): + content_list = row.get(col_name) + if not isinstance(content_list, list): + continue + for item_idx, item in enumerate(content_list): + if not isinstance(item, dict): + continue + item_b64 = item.get("image_b64") + if isinstance(item_b64, str) and item_b64: + raw = _decode_image_bytes(item_b64) + if raw is not None: + direct_ext = _resolve_direct_write_encoding(item, raw, ext) + dest = storage_root / stem / f"page_{page_num}_{type_label}_{item_idx}.{direct_ext}" + _write_bytes(dest, raw) + item.update(_build_uri_info(dest, storage_root, public_base_url)) + item["encoding"] = direct_ext + if strip_base64: + item["image_b64"] = None + else: + page_pil = _get_page_pil() + if page_pil is not None and not (isinstance(item_b64, str) and item_b64): + bbox = item.get("bbox_xyxy_norm") + if bbox and len(bbox) == 4: + dest = storage_root / stem / f"page_{page_num}_{type_label}_{item_idx}.{ext}" + if _crop_and_write(dest, page_pil, bbox, image_format=ext): + item.update(_build_uri_info(dest, storage_root, public_base_url)) + item["encoding"] = ext + df.at[idx, col_name] = content_list + + # -- Natural sub-page images -- + if store_images: + images_list = row.get("images") + if isinstance(images_list, list): + for img_idx, img_item in enumerate(images_list): + if not isinstance(img_item, dict): + continue + img_b64 = img_item.get("image_b64") + if isinstance(img_b64, str) and img_b64: + raw = _decode_image_bytes(img_b64) + if raw is not None: + direct_ext = _resolve_direct_write_encoding(img_item, raw, ext) + dest = storage_root / stem / f"page_{page_num}_image_{img_idx}.{direct_ext}" + _write_bytes(dest, raw) + img_item.update(_build_uri_info(dest, storage_root, public_base_url)) + img_item["encoding"] = direct_ext + if strip_base64: + img_item["image_b64"] = None + else: + page_pil = _get_page_pil() + if page_pil is not None and not (isinstance(img_b64, str) and img_b64): + bbox = img_item.get("bbox_xyxy_norm") + if bbox and len(bbox) == 4: + dest = storage_root / stem / f"page_{page_num}_image_{img_idx}.{ext}" + if _crop_and_write(dest, page_pil, bbox, image_format=ext): + img_item.update(_build_uri_info(dest, storage_root, public_base_url)) + img_item["encoding"] = ext + df.at[idx, "images"] = images_list + + except Exception as exc: + logger.exception("Failed to store images for row %s: %s", idx, exc) + + return df diff --git a/nemo_retriever/src/nemo_retriever/params/__init__.py b/nemo_retriever/src/nemo_retriever/params/__init__.py index e4f5b0a90..5fdc64012 100644 --- a/nemo_retriever/src/nemo_retriever/params/__init__.py +++ b/nemo_retriever/src/nemo_retriever/params/__init__.py @@ -24,6 +24,7 @@ from .models import RemoteInvokeParams from .models import RemoteRetryParams from .models import RunMode +from .models import StoreParams from .models import TableParams from .models import TextChunkParams from .models import VdbUploadParams @@ -51,6 +52,7 @@ "RemoteInvokeParams", "RemoteRetryParams", "RunMode", + "StoreParams", "TableParams", "TextChunkParams", "VdbUploadParams", diff --git a/nemo_retriever/src/nemo_retriever/params/models.py b/nemo_retriever/src/nemo_retriever/params/models.py index 5d2d5efd0..7b8e574ac 100644 --- a/nemo_retriever/src/nemo_retriever/params/models.py +++ b/nemo_retriever/src/nemo_retriever/params/models.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Literal, Optional, Sequence, Tuple +from typing import Any, Literal, Optional, Sequence, Tuple import warnings @@ -270,6 +270,19 @@ class VdbUploadParams(_ParamsModel): lancedb: LanceDbParams = Field(default_factory=LanceDbParams) +class StoreParams(_ParamsModel): + storage_uri: str = "stored_images" + storage_options: dict[str, Any] = Field(default_factory=dict) + public_base_url: Optional[str] = None + store_page_images: bool = True + store_tables: bool = True + store_charts: bool = True + store_infographics: bool = True + store_images: bool = True + image_format: str = "png" + strip_base64: bool = True + + class PageElementsParams(_ParamsModel): remote: RemoteInvokeParams = Field(default_factory=RemoteInvokeParams) remote_retry: RemoteRetryParams = Field(default_factory=RemoteRetryParams) diff --git a/nemo_retriever/tests/test_io_image_store.py b/nemo_retriever/tests/test_io_image_store.py new file mode 100644 index 000000000..4b379d864 --- /dev/null +++ b/nemo_retriever/tests/test_io_image_store.py @@ -0,0 +1,534 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for nemo_retriever.io.image_store.""" + +from __future__ import annotations + +import base64 +import io +from pathlib import Path + +import pandas as pd + +from nemo_retriever.io.image_store import _safe_stem, load_image_b64_from_uri, store_extracted_images +from nemo_retriever.params import StoreParams + + +def _make_tiny_png_b64(width: int = 4, height: int = 4, color=(255, 0, 0)) -> str: + """Create a minimal PNG image encoded as base64.""" + from PIL import Image + + buf = io.BytesIO() + img = Image.new("RGB", (width, height), color=color) + img.save(buf, format="PNG") + return base64.b64encode(buf.getvalue()).decode("ascii") + + +def _make_tiny_jpeg_b64(width: int = 4, height: int = 4, color=(255, 0, 0)) -> str: + """Create a minimal JPEG image encoded as base64.""" + from PIL import Image + + buf = io.BytesIO() + img = Image.new("RGB", (width, height), color=color) + img.save(buf, format="JPEG") + return base64.b64encode(buf.getvalue()).decode("ascii") + + +# --------------------------------------------------------------------------- +# _safe_stem +# --------------------------------------------------------------------------- + + +class TestSafeStem: + def test_normal_path(self): + assert _safe_stem("/data/docs/report.pdf") == "report" + + def test_special_characters(self): + assert _safe_stem("my file (copy).pdf") == "my_file_copy_" + + def test_empty_string(self): + assert _safe_stem("") == "document" + + def test_none_value(self): + assert _safe_stem(None) == "document" + + def test_long_name_truncated(self): + long_name = "a" * 200 + ".pdf" + result = _safe_stem(long_name) + assert len(result) <= 160 + + def test_slashes_only(self): + assert _safe_stem("///") == "document" + + +# --------------------------------------------------------------------------- +# store_extracted_images — page images +# --------------------------------------------------------------------------- + + +class TestStorePageImages: + def test_writes_file_and_updates_uri(self, tmp_path: Path): + b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": b64, "encoding": "png"}, + "table": [], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + + result = store_extracted_images(df, storage_uri=str(tmp_path)) + expected_file = tmp_path / "test" / "page_1.png" + assert expected_file.exists() + assert expected_file.stat().st_size > 0 + + page_img = result.iloc[0]["page_image"] + assert "stored_image_uri" in page_img + assert page_img["stored_image_uri"].startswith("file://") + + def test_skips_when_no_page_image(self, tmp_path: Path): + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": None, + "table": [], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path)) + assert not any(tmp_path.rglob("*.png")) + assert result.iloc[0]["page_image"] is None + + def test_disabled_flag(self, tmp_path: Path): + b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": b64}, + "table": [], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + store_extracted_images(df, storage_uri=str(tmp_path), store_page_images=False) + assert not any(tmp_path.rglob("page_*.png")) + + +# --------------------------------------------------------------------------- +# store_extracted_images — structured content crops +# --------------------------------------------------------------------------- + + +class TestStoreStructuredContent: + def test_crop_table_from_page_image(self, tmp_path: Path): + b64 = _make_tiny_png_b64(width=100, height=100) + df = pd.DataFrame( + [ + { + "path": "/docs/report.pdf", + "page_number": 2, + "page_image": {"image_b64": b64, "encoding": "png"}, + "table": [ + {"text": "col1|col2", "bbox_xyxy_norm": [0.1, 0.1, 0.9, 0.9]}, + ], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + + result = store_extracted_images(df, storage_uri=str(tmp_path)) + expected_file = tmp_path / "report" / "page_2_table_0.png" + assert expected_file.exists() + + table_item = result.iloc[0]["table"][0] + assert "stored_image_uri" in table_item + + def test_direct_image_b64_preferred(self, tmp_path: Path): + page_b64 = _make_tiny_png_b64(color=(255, 0, 0)) + item_b64 = _make_tiny_png_b64(color=(0, 255, 0)) + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": page_b64}, + "table": [ + {"text": "data", "image_b64": item_b64, "bbox_xyxy_norm": [0, 0, 1, 1]}, + ], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path)) + table_item = result.iloc[0]["table"][0] + assert "stored_image_uri" in table_item + + # Verify the file was written from item_b64, not cropped from page + written_bytes = (tmp_path / "test" / "page_1_table_0.png").read_bytes() + assert written_bytes == base64.b64decode(item_b64) + + def test_selective_flags_skip_tables(self, tmp_path: Path): + b64 = _make_tiny_png_b64(width=100, height=100) + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": b64}, + "table": [{"text": "t", "bbox_xyxy_norm": [0.1, 0.1, 0.5, 0.5]}], + "chart": [{"text": "c", "bbox_xyxy_norm": [0.1, 0.1, 0.5, 0.5]}], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + store_extracted_images(df, storage_uri=str(tmp_path), store_tables=False) + files = list(tmp_path.rglob("*.png")) + names = [f.name for f in files] + assert not any("table" in n for n in names) + assert any("chart" in n for n in names) + + +# --------------------------------------------------------------------------- +# store_extracted_images — natural sub-page images +# --------------------------------------------------------------------------- + + +class TestStoreNaturalImages: + def test_writes_from_image_b64(self, tmp_path: Path): + img_b64 = _make_tiny_png_b64(color=(0, 0, 255)) + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": None, + "table": [], + "chart": [], + "infographic": [], + "images": [{"image_b64": img_b64, "bbox_xyxy_norm": [0, 0, 1, 1]}], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path)) + expected_file = tmp_path / "test" / "page_1_image_0.png" + assert expected_file.exists() + assert result.iloc[0]["images"][0].get("stored_image_uri") is not None + + +# --------------------------------------------------------------------------- +# store_extracted_images — format consistency +# --------------------------------------------------------------------------- + + +class TestFormatConsistency: + def test_page_image_keeps_source_encoding_extension(self, tmp_path: Path): + jpeg_b64 = _make_tiny_jpeg_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": jpeg_b64, "encoding": "jpeg"}, + "table": [], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + + result = store_extracted_images(df, storage_uri=str(tmp_path), image_format="png") + expected_file = tmp_path / "test" / "page_1.jpeg" + assert expected_file.exists() + assert not (tmp_path / "test" / "page_1.png").exists() + assert expected_file.read_bytes() == base64.b64decode(jpeg_b64) + assert result.iloc[0]["page_image"]["stored_image_uri"].endswith("/test/page_1.jpeg") + + def test_direct_content_b64_keeps_payload_extension(self, tmp_path: Path): + item_b64 = _make_tiny_png_b64(color=(0, 255, 0)) + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": None, + "table": [{"image_b64": item_b64, "bbox_xyxy_norm": [0, 0, 1, 1]}], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + + result = store_extracted_images(df, storage_uri=str(tmp_path), image_format="jpeg") + expected_file = tmp_path / "test" / "page_1_table_0.png" + assert expected_file.exists() + assert not (tmp_path / "test" / "page_1_table_0.jpeg").exists() + assert expected_file.read_bytes() == base64.b64decode(item_b64) + assert result.iloc[0]["table"][0]["stored_image_uri"].endswith("/test/page_1_table_0.png") + + def test_crops_use_requested_output_format(self, tmp_path: Path): + page_b64 = _make_tiny_png_b64(width=100, height=100) + df = pd.DataFrame( + [ + { + "path": "/docs/report.pdf", + "page_number": 2, + "page_image": {"image_b64": page_b64, "encoding": "png"}, + "table": [{"bbox_xyxy_norm": [0.2, 0.2, 0.8, 0.8]}], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + + result = store_extracted_images(df, storage_uri=str(tmp_path), image_format="jpeg") + expected_file = tmp_path / "report" / "page_2_table_0.jpeg" + assert expected_file.exists() + assert expected_file.read_bytes().startswith(b"\xff\xd8\xff") + assert result.iloc[0]["table"][0]["stored_image_uri"].endswith("/report/page_2_table_0.jpeg") + + +# --------------------------------------------------------------------------- +# store_extracted_images — base64 stripping +# --------------------------------------------------------------------------- + + +class TestBase64Stripping: + def test_page_image_b64_stripped_by_default(self, tmp_path: Path): + b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": b64, "encoding": "png"}, + "table": [], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path)) + page_img = result.iloc[0]["page_image"] + assert page_img["image_b64"] is None + assert "stored_image_uri" in page_img + + def test_page_image_b64_preserved_when_strip_disabled(self, tmp_path: Path): + b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": b64, "encoding": "png"}, + "table": [], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path), strip_base64=False) + page_img = result.iloc[0]["page_image"] + assert page_img["image_b64"] == b64 + assert "stored_image_uri" in page_img + + def test_page_image_b64_stripped_when_enabled(self, tmp_path: Path): + b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": b64, "encoding": "png"}, + "table": [], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path), strip_base64=True) + page_img = result.iloc[0]["page_image"] + assert page_img["image_b64"] is None + assert "stored_image_uri" in page_img + + def test_structured_content_b64_stripped(self, tmp_path: Path): + page_b64 = _make_tiny_png_b64(width=100, height=100) + item_b64 = _make_tiny_png_b64(color=(0, 255, 0)) + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": page_b64}, + "table": [{"text": "data", "image_b64": item_b64, "bbox_xyxy_norm": [0, 0, 1, 1]}], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path), strip_base64=True) + assert result.iloc[0]["table"][0]["image_b64"] is None + assert "stored_image_uri" in result.iloc[0]["table"][0] + + def test_natural_image_b64_stripped(self, tmp_path: Path): + img_b64 = _make_tiny_png_b64(color=(0, 0, 255)) + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": None, + "table": [], + "chart": [], + "infographic": [], + "images": [{"image_b64": img_b64, "bbox_xyxy_norm": [0, 0, 1, 1]}], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path), strip_base64=True) + assert result.iloc[0]["images"][0]["image_b64"] is None + assert "stored_image_uri" in result.iloc[0]["images"][0] + + +# --------------------------------------------------------------------------- +# store_extracted_images — edge cases +# --------------------------------------------------------------------------- + + +class TestStoreEdgeCases: + def test_empty_dataframe(self, tmp_path: Path): + df = pd.DataFrame() + result = store_extracted_images(df, storage_uri=str(tmp_path)) + assert result.empty + + def test_public_base_url(self, tmp_path: Path): + b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": b64}, + "table": [], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + result = store_extracted_images( + df, + storage_uri=str(tmp_path), + public_base_url="https://cdn.example.com/assets", + ) + page_img = result.iloc[0]["page_image"] + assert page_img["stored_image_url"] == "https://cdn.example.com/assets/test/page_1.png" + + def test_multiple_pages(self, tmp_path: Path): + b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": i, + "page_image": {"image_b64": b64}, + "table": [], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + for i in range(1, 4) + ] + ) + store_extracted_images(df, storage_uri=str(tmp_path)) + for i in range(1, 4): + assert (tmp_path / "test" / f"page_{i}.png").exists() + + +# --------------------------------------------------------------------------- +# load_image_b64_from_uri +# --------------------------------------------------------------------------- + + +class TestLoadImageB64FromUri: + def test_round_trip(self, tmp_path: Path): + from PIL import Image + + img = Image.new("RGB", (4, 4), (0, 255, 0)) + dest = tmp_path / "green.png" + img.save(dest, format="PNG") + result = load_image_b64_from_uri(dest.as_uri()) + assert result is not None + raw = base64.b64decode(result) + assert raw.startswith(b"\x89PNG") + + def test_missing_file_returns_none(self): + result = load_image_b64_from_uri("file:///nonexistent/path/image.png") + assert result is None + + +# --------------------------------------------------------------------------- +# StoreParams model +# --------------------------------------------------------------------------- + + +class TestStoreParams: + def test_defaults(self): + p = StoreParams() + assert p.storage_uri == "stored_images" + assert p.store_page_images is True + assert p.store_tables is True + assert p.image_format == "png" + assert p.strip_base64 is True + + def test_overrides(self): + p = StoreParams(storage_uri="s3://bucket/prefix", store_tables=False, image_format="jpeg", strip_base64=True) + assert p.storage_uri == "s3://bucket/prefix" + assert p.store_tables is False + assert p.image_format == "jpeg" + assert p.strip_base64 is True diff --git a/nemo_retriever/tests/test_multimodal_embed.py b/nemo_retriever/tests/test_multimodal_embed.py index f357193ef..491056f88 100644 --- a/nemo_retriever/tests/test_multimodal_embed.py +++ b/nemo_retriever/tests/test_multimodal_embed.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. # All rights reserved. # SPDX-License-Identifier: Apache-2.0 @@ -8,6 +8,8 @@ from __future__ import annotations +import base64 +import io import sys from unittest.mock import MagicMock, patch @@ -23,6 +25,7 @@ _image_from_row, _multimodal_callable_runner, ) +from nemo_retriever.io.image_store import store_extracted_images # --------------------------------------------------------------------------- # Stub heavy internal modules so ``from nemo_retriever.ingest_modes.inprocess`` @@ -94,6 +97,15 @@ del _injected +def _make_tiny_png_b64(width: int = 8, height: int = 8, color=(255, 0, 0)) -> str: + from PIL import Image + + buf = io.BytesIO() + img = Image.new("RGB", (width, height), color=color) + img.save(buf, format="PNG") + return base64.b64encode(buf.getvalue()).decode("ascii") + + # =================================================================== # Pure helpers # =================================================================== @@ -237,6 +249,81 @@ def test_text_image_carries_image(self, mock_crop): ) +class TestStoreThenExplodeMultimodal: + def test_store_preserves_b64_when_strip_disabled(self, tmp_path): + page_b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "text": "hello", + "page_image": {"image_b64": page_b64, "encoding": "png"}, + } + ] + ) + + stored = store_extracted_images(df, storage_uri=str(tmp_path), strip_base64=False) + exploded = explode_content_to_rows(stored, modality="text_image") + + assert exploded["_image_b64"].iloc[0] == page_b64 + assert exploded["_embed_modality"].iloc[0] == "text_image" + + def test_store_strip_then_explode_loads_from_uri(self, tmp_path): + page_b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "text": "hello", + "page_image": {"image_b64": page_b64, "encoding": "png"}, + } + ] + ) + + stored = store_extracted_images(df, storage_uri=str(tmp_path)) + assert stored.iloc[0]["page_image"]["image_b64"] is None + assert stored.iloc[0]["page_image"]["stored_image_uri"] is not None + + exploded = explode_content_to_rows(stored, modality="text_image") + loaded_b64 = exploded["_image_b64"].iloc[0] + assert isinstance(loaded_b64, str) and len(loaded_b64) > 0 + + def test_store_strip_then_explode_structured_loads_from_uri(self, tmp_path): + page_b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "text": "hello", + "page_image": {"image_b64": page_b64, "encoding": "png"}, + "table": [ + { + "text": "table data", + "image_b64": page_b64, + "encoding": "png", + } + ], + "chart": [], + "infographic": [], + } + ] + ) + + stored = store_extracted_images(df, storage_uri=str(tmp_path)) + table_item = stored.iloc[0]["table"][0] + assert table_item["image_b64"] is None + assert table_item["stored_image_uri"] is not None + + exploded = explode_content_to_rows(stored, modality="text", structured_elements_modality="text_image") + struct_rows = exploded[exploded["_content_type"] == "table"] + assert len(struct_rows) == 1 + loaded_b64 = struct_rows.iloc[0]["_image_b64"] + assert isinstance(loaded_b64, str) and len(loaded_b64) > 0 + + # =================================================================== # collapse_content_to_page_rows # =================================================================== @@ -330,3 +417,26 @@ def test_non_dataframe_passthrough(self): """Non-DataFrame input is returned as-is.""" result = collapse_content_to_page_rows(None) assert result is None + + def test_collapse_resolves_uri_when_b64_stripped(self, tmp_path): + page_b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "text": "hello", + "page_image": {"image_b64": page_b64, "encoding": "png"}, + "table": [], + "chart": [], + "infographic": [], + } + ] + ) + + stored = store_extracted_images(df, storage_uri=str(tmp_path)) + assert stored.iloc[0]["page_image"]["image_b64"] is None + + result = collapse_content_to_page_rows(stored, modality="text_image") + loaded_b64 = result["_image_b64"].iloc[0] + assert isinstance(loaded_b64, str) and len(loaded_b64) > 0