From ee62d8db9f54e5111c0d3ae2777c8d89a4db5fc2 Mon Sep 17 00:00:00 2001 From: NingyuanChen Date: Mon, 8 Jan 2024 13:20:46 -0800 Subject: [PATCH 01/38] Initial change for save/load as one file. (#504) Co-authored-by: REDMOND\ninchen --- include/abstract_data_store.h | 3 +- include/abstract_graph_store.h | 7 +- include/distance.h | 1 + include/in_mem_data_store.h | 9 +- include/in_mem_graph_store.h | 17 ++- include/index.h | 33 +++-- include/index_config.h | 42 +++++- include/utils.h | 43 ++++-- src/distance.cpp | 3 + src/in_mem_data_store.cpp | 22 +-- src/in_mem_graph_store.cpp | 61 ++++---- src/index.cpp | 245 +++++++++++++++++++++++++++++---- 12 files changed, 382 insertions(+), 104 deletions(-) diff --git a/include/abstract_data_store.h b/include/abstract_data_store.h index d858c8eef..60eeb6c03 100644 --- a/include/abstract_data_store.h +++ b/include/abstract_data_store.h @@ -21,13 +21,14 @@ template class AbstractDataStore virtual ~AbstractDataStore() = default; // Return number of points returned - virtual location_t load(const std::string &filename) = 0; + virtual location_t load(const std::string &filename, size_t offset) = 0; // Why does store take num_pts? Since store only has capacity, but we allow // resizing we can end up in a situation where the store has spare capacity. // To optimize disk utilization, we pass the number of points that are "true" // points, so that the store can discard the empty locations before saving. virtual size_t save(const std::string &filename, const location_t num_pts) = 0; + virtual size_t save(std::ofstream &writer, const location_t num_pts, size_t offset) = 0; DISKANN_DLLEXPORT virtual location_t capacity() const; diff --git a/include/abstract_graph_store.h b/include/abstract_graph_store.h index 4d6906ca4..750fec727 100644 --- a/include/abstract_graph_store.h +++ b/include/abstract_graph_store.h @@ -21,11 +21,14 @@ class AbstractGraphStore virtual ~AbstractGraphStore() = default; // returns tuple of - virtual std::tuple load(const std::string &index_path_prefix, - const size_t num_points) = 0; + virtual std::tuple load(const std::string &index_path_prefix, const size_t num_points, + size_t offset) = 0; virtual int store(const std::string &index_path_prefix, const size_t num_points, const size_t num_fz_points, const uint32_t start) = 0; + virtual int store(std::ofstream &writer, const size_t num_points, const size_t num_fz_points, const uint32_t start, + size_t offset) = 0; + // not synchronised, user should use lock when necvessary. virtual const std::vector &get_neighbours(const location_t i) const = 0; virtual void add_neighbour(const location_t i, location_t neighbour_id) = 0; diff --git a/include/distance.h b/include/distance.h index 8b20e586b..065b38231 100644 --- a/include/distance.h +++ b/include/distance.h @@ -1,5 +1,6 @@ #pragma once #include "windows_customizations.h" +#include #include namespace diskann diff --git a/include/in_mem_data_store.h b/include/in_mem_data_store.h index 9b6968b03..b610bb2dd 100644 --- a/include/in_mem_data_store.h +++ b/include/in_mem_data_store.h @@ -24,8 +24,9 @@ template class InMemDataStore : public AbstractDataStore> distance_fn); virtual ~InMemDataStore(); - virtual location_t load(const std::string &filename) override; - virtual size_t save(const std::string &filename, const location_t num_points) override; + virtual location_t load(const std::string &filename, size_t offset = 0) override; + virtual size_t save(const std::string &filename, const location_t num_pts) override; + virtual size_t save(std::ofstream &writer, const location_t num_pts, size_t offset) override; virtual size_t get_aligned_dim() const override; @@ -59,9 +60,9 @@ template class InMemDataStore : public AbstractDataStore - virtual std::tuple load(const std::string &index_path_prefix, - const size_t num_points) override; + virtual std::tuple load(const std::string &index_path_prefix, const size_t num_points, + size_t offset) override; virtual int store(const std::string &index_path_prefix, const size_t num_points, const size_t num_frozen_points, const uint32_t start) override; - + virtual int store(std::ofstream &writer, const size_t num_points, const size_t num_fz_points, const uint32_t start, + size_t offset) override; virtual const std::vector &get_neighbours(const location_t i) const override; virtual void add_neighbour(const location_t i, location_t neighbour_id) override; virtual void clear_neighbours(const location_t i) override; @@ -33,13 +34,15 @@ class InMemGraphStore : public AbstractGraphStore virtual uint32_t get_max_observed_degree() override; protected: - virtual std::tuple load_impl(const std::string &filename, size_t expected_num_points); + virtual std::tuple load_impl(const std::string &filename, size_t expected_num_points, + size_t offset); #ifdef EXEC_ENV_OLS - virtual std::tuple load_impl(AlignedFileReader &reader, size_t expected_num_points); + virtual std::tuple load_impl(AlignedFileReader &reader, size_t expected_num_points, + size_t offset); #endif - int save_graph(const std::string &index_path_prefix, const size_t active_points, const size_t num_frozen_points, - const uint32_t start); + int save_graph(std::ofstream &writer, const size_t active_points, const size_t num_frozen_points, + const uint32_t start, size_t offset); private: size_t _max_range_of_graph = 0; diff --git a/include/index.h b/include/index.h index e7966461c..387a9ac07 100644 --- a/include/index.h +++ b/include/index.h @@ -28,6 +28,14 @@ namespace diskann { +// This struct is used for storing metadata for save_as_one_file version 1. +struct SaveLoadMetaDataV1 +{ + uint64_t data_offset; + uint64_t delete_list_offset; + uint64_t tags_offset; + uint64_t graph_offset; +}; inline double estimate_ram_usage(size_t size, uint32_t dim, uint32_t datasize, uint32_t degree) { @@ -57,7 +65,9 @@ template clas const size_t num_frozen_pts = 0, const bool dynamic_index = false, const bool enable_tags = false, const bool concurrent_consolidate = false, const bool pq_dist_build = false, const size_t num_pq_chunks = 0, - const bool use_opq = false, const bool filtered_index = false); + const bool use_opq = false, const bool filtered_index = false, + bool save_as_one_file = false, uint64_t save_as_one_file_version = 1, + bool load_from_one_file = false, uint64_t load_from_one_file_version = 1); DISKANN_DLLEXPORT Index(const IndexConfig &index_config, std::unique_ptr> data_store, std::unique_ptr graph_store); @@ -313,15 +323,15 @@ template clas DISKANN_DLLEXPORT size_t save_tags(std::string filename); DISKANN_DLLEXPORT size_t save_delete_list(const std::string &filename); #ifdef EXEC_ENV_OLS - DISKANN_DLLEXPORT size_t load_graph(AlignedFileReader &reader, size_t expected_num_points); - DISKANN_DLLEXPORT size_t load_data(AlignedFileReader &reader); - DISKANN_DLLEXPORT size_t load_tags(AlignedFileReader &reader); - DISKANN_DLLEXPORT size_t load_delete_set(AlignedFileReader &reader); + DISKANN_DLLEXPORT size_t load_graph(AlignedFileReader &reader, size_t expected_num_points, size_t offset = 0); + DISKANN_DLLEXPORT size_t load_data(AlignedFileReader &reader, size_t offset = 0); + DISKANN_DLLEXPORT size_t load_tags(AlignedFileReader &reader, size_t offset = 0); + DISKANN_DLLEXPORT size_t load_delete_set(AlignedFileReader &reader, size_t offset = 0); #else - DISKANN_DLLEXPORT size_t load_graph(const std::string filename, size_t expected_num_points); - DISKANN_DLLEXPORT size_t load_data(std::string filename0); - DISKANN_DLLEXPORT size_t load_tags(const std::string tag_file_name); - DISKANN_DLLEXPORT size_t load_delete_set(const std::string &filename); + DISKANN_DLLEXPORT size_t load_graph(const std::string filename, size_t expected_num_points, size_t offset = 0); + DISKANN_DLLEXPORT size_t load_data(std::string filename, size_t offset = 0); + DISKANN_DLLEXPORT size_t load_tags(const std::string &filename, size_t offset = 0); + DISKANN_DLLEXPORT size_t load_delete_set(const std::string &filename, size_t offset = 0); #endif private: @@ -360,7 +370,10 @@ template clas bool _has_built = false; bool _saturate_graph = false; - bool _save_as_one_file = false; // plan to support in next version + bool _save_as_one_file; // plan to support filtered index in next version. + uint64_t _save_as_one_file_version; // Version used for save index to single file. + bool _load_from_one_file; // Whether to load index from single file. + uint64_t _load_from_one_file_version; // Version used for save index to single file. bool _dynamic_index = false; bool _enable_tags = false; bool _normalize_vecs = false; // Using normalied L2 for cosine. diff --git a/include/index_config.h b/include/index_config.h index 452498b01..6ada17d07 100644 --- a/include/index_config.h +++ b/include/index_config.h @@ -28,6 +28,10 @@ struct IndexConfig bool concurrent_consolidate; bool use_opq; bool filtered_index; + bool save_as_one_file; + uint64_t save_as_one_file_version; + bool load_from_one_file; + uint64_t load_from_one_file_version; size_t num_pq_chunks; size_t num_frozen_pts; @@ -45,12 +49,15 @@ struct IndexConfig IndexConfig(DataStoreStrategy data_strategy, GraphStoreStrategy graph_strategy, Metric metric, size_t dimension, size_t max_points, size_t num_pq_chunks, size_t num_frozen_points, bool dynamic_index, bool enable_tags, bool pq_dist_build, bool concurrent_consolidate, bool use_opq, bool filtered_index, - std::string &data_type, const std::string &tag_type, const std::string &label_type, - std::shared_ptr index_write_params, + bool save_as_one_file, uint64_t save_as_one_file_version, bool load_from_one_file, + uint64_t load_from_one_file_version, std::string &data_type, const std::string &tag_type, + const std::string &label_type, std::shared_ptr index_write_params, std::shared_ptr index_search_params) : data_strategy(data_strategy), graph_strategy(graph_strategy), metric(metric), dimension(dimension), max_points(max_points), dynamic_index(dynamic_index), enable_tags(enable_tags), pq_dist_build(pq_dist_build), concurrent_consolidate(concurrent_consolidate), use_opq(use_opq), filtered_index(filtered_index), + save_as_one_file(save_as_one_file), save_as_one_file_version(save_as_one_file_version), + load_from_one_file(load_from_one_file), load_from_one_file_version(load_from_one_file_version), num_pq_chunks(num_pq_chunks), num_frozen_pts(num_frozen_points), label_type(label_type), tag_type(tag_type), data_type(data_type), index_write_params(index_write_params), index_search_params(index_search_params) { @@ -194,6 +201,30 @@ class IndexConfigBuilder return *this; } + IndexConfigBuilder &with_save_as_single_file(bool save_as_one_file) + { + this->_save_as_one_file = save_as_one_file; + return *this; + } + + IndexConfigBuilder &with_save_as_single_file_version(uint64_t save_as_one_file_version) + { + this->_save_as_one_file_version = save_as_one_file_version; + return *this; + } + + IndexConfigBuilder &with_load_from_single_file(bool load_from_one_file) + { + this->_load_from_one_file = load_from_one_file; + return *this; + } + + IndexConfigBuilder &with_load_from_single_file_version(uint64_t load_from_one_file_version) + { + this->_save_as_one_file_version = load_from_one_file_version; + return *this; + } + IndexConfig build() { if (_data_type == "" || _data_type.empty()) @@ -219,7 +250,8 @@ class IndexConfigBuilder return IndexConfig(_data_strategy, _graph_strategy, _metric, _dimension, _max_points, _num_pq_chunks, _num_frozen_pts, _dynamic_index, _enable_tags, _pq_dist_build, _concurrent_consolidate, - _use_opq, _filtered_index, _data_type, _tag_type, _label_type, _index_write_params, + _use_opq, _filtered_index, _save_as_one_file, _save_as_one_file_version, _load_from_one_file, + _load_from_one_file_version, _data_type, _tag_type, _label_type, _index_write_params, _index_search_params); } @@ -240,6 +272,10 @@ class IndexConfigBuilder bool _concurrent_consolidate = false; bool _use_opq = false; bool _filtered_index{defaults::HAS_LABELS}; + bool _save_as_one_file; + uint64_t _save_as_one_file_version; + bool _load_from_one_file; + uint64_t _load_from_one_file_version; size_t _num_pq_chunks = 0; size_t _num_frozen_pts{defaults::NUM_FROZEN_POINTS_STATIC}; diff --git a/include/utils.h b/include/utils.h index bb03d13f1..9011634f5 100644 --- a/include/utils.h +++ b/include/utils.h @@ -714,13 +714,8 @@ inline void open_file_to_write(std::ofstream &writer, const std::string &filenam } } -template -inline size_t save_bin(const std::string &filename, T *data, size_t npts, size_t ndims, size_t offset = 0) +template inline size_t save_bin(std::ofstream &writer, T *data, size_t npts, size_t ndims, size_t offset) { - std::ofstream writer; - open_file_to_write(writer, filename); - - diskann::cout << "Writing bin: " << filename.c_str() << std::endl; writer.seekp(offset, writer.beg); int npts_i32 = (int)npts, ndims_i32 = (int)ndims; size_t bytes_written = npts * ndims * sizeof(T) + 2 * sizeof(uint32_t); @@ -730,11 +725,22 @@ inline size_t save_bin(const std::string &filename, T *data, size_t npts, size_t << std::endl; writer.write((char *)data, npts * ndims * sizeof(T)); - writer.close(); diskann::cout << "Finished writing bin." << std::endl; return bytes_written; } +template +inline size_t save_bin(const std::string &filename, T *data, size_t npts, size_t ndims, size_t offset = 0) +{ + std::ofstream writer; + open_file_to_write(writer, filename); + diskann::cout << "Writing bin file: " << filename.c_str() << std::endl; + size_t bytes_written = save_bin(writer, data, npts, ndims, offset); + writer.close(); + diskann::cout << "Close file " << filename << "." << std::endl; + return bytes_written; +} + inline void print_progress(double percentage) { int val = (int)(percentage * 100); @@ -938,12 +944,11 @@ template void save_Tvecs(const char *filename, T *data, size_t npts writer.write((char *)cur_pt, ndims * sizeof(T)); } } + template -inline size_t save_data_in_base_dimensions(const std::string &filename, T *data, size_t npts, size_t ndims, - size_t aligned_dim, size_t offset = 0) +inline size_t save_data_in_base_dimensions(std::ofstream &writer, T *data, size_t npts, size_t ndims, + size_t aligned_dim, size_t offset) { - std::ofstream writer; //(filename, std::ios::binary | std::ios::out); - open_file_to_write(writer, filename); int npts_i32 = (int)npts, ndims_i32 = (int)ndims; size_t bytes_written = 2 * sizeof(uint32_t) + npts * ndims * sizeof(T); writer.seekp(offset, writer.beg); @@ -953,10 +958,21 @@ inline size_t save_data_in_base_dimensions(const std::string &filename, T *data, { writer.write((char *)(data + i * aligned_dim), ndims * sizeof(T)); } - writer.close(); return bytes_written; } +template +inline size_t save_data_in_base_dimensions(const std::string &filename, T *data, size_t npts, size_t ndims, + size_t aligned_dim, size_t offset = 0) +{ + std::ofstream writer; //(filename, std::ios::binary | std::ios::out); + open_file_to_write(writer, filename); + size_t file_size = save_data_in_base_dimensions(writer, data, npts, ndims, aligned_dim, offset); + writer.close(); + + return file_size; +} + template inline void copy_aligned_data_from_file(const char *bin_file, T *&data, size_t &npts, size_t &dim, const size_t &rounded_dim, size_t offset = 0) @@ -968,11 +984,12 @@ inline void copy_aligned_data_from_file(const char *bin_file, T *&data, size_t & throw diskann::ANNException("Null pointer passed to copy_aligned_data_from_file function", -1, __FUNCSIG__, __FILE__, __LINE__); } + std::ifstream reader; reader.exceptions(std::ios::badbit | std::ios::failbit); reader.open(bin_file, std::ios::binary); - reader.seekg(offset, reader.beg); + reader.seekg(offset, reader.beg); int npts_i32, dim_i32; reader.read((char *)&npts_i32, sizeof(int)); reader.read((char *)&dim_i32, sizeof(int)); diff --git a/src/distance.cpp b/src/distance.cpp index 31ab9d3ff..f1c1a317a 100644 --- a/src/distance.cpp +++ b/src/distance.cpp @@ -730,4 +730,7 @@ template DISKANN_DLLEXPORT class SlowDistanceL2; template DISKANN_DLLEXPORT class SlowDistanceL2; template DISKANN_DLLEXPORT class SlowDistanceL2; +template DISKANN_DLLEXPORT Distance *get_distance_function(Metric m); +template DISKANN_DLLEXPORT Distance *get_distance_function(Metric m); +template DISKANN_DLLEXPORT Distance *get_distance_function(Metric m); } // namespace diskann diff --git a/src/in_mem_data_store.cpp b/src/in_mem_data_store.cpp index 7d02bba17..8e842d159 100644 --- a/src/in_mem_data_store.cpp +++ b/src/in_mem_data_store.cpp @@ -37,13 +37,13 @@ template size_t InMemDataStore::get_alignment_factor() return _distance_fn->get_required_alignment(); } -template location_t InMemDataStore::load(const std::string &filename) +template location_t InMemDataStore::load(const std::string &filename, size_t offset) { - return load_impl(filename); + return load_impl(filename, offset); } #ifdef EXEC_ENV_OLS -template location_t InMemDataStore::load_impl(AlignedFileReader &reader) +template location_t InMemDataStore::load_impl(AlignedFileReader &reader, size_t offset) { size_t file_dim, file_num_points; @@ -69,7 +69,7 @@ template location_t InMemDataStore::load_impl(AlignedF } #endif -template location_t InMemDataStore::load_impl(const std::string &filename) +template location_t InMemDataStore::load_impl(const std::string &filename, size_t offset) { size_t file_dim, file_num_points; if (!file_exists(filename)) @@ -80,7 +80,7 @@ template location_t InMemDataStore::load_impl(const st aligned_free(_data); throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } - diskann::get_bin_metadata(filename, file_num_points, file_dim); + diskann::get_bin_metadata(filename, file_num_points, file_dim, offset); if (file_dim != this->_dim) { @@ -97,14 +97,20 @@ template location_t InMemDataStore::load_impl(const st this->resize((location_t)file_num_points); } - copy_aligned_data_from_file(filename.c_str(), _data, file_num_points, file_dim, _aligned_dim); + copy_aligned_data_from_file(filename.c_str(), _data, file_num_points, file_dim, _aligned_dim, offset); return (location_t)file_num_points; } -template size_t InMemDataStore::save(const std::string &filename, const location_t num_points) +template size_t InMemDataStore::save(const std::string &filename, const location_t num_pts) { - return save_data_in_base_dimensions(filename, _data, num_points, this->get_dims(), this->get_aligned_dim(), 0U); + return save_data_in_base_dimensions(filename, _data, num_pts, this->get_dims(), this->get_aligned_dim(), 0U); +} + +template +size_t InMemDataStore::save(std::ofstream &writer, const location_t num_pts, size_t offset) +{ + return save_data_in_base_dimensions(writer, _data, num_pts, this->get_dims(), this->get_aligned_dim(), offset); } template void InMemDataStore::populate_data(const data_t *vectors, const location_t num_pts) diff --git a/src/in_mem_graph_store.cpp b/src/in_mem_graph_store.cpp index c12b2514e..fe14c8a0d 100644 --- a/src/in_mem_graph_store.cpp +++ b/src/in_mem_graph_store.cpp @@ -17,15 +17,27 @@ InMemGraphStore::InMemGraphStore(const size_t total_pts, const size_t reserve_gr } std::tuple InMemGraphStore::load(const std::string &index_path_prefix, - const size_t num_points) + const size_t num_points, size_t offset) { - return load_impl(index_path_prefix, num_points); + return load_impl(index_path_prefix, num_points, offset); } int InMemGraphStore::store(const std::string &index_path_prefix, const size_t num_points, const size_t num_frozen_points, const uint32_t start) { - return save_graph(index_path_prefix, num_points, num_frozen_points, start); + std::ofstream writer; + open_file_to_write(writer, index_path_prefix); + int file_size = store(writer, num_points, num_frozen_points, start, 0U); + writer.close(); + + return file_size; +} + +int InMemGraphStore::store(std::ofstream &writer, const size_t num_points, const size_t num_frozen_points, + const uint32_t start, size_t offset) +{ + return save_graph(writer, num_points, num_frozen_points, start, offset); } + const std::vector &InMemGraphStore::get_neighbours(const location_t i) const { return _graph.at(i); @@ -71,7 +83,8 @@ void InMemGraphStore::clear_graph() } #ifdef EXEC_ENV_OLS -std::tuple InMemGraphStore::load_impl(AlignedFileReader &reader, size_t expected_num_points) +std::tuple InMemGraphStore::load_impl(AlignedFileReader &reader, size_t expected_num_points, + size_t offset) { size_t expected_file_size; size_t file_frozen_pts; @@ -80,7 +93,7 @@ std::tuple InMemGraphStore::load_impl(AlignedFileRea auto max_points = get_max_points(); int header_size = 2 * sizeof(size_t) + 2 * sizeof(uint32_t); std::unique_ptr header = std::make_unique(header_size); - read_array(reader, header.get(), header_size); + read_array(reader, header.get(), header_size, offset); expected_file_size = *((size_t *)header.get()); _max_observed_degree = *((uint32_t *)(header.get() + sizeof(size_t))); @@ -103,7 +116,7 @@ std::tuple InMemGraphStore::load_impl(AlignedFileRea uint32_t nodes_read = 0; size_t cc = 0; - size_t graph_offset = header_size; + size_t graph_offset = header_size + offset; while (nodes_read < expected_num_points) { uint32_t k; @@ -133,17 +146,16 @@ std::tuple InMemGraphStore::load_impl(AlignedFileRea #endif std::tuple InMemGraphStore::load_impl(const std::string &filename, - size_t expected_num_points) + size_t expected_num_points, size_t offset) { size_t expected_file_size; size_t file_frozen_pts; uint32_t start; - size_t file_offset = 0; // will need this for single file format support std::ifstream in; in.exceptions(std::ios::badbit | std::ios::failbit); in.open(filename, std::ios::binary); - in.seekg(file_offset, in.beg); + in.seekg(offset, in.beg); in.read((char *)&expected_file_size, sizeof(size_t)); in.read((char *)&_max_observed_degree, sizeof(uint32_t)); in.read((char *)&start, sizeof(uint32_t)); @@ -197,35 +209,32 @@ std::tuple InMemGraphStore::load_impl(const std::str return std::make_tuple(nodes_read, start, file_frozen_pts); } -int InMemGraphStore::save_graph(const std::string &index_path_prefix, const size_t num_points, - const size_t num_frozen_points, const uint32_t start) +int InMemGraphStore::save_graph(std::ofstream &writer, const size_t num_points, const size_t num_frozen_points, + const uint32_t start, size_t offset) { - std::ofstream out; - open_file_to_write(out, index_path_prefix); - - size_t file_offset = 0; - out.seekp(file_offset, out.beg); + writer.seekp(offset, writer.beg); size_t index_size = 24; uint32_t max_degree = 0; - out.write((char *)&index_size, sizeof(uint64_t)); - out.write((char *)&_max_observed_degree, sizeof(uint32_t)); + writer.write((char *)&index_size, sizeof(uint64_t)); + writer.write((char *)&_max_observed_degree, sizeof(uint32_t)); uint32_t ep_u32 = start; - out.write((char *)&ep_u32, sizeof(uint32_t)); - out.write((char *)&num_frozen_points, sizeof(size_t)); + writer.write((char *)&ep_u32, sizeof(uint32_t)); + writer.write((char *)&num_frozen_points, sizeof(size_t)); // Note: num_points = _nd + _num_frozen_points for (uint32_t i = 0; i < num_points; i++) { uint32_t GK = (uint32_t)_graph[i].size(); - out.write((char *)&GK, sizeof(uint32_t)); - out.write((char *)_graph[i].data(), GK * sizeof(uint32_t)); + writer.write((char *)&GK, sizeof(uint32_t)); + writer.write((char *)_graph[i].data(), GK * sizeof(uint32_t)); max_degree = _graph[i].size() > max_degree ? (uint32_t)_graph[i].size() : max_degree; index_size += (size_t)(sizeof(uint32_t) * (GK + 1)); } - out.seekp(file_offset, out.beg); - out.write((char *)&index_size, sizeof(uint64_t)); - out.write((char *)&max_degree, sizeof(uint32_t)); - out.close(); + + writer.seekp(offset, writer.beg); + writer.write((char *)&index_size, sizeof(uint64_t)); + writer.write((char *)&max_degree, sizeof(uint32_t)); + return (int)index_size; } diff --git a/src/index.cpp b/src/index.cpp index 3de3a3b7f..10db4dab6 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -34,7 +34,10 @@ Index::Index(const IndexConfig &index_config, std::unique_ptr), _conc_consolidate(index_config.concurrent_consolidate) { if (_dynamic_index && !_enable_tags) @@ -125,7 +128,8 @@ Index::Index(Metric m, const size_t dim, const size_t max_point const std::shared_ptr index_search_params, const size_t num_frozen_pts, const bool dynamic_index, const bool enable_tags, const bool concurrent_consolidate, const bool pq_dist_build, const size_t num_pq_chunks, const bool use_opq, - const bool filtered_index) + const bool filtered_index, bool save_as_one_file, uint64_t save_as_one_file_version, + bool load_from_one_file, uint64_t load_from_one_file_version) : Index(IndexConfigBuilder() .with_metric(m) .with_dimension(dim) @@ -141,6 +145,10 @@ Index::Index(Metric m, const size_t dim, const size_t max_point .is_use_opq(use_opq) .is_filtered(filtered_index) .with_data_type(diskann_type_to_name()) + .with_save_as_single_file(save_as_one_file) + .with_save_as_single_file_version(save_as_one_file_version) + .with_load_from_single_file(load_from_one_file) + .with_load_from_single_file_version(load_from_one_file_version) .build(), IndexFactory::construct_datastore( DataStoreStrategy::MEMORY, @@ -379,9 +387,100 @@ void Index::save(const char *filename, bool compact_before_save } else { - diskann::cout << "Save index in a single file currently not supported. " - "Not saving the index." - << std::endl; + if (_filtered_index) + { + diskann::cout << "Save index in a single file currently not supported for filtered index. " + "Not saving the index." + << std::endl; + } + else + { + if (_save_as_one_file_version == 1) + { + std::ofstream writer; + open_file_to_write(writer, filename); + + // Save version. + writer.write((char *)&_save_as_one_file_version, sizeof(uint64_t)); + size_t curr_pos = sizeof(uint64_t); + + // Placeholder for metadata. + // This will be filled at end; + SaveLoadMetaDataV1 metadata; + const size_t meta_data_start = curr_pos; + curr_pos += sizeof(SaveLoadMetaDataV1); + + // Save data. + metadata.data_offset = static_cast(curr_pos); + curr_pos += _data_store->save(writer, (location_t)(_nd + _num_frozen_pts), curr_pos); + + // Save delete list. + { + if (_delete_set->size() == 0) + { + metadata.delete_list_offset = static_cast(curr_pos); + } + else + { + std::unique_ptr delete_list = std::make_unique(_delete_set->size()); + uint32_t i = 0; + for (auto &del : *_delete_set) + { + delete_list[i++] = del; + } + curr_pos += save_bin(writer, delete_list.get(), _delete_set->size(), 1, curr_pos); + } + } + + // Save tags. + { + if (!_enable_tags) + { + diskann::cout << "Not saving tags as they are not enabled." << std::endl; + metadata.tags_offset = static_cast(curr_pos); + } + else + { + TagT *tag_data = new TagT[_nd + _num_frozen_pts]; + for (uint32_t i = 0; i < _nd; i++) + { + TagT tag; + if (_location_to_tag.try_get(i, tag)) + { + tag_data[i] = tag; + } + else + { + // catering to future when tagT can be any type. + std::memset((char *)&tag_data[i], 0, sizeof(TagT)); + } + } + if (_num_frozen_pts > 0) + { + std::memset((char *)&tag_data[_start], 0, sizeof(TagT) * _num_frozen_pts); + } + + curr_pos += save_bin(writer, tag_data, _nd + _num_frozen_pts, 1, curr_pos); + delete[] tag_data; + } + } + + // Save graph. + metadata.graph_offset = static_cast(curr_pos); + curr_pos += _graph_store->store(writer, _nd + _num_frozen_pts, _num_frozen_pts, _start, curr_pos); + + // Save metadata. + writer.seekp(meta_data_start, writer.beg); + writer.write((char *)&metadata, sizeof(SaveLoadMetaDataV1)); + writer.close(); + } + else + { + diskann::cout << "Save index in a single file currently only support _save_as_one_file_version = 1. " + "Not saving the index." + << std::endl; + } + } } // If frozen points were temporarily compacted to _nd, move back to @@ -393,17 +492,16 @@ void Index::save(const char *filename, bool compact_before_save #ifdef EXEC_ENV_OLS template -size_t Index::load_tags(AlignedFileReader &reader) +size_t Index::load_tags(AlignedFileReader &reader, size_t offset) { #else template -size_t Index::load_tags(const std::string tag_filename) +size_t Index::load_tags(const std::string &filename, size_t offset) { - if (_enable_tags && !file_exists(tag_filename)) + if (_enable_tags && !file_exists(filename)) { - diskann::cerr << "Tag file " << tag_filename << " does not exist!" << std::endl; - throw diskann::ANNException("Tag file " + tag_filename + " does not exist!", -1, __FUNCSIG__, __FILE__, - __LINE__); + diskann::cerr << "Tag file " << filename << " does not exist!" << std::endl; + throw diskann::ANNException("Tag file " + filename + " does not exist!", -1, __FUNCSIG__, __FILE__, __LINE__); } #endif if (!_enable_tags) @@ -415,9 +513,9 @@ size_t Index::load_tags(const std::string tag_filename) size_t file_dim, file_num_points; TagT *tag_data; #ifdef EXEC_ENV_OLS - load_bin(reader, tag_data, file_num_points, file_dim); + load_bin(reader, tag_data, file_num_points, file_dim, offset); #else - load_bin(std::string(tag_filename), tag_data, file_num_points, file_dim); + load_bin(std::string(filename), tag_data, file_num_points, file_dim, offset); #endif if (file_dim != 1) @@ -449,15 +547,15 @@ size_t Index::load_tags(const std::string tag_filename) template #ifdef EXEC_ENV_OLS -size_t Index::load_data(AlignedFileReader &reader) +size_t Index::load_data(AlignedFileReader &reader, size_t offset) { #else -size_t Index::load_data(std::string filename) +size_t Index::load_data(std::string filename, size_t offset) { #endif size_t file_dim, file_num_points; #ifdef EXEC_ENV_OLS - diskann::get_bin_metadata(reader, file_num_points, file_dim); + diskann::get_bin_metadata(reader, file_num_points, file_dim, offset); #else if (!file_exists(filename)) { @@ -466,7 +564,7 @@ size_t Index::load_data(std::string filename) diskann::cerr << stream.str() << std::endl; throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } - diskann::get_bin_metadata(filename, file_num_points, file_dim); + diskann::get_bin_metadata(filename, file_num_points, file_dim, offset); #endif // since we are loading a new dataset, _empty_slots must be cleared @@ -490,29 +588,29 @@ size_t Index::load_data(std::string filename) #ifdef EXEC_ENV_OLS // REFACTOR TODO: Must figure out how to support aligned reader in a clean // manner. - copy_aligned_data_from_file(reader, _data, file_num_points, file_dim, _data_store->get_aligned_dim()); + copy_aligned_data_from_file(reader, _data, file_num_points, file_dim, _data_store->get_aligned_dim(), offset); #else - _data_store->load(filename); // offset == 0. + _data_store->load(filename, offset); // offset == 0. #endif return file_num_points; } #ifdef EXEC_ENV_OLS template -size_t Index::load_delete_set(AlignedFileReader &reader) +size_t Index::load_delete_set(AlignedFileReader &reader, size_t offset) { #else template -size_t Index::load_delete_set(const std::string &filename) +size_t Index::load_delete_set(const std::string &filename, size_t offset) { #endif std::unique_ptr delete_list; size_t npts, ndim; #ifdef EXEC_ENV_OLS - diskann::load_bin(reader, delete_list, npts, ndim); + diskann::load_bin(reader, delete_list, npts, ndim, offset); #else - diskann::load_bin(filename, delete_list, npts, ndim); + diskann::load_bin(filename, delete_list, npts, ndim, offset); #endif assert(ndim == 1); for (uint32_t i = 0; i < npts; i++) @@ -528,6 +626,7 @@ template #ifdef EXEC_ENV_OLS void Index::load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l) { + IOContext &ctx = reader.get_ctx(); #else void Index::load(const char *filename, uint32_t num_threads, uint32_t search_l) { @@ -546,7 +645,7 @@ void Index::load(const char *filename, uint32_t num_threads, ui std::string labels_to_medoids = mem_index_file + "_labels_to_medoids.txt"; std::string labels_map_file = mem_index_file + "_labels_map.txt"; #endif - if (!_save_as_one_file) + if (!_load_from_one_file) { // For DLVS Store, we will not support saving the index in multiple // files. @@ -569,9 +668,95 @@ void Index::load(const char *filename, uint32_t num_threads, ui } else { - diskann::cout << "Single index file saving/loading support not yet " - "enabled. Not loading the index." - << std::endl; + if (_filtered_index) + { + diskann::cout << "Single index file saving/loading support for filtered index is not yet " + "enabled. Not loading the index." + << std::endl; + } + else + { + uint64_t version; + +#ifdef EXEC_ENV_OLS + std::vector readReqs; + AlignedRead readReq; + uint64_t buf[1]; + + readReq.buf = buf; + readReq.offset = 0; + readReq.len = sizeof(uint64_t); + readReqs.push_back(readReq); + reader.read(readReqs, ctx); // synchronous + if ((*(ctx.m_pRequestsStatus))[0] == IOContext::READ_SUCCESS) + { + version = buf[0]; + } +#else + std::ifstream reader(filename, std::ios::binary); + reader.read((char *)&version, sizeof(uint64_t)); +#endif + + if (version == _load_from_one_file_version) + { + SaveLoadMetaDataV1 metadata; + +#ifdef EXEC_ENV_OLS + std::vector metadata_readReqs; + AlignedRead metadata_readReq; + uint64_t metadata_buf[1]; + + metadata_readReq.buf = metadata_buf; + metadata_readReq.offset = sizeof(uint64_t); + metadata_readReq.len = sizeof(SaveLoadMetaDataV1); + metadata_readReq.push_back(readReq); + reader.read(metadata_readReqs, ctx); // synchronous + if ((*(ctx.m_pRequestsStatus))[0] == IOContext::READ_SUCCESS) + { + memcpy((void *)&metadata, (void *)buf, sizeof(SaveLoadMetaDataV1)); + } +#else + reader.read((char *)&metadata, sizeof(SaveLoadMetaDataV1)); +#endif + // Load data +#ifdef EXEC_ENV_OLS + load_data(reader, metadata.data_offset) +#else + load_data(filename, metadata.data_offset); +#endif + + // Load delete list when presents. + if (metadata.data_offset != metadata.delete_list_offset) + { +#ifdef EXEC_ENV_OLS + load_delete_set(reader, metadata.delete_list_offset); +#else + load_delete_set(filename, metadata.delete_list_offset); +#endif + } + // Load tags when presents. + if (metadata.delete_list_offset != metadata.tags_offset) + { +#ifdef EXEC_ENV_OLS + load_tags(reader, metadata.tags_offset); +#else + load_tags(filename, metadata.tags_offset); +#endif + } + // Load graph +#ifdef EXEC_ENV_OLS + load_graph(reader, metadata.graph_offset); +#else + load_graph(filename, metadata.graph_offset); +#endif + } + else + { + diskann::cout << "load index from a single file currently only support _save_as_one_file_version = 1. " + "Not loading the index." + << std::endl; + } + } return; } @@ -679,15 +864,15 @@ size_t Index::get_graph_num_frozen_points(const std::string &gr #ifdef EXEC_ENV_OLS template -size_t Index::load_graph(AlignedFileReader &reader, size_t expected_num_points) +size_t Index::load_graph(AlignedFileReader &reader, size_t expected_num_points, size_t offset) { #else template -size_t Index::load_graph(std::string filename, size_t expected_num_points) +size_t Index::load_graph(std::string filename, size_t expected_num_points, size_t offset) { #endif - auto res = _graph_store->load(filename, expected_num_points); + auto res = _graph_store->load(filename, expected_num_points, offset); _start = std::get<1>(res); _num_frozen_pts = std::get<2>(res); return std::get<0>(res); From df84a6d5bcbc494a88afdc732c70b88443c0f696 Mon Sep 17 00:00:00 2001 From: "REDMOND\\ninchen" Date: Fri, 12 Jan 2024 10:05:18 -0800 Subject: [PATCH 02/38] BANN single file Save and Load. --- include/abstract_graph_store.h | 8 +++ include/defaults.h | 2 +- include/in_mem_graph_store.h | 12 +++-- include/index.h | 2 + include/parameters.h | 16 +++--- src/in_mem_graph_store.cpp | 15 +++++- src/index.cpp | 97 +++++++++++++++++++++++----------- src/pq_flash_index.cpp | 2 +- 8 files changed, 108 insertions(+), 46 deletions(-) diff --git a/include/abstract_graph_store.h b/include/abstract_graph_store.h index 750fec727..5c239da7e 100644 --- a/include/abstract_graph_store.h +++ b/include/abstract_graph_store.h @@ -7,6 +7,8 @@ #include #include "types.h" +class AlignedFileReader; + namespace diskann { @@ -21,8 +23,14 @@ class AbstractGraphStore virtual ~AbstractGraphStore() = default; // returns tuple of +#ifdef EXEC_ENV_OLS + virtual std::tuple load(AlignedFileReader &reader, const size_t num_points, + size_t offset) = 0; +#else virtual std::tuple load(const std::string &index_path_prefix, const size_t num_points, size_t offset) = 0; +#endif + virtual int store(const std::string &index_path_prefix, const size_t num_points, const size_t num_fz_points, const uint32_t start) = 0; diff --git a/include/defaults.h b/include/defaults.h index 5ea5af495..ef1750fcf 100644 --- a/include/defaults.h +++ b/include/defaults.h @@ -17,7 +17,7 @@ const uint32_t NUM_FROZEN_POINTS_STATIC = 0; const uint32_t NUM_FROZEN_POINTS_DYNAMIC = 1; // In-mem index related limits -const float GRAPH_SLACK_FACTOR = 1.3; +const float GRAPH_SLACK_FACTOR = 1.3f; // SSD Index related limits const uint64_t MAX_GRAPH_DEGREE = 512; diff --git a/include/in_mem_graph_store.h b/include/in_mem_graph_store.h index 95e4dbcce..543a0cca7 100644 --- a/include/in_mem_graph_store.h +++ b/include/in_mem_graph_store.h @@ -14,8 +14,13 @@ class InMemGraphStore : public AbstractGraphStore InMemGraphStore(const size_t total_pts, const size_t reserve_graph_degree); // returns tuple of - virtual std::tuple load(const std::string &index_path_prefix, const size_t num_points, +#ifdef EXEC_ENV_OLS + virtual std::tuple load(AlignedFileReader &reader, const size_t num_points, size_t offset) override; +#else + virtual std::tuple load(const std::string &filename, size_t expected_num_points, + size_t offset); +#endif virtual int store(const std::string &index_path_prefix, const size_t num_points, const size_t num_frozen_points, const uint32_t start) override; virtual int store(std::ofstream &writer, const size_t num_points, const size_t num_fz_points, const uint32_t start, @@ -34,11 +39,12 @@ class InMemGraphStore : public AbstractGraphStore virtual uint32_t get_max_observed_degree() override; protected: - virtual std::tuple load_impl(const std::string &filename, size_t expected_num_points, - size_t offset); #ifdef EXEC_ENV_OLS virtual std::tuple load_impl(AlignedFileReader &reader, size_t expected_num_points, size_t offset); +#else + virtual std::tuple load_impl(const std::string &filename, size_t expected_num_points, + size_t offset); #endif int save_graph(std::ofstream &writer, const size_t active_points, const size_t num_frozen_points, diff --git a/include/index.h b/include/index.h index 387a9ac07..fc42d6e0d 100644 --- a/include/index.h +++ b/include/index.h @@ -35,6 +35,8 @@ struct SaveLoadMetaDataV1 uint64_t delete_list_offset; uint64_t tags_offset; uint64_t graph_offset; + + SaveLoadMetaDataV1(); }; inline double estimate_ram_usage(size_t size, uint32_t dim, uint32_t datasize, uint32_t degree) diff --git a/include/parameters.h b/include/parameters.h index 2bba9aeca..3c771a730 100644 --- a/include/parameters.h +++ b/include/parameters.h @@ -16,15 +16,7 @@ class IndexWriteParameters { public: - const uint32_t search_list_size; // L - const uint32_t max_degree; // R - const bool saturate_graph; - const uint32_t max_occlusion_size; // C - const float alpha; - const uint32_t num_threads; - const uint32_t filter_list_size; // Lf - private: IndexWriteParameters(const uint32_t search_list_size, const uint32_t max_degree, const bool saturate_graph, const uint32_t max_occlusion_size, const float alpha, const uint32_t num_threads, const uint32_t filter_list_size) @@ -34,6 +26,14 @@ class IndexWriteParameters { } + const uint32_t search_list_size; // L + const uint32_t max_degree; // R + const bool saturate_graph; + const uint32_t max_occlusion_size; // C + const float alpha; + const uint32_t num_threads; + const uint32_t filter_list_size; // Lf + friend class IndexWriteParametersBuilder; }; diff --git a/src/in_mem_graph_store.cpp b/src/in_mem_graph_store.cpp index fe14c8a0d..fae35ced0 100644 --- a/src/in_mem_graph_store.cpp +++ b/src/in_mem_graph_store.cpp @@ -4,6 +4,7 @@ #include "in_mem_graph_store.h" #include "utils.h" + namespace diskann { InMemGraphStore::InMemGraphStore(const size_t total_pts, const size_t reserve_graph_degree) @@ -16,11 +17,21 @@ InMemGraphStore::InMemGraphStore(const size_t total_pts, const size_t reserve_gr } } +#ifdef EXEC_ENV_OLS +std::tuple InMemGraphStore::load(AlignedFileReader &reader, + const size_t num_points, size_t offset) +{ + + return load_impl(reader, num_points, offset); +} +#else std::tuple InMemGraphStore::load(const std::string &index_path_prefix, const size_t num_points, size_t offset) { + return load_impl(index_path_prefix, num_points, offset); } +#endif int InMemGraphStore::store(const std::string &index_path_prefix, const size_t num_points, const size_t num_frozen_points, const uint32_t start) { @@ -90,7 +101,6 @@ std::tuple InMemGraphStore::load_impl(AlignedFileRea size_t file_frozen_pts; uint32_t start; - auto max_points = get_max_points(); int header_size = 2 * sizeof(size_t) + 2 * sizeof(uint32_t); std::unique_ptr header = std::make_unique(header_size); read_array(reader, header.get(), header_size, offset); @@ -143,8 +153,8 @@ std::tuple InMemGraphStore::load_impl(AlignedFileRea << std::endl; return std::make_tuple(nodes_read, start, file_frozen_pts); } -#endif +#else std::tuple InMemGraphStore::load_impl(const std::string &filename, size_t expected_num_points, size_t offset) { @@ -208,6 +218,7 @@ std::tuple InMemGraphStore::load_impl(const std::str << std::endl; return std::make_tuple(nodes_read, start, file_frozen_pts); } +#endif int InMemGraphStore::save_graph(std::ofstream &writer, const size_t num_points, const size_t num_frozen_points, const uint32_t start, size_t offset) diff --git a/src/index.cpp b/src/index.cpp index 10db4dab6..7130b3d3d 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -25,6 +25,11 @@ namespace diskann { +SaveLoadMetaDataV1::SaveLoadMetaDataV1() : data_offset(0), delete_list_offset(0), tags_offset(0), graph_offset(0) +{ +} + + // Initialize an index with metric m, load the data of type T with filename // (bin), and initialize max_points template @@ -411,16 +416,16 @@ void Index::save(const char *filename, bool compact_before_save curr_pos += sizeof(SaveLoadMetaDataV1); // Save data. - metadata.data_offset = static_cast(curr_pos); - curr_pos += _data_store->save(writer, (location_t)(_nd + _num_frozen_pts), curr_pos); + { + metadata.data_offset = static_cast(curr_pos); + curr_pos += _data_store->save(writer, (location_t)(_nd + _num_frozen_pts), curr_pos); + } // Save delete list. { - if (_delete_set->size() == 0) - { - metadata.delete_list_offset = static_cast(curr_pos); - } - else + metadata.delete_list_offset = static_cast(curr_pos); + + if (_delete_set->size() != 0) { std::unique_ptr delete_list = std::make_unique(_delete_set->size()); uint32_t i = 0; @@ -434,12 +439,9 @@ void Index::save(const char *filename, bool compact_before_save // Save tags. { - if (!_enable_tags) - { - diskann::cout << "Not saving tags as they are not enabled." << std::endl; - metadata.tags_offset = static_cast(curr_pos); - } - else + metadata.tags_offset = static_cast(curr_pos); + + if (_enable_tags) { TagT *tag_data = new TagT[_nd + _num_frozen_pts]; for (uint32_t i = 0; i < _nd; i++) @@ -466,17 +468,24 @@ void Index::save(const char *filename, bool compact_before_save } // Save graph. - metadata.graph_offset = static_cast(curr_pos); - curr_pos += _graph_store->store(writer, _nd + _num_frozen_pts, _num_frozen_pts, _start, curr_pos); + { + metadata.graph_offset = static_cast(curr_pos); + _graph_store->store(writer, _nd + _num_frozen_pts, _num_frozen_pts, _start, curr_pos); - // Save metadata. - writer.seekp(meta_data_start, writer.beg); - writer.write((char *)&metadata, sizeof(SaveLoadMetaDataV1)); - writer.close(); + // Save metadata. + writer.seekp(meta_data_start, writer.beg); + writer.write((char *)&metadata, sizeof(SaveLoadMetaDataV1)); + writer.close(); + } + + std::cout << "Metadata Saved. data_offset: " << std::to_string(metadata.data_offset) + << " delete_list_offset: " << std::to_string(metadata.delete_list_offset) + << " tag_offset: " << std::to_string(metadata.tags_offset) + << " graph_offset: " << std::to_string(metadata.graph_offset) << std::endl; } else { - diskann::cout << "Save index in a single file currently only support _save_as_one_file_version = 1. " + std::cout << "Save index in a single file currently only support _save_as_one_file_version = 1. " "Not saving the index." << std::endl; } @@ -487,7 +496,7 @@ void Index::save(const char *filename, bool compact_before_save // _max_points. reposition_frozen_point_to_end(); - diskann::cout << "Time taken for save: " << timer.elapsed() / 1000000.0 << "s." << std::endl; + std::cout << "Time taken for save: " << timer.elapsed() / 1000000.0 << "s." << std::endl; } #ifdef EXEC_ENV_OLS @@ -647,6 +656,7 @@ void Index::load(const char *filename, uint32_t num_threads, ui #endif if (!_load_from_one_file) { + std::cout << "DLVS should not load multiple files." << std::endl; // For DLVS Store, we will not support saving the index in multiple // files. #ifndef EXEC_ENV_OLS @@ -670,15 +680,18 @@ void Index::load(const char *filename, uint32_t num_threads, ui { if (_filtered_index) { - diskann::cout << "Single index file saving/loading support for filtered index is not yet " + std::cout << "Single index file saving/loading support for filtered index is not yet " "enabled. Not loading the index." << std::endl; } else { - uint64_t version; + std::cout << "Start loading index from one file." << std::endl; + uint64_t version = 0; #ifdef EXEC_ENV_OLS + std::cout << "Start Version Check." << std::endl; + std::vector readReqs; AlignedRead readReq; uint64_t buf[1]; @@ -687,11 +700,24 @@ void Index::load(const char *filename, uint32_t num_threads, ui readReq.offset = 0; readReq.len = sizeof(uint64_t); readReqs.push_back(readReq); + std::cout << "Load Version request is ready." << std::endl; + reader.read(readReqs, ctx); // synchronous - if ((*(ctx.m_pRequestsStatus))[0] == IOContext::READ_SUCCESS) + std::cout << "Load Version processed." << std::endl; + + if ((*(ctx.m_pRequestsStatus.get()))[0] == IOContext::READ_SUCCESS) { version = buf[0]; + std::cout << "Load Version is " << std::to_string(version) << "." << std::endl; } + else + { + std::stringstream str; + str << "Could not read binary metadata from index file at offset: 0." << std::endl; + std::cout << str.str() << std::endl; + throw diskann::ANNException(str.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + #else std::ifstream reader(filename, std::ios::binary); reader.read((char *)&version, sizeof(uint64_t)); @@ -699,34 +725,42 @@ void Index::load(const char *filename, uint32_t num_threads, ui if (version == _load_from_one_file_version) { + std::cout << "Version Check passed, start loading meta data." << std::endl; SaveLoadMetaDataV1 metadata; #ifdef EXEC_ENV_OLS std::vector metadata_readReqs; AlignedRead metadata_readReq; - uint64_t metadata_buf[1]; + uint64_t metadata_buf[sizeof(SaveLoadMetaDataV1)]; metadata_readReq.buf = metadata_buf; metadata_readReq.offset = sizeof(uint64_t); metadata_readReq.len = sizeof(SaveLoadMetaDataV1); - metadata_readReq.push_back(readReq); + metadata_readReqs.push_back(metadata_readReq); reader.read(metadata_readReqs, ctx); // synchronous if ((*(ctx.m_pRequestsStatus))[0] == IOContext::READ_SUCCESS) { memcpy((void *)&metadata, (void *)buf, sizeof(SaveLoadMetaDataV1)); } + + std::cout << "Metadata loaded. data_offset: " << std::to_string(metadata.data_offset) + << " delete_list_offset: " << std::to_string(metadata.delete_list_offset) + << " tag_offset: " << std::to_string(metadata.tags_offset) + << " graph_offset: " << std::to_string(metadata.graph_offset) + << std::endl; + #else reader.read((char *)&metadata, sizeof(SaveLoadMetaDataV1)); #endif // Load data #ifdef EXEC_ENV_OLS - load_data(reader, metadata.data_offset) + load_data(reader, metadata.data_offset); #else load_data(filename, metadata.data_offset); #endif // Load delete list when presents. - if (metadata.data_offset != metadata.delete_list_offset) + if (metadata.data_offset != metadata.delete_list_offset) { #ifdef EXEC_ENV_OLS load_delete_set(reader, metadata.delete_list_offset); @@ -752,12 +786,11 @@ void Index::load(const char *filename, uint32_t num_threads, ui } else { - diskann::cout << "load index from a single file currently only support _save_as_one_file_version = 1. " + std::cout << "load index from a single file currently only support _save_as_one_file_version = 1. " "Not loading the index." << std::endl; } } - return; } if (data_file_num_pts != graph_num_pts || (data_file_num_pts != tags_file_num_pts && _enable_tags)) @@ -866,13 +899,15 @@ size_t Index::get_graph_num_frozen_points(const std::string &gr template size_t Index::load_graph(AlignedFileReader &reader, size_t expected_num_points, size_t offset) { + auto res = _graph_store->load(reader, expected_num_points, offset); + #else template size_t Index::load_graph(std::string filename, size_t expected_num_points, size_t offset) { -#endif auto res = _graph_store->load(filename, expected_num_points, offset); +#endif _start = std::get<1>(res); _num_frozen_pts = std::get<2>(res); return std::get<0>(res); diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index c9b2c0ebb..33867d4be 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -1123,7 +1123,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons { uint64_t dumr, dumc; float *norm_val; - diskann::load_bin(files, norm_val, dumr, dumc); + diskann::load_bin(files, norm_file, norm_val, dumr, dumc); #else if (file_exists(norm_file) && metric == diskann::Metric::INNER_PRODUCT) { From 446727227912e86c970b74400af6ec806485af13 Mon Sep 17 00:00:00 2001 From: "REDMOND\\ninchen" Date: Tue, 16 Jan 2024 22:24:13 -0800 Subject: [PATCH 03/38] Fix Save One file bugs. --- include/abstract_data_store.h | 3 ++ include/in_mem_data_store.h | 1 + include/index_config.h | 2 +- src/in_mem_data_store.cpp | 11 ++-- src/in_mem_graph_store.cpp | 5 +- src/index.cpp | 95 ++++++++++++++++++----------------- src/utils.cpp | 49 +++++++++--------- 7 files changed, 90 insertions(+), 76 deletions(-) diff --git a/include/abstract_data_store.h b/include/abstract_data_store.h index 60eeb6c03..e0c9a99e8 100644 --- a/include/abstract_data_store.h +++ b/include/abstract_data_store.h @@ -9,6 +9,8 @@ #include "types.h" #include "windows_customizations.h" #include "distance.h" +#include "aligned_file_reader.h" + namespace diskann { @@ -22,6 +24,7 @@ template class AbstractDataStore // Return number of points returned virtual location_t load(const std::string &filename, size_t offset) = 0; + virtual location_t load(AlignedFileReader &reader, size_t offset) = 0; // Why does store take num_pts? Since store only has capacity, but we allow // resizing we can end up in a situation where the store has spare capacity. diff --git a/include/in_mem_data_store.h b/include/in_mem_data_store.h index b610bb2dd..6feb09199 100644 --- a/include/in_mem_data_store.h +++ b/include/in_mem_data_store.h @@ -25,6 +25,7 @@ template class InMemDataStore : public AbstractDataStore_save_as_one_file_version = load_from_one_file_version; + this->_load_from_one_file_version = load_from_one_file_version; return *this; } diff --git a/src/in_mem_data_store.cpp b/src/in_mem_data_store.cpp index 8e842d159..e168d96fa 100644 --- a/src/in_mem_data_store.cpp +++ b/src/in_mem_data_store.cpp @@ -42,12 +42,16 @@ template location_t InMemDataStore::load(const std::st return load_impl(filename, offset); } +template location_t InMemDataStore::load(AlignedFileReader &reader, size_t offset) +{ + return load_impl(reader, offset); +} + #ifdef EXEC_ENV_OLS template location_t InMemDataStore::load_impl(AlignedFileReader &reader, size_t offset) { size_t file_dim, file_num_points; - - diskann::get_bin_metadata(reader, file_num_points, file_dim); + diskann::get_bin_metadata(reader, file_num_points, file_dim, offset); if (file_dim != this->_dim) { @@ -63,7 +67,8 @@ template location_t InMemDataStore::load_impl(AlignedF { this->resize((location_t)file_num_points); } - copy_aligned_data_from_file(reader, _data, file_num_points, file_dim, _aligned_dim); + + copy_aligned_data_from_file(reader, _data, file_num_points, file_dim, _aligned_dim, offset); return (location_t)file_num_points; } diff --git a/src/in_mem_graph_store.cpp b/src/in_mem_graph_store.cpp index fae35ced0..5378f72c6 100644 --- a/src/in_mem_graph_store.cpp +++ b/src/in_mem_graph_store.cpp @@ -94,7 +94,8 @@ void InMemGraphStore::clear_graph() } #ifdef EXEC_ENV_OLS -std::tuple InMemGraphStore::load_impl(AlignedFileReader &reader, size_t expected_num_points, +std::tuple InMemGraphStore::load_impl(AlignedFileReader &reader, + size_t expected_num_points, size_t offset) { size_t expected_file_size; @@ -114,7 +115,7 @@ std::tuple InMemGraphStore::load_impl(AlignedFileRea << ", _max_observed_degree: " << _max_observed_degree << ", _start: " << start << ", file_frozen_pts: " << file_frozen_pts << std::endl; - diskann::cout << "Loading vamana graph from reader..." << std::flush; + diskann::cout << "Loading vamana graph from reader..." << std::endl << std::flush; // If user provides more points than max_points // resize the _graph to the larger size. diff --git a/src/index.cpp b/src/index.cpp index 7130b3d3d..80daf15d2 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -471,21 +471,24 @@ void Index::save(const char *filename, bool compact_before_save { metadata.graph_offset = static_cast(curr_pos); _graph_store->store(writer, _nd + _num_frozen_pts, _num_frozen_pts, _start, curr_pos); + } - // Save metadata. + // Save metadata. + { writer.seekp(meta_data_start, writer.beg); writer.write((char *)&metadata, sizeof(SaveLoadMetaDataV1)); - writer.close(); } - std::cout << "Metadata Saved. data_offset: " << std::to_string(metadata.data_offset) - << " delete_list_offset: " << std::to_string(metadata.delete_list_offset) - << " tag_offset: " << std::to_string(metadata.tags_offset) - << " graph_offset: " << std::to_string(metadata.graph_offset) << std::endl; + writer.close(); + + diskann::cout << "Metadata Saved. data_offset: " << std::to_string(metadata.data_offset) + << " delete_list_offset: " << std::to_string(metadata.delete_list_offset) + << " tag_offset: " << std::to_string(metadata.tags_offset) + << " graph_offset: " << std::to_string(metadata.graph_offset) << std::endl; } else { - std::cout << "Save index in a single file currently only support _save_as_one_file_version = 1. " + diskann::cout << "Save index in a single file currently only support _save_as_one_file_version = 1. " "Not saving the index." << std::endl; } @@ -496,7 +499,7 @@ void Index::save(const char *filename, bool compact_before_save // _max_points. reposition_frozen_point_to_end(); - std::cout << "Time taken for save: " << timer.elapsed() / 1000000.0 << "s." << std::endl; + diskann::cout << "Time taken for save: " << timer.elapsed() / 1000000.0 << "s." << std::endl; } #ifdef EXEC_ENV_OLS @@ -587,7 +590,6 @@ size_t Index::load_data(std::string filename, size_t offset) diskann::cerr << stream.str() << std::endl; throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } - if (file_num_points > _max_points + _num_frozen_pts) { // update and tag lock acquired in load() before calling load_data @@ -597,7 +599,7 @@ size_t Index::load_data(std::string filename, size_t offset) #ifdef EXEC_ENV_OLS // REFACTOR TODO: Must figure out how to support aligned reader in a clean // manner. - copy_aligned_data_from_file(reader, _data, file_num_points, file_dim, _data_store->get_aligned_dim(), offset); + _data_store->load(reader, offset); // offset == 0. #else _data_store->load(filename, offset); // offset == 0. #endif @@ -656,7 +658,6 @@ void Index::load(const char *filename, uint32_t num_threads, ui #endif if (!_load_from_one_file) { - std::cout << "DLVS should not load multiple files." << std::endl; // For DLVS Store, we will not support saving the index in multiple // files. #ifndef EXEC_ENV_OLS @@ -680,41 +681,33 @@ void Index::load(const char *filename, uint32_t num_threads, ui { if (_filtered_index) { - std::cout << "Single index file saving/loading support for filtered index is not yet " + diskann::cout << "Single index file saving/loading support for filtered index is not yet " "enabled. Not loading the index." << std::endl; } else { - std::cout << "Start loading index from one file." << std::endl; uint64_t version = 0; #ifdef EXEC_ENV_OLS - std::cout << "Start Version Check." << std::endl; - std::vector readReqs; AlignedRead readReq; - uint64_t buf[1]; + uint8_t buf[sizeof(uint64_t)] = {}; - readReq.buf = buf; + readReq.buf = (void *) buf; readReq.offset = 0; readReq.len = sizeof(uint64_t); readReqs.push_back(readReq); - std::cout << "Load Version request is ready." << std::endl; - reader.read(readReqs, ctx); // synchronous - std::cout << "Load Version processed." << std::endl; if ((*(ctx.m_pRequestsStatus.get()))[0] == IOContext::READ_SUCCESS) { - version = buf[0]; - std::cout << "Load Version is " << std::to_string(version) << "." << std::endl; + memcpy((void *)&version, (void *)buf, sizeof(uint64_t)); } else { std::stringstream str; str << "Could not read binary metadata from index file at offset: 0." << std::endl; - std::cout << str.str() << std::endl; throw diskann::ANNException(str.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } @@ -725,42 +718,41 @@ void Index::load(const char *filename, uint32_t num_threads, ui if (version == _load_from_one_file_version) { - std::cout << "Version Check passed, start loading meta data." << std::endl; SaveLoadMetaDataV1 metadata; - #ifdef EXEC_ENV_OLS std::vector metadata_readReqs; AlignedRead metadata_readReq; - uint64_t metadata_buf[sizeof(SaveLoadMetaDataV1)]; + uint8_t metadata_buf[sizeof(SaveLoadMetaDataV1)] = {}; - metadata_readReq.buf = metadata_buf; - metadata_readReq.offset = sizeof(uint64_t); + metadata_readReq.buf = (void*) metadata_buf; + metadata_readReq.offset = sizeof(version); metadata_readReq.len = sizeof(SaveLoadMetaDataV1); metadata_readReqs.push_back(metadata_readReq); reader.read(metadata_readReqs, ctx); // synchronous if ((*(ctx.m_pRequestsStatus))[0] == IOContext::READ_SUCCESS) { - memcpy((void *)&metadata, (void *)buf, sizeof(SaveLoadMetaDataV1)); + memcpy((void *)&metadata, (void *)metadata_buf, sizeof(SaveLoadMetaDataV1)); } - std::cout << "Metadata loaded. data_offset: " << std::to_string(metadata.data_offset) - << " delete_list_offset: " << std::to_string(metadata.delete_list_offset) - << " tag_offset: " << std::to_string(metadata.tags_offset) - << " graph_offset: " << std::to_string(metadata.graph_offset) - << std::endl; + diskann::cout << "Metadata loaded. data_offset: " << std::to_string(metadata.data_offset) + << " delete_list_offset: " << std::to_string(metadata.delete_list_offset) + << " tag_offset: " << std::to_string(metadata.tags_offset) + << " graph_offset: " << std::to_string(metadata.graph_offset) + << std::endl; #else reader.read((char *)&metadata, sizeof(SaveLoadMetaDataV1)); #endif // Load data #ifdef EXEC_ENV_OLS - load_data(reader, metadata.data_offset); + data_file_num_pts = load_data(reader, metadata.data_offset); + #else - load_data(filename, metadata.data_offset); + data_file_num_pts = load_data(filename, metadata.data_offset); #endif - // Load delete list when presents. - if (metadata.data_offset != metadata.delete_list_offset) + // Load delete list when presents. + if (metadata.delete_list_offset != metadata.tags_offset) { #ifdef EXEC_ENV_OLS load_delete_set(reader, metadata.delete_list_offset); @@ -768,27 +760,33 @@ void Index::load(const char *filename, uint32_t num_threads, ui load_delete_set(filename, metadata.delete_list_offset); #endif } + // Load tags when presents. - if (metadata.delete_list_offset != metadata.tags_offset) + if (metadata.tags_offset != metadata.graph_offset) { #ifdef EXEC_ENV_OLS - load_tags(reader, metadata.tags_offset); + tags_file_num_pts = load_tags(reader, metadata.tags_offset); #else - load_tags(filename, metadata.tags_offset); + tags_file_num_pts = load_tags(filename, metadata.tags_offset); #endif } // Load graph #ifdef EXEC_ENV_OLS - load_graph(reader, metadata.graph_offset); + + graph_num_pts = load_graph(reader, data_file_num_pts, metadata.graph_offset); #else - load_graph(filename, metadata.graph_offset); + graph_num_pts = load_graph(filename, data_file_num_pts, metadata.graph_offset); #endif } else { - std::cout << "load index from a single file currently only support _save_as_one_file_version = 1. " - "Not loading the index." - << std::endl; + std::stringstream stream; + stream << "load index from a single file currently only support _save_as_one_file_version = 1 and _save_as_one_file_version = 1. " + << "Not loading the index." + << std::endl; + diskann::cerr << stream.str() << std::endl; + + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } } } @@ -802,6 +800,8 @@ void Index::load(const char *filename, uint32_t num_threads, ui diskann::cerr << stream.str() << std::endl; throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } + + #ifndef EXEC_ENV_OLS if (file_exists(labels_file)) { @@ -849,6 +849,7 @@ void Index::load(const char *filename, uint32_t num_threads, ui } } #endif + _nd = data_file_num_pts - _num_frozen_pts; _empty_slots.clear(); _empty_slots.reserve(_max_points); @@ -2013,7 +2014,7 @@ void Index::build(const std::string &data_file, const size_t nu this->build_filtered_index(data_file.c_str(), labels_file_to_use, points_to_load); } std::chrono::duration diff = std::chrono::high_resolution_clock::now() - s; - std::cout << "Indexing time: " << diff.count() << "\n"; + diskann::cout << "Indexing time: " << diff.count() << "\n"; } template diff --git a/src/utils.cpp b/src/utils.cpp index b675e656d..ab36c42a8 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -285,6 +285,7 @@ template void load_bin(AlignedFileReader &reader, T *&data, size_t { // Code assumes that the reader is already setup correctly. get_bin_metadata(reader, npts, ndim, offset); + data = new T[npts * ndim]; size_t data_size = npts * ndim * sizeof(T); @@ -333,7 +334,7 @@ void copy_aligned_data_from_file(AlignedFileReader &reader, T *&data, size_t &np { if (data == nullptr) { - diskann::cerr << "Memory was not allocated for " << data << " before calling the load function. Exiting..." + diskann::cout << "Memory was not allocated for " << data << " before calling the load function. Exiting..." << std::endl; throw diskann::ANNException("Null pointer passed to copy_aligned_data_from_file()", -1, __FUNCSIG__, __FILE__, __LINE__); @@ -391,29 +392,31 @@ template void read_array(AlignedFileReader &reader, T *data, size_t if (data == nullptr) { throw diskann::ANNException("read_array requires an allocated buffer.", -1); - if (size * sizeof(T) > MAX_REQUEST_SIZE) - { - std::stringstream ss; - ss << "Cannot read more than " << MAX_REQUEST_SIZE - << " bytes. Current request size: " << std::to_string(size) << " sizeof(T): " << sizeof(T) << std::endl; - throw diskann::ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } - std::vector read_requests; - AlignedRead read_req; - read_req.buf = data; - read_req.len = size * sizeof(T); - read_req.offset = offset; - read_requests.push_back(read_req); - IOContext &ctx = reader.get_ctx(); - reader.read(read_requests, ctx); + } - if ((*(ctx.m_pRequestsStatus))[0] != IOContext::READ_SUCCESS) - { - std::stringstream ss; - ss << "Failed to read_array() of size: " << size * sizeof(T) << " at offset: " << offset << " from reader. " - << std::endl; - throw diskann::ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__); - } + if (size * sizeof(T) > MAX_REQUEST_SIZE) + { + std::stringstream ss; + ss << "Cannot read more than " << MAX_REQUEST_SIZE + << " bytes. Current request size: " << std::to_string(size) << " sizeof(T): " << sizeof(T) << std::endl; + throw diskann::ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + std::vector read_requests; + AlignedRead read_req; + read_req.buf = data; + read_req.len = size * sizeof(T); + read_req.offset = offset; + read_requests.push_back(read_req); + IOContext &ctx = reader.get_ctx(); + reader.read(read_requests, ctx); + + if ((*(ctx.m_pRequestsStatus))[0] != IOContext::READ_SUCCESS) + { + std::stringstream ss; + ss << "Failed to read_array() of size: " << size * sizeof(T) << " at offset: " << offset << " from reader. " + << std::endl; + throw diskann::ANNException(ss.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } } From dbd702b499adf43b2bc1a06cb9de336864982b43 Mon Sep 17 00:00:00 2001 From: Li Tan Date: Wed, 14 Feb 2024 05:59:15 -0800 Subject: [PATCH 04/38] add a get_num_deleted_points method --- include/index.h | 1 + src/index.cpp | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/include/index.h b/include/index.h index fc42d6e0d..574887232 100644 --- a/include/index.h +++ b/include/index.h @@ -92,6 +92,7 @@ template clas // get some private variables DISKANN_DLLEXPORT size_t get_num_points(); DISKANN_DLLEXPORT size_t get_max_points(); + DISKANN_DLLEXPORT size_t get_num_deleted_points(); DISKANN_DLLEXPORT bool detect_common_filters(uint32_t point_id, bool search_invocation, const std::vector &incoming_labels); diff --git a/src/index.cpp b/src/index.cpp index 80daf15d2..a5b251b5f 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2496,6 +2496,12 @@ template size_t Index size_t Index::get_num_deleted_points() +{ + std::shared_lock dl(_delete_lock); + return _delete_set->size(); +} + template void Index::generate_frozen_point() { if (_num_frozen_pts == 0) From 8e4d10d857119697af9ffca0feb53dc70b7d2b78 Mon Sep 17 00:00:00 2001 From: Li Tan Date: Fri, 16 Feb 2024 04:48:03 -0800 Subject: [PATCH 05/38] in-mem graph loading: skip reading neighbor list when out edge count is 0 --- src/in_mem_graph_store.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/in_mem_graph_store.cpp b/src/in_mem_graph_store.cpp index 5378f72c6..8d8e83cc4 100644 --- a/src/in_mem_graph_store.cpp +++ b/src/in_mem_graph_store.cpp @@ -135,9 +135,12 @@ std::tuple InMemGraphStore::load_impl(AlignedFileRea graph_offset += sizeof(uint32_t); std::vector tmp(k); tmp.reserve(k); - read_array(reader, tmp.data(), k, graph_offset); - graph_offset += k * sizeof(uint32_t); - cc += k; + if (k > 0) + { + read_array(reader, tmp.data(), k, graph_offset); + graph_offset += k * sizeof(uint32_t); + cc += k; + } _graph[nodes_read].swap(tmp); nodes_read++; if (nodes_read % 1000000 == 0) From e426e8e32bda1e7d5502aade5b5db197c246ee1f Mon Sep 17 00:00:00 2001 From: Huisheng Liu Date: Fri, 16 Feb 2024 04:48:03 -0800 Subject: [PATCH 06/38] add wait() method to AlignedFileReader --- include/aligned_file_reader.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/include/aligned_file_reader.h b/include/aligned_file_reader.h index f5e2af5c3..f39d5da39 100644 --- a/include/aligned_file_reader.h +++ b/include/aligned_file_reader.h @@ -117,4 +117,9 @@ class AlignedFileReader // process batch of aligned requests in parallel // NOTE :: blocking call virtual void read(std::vector &read_reqs, IOContext &ctx, bool async = false) = 0; + +#ifdef USE_BING_INFRA + // wait for completion of one request in a batch of requests + virtual void wait(IOContext &ctx, int &completedIndex) = 0; +#endif }; From 1de7ac41e4f7a8473349b4bf68f2580669d9faac Mon Sep 17 00:00:00 2001 From: NingyuanChen Date: Thu, 29 Feb 2024 01:22:11 -0800 Subject: [PATCH 07/38] Fix the wrong behavior when saving tags without calling any lazy_delete. (#519) * Fix the wrong behavior when saving tags without calling any lazy_delete. * Add option to force compact_data. --------- Co-authored-by: REDMOND\ninchen --- include/index.h | 2 +- src/index.cpp | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/include/index.h b/include/index.h index 574887232..60c218776 100644 --- a/include/index.h +++ b/include/index.h @@ -307,7 +307,7 @@ template clas // Renumber nodes, update tag and location maps and compact the // graph, mode = _consolidated_order in case of lazy deletion and // _compacted_order in case of eager deletion - DISKANN_DLLEXPORT void compact_data(); + DISKANN_DLLEXPORT void compact_data(bool forced = false); DISKANN_DLLEXPORT void compact_frozen_point(); // Remove deleted nodes from adjacency list of node loc diff --git a/src/index.cpp b/src/index.cpp index a5b251b5f..5e7bc60db 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -286,7 +286,7 @@ void Index::save(const char *filename, bool compact_before_save if (compact_before_save) { - compact_data(); + compact_data(true); compact_frozen_point(); } else @@ -2755,7 +2755,7 @@ template void Index void Index::compact_data() +template void Index::compact_data(bool forced) { if (!_dynamic_index) throw ANNException("Can not compact a non-dynamic index", -1, __FUNCSIG__, __FILE__, __LINE__); @@ -2763,7 +2763,11 @@ template void Indexsize() > 0) From bc7568e7b42a842497ac9badd6a33bcc1190ecaf Mon Sep 17 00:00:00 2001 From: Li Tan Date: Wed, 6 Mar 2024 04:32:30 -0800 Subject: [PATCH 08/38] Remove unnecessary tag 0 check when insertion. Fix max_points during initialization (same a PR 523) --- src/index.cpp | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index 5e7bc60db..2fca21ad3 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -157,10 +157,10 @@ Index::Index(Metric m, const size_t dim, const size_t max_point .build(), IndexFactory::construct_datastore( DataStoreStrategy::MEMORY, - max_points + (dynamic_index && num_frozen_pts == 0 ? (size_t)1 : num_frozen_pts), dim, m), + (max_points == 0? (size_t)1 : max_points) + (dynamic_index && num_frozen_pts == 0 ? (size_t)1 : num_frozen_pts), dim, m), IndexFactory::construct_graphstore( GraphStoreStrategy::MEMORY, - max_points + (dynamic_index && num_frozen_pts == 0 ? (size_t)1 : num_frozen_pts), + (max_points == 0? (size_t)1 : max_points) + (dynamic_index && num_frozen_pts == 0 ? (size_t)1 : num_frozen_pts), (size_t)((index_parameters == nullptr ? 0 : index_parameters->max_degree) * defaults::GRAPH_SLACK_FACTOR * 1.05))) { @@ -781,7 +781,7 @@ void Index::load(const char *filename, uint32_t num_threads, ui else { std::stringstream stream; - stream << "load index from a single file currently only support _save_as_one_file_version = 1 and _save_as_one_file_version = 1. " + stream << "load index from a single file currently only support _save_as_one_file_version = 1 and _load_as_one_file_version = 1. " << "Not loading the index." << std::endl; diskann::cerr << stream.str() << std::endl; @@ -3129,15 +3129,7 @@ int Index::insert_point(const T *point, const TagT tag) template int Index::insert_point(const T *point, const TagT tag, const std::vector &labels) { - assert(_has_built); - if (tag == static_cast(0)) - { - throw diskann::ANNException("Do not insert point with tag 0. That is " - "reserved for points hidden " - "from the user.", - -1, __FUNCSIG__, __FILE__, __LINE__); - } std::shared_lock shared_ul(_update_lock); std::unique_lock tl(_tag_lock); From 27af2ddda2786a0fe608279754712e70b6ad4d04 Mon Sep 17 00:00:00 2001 From: Huisheng Liu Date: Wed, 28 Feb 2024 12:47:34 -0800 Subject: [PATCH 09/38] replace callback with Wait() method --- src/pq_flash_index.cpp | 37 +++++-------------------------------- 1 file changed, 5 insertions(+), 32 deletions(-) diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 33867d4be..3ec70b163 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -1140,37 +1140,10 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons } #ifdef USE_BING_INFRA -bool getNextCompletedRequest(const IOContext &ctx, size_t size, int &completedIndex) -{ - bool waitsRemaining = false; - long completeCount = ctx.m_completeCount; - do - { - for (int i = 0; i < size; i++) - { - auto ithStatus = (*ctx.m_pRequestsStatus)[i]; - if (ithStatus == IOContext::Status::READ_SUCCESS) - { - completedIndex = i; - return true; - } - else if (ithStatus == IOContext::Status::READ_WAIT) - { - waitsRemaining = true; - } - } - - // if we didn't find one in READ_SUCCESS, wait for one to complete. - if (waitsRemaining) - { - WaitOnAddress(&ctx.m_completeCount, &completeCount, sizeof(completeCount), 100); - // this assumes the knowledge of the reader behavior (implicit - // contract). need better factoring? - } - } while (waitsRemaining); - - completedIndex = -1; - return false; +bool getNextCompletedRequest(std::shared_ptr &reader, + IOContext &ctx, int &completedIndex) { + reader->wait(ctx, completedIndex); + return completedIndex != -1; } #endif @@ -1476,7 +1449,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t long requestCount = static_cast(frontier_read_reqs.size()); // If we issued read requests and if a read is complete or there are // reads in wait state, then enter the while loop. - while (requestCount > 0 && getNextCompletedRequest(ctx, requestCount, completedIndex)) + while (requestCount > 0 && getNextCompletedRequest(reader, ctx, completedIndex)) { assert(completedIndex >= 0); auto &frontier_nhood = frontier_nhoods[completedIndex]; From 79f83c16068c05ce8243807ba3c286360bc349d9 Mon Sep 17 00:00:00 2001 From: Huisheng Liu Date: Wed, 28 Feb 2024 12:47:34 -0800 Subject: [PATCH 10/38] replace callback driven wait with new Wait() method --- src/pq_flash_index.cpp | 35 ++++------------------------------- 1 file changed, 4 insertions(+), 31 deletions(-) diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 33867d4be..17c919e5f 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -1140,37 +1140,10 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons } #ifdef USE_BING_INFRA -bool getNextCompletedRequest(const IOContext &ctx, size_t size, int &completedIndex) +bool getNextCompletedRequest(std::shared_ptr &reader, IOContext &ctx, int &completedIndex) { - bool waitsRemaining = false; - long completeCount = ctx.m_completeCount; - do - { - for (int i = 0; i < size; i++) - { - auto ithStatus = (*ctx.m_pRequestsStatus)[i]; - if (ithStatus == IOContext::Status::READ_SUCCESS) - { - completedIndex = i; - return true; - } - else if (ithStatus == IOContext::Status::READ_WAIT) - { - waitsRemaining = true; - } - } - - // if we didn't find one in READ_SUCCESS, wait for one to complete. - if (waitsRemaining) - { - WaitOnAddress(&ctx.m_completeCount, &completeCount, sizeof(completeCount), 100); - // this assumes the knowledge of the reader behavior (implicit - // contract). need better factoring? - } - } while (waitsRemaining); - - completedIndex = -1; - return false; + reader->wait(ctx, completedIndex); + return completedIndex != -1; } #endif @@ -1476,7 +1449,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t long requestCount = static_cast(frontier_read_reqs.size()); // If we issued read requests and if a read is complete or there are // reads in wait state, then enter the while loop. - while (requestCount > 0 && getNextCompletedRequest(ctx, requestCount, completedIndex)) + while (requestCount > 0 && getNextCompletedRequest(reader, ctx, completedIndex)) { assert(completedIndex >= 0); auto &frontier_nhood = frontier_nhoods[completedIndex]; From abefd07a7e1072440f1b9014c21b71075c13ea13 Mon Sep 17 00:00:00 2001 From: Li Tan Date: Thu, 14 Mar 2024 05:28:58 -0700 Subject: [PATCH 11/38] allow build as static lib --- include/windows_customizations.h | 5 ++++- src/dll/CMakeLists.txt | 29 ++++++++++++++++++++++------- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/include/windows_customizations.h b/include/windows_customizations.h index e6c58466a..16bd0070e 100644 --- a/include/windows_customizations.h +++ b/include/windows_customizations.h @@ -4,12 +4,15 @@ #pragma once #ifdef _WINDOWS - +#ifdef _NODLL +#define DISKANN_DLLEXPORT +#else #ifdef _WINDLL #define DISKANN_DLLEXPORT __declspec(dllexport) #else #define DISKANN_DLLEXPORT __declspec(dllimport) #endif +#endif #else #define DISKANN_DLLEXPORT diff --git a/src/dll/CMakeLists.txt b/src/dll/CMakeLists.txt index d00cfeb95..1e4a39264 100644 --- a/src/dll/CMakeLists.txt +++ b/src/dll/CMakeLists.txt @@ -1,20 +1,35 @@ #Copyright(c) Microsoft Corporation.All rights reserved. #Licensed under the MIT license. -add_library(${PROJECT_NAME} SHARED dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp - ../windows_aligned_file_reader.cpp ../distance.cpp ../memory_mapper.cpp ../index.cpp - ../in_mem_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp - ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) +if (DISKANN_USE_STATIC_LIB) + add_library(${PROJECT_NAME} STATIC dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp + ../windows_aligned_file_reader.cpp ../distance.cpp ../memory_mapper.cpp ../index.cpp + ../in_mem_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp + ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) +else() + add_library(${PROJECT_NAME} SHARED dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp + ../windows_aligned_file_reader.cpp ../distance.cpp ../memory_mapper.cpp ../index.cpp + ../in_mem_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp + ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) +endif() set(TARGET_DIR "$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>") -set(DISKANN_DLL_IMPLIB "${TARGET_DIR}/${PROJECT_NAME}.lib") +# set(DISKANN_DLL_IMPLIB "${TARGET_DIR}/${PROJECT_NAME}.lib") + +if (DISKANN_USE_STATIC_LIB) + target_compile_definitions(${PROJECT_NAME} PRIVATE _NODLL) +else() + target_compile_definitions(${PROJECT_NAME} PRIVATE _USRDLL _WINDLL) +endif() -target_compile_definitions(${PROJECT_NAME} PRIVATE _USRDLL _WINDLL) target_compile_options(${PROJECT_NAME} PRIVATE /GL) target_include_directories(${PROJECT_NAME} PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES}) -target_link_options(${PROJECT_NAME} PRIVATE /DLL /IMPLIB:${DISKANN_DLL_IMPLIB} /LTCG) +if (NOT DEFINED DISKANN_USE_STATIC_LIB) + target_link_options(${PROJECT_NAME} PRIVATE /DLL /IMPLIB:${DISKANN_DLL_IMPLIB} /LTCG) +endif() + target_link_libraries(${PROJECT_NAME} PRIVATE ${DISKANN_MKL_LINK_LIBRARIES}) target_link_libraries(${PROJECT_NAME} PRIVATE synchronization.lib) From 987db1989f4c66054967067d4f7478aacbf90892 Mon Sep 17 00:00:00 2001 From: Li Tan Date: Thu, 14 Mar 2024 05:30:01 -0700 Subject: [PATCH 12/38] add a size check in is_in_set method to avoid assert error in debug build --- src/natural_number_set.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/natural_number_set.cpp b/src/natural_number_set.cpp index b36cb5298..d0796cd5c 100644 --- a/src/natural_number_set.cpp +++ b/src/natural_number_set.cpp @@ -62,7 +62,7 @@ template size_t natural_number_set::size() const template bool natural_number_set::is_in_set(T id) const { - return _values_bitset->test(id); + return _values_bitset->size() > id && _values_bitset->test(id); } // Instantiate used templates. From 9ca8ac9a0fd9e216f6afa4b3b588671457a4b670 Mon Sep 17 00:00:00 2001 From: Li Tan Date: Thu, 14 Mar 2024 05:30:37 -0700 Subject: [PATCH 13/38] DLVS only: allow update vector for tag and record deleted tags --- include/index.h | 22 +++++++++++++-- src/index.cpp | 73 +++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 84 insertions(+), 11 deletions(-) diff --git a/include/index.h b/include/index.h index 60c218776..5a0b75009 100644 --- a/include/index.h +++ b/include/index.h @@ -92,7 +92,11 @@ template clas // get some private variables DISKANN_DLLEXPORT size_t get_num_points(); DISKANN_DLLEXPORT size_t get_max_points(); - DISKANN_DLLEXPORT size_t get_num_deleted_points(); + +#ifdef EXEC_ENV_OLS + DISKANN_DLLEXPORT size_t get_num_tags(); // including both active and deleted tags. + DISKANN_DLLEXPORT size_t get_num_deleted_tags(); +#endif DISKANN_DLLEXPORT bool detect_common_filters(uint32_t point_id, bool search_invocation, const std::vector &incoming_labels); @@ -151,10 +155,10 @@ template clas const size_t K, const uint32_t L, IndexType *indices, float *distances); - // Will fail if tag already in the index or if tag=0. + // Will fail if tag already in the index. DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag); - // Will fail if tag already in the index or if tag=0. + // Will fail if tag already in the index. DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag, const std::vector &label); // call this before issuing deletions to sets relevant flags @@ -196,6 +200,9 @@ template clas DISKANN_DLLEXPORT void count_nodes_at_bfs_levels(); + // Increase the max points to the new value only if it's higher. + DISKANN_DLLEXPORT void increase_size(size_t new_max_points); + // This variable MUST be updated if the number of entries in the metadata // change. DISKANN_DLLEXPORT static const int METADATA_ROWS = 5; @@ -431,6 +438,15 @@ template clas // slots to _empty_slots. natural_number_set _empty_slots; std::unique_ptr> _delete_set; +#ifdef EXEC_ENV_OLS + // Set of tags that have been deleted. + // This is to differentiate a tag that has been deleted or never existed in a R/W index instance + // for checking searched results from prior R/O index instnaces have been deleted later in the R/W index + // (if this tag exists in the _deleted_tags). + // When the R/W index is saved, deletes will be consolidated. And when loaded back as a R/O instance, + // which will always contain active tags only, hence this set doesn't need to be saved. + std::unique_ptr> _deleted_tags; +#endif bool _data_compacted = true; // true if data has been compacted bool _is_saved = false; // Checking if the index is already saved. diff --git a/src/index.cpp b/src/index.cpp index 2fca21ad3..e199d7ea5 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -45,6 +45,9 @@ Index::Index(const IndexConfig &index_config, std::unique_ptr), _conc_consolidate(index_config.concurrent_consolidate) { +#ifdef EXEC_ENV_OLS + _deleted_tags = std::make_unique>(); +#endif if (_dynamic_index && !_enable_tags) { throw ANNException("ERROR: Dynamic Indexing must have tags enabled.", -1, __FUNCSIG__, __FILE__, __LINE__); @@ -938,8 +941,21 @@ template int Index std::shared_lock lock(_tag_lock); if (_tag_to_location.find(tag) == _tag_to_location.end()) { +#ifdef EXEC_ENV_OLS + if (_deleted_tags->find(tag) != _deleted_tags->end()) + { + diskann::cout << "Tag " << tag << " has been deleted" << std::endl; + return 1; + } + else + { + diskann::cout << "Tag " << tag << " has not existed" << std::endl; + return -1; + } +#else diskann::cout << "Tag " << tag << " does not exist" << std::endl; return -1; +#endif } location_t location = _tag_to_location[tag]; @@ -2496,11 +2512,19 @@ template size_t Index size_t Index::get_num_deleted_points() +#ifdef EXEC_ENV_OLS +template size_t Index::get_num_tags() { - std::shared_lock dl(_delete_lock); - return _delete_set->size(); + std::shared_lock dl(_tag_lock); + return _tag_to_location.size() + _deleted_tags->size(); +} + +template size_t Index::get_num_deleted_tags() +{ + std::shared_lock dl(_tag_lock); + return _deleted_tags->size(); } +#endif template void Index::generate_frozen_point() { @@ -2674,7 +2698,7 @@ consolidation_report Index::consolidate_deletes(const IndexWrit return consolidation_report(diskann::consolidation_report::status_code::LOCK_FAIL, 0, 0, 0, 0, 0, 0, 0); } - diskann::cout << "Starting consolidate_deletes... "; + diskann::cout << "Starting consolidate_deletes... " << std::endl; std::unique_ptr> old_delete_set(new tsl::robin_set); { @@ -3084,6 +3108,18 @@ template void Index(stop - start).count() << "s" << std::endl; } + +template void Index::increase_size(size_t new_max_points) +{ + std::shared_lock shared_ul(_update_lock); + std::unique_lock tl(_tag_lock); + + if (new_max_points > _max_points) + { + resize(new_max_points); + } +} + template int Index::_insert_point(const DataType &point, const TagType tag) { @@ -3156,7 +3192,7 @@ int Index::insert_point(const T *point, const TagT tag, const s if (_frozen_pts_used >= _num_frozen_pts) { throw ANNException( - "Error: For dynamic filtered index, the number of frozen points should be atleast equal " + "Error: For dynamic filtered index, the number of frozen points should be at least equal " "to number of unique labels.", -1); } @@ -3209,21 +3245,39 @@ int Index::insert_point(const T *point, const TagT tag, const s return -1; #endif } // cant insert as active pts >= max_pts +#ifndef EXEC_ENV_OLS dl.unlock(); +#endif // Insert tag and mapping to location if (_enable_tags) { - // if tags are enabled and tag is already inserted. so we can't reuse that tag. + // if tags are enabled and tag is already inserted. if (_tag_to_location.find(tag) != _tag_to_location.end()) { - release_location(location); +#ifdef EXEC_ENV_OLS + // Allow update the point for the existing tag, delete the existing location first. + const auto location_to_delete = _tag_to_location[tag]; + _delete_set->insert(location_to_delete); + _location_to_tag.erase(location_to_delete); +#else + // we can't reuse that tag. return -1; +#endif } - +#ifdef EXEC_ENV_OLS + if (_deleted_tags->find(tag) != _deleted_tags->end()) + { + // If the tag was deleted, reactivate it. + _deleted_tags->erase(tag); + } +#endif _tag_to_location[tag] = location; _location_to_tag.set(location, tag); } +#ifdef EXEC_ENV_OLS + dl.unlock(); +#endif tl.unlock(); _data_store->set_vector(location, point); // update datastore @@ -3318,6 +3372,9 @@ template int Index _delete_set->insert(location); _location_to_tag.erase(location); _tag_to_location.erase(tag); +#ifdef EXEC_ENV_OLS + _deleted_tags->insert(tag); +#endif return 0; } From b20d6688abd9e0a91058de2e1613850fff73d7ce Mon Sep 17 00:00:00 2001 From: Li Tan Date: Thu, 14 Mar 2024 10:34:05 -0700 Subject: [PATCH 14/38] fix one uncommented line in CMakeLists.txt --- src/dll/CMakeLists.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dll/CMakeLists.txt b/src/dll/CMakeLists.txt index 1e4a39264..1cccf9021 100644 --- a/src/dll/CMakeLists.txt +++ b/src/dll/CMakeLists.txt @@ -15,11 +15,10 @@ endif() set(TARGET_DIR "$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>") -# set(DISKANN_DLL_IMPLIB "${TARGET_DIR}/${PROJECT_NAME}.lib") - if (DISKANN_USE_STATIC_LIB) target_compile_definitions(${PROJECT_NAME} PRIVATE _NODLL) else() + set(DISKANN_DLL_IMPLIB "${TARGET_DIR}/${PROJECT_NAME}.lib") target_compile_definitions(${PROJECT_NAME} PRIVATE _USRDLL _WINDLL) endif() From 270dfd8f2a34b1717635be50c93c67ad32179698 Mon Sep 17 00:00:00 2001 From: Li Tan Date: Fri, 15 Mar 2024 09:50:34 -0700 Subject: [PATCH 15/38] remove dllmain.cpp from building --- src/dll/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dll/CMakeLists.txt b/src/dll/CMakeLists.txt index 1cccf9021..d10fc951a 100644 --- a/src/dll/CMakeLists.txt +++ b/src/dll/CMakeLists.txt @@ -2,7 +2,7 @@ #Licensed under the MIT license. if (DISKANN_USE_STATIC_LIB) - add_library(${PROJECT_NAME} STATIC dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp + add_library(${PROJECT_NAME} STATIC ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp ../windows_aligned_file_reader.cpp ../distance.cpp ../memory_mapper.cpp ../index.cpp ../in_mem_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) From 502eb04af6ceb7a8dabadbd33962131ba868155e Mon Sep 17 00:00:00 2001 From: Li Tan Date: Fri, 15 Mar 2024 09:50:59 -0700 Subject: [PATCH 16/38] enable capacity expanding --- include/index.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/index.h b/include/index.h index 5a0b75009..4a9cbf999 100644 --- a/include/index.h +++ b/include/index.h @@ -23,7 +23,11 @@ #include "abstract_index.h" #define OVERHEAD_FACTOR 1.1 +#ifdef EXEC_ENV_OLS +#define EXPAND_IF_FULL 1 +#else #define EXPAND_IF_FULL 0 +#endif #define DEFAULT_MAXC 750 namespace diskann From facfc281062c7a31db47ad40ecb8d90b66d99da3 Mon Sep 17 00:00:00 2001 From: Huisheng Liu Date: Sun, 24 Mar 2024 18:43:47 -0700 Subject: [PATCH 17/38] wait on completeCount if callback is used --- src/pq_flash_index.cpp | 44 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 17c919e5f..cb6239f9c 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -1140,10 +1140,46 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons } #ifdef USE_BING_INFRA -bool getNextCompletedRequest(std::shared_ptr &reader, IOContext &ctx, int &completedIndex) +bool getNextCompletedRequest(std::shared_ptr &reader, IOContext &ctx, size_t size, + int &completedIndex) { - reader->wait(ctx, completedIndex); - return completedIndex != -1; + if ((*ctx.m_pRequests)[0].m_callback) + { + bool waitsRemaining = false; + long completeCount = ctx.m_completeCount; + do + { + for (int i = 0; i < size; i++) + { + auto ithStatus = (*ctx.m_pRequestsStatus)[i]; + if (ithStatus == IOContext::Status::READ_SUCCESS) + { + completedIndex = i; + return true; + } + else if (ithStatus == IOContext::Status::READ_WAIT) + { + waitsRemaining = true; + } + } + + // if we didn't find one in READ_SUCCESS, wait for one to complete. + if (waitsRemaining) + { + WaitOnAddress(&ctx.m_completeCount, &completeCount, sizeof(completeCount), 100); + // this assumes the knowledge of the reader behavior (implicit + // contract). need better factoring? + } + } while (waitsRemaining); + + completedIndex = -1; + return false; + } + else + { + reader->wait(ctx, completedIndex); + return completedIndex != -1; + } } #endif @@ -1449,7 +1485,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t long requestCount = static_cast(frontier_read_reqs.size()); // If we issued read requests and if a read is complete or there are // reads in wait state, then enter the while loop. - while (requestCount > 0 && getNextCompletedRequest(reader, ctx, completedIndex)) + while (requestCount > 0 && getNextCompletedRequest(reader, ctx, requestCount, completedIndex)) { assert(completedIndex >= 0); auto &frontier_nhood = frontier_nhoods[completedIndex]; From 9c8e88d2da7554a6709571da3f5d38b7abc2c090 Mon Sep 17 00:00:00 2001 From: Renan S Date: Thu, 25 Apr 2024 13:39:25 -0700 Subject: [PATCH 18/38] merging multifilter for bann (#543) --- include/abstract_scratch.h | 35 ++ include/index.h | 2 + include/neighbor.h | 5 + include/pq.h | 50 +-- include/pq_common.h | 30 ++ include/pq_flash_index.h | 18 +- include/pq_scratch.h | 22 ++ include/scratch.h | 28 +- src/index.cpp | 2 +- src/pq_flash_index.cpp | 640 ++++++++++++++++++++++++------------- src/scratch.cpp | 55 +++- 11 files changed, 593 insertions(+), 294 deletions(-) create mode 100644 include/abstract_scratch.h create mode 100644 include/pq_common.h create mode 100644 include/pq_scratch.h diff --git a/include/abstract_scratch.h b/include/abstract_scratch.h new file mode 100644 index 000000000..b42a836f6 --- /dev/null +++ b/include/abstract_scratch.h @@ -0,0 +1,35 @@ +#pragma once +namespace diskann +{ + +template class PQScratch; + +// By somewhat more than a coincidence, it seems that both InMemQueryScratch +// and SSDQueryScratch have the aligned query and PQScratch objects. So we +// can put them in a neat hierarchy and keep PQScratch as a standalone class. +template class AbstractScratch +{ + public: + AbstractScratch() = default; + // This class does not take any responsibilty for memory management of + // its members. It is the responsibility of the derived classes to do so. + virtual ~AbstractScratch() = default; + + // Scratch objects should not be copied + AbstractScratch(const AbstractScratch &) = delete; + AbstractScratch &operator=(const AbstractScratch &) = delete; + + data_t *aligned_query_T() + { + return _aligned_query_T; + } + PQScratch *pq_scratch() + { + return _pq_scratch; + } + + protected: + data_t *_aligned_query_T = nullptr; + PQScratch *_pq_scratch = nullptr; +}; +} // namespace diskann diff --git a/include/index.h b/include/index.h index 4a9cbf999..fd5db1488 100644 --- a/include/index.h +++ b/include/index.h @@ -21,6 +21,8 @@ #include "in_mem_data_store.h" #include "in_mem_graph_store.h" #include "abstract_index.h" +#include "pq_scratch.h" +#include "pq.h" #define OVERHEAD_FACTOR 1.1 #ifdef EXEC_ENV_OLS diff --git a/include/neighbor.h b/include/neighbor.h index d7c0c25ed..7e6b58a65 100644 --- a/include/neighbor.h +++ b/include/neighbor.h @@ -109,6 +109,11 @@ class NeighborPriorityQueue return _cur < _size; } + void sort() + { + std::sort(_data.begin(), _data.begin() + _size); + } + size_t size() const { return _size; diff --git a/include/pq.h b/include/pq.h index acfa1b30a..3e6119f22 100644 --- a/include/pq.h +++ b/include/pq.h @@ -4,13 +4,7 @@ #pragma once #include "utils.h" - -#define NUM_PQ_BITS 8 -#define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS) -#define MAX_OPQ_ITERS 20 -#define NUM_KMEANS_REPS_PQ 12 -#define MAX_PQ_TRAINING_SET_SIZE 256000 -#define MAX_PQ_CHUNKS 512 +#include "pq_common.h" namespace diskann { @@ -53,40 +47,6 @@ class FixedChunkPQTable void populate_chunk_inner_products(const float *query_vec, float *dist_vec); }; -template struct PQScratch -{ - float *aligned_pqtable_dist_scratch = nullptr; // MUST BE AT LEAST [256 * NCHUNKS] - float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE - uint8_t *aligned_pq_coord_scratch = nullptr; // MUST BE AT LEAST [N_CHUNKS * MAX_DEGREE] - float *rotated_query = nullptr; - float *aligned_query_float = nullptr; - - PQScratch(size_t graph_degree, size_t aligned_dim) - { - diskann::alloc_aligned((void **)&aligned_pq_coord_scratch, - (size_t)graph_degree * (size_t)MAX_PQ_CHUNKS * sizeof(uint8_t), 256); - diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, 256 * (size_t)MAX_PQ_CHUNKS * sizeof(float), - 256); - diskann::alloc_aligned((void **)&aligned_dist_scratch, (size_t)graph_degree * sizeof(float), 256); - diskann::alloc_aligned((void **)&aligned_query_float, aligned_dim * sizeof(float), 8 * sizeof(float)); - diskann::alloc_aligned((void **)&rotated_query, aligned_dim * sizeof(float), 8 * sizeof(float)); - - memset(aligned_query_float, 0, aligned_dim * sizeof(float)); - memset(rotated_query, 0, aligned_dim * sizeof(float)); - } - - void set(size_t dim, T *query, const float norm = 1.0f) - { - for (size_t d = 0; d < dim; ++d) - { - if (norm != 1.0f) - rotated_query[d] = aligned_query_float[d] = static_cast(query[d]) / norm; - else - rotated_query[d] = aligned_query_float[d] = static_cast(query[d]); - } - } -}; - void aggregate_coords(const std::vector &ids, const uint8_t *all_coords, const uint64_t ndims, uint8_t *out); void pq_dist_lookup(const uint8_t *pq_ids, const size_t n_pts, const size_t pq_nchunks, const float *pq_dists, @@ -107,11 +67,19 @@ DISKANN_DLLEXPORT int generate_opq_pivots(const float *train_data, size_t num_tr unsigned num_pq_chunks, std::string opq_pivots_path, bool make_zero_mean = false); +DISKANN_DLLEXPORT int generate_pq_pivots_simplified(const float *train_data, size_t num_train, size_t dim, + size_t num_pq_chunks, std::vector &pivot_data_vector); + template int generate_pq_data_from_pivots(const std::string &data_file, unsigned num_centers, unsigned num_pq_chunks, const std::string &pq_pivots_path, const std::string &pq_compressed_vectors_path, bool use_opq = false); +DISKANN_DLLEXPORT int generate_pq_data_from_pivots_simplified(const float *data, const size_t num, + const float *pivot_data, const size_t pivots_num, + const size_t dim, const size_t num_pq_chunks, + std::vector &pq); + template void generate_disk_quantized_data(const std::string &data_file_to_use, const std::string &disk_pq_pivots_path, const std::string &disk_pq_compressed_vectors_path, diff --git a/include/pq_common.h b/include/pq_common.h new file mode 100644 index 000000000..c6a3a5739 --- /dev/null +++ b/include/pq_common.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + +#define NUM_PQ_BITS 8 +#define NUM_PQ_CENTROIDS (1 << NUM_PQ_BITS) +#define MAX_OPQ_ITERS 20 +#define NUM_KMEANS_REPS_PQ 12 +#define MAX_PQ_TRAINING_SET_SIZE 256000 +#define MAX_PQ_CHUNKS 512 + +namespace diskann +{ +inline std::string get_quantized_vectors_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks) +{ + return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_compressed.bin"; +} + +inline std::string get_pivot_data_filename(const std::string &prefix, bool use_opq, uint32_t num_chunks) +{ + return prefix + (use_opq ? "_opq" : "pq") + std::to_string(num_chunks) + "_pivots.bin"; +} + +inline std::string get_rotation_matrix_suffix(const std::string &pivot_data_filename) +{ + return pivot_data_filename + "_rotation_matrix.bin"; +} + +} // namespace diskann diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index 49a504a07..b1ec6db87 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -2,6 +2,7 @@ // Licensed under the MIT license. #pragma once +#include #include "common_includes.h" #include "aligned_file_reader.h" @@ -35,6 +36,15 @@ template class PQFlashIndex DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix); #endif +#ifdef EXEC_ENV_OLS + DISKANN_DLLEXPORT void load_labels(MemoryMappedFiles &files, const std::string &disk_index_file); +#else + DISKANN_DLLEXPORT void load_labels(const std::string& disk_index_filepath); +#endif + DISKANN_DLLEXPORT void load_label_medoid_map( + const std::string &labels_to_medoids_filepath, std::istream &medoid_stream); + DISKANN_DLLEXPORT void load_dummy_map(const std::string& dummy_map_filepath, std::istream &dummy_map_stream); + #ifdef EXEC_ENV_OLS DISKANN_DLLEXPORT int load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads, const char *index_filepath, const char *pivots_filepath, @@ -77,7 +87,7 @@ template class PQFlashIndex DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, uint64_t *res_ids, float *res_dists, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, + const bool use_filter, const std::vector &filter_labels, const uint32_t io_limit, const bool use_reorder_data = false, QueryStats *stats = nullptr); @@ -116,9 +126,11 @@ template class PQFlashIndex private: DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id); - std::unordered_map load_label_map(std::basic_istream &infile); + DISKANN_DLLEXPORT inline bool point_has_any_label(uint32_t point_id, const std::vector &label_ids); + void load_label_map(std::basic_istream &map_reader, + std::unordered_map &string_to_int_map); DISKANN_DLLEXPORT void parse_label_file(std::basic_istream &infile, size_t &num_pts_labels); - DISKANN_DLLEXPORT void get_label_file_metadata(std::basic_istream &infile, uint32_t &num_pts, + DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts, uint32_t &num_total_labels); DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, const uint32_t nthreads); diff --git a/include/pq_scratch.h b/include/pq_scratch.h new file mode 100644 index 000000000..2aa90dbe1 --- /dev/null +++ b/include/pq_scratch.h @@ -0,0 +1,22 @@ +#pragma once +#include +#include "pq_common.h" +#include "utils.h" + +namespace diskann +{ + +template class PQScratch +{ + public: + float *aligned_pqtable_dist_scratch = nullptr; // MUST BE AT LEAST [256 * NCHUNKS] + float *aligned_dist_scratch = nullptr; // MUST BE AT LEAST diskann MAX_DEGREE + uint8_t *aligned_pq_coord_scratch = nullptr; // AT LEAST [N_CHUNKS * MAX_DEGREE] + float *rotated_query = nullptr; + float *aligned_query_float = nullptr; + + PQScratch(size_t graph_degree, size_t aligned_dim); + void initialize(size_t dim, const T *query, const float norm = 1.0f); +}; + +} // namespace diskann \ No newline at end of file diff --git a/include/scratch.h b/include/scratch.h index f685b36d9..2f43e3365 100644 --- a/include/scratch.h +++ b/include/scratch.h @@ -12,22 +12,22 @@ #include "tsl/sparse_map.h" #include "aligned_file_reader.h" -#include "concurrent_queue.h" -#include "defaults.h" +#include "abstract_scratch.h" #include "neighbor.h" -#include "pq.h" +#include "defaults.h" +#include "concurrent_queue.h" namespace diskann { +template class PQScratch; // -// Scratch space for in-memory index based search +// AbstractScratch space for in-memory index based search // -template class InMemQueryScratch +template class InMemQueryScratch : public AbstractScratch { public: ~InMemQueryScratch(); - // REFACTOR TODO: move all parameters to a new class. InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, size_t aligned_dim, size_t alignment_factor, bool init_pq_scratch = false); void resize_for_new_L(uint32_t new_search_l); @@ -47,11 +47,11 @@ template class InMemQueryScratch } inline T *aligned_query() { - return _aligned_query; + return this->_aligned_query_T; } inline PQScratch *pq_scratch() { - return _pq_scratch; + return this->_pq_scratch; } inline std::vector &pool() { @@ -99,10 +99,6 @@ template class InMemQueryScratch uint32_t _R; uint32_t _maxc; - T *_aligned_query = nullptr; - - PQScratch *_pq_scratch = nullptr; - // _pool stores all neighbors explored from best_L_nodes. // Usually around L+R, but could be higher. // Initialized to 3L+R for some slack, expands as needed. @@ -139,10 +135,10 @@ template class InMemQueryScratch }; // -// Scratch space for SSD index based search +// AbstractScratch space for SSD index based search // -template class SSDQueryScratch +template class SSDQueryScratch : public AbstractScratch { public: T *coord_scratch = nullptr; // MUST BE AT LEAST [sizeof(T) * data_dim] @@ -150,10 +146,6 @@ template class SSDQueryScratch char *sector_scratch = nullptr; // MUST BE AT LEAST [MAX_N_SECTOR_READS * SECTOR_LEN] size_t sector_idx = 0; // index of next [SECTOR_LEN] scratch to use - T *aligned_query_T = nullptr; - - PQScratch *_pq_scratch; - tsl::robin_set visited; NeighborPriorityQueue retset; std::vector full_retset; diff --git a/src/index.cpp b/src/index.cpp index e199d7ea5..6d61aa45e 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1063,7 +1063,7 @@ std::pair Index::iterate_to_fixed_point( { query_float[d] = (float)aligned_query[d]; } - pq_query_scratch->set(_dim, aligned_query); + pq_query_scratch->initialize(_dim, aligned_query); // center the query and rotate if we have a rotation matrix _pq_table.preprocess_query(query_rotated); diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index cb6239f9c..d5fbbff20 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -4,6 +4,8 @@ #include "common_includes.h" #include "timer.h" +#include "pq.h" +#include "pq_scratch.h" #include "pq_flash_index.h" #include "cosine_similarity.h" @@ -25,19 +27,20 @@ namespace diskann { - template PQFlashIndex::PQFlashIndex(std::shared_ptr &fileReader, diskann::Metric m) : reader(fileReader), metric(m), _thread_data(nullptr) { + diskann::Metric metric_to_invoke = m; if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT) { if (std::is_floating_point::value) { - diskann::cout << "Cosine metric chosen for (normalized) float data." - "Changing distance to L2 to boost accuracy." + diskann::cout << "Since data is floating point, we assume that it has been appropriately pre-processed " + "(normalization for cosine, and convert-to-l2 by adding extra dimension for MIPS). So we " + "shall invoke an l2 distance function." << std::endl; - metric = diskann::Metric::L2; + metric_to_invoke = diskann::Metric::L2; } else { @@ -47,8 +50,8 @@ PQFlashIndex::PQFlashIndex(std::shared_ptr &fileRe } } - this->_dist_cmp.reset(diskann::get_distance_function(metric)); - this->_dist_cmp_float.reset(diskann::get_distance_function(metric)); + this->_dist_cmp.reset(diskann::get_distance_function(metric_to_invoke)); + this->_dist_cmp_float.reset(diskann::get_distance_function(metric_to_invoke)); } template PQFlashIndex::~PQFlashIndex() @@ -567,9 +570,8 @@ void PQFlashIndex::generate_random_labels(std::vector &labels } template -std::unordered_map PQFlashIndex::load_label_map(std::basic_istream &map_reader) +void PQFlashIndex::load_label_map(std::basic_istream &map_reader, std::unordered_map& string_to_int_map) { - std::unordered_map string_to_int_mp; std::string line, token; LabelT token_as_num; std::string label_str; @@ -580,9 +582,8 @@ std::unordered_map PQFlashIndex::load_label_map( label_str = token; getline(iss, token, '\t'); token_as_num = (LabelT)std::stoul(token); - string_to_int_mp[label_str] = token_as_num; + string_to_int_map[label_str] = token_as_num; } - return string_to_int_mp; } template @@ -610,28 +611,47 @@ void PQFlashIndex::reset_stream_for_reading(std::basic_istream } template -void PQFlashIndex::get_label_file_metadata(std::basic_istream &infile, uint32_t &num_pts, +void PQFlashIndex::get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts, uint32_t &num_total_labels) { - std::string line, token; num_pts = 0; num_total_labels = 0; - while (std::getline(infile, line)) + size_t file_size = fileContent.length(); + + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + while (cur_pos < file_size && cur_pos != std::string::npos) { - std::istringstream iss(line); - while (getline(iss, token, ',')) + next_pos = fileContent.find('\n', cur_pos); + if (next_pos == std::string::npos) { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + break; + } + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string::npos) + { + next_lbl_pos = fileContent.find(',', lbl_pos); + if (next_lbl_pos == std::string::npos) // the last label + { + next_lbl_pos = next_pos; + } + num_total_labels++; + + lbl_pos = next_lbl_pos + 1; } + + cur_pos = next_pos + 1; + num_pts++; } diskann::cout << "Labels file metadata: num_points: " << num_pts << ", #total_labels: " << num_total_labels << std::endl; - reset_stream_for_reading(infile); } template @@ -651,47 +671,103 @@ inline bool PQFlashIndex::point_has_label(uint32_t point_id, LabelT l return ret_val; } +template +bool PQFlashIndex::point_has_any_label(uint32_t point_id, const std::vector &label_ids) +{ + uint32_t start_vec = _pts_to_label_offsets[point_id]; + uint32_t num_lbls = _pts_to_label_counts[start_vec]; + bool ret_val = false; + for (auto &cur_lbl : label_ids) + { + if (point_has_label(point_id, cur_lbl)) + { + ret_val = true; + break; + } + } + return ret_val; +} + + template void PQFlashIndex::parse_label_file(std::basic_istream &infile, size_t &num_points_labels) { - std::string line, token; + infile.seekg(0, std::ios::end); + size_t file_size = infile.tellg(); + + std::string buffer(file_size, ' '); + + infile.seekg(0, std::ios::beg); + infile.read(&buffer[0], file_size); + + std::string line; uint32_t line_cnt = 0; uint32_t num_pts_in_label_file; uint32_t num_total_labels; - get_label_file_metadata(infile, num_pts_in_label_file, num_total_labels); + get_label_file_metadata(buffer, num_pts_in_label_file, num_total_labels); _pts_to_label_offsets = new uint32_t[num_pts_in_label_file]; _pts_to_label_counts = new uint32_t[num_pts_in_label_file]; _pts_to_labels = new LabelT[num_total_labels]; uint32_t labels_seen_so_far = 0; - while (std::getline(infile, line)) + std::string label_str; + size_t cur_pos = 0; + size_t next_pos = 0; + while (cur_pos < file_size && cur_pos != std::string::npos) { - std::istringstream iss(line); - std::vector lbls(0); + next_pos = buffer.find('\n', cur_pos); + if (next_pos == std::string::npos) + { + break; + } _pts_to_label_offsets[line_cnt] = labels_seen_so_far; uint32_t &num_lbls_in_cur_pt = _pts_to_label_counts[line_cnt]; num_lbls_in_cur_pt = 0; - getline(iss, token, '\t'); - std::istringstream new_iss(token); - while (getline(new_iss, token, ',')) + + size_t lbl_pos = cur_pos; + size_t next_lbl_pos = 0; + while (lbl_pos < next_pos && lbl_pos != std::string::npos) { - token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); - token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); - LabelT token_as_num = (LabelT)std::stoul(token); + next_lbl_pos = buffer.find(',', lbl_pos); + if (next_lbl_pos == std::string::npos) // the last label in the whole file + { + next_lbl_pos = next_pos; + } + + if (next_lbl_pos > next_pos) // the last label in one line, just read to the end + { + next_lbl_pos = next_pos; + } + + label_str.assign(buffer.c_str() + lbl_pos, next_lbl_pos - lbl_pos); + if (label_str[label_str.length() - 1] == '\t') // '\t' won't exist in label file? + { + label_str.erase(label_str.length() - 1); + } + + LabelT token_as_num = (LabelT)std::stoul(label_str); _pts_to_labels[labels_seen_so_far++] = (LabelT)token_as_num; num_lbls_in_cur_pt++; + + // move to next label + lbl_pos = next_lbl_pos + 1; } + // move to next line + cur_pos = next_pos + 1; + if (num_lbls_in_cur_pt == 0) { diskann::cout << "No label found for point " << line_cnt << std::endl; exit(-1); } + line_cnt++; } + num_points_labels = line_cnt; reset_stream_for_reading(infile); } @@ -702,80 +778,85 @@ template void PQFlashIndex::set_univers _universal_filter_label = label; } -#ifdef EXEC_ENV_OLS template -int PQFlashIndex::load(MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix) +void PQFlashIndex::load_label_medoid_map(const std::string& labels_to_medoids_filepath, std::istream& medoid_stream) { -#else -template int PQFlashIndex::load(uint32_t num_threads, const char *index_prefix) -{ -#endif - std::string pq_table_bin = std::string(index_prefix) + "_pq_pivots.bin"; - std::string pq_compressed_vectors = std::string(index_prefix) + "_pq_compressed.bin"; - std::string _disk_index_file = std::string(index_prefix) + "_disk.index"; -#ifdef EXEC_ENV_OLS - return load_from_separate_paths(files, num_threads, _disk_index_file.c_str(), pq_table_bin.c_str(), - pq_compressed_vectors.c_str()); -#else - return load_from_separate_paths(num_threads, _disk_index_file.c_str(), pq_table_bin.c_str(), - pq_compressed_vectors.c_str()); -#endif -} + std::string line, token; -#ifdef EXEC_ENV_OLS -template -int PQFlashIndex::load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads, - const char *index_filepath, const char *pivots_filepath, - const char *compressed_filepath) -{ -#else + _filter_to_medoid_ids.clear(); + try + { + while (std::getline(medoid_stream, line)) + { + std::istringstream iss(line); + uint32_t cnt = 0; + std::vector medoids; + LabelT label; + while (std::getline(iss, token, ',')) + { + if (cnt == 0) + label = (LabelT)std::stoul(token); + else + medoids.push_back((uint32_t)stoul(token)); + cnt++; + } + _filter_to_medoid_ids[label].swap(medoids); + } + } + catch (std::system_error &e) + { + throw FileException(labels_to_medoids_filepath, e, __FUNCSIG__, __FILE__, __LINE__); + } +} template -int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, const char *index_filepath, - const char *pivots_filepath, const char *compressed_filepath) +void PQFlashIndex::load_dummy_map(const std::string &dummy_map_filepath, std::istream &dummy_map_stream) { -#endif - std::string pq_table_bin = pivots_filepath; - std::string pq_compressed_vectors = compressed_filepath; - std::string _disk_index_file = index_filepath; - std::string medoids_file = std::string(_disk_index_file) + "_medoids.bin"; - std::string centroids_file = std::string(_disk_index_file) + "_centroids.bin"; - - std::string labels_file = std ::string(_disk_index_file) + "_labels.txt"; - std::string labels_to_medoids = std ::string(_disk_index_file) + "_labels_to_medoids.txt"; - std::string dummy_map_file = std ::string(_disk_index_file) + "_dummy_map.txt"; - std::string labels_map_file = std ::string(_disk_index_file) + "_labels_map.txt"; - size_t num_pts_in_label_file = 0; + std::string line, token; - size_t pq_file_dim, pq_file_num_centroids; -#ifdef EXEC_ENV_OLS - get_bin_metadata(files, pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE); -#else - get_bin_metadata(pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE); -#endif + try + { + while (std::getline(dummy_map_stream, line)) + { + std::istringstream iss(line); + uint32_t cnt = 0; + uint32_t dummy_id; + uint32_t real_id; + while (std::getline(iss, token, ',')) + { + if (cnt == 0) + dummy_id = (uint32_t)stoul(token); + else + real_id = (uint32_t)stoul(token); + cnt++; + } + _dummy_pts.insert(dummy_id); + _has_dummy_pts.insert(real_id); + _dummy_to_real_map[dummy_id] = real_id; - this->_disk_index_file = _disk_index_file; + if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) + _real_to_dummy_map[real_id] = std::vector(); - if (pq_file_num_centroids != 256) + _real_to_dummy_map[real_id].emplace_back(dummy_id); + } + } + catch (std::system_error &e) { - diskann::cout << "Error. Number of PQ centroids is not 256. Exiting." << std::endl; - return -1; + throw FileException (dummy_map_filepath, e, __FUNCSIG__, __FILE__, __LINE__); } - - this->_data_dim = pq_file_dim; - // will change later if we use PQ on disk or if we are using - // inner product without PQ - this->_disk_bytes_per_point = this->_data_dim * sizeof(T); - this->_aligned_dim = ROUND_UP(pq_file_dim, 8); - - size_t npts_u64, nchunks_u64; +} #ifdef EXEC_ENV_OLS - diskann::load_bin(files, pq_compressed_vectors, this->data, npts_u64, nchunks_u64); +template +void PQFlashIndex::load_labels(MemoryMappedFiles &files, const std::string &disk_index_file) #else - diskann::load_bin(pq_compressed_vectors, this->data, npts_u64, nchunks_u64); +template void PQFlashIndex::load_labels(const std::string &disk_index_file) #endif +{ + std::string labels_file = _disk_index_file + "_labels.txt"; + std::string labels_to_medoids = _disk_index_file + "_labels_to_medoids.txt"; + std::string dummy_map_file = _disk_index_file + "_dummy_map.txt"; + std::string labels_map_file = _disk_index_file + "_labels_map.txt"; + size_t num_pts_in_label_file = 0; - this->_num_points = npts_u64; - this->_n_chunks = nchunks_u64; #ifdef EXEC_ENV_OLS if (files.fileExists(labels_file)) { @@ -784,7 +865,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons #else if (file_exists(labels_file)) { - std::ifstream infile(labels_file); + std::ifstream infile(labels_file, std::ios::binary); if (infile.fail()) { throw diskann::ANNException(std::string("Failed to open file ") + labels_file, -1); @@ -803,7 +884,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons #else std::ifstream map_reader(labels_map_file); #endif - _label_map = load_label_map(map_reader); + load_label_map(map_reader, _label_map); #ifndef EXEC_ENV_OLS map_reader.close(); @@ -821,32 +902,7 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons std::ifstream medoid_stream(labels_to_medoids); assert(medoid_stream.is_open()); #endif - std::string line, token; - - _filter_to_medoid_ids.clear(); - try - { - while (std::getline(medoid_stream, line)) - { - std::istringstream iss(line); - uint32_t cnt = 0; - std::vector medoids; - LabelT label; - while (std::getline(iss, token, ',')) - { - if (cnt == 0) - label = (LabelT)std::stoul(token); - else - medoids.push_back((uint32_t)stoul(token)); - cnt++; - } - _filter_to_medoid_ids[label].swap(medoids); - } - } - catch (std::system_error &e) - { - throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__, __LINE__); - } + load_label_medoid_map(labels_to_medoids, medoid_stream); } std::string univ_label_file = std ::string(_disk_index_file) + "_universal_label.txt"; @@ -883,37 +939,87 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons std::ifstream dummy_map_stream(dummy_map_file); assert(dummy_map_stream.is_open()); #endif - std::string line, token; - - while (std::getline(dummy_map_stream, line)) - { - std::istringstream iss(line); - uint32_t cnt = 0; - uint32_t dummy_id; - uint32_t real_id; - while (std::getline(iss, token, ',')) - { - if (cnt == 0) - dummy_id = (uint32_t)stoul(token); - else - real_id = (uint32_t)stoul(token); - cnt++; - } - _dummy_pts.insert(dummy_id); - _has_dummy_pts.insert(real_id); - _dummy_to_real_map[dummy_id] = real_id; - - if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) - _real_to_dummy_map[real_id] = std::vector(); - - _real_to_dummy_map[real_id].emplace_back(dummy_id); - } + load_dummy_map(dummy_map_file, dummy_map_stream); #ifndef EXEC_ENV_OLS dummy_map_stream.close(); #endif diskann::cout << "Loaded dummy map" << std::endl; } } + else + { + diskann::cout << "Index built without filter support." << std::endl; + } +} + +#ifdef EXEC_ENV_OLS +template +int PQFlashIndex::load(MemoryMappedFiles &files, uint32_t num_threads, const char *index_prefix) +{ +#else +template int PQFlashIndex::load(uint32_t num_threads, const char *index_prefix) +{ +#endif + std::string pq_table_bin = std::string(index_prefix) + "_pq_pivots.bin"; + std::string pq_compressed_vectors = std::string(index_prefix) + "_pq_compressed.bin"; + std::string _disk_index_file = std::string(index_prefix) + "_disk.index"; +#ifdef EXEC_ENV_OLS + return load_from_separate_paths(files, num_threads, _disk_index_file.c_str(), pq_table_bin.c_str(), + pq_compressed_vectors.c_str()); +#else + return load_from_separate_paths(num_threads, _disk_index_file.c_str(), pq_table_bin.c_str(), + pq_compressed_vectors.c_str()); +#endif +} + +#ifdef EXEC_ENV_OLS +template +int PQFlashIndex::load_from_separate_paths(diskann::MemoryMappedFiles &files, uint32_t num_threads, + const char *index_filepath, const char *pivots_filepath, + const char *compressed_filepath) +{ +#else +template +int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, const char *index_filepath, + const char *pivots_filepath, const char *compressed_filepath) +{ +#endif + std::string pq_table_bin = pivots_filepath; + std::string pq_compressed_vectors = compressed_filepath; + std::string _disk_index_file = index_filepath; + std::string medoids_file = std::string(_disk_index_file) + "_medoids.bin"; + std::string centroids_file = std::string(_disk_index_file) + "_centroids.bin"; + + size_t pq_file_dim, pq_file_num_centroids; +#ifdef EXEC_ENV_OLS + get_bin_metadata(files, pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE); +#else + get_bin_metadata(pq_table_bin, pq_file_num_centroids, pq_file_dim, METADATA_SIZE); +#endif + + this->_disk_index_file = _disk_index_file; + + if (pq_file_num_centroids != 256) + { + diskann::cout << "Error. Number of PQ centroids is not 256. Exiting." << std::endl; + return -1; + } + + this->_data_dim = pq_file_dim; + // will change later if we use PQ on disk or if we are using + // inner product without PQ + this->_disk_bytes_per_point = this->_data_dim * sizeof(T); + this->_aligned_dim = ROUND_UP(pq_file_dim, 8); + + size_t npts_u64, nchunks_u64; +#ifdef EXEC_ENV_OLS + diskann::load_bin(files, pq_compressed_vectors, this->data, npts_u64, nchunks_u64); +#else + diskann::load_bin(pq_compressed_vectors, this->data, npts_u64, nchunks_u64); +#endif + + this->_num_points = npts_u64; + this->_n_chunks = nchunks_u64; #ifdef EXEC_ENV_OLS _pq_table.load_pq_centroid_bin(files, pq_table_bin.c_str(), nchunks_u64); @@ -1034,6 +1140,12 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons READ_U64(index_metadata, this->_nvecs_per_sector); } + #ifdef EXEC_ENV_OLS + load_labels(files, _disk_index_file); + #else + load_labels(_disk_index_file); + #endif + diskann::cout << "Disk-Index File Meta-data: "; diskann::cout << "# nodes per sector: " << _nnodes_per_sector; diskann::cout << ", max node len (bytes): " << _max_node_len; @@ -1135,51 +1247,43 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons diskann::cout << "Setting re-scaling factor of base vectors to " << this->_max_base_norm << std::endl; delete[] norm_val; } + diskann::cout << "done.." << std::endl; return 0; } #ifdef USE_BING_INFRA -bool getNextCompletedRequest(std::shared_ptr &reader, IOContext &ctx, size_t size, - int &completedIndex) +bool getNextCompletedRequest(const IOContext &ctx, size_t size, int &completedIndex) { - if ((*ctx.m_pRequests)[0].m_callback) + bool waitsRemaining = false; + long completeCount = ctx.m_completeCount; + do { - bool waitsRemaining = false; - long completeCount = ctx.m_completeCount; - do + for (int i = 0; i < size; i++) { - for (int i = 0; i < size; i++) + auto ithStatus = (*ctx.m_pRequestsStatus)[i]; + if (ithStatus == IOContext::Status::READ_SUCCESS) { - auto ithStatus = (*ctx.m_pRequestsStatus)[i]; - if (ithStatus == IOContext::Status::READ_SUCCESS) - { - completedIndex = i; - return true; - } - else if (ithStatus == IOContext::Status::READ_WAIT) - { - waitsRemaining = true; - } + completedIndex = i; + return true; } - - // if we didn't find one in READ_SUCCESS, wait for one to complete. - if (waitsRemaining) + else if (ithStatus == IOContext::Status::READ_WAIT) { - WaitOnAddress(&ctx.m_completeCount, &completeCount, sizeof(completeCount), 100); - // this assumes the knowledge of the reader behavior (implicit - // contract). need better factoring? + waitsRemaining = true; } - } while (waitsRemaining); + } - completedIndex = -1; - return false; - } - else - { - reader->wait(ctx, completedIndex); - return completedIndex != -1; - } + // if we didn't find one in READ_SUCCESS, wait for one to complete. + if (waitsRemaining) + { + WaitOnAddress(&ctx.m_completeCount, &completeCount, sizeof(completeCount), 100); + // this assumes the knowledge of the reader behavior (implicit + // contract). need better factoring? + } + } while (waitsRemaining); + + completedIndex = -1; + return false; } #endif @@ -1198,7 +1302,9 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t const bool use_filter, const LabelT &filter_label, const bool use_reorder_data, QueryStats *stats) { - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, use_filter, filter_label, + std::vector filters(1); + filters.push_back(filter_label); + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, use_filter, filters, std::numeric_limits::max(), use_reorder_data, stats); } @@ -1208,15 +1314,15 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t const uint32_t io_limit, const bool use_reorder_data, QueryStats *stats) { - LabelT dummy_filter = 0; - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filter, io_limit, + std::vector dummy_filters(0); + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filters, io_limit, use_reorder_data, stats); } template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, uint64_t *indices, float *distances, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, + const bool use_filters, const std::vector &filter_labels, const uint32_t io_limit, const bool use_reorder_data, QueryStats *stats) { @@ -1230,7 +1336,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t auto data = manager.scratch_space(); IOContext &ctx = data->ctx; auto query_scratch = &(data->scratch); - auto pq_query_scratch = query_scratch->_pq_scratch; + auto pq_query_scratch = query_scratch->pq_scratch(); // reset query scratch query_scratch->reset(); @@ -1238,28 +1344,33 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t // copy query to thread specific aligned and allocated memory (for distance // calculations we need aligned data) float query_norm = 0; - T *aligned_query_T = query_scratch->aligned_query_T; + T *aligned_query_T = query_scratch->aligned_query_T(); float *query_float = pq_query_scratch->aligned_query_float; float *query_rotated = pq_query_scratch->rotated_query; - // if inner product, we laso normalize the query and set the last coordinate - // to 0 (this is the extra coordindate used to convert MIPS to L2 search) - if (metric == diskann::Metric::INNER_PRODUCT) + uint32_t filter_label_count = (uint32_t)filter_labels.size(); + + // normalization step. for cosine, we simply normalize the query + // for mips, we normalize the first d-1 dims, and add a 0 for last dim, since an extra coordinate was used to + // convert MIPS to L2 search + if (metric == diskann::Metric::INNER_PRODUCT || metric == diskann::Metric::COSINE) { - for (size_t i = 0; i < this->_data_dim - 1; i++) + uint64_t inherent_dim = (metric == diskann::Metric::COSINE) ? this->_data_dim : (uint64_t)(this->_data_dim - 1); + for (size_t i = 0; i < inherent_dim; i++) { aligned_query_T[i] = query1[i]; query_norm += query1[i] * query1[i]; } - aligned_query_T[this->_data_dim - 1] = 0; + if (metric == diskann::Metric::INNER_PRODUCT) + aligned_query_T[this->_data_dim - 1] = 0; query_norm = std::sqrt(query_norm); - for (size_t i = 0; i < this->_data_dim - 1; i++) + for (size_t i = 0; i < inherent_dim; i++) { aligned_query_T[i] = (T)(aligned_query_T[i] / query_norm); } - pq_query_scratch->set(this->_data_dim, aligned_query_T); + pq_query_scratch->initialize(this->_data_dim, aligned_query_T); } else { @@ -1267,7 +1378,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t { aligned_query_T[i] = query1[i]; } - pq_query_scratch->set(this->_data_dim, aligned_query_T); + pq_query_scratch->initialize(this->_data_dim, aligned_query_T); } // pointers to buffers for data @@ -1300,12 +1411,22 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t tsl::robin_set &visited = query_scratch->visited; NeighborPriorityQueue &retset = query_scratch->retset; - retset.reserve(l_search); std::vector &full_retset = query_scratch->full_retset; + tsl::robin_set full_retset_ids; + if (use_filters) { + uint64_t size_to_reserve = std::max(l_search, (std::min((uint64_t)filter_label_count, this->_max_degree) + 1)); + retset.reserve(size_to_reserve); + full_retset.reserve(4096); + full_retset_ids.reserve(4096); + } else { + retset.reserve(l_search + 1); + } + uint32_t best_medoid = 0; + uint32_t cur_list_size = 0; float best_dist = (std::numeric_limits::max)(); - if (!use_filter) + if (!use_filters) { for (uint64_t cur_m = 0; cur_m < _num_medoids; cur_m++) { @@ -1317,35 +1438,36 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t best_dist = cur_expanded_dist; } } - } - else - { - if (_filter_to_medoid_ids.find(filter_label) != _filter_to_medoid_ids.end()) + compute_dists(&best_medoid, 1, dist_scratch); + retset.insert(Neighbor(best_medoid, dist_scratch[0])); + visited.insert(best_medoid); + cur_list_size = 1; + } else { + std::vector filter_specific_medoids; + filter_specific_medoids.reserve(filter_label_count); + location_t ctr = 0; + for (; ctr < filter_label_count && ctr < this->_max_degree; ctr++) { - const auto &medoid_ids = _filter_to_medoid_ids[filter_label]; - for (uint64_t cur_m = 0; cur_m < medoid_ids.size(); cur_m++) + if (filter_labels[ctr] != -1) { - // for filtered index, we dont store global centroid data as for unfiltered index, so we use PQ distance - // as approximation to decide closest medoid matching the query filter. - compute_dists(&medoid_ids[cur_m], 1, dist_scratch); - float cur_expanded_dist = dist_scratch[0]; - if (cur_expanded_dist < best_dist) + for (auto id : this->_filter_to_medoid_ids[filter_labels[ctr]]) { - best_medoid = medoid_ids[cur_m]; - best_dist = cur_expanded_dist; + filter_specific_medoids.push_back(id); } } } - else + compute_dists(filter_specific_medoids.data(), filter_specific_medoids.size(), dist_scratch); + for (ctr = 0; ctr < filter_specific_medoids.size(); ctr++) { - throw ANNException("Cannot find medoid for specified filter.", -1, __FUNCSIG__, __FILE__, __LINE__); + retset.insert(Neighbor(filter_specific_medoids[ctr], dist_scratch[ctr])); + //retset[ctr].id = filter_specific_medoids[ctr]; + //retset[ctr].distance = dist_scratch[ctr]; + //retset[ctr].expanded = false; + visited.insert(filter_specific_medoids[ctr]); } + cur_list_size = (uint32_t) filter_specific_medoids.size(); } - compute_dists(&best_medoid, 1, dist_scratch); - retset.insert(Neighbor(best_medoid, dist_scratch[0])); - visited.insert(best_medoid); - uint32_t cmps = 0; uint32_t hops = 0; uint32_t num_ios = 0; @@ -1360,7 +1482,15 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t std::vector>> cached_nhoods; cached_nhoods.reserve(2 * beam_width); - while (retset.has_unexpanded_node() && num_ios < io_limit) + //if we are doing multi-filter search we don't want to restrict the number of IOs + //at present. Must revisit this decision later. + uint32_t max_ios_for_query = use_filters || (io_limit == 0) ? std::numeric_limits::max() : io_limit; + const std::vector& label_ids = filter_labels; //avoid renaming. + std::vector lbl_vec; + + retset.sort(); + + while (retset.has_unexpanded_node() && num_ios < max_ios_for_query) { // clear iteration state frontier.clear(); @@ -1370,6 +1500,45 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t sector_scratch_idx = 0; // find new beam uint32_t num_seen = 0; + + + for (const auto &lbl : label_ids) + { // assuming that number of OR labels is + // less than max frontier size allowed + uint32_t lbl_marker = 0; + while (lbl_marker < cur_list_size) + { + lbl_vec.clear(); + lbl_vec.emplace_back(lbl); + + if (!retset[lbl_marker].expanded && point_has_any_label(retset[lbl_marker].id, lbl_vec)) + { + num_seen++; + auto iter = _nhood_cache.find(retset[lbl_marker].id); + if (iter != _nhood_cache.end()) + { + cached_nhoods.push_back(std::make_pair(retset[lbl_marker].id, iter->second)); + if (stats != nullptr) + { + stats->n_cache_hits++; + } + } + else + { + frontier.push_back(retset[lbl_marker].id); + } + retset[lbl_marker].expanded = true; + if (this->_count_visited_nodes) + { + reinterpret_cast &>(this->_node_visit_counter[retset[lbl_marker].id].second) + .fetch_add(1); + } + break; + } + lbl_marker++; + } + } + while (retset.has_unexpanded_node() && frontier.size() < beam_width && num_seen < beam_width) { auto nbr = retset.closest_unexpanded(); @@ -1446,7 +1615,24 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t cur_expanded_dist = _disk_pq_table.l2_distance( // disk_pq does not support OPQ yet query_float, (uint8_t *)node_fp_coords_copy); } - full_retset.push_back(Neighbor((uint32_t)cached_nhood.first, cur_expanded_dist)); + if (use_filters) + { + location_t real_id = cached_nhood.first; + if (_dummy_pts.find(real_id) != _dummy_pts.end()) + { + real_id = _dummy_to_real_map[real_id]; + } + if (full_retset_ids.find(real_id) == full_retset_ids.end()) + { + full_retset.push_back(Neighbor((uint32_t)real_id, cur_expanded_dist)); + full_retset_ids.insert(real_id); + } + } + else + { + full_retset.push_back(Neighbor((unsigned)cached_nhood.first, cur_expanded_dist)); + } + uint64_t nnbrs = cached_nhood.second.first; uint32_t *node_nbrs = cached_nhood.second.second; @@ -1466,10 +1652,10 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t uint32_t id = node_nbrs[m]; if (visited.insert(id).second) { - if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + if (!use_filters && _dummy_pts.find(id) != _dummy_pts.end()) continue; - if (use_filter && !(point_has_label(id, filter_label)) && + if (use_filters && !(point_has_any_label(id, label_ids)) && (!_use_universal_label || !point_has_label(id, _universal_filter_label))) continue; cmps++; @@ -1485,7 +1671,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t long requestCount = static_cast(frontier_read_reqs.size()); // If we issued read requests and if a read is complete or there are // reads in wait state, then enter the while loop. - while (requestCount > 0 && getNextCompletedRequest(reader, ctx, requestCount, completedIndex)) + while (requestCount > 0 && getNextCompletedRequest(ctx, requestCount, completedIndex)) { assert(completedIndex >= 0); auto &frontier_nhood = frontier_nhoods[completedIndex]; @@ -1511,7 +1697,25 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t else cur_expanded_dist = _disk_pq_table.l2_distance(query_float, (uint8_t *)data_buf); } - full_retset.push_back(Neighbor(frontier_nhood.first, cur_expanded_dist)); + if (use_filters) + { + location_t real_id = frontier_nhood.first; + if (_dummy_pts.find(real_id) != _dummy_pts.end()) + { + real_id = _dummy_to_real_map[real_id]; + } + + if (full_retset_ids.find(real_id) == full_retset_ids.end()) + { + full_retset.push_back(Neighbor(real_id, cur_expanded_dist)); + full_retset_ids.insert(real_id); + } + } + else + { + full_retset.push_back(Neighbor(frontier_nhood.first, cur_expanded_dist)); + } + uint32_t *node_nbrs = (node_buf + 1); // compute node_nbrs <-> query dist in PQ space cpu_timer.reset(); @@ -1529,10 +1733,10 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t uint32_t id = node_nbrs[m]; if (visited.insert(id).second) { - if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + if (!use_filters && _dummy_pts.find(id) != _dummy_pts.end()) continue; - if (use_filter && !(point_has_label(id, filter_label)) && + if (use_filters && !(point_has_any_label(id, label_ids)) && (!_use_universal_label || !point_has_label(id, _universal_filter_label))) continue; cmps++; @@ -1552,10 +1756,8 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t stats->cpu_us += (float)cpu_timer.elapsed(); } } - hops++; } - // re-sort by distance std::sort(full_retset.begin(), full_retset.end()); diff --git a/src/scratch.cpp b/src/scratch.cpp index 112c65d28..c3836ccf1 100644 --- a/src/scratch.cpp +++ b/src/scratch.cpp @@ -5,6 +5,7 @@ #include #include "scratch.h" +#include "pq_scratch.h" namespace diskann { @@ -24,13 +25,13 @@ InMemQueryScratch::InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, throw diskann::ANNException(ss.str(), -1); } - alloc_aligned(((void **)&_aligned_query), aligned_dim * sizeof(T), alignment_factor * sizeof(T)); - memset(_aligned_query, 0, aligned_dim * sizeof(T)); + alloc_aligned(((void **)&this->_aligned_query_T), aligned_dim * sizeof(T), alignment_factor * sizeof(T)); + memset(this->_aligned_query_T, 0, aligned_dim * sizeof(T)); if (init_pq_scratch) - _pq_scratch = new PQScratch(defaults::MAX_GRAPH_DEGREE, aligned_dim); + this->_pq_scratch = new PQScratch(defaults::MAX_GRAPH_DEGREE, aligned_dim); else - _pq_scratch = nullptr; + this->_pq_scratch = nullptr; _occlude_factor.reserve(maxc); _inserted_into_pool_bs = new boost::dynamic_bitset<>(); @@ -71,12 +72,13 @@ template void InMemQueryScratch::resize_for_new_L(uint32_t new_l template InMemQueryScratch::~InMemQueryScratch() { - if (_aligned_query != nullptr) + if (this->_aligned_query_T != nullptr) { - aligned_free(_aligned_query); + aligned_free(this->_aligned_query_T); + this->_aligned_query_T = nullptr; } - delete _pq_scratch; + delete this->_pq_scratch; delete _inserted_into_pool_bs; } @@ -98,12 +100,12 @@ template SSDQueryScratch::SSDQueryScratch(size_t aligned_dim, si diskann::alloc_aligned((void **)&coord_scratch, coord_alloc_size, 256); diskann::alloc_aligned((void **)§or_scratch, defaults::MAX_N_SECTOR_READS * defaults::SECTOR_LEN, defaults::SECTOR_LEN); - diskann::alloc_aligned((void **)&aligned_query_T, aligned_dim * sizeof(T), 8 * sizeof(T)); + diskann::alloc_aligned((void **)&this->_aligned_query_T, aligned_dim * sizeof(T), 8 * sizeof(T)); - _pq_scratch = new PQScratch(defaults::MAX_GRAPH_DEGREE, aligned_dim); + this->_pq_scratch = new PQScratch(defaults::MAX_GRAPH_DEGREE, aligned_dim); memset(coord_scratch, 0, coord_alloc_size); - memset(aligned_query_T, 0, aligned_dim * sizeof(T)); + memset(this->_aligned_query_T, 0, aligned_dim * sizeof(T)); visited.reserve(visited_reserve); full_retset.reserve(visited_reserve); @@ -113,9 +115,9 @@ template SSDQueryScratch::~SSDQueryScratch() { diskann::aligned_free((void *)coord_scratch); diskann::aligned_free((void *)sector_scratch); - diskann::aligned_free((void *)aligned_query_T); + diskann::aligned_free((void *)this->_aligned_query_T); - delete[] _pq_scratch; + delete[] this->_pq_scratch; } template @@ -128,6 +130,30 @@ template void SSDThreadData::clear() scratch.reset(); } +template PQScratch::PQScratch(size_t graph_degree, size_t aligned_dim) +{ + diskann::alloc_aligned((void **)&aligned_pq_coord_scratch, + (size_t)graph_degree * (size_t)MAX_PQ_CHUNKS * sizeof(uint8_t), 256); + diskann::alloc_aligned((void **)&aligned_pqtable_dist_scratch, 256 * (size_t)MAX_PQ_CHUNKS * sizeof(float), 256); + diskann::alloc_aligned((void **)&aligned_dist_scratch, (size_t)graph_degree * sizeof(float), 256); + diskann::alloc_aligned((void **)&aligned_query_float, aligned_dim * sizeof(float), 8 * sizeof(float)); + diskann::alloc_aligned((void **)&rotated_query, aligned_dim * sizeof(float), 8 * sizeof(float)); + + memset(aligned_query_float, 0, aligned_dim * sizeof(float)); + memset(rotated_query, 0, aligned_dim * sizeof(float)); +} + +template void PQScratch::initialize(size_t dim, const T *query, const float norm) +{ + for (size_t d = 0; d < dim; ++d) + { + if (norm != 1.0f) + rotated_query[d] = aligned_query_float[d] = static_cast(query[d]) / norm; + else + rotated_query[d] = aligned_query_float[d] = static_cast(query[d]); + } +} + template DISKANN_DLLEXPORT class InMemQueryScratch; template DISKANN_DLLEXPORT class InMemQueryScratch; template DISKANN_DLLEXPORT class InMemQueryScratch; @@ -136,7 +162,12 @@ template DISKANN_DLLEXPORT class SSDQueryScratch; template DISKANN_DLLEXPORT class SSDQueryScratch; template DISKANN_DLLEXPORT class SSDQueryScratch; +template DISKANN_DLLEXPORT class PQScratch; +template DISKANN_DLLEXPORT class PQScratch; +template DISKANN_DLLEXPORT class PQScratch; + template DISKANN_DLLEXPORT class SSDThreadData; template DISKANN_DLLEXPORT class SSDThreadData; template DISKANN_DLLEXPORT class SSDThreadData; + } // namespace diskann From 39a20050a8bc482debbcfb25db08d71dde2c0a38 Mon Sep 17 00:00:00 2001 From: Renan S Date: Mon, 6 May 2024 09:55:02 -0700 Subject: [PATCH 19/38] Fix for getNextCompletedRequest (#548) * readd skip wait if callback is not present * Delete AnyBuildLogs/latest.txt * Update pq_flash_index.cpp --- src/pq_flash_index.cpp | 57 ++++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index d5fbbff20..265f266bd 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -1253,37 +1253,46 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons } #ifdef USE_BING_INFRA -bool getNextCompletedRequest(const IOContext &ctx, size_t size, int &completedIndex) +bool getNextCompletedRequest(std::shared_ptr &reader, IOContext &ctx, size_t size, + int &completedIndex) { - bool waitsRemaining = false; - long completeCount = ctx.m_completeCount; - do + if ((*ctx.m_pRequests)[0].m_callback) { - for (int i = 0; i < size; i++) + bool waitsRemaining = false; + long completeCount = ctx.m_completeCount; + do { - auto ithStatus = (*ctx.m_pRequestsStatus)[i]; - if (ithStatus == IOContext::Status::READ_SUCCESS) + for (int i = 0; i < size; i++) { - completedIndex = i; - return true; + auto ithStatus = (*ctx.m_pRequestsStatus)[i]; + if (ithStatus == IOContext::Status::READ_SUCCESS) + { + completedIndex = i; + return true; + } + else if (ithStatus == IOContext::Status::READ_WAIT) + { + waitsRemaining = true; + } } - else if (ithStatus == IOContext::Status::READ_WAIT) + + // if we didn't find one in READ_SUCCESS, wait for one to complete. + if (waitsRemaining) { - waitsRemaining = true; + WaitOnAddress(&ctx.m_completeCount, &completeCount, sizeof(completeCount), 100); + // this assumes the knowledge of the reader behavior (implicit + // contract). need better factoring? } - } + } while (waitsRemaining); - // if we didn't find one in READ_SUCCESS, wait for one to complete. - if (waitsRemaining) - { - WaitOnAddress(&ctx.m_completeCount, &completeCount, sizeof(completeCount), 100); - // this assumes the knowledge of the reader behavior (implicit - // contract). need better factoring? - } - } while (waitsRemaining); - - completedIndex = -1; - return false; + completedIndex = -1; + return false; + } + else + { + reader->wait(ctx, completedIndex); + return completedIndex != -1; + } } #endif @@ -1671,7 +1680,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t long requestCount = static_cast(frontier_read_reqs.size()); // If we issued read requests and if a read is complete or there are // reads in wait state, then enter the while loop. - while (requestCount > 0 && getNextCompletedRequest(ctx, requestCount, completedIndex)) + while (requestCount > 0 && getNextCompletedRequest(reader, ctx, requestCount, completedIndex)) { assert(completedIndex >= 0); auto &frontier_nhood = frontier_nhoods[completedIndex]; From 7e9c3f4465c80ae3ffa4f603d22564f649de7770 Mon Sep 17 00:00:00 2001 From: Renan S Date: Wed, 12 Jun 2024 12:09:45 -0700 Subject: [PATCH 20/38] 560 bug remove unused line (#561) * remove unused line * remove old code * remove other unused line --- src/pq_flash_index.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 265f266bd..491614a42 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -674,8 +674,6 @@ inline bool PQFlashIndex::point_has_label(uint32_t point_id, LabelT l template bool PQFlashIndex::point_has_any_label(uint32_t point_id, const std::vector &label_ids) { - uint32_t start_vec = _pts_to_label_offsets[point_id]; - uint32_t num_lbls = _pts_to_label_counts[start_vec]; bool ret_val = false; for (auto &cur_lbl : label_ids) { From 54736568aa61bc1eb98a891239ff95e8456a400d Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Mon, 17 Jun 2024 20:57:08 +0530 Subject: [PATCH 21/38] Fixes to utility functions and apps to support multi-filter queries --- apps/search_disk_index.cpp | 68 ++++++++++++++++++++++++++++------- include/abstract_data_store.h | 2 +- include/in_mem_data_store.h | 4 +-- include/utils.h | 5 ++- src/in_mem_data_store.cpp | 8 ++--- src/index.cpp | 2 +- src/utils.cpp | 16 +++++++++ 7 files changed, 83 insertions(+), 22 deletions(-) diff --git a/apps/search_disk_index.cpp b/apps/search_disk_index.cpp index 7e2a7ac6d..73e79f3e7 100644 --- a/apps/search_disk_index.cpp +++ b/apps/search_disk_index.cpp @@ -4,6 +4,7 @@ #include "common_includes.h" #include +#include "utils.h" #include "index.h" #include "disk_utils.h" #include "math_utils.h" @@ -47,6 +48,44 @@ void print_stats(std::string category, std::vector percentiles, std::vect diskann::cout << std::endl; } +template +void parse_labels_of_query(const std::string &filters_for_query, + std::unique_ptr> &pFlashIndex, + std::vector &label_ids_for_query) +{ + std::vector label_strs_for_query; + diskann::split_string(filters_for_query, FILTER_OR_SEPARATOR, label_strs_for_query); + for (auto &label_str_for_query : label_strs_for_query) + { + label_ids_for_query.push_back(pFlashIndex->get_converted_label(label_str_for_query)); + } +} + +template +void populate_label_ids(const std::vector &filters_of_queries, + std::unique_ptr> &pFlashIndex, + std::vector> &label_ids_of_queries, bool apply_one_to_all, uint32_t query_count) +{ + if (apply_one_to_all) + { + std::vector label_ids_of_query; + parse_labels_of_query(filters_of_queries[0], pFlashIndex, label_ids_of_query); + for (uint32_t i = 0; i < query_count; i++) + { + label_ids_of_queries.push_back(label_ids_of_query); + } + } + else + { + for (auto &filters_of_query : filters_of_queries) + { + std::vector label_ids_of_query; + parse_labels_of_query(filters_of_query, pFlashIndex, label_ids_of_query); + label_ids_of_queries.push_back(label_ids_of_query); + } + } +} + template int search_disk_index(diskann::Metric &metric, const std::string &index_path_prefix, const std::string &result_output_prefix, const std::string &query_file, std::string >_file, @@ -173,6 +212,14 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre diskann::cout << "..done" << std::endl; } + std::vector> per_query_label_ids; + if (filtered_search) + { + populate_label_ids(query_filters, _pFlashIndex, per_query_label_ids, (query_filters.size() == 1), query_num ); + } + + + diskann::cout.setf(std::ios_base::fixed, std::ios_base::floatfield); diskann::cout.precision(2); @@ -236,19 +283,10 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre } else { - LabelT label_for_search; - if (query_filters.size() == 1) - { // one label for all queries - label_for_search = _pFlashIndex->get_converted_label(query_filters[0]); - } - else - { // one label for each query - label_for_search = _pFlashIndex->get_converted_label(query_filters[i]); - } _pFlashIndex->cached_beam_search( query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at), - query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search, - use_reorder_data, stats + i); + query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, per_query_label_ids[i], + search_io_limit, use_reorder_data, stats + i); } } auto e = std::chrono::high_resolution_clock::now(); @@ -270,6 +308,9 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre auto mean_cpuus = diskann::get_mean_stats(stats, query_num, [](const diskann::QueryStats &stats) { return stats.cpu_us; }); + auto mean_hops = diskann::get_mean_stats( + stats, query_num, [](const diskann::QueryStats &stats) { return stats.n_hops; }); + double recall = 0; if (calc_recall_flag) { @@ -283,10 +324,12 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre << std::setw(16) << mean_cpuus; if (calc_recall_flag) { - diskann::cout << std::setw(16) << recall << std::endl; + diskann::cout << std::setw(16) << recall << std::endl ; } else + { diskann::cout << std::endl; + } delete[] stats; } @@ -443,7 +486,6 @@ int main(int argc, char **argv) { query_filters = read_file_to_vector_of_strings(query_filters_file); } - try { if (!query_filters.empty() && label_type == "ushort") diff --git a/include/abstract_data_store.h b/include/abstract_data_store.h index e0c9a99e8..2b67fa9e4 100644 --- a/include/abstract_data_store.h +++ b/include/abstract_data_store.h @@ -24,7 +24,7 @@ template class AbstractDataStore // Return number of points returned virtual location_t load(const std::string &filename, size_t offset) = 0; - virtual location_t load(AlignedFileReader &reader, size_t offset) = 0; + //virtual location_t load(AlignedFileReader &reader, size_t offset) = 0; // Why does store take num_pts? Since store only has capacity, but we allow // resizing we can end up in a situation where the store has spare capacity. diff --git a/include/in_mem_data_store.h b/include/in_mem_data_store.h index 6feb09199..4bde7b483 100644 --- a/include/in_mem_data_store.h +++ b/include/in_mem_data_store.h @@ -25,7 +25,7 @@ template class InMemDataStore : public AbstractDataStore class InMemDataStore : public AbstractDataStore> &groundtruth, std::vector> &our_results); +DISKANN_DLLEXPORT void split_string(const std::string &string_to_split, const std::string &delimiter, + std::vector &pieces); template inline void load_bin(const std::string &bin_file, std::unique_ptr &data, size_t &npts, size_t &dim, diff --git a/src/in_mem_data_store.cpp b/src/in_mem_data_store.cpp index e168d96fa..1e5b93a61 100644 --- a/src/in_mem_data_store.cpp +++ b/src/in_mem_data_store.cpp @@ -42,10 +42,10 @@ template location_t InMemDataStore::load(const std::st return load_impl(filename, offset); } -template location_t InMemDataStore::load(AlignedFileReader &reader, size_t offset) -{ - return load_impl(reader, offset); -} +//template location_t InMemDataStore::load(AlignedFileReader &reader, size_t offset) +//{ +// return load_impl(reader, offset); +//} #ifdef EXEC_ENV_OLS template location_t InMemDataStore::load_impl(AlignedFileReader &reader, size_t offset) diff --git a/src/index.cpp b/src/index.cpp index 6d61aa45e..92007fbb0 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2065,7 +2065,7 @@ LabelT Index::get_converted_label(const std::string &raw_label) return _universal_label; } std::stringstream stream; - stream << "Unable to find label in the Label Map"; + stream << "Unable to find label" << raw_label << "in the label map "; diskann::cerr << stream.str() << std::endl; throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } diff --git a/src/utils.cpp b/src/utils.cpp index ab36c42a8..402f2c37c 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -252,6 +252,22 @@ double calculate_range_search_recall(uint32_t num_queries, std::vector &pieces) +{ + size_t start = 0; + size_t end; + while ((end = string_to_split.find(delimiter, start)) != std::string::npos) + { + pieces.push_back(string_to_split.substr(start, end - start)); + start = end + delimiter.length(); + } + if (start != string_to_split.length()) + { + pieces.push_back(string_to_split.substr(start, string_to_split.length() - start)); + } +} + + #ifdef EXEC_ENV_OLS void get_bin_metadata(AlignedFileReader &reader, size_t &npts, size_t &ndim, size_t offset) { From e94b9a8fcb3302558b5d78bd603da7a5cc8ca32f Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Thu, 20 Jun 2024 16:45:53 +0530 Subject: [PATCH 22/38] Set sector scratch to max of (maxdegree,max_filters_per_query,max_sector_reads) --- include/defaults.h | 4 ++++ include/pq_flash_index.h | 4 +++- include/scratch.h | 4 ++-- src/pq_flash_index.cpp | 7 ++++--- src/scratch.cpp | 10 +++++++--- 5 files changed, 20 insertions(+), 9 deletions(-) diff --git a/include/defaults.h b/include/defaults.h index ef1750fcf..8aa31fd50 100644 --- a/include/defaults.h +++ b/include/defaults.h @@ -30,5 +30,9 @@ const uint32_t MAX_DEGREE = 64; const uint32_t BUILD_LIST_SIZE = 100; const uint32_t SATURATE_GRAPH = false; const uint32_t SEARCH_LIST_SIZE = 100; + +const size_t VISITED_RESERVE = 4096; +const size_t MAX_FILTERS_PER_QUERY = 4096; + } // namespace defaults } // namespace diskann diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index b1ec6db87..21381b3fe 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -120,7 +120,9 @@ template class PQFlashIndex protected: DISKANN_DLLEXPORT void use_medoids_data_as_centroids(); - DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = 4096); + DISKANN_DLLEXPORT void setup_thread_data(uint64_t nthreads, uint64_t visited_reserve = defaults::VISITED_RESERVE, + uint64_t max_degree = defaults::MAX_DEGREE, + uint64_t max_filters_per_query = defaults::MAX_FILTERS_PER_QUERY); DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); diff --git a/include/scratch.h b/include/scratch.h index 2f43e3365..8d8105c3b 100644 --- a/include/scratch.h +++ b/include/scratch.h @@ -150,7 +150,7 @@ template class SSDQueryScratch : public AbstractScratch NeighborPriorityQueue retset; std::vector full_retset; - SSDQueryScratch(size_t aligned_dim, size_t visited_reserve); + SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, size_t max_degree, size_t max_filters_per_query); ~SSDQueryScratch(); void reset(); @@ -162,7 +162,7 @@ template class SSDThreadData SSDQueryScratch scratch; IOContext ctx; - SSDThreadData(size_t aligned_dim, size_t visited_reserve); + SSDThreadData(size_t aligned_dim, size_t visited_reserve, size_t max_degree, size_t max_filters_per_query); void clear(); }; diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 491614a42..11985a16f 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -117,7 +117,8 @@ template inline T *PQFlashIndex::offset } template -void PQFlashIndex::setup_thread_data(uint64_t nthreads, uint64_t visited_reserve) +void PQFlashIndex::setup_thread_data(uint64_t nthreads, uint64_t visited_reserve, uint64_t max_degree, + uint64_t max_filters_per_query) { diskann::cout << "Setting up thread-specific contexts for nthreads: " << nthreads << std::endl; // omp parallel for to generate unique thread IDs @@ -126,7 +127,7 @@ void PQFlashIndex::setup_thread_data(uint64_t nthreads, uint64_t visi { #pragma omp critical { - SSDThreadData *data = new SSDThreadData(this->_aligned_dim, visited_reserve); + SSDThreadData *data = new SSDThreadData(this->_aligned_dim, visited_reserve, max_degree, max_filters_per_query); this->reader->register_thread(); data->ctx = this->reader->get_ctx(); this->_thread_data.push(data); @@ -1309,7 +1310,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t const bool use_filter, const LabelT &filter_label, const bool use_reorder_data, QueryStats *stats) { - std::vector filters(1); + std::vector filters; filters.push_back(filter_label); cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, use_filter, filters, std::numeric_limits::max(), use_reorder_data, stats); diff --git a/src/scratch.cpp b/src/scratch.cpp index c3836ccf1..05c9a1553 100644 --- a/src/scratch.cpp +++ b/src/scratch.cpp @@ -93,12 +93,14 @@ template void SSDQueryScratch::reset() full_retset.clear(); } -template SSDQueryScratch::SSDQueryScratch(size_t aligned_dim, size_t visited_reserve) +template SSDQueryScratch::SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, size_t max_degree, size_t max_filters_per_query) { size_t coord_alloc_size = ROUND_UP(sizeof(T) * aligned_dim, 256); + size_t sector_scratch_size = + std::max(max_degree, std::max(max_filters_per_query, (size_t)defaults::MAX_N_SECTOR_READS)); diskann::alloc_aligned((void **)&coord_scratch, coord_alloc_size, 256); - diskann::alloc_aligned((void **)§or_scratch, defaults::MAX_N_SECTOR_READS * defaults::SECTOR_LEN, + diskann::alloc_aligned((void **)§or_scratch, sector_scratch_size * defaults::SECTOR_LEN, defaults::SECTOR_LEN); diskann::alloc_aligned((void **)&this->_aligned_query_T, aligned_dim * sizeof(T), 8 * sizeof(T)); @@ -121,7 +123,9 @@ template SSDQueryScratch::~SSDQueryScratch() } template -SSDThreadData::SSDThreadData(size_t aligned_dim, size_t visited_reserve) : scratch(aligned_dim, visited_reserve) +SSDThreadData::SSDThreadData(size_t aligned_dim, size_t visited_reserve, size_t max_degree, + size_t max_filters_per_query) + : scratch(aligned_dim, visited_reserve, max_degree, max_filters_per_query) { } From 07f21f360626607ef6379f585ff6ab389c24050e Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Thu, 20 Jun 2024 22:38:10 +0530 Subject: [PATCH 23/38] Adding correct #ifdef EXEC_ENV_OLS blocks --- include/abstract_data_store.h | 4 +++- include/in_mem_data_store.h | 4 +++- include/utils.h | 4 +++- src/in_mem_data_store.cpp | 10 +++++----- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/include/abstract_data_store.h b/include/abstract_data_store.h index 2b67fa9e4..5f816f4d4 100644 --- a/include/abstract_data_store.h +++ b/include/abstract_data_store.h @@ -24,7 +24,9 @@ template class AbstractDataStore // Return number of points returned virtual location_t load(const std::string &filename, size_t offset) = 0; - //virtual location_t load(AlignedFileReader &reader, size_t offset) = 0; +#ifdef EXEC_ENV_OLS + virtual location_t load(AlignedFileReader &reader, size_t offset) = 0; +#endif // Why does store take num_pts? Since store only has capacity, but we allow // resizing we can end up in a situation where the store has spare capacity. diff --git a/include/in_mem_data_store.h b/include/in_mem_data_store.h index 4bde7b483..3eb4e0518 100644 --- a/include/in_mem_data_store.h +++ b/include/in_mem_data_store.h @@ -25,7 +25,9 @@ template class InMemDataStore : public AbstractDataStore location_t InMemDataStore::load(const std::st return load_impl(filename, offset); } -//template location_t InMemDataStore::load(AlignedFileReader &reader, size_t offset) -//{ -// return load_impl(reader, offset); -//} - #ifdef EXEC_ENV_OLS +template location_t InMemDataStore::load(AlignedFileReader &reader, size_t offset) +{ + return load_impl(reader, offset); +} + template location_t InMemDataStore::load_impl(AlignedFileReader &reader, size_t offset) { size_t file_dim, file_num_points; From b2b0942eb0771980339f2c206708eafba63f433d Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Thu, 20 Jun 2024 22:45:59 +0530 Subject: [PATCH 24/38] Fixing more EXEC_ENV_OLS --- include/in_mem_data_store.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/in_mem_data_store.h b/include/in_mem_data_store.h index 3eb4e0518..6fcfe8018 100644 --- a/include/in_mem_data_store.h +++ b/include/in_mem_data_store.h @@ -65,7 +65,7 @@ template class InMemDataStore : public AbstractDataStore Date: Fri, 5 Jul 2024 15:07:37 +0530 Subject: [PATCH 25/38] Adding cosine support in build_disk_index and ensuring that the dummy map file is written in the correct location --- apps/build_disk_index.cpp | 12 ++++++++++-- src/disk_utils.cpp | 4 +++- src/pq_flash_index.cpp | 20 ++++++++++++++++++++ 3 files changed, 33 insertions(+), 3 deletions(-) diff --git a/apps/build_disk_index.cpp b/apps/build_disk_index.cpp index b617a5f4a..759522ffe 100644 --- a/apps/build_disk_index.cpp +++ b/apps/build_disk_index.cpp @@ -103,13 +103,21 @@ int main(int argc, char **argv) bool use_filters = (label_file != "") ? true : false; diskann::Metric metric; - if (dist_fn == std::string("l2")) + if (dist_fn == std::string("l2")) + { metric = diskann::Metric::L2; + } else if (dist_fn == std::string("mips")) + { metric = diskann::Metric::INNER_PRODUCT; + } + else if (dist_fn == std::string("cosine")) + { + metric = diskann::Metric::COSINE; + } else { - std::cout << "Error. Only l2 and mips distance functions are supported" << std::endl; + std::cout << "Error. Only l2, cosine, and mips distance functions are supported" << std::endl; return -1; } diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index 297619b4a..26d739f36 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -1239,7 +1239,9 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const augmented_labels_file = index_prefix_path + "_augmented_labels.txt"; if (filter_threshold != 0) { - dummy_remap_file = index_prefix_path + "_dummy_remap.txt"; + //Changing this filename to "_disk.index_dummy_map.txt" from "_dummy_map.txt" to conform + //to the convention that index files all share the _disk.index prefix. + dummy_remap_file = index_prefix_path + "_disk.index_dummy_map.txt"; breakup_dense_points(data_file_to_use, labels_file_to_use, filter_threshold, augmented_data_file, augmented_labels_file, dummy_remap_file); // RKNOTE: This has large memory footprint, diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 11985a16f..1db76351b 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -876,6 +876,8 @@ template void PQFlashIndex::load_labels #ifndef EXEC_ENV_OLS infile.close(); #endif + diskann::cout << "Labels file: " << labels_file << " loaded with " << num_pts_in_label_file << " points" + << std::endl; #ifdef EXEC_ENV_OLS FileContent &content_labels_map = files.getContent(labels_map_file); @@ -889,6 +891,8 @@ template void PQFlashIndex::load_labels map_reader.close(); #endif + diskann::cout << "Labels map file: " << labels_map_file << " loaded." << std::endl; + #ifdef EXEC_ENV_OLS if (files.fileExists(labels_to_medoids)) { @@ -902,7 +906,16 @@ template void PQFlashIndex::load_labels assert(medoid_stream.is_open()); #endif load_label_medoid_map(labels_to_medoids, medoid_stream); + diskann::cout << "Loaded labels_to_medoids map from: " << labels_to_medoids << std::endl; + } + else + { + std::stringstream ss; + ss << "Filter support is enabled but " << labels_to_medoids << " file cannot be opened." << std::endl; + diskann::cerr << ss.str(); + throw diskann::ANNException(ss.str(), -1); } + std::string univ_label_file = std ::string(_disk_index_file) + "_universal_label.txt"; #ifdef EXEC_ENV_OLS @@ -944,6 +957,13 @@ template void PQFlashIndex::load_labels #endif diskann::cout << "Loaded dummy map" << std::endl; } + else + { + std::stringstream ss; + ss << "Note: Filter support is enabled but " << dummy_map_file << " file cannot be opened" << std::endl; + diskann::cerr << ss.str(); + } + } else { From 98141ce1a874959812eca2c5b7984f1cf676ff2a Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Sat, 20 Jul 2024 12:50:50 +0530 Subject: [PATCH 26/38] Creating version of DiskANN dll for build with dependency on tcmalloc --- src/dll/CMakeLists.txt | 25 +++++++++++++++++++++++++ src/pq.cpp | 5 +++++ 2 files changed, 30 insertions(+) diff --git a/src/dll/CMakeLists.txt b/src/dll/CMakeLists.txt index d10fc951a..1dc8bf980 100644 --- a/src/dll/CMakeLists.txt +++ b/src/dll/CMakeLists.txt @@ -11,6 +11,10 @@ else() ../windows_aligned_file_reader.cpp ../distance.cpp ../memory_mapper.cpp ../index.cpp ../in_mem_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) + add_library(${PROJECT_NAME}_build SHARED dllmain.cpp ../abstract_data_store.cpp ../partition.cpp ../pq.cpp ../pq_flash_index.cpp ../logger.cpp ../utils.cpp + ../windows_aligned_file_reader.cpp ../distance.cpp ../memory_mapper.cpp ../index.cpp + ../in_mem_data_store.cpp ../in_mem_graph_store.cpp ../math_utils.cpp ../disk_utils.cpp ../filter_utils.cpp + ../ann_exception.cpp ../natural_number_set.cpp ../natural_number_map.cpp ../scratch.cpp ../index_factory.cpp ../abstract_index.cpp) endif() set(TARGET_DIR "$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_DEBUG}>$<$:${CMAKE_LIBRARY_OUTPUT_DIRECTORY_RELEASE}>") @@ -20,17 +24,34 @@ if (DISKANN_USE_STATIC_LIB) else() set(DISKANN_DLL_IMPLIB "${TARGET_DIR}/${PROJECT_NAME}.lib") target_compile_definitions(${PROJECT_NAME} PRIVATE _USRDLL _WINDLL) + + set(DISKANN_DLL_IMPLIB_BUILD "${TARGET_DIR}/${PROJECT_NAME}_build.lib") + target_compile_definitions(${PROJECT_NAME}_build PRIVATE _USRDLL _WINDLL) + target_compile_definitions(${PROJECT_NAME}_build PRIVATE DISKANN_BUILD) endif() target_compile_options(${PROJECT_NAME} PRIVATE /GL) target_include_directories(${PROJECT_NAME} PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES}) +target_compile_options(${PROJECT_NAME}_build PRIVATE /GL) +target_include_directories(${PROJECT_NAME}_build PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES}) + + if (NOT DEFINED DISKANN_USE_STATIC_LIB) target_link_options(${PROJECT_NAME} PRIVATE /DLL /IMPLIB:${DISKANN_DLL_IMPLIB} /LTCG) + target_link_options(${PROJECT_NAME}_build PRIVATE /DLL /IMPLIB:${DISKANN_DLL_IMPLIB_BUILD} /LTCG) endif() target_link_libraries(${PROJECT_NAME} PRIVATE ${DISKANN_MKL_LINK_LIBRARIES}) target_link_libraries(${PROJECT_NAME} PRIVATE synchronization.lib) +target_link_libraries(${PROJECT_NAME}_build PRIVATE ${DISKANN_MKL_LINK_LIBRARIES}) +target_link_libraries(${PROJECT_NAME}_build PRIVATE synchronization.lib) + +#This is the crux of the build dll +target_link_libraries(${PROJECT_NAME}_build PUBLIC ${DISKANN_DLL_TCMALLOC_LINK_OPTIONS}) +set_target_properties(${PROJECT_NAME}_build PROPERTIES LINK_FLAGS /INCLUDE:_tcmalloc) + + if (DISKANN_DLL_TCMALLOC_LINK_OPTIONS) target_link_libraries(${PROJECT_NAME} PUBLIC ${DISKANN_DLL_TCMALLOC_LINK_OPTIONS}) @@ -43,4 +64,8 @@ foreach(RUNTIME_FILE ${RUNTIME_FILES_TO_COPY}) add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy "${RUNTIME_FILE}" "${TARGET_DIR}") + add_custom_command(TARGET ${PROJECT_NAME}_build + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy "${RUNTIME_FILE}" "${TARGET_DIR}") + endforeach() \ No newline at end of file diff --git a/src/pq.cpp b/src/pq.cpp index c59fc2dce..edb0c7ad5 100644 --- a/src/pq.cpp +++ b/src/pq.cpp @@ -8,6 +8,11 @@ #include "math_utils.h" #include "tsl/robin_map.h" +#if defined(RELEASE_UNUSED_TCMALLOC_MEMORY_AT_CHECKPOINTS) && defined(DISKANN_BUILD) +#include "gperftools/malloc_extension.h" +#endif + + // block size for reading/processing large files and matrices in blocks #define BLOCK_SIZE 5000000 From 96e1751a39d15004131c773f4d353c8a082ecac7 Mon Sep 17 00:00:00 2001 From: Gopal Srinivasa Date: Sun, 21 Jul 2024 00:08:59 +0530 Subject: [PATCH 27/38] Delete the mem index file and sample files --- src/disk_utils.cpp | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index 26d739f36..3a26dd860 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -1169,7 +1169,7 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const std::string mem_univ_label_file = mem_index_path + "_universal_label.txt"; std::string disk_univ_label_file = disk_index_path + "_universal_label.txt"; std::string disk_labels_int_map_file = disk_index_path + "_labels_map.txt"; - std::string dummy_remap_file = disk_index_path + "_dummy_remap.txt"; // remap will be used if we break-up points of + std::string dummy_remap_file = disk_index_path + "_dummy_map.txt"; // remap will be used if we break-up points of // high label-density to create copies std::string sample_base_prefix = index_prefix_path + "_sample"; @@ -1239,9 +1239,6 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const augmented_labels_file = index_prefix_path + "_augmented_labels.txt"; if (filter_threshold != 0) { - //Changing this filename to "_disk.index_dummy_map.txt" from "_dummy_map.txt" to conform - //to the convention that index files all share the _disk.index prefix. - dummy_remap_file = index_prefix_path + "_disk.index_dummy_map.txt"; breakup_dense_points(data_file_to_use, labels_file_to_use, filter_threshold, augmented_data_file, augmented_labels_file, dummy_remap_file); // RKNOTE: This has large memory footprint, @@ -1311,11 +1308,11 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const } diskann::cout << timer.elapsed_seconds_for_step("generating disk layout") << std::endl; - double ten_percent_points = std::ceil(points_num * 0.1); - double num_sample_points = - ten_percent_points > MAX_SAMPLE_POINTS_FOR_WARMUP ? MAX_SAMPLE_POINTS_FOR_WARMUP : ten_percent_points; - double sample_sampling_rate = num_sample_points / points_num; - gen_random_slice(data_file_to_use.c_str(), sample_base_prefix, sample_sampling_rate); + //double ten_percent_points = std::ceil(points_num * 0.1); + //double num_sample_points = + // ten_percent_points > MAX_SAMPLE_POINTS_FOR_WARMUP ? MAX_SAMPLE_POINTS_FOR_WARMUP : ten_percent_points; + //double sample_sampling_rate = num_sample_points / points_num; + //gen_random_slice(data_file_to_use.c_str(), sample_base_prefix, sample_sampling_rate); if (use_filters) { copy_file(labels_file_to_use, disk_labels_file); @@ -1331,6 +1328,7 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const } std::remove(mem_index_path.c_str()); + std::remove((mem_index_path + ".data").c_str()); if (use_disk_pq) std::remove(disk_pq_compressed_vectors_path.c_str()); From da7416a82bd7a4cf55249dd8cb445ebe5b633224 Mon Sep 17 00:00:00 2001 From: litan1 <106347144+ltan1ms@users.noreply.github.com> Date: Fri, 28 Mar 2025 08:46:18 -0700 Subject: [PATCH 28/38] Use much smaller scratch space for non filter case (#639) * Use much smaller scratch space for non filter case * fix build error --- include/pq_flash_index.h | 6 ++++++ src/pq_flash_index.cpp | 29 +++++++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index 21381b3fe..befde1a1a 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -36,6 +36,12 @@ template class PQFlashIndex DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix); #endif +#ifdef EXEC_ENV_OLS + DISKANN_DLLEXPORT bool use_filter_support(MemoryMappedFiles &files); +#else + DISKANN_DLLEXPORT bool use_filter_support(); +#endif + #ifdef EXEC_ENV_OLS DISKANN_DLLEXPORT void load_labels(MemoryMappedFiles &files, const std::string &disk_index_file); #else diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index 1db76351b..179e9a71e 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -843,6 +843,23 @@ void PQFlashIndex::load_dummy_map(const std::string &dummy_map_filepa throw FileException (dummy_map_filepath, e, __FUNCSIG__, __FILE__, __LINE__); } } + + +template +#ifdef EXEC_ENV_OLS +bool PQFlashIndex::use_filter_support(MemoryMappedFiles &files) +#else +bool PQFlashIndex::use_filter_support() +#endif +{ + std::string labels_file = _disk_index_file + "_labels.txt"; +#ifdef EXEC_ENV_OLS + return files.fileExists(labels_file); +#else + return file_exists(labels_file); +#endif +} + #ifdef EXEC_ENV_OLS template void PQFlashIndex::load_labels(MemoryMappedFiles &files, const std::string &disk_index_file) @@ -1090,7 +1107,11 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons // bytes are needed to store the header and read in that many using our // 'standard' aligned file reader approach. reader->open(_disk_index_file); - this->setup_thread_data(num_threads); + this->setup_thread_data( + num_threads, + defaults::VISITED_RESERVE, + defaults::MAX_GRAPH_DEGREE, + (use_filter_support(files)? defaults::MAX_FILTERS_PER_QUERY : 0)); this->_max_nthreads = num_threads; char *bytes = getHeaderBytes(); @@ -1180,7 +1201,11 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons // open AlignedFileReader handle to index_file std::string index_fname(_disk_index_file); reader->open(index_fname); - this->setup_thread_data(num_threads); + this->setup_thread_data( + num_threads, + defaults::VISITED_RESERVE, + defaults::MAX_GRAPH_DEGREE, + (use_filter_support()? defaults::MAX_FILTERS_PER_QUERY : 0)); this->_max_nthreads = num_threads; #endif From 69fab8841947434cc8de9751694ddcf40e74923e Mon Sep 17 00:00:00 2001 From: litan1 <106347144+ltan1ms@users.noreply.github.com> Date: Tue, 8 Apr 2025 05:27:26 -0700 Subject: [PATCH 29/38] Fix src\dll\CMakeLists.txt to have diskann_build target only for DLL case --- src/dll/CMakeLists.txt | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/dll/CMakeLists.txt b/src/dll/CMakeLists.txt index 1dc8bf980..090d3b096 100644 --- a/src/dll/CMakeLists.txt +++ b/src/dll/CMakeLists.txt @@ -33,25 +33,25 @@ endif() target_compile_options(${PROJECT_NAME} PRIVATE /GL) target_include_directories(${PROJECT_NAME} PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES}) -target_compile_options(${PROJECT_NAME}_build PRIVATE /GL) -target_include_directories(${PROJECT_NAME}_build PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES}) - - if (NOT DEFINED DISKANN_USE_STATIC_LIB) + target_compile_options(${PROJECT_NAME}_build PRIVATE /GL) + target_include_directories(${PROJECT_NAME}_build PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES}) + target_link_options(${PROJECT_NAME} PRIVATE /DLL /IMPLIB:${DISKANN_DLL_IMPLIB} /LTCG) target_link_options(${PROJECT_NAME}_build PRIVATE /DLL /IMPLIB:${DISKANN_DLL_IMPLIB_BUILD} /LTCG) endif() target_link_libraries(${PROJECT_NAME} PRIVATE ${DISKANN_MKL_LINK_LIBRARIES}) target_link_libraries(${PROJECT_NAME} PRIVATE synchronization.lib) -target_link_libraries(${PROJECT_NAME}_build PRIVATE ${DISKANN_MKL_LINK_LIBRARIES}) -target_link_libraries(${PROJECT_NAME}_build PRIVATE synchronization.lib) - -#This is the crux of the build dll -target_link_libraries(${PROJECT_NAME}_build PUBLIC ${DISKANN_DLL_TCMALLOC_LINK_OPTIONS}) -set_target_properties(${PROJECT_NAME}_build PROPERTIES LINK_FLAGS /INCLUDE:_tcmalloc) - +if (NOT DEFINED DISKANN_USE_STATIC_LIB) + target_link_libraries(${PROJECT_NAME}_build PRIVATE ${DISKANN_MKL_LINK_LIBRARIES}) + target_link_libraries(${PROJECT_NAME}_build PRIVATE synchronization.lib) + + #This is the crux of the build dll + target_link_libraries(${PROJECT_NAME}_build PUBLIC ${DISKANN_DLL_TCMALLOC_LINK_OPTIONS}) + set_target_properties(${PROJECT_NAME}_build PROPERTIES LINK_FLAGS /INCLUDE:_tcmalloc) +endif() if (DISKANN_DLL_TCMALLOC_LINK_OPTIONS) target_link_libraries(${PROJECT_NAME} PUBLIC ${DISKANN_DLL_TCMALLOC_LINK_OPTIONS}) @@ -64,8 +64,10 @@ foreach(RUNTIME_FILE ${RUNTIME_FILES_TO_COPY}) add_custom_command(TARGET ${PROJECT_NAME} POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy "${RUNTIME_FILE}" "${TARGET_DIR}") +if (NOT DEFINED DISKANN_USE_STATIC_LIB) add_custom_command(TARGET ${PROJECT_NAME}_build POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy "${RUNTIME_FILE}" "${TARGET_DIR}") +endif() -endforeach() \ No newline at end of file +endforeach() From 963918f0362114b3515a74c6af77d71c09e0e086 Mon Sep 17 00:00:00 2001 From: Amr Hisham Said Morsey Date: Tue, 8 Apr 2025 18:10:55 +0200 Subject: [PATCH 30/38] New Allocator (#621) for DLVS code path, replace windows API memory allocation to allow custom allocator. --- include/utils.h | 23 ++++++++++++++++++++--- src/in_mem_data_store.cpp | 6 +++--- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/include/utils.h b/include/utils.h index abce21f62..8e5ddaf56 100644 --- a/include/utils.h +++ b/include/utils.h @@ -263,7 +263,11 @@ inline void alloc_aligned(void **ptr, size_t size, size_t align) #ifndef _WINDOWS *ptr = ::aligned_alloc(align, size); #else - *ptr = ::_aligned_malloc(size, align); // note the swapped arguments! + #ifdef EXEC_ENV_OLS + *ptr = operator new(size, std::align_val_t(align)); + #else + *ptr = ::_aligned_malloc(size, align); // note the swapped arguments! + #endif #endif if (*ptr == nullptr) report_memory_allocation_failure(); @@ -274,7 +278,15 @@ inline void realloc_aligned(void **ptr, size_t size, size_t align) if (IS_ALIGNED(size, align) == 0) report_misalignment_of_requested_size(align); #ifdef _WINDOWS - *ptr = ::_aligned_realloc(*ptr, size, align); + #ifdef EXEC_ENV_OLS + void *newptr; + alloc_aligned(&newptr, size, align); + std::memcpy(newptr, *ptr, size); + operator delete(*ptr, std::align_val_t(1)); + *ptr = newptr; + #else + *ptr = ::_aligned_realloc(*ptr, size, align); + #endif #else diskann::cerr << "No aligned realloc on GCC. Must malloc and mem_align, " "left it out for now." @@ -302,7 +314,12 @@ inline void aligned_free(void *ptr) #ifndef _WINDOWS free(ptr); #else - ::_aligned_free(ptr); + #ifdef EXEC_ENV_OLS + operator delete(ptr, std::align_val_t(1)); + ptr = nullptr; + #else + ::_aligned_free(ptr); + #endif #endif } diff --git a/src/in_mem_data_store.cpp b/src/in_mem_data_store.cpp index e3ed84f3b..fde09febd 100644 --- a/src/in_mem_data_store.cpp +++ b/src/in_mem_data_store.cpp @@ -222,7 +222,7 @@ template location_t InMemDataStore::expand(const locat << this->capacity() << ")" << std::endl; throw diskann::ANNException(ss.str(), -1); } -#ifndef _WINDOWS +#if !defined(_WINDOWS) || defined(EXEC_ENV_OLS) data_t *new_data; alloc_aligned((void **)&new_data, new_size * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t)); memcpy(new_data, _data, this->capacity() * _aligned_dim * sizeof(data_t)); @@ -248,7 +248,7 @@ template location_t InMemDataStore::shrink(const locat << this->capacity() << ")" << std::endl; throw diskann::ANNException(ss.str(), -1); } -#ifndef _WINDOWS +#if !defined(_WINDOWS) || defined(EXEC_ENV_OLS) data_t *new_data; alloc_aligned((void **)&new_data, new_size * _aligned_dim * sizeof(data_t), 8 * sizeof(data_t)); memcpy(new_data, _data, new_size * _aligned_dim * sizeof(data_t)); @@ -378,4 +378,4 @@ template DISKANN_DLLEXPORT class InMemDataStore; template DISKANN_DLLEXPORT class InMemDataStore; template DISKANN_DLLEXPORT class InMemDataStore; -} // namespace diskann \ No newline at end of file +} // namespace diskann From f1394f2811a8ccfdba0bf538e41b24a9206b612e Mon Sep 17 00:00:00 2001 From: jinweizhang Date: Fri, 26 Sep 2025 00:31:09 -0700 Subject: [PATCH 31/38] filtering support --- apps/search_memory_index.cpp | 20 ++++- include/abstract_index.h | 15 ++++ include/index.h | 14 ++++ src/abstract_index.cpp | 149 +++++++++++++++++++++++++++++++++++ src/index.cpp | 88 +++++++++++++++++++++ 5 files changed, 285 insertions(+), 1 deletion(-) diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 1bb02c9bc..cf4e5ecea 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -143,6 +143,11 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } double best_recall = 0.0; + std::uint32_t value = 2; + std::function callback_func = [value](const uint32_t &id, float &reRankScore) -> bool { + diskann::cout << "check values for ID: " << id << std::endl; + return id != value; + }; for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) { @@ -186,6 +191,16 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r]; } } + else if (callback_func) + { + + index->search_with_callback(query + i * query_aligned_dim, recall_at, L, + query_result_tags.data() + i * recall_at, nullptr, res, callback_func); + for (int64_t r = 0; r < (int64_t)recall_at; r++) + { + query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r]; + } + } else { cmp_stats[i] = index @@ -320,6 +335,9 @@ int main(int argc, char **argv) optional_configs.add_options()("fail_if_recall_below", po::value(&fail_if_recall_below)->default_value(0.0f), program_options_utils::FAIL_IF_RECALL_BELOW); + //optional_configs.add_options()("callback_func", + // po::value>(&callback_func)->default_value(nullptr), + // program_options_utils::FAIL_IF_CALLBACK_FAIL); // Output controls po::options_description output_controls("Output controls"); @@ -441,7 +459,7 @@ int main(int argc, char **argv) else if (data_type == std::string("uint8")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); } else if (data_type == std::string("float")) diff --git a/include/abstract_index.h b/include/abstract_index.h index 12feec663..031381209 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -64,6 +64,16 @@ class AbstractIndex size_t search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, float *distances, std::vector &res_vectors); + // Initialize space for res_vectors before calling. + template + size_t search_with_callback(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, float *distances, std::vector &res_vectors, const std::function callback); + + + //// Initialize space for res_vectors before calling. + //template + //size_t search_with_callback(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, + // float *distances, std::vector &res_vectors, const std::function &callback); + // Added search overload that takes L as parameter, so that we // can customize L on a per-query basis without tampering with "Parameters" // IDtype is either uint32_t or uint64_t @@ -121,6 +131,11 @@ class AbstractIndex virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0; virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, float *distances, DataVector &res_vectors) = 0; + virtual size_t _search_with_callback(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, + float *distances, DataVector &res_vectors, const std::function callback) = 0; + + //virtual size_t _search_with_callback(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, + // float *distances, DataVector &res_vectors, const std::function &callback) = 0; virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0; virtual void _set_universal_label(const LabelType universal_label) = 0; }; diff --git a/include/index.h b/include/index.h index fd5db1488..f2addfa67 100644 --- a/include/index.h +++ b/include/index.h @@ -154,6 +154,15 @@ template clas // Initialize space for res_vectors before calling. DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags, float *distances, std::vector &res_vectors); + // Initialize space for res_vectors before calling. + DISKANN_DLLEXPORT size_t search_with_callback(const T *query, const uint64_t K, const uint32_t L, TagT *tags, + float *distances, std::vector &res_vectors, + const std::function callback); + + //// Initialize space for res_vectors before calling. + //DISKANN_DLLEXPORT size_t search_with_callback(const T *query, const uint64_t K, const uint32_t L, TagT *tags, + // float *distances, std::vector &res_vectors, + // const std::function &callback); // Filter support search template @@ -247,6 +256,11 @@ template clas virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, float *distances, DataVector &res_vectors) override; + virtual size_t _search_with_callback(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, + float *distances, DataVector &res_vectors, const std::function callback) override; + + //virtual size_t _search_with_callback(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, + // float *distances, DataVector &res_vectors, const std::function &callback) override; virtual void _set_universal_label(const LabelType universal_label) override; diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index a7a5986cc..2dc3eb595 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -32,6 +32,28 @@ size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, return this->_search_with_tags(any_query, K, L, any_tags, distances, any_res_vectors); } +template +size_t AbstractIndex::search_with_callback(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, + float *distances, std::vector &res_vectors, + const std::function callback) +{ + auto any_query = std::any(query); + auto any_tags = std::any(tags); + auto any_res_vectors = DataVector(res_vectors); + return this->_search_with_callback(any_query, K, L, any_tags, distances, any_res_vectors, callback); +} + +//template +//size_t AbstractIndex::search_with_callback(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, +// float *distances, std::vector &res_vectors, +// const std::function &callback) +//{ +// auto any_query = std::any(query); +// auto any_tags = std::any(tags); +// auto any_res_vectors = DataVector(res_vectors); +// return this->_search_with_callback(any_query, K, L, any_tags, distances, any_res_vectors, callback); +//} + template std::pair AbstractIndex::search_with_filters(const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, IndexType *indices, @@ -218,6 +240,133 @@ template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags &res_vectors); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const float *query, const uint64_t K, + const uint32_t L, int32_t *tags, + float *distances, + std::vector &res_vectors, const std::function callback); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( + const uint8_t *query, const uint64_t K, const uint32_t L, + int32_t *tags, float *distances, std::vector &res_vectors, const std::function callback); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( + const int8_t *query, + const uint64_t K, const uint32_t L, + int32_t *tags, float *distances, + std::vector &res_vectors, const std::function callback); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( + const float *query, const uint64_t K, + const uint32_t L, uint32_t *tags, + float *distances, + std::vector &res_vectors, const std::function callback); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( + const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, + std::vector &res_vectors, const std::function callback); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( + const int8_t *query, + const uint64_t K, const uint32_t L, + uint32_t *tags, float *distances, + std::vector &res_vectors, const std::function callback); + +template DISKANN_DLLEXPORT size_t +AbstractIndex::search_with_callback(const float *query, const uint64_t K, + const uint32_t L, int64_t *tags, + float *distances, + std::vector &res_vectors, const std::function callback); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( + const uint8_t *query, const uint64_t K, const uint32_t L, + int64_t *tags, float *distances, std::vector &res_vectors, const std::function callback); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( + const int8_t *query, + const uint64_t K, const uint32_t L, + int64_t *tags, float *distances, + std::vector &res_vectors, const std::function callback); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( + const float *query, const uint64_t K, + const uint32_t L, uint64_t *tags, + float *distances, + std::vector &res_vectors, const std::function callback); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( + const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, + std::vector &res_vectors, const std::function callback); + +template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( + const int8_t *query, + const uint64_t K, const uint32_t L, + uint64_t *tags, float *distances, + std::vector &res_vectors, const std::function callback); + + + +////-- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - +// +//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const float *query, const uint64_t K, +// const uint32_t L, int32_t *tags, +// float *distances, +// std::vector &res_vectors,const std::function &callback); +// +//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( +// const uint8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, +// std::vector &res_vectors, const std::function &callback); +// +//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const int8_t *query, +// const uint64_t K, const uint32_t L, +// int32_t *tags, float *distances, +// std::vector &res_vectors,const std::function &callback); +// +//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const float *query, const uint64_t K, +// const uint32_t L, uint32_t *tags, +// float *distances, +// std::vector &res_vectors,const std::function &callback); +// +//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( +// const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, +// std::vector &res_vectors,const std::function &callback); +// +//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const int8_t *query, +// const uint64_t K, const uint32_t L, +// uint32_t *tags, float *distances, +// std::vector &res_vectors,const std::function &callback); +// +//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const float *query, const uint64_t K, +// const uint32_t L, int64_t *tags, +// float *distances, +// std::vector &res_vectors,const std::function &callback); +// +//template DISKANN_DLLEXPORT size_t +//AbstractIndex::search_with_callback(const uint8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, +// std::vector &res_vectors, const std::function &callback); +// +//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const int8_t *query, +// const uint64_t K, const uint32_t L, +// int64_t *tags, float *distances, +// std::vector &res_vectors, const std::function &callback); +// +//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const float *query, const uint64_t K, +// const uint32_t L, uint64_t *tags, +// float *distances, +// std::vector &res_vectors, const std::function &callback); +// +//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( +// const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, +// std::vector &res_vectors, const std::function &callback); +// +//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const int8_t *query, +// const uint64_t K, const uint32_t L, +// uint64_t *tags, float *distances, +// std::vector &res_vectors, const std::function &callback); + +//-- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - + + template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const float *query, size_t K, size_t L, uint32_t *indices); template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const uint8_t *query, size_t K, diff --git a/src/index.cpp b/src/index.cpp index 92007fbb0..9ccef7f14 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2500,6 +2500,94 @@ size_t Index::search_with_tags(const T *query, const uint64_t K return pos; } +template +size_t Index::_search_with_callback(const DataType &query, const uint64_t K, const uint32_t L, + const TagType &tags, float *distances, DataVector &res_vectors, + const std::function callback) +{ + try + { + return this->search_with_callback(std::any_cast(query), K, L, std::any_cast(tags), distances, + res_vectors.get>(), callback); + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while performing _search_with_tags() " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } +} + +template +size_t Index::search_with_callback(const T *query, const uint64_t K, const uint32_t L, TagT *tags, + float *distances, std::vector &res_vectors, + const std::function callback) +{ + if (K > (uint64_t)L) + { + throw ANNException("Set L to a value of at least K", -1, __FUNCSIG__, __FILE__, __LINE__); + } + ScratchStoreManager> manager(_query_scratch); + auto scratch = manager.scratch_space(); + + if (L > scratch->get_L()) + { + diskann::cout << "Attempting to expand query scratch_space. Was created " + << "with Lsize: " << scratch->get_L() << " but search L is: " << L << std::endl; + scratch->resize_for_new_L(L); + diskann::cout << "Resize completed. New scratch->L is " << scratch->get_L() << std::endl; + } + + std::shared_lock ul(_update_lock); + + const std::vector init_ids = get_init_ids(); + const std::vector unused_filter_label; + + //_distance->preprocess_query(query, _data_store->get_dims(), + // scratch->aligned_query()); + _data_store->get_dist_fn()->preprocess_query(query, _data_store->get_dims(), scratch->aligned_query()); + iterate_to_fixed_point(scratch->aligned_query(), L, init_ids, scratch, false, unused_filter_label, true); + + NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); + assert(best_L_nodes.size() <= L); + + std::shared_lock tl(_tag_lock); + + size_t pos = 0; + for (size_t i = 0; i < best_L_nodes.size(); ++i) + { + auto node = best_L_nodes[i]; + + TagT tag; + if (_location_to_tag.try_get(node.id, tag)) + { + tags[pos] = tag; + + if (res_vectors.size() > 0) + { + _data_store->get_vector(node.id, res_vectors[pos]); + } + + if (distances != nullptr) + { +#ifdef EXEC_ENV_OLS + distances[pos] = node.distance; // DLVS expects negative distances +#else + distances[pos] = _dist_metric == INNER_PRODUCT ? -1 * node.distance : node.distance; +#endif + } + pos++; + // If res_vectors.size() < k, clip at the value. + if (pos == K || pos == res_vectors.size()) + break; + } + } + + return pos; +} + template size_t Index::get_num_points() { std::shared_lock tl(_tag_lock); From bf5b03dd5cf4f8ab2cf5faa5ef94ebf2618850a1 Mon Sep 17 00:00:00 2001 From: jinwei14 Date: Fri, 26 Sep 2025 01:10:44 -0700 Subject: [PATCH 32/38] switch git account, switch from uint32 to int64 --- apps/search_memory_index.cpp | 8 ++- include/abstract_index.h | 10 +--- include/index.h | 12 +---- src/abstract_index.cpp | 100 +++++------------------------------ src/index.cpp | 4 +- 5 files changed, 22 insertions(+), 112 deletions(-) diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index cf4e5ecea..e0e35792f 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -143,8 +143,9 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } double best_recall = 0.0; - std::uint32_t value = 2; - std::function callback_func = [value](const uint32_t &id, float &reRankScore) -> bool { + std::int64_t value = 2; + std::function callback_func = [value](const int64_t &id, + float &reRankScore) -> bool { diskann::cout << "check values for ID: " << id << std::endl; return id != value; }; @@ -335,9 +336,6 @@ int main(int argc, char **argv) optional_configs.add_options()("fail_if_recall_below", po::value(&fail_if_recall_below)->default_value(0.0f), program_options_utils::FAIL_IF_RECALL_BELOW); - //optional_configs.add_options()("callback_func", - // po::value>(&callback_func)->default_value(nullptr), - // program_options_utils::FAIL_IF_CALLBACK_FAIL); // Output controls po::options_description output_controls("Output controls"); diff --git a/include/abstract_index.h b/include/abstract_index.h index 031381209..496eac65e 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -66,13 +66,7 @@ class AbstractIndex // Initialize space for res_vectors before calling. template - size_t search_with_callback(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, float *distances, std::vector &res_vectors, const std::function callback); - - - //// Initialize space for res_vectors before calling. - //template - //size_t search_with_callback(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, - // float *distances, std::vector &res_vectors, const std::function &callback); + size_t search_with_callback(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, float *distances, std::vector &res_vectors, const std::function callback); // Added search overload that takes L as parameter, so that we // can customize L on a per-query basis without tampering with "Parameters" @@ -132,7 +126,7 @@ class AbstractIndex virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, float *distances, DataVector &res_vectors) = 0; virtual size_t _search_with_callback(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, - float *distances, DataVector &res_vectors, const std::function callback) = 0; + float *distances, DataVector &res_vectors, const std::function callback) = 0; //virtual size_t _search_with_callback(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, // float *distances, DataVector &res_vectors, const std::function &callback) = 0; diff --git a/include/index.h b/include/index.h index f2addfa67..6b8fd2797 100644 --- a/include/index.h +++ b/include/index.h @@ -157,12 +157,7 @@ template clas // Initialize space for res_vectors before calling. DISKANN_DLLEXPORT size_t search_with_callback(const T *query, const uint64_t K, const uint32_t L, TagT *tags, float *distances, std::vector &res_vectors, - const std::function callback); - - //// Initialize space for res_vectors before calling. - //DISKANN_DLLEXPORT size_t search_with_callback(const T *query, const uint64_t K, const uint32_t L, TagT *tags, - // float *distances, std::vector &res_vectors, - // const std::function &callback); + const std::function callback); // Filter support search template @@ -257,10 +252,7 @@ template clas virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, float *distances, DataVector &res_vectors) override; virtual size_t _search_with_callback(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, - float *distances, DataVector &res_vectors, const std::function callback) override; - - //virtual size_t _search_with_callback(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, - // float *distances, DataVector &res_vectors, const std::function &callback) override; + float *distances, DataVector &res_vectors, const std::function callback) override; virtual void _set_universal_label(const LabelType universal_label) override; diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index 2dc3eb595..a0b6f2570 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -35,25 +35,13 @@ size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, template size_t AbstractIndex::search_with_callback(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, float *distances, std::vector &res_vectors, - const std::function callback) + const std::function callback) { auto any_query = std::any(query); auto any_tags = std::any(tags); auto any_res_vectors = DataVector(res_vectors); return this->_search_with_callback(any_query, K, L, any_tags, distances, any_res_vectors, callback); } - -//template -//size_t AbstractIndex::search_with_callback(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, -// float *distances, std::vector &res_vectors, -// const std::function &callback) -//{ -// auto any_query = std::any(query); -// auto any_tags = std::any(tags); -// auto any_res_vectors = DataVector(res_vectors); -// return this->_search_with_callback(any_query, K, L, any_tags, distances, any_res_vectors, callback); -//} - template std::pair AbstractIndex::search_with_filters(const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, IndexType *indices, @@ -244,127 +232,65 @@ template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const uint8_t *query, const uint64_t K, const uint32_t L, - int32_t *tags, float *distances, std::vector &res_vectors, const std::function callback); + int32_t *tags, float *distances, std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const int8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const float *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const int8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const float *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const uint8_t *query, const uint64_t K, const uint32_t L, - int64_t *tags, float *distances, std::vector &res_vectors, const std::function callback); + int64_t *tags, float *distances, std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const int8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const float *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const int8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); - - - -////-- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - -// -//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const float *query, const uint64_t K, -// const uint32_t L, int32_t *tags, -// float *distances, -// std::vector &res_vectors,const std::function &callback); -// -//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( -// const uint8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, -// std::vector &res_vectors, const std::function &callback); -// -//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const int8_t *query, -// const uint64_t K, const uint32_t L, -// int32_t *tags, float *distances, -// std::vector &res_vectors,const std::function &callback); -// -//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const float *query, const uint64_t K, -// const uint32_t L, uint32_t *tags, -// float *distances, -// std::vector &res_vectors,const std::function &callback); -// -//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( -// const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, -// std::vector &res_vectors,const std::function &callback); -// -//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const int8_t *query, -// const uint64_t K, const uint32_t L, -// uint32_t *tags, float *distances, -// std::vector &res_vectors,const std::function &callback); -// -//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const float *query, const uint64_t K, -// const uint32_t L, int64_t *tags, -// float *distances, -// std::vector &res_vectors,const std::function &callback); -// -//template DISKANN_DLLEXPORT size_t -//AbstractIndex::search_with_callback(const uint8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, -// std::vector &res_vectors, const std::function &callback); -// -//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const int8_t *query, -// const uint64_t K, const uint32_t L, -// int64_t *tags, float *distances, -// std::vector &res_vectors, const std::function &callback); -// -//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const float *query, const uint64_t K, -// const uint32_t L, uint64_t *tags, -// float *distances, -// std::vector &res_vectors, const std::function &callback); -// -//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( -// const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, -// std::vector &res_vectors, const std::function &callback); -// -//template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const int8_t *query, -// const uint64_t K, const uint32_t L, -// uint64_t *tags, float *distances, -// std::vector &res_vectors, const std::function &callback); - -//-- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const float *query, size_t K, diff --git a/src/index.cpp b/src/index.cpp index 9ccef7f14..7ac6b5e39 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2503,7 +2503,7 @@ size_t Index::search_with_tags(const T *query, const uint64_t K template size_t Index::_search_with_callback(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, float *distances, DataVector &res_vectors, - const std::function callback) + const std::function callback) { try { @@ -2523,7 +2523,7 @@ size_t Index::_search_with_callback(const DataType &query, cons template size_t Index::search_with_callback(const T *query, const uint64_t K, const uint32_t L, TagT *tags, float *distances, std::vector &res_vectors, - const std::function callback) + const std::function callback) { if (K > (uint64_t)L) { From 9d44eda2f930d4c40316abc82143bc6daebe34f3 Mon Sep 17 00:00:00 2001 From: jinwei14 Date: Sat, 27 Sep 2025 20:25:49 -0700 Subject: [PATCH 33/38] add early stop. init --- apps/search_memory_index.cpp | 5 ++--- include/abstract_index.h | 4 ++-- include/index.h | 4 ++-- src/abstract_index.cpp | 26 +++++++++++++------------- src/index.cpp | 4 ++-- 5 files changed, 21 insertions(+), 22 deletions(-) diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index e0e35792f..46b44600a 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT license. #include @@ -144,8 +144,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, double best_recall = 0.0; std::int64_t value = 2; - std::function callback_func = [value](const int64_t &id, - float &reRankScore) -> bool { + std::function callback_func = [value](const int64_t &id, float &reRankScore, bool &earlystop) -> bool { diskann::cout << "check values for ID: " << id << std::endl; return id != value; }; diff --git a/include/abstract_index.h b/include/abstract_index.h index 496eac65e..637347708 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -66,7 +66,7 @@ class AbstractIndex // Initialize space for res_vectors before calling. template - size_t search_with_callback(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, float *distances, std::vector &res_vectors, const std::function callback); + size_t search_with_callback(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, float *distances, std::vector &res_vectors, const std::function callback); // Added search overload that takes L as parameter, so that we // can customize L on a per-query basis without tampering with "Parameters" @@ -126,7 +126,7 @@ class AbstractIndex virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, float *distances, DataVector &res_vectors) = 0; virtual size_t _search_with_callback(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, - float *distances, DataVector &res_vectors, const std::function callback) = 0; + float *distances, DataVector &res_vectors, const std::function callback) = 0; //virtual size_t _search_with_callback(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, // float *distances, DataVector &res_vectors, const std::function &callback) = 0; diff --git a/include/index.h b/include/index.h index 6b8fd2797..791680ff6 100644 --- a/include/index.h +++ b/include/index.h @@ -157,7 +157,7 @@ template clas // Initialize space for res_vectors before calling. DISKANN_DLLEXPORT size_t search_with_callback(const T *query, const uint64_t K, const uint32_t L, TagT *tags, float *distances, std::vector &res_vectors, - const std::function callback); + const std::function callback); // Filter support search template @@ -252,7 +252,7 @@ template clas virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, float *distances, DataVector &res_vectors) override; virtual size_t _search_with_callback(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, - float *distances, DataVector &res_vectors, const std::function callback) override; + float *distances, DataVector &res_vectors, const std::function callback) override; virtual void _set_universal_label(const LabelType universal_label) override; diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index a0b6f2570..e6686ec85 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -35,7 +35,7 @@ size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, template size_t AbstractIndex::search_with_callback(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, float *distances, std::vector &res_vectors, - const std::function callback) + const std::function callback) { auto any_query = std::any(query); auto any_tags = std::any(tags); @@ -232,65 +232,65 @@ template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_tags(const float *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const uint8_t *query, const uint64_t K, const uint32_t L, - int32_t *tags, float *distances, std::vector &res_vectors, const std::function callback); + int32_t *tags, float *distances, std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const int8_t *query, const uint64_t K, const uint32_t L, int32_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const float *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const uint8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const int8_t *query, const uint64_t K, const uint32_t L, uint32_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback(const float *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const uint8_t *query, const uint64_t K, const uint32_t L, - int64_t *tags, float *distances, std::vector &res_vectors, const std::function callback); + int64_t *tags, float *distances, std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const int8_t *query, const uint64_t K, const uint32_t L, int64_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const float *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const uint8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT size_t AbstractIndex::search_with_callback( const int8_t *query, const uint64_t K, const uint32_t L, uint64_t *tags, float *distances, - std::vector &res_vectors, const std::function callback); + std::vector &res_vectors, const std::function callback); template DISKANN_DLLEXPORT void AbstractIndex::search_with_optimized_layout(const float *query, size_t K, diff --git a/src/index.cpp b/src/index.cpp index 7ac6b5e39..d489eb225 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -2503,7 +2503,7 @@ size_t Index::search_with_tags(const T *query, const uint64_t K template size_t Index::_search_with_callback(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags, float *distances, DataVector &res_vectors, - const std::function callback) + const std::function callback) { try { @@ -2523,7 +2523,7 @@ size_t Index::_search_with_callback(const DataType &query, cons template size_t Index::search_with_callback(const T *query, const uint64_t K, const uint32_t L, TagT *tags, float *distances, std::vector &res_vectors, - const std::function callback) + const std::function callback) { if (K > (uint64_t)L) { From adcb6c8272f7b9c246ad9c0f629e96d834053824 Mon Sep 17 00:00:00 2001 From: jinwei14 Date: Sun, 28 Sep 2025 00:18:44 -0700 Subject: [PATCH 34/38] callback retrival id logic for internal id --- include/index.h | 13 +++ src/index.cpp | 227 +++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 238 insertions(+), 2 deletions(-) diff --git a/include/index.h b/include/index.h index 791680ff6..72e9f30f7 100644 --- a/include/index.h +++ b/include/index.h @@ -284,6 +284,19 @@ template clas InMemQueryScratch *scratch, bool use_filter, const std::vector &filters, bool search_invocation); + // Callback variant: callback(doc_id (int64_t), distance (in/out), early_terminate (out)) + // Return value of callback == true -> accept node, false -> skip adding node. + // If early_terminate is set to true by callback the search stops early. + std::pair iterate_to_fixed_point_callback( + const T *query, + uint32_t Lsize, + const std::vector &init_ids, + InMemQueryScratch *scratch, + bool use_filter, + const std::vector &filter_labels, + bool search_invocation, + const std::function &callback); + void search_for_point_and_prune(int location, uint32_t Lindex, std::vector &pruned_list, InMemQueryScratch *scratch, bool use_filter = false, uint32_t filteredLindex = 0); diff --git a/src/index.cpp b/src/index.cpp index d489eb225..0ec3cbd46 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1244,6 +1244,229 @@ std::pair Index::iterate_to_fixed_point( return std::make_pair(hops, cmps); } + +// ADD IMPLEMENTATION FOR CALLBACK BASED ITERATE TO FIXED POINT + +template +std::pair Index::iterate_to_fixed_point_callback( + const T *query, uint32_t Lsize, const std::vector &init_ids, InMemQueryScratch *scratch, + bool use_filter, const std::vector &filter_labels, bool search_invocation, + const std::function &callback) +{ + if (!callback) + { + return iterate_to_fixed_point(query, Lsize, init_ids, scratch, use_filter, filter_labels, search_invocation); + } + + std::vector &expanded_nodes = scratch->pool(); + NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); + best_L_nodes.reserve(Lsize); + tsl::robin_set &inserted_into_pool_rs = scratch->inserted_into_pool_rs(); + boost::dynamic_bitset<> &inserted_into_pool_bs = scratch->inserted_into_pool_bs(); + std::vector &id_scratch = scratch->id_scratch(); + std::vector &dist_scratch = scratch->dist_scratch(); + assert(id_scratch.empty()); + + T *aligned_query = scratch->aligned_query(); + + float *query_float = nullptr; + float *query_rotated = nullptr; + float *pq_dists = nullptr; + uint8_t *pq_coord_scratch = nullptr; + if (_pq_dist) + { + PQScratch *pq_query_scratch = scratch->pq_scratch(); + query_float = pq_query_scratch->aligned_query_float; + query_rotated = pq_query_scratch->rotated_query; + pq_dists = pq_query_scratch->aligned_pqtable_dist_scratch; + for (size_t d = 0; d < _dim; d++) + query_float[d] = (float)aligned_query[d]; + pq_query_scratch->initialize(_dim, aligned_query); + _pq_table.preprocess_query(query_rotated); + _pq_table.populate_chunk_distances(query_rotated, pq_dists); + pq_coord_scratch = pq_query_scratch->aligned_pq_coord_scratch; + } + + if (!expanded_nodes.empty() || !id_scratch.empty()) + throw ANNException("ERROR: Clear scratch space before passing.", -1, __FUNCSIG__, __FILE__, __LINE__); + + auto total_num_points = _max_points + _num_frozen_pts; + bool fast_iterate = total_num_points <= MAX_POINTS_FOR_USING_BITSET; + if (fast_iterate && inserted_into_pool_bs.size() < total_num_points) + { + auto resize_size = + 2 * total_num_points > MAX_POINTS_FOR_USING_BITSET ? MAX_POINTS_FOR_USING_BITSET : 2 * total_num_points; + inserted_into_pool_bs.resize(resize_size); + } + + auto is_not_visited = [fast_iterate, &inserted_into_pool_bs, &inserted_into_pool_rs](uint32_t id) { + return fast_iterate ? inserted_into_pool_bs[id] == 0 + : inserted_into_pool_rs.find(id) == inserted_into_pool_rs.end(); + }; + + auto compute_dists = [this, pq_coord_scratch, pq_dists](const std::vector &ids, + std::vector &dists_out) { + diskann::aggregate_coords(ids, this->_pq_data, this->_num_pq_chunks, pq_coord_scratch); + diskann::pq_dist_lookup(pq_coord_scratch, ids.size(), this->_num_pq_chunks, pq_dists, dists_out); + }; + + auto add_candidate = [&](uint32_t id) -> bool { + if (id >= _max_points + _num_frozen_pts) + return false; + if (use_filter && !detect_common_filters(id, search_invocation, filter_labels)) + return false; + + if (!is_not_visited(id)) + return false; + + if (fast_iterate) + inserted_into_pool_bs[id] = 1; + else + inserted_into_pool_rs.insert(id); + + float distance; + if (_pq_dist) + pq_dist_lookup(pq_coord_scratch, 1, this->_num_pq_chunks, pq_dists, &distance); + else + distance = _data_store->get_distance(aligned_query, id); + + // Map node id to doc id (if tags enabled use tag value, else use id) + int64_t doc_id; + if (_enable_tags) + { + TagT tag_val; + if (_location_to_tag.try_get(id, tag_val)) + doc_id = static_cast(tag_val); + else + doc_id = static_cast(id); + } + else + { + doc_id = static_cast(id); + } + + bool early_terminate = false; + float dist_ref = distance; + bool accept = callback(doc_id, dist_ref, early_terminate); + if (early_terminate) + { + if (accept) + { + Neighbor nn{id, dist_ref}; + best_L_nodes.insert(nn); + } + return false; // signal + } + if (accept) + { + Neighbor nn{id, dist_ref}; + best_L_nodes.insert(nn); + } + return true; + }; + + // Seed + for (auto id : init_ids) + { + if (add_candidate(id)) + return std::make_pair(0u, 0u); + } + + uint32_t hops = 0; + uint32_t cmps = 0; + bool terminate = false; + + while (!terminate && best_L_nodes.has_unexpanded_node()) + { + auto nbr = best_L_nodes.closest_unexpanded(); + auto n = nbr.id; + + if (!search_invocation) + { + if (!use_filter) + expanded_nodes.emplace_back(nbr); + else if (std::find(expanded_nodes.begin(), expanded_nodes.end(), nbr) == expanded_nodes.end()) + expanded_nodes.emplace_back(nbr); + } + + id_scratch.clear(); + dist_scratch.clear(); + { + if (_dynamic_index) + _locks[n].lock(); + for (auto id : _graph_store->get_neighbours(n)) + { + if (use_filter && !detect_common_filters(id, search_invocation, filter_labels)) + continue; + if (is_not_visited(id)) + id_scratch.push_back(id); + } + if (_dynamic_index) + _locks[n].unlock(); + } + + for (auto id : id_scratch) + { + if (fast_iterate) + inserted_into_pool_bs[id] = 1; + else + inserted_into_pool_rs.insert(id); + } + + if (_pq_dist) + { + assert(dist_scratch.capacity() >= id_scratch.size()); + compute_dists(id_scratch, dist_scratch); + } + else + { + for (size_t m = 0; m < id_scratch.size(); ++m) + { + uint32_t id = id_scratch[m]; + if (m + 1 < id_scratch.size()) + _data_store->prefetch_vector(id_scratch[m + 1]); + dist_scratch.push_back(_data_store->get_distance(aligned_query, id)); + } + } + cmps += (uint32_t)id_scratch.size(); + + for (size_t m = 0; m < id_scratch.size(); ++m) + { + uint32_t id = id_scratch[m]; + + // compute / override distance if PQ or normal already computed + float &dist_ref = _pq_dist ? dist_scratch[m] : dist_scratch[m]; + + int64_t doc_id; + if (_enable_tags) + { + TagT tag_val; + if (_location_to_tag.try_get(id, tag_val)) + doc_id = static_cast(tag_val); + else + doc_id = static_cast(id); + } + else + { + doc_id = static_cast(id); + } + + bool early_terminate = false; + bool accept = callback(doc_id, dist_ref, early_terminate); + if (accept) + best_L_nodes.insert(Neighbor(id, dist_ref)); + if (early_terminate) + { + terminate = true; + break; + } + } + ++hops; + } + + return std::make_pair(hops, cmps); +} + template void Index::search_for_point_and_prune(int location, uint32_t Lindex, std::vector &pruned_list, @@ -2548,8 +2771,8 @@ size_t Index::search_with_callback(const T *query, const uint64 //_distance->preprocess_query(query, _data_store->get_dims(), // scratch->aligned_query()); _data_store->get_dist_fn()->preprocess_query(query, _data_store->get_dims(), scratch->aligned_query()); - iterate_to_fixed_point(scratch->aligned_query(), L, init_ids, scratch, false, unused_filter_label, true); - + iterate_to_fixed_point_callback(scratch->aligned_query(), L, init_ids, scratch, false, unused_filter_label, true, + callback); NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); assert(best_L_nodes.size() <= L); From 6d0f6463dca929da039b411fc341993246f6cb46 Mon Sep 17 00:00:00 2001 From: jinwei14 Date: Mon, 13 Oct 2025 18:37:02 -0700 Subject: [PATCH 35/38] added test case for dist/id filtering (case:%2==0) --- apps/search_memory_index.cpp | 18 ++++++------------ apps/test_streaming_scenario.cpp | 2 +- src/index.cpp | 3 +-- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 46b44600a..6b2e36a62 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -145,8 +145,8 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, double best_recall = 0.0; std::int64_t value = 2; std::function callback_func = [value](const int64_t &id, float &reRankScore, bool &earlystop) -> bool { - diskann::cout << "check values for ID: " << id << std::endl; - return id != value; + //diskann::cout << "check values for ID: " << id << std::endl; + return id % value == 0; }; for (uint32_t test_id = 0; test_id < Lvec.size(); test_id++) @@ -184,18 +184,12 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } else if (tags) { - index->search_with_tags(query + i * query_aligned_dim, recall_at, L, - query_result_tags.data() + i * recall_at, nullptr, res); - for (int64_t r = 0; r < (int64_t)recall_at; r++) - { - query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r]; + if (callback_func){ + index->search_with_callback(query + i * query_aligned_dim, recall_at, L, query_result_tags.data() + i * recall_at, nullptr, res, callback_func); + }else{ + index->search_with_tags(query + i * query_aligned_dim, recall_at, L, query_result_tags.data() + i * recall_at, nullptr, res); } - } - else if (callback_func) - { - index->search_with_callback(query + i * query_aligned_dim, recall_at, L, - query_result_tags.data() + i * recall_at, nullptr, res, callback_func); for (int64_t r = 0; r < (int64_t)recall_at; r++) { query_result_ids[test_id][recall_at * i + r] = query_result_tags[recall_at * i + r]; diff --git a/apps/test_streaming_scenario.cpp b/apps/test_streaming_scenario.cpp index 5a43a69f3..09291190e 100644 --- a/apps/test_streaming_scenario.cpp +++ b/apps/test_streaming_scenario.cpp @@ -393,7 +393,7 @@ int main(int argc, char **argv) "with each line corresponding to a graph node"); optional_configs.add_options()("universal_label", po::value(&universal_label)->default_value(""), "Universal label, if using it, only in conjunction with labels_file"); - optional_configs.add_options()("FilteredLbuild,Lf", po::value(&Lf)->default_value(0), + optional_configs.add_options()("FilteredLbuild", po::value(&Lf)->default_value(0), "Build complexity for filtered points, higher value " "results in better graphs"); optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), diff --git a/src/index.cpp b/src/index.cpp index 0ec3cbd46..4937d77e4 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1368,8 +1368,7 @@ std::pair Index::iterate_to_fixed_point_cal // Seed for (auto id : init_ids) { - if (add_candidate(id)) - return std::make_pair(0u, 0u); + if (!add_candidate(id)) break; } uint32_t hops = 0; From 3538ede9b26a0d991aaacc4c0d6c8051481d032d Mon Sep 17 00:00:00 2001 From: jinwei14 Date: Wed, 15 Oct 2025 20:56:49 -0700 Subject: [PATCH 36/38] add fallback: if no seed accepted --- apps/search_memory_index.cpp | 2 +- src/index.cpp | 34 ++++++++++++++++++---------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 6b2e36a62..13f8b9641 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -143,7 +143,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } double best_recall = 0.0; - std::int64_t value = 2; + std::int64_t value = 3; std::function callback_func = [value](const int64_t &id, float &reRankScore, bool &earlystop) -> bool { //diskann::cout << "check values for ID: " << id << std::endl; return id % value == 0; diff --git a/src/index.cpp b/src/index.cpp index 4937d77e4..59f502f28 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1310,13 +1310,13 @@ std::pair Index::iterate_to_fixed_point_cal diskann::pq_dist_lookup(pq_coord_scratch, ids.size(), this->_num_pq_chunks, pq_dists, dists_out); }; - auto add_candidate = [&](uint32_t id) -> bool { + // Add candidate (used both for initial seeds and neighbor expansions) + auto add_candidate = [&](uint32_t id, bool force_insert_for_navigation = false) -> bool { if (id >= _max_points + _num_frozen_pts) return false; - if (use_filter && !detect_common_filters(id, search_invocation, filter_labels)) + if (use_filter && !detect_common_filters(id, search_invocation, filter_labels) && !force_insert_for_navigation) return false; - - if (!is_not_visited(id)) + if (!is_not_visited(id) && !force_insert_for_navigation) return false; if (fast_iterate) @@ -1347,30 +1347,32 @@ std::pair Index::iterate_to_fixed_point_cal bool early_terminate = false; float dist_ref = distance; - bool accept = callback(doc_id, dist_ref, early_terminate); - if (early_terminate) - { - if (accept) - { - Neighbor nn{id, dist_ref}; - best_L_nodes.insert(nn); - } - return false; // signal - } - if (accept) + bool accept = force_insert_for_navigation ? true : callback(doc_id, dist_ref, early_terminate); + + // Even if not accepted, we may need the node for navigation to avoid empty frontier. + if (accept || force_insert_for_navigation) { Neighbor nn{id, dist_ref}; best_L_nodes.insert(nn); } + + if (early_terminate) + return false; return true; }; - // Seed + // Seed initial ids via callback for (auto id : init_ids) { if (!add_candidate(id)) break; } + // Fallback: if no seed accepted, forcibly insert first init id to allow traversal + if (best_L_nodes.size() == 0 && !init_ids.empty()) + { + add_candidate(init_ids[0], true); + } + uint32_t hops = 0; uint32_t cmps = 0; bool terminate = false; From d4b6ef93891e24f6e2e4777c49c7ec5fd7fe4ad9 Mon Sep 17 00:00:00 2001 From: "jinwei.zhang" Date: Tue, 11 Nov 2025 19:38:39 -0800 Subject: [PATCH 37/38] extract callback at filter level instead of neighbour --- include/index.h | 4 + src/index.cpp | 232 +++++++++++++++++++++++++++++------------------- 2 files changed, 145 insertions(+), 91 deletions(-) diff --git a/include/index.h b/include/index.h index 72e9f30f7..c4f6a8cb4 100644 --- a/include/index.h +++ b/include/index.h @@ -107,6 +107,10 @@ template clas DISKANN_DLLEXPORT bool detect_common_filters(uint32_t point_id, bool search_invocation, const std::vector &incoming_labels); + DISKANN_DLLEXPORT bool detect_common_callback( + uint32_t point_id, float &distance, + const std::function &callback, bool &terminate_flag); + // Batch build from a file. Optionally pass tags vector. DISKANN_DLLEXPORT void build(const char *filename, const size_t num_points_to_load, const std::vector &tags = std::vector()); diff --git a/src/index.cpp b/src/index.cpp index 59f502f28..ccac488ae 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1029,6 +1029,37 @@ bool Index::detect_common_filters(uint32_t point_id, bool searc return (common_filters.size() > 0); } +template +bool Index::detect_common_callback( + uint32_t point_id, float &distance, const std::function &callback, + bool &terminate_flag) +{ + if (!callback) + { + return true; + } + + int64_t doc_id = static_cast(point_id); + if (_enable_tags) + { + TagT tag_val; + if (_location_to_tag.try_get(point_id, tag_val)) + { + doc_id = static_cast(tag_val); + } + } + + float distance_copy = distance; + bool early_terminate = false; + bool accept = callback(doc_id, distance_copy, early_terminate); + distance = distance_copy; + if (early_terminate) + { + terminate_flag = true; + } + return accept; +} + template std::pair Index::iterate_to_fixed_point( const T *query, const uint32_t Lsize, const std::vector &init_ids, InMemQueryScratch *scratch, @@ -1246,14 +1277,14 @@ std::pair Index::iterate_to_fixed_point( // ADD IMPLEMENTATION FOR CALLBACK BASED ITERATE TO FIXED POINT - template std::pair Index::iterate_to_fixed_point_callback( const T *query, uint32_t Lsize, const std::vector &init_ids, InMemQueryScratch *scratch, bool use_filter, const std::vector &filter_labels, bool search_invocation, const std::function &callback) { - if (!callback) + const bool use_callback = static_cast(callback); + if (!use_callback) { return iterate_to_fixed_point(query, Lsize, init_ids, scratch, use_filter, filter_labels, search_invocation); } @@ -1280,7 +1311,9 @@ std::pair Index::iterate_to_fixed_point_cal query_rotated = pq_query_scratch->rotated_query; pq_dists = pq_query_scratch->aligned_pqtable_dist_scratch; for (size_t d = 0; d < _dim; d++) + { query_float[d] = (float)aligned_query[d]; + } pq_query_scratch->initialize(_dim, aligned_query); _pq_table.preprocess_query(query_rotated); _pq_table.populate_chunk_distances(query_rotated, pq_dists); @@ -1288,15 +1321,21 @@ std::pair Index::iterate_to_fixed_point_cal } if (!expanded_nodes.empty() || !id_scratch.empty()) + { throw ANNException("ERROR: Clear scratch space before passing.", -1, __FUNCSIG__, __FILE__, __LINE__); + } auto total_num_points = _max_points + _num_frozen_pts; bool fast_iterate = total_num_points <= MAX_POINTS_FOR_USING_BITSET; - if (fast_iterate && inserted_into_pool_bs.size() < total_num_points) + + if (fast_iterate) { - auto resize_size = - 2 * total_num_points > MAX_POINTS_FOR_USING_BITSET ? MAX_POINTS_FOR_USING_BITSET : 2 * total_num_points; - inserted_into_pool_bs.resize(resize_size); + if (inserted_into_pool_bs.size() < total_num_points) + { + auto resize_size = + 2 * total_num_points > MAX_POINTS_FOR_USING_BITSET ? MAX_POINTS_FOR_USING_BITSET : 2 * total_num_points; + inserted_into_pool_bs.resize(resize_size); + } } auto is_not_visited = [fast_iterate, &inserted_into_pool_bs, &inserted_into_pool_rs](uint32_t id) { @@ -1310,84 +1349,96 @@ std::pair Index::iterate_to_fixed_point_cal diskann::pq_dist_lookup(pq_coord_scratch, ids.size(), this->_num_pq_chunks, pq_dists, dists_out); }; - // Add candidate (used both for initial seeds and neighbor expansions) - auto add_candidate = [&](uint32_t id, bool force_insert_for_navigation = false) -> bool { - if (id >= _max_points + _num_frozen_pts) - return false; - if (use_filter && !detect_common_filters(id, search_invocation, filter_labels) && !force_insert_for_navigation) - return false; - if (!is_not_visited(id) && !force_insert_for_navigation) - return false; - - if (fast_iterate) - inserted_into_pool_bs[id] = 1; - else - inserted_into_pool_rs.insert(id); - - float distance; - if (_pq_dist) - pq_dist_lookup(pq_coord_scratch, 1, this->_num_pq_chunks, pq_dists, &distance); - else - distance = _data_store->get_distance(aligned_query, id); + bool callback_terminate = false; - // Map node id to doc id (if tags enabled use tag value, else use id) - int64_t doc_id; - if (_enable_tags) + for (auto id : init_ids) + { + if (id >= _max_points + _num_frozen_pts) { - TagT tag_val; - if (_location_to_tag.try_get(id, tag_val)) - doc_id = static_cast(tag_val); - else - doc_id = static_cast(id); + diskann::cerr << "Out of range loc found as an edge : " << id << std::endl; + throw diskann::ANNException(std::string("Wrong loc") + std::to_string(id), -1, __FUNCSIG__, __FILE__, + __LINE__); } - else + + if (use_filter) { - doc_id = static_cast(id); + if (!detect_common_filters(id, search_invocation, filter_labels)) + continue; } - bool early_terminate = false; - float dist_ref = distance; - bool accept = force_insert_for_navigation ? true : callback(doc_id, dist_ref, early_terminate); - - // Even if not accepted, we may need the node for navigation to avoid empty frontier. - if (accept || force_insert_for_navigation) + if (use_callback) { - Neighbor nn{id, dist_ref}; - best_L_nodes.insert(nn); + if (float temp_distance = _data_store->get_distance(aligned_query, id); + !detect_common_callback(id, temp_distance, callback, callback_terminate)) + { + continue; + } } - if (early_terminate) - return false; - return true; - }; + if (is_not_visited(id)) + { + if (fast_iterate) + { + inserted_into_pool_bs[id] = 1; + } + else + { + inserted_into_pool_rs.insert(id); + } - // Seed initial ids via callback - for (auto id : init_ids) - { - if (!add_candidate(id)) break; + float distance; + if (_pq_dist) + { + pq_dist_lookup(pq_coord_scratch, 1, this->_num_pq_chunks, pq_dists, &distance); + } + else + { + distance = _data_store->get_distance(aligned_query, id); + } + Neighbor nn = Neighbor(id, distance); + best_L_nodes.insert(nn); + } } // Fallback: if no seed accepted, forcibly insert first init id to allow traversal if (best_L_nodes.size() == 0 && !init_ids.empty()) { - add_candidate(init_ids[0], true); + float distance; + if (_pq_dist) + { + pq_dist_lookup(pq_coord_scratch, 1, this->_num_pq_chunks, pq_dists, &distance); + } + else + { + distance = _data_store->get_distance(aligned_query, init_ids[0]); + } + Neighbor nn = Neighbor(init_ids[0], distance); + best_L_nodes.insert(nn); } - uint32_t hops = 0; uint32_t cmps = 0; - bool terminate = false; - while (!terminate && best_L_nodes.has_unexpanded_node()) + while (!callback_terminate && best_L_nodes.has_unexpanded_node()) { auto nbr = best_L_nodes.closest_unexpanded(); auto n = nbr.id; if (!search_invocation) { - if (!use_filter) - expanded_nodes.emplace_back(nbr); - else if (std::find(expanded_nodes.begin(), expanded_nodes.end(), nbr) == expanded_nodes.end()) + if (!use_filter && !use_callback) + { expanded_nodes.emplace_back(nbr); + } + else + { + // in filter based indexing, the same point might invoke + // multiple iterate_to_fixed_points, so need to be careful + // not to add the same item to pool multiple times + if (std::find(expanded_nodes.begin(), expanded_nodes.end(), nbr) == expanded_nodes.end()) + { + expanded_nodes.emplace_back(nbr); + } + } } id_scratch.clear(); @@ -1397,23 +1448,44 @@ std::pair Index::iterate_to_fixed_point_cal _locks[n].lock(); for (auto id : _graph_store->get_neighbours(n)) { - if (use_filter && !detect_common_filters(id, search_invocation, filter_labels)) - continue; + assert(id < _max_points + _num_frozen_pts); + + if (use_filter) + { + // NOTE: NEED TO CHECK IF THIS CORRECT WITH NEW LOCKS. + if (!detect_common_filters(id, search_invocation, filter_labels)) + continue; + } + if (use_callback) + { + if (float temp_distance = _data_store->get_distance(aligned_query, id); + !detect_common_callback(id, temp_distance, callback, callback_terminate)) + continue; + } + if (is_not_visited(id)) + { id_scratch.push_back(id); + } } if (_dynamic_index) _locks[n].unlock(); } + // Mark nodes visited for (auto id : id_scratch) { if (fast_iterate) + { inserted_into_pool_bs[id] = 1; + } else + { inserted_into_pool_rs.insert(id); + } } + // Compute distances to unvisited nodes in the expansion if (_pq_dist) { assert(dist_scratch.capacity() >= id_scratch.size()); @@ -1421,50 +1493,28 @@ std::pair Index::iterate_to_fixed_point_cal } else { + assert(dist_scratch.size() == 0); for (size_t m = 0; m < id_scratch.size(); ++m) { uint32_t id = id_scratch[m]; + if (m + 1 < id_scratch.size()) - _data_store->prefetch_vector(id_scratch[m + 1]); + { + auto nextn = id_scratch[m + 1]; + _data_store->prefetch_vector(nextn); + } + dist_scratch.push_back(_data_store->get_distance(aligned_query, id)); } } cmps += (uint32_t)id_scratch.size(); + // Insert pairs into the pool of candidates for (size_t m = 0; m < id_scratch.size(); ++m) { - uint32_t id = id_scratch[m]; - - // compute / override distance if PQ or normal already computed - float &dist_ref = _pq_dist ? dist_scratch[m] : dist_scratch[m]; - - int64_t doc_id; - if (_enable_tags) - { - TagT tag_val; - if (_location_to_tag.try_get(id, tag_val)) - doc_id = static_cast(tag_val); - else - doc_id = static_cast(id); - } - else - { - doc_id = static_cast(id); - } - - bool early_terminate = false; - bool accept = callback(doc_id, dist_ref, early_terminate); - if (accept) - best_L_nodes.insert(Neighbor(id, dist_ref)); - if (early_terminate) - { - terminate = true; - break; - } + best_L_nodes.insert(Neighbor(id_scratch[m], dist_scratch[m])); } - ++hops; } - return std::make_pair(hops, cmps); } From 76a92c27f3fcfb6b043a290ca7dd4f9ddbb706ec Mon Sep 17 00:00:00 2001 From: "jinwei.zhang" Date: Tue, 11 Nov 2025 21:20:20 -0800 Subject: [PATCH 38/38] clean up comments and brackets --- src/index.cpp | 32 +++++--------------------------- 1 file changed, 5 insertions(+), 27 deletions(-) diff --git a/src/index.cpp b/src/index.cpp index ccac488ae..80b822785 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -1426,18 +1426,9 @@ std::pair Index::iterate_to_fixed_point_cal if (!search_invocation) { if (!use_filter && !use_callback) - { expanded_nodes.emplace_back(nbr); - } - else - { - // in filter based indexing, the same point might invoke - // multiple iterate_to_fixed_points, so need to be careful - // not to add the same item to pool multiple times - if (std::find(expanded_nodes.begin(), expanded_nodes.end(), nbr) == expanded_nodes.end()) - { - expanded_nodes.emplace_back(nbr); - } + else if (std::find(expanded_nodes.begin(), expanded_nodes.end(), nbr) == expanded_nodes.end()){ + expanded_nodes.emplace_back(nbr); } } @@ -1452,40 +1443,31 @@ std::pair Index::iterate_to_fixed_point_cal if (use_filter) { - // NOTE: NEED TO CHECK IF THIS CORRECT WITH NEW LOCKS. if (!detect_common_filters(id, search_invocation, filter_labels)) continue; } if (use_callback) { - if (float temp_distance = _data_store->get_distance(aligned_query, id); - !detect_common_callback(id, temp_distance, callback, callback_terminate)) + float temp_distance = _data_store->get_distance(aligned_query, id); + if(!detect_common_callback(id, temp_distance, callback, callback_terminate)) continue; } if (is_not_visited(id)) - { id_scratch.push_back(id); - } } if (_dynamic_index) _locks[n].unlock(); } - // Mark nodes visited for (auto id : id_scratch) { if (fast_iterate) - { inserted_into_pool_bs[id] = 1; - } else - { inserted_into_pool_rs.insert(id); - } } - // Compute distances to unvisited nodes in the expansion if (_pq_dist) { assert(dist_scratch.capacity() >= id_scratch.size()); @@ -1497,12 +1479,8 @@ std::pair Index::iterate_to_fixed_point_cal for (size_t m = 0; m < id_scratch.size(); ++m) { uint32_t id = id_scratch[m]; - if (m + 1 < id_scratch.size()) - { - auto nextn = id_scratch[m + 1]; - _data_store->prefetch_vector(nextn); - } + _data_store->prefetch_vector(id_scratch[m + 1]); dist_scratch.push_back(_data_store->get_distance(aligned_query, id)); }