Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bindings/python/tests/test_ivf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
test_groundtruth_cosine, \
test_ivf_reference, \
test_ivf_clustering, \
test_number_of_clusters, \
test_number_of_vectors, \
test_dimensions, \
timed, \
get_test_set, \
Expand Down Expand Up @@ -167,7 +167,7 @@ def _test_basic_inner(
test_get_distance(ivf, svs.DistanceType.L2, data)

# Data interface
self.assertEqual(ivf.size, test_number_of_clusters)
self.assertEqual(ivf.size, test_number_of_vectors)

# The dimensionality exposed by the index should always match the original
# dataset dimensions.
Expand Down
97 changes: 95 additions & 2 deletions include/svs/index/ivf/dynamic_ivf.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@

namespace svs::index::ivf {

// Forward declaration of BatchIterator (already declared in index.h, but redeclaring for
// clarity)
template <typename Index, typename QueryType> class BatchIterator;

///
/// Metadata tracking the state of a particular data index for DynamicIVFIndex.
/// The following states have the given meaning for their corresponding slot:
Expand Down Expand Up @@ -79,6 +83,14 @@ class DynamicIVFIndex {
using InterQueryThreadPool = threads::ThreadPoolHandle;
using IntraQueryThreadPool = threads::DefaultThreadPool;

// Reuse scratchspace types from static IVF
using buffer_centroids_type = SortedBuffer<Idx, compare>;
using buffer_leaves_type = std::vector<SortedBuffer<Idx, compare>>;
using inner_scratch_type =
svs::tag_t<extensions::per_thread_batch_search_setup>::result_t<Data, Dist>;
using scratchspace_type =
ivf::IVFScratchspace<buffer_centroids_type, buffer_leaves_type, inner_scratch_type>;

private:
// Core IVF components (same structure as static IVF)
centroids_type centroids_;
Expand All @@ -98,7 +110,7 @@ class DynamicIVFIndex {
// Threading infrastructure (same as static IVF)
InterQueryThreadPool inter_query_threadpool_;
const size_t intra_query_thread_count_;
std::vector<IntraQueryThreadPool> intra_query_threadpools_;
mutable std::vector<IntraQueryThreadPool> intra_query_threadpools_;

// Search infrastructure (same as static IVF)
std::vector<data::SimpleData<float>> matmul_results_;
Expand Down Expand Up @@ -337,6 +349,87 @@ class DynamicIVFIndex {
);
}

/// @brief Return scratch space resources for external threading
/// @param sp Search parameters to configure the scratchspace
/// @param num_neighbors Number of neighbors to return (default: 10)
scratchspace_type
scratchspace(const search_parameters_type& sp, size_t num_neighbors = 10) const {
size_t buffer_leaves_size =
static_cast<size_t>(sp.k_reorder_ * static_cast<float>(num_neighbors));
return scratchspace_type{
create_centroid_buffer(sp.n_probes_),
create_leaf_buffers(buffer_leaves_size),
extensions::per_thread_batch_search_setup(clusters_[0].data_, distance_)};
}

/// @brief Return scratch space resources for external threading with default parameters
scratchspace_type scratchspace() const { return scratchspace(search_parameters_); }

/// @brief Perform a nearest neighbor search for a single query using provided scratch
/// space
///
/// Operations performed:
/// * Compute centroid distances for the single query
/// * Search centroids to find n_probes nearest clusters
/// * Search within selected clusters to find k nearest neighbors
///
/// Results will be present in the scratch.buffer_leaves[0] data structure.
/// The caller is responsible for extracting and processing results.
/// Results will contain internal IDs - use translate_to_external() to convert to
/// external IDs.
///
/// **Note**: It is the caller's responsibility to ensure that the scratch space has
/// been initialized properly to return the requested number of neighbors.
///
template <typename Query> void search(const Query& query, scratchspace_type& scratch) {
// Compute centroid distances for the single query
// Create a 1-query view and compute matmul_results
auto query_view = data::ConstSimpleDataView<float>(query.data(), 1, query.size());
compute_centroid_distances(
query_view, centroids_, matmul_results_, inter_query_threadpool_
);

// Wrapper lambdas that drop query_idx and tid parameters
auto search_centroids_fn = [&](const auto& q, auto& buf) {
search_centroids_closure()(q, buf, 0);
};
auto search_leaves_fn =
[&](const auto& q, auto& dist, const auto& buf_cent, auto& buf_leaves) {
search_leaves_closure()(q, dist, buf_cent, buf_leaves, 0);
};

extensions::single_search(
clusters_[0].data_,
*this,
scratch.buffer_centroids,
scratch.buffer_leaves,
scratch.scratch,
query,
search_centroids_fn,
search_leaves_fn
);
}

///// Batch Iterator /////

/// @brief Create a batch iterator for retrieving neighbors in batches.
///
/// The iterator allows incremental retrieval of neighbors, expanding the search
/// space on each call to `next()`. This is useful for applications that need
/// to process neighbors in batches or implement early termination.
///
/// @tparam QueryType The element type of the query vector.
/// @param query The query vector as a span.
/// @param extra_search_buffer_capacity Additional buffer capacity for the search.
/// @return A BatchIterator for the given query.
///
template <typename QueryType>
auto make_batch_iterator(
std::span<const QueryType> query, size_t extra_search_buffer_capacity = 0
) {
return BatchIterator(*this, query, extra_search_buffer_capacity);
}

/// @brief Iterate over all external IDs
template <typename F> void on_ids(F&& f) const {
for (size_t i = 0; i < status_.size(); ++i) {
Expand Down Expand Up @@ -860,7 +953,7 @@ class DynamicIVFIndex {
}

/// @brief Create closure for searching clusters/leaves
auto search_leaves_closure() {
auto search_leaves_closure() const {
return [this](
const auto& query,
auto& distance,
Expand Down
89 changes: 89 additions & 0 deletions include/svs/index/ivf/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,95 @@ Distance svs_invoke(
return threads::shallow_copy(distance);
}

///
/// @brief Customization point for single query search.
///
struct IVFSingleSearchType {
template <
typename Data,
typename Cluster,
typename BufferCentroids,
typename BufferLeaves,
typename Scratch,
typename Query,
typename SearchCentroids,
typename SearchLeaves>
void operator()(
const Data& data,
const Cluster& cluster,
BufferCentroids& buffer_centroids,
BufferLeaves& buffer_leaves,
Scratch& scratch,
const Query& query,
const SearchCentroids& search_centroids,
const SearchLeaves& search_leaves
) const {
svs::svs_invoke(
*this,
data,
cluster,
buffer_centroids,
buffer_leaves,
scratch,
query,
search_centroids,
search_leaves
);
}
};

inline constexpr IVFSingleSearchType single_search{};

// Default implementation for single query search
template <
typename Data,
typename Cluster,
typename BufferCentroids,
typename BufferLeaves,
typename Distance,
typename Query,
typename SearchCentroids,
typename SearchLeaves>
void svs_invoke(
svs::tag_t<single_search>,
const Data& SVS_UNUSED(data),
const Cluster& cluster,
BufferCentroids& buffer_centroids,
BufferLeaves& buffer_leaves,
Distance& distance,
const Query& query,
const SearchCentroids& search_centroids,
const SearchLeaves& search_leaves
) {
size_t n_inner_threads = buffer_leaves.size();
size_t buffer_leaves_size = buffer_leaves[0].capacity();

// Search centroids to find nearest clusters
search_centroids(query, buffer_centroids);

// Search within selected clusters
search_leaves(query, distance, buffer_centroids, buffer_leaves);

// Accumulate results from intra-query threads into buffer_leaves[0]
for (size_t j = 1; j < n_inner_threads; ++j) {
for (size_t k = 0; k < buffer_leaves_size; ++k) {
buffer_leaves[0].insert(buffer_leaves[j][k]);
}
}

// Sort buffer to get valid results in order
buffer_leaves[0].sort();

// Convert (cluster_id, local_id) to global_id
for (size_t j = 0; j < buffer_leaves_size; ++j) {
auto& neighbor = buffer_leaves[0][j];
auto cluster_id = neighbor.id();
auto local_id = neighbor.get_local_id();
auto global_id = cluster.get_global_id(cluster_id, local_id);
neighbor.set_id(global_id);
}
}

///
/// @brief Customization point for working with a batch of threads.
///
Expand Down
Loading
Loading