From 2bc884027b8bfd01967647d795fffcff069d98d0 Mon Sep 17 00:00:00 2001 From: Yichen Zhao Date: Tue, 24 Mar 2026 19:21:34 +0800 Subject: [PATCH 1/9] feat(grimoire): Integrate Weaviate support into vector retrieval --- grimoire/agent/agent.py | 6 +- grimoire/config.py | 7 + grimoire/entity/tools.py | 51 +++- grimoire/retriever/weaviate_vector_db.py | 367 +++++++++++++++++++++++ 4 files changed, 427 insertions(+), 4 deletions(-) create mode 100644 grimoire/retriever/weaviate_vector_db.py diff --git a/grimoire/agent/agent.py b/grimoire/agent/agent.py index d68dfc6..8ec202b 100644 --- a/grimoire/agent/agent.py +++ b/grimoire/agent/agent.py @@ -40,8 +40,8 @@ PrivateSearchResourceType, ) from wizard_common.grimoire.retriever.base import BaseRetriever -from wizard_common.grimoire.retriever.meili_vector_db import ( - MeiliVectorRetriever, +from wizard_common.grimoire.retriever.weaviate_vector_db import ( + WeaviateVectorRetriever, ) from wizard_common.grimoire.retriever.reranker import ( get_tool_executor_config, @@ -248,7 +248,7 @@ def message_dtos_to_openai_messages( class BaseSearchableAgent(BaseStreamable, ABC): def __init__(self, config: GrimoireAgentConfig): - self.knowledge_database_retriever = MeiliVectorRetriever(config=config.vector) + self.knowledge_database_retriever = WeaviateVectorRetriever(config=config.vector) self.web_search_retriever = SearXNG( base_url=config.tools.searxng.base_url, engines=config.tools.searxng.engines ) diff --git a/grimoire/config.py b/grimoire/config.py index 2249423..fb39612 100644 --- a/grimoire/config.py +++ b/grimoire/config.py @@ -5,11 +5,18 @@ from wizard_common.config import OpenAIConfig +class WeaviateConfig(BaseModel): + host: str | None = Field(default=None) + port: int = Field(default=8080) + api_key: str | None = Field(default=None) + + class VectorConfig(BaseModel): embedding: OpenAIConfig host: str port: int = Field(default=8000) meili_api_key: str = Field(default=None) + weaviate: WeaviateConfig = Field(default_factory=WeaviateConfig) batch_size: int = Field(default=1) max_results: int = Field(default=10) wait_timeout: int = Field(default=0) diff --git a/grimoire/entity/tools.py b/grimoire/entity/tools.py index f937d7e..821fbe8 100644 --- a/grimoire/entity/tools.py +++ b/grimoire/entity/tools.py @@ -1,9 +1,10 @@ from enum import Enum from functools import partial -from typing import List, Literal, Callable, TypedDict, Awaitable, Union, get_args, cast +from typing import List, Literal, Callable, TypedDict, Awaitable, Union, get_args, cast, Any from opentelemetry import trace from pydantic import BaseModel, Field +import weaviate.classes as wvc tracer = trace.get_tracer("grimoire.entity.tools") ToolName = Literal["private_search", "web_search"] @@ -12,6 +13,8 @@ class Condition(BaseModel): namespace_id: str + user_id: str | None = Field(default=None) + record_type: str | None = Field(default=None) resource_ids: list[str] | None = Field(default=None) parent_ids: list[str] | None = Field(default=None) created_at: tuple[float, float] | None = Field(default=None) @@ -46,6 +49,52 @@ def to_meili_where(self) -> List[str | List[str]]: return and_clause + def to_weaviate_filters(self) -> Any: + where = wvc.query.Filter.by_property("namespace_id").equal(self.namespace_id) + + if self.user_id: + where = where & ( + wvc.query.Filter.by_property("user_id").is_none(True) + | wvc.query.Filter.by_property("user_id").equal(self.user_id) + ) + + if self.record_type: + where = where & wvc.query.Filter.by_property("type").equal(self.record_type) + + if self.resource_ids: + resource_filter = None + for rid in self.resource_ids: + each = wvc.query.Filter.by_property("chunk.resource_id").equal(rid) + resource_filter = each if resource_filter is None else (resource_filter | each) + if resource_filter is not None: + where = where & resource_filter + + if self.parent_ids: + parent_filter = None + for pid in self.parent_ids: + each = wvc.query.Filter.by_property("chunk.parent_id").equal(pid) + parent_filter = each if parent_filter is None else (parent_filter | each) + if parent_filter is not None: + where = where & parent_filter + + if self.created_at is not None: + where = where & wvc.query.Filter.by_property("chunk.created_at").greater_or_equal( + self.created_at[0] + ) + where = where & wvc.query.Filter.by_property("chunk.created_at").less_or_equal( + self.created_at[1] + ) + + if self.updated_at is not None: + where = where & wvc.query.Filter.by_property("chunk.updated_at").greater_or_equal( + self.updated_at[0] + ) + where = where & wvc.query.Filter.by_property("chunk.updated_at").less_or_equal( + self.updated_at[1] + ) + + return where + class ToolExecutorConfig(TypedDict): name: str diff --git a/grimoire/retriever/weaviate_vector_db.py b/grimoire/retriever/weaviate_vector_db.py new file mode 100644 index 0000000..8d3480f --- /dev/null +++ b/grimoire/retriever/weaviate_vector_db.py @@ -0,0 +1,367 @@ +import asyncio +from functools import partial +from typing import Any, List, Tuple + +import weaviate +import weaviate.classes as wvc +from openai import AsyncOpenAI +from opentelemetry import propagate, trace +from weaviate.exceptions import UnexpectedStatusCodeError + +from common.trace_info import TraceInfo +from wizard_common.grimoire.config import VectorConfig +from wizard_common.grimoire.entity.chunk import Chunk, ResourceChunkRetrieval +from wizard_common.grimoire.entity.index_record import IndexRecord, IndexRecordType +from wizard_common.grimoire.entity.message import Message +from wizard_common.grimoire.entity.retrieval import Score +from wizard_common.grimoire.entity.tools import ( + Condition, + PrivateSearchResourceType, + PrivateSearchTool, + Resource, +) +from wizard_common.grimoire.retriever.base import BaseRetriever, SearchFunction + +tracer = trace.get_tracer(__name__) +COLLECTION_NAME = "omnibox_index" + + +class WeaviateVectorDB: + def __init__(self, config: VectorConfig): + self.config: VectorConfig = config + self.batch_size: int = config.batch_size + self.openai = AsyncOpenAI( + api_key=config.embedding.api_key, base_url=config.embedding.base_url + ) + self.client: weaviate.WeaviateAsyncClient = ... + self._init_lock = asyncio.Lock() + self.dimension = config.dimension + + async def _ensure_client(self) -> None: + if self.client is not ...: + return + async with self._init_lock: + if self.client is not ...: + return + + connect_kwargs: dict[str, Any] = {"port": self.config.weaviate.port} + if self.config.weaviate.api_key: + connect_kwargs["auth_credentials"] = wvc.init.Auth.api_key( + self.config.weaviate.api_key + ) + if self.config.weaviate.host: + connect_kwargs["host"] = self.config.weaviate.host + client = weaviate.use_async_with_local(**connect_kwargs) + await client.connect() + + if not await client.collections.exists(COLLECTION_NAME): + try: + await client.collections.create( + name=COLLECTION_NAME, + vector_config=wvc.config.Configure.Vectors.self_provided(), + multi_tenancy_config=wvc.config.Configure.multi_tenancy( + enabled=True, auto_tenant_creation=True + ), + properties=[ + wvc.config.Property( + name="id", data_type=wvc.config.DataType.TEXT + ), + wvc.config.Property( + name="type", data_type=wvc.config.DataType.TEXT + ), + wvc.config.Property( + name="namespace_id", data_type=wvc.config.DataType.TEXT + ), + wvc.config.Property( + name="user_id", data_type=wvc.config.DataType.TEXT + ), + wvc.config.Property( + name="chunk", data_type=wvc.config.DataType.OBJECT + ), + wvc.config.Property( + name="message", data_type=wvc.config.DataType.OBJECT + ), + ], + ) + except UnexpectedStatusCodeError as e: + # Concurrent creator already created the collection. + if e.status_code != 422: + raise + self.client = client + return + + async def _get_shard(self, namespace_id: str): + if not namespace_id: + raise ValueError("namespace_id is required") + await self._ensure_client() + collection = self.client.collections.get(COLLECTION_NAME) + return collection.with_tenant(namespace_id) + + async def _embed(self, input_: str | list[str]) -> list[list[float]]: + headers = {} + propagate.inject(headers) + embeddings = await self.openai.embeddings.create( + model=self.config.embedding.model, input=input_, extra_headers=headers + ) + return [item.embedding for item in embeddings.data] + + async def _hybrid_query( + self, + namespace_id: str, + query: str, + condition: Condition, + limit: int, + offset: int = 0, + ) -> List[Tuple[dict, float]]: + collection = await self._get_shard(namespace_id) + vector = (await self._embed(query))[0] if query else None + + search_limit = limit + offset + try: + response = await collection.query.hybrid( + query=query or "", + vector=vector, + alpha=0.5 if query else 1.0, + filters=condition.to_weaviate_filters(), + limit=search_limit, + return_metadata=wvc.query.MetadataQuery.full(), + ) + except UnexpectedStatusCodeError as e: + # Tenant not found -> no data yet. + if e.status_code == 422: + return [] + raise + + hits: list[Tuple[dict, float]] = [] + for obj in response.objects: + properties = obj.properties or {} + score = 0.0 + if obj.metadata and obj.metadata.score is not None: + score = float(obj.metadata.score) + elif obj.metadata and obj.metadata.certainty is not None: + score = float(obj.metadata.certainty) + hits.append((properties, score)) + + hits.sort(key=lambda x: x[1], reverse=True) + return hits[offset : offset + limit] + + @tracer.start_as_current_span("WeaviateVectorDB.insert_chunks") + async def insert_chunks(self, namespace_id: str, chunk_list: List[Chunk]): + collection = await self._get_shard(namespace_id) + + for i in range(0, len(chunk_list), self.batch_size): + raw_batch = chunk_list[i : i + self.batch_size] + batch: List[Chunk] = [] + prompts: list[str] = [] + for x in raw_batch: + prompt = x.to_prompt() + if prompt: + batch.append(x) + prompts.append(prompt) + if not batch: + continue + + vectors = await self._embed(prompts) + objects = [] + for chunk, vector in zip(batch, vectors): + record = IndexRecord( + id=f"chunk_{chunk.chunk_id}", + type=IndexRecordType.chunk, + namespace_id=namespace_id, + chunk=chunk, + ) + objects.append( + wvc.data.DataObject( + properties=record.model_dump(exclude_none=True), + vector=vector, + ) + ) + await collection.data.insert_many(objects) + + @tracer.start_as_current_span("WeaviateVectorDB.upsert_message") + async def upsert_message(self, namespace_id: str, user_id: str, message: Message): + collection = await self._get_shard(namespace_id) + record_id = f"message_{message.message_id}" + + if not message.message.content.strip(): + await collection.data.delete_many( + where=wvc.query.Filter.by_property("id").equal(record_id) + ) + return + + vector = (await self._embed(message.message.content or ""))[0] + record = IndexRecord( + id=record_id, + type=IndexRecordType.message, + namespace_id=namespace_id, + user_id=user_id, + message=message, + ) + + # Upsert via delete-then-insert to keep the API behavior simple. + await collection.data.delete_many( + where=wvc.query.Filter.by_property("id").equal(record_id) + ) + + await collection.data.insert( + properties=record.model_dump(exclude_none=True), vector=vector + ) + + @tracer.start_as_current_span("WeaviateVectorDB.remove_conversation") + async def remove_conversation(self, namespace_id: str, conversation_id: str): + collection = await self._get_shard(namespace_id) + try: + ret = await collection.data.delete_many( + where=wvc.query.Filter.by_property("type").equal( + IndexRecordType.message.value + ) + & wvc.query.Filter.by_property("namespace_id").equal(namespace_id) + & wvc.query.Filter.by_property("message.conversation_id").equal( + conversation_id + ) + ) + except UnexpectedStatusCodeError as e: + if e.status_code == 422: + return + raise + + @tracer.start_as_current_span("WeaviateVectorDB.remove_chunks") + async def remove_chunks(self, namespace_id: str, resource_id: str): + collection = await self._get_shard(namespace_id) + try: + ret = await collection.data.delete_many( + where=wvc.query.Filter.by_property("type").equal( + IndexRecordType.chunk.value + ) + & wvc.query.Filter.by_property("namespace_id").equal(namespace_id) + & wvc.query.Filter.by_property("chunk.resource_id").equal(resource_id) + ) + except UnexpectedStatusCodeError as e: + if e.status_code == 422: + return + raise + + @tracer.start_as_current_span("WeaviateVectorDB.search") + async def search( + self, + query: str, + namespace_id: str, + user_id: str | None, + record_type: IndexRecordType | None, + offset: int, + limit: int, + ) -> List[IndexRecord]: + condition = Condition( + namespace_id=namespace_id, + user_id=user_id, + record_type=record_type.value if record_type else None, + ) + + hits = await self._hybrid_query( + namespace_id=namespace_id, + query=query, + condition=condition, + limit=limit, + offset=offset, + ) + return [IndexRecord(**hit) for hit, _ in hits] + + @tracer.start_as_current_span("WeaviateVectorDB.query_chunks") + async def query_chunks( + self, + namespace_id: str, + query: str, + k: int, + condition: Condition, + ) -> List[Tuple[Chunk, float]]: + combined_condition = condition.model_copy( + update={"record_type": IndexRecordType.chunk.value} + ) + hits = await self._hybrid_query( + namespace_id=namespace_id, + query=query, + condition=combined_condition, + limit=k, + ) + output: List[Tuple[Chunk, float]] = [] + for hit, score in hits: + chunk_data = hit.get("chunk") + if chunk_data: + output.append((Chunk(**chunk_data), score)) + return output + + +class WeaviateVectorRetriever(BaseRetriever): + def __init__(self, config: VectorConfig): + self.vector_db = WeaviateVectorDB(config) + + @staticmethod + def get_folder(resource_id: str, resources: list[Resource]) -> str | None: + for resource in resources: + if ( + resource.type == PrivateSearchResourceType.FOLDER + and resource.child_ids + and resource_id in resource.child_ids + ): + return resource.name + return None + + @staticmethod + def get_type( + resource_id: str, resources: list[Resource] + ) -> PrivateSearchResourceType | None: + for resource in resources: + if resource.id == resource_id: + return resource.type + return None + + def get_function( + self, private_search_tool: PrivateSearchTool, **kwargs + ) -> SearchFunction: + return partial( + self.query, private_search_tool=private_search_tool, k=40, **kwargs + ) + + @classmethod + def get_schema(cls) -> dict: + return cls.generate_schema( + "private_search", + 'Search for user\'s private & personal resources. Return in format.', + display_name={"zh": "知识库搜索", "en": "Knowledge Base Search"}, + ) + + @tracer.start_as_current_span("WeaviateVectorRetriever.query") + async def query( + self, + query: str, + k: int, + *, + private_search_tool: PrivateSearchTool, + trace_info: TraceInfo | None = None, + ) -> list[ResourceChunkRetrieval]: + condition: Condition = private_search_tool.to_condition() + recall_result_list = await self.vector_db.query_chunks( + private_search_tool.namespace_id, query, k, condition + ) + retrievals: List[ResourceChunkRetrieval] = [ + ResourceChunkRetrieval( + chunk=chunk, + folder=self.get_folder( + chunk.resource_id, private_search_tool.resources or [] + ), + type=self.get_type( + chunk.resource_id, private_search_tool.visible_resources or [] + ), + namespace_id=private_search_tool.namespace_id, + score=Score(recall=score, rerank=0), + ) + for chunk, score in recall_result_list + ] + trace_info and trace_info.debug( + { + "where": condition.to_weaviate_filters(), + "condition": condition.model_dump() if condition else condition, + "len(retrievals)": len(retrievals), + } + ) + return retrievals From 67dde0a4967d1feea01c6eed6cdaa98670bd60f7 Mon Sep 17 00:00:00 2001 From: Yichen Zhao Date: Wed, 25 Mar 2026 17:05:39 +0800 Subject: [PATCH 2/9] refactor(weaviate): update weaviate_vector_db --- grimoire/config.py | 1 - grimoire/entity/index_record.py | 1 - grimoire/entity/tools.py | 29 -- grimoire/retriever/meili_vector_db.py | 484 ----------------------- grimoire/retriever/weaviate_vector_db.py | 28 +- 5 files changed, 14 insertions(+), 529 deletions(-) delete mode 100644 grimoire/retriever/meili_vector_db.py diff --git a/grimoire/config.py b/grimoire/config.py index fb39612..bea4dff 100644 --- a/grimoire/config.py +++ b/grimoire/config.py @@ -15,7 +15,6 @@ class VectorConfig(BaseModel): embedding: OpenAIConfig host: str port: int = Field(default=8000) - meili_api_key: str = Field(default=None) weaviate: WeaviateConfig = Field(default_factory=WeaviateConfig) batch_size: int = Field(default=1) max_results: int = Field(default=10) diff --git a/grimoire/entity/index_record.py b/grimoire/entity/index_record.py index da4f1c1..fb902bc 100644 --- a/grimoire/entity/index_record.py +++ b/grimoire/entity/index_record.py @@ -11,7 +11,6 @@ class IndexRecordType(str, Enum): class IndexRecord(BaseModel): - id: str type: IndexRecordType namespace_id: str user_id: str | None = None diff --git a/grimoire/entity/tools.py b/grimoire/entity/tools.py index 821fbe8..2fee3aa 100644 --- a/grimoire/entity/tools.py +++ b/grimoire/entity/tools.py @@ -20,35 +20,6 @@ class Condition(BaseModel): created_at: tuple[float, float] | None = Field(default=None) updated_at: tuple[float, float] | None = Field(default=None) - def to_meili_where(self) -> List[str | List[str]]: - and_clause: List[str | List[str]] = [ - 'namespace_id = "{}"'.format(self.namespace_id) - ] - or_clause: List[str] = [] - if self.resource_ids: - or_clause.append( - "chunk.resource_id IN [{}]".format( - ", ".join('"{}"'.format(rid) for rid in self.resource_ids) - ) - ) - if self.parent_ids: - or_clause.append( - "chunk.parent_id IN [{}]".format( - ", ".join('"{}"'.format(pid) for pid in self.parent_ids) - ) - ) - if or_clause: - and_clause.append(or_clause) - - if self.created_at is not None: - and_clause.append("chunk.created_at >= {}".format(self.created_at[0])) - and_clause.append("chunk.created_at <= {}".format(self.created_at[1])) - if self.updated_at is not None: - and_clause.append("chunk.updated_at >= {}".format(self.updated_at[0])) - and_clause.append("chunk.updated_at <= {}".format(self.updated_at[1])) - - return and_clause - def to_weaviate_filters(self) -> Any: where = wvc.query.Filter.by_property("namespace_id").equal(self.namespace_id) diff --git a/grimoire/retriever/meili_vector_db.py b/grimoire/retriever/meili_vector_db.py deleted file mode 100644 index fd016f1..0000000 --- a/grimoire/retriever/meili_vector_db.py +++ /dev/null @@ -1,484 +0,0 @@ -import asyncio -from functools import partial -from hashlib import md5 -from typing import List, Tuple - -from meilisearch_python_sdk import AsyncClient -from meilisearch_python_sdk.errors import MeilisearchApiError -from meilisearch_python_sdk.models.search import Hybrid -from meilisearch_python_sdk.models.settings import ( - Embedders, - Filter, - FilterableAttributeFeatures, - FilterableAttributes, - UserProvidedEmbedder, -) -from meilisearch_python_sdk.models.task import TaskInfo -from openai import AsyncOpenAI -from opentelemetry import propagate, trace - -from common.trace_info import TraceInfo -from wizard_common.grimoire.config import VectorConfig -from wizard_common.grimoire.entity.chunk import Chunk, ResourceChunkRetrieval -from wizard_common.grimoire.entity.index_record import ( - IndexRecord, - IndexRecordType, -) -from wizard_common.grimoire.entity.message import Message -from wizard_common.grimoire.entity.retrieval import Score -from wizard_common.grimoire.entity.tools import ( - Condition, - PrivateSearchResourceType, - PrivateSearchTool, - Resource, -) -from wizard_common.grimoire.retriever.base import BaseRetriever, SearchFunction - -tracer = trace.get_tracer(__name__) - - -def to_filterable_attributes( - filter_: str, comparison: bool = False -) -> FilterableAttributes: - """Convert a string filter to FilterableAttributes.""" - return FilterableAttributes( - attribute_patterns=[filter_], - features=FilterableAttributeFeatures( - facet_search=False, - filter=Filter(equality=True, comparison=comparison), - ), - ) - - -def sharded_index_uid(idx: int) -> str: - return f"omnibox-index-{idx}" - - -class MeiliVectorDB: - def __init__(self, config: VectorConfig): - self.config: VectorConfig = config - self.batch_size: int = config.batch_size - self.openai = AsyncOpenAI( - api_key=config.embedding.api_key, base_url=config.embedding.base_url - ) - self.meili: AsyncClient = ... - self.index_uid = "omniboxIndex" - self.num_shards = 20 - self.embedder_name = "omniboxEmbed" - self.dimension = config.dimension - self.has_old_index = False - - async def get_or_init_client(self) -> AsyncClient: - """Get the initialized MeiliSearch client.""" - if self.meili is ...: - client = AsyncClient(self.config.host, self.config.meili_api_key) - try: - await client.get_index(self.index_uid) - self.has_old_index = True - except MeilisearchApiError as e: - if e.status_code == 404: - self.has_old_index = False - else: - raise - for i in range(self.num_shards): - await self.init_shard_index(client, sharded_index_uid(i)) - self.meili = client - return self.meili - - def get_shard(self, namespace_id: str): - if not namespace_id: - raise ValueError("namespace_id is required") - h = md5(namespace_id.encode("utf-8")).digest() - idx = int.from_bytes(h[:4], byteorder="big") - idx %= self.num_shards - return sharded_index_uid(idx) - - async def init_shard_index(self, client: AsyncClient, index_uid: str): - """Initialize a single shard index with proper settings.""" - index = await client.get_or_create_index(index_uid) - - cur_filters: List[FilterableAttributes] = [] - for f in await index.get_filterable_attributes() or []: - if isinstance(f, FilterableAttributes): - cur_filters.append(f) - elif isinstance(f, str): - cur_filters.append(to_filterable_attributes(f)) - else: - raise ValueError( - f"Unexpected filterable attribute type: {type(f)}. Expected str or FilterableAttributes." - ) - - expected_filters = [ - "namespace_id", - "user_id", - "type", - "chunk.resource_id", - "chunk.parent_id", - "chunk.created_at", - "chunk.updated_at", - "message.conversation_id", - ] - comparison_filters = [ - "chunk.created_at", - "chunk.updated_at", - ] - missing_filters: List[FilterableAttributes] = [] - for expected_filter in expected_filters: - found = False - for cur_filter in cur_filters: - if expected_filter in cur_filter.attribute_patterns: - found = True - break - if not found: - missing_filters.append( - to_filterable_attributes( - expected_filter, - comparison=expected_filter in comparison_filters, - ) - ) - - if missing_filters: - new_filters = cur_filters + missing_filters - await index.update_filterable_attributes(new_filters) - - embedders = await index.get_embedders() - if not embedders or self.embedder_name not in embedders.embedders: - await index.update_embedders( - Embedders( - embedders={ - self.embedder_name: UserProvidedEmbedder( - dimensions=self.dimension - ) - } - ) - ) - - @tracer.start_as_current_span("MeiliVectorDB.insert_chunks") - async def insert_chunks( - self, namespace_id: str, chunk_list: List[Chunk], tasks: List[TaskInfo] - ): - client = await self.get_or_init_client() - index = client.index(self.get_shard(namespace_id)) - for i in range(0, len(chunk_list), self.batch_size): - raw_batch = chunk_list[i : i + self.batch_size] - - batch: List[Chunk] = [] - prompts: list[str] = [] - for x in raw_batch: - prompt: str = x.to_prompt() - if prompt: - batch.append(x) - prompts.append(prompt) - - headers = {} - propagate.inject(headers) - - embeddings = await self.openai.embeddings.create( - model=self.config.embedding.model, input=prompts, extra_headers=headers - ) - records = [] - for chunk, embed in zip(batch, embeddings.data): - record = IndexRecord( - id="chunk_{}".format(chunk.chunk_id), - type=IndexRecordType.chunk, - namespace_id=namespace_id, - chunk=chunk, - _vectors={ - self.embedder_name: embed.embedding, - }, - ) - records.append(record.model_dump(by_alias=True)) - tasks.append(await index.add_documents(records, primary_key="id")) - - @tracer.start_as_current_span("MeiliVectorDB.upsert_message") - async def upsert_message( - self, namespace_id: str, user_id: str, message: Message, tasks: List[TaskInfo] - ): - client = await self.get_or_init_client() - index = client.index(self.get_shard(namespace_id)) - record_id = "message_{}".format(message.message_id) - - if not message.message.content.strip(): - await index.delete_document(record_id) - return - - headers = {} - propagate.inject(headers) - - embedding = await self.openai.embeddings.create( - model=self.config.embedding.model, - input=message.message.content or "", - extra_headers=headers, - ) - record = IndexRecord( - id=record_id, - type=IndexRecordType.message, - namespace_id=namespace_id, - user_id=user_id, - message=message, - _vectors={ - self.embedder_name: embedding.data[0].embedding, - }, - ) - task = await index.add_documents( - [record.model_dump(by_alias=True)], primary_key="id" - ) - tasks.append(task) - - @tracer.start_as_current_span("MeiliVectorDB.remove_conversation") - async def remove_conversation( - self, namespace_id: str, conversation_id: str, tasks: List[TaskInfo] - ): - await self.delete_from_both_indexes( - namespace_id, - filter_=[ - "type = {}".format(IndexRecordType.message.value), - "namespace_id = {}".format(namespace_id), - "message.conversation_id = {}".format(conversation_id), - ], - tasks=tasks, - ) - - @tracer.start_as_current_span("MeiliVectorDB.remove_chunks") - async def remove_chunks( - self, namespace_id: str, resource_id: str, tasks: List[TaskInfo] - ): - await self.delete_from_both_indexes( - namespace_id, - filter_=[ - "type = {}".format(IndexRecordType.chunk.value), - "namespace_id = {}".format(namespace_id), - "chunk.resource_id = {}".format(resource_id), - ], - tasks=tasks, - ) - - @tracer.start_as_current_span("MeiliVectorDB.vector_params") - async def vector_params(self, query: str) -> dict: - if query: - headers = {} - propagate.inject(headers) - - embedding = await self.openai.embeddings.create( - model=self.config.embedding.model, input=query, extra_headers=headers - ) - vector = embedding.data[0].embedding - hybrid = Hybrid(embedder=self.embedder_name, semantic_ratio=0.5) - return { - "vector": vector, - "hybrid": hybrid, - } - return {} - - @tracer.start_as_current_span("MeiliVectorDB.query_both_indexes") - async def query_both_indexes( - self, - namespace_id: str, - query: str, - filter_: List[str | List[str]], - limit: int, - vector_params: dict, - **search_kwargs, - ) -> List[dict]: - client = await self.get_or_init_client() - - hits = [] - search_tasks = [] - - if self.has_old_index: - old_index = client.index(self.index_uid) - task = old_index.search( - query, - filter=filter_, - limit=limit, - **vector_params, - **search_kwargs, - show_ranking_score=True, - ) - search_tasks.append(task) - - index = client.index(self.get_shard(namespace_id)) - task = index.search( - query, - filter=filter_, - limit=limit, - **vector_params, - **search_kwargs, - show_ranking_score=True, - ) - search_tasks.append(task) - - results = await asyncio.gather(*search_tasks) - for result in results: - hits.extend(result.hits) - - hits.sort(key=lambda x: x.get("_rankingScore", 0), reverse=True) - return hits[:limit] - - @tracer.start_as_current_span("MeiliVectorDB.delete_from_both_indexes") - async def delete_from_both_indexes( - self, namespace_id: str, filter_: List[str | List[str]], tasks: List[TaskInfo] - ): - client = await self.get_or_init_client() - - if self.has_old_index: - old_index = client.index(self.index_uid) - tasks.append(await old_index.delete_documents_by_filter(filter=filter_)) - - index = client.index(self.get_shard(namespace_id)) - tasks.append(await index.delete_documents_by_filter(filter=filter_)) - - @tracer.start_as_current_span("MeiliVectorDB.search") - async def search( - self, - query: str, - namespace_id: str, - user_id: str | None, - record_type: IndexRecordType | None, - offset: int, - limit: int, - ) -> List[IndexRecord]: - filter_: List[str | List[str]] = [] - filter_.append("namespace_id = {}".format(namespace_id)) - if user_id: - filter_.append( - "user_id NOT EXISTS OR user_id IS NULL OR user_id = {}".format(user_id) - ) - if record_type: - filter_.append("type = {}".format(record_type.value)) - vector_params: dict = await self.vector_params(query) - - hits = await self.query_both_indexes( - namespace_id, - query, - filter_, - offset + limit, - vector_params, - ) - return [IndexRecord(**hit) for hit in hits[offset:]] - - @tracer.start_as_current_span("MeiliVectorDB.query_chunks") - async def query_chunks( - self, - namespace_id: str, - query: str, - k: int, - filter_: List[str | List[str]], - ) -> List[Tuple[Chunk, float]]: - combined_filters = filter_ + ["type = {}".format(IndexRecordType.chunk.value)] - vector_params: dict = await self.vector_params(query) - hits = await self.query_both_indexes( - namespace_id, - query, - combined_filters, - k, - vector_params, - ) - output: List[Tuple[Chunk, float]] = [] - for hit in hits: - chunk_data = hit["chunk"] - score = hit["_rankingScore"] - if chunk_data: - chunk = Chunk(**chunk_data) - output.append((chunk, score)) - return output - - async def wait_for_tasks(self, tasks: List[TaskInfo]): - if self.config.wait_timeout > 0: - client = await self.get_or_init_client() - await asyncio.gather( - *[ - client.wait_for_task( - task.task_uid, - timeout_in_ms=self.config.wait_timeout, - interval_in_ms=500, - ) - for task in tasks - ] - ) - - -class MeiliVectorRetriever(BaseRetriever): - def __init__(self, config: VectorConfig): - self.vector_db = MeiliVectorDB(config) - - @staticmethod - def get_folder(resource_id: str, resources: list[Resource]) -> str | None: - for resource in resources: - if ( - resource.type == PrivateSearchResourceType.FOLDER - and resource_id in resource.child_ids - ): - return resource.name - return None - - @staticmethod - def get_type( - resource_id: str, resources: list[Resource] - ) -> PrivateSearchResourceType | None: - for resource in resources: - if resource.id == resource_id: - return resource.type - return None - - def get_function( - self, private_search_tool: PrivateSearchTool, **kwargs - ) -> SearchFunction: - return partial( - self.query, private_search_tool=private_search_tool, k=40, **kwargs - ) - - @classmethod - def get_schema(cls) -> dict: - return cls.generate_schema( - "private_search", - 'Search for user\'s private & personal resources. Return in format.', - display_name={"zh": "知识库搜索", "en": "Knowledge Base Search"}, - ) - - @tracer.start_as_current_span("MeiliVectorRetriever.query") - async def query( - self, - query: str, - k: int, - *, - private_search_tool: PrivateSearchTool, - trace_info: TraceInfo | None = None, - ) -> list[ResourceChunkRetrieval]: - condition: Condition = private_search_tool.to_condition() - where = condition.to_meili_where() - if len(where) == 0: - trace_info and trace_info.warning( - { - "warning": "empty_where", - "where": where, - "condition": condition.model_dump() if condition else condition, - } - ) - return [] - - recall_result_list = await self.vector_db.query_chunks( - private_search_tool.namespace_id, query, k, where - ) - retrievals: List[ResourceChunkRetrieval] = [ - ResourceChunkRetrieval( - chunk=chunk, - folder=self.get_folder( - chunk.resource_id, private_search_tool.resources or [] - ), - type=self.get_type( - chunk.resource_id, private_search_tool.visible_resources or [] - ), - namespace_id=private_search_tool.namespace_id, - score=Score(recall=score, rerank=0), - ) - for chunk, score in recall_result_list - ] - trace_info and trace_info.debug( - { - "where": where, - "condition": condition.model_dump() if condition else condition, - "len(retrievals)": len(retrievals), - } - ) - return retrievals diff --git a/grimoire/retriever/weaviate_vector_db.py b/grimoire/retriever/weaviate_vector_db.py index 8d3480f..e6ec5ee 100644 --- a/grimoire/retriever/weaviate_vector_db.py +++ b/grimoire/retriever/weaviate_vector_db.py @@ -63,9 +63,6 @@ async def _ensure_client(self) -> None: enabled=True, auto_tenant_creation=True ), properties=[ - wvc.config.Property( - name="id", data_type=wvc.config.DataType.TEXT - ), wvc.config.Property( name="type", data_type=wvc.config.DataType.TEXT ), @@ -165,7 +162,6 @@ async def insert_chunks(self, namespace_id: str, chunk_list: List[Chunk]): objects = [] for chunk, vector in zip(batch, vectors): record = IndexRecord( - id=f"chunk_{chunk.chunk_id}", type=IndexRecordType.chunk, namespace_id=namespace_id, chunk=chunk, @@ -181,28 +177,30 @@ async def insert_chunks(self, namespace_id: str, chunk_list: List[Chunk]): @tracer.start_as_current_span("WeaviateVectorDB.upsert_message") async def upsert_message(self, namespace_id: str, user_id: str, message: Message): collection = await self._get_shard(namespace_id) - record_id = f"message_{message.message_id}" - if not message.message.content.strip(): + try: await collection.data.delete_many( - where=wvc.query.Filter.by_property("id").equal(record_id) + where=wvc.query.Filter.by_property("message.message_id").equal( + message.message_id + ) ) + except UnexpectedStatusCodeError as e: + # 422: Tenant not found (no data yet for this namespace) + if e.status_code != 422: + raise + + message_content = message.message.content.strip() + if not message_content: return - vector = (await self._embed(message.message.content or ""))[0] + vector = (await self._embed(message_content))[0] record = IndexRecord( - id=record_id, type=IndexRecordType.message, namespace_id=namespace_id, user_id=user_id, message=message, ) - # Upsert via delete-then-insert to keep the API behavior simple. - await collection.data.delete_many( - where=wvc.query.Filter.by_property("id").equal(record_id) - ) - await collection.data.insert( properties=record.model_dump(exclude_none=True), vector=vector ) @@ -221,6 +219,7 @@ async def remove_conversation(self, namespace_id: str, conversation_id: str): ) ) except UnexpectedStatusCodeError as e: + # 422: Tenant not found (no data yet for this namespace) if e.status_code == 422: return raise @@ -237,6 +236,7 @@ async def remove_chunks(self, namespace_id: str, resource_id: str): & wvc.query.Filter.by_property("chunk.resource_id").equal(resource_id) ) except UnexpectedStatusCodeError as e: + # 422: Tenant not found (no data yet for this namespace) if e.status_code == 422: return raise From 6f0a1547b993cc446c3207f9536238a108dc0203 Mon Sep 17 00:00:00 2001 From: Yichen Zhao Date: Wed, 25 Mar 2026 17:28:52 +0800 Subject: [PATCH 3/9] refactor(weaviate): update fields --- grimoire/retriever/weaviate_vector_db.py | 54 +++++++++++++++++++++--- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/grimoire/retriever/weaviate_vector_db.py b/grimoire/retriever/weaviate_vector_db.py index e6ec5ee..3a82c59 100644 --- a/grimoire/retriever/weaviate_vector_db.py +++ b/grimoire/retriever/weaviate_vector_db.py @@ -64,19 +64,63 @@ async def _ensure_client(self) -> None: ), properties=[ wvc.config.Property( - name="type", data_type=wvc.config.DataType.TEXT + name="type", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, ), wvc.config.Property( - name="namespace_id", data_type=wvc.config.DataType.TEXT + name="namespace_id", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, ), wvc.config.Property( - name="user_id", data_type=wvc.config.DataType.TEXT + name="user_id", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, ), wvc.config.Property( - name="chunk", data_type=wvc.config.DataType.OBJECT + name="chunk", + data_type=wvc.config.DataType.OBJECT, + nested_properties=[ + wvc.config.Property( + name="resource_id", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, + ), + wvc.config.Property( + name="parent_id", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, + ), + wvc.config.Property( + name="created_at", + data_type=wvc.config.DataType.NUMBER, + index_filterable=True, + index_range_filters=True, + ), + wvc.config.Property( + name="updated_at", + data_type=wvc.config.DataType.NUMBER, + index_filterable=True, + index_range_filters=True, + ), + ], ), wvc.config.Property( - name="message", data_type=wvc.config.DataType.OBJECT + name="message", + data_type=wvc.config.DataType.OBJECT, + nested_properties=[ + wvc.config.Property( + name="message_id", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, + ), + wvc.config.Property( + name="conversation_id", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, + ), + ], ), ], ) From 82279febebc62ece46030029b3abb91283223763 Mon Sep 17 00:00:00 2001 From: Yichen Zhao Date: Wed, 25 Mar 2026 18:13:19 +0800 Subject: [PATCH 4/9] refactor(weaviate): update weaviate_vector_db --- grimoire/retriever/weaviate_vector_db.py | 141 +++++++++++------------ 1 file changed, 66 insertions(+), 75 deletions(-) diff --git a/grimoire/retriever/weaviate_vector_db.py b/grimoire/retriever/weaviate_vector_db.py index 3a82c59..a85647b 100644 --- a/grimoire/retriever/weaviate_vector_db.py +++ b/grimoire/retriever/weaviate_vector_db.py @@ -6,7 +6,7 @@ import weaviate.classes as wvc from openai import AsyncOpenAI from opentelemetry import propagate, trace -from weaviate.exceptions import UnexpectedStatusCodeError +from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateDeleteManyError from common.trace_info import TraceInfo from wizard_common.grimoire.config import VectorConfig @@ -53,83 +53,79 @@ async def _ensure_client(self) -> None: connect_kwargs["host"] = self.config.weaviate.host client = weaviate.use_async_with_local(**connect_kwargs) await client.connect() + self.client = client + + if await client.collections.exists(COLLECTION_NAME): + return - if not await client.collections.exists(COLLECTION_NAME): - try: - await client.collections.create( - name=COLLECTION_NAME, - vector_config=wvc.config.Configure.Vectors.self_provided(), - multi_tenancy_config=wvc.config.Configure.multi_tenancy( - enabled=True, auto_tenant_creation=True - ), - properties=[ + await self.client.collections.create( + name=COLLECTION_NAME, + vector_config=wvc.config.Configure.Vectors.self_provided(), + multi_tenancy_config=wvc.config.Configure.multi_tenancy( + enabled=True, auto_tenant_creation=True + ), + properties=[ + wvc.config.Property( + name="type", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, + ), + wvc.config.Property( + name="namespace_id", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, + ), + wvc.config.Property( + name="user_id", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, + ), + wvc.config.Property( + name="chunk", + data_type=wvc.config.DataType.OBJECT, + nested_properties=[ wvc.config.Property( - name="type", + name="resource_id", data_type=wvc.config.DataType.TEXT, index_filterable=True, ), wvc.config.Property( - name="namespace_id", + name="parent_id", data_type=wvc.config.DataType.TEXT, index_filterable=True, ), wvc.config.Property( - name="user_id", - data_type=wvc.config.DataType.TEXT, + name="created_at", + data_type=wvc.config.DataType.NUMBER, index_filterable=True, + index_range_filters=True, ), wvc.config.Property( - name="chunk", - data_type=wvc.config.DataType.OBJECT, - nested_properties=[ - wvc.config.Property( - name="resource_id", - data_type=wvc.config.DataType.TEXT, - index_filterable=True, - ), - wvc.config.Property( - name="parent_id", - data_type=wvc.config.DataType.TEXT, - index_filterable=True, - ), - wvc.config.Property( - name="created_at", - data_type=wvc.config.DataType.NUMBER, - index_filterable=True, - index_range_filters=True, - ), - wvc.config.Property( - name="updated_at", - data_type=wvc.config.DataType.NUMBER, - index_filterable=True, - index_range_filters=True, - ), - ], + name="updated_at", + data_type=wvc.config.DataType.NUMBER, + index_filterable=True, + index_range_filters=True, ), + ], + ), + wvc.config.Property( + name="message", + data_type=wvc.config.DataType.OBJECT, + nested_properties=[ wvc.config.Property( - name="message", - data_type=wvc.config.DataType.OBJECT, - nested_properties=[ - wvc.config.Property( - name="message_id", - data_type=wvc.config.DataType.TEXT, - index_filterable=True, - ), - wvc.config.Property( - name="conversation_id", - data_type=wvc.config.DataType.TEXT, - index_filterable=True, - ), - ], + name="message_id", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, + ), + wvc.config.Property( + name="conversation_id", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, ), ], - ) - except UnexpectedStatusCodeError as e: - # Concurrent creator already created the collection. - if e.status_code != 422: - raise - self.client = client - return + ), + ], + ) async def _get_shard(self, namespace_id: str): if not namespace_id: @@ -228,10 +224,9 @@ async def upsert_message(self, namespace_id: str, user_id: str, message: Message message.message_id ) ) - except UnexpectedStatusCodeError as e: - # 422: Tenant not found (no data yet for this namespace) - if e.status_code != 422: - raise + except WeaviateDeleteManyError: + # Tenant not found (no data yet for this namespace) + pass message_content = message.message.content.strip() if not message_content: @@ -262,11 +257,9 @@ async def remove_conversation(self, namespace_id: str, conversation_id: str): conversation_id ) ) - except UnexpectedStatusCodeError as e: - # 422: Tenant not found (no data yet for this namespace) - if e.status_code == 422: - return - raise + except WeaviateDeleteManyError: + # Tenant not found (no data yet for this namespace) + pass @tracer.start_as_current_span("WeaviateVectorDB.remove_chunks") async def remove_chunks(self, namespace_id: str, resource_id: str): @@ -279,11 +272,9 @@ async def remove_chunks(self, namespace_id: str, resource_id: str): & wvc.query.Filter.by_property("namespace_id").equal(namespace_id) & wvc.query.Filter.by_property("chunk.resource_id").equal(resource_id) ) - except UnexpectedStatusCodeError as e: - # 422: Tenant not found (no data yet for this namespace) - if e.status_code == 422: - return - raise + except WeaviateDeleteManyError: + # Tenant not found (no data yet for this namespace) + pass @tracer.start_as_current_span("WeaviateVectorDB.search") async def search( From 9977abc71af236dd060bac24cc828c2e72586943 Mon Sep 17 00:00:00 2001 From: Yichen Zhao Date: Wed, 25 Mar 2026 18:32:44 +0800 Subject: [PATCH 5/9] refactor(weaviate): update weaviate config --- grimoire/retriever/weaviate_vector_db.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/grimoire/retriever/weaviate_vector_db.py b/grimoire/retriever/weaviate_vector_db.py index a85647b..4ced219 100644 --- a/grimoire/retriever/weaviate_vector_db.py +++ b/grimoire/retriever/weaviate_vector_db.py @@ -64,6 +64,9 @@ async def _ensure_client(self) -> None: multi_tenancy_config=wvc.config.Configure.multi_tenancy( enabled=True, auto_tenant_creation=True ), + inverted_index_config=wvc.config.Configure.inverted_index( + index_null_state=True + ), properties=[ wvc.config.Property( name="type", From 596a3e52b105c242c22eb5abff7ca068d2b03b5e Mon Sep 17 00:00:00 2001 From: Yichen Zhao Date: Thu, 26 Mar 2026 17:52:48 +0800 Subject: [PATCH 6/9] refactor(chunk): remove user_id field from Chunk entity --- grimoire/entity/chunk.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/grimoire/entity/chunk.py b/grimoire/entity/chunk.py index 5084c3e..41626aa 100644 --- a/grimoire/entity/chunk.py +++ b/grimoire/entity/chunk.py @@ -1,16 +1,15 @@ import time from datetime import datetime from enum import Enum -from typing import Optional, Literal +from typing import Literal, Optional import shortuuid from pydantic import Field - from wizard_common.grimoire.entity.retrieval import ( BaseRetrieval, Citation, - to_prompt, PromptContext, + to_prompt, ) from wizard_common.grimoire.entity.tools import PrivateSearchResourceType @@ -36,7 +35,6 @@ class Chunk(PromptContext): text: str | None = Field(default=None, description="Chunk content") chunk_type: ChunkType = Field(description="Chunk type") - user_id: str parent_id: str chunk_id: str = Field(description="ID of chunk", default_factory=shortuuid.uuid) From 80cd4a191913fde023e372bc2c58c45d65593553 Mon Sep 17 00:00:00 2001 From: Yichen Zhao Date: Thu, 26 Mar 2026 20:55:38 +0800 Subject: [PATCH 7/9] feat(retriever): restore meili_vector_db and update for new index_record structure --- grimoire/config.py | 1 + grimoire/entity/tools.py | 41 +++ grimoire/retriever/meili_vector_db.py | 486 ++++++++++++++++++++++++++ 3 files changed, 528 insertions(+) create mode 100644 grimoire/retriever/meili_vector_db.py diff --git a/grimoire/config.py b/grimoire/config.py index bea4dff..3dc1346 100644 --- a/grimoire/config.py +++ b/grimoire/config.py @@ -15,6 +15,7 @@ class VectorConfig(BaseModel): embedding: OpenAIConfig host: str port: int = Field(default=8000) + meili_api_key: str | None = Field(default=None) weaviate: WeaviateConfig = Field(default_factory=WeaviateConfig) batch_size: int = Field(default=1) max_results: int = Field(default=10) diff --git a/grimoire/entity/tools.py b/grimoire/entity/tools.py index 2fee3aa..33b6f96 100644 --- a/grimoire/entity/tools.py +++ b/grimoire/entity/tools.py @@ -20,6 +20,47 @@ class Condition(BaseModel): created_at: tuple[float, float] | None = Field(default=None) updated_at: tuple[float, float] | None = Field(default=None) + def to_meili_where(self) -> List[str | List[str]]: + and_clause: List[str | List[str]] = [ + 'namespace_id = "{}"'.format(self.namespace_id) + ] + + if self.user_id: + and_clause.append( + [ + 'user_id IS NULL', + 'user_id = "{}"'.format(self.user_id) + ] + ) + + if self.record_type: + and_clause.append('type = "{}"'.format(self.record_type)) + + or_clause: List[str] = [] + if self.resource_ids: + or_clause.append( + "chunk.resource_id IN [{}]".format( + ", ".join('"{}"'.format(rid) for rid in self.resource_ids) + ) + ) + if self.parent_ids: + or_clause.append( + "chunk.parent_id IN [{}]".format( + ", ".join('"{}"'.format(pid) for pid in self.parent_ids) + ) + ) + if or_clause: + and_clause.append(or_clause) + + if self.created_at is not None: + and_clause.append("chunk.created_at >= {}".format(self.created_at[0])) + and_clause.append("chunk.created_at <= {}".format(self.created_at[1])) + if self.updated_at is not None: + and_clause.append("chunk.updated_at >= {}".format(self.updated_at[0])) + and_clause.append("chunk.updated_at <= {}".format(self.updated_at[1])) + + return and_clause + def to_weaviate_filters(self) -> Any: where = wvc.query.Filter.by_property("namespace_id").equal(self.namespace_id) diff --git a/grimoire/retriever/meili_vector_db.py b/grimoire/retriever/meili_vector_db.py new file mode 100644 index 0000000..e7c45aa --- /dev/null +++ b/grimoire/retriever/meili_vector_db.py @@ -0,0 +1,486 @@ +import asyncio +from functools import partial +from hashlib import md5 +from typing import List, Tuple + +from meilisearch_python_sdk import AsyncClient +from meilisearch_python_sdk.errors import MeilisearchApiError +from meilisearch_python_sdk.models.search import Hybrid +from meilisearch_python_sdk.models.settings import ( + Embedders, + Filter, + FilterableAttributeFeatures, + FilterableAttributes, + UserProvidedEmbedder, +) +from meilisearch_python_sdk.models.task import TaskInfo +from openai import AsyncOpenAI +from opentelemetry import propagate, trace + +from common.trace_info import TraceInfo +from wizard_common.grimoire.config import VectorConfig +from wizard_common.grimoire.entity.chunk import Chunk, ResourceChunkRetrieval +from wizard_common.grimoire.entity.index_record import ( + IndexRecord, + IndexRecordType, +) +from wizard_common.grimoire.entity.message import Message +from wizard_common.grimoire.entity.retrieval import Score +from wizard_common.grimoire.entity.tools import ( + Condition, + PrivateSearchResourceType, + PrivateSearchTool, + Resource, +) +from wizard_common.grimoire.retriever.base import BaseRetriever, SearchFunction + +tracer = trace.get_tracer(__name__) + + +def to_filterable_attributes( + filter_: str, comparison: bool = False +) -> FilterableAttributes: + """Convert a string filter to FilterableAttributes.""" + return FilterableAttributes( + attribute_patterns=[filter_], + features=FilterableAttributeFeatures( + facet_search=False, + filter=Filter(equality=True, comparison=comparison), + ), + ) + + +def sharded_index_uid(idx: int) -> str: + return f"omnibox-index-{idx}" + + +class MeiliVectorDB: + def __init__(self, config: VectorConfig): + self.config: VectorConfig = config + self.batch_size: int = config.batch_size + self.openai = AsyncOpenAI( + api_key=config.embedding.api_key, base_url=config.embedding.base_url + ) + self.meili: AsyncClient = ... + self.index_uid = "omniboxIndex" + self.num_shards = 20 + self.embedder_name = "omniboxEmbed" + self.dimension = config.dimension + self.has_old_index = False + + async def get_or_init_client(self) -> AsyncClient: + """Get the initialized MeiliSearch client.""" + if self.meili is ...: + client = AsyncClient(self.config.host, self.config.meili_api_key) + try: + await client.get_index(self.index_uid) + self.has_old_index = True + except MeilisearchApiError as e: + if e.status_code == 404: + self.has_old_index = False + else: + raise + for i in range(self.num_shards): + await self.init_shard_index(client, sharded_index_uid(i)) + self.meili = client + return self.meili + + def get_shard(self, namespace_id: str): + if not namespace_id: + raise ValueError("namespace_id is required") + h = md5(namespace_id.encode("utf-8")).digest() + idx = int.from_bytes(h[:4], byteorder="big") + idx %= self.num_shards + return sharded_index_uid(idx) + + async def init_shard_index(self, client: AsyncClient, index_uid: str): + """Initialize a single shard index with proper settings.""" + index = await client.get_or_create_index(index_uid) + + cur_filters: List[FilterableAttributes] = [] + for f in await index.get_filterable_attributes() or []: + if isinstance(f, FilterableAttributes): + cur_filters.append(f) + elif isinstance(f, str): + cur_filters.append(to_filterable_attributes(f)) + else: + raise ValueError( + f"Unexpected filterable attribute type: {type(f)}. Expected str or FilterableAttributes." + ) + + expected_filters = [ + "namespace_id", + "user_id", + "type", + "chunk.resource_id", + "chunk.parent_id", + "chunk.created_at", + "chunk.updated_at", + "message.conversation_id", + ] + comparison_filters = [ + "chunk.created_at", + "chunk.updated_at", + ] + missing_filters: List[FilterableAttributes] = [] + for expected_filter in expected_filters: + found = False + for cur_filter in cur_filters: + if expected_filter in cur_filter.attribute_patterns: + found = True + break + if not found: + missing_filters.append( + to_filterable_attributes( + expected_filter, + comparison=expected_filter in comparison_filters, + ) + ) + + if missing_filters: + new_filters = cur_filters + missing_filters + await index.update_filterable_attributes(new_filters) + + embedders = await index.get_embedders() + if not embedders or self.embedder_name not in embedders.embedders: + await index.update_embedders( + Embedders( + embedders={ + self.embedder_name: UserProvidedEmbedder( + dimensions=self.dimension + ) + } + ) + ) + + @tracer.start_as_current_span("MeiliVectorDB.insert_chunks") + async def insert_chunks( + self, namespace_id: str, chunk_list: List[Chunk], tasks: List[TaskInfo] + ): + client = await self.get_or_init_client() + index = client.index(self.get_shard(namespace_id)) + for i in range(0, len(chunk_list), self.batch_size): + raw_batch = chunk_list[i : i + self.batch_size] + + batch: List[Chunk] = [] + prompts: list[str] = [] + for x in raw_batch: + prompt: str = x.to_prompt() + if prompt: + batch.append(x) + prompts.append(prompt) + + headers = {} + propagate.inject(headers) + + embeddings = await self.openai.embeddings.create( + model=self.config.embedding.model, input=prompts, extra_headers=headers + ) + records = [] + for chunk, embed in zip(batch, embeddings.data): + record = IndexRecord( + type=IndexRecordType.chunk, + namespace_id=namespace_id, + chunk=chunk, + _vectors={ + self.embedder_name: embed.embedding, + }, + ) + record_dict = record.model_dump(by_alias=True) + record_dict["id"] = "chunk_{}".format(chunk.chunk_id) + records.append(record_dict) + tasks.append(await index.add_documents(records, primary_key="id")) + + @tracer.start_as_current_span("MeiliVectorDB.upsert_message") + async def upsert_message( + self, namespace_id: str, user_id: str, message: Message, tasks: List[TaskInfo] + ): + client = await self.get_or_init_client() + index = client.index(self.get_shard(namespace_id)) + record_id = "message_{}".format(message.message_id) + + if not message.message.content.strip(): + await index.delete_document(record_id) + return + + headers = {} + propagate.inject(headers) + + embedding = await self.openai.embeddings.create( + model=self.config.embedding.model, + input=message.message.content or "", + extra_headers=headers, + ) + record = IndexRecord( + type=IndexRecordType.message, + namespace_id=namespace_id, + user_id=user_id, + message=message, + _vectors={ + self.embedder_name: embedding.data[0].embedding, + }, + ) + record_dict = record.model_dump(by_alias=True) + record_dict["id"] = record_id + task = await index.add_documents( + [record_dict], primary_key="id" + ) + tasks.append(task) + + @tracer.start_as_current_span("MeiliVectorDB.remove_conversation") + async def remove_conversation( + self, namespace_id: str, conversation_id: str, tasks: List[TaskInfo] + ): + await self.delete_from_both_indexes( + namespace_id, + filter_=[ + "type = {}".format(IndexRecordType.message.value), + "namespace_id = {}".format(namespace_id), + "message.conversation_id = {}".format(conversation_id), + ], + tasks=tasks, + ) + + @tracer.start_as_current_span("MeiliVectorDB.remove_chunks") + async def remove_chunks( + self, namespace_id: str, resource_id: str, tasks: List[TaskInfo] + ): + await self.delete_from_both_indexes( + namespace_id, + filter_=[ + "type = {}".format(IndexRecordType.chunk.value), + "namespace_id = {}".format(namespace_id), + "chunk.resource_id = {}".format(resource_id), + ], + tasks=tasks, + ) + + @tracer.start_as_current_span("MeiliVectorDB.vector_params") + async def vector_params(self, query: str) -> dict: + if query: + headers = {} + propagate.inject(headers) + + embedding = await self.openai.embeddings.create( + model=self.config.embedding.model, input=query, extra_headers=headers + ) + vector = embedding.data[0].embedding + hybrid = Hybrid(embedder=self.embedder_name, semantic_ratio=0.5) + return { + "vector": vector, + "hybrid": hybrid, + } + return {} + + @tracer.start_as_current_span("MeiliVectorDB.query_both_indexes") + async def query_both_indexes( + self, + namespace_id: str, + query: str, + filter_: List[str | List[str]], + limit: int, + vector_params: dict, + **search_kwargs, + ) -> List[dict]: + client = await self.get_or_init_client() + + hits = [] + search_tasks = [] + + if self.has_old_index: + old_index = client.index(self.index_uid) + task = old_index.search( + query, + filter=filter_, + limit=limit, + **vector_params, + **search_kwargs, + show_ranking_score=True, + ) + search_tasks.append(task) + + index = client.index(self.get_shard(namespace_id)) + task = index.search( + query, + filter=filter_, + limit=limit, + **vector_params, + **search_kwargs, + show_ranking_score=True, + ) + search_tasks.append(task) + + results = await asyncio.gather(*search_tasks) + for result in results: + hits.extend(result.hits) + + hits.sort(key=lambda x: x.get("_rankingScore", 0), reverse=True) + return hits[:limit] + + @tracer.start_as_current_span("MeiliVectorDB.delete_from_both_indexes") + async def delete_from_both_indexes( + self, namespace_id: str, filter_: List[str | List[str]], tasks: List[TaskInfo] + ): + client = await self.get_or_init_client() + + if self.has_old_index: + old_index = client.index(self.index_uid) + tasks.append(await old_index.delete_documents_by_filter(filter=filter_)) + + index = client.index(self.get_shard(namespace_id)) + tasks.append(await index.delete_documents_by_filter(filter=filter_)) + + @tracer.start_as_current_span("MeiliVectorDB.search") + async def search( + self, + query: str, + namespace_id: str, + user_id: str | None, + record_type: IndexRecordType | None, + offset: int, + limit: int, + ) -> List[IndexRecord]: + filter_: List[str | List[str]] = [] + filter_.append("namespace_id = {}".format(namespace_id)) + if user_id: + filter_.append( + "user_id NOT EXISTS OR user_id IS NULL OR user_id = {}".format(user_id) + ) + if record_type: + filter_.append("type = {}".format(record_type.value)) + vector_params: dict = await self.vector_params(query) + + hits = await self.query_both_indexes( + namespace_id, + query, + filter_, + offset + limit, + vector_params, + ) + return [IndexRecord(**hit) for hit in hits[offset:]] + + @tracer.start_as_current_span("MeiliVectorDB.query_chunks") + async def query_chunks( + self, + namespace_id: str, + query: str, + k: int, + filter_: List[str | List[str]], + ) -> List[Tuple[Chunk, float]]: + combined_filters = filter_ + ["type = {}".format(IndexRecordType.chunk.value)] + vector_params: dict = await self.vector_params(query) + hits = await self.query_both_indexes( + namespace_id, + query, + combined_filters, + k, + vector_params, + ) + output: List[Tuple[Chunk, float]] = [] + for hit in hits: + chunk_data = hit["chunk"] + score = hit["_rankingScore"] + if chunk_data: + chunk = Chunk(**chunk_data) + output.append((chunk, score)) + return output + + async def wait_for_tasks(self, tasks: List[TaskInfo]): + if self.config.wait_timeout > 0: + client = await self.get_or_init_client() + await asyncio.gather( + *[ + client.wait_for_task( + task.task_uid, + timeout_in_ms=self.config.wait_timeout, + interval_in_ms=500, + ) + for task in tasks + ] + ) + + +class MeiliVectorRetriever(BaseRetriever): + def __init__(self, config: VectorConfig): + self.vector_db = MeiliVectorDB(config) + + @staticmethod + def get_folder(resource_id: str, resources: list[Resource]) -> str | None: + for resource in resources: + if ( + resource.type == PrivateSearchResourceType.FOLDER + and resource_id in resource.child_ids + ): + return resource.name + return None + + @staticmethod + def get_type( + resource_id: str, resources: list[Resource] + ) -> PrivateSearchResourceType | None: + for resource in resources: + if resource.id == resource_id: + return resource.type + return None + + def get_function( + self, private_search_tool: PrivateSearchTool, **kwargs + ) -> SearchFunction: + return partial( + self.query, private_search_tool=private_search_tool, k=40, **kwargs + ) + + @classmethod + def get_schema(cls) -> dict: + return cls.generate_schema( + "private_search", + 'Search for user\'s private & personal resources. Return in format.', + display_name={"zh": "知识库搜索", "en": "Knowledge Base Search"}, + ) + + @tracer.start_as_current_span("MeiliVectorRetriever.query") + async def query( + self, + query: str, + k: int, + *, + private_search_tool: PrivateSearchTool, + trace_info: TraceInfo | None = None, + ) -> list[ResourceChunkRetrieval]: + condition: Condition = private_search_tool.to_condition() + where = condition.to_meili_where() + if len(where) == 0: + trace_info and trace_info.warning( + { + "warning": "empty_where", + "where": where, + "condition": condition.model_dump() if condition else condition, + } + ) + return [] + + recall_result_list = await self.vector_db.query_chunks( + private_search_tool.namespace_id, query, k, where + ) + retrievals: List[ResourceChunkRetrieval] = [ + ResourceChunkRetrieval( + chunk=chunk, + folder=self.get_folder( + chunk.resource_id, private_search_tool.resources or [] + ), + type=self.get_type( + chunk.resource_id, private_search_tool.visible_resources or [] + ), + namespace_id=private_search_tool.namespace_id, + score=Score(recall=score, rerank=0), + ) + for chunk, score in recall_result_list + ] + trace_info and trace_info.debug( + { + "where": where, + "condition": condition.model_dump() if condition else condition, + "len(retrievals)": len(retrievals), + } + ) + return retrievals From 927d5ccbfe174c27b6f61f2e16e11607028ca2be Mon Sep 17 00:00:00 2001 From: Yichen Zhao Date: Thu, 26 Mar 2026 21:07:27 +0800 Subject: [PATCH 8/9] feat(agent): switch back to MeiliVectorRetriever --- grimoire/agent/agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/grimoire/agent/agent.py b/grimoire/agent/agent.py index 8ec202b..d68dfc6 100644 --- a/grimoire/agent/agent.py +++ b/grimoire/agent/agent.py @@ -40,8 +40,8 @@ PrivateSearchResourceType, ) from wizard_common.grimoire.retriever.base import BaseRetriever -from wizard_common.grimoire.retriever.weaviate_vector_db import ( - WeaviateVectorRetriever, +from wizard_common.grimoire.retriever.meili_vector_db import ( + MeiliVectorRetriever, ) from wizard_common.grimoire.retriever.reranker import ( get_tool_executor_config, @@ -248,7 +248,7 @@ def message_dtos_to_openai_messages( class BaseSearchableAgent(BaseStreamable, ABC): def __init__(self, config: GrimoireAgentConfig): - self.knowledge_database_retriever = WeaviateVectorRetriever(config=config.vector) + self.knowledge_database_retriever = MeiliVectorRetriever(config=config.vector) self.web_search_retriever = SearXNG( base_url=config.tools.searxng.base_url, engines=config.tools.searxng.engines ) From ef1de8fc45119c30590bedc59393d950a539a70c Mon Sep 17 00:00:00 2001 From: Yichen Zhao Date: Fri, 27 Mar 2026 18:57:26 +0800 Subject: [PATCH 9/9] refactor(weaviate): update property names and structure in WeaviateVectorDB --- grimoire/entity/tools.py | 12 +- grimoire/retriever/weaviate_vector_db.py | 179 ++++++++++++++--------- 2 files changed, 116 insertions(+), 75 deletions(-) diff --git a/grimoire/entity/tools.py b/grimoire/entity/tools.py index 33b6f96..3360cb0 100644 --- a/grimoire/entity/tools.py +++ b/grimoire/entity/tools.py @@ -76,7 +76,7 @@ def to_weaviate_filters(self) -> Any: if self.resource_ids: resource_filter = None for rid in self.resource_ids: - each = wvc.query.Filter.by_property("chunk.resource_id").equal(rid) + each = wvc.query.Filter.by_property("chunk_resource_id").equal(rid) resource_filter = each if resource_filter is None else (resource_filter | each) if resource_filter is not None: where = where & resource_filter @@ -84,24 +84,24 @@ def to_weaviate_filters(self) -> Any: if self.parent_ids: parent_filter = None for pid in self.parent_ids: - each = wvc.query.Filter.by_property("chunk.parent_id").equal(pid) + each = wvc.query.Filter.by_property("chunk_parent_id").equal(pid) parent_filter = each if parent_filter is None else (parent_filter | each) if parent_filter is not None: where = where & parent_filter if self.created_at is not None: - where = where & wvc.query.Filter.by_property("chunk.created_at").greater_or_equal( + where = where & wvc.query.Filter.by_property("chunk_created_at").greater_or_equal( self.created_at[0] ) - where = where & wvc.query.Filter.by_property("chunk.created_at").less_or_equal( + where = where & wvc.query.Filter.by_property("chunk_created_at").less_or_equal( self.created_at[1] ) if self.updated_at is not None: - where = where & wvc.query.Filter.by_property("chunk.updated_at").greater_or_equal( + where = where & wvc.query.Filter.by_property("chunk_updated_at").greater_or_equal( self.updated_at[0] ) - where = where & wvc.query.Filter.by_property("chunk.updated_at").less_or_equal( + where = where & wvc.query.Filter.by_property("chunk_updated_at").less_or_equal( self.updated_at[1] ) diff --git a/grimoire/retriever/weaviate_vector_db.py b/grimoire/retriever/weaviate_vector_db.py index 4ced219..7ff1d59 100644 --- a/grimoire/retriever/weaviate_vector_db.py +++ b/grimoire/retriever/weaviate_vector_db.py @@ -2,13 +2,9 @@ from functools import partial from typing import Any, List, Tuple -import weaviate -import weaviate.classes as wvc +from common.trace_info import TraceInfo from openai import AsyncOpenAI from opentelemetry import propagate, trace -from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateDeleteManyError - -from common.trace_info import TraceInfo from wizard_common.grimoire.config import VectorConfig from wizard_common.grimoire.entity.chunk import Chunk, ResourceChunkRetrieval from wizard_common.grimoire.entity.index_record import IndexRecord, IndexRecordType @@ -22,6 +18,10 @@ ) from wizard_common.grimoire.retriever.base import BaseRetriever, SearchFunction +import weaviate +import weaviate.classes as wvc +from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateDeleteManyError + tracer = trace.get_tracer(__name__) COLLECTION_NAME = "omnibox_index" @@ -84,48 +84,71 @@ async def _ensure_client(self) -> None: index_filterable=True, ), wvc.config.Property( - name="chunk", - data_type=wvc.config.DataType.OBJECT, - nested_properties=[ - wvc.config.Property( - name="resource_id", - data_type=wvc.config.DataType.TEXT, - index_filterable=True, - ), - wvc.config.Property( - name="parent_id", - data_type=wvc.config.DataType.TEXT, - index_filterable=True, - ), - wvc.config.Property( - name="created_at", - data_type=wvc.config.DataType.NUMBER, - index_filterable=True, - index_range_filters=True, - ), - wvc.config.Property( - name="updated_at", - data_type=wvc.config.DataType.NUMBER, - index_filterable=True, - index_range_filters=True, - ), - ], + name="chunk_title", + data_type=wvc.config.DataType.TEXT, + index_searchable=True, + ), + wvc.config.Property( + name="chunk_text", + data_type=wvc.config.DataType.TEXT, + index_searchable=True, + ), + wvc.config.Property( + name="chunk_resource_id", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, + ), + wvc.config.Property( + name="chunk_parent_id", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, + ), + wvc.config.Property( + name="chunk_type", + data_type=wvc.config.DataType.TEXT, + ), + wvc.config.Property( + name="chunk_id", + data_type=wvc.config.DataType.TEXT, + ), + wvc.config.Property( + name="chunk_start_index", + data_type=wvc.config.DataType.INT, + ), + wvc.config.Property( + name="chunk_end_index", + data_type=wvc.config.DataType.INT, + ), + wvc.config.Property( + name="chunk_created_at", + data_type=wvc.config.DataType.NUMBER, + index_filterable=True, + index_range_filters=True, + ), + wvc.config.Property( + name="chunk_updated_at", + data_type=wvc.config.DataType.NUMBER, + index_filterable=True, + index_range_filters=True, ), wvc.config.Property( - name="message", - data_type=wvc.config.DataType.OBJECT, - nested_properties=[ - wvc.config.Property( - name="message_id", - data_type=wvc.config.DataType.TEXT, - index_filterable=True, - ), - wvc.config.Property( - name="conversation_id", - data_type=wvc.config.DataType.TEXT, - index_filterable=True, - ), - ], + name="message_id", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, + ), + wvc.config.Property( + name="conversation_id", + data_type=wvc.config.DataType.TEXT, + index_filterable=True, + ), + wvc.config.Property( + name="message_role", + data_type=wvc.config.DataType.TEXT, + ), + wvc.config.Property( + name="message_content", + data_type=wvc.config.DataType.TEXT, + index_searchable=True, ), ], ) @@ -204,14 +227,23 @@ async def insert_chunks(self, namespace_id: str, chunk_list: List[Chunk]): vectors = await self._embed(prompts) objects = [] for chunk, vector in zip(batch, vectors): - record = IndexRecord( - type=IndexRecordType.chunk, - namespace_id=namespace_id, - chunk=chunk, - ) + properties = { + "type": IndexRecordType.chunk.value, + "namespace_id": namespace_id, + } + properties["chunk_title"] = chunk.title + properties["chunk_text"] = chunk.text + properties["chunk_resource_id"] = chunk.resource_id + properties["chunk_parent_id"] = chunk.parent_id + properties["chunk_type"] = chunk.chunk_type.value + properties["chunk_id"] = chunk.chunk_id + properties["chunk_created_at"] = chunk.created_at + properties["chunk_updated_at"] = chunk.updated_at + properties["chunk_start_index"] = chunk.start_index + properties["chunk_end_index"] = chunk.end_index objects.append( wvc.data.DataObject( - properties=record.model_dump(exclude_none=True), + properties=properties, vector=vector, ) ) @@ -223,7 +255,7 @@ async def upsert_message(self, namespace_id: str, user_id: str, message: Message try: await collection.data.delete_many( - where=wvc.query.Filter.by_property("message.message_id").equal( + where=wvc.query.Filter.by_property("message_id").equal( message.message_id ) ) @@ -236,16 +268,17 @@ async def upsert_message(self, namespace_id: str, user_id: str, message: Message return vector = (await self._embed(message_content))[0] - record = IndexRecord( - type=IndexRecordType.message, - namespace_id=namespace_id, - user_id=user_id, - message=message, - ) - - await collection.data.insert( - properties=record.model_dump(exclude_none=True), vector=vector - ) + properties = { + "type": IndexRecordType.message.value, + "namespace_id": namespace_id, + "user_id": user_id, + } + properties["message_id"] = message.message_id + properties["conversation_id"] = message.conversation_id + properties["message_role"] = message.message.role + properties["message_content"] = message_content + + await collection.data.insert(properties=properties, vector=vector) @tracer.start_as_current_span("WeaviateVectorDB.remove_conversation") async def remove_conversation(self, namespace_id: str, conversation_id: str): @@ -256,9 +289,7 @@ async def remove_conversation(self, namespace_id: str, conversation_id: str): IndexRecordType.message.value ) & wvc.query.Filter.by_property("namespace_id").equal(namespace_id) - & wvc.query.Filter.by_property("message.conversation_id").equal( - conversation_id - ) + & wvc.query.Filter.by_property("conversation_id").equal(conversation_id) ) except WeaviateDeleteManyError: # Tenant not found (no data yet for this namespace) @@ -273,7 +304,7 @@ async def remove_chunks(self, namespace_id: str, resource_id: str): IndexRecordType.chunk.value ) & wvc.query.Filter.by_property("namespace_id").equal(namespace_id) - & wvc.query.Filter.by_property("chunk.resource_id").equal(resource_id) + & wvc.query.Filter.by_property("chunk_resource_id").equal(resource_id) ) except WeaviateDeleteManyError: # Tenant not found (no data yet for this namespace) @@ -323,9 +354,19 @@ async def query_chunks( ) output: List[Tuple[Chunk, float]] = [] for hit, score in hits: - chunk_data = hit.get("chunk") - if chunk_data: - output.append((Chunk(**chunk_data), score)) + chunk = Chunk( + title=hit.get("chunk_title"), + resource_id=hit["chunk_resource_id"], + text=hit.get("chunk_text"), + chunk_type=hit["chunk_type"], + parent_id=hit["chunk_parent_id"], + chunk_id=hit["chunk_id"], + created_at=hit["chunk_created_at"], + updated_at=hit["chunk_updated_at"], + start_index=hit.get("chunk_start_index"), + end_index=hit.get("chunk_end_index"), + ) + output.append((chunk, score)) return output