diff --git a/.gitignore b/.gitignore index 7eb2e3b..4e64135 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ /data/ .TODO /databases +/crates/benchmark/Datasets .env diff --git a/Cargo.lock b/Cargo.lock index e19cfa0..748c98f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -32,6 +32,56 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +[[package]] +name = "anstream" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.61.2", +] + [[package]] name = "anyhow" version = "1.0.100" @@ -42,6 +92,7 @@ checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" name = "api" version = "0.1.0" dependencies = [ + "clap", "defs", "index", "snafu", @@ -140,6 +191,17 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "benchmark" +version = "0.1.0" +dependencies = [ + "clap", + "defs", + "index", + "rayon", + "uuid", +] + [[package]] name = "bincode" version = "1.3.3" @@ -275,6 +337,46 @@ dependencies = [ "libloading", ] +[[package]] +name = "clap" +version = "4.5.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75ca66430e33a14957acc24c5077b503e7d374151b2b4b3a10c83b4ceb4be0e" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793207c7fa6300a0608d1080b858e5fdbe713cdc1c8db9fb17777d8a13e63df0" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e64b0cc0439b12df2fa678eae89a1c56a529fd067a9115f7827f1fffd22b32" + [[package]] name = "color-eyre" version = "0.6.5" @@ -302,6 +404,12 @@ dependencies = [ "tracing-error", ] +[[package]] +name = "colorchoice" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + [[package]] name = "compact_str" version = "0.7.1" @@ -331,6 +439,31 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crossterm" version = "0.27.0" @@ -360,6 +493,7 @@ dependencies = [ name = "defs" version = "0.1.0" dependencies = [ + "clap", "serde", "uuid", ] @@ -923,6 +1057,12 @@ dependencies = [ "serde", ] +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + [[package]] name = "itertools" version = "0.12.1" @@ -1200,6 +1340,12 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + [[package]] name = "openssl" version = "0.10.75" @@ -1514,6 +1660,26 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -1946,6 +2112,12 @@ dependencies = [ "uuid", ] +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "strum" version = "0.26.3" @@ -2399,6 +2571,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + [[package]] name = "uuid" version = "1.19.0" diff --git a/Cargo.toml b/Cargo.toml index 76c1c7f..2acb0e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "crates/http", "crates/tui", "crates/grpc", + "crates/benchmark", ] [workspace.package] @@ -20,6 +21,7 @@ axum = "0.8" axum-test = "18.1.0" bincode = "1.3.3" chrono = { version = "0.4", features = ["serde"] } +clap = { version = "4.5.54", features = ["derive"] } color-eyre = "0.6.5" crossterm = "0.27" dotenv = "0.15.0" @@ -53,3 +55,4 @@ index = { path = "crates/index" } server = { path = "crates/server" } storage = { path = "crates/storage" } tui = { path = "crates/tui" } +#benchmark = { path = "crates/benchmark"} diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index 287f94c..25a50fc 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -7,6 +7,7 @@ edition.workspace = true license.workspace = true [dependencies] +clap.workspace = true defs.workspace = true index.workspace = true snafu.workspace = true diff --git a/crates/benchmark/Cargo.toml b/crates/benchmark/Cargo.toml new file mode 100644 index 0000000..87ca42f --- /dev/null +++ b/crates/benchmark/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "benchmark" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +clap = { version = "4.5.54", features = ["derive"] } +defs.workspace = true +index.workspace = true +rayon = "1.11.0" +uuid = { version = "1.19.0", features = ["v4"] } diff --git a/crates/benchmark/Scripts/download_dataset.sh b/crates/benchmark/Scripts/download_dataset.sh new file mode 100755 index 0000000..561a2e1 --- /dev/null +++ b/crates/benchmark/Scripts/download_dataset.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +set -e + +if ! CARGO_TOML_PATH=$(cargo locate-project --message-format plain 2>/dev/null); then + echo "āŒ Error: Could not find Cargo.toml. Are you running this inside a Rust project?" >&2 + exit 1 +fi + +CRATE_ROOT=$(dirname "$CARGO_TOML_PATH") + +ROOT_DIR="$CRATE_ROOT/Datasets" +BASE_URL="ftp://ftp.irisa.fr/local/texmex/corpus" + +if [ -d "$ROOT_DIR" ]; then + echo "āœ… '$ROOT_DIR' already exists. Skipping download." + exit 0 +fi + +echo "šŸ“‚ Creating directory at: $ROOT_DIR" +mkdir -p "$ROOT_DIR" + +DATASETS=("siftsmall.tar.gz" "sift.tar.gz") + +for FILENAME in "${DATASETS[@]}"; do + URL="$BASE_URL/$FILENAME" + TEMP_FILE_PATH="$ROOT_DIR/$FILENAME" + EXPECTED_FOLDER="${FILENAME%%.tar.gz}" + + echo "ā¬‡ļø Downloading $FILENAME using curl..." + + if ! curl -# -L -o "$TEMP_FILE_PATH" "$URL"; then + echo "āŒ Failed to download $URL" >&2 + exit 1 + fi + + echo "šŸ“¦ Extracting $FILENAME..." + + tar -xzf "$TEMP_FILE_PATH" -C "$ROOT_DIR" + + rm "$TEMP_FILE_PATH" + + echo "✨ Extracted to $ROOT_DIR/$EXPECTED_FOLDER" +done + +echo -e "\nšŸŽ‰ Setup complete! Your file tree is ready." \ No newline at end of file diff --git a/crates/benchmark/src/load_dataset.rs b/crates/benchmark/src/load_dataset.rs new file mode 100644 index 0000000..df22542 --- /dev/null +++ b/crates/benchmark/src/load_dataset.rs @@ -0,0 +1,113 @@ +use crate::types::{Dataset, DatasetType, GroundTruth}; +use defs::IndexedVector; +use std::fs; +use std::path::PathBuf; +use uuid::Uuid; + +const MAX_DATA_POINTS: u128 = 50000; + +pub fn load_ground_truth(set: &mut Dataset, dataset: String, dataset_type: DatasetType) { + let manifest_dir = env!("CARGO_MANIFEST_DIR"); + let base_path = PathBuf::from(manifest_dir).join("Datasets"); + + let path: PathBuf = match (&dataset_type, dataset.as_str()) { + (DatasetType::DataSet, "1M") => base_path.join("sift/sift_base.fvecs"), + (DatasetType::DataSet, _) => base_path.join("siftsmall/siftsmall_base.fvecs"), + (DatasetType::TestQueries, "1M") => base_path.join("sift/sift_query.fvecs"), + (DatasetType::TestQueries, _) => base_path.join("siftsmall/siftsmall_query.fvecs"), + (DatasetType::GroundTruth, "1M") => base_path.join("sift/sift_groundtruth.ivecs"), + (DatasetType::GroundTruth, _) => base_path.join("siftsmall/siftsmall_groundtruth.ivecs"), + }; + let data = fs::read(path).unwrap(); + + let mut data_iter = data.iter(); + + let mut point_id: u128 = 0; + loop { + if point_id > MAX_DATA_POINTS { + break; + } + + let Some(first) = data_iter.next() else { + break; + }; + + let dim: u32 = (*first as u32) + | ((*data_iter.next().unwrap() as u32) << 8) + | ((*data_iter.next().unwrap() as u32) << 16) + | ((*data_iter.next().unwrap() as u32) << 24); + + let mut one_vector: GroundTruth = GroundTruth { + id: Uuid::from_bytes(point_id.to_be_bytes()), + vector: Vec::with_capacity(set.dimension), + }; + for _i in 0..dim { + let int_val: u128 = (*data_iter.next().unwrap() as u128) + | ((*data_iter.next().unwrap() as u128) << 8) + | ((*data_iter.next().unwrap() as u128) << 16) + | ((*data_iter.next().unwrap() as u128) << 24); + one_vector + .vector + .push(Uuid::from_bytes(int_val.to_be_bytes())); + } + if let DatasetType::GroundTruth = dataset_type { + set.ground_truth.push(one_vector); + } + point_id += 1; + } +} + +pub fn load_dataset_and_test_query(set: &mut Dataset, dataset: String, dataset_type: DatasetType) { + let manifest_dir = env!("CARGO_MANIFEST_DIR"); + let base_path = PathBuf::from(manifest_dir).join("Datasets"); + + let path: PathBuf = match (&dataset_type, dataset.as_str()) { + (DatasetType::DataSet, "1M") => base_path.join("sift/sift_base.fvecs"), + (DatasetType::DataSet, _) => base_path.join("siftsmall/siftsmall_base.fvecs"), + (DatasetType::TestQueries, "1M") => base_path.join("sift/sift_query.fvecs"), + (DatasetType::TestQueries, _) => base_path.join("siftsmall/siftsmall_query.fvecs"), + (DatasetType::GroundTruth, "1M") => base_path.join("sift/sift_groundtruth.ivecs"), + (DatasetType::GroundTruth, _) => base_path.join("siftsmall/siftsmall_groundtruth.ivecs"), + }; + + let data = fs::read(path).unwrap(); + + let mut data_iter = data.iter(); + + let mut point_id: u128 = 0; + loop { + if point_id > MAX_DATA_POINTS { + break; + } + let Some(first) = data_iter.next() else { + break; + }; + + let dim: u32 = (*first as u32) + | ((*data_iter.next().unwrap() as u32) << 8) + | ((*data_iter.next().unwrap() as u32) << 16) + | ((*data_iter.next().unwrap() as u32) << 24); + + let mut one_vector: IndexedVector = IndexedVector { + id: Uuid::from_bytes(point_id.to_be_bytes()), + vector: Vec::with_capacity(set.dimension), + }; + for _i in 0..dim { + let int_val: u32 = (*data_iter.next().unwrap() as u32) + | ((*data_iter.next().unwrap() as u32) << 8) + | ((*data_iter.next().unwrap() as u32) << 16) + | ((*data_iter.next().unwrap() as u32) << 24); + one_vector.vector.push(f32::from_bits(int_val)); + } + match dataset_type { + DatasetType::DataSet => { + set.data.push(one_vector); + } + DatasetType::TestQueries => { + set.test_queries.push(one_vector); + } + _ => (), + } + point_id += 1; + } +} diff --git a/crates/benchmark/src/main.rs b/crates/benchmark/src/main.rs new file mode 100644 index 0000000..1f26039 --- /dev/null +++ b/crates/benchmark/src/main.rs @@ -0,0 +1,84 @@ +use std::time::Instant; +pub mod load_dataset; +pub mod test_methods; +mod types; + +use crate::load_dataset::{load_dataset_and_test_query, load_ground_truth}; +use crate::test_methods::{build_indexer, query_latency, test_accuracy, test_throughput}; +use crate::types::DatasetType::{DataSet, GroundTruth, TestQueries}; +use crate::types::{Args, BenchIndexer, Dataset, DatasetType, IndexerType}; +use clap::Parser; +use index::{flat, hnsw}; + +fn main() { + let args: Args = Args::parse(); + + // Loading the dataset. + let mut set: Dataset = Dataset::_new(); + let set_type: DatasetType = DataSet; + + // Data + println!("Loading dataset..."); + let dataset = args.dataset; + load_dataset_and_test_query(&mut set, dataset.clone(), set_type); + + //Test query + println!("Loading test queries..."); + let set_type: DatasetType = TestQueries; + load_dataset_and_test_query(&mut set, dataset.clone(), set_type); + + //Ground Truth + println!("Loading ground truth...\n"); + let set_type: DatasetType = GroundTruth; + load_ground_truth(&mut set, dataset, set_type); + + let a = set.data.len(); + let b = set.test_queries.len(); + let c = set.ground_truth.len(); + println!("Dataset size: {:?}", a); + println!("Size of test queries: {:?}", b); + println!( + "Size of Ground truth: {:?} (Done for sanity check of dataset)\n", + c + ); + + // Create indexer + let mut indexer: BenchIndexer; + + match args.indexer { + IndexerType::Hnsw => { + let index: hnsw::HnswIndex = hnsw::HnswIndex::new(args.similarity, set.dimension); + indexer = BenchIndexer::HnswIndex(index); + } + IndexerType::Flat => { + let index = flat::FlatIndex::new(); + indexer = BenchIndexer::FlatIndex(index); + } + + IndexerType::KdTree => { + let index: hnsw::HnswIndex = hnsw::HnswIndex::new(args.similarity, set.dimension); + indexer = BenchIndexer::HnswIndex(index); + } + } + + println!("Building dataset ",); + let start = Instant::now(); + build_indexer(set.data, &mut indexer); + let duration = start.elapsed(); + println!("Building took {:?} \n", duration); + + println!("Testing accuracy:"); + test_accuracy( + set.test_queries.clone(), + set.ground_truth, + args.similarity, + &indexer, + args.k, + ); + + println!("Benchmarking Query Latency"); + query_latency(set.test_queries.clone(), &indexer, args.similarity, args.k); + + println!("Test throughput "); + test_throughput(set.test_queries, &indexer, args.similarity, args.k); +} diff --git a/crates/benchmark/src/test_methods.rs b/crates/benchmark/src/test_methods.rs new file mode 100644 index 0000000..85143b3 --- /dev/null +++ b/crates/benchmark/src/test_methods.rs @@ -0,0 +1,159 @@ +use crate::types::{BenchIndexer, GroundTruth}; +use defs::{IndexedVector, Similarity}; +use index::VectorIndex; +use index::flat::FlatIndex; +use index::hnsw::HnswIndex; +use index::kd_tree::index::KDTree; +use rayon::prelude::*; +use std::time::{Duration, Instant}; + +fn build_flat(dataset: Vec, indexer: &mut FlatIndex) { + for i in dataset { + indexer.insert(i).unwrap(); + } +} +fn build_hnsw(dataset: Vec, indexer: &mut HnswIndex) { + let mut dataset = dataset.into_iter(); + for _i in 0..10000 { + indexer.insert(dataset.next().unwrap()).unwrap(); + } +} + +fn build_kd_tree(dataset: Vec, indexer: &mut KDTree) { + for i in dataset { + indexer.insert(i).unwrap(); + } +} +pub fn build_indexer(dataset: Vec, indexer: &mut BenchIndexer) { + match indexer { + BenchIndexer::HnswIndex(indexer) => { + build_hnsw(dataset, indexer); + } + + BenchIndexer::FlatIndex(indexer) => { + build_flat(dataset, indexer); + } + + BenchIndexer::KdTree(indexer) => { + build_kd_tree(dataset, indexer); + } + } +} + +pub fn ttest_accuracy( + queries: Vec, + ground_truth: Vec, + similarity: Similarity, + indexer: &BenchIndexer, + k: usize, +) { + let length = queries.len(); + let mut results: Vec = Vec::with_capacity(length); + for query in queries.into_iter() { + let result = indexer.search(query.vector, similarity, k).unwrap(); + let one_unit = GroundTruth { + id: query.id, + vector: result, + }; + results.push(one_unit); + } + + let mut result_iter = results.into_iter(); + let mut ground_iter = ground_truth.into_iter(); + for _i in 0..length { + let a = result_iter.next().unwrap(); + let b = ground_iter.next().unwrap(); + let c = &b.vector[..k.min(b.vector.len())]; + + let mut count: i32 = 0; + for j in 0..k { + if a.vector[j] != b.vector[j] { + count += 1; + } + } + + println!("Point id: "); + println!("{:?},\n{:?}\n", a.id, b.id); + + println!("Vectors: "); + println!("{:?},\n{:?}\n", a.vector, c); + + println!("Incorrect matches: {count}"); + } +} +pub fn test_accuracy( + queries: Vec, + ground_truth: Vec, + similarity: Similarity, + indexer: &BenchIndexer, + k: usize, +) { + let length = ground_truth.len(); + let mut count: i32 = 0; + queries + .into_iter() + .zip(ground_truth) + .for_each(|(query, truth)| { + let result_vector = indexer.search(query.vector, similarity, k).unwrap(); + + let truth_slice = &truth.vector[..k.min(truth.vector.len())]; + + for j in 0..k { + if j < result_vector.len() + && j < truth_slice.len() + && result_vector[j] != truth_slice[j] + { + count += 1; + } + } + }); + println!("{:?} wrong out of {:?} \n ", count, length); +} +pub fn query_latency( + dataset: Vec, + indexer_enum: &BenchIndexer, + similarity: Similarity, + k: usize, +) { + let mut query_time: Vec = Vec::with_capacity(dataset.len()); + + for query in dataset { + let start = Instant::now(); + let _ = indexer_enum.search(query.vector, similarity, k); + let duration = start.elapsed(); + query_time.push(duration); + } + + query_time.sort_unstable(); + println!("Query time median: {:?}", query_time[query_time.len() / 2]); + println!( + "Query time 99 percentiles: {:?}\n", + query_time[query_time.len() - 1] + ); +} + +pub fn test_throughput( + queries: Vec, + indexer: &BenchIndexer, + similarity: Similarity, + k: usize, +) { + let num_queries = queries.len(); + println!("Starting throughput test with {} queries...", num_queries); + + let start = Instant::now(); + + queries.into_par_iter().for_each(|query| { + let _result = indexer.search(query.vector, similarity, k).unwrap(); + }); + + let duration = start.elapsed(); + let total_seconds = duration.as_secs_f64(); + let qps = num_queries as f64 / total_seconds; + + println!("-----------------------------------"); + println!("Throughput Test Results:"); + println!("Total Time: {:.4} s", total_seconds); + println!("Throughput: {:.2} QPS (Queries Per Second)", qps); + println!("-----------------------------------"); +} diff --git a/crates/benchmark/src/types.rs b/crates/benchmark/src/types.rs new file mode 100644 index 0000000..5a338bd --- /dev/null +++ b/crates/benchmark/src/types.rs @@ -0,0 +1,84 @@ +use clap::{Parser, ValueEnum}; +use defs::{DenseVector, IndexedVector, PointId, Similarity}; +use index::{VectorIndex, error::Result, flat, hnsw, kd_tree}; +#[derive(Debug)] +pub enum DatasetType { + DataSet, + TestQueries, + GroundTruth, +} + +#[derive(Debug)] +pub struct GroundTruth { + pub id: PointId, + pub vector: Vec, +} + +pub enum BenchIndexer { + FlatIndex(flat::FlatIndex), + HnswIndex(hnsw::HnswIndex), + KdTree(kd_tree::index::KDTree), +} + +impl BenchIndexer { + pub fn search( + &self, + vector: DenseVector, + similarity: Similarity, + k: usize, + ) -> Result> { + match self { + // Forward the call to the inner FlatIndex + BenchIndexer::FlatIndex(inner) => inner.search(vector, similarity, k), + // Forward the call to the inner HnswIndex + BenchIndexer::HnswIndex(inner) => inner.search(vector, similarity, k), + + BenchIndexer::KdTree(inner) => inner.search(vector, similarity, k), + } + } +} + +pub struct Dataset { + pub dimension: usize, + pub data: Vec, + pub test_queries: Vec, + pub ground_truth: Vec, +} + +impl Dataset { + pub fn _new() -> Self { + Dataset { + dimension: 128, + data: Vec::new(), + test_queries: Vec::new(), + ground_truth: Vec::new(), + } + } +} + +#[derive(ValueEnum, Clone, Debug)] +#[clap(rename_all = "lower")] +pub enum IndexerType { + Flat, + KdTree, + Hnsw, +} + +#[derive(Parser, Debug)] +#[command(author, version, about = "Vector Index Benchmark")] +pub struct Args { + /// The dataset to use for benchmarking + #[clap(short, long, default_value = "1M")] + pub dataset: String, + + /// Type of indexer to benchmark + #[arg(short, long, value_enum, default_value_t = IndexerType::Flat)] + pub indexer: IndexerType, + + /// Number of nearest neighbors to retrieve + #[arg(short, long, default_value_t = 10)] + pub k: usize, + + #[arg(short, long, value_enum, default_value_t = Similarity::Euclidean)] + pub similarity: Similarity, +} diff --git a/crates/defs/Cargo.toml b/crates/defs/Cargo.toml index 600b80c..9dd14e4 100644 --- a/crates/defs/Cargo.toml +++ b/crates/defs/Cargo.toml @@ -7,5 +7,6 @@ edition.workspace = true license.workspace = true [dependencies] +clap.workspace = true serde.workspace = true uuid.workspace = true diff --git a/crates/defs/src/types.rs b/crates/defs/src/types.rs index 65c861a..be544ce 100644 --- a/crates/defs/src/types.rs +++ b/crates/defs/src/types.rs @@ -1,3 +1,4 @@ +use clap::ValueEnum; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use uuid::Uuid; @@ -45,7 +46,7 @@ pub struct IndexedVector { pub vector: DenseVector, } -#[derive(Debug, Deserialize, Copy, Clone)] +#[derive(Debug, Deserialize, Copy, Clone, ValueEnum)] pub enum Similarity { Euclidean, Manhattan,