2424#include " percentile_stats.h"
2525#include < bitset>
2626#include " label_bitmask.h"
27+ #include " integer_label_vector.h"
2728
2829#include " quantized_distance.h"
2930#include " pq_data_store.h"
@@ -80,7 +81,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
8081#ifdef EXEC_ENV_OLS
8182 DISKANN_DLLEXPORT void load (AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l);
8283#else
83- DISKANN_DLLEXPORT void load (const char *index_file, uint32_t num_threads, uint32_t search_l, bool loadBitmaskLabelFile = false );
84+ DISKANN_DLLEXPORT void load (const char *index_file, uint32_t num_threads, uint32_t search_l, LabelFormatType label_format_type = LabelFormatType::String );
8485#endif
8586
8687 // get some private variables
@@ -118,6 +119,10 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
118119
119120 DISKANN_DLLEXPORT bool is_set_universal_label () const override ;
120121
122+ DISKANN_DLLEXPORT void enable_integer_label () override ;
123+
124+ DISKANN_DLLEXPORT bool integer_label_enabled () const override ;
125+
121126 // Set starting point of an index before inserting any points incrementally.
122127 // The data count should be equal to _num_frozen_pts * _aligned_dim.
123128 DISKANN_DLLEXPORT void set_start_points (const T *data, size_t data_count);
@@ -144,15 +149,15 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
144149
145150 // Initialize space for res_vectors before calling.
146151 DISKANN_DLLEXPORT size_t search_with_tags (const T *query, const uint64_t K, const uint32_t L, TagT *tags,
147- float *distances, std::vector<T *> &res_vectors, bool use_filters = false ,
148- const std::string filter_label = " " );
152+ float *distances, std::vector<T *> &res_vectors, bool use_filters,
153+ const std::vector<std:: string>& filter_labels );
149154
150155 virtual std::pair<uint32_t , uint32_t > _diverse_search (const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller,
151156 std::any& indices, float * distances = nullptr ) override ;
152157
153158 // Filter support search
154159 template <typename IndexType>
155- DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > search_with_filters (const T *query, const LabelT &filter_label ,
160+ DISKANN_DLLEXPORT std::pair<uint32_t , uint32_t > search_with_filters (const T *query, const std::vector< LabelT> &filter_labels ,
156161 const size_t K, const uint32_t L, const uint32_t maxLperSeller,
157162 IndexType *indices, float *distances);
158163
@@ -217,7 +222,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
217222 virtual std::pair<uint32_t , uint32_t > _search (const DataType &query, const size_t K, const uint32_t L,
218223 std::any &indices, float *distances = nullptr ) override ;
219224 virtual std::pair<uint32_t , uint32_t > _search_with_filters (const DataType &query,
220- const std::string &filter_label_raw , const size_t K,
225+ const std::vector<std:: string> &filter_labels_raw , const size_t K,
221226 const uint32_t L, const uint32_t maxLperSeller, std::any &indices,
222227 float *distances) override ;
223228
@@ -237,8 +242,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
237242 virtual void _search_with_optimized_layout (const DataType &query, size_t K, size_t L, uint32_t *indices) override ;
238243
239244 virtual size_t _search_with_tags (const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
240- float *distances, DataVector &res_vectors, bool use_filters = false ,
241- const std::string filter_label = " " ) override ;
245+ float *distances, DataVector &res_vectors, bool use_filters,
246+ const std::vector<std:: string>& filter_labels ) override ;
242247
243248 virtual void _set_universal_label (const LabelType universal_label) override ;
244249
@@ -253,11 +258,18 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
253258 // determines navigating node of the graph by calculating medoid of datafopt
254259 uint32_t calculate_entry_point ();
255260
256- void parse_label_file (const std::string &label_file, size_t &num_pts_labels);
261+ void parse_label_file (const std::string &label_file, size_t &num_pts_labels, size_t & total_labels );
257262 void parse_seller_file (const std::string& label_file, size_t & num_pts_labels);
258263
259264 void convert_pts_label_to_bitmask (std::vector<std::vector<LabelT>>& pts_to_labels, simple_bitmask_buf& bitmask_buf, size_t num_labels);
260265
266+ void convert_pts_label_to_integer_vector (std::vector<std::vector<LabelT>> &pts_to_labels,
267+ integer_label_vector &int_label_vector, size_t total_labels);
268+
269+ void aggregate_points_by_bitmask_label (std::unordered_map<LabelT, std::vector<uint32_t >>& label_to_points, size_t num_points_to_load);
270+
271+ void aggregate_points_by_integer_label (std::unordered_map<LabelT, std::vector<uint32_t >>& label_to_points, size_t num_points_to_load);
272+
261273 std::unordered_map<std::string, LabelT> load_label_map (const std::string &map_file);
262274
263275 // Returns the locations of start point and frozen points suitable for use
@@ -463,6 +475,9 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
463475
464476 simple_bitmask_buf _bitmask_buf;
465477
478+ bool _use_integer_labels = false ;
479+ integer_label_vector _label_vector;
480+
466481 TableStats _table_stats;
467482
468483 static const float INDEX_GROWTH_FACTOR;
0 commit comments