Skip to content

Commit 699f568

Browse files
committed
Fix errors
1 parent 90973d1 commit 699f568

File tree

6 files changed

+111
-42
lines changed

6 files changed

+111
-42
lines changed

apps/build_memory_index.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ namespace po = boost::program_options;
2424

2525
int main(int argc, char **argv)
2626
{
27-
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
28-
uint32_t num_threads, R, L, Lf, build_PQ_bytes;
27+
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type, seller_file, attribute_file;
28+
uint32_t num_threads, R, L, Lf, build_PQ_bytes, num_diverse_build;
2929
float alpha;
30-
bool use_pq_build, use_opq;
30+
bool use_pq_build, use_opq, diverse_index = false, attribute_diversity = false;
31+
3132

3233
po::options_description desc{
3334
program_options_utils::make_program_description("build_memory_index", "Build a memory-based DiskANN index.")};
@@ -70,6 +71,12 @@ int main(int argc, char **argv)
7071
program_options_utils::FILTERED_LBUILD);
7172
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
7273
program_options_utils::LABEL_TYPE_DESCRIPTION);
74+
optional_configs.add_options()("seller_file", po::value<std::string>(&seller_file)->default_value(""),
75+
program_options_utils::DIVERSITY_FILE);
76+
optional_configs.add_options()("attribute_file", po::value<std::string>(&attribute_file)->default_value(""),
77+
program_options_utils::ATTRIBUTE_DIVERSITY_FILE);
78+
optional_configs.add_options()("NumDiverse", po::value<uint32_t>(&num_diverse_build)->default_value(1),
79+
program_options_utils::NUM_DIVERSE);
7380

7481
// Merge required and optional parameters
7582
desc.add(required_configs).add(optional_configs);
@@ -112,19 +119,33 @@ int main(int argc, char **argv)
112119
return -1;
113120
}
114121

122+
if(seller_file != "")
123+
diverse_index = true;
124+
125+
126+
if(attribute_file != ""){
127+
diverse_index = true;
128+
attribute_diversity = true;
129+
}
115130
try
116131
{
117132
diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha
118133
<< " #threads: " << num_threads << std::endl;
119134

120135
size_t data_num, data_dim;
121136
diskann::get_bin_metadata(data_path, data_num, data_dim);
137+
std::cout<<"Num diverse build: " << num_diverse_build << std::endl;
122138

123139
auto index_build_params = diskann::IndexWriteParametersBuilder(L, R)
124140
.with_filter_list_size(Lf)
125141
.with_alpha(alpha)
126142
.with_saturate_graph(false)
127143
.with_num_threads(num_threads)
144+
.with_diverse_index(diverse_index)
145+
.with_seller_file(seller_file)
146+
.with_num_diverse_build(num_diverse_build)
147+
.with_attribute_diversity(attribute_diversity)
148+
.with_attribute_file(attribute_file)
128149
.build();
129150

130151
auto filter_params = diskann::IndexFilterParamsBuilder()

apps/search_memory_index.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,9 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
153153
continue;
154154
}
155155

156-
query_result_ids[test_id].resize(recall_at * query_num);
157-
query_result_dists[test_id].resize(recall_at * query_num);
156+
query_result_ids[test_id].resize(recall_at * query_num, std::numeric_limits<uint32_t>::max());
157+
query_result_dists[test_id].resize(recall_at * query_num, std::numeric_limits<float>::max());
158+
158159
std::vector<T *> res = std::vector<T *>();
159160

160161
auto s = std::chrono::high_resolution_clock::now();
@@ -165,12 +166,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
165166
auto qs = std::chrono::high_resolution_clock::now();
166167
if (filtered_search && !tags)
167168
{
168-
std::string raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
169-
170-
auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, recall_at, L,
171-
query_result_ids[test_id].data() + i * recall_at,
172-
query_result_dists[test_id].data() + i * recall_at);
173-
cmp_stats[i] = retval.second;
169+
;
174170
}
175171
else if (metric == diskann::FAST_L2)
176172
{
@@ -198,11 +194,17 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
198194
}
199195
}
200196
else
201-
{
197+
{ std::vector<uint32_t> results(L,std::numeric_limits<uint32_t>::max());
198+
std::vector<float> dists(L,std::numeric_limits<float>::max());
199+
202200
cmp_stats[i] = index
203201
->search(query + i * query_aligned_dim, recall_at, L,
204-
query_result_ids[test_id].data() + i * recall_at)
202+
results.data(), dists.data())
205203
.second;
204+
for (uint32_t ctr = 0; ctr < std::min(results.size(),(uint64_t)recall_at); ctr++) {
205+
query_result_ids[test_id][recall_at * i + ctr] = results[ctr];
206+
query_result_dists[test_id][recall_at * i + ctr] = dists[ctr];
207+
}
206208
}
207209
auto qe = std::chrono::high_resolution_clock::now();
208210
std::chrono::duration<double> diff = qe - qs;

apps/utils/tsv_to_bin.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,21 @@ void block_convert_uint8(std::ifstream &reader, std::ofstream &writer, size_t np
6464
delete[] read_buf;
6565
}
6666

