Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Reference documentation for the RedisVL API.

schema
searchindex
vector
query
filter
vectorizer
Expand Down
14 changes: 14 additions & 0 deletions docs/api/query.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,17 @@ CountQuery
:inherited-members:
:show-inheritance:
:exclude-members: add_filter,get_args,highlight,return_field,summarize



MultiVectorQuery
==========

.. currentmodule:: redisvl.query


.. autoclass:: MultiVectorQuery
:members:
:inherited-members:
:show-inheritance:
:exclude-members: add_filter,get_args,highlight,return_field,summarize
17 changes: 17 additions & 0 deletions docs/api/vector.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

*****
Vector
*****

The Vector class in RedisVL is a container that encapsulates a numerical vector, it's datatype, corresponding index field name, and optional importance weight. It is used when constructing multi-vector queries using the MultiVectorQuery class.


Vector
===========

.. currentmodule:: redisvl.query


.. autoclass:: Vector
:members:
:exclude-members:
9 changes: 8 additions & 1 deletion redisvl/query/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from redisvl.query.aggregate import AggregationQuery, HybridQuery
from redisvl.query.aggregate import (
AggregationQuery,
HybridQuery,
MultiVectorQuery,
Vector,
)
from redisvl.query.query import (
BaseQuery,
BaseVectorQuery,
Expand All @@ -21,4 +26,6 @@
"TextQuery",
"AggregationQuery",
"HybridQuery",
"MultiVectorQuery",
"Vector",
]
171 changes: 171 additions & 0 deletions redisvl/query/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,41 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from pydantic import BaseModel, field_validator
from redis.commands.search.aggregation import AggregateRequest, Desc

from redisvl.query.filter import FilterExpression
from redisvl.redis.utils import array_to_buffer
from redisvl.schema.fields import VectorDataType
from redisvl.utils.token_escaper import TokenEscaper
from redisvl.utils.utils import lazy_import

nltk = lazy_import("nltk")
nltk_stopwords = lazy_import("nltk.corpus.stopwords")


class Vector(BaseModel):
"""
Simple object containing the necessary arguments to perform a multi vector query.
"""

vector: Union[List[float], bytes]
field_name: str
dtype: str = "float32"
weight: float = 1.0

@field_validator("dtype")
@classmethod
def validate_dtype(cls, dtype: str) -> str:
try:
VectorDataType(dtype.upper())
except ValueError:
raise ValueError(
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
)

return dtype


class AggregationQuery(AggregateRequest):
"""
Base class for aggregation queries used to create aggregation queries for Redis.
Expand Down Expand Up @@ -227,3 +252,149 @@ def _build_query_string(self) -> str:
def __str__(self) -> str:
"""Return the string representation of the query."""
return " ".join([str(x) for x in self.build_args()])


class MultiVectorQuery(AggregationQuery):
"""
MultiVectorQuery allows for search over multiple vector fields in a document simulateously.
The final score will be a weighted combination of the individual vector similarity scores
following the formula:

score = (w_1 * score_1 + w_2 * score_2 + w_3 * score_3 + ... )

Vectors may be of different size and datatype, but must be indexed using the 'cosine' distance_metric.

.. code-block:: python

from redisvl.query import MultiVectorQuery, Vector
from redisvl.index import SearchIndex

index = SearchIndex.from_yaml("path/to/index.yaml")

vector_1 = Vector(
vector=[0.1, 0.2, 0.3],
field_name="text_vector",
dtype="float32",
weight=0.7,
)
vector_2 = Vector(
vector=[0.5, 0.5],
field_name="image_vector",
dtype="bfloat16",
weight=0.2,
)
vector_3 = Vector(
vector=[0.1, 0.2, 0.3],
field_name="text_vector",
dtype="float64",
weight=0.5,
)

query = MultiVectorQuery(
vectors=[vector_1, vector_2, vector_3],
filter_expression=None,
num_results=10,
return_fields=["field1", "field2"],
dialect=2,
)

results = index.query(query)
"""

_vectors: List[Vector]

def __init__(
self,
vectors: Union[Vector, List[Vector]],
return_fields: Optional[List[str]] = None,
filter_expression: Optional[Union[str, FilterExpression]] = None,
num_results: int = 10,
dialect: int = 2,
):
"""
Instantiates a MultiVectorQuery object.

Args:
vectors (Union[Vector, List[Vector]]): The Vectors to perform vector similarity search.
return_fields (Optional[List[str]], optional): The fields to return. Defaults to None.
filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to use.
Defaults to None.
num_results (int, optional): The number of results to return. Defaults to 10.
dialect (int, optional): The Redis dialect version. Defaults to 2.
"""

self._filter_expression = filter_expression
self._num_results = num_results

if isinstance(vectors, Vector):
self._vectors = [vectors]
else:
self._vectors = vectors # type: ignore

if not all([isinstance(v, Vector) for v in self._vectors]):
raise TypeError(
"vector argument must be a Vector object or list of Vector objects."
)

query_string = self._build_query_string()
super().__init__(query_string)

# calculate the respective vector similarities
for i in range(len(self._vectors)):
self.apply(**{f"score_{i}": f"(2 - @distance_{i})/2"})

# construct the scoring string based on the vector similarity scores and weights
combined_scores = []
for i, w in enumerate([v.weight for v in self._vectors]):
combined_scores.append(f"@score_{i} * {w}")
combined_score_string = " + ".join(combined_scores)

self.apply(combined_score=combined_score_string)

self.sort_by(Desc("@combined_score"), max=num_results) # type: ignore
self.dialect(dialect)
if return_fields:
self.load(*return_fields) # type: ignore[arg-type]

@property
def params(self) -> Dict[str, Any]:
"""Return the parameters for the aggregation.

