Skip to content

Commit 595f3a0

Browse files
committed
support integer filter label
1 parent 506718d commit 595f3a0

12 files changed

+855
-135
lines changed

include/abstract_index.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class AbstractIndex
5353
#ifdef EXEC_ENV_OLS
5454
virtual void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l) = 0;
5555
#else
56-
virtual void load(const char *index_file, uint32_t num_threads, uint32_t search_l, bool loadBitmaskLabelFile = false) = 0;
56+
virtual void load(const char *index_file, uint32_t num_threads, uint32_t search_l, LabelFormatType label_format_type = LabelFormatType::String) = 0;
5757
#endif
5858

5959
// For FastL2 search on optimized layout

include/filter_match_proxy.h

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#pragma once
2+
#include "label_bitmask.h"
3+
#include "integer_label_vector.h"
4+
5+
namespace diskann
6+
{
7+
8+
class filter_match_proxy
9+
{
10+
public:
11+
virtual bool contain_filtered_label(uint32_t id) = 0;
12+
};
13+
14+
template <typename LabelT>
15+
class bitmask_filter_match : public filter_match_proxy
16+
{
17+
public:
18+
bitmask_filter_match(simple_bitmask_buf& bitmask_filters,
19+
std::vector<std::uint64_t>& query_bitmask_buf,
20+
const std::vector<LabelT>& filter_labels,
21+
LabelT unv_label);
22+
23+
virtual bool contain_filtered_label(uint32_t id) override;
24+
25+
private:
26+
simple_bitmask_buf& _bitmask_filters;
27+
std::vector<std::uint64_t>& _query_bitmask_buf;
28+
simple_bitmask_full_val _bitmask_full_val;
29+
};
30+
31+
template <typename LabelT>
32+
class integer_label_filter_match : public filter_match_proxy
33+
{
34+
public:
35+
integer_label_filter_match(integer_label_vector& label_vector,
36+
const std::vector<LabelT>& filter_labels,
37+
LabelT unv_label);
38+
39+
virtual bool contain_filtered_label(uint32_t id) override;
40+
41+
private:
42+
integer_label_vector& _label_vector;
43+
const std::vector<LabelT>& _filter_labels;
44+
LabelT _unv_label;
45+
};
46+
47+
template <typename LabelT>
48+
class label_filter_match_holder : public filter_match_proxy
49+
{
50+
public:
51+
label_filter_match_holder(simple_bitmask_buf& bitmask_filters,
52+
std::vector<std::uint64_t>& query_bitmask_buf,
53+
integer_label_vector& label_vector,
54+
const std::vector<LabelT>& filter_labels,
55+
LabelT unv_label,
56+
bool use_integer_labels);
57+
58+
virtual bool contain_filtered_label(uint32_t id) override;
59+
60+
private:
61+
bitmask_filter_match<LabelT> _bitmask_filter_match;
62+
integer_label_filter_match<LabelT> _integer_label_filter_match;
63+
bool _use_integer_labels;
64+
};
65+
66+
}

include/index.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "percentile_stats.h"
2525
#include <bitset>
2626
#include "label_bitmask.h"
27+
#include "integer_label_vector.h"
2728

2829
#include "quantized_distance.h"
2930
#include "pq_data_store.h"
@@ -80,7 +81,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
8081
#ifdef EXEC_ENV_OLS
8182
DISKANN_DLLEXPORT void load(AlignedFileReader &reader, uint32_t num_threads, uint32_t search_l);
8283
#else
83-
DISKANN_DLLEXPORT void load(const char *index_file, uint32_t num_threads, uint32_t search_l, bool loadBitmaskLabelFile = false);
84+
DISKANN_DLLEXPORT void load(const char *index_file, uint32_t num_threads, uint32_t search_l, LabelFormatType label_format_type = LabelFormatType::String);
8485
#endif
8586

8687
// get some private variables
@@ -118,6 +119,10 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
118119

119120
DISKANN_DLLEXPORT bool is_set_universal_label() const override;
120121

122+
DISKANN_DLLEXPORT void enable_integer_label();
123+
124+
DISKANN_DLLEXPORT bool integer_label_enabled();
125+
121126
// Set starting point of an index before inserting any points incrementally.
122127
// The data count should be equal to _num_frozen_pts * _aligned_dim.
123128
DISKANN_DLLEXPORT void set_start_points(const T *data, size_t data_count);
@@ -253,11 +258,18 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
253258
// determines navigating node of the graph by calculating medoid of datafopt
254259
uint32_t calculate_entry_point();
255260

