Skip to content

Commit 10309e4

Browse files
authored
Jegao/multiple filter support (#680)
* change filtered search interface * fix streaming interface * fix compile issue * support integer filter label * fix ssd index building * expose integer label get/set * clean up * Fix ssd medoid file loading * sort label befor searching * resolve comment
1 parent bf1e48e commit 10309e4

15 files changed

+1076
-269
lines changed

include/abstract_index.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class AbstractIndex
5353
#ifdef EXEC_ENV_OLS
5454
virtual void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l) = 0;
5555
#else
56-
virtual void load(const char *index_file, uint32_t num_threads, uint32_t search_l, bool loadBitmaskLabelFile = false) = 0;
56+
virtual void load(const char *index_file, uint32_t num_threads, uint32_t search_l, LabelFormatType label_format_type = LabelFormatType::String) = 0;
5757
#endif
5858

5959
// For FastL2 search on optimized layout
@@ -63,8 +63,8 @@ class AbstractIndex
6363
// Initialize space for res_vectors before calling.
6464
template <typename data_type, typename tag_type>
6565
size_t search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags,
66-
float *distances, std::vector<data_type *> &res_vectors, bool use_filters = false,
67-
const std::string filter_label = "");
66+
float *distances, std::vector<data_type *> &res_vectors, bool use_filters,
67+
const std::vector<std::string>& filter_labels);
6868

6969
// Added search overload that takes L as parameter, so that we
7070
// can customize L on a per-query basis without tampering with "Parameters"
@@ -80,7 +80,7 @@ class AbstractIndex
8080
// Filter support search
8181
// IndexType is either uint32_t or uint64_t
8282
template <typename IndexType>
83-
std::pair<uint32_t, uint32_t> search_with_filters(const DataType &query, const std::string &raw_label,
83+
std::pair<uint32_t, uint32_t> search_with_filters(const DataType &query, const std::vector<std::string> &raw_labels,
8484
const size_t K, const uint32_t L, const uint32_t maxLperSeller,
8585
IndexType *indices,
8686
float *distances);
@@ -112,6 +112,9 @@ class AbstractIndex
112112

113113
template <typename label_type> void set_universal_label(const label_type universal_label);
114114

115+
virtual void enable_integer_label() = 0;
116+
virtual bool integer_label_enabled() const = 0;
117+
115118
virtual bool is_label_valid(const std::string &raw_label) const = 0;
116119
virtual bool is_set_universal_label() const = 0;
117120
virtual TableStats get_table_stats() const = 0;
@@ -122,7 +125,7 @@ class AbstractIndex
122125
std::any &indices, float *distances = nullptr) = 0;
123126
virtual std::pair<uint32_t, uint32_t> _diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller,
124127
std::any& indices, float* distances = nullptr) = 0;
125-
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query, const std::string &filter_label,
128+
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query, const std::vector<std::string> &filter_labels,
126129
const size_t K, const uint32_t L, const uint32_t maxLperSeller, std::any &indices,
127130
float *distances) = 0;
128131
virtual int _insert_point(const DataType &data_point, const TagType tag, const std::vector<std::string> &labels) = 0;
@@ -133,8 +136,8 @@ class AbstractIndex
133136
virtual void _set_start_points_at_random(DataType radius, uint32_t random_seed = 0) = 0;
134137
virtual int _get_vector_by_tag(TagType &tag, DataType &vec) = 0;
135138
virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
136-
float *distances, DataVector &res_vectors, bool use_filters = false,
137-
const std::string filter_label = "") = 0;
139+
float *distances, DataVector &res_vectors, bool use_filters,
140+
const std::vector<std::string>& filter_labels) = 0;
138141
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) = 0;
139142
virtual void _set_universal_label(const LabelType universal_label) = 0;
140143
};

