From a8a9eef7649f3606e8c103d58778c081ecb2225e Mon Sep 17 00:00:00 2001 From: t-sanagrawal Date: Wed, 10 Sep 2025 17:26:19 +0530 Subject: [PATCH 1/2] Please enter the commit message for your changes. Lines starting --- apps/build_stitched_index.cpp | 3 ++- apps/search_memory_index.cpp | 15 +++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/apps/build_stitched_index.cpp b/apps/build_stitched_index.cpp index 60e38c1be..ca5e96916 100644 --- a/apps/build_stitched_index.cpp +++ b/apps/build_stitched_index.cpp @@ -343,8 +343,9 @@ int main(int argc, char **argv) path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt"; path labels_map_file = final_index_path_prefix + "_labels_map.txt"; + uint32_t universal_label_id = 0; - convert_labels_string_to_int(label_data_path, labels_file_to_use, labels_map_file, universal_label); + convert_labels_string_to_int(label_data_path, labels_file_to_use, labels_map_file, universal_label, universal_label_id); // 2. parse label file and create necessary data structures std::vector point_ids_to_labels; diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 1a9acc285..4db970584 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -166,10 +166,17 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, if (filtered_search && !tags) { std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; - - auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L, - query_result_ids[test_id].data() + i * recall_at, - query_result_dists[test_id].data() + i * recall_at); + uint32_t maxi = 0; + std::pair retval = + index->search_with_filters( + query + i * query_aligned_dim, + raw_filter, + recall_at, + L, + maxi, + query_result_ids[test_id].data() + i * recall_at, + query_result_dists[test_id].data() + i * recall_at + ); cmp_stats[i] = retval.second; } else if (metric == diskann::FAST_L2) From 2ec87450449f93da115c8e1c41cd9ab878a53dfc Mon Sep 17 00:00:00 2001 From: t-sanagrawal Date: Thu, 11 Sep 2025 19:30:47 +0530 Subject: [PATCH 2/2] Fixed files --- apps/build_memory_index.cpp | 21 +- apps/search_memory_index.cpp | 536 +++++++++++++++--------------- include/program_options_utils.hpp | 3 + 3 files changed, 297 insertions(+), 263 deletions(-) diff --git a/apps/build_memory_index.cpp b/apps/build_memory_index.cpp index 544e42dee..212dfc375 100644 --- a/apps/build_memory_index.cpp +++ b/apps/build_memory_index.cpp @@ -24,10 +24,11 @@ namespace po = boost::program_options; int main(int argc, char **argv) { - std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type; - uint32_t num_threads, R, L, Lf, build_PQ_bytes; + std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type, seller_file; + uint32_t num_threads, R, L, Lf, build_PQ_bytes, num_diverse_build; float alpha; bool use_pq_build, use_opq; + bool diverse_index=false; po::options_description desc{ program_options_utils::make_program_description("build_memory_index", "Build a memory-based DiskANN index.")}; @@ -70,6 +71,12 @@ int main(int argc, char **argv) program_options_utils::FILTERED_LBUILD); optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), program_options_utils::LABEL_TYPE_DESCRIPTION); + optional_configs.add_options()("seller_file", po::value(&seller_file)->default_value(""), + program_options_utils::DIVERSITY_FILE); + optional_configs.add_options()("NumDiverse", po::value(&num_diverse_build)->default_value(1), + program_options_utils::NUM_DIVERSE); + + // Merge required and optional parameters desc.add(required_configs).add(optional_configs); @@ -112,6 +119,9 @@ int main(int argc, char **argv) return -1; } + if(seller_file != "") + diverse_index = true; + try { diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha @@ -120,11 +130,16 @@ int main(int argc, char **argv) size_t data_num, data_dim; diskann::get_bin_metadata(data_path, data_num, data_dim); + std::cout<<"Num diverse build: " << num_diverse_build << std::endl; + auto index_build_params = diskann::IndexWriteParametersBuilder(L, R) .with_filter_list_size(Lf) .with_alpha(alpha) .with_saturate_graph(false) .with_num_threads(num_threads) + .with_diverse_index(diverse_index) + .with_seller_file(seller_file) + .with_num_diverse_build(num_diverse_build) .build(); auto filter_params = diskann::IndexFilterParamsBuilder() @@ -161,4 +176,4 @@ int main(int argc, char **argv) diskann::cerr << "Index build failed." << std::endl; return -1; } -} +} \ No newline at end of file diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 4db970584..203079e8e 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -1,60 +1,79 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. -#include -#include -#include -#include -#include -#include -#include +#include "common_includes.h" #include +#include "index.h" +#include "disk_utils.h" +#include "math_utils.h" +#include "memory_mapper.h" +#include "partition.h" +#include "pq_flash_index.h" +#include "timer.h" +#include "percentile_stats.h" +#include "program_options_utils.hpp" + #ifndef _WINDOWS #include #include -#include #include +#include "linux_aligned_file_reader.h" +#else +#ifdef USE_BING_INFRA +#include "bing_aligned_file_reader.h" +#else +#include "windows_aligned_file_reader.h" +#endif #endif -#include "index.h" -#include "memory_mapper.h" -#include "utils.h" -#include "program_options_utils.hpp" -#include "index_factory.h" +#define WARMUP false namespace po = boost::program_options; +void print_stats(std::string category, std::vector percentiles, std::vector results) +{ + diskann::cout << std::setw(20) << category << ": " << std::flush; + for (uint32_t s = 0; s < percentiles.size(); s++) + { + diskann::cout << std::setw(8) << percentiles[s] << "%"; + } + diskann::cout << std::endl; + diskann::cout << std::setw(22) << " " << std::flush; + for (uint32_t s = 0; s < percentiles.size(); s++) + { + diskann::cout << std::setw(9) << results[s]; + } + diskann::cout << std::endl; +} + template -int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &result_path_prefix, - const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads, - const uint32_t recall_at, const bool print_all_recalls, const std::vector &Lvec, - const bool dynamic, const bool tags, const bool show_qps_per_thread, - const std::vector &query_filters, const float fail_if_recall_below) +int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix, + const std::string &result_output_prefix, const std::string &query_file, std::string >_file, + const uint32_t num_threads, const uint32_t recall_at, const uint32_t beamwidth, + const uint32_t num_nodes_to_cache, const uint32_t search_io_limit, + const std::vector &Lvec, const float fail_if_recall_below, + const std::vector &query_filters, const bool use_reorder_data = false, const uint32_t max_K_per_seller = std::numeric_limits::max()) { - using TagT = uint32_t; - // Load the query file + diskann::cout << "Search parameters: #threads: " << num_threads << ", "; + if (beamwidth <= 0) + diskann::cout << "beamwidth to be optimized for each L value" << std::flush; + else + diskann::cout << " beamwidth: " << beamwidth << std::flush; + if (search_io_limit == std::numeric_limits::max()) + diskann::cout << "." << std::endl; + else + diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl; + + std::string warmup_query_file = index_path_prefix + "_sample_data.bin"; + + // load query bin T *query = nullptr; uint32_t *gt_ids = nullptr; float *gt_dists = nullptr; size_t query_num, query_dim, query_aligned_dim, gt_num, gt_dim; diskann::load_aligned_bin(query_file, query, query_num, query_dim, query_aligned_dim); - bool calc_recall_flag = false; - if (truthset_file != std::string("null") && file_exists(truthset_file)) - { - diskann::load_truthset(truthset_file, gt_ids, gt_dists, gt_num, gt_dim); - if (gt_num != query_num) - { - std::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl; - } - calc_recall_flag = true; - } - else - { - diskann::cout << " Truthset file " << truthset_file << " not found. Not computing recall." << std::endl; - } - bool filtered_search = false; if (!query_filters.empty()) { @@ -68,234 +87,244 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } } - const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path); - - auto config = diskann::IndexConfigBuilder() - .with_metric(metric) - .with_dimension(query_dim) - .with_max_points(0) - .with_data_load_store_strategy(diskann::DataStoreStrategy::MEMORY) - .with_graph_load_store_strategy(diskann::GraphStoreStrategy::MEMORY) - .with_data_type(diskann_type_to_name()) - .with_label_type(diskann_type_to_name()) - .with_tag_type(diskann_type_to_name()) - .is_dynamic_index(dynamic) - .is_enable_tags(tags) - .is_concurrent_consolidate(false) - .is_pq_dist_build(false) - .is_use_opq(false) - .with_num_pq_chunks(0) - .with_num_frozen_pts(num_frozen_pts) - .build(); - - auto index_factory = diskann::IndexFactory(config); - auto index = index_factory.create_instance(); - index->load(index_path.c_str(), num_threads, *(std::max_element(Lvec.begin(), Lvec.end()))); - std::cout << "Index loaded" << std::endl; - - if (metric == diskann::FAST_L2) - index->optimize_index_layout(); - - std::cout << "Using " << num_threads << " threads to search" << std::endl; - std::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); - std::cout.precision(2); - const std::string qps_title = show_qps_per_thread ? "QPS/thread" : "QPS"; - uint32_t table_width = 0; - if (tags) + bool calc_recall_flag = false; + if (gt_file != std::string("null") && gt_file != std::string("NULL") && file_exists(gt_file)) { - std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(20) << "Mean Latency (mus)" - << std::setw(15) << "99.9 Latency"; - table_width += 4 + 12 + 20 + 15; + diskann::load_truthset(gt_file, gt_ids, gt_dists, gt_num, gt_dim); + if (gt_num != query_num) + { + diskann::cout << "Error. Mismatch in number of queries and ground truth data" << std::endl; + } + calc_recall_flag = true; } - else + + std::shared_ptr reader = nullptr; +#ifdef _WINDOWS +#ifndef USE_BING_INFRA + reader.reset(new WindowsAlignedFileReader()); +#else + reader.reset(new diskann::BingAlignedFileReader()); +#endif +#else + reader.reset(new LinuxAlignedFileReader()); +#endif + + std::unique_ptr> _pFlashIndex( + new diskann::PQFlashIndex(reader, metric)); + + int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str()); + + if (res != 0) { - std::cout << std::setw(4) << "Ls" << std::setw(12) << qps_title << std::setw(18) << "Avg dist cmps" - << std::setw(20) << "Mean Latency (mus)" << std::setw(15) << "99.9 Latency"; - table_width += 4 + 12 + 18 + 20 + 15; + return res; } - uint32_t recalls_to_print = 0; - const uint32_t first_recall = print_all_recalls ? 1 : recall_at; - if (calc_recall_flag) + + std::vector node_list; + diskann::cout << "Caching " << num_nodes_to_cache << " nodes around medoid(s)" << std::endl; + _pFlashIndex->cache_bfs_levels(num_nodes_to_cache, node_list); + // if (num_nodes_to_cache > 0) + // _pFlashIndex->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6, num_nodes_to_cache, + // num_threads, node_list); + _pFlashIndex->load_cache_list(node_list); + node_list.clear(); + node_list.shrink_to_fit(); + + omp_set_num_threads(num_threads); + + uint64_t warmup_L = 20; + uint64_t warmup_num = 0, warmup_dim = 0, warmup_aligned_dim = 0; + T *warmup = nullptr; + + if (WARMUP) { - for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++) + if (file_exists(warmup_query_file)) { - std::cout << std::setw(12) << ("Recall@" + std::to_string(curr_recall)); + diskann::load_aligned_bin(warmup_query_file, warmup, warmup_num, warmup_dim, warmup_aligned_dim); } - recalls_to_print = recall_at + 1 - first_recall; - table_width += recalls_to_print * 12; - } - std::cout << std::endl; - std::cout << std::string(table_width, '=') << std::endl; + else + { + warmup_num = (std::min)((uint32_t)150000, (uint32_t)15000 * num_threads); + warmup_dim = query_dim; + warmup_aligned_dim = query_aligned_dim; + diskann::alloc_aligned(((void **)&warmup), warmup_num * warmup_aligned_dim * sizeof(T), 8 * sizeof(T)); + std::memset(warmup, 0, warmup_num * warmup_aligned_dim * sizeof(T)); + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(-128, 127); + for (uint32_t i = 0; i < warmup_num; i++) + { + for (uint32_t d = 0; d < warmup_dim; d++) + { + warmup[i * warmup_aligned_dim + d] = (T)dis(gen); + } + } + } + diskann::cout << "Warming up index... " << std::flush; + std::vector warmup_result_ids_64(warmup_num, 0); + std::vector warmup_result_dists(warmup_num, 0); - std::vector> query_result_ids(Lvec.size()); - std::vector> query_result_dists(Lvec.size()); - std::vector latency_stats(query_num, 0); - std::vector cmp_stats; - if (not tags || filtered_search) - { - cmp_stats = std::vector(query_num, 0); +#pragma omp parallel for schedule(dynamic, 1) + for (int64_t i = 0; i < (int64_t)warmup_num; i++) + { + _pFlashIndex->cached_beam_search(warmup + (i * warmup_aligned_dim), 1, warmup_L, + warmup_result_ids_64.data() + (i * 1), + warmup_result_dists.data() + (i * 1), 4); + } + diskann::cout << "..done" << std::endl; } - std::vector query_result_tags; - if (tags) + diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); + diskann::cout.precision(2); + + std::string recall_string = "Recall@" + std::to_string(recall_at); + diskann::cout << std::setw(6) << "L" << std::setw(12) << "Beamwidth" << std::setw(16) << "QPS" << std::setw(16) + << "Mean Latency" << std::setw(16) << "99.9 Latency" << std::setw(16) << "Mean IOs" << std::setw(16) + << "CPU (s)"; + if (calc_recall_flag) { - query_result_tags.resize(recall_at * query_num); + diskann::cout << std::setw(16) << recall_string << std::endl; } + else + diskann::cout << std::endl; + diskann::cout << "===============================================================" + "=======================================================" + << std::endl; + + std::vector> query_result_ids(Lvec.size()); + std::vector> query_result_dists(Lvec.size()); + + uint32_t optimized_beamwidth = 2; double best_recall = 0.0; for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) { uint32_t L = Lvec[test_id]; + if (L < recall_at) { diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl; continue; } + if (beamwidth <= 0) + { + diskann::cout << "Tuning beamwidth.." << std::endl; + optimized_beamwidth = + optimize_beamwidth(_pFlashIndex, warmup, warmup_num, warmup_aligned_dim, L, optimized_beamwidth); + } + else + optimized_beamwidth = beamwidth; + query_result_ids[test_id].resize(recall_at * query_num); query_result_dists[test_id].resize(recall_at * query_num); - std::vector res = std::vector(); + auto stats = new diskann::QueryStats[query_num]; + + std::vector query_result_ids_64(recall_at * query_num); auto s = std::chrono::high_resolution_clock::now(); - omp_set_num_threads(num_threads); + #pragma omp parallel for schedule(dynamic, 1) for (int64_t i = 0; i < (int64_t)query_num; i++) { - auto qs = std::chrono::high_resolution_clock::now(); - if (filtered_search && !tags) + if (!filtered_search) { - std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; - uint32_t maxi = 0; - std::pair retval = - index->search_with_filters( - query + i * query_aligned_dim, - raw_filter, - recall_at, - L, - maxi, - query_result_ids[test_id].data() + i * recall_at, - query_result_dists[test_id].data() + i * recall_at - ); - cmp_stats[i] = retval.second; + _pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L, + query_result_ids_64.data() + (i * recall_at), + query_result_dists[test_id].data() + (i * recall_at), + optimized_beamwidth, max_K_per_seller, use_reorder_data, stats + i); } - else if (metric == diskann::FAST_L2) - { - index->search_with_optimized_layout(query + i * query_aligned_dim, recall_at, L, - query_result_ids[test_id].data() + i * recall_at); - } - else if (tags) + else { - if (!filtered_search) - { - index->search_with_tags(query + i * query_aligned_dim, recall_at, L, - query_result_tags.data() + i * recall_at, nullptr, res); + LabelT label_for_search; + if (query_filters.size() == 1) + { // one label for all queries + label_for_search = _pFlashIndex->get_converted_label(query_filters[0]); } else - { - std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i]; - - index->search_with_tags(query + i * query_aligned_dim, recall_at, L, - query_result_tags.data() + i * recall_at, nullptr, res, true, raw_filter); - } - - for (int64_t r = 0; r < (int64_t)recall_at; r++) - { - query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r]; + { // one label for each query + label_for_search = _pFlashIndex->get_converted_label(query_filters[i]); } + _pFlashIndex->cached_beam_search( + query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at), + query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search, std::numeric_limits::max(), + use_reorder_data, stats + i); } - else - { - cmp_stats[i] = index - ->search(query + i * query_aligned_dim, recall_at, L, - query_result_ids[test_id].data() + i * recall_at) - .second; - } - auto qe = std::chrono::high_resolution_clock::now(); - std::chrono::duration diff = qe - qs; - latency_stats[i] = (float)(diff.count() * 1000000); } - std::chrono::duration diff = std::chrono::high_resolution_clock::now() - s; + auto e = std::chrono::high_resolution_clock::now(); + std::chrono::duration diff = e - s; + double qps = (1.0 * query_num) / (1.0 * diff.count()); - double displayed_qps = query_num / diff.count(); + diskann::convert_types(query_result_ids_64.data(), query_result_ids[test_id].data(), + query_num, recall_at); - if (show_qps_per_thread) - displayed_qps /= num_threads; + auto mean_latency = diskann::get_mean_stats( + stats, query_num, [](const diskann::QueryStats &stats) { return stats.total_us; }); - std::vector recalls; - if (calc_recall_flag) - { - recalls.reserve(recalls_to_print); - for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++) - { - recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, - query_result_ids[test_id].data(), recall_at, curr_recall)); - } - } + auto latency_999 = diskann::get_percentile_stats( + stats, query_num, 0.999, [](const diskann::QueryStats &stats) { return stats.total_us; }); - std::sort(latency_stats.begin(), latency_stats.end()); - double mean_latency = - std::accumulate(latency_stats.begin(), latency_stats.end(), 0.0) / static_cast(query_num); + auto mean_ios = diskann::get_mean_stats(stats, query_num, + [](const diskann::QueryStats &stats) { return stats.n_ios; }); - float avg_cmps = (float)std::accumulate(cmp_stats.begin(), cmp_stats.end(), 0) / (float)query_num; + auto mean_cpuus = diskann::get_mean_stats(stats, query_num, + [](const diskann::QueryStats &stats) { return stats.cpu_us; }); - if (tags && !filtered_search) - { - std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(20) << (float)mean_latency - << std::setw(15) << (float)latency_stats[(uint64_t)(0.999 * query_num)]; - } - else + double recall = 0; + if (calc_recall_flag) { - std::cout << std::setw(4) << L << std::setw(12) << displayed_qps << std::setw(18) << avg_cmps - << std::setw(20) << (float)mean_latency << std::setw(15) - << (float)latency_stats[(uint64_t)(0.999 * query_num)]; + recall = diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, + query_result_ids[test_id].data(), recall_at, recall_at); + best_recall = std::max(recall, best_recall); } - for (double recall : recalls) + + diskann::cout << std::setw(6) << L << std::setw(12) << optimized_beamwidth << std::setw(16) << qps + << std::setw(16) << mean_latency << std::setw(16) << latency_999 << std::setw(16) << mean_ios + << std::setw(16) << mean_cpuus; + if (calc_recall_flag) { - std::cout << std::setw(12) << recall; - best_recall = std::max(recall, best_recall); + diskann::cout << std::setw(16) << recall << std::endl; } - std::cout << std::endl; + else + diskann::cout << std::endl; + delete[] stats; } - std::cout << "Done searching. Now saving results " << std::endl; + diskann::cout << "Done searching. Now saving results " << std::endl; uint64_t test_id = 0; for (auto L : Lvec) { if (L < recall_at) - { - diskann::cout << "Ignoring search with L:" << L << " since it's smaller than K:" << recall_at << std::endl; continue; - } - std::string cur_result_path_prefix = result_path_prefix + "_" + std::to_string(L); - std::string cur_result_path = cur_result_path_prefix + "_idx_uint32.bin"; + std::string cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_idx_uint32.bin"; diskann::save_bin(cur_result_path, query_result_ids[test_id].data(), query_num, recall_at); - cur_result_path = cur_result_path_prefix + "_dists_float.bin"; - diskann::save_bin(cur_result_path, query_result_dists[test_id].data(), query_num, recall_at); - - test_id++; + cur_result_path = result_output_prefix + "_" + std::to_string(L) + "_dists_float.bin"; + diskann::save_bin(cur_result_path, query_result_dists[test_id++].data(), query_num, recall_at); } diskann::aligned_free(query); + if (warmup != nullptr) + diskann::aligned_free(warmup); return best_recall >= fail_if_recall_below ? 0 : -1; } int main(int argc, char **argv) { - std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type, - query_filters_file; - uint32_t num_threads, K; + std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file, filter_label, + label_type, query_filters_file; + uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit; + uint32_t max_K_per_seller = std::numeric_limits::max(); std::vector Lvec; - bool print_all_recalls, dynamic, tags, show_qps_per_thread; + bool use_reorder_data = false; float fail_if_recall_below = 0.0f; po::options_description desc{ - program_options_utils::make_program_description("search_memory_index", "Searches in-memory DiskANN indexes")}; + program_options_utils::make_program_description("search_disk_index", "Searches on-disk DiskANN indexes")}; try { - desc.add_options()("help,h", "Print this information on arguments"); + desc.add_options()("help,h", "Print information on arguments"); // Required parameters po::options_description required_configs("Required"); @@ -305,7 +334,7 @@ int main(int argc, char **argv) program_options_utils::DISTANCE_FUNCTION_DESCRIPTION); required_configs.add_options()("index_path_prefix", po::value(&index_path_prefix)->required(), program_options_utils::INDEX_PATH_PREFIX_DESCRIPTION); - required_configs.add_options()("result_path", po::value(&result_path)->required(), + required_configs.add_options()("result_path", po::value(&result_path_prefix)->required(), program_options_utils::RESULT_PATH_DESCRIPTION); required_configs.add_options()("query_file", po::value(&query_file)->required(), program_options_utils::QUERY_FILE_DESCRIPTION); @@ -317,6 +346,22 @@ int main(int argc, char **argv) // Optional parameters po::options_description optional_configs("Optional"); + optional_configs.add_options()("gt_file", po::value(>_file)->default_value(std::string("null")), + program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); + optional_configs.add_options()("beamwidth,W", po::value(&W)->default_value(2), + program_options_utils::BEAMWIDTH); + optional_configs.add_options()("num_nodes_to_cache", po::value(&num_nodes_to_cache)->default_value(0), + program_options_utils::NUMBER_OF_NODES_TO_CACHE); + optional_configs.add_options()( + "search_io_limit", + po::value(&search_io_limit)->default_value(std::numeric_limits::max()), + "Max #IOs for search. Default value: uint32::max()"); + optional_configs.add_options()("num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("use_reorder_data", po::bool_switch()->default_value(false), + "Include full precision data in the index. Use only in " + "conjuction with compressed data on SSD. Default value: false"); optional_configs.add_options()("filter_label", po::value(&filter_label)->default_value(std::string("")), program_options_utils::FILTER_LABEL_DESCRIPTION); @@ -325,31 +370,15 @@ int main(int argc, char **argv) program_options_utils::FILTERS_FILE_DESCRIPTION); optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), program_options_utils::LABEL_TYPE_DESCRIPTION); - optional_configs.add_options()("gt_file", po::value(>_file)->default_value(std::string("null")), - program_options_utils::GROUND_TRUTH_FILE_DESCRIPTION); - optional_configs.add_options()("num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - program_options_utils::NUMBER_THREADS_DESCRIPTION); - optional_configs.add_options()( - "dynamic", po::value(&dynamic)->default_value(false), - "Whether the index is dynamic. Dynamic indices must have associated tags. Default false."); - optional_configs.add_options()("tags", po::value(&tags)->default_value(false), - "Whether to search with external identifiers (tags). Default false."); optional_configs.add_options()("fail_if_recall_below", po::value(&fail_if_recall_below)->default_value(0.0f), program_options_utils::FAIL_IF_RECALL_BELOW); + optional_configs.add_options()("max_K_per_seller", po::value(&max_K_per_seller)->default_value(std::numeric_limits::max()), + "Diverse search, max number of results per seller"); - // Output controls - po::options_description output_controls("Output controls"); - output_controls.add_options()("print_all_recalls", po::bool_switch(&print_all_recalls), - "Print recalls at all positions, from 1 up to specified " - "recall_at value"); - output_controls.add_options()("print_qps_per_thread", po::bool_switch(&show_qps_per_thread), - "Print overall QPS divided by the number of threads in " - "the output table"); // Merge required and optional parameters - desc.add(required_configs).add(optional_configs).add(output_controls); + desc.add(required_configs).add(optional_configs); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); @@ -359,6 +388,8 @@ int main(int argc, char **argv) return 0; } po::notify(vm); + if (vm["use_reorder_data"].as()) + use_reorder_data = true; } catch (const std::exception &ex) { @@ -367,7 +398,7 @@ int main(int argc, char **argv) } diskann::Metric metric; - if ((dist_fn == std::string("mips")) && (data_type == std::string("float"))) + if (dist_fn == std::string("mips")) { metric = diskann::Metric::INNER_PRODUCT; } @@ -379,28 +410,25 @@ int main(int argc, char **argv) { metric = diskann::Metric::COSINE; } - else if ((dist_fn == std::string("fast_l2")) && (data_type == std::string("float"))) - { - metric = diskann::Metric::FAST_L2; - } else { - std::cout << "Unsupported distance function. Currently only l2/ cosine are " - "supported in general, and mips/fast_l2 only for floating " - "point data." + std::cout << "Unsupported distance function. Currently only L2/ Inner " + "Product/Cosine are supported." << std::endl; return -1; } - if (dynamic && not tags) + if ((data_type != std::string("float")) && (metric == diskann::Metric::INNER_PRODUCT)) { - std::cerr << "Tags must be enabled while searching dynamically built indices" << std::endl; + std::cout << "Currently support only floating point data for Inner Product." << std::endl; return -1; } - if (fail_if_recall_below < 0.0 || fail_if_recall_below >= 100.0) + if (use_reorder_data && data_type != std::string("float")) { - std::cerr << "fail_if_recall_below parameter must be between 0 and 100%" << std::endl; + std::cout << "Error: Reorder data for reordering currently only " + "supported for float data type." + << std::endl; return -1; } @@ -424,61 +452,49 @@ int main(int argc, char **argv) { if (!query_filters.empty() && label_type == "ushort") { - if (data_type == std::string("int8")) - { - return search_memory_index( - metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); - } + if (data_type == std::string("float")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, + num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); + else if (data_type == std::string("int8")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, + num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else if (data_type == std::string("uint8")) - { - return search_memory_index( - metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); - } - else if (data_type == std::string("float")) - { - return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); - } + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, + num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else { - std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; + std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; return -1; } } else { - if (data_type == std::string("int8")) - { - return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); - } + if (data_type == std::string("float")) + return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); + else if (data_type == std::string("int8")) + return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else if (data_type == std::string("uint8")) - { - return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); - } - else if (data_type == std::string("float")) - { - return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); - } + return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else { - std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; + std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; return -1; } } } - catch (std::exception &e) + catch (const std::exception &e) { std::cout << std::string(e.what()) << std::endl; diskann::cerr << "Index search failed." << std::endl; return -1; } -} +} \ No newline at end of file diff --git a/include/program_options_utils.hpp b/include/program_options_utils.hpp index 2be60595b..9d622557b 100644 --- a/include/program_options_utils.hpp +++ b/include/program_options_utils.hpp @@ -77,5 +77,8 @@ const char *UNIVERSAL_LABEL = "in the labels file instead of listing all labels for a node. DiskANN will not automatically assign a " "universal label to a node."; const char *FILTERED_LBUILD = "Build complexity for filtered points, higher value results in better graphs"; +const char *DIVERSITY_FILE = "Seller diversity file for diverse index"; +const char *NUM_DIVERSE = "Number of diverse edges needed per node in each local region"; + } // namespace program_options_utils