256-
void parse_label_file(const std::string &label_file, size_t &num_pts_labels);
261+
void parse_label_file(const std::string &label_file, size_t &num_pts_labels, size_t& total_labels);
257262
void parse_seller_file(const std::string& label_file, size_t& num_pts_labels);
258263

259264
void convert_pts_label_to_bitmask(std::vector<std::vector<LabelT>>& pts_to_labels, simple_bitmask_buf& bitmask_buf, size_t num_labels);
260265

266+
void convert_pts_label_to_integer_vector(std::vector<std::vector<LabelT>> &pts_to_labels,
267+
integer_label_vector &int_label_vector, size_t total_labels);
268+
269+
void aggregate_points_by_bitmask_label(std::unordered_map<LabelT, std::vector<uint32_t>>& label_to_points, size_t num_points_to_load);
270+
271+
void aggregate_points_by_integer_label(std::unordered_map<LabelT, std::vector<uint32_t>>& label_to_points, size_t num_points_to_load);
272+
261273
std::unordered_map<std::string, LabelT> load_label_map(const std::string &map_file);
262274

263275
// Returns the locations of start point and frozen points suitable for use
@@ -463,6 +475,9 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
463475

464476
simple_bitmask_buf _bitmask_buf;
465477

478+
bool _use_integer_labels = false;
479+
integer_label_vector _label_vector;
480+
466481
TableStats _table_stats;
467482

468483
static const float INDEX_GROWTH_FACTOR;

include/integer_label_vector.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#pragma once
2+
#include <vector>
3+
#include <string>
4+
5+
namespace diskann
6+
{
7+
8+
class integer_label_vector
9+
{
10+
public:
11+
bool initialize(size_t numpoints, size_t total_labels);
12+
13+
bool initialize_from_file(const std::string &label_file, size_t &numpoints);
14+
15+
bool write_to_file(const std::string &label_file) const;
16+
17+
template <typename LabelT>
18+
bool add_labels(uint32_t point_id, std::vector<LabelT> &labels);
19+
20+
bool check_label_exists(uint32_t point_id, uint32_t label);
21+
22+
template <typename LabelT>
23+
bool check_label_exists(uint32_t point_id, const std::vector<LabelT> &labels);
24+
25+
bool check_label_full_contain(uint32_t point_id, const std::vector<uint32_t> &labels);
26+
27+
bool check_label_full_contain(uint32_t source_point, uint32_t target_point);
28+
29+
const std::vector<size_t> &get_offset_vector() const;
30+
31+
const std::vector<uint32_t> &get_data_vector() const;
32+
33+
size_t get_memory_usage() const;
34+
35+
private:
36+
bool binary_search(size_t start, size_t end, uint32_t label, size_t& last_check);
37+
38+
private:
39+
std::vector<size_t> _offset;
40+
std::vector<uint32_t> _data;
41+
};
42+
43+
}

include/label_helper.h

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22
#include "label_bitmask.h"
3+
#include "integer_label_vector.h"
34
#include "percentile_stats.h"
5+
#include "tsl/robin_set.h"
46
#include <string>
57

68
namespace diskann
@@ -20,6 +22,113 @@ class label_helper
2022

