Skip to content

Commit dcac415

Browse files
committed
[api][integration] Enrich interface for vector store.
1 parent 82a2c30 commit dcac415

File tree

3 files changed

+397
-65
lines changed

3 files changed

+397
-65
lines changed

python/flink_agents/api/vector_stores/vector_store.py

Lines changed: 158 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,17 @@ class VectorStoreQuery(BaseModel):
5959
namespaces, or other search configurations.
6060
"""
6161

62-
mode: VectorStoreQueryMode = Field(default=VectorStoreQueryMode.SEMANTIC, description="The type of query "
63-
"operation to perform.")
64-
query_text: str = Field(description="Text query to be converted to embedding for semantic search.")
62+
mode: VectorStoreQueryMode = Field(
63+
default=VectorStoreQueryMode.SEMANTIC,
64+
description="The type of query operation to perform.",
65+
)
66+
query_text: str = Field(
67+
description="Text query to be converted to embedding for semantic search."
68+
)
6569
limit: int = Field(default=10, description="Maximum number of results to return.")
66-
extra_args: Dict[str, Any] = Field(default_factory=dict, description="Vector store-specific parameters.")
70+
extra_args: Dict[str, Any] = Field(
71+
default_factory=dict, description="Vector store-specific parameters."
72+
)
6773

6874
def __str__(self) -> str:
6975
return f"{self.mode.value} query: '{self.query_text}' (limit={self.limit})"
@@ -85,11 +91,18 @@ class Document(BaseModel):
8591
"""
8692

8793
content: str = Field(description="The actual text content of the document.")
88-
metadata: Dict[str, Any] = Field(default_factory=dict, description="Document metadata such as source, author, timestamp, etc.")
89-
id: str | None = Field(default=None, description="Unique identifier of the document.")
94+
metadata: Dict[str, Any] = Field(
95+
default_factory=dict,
96+
description="Document metadata such as source, author, timestamp, etc.",
97+
)
98+
id: str | None = Field(
99+
default=None, description="Unique identifier of the document."
100+
)
90101

91102
def __str__(self) -> str:
92-
content_preview = self.content[:50] + "..." if len(self.content) > 50 else self.content
103+
content_preview = (
104+
self.content[:50] + "..." if len(self.content) > 50 else self.content
105+
)
93106
return f"Document: {content_preview}"
94107

95108

@@ -104,7 +117,9 @@ class VectorStoreQueryResult(BaseModel):
104117
List of documents retrieved from the vector store.
105118
"""
106119

107-
documents: List[Document] = Field(description="List of documents retrieved from the vector store.")
120+
documents: List[Document] = Field(
121+
description="List of documents retrieved from the vector store."
122+
)
108123

109124
def __str__(self) -> str:
110125
return f"QueryResult: {len(self.documents)} documents"
@@ -118,7 +133,9 @@ class BaseVectorStore(Resource, ABC):
118133
embedding generation internally.
119134
"""
120135

121-
embedding_model: str = Field(description="Name of the embedding model resource to use.")
136+
embedding_model: str = Field(
137+
description="Name of the embedding model resource to use."
138+
)
122139

123140
@classmethod
124141
@override
@@ -135,6 +152,40 @@ def store_kwargs(self) -> Dict[str, Any]:
135152
when performing vector search operations.
136153
"""
137154

155+
def add(
156+
self, documents: Document | List[Document], collection_name: str | None = None, **kwargs: Any
157+
) -> List[str]:
158+
"""Add documents to the vector store.
159+
160+
Converts document content to embeddings and stores them in the vector store.
161+
The implementation may generate IDs for documents that don't have one.
162+
163+
Args:
164+
documents: Single Document or list of Documents to add.
165+
collection_name: The collection name of the documents to add to. Optional.
166+
**kwargs: Vector store specific parameters.
167+
168+
Returns:
169+
List of document IDs that were added to the vector store
170+
"""
171+
# Normalize to list
172+
documents = maybe_cast_to_list(documents)
173+
174+
# Generate embeddings for all documents
175+
embedding_model = self.get_resource(
176+
self.embedding_model, ResourceType.EMBEDDING_MODEL
177+
)
178+
179+
# Generate embeddings for each document
180+
embeddings = [embedding_model.embed(doc.content) for doc in documents]
181+
182+
# Merge setup kwargs with add-specific args
183+
merged_kwargs = self.store_kwargs.copy()
184+
merged_kwargs.update(kwargs)
185+
186+
# Perform add operation using the abstract method
187+
return self.add_embedding(documents=documents, embeddings=embeddings, collection_name=collection_name, **merged_kwargs)
188+
138189
def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
139190
"""Perform vector search using structured query object.
140191
@@ -160,12 +211,35 @@ def query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
160211
documents = self.query_embedding(query_embedding, query.limit, **merged_kwargs)
161212

162213
# Return structured result
163-
return VectorStoreQueryResult(
164-
documents=documents
165-
)
214+
return VectorStoreQueryResult(documents=documents)
166215