67+
size_t count_lines(const std::string &filename)
68+
{
69+
std::ifstream file(filename);
70+
size_t lines = 0;
71+
std::string line;
72+
while (std::getline(file, line))
73+
{
74+
++lines;
75+
}
76+
return lines;
77+
}
78+
6779
int main(int argc, char **argv)
6880
{
69-
if (argc != 6)
81+
if (argc != 5)
7082
{
7183
std::cout << argv[0]
7284
<< "<float/int8/uint8> input_filename.tsv output_filename.bin "
@@ -82,7 +94,7 @@ int main(int argc, char **argv)
8294
}
8395

8496
size_t ndims = atoi(argv[4]);
85-
size_t npts = atoi(argv[5]);
97+
size_t npts = count_lines(argv[2]);
8698

8799
std::ifstream reader(argv[2], std::ios::binary | std::ios::ate);
88100
// size_t fsize = reader.tellg();

include/index.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
402402
uint32_t _num_diverse_build = 1;
403403
uint32_t _max_L_per_seller = 0;
404404
std::vector<uint32_t> _location_to_seller;
405-
std::vector<std::vector<uint32_t>> _location_to_attributes;
405+
std::vector<std::vector<std::string>> _location_to_attributes;
406406
uint32_t _num_unique_sellers = 0;
407407
std::string _seller_file;
408408
std::string _attribute_file;

include/program_options_utils.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,11 @@ const char *UNIVERSAL_LABEL =
7777
"in the labels file instead of listing all labels for a node. DiskANN will not automatically assign a "
7878
"universal label to a node.";
7979
const char *FILTERED_LBUILD = "Build complexity for filtered points, higher value results in better graphs";
80+
const char *DIVERSE_INDEX = "Build Diverse Index";
81+
const char *ATTRIBUTE_DIVERSE_INDEX = "Build Attribute Diverse Index";
82+
const char *NUM_DIVERSE = "Number of diverse edges needed per node in each local region";
83+
const char *DIVERSITY_FILE = "Seller diversity file for diverse index";
84+
const char *ATTRIBUTE_DIVERSITY_FILE = "Attribute diversity file for diverse index";
85+
8086

8187
} // namespace program_options_utils

src/index.cpp

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -866,7 +866,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
866866

867867
std::vector<Neighbor> &expanded_nodes = scratch->pool();
868868

