Skip to content
9 changes: 8 additions & 1 deletion grimoire/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,18 @@
from wizard_common.config import OpenAIConfig


class WeaviateConfig(BaseModel):
host: str | None = Field(default=None)
port: int = Field(default=8080)
api_key: str | None = Field(default=None)


class VectorConfig(BaseModel):
embedding: OpenAIConfig
host: str
port: int = Field(default=8000)
meili_api_key: str = Field(default=None)
meili_api_key: str | None = Field(default=None)
weaviate: WeaviateConfig = Field(default_factory=WeaviateConfig)
batch_size: int = Field(default=1)
max_results: int = Field(default=10)
wait_timeout: int = Field(default=0)
Expand Down
6 changes: 2 additions & 4 deletions grimoire/entity/chunk.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import time
from datetime import datetime
from enum import Enum
from typing import Optional, Literal
from typing import Literal, Optional

import shortuuid
from pydantic import Field

from wizard_common.grimoire.entity.retrieval import (
BaseRetrieval,
Citation,
to_prompt,
PromptContext,
to_prompt,
)
from wizard_common.grimoire.entity.tools import PrivateSearchResourceType

Expand All @@ -36,7 +35,6 @@ class Chunk(PromptContext):
text: str | None = Field(default=None, description="Chunk content")
chunk_type: ChunkType = Field(description="Chunk type")

user_id: str
parent_id: str

chunk_id: str = Field(description="ID of chunk", default_factory=shortuuid.uuid)
Expand Down
1 change: 0 additions & 1 deletion grimoire/entity/index_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class IndexRecordType(str, Enum):


class IndexRecord(BaseModel):
id: str
type: IndexRecordType
namespace_id: str
user_id: str | None = None
Expand Down
63 changes: 62 additions & 1 deletion grimoire/entity/tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from enum import Enum
from functools import partial
from typing import List, Literal, Callable, TypedDict, Awaitable, Union, get_args, cast
from typing import List, Literal, Callable, TypedDict, Awaitable, Union, get_args, cast, Any

from opentelemetry import trace
from pydantic import BaseModel, Field
import weaviate.classes as wvc

tracer = trace.get_tracer("grimoire.entity.tools")
ToolName = Literal["private_search", "web_search"]
Expand All @@ -12,6 +13,8 @@

class Condition(BaseModel):
namespace_id: str
user_id: str | None = Field(default=None)
record_type: str | None = Field(default=None)
resource_ids: list[str] | None = Field(default=None)
parent_ids: list[str] | None = Field(default=None)
created_at: tuple[float, float] | None = Field(default=None)
Expand All @@ -21,6 +24,18 @@ def to_meili_where(self) -> List[str | List[str]]:
and_clause: List[str | List[str]] = [
'namespace_id = "{}"'.format(self.namespace_id)
]

if self.user_id:
and_clause.append(
[
'user_id IS NULL',
'user_id = "{}"'.format(self.user_id)
]
)

if self.record_type:
and_clause.append('type = "{}"'.format(self.record_type))

or_clause: List[str] = []
if self.resource_ids:
or_clause.append(
Expand All @@ -46,6 +61,52 @@ def to_meili_where(self) -> List[str | List[str]]:

return and_clause

def to_weaviate_filters(self) -> Any:
where = wvc.query.Filter.by_property("namespace_id").equal(self.namespace_id)

if self.user_id:
where = where & (
wvc.query.Filter.by_property("user_id").is_none(True)
| wvc.query.Filter.by_property("user_id").equal(self.user_id)
)

if self.record_type:
where = where & wvc.query.Filter.by_property("type").equal(self.record_type)

if self.resource_ids:
resource_filter = None
for rid in self.resource_ids:
each = wvc.query.Filter.by_property("chunk_resource_id").equal(rid)
resource_filter = each if resource_filter is None else (resource_filter | each)
if resource_filter is not None:
where = where & resource_filter

if self.parent_ids:
parent_filter = None
for pid in self.parent_ids:
each = wvc.query.Filter.by_property("chunk_parent_id").equal(pid)
parent_filter = each if parent_filter is None else (parent_filter | each)
if parent_filter is not None:
where = where & parent_filter

if self.created_at is not None:
where = where & wvc.query.Filter.by_property("chunk_created_at").greater_or_equal(
self.created_at[0]
)
where = where & wvc.query.Filter.by_property("chunk_created_at").less_or_equal(
self.created_at[1]
)

if self.updated_at is not None:
where = where & wvc.query.Filter.by_property("chunk_updated_at").greater_or_equal(
self.updated_at[0]
)
where = where & wvc.query.Filter.by_property("chunk_updated_at").less_or_equal(
self.updated_at[1]
)

return where


class ToolExecutorConfig(TypedDict):
name: str
Expand Down
10 changes: 6 additions & 4 deletions grimoire/retriever/meili_vector_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,16 @@ async def insert_chunks(
records = []
for chunk, embed in zip(batch, embeddings.data):
record = IndexRecord(
id="chunk_{}".format(chunk.chunk_id),
type=IndexRecordType.chunk,
namespace_id=namespace_id,
chunk=chunk,
_vectors={
self.embedder_name: embed.embedding,
},
)
records.append(record.model_dump(by_alias=True))
record_dict = record.model_dump(by_alias=True)
record_dict["id"] = "chunk_{}".format(chunk.chunk_id)
records.append(record_dict)
tasks.append(await index.add_documents(records, primary_key="id"))

@tracer.start_as_current_span("MeiliVectorDB.upsert_message")
Expand All @@ -211,7 +212,6 @@ async def upsert_message(
extra_headers=headers,
)
record = IndexRecord(
id=record_id,
type=IndexRecordType.message,
namespace_id=namespace_id,
user_id=user_id,
Expand All @@ -220,8 +220,10 @@ async def upsert_message(
self.embedder_name: embedding.data[0].embedding,
},
)
record_dict = record.model_dump(by_alias=True)
record_dict["id"] = record_id
task = await index.add_documents(
[record.model_dump(by_alias=True)], primary_key="id"
[record_dict], primary_key="id"
)
tasks.append(task)

Expand Down
Loading