167216
@abstractmethod
168-
def query_embedding(self, embedding: List[float], limit: int = 10, **kwargs: Any) -> List[Document]:
217+
def get(self, ids: str | List[str] | None = None, collection_name: str | None = None, **kwargs: Any) -> List[Document]:
218+
"""Retrieve a document from the vector store by its ID.
219+
220+
Args:
221+
ids: Unique identifier of the documents to retrieve
222+
collection_name: The collection name of the documents to retrieve. Optional.
223+
**kwargs: Vector store specific parameters (offset, limit, filter etc.)
224+
225+
Returns:
226+
Document object if found, None otherwise
227+
"""
228+
229+
@abstractmethod
230+
def delete(self, ids: str | List[str] | None = None, collection_name: str | None = None, **kwargs: Any) -> None:
231+
"""Delete documents in the vector store by its IDs.
232+
233+
Args:
234+
ids: Unique identifier of the documents to delete
235+
collection_name: The collection name of the documents belong to. Optional.
236+
**kwargs: Vector store specific parameters (filter etc.)
237+
"""
238+
239+
@abstractmethod
240+
def query_embedding(
241+
self, embedding: List[float], limit: int = 10, **kwargs: Any
242+
) -> List[Document]:
169243
"""Perform vector search using pre-computed embedding.
170244
171245
Args:
@@ -176,3 +250,74 @@ def query_embedding(self, embedding: List[float], limit: int = 10, **kwargs: Any
176250
Returns:
177251
List of documents matching the search criteria
178252
"""
253+
254+
@abstractmethod
255+
def add_embedding(
256+
self,
257+
*,
258+
documents: List[Document],
259+
embeddings: List[List[float]],
260+
collection_name: str | None = None,
261+
**kwargs: Any,
262+
) -> List[str]:
263+
"""Add documents with pre-computed embeddings to the vector store.
264+
265+
Args:
266+
documents: Documents to add to the vector store
267+
embeddings: Pre-computed embedding vector for each document
268+
collection_name: The collection name of the documents to add. Optional.
269+
**kwargs: Vector store-specific parameters (collection, namespace, etc.)
270+
271+
Returns:
272+
List of document IDs that were added to the vector store
273+
"""
274+
275+
276+
class Collection(BaseModel):
277+
"""Represents a collection of documents."""
278+
name: str
279+
size: int
280+
metadata: Dict[str, Any] | None = None
281+
282+
283+
class CollectionManageableVectorStore(BaseVectorStore, ABC):
284+
"""Base abstract class for vector store which support collection management."""
285+
286+
@abstractmethod
287+
def get_or_create_collection(
288+
self, name: str, metadata: Dict[str, Any]
289+
) -> Collection:
290+
"""Get a collection, or create it if it doesn't exist.
291+
292+
Args:
293+
name: Name of the collection
294+
metadata: Metadata of the collection
295+
Returns:
296+
The retrieved or created collection
297+
"""
298+
299+
@abstractmethod
300+
def get_collection(self, name: str) -> Collection:
301+
"""Get a collection, raise an exception if it doesn't exist.
302+
303+
Args:
304+
name: Name of the collection
305+
Returns:
306+
The retrieved collection
307+
"""
308+
309+
@abstractmethod
310+
def delete_collection(self, name: str) -> Collection:
311+
"""Delete a collection.
312+
313+
Args:
314+
name: Name of the collection
315+
Returns:
316+
The deleted collection
317+
"""
318+
319+
def maybe_cast_to_list(value: Any | List[Any]) -> List[Any] | None:
320+
"""Cast T to List[T] if T is not None."""
321+
if value is None:
322+
return None
323+
return [value] if not isinstance(value, list) else value

python/flink_agents/integrations/vector_stores/chroma/chroma_vector_store.py

Lines changed: 117 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,26 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
################################################################################
18+
import uuid
1819
from typing import Any, Dict, List
1920

2021
import chromadb
2122
from chromadb import ClientAPI as ChromaClient
2223
from chromadb import CloudClient
2324
from chromadb.config import Settings
2425
from pydantic import Field
26+
from typing_extensions import override
2527

2628
from flink_agents.api.vector_stores.vector_store import (
27-
BaseVectorStore,
29+
Collection,
30+
CollectionManageableVectorStore,
2831
Document,
2932
)
3033

3134
DEFAULT_COLLECTION = "flink_agents_chroma_collection"
3235

3336