2123
bool read_bitmask_from_file(const std::string &bitmask_label_file, simple_bitmask_buf &bitmask_buf,
2224
size_t& num_points);
25+
26+
bool parse_label_file_in_integer(
27+
const std::string& label_file,
28+
size_t& num_points,
29+
integer_label_vector& integer_vector,
30+
tsl::robin_set<uint32_t>& labels, TableStats &table_stats);
31+
32+
template <typename LabelT>
33+
bool load_label_map(
34+
const std::string& label_map_file,
35+
std::unordered_map<std::string, LabelT>& label_map)
36+
{
37+
std::ifstream infile(label_map_file, std::ios::binary);
38+
if (infile.fail())
39+
{
40+
throw diskann::ANNException(std::string("Failed to open file ") + label_map_file, -1);
41+
}
42+
infile.seekg(0, std::ios::end);
43+
size_t file_size = infile.tellg();
44+
45+
std::string buffer(file_size, ' ');
46+
47+
infile.seekg(0, std::ios::beg);
48+
infile.read(&buffer[0], file_size);
49+
infile.close();
50+
51+
unsigned line_cnt = 0;
52+
53+
size_t cur_pos = 0;
54+
size_t next_pos = 0;
55+
size_t lbl_pos = 0;
56+
std::string token;
57+
std::string labe_str;
58+
while (cur_pos < file_size && cur_pos != std::string::npos)
59+
{
60+
next_pos = buffer.find('\n', cur_pos);
61+
if (next_pos == std::string::npos)
62+
{
63+
break;
64+
}
65+
66+
lbl_pos = search_string_range(buffer, '\t', cur_pos, next_pos);
67+
labe_str.assign(buffer.c_str() + cur_pos, lbl_pos - cur_pos);
68+
69+
token.assign(buffer.c_str() + lbl_pos + 1, next_pos - lbl_pos - 1);
70+
LabelT label_num = (LabelT)std::stoul(token);
71+
72+
label_map[labe_str] = label_num;
73+
74+
cur_pos = next_pos + 1;
75+
76+
line_cnt++;
77+
}
78+
79+
return true;
80+
}
81+
82+
template <typename LabelT>
83+
bool load_label_medoids(
84+
const std::string& label_medoids_file,
85+
std::unordered_map<LabelT, uint32_t>& label_to_start_id)
86+
{
87+
std::ifstream infile(label_medoids_file, std::ios::binary);
88+
if (infile.fail())
89+
{
90+
throw diskann::ANNException(std::string("Failed to open file ") + label_medoids_file, -1);
91+
}
92+
infile.seekg(0, std::ios::end);
93+
size_t file_size = infile.tellg();
94+
95+
std::string buffer(file_size, ' ');
96+
97+
infile.seekg(0, std::ios::beg);
98+
infile.read(&buffer[0], file_size);
99+
infile.close();
100+
101+
unsigned line_cnt = 0;
102+
103+
size_t cur_pos = 0;
104+
size_t next_pos = 0;
105+
size_t lbl_pos = 0;
106+
std::string token;
107+
while (cur_pos < file_size && cur_pos != std::string::npos)
108+
{
109+
next_pos = buffer.find('\n', cur_pos);
110+
if (next_pos == std::string::npos)
111+
{
112+
break;
113+
}
114+
115+
lbl_pos = search_string_range(buffer, ',', cur_pos, next_pos);
116+
token.assign(buffer.c_str() + cur_pos, lbl_pos - cur_pos);
117+
LabelT label_num = (LabelT)std::stoul(token);
118+
119+
token.assign(buffer.c_str() + lbl_pos + 1, next_pos - lbl_pos - 1);
120+
uint32_t medoid = (uint32_t)std::stoul(token);
121+
122+
label_to_start_id[label_num] = medoid;
123+
124+
cur_pos = next_pos + 1;
125+
126+
line_cnt++;
127+
}
128+
129+
return true;
130+
}
131+
23132
private:
24133
size_t search_string_range(const std::string& str, char ch, size_t start, size_t end);
25134
};

include/parameters.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212
namespace diskann
1313
{
1414

15+
enum class LabelFormatType :uint8_t
16+
{
17+
String = 0,
18+
BitMask = 1,
19+
Integer = 2
20+
};
21+
1522
class IndexWriteParameters
1623

1724
{

include/pq_flash_index.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "tsl/robin_map.h"
1717
#include "tsl/robin_set.h"
1818
#include "label_bitmask.h"
19+
#include "integer_label_vector.h"
1920

2021
#define FULL_PRECISION_REORDER_MULTIPLIER 3
2122

@@ -47,7 +48,7 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
4748
const char* labels_filepath, const char* labels_to_medoids_filepath,
4849
const char* labels_map_filepath, const char* unv_label_filepath,
4950
const char* seller_filepath,
50-
bool load_bitmask_label = false);
51+
LabelFormatType label_format_type = LabelFormatType::String);
5152
#endif
5253

5354
DISKANN_DLLEXPORT void load_cache_list(std::vector<uint32_t> &node_list);
@@ -234,9 +235,12 @@ template <typename T, typename LabelT = uint32_t> class PQFlashIndex
234235
// filter support
235236
simple_bitmask_buf _bitmask_buf;
236237

238+
bool _use_integer_labels = false;
239+
integer_label_vector _label_vector;
240+
237241
std::unordered_map<LabelT, std::vector<uint32_t>> _filter_to_medoid_ids;
238242
bool _use_universal_label = false;
239-
LabelT _universal_filter_label;
243+
LabelT _universal_filter_label = 0;
240244
tsl::robin_set<uint32_t> _dummy_pts;
241245
tsl::robin_set<uint32_t> _has_dummy_pts;
242246
tsl::robin_map<uint32_t, uint32_t> _dummy_to_real_map;

0 commit comments

Comments
 (0)