Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

87 changes: 69 additions & 18 deletions crates/api/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use defs::{Dimension, IndexedVector, Similarity};
use defs::{Dimension, IndexedVector, SearchQueryInput, Similarity};

use defs::{DenseVector, Payload, Point, PointId};
use index::hnsw::HnswIndex;
Expand Down Expand Up @@ -68,6 +68,28 @@ impl VectorDb {
Ok(point_id)
}

pub fn insert_batch(&self, points: Vec<Point>) -> Result<Vec<PointId>> {
let mut ids = Vec::with_capacity(points.len());

for point in points {
let id = point.id.unwrap_or_else(Uuid::new_v4);
let vector = point.vector;
let payload = point.payload;

self.storage.insert_point(id, vector.clone(), payload)?;

if let Some(v) = vector {
let indexed = IndexedVector { id, vector: v };
let mut index = self.index.write().map_err(|_| ApiError::LockError)?;
index.insert(indexed)?;
}

ids.push(id);
}

Ok(ids)
}

//TODO: Make this an atomic operation
pub fn delete(&self, id: PointId) -> Result<bool> {
// Remove from storage
Expand All @@ -84,7 +106,7 @@ impl VectorDb {
let vector = self.storage.get_vector(id)?;
if payload.is_some() || vector.is_some() {
Ok(Some(Point {
id,
id: Some(id),
payload,
vector,
}))
Expand All @@ -93,34 +115,41 @@ impl VectorDb {
}
}

pub fn search(
&self,
query: DenseVector,
similarity: Similarity,
limit: usize,
) -> Result<Vec<PointId>> {
pub fn search(&self, query: SearchQueryInput) -> Result<Vec<PointId>> {
// Validate search limit
if limit == 0 {
return Err(ApiError::InvalidSearchLimit { limit });
if query.limit == 0 {
return Err(ApiError::InvalidSearchLimit { limit: query.limit });
}

// Validate query dimension
if query.len() != self.dimension {
if query.vector.len() != self.dimension {
return Err(ApiError::DimensionMismatch {
expected: self.dimension,
got: query.len(),
got: query.vector.len(),
});
}

// Use vector index to find similar vectors
let index = self.index.read().map_err(|_| ApiError::LockError)?;

//TODO: Add feat of returning similarity scores in the search
let vectors = index.search(query, similarity, limit)?;
let vectors = index.search(query.vector, query.similarity, query.limit)?;

Ok(vectors)
}

pub fn search_batch(&self, queries: Vec<SearchQueryInput>) -> Result<Vec<Vec<PointId>>> {
let mut results = Vec::with_capacity(queries.len());
let index = self.index.read().unwrap();

for query in queries {
let found = index.search(query.vector, query.similarity, query.limit)?;
results.push(found);
}

Ok(results)
}

pub fn list(&self, offset: PointId, limit: usize) -> Result<Option<VectorPage>> {
let page = self.storage.list_vectors(offset, limit)?;
Ok(page)
Expand Down Expand Up @@ -224,7 +253,7 @@ mod tests {

// Test get
let point = db.get(id).unwrap().unwrap();
assert_eq!(point.id, id);
assert_eq!(point.id, Some(id));
assert_eq!(point.vector.as_ref().unwrap(), &vector);
assert_eq!(point.payload.as_ref().unwrap(), &payload);
assert_eq!(
Expand Down Expand Up @@ -305,7 +334,13 @@ mod tests {

// Search for the closest vector to [1.0, 0.1, 0.1]
let query = vec![1.0, 0.1, 0.1];
let results = db.search(query, Similarity::Cosine, 1).unwrap();
let results = db
.search(SearchQueryInput {
vector: query,
similarity: Similarity::Cosine,
limit: 1,
})
.unwrap();

assert_eq!(results.len(), 1);
assert_eq!(results[0], ids[0]); // The first vector should be closest
Expand Down Expand Up @@ -333,7 +368,13 @@ mod tests {

// Search with limit 3
let query = vec![0.0, 0.0, 0.0];
let results = db.search(query, Similarity::Euclidean, 3).unwrap();
let results = db
.search(SearchQueryInput {
vector: query,
similarity: Similarity::Euclidean,
limit: 3,
})
.unwrap();

assert_eq!(results.len(), 3);
}
Expand All @@ -343,7 +384,11 @@ mod tests {
let (db, _temp_dir) = create_test_db();

let query = vec![1.0, 2.0, 3.0];
let result = db.search(query, Similarity::Cosine, 0);
let result = db.search(SearchQueryInput {
vector: query,
similarity: Similarity::Cosine,
limit: 0,
});

assert!(result.is_err());
match result.unwrap_err() {
Expand All @@ -362,7 +407,13 @@ mod tests {
assert!(db.get(Uuid::new_v4()).unwrap().is_none());

let query = vec![1.0, 2.0, 3.0];
let results = db.search(query, Similarity::Cosine, 10).unwrap();
let results = db
.search(SearchQueryInput {
vector: query,
similarity: Similarity::Cosine,
limit: 10,
})
.unwrap();
assert_eq!(results.len(), 0);
}

Expand Down
2 changes: 2 additions & 0 deletions crates/defs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ edition.workspace = true
license.workspace = true

[dependencies]
axum.workspace = true
serde.workspace = true
snafu.workspace = true
uuid.workspace = true
22 changes: 19 additions & 3 deletions crates/defs/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
use std::io;

#[derive(Debug)]
use axum::{http::StatusCode, response::IntoResponse};
use snafu::Snafu;

#[derive(Debug, Snafu)]
pub enum ServerError {
Bind(io::Error),
Serve(io::Error),
#[snafu(display("Failed to bind: {source}"))]
Bind { source: io::Error },

#[snafu(display("Failed to serve: {source}"))]
Serve { source: io::Error },
}

#[derive(Debug)]
pub enum AppError {
ServerError(ServerError),
Api(String),
}

impl IntoResponse for AppError {
fn into_response(self) -> axum::response::Response {
let (status, message) = match self {
AppError::Api(msg) => (StatusCode::BAD_REQUEST, msg),
AppError::ServerError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()),
};
(status, message).into_response()
}
}
// Error type for server
pub type BoxError = Box<dyn std::error::Error + Send + Sync>;
43 changes: 39 additions & 4 deletions crates/defs/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,22 @@ pub enum StoredVector {
Dense(DenseVector),
}

#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq)]
#[derive(Serialize, Deserialize, Clone, Copy, Debug, Default, PartialEq)]
pub enum ContentType {
#[default]
Text,
Image,
}

#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)]
pub struct Payload {
pub content_type: ContentType,
pub content: String,
}

#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub struct Point {
pub id: PointId,
pub id: Option<PointId>,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't id be necessary in a Point? what does a point with None id signify?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While inserting (mostly in batch inserts), the user may not provide an id. In that case, the server generates an 'id` for that.

pub vector: Option<DenseVector>,
pub payload: Option<Payload>,
}
Expand All @@ -45,14 +46,48 @@ pub struct IndexedVector {
pub vector: DenseVector,
}

#[derive(Debug, Deserialize, Copy, Clone)]
#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
pub enum Similarity {
Euclidean,
Manhattan,
Hamming,
Cosine,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchInsertRequest {
pub points: Vec<Point>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchInsertResponse {
pub inserted: usize,
pub ids: Vec<PointId>,
}

// For batch search
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchSearchRequest {
pub queries: Vec<SearchQueryInput>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchQueryInput {
pub vector: DenseVector,
pub similarity: Similarity,
pub limit: usize,
}

#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct SearchResponse {
pub results: Vec<PointId>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchSearchResponse {
pub results: Vec<SearchResponse>,
}

// Struct which stores the distance between a vector and query vector and implements ordering traits
#[derive(Copy, Clone)]
pub struct DistanceOrderedVector<'q> {
Expand Down
18 changes: 18 additions & 0 deletions crates/grpc/proto/vector-db.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ service VectorDB {

//Search for the k nearest vectors to a target vector given a distance function
rpc SearchPoints(SearchRequest) returns (SearchResponse) {}

rpc InsertVectorsBatch(InsertVectorsBatchRequest) returns (InsertVectorsBatchResponse) {}
rpc SearchPointsBatch(SearchPointsBatchRequest) returns (SearchPointsBatchResponse) {}
}


Expand Down Expand Up @@ -71,3 +74,18 @@ message Payload {
string content = 2;
}

message InsertVectorsBatchRequest {
repeated InsertVectorRequest vectors = 1;
}

message InsertVectorsBatchResponse {
repeated PointID ids = 1;
}

message SearchPointsBatchRequest {
repeated SearchRequest queries = 1;
}

message SearchPointsBatchResponse {
repeated SearchResponse results = 1;
}
Loading
Loading