Skip to content

Commit 86bc8cd

Browse files
committed
refactor weaviate filters
1 parent 251b945 commit 86bc8cd

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

packages/ragbits-core/src/ragbits/core/vector_stores/weaviate_vector.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from functools import reduce
2-
from operator import and_
31
from typing import TypeVar, cast
42
from uuid import UUID
53

@@ -9,7 +7,7 @@
97
from weaviate import WeaviateAsyncClient
108
from weaviate.classes.config import Configure, VectorDistances
119
from weaviate.classes.query import Filter, MetadataQuery
12-
from weaviate.collections.classes.filters import _Filters
10+
from weaviate.collections.classes.filters import FilterReturn
1311

1412
from ragbits.core.audit.traces import trace
1513
from ragbits.core.embeddings import Embedder
@@ -261,7 +259,7 @@ async def remove(self, ids: list[UUID]) -> None:
261259
await index.data.delete_many(where=Filter.by_id().contains_any(ids))
262260

263261
@staticmethod
264-
def _create_weaviate_filter(where: WhereQuery, separator: str) -> _Filters:
262+
def _create_weaviate_filter(where: WhereQuery, separator: str) -> FilterReturn:
265263
"""
266264
Creates the Filter from the given WhereQuery.
267265
@@ -274,13 +272,16 @@ def _create_weaviate_filter(where: WhereQuery, separator: str) -> _Filters:
274272
"""
275273
where = flatten_dict(where) # type: ignore
276274

277-
filters = (
278-
Filter.by_property(f"metadata{separator}{key.replace('.', separator)}").equal(cast(str | int | bool, value))
279-
for key, value in where.items()
275+
filters = Filter.all_of(
276+
[
277+
Filter.by_property(f"metadata{separator}{key.replace('.', separator)}").equal(
278+
cast(str | int | bool, value)
279+
)
280+
for key, value in where.items()
281+
]
280282
)
281283

282-
filters = reduce(and_, filters) # type: ignore
283-
return filters # type: ignore
284+
return filters
284285

285286
async def list(
286287
self,

packages/ragbits-core/tests/unit/vector_stores/test_weaviate.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import weaviate.classes as wvc
66
from weaviate.classes.query import Filter
7-
from weaviate.collections.classes.filters import _FilterAnd
7+
from weaviate.collections.classes.filters import _FilterAnd, _FilterValue
88
from weaviate.collections.classes.internal import MetadataReturn, Object
99

1010
from ragbits.core.embeddings.dense import NoopEmbedder
@@ -472,13 +472,13 @@ async def test_create_weaviate_filter():
472472
where = {"a": "A", "b": 3, "c": True}
473473
weaviate_filter = WeaviateVectorStore._create_weaviate_filter(where, separator="___") # type: ignore
474474
assert isinstance(weaviate_filter, _FilterAnd)
475-
assert isinstance(weaviate_filter.filters[0], _FilterAnd)
476-
assert weaviate_filter.filters[0].filters[0].target == "metadata___a" # type: ignore
477-
assert weaviate_filter.filters[0].filters[0].value == "A" # type: ignore
478-
assert weaviate_filter.filters[0].filters[1].target == "metadata___b" # type: ignore
479-
assert weaviate_filter.filters[0].filters[1].value == 3 # type: ignore
480-
assert weaviate_filter.filters[1].target == "metadata___c" # type: ignore
481-
assert weaviate_filter.filters[1].value # type: ignore
475+
assert isinstance(weaviate_filter.filters[0], _FilterValue)
476+
assert weaviate_filter.filters[0].target == "metadata___a" # type: ignore
477+
assert weaviate_filter.filters[0].value == "A" # type: ignore
478+
assert weaviate_filter.filters[1].target == "metadata___b" # type: ignore
479+
assert weaviate_filter.filters[1].value == 3 # type: ignore
480+
assert weaviate_filter.filters[2].target == "metadata___c" # type: ignore
481+
assert weaviate_filter.filters[2].value # type: ignore
482482

483483

484484
@pytest.mark.asyncio

0 commit comments

Comments
 (0)