Skip to content

Commit d303023

Browse files
author
Ananya Sutradhar
committed
multithread support - fixing Q Planning stats
1 parent 2a5ca2a commit d303023

File tree

5 files changed

+102
-74
lines changed

5 files changed

+102
-74
lines changed

apps/search_memory_index.cpp

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
197197
//query_num = 1;
198198
double best_recall = 0.0;
199199

200+
// Declare per-test method counters
201+
std::vector<uint32_t> test_num_brutes(Lvec.size(), 0);
202+
std::vector<uint32_t> test_num_graphs(Lvec.size(), 0);
203+
200204
// std::cout << "[PERF] About to start search_with_filters. Press Enter to continue..." << std::endl;
201205
// std::cin.get(); // Wait for user input
202206

@@ -227,7 +231,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
227231
// std::cin.get(); // Wait for user input
228232
//}
229233

230-
// Variables for OpenMP reductions (latency, distance comparisons, recall, and timing)
234+
// Variables for OpenMP reductions (latency, distance comparisons, recall, timing, and method counts)
231235
float local_brute_lat = 0.0f;
232236
float local_graph_lat = 0.0f;
233237
float local_brute_dist_cmp = 0.0f;
@@ -236,6 +240,8 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
236240
float local_graph_recall = 0.0f;
237241
double local_intersection_time = 0.0;
238242
double local_jaccard_time = 0.0;
243+
uint32_t local_num_brutes = 0;
244+
uint32_t local_num_graphs = 0;
239245

240246
auto s = std::chrono::high_resolution_clock::now();
241247
omp_set_num_threads(num_threads);
@@ -244,7 +250,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
244250
std::vector<std::vector<uint32_t>> temp_result_ids(query_num, std::vector<uint32_t>(K_or_L));
245251
std::vector<std::vector<float>> temp_result_dists(query_num, std::vector<float>(K_or_L));
246252

247-
#pragma omp parallel for schedule(dynamic, 1) reduction(+:local_brute_lat,local_graph_lat,local_brute_dist_cmp,local_graph_dist_cmp,local_brute_recall,local_graph_recall,local_intersection_time,local_jaccard_time)
253+
#pragma omp parallel for schedule(dynamic, 1) reduction(+:local_brute_lat,local_graph_lat,local_brute_dist_cmp,local_graph_dist_cmp,local_brute_recall,local_graph_recall,local_intersection_time,local_jaccard_time,local_num_brutes,local_num_graphs)
248254
for (int64_t i = 0; i < (int64_t)query_num; i++)
249255
{
250256
int method_used = 0; // 0 = brute force, 1 = graph search
@@ -266,42 +272,24 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
266272

267273
if (filtered_search && !tags)
268274
{
269-
uint32_t old_b, old_g, old_c;
270-
old_b = diskann::num_brutes;
271-
old_g = diskann::num_graphs;
272-
method_used = 0;
273275
std::vector<std::vector<std::string>> raw_filter = query_filters.size() == 1 ? query_filters[0] : query_filters[i];
274276

275277
// Reset global timing variables before search (thread-unsafe but better than accumulation)
276278
// Note: This is not fully thread-safe but provides more accurate per-query measurements
277279
diskann::curr_intersection_time = 0.0;
278280
diskann::curr_jaccard_time = 0.0;
279281

280-
auto retval = index->search_with_filters(query + i * query_aligned_dim, raw_filter, K_or_L, L,
282+
auto [num_candidates, distance_comparisons, method_type] = index->search_with_filters(query + i * query_aligned_dim, raw_filter, K_or_L, L,
281283
temp_result_ids[i].data(),
282284
temp_result_dists[i].data());
283285

284286
// Capture timing for this query (now represents only this query's timing)
285287
delta_intersection = diskann::curr_intersection_time;
286288
delta_jaccard = diskann::curr_jaccard_time;
287289

288-
//// Debug: print num_graphs after search
289-
//if (num_threads == 1 && i % 100 == 0) {
290-
// std::cout << "[DEBUG] Query " << i << ": search_with_filters completed, num_graphs = "<< diskann::num_graphs << std::endl;
291-
//}
292-
////std::cout << "[DEBUG] After search_with_filters: num_graphs = " << diskann::num_graphs << std::endl;
293-
294-
if (diskann::num_graphs > old_g) {
295-
method_used = 1; // Graph search
296-
// Debug the classifications
297-
// if (i < 5) std::cout << "[DEBUG] Query " << i << ": classified as GRAPH search" << std::endl;
298-
}
299-
else {
300-
method_used = 0; // Brute force
301-
// Debug the classifications
302-
// if (i < 5) std::cout << "[DEBUG] Query " << i << ": classified as BRUTE FORCE search" << std::endl;
303-
}
304-
cmp_stats[i] = retval.second;
290+
// Method information is now directly available from the search function - no race condition!
291+
method_used = method_type;
292+
cmp_stats[i] = distance_comparisons;
305293
// filter_match_time[i] = time_to_get_valid*1000000;
306294
// dist_cmp_time[i] = time_to_compare*1000000;
307295
}
@@ -385,12 +373,20 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
385373
local_brute_lat += latency_stats[i];
386374
local_brute_dist_cmp += (float)cmp_stats[i];
387375
local_brute_recall += cur_recall;
376+
local_num_brutes++;
377+
if (i < 10) { // Only debug first 10 queries to avoid spam
378+
379+
}
388380
break;
389381
case 1:
390382
query_result_class[test_id][i] = 1;
391383
local_graph_lat += latency_stats[i];
392384
local_graph_dist_cmp += (float)cmp_stats[i];
393385
local_graph_recall += cur_recall;
386+
local_num_graphs++;
387+
if (i < 10) { // Only debug first 10 queries to avoid spam
388+
389+
}
394390
break;
395391
}
396392
}
@@ -412,6 +408,18 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
412408
brute_recalls[test_id] = local_brute_recall;
413409
graph_recalls[test_id] = local_graph_recall;
414410