Returns:
Dict[str, Any]: The parameters for the aggregation.
"""
params = {}
for i, (vector, dtype) in enumerate(
[(v.vector, v.dtype) for v in self._vectors]
):
if isinstance(vector, list):
vector = array_to_buffer(vector, dtype=dtype) # type: ignore
params[f"vector_{i}"] = vector
return params

def _build_query_string(self) -> str:
"""Build the full query string for text search with optional filtering."""

# base KNN query
range_queries = []
for i, (vector, field) in enumerate(
[(v.vector, v.field_name) for v in self._vectors]
):
range_queries.append(
f"@{field}:[VECTOR_RANGE 2.0 $vector_{i}]=>{{$YIELD_DISTANCE_AS: distance_{i}}}"
)

range_query = " | ".join(range_queries)

filter_expression = self._filter_expression
if isinstance(self._filter_expression, FilterExpression):
filter_expression = str(self._filter_expression)

if filter_expression:
return f"({range_query}) AND ({filter_expression})"
else:
return f"{range_query}"

def __str__(self) -> str:
"""Return the string representation of the query."""
return " ".join([str(x) for x in self.build_args()])
90 changes: 90 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,96 @@ def sample_data(sample_datetimes):
]


@pytest.fixture
def multi_vector_data(sample_datetimes):
return [
{
"user": "john",
"age": 18,
"job": "engineer",
"description": "engineers conduct trains that ride on train tracks",
"last_updated": sample_datetimes["low"].timestamp(),
"credit_score": "high",
"location": "-122.4194,37.7749",
"user_embedding": [0.1, 0.1, 0.5],
"image_embedding": [0.1, 0.1, 0.1, 0.1, 0.1],
"audio_embedding": [34, 18.5, -6.0, -12, 115, 96.5],
},
{
"user": "mary",
"age": 14,
"job": "doctor",
"description": "a medical professional who treats diseases and helps people stay healthy",
"last_updated": sample_datetimes["low"].timestamp(),
"credit_score": "low",
"location": "-122.4194,37.7749",
"user_embedding": [0.1, 0.1, 0.5],
"image_embedding": [0.1, 0.2, 0.3, 0.4, 0.5],
"audio_embedding": [0.0, -1.06, 4.55, -1.93, 0.0, 1.53],
},
{
"user": "nancy",
"age": 94,
"job": "doctor",
"description": "a research scientist specializing in cancers and diseases of the lungs",
"last_updated": sample_datetimes["mid"].timestamp(),
"credit_score": "high",
"location": "-122.4194,37.7749",
"user_embedding": [0.7, 0.1, 0.5],
"image_embedding": [0.1, 0.1, 0.3, 0.3, 0.5],
"audio_embedding": [2.75, -0.33, -3.01, -0.52, 5.59, -2.30],
},
{
"user": "tyler",
"age": 100,
"job": "engineer",
"description": "a software developer with expertise in mathematics and computer science",
"last_updated": sample_datetimes["mid"].timestamp(),
"credit_score": "high",
"location": "-110.0839,37.3861",
"user_embedding": [0.1, 0.4, 0.5],
"image_embedding": [-0.1, -0.2, -0.3, -0.4, -0.5],
"audio_embedding": [1.11, -6.73, 5.41, 1.04, 3.92, 0.73],
},
{
"user": "tim",
"age": 12,
"job": "dermatologist",
"description": "a medical professional specializing in diseases of the skin",
"last_updated": sample_datetimes["mid"].timestamp(),
"credit_score": "high",
"location": "-110.0839,37.3861",
"user_embedding": [0.4, 0.4, 0.5],
"image_embedding": [-0.1, 0.0, 0.6, 0.0, -0.9],
"audio_embedding": [0.03, -2.67, -2.08, 4.57, -2.33, 0.0],
},
{
"user": "taimur",
"age": 15,
"job": "CEO",
"description": "high stress, but financially rewarding position at the head of a company",
"last_updated": sample_datetimes["high"].timestamp(),
"credit_score": "low",
"location": "-110.0839,37.3861",
"user_embedding": [0.6, 0.1, 0.5],
"image_embedding": [1.1, 1.2, -0.3, -4.1, 5.0],
"audio_embedding": [0.68, 0.26, 2.08, 2.96, 0.01, 5.13],
},
{
"user": "joe",
"age": 35,
"job": "dentist",
"description": "like the tooth fairy because they'll take your teeth, but you have to pay them!",
"last_updated": sample_datetimes["high"].timestamp(),
"credit_score": "medium",
"location": "-110.0839,37.3861",
"user_embedding": [-0.1, -0.1, -0.5],
"image_embedding": [-0.8, 2.0, 3.1, 1.5, -1.6],
"audio_embedding": [0.91, 7.10, -2.14, -0.52, -6.08, -5.53],
},
]


def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption(
"--run-api-tests",
Expand Down
Loading