34-
class ChromaVectorStore(BaseVectorStore):
37+
class ChromaVectorStore(CollectionManageableVectorStore):
3538
"""ChromaDB vector store that handles connection and semantic search.
3639
3740
Visit https://docs.trychroma.com/ for ChromaDB documentation.
@@ -194,7 +197,100 @@ def store_kwargs(self) -> Dict[str, Any]:
194197
"create_collection_if_not_exists": self.create_collection_if_not_exists,
195198
}
196199

197-
def query_embedding(self, embedding: List[float], limit: int = 10, **kwargs: Any) -> List[Document]:
200+
@override
201+
def get_or_create_collection(
202+
self, name: str, metadata: Dict[str, Any] | None = None
203+
) -> Collection:
204+
collection = self.client.get_or_create_collection(name=name, metadata=metadata)
205+
return Collection(
206+
name=collection.name, size=collection.count(), metadata=collection.metadata
207+
)
208+
209+
@override
210+
def get_collection(self, name: str) -> Collection:
211+
collection = self.client.get_collection(name=name)
212+
return Collection(
213+
name=collection.name, size=collection.count(), metadata=collection.metadata
214+
)
215+
216+
@override
217+
def delete_collection(self, name: str) -> Collection:
218+
collection = self.get_collection(name=name)
219+
self.client.delete_collection(name=collection.name)
220+
return collection
221+
222+
@override
223+
def get(
224+
self,
225+
ids: str | List[str] | None = None,
226+
collection_name: str | None = None,
227+
**kwargs: Any,
228+
) -> List[Document]:
229+
where = kwargs.get("where")
230+
limit = kwargs.get("limit")
231+
offset = kwargs.get("offset")
232+
where_document = kwargs.get("where_document")
233+
results = self.client.get_collection(
234+
name=collection_name or self.collection
235+
).get(
236+
ids=ids,
237+
where=where,
238+
limit=limit,
239+
offset=offset,
240+
where_document=where_document,
241+
)
242+
243+
ids = results["ids"]
244+
documents = results["documents"]
245+
metadatas = results["metadatas"]
246+
247+
return [
248+
Document(id=id, content=document, metadata=dict(metadata))
249+
for id, document, metadata in zip(ids, documents, metadatas, strict=False)
250+
]
251+
252+
@override
253+
def delete(
254+
self,
255+
ids: str | List[str] | None = None,
256+
collection_name: str | None = None,
257+
**kwargs: Any,
258+
) -> None:
259+
where = kwargs.get("where")
260+
where_document = kwargs.get("where_document")
261+
if ids is None and where is None and where_document is None:
262+
ids = (
263+
self.client.get_collection(collection_name or self.collection)
264+
.get(include=[])
265+
.get("ids")
266+
)
267+
# collection is empty
268+
if len(ids) == 0:
269+
return
270+
self.client.get_collection(name=collection_name or self.collection).delete(
271+
ids=ids, where=where, where_document=where_document
272+
)
273+
274+
@override
275+
def add_embedding(
276+
self,
277+
*,
278+
documents: List[Document],
279+
embeddings: List[List[float]],
280+
collection_name: str | None = None,
281+
**kwargs: Any,
282+
) -> List[str]:
283+
ids = [doc.id or str(uuid.uuid4()) for doc in documents]
284+
docs = [doc.content for doc in documents]
285+
metadatas = [doc.metadata for doc in documents]
286+
self.client.get_collection(name=collection_name or self.collection).add(
287+
ids=ids, documents=docs, embeddings=embeddings, metadatas=metadatas
288+
)
289+
return ids
290+
291+
def query_embedding(
292+
self, embedding: List[float], limit: int = 10, **kwargs: Any
293+
) -> List[Document]:
198294
"""Perform vector search using pre-computed embedding.
199295
200296
Args:
@@ -207,8 +303,12 @@ def query_embedding(self, embedding: List[float], limit: int = 10, **kwargs: Any
207303
"""
208304
# Extract ChromaDB-specific parameters
209305
collection_name = kwargs.get("collection", self.collection)
210-
collection_metadata = kwargs.get("collection_metadata", self.collection_metadata)
211-
create_collection_if_not_exists = kwargs.get("create_collection_if_not_exists", self.create_collection_if_not_exists)
306+
collection_metadata = kwargs.get(
307+
"collection_metadata", self.collection_metadata
308+
)
309+
create_collection_if_not_exists = kwargs.get(
310+
"create_collection_if_not_exists", self.create_collection_if_not_exists
311+
)
212312
where = kwargs.get("where") # Metadata filters
213313

214314
# Get or create collection based on configuration
@@ -235,13 +335,18 @@ def query_embedding(self, embedding: List[float], limit: int = 10, **kwargs: Any
235335
if results["documents"] and results["documents"][0]:
236336
for i, doc_content in enumerate(results["documents"][0]):
237337
doc_id = results["ids"][0][i] if results["ids"] else None
238-
metadata = results["metadatas"][0][i] if results["metadatas"] and results["metadatas"][0] else {}
338+
metadata = (
339+
results["metadatas"][0][i]
340+
if results["metadatas"] and results["metadatas"][0]
341+
else {}
342+
)
239343

240-
documents.append(Document(
241-
content=doc_content,
242-
id=doc_id,
243-
metadata=metadata,
244-
))
344+
documents.append(
345+
Document(
346+
content=doc_content,
347+
id=doc_id,
348+
metadata=metadata,
349+
)
350+
)
245351

246352
return documents
247-

0 commit comments

Comments
 (0)