diff --git a/bindings/python/CMakeLists.txt b/bindings/python/CMakeLists.txt index 2733d32ae..ca37f12b0 100644 --- a/bindings/python/CMakeLists.txt +++ b/bindings/python/CMakeLists.txt @@ -43,6 +43,7 @@ set(CPP_FILES # ivf if (SVS_EXPERIMENTAL_ENABLE_IVF) list(APPEND CPP_FILES + src/dynamic_ivf.cpp src/ivf.cpp ) endif() diff --git a/bindings/python/include/svs/python/dynamic_ivf.h b/bindings/python/include/svs/python/dynamic_ivf.h new file mode 100644 index 000000000..48486d7a6 --- /dev/null +++ b/bindings/python/include/svs/python/dynamic_ivf.h @@ -0,0 +1,39 @@ +/* + * Copyright 2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// svs python bindings +#include "svs/python/core.h" + +#include + +namespace svs::python::dynamic_ivf { + +// Specializations +template void for_standard_specializations(F&& f) { +#define X(Q, T, N) f.template operator()() + X(float, float, Dynamic); + X(float, float, Dynamic); + X(float, svs::Float16, Dynamic); + X(float, svs::Float16, Dynamic); + X(float, svs::BFloat16, Dynamic); + X(float, svs::BFloat16, Dynamic); +#undef X +} + +void wrap(pybind11::module& m); +} // namespace svs::python::dynamic_ivf diff --git a/bindings/python/include/svs/python/ivf.h b/bindings/python/include/svs/python/ivf.h index 936c715bf..ab9be6449 100644 --- a/bindings/python/include/svs/python/ivf.h +++ b/bindings/python/include/svs/python/ivf.h @@ -20,7 +20,9 @@ #include "svs/python/common.h" #include "svs/python/core.h" +#include "svs/core/data/simple.h" #include "svs/core/distance.h" +#include "svs/index/ivf/clustering.h" #include "svs/lib/bfloat16.h" #include "svs/lib/datatype.h" #include "svs/lib/float16.h" @@ -30,6 +32,8 @@ #include #include +#include + namespace svs::python { namespace ivf_specializations { /// @@ -61,6 +65,18 @@ template void for_standard_specializations(F&& f) { } // namespace ivf_specializations namespace ivf { + +// The build process in IVF uses Kmeans to get centroids and assignments of data. +// This sparse clustering can be saved with centroids stored as float datatype. +// While assembling, the sparse clustering is used to create DenseClusters and +// centroids datatype can be changed as per the search specializations. +// Support both BFloat16 and Float16 centroids to match data types and leverage AMX. +using ClusteringBF16 = + svs::index::ivf::Clustering, uint32_t>; +using ClusteringF16 = + svs::index::ivf::Clustering, uint32_t>; +using Clustering = std::variant; + template void add_interface(pybind11::class_& manager) { manager.def_property_readonly( "experimental_backend_string", @@ -82,6 +98,28 @@ template void add_interface(pybind11::class_& manage See also: `svs.IVFSearchParameters`.)" ); + + manager.def( + "get_distance", + [](const Manager& index, size_t id, const py_contiguous_array_t& query) { + return index.get_distance(id, as_span(query)); + }, + pybind11::arg("id"), + pybind11::arg("query"), + R"( + Compute the distance between a query vector and a vector in the index. + + Args: + id: The ID of the vector in the index. + query: The query vector as a numpy array. + + Returns: + The distance between the query and the indexed vector. + + Raises: + RuntimeError: If the ID doesn't exist or dimensions don't match. + )" + ); } void wrap(pybind11::module& m); diff --git a/bindings/python/src/dynamic_ivf.cpp b/bindings/python/src/dynamic_ivf.cpp new file mode 100644 index 000000000..52950ad38 --- /dev/null +++ b/bindings/python/src/dynamic_ivf.cpp @@ -0,0 +1,535 @@ +/* + * Copyright 2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// svs python bindings +#include "svs/python/dynamic_ivf.h" +#include "svs/python/common.h" +#include "svs/python/core.h" +#include "svs/python/ivf.h" +#include "svs/python/manager.h" + +// svs +#include "svs/lib/dispatcher.h" +#include "svs/orchestrators/dynamic_ivf.h" + +// pybind +#include +#include +#include + +// fmt +#include + +// stl +#include + +///// +///// DynamicIVF +///// + +namespace py = pybind11; +namespace svs::python::dynamic_ivf { + +// Reuse the Clustering type from static IVF since clustering is the same +using Clustering = svs::python::ivf::Clustering; + +using IVFAssembleTypes = + std::variant; + +///// +///// Dispatch Invocation +///// + +///// +///// Assembly from Clustering +///// + +template +svs::DynamicIVF assemble_uncompressed( + Clustering clustering, + svs::VectorDataLoader> data, + std::span ids, + svs::DistanceType distance_type, + size_t num_threads, + size_t intra_query_threads = 1 +) { + // Use std::visit to handle the variant clustering type + return std::visit( + [&](auto&& actual_clustering) { + return svs::DynamicIVF::assemble_from_clustering( + std::move(actual_clustering), + std::move(data), + ids, + distance_type, + num_threads, + intra_query_threads + ); + }, + std::move(clustering) + ); +} + +template +void register_uncompressed_ivf_assemble(Dispatcher& dispatcher) { + for_standard_specializations([&dispatcher]() { + auto method = &assemble_uncompressed; + dispatcher.register_target(svs::lib::dispatcher_build_docs, method); + }); +} + +template void register_ivf_assembly(Dispatcher& dispatcher) { + register_uncompressed_ivf_assemble(dispatcher); +} + +///// +///// Assembly from File +///// +template +svs::DynamicIVF assemble_from_file_uncompressed( + const std::filesystem::path& cluster_path, + svs::VectorDataLoader> data, + std::span ids, + svs::DistanceType distance_type, + size_t num_threads, + size_t intra_query_threads = 1 +) { + return svs::DynamicIVF::assemble_from_file( + cluster_path, std::move(data), ids, distance_type, num_threads, intra_query_threads + ); +} + +template +void register_uncompressed_ivf_assemble_from_file(Dispatcher& dispatcher) { + for_standard_specializations([&dispatcher]() { + auto method = &assemble_from_file_uncompressed; + dispatcher.register_target(svs::lib::dispatcher_build_docs, method); + }); +} + +template +void register_ivf_assembly_from_file(Dispatcher& dispatcher) { + register_uncompressed_ivf_assemble_from_file(dispatcher); +} + +using IVFAssembleTypes = + std::variant; + +///// +///// Dispatch Invocation +///// + +using AssemblyDispatcher = svs::lib::Dispatcher< + svs::DynamicIVF, + Clustering, + IVFAssembleTypes, + std::span, + svs::DistanceType, + size_t, + size_t>; + +AssemblyDispatcher assembly_dispatcher() { + auto dispatcher = AssemblyDispatcher{}; + + // Register available backend methods. + register_ivf_assembly(dispatcher); + return dispatcher; +} + +// Assemble +svs::DynamicIVF assemble_from_clustering( + Clustering clustering, + IVFAssembleTypes data_kind, + const py_contiguous_array_t& py_ids, + svs::DistanceType distance_type, + svs::DataType SVS_UNUSED(query_type), + bool SVS_UNUSED(enforce_dims), + size_t num_threads, + size_t intra_query_threads = 1 +) { + auto ids = std::span(py_ids.data(), py_ids.size()); + return assembly_dispatcher().invoke( + std::move(clustering), + std::move(data_kind), + ids, + distance_type, + num_threads, + intra_query_threads + ); +} + +using AssemblyFromFileDispatcher = svs::lib::Dispatcher< + svs::DynamicIVF, + const std::filesystem::path&, + IVFAssembleTypes, + std::span, + svs::DistanceType, + size_t, + size_t>; + +AssemblyFromFileDispatcher assembly_from_file_dispatcher() { + auto dispatcher = AssemblyFromFileDispatcher{}; + + // Register available backend methods. + register_ivf_assembly_from_file(dispatcher); + return dispatcher; +} + +// Assemble from file +svs::DynamicIVF assemble_from_file( + const std::string& cluster_path, + IVFAssembleTypes data_kind, + const py_contiguous_array_t& py_ids, + svs::DistanceType distance_type, + svs::DataType SVS_UNUSED(query_type), + bool SVS_UNUSED(enforce_dims), + size_t num_threads, + size_t intra_query_threads = 1 +) { + auto ids = std::span(py_ids.data(), py_ids.size()); + return assembly_from_file_dispatcher().invoke( + cluster_path, + std::move(data_kind), + ids, + distance_type, + num_threads, + intra_query_threads + ); +} + +constexpr std::string_view ASSEMBLE_DOCSTRING_PROTO = R"( +Assemble a searchable IVF index from provided clustering and data + +Args: + clustering_path/clustering: Path to the directory where the clustering was generated. + OR directly provide the loaded Clustering. + data_loader: The loader for the dataset. See comment below for accepted types. + ids: External IDs for the vectors. Must match dataset length and contain unique values. + distance: The distance function to use. + query_type: The data type of the queries. + enforce_dims: Require that the compiled dimensionality of the returned index matches + the dimensionality provided in the ``data_loader`` argument. If a match is not + found, an exception is thrown. + + This is meant to ensure that specialized dimensionality is provided without falling + back to generic implementations. Leaving the ``dims`` out when constructing the + ``data_loader`` will with `enable_dims = True` will always attempt to use a generic + implementation. + num_threads: The number of threads to use for queries (can't be changed after loading). + intra_query_threads: (default: 1) these many threads work on a single query. + Total number of threads required = ``query_batch_size`` * ``intra_query_threads``. + Where ``query_batch_size`` is the number of queries processed in parallel. + Use this parameter only when the ``query_batch_size`` is smaller and ensure your + system has sufficient threads available. Set ``num_threads`` = ``query_batch_size`` + +The top level type is an abstract type backed by various specialized backends that will +be instantiated based on their applicability to the particular problem instance. + +The arguments upon which specialization is conducted are: + +* `data_loader`: Both kind (type of loader) and inner aspects of the loader like data type, + quantization type, and number of dimensions. +* `distance`: The distance measure being used. + +Specializations compiled into the binary are listed below. + +{} +)"; + +///// +///// Add points +///// + +template +void add_points( + svs::DynamicIVF& index, + const py_contiguous_array_t& py_data, + const py_contiguous_array_t& ids, + bool reuse_empty = false +) { + if (py_data.ndim() != 2) { + throw ANNEXCEPTION("Expected points to have 2 dimensions!"); + } + if (ids.ndim() != 1) { + throw ANNEXCEPTION("Expected ids to have 1 dimension!"); + } + if (py_data.shape(0) != ids.shape(0)) { + throw ANNEXCEPTION( + "Expected IDs to be the same length as the number of rows in points!" + ); + } + index.add_points(data_view(py_data), std::span(ids.data(), ids.size()), reuse_empty); +} + +const char* ADD_POINTS_DOCSTRING = R"( +Add every point in ``points`` to the index, assigning the element-wise corresponding ID to +each point. + +Args: + points: A matrix of data whose rows, corresponding to points in R^n, will be added to + the index. + ids: Vector of ids to assign to each row in ``points``. Must have the same number of + elements as ``points`` has rows. + reuse_empty: A flag that determines whether to reuse empty entries that may exist after deletion and consolidation. When enabled, + scan from the beginning to find and fill these empty entries when adding new points. + +Furthermore, all entries in ``ids`` must be unique and not already exist in the index. +If either of these does not hold, an exception will be thrown without mutating the +underlying index. + +When ``delete`` is called, a soft deletion is performed, marking the entries as ``deleted``. +When ``consolidate`` is called, the state of these deleted entries becomes ``empty``. +When ``add_points`` is called with the ``reuse_empty`` flag enabled, the memory is scanned from the beginning to locate and fill these empty entries with new points. +)"; + +template +void add_points_specialization(py::class_& index) { + index.def( + "add", + &add_points, + py::arg("points"), + py::arg("ids"), + py::arg("reuse_empty") = false, + ADD_POINTS_DOCSTRING + ); +} + +const char* CONSOLIDATE_DOCSTRING = R"( +No-op method for compatibility with dynamic index interface. +For the IVF index, deletion marks entries as Empty and they are excluded from searches. +Empty slots can be reused when adding new points. +)"; + +const char* COMPACT_DOCSTRING = R"( +Remove any holes created in the data by renumbering internal IDs. +Shrink the underlying data structures. +This can potentially reduce the memory footprint of the index +if a sufficient number of points were deleted. +)"; + +const char* DELETE_DOCSTRING = R"( +Soft delete the IDs from the index. Soft deletion does not remove the IDs from the index, +but prevents them from being returned from future searches. + +Args: + ids: The IDs to delete. + +Each element in IDs must be unique and must correspond to a valid ID stored in the index. +Otherwise, an exception will be thrown. If an exception is thrown for this reason, the +index will be left unchanged from before the function call. +)"; + +const char* ALL_IDS_DOCSTRING = R"( +Return a Numpy vector of all IDs currently in the index. +)"; + +// Index saving. +void save_index( + svs::DynamicIVF& index, const std::string& config_path, const std::string& data_dir +) { + index.save(config_path, data_dir); +} + +void wrap(py::module& m) { + std::string name = "DynamicIVF"; + py::class_ dynamic_ivf( + m, name.c_str(), "Top level class for the dynamic IVF index." + ); + + add_search_specialization(dynamic_ivf); + add_threading_interface(dynamic_ivf); + add_data_interface(dynamic_ivf); + + // IVF specific extensions. + ivf::add_interface(dynamic_ivf); + + // Dynamic interface. + dynamic_ivf.def("consolidate", &svs::DynamicIVF::consolidate, CONSOLIDATE_DOCSTRING); + dynamic_ivf.def( + "compact", + &svs::DynamicIVF::compact, + py::arg("batchsize") = 1'000'000, + COMPACT_DOCSTRING + ); + + // Assemble interface + { + auto dispatcher = assembly_dispatcher(); + // Procedurally generate the dispatch string. + auto dynamic = std::string{}; + for (size_t i = 0; i < dispatcher.size(); ++i) { + fmt::format_to( + std::back_inserter(dynamic), + R"( +Method {}: + - data_loader: {} + - distance: {} +)", + i, + dispatcher.description(i, 2), + dispatcher.description(i, 3) + ); + } + + dynamic_ivf.def_static( + "assemble_from_clustering", + [](Clustering clustering, + IVFAssembleTypes data_loader, + const py_contiguous_array_t& py_ids, + svs::DistanceType distance, + svs::DataType query_type, + bool enforce_dims, + size_t num_threads, + size_t intra_query_threads) { + return assemble_from_clustering( + std::move(clustering), + std::move(data_loader), + py_ids, + distance, + query_type, + enforce_dims, + num_threads, + intra_query_threads + ); + }, + py::arg("clustering"), + py::arg("data_loader"), + py::arg("ids"), + py::arg("distance") = svs::L2, + py::arg("query_type") = svs::DataType::float32, + py::arg("enforce_dims") = false, + py::arg("num_threads") = 1, + py::arg("intra_query_threads") = 1, + fmt::format(ASSEMBLE_DOCSTRING_PROTO, dynamic).c_str() + ); + dynamic_ivf.def_static( + "assemble_from_file", + [](const std::string& clustering_path, + IVFAssembleTypes data_loader, + const py_contiguous_array_t& py_ids, + svs::DistanceType distance, + svs::DataType query_type, + bool enforce_dims, + size_t num_threads, + size_t intra_query_threads) { + return assemble_from_file( + clustering_path, + std::move(data_loader), + py_ids, + distance, + query_type, + enforce_dims, + num_threads, + intra_query_threads + ); + }, + py::arg("clustering_path"), + py::arg("data_loader"), + py::arg("ids"), + py::arg("distance") = svs::L2, + py::arg("query_type") = svs::DataType::float32, + py::arg("enforce_dims") = false, + py::arg("num_threads") = 1, + py::arg("intra_query_threads") = 1, + fmt::format(ASSEMBLE_DOCSTRING_PROTO, dynamic).c_str() + ); + } + + // Index modification. + add_points_specialization(dynamic_ivf); + + // Note: DynamicIVFIndex doesn't support reconstruct_at, so we don't add reconstruct + // interface + + // Index Deletion. + dynamic_ivf.def( + "delete", + [](svs::DynamicIVF& index, const py_contiguous_array_t& ids) { + return index.delete_points(as_span(ids)); + }, + py::arg("ids"), + DELETE_DOCSTRING + ); + + // ID inspection + dynamic_ivf.def( + "has_id", + &svs::DynamicIVF::has_id, + py::arg("id"), + "Return whether the ID exists in the index." + ); + + dynamic_ivf.def( + "all_ids", + [](const svs::DynamicIVF& index) { + const auto& v = index.all_ids(); + // Populate a numpy-set + auto npv = numpy_vector(v.size()); + std::copy(v.begin(), v.end(), npv.mutable_unchecked().mutable_data()); + return npv; + }, + ALL_IDS_DOCSTRING + ); + + // Distance calculation + dynamic_ivf.def( + "get_distance", + [](const svs::DynamicIVF& index, + size_t id, + const py_contiguous_array_t& query) { + return index.get_distance(id, as_span(query)); + }, + py::arg("id"), + py::arg("query"), + R"( + Compute the distance between a query vector and a vector in the index. + + Args: + id: The external ID of the vector in the index. + query: The query vector as a numpy array. + + Returns: + The distance between the query and the indexed vector. + + Raises: + RuntimeError: If the ID doesn't exist or dimensions don't match. + )" + ); + + // Saving + dynamic_ivf.def( + "save", + &save_index, + py::arg("config_directory"), + py::arg("data_directory"), + R"( +Save a constructed index to disk (useful following index construction). + +Args: + config_directory: Directory where index configuration information will be saved. + data_directory: Directory where the dataset will be saved. + +Note: All directories should be separate to avoid accidental name collision with any +auxiliary files that are needed when saving the various components of the index. + +If the directory does not exist, it will be created if its parent exists. + +It is the caller's responsibility to ensure that no existing data will be +overwritten when saving the index to this directory. + )" + ); +} + +} // namespace svs::python::dynamic_ivf diff --git a/bindings/python/src/ivf.cpp b/bindings/python/src/ivf.cpp index 06a651fe7..7d231c998 100644 --- a/bindings/python/src/ivf.cpp +++ b/bindings/python/src/ivf.cpp @@ -21,6 +21,9 @@ #include "svs/python/dispatch.h" #include "svs/python/manager.h" +// pybind11 +#include // For std::variant support + // svs #include "svs/core/data/simple.h" #include "svs/core/distance.h" @@ -50,14 +53,6 @@ namespace py = pybind11; using namespace svs::python::ivf_specializations; namespace svs::python::ivf { -// The build process in IVF uses Kmeans to get centroids and assignments of data. -// This sparse clustering can be saved with centroids stored as float datatype. -// While assembling, the sparse clustering is used to create DenseClusters and -// centroids datatype can be changed as per the search specializations. -// By default, BFloat16 centroids are used to take advantage of AMX -// template -using Clustering = - svs::index::ivf::Clustering, uint32_t>; namespace detail { @@ -73,12 +68,18 @@ svs::IVF assemble_uncompressed( size_t num_threads, size_t intra_query_threads = 1 ) { - return svs::IVF::assemble_from_clustering( - std::move(clustering), - std::move(data), - distance_type, - num_threads, - intra_query_threads + // Use std::visit to handle the variant clustering type + return std::visit( + [&](auto&& actual_clustering) { + return svs::IVF::assemble_from_clustering( + std::move(actual_clustering), + std::move(data), + distance_type, + num_threads, + intra_query_threads + ); + }, + std::move(clustering) ); } @@ -145,9 +146,21 @@ Clustering build_uncompressed( svs::DistanceType distance_type, size_t num_threads ) { - return svs::IVF::build_clustering( + // Choose build type for clustering to leverage AMX instructions: + // - Float32 data -> BFloat16 (AMX supports BFloat16) + // - Float16 data -> Float16 (AMX supports Float16) + // - BFloat16 data -> BFloat16 (already optimal) + using BuildType = std::conditional_t, svs::BFloat16, T>; + auto clustering = svs::IVF::build_clustering( parameters, std::move(data), distance_type, num_threads ); + + // Return as variant - Float16 or BFloat16 based on BuildType + if constexpr (std::is_same_v) { + return Clustering(std::in_place_index<1>, std::move(clustering)); + } else { + return Clustering(std::in_place_index<0>, std::move(clustering)); + } } template @@ -182,9 +195,21 @@ Clustering uncompressed_build_from_array( auto data = svs::data::SimpleData>(view.size(), view.dimensions()); svs::data::copy(view, data); - return svs::IVF::build_clustering( + // Choose build type for clustering to leverage AMX instructions: + // - Float32 data -> BFloat16 (AMX supports BFloat16) + // - Float16 data -> Float16 (AMX supports Float16) + // - BFloat16 data -> BFloat16 (already optimal) + using BuildType = std::conditional_t, svs::BFloat16, T>; + auto clustering = svs::IVF::build_clustering( parameters, std::move(data), distance_type, num_threads ); + + // Return as variant - Float16 or BFloat16 based on BuildType + if constexpr (std::is_same_v) { + return Clustering(std::in_place_index<1>, std::move(clustering)); + } else { + return Clustering(std::in_place_index<0>, std::move(clustering)); + } } template void register_ivf_build_from_array(Dispatcher& dispatcher) { @@ -480,13 +505,31 @@ void wrap_build_from_file(py::class_& clustering) { // Save the sparse clustering to a directory void save_clustering(Clustering& clustering, const std::string& clustering_path) { - svs::lib::save_to_disk(clustering, clustering_path); + std::visit( + [&](auto&& actual_clustering) { + svs::lib::save_to_disk(actual_clustering, clustering_path); + }, + clustering + ); } -// Save the sparse clustering to a directory +// Load the sparse clustering from a directory +// Try loading as BFloat16 first, then Float16 if that fails auto load_clustering(const std::string& clustering_path, size_t num_threads = 1) { auto threadpool = threads::as_threadpool(num_threads); - return svs::lib::load_from_disk(clustering_path, threadpool); + try { + auto bf16_clustering = svs::lib::load_from_disk< + svs::index::ivf::Clustering, uint32_t>>( + clustering_path, threadpool + ); + return Clustering(std::in_place_index<0>, std::move(bf16_clustering)); + } catch (...) { + auto f16_clustering = svs::lib::load_from_disk< + svs::index::ivf::Clustering, uint32_t>>( + clustering_path, threadpool + ); + return Clustering(std::in_place_index<1>, std::move(f16_clustering)); + } } } // namespace detail @@ -590,6 +633,30 @@ void wrap(py::module& m) { // Reconstruction. // add_reconstruct_interface(ivf); + // Register both clustering types that make up the variant + name = "ClusteringBFloat16"; + py::class_ clustering_bf16(m, name.c_str()); + clustering_bf16.def( + "save", + [](ClusteringBF16& clustering, const std::string& clustering_path) { + svs::lib::save_to_disk(clustering, clustering_path); + }, + py::arg("clustering_directory"), + "Save a constructed IVF clustering to disk." + ); + + name = "ClusteringFloat16"; + py::class_ clustering_f16(m, name.c_str()); + clustering_f16.def( + "save", + [](ClusteringF16& clustering, const std::string& clustering_path) { + svs::lib::save_to_disk(clustering, clustering_path); + }, + py::arg("clustering_directory"), + "Save a constructed IVF clustering to disk." + ); + + // Register the variant type as the main Clustering class name = "Clustering"; py::class_ clustering( m, name.c_str(), "Top level class for sparse IVF clustering" diff --git a/bindings/python/src/python_bindings.cpp b/bindings/python/src/python_bindings.cpp index 67baf9049..e7c14bf6f 100644 --- a/bindings/python/src/python_bindings.cpp +++ b/bindings/python/src/python_bindings.cpp @@ -24,6 +24,7 @@ SVS_VALIDATE_BOOL_ENV(SVS_ENABLE_IVF) #if SVS_ENABLE_IVF +#include "svs/python/dynamic_ivf.h" #include "svs/python/ivf.h" #endif // SVS_ENABLE_IVF @@ -255,5 +256,6 @@ Convert the `fvecs` file on disk with 32-bit floating point entries to a `fvecs` SVS_VALIDATE_BOOL_ENV(SVS_ENABLE_IVF) #if SVS_ENABLE_IVF svs::python::ivf::wrap(m); + svs::python::dynamic_ivf::wrap(m); #endif // SVS_ENABLE_IVF } diff --git a/bindings/python/tests/test_ivf.py b/bindings/python/tests/test_ivf.py index b5bdf7b2c..08968e607 100644 --- a/bindings/python/tests/test_ivf.py +++ b/bindings/python/tests/test_ivf.py @@ -39,7 +39,8 @@ test_number_of_clusters, \ test_dimensions, \ timed, \ - get_test_set + get_test_set, \ + test_get_distance from .dataset import UncompressedMatcher @@ -161,6 +162,10 @@ def _test_basic_inner( self.assertEqual(queries.shape, (1000, 128)) self.assertEqual(groundtruth.shape, (1000, 100)) + # Test get_distance + data = svs.read_vecs(test_data_vecs) + test_get_distance(ivf, svs.DistanceType.L2, data) + # Data interface self.assertEqual(ivf.size, test_number_of_clusters) diff --git a/examples/python/example_ivf.py b/examples/python/example_ivf.py new file mode 100644 index 000000000..63f80dd62 --- /dev/null +++ b/examples/python/example_ivf.py @@ -0,0 +1,195 @@ +# Copyright 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: Static IVF Index + +This example demonstrates how to: +1. Build clustering for IVF index +2. Assemble an IVF index from clustering +3. Search the index +4. Save and reload clustering +5. Load index from saved clustering +""" + +import os +import svs +import numpy as np + +def main(): + print("=" * 80) + print("Static IVF Index Example") + print("=" * 80) + + # [generate-dataset] + # Create a test dataset + test_data_dir = "./example_data_ivf" + print(f"\n1. Generating test dataset in '{test_data_dir}'...") + + svs.generate_test_dataset( + 10000, # Create 10,000 vectors in the dataset + 1000, # Generate 1,000 query vectors + 128, # Set vector dimensionality to 128 + test_data_dir, # Directory where results will be generated + data_seed = 1234, # Random seed for reproducibility + query_seed = 5678, # Random seed for reproducibility + num_threads = 4, # Number of threads to use + distance = svs.DistanceType.L2, # Distance metric + ) + print(" ✓ Dataset generated") + # [generate-dataset] + + # [build-parameters] + # Configure clustering parameters for IVF + print("\n2. Configuring build parameters...") + build_parameters = svs.IVFBuildParameters( + num_centroids = 50, # Number of clusters/centroids + minibatch_size = 2000, # Minibatch size for k-means + num_iterations = 20, # Number of k-means iterations + is_hierarchical = True, # Use hierarchical k-means + training_fraction = 0.5, # Fraction of data for training + seed = 0xc0ffee, # Random seed for clustering + ) + print(f" ✓ Configured {build_parameters.num_centroids} centroids") + # [build-parameters] + + # [load-data] + # Load the dataset + print("\n3. Loading dataset...") + data_path = os.path.join(test_data_dir, "data.fvecs") + data_loader = svs.VectorDataLoader( + data_path, + svs.DataType.float32, + dims = 128 + ) + print(f" ✓ Data loader created") + # [load-data] + + # [build-clustering] + # Build the clustering + print("\n4. Building clustering (k-means)...") + clustering = svs.Clustering.build( + build_parameters = build_parameters, + data_loader = data_loader, + distance = svs.DistanceType.L2, + num_threads = 4, + ) + print(f" ✓ Clustering built with {build_parameters.num_centroids} centroids") + # [build-clustering] + + # [assemble-index] + # Assemble the IVF index from clustering + print("\n5. Assembling IVF index from clustering...") + index = svs.IVF.assemble_from_clustering( + clustering = clustering, + data_loader = data_loader, + distance = svs.DistanceType.L2, + num_threads = 4, + intra_query_threads = 1, + ) + print(f" ✓ Index assembled with {index.size} vectors") + print(f" ✓ Index dimensions: {index.dimensions}") + # [assemble-index] + + # [configure-search] + # Configure search parameters + print("\n6. Configuring search parameters...") + search_params = svs.IVFSearchParameters( + n_probes = 10, # Number of clusters to search + k_reorder = 1.0 # Reorder factor (1.0 = no reordering) + ) + index.search_parameters = search_params + print(f" ✓ Search parameters: n_probes={search_params.n_probes}") + # [configure-search] + + # [search] + # Perform search + print("\n7. Searching the index...") + queries = svs.read_vecs(os.path.join(test_data_dir, "queries.fvecs")) + groundtruth = svs.read_vecs(os.path.join(test_data_dir, "groundtruth.ivecs")) + + num_neighbors = 10 + I, D = index.search(queries, num_neighbors) + recall = svs.k_recall_at(groundtruth, I, num_neighbors, num_neighbors) + print(f" ✓ Recall@{num_neighbors}: {recall:.4f}") + print(f" ✓ Result shape: {I.shape}") + # [search] + + # [save-clustering] + # Save the clustering for later use + print("\n8. Saving clustering...") + clustering_path = os.path.join(test_data_dir, "clustering") + clustering.save(clustering_path) + print(f" ✓ Clustering saved to '{clustering_path}'") + # [save-clustering] + + # [load-and-assemble] + # Load clustering and assemble a new index + print("\n9. Loading clustering and assembling new index...") + loaded_clustering = svs.Clustering.load_clustering(clustering_path) + + new_index = svs.IVF.assemble_from_clustering( + clustering = loaded_clustering, + data_loader = data_loader, + distance = svs.DistanceType.L2, + num_threads = 4, + intra_query_threads = 1, + ) + print(f" ✓ New index assembled with {new_index.size} vectors") + # [load-and-assemble] + + # [assemble-from-file] + # Or directly assemble from file + print("\n10. Assembling index directly from clustering file...") + index_from_file = svs.IVF.assemble_from_file( + clustering_path = clustering_path, + data_loader = data_loader, + distance = svs.DistanceType.L2, + num_threads = 4, + intra_query_threads = 1, + ) + print(f" ✓ Index assembled with {index_from_file.size} vectors") + # [assemble-from-file] + + # [search-verification] + # Verify both indices produce the same results + print("\n11. Verifying search results consistency...") + index_from_file.search_parameters = search_params + I2, D2 = index_from_file.search(queries, num_neighbors) + recall2 = svs.k_recall_at(groundtruth, I2, num_neighbors, num_neighbors) + print(f" ✓ Recall@{num_neighbors}: {recall2:.4f}") + + if np.allclose(D, D2): + print(" ✓ Both indices produce identical results") + else: + print(" ✗ Warning: Results differ slightly (expected due to floating point)") + # [search-verification] + + # [tune-search-parameters] + # Experiment with different search parameters + print("\n12. Tuning search parameters...") + for n_probes in [5, 10, 20]: + search_params.n_probes = n_probes + index.search_parameters = search_params + I_tuned, _ = index.search(queries, num_neighbors) + recall_tuned = svs.k_recall_at(groundtruth, I_tuned, num_neighbors, num_neighbors) + print(f" ✓ n_probes={n_probes:2d}: Recall@{num_neighbors} = {recall_tuned:.4f}") + # [tune-search-parameters] + + print("\n" + "=" * 80) + print("Example completed successfully!") + print("=" * 80) + +if __name__ == "__main__": + main() diff --git a/examples/python/example_ivf_dynamic.py b/examples/python/example_ivf_dynamic.py new file mode 100644 index 000000000..605ff9ecc --- /dev/null +++ b/examples/python/example_ivf_dynamic.py @@ -0,0 +1,234 @@ +# Copyright 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: Dynamic IVF Index + +This example demonstrates how to: +1. Build a dynamic IVF index from scratch +2. Add new vectors to the index +3. Remove vectors from the index +4. Search the index +5. Compute distances between queries and indexed vectors +6. Save and reload the index +""" + +import os +import svs +import numpy as np + +def main(): + print("=" * 80) + print("Dynamic IVF Index Example") + print("=" * 80) + + # [generate-dataset] + # Create a test dataset with 10,000 vectors + test_data_dir = "./example_data_ivf_dynamic" + print(f"\n1. Generating test dataset in '{test_data_dir}'...") + + svs.generate_test_dataset( + 1000, # Create 1000 vectors in the dataset + 100, # Generate 100 query vectors + 128, # Set vector dimensionality to 128 + test_data_dir, # Directory where results will be generated + data_seed = 1234, # Random seed for reproducibility + query_seed = 5678, # Random seed for reproducibility + num_threads = 4, # Number of threads to use + distance = svs.DistanceType.L2, # Distance metric + ) + print(" ✓ Dataset generated") + # [generate-dataset] + + # [build-parameters] + # Configure clustering parameters for IVF + print("\n2. Configuring build parameters...") + build_parameters = svs.IVFBuildParameters( + num_centroids = 20, # Number of clusters/centroids + minibatch_size = 1000, # Minibatch size for k-means + num_iterations = 10, # Number of k-means iterations + is_hierarchical = True, # Use hierarchical k-means + training_fraction = 0.1, # Fraction of data for training + seed = 0xc0ffee, # Random seed for clustering + ) + print(f" ✓ Configured {build_parameters.num_centroids} centroids") + # [build-parameters] + + # [build-clustering-and-assemble] + # Build clustering and then assemble the dynamic IVF index + print("\n3. Building clustering and assembling dynamic IVF index...") + + # Load all data + data = svs.read_vecs(os.path.join(test_data_dir, "data.fvecs")) + n_total = data.shape[0] # Total vectors (1000) + ids_all = np.arange(n_total).astype('uint64') + + # Build the clustering using all data + data_loader = svs.VectorDataLoader( + os.path.join(test_data_dir, "data.fvecs"), + svs.DataType.float32, + dims = 128 + ) + clustering = svs.Clustering.build( + build_parameters = build_parameters, + data_loader = data_loader, + distance = svs.DistanceType.L2, + num_threads = 4, + ) + print(f" ✓ Clustering built with {build_parameters.num_centroids} centroids") + + # Assemble the dynamic IVF index with all vectors + print(" Assembling dynamic IVF index from clustering...") + index = svs.DynamicIVF.assemble_from_clustering( + clustering = clustering, + data_loader = data_loader, + ids = ids_all, # Index all vectors + distance = svs.DistanceType.L2, + num_threads = 4, + intra_query_threads = 1, + ) + print(f" ✓ Index assembled with {index.size} vectors") + print(f" ✓ Index dimensions: {index.dimensions}") + # [build-clustering-and-assemble] + + # [demonstrate-dynamic-operations] + # Demonstrate add and delete operations (even though we already have all vectors) + print("\n4. Demonstrating dynamic operations...") + print(f" Initial index size: {index.size}") + + # Delete some vectors + print(" Deleting first 100 vectors...") + ids_to_delete = np.arange(100).astype('uint64') + index.delete(ids_to_delete) + print(f" After deletion: {index.size} vectors") + + # Add them back + print(" Adding 100 vectors back...") + index.add(data[:100], ids_to_delete) + print(f" After addition: {index.size} vectors") + # [demonstrate-dynamic-operations] + + # [search-before-delete] + # Search before deletion + print("\n5. Searching the index...") + queries = svs.read_vecs(os.path.join(test_data_dir, "queries.fvecs")) + groundtruth = svs.read_vecs(os.path.join(test_data_dir, "groundtruth.ivecs")) + + # Configure search parameters + search_params = svs.IVFSearchParameters( + n_probes = 10, # Number of clusters to search + k_reorder = 1.0 # Reorder factor + ) + index.search_parameters = search_params + + # Perform search + num_neighbors = 10 + I, D = index.search(queries, num_neighbors) + recall = svs.k_recall_at(groundtruth, I, num_neighbors, num_neighbors) + print(f" ✓ Recall@{num_neighbors}: {recall:.4f}") + # [search-before-delete] + + # [get-distance] + # Compute distance between a query and a specific indexed vector + print("\n6. Computing distances with get_distance()...") + query_vector = queries[0] + test_id = 100 + + if index.has_id(test_id): + distance = index.get_distance(test_id, query_vector) + print(f" ✓ Distance from query to vector {test_id}: {distance:.6f}") + else: + print(f" ✗ Vector {test_id} not found in index") + # [get-distance] + + # [remove-vectors] + # Remove vectors from the index + print("\n7. Removing the first 50 vectors...") + ids_to_delete = ids_all[:50] + num_deleted = index.delete(ids_to_delete) + print(f" ✓ Deleted {num_deleted} vectors") + print(f" ✓ Index size after deletion: {index.size}") + + # Verify vectors are deleted + if not index.has_id(25): + print(f" ✓ Verified: Vector ID 25 no longer in index") + # [remove-vectors] + + # [consolidate-index] + # Consolidate and compact the index + print("\n8. Consolidating and compacting the index...") + index.consolidate().compact(1000) + print(f" ✓ Index consolidated and compacted") + # [consolidate-index] + + # [search-after-modifications] + # Search after modifications + print("\n9. Searching after modifications...") + I, D = index.search(queries, num_neighbors) + recall = svs.k_recall_at(groundtruth, I, num_neighbors, num_neighbors) + print(f" ✓ Recall@{num_neighbors}: {recall:.4f}") + # [search-after-modifications] + + # [tune-search-parameters] + # Experiment with different search parameters + print("\n10. Tuning search parameters...") + for n_probes in [5, 10, 20, 30]: + search_params.n_probes = n_probes + index.search_parameters = search_params + + I, D = index.search(queries, num_neighbors) + recall = svs.k_recall_at(groundtruth, I, num_neighbors, num_neighbors) + print(f" n_probes={n_probes:2d} → Recall@{num_neighbors}: {recall:.4f}") + # [tune-search-parameters] + + # [save-index] + # Save the index to disk + print("\n11. Saving the index...") + config_dir = os.path.join(test_data_dir, "saved_config") + data_dir = os.path.join(test_data_dir, "saved_data") + + # Create directories if they don't exist + os.makedirs(config_dir, exist_ok=True) + os.makedirs(data_dir, exist_ok=True) + + index.save(config_dir, data_dir) + print(f" ✓ Index saved to:") + print(f" Config: {config_dir}") + print(f" Data: {data_dir}") + # [save-index] + + # [load-index] + # Note: DynamicIVF.load() is being implemented for easier reload + # For now, the index has been successfully saved and can be accessed at: + print("\n12. Index saved successfully!") + print(f" ✓ Config: {config_dir}") + print(f" ✓ Data: {data_dir}") + print(f" Note: load() API coming soon for simplified reload") + # [load-index] + + # [get-all-ids] + # Inspect final index state + print("\n13. Final index inspection...") + all_ids = index.all_ids() + print(f" ✓ Index contains {len(all_ids)} unique IDs") + print(f" ✓ ID range: [{np.min(all_ids)}, {np.max(all_ids)}]") + # [get-all-ids] + + print("\n" + "=" * 80) + print("Dynamic IVF Example Completed Successfully!") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/include/svs/core/data/simple.h b/include/svs/core/data/simple.h index df0a45c3f..0fcb31bbb 100644 --- a/include/svs/core/data/simple.h +++ b/include/svs/core/data/simple.h @@ -38,6 +38,9 @@ namespace svs { namespace data { +// Forward declaration for Blocked allocator +template class Blocked; + template bool check_dims(size_t m, size_t n) { if constexpr (M == Dynamic || N == Dynamic) { return m == n; @@ -247,6 +250,8 @@ class SimpleData { /// Data wrapped in the library allocator. using lib_alloc_data_type = SimpleData>; + /// Data wrapped in the library blocked allocator for dynamic IVF. + using lib_blocked_alloc_data_type = SimpleData>>; /// Return the underlying allocator. const allocator_type& get_allocator() const { return data_.get_allocator(); } @@ -607,6 +612,8 @@ class SimpleData> { using const_value_type = std::span; using lib_alloc_data_type = SimpleData>>; + /// Already blocked, so lib_blocked_alloc_data_type is the same as lib_alloc_data_type. + using lib_blocked_alloc_data_type = SimpleData>>; ///// Constructors SimpleData(size_t n_elements, size_t n_dimensions, const Blocked& alloc) diff --git a/include/svs/extensions/ivf/scalar.h b/include/svs/extensions/ivf/scalar.h index cf199a641..4732a3fa3 100644 --- a/include/svs/extensions/ivf/scalar.h +++ b/include/svs/extensions/ivf/scalar.h @@ -45,4 +45,21 @@ auto svs_invoke( return new_sqdata; } +// Specialization for blocked allocators (Dynamic IVF) +template +auto svs_invoke( + svs::tag_t, + const Data& original, + size_t new_size, + const data::Blocked& SVS_UNUSED(blocked_alloc) +) { + auto new_sqdata = + SQDataset>( + new_size, original.dimensions() + ); + new_sqdata.set_scale(original.get_scale()); + new_sqdata.set_bias(original.get_bias()); + return new_sqdata; +} + } // namespace svs::quantization::scalar diff --git a/include/svs/index/ivf/clustering.h b/include/svs/index/ivf/clustering.h index 8f0555171..93a36526c 100644 --- a/include/svs/index/ivf/clustering.h +++ b/include/svs/index/ivf/clustering.h @@ -254,6 +254,12 @@ template class Clustering { template struct DenseCluster { public: + using data_type = Data; + using index_type = I; + + // Default constructor for in-place initialization + DenseCluster() = default; + DenseCluster(Data data, std::vector ids) : data_{std::move(data)} , ids_{std::move(ids)} { @@ -264,6 +270,12 @@ template struct DenseCluster { size_t size() const { return data_.size(); } + // Support for dynamic operations + void resize(size_t new_size) { + data_.resize(new_size); + ids_.resize(new_size); + } + template void on_leaves(Callback&& f, size_t prefetch_offset) const { size_t p = 0; @@ -287,6 +299,7 @@ template struct DenseCluster { auto get_secondary(size_t id) const { return data_.get_secondary(id); } auto get_global_id(size_t local_id) const { return ids_[local_id]; } const Data& view_cluster() const { return data_; } + Data& view_cluster() { return data_; } public: Data data_; @@ -303,7 +316,7 @@ class DenseClusteredDataset { using index_type = I; using data_type = Data; - // Constructor + // Constructor from clustering (for building from existing data) template DenseClusteredDataset( const Clustering& clustering, @@ -329,12 +342,35 @@ class DenseClusteredDataset { ); } + // Constructor for empty clusters (for assembly/dynamic operations) + template + DenseClusteredDataset(size_t num_clusters, size_t dimensions, const Alloc& allocator) + : clusters_{} { + clusters_.reserve(num_clusters); + for (size_t i = 0; i < num_clusters; ++i) { + clusters_.emplace_back(Data(0, dimensions, allocator), std::vector()); + } + } + template void on_leaves(Callback&& f, size_t cluster) const { clusters_.at(cluster).on_leaves(SVS_FWD(f), prefetch_offset_); } size_t get_prefetch_offset() const { return prefetch_offset_; } void set_prefetch_offset(size_t offset) { prefetch_offset_ = offset; } + + // Cluster access (const) + const DenseCluster& operator[](size_t cluster) const { + return clusters_[cluster]; + } + + // Cluster access (mutable) - for dynamic IVF operations + DenseCluster& operator[](size_t cluster) { return clusters_[cluster]; } + + // Number of clusters + size_t size() const { return clusters_.size(); } + + // Datum access (const) auto get_datum(size_t cluster, size_t id) const { return clusters_.at(cluster).get_datum(id); } @@ -344,10 +380,15 @@ class DenseClusteredDataset { auto get_global_id(size_t cluster, size_t id) const { return clusters_.at(cluster).get_global_id(id); } + + // View cluster data (const) const Data& view_cluster(size_t cluster) const { return clusters_.at(cluster).view_cluster(); } + // View cluster data (mutable) - for dynamic IVF operations + Data& view_cluster(size_t cluster) { return clusters_[cluster].view_cluster(); } + private: std::vector> clusters_; size_t prefetch_offset_ = 8; diff --git a/include/svs/index/ivf/common.h b/include/svs/index/ivf/common.h index 28f12151b..e914778d6 100644 --- a/include/svs/index/ivf/common.h +++ b/include/svs/index/ivf/common.h @@ -47,6 +47,13 @@ namespace svs::index::ivf { // threshold for numerical stability in algorithms such as k-means clustering, where exact constexpr double EPSILON = 1.0 / 1024.0; +/// Minimum training sample multiplier for clustering algorithms. +/// When training data size is small relative to the number of clusters, we ensure +/// at least (num_clusters * MIN_TRAINING_SAMPLE_MULTIPLIER) samples are used for +/// training to maintain clustering quality. This prevents degenerate cases where +/// training_fraction would produce insufficient samples. +constexpr size_t MIN_TRAINING_SAMPLE_MULTIPLIER = 2; + /// @brief Parameters controlling the IVF build/k-means algortihm. struct IVFBuildParameters { public: @@ -224,56 +231,79 @@ template void compute_matmul( const T* data, const T* centroids, float* results, size_t m, size_t n, size_t k ) { + // Early return for zero dimensions. + // Calling Intel MKL functions with zero dimensions may result in undefined behavior + // or runtime errors. This check ensures we avoid such cases. + if (m == 0 || n == 0 || k == 0) { + return; // Nothing to compute + } + + // Check for integer overflow when casting to int (MKL requirement) + constexpr size_t max_int = static_cast(std::numeric_limits::max()); + if (m > max_int || n > max_int || k > max_int) { + throw ANNEXCEPTION( + "Matrix dimensions too large for Intel MKL GEMM: m={}, n={}, k={}, max={}", + m, + n, + k, + max_int + ); + } + + // Cast size_t parameters to int for MKL GEMM functions + int m_int = static_cast(m); + int n_int = static_cast(n); + int k_int = static_cast(k); if constexpr (std::is_same_v) { cblas_sgemm( CblasRowMajor, // CBLAS_LAYOUT layout CblasNoTrans, // CBLAS_TRANSPOSE TransA CblasTrans, // CBLAS_TRANSPOSE TransB - m, // const int M - n, // const int N - k, // const int K - 1.0, // float alpha + m_int, // const int M + n_int, // const int N + k_int, // const int K + 1.0F, // float alpha data, // const float* A - k, // const int lda + k_int, // const int lda centroids, // const float* B - k, // const int ldb - 0.0, // const float beta + k_int, // const int ldb + 0.0F, // const float beta results, // float* c - n // const int ldc + n_int // const int ldc ); } else if constexpr (std::is_same_v) { cblas_gemm_bf16bf16f32( CblasRowMajor, // CBLAS_LAYOUT layout CblasNoTrans, // CBLAS_TRANSPOSE TransA CblasTrans, // CBLAS_TRANSPOSE TransB - m, // const int M - n, // const int N - k, // const int K - 1.0, // float alpha + m_int, // const int M + n_int, // const int N + k_int, // const int K + 1.0F, // float alpha (const uint16_t*)data, // const *uint16_t A - k, // const int lda + k_int, // const int lda (const uint16_t*)centroids, // const uint16_t* B - k, // const int ldb - 0.0, // const float beta + k_int, // const int ldb + 0.0F, // const float beta results, // float* c - n // const int ldc + n_int // const int ldc ); } else if constexpr (std::is_same_v) { cblas_gemm_f16f16f32( CblasRowMajor, // CBLAS_LAYOUT layout CblasNoTrans, // CBLAS_TRANSPOSE TransA CblasTrans, // CBLAS_TRANSPOSE TransB - m, // const int M - n, // const int N - k, // const int K - 1.0, // float alpha + m_int, // const int M + n_int, // const int N + k_int, // const int K + 1.0F, // float alpha (const uint16_t*)data, // const *uint16_t A - k, // const int lda + k_int, // const int lda (const uint16_t*)centroids, // const uint16_t* B - k, // const int ldb - 0.0, // const float beta + k_int, // const int ldb + 0.0F, // const float beta results, // float* c - n // const int ldc + n_int // const int ldc ); } else { throw ANNEXCEPTION("GEMM type not supported!"); @@ -310,7 +340,7 @@ void normalize_centroids( auto datum = centroids.get_datum(i); float norm = distance::norm(datum); if (norm != 0.0) { - float norm_inv = 1.0 / norm; + float norm_inv = 1.0F / norm; for (size_t j = 0; j < datum.size(); j++) { datum[j] = datum[j] * norm_inv; } @@ -327,7 +357,7 @@ template < typename Distance, threads::ThreadPool Pool> void centroid_assignment( - Data& data, + const Data& data, std::vector& data_norm, threads::UnitRange batch_range, Distance& SVS_UNUSED(distance), @@ -338,21 +368,43 @@ void centroid_assignment( Pool& threadpool, lib::Timer& timer ) { + using DataType = typename Data::element_type; + using CentroidType = T; + + // Convert data to match centroid type if necessary + data::SimpleData data_conv; + if constexpr (!std::is_same_v) { + data_conv = convert_data(data, threadpool); + } + auto generate_assignments = timer.push_back("generate assignments"); threads::parallel_for( threadpool, threads::StaticPartition{batch_range.size()}, [&](auto indices, auto /*tid*/) { auto range = threads::UnitRange(indices); - compute_matmul( - data.get_datum(range.start()).data(), - centroids.data(), - matmul_results.get_datum(range.start()).data(), - range.size(), - centroids.size(), - data.dimensions() - ); - if constexpr (std::is_same_v) { + if constexpr (!std::is_same_v) { + compute_matmul( + data_conv.get_datum(range.start()).data(), + centroids.data(), + matmul_results.get_datum(range.start()).data(), + range.size(), + centroids.size(), + data.dimensions() + ); + } else { + compute_matmul( + data.get_datum(range.start()).data(), + centroids.data(), + matmul_results.get_datum(range.start()).data(), + range.size(), + centroids.size(), + data.dimensions() + ); + } + if constexpr (std::is_same_v< + std::remove_cvref_t, + distance::DistanceIP>) { for (auto i : indices) { auto nearest = type_traits::sentinel_v, std::greater<>>; @@ -362,13 +414,15 @@ void centroid_assignment( } assignments[batch_range.start() + i] = nearest.id(); } - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v< + std::remove_cvref_t, + distance::DistanceL2>) { for (auto i : indices) { auto nearest = type_traits::sentinel_v, std::less<>>; auto dists = matmul_results.get_datum(i); for (size_t j = 0; j < centroids.size(); j++) { auto dist = data_norm[batch_range.start() + i] + centroids_norm[j] - - 2 * dists[j]; + (2 * dists[j]); nearest = std::min(nearest, Neighbor(j, dist)); } assignments[batch_range.start() + i] = nearest.id(); @@ -456,7 +510,7 @@ void centroid_split( if (counts.at(j) == 0) { continue; } - float p = counts.at(j) / float(num_data); + float p = static_cast(counts.at(j)) / static_cast(num_data); float r = distribution(rng); if (r < p) { break; @@ -511,13 +565,13 @@ auto kmeans_training( auto training_timer = timer.push_back("Kmeans training"); data::SimpleData centroids_fp32 = convert_data(centroids, threadpool); - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v, distance::DistanceIP>) { normalize_centroids(centroids_fp32, threadpool, timer); } auto assignments = std::vector(data.size()); std::vector data_norm; - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v, distance::DistanceL2>) { generate_norms(data, data_norm, threadpool); } std::vector centroids_norm; @@ -526,7 +580,7 @@ auto kmeans_training( auto iter_timer = timer.push_back("iteration"); auto batchsize = parameters.minibatch_size_; auto num_batches = lib::div_round_up(data.size(), batchsize); - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v, distance::DistanceL2>) { generate_norms(centroids_fp32, centroids_norm, threadpool); } @@ -559,7 +613,7 @@ auto kmeans_training( centroid_split(data, centroids_fp32, counts, rng, threadpool, timer); - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v, distance::DistanceIP>) { normalize_centroids(centroids_fp32, threadpool, timer); } } @@ -638,8 +692,9 @@ data::SimpleData make_training_set( threadpool, threads::StaticPartition{num_training}, [&](auto indices, auto /*tid*/) { - for (auto i : indices) + for (auto i : indices) { trainset.set_datum(i, data.get_datum(ids[i])); + } } ); return trainset; @@ -660,8 +715,9 @@ data::SimpleData init_centroids( threadpool, threads::StaticPartition{num_centroids}, [&](auto indices, auto) { - for (auto i : indices) + for (auto i : indices) { centroids.set_datum(i, trainset.get_datum(ids[i])); + } } ); return centroids; @@ -671,7 +727,7 @@ data::SimpleData init_centroids( template std::vector maybe_compute_norms(const Data& data, Pool& threadpool) { std::vector norms; - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v, distance::DistanceL2>) { generate_norms(data, norms, threadpool); } return norms; @@ -680,14 +736,112 @@ std::vector maybe_compute_norms(const Data& data, Pool& threadpool) { /// @brief Assign all points to clusters according to assignments template std::vector> group_assignments( - const std::vector& assignments, size_t num_clusters, const Data& data_train + const std::vector& assignments, size_t num_clusters, const Data& data ) { std::vector> clusters(num_clusters); - for (auto i : data_train.eachindex()) + for (auto i : data.eachindex()) { clusters[assignments[i]].push_back(i); + } return clusters; } +/// @brief Perform cluster assignment for data given pre-trained centroids +/// +/// @tparam BuildType The numeric type used for matrix operations (float, Float16, BFloat16) +/// @tparam Data The dataset type +/// @tparam Centroids The centroids dataset type +/// @tparam Distance The distance metric type (DistanceIP or DistanceL2) +/// @tparam Pool The thread pool type +/// @tparam I The integer type for cluster indices +/// +/// @param data The dataset to assign to clusters +/// @param centroids The pre-trained centroids +/// @param distance The distance metric +/// @param threadpool The thread pool for parallel execution +/// @param minibatch_size Size of each processing batch (default: 10000) +/// @param integer_type Type tag for cluster indices (default: uint32_t) +/// +/// @return A vector of vectors where each inner vector contains the indices of data +/// points assigned to that cluster +template < + typename BuildType, + data::ImmutableMemoryDataset Data, + data::ImmutableMemoryDataset Centroids, + typename Distance, + threads::ThreadPool Pool, + std::integral I = uint32_t> +auto cluster_assignment( + Data& data, + Centroids& centroids, + Distance& distance, + Pool& threadpool, + size_t minibatch_size = 10'000, + lib::Type SVS_UNUSED(integer_type) = {} +) { + size_t ndims = data.dimensions(); + size_t num_centroids = centroids.size(); + + if (data.dimensions() != centroids.dimensions()) { + throw ANNEXCEPTION( + "Data and centroids must have the same dimensions! Data dims: {}, Centroids " + "dims: {}", + data.dimensions(), + centroids.dimensions() + ); + } + + // Allocate memory for assignments and matmul results + auto assignments = std::vector(data.size()); + auto matmul_results = data::SimpleData{minibatch_size, num_centroids}; + + // Convert centroids to BuildType if necessary + using CentroidType = typename Centroids::element_type; + data::SimpleData centroids_build; + if constexpr (!std::is_same_v) { + centroids_build = convert_data(centroids, threadpool); + } else { + centroids_build = + data::SimpleData{centroids.size(), centroids.dimensions()}; + convert_data(centroids, centroids_build, threadpool); + } + + // Compute norms if using L2 distance + auto data_norm = maybe_compute_norms(data, threadpool); + auto centroids_norm = maybe_compute_norms(centroids_build, threadpool); + + // Process data in batches + size_t batchsize = minibatch_size; + size_t num_batches = lib::div_round_up(data.size(), batchsize); + + using Alloc = svs::HugepageAllocator; + auto data_batch = data::SimpleData{batchsize, ndims}; + + for (size_t batch = 0; batch < num_batches; ++batch) { + auto this_batch = threads::UnitRange{ + batch * batchsize, std::min((batch + 1) * batchsize, data.size())}; + auto data_batch_view = data::make_view(data, this_batch); + convert_data(data_batch_view, data_batch, threadpool); + + // Use the existing centroid_assignment function to compute assignments + auto timer = lib::Timer(); + centroid_assignment( + data_batch, + data_norm, + this_batch, + distance, + centroids_build, + centroids_norm, + assignments, + matmul_results, + threadpool, + timer + ); + } + + // Group assignments into clusters + return group_assignments(assignments, num_centroids, data); +} + template void search_centroids( const Query& query, @@ -700,7 +854,7 @@ void search_centroids( ) { unsigned int count = 0; buffer.clear(); - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v, distance::DistanceIP>) { for (size_t j = 0; j < num_threads; j++) { auto distance = matmul_results[j].get_datum(query_id); for (size_t k = 0; k < distance.size(); k++) { @@ -708,12 +862,12 @@ void search_centroids( count++; } } - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v, distance::DistanceL2>) { float query_norm = distance::norm_square(query); for (size_t j = 0; j < num_threads; j++) { auto distance = matmul_results[j].get_datum(query_id); for (size_t k = 0; k < distance.size(); k++) { - float dist = query_norm + centroids_norm[count] - 2 * distance[k]; + float dist = query_norm + centroids_norm[count] - (2 * distance[k]); buffer.insert({count, dist}); count++; } diff --git a/include/svs/index/ivf/dynamic_ivf.h b/include/svs/index/ivf/dynamic_ivf.h new file mode 100644 index 000000000..966fc3b12 --- /dev/null +++ b/include/svs/index/ivf/dynamic_ivf.h @@ -0,0 +1,973 @@ +/* + * Copyright 2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// Include the IVF index and clustering +#include "svs/index/ivf/clustering.h" +#include "svs/index/ivf/index.h" + +// svs +#include "svs/concepts/distance.h" +#include "svs/core/loading.h" +#include "svs/core/logging.h" +#include "svs/core/query_result.h" +#include "svs/core/translation.h" +#include "svs/lib/misc.h" +#include "svs/lib/threads.h" + +// stdlib +#include +#include + +namespace svs::index::ivf { + +/// +/// Metadata tracking the state of a particular data index for DynamicIVFIndex. +/// The following states have the given meaning for their corresponding slot: +/// +/// * Valid: Valid and present in the associated dataset. +/// * Empty: Available slot that can be used for new data or reclaimed after deletion. +/// +enum class IVFSlotMetadata : uint8_t { Empty = 0x00, Valid = 0x01 }; + +/// +/// @brief Dynamic IVF Index with insertion and deletion support +/// +/// @tparam Centroids The type of centroid storage +/// @tparam Cluster Type representing cluster storage (DenseCluster with BlockedData) +/// @tparam Dist The distance functor used to compare queries with the elements +/// @tparam ThreadPoolProto Thread pool prototype type +/// +/// An IVF index implementation that supports dynamic insertion and deletion of vectors +/// while maintaining the inverted file structure for efficient similarity search. +/// +template +class DynamicIVFIndex { + public: + // Traits + static constexpr bool supports_insertions = true; + static constexpr bool supports_deletions = true; + static constexpr bool supports_saving = true; + static constexpr bool needs_id_translation = true; + + // Type Aliases + using Idx = typename Cluster::index_type; + using Data = typename Cluster::data_type; + using internal_id_type = size_t; + using external_id_type = size_t; + using distance_type = Dist; + using centroids_type = Centroids; + using cluster_type = Cluster; + using search_parameters_type = IVFSearchParameters; + using compare = distance::compare_t; + + // Thread-related type aliases + using InterQueryThreadPool = threads::ThreadPoolHandle; + using IntraQueryThreadPool = threads::DefaultThreadPool; + + private: + // Core IVF components (same structure as static IVF) + centroids_type centroids_; + Cluster clusters_; // Cluster container + + // Metadata tracking for dynamic operations + std::vector status_; // Status of each global slot + std::vector id_to_cluster_; // Maps global ID to cluster index + std::vector id_in_cluster_; // Maps global ID to position in cluster + size_t first_empty_ = 0; + size_t prefetch_offset_ = 8; + + // Translation and distance + IDTranslator translator_; + distance_type distance_; + + // Threading infrastructure (same as static IVF) + InterQueryThreadPool inter_query_threadpool_; + const size_t intra_query_thread_count_; + std::vector intra_query_threadpools_; + + // Search infrastructure (same as static IVF) + std::vector> matmul_results_; + std::vector centroids_norm_; + search_parameters_type search_parameters_; + + // Logger + svs::logging::logger_ptr logger_; + + public: + /// @brief Construct a new Dynamic IVF Index + /// + /// @param centroids Centroid collection for space partitioning + /// @param clusters Cluster container + /// @param external_ids External IDs for all vectors + /// @param distance_function Distance metric for similarity computation + /// @param threadpool_proto Primary thread pool prototype + /// @param intra_query_thread_count Number of threads for intra-query parallelism + /// @param logger Logger for per-index logging customization + template + DynamicIVFIndex( + centroids_type centroids, + Cluster clusters, + const ExternalIds& external_ids, + Dist distance_function, + TP threadpool_proto, + const size_t intra_query_thread_count = 1, + svs::logging::logger_ptr logger = svs::logging::get() + ) + : centroids_{std::move(centroids)} + , clusters_{std::move(clusters)} + , first_empty_{0} + , prefetch_offset_{8} + , distance_{std::move(distance_function)} + , inter_query_threadpool_{threads::as_threadpool(std::move(threadpool_proto))} + , intra_query_thread_count_{intra_query_thread_count} + , logger_{std::move(logger)} { + // Initialize metadata structures + size_t total_size = 0; + for (size_t cluster_idx = 0; cluster_idx < clusters_.size(); ++cluster_idx) { + const auto& cluster = clusters_[cluster_idx]; + for (size_t pos = 0; pos < cluster.ids_.size(); ++pos) { + total_size = + std::max(total_size, static_cast(cluster.ids_[pos]) + 1); + } + } + + status_.resize(total_size, IVFSlotMetadata::Valid); + id_to_cluster_.resize(total_size); + id_in_cluster_.resize(total_size); + first_empty_ = total_size; + + // Build reverse mapping from global ID to cluster location + for (size_t cluster_idx = 0; cluster_idx < clusters_.size(); ++cluster_idx) { + const auto& cluster = clusters_[cluster_idx]; + for (size_t pos = 0; pos < cluster.ids_.size(); ++pos) { + Idx global_id = cluster.ids_[pos]; + id_to_cluster_[global_id] = cluster_idx; + id_in_cluster_[global_id] = pos; + } + } + + // Initialize ID translation + translator_.insert( + external_ids, threads::UnitRange(0, external_ids.size()) + ); + + // Initialize thread pools and search infrastructure + validate_thread_configuration(); + initialize_thread_pools(); + initialize_search_buffers(); + initialize_distance_metadata(); + } + + /// @brief Constructor with pre-existing translator (for loading from saved state) + template + DynamicIVFIndex( + centroids_type centroids, + Cluster clusters, + IDTranslator translator, + Dist distance_function, + TP threadpool_proto, + const size_t intra_query_thread_count = 1, + svs::logging::logger_ptr logger = svs::logging::get() + ) + : centroids_{std::move(centroids)} + , clusters_{std::move(clusters)} + , first_empty_{0} + , prefetch_offset_{8} + , translator_{std::move(translator)} + , distance_{std::move(distance_function)} + , inter_query_threadpool_{threads::as_threadpool(std::move(threadpool_proto))} + , intra_query_thread_count_{intra_query_thread_count} + , logger_{std::move(logger)} { + // Initialize metadata structures based on cluster contents + size_t total_size = 0; + for (const auto& cluster : clusters_) { + for (size_t pos = 0; pos < cluster.ids_.size(); ++pos) { + total_size = + std::max(total_size, static_cast(cluster.ids_[pos]) + 1); + } + } + + status_.resize(total_size, IVFSlotMetadata::Valid); + id_to_cluster_.resize(total_size); + id_in_cluster_.resize(total_size); + first_empty_ = total_size; + + // Build reverse mapping from global ID to cluster location + for (size_t cluster_idx = 0; cluster_idx < clusters_.size(); ++cluster_idx) { + const auto& cluster = clusters_[cluster_idx]; + for (size_t pos = 0; pos < cluster.ids_.size(); ++pos) { + Idx global_id = cluster.ids_[pos]; + id_to_cluster_[global_id] = cluster_idx; + id_in_cluster_[global_id] = pos; + } + } + + // Initialize thread pools and search infrastructure + validate_thread_configuration(); + initialize_thread_pools(); + initialize_search_buffers(); + initialize_distance_metadata(); + } + + ///// Basic Properties ///// + + /// @brief Get logger + svs::logging::logger_ptr get_logger() const { return logger_; } + + /// @brief Return the number of valid entries in the index + size_t size() const { return translator_.size(); } + + /// @brief Return the number of centroids/clusters + size_t num_clusters() const { return centroids_.size(); } + + /// @brief Return the logical number of dimensions + size_t dimensions() const { return centroids_.dimensions(); } + + /// @brief Get index name + std::string name() const { return "Dynamic IVF Index"; } + + ///// Search Parameters ///// + + /// @brief Get current search parameters + search_parameters_type get_search_parameters() const { return search_parameters_; } + + /// @brief Set search parameters + void set_search_parameters(const search_parameters_type& params) { + search_parameters_ = params; + } + + ///// Threading Configuration ///// + + /// @brief Get number of threads for inter-query parallelism + size_t get_num_threads() const { return inter_query_threadpool_.size(); } + + /// @brief Get number of threads for intra-query parallelism + size_t get_num_intra_query_threads() const { return intra_query_thread_count_; } + + /// @brief Set threadpool for inter-query parallelism + void set_threadpool(InterQueryThreadPool threadpool) { + if (threadpool.size() != inter_query_threadpool_.size()) { + throw std::runtime_error( + "Threadpool change not supported - thread count must remain constant" + ); + } + inter_query_threadpool_ = std::move(threadpool); + } + + /// @brief Get threadpool handle + InterQueryThreadPool& get_threadpool_handle() { return inter_query_threadpool_; } + + /// @brief Get const threadpool handle + const InterQueryThreadPool& get_threadpool_handle() const { + return inter_query_threadpool_; + } + + ///// Index Translation ///// + + /// @brief Translate external ID to internal ID + size_t translate_external_id(size_t e) const { return translator_.get_internal(e); } + + /// @brief Translate internal ID to external ID + size_t translate_internal_id(size_t i) const { return translator_.get_external(i); } + + /// @brief Check whether external ID exists + bool has_id(size_t e) const { return translator_.has_external(e); } + + /// @brief Get the raw data for external id + auto get_datum(size_t e) const { + size_t internal_id = translate_external_id(e); + size_t cluster_idx = id_to_cluster_[internal_id]; + size_t pos = id_in_cluster_[internal_id]; + return clusters_[cluster_idx].get_datum(pos); + } + + /// @brief Get raw data by cluster and local position (for extension compatibility) + auto get_datum(size_t cluster_idx, size_t local_pos) const { + return clusters_[cluster_idx].get_datum(local_pos); + } + + /// @brief Get secondary data by cluster and local position (for LeanVec) + auto get_secondary(size_t cluster_idx, size_t local_pos) const { + return clusters_[cluster_idx].data_.get_secondary(local_pos); + } + + ///// Distance + + /// @brief Compute the distance between an external vector and a vector in the index. + template double get_distance(size_t id, const Query& query) const { + // Check if id exists + if (!has_id(id)) { + throw ANNEXCEPTION("ID {} does not exist in the index!", id); + } + + // Verify dimensions match + const size_t query_size = query.size(); + const size_t index_vector_size = dimensions(); + if (query_size != index_vector_size) { + throw ANNEXCEPTION( + "Incompatible dimensions. Query has {} while the index expects {}.", + query_size, + index_vector_size + ); + } + + // Translate external ID to internal ID and get cluster location + size_t internal_id = translate_external_id(id); + size_t cluster_idx = id_to_cluster_[internal_id]; + size_t pos = id_in_cluster_[internal_id]; + + // Call extension for distance computation + return svs::index::ivf::extensions::get_distance_ext( + clusters_, distance_, cluster_idx, pos, query + ); + } + + /// @brief Iterate over all external IDs + template void on_ids(F&& f) const { + for (size_t i = 0; i < status_.size(); ++i) { + if (is_valid(i)) { + f(translator_.get_external(i)); + } + } + } + + /// @brief Get external IDs (compatibility method) + auto external_ids() const { + std::vector ids; + ids.reserve(size()); + on_ids([&ids](size_t id) { ids.push_back(id); }); + return ids; + } + + ///// Insertion ///// + + /// @brief Add points to the index + /// + /// New points are assigned to clusters based on nearest centroid. + /// Empty slots from previous deletions can be reused if reuse_empty is enabled. + /// + /// @param points Dataset of points to add + /// @param external_ids External IDs for the points + /// @param reuse_empty Whether to reuse empty slots from deletions + /// @return Vector of internal IDs where points were inserted + template + std::vector add_points( + const Points& points, const ExternalIds& external_ids, bool reuse_empty = false + ) { + const size_t num_points = points.size(); + const size_t num_ids = external_ids.size(); + + if (num_points != num_ids) { + throw ANNEXCEPTION( + "Number of points ({}) not equal to number of external ids ({})!", + num_points, + num_ids + ); + } + + // Assign each point to its nearest centroid + std::vector assigned_clusters(num_points); + assign_to_clusters(points, assigned_clusters); + + // Allocate global IDs + std::vector global_ids = allocate_ids(num_points, reuse_empty); + + // Try to update ID translation + translator_.insert(external_ids, global_ids); + + // Insert points into their assigned clusters + insert_into_clusters(points, global_ids, assigned_clusters); + + return global_ids; + } + + ///// Deletion ///// + + /// @brief Delete entries by external ID + /// + /// Entries are marked as Empty and can be reused immediately. + /// Call compact() periodically to reclaim memory and reorganize clusters. + /// + /// @param ids Container of external IDs to delete + /// @return Number of entries deleted + template size_t delete_entries(const T& ids) { + translator_.check_external_exist(ids.begin(), ids.end()); + + for (auto external_id : ids) { + size_t internal_id = translator_.get_internal(external_id); + assert(internal_id < status_.size()); + assert(status_[internal_id] == IVFSlotMetadata::Valid); + status_[internal_id] = IVFSlotMetadata::Empty; + first_empty_ = std::min(first_empty_, internal_id); + } + + translator_.delete_external(ids); + return ids.size(); + } + + ///// Compaction ///// + + /// @brief Consolidate the data structure (no-op for IVF). + /// + /// In the IVF index implementation, deletion marks entries as Empty in metadata, + /// making them invalid for searches. These empty slots can be reused by add_points. + /// This method is a no-op for compatibility with the dynamic index interface. + /// + void consolidate() { + // No-op: Deleted entries are marked Empty and excluded from searches + } + + /// @brief Compact the data structure + /// + /// Compact removes all empty slots, rebuilding the index structure + /// for optimal memory usage and search performance. + /// + /// @param batch_size Granularity at which points are shuffled (unused for IVF) + void compact(size_t batch_size = 1'000) { + // Step 1: Compute mapping from new to old indices + auto valid_indices = nonmissing_indices(); + + // Step 2: Group valid indices by cluster + std::vector>> cluster_valid_indices( + clusters_.size() + ); + + // Collect all external ID mappings BEFORE modifying translator + std::vector external_ids; + std::vector new_internal_ids; + external_ids.reserve(valid_indices.size()); + new_internal_ids.reserve(valid_indices.size()); + + for (size_t new_id = 0; new_id < valid_indices.size(); ++new_id) { + size_t old_id = valid_indices[new_id]; + size_t cluster_idx = id_to_cluster_[old_id]; + cluster_valid_indices[cluster_idx].push_back({new_id, old_id}); + + auto external_id = translator_.get_external(old_id); + external_ids.push_back(external_id); + new_internal_ids.push_back(new_id); + } + + // Step 3: Save old metadata before clearing + auto old_id_in_cluster = id_in_cluster_; + translator_ = IDTranslator(); + + // Step 4: Compact each cluster using data_.compact() + for (size_t cluster_idx = 0; cluster_idx < clusters_.size(); ++cluster_idx) { + const auto& indices = cluster_valid_indices[cluster_idx]; + if (indices.empty()) { + clusters_[cluster_idx].data_.resize(0); + clusters_[cluster_idx].ids_.clear(); + continue; + } + + // Create a map from old position in cluster to new_global_id + // Use std::map to automatically sort by old position + std::map old_pos_to_global_id; + std::vector old_positions_sorted; + old_positions_sorted.reserve(indices.size()); + + for (const auto& [new_global_id, old_global_id] : indices) { + size_t old_pos = old_id_in_cluster[old_global_id]; + old_pos_to_global_id[old_pos] = new_global_id; + } + + // Extract sorted old positions (map keeps them sorted by key) + for (const auto& [old_pos, _] : old_pos_to_global_id) { + old_positions_sorted.push_back(old_pos); + } + + // Use data's compact() method - this reorders data in place + clusters_[cluster_idx].data_.compact( + lib::as_const_span(old_positions_sorted), + inter_query_threadpool_, + batch_size + ); + clusters_[cluster_idx].data_.resize(indices.size()); + + // After compact(), data is at positions [0, 1, 2, ...] corresponding to + // the sorted old positions. Build new IDs and metadata. + std::vector new_ids(indices.size()); + size_t compacted_pos = 0; + for (size_t old_pos : old_positions_sorted) { + size_t new_global_id = old_pos_to_global_id[old_pos]; + new_ids[compacted_pos] = static_cast(new_global_id); + id_to_cluster_[new_global_id] = cluster_idx; + id_in_cluster_[new_global_id] = compacted_pos; + compacted_pos++; + } + + clusters_[cluster_idx].ids_ = std::move(new_ids); + } + + // Step 5: Update global metadata + size_t new_size = valid_indices.size(); + status_.resize(new_size); + std::fill(status_.begin(), status_.end(), IVFSlotMetadata::Valid); + id_to_cluster_.resize(new_size); + id_in_cluster_.resize(new_size); + first_empty_ = new_size; + + // Step 6: Re-add all IDs to translator + translator_.insert(external_ids, new_internal_ids, false); + + svs::logging::info(logger_, "Compaction complete: {} valid entries", new_size); + } + + ///// Search ///// + + /// Translate internal IDs to external IDs in search results. + /// This method converts all IDs in the result view from internal (global) IDs + /// to external IDs using the ID map. + /// + /// @param ids Result indices to translate (2D array) + template + requires(std::tuple_size_v == 2) + void translate_to_external(DenseArray& ids) { + threads::parallel_for( + inter_query_threadpool_, + threads::StaticPartition{getsize<0>(ids)}, + [&](const auto is, auto /*tid*/) { + for (auto i : is) { + for (size_t j = 0, jmax = getsize<1>(ids); j < jmax; ++j) { + auto internal = lib::narrow_cast(ids.at(i, j)); + ids.at(i, j) = translate_internal_id(internal); + } + } + } + ); + } + + /// @brief Perform similarity search + /// + /// Search Process: + /// 1. Inter-query parallel: Distribute queries across primary threads + /// 2. For each query: Find n_probe nearest centroids + /// 3. Intra-query parallel: Explore identified clusters using inner threads + /// 4. Combine results from all explored clusters (skipping empty entries) + /// + /// @param results View for storing search results + /// @param queries Query vectors + /// @param search_parameters Search configuration + /// @param cancel Optional cancellation predicate + template + void search( + QueryResultView results, + const Queries& queries, + const search_parameters_type& search_parameters, + const lib::DefaultPredicate& SVS_UNUSED(cancel) = lib::Returns(lib::Const()) + ) { + validate_query_batch_size(queries.size()); + + size_t num_neighbors = results.n_neighbors(); + size_t buffer_leaves_size = static_cast( + search_parameters.k_reorder_ * static_cast(num_neighbors) + ); + + // Phase 1: Inter-query parallel - Compute distances to centroids + compute_centroid_distances( + queries, centroids_, matmul_results_, inter_query_threadpool_ + ); + + // Phase 2: Process queries in parallel + threads::parallel_for( + inter_query_threadpool_, + threads::StaticPartition(queries.size()), + [&](auto is, auto tid) { + // Initialize search buffers + auto buffer_centroids = create_centroid_buffer(search_parameters.n_probes_); + auto buffer_leaves = create_leaf_buffers(buffer_leaves_size); + + // Prepare cluster search scratch space (distance copy) + // Pass cluster data (not centroids) to support quantized datasets + auto scratch = extensions::per_thread_batch_search_setup( + clusters_[0].data_, distance_ + ); + + // Execute search with intra-query parallelism + // Pass cluster data as first parameter to enable dataset-specific overrides + extensions::per_thread_batch_search( + clusters_[0].data_, + *this, + buffer_centroids, + buffer_leaves, + scratch, + queries, + results, + threads::UnitRange{is}, + tid, + search_centroids_closure(), + search_leaves_closure() + ); + } + ); + + // Convert internal IDs to external IDs + this->translate_to_external(results.indices()); + } + + ///// Saving ///// + + static constexpr lib::Version save_version = lib::Version(0, 0, 0); + + void save( + const std::filesystem::path& config_directory, + const std::filesystem::path& data_directory + ) { + // Compact before saving to remove empty slots + compact(); + + // Save configuration + lib::save_to_disk( + lib::SaveOverride([&](const lib::SaveContext& ctx) { + return lib::SaveTable( + "dynamic_ivf_config", + save_version, + { + {"name", lib::save(name())}, + {"translation", lib::save(translator_, ctx)}, + {"num_clusters", lib::save(clusters_.size())}, + } + ); + }), + config_directory + ); + + // Save centroids and cluster data + lib::save_to_disk(centroids_, data_directory / "centroids"); + + for (size_t i = 0; i < clusters_.size(); ++i) { + auto cluster_path = data_directory / fmt::format("cluster_{}", i); + lib::save_to_disk(clusters_[i].data_, cluster_path); + + auto ids_path = data_directory / fmt::format("cluster_ids_{}", i); + lib::save_to_disk(clusters_[i].ids_, ids_path); + } + } + + private: + ///// Helper Methods ///// + + void validate_thread_configuration() { + if (intra_query_thread_count_ < 1) { + throw std::invalid_argument("Intra-query thread count must be at least 1"); + } + } + + void initialize_thread_pools() { + for (size_t i = 0; i < inter_query_threadpool_.size(); i++) { + intra_query_threadpools_.push_back( + threads::as_threadpool(intra_query_thread_count_) + ); + } + } + + void initialize_search_buffers() { + auto batches = + std::vector>(inter_query_threadpool_.size()); + + threads::parallel_for( + inter_query_threadpool_, + threads::StaticPartition(centroids_.size()), + [&](auto is, auto tid) { batches[tid] = threads::UnitRange{is}; } + ); + + for (size_t i = 0; i < inter_query_threadpool_.size(); i++) { + matmul_results_.emplace_back(MAX_QUERY_BATCH_SIZE, batches[i].size()); + } + } + + void initialize_distance_metadata() { + if constexpr (std::is_same_v, distance::DistanceL2>) { + centroids_norm_.reserve(centroids_.size()); + for (size_t i = 0; i < centroids_.size(); ++i) { + centroids_norm_.push_back(distance::norm_square(centroids_.get_datum(i))); + } + } + } + + void validate_query_batch_size(size_t query_size) const { + if (query_size > MAX_QUERY_BATCH_SIZE) { + throw std::runtime_error(fmt::format( + "Query batch size {} exceeds maximum allowed {}", + query_size, + MAX_QUERY_BATCH_SIZE + )); + } + } + + auto create_centroid_buffer(size_t n_probes) const { + return SortedBuffer(n_probes, distance::comparator(distance_)); + } + + auto create_leaf_buffers(size_t buffer_size) const { + std::vector> buffers; + buffers.reserve(intra_query_thread_count_); + for (size_t j = 0; j < intra_query_thread_count_; j++) { + buffers.push_back( + SortedBuffer(buffer_size, distance::comparator(distance_)) + ); + } + return buffers; + } + + bool is_empty(size_t i) const { return status_[i] == IVFSlotMetadata::Empty; } + + bool is_valid(size_t i) const { return status_[i] == IVFSlotMetadata::Valid; } + + std::vector nonmissing_indices() const { + std::vector indices; + indices.reserve(size()); + for (size_t i = 0; i < status_.size(); ++i) { + if (is_valid(i)) { + indices.push_back(i); + } + } + return indices; + } + + /// @brief Assign points to their nearest centroids using parallel processing + /// + /// Uses centroid_assignment with batching to handle matmul_results size constraints. + /// Processes points in batches for efficient parallel centroid assignment. + /// + /// @param points Dataset to assign to clusters + /// @param assignments Output vector for cluster assignments + template + void assign_to_clusters(const Points& points, std::vector& assignments) { + size_t num_points = points.size(); + size_t num_centroids = centroids_.size(); + + // Compute norms if using L2 distance + auto data_norm = maybe_compute_norms(points, inter_query_threadpool_); + + // Determine batch size based on matmul_results capacity + // matmul_results_ is sized for queries, reuse for point assignment + size_t batch_size = matmul_results_[0].size(); // Number of queries it can hold + size_t num_batches = lib::div_round_up(num_points, batch_size); + + // Create a local matmul buffer for assignments (batch_size x num_centroids) + auto matmul_buffer = data::SimpleData{batch_size, num_centroids}; + auto timer = lib::Timer(); + + // Process points in batches + for (size_t batch = 0; batch < num_batches; ++batch) { + auto batch_range = threads::UnitRange{ + batch * batch_size, std::min((batch + 1) * batch_size, num_points)}; + + // Use centroid_assignment to compute assignments for this batch + centroid_assignment( + points, + data_norm, + batch_range, + distance_, + centroids_, + centroids_norm_, + assignments, + matmul_buffer, + inter_query_threadpool_, + timer + ); + } + } + + std::vector allocate_ids(size_t count, bool reuse_empty) { + std::vector ids; + ids.reserve(count); + + // Try to find empty slots if reuse is enabled + if (reuse_empty) { + for (size_t i = 0; i < status_.size() && ids.size() < count; ++i) { + if (is_empty(i)) { + ids.push_back(i); + status_[i] = IVFSlotMetadata::Valid; // Mark as valid when reusing + } + } + } + + // Allocate new slots as needed + size_t current_size = status_.size(); + while (ids.size() < count) { + ids.push_back(current_size++); + } + + // Resize metadata if we added new slots + if (current_size > status_.size()) { + status_.resize(current_size, IVFSlotMetadata::Valid); + id_to_cluster_.resize(current_size); + id_in_cluster_.resize(current_size); + first_empty_ = current_size; + } + + return ids; + } + + template + void insert_into_clusters( + const Points& points, + const std::vector& global_ids, + const std::vector& assigned_clusters + ) { + for (size_t i = 0; i < points.size(); ++i) { + size_t global_id = global_ids[i]; + size_t cluster_idx = assigned_clusters[i]; + + // Add to cluster + auto& cluster = clusters_[cluster_idx]; + + size_t pos = cluster.size(); + cluster.resize(cluster.size() + 1); + cluster.data_.set_datum(pos, points.get_datum(i)); + cluster.ids_.push_back(static_cast(global_id)); + + // Update metadata + status_[global_id] = IVFSlotMetadata::Valid; + id_to_cluster_[global_id] = cluster_idx; + id_in_cluster_[global_id] = pos; + } + } + + ///// Search Closures ///// + + /// @brief Create closure for searching centroids + auto search_centroids_closure() const { + return [this](const auto& query, auto& buffer_centroids, size_t query_idx) { + search_centroids( + query, + distance_, + matmul_results_, + buffer_centroids, + query_idx, + centroids_norm_, + get_num_threads() + ); + }; + } + + /// @brief Create closure for searching clusters/leaves + auto search_leaves_closure() { + return [this]( + const auto& query, + auto& distance, + const auto& buffer_centroids, + auto& buffer_leaves, + size_t tid + ) { + // Use the common search_leaves function with *this as cluster accessor + // DynamicIVFIndex provides a custom on_leaves that filters invalid entries + search_leaves( + query, + distance, + *this, + buffer_centroids, + buffer_leaves, + intra_query_threadpools_[tid] + ); + }; + } + + public: + /// @brief Custom on_leaves that wraps DenseCluster::on_leaves with validity filtering + /// This ensures deleted entries are skipped during search + template void on_leaves(Callback&& f, size_t cluster_id) const { + clusters_[cluster_id].on_leaves( + [this, &f](const auto& datum, auto global_id, auto local_pos) { + // Only invoke callback for valid (non-deleted) entries + if (is_valid(global_id)) { + f(datum, global_id, local_pos); + } + }, + prefetch_offset_ + ); + } + + /// @brief Get global ID - delegates to DenseClusteredDataset + size_t get_global_id(size_t cluster_id, size_t local_pos) const { + return clusters_.get_global_id(cluster_id, local_pos); + } +}; + +/// @brief Assemble a DynamicIVFIndex from clustering and data prototype +/// +/// @param clustering The clustering result containing centroids and assignments +/// @param data_proto Data prototype (file path or data object) to load +/// @param ids External IDs for the data points (must match data size) +/// @param distance Distance function to use +/// @param threadpool_proto Thread pool for parallel operations +/// @param intra_query_thread_count Number of threads for intra-query parallelism +/// +template < + typename Clustering, + typename DataProto, + typename Distance, + typename ThreadpoolProto> +auto assemble_dynamic_from_clustering( + Clustering clustering, + const DataProto& data_proto, + std::span ids, + Distance distance, + ThreadpoolProto threadpool_proto, + const size_t intra_query_thread_count = 1 +) { + using I = uint32_t; + using centroids_type = data::SimpleData; + + // Load the data + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + auto data = svs::detail::dispatch_load(data_proto, threadpool); + + // Validate that ids size matches data size + if (ids.size() != data.size()) { + throw ANNEXCEPTION( + "IDs size (", ids.size(), ") does not match data size (", data.size(), ")" + ); + } + + // Use lib_blocked_alloc_data_type for Dynamic IVF + using blocked_data_type = typename decltype(data)::lib_blocked_alloc_data_type; + + // Use a small block size for IVF clusters (1MB instead of 1GB default) + auto blocking_params = data::BlockingParameters{ + .blocksize_bytes = lib::PowerOfTwo(20) // 2^20 = 1MB + }; + using allocator_type = typename blocked_data_type::allocator_type; + auto blocked_allocator = + allocator_type(blocking_params, typename allocator_type::allocator_type()); + + // Create clustered dataset - DenseClusteredDataset will use the extension system + // to create the appropriate data type with blocked allocator via create_dense_cluster + auto dense_clusters = DenseClusteredDataset( + clustering, data, threadpool, blocked_allocator + ); + + // Create the index + return DynamicIVFIndex< + centroids_type, + decltype(dense_clusters), + Distance, + decltype(threadpool)>( + std::move(clustering.centroids()), + std::move(dense_clusters), + ids, + std::move(distance), + std::move(threadpool), + intra_query_thread_count + ); +} + +} // namespace svs::index::ivf diff --git a/include/svs/index/ivf/extensions.h b/include/svs/index/ivf/extensions.h index 0ef83587b..79785f2a4 100644 --- a/include/svs/index/ivf/extensions.h +++ b/include/svs/index/ivf/extensions.h @@ -180,16 +180,32 @@ struct CreateDenseCluster { inline constexpr CreateDenseCluster create_dense_cluster{}; -template +// Specialization for default allocator (backward compatibility) +// When no specific allocator is provided, use default construction with same extent +template svs::data::SimpleData svs_invoke( svs::tag_t, - const svs::data::SimpleData& original, + const svs::data::SimpleData& original, size_t new_size, - const NewAlloc& SVS_UNUSED(allocator) + const svs::lib::Allocator& SVS_UNUSED(allocator) ) { return svs::data::SimpleData(new_size, original.dimensions()); } +// General implementation for Blocked allocators: Always use Dynamic extent for flexibility +// This enables dynamic resizing which is essential for dynamic IVF operations +template +svs::data::SimpleData> svs_invoke( + svs::tag_t, + const svs::data::SimpleData& original, + size_t new_size, + const svs::data::Blocked& allocator +) { + return svs::data::SimpleData>( + new_size, original.dimensions(), allocator + ); +} + struct SetDenseCluster { template void operator()( @@ -217,4 +233,48 @@ void svs_invoke( } } +///// +///// Distance Computation +///// + +struct ComputeDistanceType { + template + double operator()( + const Clusters& clusters, + const Distance& distance, + size_t cluster_idx, + size_t pos, + const Query& query + ) const { + return svs_invoke( + *this, clusters[cluster_idx].view_cluster(), distance, pos, query + ); + } +}; + +// CPO for distance computation +inline constexpr ComputeDistanceType get_distance_ext{}; + +// Default overload +template +double svs_invoke( + svs::tag_t, + const Data& data, + const Distance& distance, + size_t pos, + const Query& query +) { + // Get distance function + auto dist_f = per_thread_batch_search_setup(data, distance); + svs::distance::maybe_fix_argument(dist_f, query); + + // Get the vector + auto indexed_span = data.get_datum(pos); + + // Compute the distance + auto dist = svs::distance::compute(dist_f, query, indexed_span); + + return static_cast(dist); +} + } // namespace svs::index::ivf::extensions diff --git a/include/svs/index/ivf/hierarchical_kmeans.h b/include/svs/index/ivf/hierarchical_kmeans.h index 168540868..bf3d1e8c5 100644 --- a/include/svs/index/ivf/hierarchical_kmeans.h +++ b/include/svs/index/ivf/hierarchical_kmeans.h @@ -71,7 +71,8 @@ auto hierarchical_kmeans_clustering_impl( Distance& distance, Pool& threadpool, lib::Type SVS_UNUSED(integer_type) = {}, - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + bool train_only = false ) { auto timer = lib::Timer(); auto kmeans_timer = timer.push_back("Hierarchical kmeans clustering"); @@ -84,21 +85,30 @@ auto hierarchical_kmeans_clustering_impl( size_t num_level1_clusters = parameters.hierarchical_level1_clusters_; if (num_level1_clusters == 0) { - num_level1_clusters = std::sqrt(num_clusters); + num_level1_clusters = static_cast(std::sqrt(num_clusters)); } svs::logging::debug(logger, "Level1 clusters: {}\n", num_level1_clusters); // Step 1: Create training set - size_t num_training_data = - lib::narrow(std::ceil(data.size() * parameters.training_fraction_)); - if (num_training_data < num_clusters || num_training_data > data.size()) { + // Use at least MIN_TRAINING_SAMPLE_MULTIPLIER times the number of centroids, + // but no more than the dataset size. This ensures we have enough training data + // even for small datasets, without exceeding the available data. + size_t min_training_data = + std::min(num_clusters * MIN_TRAINING_SAMPLE_MULTIPLIER, data.size()); + size_t num_training_data = std::max( + min_training_data, + lib::narrow(std::ceil(data.size() * parameters.training_fraction_)) + ); + // Ensure we don't exceed the data size + num_training_data = std::min(num_training_data, data.size()); + + if (num_training_data < num_clusters) { throw ANNEXCEPTION( - "Invalid number of training data: {}, num_clusters: {}, total data size: " - "{}\n", - num_training_data, - num_clusters, - data.size() + "Insufficient data for clustering: {} datapoints, {} clusters required. " + "Need at least as many datapoints as clusters.\n", + data.size(), + num_clusters ); } auto rng = std::mt19937(parameters.seed_); @@ -157,41 +167,46 @@ auto hierarchical_kmeans_clustering_impl( auto clusters_level1 = group_assignments(assignments_level1, num_level1_clusters, data_train); - // Step 5: Assign all data to clusters + std::vector> clusters_level1_all; + + // Declare timer outside of block to avoid scope issues auto all_assignments_time = timer.push_back("level1 all assignments"); - auto all_assignments_alloc = timer.push_back("level1 all assignments alloc"); - auto assignments_level1_all = std::vector(data.size()); - all_assignments_alloc.finish(); - batchsize = parameters.minibatch_size_; - num_batches = lib::div_round_up(data.size(), batchsize); + if (!train_only) { + // Step 5: Assign all data to clusters + auto assignments_level1_all = std::vector(data.size()); - data_norm = maybe_compute_norms(data, threadpool); - auto data_batch = data::SimpleData{batchsize, ndims}; - for (size_t batch = 0; batch < num_batches; ++batch) { - auto this_batch = threads::UnitRange{ - batch * batchsize, std::min((batch + 1) * batchsize, data.size())}; - auto data_batch_view = data::make_view(data, this_batch); - auto all_assignments_convert = timer.push_back("level1 all assignments convert"); - convert_data(data_batch_view, data_batch, threadpool); - all_assignments_convert.finish(); - centroid_assignment( - data_batch, - data_norm, - this_batch, - distance, - centroids_level1, - centroids_level1_norm, - assignments_level1_all, - matmul_results_level1, - threadpool, - timer - ); + batchsize = parameters.minibatch_size_; + num_batches = lib::div_round_up(data.size(), batchsize); + + data_norm = maybe_compute_norms(data, threadpool); + auto data_batch = data::SimpleData{batchsize, ndims}; + for (size_t batch = 0; batch < num_batches; ++batch) { + auto this_batch = threads::UnitRange{ + batch * batchsize, std::min((batch + 1) * batchsize, data.size())}; + auto data_batch_view = data::make_view(data, this_batch); + convert_data(data_batch_view, data_batch, threadpool); + centroid_assignment( + data_batch, + data_norm, + this_batch, + distance, + centroids_level1, + centroids_level1_norm, + assignments_level1_all, + matmul_results_level1, + threadpool, + timer + ); + } + auto all_assignments_cluster = timer.push_back("level1 all assignments clusters"); + clusters_level1_all = + group_assignments(assignments_level1_all, num_level1_clusters, data); + all_assignments_cluster.finish(); + } else { + // For train_only, create empty clusters + clusters_level1_all.resize(num_level1_clusters); } - auto all_assignments_cluster = timer.push_back("level1 all assignments clusters"); - auto clusters_level1_all = - group_assignments(assignments_level1_all, num_level1_clusters, data); - all_assignments_cluster.finish(); all_assignments_time.finish(); level1_training_time.finish(); @@ -206,10 +221,20 @@ auto hierarchical_kmeans_clustering_impl( auto clusters_final = std::vector>(num_clusters); size_t max_data_per_cluster = 0; - for (size_t cluster = 0; cluster < num_level1_clusters; cluster++) { - max_data_per_cluster = clusters_level1_all[cluster].size() > max_data_per_cluster - ? clusters_level1_all[cluster].size() - : max_data_per_cluster; + if (!train_only) { + for (size_t cluster = 0; cluster < num_level1_clusters; cluster++) { + max_data_per_cluster = + clusters_level1_all[cluster].size() > max_data_per_cluster + ? clusters_level1_all[cluster].size() + : max_data_per_cluster; + } + } else { + // In train_only mode, use training clusters for Level 2 training + for (size_t cluster = 0; cluster < num_level1_clusters; cluster++) { + max_data_per_cluster = clusters_level1[cluster].size() > max_data_per_cluster + ? clusters_level1[cluster].size() + : max_data_per_cluster; + } } auto data_level2 = data::SimpleData{max_data_per_cluster, ndims}; @@ -219,7 +244,8 @@ auto hierarchical_kmeans_clustering_impl( for (size_t cluster = 0; cluster < num_level1_clusters; cluster++) { size_t num_clusters_l2 = num_level2_clusters[cluster]; size_t num_assignments_l2 = clusters_level1[cluster].size(); - size_t num_assignments_l2_all = clusters_level1_all[cluster].size(); + size_t num_assignments_l2_all = + train_only ? 0 : clusters_level1_all[cluster].size(); auto matmul_results_level2 = data::SimpleData{parameters.minibatch_size_, num_clusters_l2}; @@ -255,47 +281,51 @@ auto hierarchical_kmeans_clustering_impl( ); auto all_assignments_level2 = timer.push_back("level2 all assignments"); - threads::parallel_for( - threadpool, - threads::StaticPartition{num_assignments_l2_all}, - [&](auto indices, auto /*tid*/) { - for (auto i : indices) { - data_level2.set_datum( - i, data.get_datum(clusters_level1_all[cluster][i]) - ); - } - } - ); - batchsize = parameters.minibatch_size_; - num_batches = lib::div_round_up(num_assignments_l2_all, batchsize); - - data_norm = maybe_compute_norms(data_level2, threadpool); - auto centroids_level2_norm = - maybe_compute_norms(centroids_level2_fp32, threadpool); - for (size_t batch = 0; batch < num_batches; ++batch) { - auto this_batch = threads::UnitRange{ - batch * batchsize, - std::min((batch + 1) * batchsize, num_assignments_l2_all)}; - auto data_batch = data::make_view(data_level2, this_batch); - centroid_assignment( - data_batch, - data_norm, - this_batch, - distance, - centroids_level2, - centroids_level2_norm, - assignments_level2_all, - matmul_results_level2, + if (!train_only) { + // Only do Level 2 assignments if not in train_only mode + threads::parallel_for( threadpool, - timer + threads::StaticPartition{num_assignments_l2_all}, + [&](auto indices, auto /*tid*/) { + for (auto i : indices) { + data_level2.set_datum( + i, data.get_datum(clusters_level1_all[cluster][i]) + ); + } + } ); - } - for (size_t i = 0; i < num_assignments_l2_all; i++) { - clusters_final[cluster_start + assignments_level2_all[i]].push_back( - clusters_level1_all[cluster][i] - ); + batchsize = parameters.minibatch_size_; + num_batches = lib::div_round_up(num_assignments_l2_all, batchsize); + + data_norm = maybe_compute_norms(data_level2, threadpool); + auto centroids_level2_norm = + maybe_compute_norms(centroids_level2_fp32, threadpool); + for (size_t batch = 0; batch < num_batches; ++batch) { + auto this_batch = threads::UnitRange{ + batch * batchsize, + std::min((batch + 1) * batchsize, num_assignments_l2_all)}; + auto data_batch = data::make_view(data_level2, this_batch); + centroid_assignment( + data_batch, + data_norm, + this_batch, + distance, + centroids_level2, + centroids_level2_norm, + assignments_level2_all, + matmul_results_level2, + threadpool, + timer + ); + } + + for (size_t i = 0; i < num_assignments_l2_all; i++) { + clusters_final[cluster_start + assignments_level2_all[i]].push_back( + clusters_level1_all[cluster][i] + ); + } } threads::parallel_for( @@ -313,6 +343,7 @@ auto hierarchical_kmeans_clustering_impl( cluster_start += num_clusters_l2; all_assignments_level2.finish(); } + level2_training_time.finish(); kmeans_timer.finish(); @@ -338,10 +369,11 @@ auto hierarchical_kmeans_clustering( Distance& distance, Pool& threadpool, lib::Type integer_type = {}, - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + bool train_only = false ) { return hierarchical_kmeans_clustering_impl( - parameters, data, distance, threadpool, integer_type, std::move(logger) + parameters, data, distance, threadpool, integer_type, std::move(logger), train_only ); } diff --git a/include/svs/index/ivf/index.h b/include/svs/index/ivf/index.h index fb3f3c80c..dc11473f1 100644 --- a/include/svs/index/ivf/index.h +++ b/include/svs/index/ivf/index.h @@ -30,6 +30,8 @@ #include "fmt/core.h" // stl +#include +#include #include #include @@ -39,7 +41,7 @@ namespace svs::index::ivf { // performance. This value was chosen based on empirical testing to avoid excessive memory // allocation while supporting large batch operations typical in high-throughput // environments. -const size_t MAX_QUERY_BATCH_SIZE = 10000; +constexpr size_t MAX_QUERY_BATCH_SIZE = 10000; /// @brief IVF (Inverted File) Index implementation for efficient similarity search /// @@ -164,6 +166,50 @@ class IVFIndex { search_parameters_ = search_parameters; } + ///// ID Mapping ///// + + /// @brief Check if an ID exists in the index + bool has_id(size_t id) const { + return id < id_to_cluster_.size() && id_to_cluster_[id] != SIZE_MAX; + } + + ///// Distance Computation ///// + + /// @brief Compute the distance between a query vector and a vector in the index + template double get_distance(size_t id, const Query& query) const { + // Thread-safe lazy initialization of ID mapping + std::call_once(*id_mapping_init_flag_, [this]() { initialize_id_mapping(); }); + + // Check if id exists + if (!has_id(id)) { + throw ANNEXCEPTION("ID {} does not exist in the index!", id); + } + + // Verify dimensions match + const size_t query_size = query.size(); + const size_t index_vector_size = dimensions(); + if (query_size != index_vector_size) { + throw ANNEXCEPTION( + "Incompatible dimensions. Query has {} while the index expects {}.", + query_size, + index_vector_size + ); + } + + // Get cluster and position + size_t cluster_id = id_to_cluster_[id]; + size_t pos = id_in_cluster_[id]; + + // Fix distance argument if needed + auto distance_copy = distance_; + svs::distance::maybe_fix_argument(distance_copy, query); + + // Call extension for distance computation + return svs::index::ivf::extensions::get_distance_ext( + cluster_, distance_copy, cluster_id, pos, query + ); + } + ///// Search Implementation ///// /// @brief Search closure for centroid distance computation @@ -227,7 +273,9 @@ class IVFIndex { validate_query_batch_size(queries.size()); size_t num_neighbors = results.n_neighbors(); - size_t buffer_leaves_size = search_parameters.k_reorder_ * num_neighbors; + size_t buffer_leaves_size = static_cast( + search_parameters.k_reorder_ * static_cast(num_neighbors) + ); // Phase 1: Inter-query parallel - Compute distances to centroids compute_centroid_distances( @@ -272,6 +320,15 @@ class IVFIndex { Data cluster0_; Dist distance_; + ///// ID Mapping for get_distance ///// + // Maps ID -> cluster_id + mutable std::vector id_to_cluster_{}; + // Maps ID -> position within cluster + mutable std::vector id_in_cluster_{}; + // Thread-safe initialization flag for ID mapping (wrapped in unique_ptr for movability) + mutable std::unique_ptr id_mapping_init_flag_{ + std::make_unique()}; + ///// Threading Infrastructure ///// InterQueryThreadPool inter_query_threadpool_; // Handles parallelism across queries const size_t intra_query_thread_count_; // Number of threads per query processing @@ -281,7 +338,7 @@ class IVFIndex { ///// Search Data ///// std::vector> matmul_results_; std::vector centroids_norm_; - search_parameters_type search_parameters_{}; + search_parameters_type search_parameters_; // SVS logger for per index logging svs::logging::logger_ptr logger_; @@ -320,7 +377,7 @@ class IVFIndex { void initialize_distance_metadata() { // Precalculate centroid norms for L2 distance - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v, distance::DistanceL2>) { centroids_norm_.reserve(centroids_.size()); for (size_t i = 0; i < centroids_.size(); i++) { centroids_norm_.push_back(distance::norm_square(centroids_.get_datum(i))); @@ -328,6 +385,31 @@ class IVFIndex { } } + void initialize_id_mapping() const { + // Build ID-to-location mapping from cluster data + // Compute total size by summing all cluster sizes + size_t total_size = 0; + size_t num_clusters = centroids_.size(); + for (size_t cluster_id = 0; cluster_id < num_clusters; ++cluster_id) { + total_size += cluster_.view_cluster(cluster_id).size(); + } + + // Initialize mapping vectors with sentinel value + id_to_cluster_.resize(total_size, SIZE_MAX); + id_in_cluster_.resize(total_size, SIZE_MAX); + + // Populate mappings + for (size_t cluster_id = 0; cluster_id < num_clusters; ++cluster_id) { + auto cluster_view = cluster_.view_cluster(cluster_id); + size_t cluster_size = cluster_view.size(); + for (size_t pos = 0; pos < cluster_size; ++pos) { + size_t id = cluster_.get_global_id(cluster_id, pos); + id_to_cluster_[id] = cluster_id; + id_in_cluster_[id] = pos; + } + } + } + ///// Helper Methods ///// void validate_query_batch_size(size_t query_size) const { @@ -388,6 +470,7 @@ auto build_clustering( const DataProto& data_proto, Distance distance, ThreadpoolProto threadpool_proto, + bool train_only = false, svs::logging::logger_ptr logger = svs::logging::get() ) { auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); @@ -402,11 +485,11 @@ auto build_clustering( // Choose clustering method based on parameters if (parameters.is_hierarchical_) { std::tie(centroids, clusters) = hierarchical_kmeans_clustering( - parameters, data, distance, threadpool, Idx{}, logger + parameters, data, distance, threadpool, Idx{}, logger, train_only ); } else { std::tie(centroids, clusters) = kmeans_clustering( - parameters, data, distance, threadpool, Idx{}, logger + parameters, data, distance, threadpool, Idx{}, logger, train_only ); } diff --git a/include/svs/index/ivf/kmeans.h b/include/svs/index/ivf/kmeans.h index c29d5c7fe..98bc94e27 100644 --- a/include/svs/index/ivf/kmeans.h +++ b/include/svs/index/ivf/kmeans.h @@ -32,7 +32,8 @@ auto kmeans_clustering_impl( Distance& distance, Pool& threadpool, lib::Type SVS_UNUSED(integer_type) = {}, - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + bool train_only = false ) { auto timer = lib::Timer(); auto kmeans_timer = timer.push_back("Non-hierarchical kmeans clustering"); @@ -44,15 +45,24 @@ auto kmeans_clustering_impl( auto num_centroids = parameters.num_centroids_; // Step 1: Create training set - size_t num_training_data = - lib::narrow(std::ceil(data.size() * parameters.training_fraction_)); - if (num_training_data < num_centroids || num_training_data > data.size()) { + // Use at least MIN_TRAINING_SAMPLE_MULTIPLIER times the number of centroids, + // but no more than the dataset size. This ensures we have enough training data + // even for small datasets, without exceeding the available data. + size_t min_training_data = + std::min(num_centroids * MIN_TRAINING_SAMPLE_MULTIPLIER, data.size()); + size_t num_training_data = std::max( + min_training_data, + lib::narrow(std::ceil(data.size() * parameters.training_fraction_)) + ); + // Ensure we don't exceed the data size + num_training_data = std::min(num_training_data, data.size()); + + if (num_training_data < num_centroids) { throw ANNEXCEPTION( - "Invalid number of training data: {}, num_centroids: {}, total data size: " - "{}\n", - num_training_data, - num_centroids, - data.size() + "Insufficient data for clustering: {} datapoints, {} centroids required. " + "Need at least as many datapoints as centroids.\n", + data.size(), + num_centroids ); } auto rng = std::mt19937(parameters.seed_); @@ -74,38 +84,45 @@ auto kmeans_clustering_impl( parameters, data_train, distance, centroids, matmul_results, rng, threadpool, timer ); - auto final_assignments_time = timer.push_back("final assignments"); - auto assignments = std::vector(data.size()); - auto batchsize = parameters.minibatch_size_; - auto num_batches = lib::div_round_up(data.size(), batchsize); + std::vector> clusters; - auto data_norm = maybe_compute_norms(data, threadpool); - auto centroids_norm = maybe_compute_norms(centroids_fp32, threadpool); + if (train_only) { + // Only train centroids, return empty clusters + clusters.resize(num_centroids); + } else { + // Step 4: Assign all data to clusters + auto final_assignments_time = timer.push_back("final assignments"); + auto assignments = std::vector(data.size()); + auto batchsize = parameters.minibatch_size_; + auto num_batches = lib::div_round_up(data.size(), batchsize); - // Step 4: Assign training data to clusters - auto data_batch = data::SimpleData{batchsize, ndims}; - for (size_t batch = 0; batch < num_batches; ++batch) { - auto this_batch = threads::UnitRange{ - batch * batchsize, std::min((batch + 1) * batchsize, data.size())}; - auto data_batch_view = data::make_view(data, this_batch); - convert_data(data_batch_view, data_batch, threadpool); - centroid_assignment( - data_batch, - data_norm, - this_batch, - distance, - centroids, - centroids_norm, - assignments, - matmul_results, - threadpool, - timer - ); - } + auto data_norm = maybe_compute_norms(data, threadpool); + auto centroids_norm = maybe_compute_norms(centroids_fp32, threadpool); + + auto data_batch = data::SimpleData{batchsize, ndims}; + for (size_t batch = 0; batch < num_batches; ++batch) { + auto this_batch = threads::UnitRange{ + batch * batchsize, std::min((batch + 1) * batchsize, data.size())}; + auto data_batch_view = data::make_view(data, this_batch); + convert_data(data_batch_view, data_batch, threadpool); + centroid_assignment( + data_batch, + data_norm, + this_batch, + distance, + centroids, + centroids_norm, + assignments, + matmul_results, + threadpool, + timer + ); + } - // Step 5: Assign all data to clusters - auto clusters = group_assignments(assignments, num_centroids, data); - final_assignments_time.finish(); + // Step 5: Group assignments into clusters + clusters = group_assignments(assignments, num_centroids, data); + final_assignments_time.finish(); + } kmeans_timer.finish(); svs::logging::debug(logger, "{}", timer); svs::logging::debug( @@ -126,10 +143,11 @@ auto kmeans_clustering( Distance& distance, Pool& threadpool, lib::Type integer_type = {}, - svs::logging::logger_ptr logger = svs::logging::get() + svs::logging::logger_ptr logger = svs::logging::get(), + bool train_only = false ) { return kmeans_clustering_impl( - parameters, data, distance, threadpool, integer_type, std::move(logger) + parameters, data, distance, threadpool, integer_type, std::move(logger), train_only ); } } // namespace svs::index::ivf diff --git a/include/svs/orchestrators/dynamic_ivf.h b/include/svs/orchestrators/dynamic_ivf.h new file mode 100644 index 000000000..d0e15ff1a --- /dev/null +++ b/include/svs/orchestrators/dynamic_ivf.h @@ -0,0 +1,298 @@ +/* + * Copyright 2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "svs/index/ivf/dynamic_ivf.h" +#include "svs/index/ivf/index.h" +#include "svs/lib/bfloat16.h" +#include "svs/orchestrators/ivf.h" +#include "svs/orchestrators/manager.h" + +namespace svs { + +/// +/// @brief Type-erased wrapper for DynamicIVF. +/// +/// Implementation details: The DynamicIVF implementation implements a superset of the +/// operations supported by the IVFInterface. +/// +class DynamicIVFInterface : public IVFInterface { + public: + // TODO: For now - only accept floating point entries. + virtual void add_points( + const float* data, + size_t dim0, + size_t dim1, + std::span ids, + bool reuse_empty = false + ) = 0; + + virtual size_t delete_points(std::span ids) = 0; + virtual void consolidate() = 0; + virtual void compact(size_t batchsize = 1'000'000) = 0; + + // ID inspection. + virtual bool has_id(size_t id) const = 0; + virtual void all_ids(std::vector& ids) const = 0; + + // Distance calculation + virtual double get_distance(size_t id, const AnonymousArray<1>& query) const = 0; + + // Saving + virtual void save( + const std::filesystem::path& config_directory, + const std::filesystem::path& data_directory + ) = 0; +}; + +template +class DynamicIVFImpl : public IVFImpl { + public: + using base_type = IVFImpl; + using base_type::impl; + + explicit DynamicIVFImpl(Impl impl) + : base_type{std::move(impl)} {} + + template + explicit DynamicIVFImpl(Args&&... args) + : base_type{std::forward(args)...} {} + + // Implement the interface. + void add_points( + const float* data, + size_t dim0, + size_t dim1, + std::span ids, + bool reuse_empty = false + ) override { + auto points = data::ConstSimpleDataView(data, dim0, dim1); + impl().add_points(points, ids, reuse_empty); + } + + size_t delete_points(std::span ids) override { + return impl().delete_entries(ids); + } + + void consolidate() override { impl().consolidate(); } + + void compact(size_t batchsize) override { impl().compact(batchsize); } + + // ID inspection. + bool has_id(size_t id) const override { return impl().has_id(id); } + + void all_ids(std::vector& ids) const override { + ids.clear(); + impl().on_ids([&ids](size_t id) { ids.push_back(id); }); + } + + ///// Distance + double get_distance(size_t id, const AnonymousArray<1>& query) const override { + return svs::lib::match( + QueryTypes{}, + query.type(), + [&](svs::lib::Type) { + auto query_span = std::span(get(query), query.size(0)); + return impl().get_distance(id, query_span); + } + ); + } + + ///// Saving + void save( + const std::filesystem::path& config_directory, + const std::filesystem::path& data_directory + ) override { + impl().save(config_directory, data_directory); + } +}; + +// Forward Declarations. +class DynamicIVF; + +template +DynamicIVF make_dynamic_ivf(Args&&... args); + +/// +/// DynamicIVF +/// +class DynamicIVF : public manager::IndexManager { + public: + using base_type = manager::IndexManager; + using IVFSearchParameters = index::ivf::IVFSearchParameters; + + struct AssembleTag {}; + + /// + /// @brief Construct a new DynamicIVF instance. + /// + /// @param impl A pointer to a concrete implementation of the full + /// DynamicIVFInterface. + /// + explicit DynamicIVF(std::unique_ptr> impl + ) + : base_type{std::move(impl)} {} + + template + explicit DynamicIVF(AssembleTag SVS_UNUSED(tag), QueryTypes SVS_UNUSED(type), Impl impl) + : base_type{std::make_unique>(std::move(impl))} {} + + // Mutable Interface. + DynamicIVF& add_points( + data::ConstSimpleDataView points, + std::span ids, + bool reuse_empty = false + ) { + impl_->add_points( + points.data(), points.size(), points.dimensions(), ids, reuse_empty + ); + return *this; + } + + size_t delete_points(std::span ids) { return impl_->delete_points(ids); } + + DynamicIVF& consolidate() { + impl_->consolidate(); + return *this; + } + + DynamicIVF& compact(size_t batchsize = 1'000'000) { + impl_->compact(batchsize); + return *this; + } + + // Backend String + std::string experimental_backend_string() const { + return impl_->experimental_backend_string(); + } + + // ID Inspection + + /// + /// @brief Return whether ``id`` is in the index. + /// + bool has_id(size_t id) const { return impl_->has_id(id); } + + /// + /// @brief Return all ``ids`` currently in the index. + /// + /// Note: If the stored index is large, the returned container may result in a + /// significant memory allocation. + /// + /// If more precise handling is required, please work with the lower level C++ class + /// directly. + /// + std::vector all_ids() const { + auto v = std::vector(); + impl_->all_ids(v); + return v; + } + + void save( + const std::filesystem::path& config_directory, + const std::filesystem::path& data_directory + ) { + impl_->save(config_directory, data_directory); + } + + ///// Distance + template double get_distance(size_t id, const Query& query) const { + // Create AnonymousArray from the query + AnonymousArray<1> query_array{query.data(), query.size()}; + return impl_->get_distance(id, query_array); + } + + ///// Assembly - Assemble from clustering and data + template < + manager::QueryTypeDefinition QueryTypes, + typename Clustering, + typename Data, + typename Distance, + typename ThreadPoolProto> + static DynamicIVF assemble_from_clustering( + Clustering clustering, + Data data, + std::span ids, + Distance distance, + ThreadPoolProto threadpool_proto, + size_t intra_query_threads = 1 + ) { + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + + if constexpr (std::is_same_v, DistanceType>) { + auto dispatcher = DistanceDispatcher(distance); + return dispatcher([&](auto distance_function) { + auto impl = index::ivf::assemble_dynamic_from_clustering( + std::move(clustering), + data, + ids, + std::move(distance_function), + std::move(threadpool), + intra_query_threads + ); + return DynamicIVF( + AssembleTag(), manager::as_typelist{}, std::move(impl) + ); + }); + } else { + auto impl = index::ivf::assemble_dynamic_from_clustering( + std::move(clustering), + data, + ids, + distance, + std::move(threadpool), + intra_query_threads + ); + return DynamicIVF( + AssembleTag(), manager::as_typelist{}, std::move(impl) + ); + } + } + + ///// Assembly - Assemble from file (load clustering from disk) + template < + manager::QueryTypeDefinition QueryTypes, + typename BuildType, + typename Data, + typename Distance, + typename ThreadPoolProto> + static DynamicIVF assemble_from_file( + const std::filesystem::path& cluster_path, + Data data, + std::span ids, + Distance distance, + ThreadPoolProto threadpool_proto, + size_t intra_query_threads = 1 + ) { + using centroids_type = data::SimpleData; + auto threadpool = threads::as_threadpool(std::move(threadpool_proto)); + auto clustering = + lib::load_from_disk>( + cluster_path, threadpool + ); + return assemble_from_clustering( + std::move(clustering), + data, + ids, + distance, + std::move(threadpool), + intra_query_threads + ); + } +}; + +} // namespace svs diff --git a/include/svs/orchestrators/ivf.h b/include/svs/orchestrators/ivf.h index 7c035f11c..f0f86f84b 100644 --- a/include/svs/orchestrators/ivf.h +++ b/include/svs/orchestrators/ivf.h @@ -27,6 +27,9 @@ class IVFInterface { ///// Backend information interface virtual std::string experimental_backend_string() const = 0; + + ///// Distance calculation + virtual double get_distance(size_t id, const AnonymousArray<1>& query) const = 0; }; template @@ -56,6 +59,19 @@ class IVFImpl : public manager::ManagerImpl { [[nodiscard]] std::string experimental_backend_string() const override { return std::string{typename_impl.begin(), typename_impl.end() - 1}; } + + ///// Distance Calculation + [[nodiscard]] double + get_distance(size_t id, const AnonymousArray<1>& query) const override { + return svs::lib::match( + QueryTypes{}, + query.type(), + [&](svs::lib::Type) { + auto query_span = std::span(get(query), query.size(0)); + return impl().get_distance(id, query_span); + } + ); + } }; ///// @@ -81,6 +97,14 @@ class IVF : public manager::IndexManager { return impl_->experimental_backend_string(); } + ///// Distance Calculation + template + double get_distance(size_t id, const QueryType& query) const { + // Create AnonymousArray from the query + AnonymousArray<1> query_array{query.data(), query.size()}; + return impl_->get_distance(id, query_array); + } + ///// Assembling template < manager::QueryTypeDefinition QueryTypes, diff --git a/include/svs/quantization/scalar/scalar.h b/include/svs/quantization/scalar/scalar.h index 7ddf1cb9d..a2244d1fd 100644 --- a/include/svs/quantization/scalar/scalar.h +++ b/include/svs/quantization/scalar/scalar.h @@ -374,6 +374,9 @@ class SQDataset { // Data wrapped in the library allocator. using lib_alloc_data_type = SQDataset>; + // Data wrapped in the blocked library allocator (for Dynamic IVF). + using lib_blocked_alloc_data_type = + SQDataset>>; private: float scale_; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 63c55a934..a023608de 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -178,6 +178,7 @@ if (SVS_EXPERIMENTAL_ENABLE_IVF) ${TEST_DIR}/integration/ivf/index_build.cpp ${TEST_DIR}/integration/ivf/index_search.cpp ${TEST_DIR}/integration/ivf/scalar_search.cpp + ${TEST_DIR}/integration/ivf/dynamic_scalar.cpp ) endif() @@ -207,6 +208,7 @@ if (SVS_EXPERIMENTAL_ENABLE_IVF) ${TEST_DIR}/svs/index/ivf/kmeans.cpp ${TEST_DIR}/svs/index/ivf/hierarchical_kmeans.cpp ${TEST_DIR}/svs/index/ivf/common.cpp + ${TEST_DIR}/svs/index/ivf/dynamic_ivf.cpp ) endif() diff --git a/tests/integration/ivf/dynamic_scalar.cpp b/tests/integration/ivf/dynamic_scalar.cpp new file mode 100644 index 000000000..df6a761fa --- /dev/null +++ b/tests/integration/ivf/dynamic_scalar.cpp @@ -0,0 +1,227 @@ +/* + * Copyright 2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// svs +#include "svs/core/data.h" +#include "svs/core/distance.h" +#include "svs/core/recall.h" +#include "svs/extensions/ivf/scalar.h" +#include "svs/index/ivf/clustering.h" +#include "svs/orchestrators/dynamic_ivf.h" +#include "svs/quantization/scalar/scalar.h" + +// catch2 +#include "catch2/catch_test_macros.hpp" + +// tests +#include "tests/utils/test_dataset.h" + +// fmt +#include "fmt/core.h" + +// stl +#include +#include + +namespace sc = svs::quantization::scalar; + +namespace { + +constexpr size_t NUM_NEIGHBORS = 10; +constexpr size_t NUM_CLUSTERS = 10; +constexpr size_t EXTENT = 128; + +/// +/// Test Dynamic IVF with Scalar Quantization +/// +template +void test_dynamic_ivf_scalar(const Distance& distance) { + size_t num_threads = 2; + size_t intra_query_threads = 2; + + // Load test dataset + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto queries = test_dataset::queries(); + auto gt = test_dataset::groundtruth_euclidean(); + + // Build clustering on UNCOMPRESSED data + auto build_params = svs::index::ivf::IVFBuildParameters(NUM_CLUSTERS, 10, false); + auto threadpool = svs::threads::SequentialThreadPool(); + auto clustering = svs::index::ivf::build_clustering( + build_params, data, distance, threadpool, false + ); + + // Compress the data with Scalar Quantization + auto compressed_data = sc::SQDataset::compress(data); + + // Generate external IDs for the data + std::vector ids(data.size()); + std::iota(ids.begin(), ids.end(), 0); + + auto index = svs::DynamicIVF::assemble_from_clustering( + std::move(clustering), + compressed_data, + ids, + distance, + svs::threads::as_threadpool(num_threads), + intra_query_threads + ); + + // Search + auto search_params = svs::index::ivf::IVFSearchParameters( + NUM_CLUSTERS, // n_probes + NUM_NEIGHBORS // k_reorder + ); + + auto results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + index.search( + results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + search_params + ); + + // Check recall + auto recall = svs::k_recall_at_n(gt, results, NUM_NEIGHBORS, NUM_NEIGHBORS); + + // Set expected recall thresholds based on quantization level + CATCH_REQUIRE(recall > 0.9); +} + +/// +/// Test Dynamic IVF with Scalar Quantization - Add/Delete/Compact stress test +/// +template +void test_dynamic_ivf_scalar_stress(const Distance& distance) { + size_t num_threads = 2; + size_t intra_query_threads = 2; + + // Load test dataset + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto queries = test_dataset::queries(); + auto gt = test_dataset::groundtruth_euclidean(); + + // Start with half the data + size_t initial_size = data.size() / 2; + auto initial_data = svs::data::SimpleData(initial_size, EXTENT); + for (size_t i = 0; i < initial_size; ++i) { + initial_data.set_datum(i, data.get_datum(i)); + } + + // Build clustering on initial data + auto build_params = svs::index::ivf::IVFBuildParameters(NUM_CLUSTERS, 10, false); + auto threadpool = svs::threads::SequentialThreadPool(); + auto clustering = svs::index::ivf::build_clustering( + build_params, initial_data, distance, threadpool, false + ); + + // Compress with Scalar Quantization + auto compressed_data = sc::SQDataset::compress(initial_data); + + // Generate external IDs + std::vector ids(initial_size); + std::iota(ids.begin(), ids.end(), 0); + + auto index = svs::DynamicIVF::assemble_from_clustering( + std::move(clustering), + compressed_data, + ids, + distance, + svs::threads::as_threadpool(num_threads), + intra_query_threads + ); + + auto search_params = svs::index::ivf::IVFSearchParameters(NUM_CLUSTERS, NUM_NEIGHBORS); + auto results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + + // Perform add/delete/compact cycles + std::mt19937 rng(12345); + std::uniform_int_distribution idx_dist(0, initial_size - 1); + + for (size_t cycle = 0; cycle < 3; ++cycle) { + // Delete some entries + std::vector to_delete; + for (size_t i = 0; i < 20 && i < ids.size(); ++i) { + size_t idx = idx_dist(rng) % ids.size(); + to_delete.push_back(ids[idx]); + } + if (!to_delete.empty()) { + index.delete_points(to_delete); + } + + // Add new entries (uncompressed - index will compress them) + size_t num_to_add = 30; + auto new_data = svs::data::SimpleData(num_to_add, EXTENT); + std::vector new_ids; + size_t new_base_id = 100000 + cycle * 1000; + + for (size_t i = 0; i < num_to_add; ++i) { + new_ids.push_back(new_base_id + i); + new_data.set_datum(i, data.get_datum(i % data.size())); + } + + // Pass uncompressed data as ConstSimpleDataView - index will compress + auto new_data_view = svs::data::ConstSimpleDataView{ + new_data.data(), new_data.size(), new_data.dimensions()}; + index.add_points(new_data_view, new_ids, false); + + // Search after modifications + index.search( + results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + search_params + ); + + // Verify no deleted IDs appear in results + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t k = 0; k < NUM_NEIGHBORS; ++k) { + auto result_id = results.index(q, k); + for (auto deleted_id : to_delete) { + CATCH_REQUIRE(result_id != deleted_id); + } + } + } + + // Compact every cycle + index.compact(50); + + // Search after compaction + index.search( + results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + search_params + ); + + // Verify all results are valid + for (size_t q = 0; q < queries.size(); ++q) { + CATCH_REQUIRE(results.index(q, 0) != std::numeric_limits::max()); + } + } +} + +} // anonymous namespace + +CATCH_TEST_CASE( + "Dynamic IVF with Scalar Quantization", "[integration][dynamic_ivf][scalar]" +) { + auto distance = svs::DistanceL2(); + + CATCH_SECTION("int8 quantization") { test_dynamic_ivf_scalar(distance); } + + CATCH_SECTION("int8 stress test") { test_dynamic_ivf_scalar_stress(distance); } +} diff --git a/tests/integration/ivf/index_build.cpp b/tests/integration/ivf/index_build.cpp index 36a2d6e50..3c6a11d90 100644 --- a/tests/integration/ivf/index_build.cpp +++ b/tests/integration/ivf/index_build.cpp @@ -17,6 +17,10 @@ // svs #include "svs/core/data/simple.h" #include "svs/core/recall.h" +#include "svs/index/ivf/clustering.h" +#include "svs/index/ivf/common.h" +#include "svs/index/ivf/hierarchical_kmeans.h" +#include "svs/lib/float16.h" #include "svs/lib/timing.h" #include "svs/orchestrators/ivf.h" @@ -99,6 +103,96 @@ void test_build(const Distance& distance, size_t num_inner_threads = 1) { } } +template +void test_build_train_only(const Distance& distance, size_t num_inner_threads = 1) { + const double epsilon = 0.06; // Wider tolerance for train_only workflow + const auto queries = svs::data::SimpleData::load(test_dataset::query_file()); + CATCH_REQUIRE(svs_test::prepare_temp_directory()); + size_t num_threads = 2; + + auto expected_result = test_dataset::ivf::expected_build_results( + svs::distance_type_v, svsbenchmark::Uncompressed(svs::datatype_v) + ); + + // Load data + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto threadpool = svs::threads::as_threadpool(num_threads); + auto parameters = expected_result.build_parameters_.value(); + + // Step 1: Use train_only mode to get centroids + svs::data::SimpleData centroids_train; + std::vector> clusters_train; + fmt::print( + "Starting Train-Only Mode Clustering with {} centroids\n", parameters.num_centroids_ + ); + + if (parameters.is_hierarchical_) { + fmt::print("Using Hierarchical KMeans Clustering\n"); + std::tie(centroids_train, clusters_train) = + svs::index::ivf::hierarchical_kmeans_clustering( + parameters, + data, + distance, + threadpool, + svs::lib::Type(), + svs::logging::get(), + true // train_only = true + ); + } else { + std::tie(centroids_train, clusters_train) = svs::index::ivf::kmeans_clustering( + parameters, + data, + distance, + threadpool, + svs::lib::Type(), + svs::logging::get(), + true // train_only = true + ); + } + + fmt::print("Train-Only Mode - Obtained {} centroids\n", centroids_train.size()); + + // Step 2: Assign data to clusters using cluster_assignment + auto clusters = svs::index::ivf::cluster_assignment( + data, + centroids_train, + distance, + threadpool, + 10'000, // minibatch_size + svs::lib::Type() + ); + + // Step 3: Create clustering and assemble index + svs::index::ivf::Clustering clustering(std::move(centroids_train), std::move(clusters)); + + auto index = svs::IVF::assemble_from_clustering( + std::move(clustering), std::move(data), distance, num_threads, num_inner_threads + ); + + // Test the index with the same expected results + auto groundtruth = test_dataset::load_groundtruth(svs::distance_type_v); + for (const auto& expected : expected_result.config_and_recall_) { + auto these_queries = test_dataset::get_test_set(queries, expected.num_queries_); + auto these_groundtruth = + test_dataset::get_test_set(groundtruth, expected.num_queries_); + index.set_search_parameters(expected.search_parameters_); + auto results = index.search(these_queries, expected.num_neighbors_); + double recall = svs::k_recall_at_n( + these_groundtruth, results, expected.num_neighbors_, expected.recall_k_ + ); + + fmt::print( + "Train-Only Mode - n_probes: {}, Expected Recall: {}, Actual Recall: {}\n", + index.get_search_parameters().n_probes_, + expected.recall_, + recall + ); + // Just check that recall is reasonable (within wider tolerance) + CATCH_REQUIRE(recall > expected.recall_ - epsilon); + CATCH_REQUIRE(recall < expected.recall_ + epsilon); + } +} + } // namespace CATCH_TEST_CASE("IVF Build/Clustering", "[integration][build][ivf]") { @@ -113,3 +207,11 @@ CATCH_TEST_CASE("IVF Build/Clustering", "[integration][build][ivf]") { // test_build(svs::DistanceL2(), 4); // test_build(svs::DistanceIP(), 4); } + +CATCH_TEST_CASE("IVF Build/Clustering", "[integration][build][ivf][train_only]") { + test_build_train_only(svs::DistanceL2()); + test_build_train_only(svs::DistanceIP()); + + test_build_train_only(svs::DistanceL2()); + test_build_train_only(svs::DistanceIP()); +} diff --git a/tests/integration/ivf/index_search.cpp b/tests/integration/ivf/index_search.cpp index cec5fcdad..7de26e2e7 100644 --- a/tests/integration/ivf/index_search.cpp +++ b/tests/integration/ivf/index_search.cpp @@ -15,11 +15,13 @@ */ // stl +#include #include #include #include #include #include +#include #include // svs @@ -139,3 +141,131 @@ CATCH_TEST_CASE("IVF Search", "[integration][search][ivf]") { test_search(data_f16, dist_ip, queries, gt_ip); test_search(data_f16, dist_ip, queries, gt_ip, 2); } + +CATCH_TEST_CASE("IVF get_distance", "[integration][ivf][get_distance]") { + auto datafile = test_dataset::data_svs_file(); + auto queries = test_dataset::queries(); + auto dist_l2 = svs::distance::DistanceL2(); + + auto data = svs::data::SimpleData::load(datafile); + + size_t num_threads = 2; + auto index = svs::IVF::assemble_from_file( + test_dataset::clustering_directory(), data, dist_l2, num_threads, 1 + ); + + // Test get_distance functionality with strict tolerance + constexpr double TOLERANCE = 1e-2; // 1% tolerance + + // Test with a few different IDs + std::vector test_ids = {0, 10, 50}; + if (data.size() > 100) { + test_ids.push_back(100); + } + + for (size_t test_id : test_ids) { + if (test_id >= data.size()) { + continue; + } + + // Get a query vector + size_t query_id = std::min(5, queries.size() - 1); + auto query = queries.get_datum(query_id); + + // Get distance from index + double index_distance = index.get_distance(test_id, query); + + // Compute expected distance from original data + auto datum = data.get_datum(test_id); + svs::distance::DistanceL2 dist_copy; + svs::distance::maybe_fix_argument(dist_copy, query); + double expected_distance = svs::distance::compute(dist_copy, query, datum); + + // Verify the distance is correct + double relative_diff = + std::abs((index_distance - expected_distance) / expected_distance); + CATCH_REQUIRE(relative_diff < TOLERANCE); + } + + // Test with out of bounds ID - should throw + CATCH_REQUIRE_THROWS_AS( + index.get_distance(data.size() + 1000, queries.get_datum(0)), svs::ANNException + ); +} + +CATCH_TEST_CASE( + "IVF get_distance thread safety", "[integration][ivf][get_distance][thread_safety]" +) { + auto datafile = test_dataset::data_svs_file(); + auto queries = test_dataset::queries(); + auto dist_l2 = svs::distance::DistanceL2(); + + auto data = svs::data::SimpleData::load(datafile); + + size_t num_threads = 2; + auto index = svs::IVF::assemble_from_file( + test_dataset::clustering_directory(), data, dist_l2, num_threads, 1 + ); + + // Test thread safety of get_distance with concurrent calls + // The lazy initialization of ID mapping should be thread-safe with std::call_once + constexpr size_t NUM_TEST_THREADS = 8; + constexpr size_t CALLS_PER_THREAD = 100; + constexpr double TOLERANCE = 1e-2; + + // Prepare test data + std::vector test_ids; + for (size_t i = 0; i < std::min(10, data.size()); ++i) { + test_ids.push_back(i * (data.size() / 10)); + } + + // Pre-compute expected distances for verification + std::vector> expected_distances(test_ids.size()); + for (size_t i = 0; i < test_ids.size(); ++i) { + expected_distances[i].resize(queries.size()); + auto datum = data.get_datum(test_ids[i]); + for (size_t q = 0; q < queries.size(); ++q) { + auto query = queries.get_datum(q); + svs::distance::DistanceL2 dist_copy; + svs::distance::maybe_fix_argument(dist_copy, query); + expected_distances[i][q] = svs::distance::compute(dist_copy, query, datum); + } + } + + // Track results and errors from threads + std::atomic success_count{0}; + std::atomic error_count{0}; + std::vector threads; + threads.reserve(NUM_TEST_THREADS); + + // Launch multiple threads that concurrently call get_distance + for (size_t t = 0; t < NUM_TEST_THREADS; ++t) { + threads.emplace_back([&, t]() { + for (size_t call = 0; call < CALLS_PER_THREAD; ++call) { + size_t id_idx = (t + call) % test_ids.size(); + size_t query_idx = (t * CALLS_PER_THREAD + call) % queries.size(); + size_t test_id = test_ids[id_idx]; + + auto query = queries.get_datum(query_idx); + double index_distance = index.get_distance(test_id, query); + double expected = expected_distances[id_idx][query_idx]; + + double relative_diff = std::abs((index_distance - expected) / expected); + if (relative_diff < TOLERANCE) { + ++success_count; + } else { + ++error_count; + } + } + }); + } + + // Wait for all threads to complete + for (auto& thread : threads) { + thread.join(); + } + + // Verify all calls succeeded + CATCH_REQUIRE(error_count == 0); + CATCH_REQUIRE(success_count == NUM_TEST_THREADS * CALLS_PER_THREAD); +} diff --git a/tests/svs/index/ivf/common.cpp b/tests/svs/index/ivf/common.cpp index 39df8503a..a23efc4bc 100644 --- a/tests/svs/index/ivf/common.cpp +++ b/tests/svs/index/ivf/common.cpp @@ -24,6 +24,19 @@ // catch #include "catch2/catch_test_macros.hpp" +// svs +#include "svs/core/data.h" +#include "svs/core/distance.h" +#include "svs/index/ivf/hierarchical_kmeans.h" +#include "svs/index/ivf/kmeans.h" +#include "svs/lib/threads.h" + +// stl +#include +#include +#include +#include + CATCH_TEST_CASE("Kmeans Clustering", "[ivf][parameters]") { namespace ivf = svs::index::ivf; CATCH_SECTION("IVF Build Parameters") { @@ -69,3 +82,642 @@ CATCH_TEST_CASE("Kmeans Clustering", "[ivf][parameters]") { CATCH_REQUIRE(svs::lib::test_self_save_load(p, dir)); } } + +CATCH_TEST_CASE("Common Utility Functions", "[ivf][common][core]") { + namespace ivf = svs::index::ivf; + + CATCH_SECTION("compute_matmul - All Data Types") { + // Test matrix multiplication for different data types + constexpr size_t m = 10; // number of data points + constexpr size_t n = 5; // number of centroids + constexpr size_t k = 8; // dimensions + + auto test_matmul = [&]() { + // Create test data + auto data = svs::data::SimpleData(m, k); + auto centroids = svs::data::SimpleData(n, k); + auto results = svs::data::SimpleData(m, n); + + // Fill with test values + for (size_t i = 0; i < m; ++i) { + auto datum = data.get_datum(i); + for (size_t j = 0; j < k; ++j) { + datum[j] = static_cast(i + j * 0.1); + } + } + + for (size_t i = 0; i < n; ++i) { + auto centroid = centroids.get_datum(i); + for (size_t j = 0; j < k; ++j) { + centroid[j] = static_cast(i * 0.5 + j); + } + } + + // Compute matrix multiplication + ivf::compute_matmul(data.data(), centroids.data(), results.data(), m, n, k); + + // Verify results are valid (not NaN or Inf) + for (size_t i = 0; i < m; ++i) { + for (size_t j = 0; j < n; ++j) { + float val = results.get_datum(i)[j]; + CATCH_REQUIRE(std::isfinite(val)); + } + } + + // Verify dimensions match expected output + CATCH_REQUIRE(results.size() == m); + CATCH_REQUIRE(results.dimensions() == n); + }; + + // Test all data types + test_matmul.operator()(); + test_matmul.operator()(); + test_matmul.operator()(); + } + + CATCH_SECTION("compute_matmul - Edge Cases") { + // Test with zero dimensions (should return without error) + auto results = svs::data::SimpleData(0, 0); + auto data = svs::data::SimpleData(0, 0); + auto centroids = svs::data::SimpleData(0, 0); + + // Should not crash with zero dimensions + ivf::compute_matmul(data.data(), centroids.data(), results.data(), 0, 0, 0); + + // Test with single point and single centroid + auto data_single = svs::data::SimpleData(1, 4); + auto centroid_single = svs::data::SimpleData(1, 4); + auto result_single = svs::data::SimpleData(1, 1); + + auto datum = data_single.get_datum(0); + auto centroid = centroid_single.get_datum(0); + for (size_t i = 0; i < 4; ++i) { + datum[i] = static_cast(i); + centroid[i] = static_cast(i + 1); + } + + ivf::compute_matmul( + data_single.data(), centroid_single.data(), result_single.data(), 1, 1, 4 + ); + + CATCH_REQUIRE(std::isfinite(result_single.get_datum(0)[0])); + } + + CATCH_SECTION("convert_data - Type Conversions") { + auto threadpool = svs::threads::as_threadpool(4); + + // Test float to Float16 conversion + auto data_float = svs::data::SimpleData(10, 8); + for (size_t i = 0; i < data_float.size(); ++i) { + auto datum = data_float.get_datum(i); + for (size_t j = 0; j < data_float.dimensions(); ++j) { + datum[j] = static_cast(i * 10 + j); + } + } + + auto data_fp16 = ivf::convert_data(data_float, threadpool); + CATCH_REQUIRE(data_fp16.size() == data_float.size()); + CATCH_REQUIRE(data_fp16.dimensions() == data_float.dimensions()); + + // Test float to BFloat16 conversion + auto data_bf16 = ivf::convert_data(data_float, threadpool); + CATCH_REQUIRE(data_bf16.size() == data_float.size()); + CATCH_REQUIRE(data_bf16.dimensions() == data_float.dimensions()); + + // Test Float16 to float conversion + auto data_back = ivf::convert_data(data_fp16, threadpool); + CATCH_REQUIRE(data_back.size() == data_fp16.size()); + CATCH_REQUIRE(data_back.dimensions() == data_fp16.dimensions()); + } + + CATCH_SECTION("generate_norms") { + auto threadpool = svs::threads::as_threadpool(4); + + // Create test data + auto data = svs::data::SimpleData(20, 10); + for (size_t i = 0; i < data.size(); ++i) { + auto datum = data.get_datum(i); + for (size_t j = 0; j < data.dimensions(); ++j) { + datum[j] = static_cast(i + j); + } + } + + std::vector norms(data.size()); + ivf::generate_norms(data, norms, threadpool); + + // Verify norms are computed + CATCH_REQUIRE(norms.size() == data.size()); + for (const auto& norm : norms) { + CATCH_REQUIRE(norm >= 0.0f); + CATCH_REQUIRE(std::isfinite(norm)); + } + } + + CATCH_SECTION("maybe_compute_norms") { + auto threadpool = svs::threads::as_threadpool(4); + auto data = svs::data::SimpleData(15, 8); + + for (size_t i = 0; i < data.size(); ++i) { + auto datum = data.get_datum(i); + for (size_t j = 0; j < data.dimensions(); ++j) { + datum[j] = static_cast(i + j * 0.5); + } + } + + // For L2 distance, norms should be computed + auto norms_l2 = ivf::maybe_compute_norms(data, threadpool); + CATCH_REQUIRE(norms_l2.size() == data.size()); + for (const auto& norm : norms_l2) { + CATCH_REQUIRE(norm >= 0.0f); + } + + // For IP distance, norms should be empty + auto norms_ip = ivf::maybe_compute_norms(data, threadpool); + CATCH_REQUIRE(norms_ip.empty()); + } + + CATCH_SECTION("group_assignments") { + // Test grouping assignments + size_t num_centroids = 5; + size_t data_size = 50; + + // Create assignments (each point assigned to a centroid) + std::vector assignments(data_size); + for (size_t i = 0; i < data_size; ++i) { + assignments[i] = i % num_centroids; + } + + auto data = svs::data::SimpleData(data_size, 8); + auto groups = ivf::group_assignments(assignments, num_centroids, data); + + CATCH_REQUIRE(groups.size() == num_centroids); + + // Verify all points are assigned + size_t total_assigned = 0; + for (const auto& group : groups) { + total_assigned += group.size(); + } + CATCH_REQUIRE(total_assigned == data_size); + + // Verify each group has expected size + for (const auto& group : groups) { + CATCH_REQUIRE(group.size() == data_size / num_centroids); + } + } + + CATCH_SECTION("make_training_set") { + auto threadpool = svs::threads::as_threadpool(4); + auto rng = std::mt19937(12345); + + // Create full dataset + size_t full_size = 100; + size_t training_size = 30; + auto data = svs::data::SimpleData(full_size, 16); + + for (size_t i = 0; i < data.size(); ++i) { + auto datum = data.get_datum(i); + for (size_t j = 0; j < data.dimensions(); ++j) { + datum[j] = static_cast(i * 10 + j); + } + } + + std::vector ids(training_size); + auto training_set = + ivf::make_training_set>( + data, ids, training_size, rng, threadpool + ); + + CATCH_REQUIRE(training_set.size() == training_size); + CATCH_REQUIRE(training_set.dimensions() == data.dimensions()); + CATCH_REQUIRE(ids.size() == training_size); + + // Verify IDs are valid and unique + std::unordered_set unique_ids(ids.begin(), ids.end()); + CATCH_REQUIRE(unique_ids.size() == training_size); + for (const auto& id : ids) { + CATCH_REQUIRE(id < full_size); + } + } + + CATCH_SECTION("init_centroids") { + auto threadpool = svs::threads::as_threadpool(4); + auto rng = std::mt19937(54321); + + // Create training data + size_t training_size = 50; + size_t num_centroids = 10; + auto trainset = svs::data::SimpleData(training_size, 12); + + for (size_t i = 0; i < trainset.size(); ++i) { + auto datum = trainset.get_datum(i); + for (size_t j = 0; j < trainset.dimensions(); ++j) { + datum[j] = static_cast(i + j * 0.3); + } + } + + std::vector ids(num_centroids); + auto centroids = + ivf::init_centroids(trainset, ids, num_centroids, rng, threadpool); + + CATCH_REQUIRE(centroids.size() == num_centroids); + CATCH_REQUIRE(centroids.dimensions() == trainset.dimensions()); + + // Verify centroids are from training set + for (size_t i = 0; i < num_centroids; ++i) { + auto centroid = centroids.get_datum(i); + bool found = false; + for (size_t j = 0; j < trainset.size(); ++j) { + auto train_point = trainset.get_datum(j); + bool matches = true; + for (size_t k = 0; k < trainset.dimensions(); ++k) { + if (std::abs(centroid[k] - train_point[k]) > 1e-6f) { + matches = false; + break; + } + } + if (matches) { + found = true; + break; + } + } + CATCH_REQUIRE(found); + } + } + + CATCH_SECTION("normalize_centroids") { + auto threadpool = svs::threads::as_threadpool(4); + auto timer = svs::lib::Timer(); + + // Create centroids with non-unit norms + auto centroids = svs::data::SimpleData(8, 10); + for (size_t i = 0; i < centroids.size(); ++i) { + auto centroid = centroids.get_datum(i); + for (size_t j = 0; j < centroids.dimensions(); ++j) { + centroid[j] = static_cast((i + 1) * (j + 1)); + } + } + + ivf::normalize_centroids(centroids, threadpool, timer); + + // Verify centroids are normalized (L2 norm = 1) + for (size_t i = 0; i < centroids.size(); ++i) { + auto centroid = centroids.get_datum(i); + float norm_sq = 0.0f; + for (size_t j = 0; j < centroids.dimensions(); ++j) { + norm_sq += centroid[j] * centroid[j]; + } + float norm = std::sqrt(norm_sq); + CATCH_REQUIRE(std::abs(norm - 1.0f) < 1e-5f); + } + } +} + +CATCH_TEST_CASE("Cluster Assignment Utility", "[ivf][common][cluster_assignment]") { + namespace ivf = svs::index::ivf; + + auto test_cluster_assignment = + [&]() { + auto threadpool = svs::threads::as_threadpool(4); + + // Create test data + size_t num_points = 1000; + size_t num_centroids = 10; + size_t dims = 128; + + auto data = svs::data::SimpleData(num_points, dims); + auto centroids = svs::data::SimpleData(num_centroids, dims); + + // Initialize data with structured patterns + for (size_t i = 0; i < num_points; ++i) { + auto datum = data.get_datum(i); + size_t cluster_id = i % num_centroids; + for (size_t j = 0; j < dims; ++j) { + // Create data that naturally clusters around centroids + datum[j] = static_cast( + cluster_id * 10.0f + j * 0.1f + (i % 10) * 0.01f + ); + } + } + + // Initialize centroids to match cluster centers + for (size_t i = 0; i < num_centroids; ++i) { + auto centroid = centroids.get_datum(i); + for (size_t j = 0; j < dims; ++j) { + centroid[j] = static_cast(i * 10.0f + j * 0.1f); + } + } + + // Normalize for IP distance if needed + if constexpr (std::is_same_v) { + auto timer = svs::lib::Timer(); + ivf::normalize_centroids(centroids, threadpool, timer); + + // Normalize data as well for IP + for (size_t i = 0; i < num_points; ++i) { + auto datum = data.get_datum(i); + float norm = 0.0f; + for (size_t j = 0; j < dims; ++j) { + norm += static_cast(datum[j]) * static_cast(datum[j]); + } + norm = std::sqrt(norm); + if (norm > 0.0f) { + for (size_t j = 0; j < dims; ++j) { + datum[j] = + static_cast(static_cast(datum[j]) / norm); + } + } + } + } + + auto distance = Distance(); + + // Call cluster_assignment utility + auto clusters = ivf::cluster_assignment( + data, centroids, distance, threadpool, 10'000, svs::lib::Type() + ); + + // Verify results + CATCH_REQUIRE(clusters.size() == num_centroids); + + // Count total assigned points + size_t total_assigned = 0; + for (const auto& cluster : clusters) { + total_assigned += cluster.size(); + } + CATCH_REQUIRE(total_assigned == num_points); + + // Verify no cluster is empty (with our structured data) + size_t empty_clusters = 0; + for (const auto& cluster : clusters) { + if (cluster.empty()) { + empty_clusters++; + } + } + // With structured data, we expect most clusters to have points + // but allow a few empty clusters due to random initialization + CATCH_REQUIRE(empty_clusters <= 2); + }; + + CATCH_SECTION("Float32 with L2 Distance") { + test_cluster_assignment.operator()(); + } + + CATCH_SECTION("Float32 with IP Distance") { + test_cluster_assignment.operator()(); + } + + CATCH_SECTION("Float16 with L2 Distance") { + test_cluster_assignment.operator()(); + } + + CATCH_SECTION("Float16 with IP Distance") { + test_cluster_assignment.operator()(); + } + + CATCH_SECTION("BFloat16 with L2 Distance") { + test_cluster_assignment.operator()(); + } + + CATCH_SECTION("BFloat16 with IP Distance") { + test_cluster_assignment.operator()(); + } +} + +CATCH_TEST_CASE( + "IVF Train-Only and Cluster Assignment", "[ivf][common][train_only][cluster_assignment]" +) { + namespace ivf = svs::index::ivf; + auto threadpool = svs::threads::as_threadpool(4); + auto data = test_dataset::data_f32(); + + auto parameters = ivf::IVFBuildParameters() + .num_centroids(50) + .minibatch_size(500) + .num_iterations(10) + .is_hierarchical(false) + .training_fraction(0.5) + .seed(12345); + + CATCH_SECTION("Flat K-means: train_only + cluster_assignment vs full clustering") { + auto distance_l2 = svs::DistanceL2(); + + // Method 1: Full clustering (without train_only) + auto [centroids_full, clusters_full] = ivf::kmeans_clustering( + parameters, + data, + distance_l2, + threadpool, + svs::lib::Type(), + svs::logging::get(), + false // train_only = false + ); + + // Method 2: Train-only + cluster_assignment + auto [centroids_train, clusters_train] = ivf::kmeans_clustering( + parameters, + data, + distance_l2, + threadpool, + svs::lib::Type(), + svs::logging::get(), + true // train_only = true + ); + + // Verify train_only returns empty clusters + CATCH_REQUIRE(clusters_train.size() == parameters.num_centroids_); + for (const auto& cluster : clusters_train) { + CATCH_REQUIRE(cluster.empty()); + } + + // Now assign data using the cluster_assignment utility + auto clusters_assigned = ivf::cluster_assignment( + data, + centroids_train, + distance_l2, + threadpool, + 500, // minibatch_size + svs::lib::Type() + ); + + // Verify centroids match (within tolerance) + CATCH_REQUIRE(centroids_train.size() == centroids_full.size()); + CATCH_REQUIRE(centroids_train.dimensions() == centroids_full.dimensions()); + + for (size_t i = 0; i < centroids_train.size(); ++i) { + auto c1 = centroids_train.get_datum(i); + auto c2 = centroids_full.get_datum(i); + for (size_t j = 0; j < centroids_train.dimensions(); ++j) { + CATCH_REQUIRE(std::abs(c1[j] - c2[j]) < 1e-5f); + } + } + + // Verify cluster assignments match + CATCH_REQUIRE(clusters_assigned.size() == clusters_full.size()); + for (size_t i = 0; i < clusters_assigned.size(); ++i) { + CATCH_REQUIRE(clusters_assigned[i].size() == clusters_full[i].size()); + + // Sort both to compare + auto a = clusters_assigned[i]; + auto b = clusters_full[i]; + std::sort(a.begin(), a.end()); + std::sort(b.begin(), b.end()); + CATCH_REQUIRE(a == b); + } + + // Verify all points are assigned + size_t total_assigned = 0; + for (const auto& cluster : clusters_assigned) { + total_assigned += cluster.size(); + } + CATCH_REQUIRE(total_assigned == data.size()); + } + + CATCH_SECTION("Hierarchical K-means: train_only + cluster_assignment vs full clustering" + ) { + auto distance_ip = svs::DistanceIP(); + + // Use hierarchical k-means + auto hier_params = + parameters.is_hierarchical(true).hierarchical_level1_clusters(10); + + // Method 1: Full clustering (without train_only) + auto [centroids_full, clusters_full] = ivf::hierarchical_kmeans_clustering( + hier_params, + data, + distance_ip, + threadpool, + svs::lib::Type(), + svs::logging::get(), + false // train_only = false + ); + + // Method 2: Train-only + cluster_assignment + auto [centroids_train, clusters_train] = ivf::hierarchical_kmeans_clustering( + hier_params, + data, + distance_ip, + threadpool, + svs::lib::Type(), + svs::logging::get(), + true // train_only = true + ); + + // Verify train_only returns empty clusters + CATCH_REQUIRE(clusters_train.size() == hier_params.num_centroids_); + for (const auto& cluster : clusters_train) { + CATCH_REQUIRE(cluster.empty()); + } + + // Now assign data using the cluster_assignment utility + auto clusters_assigned = ivf::cluster_assignment( + data, + centroids_train, + distance_ip, + threadpool, + 500, // minibatch_size + svs::lib::Type() + ); + + // Verify centroids match (within tolerance) + CATCH_REQUIRE(centroids_train.size() == centroids_full.size()); + CATCH_REQUIRE(centroids_train.dimensions() == centroids_full.dimensions()); + + for (size_t i = 0; i < centroids_train.size(); ++i) { + auto c1 = centroids_train.get_datum(i); + auto c2 = centroids_full.get_datum(i); + for (size_t j = 0; j < centroids_train.dimensions(); ++j) { + CATCH_REQUIRE(std::abs(c1[j] - c2[j]) < 1e-5f); + } + } + + // Verify cluster structure is reasonable + CATCH_REQUIRE(clusters_assigned.size() == clusters_full.size()); + + // Verify all points are assigned in both methods + size_t total_assigned = 0; + size_t total_full = 0; + for (size_t i = 0; i < clusters_assigned.size(); ++i) { + total_assigned += clusters_assigned[i].size(); + total_full += clusters_full[i].size(); + } + CATCH_REQUIRE(total_assigned == data.size()); + CATCH_REQUIRE(total_full == data.size()); + + // For hierarchical k-means, assignments may differ slightly due to + // precision differences in the two-level clustering process. + // The important thing is that both methods produce valid clusterings. + // We verify this by checking that the distribution of cluster sizes + // is reasonable and similar. + + // Check no cluster is excessively large (> 50% of data) + for (const auto& cluster : clusters_assigned) { + CATCH_REQUIRE(cluster.size() <= data.size() / 2); + } + for (const auto& cluster : clusters_full) { + CATCH_REQUIRE(cluster.size() <= data.size() / 2); + } + + // Count non-empty clusters in both + size_t non_empty_assigned = 0; + size_t non_empty_full = 0; + for (size_t i = 0; i < clusters_assigned.size(); ++i) { + if (!clusters_assigned[i].empty()) + non_empty_assigned++; + if (!clusters_full[i].empty()) + non_empty_full++; + } + + // Both should have similar number of non-empty clusters (within 20%) + double ratio = static_cast(non_empty_assigned) / non_empty_full; + CATCH_REQUIRE(ratio >= 0.8); + CATCH_REQUIRE(ratio <= 1.2); + } + + CATCH_SECTION("Different data types with train_only workflow") { + auto distance_l2 = svs::DistanceL2(); + + // Test with Float16 + auto [centroids_fp16, clusters_empty_fp16] = ivf::kmeans_clustering( + parameters, + data, + distance_l2, + threadpool, + svs::lib::Type(), + svs::logging::get(), + true // train_only = true + ); + + auto clusters_fp16 = ivf::cluster_assignment( + data, centroids_fp16, distance_l2, threadpool, 500, svs::lib::Type() + ); + + CATCH_REQUIRE(clusters_fp16.size() == parameters.num_centroids_); + size_t total_fp16 = 0; + for (const auto& cluster : clusters_fp16) { + total_fp16 += cluster.size(); + } + CATCH_REQUIRE(total_fp16 == data.size()); + + // Test with BFloat16 + auto [centroids_bf16, clusters_empty_bf16] = ivf::kmeans_clustering( + parameters, + data, + distance_l2, + threadpool, + svs::lib::Type(), + svs::logging::get(), + true // train_only = true + ); + + auto clusters_bf16 = ivf::cluster_assignment( + data, centroids_bf16, distance_l2, threadpool, 500, svs::lib::Type() + ); + + CATCH_REQUIRE(clusters_bf16.size() == parameters.num_centroids_); + size_t total_bf16 = 0; + for (const auto& cluster : clusters_bf16) { + total_bf16 += cluster.size(); + } + CATCH_REQUIRE(total_bf16 == data.size()); + } +} diff --git a/tests/svs/index/ivf/dynamic_ivf.cpp b/tests/svs/index/ivf/dynamic_ivf.cpp new file mode 100644 index 000000000..690213b61 --- /dev/null +++ b/tests/svs/index/ivf/dynamic_ivf.cpp @@ -0,0 +1,988 @@ +/* + * Copyright 2025 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// svs +#include "svs/index/ivf/dynamic_ivf.h" +#include "svs/core/data.h" +#include "svs/core/distance.h" +#include "svs/core/query_result.h" +#include "svs/core/recall.h" +#include "svs/index/ivf/clustering.h" +#include "svs/lib/preprocessor.h" +#include "svs/lib/threads.h" +#include "svs/lib/timing.h" +#include "svs/misc/dynamic_helper.h" + +// tests +#include "tests/utils/test_dataset.h" +#include "tests/utils/utils.h" + +// catch +#include "catch2/catch_test_macros.hpp" + +// stl +#include +#include +#include +#include +#include + +using Idx = uint32_t; +using Eltype = float; +using QueryEltype = float; +using Distance = svs::distance::DistanceL2; +const size_t N = 128; + +const size_t NUM_NEIGHBORS = 10; +const size_t NUM_CLUSTERS = 10; + +/// +/// Utility Methods +/// + +template I div(I i, float fraction) { + return svs::lib::narrow(std::floor(svs::lib::narrow(i) * fraction)); +} + +template std::string stringify(Args&&... args) { + std::ostringstream stream{}; + ((stream << args), ...); + return stream.str(); +} + +/// +/// Main Loop. +/// + +template +void do_check( + MutableIndex& index, + svs::misc::ReferenceDataset& reference, + const Queries& queries, + double operation_time, + std::string message +) { + // Compute groundtruth + auto tic = svs::lib::now(); + auto gt = reference.groundtruth(); + CATCH_REQUIRE(gt.n_neighbors() == NUM_NEIGHBORS); + CATCH_REQUIRE(gt.n_queries() == queries.size()); + + double groundtruth_time = svs::lib::time_difference(tic); + + // Run search + tic = svs::lib::now(); + auto results = svs::QueryResult(gt.n_queries(), NUM_NEIGHBORS); + auto search_parameters = svs::index::ivf::IVFSearchParameters( + NUM_CLUSTERS, // n_probes - search all clusters for accuracy + NUM_NEIGHBORS // k_reorder + ); + + index.search( + results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + search_parameters + ); + double search_time = svs::lib::time_difference(tic); + + // Extra ID checks + reference.check_ids(results); + reference.check_equal_ids(index); + + // compute recall + double recall = svs::k_recall_at_n(gt, results, NUM_NEIGHBORS, NUM_NEIGHBORS); + + std::cout << "[" << message << "] -- {" + << "operation: " << operation_time << ", groundtruth: " << groundtruth_time + << ", search: " << search_time << ", recall: " << recall << "}\n"; +} + +template +void test_loop( + MutableIndex& index, + svs::misc::ReferenceDataset& reference, + const Queries& queries, + size_t num_points, + size_t consolidate_every, + size_t iterations +) { + size_t consolidate_count = 0; + for (size_t i = 0; i < iterations; ++i) { + // Add Points + { + auto [points, time] = reference.add_points(index, num_points); + CATCH_REQUIRE(points <= num_points); + CATCH_REQUIRE(points > num_points - reference.bucket_size()); + do_check(index, reference, queries, time, stringify("add ", points, " points")); + } + + // Delete Points + { + auto [points, time] = reference.delete_points(index, num_points); + CATCH_REQUIRE(points <= num_points); + CATCH_REQUIRE(points > num_points - reference.bucket_size()); + do_check( + index, reference, queries, time, stringify("delete ", points, " points") + ); + } + + // Maybe compact. + ++consolidate_count; + if (consolidate_count == consolidate_every) { + auto tic = svs::lib::now(); + // Use a batchsize smaller than the whole dataset to ensure that the compaction + // algorithm correctly handles this case. + index.compact(reference.valid() / 10); + double diff = svs::lib::time_difference(tic); + do_check(index, reference, queries, diff, "compact"); + consolidate_count = 0; + } + } +} + +CATCH_TEST_CASE("Testing Dynamic IVF Index", "[dynamic_ivf]") { +#if defined(NDEBUG) + const float initial_fraction = 0.25; + const float modify_fraction = 0.05; +#else + const float initial_fraction = 0.05; + const float modify_fraction = 0.005; +#endif + const size_t num_threads = 10; + + // Load the base dataset and queries. + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto num_points = data.size(); + auto queries = test_dataset::queries(); + + auto reference = svs::misc::ReferenceDataset( + std::move(data), + Distance(), + num_threads, + div(num_points, 0.5 * modify_fraction), + NUM_NEIGHBORS, + queries, + 0x12345678 + ); + + auto num_indices_to_add = div(reference.size(), initial_fraction); + + // Generate initial vectors and indices + std::vector initial_indices{}; + auto initial_data = svs::data::SimpleData(num_indices_to_add, N); + { + auto [vectors, indices] = reference.generate(num_indices_to_add); + auto num_points_added = indices.size(); + CATCH_REQUIRE(vectors.size() == num_points_added); + CATCH_REQUIRE(num_points_added <= num_indices_to_add); + CATCH_REQUIRE(num_points_added > num_indices_to_add - reference.bucket_size()); + + initial_indices = indices; + if (vectors.size() != num_indices_to_add || indices.size() != num_indices_to_add) { + throw ANNEXCEPTION("Something went horribly wrong!"); + } + + for (size_t i = 0; i < num_indices_to_add; ++i) { + initial_data.set_datum(i, vectors.get_datum(i)); + } + } + + // Build IVF clustering + auto build_params = svs::index::ivf::IVFBuildParameters( + NUM_CLUSTERS, + /* max_iters */ 10, + /* is_hierarchical */ false + ); + + auto threadpool = svs::threads::SequentialThreadPool(); + auto clustering = svs::index::ivf::build_clustering( + build_params, + svs::lib::Lazy([&initial_data]() { return initial_data; }), + Distance(), + threadpool, + /* train_only */ false + ); + + // Create dynamic clusters using DenseClusteredDataset + auto centroids = clustering.centroids(); + using DataType = svs::data::SimpleData; + auto dense_clusters = + svs::index::ivf::DenseClusteredDataset( + clustering, initial_data, threadpool, svs::lib::Allocator() + ); + + // Create the dynamic IVF index + auto threadpool_for_index = svs::threads::as_threadpool(num_threads); + using IndexType = svs::index::ivf::DynamicIVFIndex< + decltype(centroids), + decltype(dense_clusters), + Distance, + decltype(threadpool_for_index)>; + + auto index = IndexType( + std::move(centroids), + std::move(dense_clusters), + initial_indices, + Distance(), + std::move(threadpool_for_index), + 1 // intra_query_threads + ); + + reference.configure_extra_checks(true); + CATCH_REQUIRE(reference.extra_checks_enabled()); + + test_loop(index, reference, queries, div(reference.size(), modify_fraction), 2, 6); +} + +CATCH_TEST_CASE("Testing Dynamic IVF Index with BlockedData", "[dynamic_ivf]") { + // This test verifies that BlockedData allocator works correctly for dynamic operations + const size_t num_threads = 4; + + // Load data + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto queries = test_dataset::queries(); + + // Build clustering + auto build_params = svs::index::ivf::IVFBuildParameters(10, 10, false); + auto threadpool = svs::threads::SequentialThreadPool(); + auto clustering = svs::index::ivf::build_clustering( + build_params, data, Distance(), threadpool, false + ); + + // Use assemble_dynamic_from_clustering with external IDs + std::vector ids(data.size()); + std::iota(ids.begin(), ids.end(), 0); + + auto index = svs::index::ivf::assemble_dynamic_from_clustering( + std::move(clustering), + data, + ids, + Distance(), + svs::threads::as_threadpool(num_threads), + 1 + ); + + // Test 1: Initial search works + auto params = svs::index::ivf::IVFSearchParameters(10, NUM_NEIGHBORS); + auto results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + + index.search( + results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + params + ); + + // Verify we got results + size_t valid_results = 0; + for (size_t i = 0; i < results.n_queries(); ++i) { + if (results.index(i, 0) != std::numeric_limits::max()) { + valid_results++; + } + } + CATCH_REQUIRE(valid_results > 0); + + // Test 2: Add points (BlockedData's resize capability) + constexpr size_t num_add = 100; + std::vector new_ids; + auto new_data = svs::data::SimpleData(num_add, N); + for (size_t i = 0; i < num_add; ++i) { + new_ids.push_back(data.size() + i); + new_data.set_datum(i, data.get_datum(i % data.size())); + } + + size_t size_before = index.size(); + index.add_points(new_data, new_ids, false); + CATCH_REQUIRE(index.size() == size_before + num_add); + + // Test 3: Search still works after adding + index.search( + results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + params + ); + + valid_results = 0; + for (size_t i = 0; i < results.n_queries(); ++i) { + if (results.index(i, 0) != std::numeric_limits::max()) { + valid_results++; + } + } + CATCH_REQUIRE(valid_results > 0); + + // Test 4: Delete some points + std::vector to_delete; + for (size_t i = 0; i < 50; ++i) { + to_delete.push_back(i); + } + size_t deleted = index.delete_entries(to_delete); + CATCH_REQUIRE(deleted == to_delete.size()); + CATCH_REQUIRE(index.size() == size_before + num_add - deleted); + + // Test 5: Compact works with BlockedData + index.compact(1000); + CATCH_REQUIRE(index.size() == size_before + num_add - deleted); + + // Test 6: Search after compaction + index.search( + results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + params + ); + + valid_results = 0; + for (size_t i = 0; i < results.n_queries(); ++i) { + if (results.index(i, 0) != std::numeric_limits::max()) { + valid_results++; + } + } + CATCH_REQUIRE(valid_results > 0); +} + +CATCH_TEST_CASE("Dynamic IVF - Edge Cases", "[dynamic_ivf]") { + const size_t num_threads = 4; + const size_t num_points = 100; + + // Create a small dataset + auto data = svs::data::SimpleData(num_points, N); + std::mt19937 rng(42); + std::uniform_real_distribution dist(0.0f, 1.0f); + for (size_t i = 0; i < num_points; ++i) { + std::vector vec(N); + for (size_t j = 0; j < N; ++j) { + vec[j] = dist(rng); + } + data.set_datum(i, vec); + } + + // Build clustering with more clusters than points to test empty clusters + // With the fix, this should now work by using all 100 datapoints for training + auto build_params = svs::index::ivf::IVFBuildParameters( + 50, // More clusters than 10% of data (which would be 10 points) + 10, // max_iters + false // is_hierarchical + ); + + auto threadpool = svs::threads::SequentialThreadPool(); + auto clustering = svs::index::ivf::build_clustering( + build_params, + svs::lib::Lazy([&data]() { return data; }), + Distance(), + threadpool, + false + ); + + // Create dynamic clusters using DenseClusteredDataset + std::vector initial_indices; + for (size_t c = 0; c < clustering.size(); ++c) { + for (auto idx : clustering.cluster(c)) { + initial_indices.push_back(idx); + } + } + + auto centroids = clustering.centroids(); + using DataType = svs::data::SimpleData; + auto dense_clusters = + svs::index::ivf::DenseClusteredDataset( + clustering, data, threadpool, svs::lib::Allocator() + ); + + auto threadpool_for_index = svs::threads::as_threadpool(num_threads); + using IndexType = svs::index::ivf::DynamicIVFIndex< + decltype(centroids), + decltype(dense_clusters), + Distance, + decltype(threadpool_for_index)>; + + auto index = IndexType( + std::move(centroids), + std::move(dense_clusters), + initial_indices, + Distance(), + std::move(threadpool_for_index), + 1 + ); + + // Test 1: Search with sparse/empty clusters (should not crash) + auto query = svs::data::SimpleData(1, N); + std::vector query_vec(N); + for (size_t j = 0; j < N; ++j) { + query_vec[j] = dist(rng); + } + query.set_datum(0, query_vec); + + auto results = svs::QueryResult(1, NUM_NEIGHBORS); + auto search_params = svs::index::ivf::IVFSearchParameters(50, NUM_NEIGHBORS); + + index.search( + results.view(), + svs::data::ConstSimpleDataView{query.data(), 1, N}, + search_params + ); + + // Verify results are valid (not all max values) + bool found_valid = false; + for (size_t i = 0; i < NUM_NEIGHBORS; ++i) { + if (results.index(0, i) != std::numeric_limits::max()) { + found_valid = true; + break; + } + } + CATCH_REQUIRE(found_valid); + + // Test 2: Delete and compact + std::vector to_delete; + for (size_t i = 0; i < 20 && i < initial_indices.size(); ++i) { + to_delete.push_back(initial_indices[i]); + } + + index.delete_entries(to_delete); + + index.compact(10); + + // Search after compaction + index.search( + results.view(), + svs::data::ConstSimpleDataView{query.data(), 1, N}, + search_params + ); + + CATCH_REQUIRE(results.index(0, 0) != std::numeric_limits::max()); +} + +CATCH_TEST_CASE("Dynamic IVF - Search Parameters Variations", "[dynamic_ivf]") { + const size_t num_threads = 4; + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto queries = test_dataset::queries(); + + // Build with standard parameters + auto build_params = svs::index::ivf::IVFBuildParameters(NUM_CLUSTERS, 10, false); + auto threadpool = svs::threads::SequentialThreadPool(); + auto clustering = svs::index::ivf::build_clustering( + build_params, + svs::lib::Lazy([&data]() { return data; }), + Distance(), + threadpool, + false + ); + + // Create dynamic clusters using DenseClusteredDataset + std::vector indices; + for (size_t c = 0; c < clustering.size(); ++c) { + for (auto idx : clustering.cluster(c)) { + indices.push_back(idx); + } + } + + auto centroids = clustering.centroids(); + using DataType = svs::data::SimpleData; + auto dense_clusters = + svs::index::ivf::DenseClusteredDataset( + clustering, data, threadpool, svs::lib::Allocator() + ); + + auto threadpool_for_index = svs::threads::as_threadpool(num_threads); + using IndexType = svs::index::ivf::DynamicIVFIndex< + decltype(centroids), + decltype(dense_clusters), + Distance, + decltype(threadpool_for_index)>; + + auto index = IndexType( + std::move(centroids), + std::move(dense_clusters), + indices, + Distance(), + std::move(threadpool_for_index), + 1 + ); + + auto results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + + // Test with different n_probes values + std::vector probe_counts = {1, 3, 5, NUM_CLUSTERS}; + std::vector recalls; + + for (auto n_probes : probe_counts) { + auto params = svs::index::ivf::IVFSearchParameters(n_probes, NUM_NEIGHBORS); + index.search( + results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + params + ); + + // Verify all results are valid + for (size_t i = 0; i < queries.size(); ++i) { + for (size_t j = 0; j < NUM_NEIGHBORS; ++j) { + auto idx = results.index(i, j); + CATCH_REQUIRE( + (idx < data.size() || idx == std::numeric_limits::max()) + ); + } + } + } +} + +CATCH_TEST_CASE("Dynamic IVF - Threading Configurations", "[dynamic_ivf]") { + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto queries = test_dataset::queries(); + + auto build_params = svs::index::ivf::IVFBuildParameters(NUM_CLUSTERS, 10, false); + auto threadpool = svs::threads::SequentialThreadPool(); + auto clustering = svs::index::ivf::build_clustering( + build_params, + svs::lib::Lazy([&data]() { return data; }), + Distance(), + threadpool, + false + ); + + // Test with different thread configurations + std::vector thread_configs = {1, 2, 4, 8}; + std::vector intra_query_configs = {1, 2}; + + for (auto num_threads : thread_configs) { + for (auto intra_threads : intra_query_configs) { + std::vector indices; + for (size_t c = 0; c < clustering.size(); ++c) { + for (auto idx : clustering.cluster(c)) { + indices.push_back(idx); + } + } + + auto centroids_copy = clustering.centroids(); + using DataType = svs::data::SimpleData; + auto dense_clusters = svs::index::ivf:: + DenseClusteredDataset( + clustering, data, threadpool, svs::lib::Allocator() + ); + + auto threadpool_for_index = svs::threads::as_threadpool(num_threads); + using IndexType = svs::index::ivf::DynamicIVFIndex< + decltype(centroids_copy), + decltype(dense_clusters), + Distance, + decltype(threadpool_for_index)>; + + auto index = IndexType( + std::move(centroids_copy), + std::move(dense_clusters), + indices, + Distance(), + std::move(threadpool_for_index), + intra_threads + ); + + auto results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + auto params = svs::index::ivf::IVFSearchParameters(NUM_CLUSTERS, NUM_NEIGHBORS); + + index.search( + results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + params + ); + + // Verify results are consistent + for (size_t i = 0; i < queries.size(); ++i) { + for (size_t j = 0; j < NUM_NEIGHBORS; ++j) { + auto idx = results.index(i, j); + CATCH_REQUIRE( + (idx < data.size() || idx == std::numeric_limits::max()) + ); + } + } + } + } +} + +CATCH_TEST_CASE("Dynamic IVF - Add/Delete Stress Test", "[dynamic_ivf]") { + const size_t num_threads = 4; + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto queries = test_dataset::queries(); + + auto build_params = svs::index::ivf::IVFBuildParameters(NUM_CLUSTERS, 10, false); + auto threadpool = svs::threads::SequentialThreadPool(); + + // Start with half the data + size_t initial_size = data.size() / 2; + auto initial_data = svs::data::SimpleData(initial_size, N); + for (size_t i = 0; i < initial_size; ++i) { + initial_data.set_datum(i, data.get_datum(i)); + } + + auto clustering = svs::index::ivf::build_clustering( + build_params, + svs::lib::Lazy([&initial_data]() { return initial_data; }), + Distance(), + threadpool, + false + ); + + // Create dynamic clusters using DenseClusteredDataset + std::vector indices; + for (size_t c = 0; c < clustering.size(); ++c) { + for (auto idx : clustering.cluster(c)) { + indices.push_back(idx); + } + } + + auto centroids = clustering.centroids(); + using DataType = svs::data::SimpleData; + auto dense_clusters = + svs::index::ivf::DenseClusteredDataset( + clustering, initial_data, threadpool, svs::lib::Allocator() + ); + + auto threadpool_for_index = svs::threads::as_threadpool(num_threads); + using IndexType = svs::index::ivf::DynamicIVFIndex< + decltype(centroids), + decltype(dense_clusters), + Distance, + decltype(threadpool_for_index)>; + + auto index = IndexType( + std::move(centroids), + std::move(dense_clusters), + indices, + Distance(), + std::move(threadpool_for_index), + 1 + ); + + auto results = svs::QueryResult(queries.size(), NUM_NEIGHBORS); + auto params = svs::index::ivf::IVFSearchParameters(NUM_CLUSTERS, NUM_NEIGHBORS); + + // Test: Rapid add/delete cycles + std::mt19937 rng(12345); + std::uniform_int_distribution idx_dist(0, indices.size() - 1); + + for (size_t cycle = 0; cycle < 5; ++cycle) { + // Delete random entries + std::vector deleted; + for (size_t i = 0; i < 10 && i < indices.size(); ++i) { + size_t idx = idx_dist(rng) % indices.size(); + deleted.push_back(indices[idx]); + } + if (!deleted.empty()) { + index.delete_entries(deleted); + } + + // Search after deletion + index.search( + results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + params + ); + + // Verify deleted IDs don't appear in results + for (size_t q = 0; q < queries.size(); ++q) { + for (size_t k = 0; k < NUM_NEIGHBORS; ++k) { + auto result_id = results.index(q, k); + for (auto deleted_id : deleted) { + CATCH_REQUIRE(result_id != deleted_id); + } + } + } + + // Add new entries + std::vector new_ids; + auto new_data = svs::data::SimpleData(10, N); + Idx new_base_id = 10000 + cycle * 100; + for (size_t i = 0; i < 10; ++i) { + new_ids.push_back(new_base_id + i); + new_data.set_datum(i, data.get_datum(i % data.size())); + } + index.add_points(new_data, new_ids, false); + + // Search after addition + index.search( + results.view(), + svs::data::ConstSimpleDataView{ + queries.data(), queries.size(), queries.dimensions()}, + params + ); + + // All results should be valid + for (size_t q = 0; q < queries.size(); ++q) { + CATCH_REQUIRE(results.index(q, 0) != std::numeric_limits::max()); + } + + // Compact periodically + if (cycle % 2 == 1) { + index.compact(50); + } + } +} + +CATCH_TEST_CASE("Dynamic IVF - Single Query Search", "[dynamic_ivf]") { + const size_t num_threads = 2; + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + auto queries = test_dataset::queries(); + + auto build_params = svs::index::ivf::IVFBuildParameters(NUM_CLUSTERS, 10, false); + auto threadpool = svs::threads::SequentialThreadPool(); + auto clustering = svs::index::ivf::build_clustering( + build_params, + svs::lib::Lazy([&data]() { return data; }), + Distance(), + threadpool, + false + ); + + // Create dynamic clusters using DenseClusteredDataset + std::vector indices; + for (size_t c = 0; c < clustering.size(); ++c) { + for (auto idx : clustering.cluster(c)) { + indices.push_back(idx); + } + } + + auto centroids = clustering.centroids(); + using DataType = svs::data::SimpleData; + auto dense_clusters = + svs::index::ivf::DenseClusteredDataset( + clustering, data, threadpool, svs::lib::Allocator() + ); + + auto threadpool_for_index = svs::threads::as_threadpool(num_threads); + using IndexType = svs::index::ivf::DynamicIVFIndex< + decltype(centroids), + decltype(dense_clusters), + Distance, + decltype(threadpool_for_index)>; + + auto index = IndexType( + std::move(centroids), + std::move(dense_clusters), + indices, + Distance(), + std::move(threadpool_for_index), + 1 + ); + + // Test single query search + auto single_query = svs::data::SimpleData(1, N); + single_query.set_datum(0, queries.get_datum(0)); + + auto results = svs::QueryResult(1, NUM_NEIGHBORS); + auto params = svs::index::ivf::IVFSearchParameters(NUM_CLUSTERS, NUM_NEIGHBORS); + + index.search( + results.view(), + svs::data::ConstSimpleDataView{single_query.data(), 1, N}, + params + ); + + // Verify we got valid results + CATCH_REQUIRE(results.index(0, 0) != std::numeric_limits::max()); + + // Verify distances are in ascending order + for (size_t k = 1; k < NUM_NEIGHBORS; ++k) { + if (results.index(0, k) != std::numeric_limits::max()) { + CATCH_REQUIRE(results.distance(0, k) >= results.distance(0, k - 1)); + } + } +} + +CATCH_TEST_CASE("Dynamic IVF Get Distance", "[index][ivf][dynamic_ivf]") { + const size_t num_threads = 2; + const size_t num_points = 200; + + // Create test dataset + auto data = svs::data::SimpleData(num_points, N); + std::mt19937 rng(42); + std::uniform_real_distribution dist(0.0f, 1.0f); + for (size_t i = 0; i < num_points; ++i) { + std::vector vec(N); + for (size_t j = 0; j < N; ++j) { + vec[j] = dist(rng); + } + data.set_datum(i, vec); + } + + // Create queries + const size_t num_queries = 20; + auto queries = svs::data::SimpleData(num_queries, N); + for (size_t i = 0; i < num_queries; ++i) { + std::vector vec(N); + for (size_t j = 0; j < N; ++j) { + vec[j] = dist(rng); + } + queries.set_datum(i, vec); + } + + // Build IVF clustering + auto build_params = svs::index::ivf::IVFBuildParameters( + NUM_CLUSTERS, + /* max_iters */ 10, + /* is_hierarchical */ false + ); + + auto threadpool = svs::threads::SequentialThreadPool(); + auto clustering = svs::index::ivf::build_clustering( + build_params, + svs::lib::Lazy([&data]() { return data; }), + Distance(), + threadpool, + /* train_only */ false + ); + + // Create dynamic clusters using DenseClusteredDataset + std::vector initial_indices; // External IDs in order + + for (size_t c = 0; c < clustering.size(); ++c) { + const auto& cluster_indices = clustering.cluster(c); + for (size_t i = 0; i < cluster_indices.size(); ++i) { + Idx external_id = cluster_indices[i]; // Use clustering index as external ID + initial_indices.push_back(external_id); + } + } + + auto centroids = clustering.centroids(); + using DataType = svs::data::SimpleData; + auto dense_clusters = + svs::index::ivf::DenseClusteredDataset( + clustering, data, threadpool, svs::lib::Allocator() + ); + + // Need to update cluster IDs to use sequential internal IDs + for (size_t c = 0, global_idx = 0; c < dense_clusters.size(); ++c) { + auto& cluster = dense_clusters[c]; + for (size_t i = 0; i < cluster.ids_.size(); ++i) { + cluster.ids_[i] = global_idx++; + } + } + + // Create the dynamic IVF index + auto threadpool_for_index = svs::threads::as_threadpool(num_threads); + using IndexType = svs::index::ivf::DynamicIVFIndex< + decltype(centroids), + decltype(dense_clusters), + Distance, + decltype(threadpool_for_index)>; + + auto index = IndexType( + std::move(centroids), + std::move(dense_clusters), + initial_indices, + Distance(), + std::move(threadpool_for_index), + 1 // intra_query_threads + ); + + // Test get_distance functionality using the standard tester + CATCH_SECTION("Get Distance Test") { + // Test with strict tolerance to verify correctness + constexpr double TOLERANCE = 1e-2; // 1% tolerance, same as flat index + + // Test with a few different IDs + std::vector test_ids = {0, 10, 50}; + if (index.size() > 100) { + test_ids.push_back(100); + } + + for (size_t test_id : test_ids) { + if (test_id >= index.size()) { + continue; + } + + // Get a query vector + size_t query_id = std::min(5, queries.size() - 1); + auto query = queries.get_datum(query_id); + + // Get distance from index + double index_distance = index.get_distance(test_id, query); + + // Compute expected distance from original data + // test_id is the external ID which maps to data[test_id] + auto datum = data.get_datum(test_id); + Distance dist_copy = Distance(); + svs::distance::maybe_fix_argument(dist_copy, query); + double expected_distance = svs::distance::compute(dist_copy, query, datum); + + // Verify the distance is correct + double relative_diff = + std::abs((index_distance - expected_distance) / expected_distance); + CATCH_REQUIRE(relative_diff < TOLERANCE); + } + + // Test with out of bounds ID - should throw + CATCH_REQUIRE_THROWS_AS( + index.get_distance(index.size() + 1000, queries.get_datum(0)), svs::ANNException + ); + } + + // Test get_distance after adding and removing points + CATCH_SECTION("Get Distance After Modifications") { + // Test with strict tolerance to verify correctness + constexpr double TOLERANCE = 1e-2; // 1% tolerance, same as flat index + + // Add some new points + std::vector new_ids = {10000, 10001, 10002}; + + // Prepare data for batch insertion + auto new_data = svs::data::SimpleData(new_ids.size(), N); + for (size_t i = 0; i < new_ids.size(); ++i) { + new_data.set_datum(i, data.get_datum(i)); + } + + // Add points in batch + index.add_points(new_data, new_ids); + + // Test get_distance for newly added points + for (size_t i = 0; i < new_ids.size(); ++i) { + size_t query_id = std::min(7, queries.size() - 1); + auto query = queries.get_datum(query_id); + + double index_distance = index.get_distance(new_ids[i], query); + + // Compute expected distance from the original data we added + auto datum = data.get_datum(i); + Distance dist_copy = Distance(); + svs::distance::maybe_fix_argument(dist_copy, query); + double expected_distance = svs::distance::compute(dist_copy, query, datum); + + double relative_diff = + std::abs((index_distance - expected_distance) / expected_distance); + CATCH_REQUIRE(relative_diff < TOLERANCE); + } + + // Delete a point + std::vector ids_to_delete = {new_ids[0]}; + index.delete_entries(ids_to_delete); + + // Verify the deleted point throws exception + CATCH_REQUIRE_THROWS_AS( + index.get_distance(new_ids[0], queries.get_datum(0)), svs::ANNException + ); + + // Verify other points still work + for (size_t i = 1; i < new_ids.size(); ++i) { + size_t query_id = std::min(8, queries.size() - 1); + auto query = queries.get_datum(query_id); + + // Should not throw + double distance = index.get_distance(new_ids[i], query); + CATCH_REQUIRE(distance >= 0.0); + } + } +} diff --git a/tests/svs/index/ivf/hierarchical_kmeans.cpp b/tests/svs/index/ivf/hierarchical_kmeans.cpp index db15940c3..e448e53c0 100644 --- a/tests/svs/index/ivf/hierarchical_kmeans.cpp +++ b/tests/svs/index/ivf/hierarchical_kmeans.cpp @@ -17,6 +17,10 @@ // header under test #include "svs/index/ivf/hierarchical_kmeans.h" +// additional headers for train_only test +#include "svs/index/ivf/index.h" +#include "svs/index/ivf/kmeans.h" + // tests #include "tests/utils/test_dataset.h" #include "tests/utils/utils.h" @@ -25,6 +29,7 @@ #include "catch2/catch_test_macros.hpp" // stl +#include #include namespace { @@ -62,14 +67,477 @@ void test_hierarchical_kmeans_clustering(const Data& data, Distance distance) { } } +template +void test_train_only_centroids_match(const Data& data, Distance distance) { + namespace ivf = svs::index::ivf; + + // Test both flat and hierarchical k-means with different modes + for (bool is_hierarchical : {false, true}) { + for (size_t n_centroids : {25}) { + for (size_t minibatch : {25}) { + for (size_t iters : {3}) { + for (float training_fraction : {0.6}) { + auto params = ivf::IVFBuildParameters() + .num_centroids(n_centroids) + .minibatch_size(minibatch) + .num_iterations(iters) + .is_hierarchical(is_hierarchical) + .training_fraction(training_fraction) + .seed(12345); // Fixed seed for reproducibility + + if (is_hierarchical) { + params.hierarchical_level1_clusters(5); + } + + size_t num_threads = 4; + + // Run with train_only = false (normal mode) + auto [centroids_normal, clusters_normal] = + ivf::build_clustering( + params, data, distance, num_threads, false + ); + + // Run with train_only = true + auto [centroids_train_only, clusters_train_only] = + ivf::build_clustering( + params, data, distance, num_threads, true + ); + + // Verify centroids are identical + CATCH_REQUIRE( + centroids_normal.size() == centroids_train_only.size() + ); + CATCH_REQUIRE( + centroids_normal.dimensions() == + centroids_train_only.dimensions() + ); + + constexpr float tolerance = 1e-6f; + for (size_t i = 0; i < centroids_normal.size(); ++i) { + auto datum_normal = centroids_normal.get_datum(i); + auto datum_train_only = centroids_train_only.get_datum(i); + + for (size_t j = 0; j < centroids_normal.dimensions(); ++j) { + float diff = + std::abs(datum_normal[j] - datum_train_only[j]); + CATCH_REQUIRE(diff < tolerance); + } + } + + // Verify train_only clusters are empty (as expected) + for (const auto& cluster : clusters_train_only) { + CATCH_REQUIRE(cluster.empty()); + } + + // Verify normal mode has non-empty clusters (at least some) + bool has_non_empty_cluster = false; + for (const auto& cluster : clusters_normal) { + if (!cluster.empty()) { + has_non_empty_cluster = true; + break; + } + } + CATCH_REQUIRE(has_non_empty_cluster); + } + } + } + } + } +} + +template +void test_hierarchical_kmeans_level1_clusters(const Data& data, Distance distance) { + namespace ivf = svs::index::ivf; + + // Test different Level 1 cluster configurations + for (size_t n_centroids : {64, 100}) { + for (size_t l1_clusters : {0, 4, 8, 16}) { // 0 means auto-calculate + auto params = ivf::IVFBuildParameters() + .num_centroids(n_centroids) + .minibatch_size(25) + .num_iterations(3) + .is_hierarchical(true) + .training_fraction(0.6f) + .hierarchical_level1_clusters(l1_clusters); + + auto threadpool = svs::threads::as_threadpool(4); + auto [centroids, clusters] = hierarchical_kmeans_clustering( + params, data, distance, threadpool + ); + + CATCH_REQUIRE(centroids.size() == n_centroids); + CATCH_REQUIRE(centroids.dimensions() == data.dimensions()); + CATCH_REQUIRE(clusters.size() == n_centroids); + + // Verify all data points are assigned + std::unordered_set assigned_points; + for (const auto& cluster : clusters) { + for (auto point_id : cluster) { + CATCH_REQUIRE(point_id < data.size()); + assigned_points.insert(point_id); + } + } + CATCH_REQUIRE(assigned_points.size() == data.size()); + } + } +} + +template +void test_hierarchical_kmeans_reproducibility(const Data& data, Distance distance) { + namespace ivf = svs::index::ivf; + + const size_t seed = 98765; + const size_t n_centroids = 50; + const size_t l1_clusters = 7; + + auto params1 = ivf::IVFBuildParameters() + .num_centroids(n_centroids) + .minibatch_size(25) + .num_iterations(4) + .is_hierarchical(true) + .training_fraction(0.7f) + .hierarchical_level1_clusters(l1_clusters) + .seed(seed); + + auto params2 = ivf::IVFBuildParameters() + .num_centroids(n_centroids) + .minibatch_size(25) + .num_iterations(4) + .is_hierarchical(true) + .training_fraction(0.7f) + .hierarchical_level1_clusters(l1_clusters) + .seed(seed); + + auto threadpool = svs::threads::as_threadpool(4); + + auto [centroids1, clusters1] = + hierarchical_kmeans_clustering(params1, data, distance, threadpool); + + auto [centroids2, clusters2] = + hierarchical_kmeans_clustering(params2, data, distance, threadpool); + + // Verify centroids are identical + CATCH_REQUIRE(centroids1.size() == centroids2.size()); + constexpr float tolerance = 1e-6f; + + for (size_t i = 0; i < centroids1.size(); ++i) { + auto centroid1 = centroids1.get_datum(i); + auto centroid2 = centroids2.get_datum(i); + + for (size_t j = 0; j < centroids1.dimensions(); ++j) { + float diff = std::abs(centroid1[j] - centroid2[j]); + CATCH_REQUIRE(diff < tolerance); + } + } +} + +template +void test_hierarchical_vs_flat_kmeans(const Data& data, Distance distance) { + namespace ivf = svs::index::ivf; + + const size_t n_centroids = 36; + + // Flat k-means + auto flat_params = ivf::IVFBuildParameters() + .num_centroids(n_centroids) + .minibatch_size(25) + .num_iterations(3) + .is_hierarchical(false) + .training_fraction(0.6f) + .seed(555); + + // Hierarchical k-means + auto hierarchical_params = ivf::IVFBuildParameters() + .num_centroids(n_centroids) + .minibatch_size(25) + .num_iterations(3) + .is_hierarchical(true) + .training_fraction(0.6f) + .hierarchical_level1_clusters(6) + .seed(555); + + auto threadpool = svs::threads::as_threadpool(4); + + auto [flat_centroids, flat_clusters] = + ivf::kmeans_clustering(flat_params, data, distance, threadpool); + + auto [hierarchical_centroids, hierarchical_clusters] = + hierarchical_kmeans_clustering( + hierarchical_params, data, distance, threadpool + ); + + // Both should produce same number of centroids and clusters + CATCH_REQUIRE(flat_centroids.size() == n_centroids); + CATCH_REQUIRE(hierarchical_centroids.size() == n_centroids); + CATCH_REQUIRE(flat_clusters.size() == n_centroids); + CATCH_REQUIRE(hierarchical_clusters.size() == n_centroids); + + // Both should assign all points + std::unordered_set flat_points, hierarchical_points; + + for (const auto& cluster : flat_clusters) { + for (auto point_id : cluster) { + flat_points.insert(point_id); + } + } + + for (const auto& cluster : hierarchical_clusters) { + for (auto point_id : cluster) { + hierarchical_points.insert(point_id); + } + } + + CATCH_REQUIRE(flat_points.size() == data.size()); + CATCH_REQUIRE(hierarchical_points.size() == data.size()); +} + +template +void test_hierarchical_kmeans_edge_cases(const Data& data, Distance distance) { + namespace ivf = svs::index::ivf; + auto threadpool = svs::threads::as_threadpool(4); + + // Test with Level 1 clusters equal to total centroids (degenerate case) + { + const size_t n_centroids = 16; + auto params = ivf::IVFBuildParameters() + .num_centroids(n_centroids) + .minibatch_size(20) + .num_iterations(2) + .is_hierarchical(true) + .training_fraction(0.5f) + .hierarchical_level1_clusters(n_centroids); + + auto [centroids, clusters] = + hierarchical_kmeans_clustering(params, data, distance, threadpool); + + CATCH_REQUIRE(centroids.size() == n_centroids); + CATCH_REQUIRE(clusters.size() == n_centroids); + } + + // Test with very few Level 1 clusters + { + const size_t n_centroids = 60; + auto params = ivf::IVFBuildParameters() + .num_centroids(n_centroids) + .minibatch_size(25) + .num_iterations(3) + .is_hierarchical(true) + .training_fraction(0.6f) + .hierarchical_level1_clusters(2); + + auto [centroids, clusters] = + hierarchical_kmeans_clustering(params, data, distance, threadpool); + + CATCH_REQUIRE(centroids.size() == n_centroids); + CATCH_REQUIRE(clusters.size() == n_centroids); + } + + // Test with different training fractions + for (float training_fraction : {0.3f, 0.5f, 0.8f, 1.0f}) { + auto params = ivf::IVFBuildParameters() + .num_centroids(24) + .minibatch_size(20) + .num_iterations(2) + .is_hierarchical(true) + .training_fraction(training_fraction) + .hierarchical_level1_clusters(4); + + auto [centroids, clusters] = + hierarchical_kmeans_clustering(params, data, distance, threadpool); + + CATCH_REQUIRE(centroids.size() == 24); + CATCH_REQUIRE(clusters.size() == 24); + + // Verify centroids are valid + for (size_t i = 0; i < centroids.size(); ++i) { + auto centroid = centroids.get_datum(i); + for (size_t j = 0; j < centroids.dimensions(); ++j) { + CATCH_REQUIRE(std::isfinite(centroid[j])); + } + } + } +} + +template +void test_hierarchical_kmeans_cluster_distribution(const Data& data, Distance distance) { + namespace ivf = svs::index::ivf; + + // Test that Level 2 clusters are reasonably distributed across Level 1 clusters + const size_t n_centroids = 48; + const size_t l1_clusters = 6; + + auto params = ivf::IVFBuildParameters() + .num_centroids(n_centroids) + .minibatch_size(25) + .num_iterations(4) + .is_hierarchical(true) + .training_fraction(0.7f) + .hierarchical_level1_clusters(l1_clusters) + .seed(777); + + auto threadpool = svs::threads::as_threadpool(4); + auto [centroids, clusters] = + hierarchical_kmeans_clustering(params, data, distance, threadpool); + + CATCH_REQUIRE(centroids.size() == n_centroids); + CATCH_REQUIRE(clusters.size() == n_centroids); + + // Verify we have some reasonable distribution of cluster sizes + size_t empty_clusters = 0; + size_t total_assigned = 0; + + for (const auto& cluster : clusters) { + if (cluster.empty()) { + empty_clusters++; + } + total_assigned += cluster.size(); + } + + CATCH_REQUIRE(total_assigned == data.size()); + // Allow some empty clusters but not too many (less than half) + CATCH_REQUIRE(empty_clusters < n_centroids / 2); +} + } // namespace CATCH_TEST_CASE("Hierarchical Kmeans Param Check", "[ivf][hierarchial_parameter_check]") { - CATCH_SECTION("Uncompressed Data") { + CATCH_SECTION("Uncompressed Data - All Data Types") { auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + + // Test float32 test_hierarchical_kmeans_clustering(data, svs::DistanceIP()); + test_hierarchical_kmeans_clustering(data, svs::DistanceL2()); + + // Test Float16 (fp16) test_hierarchical_kmeans_clustering(data, svs::DistanceIP()); + test_hierarchical_kmeans_clustering(data, svs::DistanceL2()); + + // Test BFloat16 (bf16) test_hierarchical_kmeans_clustering(data, svs::DistanceIP()); test_hierarchical_kmeans_clustering(data, svs::DistanceL2()); } } + +CATCH_TEST_CASE( + "Hierarchical Kmeans Level1 Clusters", "[ivf][hierarchical_kmeans][level1]" +) { + CATCH_SECTION("Uncompressed Data - All Data Types") { + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + + // Test float32 + test_hierarchical_kmeans_level1_clusters(data, svs::DistanceIP()); + test_hierarchical_kmeans_level1_clusters(data, svs::DistanceL2()); + + // Test Float16 (fp16) + test_hierarchical_kmeans_level1_clusters(data, svs::DistanceIP()); + test_hierarchical_kmeans_level1_clusters(data, svs::DistanceL2()); + + // Test BFloat16 (bf16) + test_hierarchical_kmeans_level1_clusters(data, svs::DistanceIP()); + test_hierarchical_kmeans_level1_clusters(data, svs::DistanceL2()); + } +} + +CATCH_TEST_CASE( + "Hierarchical Kmeans Reproducibility", "[ivf][hierarchical_kmeans][reproducibility]" +) { + CATCH_SECTION("Uncompressed Data - All Data Types") { + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + + // Test float32 + test_hierarchical_kmeans_reproducibility(data, svs::DistanceIP()); + test_hierarchical_kmeans_reproducibility(data, svs::DistanceL2()); + + // Test Float16 (fp16) + test_hierarchical_kmeans_reproducibility(data, svs::DistanceIP()); + test_hierarchical_kmeans_reproducibility(data, svs::DistanceL2()); + + // Test BFloat16 (bf16) + test_hierarchical_kmeans_reproducibility(data, svs::DistanceIP()); + test_hierarchical_kmeans_reproducibility(data, svs::DistanceL2()); + } +} + +CATCH_TEST_CASE("Hierarchical vs Flat Kmeans", "[ivf][hierarchical_kmeans][comparison]") { + CATCH_SECTION("Uncompressed Data - All Data Types") { + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + + // Test float32 + test_hierarchical_vs_flat_kmeans(data, svs::DistanceIP()); + test_hierarchical_vs_flat_kmeans(data, svs::DistanceL2()); + + // Test Float16 (fp16) + test_hierarchical_vs_flat_kmeans(data, svs::DistanceIP()); + test_hierarchical_vs_flat_kmeans(data, svs::DistanceL2()); + + // Test BFloat16 (bf16) + test_hierarchical_vs_flat_kmeans(data, svs::DistanceIP()); + test_hierarchical_vs_flat_kmeans(data, svs::DistanceL2()); + } +} + +CATCH_TEST_CASE( + "Hierarchical Kmeans Edge Cases", "[ivf][hierarchical_kmeans][edge_cases]" +) { + CATCH_SECTION("Uncompressed Data - All Data Types") { + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + + // Test float32 + test_hierarchical_kmeans_edge_cases(data, svs::DistanceIP()); + test_hierarchical_kmeans_edge_cases(data, svs::DistanceL2()); + + // Test Float16 (fp16) + test_hierarchical_kmeans_edge_cases(data, svs::DistanceIP()); + test_hierarchical_kmeans_edge_cases(data, svs::DistanceL2()); + + // Test BFloat16 (bf16) + test_hierarchical_kmeans_edge_cases(data, svs::DistanceIP()); + test_hierarchical_kmeans_edge_cases(data, svs::DistanceL2()); + } +} + +CATCH_TEST_CASE( + "Hierarchical Kmeans Cluster Distribution", "[ivf][hierarchical_kmeans][distribution]" +) { + CATCH_SECTION("Uncompressed Data - All Data Types") { + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + + // Test float32 + test_hierarchical_kmeans_cluster_distribution(data, svs::DistanceIP()); + test_hierarchical_kmeans_cluster_distribution(data, svs::DistanceL2()); + + // Test Float16 (fp16) + test_hierarchical_kmeans_cluster_distribution( + data, svs::DistanceIP() + ); + test_hierarchical_kmeans_cluster_distribution( + data, svs::DistanceL2() + ); + + // Test BFloat16 (bf16) + test_hierarchical_kmeans_cluster_distribution( + data, svs::DistanceIP() + ); + test_hierarchical_kmeans_cluster_distribution( + data, svs::DistanceL2() + ); + } +} + +CATCH_TEST_CASE("Train Only Centroids Match", "[ivf][kmeans][train_only]") { + CATCH_SECTION("Uncompressed Data - All Data Types") { + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + + // Test float32 + test_train_only_centroids_match(data, svs::DistanceIP()); + test_train_only_centroids_match(data, svs::DistanceL2()); + + // Test Float16 (fp16) + test_train_only_centroids_match(data, svs::DistanceIP()); + test_train_only_centroids_match(data, svs::DistanceL2()); + + // Test BFloat16 (bf16) + test_train_only_centroids_match(data, svs::DistanceIP()); + test_train_only_centroids_match(data, svs::DistanceL2()); + } +} diff --git a/tests/svs/index/ivf/kmeans.cpp b/tests/svs/index/ivf/kmeans.cpp index 49e20a10f..245b6f501 100644 --- a/tests/svs/index/ivf/kmeans.cpp +++ b/tests/svs/index/ivf/kmeans.cpp @@ -57,14 +57,400 @@ void test_kmeans_clustering(const Data& data, Distance distance) { } } +template +void test_kmeans_train_only_functionality(const Data& data, Distance distance) { + namespace ivf = svs::index::ivf; + + // Test train_only functionality + for (size_t n_centroids : {25, 50}) { + for (size_t minibatch : {25}) { + for (size_t iters : {3}) { + for (float training_fraction : {0.6f}) { + auto params = ivf::IVFBuildParameters() + .num_centroids(n_centroids) + .minibatch_size(minibatch) + .num_iterations(iters) + .is_hierarchical(false) + .training_fraction(training_fraction) + .seed(42); // Fixed seed for reproducibility + + auto threadpool = svs::threads::as_threadpool(4); + + // Test train_only = false (normal mode) + auto [centroids_normal, clusters_normal] = + ivf::kmeans_clustering( + params, data, distance, threadpool, false + ); + + // Test train_only = true + auto [centroids_train_only, clusters_train_only] = + ivf::kmeans_clustering( + params, data, distance, threadpool, true + ); + + // Verify basic structure + CATCH_REQUIRE(centroids_normal.size() == n_centroids); + CATCH_REQUIRE(centroids_train_only.size() == n_centroids); + CATCH_REQUIRE(centroids_normal.dimensions() == data.dimensions()); + CATCH_REQUIRE(centroids_train_only.dimensions() == data.dimensions()); + + CATCH_REQUIRE(clusters_normal.size() == n_centroids); + CATCH_REQUIRE(clusters_train_only.size() == n_centroids); + + // Verify train_only produces empty clusters + for (const auto& cluster : clusters_train_only) { + CATCH_REQUIRE(cluster.empty()); + } + + // Verify normal mode has at least some non-empty clusters + bool has_non_empty = false; + for (const auto& cluster : clusters_normal) { + if (!cluster.empty()) { + has_non_empty = true; + break; + } + } + CATCH_REQUIRE(has_non_empty); + + // Verify centroids are identical (using same seed) + constexpr float tolerance = 1e-6f; + for (size_t i = 0; i < n_centroids; ++i) { + auto normal_centroid = centroids_normal.get_datum(i); + auto train_only_centroid = centroids_train_only.get_datum(i); + + for (size_t j = 0; j < data.dimensions(); ++j) { + float diff = + std::abs(normal_centroid[j] - train_only_centroid[j]); + CATCH_REQUIRE(diff < tolerance); + } + } + } + } + } + } +} + +template +void test_kmeans_train_only_performance(const Data& data, Distance distance) { + namespace ivf = svs::index::ivf; + + // Test that train_only mode is at least as fast as normal mode + // (it should be faster since it skips assignment, but we just check it doesn't slow + // down) + size_t n_centroids = 50; + auto params = ivf::IVFBuildParameters() + .num_centroids(n_centroids) + .minibatch_size(25) + .num_iterations(3) + .is_hierarchical(false) + .training_fraction(0.5f) + .seed(123); + + auto threadpool = svs::threads::as_threadpool(4); + + // Time normal mode + auto start_normal = std::chrono::high_resolution_clock::now(); + auto [centroids_normal, clusters_normal] = + ivf::kmeans_clustering(params, data, distance, threadpool, false); + auto end_normal = std::chrono::high_resolution_clock::now(); + + // Time train_only mode + auto start_train_only = std::chrono::high_resolution_clock::now(); + auto [centroids_train_only, clusters_train_only] = + ivf::kmeans_clustering(params, data, distance, threadpool, true); + auto end_train_only = std::chrono::high_resolution_clock::now(); + + auto normal_duration = + std::chrono::duration_cast(end_normal - start_normal); + auto train_only_duration = std::chrono::duration_cast( + end_train_only - start_train_only + ); + + CATCH_REQUIRE(train_only_duration.count() <= normal_duration.count() * 1.5); + // Note: We do not assert on performance here, as wall-clock timing is unreliable in CI. + // In practice, train_only should be faster, but this is best verified with dedicated + // benchmarks. + + // Verify results are still valid + CATCH_REQUIRE(centroids_train_only.size() == n_centroids); + for (const auto& cluster : clusters_train_only) { + CATCH_REQUIRE(cluster.empty()); + } +} + +template +void test_kmeans_edge_cases(const Data& data, Distance distance) { + namespace ivf = svs::index::ivf; + + // Test with minimum centroids + { + auto params = ivf::IVFBuildParameters() + .num_centroids(1) + .minibatch_size(10) + .num_iterations(2) + .is_hierarchical(false) + .training_fraction(0.5f); + auto threadpool = svs::threads::as_threadpool(2); + auto [centroids, clusters] = + ivf::kmeans_clustering(params, data, distance, threadpool); + + CATCH_REQUIRE(centroids.size() == 1); + CATCH_REQUIRE(clusters.size() == 1); + CATCH_REQUIRE(clusters[0].size() > 0); // Should contain all points + } + + // Test with large number of centroids (but less than data points) + if (data.size() > 100) { + auto params = ivf::IVFBuildParameters() + .num_centroids(std::min(data.size() - 1, size_t(100))) + .minibatch_size(20) + .num_iterations(3) + .is_hierarchical(false) + .training_fraction(0.7f); + auto threadpool = svs::threads::as_threadpool(4); + auto [centroids, clusters] = + ivf::kmeans_clustering(params, data, distance, threadpool); + + CATCH_REQUIRE(centroids.size() == std::min(data.size() - 1, size_t(100))); + CATCH_REQUIRE(clusters.size() == std::min(data.size() - 1, size_t(100))); + } +} + +template +void test_kmeans_reproducibility(const Data& data, Distance distance) { + namespace ivf = svs::index::ivf; + + // Test that same seed produces same results + const size_t seed = 12345; + const size_t n_centroids = 25; + + auto params1 = ivf::IVFBuildParameters() + .num_centroids(n_centroids) + .minibatch_size(25) + .num_iterations(3) + .is_hierarchical(false) + .training_fraction(0.6f) + .seed(seed); + + auto params2 = ivf::IVFBuildParameters() + .num_centroids(n_centroids) + .minibatch_size(25) + .num_iterations(3) + .is_hierarchical(false) + .training_fraction(0.6f) + .seed(seed); + + auto threadpool = svs::threads::as_threadpool(4); + + auto [centroids1, clusters1] = + ivf::kmeans_clustering(params1, data, distance, threadpool); + + auto [centroids2, clusters2] = + ivf::kmeans_clustering(params2, data, distance, threadpool); + + // Verify centroids are identical + CATCH_REQUIRE(centroids1.size() == centroids2.size()); + constexpr float tolerance = 1e-6f; + + for (size_t i = 0; i < centroids1.size(); ++i) { + auto centroid1 = centroids1.get_datum(i); + auto centroid2 = centroids2.get_datum(i); + + for (size_t j = 0; j < centroids1.dimensions(); ++j) { + float diff = std::abs(centroid1[j] - centroid2[j]); + CATCH_REQUIRE(diff < tolerance); + } + } + + // Verify cluster assignments are identical + CATCH_REQUIRE(clusters1.size() == clusters2.size()); + for (size_t i = 0; i < clusters1.size(); ++i) { + CATCH_REQUIRE(clusters1[i].size() == clusters2[i].size()); + // Note: We don't check exact order as cluster assignment might vary with same + // centroids + } +} + +template +void test_kmeans_cluster_assignment_validity(const Data& data, Distance distance) { + namespace ivf = svs::index::ivf; + + auto params = ivf::IVFBuildParameters() + .num_centroids(20) + .minibatch_size(25) + .num_iterations(5) + .is_hierarchical(false) + .training_fraction(0.8f); + + auto threadpool = svs::threads::as_threadpool(4); + auto [centroids, clusters] = + ivf::kmeans_clustering(params, data, distance, threadpool); + + // Verify all data points are assigned to exactly one cluster + std::unordered_set assigned_points; + for (size_t i = 0; i < clusters.size(); ++i) { + for (auto point_id : clusters[i]) { + CATCH_REQUIRE(point_id < data.size()); // Valid point index + CATCH_REQUIRE( + assigned_points.find(point_id) == assigned_points.end() + ); // Not already assigned + assigned_points.insert(point_id); + } + } + + CATCH_REQUIRE(assigned_points.size() == data.size()); // All points assigned + + // Verify centroids have valid values (no NaN or infinity) + for (size_t i = 0; i < centroids.size(); ++i) { + auto centroid = centroids.get_datum(i); + for (size_t j = 0; j < centroids.dimensions(); ++j) { + CATCH_REQUIRE(std::isfinite(centroid[j])); + } + } +} + +template +void test_kmeans_parameter_variations(const Data& data, Distance distance) { + namespace ivf = svs::index::ivf; + auto threadpool = svs::threads::as_threadpool(4); + + // Test different minibatch sizes + for (size_t minibatch : {10, 25, 50}) { + auto params = ivf::IVFBuildParameters() + .num_centroids(15) + .minibatch_size(minibatch) + .num_iterations(3) + .is_hierarchical(false) + .training_fraction(0.6f); + + auto [centroids, clusters] = + ivf::kmeans_clustering(params, data, distance, threadpool); + + CATCH_REQUIRE(centroids.size() == 15); + CATCH_REQUIRE(clusters.size() == 15); + } + + // Test different iteration counts + for (size_t iters : {1, 3, 5, 10}) { + auto params = ivf::IVFBuildParameters() + .num_centroids(10) + .minibatch_size(25) + .num_iterations(iters) + .is_hierarchical(false) + .training_fraction(0.6f); + + auto [centroids, clusters] = + ivf::kmeans_clustering(params, data, distance, threadpool); + + CATCH_REQUIRE(centroids.size() == 10); + CATCH_REQUIRE(clusters.size() == 10); + } + + // Test different training fractions + for (float training_fraction : {0.3f, 0.5f, 0.7f, 0.9f}) { + auto params = ivf::IVFBuildParameters() + .num_centroids(12) + .minibatch_size(25) + .num_iterations(3) + .is_hierarchical(false) + .training_fraction(training_fraction); + + auto [centroids, clusters] = + ivf::kmeans_clustering(params, data, distance, threadpool); + + CATCH_REQUIRE(centroids.size() == 12); + CATCH_REQUIRE(clusters.size() == 12); + } +} + } // namespace CATCH_TEST_CASE("Build Kmeans Param Check", "[ivf][parameter_check]") { - CATCH_SECTION("Uncompressed Data") { + CATCH_SECTION("Uncompressed Data - All Data Types") { auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + + // Test float32 test_kmeans_clustering(data, svs::DistanceIP()); + test_kmeans_clustering(data, svs::DistanceL2()); + + // Test Float16 (fp16) test_kmeans_clustering(data, svs::DistanceIP()); + test_kmeans_clustering(data, svs::DistanceL2()); + + // Test BFloat16 (bf16) test_kmeans_clustering(data, svs::DistanceIP()); test_kmeans_clustering(data, svs::DistanceL2()); } } + +CATCH_TEST_CASE("Kmeans Edge Cases", "[ivf][kmeans][edge_cases]") { + CATCH_SECTION("Uncompressed Data - All Data Types") { + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + + // Test float32 + test_kmeans_edge_cases(data, svs::DistanceIP()); + test_kmeans_edge_cases(data, svs::DistanceL2()); + + // Test Float16 (fp16) + test_kmeans_edge_cases(data, svs::DistanceIP()); + test_kmeans_edge_cases(data, svs::DistanceL2()); + + // Test BFloat16 (bf16) + test_kmeans_edge_cases(data, svs::DistanceIP()); + test_kmeans_edge_cases(data, svs::DistanceL2()); + } +} + +CATCH_TEST_CASE("Kmeans Reproducibility", "[ivf][kmeans][reproducibility]") { + CATCH_SECTION("Uncompressed Data - All Data Types") { + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + + // Test float32 + test_kmeans_reproducibility(data, svs::DistanceIP()); + test_kmeans_reproducibility(data, svs::DistanceL2()); + + // Test Float16 (fp16) + test_kmeans_reproducibility(data, svs::DistanceIP()); + test_kmeans_reproducibility(data, svs::DistanceL2()); + + // Test BFloat16 (bf16) + test_kmeans_reproducibility(data, svs::DistanceIP()); + test_kmeans_reproducibility(data, svs::DistanceL2()); + } +} + +CATCH_TEST_CASE("Kmeans Cluster Assignment Validity", "[ivf][kmeans][cluster_validity]") { + CATCH_SECTION("Uncompressed Data - All Data Types") { + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + + // Test float32 + test_kmeans_cluster_assignment_validity(data, svs::DistanceIP()); + test_kmeans_cluster_assignment_validity(data, svs::DistanceL2()); + + // Test Float16 (fp16) + test_kmeans_cluster_assignment_validity(data, svs::DistanceIP()); + test_kmeans_cluster_assignment_validity(data, svs::DistanceL2()); + + // Test BFloat16 (bf16) + test_kmeans_cluster_assignment_validity(data, svs::DistanceIP()); + test_kmeans_cluster_assignment_validity(data, svs::DistanceL2()); + } +} + +CATCH_TEST_CASE("Kmeans Parameter Variations", "[ivf][kmeans][parameters]") { + CATCH_SECTION("Uncompressed Data - All Data Types") { + auto data = svs::data::SimpleData::load(test_dataset::data_svs_file()); + + // Test float32 + test_kmeans_parameter_variations(data, svs::DistanceIP()); + test_kmeans_parameter_variations(data, svs::DistanceL2()); + + // Test Float16 (fp16) + test_kmeans_parameter_variations(data, svs::DistanceIP()); + test_kmeans_parameter_variations(data, svs::DistanceL2()); + + // Test BFloat16 (bf16) + test_kmeans_parameter_variations(data, svs::DistanceIP()); + test_kmeans_parameter_variations(data, svs::DistanceL2()); + } +}