411+
// DEBUG: Print final values after OpenMP loop
412+
413+
414+
415+
416+
417+
418+
419+
420+
// Store method counts for this specific test
421+
test_num_brutes[test_id] = local_num_brutes;
422+
test_num_graphs[test_id] = local_num_graphs;
415423

416424
std::chrono::duration<double> diff = std::chrono::high_resolution_clock::now() - s;
417425

@@ -504,19 +512,29 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
504512
float avg_intersection_time_us = (float)((local_intersection_time * 1000000.0) / (double)query_num);
505513
float avg_jaccard_time_us = (float)((local_jaccard_time * 1000000.0) / (double)query_num);
506514

515+
// DEBUG: Print recall calculation values
516+
517+
518+
519+
520+
521+
522+
523+
524+
507525
std::cout << std::setw(4) << L << std::setw(4) << recall_at << std::setw(8) << displayed_qps << std::setw(18) << avg_cmps
508526
<< std::setw(20) << (float)mean_latency
509527
<< std::setw(20) << (float)latency_999
510528
<< std::setw(20) << (float)latency_99
511529
<< std::setw(20) << (float)latency_95
512530
<< std::setw(10) << (float)recalls[0]
513-
<< std::setw(22) << (float)(brute_dist_cmp[test_id] * 1.0) / (diskann::num_brutes.load() * 1.0)
514-
<< std::setw(22) << (float)(brute_lat[test_id] * 1.0) / (diskann::num_brutes.load() * 1.0)
515-
<< std::setw(20) << (float)(brute_recalls[test_id] * 100.0) / (diskann::num_brutes.load() * recall_at * 1.0)
531+
<< std::setw(22) << (float)(brute_dist_cmp[test_id] * 1.0) / (test_num_brutes[test_id] * 1.0)
532+
<< std::setw(22) << (float)(brute_lat[test_id] * 1.0) / (test_num_brutes[test_id] * 1.0)
533+
<< std::setw(20) << (float)(brute_recalls[test_id] * 100.0) / (test_num_brutes[test_id] * recall_at * 1.0)
516534
<< std::setw(22)
517-
<< (float)(graph_dist_cmp[test_id] * 1.0) / (diskann::num_graphs.load() * 1.0)
518-
<< std::setw(22) << (float)(graph_lat[test_id] * 1.0) / (diskann::num_graphs.load() * 1.0)
519-
<< std::setw(20) << (float)(graph_recalls[test_id] * 100.0) / (diskann::num_graphs.load() * recall_at * 1.0)
535+
<< (float)(graph_dist_cmp[test_id] * 1.0) / (test_num_graphs[test_id] * 1.0)
536+
<< std::setw(22) << (float)(graph_lat[test_id] * 1.0) / (test_num_graphs[test_id] * 1.0)
537+
<< std::setw(20) << (float)(graph_recalls[test_id] * 100.0) / (test_num_graphs[test_id] * recall_at * 1.0)
520538
// << std::setw(18) << filter_eval_time_us
521539
// << std::setw(18) << penalty_detection_time_us
522540
// << std::setw(18) << core_algo_time_us
@@ -525,8 +543,9 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
525543
<< std::endl;
526544
}
527545
}
528-
std::cout << "num_graphs " << diskann::num_graphs.load() << std::endl;
529-
std::cout << "num_brutes " << diskann::num_brutes.load() << std::endl;
546+
// Show per-L method counts (should be same for each L)
547+
std::cout << "num_graphs " << test_num_graphs[0] << std::endl;
548+
std::cout << "num_brutes " << test_num_brutes[0] << std::endl;
530549

531550
// Print detailed timing breakdown summary
532551
if (filtered_search) {
@@ -569,8 +588,8 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path,
569588
// std::cout << "Total filter overhead: " << total_filter_overhead << " μs/query" << std::endl;
570589

571590
std::cout << "Breakdown percentage:" << std::endl;
572-
std::cout << " Graph searches: " << (100.0 * diskann::num_graphs.load()) / query_num << "%" << std::endl;
573-
std::cout << " Brute force searches: " << (100.0 * diskann::num_brutes.load()) / query_num << "%" << std::endl;
591+
std::cout << " Graph searches: " << (100.0 * test_num_graphs[0]) / query_num << "%" << std::endl;
592+
std::cout << " Brute force searches: " << (100.0 * test_num_brutes[0]) / query_num << "%" << std::endl;
574593
std::cout << "=================================" << std::endl;
575594
}
576595