include/disk_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ DISKANN_DLLEXPORT int build_merged_vamana_index(std::string base_file, diskann::
7979
uint32_t R, double sampling_rate, double ram_budget,
8080
std::string mem_index_path, std::string medoids_file,
8181
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
82-
uint32_t num_threads, bool use_filters = false,
82+
uint32_t num_threads, bool use_filters = false, bool use_integer_labels = false,
8383
const std::string &label_file = std::string(""),
8484
const std::string &labels_to_medoids_file = std::string(""),
8585
const std::string &universal_label = "", const uint32_t Lf = 0);
@@ -95,7 +95,7 @@ DISKANN_DLLEXPORT int build_disk_index(
9595
const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters,
9696
diskann::Metric _compareMetric, bool use_opq = false,
9797
const std::string &codebook_prefix = "", // default is empty for no codebook pass in
98-
bool use_filters = false,
98+
bool use_filters = false, bool use_integer_labels = false,
9999
const std::string &label_file = std::string(""), // default is empty string for no label_file
100100
const std::string &universal_label = "", const uint32_t filter_threshold = 0,
101101
const uint32_t Lf = 0,

include/filter_match_proxy.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#pragma once
2+
#include "label_bitmask.h"
3+
#include "integer_label_vector.h"
4+
5+
namespace diskann
6+
{
7+
8+
class filter_match_proxy
9+
{
10+
public:
11+
virtual bool contain_filtered_label(uint32_t id) = 0;
12+
};
13+
14+
template <typename LabelT>
15+
class bitmask_filter_match : public filter_match_proxy
16+
{
17+
public:
18+
bitmask_filter_match(simple_bitmask_buf& bitmask_filters,
19+
std::vector<std::uint64_t>& query_bitmask_buf,
20+
const std::vector<LabelT>& filter_labels,
21+
LabelT unv_label);
22+
23+
virtual bool contain_filtered_label(uint32_t id) override;
24+
25+
private:
26+
simple_bitmask_buf& _bitmask_filters;
27+
std::vector<std::uint64_t>& _query_bitmask_buf;
28+
simple_bitmask_full_val _bitmask_full_val;
29+
};
30+
31+
template <typename LabelT>
32+
class integer_label_filter_match : public filter_match_proxy
33+
{
34+
public:
35+
integer_label_filter_match(integer_label_vector& label_vector,
36+
const std::vector<LabelT>& filter_labels,
37+
LabelT unv_label);
38+
39+
virtual bool contain_filtered_label(uint32_t id) override;
40+
41+
private:
42+
integer_label_vector& _label_vector;
43+
const std::vector<LabelT>& _filter_labels;
44+
LabelT _unv_label;
45+
};
46+
47+
template <typename LabelT>
48+
class label_filter_match_holder : public filter_match_proxy
49+
{
50+
public:
51+
label_filter_match_holder(simple_bitmask_buf& bitmask_filters,
52+
std::vector<std::uint64_t>& query_bitmask_buf,
53+
integer_label_vector& label_vector,
54+
const std::vector<LabelT>& filter_labels,
55+
LabelT unv_label,
56+
bool use_integer_labels);
57+
58+
virtual bool contain_filtered_label(uint32_t id) override;
59+
60+
private:
61+
bitmask_filter_match<LabelT> _bitmask_filter_match;
62+
integer_label_filter_match<LabelT> _integer_label_filter_match;
63+
bool _use_integer_labels;
64+
};
65+
66+
}

include/index.h

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "percentile_stats.h"
2525
#include <bitset>
2626
#include "label_bitmask.h"
27+
#include "integer_label_vector.h"
2728

2829
#include "quantized_distance.h"
2930
#include "pq_data_store.h"
@@ -80,7 +81,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
8081
#ifdef EXEC_ENV_OLS
8182
DISKANN_DLLEXPORT void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l);
8283
#else
83-
DISKANN_DLLEXPORT void load(const char *index_file, uint32_t num_threads, uint32_t search_l, bool loadBitmaskLabelFile = false);
84+
DISKANN_DLLEXPORT void load(const char *index_file, uint32_t num_threads, uint32_t search_l, LabelFormatType label_format_type = LabelFormatType::String);
8485
#endif
8586

8687
// get some private variables
@@ -118,6 +119,10 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
118119

119120
DISKANN_DLLEXPORT bool is_set_universal_label() const override;
120121

122+
DISKANN_DLLEXPORT void enable_integer_label() override;
123+
124+
DISKANN_DLLEXPORT bool integer_label_enabled() const override;
125+
121126
// Set starting point of an index before inserting any points incrementally.
122127
// The data count should be equal to _num_frozen_pts * _aligned_dim.
123128
DISKANN_DLLEXPORT void set_start_points(const T *data, size_t data_count);
@@ -144,15 +149,15 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
144149

145150
// Initialize space for res_vectors before calling.
146151
DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags,
147-
float *distances, std::vector<T *> &res_vectors, bool use_filters = false,
148-
const std::string filter_label = "");
152+
float *distances, std::vector<T *> &res_vectors, bool use_filters,
153+
const std::vector<std::string>& filter_labels);
149154

