@@ -866,7 +866,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::iterate_to_fixed_point(
866866
867867 std::vector<Neighbor> &expanded_nodes = scratch->pool ();
868868
869- if (!_diverse_index)
869+ if (!_diverse_index || _attribute_diversity )
870870 {
871871 NeighborPriorityQueue& best_L_nodes_ref = scratch->best_l_nodes ();
872872 best_L_nodes_ref.reserve (Lsize);
@@ -1233,16 +1233,30 @@ void Index<T, TagT, LabelT>::prune_search_result(int location, std::vector<uint3
12331233// #include <execution>
12341234
12351235
1236- double attribute_distance (std::vector<uint32_t > &a, std::vector <uint32_t > &b)
1237- {
1238- std::vector<int > result (a.size ());
1236+ // double attribute_distance(std::vector<std::string > &a, std::vector <std::string > &b)
1237+ // {
1238+ // std::vector<int> result(a.size());
12391239
1240- std::transform (a.begin (), a.end (), b.begin (), result.begin (),
1241- [](int x, int y) { return x == y ? 1 : 0 ; });
1240+ // std::transform(a.begin(), a.end(), b.begin(), result.begin(),
1241+ // [](std::string x, std::string y) { return x == y ? 1 : 0; });
1242+
1243+ // int sum = std::accumulate(result.begin(), result.end(), 0);
1244+ // double distance = sum * 1.0 / result.size();
1245+ // return 1 - distance;
1246+ // }
1247+ double attribute_distance (const std::vector<std::string> &a, const std::vector<std::string> &b)
1248+ {
1249+ // Use the smaller size to avoid out-of-bounds issues
1250+ size_t n = std::min (a.size (), b.size ());
1251+ if (n == 0 )
1252+ return 0.0 ; // or maybe 1.0 if you want full "distance" when both are empty
12421253
1243- int sum = std::accumulate (result.begin (), result.end (), 0 );
1244- double distance = sum * 1.0 / result.size ();
1245- return 1 - distance;
1254+ int matches = 0 ;
1255+ for (size_t i = 0 ; i < n; ++i)
1256+ if (a[i] == b[i]) matches++;
1257+
1258+ double similarity = static_cast <double >(matches) / n;
1259+ return 1.0 - similarity;
12461260}
12471261
12481262template <typename T, typename TagT, typename LabelT>
@@ -1295,10 +1309,13 @@ void Index<T, TagT, LabelT>::occlude_list(const uint32_t location, std::vector<N
12951309 {
12961310 continue ;
12971311 }
1298- auto iter_seller = _location_to_seller[iter->id ];
1299- if (blockers[cur_index].find (iter_seller) != blockers[cur_index].end () && !_attribute_diversity)
1300- {
1301- continue ;
1312+
1313+ if (!_attribute_diversity){
1314+ auto iter_seller = _location_to_seller[iter->id ];
1315+ if (blockers[cur_index].find (iter_seller) != blockers[cur_index].end () && !_attribute_diversity)
1316+ {
1317+ continue ;
1318+ }
13021319 }
13031320
13041321 if (blockers[cur_index].find ((uint32_t )cur_index) != blockers[cur_index].end () && _attribute_diversity)
@@ -1333,11 +1350,13 @@ void Index<T, TagT, LabelT>::occlude_list(const uint32_t location, std::vector<N
13331350 {
13341351 continue ;
13351352 }
1336-
1337- auto iter2_seller = _location_to_seller[iter2->id ];
1338- if (blockers[t].find (iter2_seller) != blockers[t].end ())
1339- {
1340- continue ;
1353+
1354+ if (!_attribute_diversity){
1355+ auto iter2_seller = _location_to_seller[iter2->id ];
1356+ if (blockers[t].find (iter2_seller) != blockers[t].end ())
1357+ {
1358+ continue ;
1359+ }
13411360 }
13421361 }
13431362
@@ -1369,6 +1388,7 @@ void Index<T, TagT, LabelT>::occlude_list(const uint32_t location, std::vector<N
13691388 blockers[t].insert (iter_seller);
13701389 }
13711390 else {
1391+ // double attr_dist = 0.6;
13721392 double attr_dist = attribute_distance (_location_to_attributes[iter2->id ], _location_to_attributes[iter->id ]);
13731393 if (attr_dist > 0.5 ){ // attribute distance threshold
13741394 blockers[t].insert (iter->id );
@@ -1388,7 +1408,7 @@ void Index<T, TagT, LabelT>::occlude_list(const uint32_t location, std::vector<N
13881408 if (y > cur_alpha * x)
13891409 {
13901410 occlude_factor[t] = std::max (occlude_factor[t], eps);
1391- if (_diverse_index)
1411+ if (_diverse_index && !_attribute_diversity )
13921412 {
13931413 auto iter_seller = _location_to_seller[iter->id ];
13941414 blockers[t].insert (iter_seller);
@@ -1553,7 +1573,9 @@ template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT
15531573
15541574 diskann::Timer link_timer;
15551575
1576+ try {
15561577#pragma omp parallel for schedule(dynamic, 2048)
1578+
15571579 for (int64_t node_ctr = 0 ; node_ctr < (int64_t )(visit_order.size ()); node_ctr++)
15581580 {
15591581 auto node = visit_order[node_ctr];
@@ -1567,8 +1589,12 @@ template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT
15671589 search_for_point_and_prune (node, _indexingQueueSize, pruned_list, scratch, true , _filterIndexingQueueSize);
15681590 }
15691591 else
1570- {
1571- search_for_point_and_prune (node, _indexingQueueSize, pruned_list, scratch);
1592+ { try {
1593+ search_for_point_and_prune (node, _indexingQueueSize, pruned_list, scratch);
1594+
1595+ } catch (const std::exception &e) {
1596+ diskann::cerr << " Exception caught during link for node " << node << " : " << e.what () << std::endl;
1597+ }
15721598 }
15731599 assert (pruned_list.size () > 0 );
15741600
@@ -1586,6 +1612,8 @@ template <typename T, typename TagT, typename LabelT> void Index<T, TagT, LabelT
15861612 diskann::cout << " \r " << (100.0 * node_ctr) / (visit_order.size ()) << " % of index build completed."
15871613 << std::flush;
15881614 }
1615+ }} catch (const std::exception &e) {
1616+ diskann::cerr << " Exception caught during link: " << e.what () << std::endl;
15891617 }
15901618
15911619 if (_nd > 0 )
@@ -2272,14 +2300,14 @@ void Index<T, TagT, LabelT>::parse_attribute_file(const std::string &label_file,
22722300 std::istringstream iss (line);
22732301 getline (iss, token, ' \t ' );
22742302 std::istringstream new_iss (token);
2275- std::vector<uint32_t > attributes;
2303+ std::vector<std::string > attributes;
22762304
22772305 while (getline (new_iss, token, ' ,' ))
22782306 {
22792307 token.erase (std::remove (token.begin (), token.end (), ' \n ' ), token.end ());
22802308 token.erase (std::remove (token.begin (), token.end (), ' \r ' ), token.end ());
2281- uint32_t token_as_num = (uint32_t )std::stoul (token);
2282- attributes.push_back (token_as_num );
2309+ // uint32_t token_as_num = (uint32_t)std::stoul(token);
2310+ attributes.push_back (token );
22832311 }
22842312
22852313 _location_to_attributes.push_back (attributes);
@@ -2465,7 +2493,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search(const T *query, con
24652493 auto retval = iterate_to_fixed_point (scratch, L, init_ids, false , unused_filter_label, true , maxLperSeller);
24662494
24672495 NeighborPriorityQueueBase* best_L_nodes;
2468- if (!_diverse_index || maxLperSeller == 0 ) {
2496+ if (!_diverse_index || maxLperSeller == 0 || _attribute_diversity ) {
24692497 best_L_nodes = &(scratch->best_l_nodes ());
24702498 }
24712499 else {
@@ -2575,7 +2603,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
25752603 auto retval = iterate_to_fixed_point (scratch, L, init_ids, true , filter_vec, true , maxLperSeller);
25762604
25772605 NeighborPriorityQueueBase* best_L_nodes;
2578- if (!_diverse_index) {
2606+ if (!_diverse_index || _attribute_diversity ) {
25792607 best_L_nodes = &(scratch->best_l_nodes ());
25802608 }
25812609 else {
0 commit comments