diff --git a/libs/langchain-mongodb/langchain_mongodb/agent_toolkit/database.py b/libs/langchain-mongodb/langchain_mongodb/agent_toolkit/database.py index 31e8c5ef..9c6b722c 100644 --- a/libs/langchain-mongodb/langchain_mongodb/agent_toolkit/database.py +++ b/libs/langchain-mongodb/langchain_mongodb/agent_toolkit/database.py @@ -64,7 +64,7 @@ def __init__( self._sample_docs_in_coll_info = sample_docs_in_collection_info self._indexes_in_coll_info = indexes_in_collection_info - _append_client_metadata(self._client) + _append_client_metadata(self._client, DRIVER_METADATA) @classmethod def from_connection_string( diff --git a/libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py b/libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py index 6ebd96bd..2f9a71c1 100644 --- a/libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py +++ b/libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py @@ -112,7 +112,7 @@ def __init__( if connection_string: raise ValueError("Must provide connection_string or client, not both") self.client = client - _append_client_metadata(self.client) + _append_client_metadata(self.client, DRIVER_METADATA) elif connection_string: try: self.client = MongoClient( diff --git a/libs/langchain-mongodb/langchain_mongodb/docstores.py b/libs/langchain-mongodb/langchain_mongodb/docstores.py index f24b1f5b..23cd61d4 100644 --- a/libs/langchain-mongodb/langchain_mongodb/docstores.py +++ b/libs/langchain-mongodb/langchain_mongodb/docstores.py @@ -37,7 +37,7 @@ def __init__(self, collection: Collection, text_key: str = "page_content") -> No self.collection = collection self._text_key = text_key - _append_client_metadata(self.collection.database.client) + _append_client_metadata(self.collection.database.client, DRIVER_METADATA) @classmethod def from_connection_string( diff --git a/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py b/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py index 52417444..ac283f9a 100644 --- a/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py +++ b/libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py @@ -186,7 +186,7 @@ def __init__( self.collection = collection # append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions - _append_client_metadata(collection.database.client) + _append_client_metadata(collection.database.client, DRIVER_METADATA) self.entity_extraction_model = entity_extraction_model self.entity_prompt = ( diff --git a/libs/langchain-mongodb/langchain_mongodb/index.py b/libs/langchain-mongodb/langchain_mongodb/index.py index fd53956c..a5a4d705 100644 --- a/libs/langchain-mongodb/langchain_mongodb/index.py +++ b/libs/langchain-mongodb/langchain_mongodb/index.py @@ -7,6 +7,13 @@ from pymongo.collection import Collection from pymongo.operations import SearchIndexModel +# Don't break imports for modules that expect these functions +# to be in this module. +from pymongo_search_utils import ( # noqa: F401 + create_vector_search_index, + update_vector_search_index, +) + logger = logging.getLogger(__file__) @@ -34,60 +41,6 @@ def _vector_search_index_definition( return definition -def create_vector_search_index( - collection: Collection, - index_name: str, - dimensions: int, - path: str, - similarity: str, - filters: Optional[List[str]] = None, - *, - wait_until_complete: Optional[float] = None, - **kwargs: Any, -) -> None: - """Experimental Utility function to create a vector search index - - Args: - collection (Collection): MongoDB Collection - index_name (str): Name of Index - dimensions (int): Number of dimensions in embedding - path (str): field with vector embedding - similarity (str): The similarity score used for the index - filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch - wait_until_complete (Optional[float]): If provided, number of seconds to wait - until search index is ready. - kwargs: Keyword arguments supplying any additional options to SearchIndexModel. - """ - logger.info("Creating Search Index %s on %s", index_name, collection.name) - - if collection.name not in collection.database.list_collection_names( - authorizedCollections=True - ): - collection.database.create_collection(collection.name) - - result = collection.create_search_index( - SearchIndexModel( - definition=_vector_search_index_definition( - dimensions=dimensions, - path=path, - similarity=similarity, - filters=filters, - **kwargs, - ), - name=index_name, - type="vectorSearch", - ) - ) - - if wait_until_complete: - _wait_for_predicate( - predicate=lambda: _is_index_ready(collection, index_name), - err=f"{index_name=} did not complete in {wait_until_complete}!", - timeout=wait_until_complete, - ) - logger.info(result) - - def drop_vector_search_index( collection: Collection, index_name: str, @@ -115,54 +68,6 @@ def drop_vector_search_index( logger.info("Vector Search index %s.%s dropped", collection.name, index_name) -def update_vector_search_index( - collection: Collection, - index_name: str, - dimensions: int, - path: str, - similarity: str, - filters: Optional[List[str]] = None, - *, - wait_until_complete: Optional[float] = None, - **kwargs: Any, -) -> None: - """Update a search index. - - Replace the existing index definition with the provided definition. - - Args: - collection (Collection): MongoDB Collection - index_name (str): Name of Index - dimensions (int): Number of dimensions in embedding - path (str): field with vector embedding - similarity (str): The similarity score used for the index. - filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch - wait_until_complete (Optional[float]): If provided, number of seconds to wait - until search index is ready. - kwargs: Keyword arguments supplying any additional options to SearchIndexModel. - """ - logger.info( - "Updating Search Index %s from Collection: %s", index_name, collection.name - ) - collection.update_search_index( - name=index_name, - definition=_vector_search_index_definition( - dimensions=dimensions, - path=path, - similarity=similarity, - filters=filters, - **kwargs, - ), - ) - if wait_until_complete: - _wait_for_predicate( - predicate=lambda: _is_index_ready(collection, index_name), - err=f"Index {index_name} update did not complete in {wait_until_complete}!", - timeout=wait_until_complete, - ) - logger.info("Update succeeded") - - def _is_index_ready(collection: Collection, index_name: str) -> bool: """Check for the index name in the list of available search indexes to see if the specified index is of status READY diff --git a/libs/langchain-mongodb/langchain_mongodb/indexes.py b/libs/langchain-mongodb/langchain_mongodb/indexes.py index 7f5c00a0..8f86b8a4 100644 --- a/libs/langchain-mongodb/langchain_mongodb/indexes.py +++ b/libs/langchain-mongodb/langchain_mongodb/indexes.py @@ -36,7 +36,7 @@ def __init__(self, collection: Collection) -> None: super().__init__(namespace=namespace) self._collection = collection - _append_client_metadata(self._collection.database.client) + _append_client_metadata(self._collection.database.client, DRIVER_METADATA) @classmethod def from_connection_string( diff --git a/libs/langchain-mongodb/langchain_mongodb/loaders.py b/libs/langchain-mongodb/langchain_mongodb/loaders.py index 2fc42900..934ba8a6 100644 --- a/libs/langchain-mongodb/langchain_mongodb/loaders.py +++ b/libs/langchain-mongodb/langchain_mongodb/loaders.py @@ -54,7 +54,7 @@ def __init__( self.include_db_collection_in_metadata = include_db_collection_in_metadata # append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions - _append_client_metadata(self.db.client) + _append_client_metadata(self.db.client, DRIVER_METADATA) @classmethod def from_connection_string( diff --git a/libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py b/libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py index 087a4cd0..e30b0ddd 100644 --- a/libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py +++ b/libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py @@ -8,7 +8,11 @@ from pymongo.collection import Collection from langchain_mongodb.pipelines import text_search_stage -from langchain_mongodb.utils import _append_client_metadata, make_serializable +from langchain_mongodb.utils import ( + DRIVER_METADATA, + _append_client_metadata, + make_serializable, +) class MongoDBAtlasFullTextSearchRetriever(BaseRetriever): @@ -64,7 +68,7 @@ def _get_relevant_documents( ) if not self._added_metadata: - _append_client_metadata(self.collection.database.client) + _append_client_metadata(self.collection.database.client, DRIVER_METADATA) self._added_metadata = True # Execution diff --git a/libs/langchain-mongodb/langchain_mongodb/utils.py b/libs/langchain-mongodb/langchain_mongodb/utils.py index 826ce56e..847d411e 100644 --- a/libs/langchain-mongodb/langchain_mongodb/utils.py +++ b/libs/langchain-mongodb/langchain_mongodb/utils.py @@ -24,9 +24,14 @@ from typing import Any, Dict, List, Union import numpy as np -from pymongo import MongoClient from pymongo.driver_info import DriverInfo +# Don't break imports for modules that expect this function +# to be in this module. +from pymongo_search_utils import ( + append_client_metadata as _append_client_metadata, # noqa: F401 +) + logger = logging.getLogger(__name__) Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] @@ -34,12 +39,6 @@ DRIVER_METADATA = DriverInfo(name="Langchain", version=version("langchain-mongodb")) -def _append_client_metadata(client: MongoClient) -> None: - # append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions - if callable(client.append_metadata): - client.append_metadata(DRIVER_METADATA) - - def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: """Row-wise cosine similarity between two equal-width matrices.""" if len(X) == 0 or len(Y) == 0: diff --git a/libs/langchain-mongodb/langchain_mongodb/vectorstores.py b/libs/langchain-mongodb/langchain_mongodb/vectorstores.py index a1adf7e5..acf21d7e 100644 --- a/libs/langchain-mongodb/langchain_mongodb/vectorstores.py +++ b/libs/langchain-mongodb/langchain_mongodb/vectorstores.py @@ -22,9 +22,10 @@ from langchain_core.embeddings import Embeddings from langchain_core.runnables.config import run_in_executor from langchain_core.vectorstores import VectorStore -from pymongo import MongoClient, ReplaceOne +from pymongo import MongoClient from pymongo.collection import Collection from pymongo.errors import CollectionInvalid +from pymongo_search_utils import bulk_embed_and_insert_texts from langchain_mongodb.index import ( create_vector_search_index, @@ -238,7 +239,7 @@ def __init__( self._relevance_score_fn = relevance_score_fn # append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions - _append_client_metadata(self._collection.database.client) + _append_client_metadata(self._collection.database.client, DRIVER_METADATA) if auto_create_index is False: return @@ -362,12 +363,23 @@ def add_texts( metadatas_batch.append(metadata) if (j + 1) % batch_size == 0 or size >= 47_000_000: if ids: - batch_res = self.bulk_embed_and_insert_texts( - texts_batch, metadatas_batch, ids[i : j + 1] + batch_res = bulk_embed_and_insert_texts( + embedding_func=self._embedding.embed_documents, + collection=self._collection, + text_key=self._text_key, + embedding_key=self._embedding_key, + texts=texts_batch, + metadatas=metadatas_batch, + ids=ids[i : j + 1], ) else: - batch_res = self.bulk_embed_and_insert_texts( - texts_batch, metadatas_batch + batch_res = bulk_embed_and_insert_texts( + embedding_func=self._embedding.embed_documents, + collection=self._collection, + text_key=self._text_key, + embedding_key=self._embedding_key, + texts=texts_batch, + metadatas=metadatas_batch, ) result_ids.extend(batch_res) texts_batch = [] @@ -376,12 +388,23 @@ def add_texts( i = j + 1 if texts_batch: if ids: - batch_res = self.bulk_embed_and_insert_texts( - texts_batch, metadatas_batch, ids[i : j + 1] + batch_res = bulk_embed_and_insert_texts( + embedding_func=self._embedding.embed_documents, + collection=self._collection, + text_key=self._text_key, + embedding_key=self._embedding_key, + texts=texts_batch, + metadatas=metadatas_batch, + ids=ids[i : j + 1], ) else: - batch_res = self.bulk_embed_and_insert_texts( - texts_batch, metadatas_batch + batch_res = bulk_embed_and_insert_texts( + embedding_func=self._embedding.embed_documents, + collection=self._collection, + text_key=self._text_key, + embedding_key=self._embedding_key, + texts=texts_batch, + metadatas=metadatas_batch, ) result_ids.extend(batch_res) return result_ids @@ -419,39 +442,6 @@ def get_by_ids(self, ids: Sequence[str], /) -> list[Document]: docs.append(Document(page_content=text, id=oid_to_str(_id), metadata=doc)) return docs - def bulk_embed_and_insert_texts( - self, - texts: Union[List[str], Iterable[str]], - metadatas: Union[List[dict], Generator[dict, Any, Any]], - ids: Optional[List[str]] = None, - ) -> List[str]: - """Bulk insert single batch of texts, embeddings, and optionally ids. - - See add_texts for additional details. - """ - if not texts: - return [] - # Compute embedding vectors - embeddings = self._embedding.embed_documents(list(texts)) - if not ids: - ids = [str(ObjectId()) for _ in range(len(list(texts)))] - docs = [ - { - "_id": str_to_oid(i), - self._text_key: t, - self._embedding_key: embedding, - **m, - } - for i, t, m, embedding in zip( - ids, texts, metadatas, embeddings, strict=True - ) - ] - operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs] - # insert the documents in MongoDB Atlas - result = self._collection.bulk_write(operations) - assert result.upserted_ids is not None - return [oid_to_str(_id) for _id in result.upserted_ids.values()] - def add_documents( self, documents: List[Document], @@ -484,8 +474,14 @@ def add_documents( strict=True, ) result_ids.extend( - self.bulk_embed_and_insert_texts( - texts=texts, metadatas=metadatas, ids=ids[start:end] + bulk_embed_and_insert_texts( + embedding_func=self._embedding.embed_documents, + collection=self._collection, + text_key=self._text_key, + embedding_key=self._embedding_key, + texts=texts, + metadatas=metadatas, + ids=ids[start:end], ) ) start = end diff --git a/libs/langchain-mongodb/pyproject.toml b/libs/langchain-mongodb/pyproject.toml index 07ac4982..5414332e 100644 --- a/libs/langchain-mongodb/pyproject.toml +++ b/libs/langchain-mongodb/pyproject.toml @@ -42,6 +42,7 @@ dev = [ "langchain-tests==0.3.22,<1.0", "pip>=25.0.1", "typing-extensions>=4.12.2", + "pymongo-search-utils@git+https://github.com/mongodb-labs/pymongo-search-utils.git", ] [tool.pytest.ini_options] @@ -77,7 +78,7 @@ lint.select = [ "B", # flake8-bugbear "I", # isort ] -lint.ignore = ["E501", "B008", "UP007", "UP006", "UP035", "UP045"] +lint.ignore = ["E501", "B008", "UP007", "UP006", "UP035", "UP038", "UP045"] [tool.coverage.run] omit = ["tests/*"] diff --git a/libs/langchain-mongodb/tests/utils.py b/libs/langchain-mongodb/tests/utils.py index 97482342..1a0ab51e 100644 --- a/libs/langchain-mongodb/tests/utils.py +++ b/libs/langchain-mongodb/tests/utils.py @@ -26,6 +26,7 @@ from pymongo.driver_info import DriverInfo from pymongo.operations import SearchIndexModel from pymongo.results import BulkWriteResult, DeleteResult, InsertManyResult +from pymongo_search_utils import bulk_embed_and_insert_texts from langchain_mongodb import MongoDBAtlasVectorSearch from langchain_mongodb.agent_toolkit.database import MongoDBDatabase @@ -63,7 +64,9 @@ def bulk_embed_and_insert_texts( ids: Optional[List[str]] = None, ) -> List: """Patched insert_texts that waits for data to be indexed before returning""" - ids_inserted = super().bulk_embed_and_insert_texts(texts, metadatas, ids) + ids_inserted = bulk_embed_and_insert_texts( + self.embedding, self.collection, self._embedding_field_config, texts, metadatas, ids + ) n_docs = self.collection.count_documents({}) start = monotonic() while monotonic() - start <= TIMEOUT: diff --git a/libs/langchain-mongodb/uv.lock b/libs/langchain-mongodb/uv.lock index f96e8612..e983bc25 100644 --- a/libs/langchain-mongodb/uv.lock +++ b/libs/langchain-mongodb/uv.lock @@ -792,6 +792,7 @@ dev = [ { name = "mypy" }, { name = "pip" }, { name = "pre-commit" }, + { name = "pymongo-search-utils" }, { name = "pypdf" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -828,6 +829,7 @@ dev = [ { name = "mypy", specifier = ">=1.10" }, { name = "pip", specifier = ">=25.0.1" }, { name = "pre-commit", specifier = ">=4.0" }, + { name = "pymongo-search-utils", git = "https://github.com/mongodb-labs/pymongo-search-utils.git" }, { name = "pypdf", specifier = ">=5.0.1" }, { name = "pytest", specifier = ">=7.3.0" }, { name = "pytest-asyncio", specifier = ">=0.21.1" }, @@ -1846,6 +1848,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/fa/68b1555e62ed3ee87f8a2de99d5fb840cf045748da4488870b4dced44a95/pymongo-4.14.0-cp313-cp313t-win_amd64.whl", hash = "sha256:e506af9b25aac77cc5c5ea4a72f81764e4f5ea90ca799aac43d665ab269f291d", size = 1011181, upload-time = "2025-08-06T13:40:48.641Z" }, ] +[[package]] +name = "pymongo-search-utils" +version = "0.1.0.dev0" +source = { git = "https://github.com/mongodb-labs/pymongo-search-utils.git#6bb4c49e3c8e3dda5446b32b80d6c1fe17d4fcf4" } +dependencies = [ + { name = "pymongo" }, +] + [[package]] name = "pypdf" version = "6.0.0" diff --git a/uv.lock b/uv.lock index 5ed8532e..c44abaf4 100644 --- a/uv.lock +++ b/uv.lock @@ -804,6 +804,7 @@ dev = [ { name = "mypy", specifier = ">=1.10" }, { name = "pip", specifier = ">=25.0.1" }, { name = "pre-commit", specifier = ">=4.0" }, + { name = "pymongo-search-utils", git = "https://github.com/mongodb-labs/pymongo-search-utils.git" }, { name = "pypdf", specifier = ">=5.0.1" }, { name = "pytest", specifier = ">=7.3.0" }, { name = "pytest-asyncio", specifier = ">=0.21.1" },