Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo_retriever/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ dependencies = [
"soundfile>=0.12.0",
"scipy>=1.11.0",
"nvidia-ml-py",
"fsspec>=2023.1.0",
"universal-pathlib>=0.2.0",
"vllm==0.16.0",
]

Expand Down
11 changes: 11 additions & 0 deletions nemo_retriever/src/nemo_retriever/examples/batch_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nemo_retriever.params import CaptionParams
from nemo_retriever.params import DedupParams
from nemo_retriever.params import EmbedParams
from nemo_retriever.params import StoreParams
from nemo_retriever.params import ExtractParams
from nemo_retriever.params import TextChunkParams
from nemo_retriever.params.models import BatchTuningParams
Expand Down Expand Up @@ -329,6 +330,11 @@ def main(
table_structure_invoke_url: Optional[str] = typer.Option(
None, "--table-structure-invoke-url", help="Remote endpoint URL for table-structure model inference."
),
store_images_uri: Optional[str] = typer.Option(
None,
"--store-images-uri",
help="When set, store extracted images to this URI (local path or fsspec URI like s3://bucket/prefix).",
),
extract_text: bool = typer.Option(
True, "--extract-text/--no-extract-text", help="Enable text extraction from documents."
),
Expand Down Expand Up @@ -549,6 +555,9 @@ def main(
if enable_text_chunk:
ingestor = ingestor.split(text_chunk_params)

if store_images_uri:
ingestor = ingestor.store(StoreParams(storage_uri=store_images_uri))

enable_caption = caption or caption_invoke_url is not None
enable_dedup = dedup if dedup is not None else enable_caption
if enable_dedup:
Expand Down Expand Up @@ -614,6 +623,8 @@ def main(
handle_lancedb(ingest_local_results, lancedb_uri, LANCEDB_TABLE, hybrid=hybrid, mode="overwrite")
lancedb_write_time = time.perf_counter() - lancedb_write_start

del ingest_local_results # free driver heap before recall

from nemo_retriever.model import resolve_embed_model
from nemo_retriever.recall.beir import BeirConfig
from nemo_retriever.recall.core import RecallConfig, retrieve_and_score
Expand Down
20 changes: 20 additions & 0 deletions nemo_retriever/src/nemo_retriever/ingest_modes/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from ..params import IngestExecuteParams
from ..params import PdfSplitParams
from ..params import TextChunkParams
from ..params import StoreParams
from ..params import CaptionParams
from ..params import VdbUploadParams

Expand Down Expand Up @@ -1037,6 +1038,25 @@ def embed(

return self

def store(self, params: StoreParams | None = None, **kwargs: Any) -> "BatchIngestor":
"""Store extracted images to disk or cloud storage via fsspec."""
if self._rd_dataset is None:
raise RuntimeError("No Ray Dataset to store from. Call .files(...) / .extract(...) first.")

from nemo_retriever.io.image_store import store_extracted_images

resolved = _coerce_params(params, StoreParams, kwargs)
store_kwargs = resolved.model_dump(mode="python")
self._tasks.append(("store", dict(store_kwargs)))

_store_fn = partial(store_extracted_images, **store_kwargs)
self._rd_dataset = self._rd_dataset.map_batches(
_store_fn,
batch_format="pandas",
num_cpus=1,
)
return self

def dedup(self, params: "DedupParams | None" = None, **kwargs: Any) -> "BatchIngestor":
"""Remove duplicate and overlapping images before captioning."""
if self._rd_dataset is None:
Expand Down
70 changes: 52 additions & 18 deletions nemo_retriever/src/nemo_retriever/ingest_modes/inprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@
from ..params import HtmlChunkParams
from ..params import IngestExecuteParams
from ..params import TextChunkParams
from ..params import StoreParams
from ..params import VdbUploadParams
from ..io.image_store import load_image_b64_from_uri
from ..pdf.extract import pdf_extraction
from ..pdf.split import _split_pdf_to_single_page_bytes, pdf_path_to_pages_df
from ..utils.remote_auth import resolve_remote_api_key
Expand Down Expand Up @@ -109,6 +111,41 @@ def _deep_copy_row(row_dict: Dict[str, Any]) -> Dict[str, Any]:
return out


def _b64_from_dict(d: dict) -> Optional[str]:
"""Resolve base64 from a dict: inline ``image_b64`` then ``stored_image_uri``."""
b64 = d.get("image_b64")
if isinstance(b64, str) and b64:
return b64
uri = d.get("stored_image_uri")
if isinstance(uri, str) and uri:
return load_image_b64_from_uri(uri)
return None


def _resolve_page_image_b64(page_image: Any) -> Optional[str]:
"""Return base64 for a page image, loading from stored URI if needed."""
if not isinstance(page_image, dict):
return None
return _b64_from_dict(page_image)


def _resolve_item_image_b64(item: dict, page_image_b64: Optional[str]) -> Optional[str]:
"""Return base64 for a structured content item, with fallback chain.

Priority: inline base64 > stored URI > crop from page image.
"""
resolved = _b64_from_dict(item)
if resolved:
return resolved
if page_image_b64:
bbox = item.get("bbox_xyxy_norm")
if bbox and len(bbox) == 4:
cropped_b64, _ = _crop_b64_image_by_norm_bbox(page_image_b64, bbox_xyxy_norm=bbox)
return cropped_b64
return page_image_b64
return None


def explode_content_to_rows(
batch_df: Any,
*,
Expand Down Expand Up @@ -167,9 +204,7 @@ def explode_content_to_rows(
if not any(c in batch_df.columns for c in content_columns):
batch_df = batch_df.copy()
if text_mod in IMAGE_MODALITIES and "page_image" in batch_df.columns:
batch_df["_image_b64"] = batch_df["page_image"].apply(
lambda pi: pi.get("image_b64") if isinstance(pi, dict) else None
)
batch_df["_image_b64"] = batch_df["page_image"].apply(_resolve_page_image_b64)
batch_df["_embed_modality"] = text_mod
return batch_df

Expand All @@ -181,8 +216,8 @@ def explode_content_to_rows(
# Extract page-level image b64 once per source row.
page_image = row_dict.get("page_image")
page_image_b64: Optional[str] = None
if any_images and isinstance(page_image, dict):
page_image_b64 = page_image.get("image_b64")
if any_images:
page_image_b64 = _resolve_page_image_b64(page_image)

# Row for page text.
page_text = row_dict.get(text_column)
Expand Down Expand Up @@ -210,15 +245,8 @@ def explode_content_to_rows(
content_row[text_column] = t.strip()
content_row["_embed_modality"] = struct_mod
content_row["_content_type"] = col
if struct_mod in IMAGE_MODALITIES and page_image_b64:
bbox = item.get("bbox_xyxy_norm")
if bbox and len(bbox) == 4:
cropped_b64, _ = _crop_b64_image_by_norm_bbox(page_image_b64, bbox_xyxy_norm=bbox)
content_row["_image_b64"] = cropped_b64
else:
content_row["_image_b64"] = page_image_b64
elif struct_mod in IMAGE_MODALITIES:
content_row["_image_b64"] = None
if struct_mod in IMAGE_MODALITIES:
content_row["_image_b64"] = _resolve_item_image_b64(item, page_image_b64)
new_rows.append(content_row)
exploded_any = True

Expand Down Expand Up @@ -271,9 +299,7 @@ def collapse_content_to_page_rows(
# Full page image (no cropping) for image modalities.
if modality in IMAGE_MODALITIES:
if "page_image" in batch_df.columns:
batch_df["_image_b64"] = batch_df["page_image"].apply(
lambda pi: pi.get("image_b64") if isinstance(pi, dict) else None
)
batch_df["_image_b64"] = batch_df["page_image"].apply(_resolve_page_image_b64)
else:
batch_df["_image_b64"] = None

Expand Down Expand Up @@ -1337,6 +1363,15 @@ def extract_audio(
self._tasks.append((apply_asr_to_df, {"asr_params": self._extract_audio_asr_kwargs}))
return self

def store(self, params: "StoreParams | None" = None, **kwargs: Any) -> "InProcessIngestor":
"""Store extracted images to disk or cloud storage via fsspec."""
from nemo_retriever.io.image_store import store_extracted_images

resolved = _coerce_params(params, StoreParams, kwargs)
store_kwargs = resolved.model_dump(mode="python")
self._tasks.append((store_extracted_images, store_kwargs))
return self

def dedup(self, params: "DedupParams | None" = None, **kwargs: Any) -> "InProcessIngestor":
"""Remove duplicate and overlapping images before captioning."""

Expand All @@ -1356,7 +1391,6 @@ def caption(self, params: "CaptionParams | None" = None, **kwargs: Any) -> "InPr
Otherwise a local ``NemotronVLMCaptioner`` is loaded from HF.
"""
from nemo_retriever.caption.caption import caption_images
from nemo_retriever.params import CaptionParams

resolved = _coerce_params(params, CaptionParams, kwargs)
caption_kwargs = resolved.model_dump(mode="python")
Expand Down
4 changes: 3 additions & 1 deletion nemo_retriever/src/nemo_retriever/ingestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from nemo_retriever.params import IngestExecuteParams
from nemo_retriever.params import IngestorCreateParams
from nemo_retriever.params import RunMode
from nemo_retriever.params import StoreParams
from nemo_retriever.params import VdbUploadParams


Expand Down Expand Up @@ -141,8 +142,9 @@ def split(self, params: TextChunkParams | None = None, **kwargs: Any) -> "ingest
_ = _merge_params(params, kwargs)
self._not_implemented("split")

def store(self) -> "ingestor":
def store(self, params: StoreParams | None = None, **kwargs: Any) -> "ingestor":
"""Record a store task configuration."""
_ = _merge_params(params, kwargs)
self._not_implemented("store")

def store_embed(self) -> "ingestor":
Expand Down
3 changes: 3 additions & 0 deletions nemo_retriever/src/nemo_retriever/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
# SPDX-License-Identifier: Apache-2.0

from .dataframe import read_dataframe, validate_primitives_dataframe, write_dataframe
from .image_store import load_image_b64_from_uri, store_extracted_images
from .markdown import to_markdown, to_markdown_by_page
from .stage_files import build_stage_output_path, find_stage_inputs

__all__ = [
"build_stage_output_path",
"find_stage_inputs",
"load_image_b64_from_uri",
"read_dataframe",
"store_extracted_images",
"to_markdown",
"to_markdown_by_page",
"validate_primitives_dataframe",
Expand Down
Loading
Loading