From 8effb88809e14fba6caff81ddb378de039bc372c Mon Sep 17 00:00:00 2001 From: Arshdeep54 Date: Tue, 24 Mar 2026 03:18:16 +0530 Subject: [PATCH] feat: add batch insert and search Signed-off-by: Arshdeep54 --- Cargo.lock | 2 + crates/api/src/lib.rs | 87 ++++++++++++++++++++++------ crates/defs/Cargo.toml | 2 + crates/defs/src/error.rs | 22 ++++++- crates/defs/src/types.rs | 43 ++++++++++++-- crates/grpc/proto/vector-db.proto | 18 ++++++ crates/grpc/src/service.rs | 95 +++++++++++++++++++++++++++++-- crates/http/src/handler.rs | 54 ++++++++++++------ crates/http/src/lib.rs | 6 +- crates/index/src/lib.rs | 7 +++ crates/storage/src/rocks_db.rs | 10 ++-- crates/tui/src/app/events.rs | 13 ++++- 12 files changed, 303 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e19cfa0..24e106d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -360,7 +360,9 @@ dependencies = [ name = "defs" version = "0.1.0" dependencies = [ + "axum", "serde", + "snafu", "uuid", ] diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index 16e9dab..f759614 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -1,4 +1,4 @@ -use defs::{Dimension, IndexedVector, Similarity}; +use defs::{Dimension, IndexedVector, SearchQueryInput, Similarity}; use defs::{DenseVector, Payload, Point, PointId}; use index::hnsw::HnswIndex; @@ -68,6 +68,28 @@ impl VectorDb { Ok(point_id) } + pub fn insert_batch(&self, points: Vec) -> Result> { + let mut ids = Vec::with_capacity(points.len()); + + for point in points { + let id = point.id.unwrap_or_else(Uuid::new_v4); + let vector = point.vector; + let payload = point.payload; + + self.storage.insert_point(id, vector.clone(), payload)?; + + if let Some(v) = vector { + let indexed = IndexedVector { id, vector: v }; + let mut index = self.index.write().map_err(|_| ApiError::LockError)?; + index.insert(indexed)?; + } + + ids.push(id); + } + + Ok(ids) + } + //TODO: Make this an atomic operation pub fn delete(&self, id: PointId) -> Result { // Remove from storage @@ -84,7 +106,7 @@ impl VectorDb { let vector = self.storage.get_vector(id)?; if payload.is_some() || vector.is_some() { Ok(Some(Point { - id, + id: Some(id), payload, vector, })) @@ -93,22 +115,17 @@ impl VectorDb { } } - pub fn search( - &self, - query: DenseVector, - similarity: Similarity, - limit: usize, - ) -> Result> { + pub fn search(&self, query: SearchQueryInput) -> Result> { // Validate search limit - if limit == 0 { - return Err(ApiError::InvalidSearchLimit { limit }); + if query.limit == 0 { + return Err(ApiError::InvalidSearchLimit { limit: query.limit }); } // Validate query dimension - if query.len() != self.dimension { + if query.vector.len() != self.dimension { return Err(ApiError::DimensionMismatch { expected: self.dimension, - got: query.len(), + got: query.vector.len(), }); } @@ -116,11 +133,23 @@ impl VectorDb { let index = self.index.read().map_err(|_| ApiError::LockError)?; //TODO: Add feat of returning similarity scores in the search - let vectors = index.search(query, similarity, limit)?; + let vectors = index.search(query.vector, query.similarity, query.limit)?; Ok(vectors) } + pub fn search_batch(&self, queries: Vec) -> Result>> { + let mut results = Vec::with_capacity(queries.len()); + let index = self.index.read().unwrap(); + + for query in queries { + let found = index.search(query.vector, query.similarity, query.limit)?; + results.push(found); + } + + Ok(results) + } + pub fn list(&self, offset: PointId, limit: usize) -> Result> { let page = self.storage.list_vectors(offset, limit)?; Ok(page) @@ -224,7 +253,7 @@ mod tests { // Test get let point = db.get(id).unwrap().unwrap(); - assert_eq!(point.id, id); + assert_eq!(point.id, Some(id)); assert_eq!(point.vector.as_ref().unwrap(), &vector); assert_eq!(point.payload.as_ref().unwrap(), &payload); assert_eq!( @@ -305,7 +334,13 @@ mod tests { // Search for the closest vector to [1.0, 0.1, 0.1] let query = vec![1.0, 0.1, 0.1]; - let results = db.search(query, Similarity::Cosine, 1).unwrap(); + let results = db + .search(SearchQueryInput { + vector: query, + similarity: Similarity::Cosine, + limit: 1, + }) + .unwrap(); assert_eq!(results.len(), 1); assert_eq!(results[0], ids[0]); // The first vector should be closest @@ -333,7 +368,13 @@ mod tests { // Search with limit 3 let query = vec![0.0, 0.0, 0.0]; - let results = db.search(query, Similarity::Euclidean, 3).unwrap(); + let results = db + .search(SearchQueryInput { + vector: query, + similarity: Similarity::Euclidean, + limit: 3, + }) + .unwrap(); assert_eq!(results.len(), 3); } @@ -343,7 +384,11 @@ mod tests { let (db, _temp_dir) = create_test_db(); let query = vec![1.0, 2.0, 3.0]; - let result = db.search(query, Similarity::Cosine, 0); + let result = db.search(SearchQueryInput { + vector: query, + similarity: Similarity::Cosine, + limit: 0, + }); assert!(result.is_err()); match result.unwrap_err() { @@ -362,7 +407,13 @@ mod tests { assert!(db.get(Uuid::new_v4()).unwrap().is_none()); let query = vec![1.0, 2.0, 3.0]; - let results = db.search(query, Similarity::Cosine, 10).unwrap(); + let results = db + .search(SearchQueryInput { + vector: query, + similarity: Similarity::Cosine, + limit: 10, + }) + .unwrap(); assert_eq!(results.len(), 0); } diff --git a/crates/defs/Cargo.toml b/crates/defs/Cargo.toml index 600b80c..81d8dd9 100644 --- a/crates/defs/Cargo.toml +++ b/crates/defs/Cargo.toml @@ -7,5 +7,7 @@ edition.workspace = true license.workspace = true [dependencies] +axum.workspace = true serde.workspace = true +snafu.workspace = true uuid.workspace = true diff --git a/crates/defs/src/error.rs b/crates/defs/src/error.rs index 13f0b60..112b4f6 100644 --- a/crates/defs/src/error.rs +++ b/crates/defs/src/error.rs @@ -1,15 +1,31 @@ use std::io; -#[derive(Debug)] +use axum::{http::StatusCode, response::IntoResponse}; +use snafu::Snafu; + +#[derive(Debug, Snafu)] pub enum ServerError { - Bind(io::Error), - Serve(io::Error), + #[snafu(display("Failed to bind: {source}"))] + Bind { source: io::Error }, + + #[snafu(display("Failed to serve: {source}"))] + Serve { source: io::Error }, } #[derive(Debug)] pub enum AppError { ServerError(ServerError), + Api(String), } +impl IntoResponse for AppError { + fn into_response(self) -> axum::response::Response { + let (status, message) = match self { + AppError::Api(msg) => (StatusCode::BAD_REQUEST, msg), + AppError::ServerError(e) => (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()), + }; + (status, message).into_response() + } +} // Error type for server pub type BoxError = Box; diff --git a/crates/defs/src/types.rs b/crates/defs/src/types.rs index 65c861a..8ba46b0 100644 --- a/crates/defs/src/types.rs +++ b/crates/defs/src/types.rs @@ -19,13 +19,14 @@ pub enum StoredVector { Dense(DenseVector), } -#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq)] +#[derive(Serialize, Deserialize, Clone, Copy, Debug, Default, PartialEq)] pub enum ContentType { + #[default] Text, Image, } -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +#[derive(Serialize, Deserialize, Clone, Debug, Default, PartialEq)] pub struct Payload { pub content_type: ContentType, pub content: String, @@ -33,7 +34,7 @@ pub struct Payload { #[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] pub struct Point { - pub id: PointId, + pub id: Option, pub vector: Option, pub payload: Option, } @@ -45,7 +46,7 @@ pub struct IndexedVector { pub vector: DenseVector, } -#[derive(Debug, Deserialize, Copy, Clone)] +#[derive(Debug, Serialize, Deserialize, Copy, Clone)] pub enum Similarity { Euclidean, Manhattan, @@ -53,6 +54,40 @@ pub enum Similarity { Cosine, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BatchInsertRequest { + pub points: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BatchInsertResponse { + pub inserted: usize, + pub ids: Vec, +} + +// For batch search +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BatchSearchRequest { + pub queries: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchQueryInput { + pub vector: DenseVector, + pub similarity: Similarity, + pub limit: usize, +} + +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct SearchResponse { + pub results: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BatchSearchResponse { + pub results: Vec, +} + // Struct which stores the distance between a vector and query vector and implements ordering traits #[derive(Copy, Clone)] pub struct DistanceOrderedVector<'q> { diff --git a/crates/grpc/proto/vector-db.proto b/crates/grpc/proto/vector-db.proto index b3834e9..d5d556b 100644 --- a/crates/grpc/proto/vector-db.proto +++ b/crates/grpc/proto/vector-db.proto @@ -20,6 +20,9 @@ service VectorDB { //Search for the k nearest vectors to a target vector given a distance function rpc SearchPoints(SearchRequest) returns (SearchResponse) {} + + rpc InsertVectorsBatch(InsertVectorsBatchRequest) returns (InsertVectorsBatchResponse) {} +rpc SearchPointsBatch(SearchPointsBatchRequest) returns (SearchPointsBatchResponse) {} } @@ -71,3 +74,18 @@ message Payload { string content = 2; } +message InsertVectorsBatchRequest { + repeated InsertVectorRequest vectors = 1; +} + +message InsertVectorsBatchResponse { + repeated PointID ids = 1; +} + +message SearchPointsBatchRequest { + repeated SearchRequest queries = 1; +} + +message SearchPointsBatchResponse { + repeated SearchResponse results = 1; +} diff --git a/crates/grpc/src/service.rs b/crates/grpc/src/service.rs index 6fcb0ad..2a5a3ae 100644 --- a/crates/grpc/src/service.rs +++ b/crates/grpc/src/service.rs @@ -1,15 +1,18 @@ use std::str::FromStr; use std::sync::Arc; +use crate::error::GrpcError; use crate::interceptors; use crate::service::vectordb::{ContentType, Uuid}; use crate::utils::log_rpc; use crate::{constants::SIMILARITY_PROTOBUFF_MAP, utils::ServerEndpoint}; +use defs::SearchQueryInput; use tonic::{Request, Response, Status, service::InterceptorLayer, transport::Server}; use tracing::{Level, event}; use uuid::Uuid as UuidCrate; use vectordb::{ - DenseVector, InsertVectorRequest, Point, PointId, SearchRequest, SearchResponse, + DenseVector, InsertVectorRequest, InsertVectorsBatchRequest, InsertVectorsBatchResponse, Point, + PointId, SearchPointsBatchRequest, SearchPointsBatchResponse, SearchRequest, SearchResponse, vector_db_server::{VectorDb, VectorDbServer}, }; @@ -96,7 +99,13 @@ impl VectorDb for VectorDBService { Ok(Response::new(Point { id: Some(PointId { id: Some(Uuid { - value: point.id.to_string(), + value: point + .id + .ok_or(GrpcError::Internal { + message: "point has no id".into(), + }) + .map_err(Status::from)? + .to_string(), }), }), vector: Some(DenseVector { @@ -129,7 +138,11 @@ impl VectorDb for VectorDBService { let result_point_ids = self .vector_db - .search(query_vect.values, *similarity, limit as usize) + .search(SearchQueryInput { + vector: query_vect.values, + similarity: *similarity, + limit: limit as usize, + }) .map_err(|e| Status::from(crate::error::GrpcError::from(e)))?; // create a mapped vector of PointIds @@ -166,8 +179,82 @@ impl VectorDb for VectorDBService { Err(e) => Err(Status::from(crate::error::GrpcError::from(e))), } } -} + async fn insert_vectors_batch( + &self, + request: tonic::Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + let mut ids = Vec::with_capacity(req.vectors.len()); + + for vec in req.vectors { + let payload = vec.payload.map(|p| defs::Payload { + content_type: match ContentType::try_from(p.content_type) + .unwrap_or(ContentType::Text) + { + ContentType::Text => defs::ContentType::Text, + ContentType::Image => defs::ContentType::Image, + }, + content: p.content, + }); + + let id = self + .vector_db + .insert( + vec.vector.unwrap_or_default().values, + payload.unwrap_or_default(), + ) + .map_err(|e| tonic::Status::internal(e.to_string()))?; + + ids.push(PointId { + id: Some(Uuid { + value: id.to_string(), + }), + }); + } + + Ok(tonic::Response::new(InsertVectorsBatchResponse { ids })) + } + + async fn search_points_batch( + &self, + request: tonic::Request, + ) -> Result, tonic::Status> { + let req = request.into_inner(); + let mut results = Vec::with_capacity(req.queries.len()); + + for query in req.queries { + let similarity = SIMILARITY_PROTOBUFF_MAP + .get(query.similarity as usize) + .ok_or(tonic::Status::invalid_argument("Invalid similarity"))?; + + let ids = self + .vector_db + .search(SearchQueryInput { + vector: query + .query_vector + .ok_or(tonic::Status::invalid_argument("missing query_vector"))? + .values, + similarity: *similarity, + limit: query.limit as usize, + }) + .map_err(|e| tonic::Status::from(GrpcError::from(e)))?; + + results.push(SearchResponse { + result_point_ids: ids + .into_iter() + .map(|id| PointId { + id: Some(Uuid { + value: id.to_string(), + }), + }) + .collect(), + }); + } + + Ok(tonic::Response::new(SearchPointsBatchResponse { results })) + } +} pub async fn run_server( vector_db_service: VectorDBService, endpoint: ServerEndpoint, diff --git a/crates/http/src/handler.rs b/crates/http/src/handler.rs index ee49e46..76ba5a2 100644 --- a/crates/http/src/handler.rs +++ b/crates/http/src/handler.rs @@ -4,7 +4,10 @@ use axum::{ extract::{Path, State}, http::StatusCode, }; -use defs::{DenseVector, Payload, Point, PointId, Similarity}; +use defs::{ + AppError, BatchInsertRequest, BatchInsertResponse, BatchSearchRequest, BatchSearchResponse, + DenseVector, Payload, Point, PointId, SearchQueryInput, SearchResponse, +}; use index::error::IndexError; use serde::{Deserialize, Serialize}; use storage::error::StorageError; @@ -47,6 +50,21 @@ pub async fn insert_point_handler( } } +pub async fn batch_insert_handler( + State(state): State, + Json(request): Json, +) -> Result, AppError> { + let ids = state + .db + .insert_batch(request.points) + .map_err(|e| AppError::Api(e.to_string()))?; + + Ok(Json(BatchInsertResponse { + inserted: ids.len(), + ids, + })) +} + pub async fn get_point_handler( Path(point_id): Path, State(app_state): State, @@ -74,26 +92,11 @@ pub async fn delete_point_handler( } } -#[derive(Deserialize)] -pub struct SearchRequest { - pub vector: DenseVector, - pub similarity: Similarity, - pub limit: usize, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct SearchResponse { - pub results: Vec, -} - pub async fn search_points_handler( State(app_state): State, - Json(request): Json, + Json(request): Json, ) -> Result, (StatusCode, String)> { - match app_state - .db - .search(request.vector, request.similarity, request.limit) - { + match app_state.db.search(request) { Ok(results) => { let response = SearchResponse { results }; Ok(Json(response)) @@ -105,6 +108,21 @@ pub async fn search_points_handler( } } +pub async fn batch_search_handler( + State(state): State, + Json(request): Json, +) -> Result, AppError> { + let results = state + .db + .search_batch(request.queries) + .map_err(|e| AppError::Api(e.to_string()))? + .into_iter() + .map(|ids| SearchResponse { results: ids }) + .collect(); + + Ok(Json(BatchSearchResponse { results })) +} + /// Map `ApiError` into an HTTP `(StatusCode, String)` response. fn api_error_to_response(err: &ApiError) -> (StatusCode, String) { match err { diff --git a/crates/http/src/lib.rs b/crates/http/src/lib.rs index 8554157..544a28a 100644 --- a/crates/http/src/lib.rs +++ b/crates/http/src/lib.rs @@ -12,8 +12,8 @@ use tokio::net::TcpListener; use tracing::info; use handler::{ - delete_point_handler, get_point_handler, health_handler, insert_point_handler, root_handler, - search_points_handler, + batch_insert_handler, batch_search_handler, delete_point_handler, get_point_handler, + health_handler, insert_point_handler, root_handler, search_points_handler, }; #[derive(Clone)] @@ -33,6 +33,8 @@ pub fn create_router(db: Arc) -> Router { get(get_point_handler).delete(delete_point_handler), ) .route("/points/search", post(search_points_handler)) + .route("/points/batch", post(batch_insert_handler)) + .route("/points/search/batch", post(batch_search_handler)) .with_state(app_state) } diff --git a/crates/index/src/lib.rs b/crates/index/src/lib.rs index 3725573..44b528a 100644 --- a/crates/index/src/lib.rs +++ b/crates/index/src/lib.rs @@ -18,6 +18,13 @@ pub trait VectorIndex: Send + Sync { similarity: Similarity, k: usize, ) -> Result>; // Return a Vec of ids of closest vectors (length max k) + + fn insert_batch(&mut self, vectors: Vec) -> Result<()> { + for v in vectors { + self.insert(v)?; + } + Ok(()) + } } /// Distance function to get the distance between two vectors (taken from old version) diff --git a/crates/storage/src/rocks_db.rs b/crates/storage/src/rocks_db.rs index 4c68290..6f8d586 100644 --- a/crates/storage/src/rocks_db.rs +++ b/crates/storage/src/rocks_db.rs @@ -53,7 +53,7 @@ impl StorageEngine for RocksDbStorage { ) -> Result<(), StorageError> { let key = id.to_string(); let point = Point { - id, + id: Some(id), vector, payload, }; @@ -137,13 +137,15 @@ impl StorageEngine for RocksDbStorage { let point: Point = deserialize(&v).context(error::DeserializationSnafu { id: offset })?; - if point.id <= offset { + let Some(id) = point.id else { continue }; + + if id <= offset { continue; } if let Some(vec) = point.vector { - last_id = point.id; - result.push((point.id, vec)); + last_id = id; + result.push((id, vec)); if result.len() == limit { break; } diff --git a/crates/tui/src/app/events.rs b/crates/tui/src/app/events.rs index e8f373c..2fb390c 100644 --- a/crates/tui/src/app/events.rs +++ b/crates/tui/src/app/events.rs @@ -1,6 +1,6 @@ use super::{App, AppState, ModalType, VectorListItem}; use crossterm::event::{Event, KeyCode, KeyEvent}; -use defs::{ContentType, Payload, Similarity}; +use defs::{ContentType, Payload, SearchQueryInput, Similarity}; use std::io; use std::path::PathBuf; use uuid::Uuid; @@ -315,7 +315,13 @@ fn execute_modal_action(app: &mut App) -> io::Result<()> { } }; - let ids = db.search(query, Similarity::Cosine, k).map_err(to_io)?; + let ids = db + .search(SearchQueryInput { + vector: query, + similarity: Similarity::Cosine, + limit: k, + }) + .map_err(to_io)?; app.vector_list_items.clear(); app.vector_detail = None; @@ -326,9 +332,10 @@ fn execute_modal_action(app: &mut App) -> io::Result<()> { for id in ids.into_iter() { if let Some(point) = db.get(id).map_err(to_io)? && let Some(vector) = point.vector + && let Some(point_id) = point.id { app.vector_list_items.push(VectorListItem { - id: point.id, + id: point_id, vector, payload: point.payload, });