From dcb2d91d751d93d5e43bc0444ca56a2297ad2e4e Mon Sep 17 00:00:00 2001 From: Jacob Ioffe Date: Tue, 24 Mar 2026 18:09:20 +0000 Subject: [PATCH 1/4] (retriever) Add .store() task for persisting extracted images (#1675) - Add store_extracted_images() with fsspec/UPath support for local and cloud storage - Wire .store() into InProcessIngestor, BatchIngestor, and batch_pipeline CLI - Add StoreParams model, unit tests, and fsspec/universal-pathlib deps Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Jacob Ioffe --- nemo_retriever/pyproject.toml | 2 + .../nemo_retriever/examples/batch_pipeline.py | 10 + .../src/nemo_retriever/ingest_modes/batch.py | 20 ++ .../nemo_retriever/ingest_modes/inprocess.py | 10 + nemo_retriever/src/nemo_retriever/ingestor.py | 4 +- .../src/nemo_retriever/io/__init__.py | 2 + .../src/nemo_retriever/io/image_store.py | 241 ++++++++++++++ .../src/nemo_retriever/params/__init__.py | 2 + .../src/nemo_retriever/params/models.py | 14 +- nemo_retriever/tests/test_io_image_store.py | 312 ++++++++++++++++++ 10 files changed, 615 insertions(+), 2 deletions(-) create mode 100644 nemo_retriever/src/nemo_retriever/io/image_store.py create mode 100644 nemo_retriever/tests/test_io_image_store.py diff --git a/nemo_retriever/pyproject.toml b/nemo_retriever/pyproject.toml index ea22099fb..1005d4004 100644 --- a/nemo_retriever/pyproject.toml +++ b/nemo_retriever/pyproject.toml @@ -71,6 +71,8 @@ dependencies = [ "soundfile>=0.12.0", "scipy>=1.11.0", "nvidia-ml-py", + "fsspec>=2023.1.0", + "universal-pathlib>=0.2.0", ] [project.optional-dependencies] diff --git a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py index b7e96ac93..ed8147ef8 100644 --- a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py @@ -462,6 +462,11 @@ def main( "(used when --table-output-format=markdown)." ), ), + store_images_uri: Optional[str] = typer.Option( + None, + "--store-images-uri", + help="When set, store extracted images to this URI (local path or fsspec URI like s3://bucket/prefix).", + ), text_chunk: bool = typer.Option( False, "--text-chunk", @@ -683,6 +688,11 @@ def _extract_params(batch_tuning: dict, **overrides: Any) -> ExtractParams: if enable_text_chunk: ingestor = ingestor.split(_text_chunk_params) + if store_images_uri: + from nemo_retriever.params import StoreParams + + ingestor = ingestor.store(StoreParams(storage_uri=store_images_uri)) + ingestor = ingestor.embed(embed_params) logger.info("Running extraction...") diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py index 84c13fe5f..5af809201 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/batch.py @@ -47,6 +47,7 @@ from ..params import IngestExecuteParams from ..params import PdfSplitParams from ..params import TextChunkParams +from ..params import StoreParams from ..params import VdbUploadParams logger = logging.getLogger(__name__) @@ -861,6 +862,25 @@ def embed( return self + def store(self, params: StoreParams | None = None, **kwargs: Any) -> "BatchIngestor": + """Store extracted images to disk or cloud storage via fsspec.""" + if self._rd_dataset is None: + raise RuntimeError("No Ray Dataset to store from. Call .files(...) / .extract(...) first.") + + from nemo_retriever.io.image_store import store_extracted_images + + resolved = _coerce_params(params, StoreParams, kwargs) + store_kwargs = resolved.model_dump(mode="python") + self._tasks.append(("store", dict(store_kwargs))) + + _store_fn = partial(store_extracted_images, **store_kwargs) + self._rd_dataset = self._rd_dataset.map_batches( + _store_fn, + batch_format="pandas", + num_cpus=1, + ) + return self + def vdb_upload(self, params: VdbUploadParams | None = None, **kwargs: Any) -> "BatchIngestor": """ Add a streaming LanceDB upload stage to the batch pipeline. diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py index 34eaf7ed5..a4d93d04e 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py @@ -1328,6 +1328,16 @@ def extract_audio( self._tasks.append((apply_asr_to_df, {"asr_params": self._extract_audio_asr_kwargs})) return self + def store(self, params: "StoreParams | None" = None, **kwargs: Any) -> "InProcessIngestor": + """Store extracted images to disk or cloud storage via fsspec.""" + from nemo_retriever.io.image_store import store_extracted_images + from nemo_retriever.params import StoreParams + + resolved = _coerce_params(params, StoreParams, kwargs) + store_kwargs = resolved.model_dump(mode="python") + self._tasks.append((store_extracted_images, store_kwargs)) + return self + def embed(self, params: EmbedParams | None = None, **kwargs: Any) -> "InProcessIngestor": """ Configure embedding for in-process execution. diff --git a/nemo_retriever/src/nemo_retriever/ingestor.py b/nemo_retriever/src/nemo_retriever/ingestor.py index 7bbc19486..32f44110c 100644 --- a/nemo_retriever/src/nemo_retriever/ingestor.py +++ b/nemo_retriever/src/nemo_retriever/ingestor.py @@ -26,6 +26,7 @@ from nemo_retriever.params import IngestExecuteParams from nemo_retriever.params import IngestorCreateParams from nemo_retriever.params import RunMode +from nemo_retriever.params import StoreParams from nemo_retriever.params import VdbUploadParams @@ -138,8 +139,9 @@ def split(self, params: TextChunkParams | None = None, **kwargs: Any) -> "ingest _ = _merge_params(params, kwargs) self._not_implemented("split") - def store(self) -> "ingestor": + def store(self, params: StoreParams | None = None, **kwargs: Any) -> "ingestor": """Record a store task configuration.""" + _ = _merge_params(params, kwargs) self._not_implemented("store") def store_embed(self) -> "ingestor": diff --git a/nemo_retriever/src/nemo_retriever/io/__init__.py b/nemo_retriever/src/nemo_retriever/io/__init__.py index 8f80f45d6..0dcde6331 100644 --- a/nemo_retriever/src/nemo_retriever/io/__init__.py +++ b/nemo_retriever/src/nemo_retriever/io/__init__.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from .dataframe import read_dataframe, validate_primitives_dataframe, write_dataframe +from .image_store import store_extracted_images from .markdown import to_markdown, to_markdown_by_page from .stage_files import build_stage_output_path, find_stage_inputs @@ -10,6 +11,7 @@ "build_stage_output_path", "find_stage_inputs", "read_dataframe", + "store_extracted_images", "to_markdown", "to_markdown_by_page", "validate_primitives_dataframe", diff --git a/nemo_retriever/src/nemo_retriever/io/image_store.py b/nemo_retriever/src/nemo_retriever/io/image_store.py new file mode 100644 index 000000000..ed1cbcb49 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/io/image_store.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Store extracted images to disk or cloud storage via fsspec.""" + +from __future__ import annotations + +import base64 +import io +import logging +import os +import re +from typing import Any, Dict, Optional, Sequence + +import pandas as pd +from PIL import Image +from upath import UPath + +logger = logging.getLogger(__name__) + +# Known limitation: _safe_stem derives the output subdirectory from the +# filename alone (e.g. "report.pdf" → "report/"). Two source files with +# the same basename but different parent directories will write to the same +# subdirectory and may overwrite each other. This matches the legacy +# nv-ingest store behaviour. A future PR should incorporate a short hash +# of the full source path to eliminate collisions. + + +def _safe_stem(name: str) -> str: + """Derive a filesystem-safe stem from a source path.""" + s = str(name or "").strip() or "document" + s = os.path.splitext(os.path.basename(s))[0] or "document" + s = re.sub(r"[^A-Za-z0-9._-]+", "_", s) + return s[:160] if len(s) > 160 else s + + +def _decode_and_write(dest: UPath, image_b64: str) -> None: + """Decode a base64 image and write raw bytes to *dest* via UPath.""" + raw = base64.b64decode(image_b64) + dest.parent.mkdir(parents=True, exist_ok=True) + with dest.open("wb") as f: + f.write(raw) + + +def _crop_and_write( + dest: UPath, + page_image: Image.Image, + bbox_xyxy_norm: Sequence[float], + image_format: str = "png", +) -> bool: + """Crop a region from an already-decoded page image and write to *dest*. + + Returns ``True`` on success, ``False`` on skip/failure. + """ + try: + w, h = page_image.size + if w <= 1 or h <= 1: + return False + + x1n, y1n, x2n, y2n = (float(v) for v in bbox_xyxy_norm) + x1 = int(min(max(x1n * w, 0), w)) + y1 = int(min(max(y1n * h, 0), h)) + x2 = int(min(max(x2n * w, 0), w)) + y2 = int(min(max(y2n * h, 0), h)) + if x2 <= x1 or y2 <= y1: + return False + + crop = page_image.crop((x1, y1, x2, y2)) + buf = io.BytesIO() + crop.save(buf, format=image_format.upper()) + + dest.parent.mkdir(parents=True, exist_ok=True) + with dest.open("wb") as f: + f.write(buf.getvalue()) + return True + except Exception as exc: + logger.warning("Failed to crop and write %s: %s", dest, exc) + return False + + +def _decode_page_image(page_image_b64: str) -> Image.Image | None: + """Decode a base64-encoded page image into a PIL Image (once per row).""" + try: + raw = base64.b64decode(page_image_b64) + return Image.open(io.BytesIO(raw)).convert("RGB") + except Exception as exc: + logger.warning("Failed to decode page image: %s", exc) + return None + + +def _build_uri_info( + dest: UPath, + storage_root: UPath, + public_base_url: Optional[str], +) -> Dict[str, Optional[str]]: + """Build a dict with ``stored_image_uri`` and optionally ``stored_image_url``.""" + relative_key = dest.relative_to(storage_root).as_posix() + info: Dict[str, Optional[str]] = {"stored_image_uri": dest.as_uri()} + if public_base_url: + info["stored_image_url"] = f"{public_base_url.rstrip('/')}/{relative_key}" + return info + + +def store_extracted_images( + df: pd.DataFrame, + *, + storage_uri: str = "stored_images", + storage_options: dict[str, Any] | None = None, + public_base_url: str | None = None, + store_page_images: bool = True, + store_tables: bool = True, + store_charts: bool = True, + store_infographics: bool = True, + store_images: bool = True, + image_format: str = "png", +) -> pd.DataFrame: + """Pipeline task: store extracted images to disk or cloud storage. + + For each row in the DataFrame: + + * Writes the full page image (from ``page_image["image_b64"]``) when + *store_page_images* is ``True``. + * Crops and writes sub-page images for tables / charts / infographics + using ``bbox_xyxy_norm`` from the page image. + * Writes natural sub-page images from the ``images`` column. + * Updates the DataFrame in-place with ``stored_image_uri`` (and an + optional ``stored_image_url`` when *public_base_url* is set). + + Parameters + ---------- + df : pd.DataFrame + Primitives DataFrame produced by the extraction pipeline. + storage_uri : str + Base URI for storage. Local path (``"./output"``) or + fsspec-compatible URI (``"s3://bucket/prefix"``). + storage_options : dict | None + Extra options forwarded to fsspec / UPath (auth keys, endpoint, etc.). + public_base_url : str | None + When set, each stored item also receives a ``stored_image_url`` + built as ``{public_base_url}/{relative_key}``. + store_page_images : bool + Save full page images. + store_tables : bool + Save table crops. + store_charts : bool + Save chart crops. + store_infographics : bool + Save infographic crops. + store_images : bool + Save natural sub-page images from the ``images`` column. + image_format : str + Output image format (default ``"png"``). + + Returns + ------- + pd.DataFrame + The (mutated) input DataFrame with storage URIs added. + """ + if not isinstance(df, pd.DataFrame) or df.empty: + return df + + storage_root = UPath(storage_uri, **(storage_options or {})) + ext = image_format.lower() + + col_flags: dict[str, str] = {} + if store_tables: + col_flags["tables"] = "table" + if store_charts: + col_flags["charts"] = "chart" + if store_infographics: + col_flags["infographics"] = "infographic" + + for idx, row in df.iterrows(): + try: + source_path = row.get("path") or "" + stem = _safe_stem(source_path) + page_num = row.get("page_number", 1) + + page_image = row.get("page_image") + page_image_b64: str | None = None + if isinstance(page_image, dict): + page_image_b64 = page_image.get("image_b64") + + # Decode the page image once for this row; reused by all crops. + page_pil: Image.Image | None = None + needs_crops = bool(col_flags) or store_images + if needs_crops and isinstance(page_image_b64, str) and page_image_b64: + page_pil = _decode_page_image(page_image_b64) + + # -- Full page image -- + if store_page_images and isinstance(page_image_b64, str) and page_image_b64: + dest = storage_root / stem / f"page_{page_num}.{ext}" + _decode_and_write(dest, page_image_b64) + uri_info = _build_uri_info(dest, storage_root, public_base_url) + page_image.update(uri_info) + df.at[idx, "page_image"] = page_image + + # -- Structured content (tables / charts / infographics) -- + for col_name, type_label in col_flags.items(): + content_list = row.get(col_name) + if not isinstance(content_list, list): + continue + for item_idx, item in enumerate(content_list): + if not isinstance(item, dict): + continue + dest = storage_root / stem / f"page_{page_num}_{type_label}_{item_idx}.{ext}" + item_b64 = item.get("image_b64") + if isinstance(item_b64, str) and item_b64: + _decode_and_write(dest, item_b64) + item.update(_build_uri_info(dest, storage_root, public_base_url)) + elif page_pil is not None: + bbox = item.get("bbox_xyxy_norm") + if bbox and len(bbox) == 4: + if _crop_and_write(dest, page_pil, bbox, image_format=ext): + item.update(_build_uri_info(dest, storage_root, public_base_url)) + df.at[idx, col_name] = content_list + + # -- Natural sub-page images -- + if store_images: + images_list = row.get("images") + if isinstance(images_list, list): + for img_idx, img_item in enumerate(images_list): + if not isinstance(img_item, dict): + continue + dest = storage_root / stem / f"page_{page_num}_image_{img_idx}.{ext}" + img_b64 = img_item.get("image_b64") + if isinstance(img_b64, str) and img_b64: + _decode_and_write(dest, img_b64) + img_item.update(_build_uri_info(dest, storage_root, public_base_url)) + elif page_pil is not None: + bbox = img_item.get("bbox_xyxy_norm") + if bbox and len(bbox) == 4: + if _crop_and_write(dest, page_pil, bbox, image_format=ext): + img_item.update(_build_uri_info(dest, storage_root, public_base_url)) + df.at[idx, "images"] = images_list + + except Exception as exc: + logger.exception("Failed to store images for row %s: %s", idx, exc) + + return df diff --git a/nemo_retriever/src/nemo_retriever/params/__init__.py b/nemo_retriever/src/nemo_retriever/params/__init__.py index 5f4eef723..9da4d0dd3 100644 --- a/nemo_retriever/src/nemo_retriever/params/__init__.py +++ b/nemo_retriever/src/nemo_retriever/params/__init__.py @@ -22,6 +22,7 @@ from .models import RemoteInvokeParams from .models import RemoteRetryParams from .models import RunMode +from .models import StoreParams from .models import TableParams from .models import TextChunkParams from .models import VdbUploadParams @@ -47,6 +48,7 @@ "RemoteInvokeParams", "RemoteRetryParams", "RunMode", + "StoreParams", "TableParams", "TextChunkParams", "VdbUploadParams", diff --git a/nemo_retriever/src/nemo_retriever/params/models.py b/nemo_retriever/src/nemo_retriever/params/models.py index 66e925162..fbc6fca1b 100644 --- a/nemo_retriever/src/nemo_retriever/params/models.py +++ b/nemo_retriever/src/nemo_retriever/params/models.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Literal, Optional, Sequence, Tuple +from typing import Any, Literal, Optional, Sequence, Tuple import warnings @@ -262,6 +262,18 @@ class VdbUploadParams(_ParamsModel): lancedb: LanceDbParams = Field(default_factory=LanceDbParams) +class StoreParams(_ParamsModel): + storage_uri: str = "stored_images" + storage_options: dict[str, Any] = Field(default_factory=dict) + public_base_url: Optional[str] = None + store_page_images: bool = True + store_tables: bool = True + store_charts: bool = True + store_infographics: bool = True + store_images: bool = True + image_format: str = "png" + + class PageElementsParams(_ParamsModel): remote: RemoteInvokeParams = Field(default_factory=RemoteInvokeParams) remote_retry: RemoteRetryParams = Field(default_factory=RemoteRetryParams) diff --git a/nemo_retriever/tests/test_io_image_store.py b/nemo_retriever/tests/test_io_image_store.py new file mode 100644 index 000000000..c098e5d00 --- /dev/null +++ b/nemo_retriever/tests/test_io_image_store.py @@ -0,0 +1,312 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for nemo_retriever.io.image_store.""" + +from __future__ import annotations + +import base64 +import io +from pathlib import Path + +import pandas as pd +import pytest + +from nemo_retriever.io.image_store import _safe_stem, store_extracted_images +from nemo_retriever.params import StoreParams + + +def _make_tiny_png_b64(width: int = 4, height: int = 4, color=(255, 0, 0)) -> str: + """Create a minimal PNG image encoded as base64.""" + from PIL import Image + + buf = io.BytesIO() + img = Image.new("RGB", (width, height), color=color) + img.save(buf, format="PNG") + return base64.b64encode(buf.getvalue()).decode("ascii") + + +# --------------------------------------------------------------------------- +# _safe_stem +# --------------------------------------------------------------------------- + + +class TestSafeStem: + def test_normal_path(self): + assert _safe_stem("/data/docs/report.pdf") == "report" + + def test_special_characters(self): + assert _safe_stem("my file (copy).pdf") == "my_file_copy_" + + def test_empty_string(self): + assert _safe_stem("") == "document" + + def test_none_value(self): + assert _safe_stem(None) == "document" + + def test_long_name_truncated(self): + long_name = "a" * 200 + ".pdf" + result = _safe_stem(long_name) + assert len(result) <= 160 + + def test_slashes_only(self): + assert _safe_stem("///") == "document" + + +# --------------------------------------------------------------------------- +# store_extracted_images — page images +# --------------------------------------------------------------------------- + + +class TestStorePageImages: + def test_writes_file_and_updates_uri(self, tmp_path: Path): + b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": b64, "encoding": "png"}, + "tables": [], + "charts": [], + "infographics": [], + "images": [], + "metadata": {}, + } + ] + ) + + result = store_extracted_images(df, storage_uri=str(tmp_path)) + expected_file = tmp_path / "test" / "page_1.png" + assert expected_file.exists() + assert expected_file.stat().st_size > 0 + + page_img = result.iloc[0]["page_image"] + assert "stored_image_uri" in page_img + assert page_img["stored_image_uri"].startswith("file://") + + def test_skips_when_no_page_image(self, tmp_path: Path): + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": None, + "tables": [], + "charts": [], + "infographics": [], + "images": [], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path)) + assert not any(tmp_path.rglob("*.png")) + assert result.iloc[0]["page_image"] is None + + def test_disabled_flag(self, tmp_path: Path): + b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": b64}, + "tables": [], + "charts": [], + "infographics": [], + "images": [], + "metadata": {}, + } + ] + ) + store_extracted_images(df, storage_uri=str(tmp_path), store_page_images=False) + assert not any(tmp_path.rglob("page_*.png")) + + +# --------------------------------------------------------------------------- +# store_extracted_images — structured content crops +# --------------------------------------------------------------------------- + + +class TestStoreStructuredContent: + def test_crop_table_from_page_image(self, tmp_path: Path): + b64 = _make_tiny_png_b64(width=100, height=100) + df = pd.DataFrame( + [ + { + "path": "/docs/report.pdf", + "page_number": 2, + "page_image": {"image_b64": b64, "encoding": "png"}, + "tables": [ + {"text": "col1|col2", "bbox_xyxy_norm": [0.1, 0.1, 0.9, 0.9]}, + ], + "charts": [], + "infographics": [], + "images": [], + "metadata": {}, + } + ] + ) + + result = store_extracted_images(df, storage_uri=str(tmp_path)) + expected_file = tmp_path / "report" / "page_2_table_0.png" + assert expected_file.exists() + + table_item = result.iloc[0]["tables"][0] + assert "stored_image_uri" in table_item + + def test_direct_image_b64_preferred(self, tmp_path: Path): + page_b64 = _make_tiny_png_b64(color=(255, 0, 0)) + item_b64 = _make_tiny_png_b64(color=(0, 255, 0)) + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": page_b64}, + "tables": [ + {"text": "data", "image_b64": item_b64, "bbox_xyxy_norm": [0, 0, 1, 1]}, + ], + "charts": [], + "infographics": [], + "images": [], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path)) + table_item = result.iloc[0]["tables"][0] + assert "stored_image_uri" in table_item + + # Verify the file was written from item_b64, not cropped from page + written_bytes = (tmp_path / "test" / "page_1_table_0.png").read_bytes() + assert written_bytes == base64.b64decode(item_b64) + + def test_selective_flags_skip_tables(self, tmp_path: Path): + b64 = _make_tiny_png_b64(width=100, height=100) + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": b64}, + "tables": [{"text": "t", "bbox_xyxy_norm": [0.1, 0.1, 0.5, 0.5]}], + "charts": [{"text": "c", "bbox_xyxy_norm": [0.1, 0.1, 0.5, 0.5]}], + "infographics": [], + "images": [], + "metadata": {}, + } + ] + ) + store_extracted_images(df, storage_uri=str(tmp_path), store_tables=False) + files = list(tmp_path.rglob("*.png")) + names = [f.name for f in files] + assert not any("table" in n for n in names) + assert any("chart" in n for n in names) + + +# --------------------------------------------------------------------------- +# store_extracted_images — natural sub-page images +# --------------------------------------------------------------------------- + + +class TestStoreNaturalImages: + def test_writes_from_image_b64(self, tmp_path: Path): + img_b64 = _make_tiny_png_b64(color=(0, 0, 255)) + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": None, + "tables": [], + "charts": [], + "infographics": [], + "images": [{"image_b64": img_b64, "bbox_xyxy_norm": [0, 0, 1, 1]}], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path)) + expected_file = tmp_path / "test" / "page_1_image_0.png" + assert expected_file.exists() + assert result.iloc[0]["images"][0].get("stored_image_uri") is not None + + +# --------------------------------------------------------------------------- +# store_extracted_images — edge cases +# --------------------------------------------------------------------------- + + +class TestStoreEdgeCases: + def test_empty_dataframe(self, tmp_path: Path): + df = pd.DataFrame() + result = store_extracted_images(df, storage_uri=str(tmp_path)) + assert result.empty + + def test_public_base_url(self, tmp_path: Path): + b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": b64}, + "tables": [], + "charts": [], + "infographics": [], + "images": [], + "metadata": {}, + } + ] + ) + result = store_extracted_images( + df, + storage_uri=str(tmp_path), + public_base_url="https://cdn.example.com/assets", + ) + page_img = result.iloc[0]["page_image"] + assert page_img["stored_image_url"] == "https://cdn.example.com/assets/test/page_1.png" + + def test_multiple_pages(self, tmp_path: Path): + b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": i, + "page_image": {"image_b64": b64}, + "tables": [], + "charts": [], + "infographics": [], + "images": [], + "metadata": {}, + } + for i in range(1, 4) + ] + ) + store_extracted_images(df, storage_uri=str(tmp_path)) + for i in range(1, 4): + assert (tmp_path / "test" / f"page_{i}.png").exists() + + +# --------------------------------------------------------------------------- +# StoreParams model +# --------------------------------------------------------------------------- + + +class TestStoreParams: + def test_defaults(self): + p = StoreParams() + assert p.storage_uri == "stored_images" + assert p.store_page_images is True + assert p.store_tables is True + assert p.image_format == "png" + + def test_overrides(self): + p = StoreParams(storage_uri="s3://bucket/prefix", store_tables=False, image_format="jpeg") + assert p.storage_uri == "s3://bucket/prefix" + assert p.store_tables is False + assert p.image_format == "jpeg" From 698ae6aa27c8dc9d110ef41ce750ff9807bc68b5 Mon Sep 17 00:00:00 2001 From: Jacob Ioffe Date: Tue, 24 Mar 2026 20:18:50 +0000 Subject: [PATCH 2/4] =?UTF-8?q?(retriever)=20Harden=20.store()=20=E2=80=94?= =?UTF-8?q?=20format=20sniffing,=20opt-in=20stripping,=20column=20name=20f?= =?UTF-8?q?ix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix column names (table/chart/infographic not plural) to match OCR output - Add magic byte sniffing so file extension matches actual image encoding - Make base64 stripping opt-in (strip_base64=False) to preserve multimodal compat - Add multimodal embed interaction tests and format consistency tests Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Jacob Ioffe --- .../src/nemo_retriever/io/image_store.py | 146 +++++++++-- .../src/nemo_retriever/params/models.py | 1 + nemo_retriever/tests/test_io_image_store.py | 241 +++++++++++++++--- nemo_retriever/tests/test_multimodal_embed.py | 54 +++- 4 files changed, 385 insertions(+), 57 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/io/image_store.py b/nemo_retriever/src/nemo_retriever/io/image_store.py index ed1cbcb49..45991a6dc 100644 --- a/nemo_retriever/src/nemo_retriever/io/image_store.py +++ b/nemo_retriever/src/nemo_retriever/io/image_store.py @@ -19,6 +19,9 @@ logger = logging.getLogger(__name__) +_DIRECT_IMAGE_FORMATS = {"png", "jpeg"} +_FORMAT_ALIASES = {"jpg": "jpeg"} + # Known limitation: _safe_stem derives the output subdirectory from the # filename alone (e.g. "report.pdf" → "report/"). Two source files with # the same basename but different parent directories will write to the same @@ -35,14 +38,62 @@ def _safe_stem(name: str) -> str: return s[:160] if len(s) > 160 else s -def _decode_and_write(dest: UPath, image_b64: str) -> None: - """Decode a base64 image and write raw bytes to *dest* via UPath.""" - raw = base64.b64decode(image_b64) +def _normalize_image_format(image_format: str) -> str: + fmt = str(image_format or "png").strip().lower() + fmt = _FORMAT_ALIASES.get(fmt, fmt) + if fmt not in _DIRECT_IMAGE_FORMATS: + raise ValueError(f"Unsupported image_format: {image_format!r}. Supported formats: png, jpeg") + return fmt + + +def _decode_image_bytes(image_b64: str) -> bytes | None: + try: + return base64.b64decode(image_b64) + except Exception as exc: + logger.warning("Failed to decode image payload: %s", exc) + return None + + +def _write_bytes(dest: UPath, raw: bytes) -> None: dest.parent.mkdir(parents=True, exist_ok=True) with dest.open("wb") as f: f.write(raw) +def _normalized_encoding(value: Any) -> str | None: + if not isinstance(value, str): + return None + enc = value.strip().lower() + enc = _FORMAT_ALIASES.get(enc, enc) + if enc in _DIRECT_IMAGE_FORMATS: + return enc + return None + + +def _sniff_image_encoding(raw: bytes) -> str | None: + if raw.startswith(b"\x89PNG\r\n\x1a\n"): + return "png" + if raw.startswith(b"\xff\xd8\xff"): + return "jpeg" + return None + + +def _resolve_direct_write_encoding( + payload: dict[str, Any] | None, + raw: bytes, + fallback: str, +) -> str: + sniffed = _sniff_image_encoding(raw) + if sniffed: + return sniffed + + declared = _normalized_encoding(payload.get("encoding")) if isinstance(payload, dict) else None + if declared: + return declared + + return fallback + + def _crop_and_write( dest: UPath, page_image: Image.Image, @@ -114,6 +165,7 @@ def store_extracted_images( store_infographics: bool = True, store_images: bool = True, image_format: str = "png", + strip_base64: bool = False, ) -> pd.DataFrame: """Pipeline task: store extracted images to disk or cloud storage. @@ -150,7 +202,12 @@ def store_extracted_images( store_images : bool Save natural sub-page images from the ``images`` column. image_format : str - Output image format (default ``"png"``). + Output image format for generated crops (default ``"png"``). + Direct-write payloads preserve their source encoding and file extension. + strip_base64 : bool + When ``True``, clear ``image_b64`` after successful writes to reduce memory + pressure. Keep this ``False`` when embedding with image-based modalities + after ``.store()``. Returns ------- @@ -160,16 +217,22 @@ def store_extracted_images( if not isinstance(df, pd.DataFrame) or df.empty: return df + logger.info("Storing extracted images to %s", storage_uri) storage_root = UPath(storage_uri, **(storage_options or {})) - ext = image_format.lower() + ext = _normalize_image_format(image_format) + if strip_base64: + logger.warning( + "strip_base64=True removes image payloads. If .store() runs before .embed() " + "with image/text_image modalities, embeddings will lose image input." + ) col_flags: dict[str, str] = {} if store_tables: - col_flags["tables"] = "table" + col_flags["table"] = "table" if store_charts: - col_flags["charts"] = "chart" + col_flags["chart"] = "chart" if store_infographics: - col_flags["infographics"] = "infographic" + col_flags["infographic"] = "infographic" for idx, row in df.iterrows(): try: @@ -182,19 +245,32 @@ def store_extracted_images( if isinstance(page_image, dict): page_image_b64 = page_image.get("image_b64") - # Decode the page image once for this row; reused by all crops. + # Decode the page image lazily and reuse for all crops in this row. page_pil: Image.Image | None = None - needs_crops = bool(col_flags) or store_images - if needs_crops and isinstance(page_image_b64, str) and page_image_b64: - page_pil = _decode_page_image(page_image_b64) + page_pil_decode_attempted = False + + def _get_page_pil() -> Image.Image | None: + nonlocal page_pil, page_pil_decode_attempted + if page_pil_decode_attempted: + return page_pil + page_pil_decode_attempted = True + if isinstance(page_image_b64, str) and page_image_b64: + page_pil = _decode_page_image(page_image_b64) + return page_pil # -- Full page image -- if store_page_images and isinstance(page_image_b64, str) and page_image_b64: - dest = storage_root / stem / f"page_{page_num}.{ext}" - _decode_and_write(dest, page_image_b64) - uri_info = _build_uri_info(dest, storage_root, public_base_url) - page_image.update(uri_info) - df.at[idx, "page_image"] = page_image + raw = _decode_image_bytes(page_image_b64) + if raw is not None: + direct_ext = _resolve_direct_write_encoding(page_image, raw, ext) + dest = storage_root / stem / f"page_{page_num}.{direct_ext}" + _write_bytes(dest, raw) + uri_info = _build_uri_info(dest, storage_root, public_base_url) + page_image.update(uri_info) + page_image["encoding"] = direct_ext + if strip_base64: + page_image["image_b64"] = None + df.at[idx, "page_image"] = page_image # -- Structured content (tables / charts / infographics) -- for col_name, type_label in col_flags.items(): @@ -204,16 +280,26 @@ def store_extracted_images( for item_idx, item in enumerate(content_list): if not isinstance(item, dict): continue - dest = storage_root / stem / f"page_{page_num}_{type_label}_{item_idx}.{ext}" item_b64 = item.get("image_b64") if isinstance(item_b64, str) and item_b64: - _decode_and_write(dest, item_b64) - item.update(_build_uri_info(dest, storage_root, public_base_url)) - elif page_pil is not None: + raw = _decode_image_bytes(item_b64) + if raw is not None: + direct_ext = _resolve_direct_write_encoding(item, raw, ext) + dest = storage_root / stem / f"page_{page_num}_{type_label}_{item_idx}.{direct_ext}" + _write_bytes(dest, raw) + item.update(_build_uri_info(dest, storage_root, public_base_url)) + item["encoding"] = direct_ext + if strip_base64: + item["image_b64"] = None + else: + page_pil = _get_page_pil() + if page_pil is not None and not (isinstance(item_b64, str) and item_b64): bbox = item.get("bbox_xyxy_norm") if bbox and len(bbox) == 4: + dest = storage_root / stem / f"page_{page_num}_{type_label}_{item_idx}.{ext}" if _crop_and_write(dest, page_pil, bbox, image_format=ext): item.update(_build_uri_info(dest, storage_root, public_base_url)) + item["encoding"] = ext df.at[idx, col_name] = content_list # -- Natural sub-page images -- @@ -223,16 +309,26 @@ def store_extracted_images( for img_idx, img_item in enumerate(images_list): if not isinstance(img_item, dict): continue - dest = storage_root / stem / f"page_{page_num}_image_{img_idx}.{ext}" img_b64 = img_item.get("image_b64") if isinstance(img_b64, str) and img_b64: - _decode_and_write(dest, img_b64) - img_item.update(_build_uri_info(dest, storage_root, public_base_url)) - elif page_pil is not None: + raw = _decode_image_bytes(img_b64) + if raw is not None: + direct_ext = _resolve_direct_write_encoding(img_item, raw, ext) + dest = storage_root / stem / f"page_{page_num}_image_{img_idx}.{direct_ext}" + _write_bytes(dest, raw) + img_item.update(_build_uri_info(dest, storage_root, public_base_url)) + img_item["encoding"] = direct_ext + if strip_base64: + img_item["image_b64"] = None + else: + page_pil = _get_page_pil() + if page_pil is not None and not (isinstance(img_b64, str) and img_b64): bbox = img_item.get("bbox_xyxy_norm") if bbox and len(bbox) == 4: + dest = storage_root / stem / f"page_{page_num}_image_{img_idx}.{ext}" if _crop_and_write(dest, page_pil, bbox, image_format=ext): img_item.update(_build_uri_info(dest, storage_root, public_base_url)) + img_item["encoding"] = ext df.at[idx, "images"] = images_list except Exception as exc: diff --git a/nemo_retriever/src/nemo_retriever/params/models.py b/nemo_retriever/src/nemo_retriever/params/models.py index fbc6fca1b..cc4111637 100644 --- a/nemo_retriever/src/nemo_retriever/params/models.py +++ b/nemo_retriever/src/nemo_retriever/params/models.py @@ -272,6 +272,7 @@ class StoreParams(_ParamsModel): store_infographics: bool = True store_images: bool = True image_format: str = "png" + strip_base64: bool = False class PageElementsParams(_ParamsModel): diff --git a/nemo_retriever/tests/test_io_image_store.py b/nemo_retriever/tests/test_io_image_store.py index c098e5d00..787f58ec0 100644 --- a/nemo_retriever/tests/test_io_image_store.py +++ b/nemo_retriever/tests/test_io_image_store.py @@ -11,7 +11,6 @@ from pathlib import Path import pandas as pd -import pytest from nemo_retriever.io.image_store import _safe_stem, store_extracted_images from nemo_retriever.params import StoreParams @@ -27,6 +26,16 @@ def _make_tiny_png_b64(width: int = 4, height: int = 4, color=(255, 0, 0)) -> st return base64.b64encode(buf.getvalue()).decode("ascii") +def _make_tiny_jpeg_b64(width: int = 4, height: int = 4, color=(255, 0, 0)) -> str: + """Create a minimal JPEG image encoded as base64.""" + from PIL import Image + + buf = io.BytesIO() + img = Image.new("RGB", (width, height), color=color) + img.save(buf, format="JPEG") + return base64.b64encode(buf.getvalue()).decode("ascii") + + # --------------------------------------------------------------------------- # _safe_stem # --------------------------------------------------------------------------- @@ -68,9 +77,9 @@ def test_writes_file_and_updates_uri(self, tmp_path: Path): "path": "/docs/test.pdf", "page_number": 1, "page_image": {"image_b64": b64, "encoding": "png"}, - "tables": [], - "charts": [], - "infographics": [], + "table": [], + "chart": [], + "infographic": [], "images": [], "metadata": {}, } @@ -93,9 +102,9 @@ def test_skips_when_no_page_image(self, tmp_path: Path): "path": "/docs/test.pdf", "page_number": 1, "page_image": None, - "tables": [], - "charts": [], - "infographics": [], + "table": [], + "chart": [], + "infographic": [], "images": [], "metadata": {}, } @@ -113,9 +122,9 @@ def test_disabled_flag(self, tmp_path: Path): "path": "/docs/test.pdf", "page_number": 1, "page_image": {"image_b64": b64}, - "tables": [], - "charts": [], - "infographics": [], + "table": [], + "chart": [], + "infographic": [], "images": [], "metadata": {}, } @@ -139,11 +148,11 @@ def test_crop_table_from_page_image(self, tmp_path: Path): "path": "/docs/report.pdf", "page_number": 2, "page_image": {"image_b64": b64, "encoding": "png"}, - "tables": [ + "table": [ {"text": "col1|col2", "bbox_xyxy_norm": [0.1, 0.1, 0.9, 0.9]}, ], - "charts": [], - "infographics": [], + "chart": [], + "infographic": [], "images": [], "metadata": {}, } @@ -154,7 +163,7 @@ def test_crop_table_from_page_image(self, tmp_path: Path): expected_file = tmp_path / "report" / "page_2_table_0.png" assert expected_file.exists() - table_item = result.iloc[0]["tables"][0] + table_item = result.iloc[0]["table"][0] assert "stored_image_uri" in table_item def test_direct_image_b64_preferred(self, tmp_path: Path): @@ -166,18 +175,18 @@ def test_direct_image_b64_preferred(self, tmp_path: Path): "path": "/docs/test.pdf", "page_number": 1, "page_image": {"image_b64": page_b64}, - "tables": [ + "table": [ {"text": "data", "image_b64": item_b64, "bbox_xyxy_norm": [0, 0, 1, 1]}, ], - "charts": [], - "infographics": [], + "chart": [], + "infographic": [], "images": [], "metadata": {}, } ] ) result = store_extracted_images(df, storage_uri=str(tmp_path)) - table_item = result.iloc[0]["tables"][0] + table_item = result.iloc[0]["table"][0] assert "stored_image_uri" in table_item # Verify the file was written from item_b64, not cropped from page @@ -192,9 +201,9 @@ def test_selective_flags_skip_tables(self, tmp_path: Path): "path": "/docs/test.pdf", "page_number": 1, "page_image": {"image_b64": b64}, - "tables": [{"text": "t", "bbox_xyxy_norm": [0.1, 0.1, 0.5, 0.5]}], - "charts": [{"text": "c", "bbox_xyxy_norm": [0.1, 0.1, 0.5, 0.5]}], - "infographics": [], + "table": [{"text": "t", "bbox_xyxy_norm": [0.1, 0.1, 0.5, 0.5]}], + "chart": [{"text": "c", "bbox_xyxy_norm": [0.1, 0.1, 0.5, 0.5]}], + "infographic": [], "images": [], "metadata": {}, } @@ -221,9 +230,9 @@ def test_writes_from_image_b64(self, tmp_path: Path): "path": "/docs/test.pdf", "page_number": 1, "page_image": None, - "tables": [], - "charts": [], - "infographics": [], + "table": [], + "chart": [], + "infographic": [], "images": [{"image_b64": img_b64, "bbox_xyxy_norm": [0, 0, 1, 1]}], "metadata": {}, } @@ -235,6 +244,174 @@ def test_writes_from_image_b64(self, tmp_path: Path): assert result.iloc[0]["images"][0].get("stored_image_uri") is not None +# --------------------------------------------------------------------------- +# store_extracted_images — format consistency +# --------------------------------------------------------------------------- + + +class TestFormatConsistency: + def test_page_image_keeps_source_encoding_extension(self, tmp_path: Path): + jpeg_b64 = _make_tiny_jpeg_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": jpeg_b64, "encoding": "jpeg"}, + "table": [], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + + result = store_extracted_images(df, storage_uri=str(tmp_path), image_format="png") + expected_file = tmp_path / "test" / "page_1.jpeg" + assert expected_file.exists() + assert not (tmp_path / "test" / "page_1.png").exists() + assert expected_file.read_bytes() == base64.b64decode(jpeg_b64) + assert result.iloc[0]["page_image"]["stored_image_uri"].endswith("/test/page_1.jpeg") + + def test_direct_content_b64_keeps_payload_extension(self, tmp_path: Path): + item_b64 = _make_tiny_png_b64(color=(0, 255, 0)) + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": None, + "table": [{"image_b64": item_b64, "bbox_xyxy_norm": [0, 0, 1, 1]}], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + + result = store_extracted_images(df, storage_uri=str(tmp_path), image_format="jpeg") + expected_file = tmp_path / "test" / "page_1_table_0.png" + assert expected_file.exists() + assert not (tmp_path / "test" / "page_1_table_0.jpeg").exists() + assert expected_file.read_bytes() == base64.b64decode(item_b64) + assert result.iloc[0]["table"][0]["stored_image_uri"].endswith("/test/page_1_table_0.png") + + def test_crops_use_requested_output_format(self, tmp_path: Path): + page_b64 = _make_tiny_png_b64(width=100, height=100) + df = pd.DataFrame( + [ + { + "path": "/docs/report.pdf", + "page_number": 2, + "page_image": {"image_b64": page_b64, "encoding": "png"}, + "table": [{"bbox_xyxy_norm": [0.2, 0.2, 0.8, 0.8]}], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + + result = store_extracted_images(df, storage_uri=str(tmp_path), image_format="jpeg") + expected_file = tmp_path / "report" / "page_2_table_0.jpeg" + assert expected_file.exists() + assert expected_file.read_bytes().startswith(b"\xff\xd8\xff") + assert result.iloc[0]["table"][0]["stored_image_uri"].endswith("/report/page_2_table_0.jpeg") + + +# --------------------------------------------------------------------------- +# store_extracted_images — base64 stripping +# --------------------------------------------------------------------------- + + +class TestBase64Stripping: + def test_page_image_b64_preserved_by_default(self, tmp_path: Path): + b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": b64, "encoding": "png"}, + "table": [], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path)) + page_img = result.iloc[0]["page_image"] + assert page_img["image_b64"] == b64 + assert "stored_image_uri" in page_img + + def test_page_image_b64_stripped_when_enabled(self, tmp_path: Path): + b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": b64, "encoding": "png"}, + "table": [], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path), strip_base64=True) + page_img = result.iloc[0]["page_image"] + assert page_img["image_b64"] is None + assert "stored_image_uri" in page_img + + def test_structured_content_b64_stripped(self, tmp_path: Path): + page_b64 = _make_tiny_png_b64(width=100, height=100) + item_b64 = _make_tiny_png_b64(color=(0, 255, 0)) + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": page_b64}, + "table": [{"text": "data", "image_b64": item_b64, "bbox_xyxy_norm": [0, 0, 1, 1]}], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path), strip_base64=True) + assert result.iloc[0]["table"][0]["image_b64"] is None + assert "stored_image_uri" in result.iloc[0]["table"][0] + + def test_natural_image_b64_stripped(self, tmp_path: Path): + img_b64 = _make_tiny_png_b64(color=(0, 0, 255)) + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": None, + "table": [], + "chart": [], + "infographic": [], + "images": [{"image_b64": img_b64, "bbox_xyxy_norm": [0, 0, 1, 1]}], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path), strip_base64=True) + assert result.iloc[0]["images"][0]["image_b64"] is None + assert "stored_image_uri" in result.iloc[0]["images"][0] + + # --------------------------------------------------------------------------- # store_extracted_images — edge cases # --------------------------------------------------------------------------- @@ -254,9 +431,9 @@ def test_public_base_url(self, tmp_path: Path): "path": "/docs/test.pdf", "page_number": 1, "page_image": {"image_b64": b64}, - "tables": [], - "charts": [], - "infographics": [], + "table": [], + "chart": [], + "infographic": [], "images": [], "metadata": {}, } @@ -278,9 +455,9 @@ def test_multiple_pages(self, tmp_path: Path): "path": "/docs/test.pdf", "page_number": i, "page_image": {"image_b64": b64}, - "tables": [], - "charts": [], - "infographics": [], + "table": [], + "chart": [], + "infographic": [], "images": [], "metadata": {}, } @@ -304,9 +481,11 @@ def test_defaults(self): assert p.store_page_images is True assert p.store_tables is True assert p.image_format == "png" + assert p.strip_base64 is False def test_overrides(self): - p = StoreParams(storage_uri="s3://bucket/prefix", store_tables=False, image_format="jpeg") + p = StoreParams(storage_uri="s3://bucket/prefix", store_tables=False, image_format="jpeg", strip_base64=True) assert p.storage_uri == "s3://bucket/prefix" assert p.store_tables is False assert p.image_format == "jpeg" + assert p.strip_base64 is True diff --git a/nemo_retriever/tests/test_multimodal_embed.py b/nemo_retriever/tests/test_multimodal_embed.py index f357193ef..692575143 100644 --- a/nemo_retriever/tests/test_multimodal_embed.py +++ b/nemo_retriever/tests/test_multimodal_embed.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. # All rights reserved. # SPDX-License-Identifier: Apache-2.0 @@ -8,6 +8,8 @@ from __future__ import annotations +import base64 +import io import sys from unittest.mock import MagicMock, patch @@ -23,6 +25,7 @@ _image_from_row, _multimodal_callable_runner, ) +from nemo_retriever.io.image_store import store_extracted_images # --------------------------------------------------------------------------- # Stub heavy internal modules so ``from nemo_retriever.ingest_modes.inprocess`` @@ -94,6 +97,15 @@ del _injected +def _make_tiny_png_b64(width: int = 8, height: int = 8, color=(255, 0, 0)) -> str: + from PIL import Image + + buf = io.BytesIO() + img = Image.new("RGB", (width, height), color=color) + img.save(buf, format="PNG") + return base64.b64encode(buf.getvalue()).decode("ascii") + + # =================================================================== # Pure helpers # =================================================================== @@ -237,6 +249,46 @@ def test_text_image_carries_image(self, mock_crop): ) +class TestStoreThenExplodeMultimodal: + def test_store_keeps_image_payloads_by_default(self, tmp_path): + page_b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "text": "hello", + "page_image": {"image_b64": page_b64, "encoding": "png"}, + } + ] + ) + + stored = store_extracted_images(df, storage_uri=str(tmp_path)) + exploded = explode_content_to_rows(stored, modality="text_image") + + assert exploded["_image_b64"].iloc[0] == page_b64 + assert exploded["_embed_modality"].iloc[0] == "text_image" + + def test_store_strip_base64_removes_multimodal_image_inputs(self, tmp_path): + page_b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "text": "hello", + "page_image": {"image_b64": page_b64, "encoding": "png"}, + } + ] + ) + + stored = store_extracted_images(df, storage_uri=str(tmp_path), strip_base64=True) + exploded = explode_content_to_rows(stored, modality="text_image") + + assert stored.iloc[0]["page_image"]["image_b64"] is None + assert exploded["_image_b64"].iloc[0] is None + + # =================================================================== # collapse_content_to_page_rows # =================================================================== From a36a1eaaa216b2ca84f8e347b95e912c8e89ff16 Mon Sep 17 00:00:00 2001 From: Jacob Ioffe Date: Wed, 25 Mar 2026 18:33:44 +0000 Subject: [PATCH 3/4] WIP: checkpoint before merge upstream/main (OOM-related fixes) Signed-off-by: Jacob Ioffe --- .../nemo_retriever/examples/batch_pipeline.py | 2 + .../nemo_retriever/ingest_modes/inprocess.py | 61 +++++++++++----- .../src/nemo_retriever/io/__init__.py | 3 +- .../src/nemo_retriever/io/image_store.py | 28 +++++--- .../src/nemo_retriever/params/models.py | 2 +- nemo_retriever/tests/test_io_image_store.py | 49 ++++++++++++- nemo_retriever/tests/test_multimodal_embed.py | 70 +++++++++++++++++-- 7 files changed, 178 insertions(+), 37 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py index ed8147ef8..244e2371e 100644 --- a/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py @@ -721,6 +721,8 @@ def _extract_params(batch_tuning: dict, **overrides: Any) -> ExtractParams: handle_lancedb(ingest_local_results, lancedb_uri, LANCEDB_TABLE, hybrid=hybrid, mode="overwrite") lancedb_write_time = time.perf_counter() - lancedb_write_start + del ingest_local_results # free driver heap before error check + if isinstance(ingestor, BatchIngestor): error_rows = ingestor.get_error_rows(dataset=ingest_results).materialize() error_count = int(error_rows.count()) diff --git a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py index a4d93d04e..dd49f0fa7 100644 --- a/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py +++ b/nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py @@ -55,7 +55,9 @@ from ..params import HtmlChunkParams from ..params import IngestExecuteParams from ..params import TextChunkParams +from ..params import StoreParams from ..params import VdbUploadParams +from ..io.image_store import load_image_b64_from_uri from ..pdf.extract import pdf_extraction from ..pdf.split import _split_pdf_to_single_page_bytes, pdf_path_to_pages_df from ..utils.remote_auth import resolve_remote_api_key @@ -104,6 +106,41 @@ def _deep_copy_row(row_dict: Dict[str, Any]) -> Dict[str, Any]: return out +def _b64_from_dict(d: dict) -> Optional[str]: + """Resolve base64 from a dict: inline ``image_b64`` then ``stored_image_uri``.""" + b64 = d.get("image_b64") + if isinstance(b64, str) and b64: + return b64 + uri = d.get("stored_image_uri") + if isinstance(uri, str) and uri: + return load_image_b64_from_uri(uri) + return None + + +def _resolve_page_image_b64(page_image: Any) -> Optional[str]: + """Return base64 for a page image, loading from stored URI if needed.""" + if not isinstance(page_image, dict): + return None + return _b64_from_dict(page_image) + + +def _resolve_item_image_b64(item: dict, page_image_b64: Optional[str]) -> Optional[str]: + """Return base64 for a structured content item, with fallback chain. + + Priority: inline base64 > stored URI > crop from page image. + """ + resolved = _b64_from_dict(item) + if resolved: + return resolved + if page_image_b64: + bbox = item.get("bbox_xyxy_norm") + if bbox and len(bbox) == 4: + cropped_b64, _ = _crop_b64_image_by_norm_bbox(page_image_b64, bbox_xyxy_norm=bbox) + return cropped_b64 + return page_image_b64 + return None + + def explode_content_to_rows( batch_df: Any, *, @@ -162,9 +199,7 @@ def explode_content_to_rows( if not any(c in batch_df.columns for c in content_columns): batch_df = batch_df.copy() if text_mod in IMAGE_MODALITIES and "page_image" in batch_df.columns: - batch_df["_image_b64"] = batch_df["page_image"].apply( - lambda pi: pi.get("image_b64") if isinstance(pi, dict) else None - ) + batch_df["_image_b64"] = batch_df["page_image"].apply(_resolve_page_image_b64) batch_df["_embed_modality"] = text_mod return batch_df @@ -176,8 +211,8 @@ def explode_content_to_rows( # Extract page-level image b64 once per source row. page_image = row_dict.get("page_image") page_image_b64: Optional[str] = None - if any_images and isinstance(page_image, dict): - page_image_b64 = page_image.get("image_b64") + if any_images: + page_image_b64 = _resolve_page_image_b64(page_image) # Row for page text. page_text = row_dict.get(text_column) @@ -205,15 +240,8 @@ def explode_content_to_rows( content_row[text_column] = t.strip() content_row["_embed_modality"] = struct_mod content_row["_content_type"] = col - if struct_mod in IMAGE_MODALITIES and page_image_b64: - bbox = item.get("bbox_xyxy_norm") - if bbox and len(bbox) == 4: - cropped_b64, _ = _crop_b64_image_by_norm_bbox(page_image_b64, bbox_xyxy_norm=bbox) - content_row["_image_b64"] = cropped_b64 - else: - content_row["_image_b64"] = page_image_b64 - elif struct_mod in IMAGE_MODALITIES: - content_row["_image_b64"] = None + if struct_mod in IMAGE_MODALITIES: + content_row["_image_b64"] = _resolve_item_image_b64(item, page_image_b64) new_rows.append(content_row) exploded_any = True @@ -266,9 +294,7 @@ def collapse_content_to_page_rows( # Full page image (no cropping) for image modalities. if modality in IMAGE_MODALITIES: if "page_image" in batch_df.columns: - batch_df["_image_b64"] = batch_df["page_image"].apply( - lambda pi: pi.get("image_b64") if isinstance(pi, dict) else None - ) + batch_df["_image_b64"] = batch_df["page_image"].apply(_resolve_page_image_b64) else: batch_df["_image_b64"] = None @@ -1331,7 +1357,6 @@ def extract_audio( def store(self, params: "StoreParams | None" = None, **kwargs: Any) -> "InProcessIngestor": """Store extracted images to disk or cloud storage via fsspec.""" from nemo_retriever.io.image_store import store_extracted_images - from nemo_retriever.params import StoreParams resolved = _coerce_params(params, StoreParams, kwargs) store_kwargs = resolved.model_dump(mode="python") diff --git a/nemo_retriever/src/nemo_retriever/io/__init__.py b/nemo_retriever/src/nemo_retriever/io/__init__.py index 0dcde6331..8ebe1198c 100644 --- a/nemo_retriever/src/nemo_retriever/io/__init__.py +++ b/nemo_retriever/src/nemo_retriever/io/__init__.py @@ -3,13 +3,14 @@ # SPDX-License-Identifier: Apache-2.0 from .dataframe import read_dataframe, validate_primitives_dataframe, write_dataframe -from .image_store import store_extracted_images +from .image_store import load_image_b64_from_uri, store_extracted_images from .markdown import to_markdown, to_markdown_by_page from .stage_files import build_stage_output_path, find_stage_inputs __all__ = [ "build_stage_output_path", "find_stage_inputs", + "load_image_b64_from_uri", "read_dataframe", "store_extracted_images", "to_markdown", diff --git a/nemo_retriever/src/nemo_retriever/io/image_store.py b/nemo_retriever/src/nemo_retriever/io/image_store.py index 45991a6dc..3d19d4d15 100644 --- a/nemo_retriever/src/nemo_retriever/io/image_store.py +++ b/nemo_retriever/src/nemo_retriever/io/image_store.py @@ -153,6 +153,20 @@ def _build_uri_info( return info +def load_image_b64_from_uri(uri: str) -> Optional[str]: + """Read an image from a stored URI and return its base64 encoding. + + Accepts any fsspec-compatible URI (``file://``, ``s3://``, etc.). + Returns ``None`` on failure so callers can fall back gracefully. + """ + try: + raw = UPath(uri).read_bytes() + return base64.b64encode(raw).decode("ascii") + except Exception as exc: + logger.warning("Failed to load image from %s: %s", uri, exc) + return None + + def store_extracted_images( df: pd.DataFrame, *, @@ -165,7 +179,7 @@ def store_extracted_images( store_infographics: bool = True, store_images: bool = True, image_format: str = "png", - strip_base64: bool = False, + strip_base64: bool = True, ) -> pd.DataFrame: """Pipeline task: store extracted images to disk or cloud storage. @@ -205,9 +219,10 @@ def store_extracted_images( Output image format for generated crops (default ``"png"``). Direct-write payloads preserve their source encoding and file extension. strip_base64 : bool - When ``True``, clear ``image_b64`` after successful writes to reduce memory - pressure. Keep this ``False`` when embedding with image-based modalities - after ``.store()``. + When ``True`` (the default), clear ``image_b64`` after successful writes + to reduce memory pressure. The embed stage loads images from the stored + URIs when base64 is absent. Set to ``False`` only if downstream code + requires inline base64 for a reason other than embedding. Returns ------- @@ -221,10 +236,7 @@ def store_extracted_images( storage_root = UPath(storage_uri, **(storage_options or {})) ext = _normalize_image_format(image_format) if strip_base64: - logger.warning( - "strip_base64=True removes image payloads. If .store() runs before .embed() " - "with image/text_image modalities, embeddings will lose image input." - ) + logger.debug("strip_base64=True: image payloads will be cleared after writing.") col_flags: dict[str, str] = {} if store_tables: diff --git a/nemo_retriever/src/nemo_retriever/params/models.py b/nemo_retriever/src/nemo_retriever/params/models.py index cc4111637..4c887ae4f 100644 --- a/nemo_retriever/src/nemo_retriever/params/models.py +++ b/nemo_retriever/src/nemo_retriever/params/models.py @@ -272,7 +272,7 @@ class StoreParams(_ParamsModel): store_infographics: bool = True store_images: bool = True image_format: str = "png" - strip_base64: bool = False + strip_base64: bool = True class PageElementsParams(_ParamsModel): diff --git a/nemo_retriever/tests/test_io_image_store.py b/nemo_retriever/tests/test_io_image_store.py index 787f58ec0..4b379d864 100644 --- a/nemo_retriever/tests/test_io_image_store.py +++ b/nemo_retriever/tests/test_io_image_store.py @@ -12,7 +12,7 @@ import pandas as pd -from nemo_retriever.io.image_store import _safe_stem, store_extracted_images +from nemo_retriever.io.image_store import _safe_stem, load_image_b64_from_uri, store_extracted_images from nemo_retriever.params import StoreParams @@ -328,7 +328,7 @@ def test_crops_use_requested_output_format(self, tmp_path: Path): class TestBase64Stripping: - def test_page_image_b64_preserved_by_default(self, tmp_path: Path): + def test_page_image_b64_stripped_by_default(self, tmp_path: Path): b64 = _make_tiny_png_b64() df = pd.DataFrame( [ @@ -346,6 +346,27 @@ def test_page_image_b64_preserved_by_default(self, tmp_path: Path): ) result = store_extracted_images(df, storage_uri=str(tmp_path)) page_img = result.iloc[0]["page_image"] + assert page_img["image_b64"] is None + assert "stored_image_uri" in page_img + + def test_page_image_b64_preserved_when_strip_disabled(self, tmp_path: Path): + b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "page_image": {"image_b64": b64, "encoding": "png"}, + "table": [], + "chart": [], + "infographic": [], + "images": [], + "metadata": {}, + } + ] + ) + result = store_extracted_images(df, storage_uri=str(tmp_path), strip_base64=False) + page_img = result.iloc[0]["page_image"] assert page_img["image_b64"] == b64 assert "stored_image_uri" in page_img @@ -469,6 +490,28 @@ def test_multiple_pages(self, tmp_path: Path): assert (tmp_path / "test" / f"page_{i}.png").exists() +# --------------------------------------------------------------------------- +# load_image_b64_from_uri +# --------------------------------------------------------------------------- + + +class TestLoadImageB64FromUri: + def test_round_trip(self, tmp_path: Path): + from PIL import Image + + img = Image.new("RGB", (4, 4), (0, 255, 0)) + dest = tmp_path / "green.png" + img.save(dest, format="PNG") + result = load_image_b64_from_uri(dest.as_uri()) + assert result is not None + raw = base64.b64decode(result) + assert raw.startswith(b"\x89PNG") + + def test_missing_file_returns_none(self): + result = load_image_b64_from_uri("file:///nonexistent/path/image.png") + assert result is None + + # --------------------------------------------------------------------------- # StoreParams model # --------------------------------------------------------------------------- @@ -481,7 +524,7 @@ def test_defaults(self): assert p.store_page_images is True assert p.store_tables is True assert p.image_format == "png" - assert p.strip_base64 is False + assert p.strip_base64 is True def test_overrides(self): p = StoreParams(storage_uri="s3://bucket/prefix", store_tables=False, image_format="jpeg", strip_base64=True) diff --git a/nemo_retriever/tests/test_multimodal_embed.py b/nemo_retriever/tests/test_multimodal_embed.py index 692575143..491056f88 100644 --- a/nemo_retriever/tests/test_multimodal_embed.py +++ b/nemo_retriever/tests/test_multimodal_embed.py @@ -250,7 +250,7 @@ def test_text_image_carries_image(self, mock_crop): class TestStoreThenExplodeMultimodal: - def test_store_keeps_image_payloads_by_default(self, tmp_path): + def test_store_preserves_b64_when_strip_disabled(self, tmp_path): page_b64 = _make_tiny_png_b64() df = pd.DataFrame( [ @@ -263,13 +263,13 @@ def test_store_keeps_image_payloads_by_default(self, tmp_path): ] ) - stored = store_extracted_images(df, storage_uri=str(tmp_path)) + stored = store_extracted_images(df, storage_uri=str(tmp_path), strip_base64=False) exploded = explode_content_to_rows(stored, modality="text_image") assert exploded["_image_b64"].iloc[0] == page_b64 assert exploded["_embed_modality"].iloc[0] == "text_image" - def test_store_strip_base64_removes_multimodal_image_inputs(self, tmp_path): + def test_store_strip_then_explode_loads_from_uri(self, tmp_path): page_b64 = _make_tiny_png_b64() df = pd.DataFrame( [ @@ -282,11 +282,46 @@ def test_store_strip_base64_removes_multimodal_image_inputs(self, tmp_path): ] ) - stored = store_extracted_images(df, storage_uri=str(tmp_path), strip_base64=True) + stored = store_extracted_images(df, storage_uri=str(tmp_path)) + assert stored.iloc[0]["page_image"]["image_b64"] is None + assert stored.iloc[0]["page_image"]["stored_image_uri"] is not None + exploded = explode_content_to_rows(stored, modality="text_image") + loaded_b64 = exploded["_image_b64"].iloc[0] + assert isinstance(loaded_b64, str) and len(loaded_b64) > 0 - assert stored.iloc[0]["page_image"]["image_b64"] is None - assert exploded["_image_b64"].iloc[0] is None + def test_store_strip_then_explode_structured_loads_from_uri(self, tmp_path): + page_b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "text": "hello", + "page_image": {"image_b64": page_b64, "encoding": "png"}, + "table": [ + { + "text": "table data", + "image_b64": page_b64, + "encoding": "png", + } + ], + "chart": [], + "infographic": [], + } + ] + ) + + stored = store_extracted_images(df, storage_uri=str(tmp_path)) + table_item = stored.iloc[0]["table"][0] + assert table_item["image_b64"] is None + assert table_item["stored_image_uri"] is not None + + exploded = explode_content_to_rows(stored, modality="text", structured_elements_modality="text_image") + struct_rows = exploded[exploded["_content_type"] == "table"] + assert len(struct_rows) == 1 + loaded_b64 = struct_rows.iloc[0]["_image_b64"] + assert isinstance(loaded_b64, str) and len(loaded_b64) > 0 # =================================================================== @@ -382,3 +417,26 @@ def test_non_dataframe_passthrough(self): """Non-DataFrame input is returned as-is.""" result = collapse_content_to_page_rows(None) assert result is None + + def test_collapse_resolves_uri_when_b64_stripped(self, tmp_path): + page_b64 = _make_tiny_png_b64() + df = pd.DataFrame( + [ + { + "path": "/docs/test.pdf", + "page_number": 1, + "text": "hello", + "page_image": {"image_b64": page_b64, "encoding": "png"}, + "table": [], + "chart": [], + "infographic": [], + } + ] + ) + + stored = store_extracted_images(df, storage_uri=str(tmp_path)) + assert stored.iloc[0]["page_image"]["image_b64"] is None + + result = collapse_content_to_page_rows(stored, modality="text_image") + loaded_b64 = result["_image_b64"].iloc[0] + assert isinstance(loaded_b64, str) and len(loaded_b64) > 0 From fcfc67f16c959a99e7de22a2f2bc0e056f69a01a Mon Sep 17 00:00:00 2001 From: Jacob Ioffe Date: Thu, 26 Mar 2026 15:48:14 +0000 Subject: [PATCH 4/4] =?UTF-8?q?(retriever)=20Fix=20relative=20path=20URI?= =?UTF-8?q?=20error=20in=20.store()=20=E2=80=94=20resolve=20=20=20storage?= =?UTF-8?q?=5Furi=20to=20absolute?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Jacob Ioffe --- nemo_retriever/src/nemo_retriever/io/image_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_retriever/src/nemo_retriever/io/image_store.py b/nemo_retriever/src/nemo_retriever/io/image_store.py index 3d19d4d15..1109b7894 100644 --- a/nemo_retriever/src/nemo_retriever/io/image_store.py +++ b/nemo_retriever/src/nemo_retriever/io/image_store.py @@ -233,7 +233,7 @@ def store_extracted_images( return df logger.info("Storing extracted images to %s", storage_uri) - storage_root = UPath(storage_uri, **(storage_options or {})) + storage_root = UPath(storage_uri, **(storage_options or {})).resolve() ext = _normalize_image_format(image_format) if strip_base64: logger.debug("strip_base64=True: image payloads will be cleared after writing.")