150155
virtual std::pair<uint32_t, uint32_t> _diverse_search(const DataType& query, const size_t K, const uint32_t L, const uint32_t maxLperSeller,
151156
std::any& indices, float* distances = nullptr) override;
152157

153158
// Filter support search
154159
template <typename IndexType>
155-
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(const T *query, const LabelT &filter_label,
160+
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(const T *query, const std::vector<LabelT> &filter_labels,
156161
const size_t K, const uint32_t L, const uint32_t maxLperSeller,
157162
IndexType *indices, float *distances);
158163

@@ -217,7 +222,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
217222
virtual std::pair<uint32_t, uint32_t> _search(const DataType &query, const size_t K, const uint32_t L,
218223
std::any &indices, float *distances = nullptr) override;
219224
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query,
220-
const std::string &filter_label_raw, const size_t K,
225+
const std::vector<std::string> &filter_labels_raw, const size_t K,
221226
const uint32_t L, const uint32_t maxLperSeller, std::any &indices,
222227
float *distances) override;
223228

@@ -237,8 +242,8 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
237242
virtual void _search_with_optimized_layout(const DataType &query, size_t K, size_t L, uint32_t *indices) override;
238243

239244
virtual size_t _search_with_tags(const DataType &query, const uint64_t K, const uint32_t L, const TagType &tags,
240-
float *distances, DataVector &res_vectors, bool use_filters = false,
241-
const std::string filter_label = "") override;
245+
float *distances, DataVector &res_vectors, bool use_filters,
246+
const std::vector<std::string>& filter_labels) override;
242247

243248
virtual void _set_universal_label(const LabelType universal_label) override;
244249

@@ -253,11 +258,18 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
253258
// determines navigating node of the graph by calculating medoid of datafopt
254259
uint32_t calculate_entry_point();
255260

256-
void parse_label_file(const std::string &label_file, size_t &num_pts_labels);
261+
void parse_label_file(const std::string &label_file, size_t &num_pts_labels, size_t& total_labels);
257262
void parse_seller_file(const std::string& label_file, size_t& num_pts_labels);
258263

259264
void convert_pts_label_to_bitmask(std::vector<std::vector<LabelT>>& pts_to_labels, simple_bitmask_buf& bitmask_buf, size_t num_labels);
260265

266+
void convert_pts_label_to_integer_vector(std::vector<std::vector<LabelT>> &pts_to_labels,
267+
integer_label_vector &int_label_vector, size_t total_labels);
268+
269+
void aggregate_points_by_bitmask_label(std::unordered_map<LabelT, std::vector<uint32_t>>& label_to_points, size_t num_points_to_load);
270+
271+
void aggregate_points_by_integer_label(std::unordered_map<LabelT, std::vector<uint32_t>>& label_to_points, size_t num_points_to_load);
272+
261273
std::unordered_map<std::string, LabelT> load_label_map(const std::string &map_file);
262274

263275
// Returns the locations of start point and frozen points suitable for use
@@ -463,6 +475,9 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
463475

464476
simple_bitmask_buf _bitmask_buf;
465477

478+
bool _use_integer_labels = false;
479+
integer_label_vector _label_vector;
480+
466481
TableStats _table_stats;
467482

468483
static const float INDEX_GROWTH_FACTOR;

include/integer_label_vector.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#pragma once
2+
#include <vector>
3+
#include <string>
4+
5+
namespace diskann
6+
{
7+
8+
class integer_label_vector
9+
{
10+
public:
11+
bool initialize(size_t numpoints, size_t total_labels);
12+
13+
bool initialize_from_file(const std::string &label_file, size_t &numpoints);
14+
15+
bool write_to_file(const std::string &label_file) const;
16+
17+
template <typename LabelT>
18+
bool add_labels(uint32_t point_id, std::vector<LabelT> &labels);
19+
20+
bool check_label_exists(uint32_t point_id, uint32_t label);
21+
22+
template <typename LabelT>
23+
bool check_label_exists(uint32_t point_id, const std::vector<LabelT> &labels);
24+
25+
bool check_label_full_contain(uint32_t source_point, uint32_t target_point);
26+
27+
const std::vector<size_t> &get_offset_vector() const;
28+
29+
const std::vector<uint32_t> &get_data_vector() const;
30+
31+
size_t get_memory_usage() const;
32+
33+
private:
34+
bool binary_search(size_t start, size_t end, uint32_t label, size_t& last_check);
35+
36+
private:
37+
std::vector<size_t> _offset;
38+
std::vector<uint32_t> _data;
39+
};
40+
41+
}

0 commit comments

Comments
 (0)