Skip to content

Commit 010f9a7

Browse files
committed
INTPYTHON-752 Integrate pymongo-vectorsearch-utils
1 parent 90424ed commit 010f9a7

File tree

17 files changed

+157
-1041
lines changed

17 files changed

+157
-1041
lines changed

libs/langchain-mongodb/langchain_mongodb/agent_toolkit/database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
self._sample_docs_in_coll_info = sample_docs_in_collection_info
6565
self._indexes_in_coll_info = indexes_in_collection_info
6666

67-
_append_client_metadata(self._client)
67+
_append_client_metadata(self._client, DRIVER_METADATA)
6868

6969
@classmethod
7070
def from_connection_string(

libs/langchain-mongodb/langchain_mongodb/chat_message_histories.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(
112112
if connection_string:
113113
raise ValueError("Must provide connection_string or client, not both")
114114
self.client = client
115-
_append_client_metadata(self.client)
115+
_append_client_metadata(self.client, DRIVER_METADATA)
116116
elif connection_string:
117117
try:
118118
self.client = MongoClient(

libs/langchain-mongodb/langchain_mongodb/docstores.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, collection: Collection, text_key: str = "page_content") -> No
3737
self.collection = collection
3838
self._text_key = text_key
3939

40-
_append_client_metadata(self.collection.database.client)
40+
_append_client_metadata(self.collection.database.client, DRIVER_METADATA)
4141

4242
@classmethod
4343
def from_connection_string(
@@ -99,12 +99,13 @@ def mset(
9999
batch_size: Number of documents to insert at a time.
100100
Tuning this may help with performance and sidestep MongoDB limits.
101101
"""
102-
keys, docs = zip(*key_value_pairs)
102+
keys, docs = zip(*key_value_pairs, strict=False)
103103
n_docs = len(docs)
104104
start = 0
105105
for end in range(batch_size, n_docs + batch_size, batch_size):
106106
texts, metadatas = zip(
107-
*[(doc.page_content, doc.metadata) for doc in docs[start:end]]
107+
*[(doc.page_content, doc.metadata) for doc in docs[start:end]],
108+
strict=False,
108109
)
109110
self.insert_many(texts=texts, metadatas=metadatas, ids=keys[start:end]) # type: ignore
110111
start = end
@@ -149,6 +150,7 @@ def insert_many(
149150
in the batch that do not have conflicting _ids will still be inserted.
150151
"""
151152
to_insert = [
152-
{"_id": i, self._text_key: t, **m} for i, t, m in zip(ids, texts, metadatas)
153+
{"_id": i, self._text_key: t, **m}
154+
for i, t, m in zip(ids, texts, metadatas, strict=False)
153155
]
154156
self.collection.insert_many(to_insert) # type: ignore

libs/langchain-mongodb/langchain_mongodb/graphrag/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def __init__(
186186
self.collection = collection
187187

188188
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
189-
_append_client_metadata(collection.database.client)
189+
_append_client_metadata(collection.database.client, DRIVER_METADATA)
190190

191191
self.entity_extraction_model = entity_extraction_model
192192
self.entity_prompt = (

libs/langchain-mongodb/langchain_mongodb/index.py

Lines changed: 7 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
from pymongo.collection import Collection
88
from pymongo.operations import SearchIndexModel
99

10+
# Don't break imports for modules that expect these functions
11+
# to be in this module.
12+
from pymongo_search_utils import ( # noqa: F401
13+
create_vector_search_index,
14+
update_vector_search_index,
15+
)
16+
1017
logger = logging.getLogger(__file__)
1118

1219

@@ -34,60 +41,6 @@ def _vector_search_index_definition(
3441
return definition
3542

3643

37-
def create_vector_search_index(
38-
collection: Collection,
39-
index_name: str,
40-
dimensions: int,
41-
path: str,
42-
similarity: str,
43-
filters: Optional[List[str]] = None,
44-
*,
45-
wait_until_complete: Optional[float] = None,
46-
**kwargs: Any,
47-
) -> None:
48-
"""Experimental Utility function to create a vector search index
49-
50-
Args:
51-
collection (Collection): MongoDB Collection
52-
index_name (str): Name of Index
53-
dimensions (int): Number of dimensions in embedding
54-
path (str): field with vector embedding
55-
similarity (str): The similarity score used for the index
56-
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
57-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
58-
until search index is ready.
59-
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
60-
"""
61-
logger.info("Creating Search Index %s on %s", index_name, collection.name)
62-
63-
if collection.name not in collection.database.list_collection_names(
64-
authorizedCollections=True
65-
):
66-
collection.database.create_collection(collection.name)
67-
68-
result = collection.create_search_index(
69-
SearchIndexModel(
70-
definition=_vector_search_index_definition(
71-
dimensions=dimensions,
72-
path=path,
73-
similarity=similarity,
74-
filters=filters,
75-
**kwargs,
76-
),
77-
name=index_name,
78-
type="vectorSearch",
79-
)
80-
)
81-
82-
if wait_until_complete:
83-
_wait_for_predicate(
84-
predicate=lambda: _is_index_ready(collection, index_name),
85-
err=f"{index_name=} did not complete in {wait_until_complete}!",
86-
timeout=wait_until_complete,
87-
)
88-
logger.info(result)
89-
90-
9144
def drop_vector_search_index(
9245
collection: Collection,
9346
index_name: str,
@@ -115,54 +68,6 @@ def drop_vector_search_index(
11568
logger.info("Vector Search index %s.%s dropped", collection.name, index_name)
11669

11770

118-
def update_vector_search_index(
119-
collection: Collection,
120-
index_name: str,
121-
dimensions: int,
122-
path: str,
123-
similarity: str,
124-
filters: Optional[List[str]] = None,
125-
*,
126-
wait_until_complete: Optional[float] = None,
127-
**kwargs: Any,
128-
) -> None:
129-
"""Update a search index.
130-
131-
Replace the existing index definition with the provided definition.
132-
133-
Args:
134-
collection (Collection): MongoDB Collection
135-
index_name (str): Name of Index
136-
dimensions (int): Number of dimensions in embedding
137-
path (str): field with vector embedding
138-
similarity (str): The similarity score used for the index.
139-
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
140-
wait_until_complete (Optional[float]): If provided, number of seconds to wait
141-
until search index is ready.
142-
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
143-
"""
144-
logger.info(
145-
"Updating Search Index %s from Collection: %s", index_name, collection.name
146-
)
147-
collection.update_search_index(
148-
name=index_name,
149-
definition=_vector_search_index_definition(
150-
dimensions=dimensions,
151-
path=path,
152-
similarity=similarity,
153-
filters=filters,
154-
**kwargs,
155-
),
156-
)
157-
if wait_until_complete:
158-
_wait_for_predicate(
159-
predicate=lambda: _is_index_ready(collection, index_name),
160-
err=f"Index {index_name} update did not complete in {wait_until_complete}!",
161-
timeout=wait_until_complete,
162-
)
163-
logger.info("Update succeeded")
164-
165-
16671
def _is_index_ready(collection: Collection, index_name: str) -> bool:
16772
"""Check for the index name in the list of available search indexes to see if the
16873
specified index is of status READY

libs/langchain-mongodb/langchain_mongodb/indexes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, collection: Collection) -> None:
3636
super().__init__(namespace=namespace)
3737
self._collection = collection
3838

39-
_append_client_metadata(self._collection.database.client)
39+
_append_client_metadata(self._collection.database.client, DRIVER_METADATA)
4040

4141
@classmethod
4242
def from_connection_string(
@@ -85,7 +85,7 @@ def update(
8585
if len(keys) != len(group_ids):
8686
raise ValueError("Number of keys does not match number of group_ids")
8787

88-
for key, group_id in zip(keys, group_ids):
88+
for key, group_id in zip(keys, group_ids, strict=False):
8989
self._collection.find_one_and_update(
9090
{"namespace": self.namespace, "key": key},
9191
{"$set": {"group_id": group_id, "updated_at": self.get_time()}},

libs/langchain-mongodb/langchain_mongodb/loaders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
self.include_db_collection_in_metadata = include_db_collection_in_metadata
5555

5656
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
57-
_append_client_metadata(self.db.client)
57+
_append_client_metadata(self.db.client, DRIVER_METADATA)
5858

5959
@classmethod
6060
def from_connection_string(

libs/langchain-mongodb/langchain_mongodb/retrievers/full_text_search.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
from pymongo.collection import Collection
99

1010
from langchain_mongodb.pipelines import text_search_stage
11-
from langchain_mongodb.utils import _append_client_metadata, make_serializable
11+
from langchain_mongodb.utils import (
12+
DRIVER_METADATA,
13+
_append_client_metadata,
14+
make_serializable,
15+
)
1216

1317

1418
class MongoDBAtlasFullTextSearchRetriever(BaseRetriever):
@@ -64,7 +68,7 @@ def _get_relevant_documents(
6468
)
6569

6670
if not self._added_metadata:
67-
_append_client_metadata(self.collection.database.client)
71+
_append_client_metadata(self.collection.database.client, DRIVER_METADATA)
6872
self._added_metadata = True
6973

7074
# Execution

libs/langchain-mongodb/langchain_mongodb/utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,21 @@
2424
from typing import Any, Dict, List, Union
2525

2626
import numpy as np
27-
from pymongo import MongoClient
2827
from pymongo.driver_info import DriverInfo
2928

29+
# Don't break imports for modules that expect this function
30+
# to be in this module.
31+
from pymongo_search_utils import (
32+
append_client_metadata as _append_client_metadata, # noqa: F401
33+
)
34+
3035
logger = logging.getLogger(__name__)
3136

3237
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
3338

3439
DRIVER_METADATA = DriverInfo(name="Langchain", version=version("langchain-mongodb"))
3540

3641

37-
def _append_client_metadata(client: MongoClient) -> None:
38-
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
39-
if callable(client.append_metadata):
40-
client.append_metadata(DRIVER_METADATA)
41-
42-
4342
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
4443
"""Row-wise cosine similarity between two equal-width matrices."""
4544
if len(X) == 0 or len(Y) == 0:

0 commit comments

Comments
 (0)