include/abstract_index.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ class AbstractIndex
7575
// Filter support search
7676
// IndexType is either uint32_t or uint64_t
7777
template <typename IndexType>
78-
std::pair<uint32_t, uint32_t> search_with_filters(const DataType &query, const std::string &raw_label,
78+
std::tuple<uint32_t, uint32_t, uint32_t> search_with_filters(const DataType &query, const std::string &raw_label,
7979
const size_t K, const uint32_t L, IndexType *indices,
8080
float *distances);
8181
template <typename IndexType>
82-
std::pair<uint32_t, uint32_t> search_with_filters(const DataType &query, const std::vector<std::vector<std::string>> &raw_label,
82+
std::tuple<uint32_t, uint32_t, uint32_t> search_with_filters(const DataType &query, const std::vector<std::vector<std::string>> &raw_label,
8383
const size_t K, const uint32_t L, IndexType *indices,
8484
float *distances);
8585

@@ -114,7 +114,7 @@ class AbstractIndex
114114
virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) = 0;
115115
virtual std::pair<uint32_t, uint32_t> _search(const DataType &query, const size_t K, const uint32_t L,
116116
std::any &indices, float *distances = nullptr) = 0;
117-
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query,
117+
virtual std::tuple<uint32_t, uint32_t, uint32_t> _search_with_filters(const DataType &query,
118118
const std::vector<std::vector<std::string>> &filter_label,
119119
const size_t K, const uint32_t L, std::any &indices,
120120
float *distances) = 0;

include/index.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
190190
const std::string filter_label = "");
191191

192192
template <typename IndexType>
193-
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(const T *query,
193+
DISKANN_DLLEXPORT std::tuple<uint32_t, uint32_t, uint32_t> search_with_filters(const T *query,
194194
const std::vector<std::vector<LabelT>> &filter_label,
195195
const size_t K, const uint32_t L,
196196
IndexType *indices, float *distances);
@@ -256,7 +256,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas
256256

257257
virtual std::pair<uint32_t, uint32_t> _search(const DataType &query, const size_t K, const uint32_t L,
258258
std::any &indices, float *distances = nullptr) override;
259-
virtual std::pair<uint32_t, uint32_t> _search_with_filters(const DataType &query,
259+
virtual std::tuple<uint32_t, uint32_t, uint32_t> _search_with_filters(const DataType &query,
260260
const std::vector<std::vector<std::string>> &filter_label_raw,
261261
const size_t K, const uint32_t L, std::any &indices,
262262
float *distances) override;

src/abstract_index.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K,
3434
}
3535

3636
template <typename IndexType>
37-
std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters(const DataType &query, const std::string &raw_label,
37+
std::tuple<uint32_t, uint32_t, uint32_t> AbstractIndex::search_with_filters(const DataType &query, const std::string &raw_label,
3838
const size_t K, const uint32_t L, IndexType *indices,
3939
float *distances)
4040
{
@@ -47,7 +47,7 @@ std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters(const DataType
4747
}
4848

4949
template <typename IndexType>
50-
std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters(const DataType &query,
50+
std::tuple<uint32_t, uint32_t, uint32_t> AbstractIndex::search_with_filters(const DataType &query,
5151
const std::vector<std::vector<std::string>> &raw_label,
5252
const size_t K, const uint32_t L, IndexType *indices,
5353
float *distances)
@@ -169,18 +169,18 @@ template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search<u
169169
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search<int8_t, uint64_t>(
170170
const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances);
171171

172-
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters<uint32_t>(
172+
template DISKANN_DLLEXPORT std::tuple<uint32_t, uint32_t, uint32_t> AbstractIndex::search_with_filters<uint32_t>(
173173
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint32_t *indices,
174174
float *distances);
175175

176-
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters<uint64_t>(
176+
template DISKANN_DLLEXPORT std::tuple<uint32_t, uint32_t, uint32_t> AbstractIndex::search_with_filters<uint64_t>(
177177
const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint64_t *indices,
178178
float *distances);
179-
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters<uint32_t>(
179+
template DISKANN_DLLEXPORT std::tuple<uint32_t, uint32_t, uint32_t> AbstractIndex::search_with_filters<uint32_t>(
180180
const DataType &query, const std::vector<std::vector<std::string>> &raw_label, const size_t K, const uint32_t L,
181181
uint32_t *indices, float *distances);
182182

183-
template DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> AbstractIndex::search_with_filters<uint64_t>(
183+
template DISKANN_DLLEXPORT std::tuple<uint32_t, uint32_t, uint32_t> AbstractIndex::search_with_filters<uint64_t>(
184184
const DataType &query, const std::vector<std::vector<std::string>> &raw_label, const size_t K, const uint32_t L,
185185
uint64_t *indices, float *distances);
186186

0 commit comments

Comments
 (0)