869-
if (!_diverse_index)
869+
if (!_diverse_index || _attribute_diversity)
870870
{
871871
NeighborPriorityQueue& best_L_nodes_ref = scratch->best_l_nodes();
872872
best_L_nodes_ref.reserve(Lsize);
@@ -1233,16 +1233,30 @@ void Index<T, TagT, LabelT>::prune_search_result(int location, std::vector<uint3
12331233
// #include <execution>
12341234

12351235

1236-
double attribute_distance(std::vector<uint32_t> &a, std::vector <uint32_t> &b)
1237-
{
1238-
std::vector<int> result(a.size());
1236+
// double attribute_distance(std::vector<std::string> &a, std::vector <std::string> &b)
1237+
// {
1238+
// std::vector<int> result(a.size());
12391239

1240-
std::transform(a.begin(), a.end(), b.begin(), result.begin(),
1241-
[](int x, int y) { return x == y ? 1 : 0; });
1240+
// std::transform(a.begin(), a.end(), b.begin(), result.begin(),
1241+
// [](std::string x, std::string y) { return x == y ? 1 : 0; });
1242+
1243+
// int sum = std::accumulate(result.begin(), result.end(), 0);
1244+
// double distance = sum * 1.0 / result.size();
1245+
// return 1 - distance;
1246+
// }
1247+
double attribute_distance(const std::vector<std::string> &a, const std::vector<std::string> &b)
1248+
{
1249+
// Use the smaller size to avoid out-of-bounds issues
1250+
size_t n = std::min(a.size(), b.size());
1251+
if(n == 0)
1252+
return 0.0; // or maybe 1.0 if you want full "distance" when both are empty
12421253

1243-
int sum = std::accumulate(result.begin(), result.end(), 0);
1244-
double distance = sum * 1.0 / result.size();
1245-
return 1 - distance;
1254+
int matches = 0;
1255+
for(size_t i = 0; i < n; ++i)
1256+
if(a[i] == b[i]) matches++;
1257+
1258+
double similarity = static_cast<double>(matches) / n;
1259+
return 1.0 - similarity;
12461260
}
12471261

12481262
template <typename T, typename TagT, typename LabelT>
@@ -1295,10 +1309,13 @@ void Index<T, TagT, LabelT>::occlude_list(const uint32_t location, std::vector<N
12951309
{
12961310
continue;
12971311
}
1298-
auto iter_seller = _location_to_seller[iter->id];
1299-
if (blockers[cur_index].find(iter_seller) != blockers[cur_index].end() && !_attribute_diversity)
1300-
{
1301-
continue;
1312+
1313+
if(!_attribute_diversity){
1314+
auto iter_seller = _location_to_seller[iter->id];
1315+
if (blockers[cur_index].find(iter_seller) != blockers[cur_index].end() && !_attribute_diversity)
1316+
{
1317+
continue;
1318+
}
13021319
}
13031320

13041321
if (blockers[cur_index].find((uint32_t)cur_index) != blockers[cur_index].end() && _attribute_diversity)
@@ -1333,11 +1350,13 @@ void Index<T, TagT, LabelT>::occlude_list(const uint32_t location, std::vector<N
13331350
{
13341351
continue;
13351352
}
1336-
1337-
auto iter2_seller = _location_to_seller[iter2->id];
1338-
if (blockers[t].find(iter2_seller) != blockers[t].end())
1339-
{
1340-
continue;
1353+
1354+
if(!_attribute_diversity){
1355+
auto iter2_seller = _location_to_seller[iter2->id];
1356+
if (blockers[t].find(iter2_seller) != blockers[t].end())
1357+
{
1358+
continue;
1359+
}
13411360
}
13421361
}
13431362

@@ -1369,6 +1388,7 @@ void Index<T, TagT, LabelT>::occlude_list(const uint32_t location, std::vector<N
13691388
blockers[t].insert(iter_seller);
13701389
}
13711390
else{
1391+
//double attr_dist = 0.6;
13721392
double attr_dist = attribute_distance(_location_to_attributes[iter2->id], _location_to_attributes[iter->id]);
13731393
if(attr_dist > 0.5){ // attribute distance threshold
13741394
blockers[t].insert(iter->id);
@@ -1388,7 +1408,7 @@ void Index<T, TagT, LabelT>::occlude_list(const uint32_t location, std::vector<N
13881408
if (y > cur_alpha * x)
13891409
{
13901410
occlude_factor[t] = std::max(occlude_factor[t], eps);
1391-
if (_diverse_index)
1411+
if (_diverse_index && !_attribute_diversity)
13921412
{
13931413
auto iter_seller = _location_to_seller[iter->id];
13941414
blockers[t].insert(iter_seller);
@@ -1553,7 +1573,9 @@ template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT
15531573

15541574
diskann::Timer link_timer;
15551575

1576+
try{
15561577
#pragma omp parallel for schedule(dynamic, 2048)
1578+
15571579
for (int64_t node_ctr = 0; node_ctr < (int64_t)(visit_order.size()); node_ctr++)
15581580
{
15591581
auto node = visit_order[node_ctr];
@@ -1567,8 +1589,12 @@ template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT
15671589
search_for_point_and_prune(node, _indexingQueueSize, pruned_list, scratch, true, _filterIndexingQueueSize);
15681590
}
15691591
else
1570-
{
1571-
search_for_point_and_prune(node, _indexingQueueSize, pruned_list, scratch);
1592+
{ try{
1593+
search_for_point_and_prune(node, _indexingQueueSize, pruned_list, scratch);
1594+
1595+
} catch (const std::exception &e) {
1596+
diskann::cerr << "Exception caught during link for node " << node << " : " << e.what() << std::endl;
1597+
}
15721598
}
15731599
assert(pruned_list.size() > 0);
15741600

@@ -1586,6 +1612,8 @@ template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT
15861612
diskann::cout << "\r" << (100.0 * node_ctr) / (visit_order.size()) << "% of index build completed."
15871613
<< std::flush;
15881614
}
1615+
}} catch (const std::exception &e) {
1616+
diskann::cerr << "Exception caught during link: " << e.what() << std::endl;
15891617
}
15901618

15911619
if (_nd > 0)
@@ -2272,14 +2300,14 @@ void Index<T, TagT, LabelT>::parse_attribute_file(const std::string &label_file,
22722300
std::istringstream iss(line);
22732301
getline(iss, token, '\t');
22742302
std::istringstream new_iss(token);
2275-
std::vector<uint32_t> attributes;
2303+
std::vector<std::string> attributes;
22762304

22772305
while (getline(new_iss, token, ','))
22782306
{
22792307
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
22802308
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
2281-
uint32_t token_as_num = (uint32_t)std::stoul(token);
2282-
attributes.push_back(token_as_num);
2309+
//uint32_t token_as_num = (uint32_t)std::stoul(token);
2310+
attributes.push_back(token);
22832311
}
22842312

22852313
_location_to_attributes.push_back(attributes);
@@ -2465,7 +2493,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search(const T *query, con
24652493
auto retval = iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true, maxLperSeller);
24662494

24672495
NeighborPriorityQueueBase* best_L_nodes;
2468-
if (!_diverse_index || maxLperSeller == 0) {
2496+
if (!_diverse_index || maxLperSeller == 0 || _attribute_diversity) {
24692497
best_L_nodes = &(scratch->best_l_nodes());
24702498
}
24712499
else {
@@ -2575,7 +2603,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
25752603
auto retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true, maxLperSeller);
25762604

25772605
NeighborPriorityQueueBase* best_L_nodes;
2578-
if (!_diverse_index) {
2606+
if (!_diverse_index || _attribute_diversity) {
25792607
best_L_nodes = &(scratch->best_l_nodes());
25802608
}
25812609
else {

0 commit comments

Comments
 (0)