diff --git a/.github/workflows/04-android-build.yml b/.github/workflows/04-android-build.yml index 5ee6fe345..63c70a5b5 100644 --- a/.github/workflows/04-android-build.yml +++ b/.github/workflows/04-android-build.yml @@ -20,6 +20,8 @@ jobs: steps: - name: Checkout uses: actions/checkout@v6 + with: + submodules: recursive - name: Cache dependencies uses: actions/cache@v5 @@ -60,7 +62,7 @@ jobs: - name: Use host env to compile protoc shell: bash run: | - git submodule update --init + git submodule update --init --recursive if [ ! -d "build-host" ]; then export CCACHE_BASEDIR="$GITHUB_WORKSPACE" export CCACHE_NOHASHDIR=1 diff --git a/.github/workflows/scripts/run_vdb.sh b/.github/workflows/scripts/run_vdb.sh index f153a598f..b944fc5b2 100644 --- a/.github/workflows/scripts/run_vdb.sh +++ b/.github/workflows/scripts/run_vdb.sh @@ -15,7 +15,7 @@ echo "workspace: $GITHUB_WORKSPACE" DB_LABEL_PREFIX="Zvec16c64g-$COMMIT_ID" # install zvec -git submodule update --init +git submodule update --init --recursive # for debug #cd .. @@ -85,4 +85,4 @@ EOF cat prom_metrics.txt curl --data-binary @prom_metrics.txt "http://47.93.34.27:9091/metrics/job/benchmarks-${CASE_TYPE}/case_type/${CASE_TYPE}/quantize_type/${QUANTIZE_TYPE}" -v done -done \ No newline at end of file +done diff --git a/.gitmodules b/.gitmodules index 49ed1920b..0d3fe89b9 100644 --- a/.gitmodules +++ b/.gitmodules @@ -40,6 +40,9 @@ path = thirdparty/magic_enum/magic_enum-0.9.7 url = https://github.com/Neargye/magic_enum.git ignore = all +[submodule "thirdparty/omega/OMEGALib"] + path = thirdparty/omega/OMEGALib + url = https://github.com/driPyf/OMEGALib.git [submodule "thirdparty/RaBitQ-Library/RaBitQ-Library-0.1"] path = thirdparty/RaBitQ-Library/RaBitQ-Library-0.1 url = https://github.com/VectorDB-NTU/RaBitQ-Library.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 55e31591f..a368854d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -65,6 +65,14 @@ message(STATUS "BUILD_C_BINDINGS:${BUILD_C_BINDINGS}") option(BUILD_TOOLS "Build tools" ON) message(STATUS "BUILD_TOOLS:${BUILD_TOOLS}") +if(ANDROID) + option(ZVEC_ENABLE_OMEGA "Build OMEGA support" OFF) +else() + option(ZVEC_ENABLE_OMEGA "Build OMEGA support" ON) +endif() +message(STATUS "ZVEC_ENABLE_OMEGA:${ZVEC_ENABLE_OMEGA}") +add_compile_definitions(ZVEC_ENABLE_OMEGA=$) + option(RABITQ_ENABLE_AVX512 "Compile RaBitQ with AVX-512 support" OFF) if(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|amd64|AMD64" AND NOT ANDROID) diff --git a/cmake/bazel.cmake b/cmake/bazel.cmake index 2910cc166..457669a4a 100644 --- a/cmake/bazel.cmake +++ b/cmake/bazel.cmake @@ -453,7 +453,7 @@ if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) "$<$:-Wall;-Wextra;-Wshadow>" "$<$:-Wall;-Wextra;-Wshadow>" "$<$:-Wall;-Wextra;-Wshadow-local;-Wno-misleading-indentation>" - "$<$:/W4>" + "$<$:/W4;/utf-8>" ${BAZEL_CC_ASAN_COMPILE_FLAGS} ${BAZEL_CC_COVERAGE_COMPILE_FLAGS} ) @@ -463,7 +463,7 @@ else() "$<$:-Wall;-Wextra;-Wshadow>" "$<$:-Wall;-Wextra;-Wshadow>" "$<$:-Wall;-Wextra;-Wshadow;-Wno-misleading-indentation>" - "$<$:/W4>" + "$<$:/W4;/utf-8>" ${BAZEL_CC_ASAN_COMPILE_FLAGS} ${BAZEL_CC_COVERAGE_COMPILE_FLAGS} ) @@ -484,7 +484,7 @@ set( "$<$:-Wall>" "$<$:-Wall>" "$<$:-Wall>" - "$<$:/W3>" + "$<$:/W3;/utf-8>" ${BAZEL_CC_ASAN_COMPILE_FLAGS} ${BAZEL_CC_COVERAGE_COMPILE_FLAGS} ) diff --git a/examples/c++/CMakeLists.txt b/examples/c++/CMakeLists.txt index 13b8d7c59..91c3bf548 100644 --- a/examples/c++/CMakeLists.txt +++ b/examples/c++/CMakeLists.txt @@ -16,6 +16,12 @@ set(ZVEC_INCLUDE_DIR ${CMAKE_BINARY_DIR}/../../../src/include) set(ZVEC_LIB_DIR ${CMAKE_BINARY_DIR}/../../../${HOST_BUILD_DIR}/lib) set(ZVEC_DEPENDENCY_LIB_DIR ${CMAKE_BINARY_DIR}/../../../${HOST_BUILD_DIR}/external/usr/local/lib) +if(ANDROID) + option(ZVEC_ENABLE_OMEGA "Link examples against OMEGA support from the host build" OFF) +else() + option(ZVEC_ENABLE_OMEGA "Link examples against OMEGA support from the host build" ON) +endif() + # Add include and library search paths include_directories(${ZVEC_INCLUDE_DIR}) link_directories(${ZVEC_LIB_DIR} ${ZVEC_DEPENDENCY_LIB_DIR}) @@ -31,13 +37,48 @@ else() set(PROTOBUF_LIB protobuf) endif() +if(ZVEC_ENABLE_OMEGA) + set(_zvec_omega_search_paths + ${ZVEC_DEPENDENCY_LIB_DIR} + ${ZVEC_DEPENDENCY_LIB_DIR}/Debug + ${ZVEC_DEPENDENCY_LIB_DIR}/Release + ${ZVEC_DEPENDENCY_LIB_DIR}/RelWithDebInfo + ${ZVEC_DEPENDENCY_LIB_DIR}/MinSizeRel + ) + find_library(ZVEC_OMEGA_LIB + NAMES omega + PATHS ${_zvec_omega_search_paths} + NO_DEFAULT_PATH + ) + find_library(ZVEC_LIGHTGBM_LIB + NAMES lib_lightgbm _lightgbm + PATHS ${_zvec_omega_search_paths} + NO_DEFAULT_PATH + ) + if(NOT ZVEC_OMEGA_LIB OR NOT ZVEC_LIGHTGBM_LIB) + message(FATAL_ERROR + "ZVEC_ENABLE_OMEGA=ON but failed to locate OMEGA host libraries under " + "${ZVEC_DEPENDENCY_LIB_DIR}" + ) + endif() +endif() + # --- Dependency groups --- find_package(Threads REQUIRED) +find_package(OpenMP QUIET) +set(zvec_openmp_deps) +if(OpenMP_FOUND AND NOT ANDROID) + list(APPEND zvec_openmp_deps OpenMP::OpenMP_CXX) +endif() set(zvec_core_deps zvec_turbo + ${zvec_openmp_deps} ) +if(ZVEC_ENABLE_OMEGA) + list(APPEND zvec_core_deps ${ZVEC_OMEGA_LIB} ${ZVEC_LIGHTGBM_LIB}) +endif() if (NOT WIN32) set(zvec_ailego_deps @@ -63,7 +104,11 @@ if (NOT WIN32) ${GFLAGS_LIB} ${PROTOBUF_LIB} lz4 + ${zvec_openmp_deps} ) + if(ZVEC_ENABLE_OMEGA) + list(APPEND zvec_db_deps ${ZVEC_OMEGA_LIB} ${ZVEC_LIGHTGBM_LIB}) + endif() else () # Windows static libraries use different naming conventions set(PROTOBUF_LIB libprotobuf) @@ -90,7 +135,11 @@ else () lz4 rpcrt4 shlwapi + ${zvec_openmp_deps} ) + if(ZVEC_ENABLE_OMEGA) + list(APPEND zvec_db_deps ${ZVEC_OMEGA_LIB} ${ZVEC_LIGHTGBM_LIB}) + endif() endif () # --- Create INTERFACE targets for Zvec components --- @@ -127,6 +176,11 @@ elseif(APPLE) zvec-ailego ${zvec_core_deps} ) + if(ZVEC_ENABLE_OMEGA) + target_link_libraries(zvec-core INTERFACE + -Wl,-force_load ${ZVEC_OMEGA_LIB} + ) + endif() elseif(ANDROID) target_link_libraries(zvec-core INTERFACE -Wl,--whole-archive @@ -161,6 +215,7 @@ if(CMAKE_SYSTEM_NAME STREQUAL "Linux") ) elseif(APPLE) target_link_libraries(zvec-db INTERFACE + -Wl,-force_load ${ZVEC_LIB_DIR}/libzvec_db.a zvec_db zvec-core zvec-ailego @@ -208,6 +263,13 @@ target_link_libraries(ailego-example PRIVATE zvec-ailego ) +if(ZVEC_ENABLE_OMEGA) + add_executable(omega-example omega/main.cc) + target_link_libraries(omega-example PRIVATE + zvec-core + ) +endif() + # Strip symbols to reduce executable size if(CMAKE_BUILD_TYPE STREQUAL "Release" AND ANDROID) add_custom_command(TARGET db-example POST_BUILD @@ -219,6 +281,11 @@ if(CMAKE_BUILD_TYPE STREQUAL "Release" AND ANDROID) add_custom_command(TARGET ailego-example POST_BUILD COMMAND ${CMAKE_STRIP} "$" COMMENT "Stripping symbols from ailego-example") + if(ZVEC_ENABLE_OMEGA) + add_custom_command(TARGET omega-example POST_BUILD + COMMAND ${CMAKE_STRIP} "$" + COMMENT "Stripping symbols from omega-example") + endif() endif() # Optimize for size @@ -227,4 +294,10 @@ if(CMAKE_BUILD_TYPE STREQUAL "Release" AND ANDROID) PROPERTY COMPILE_FLAGS "-Os") set_property(TARGET db-example core-example ailego-example PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) + if(ZVEC_ENABLE_OMEGA) + set_property(TARGET omega-example + PROPERTY COMPILE_FLAGS "-Os") + set_property(TARGET omega-example + PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) + endif() endif() diff --git a/examples/c++/omega/main.cc b/examples/c++/omega/main.cc new file mode 100644 index 000000000..b15b5d331 --- /dev/null +++ b/examples/c++/omega/main.cc @@ -0,0 +1,92 @@ +#include +#include +#include +#include +#include +#include + +using namespace zvec::core_interface; + +namespace { + +constexpr uint32_t kDimension = 32; +const std::string kIndexPath = "omega_example.index"; + +BaseIndexParam::Pointer CreateOmegaParam() { + auto param = HNSWIndexParamBuilder() + .WithMetricType(MetricType::kInnerProduct) + .WithDataType(DataType::DT_FP32) + .WithDimension(kDimension) + .WithIsSparse(false) + .WithM(8) + .WithEFConstruction(64) + .Build(); + param->index_type = IndexType::kOMEGA; + return param; +} + +} // namespace + +int main() { + std::filesystem::remove_all(kIndexPath); + + auto index = IndexFactory::CreateAndInitIndex(*CreateOmegaParam()); + if (!index) { + std::cerr << "failed to create omega index" << std::endl; + return 1; + } + + if (index->Open(kIndexPath, StorageOptions{StorageOptions::StorageType::kMMAP, + true}) != 0) { + std::cerr << "failed to open omega index" << std::endl; + return 1; + } + + for (uint32_t doc_id = 0; doc_id < 6; ++doc_id) { + std::vector values(kDimension, static_cast(doc_id) / 10.0f); + values[0] = 1.0f + static_cast(doc_id); + VectorData vector_data; + vector_data.vector = DenseVector{values.data()}; + if (index->Add(vector_data, doc_id) != 0) { + std::cerr << "failed to add document " << doc_id << std::endl; + return 1; + } + } + + if (index->Train() != 0) { + std::cerr << "failed to train omega index" << std::endl; + return 1; + } + + std::vector query_values(kDimension, 0.0f); + query_values[0] = 1.0f; + VectorData query{DenseVector{query_values.data()}}; + + auto query_param = OmegaQueryParamBuilder() + .with_topk(3) + .with_fetch_vector(true) + .with_ef_search(32) + .with_target_recall(0.95f) + .build(); + + SearchResult result; + if (index->Search(query, query_param, &result) != 0) { + std::cerr << "failed to search omega index" << std::endl; + return 1; + } + + std::cout << "omega results: " << result.doc_list_.size() << std::endl; + if (result.doc_list_.empty()) { + std::cerr << "omega example returned no results" << std::endl; + return 1; + } + + std::cout << "top result key=" << result.doc_list_[0].key() + << " score=" << result.doc_list_[0].score() << std::endl; + if (index->Close() != 0) { + std::cerr << "failed to close omega index" << std::endl; + return 1; + } + + return 0; +} diff --git a/examples/c/CMakeLists.txt b/examples/c/CMakeLists.txt index 5edb881a6..353d2641a 100644 --- a/examples/c/CMakeLists.txt +++ b/examples/c/CMakeLists.txt @@ -29,14 +29,13 @@ endif() set(ZVEC_INCLUDE_DIR ${CMAKE_BINARY_DIR}/../../../src/include) set(ZVEC_GENERATED_INCLUDE_DIR ${CMAKE_BINARY_DIR}/../../../${HOST_BUILD_DIR}/src/generated) +set(ZVEC_LIB_DIR ${CMAKE_BINARY_DIR}/../../../${HOST_BUILD_DIR}/lib) +set(ZVEC_DEPENDENCY_LIB_DIR ${CMAKE_BINARY_DIR}/../../../${HOST_BUILD_DIR}/external/usr/local/lib) -# On Windows, libraries are in Debug/Release subdirectories -if(WIN32) - set(ZVEC_LIB_DIR ${CMAKE_BINARY_DIR}/../../../${HOST_BUILD_DIR}/lib/$) - set(ZVEC_DEPENDENCY_LIB_DIR ${CMAKE_BINARY_DIR}/../../../${HOST_BUILD_DIR}/external/usr/local/lib/$) +if(ANDROID) + option(ZVEC_ENABLE_OMEGA "Link examples against OMEGA support from the host build" OFF) else() - set(ZVEC_LIB_DIR ${CMAKE_BINARY_DIR}/../../../${HOST_BUILD_DIR}/lib) - set(ZVEC_DEPENDENCY_LIB_DIR ${CMAKE_BINARY_DIR}/../../../${HOST_BUILD_DIR}/external/usr/local/lib) + option(ZVEC_ENABLE_OMEGA "Link examples against OMEGA support from the host build" ON) endif() # Add include and library search paths @@ -45,6 +44,12 @@ link_directories(${ZVEC_LIB_DIR} ${ZVEC_DEPENDENCY_LIB_DIR}) # Find required packages find_package(Threads REQUIRED) +find_package(OpenMP QUIET) + +set(zvec_openmp_deps) +if(OpenMP_FOUND AND NOT ANDROID) + list(APPEND zvec_openmp_deps OpenMP::OpenMP_CXX) +endif() # --- Determine debug/release library names --- if(CMAKE_BUILD_TYPE STREQUAL "Debug") @@ -57,6 +62,32 @@ else() set(PROTOBUF_LIB protobuf) endif() +if(ZVEC_ENABLE_OMEGA) + set(_zvec_omega_search_paths + ${ZVEC_DEPENDENCY_LIB_DIR} + ${ZVEC_DEPENDENCY_LIB_DIR}/Debug + ${ZVEC_DEPENDENCY_LIB_DIR}/Release + ${ZVEC_DEPENDENCY_LIB_DIR}/RelWithDebInfo + ${ZVEC_DEPENDENCY_LIB_DIR}/MinSizeRel + ) + find_library(ZVEC_OMEGA_LIB + NAMES omega + PATHS ${_zvec_omega_search_paths} + NO_DEFAULT_PATH + ) + find_library(ZVEC_LIGHTGBM_LIB + NAMES lib_lightgbm _lightgbm + PATHS ${_zvec_omega_search_paths} + NO_DEFAULT_PATH + ) + if(NOT ZVEC_OMEGA_LIB OR NOT ZVEC_LIGHTGBM_LIB) + message(FATAL_ERROR + "ZVEC_ENABLE_OMEGA=ON but failed to locate OMEGA host libraries under " + "${ZVEC_DEPENDENCY_LIB_DIR}" + ) + endif() +endif() + # --- Dependency groups --- if(NOT WIN32) set(zvec_c_api_deps @@ -73,9 +104,13 @@ if(NOT WIN32) ${GFLAGS_LIB} ${PROTOBUF_LIB} lz4 + ${zvec_openmp_deps} ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS} ) + if(ZVEC_ENABLE_OMEGA) + list(APPEND zvec_c_api_deps ${ZVEC_OMEGA_LIB} ${ZVEC_LIGHTGBM_LIB}) + endif() else() # Windows static libraries use different naming conventions set(PROTOBUF_LIB libprotobuf) @@ -93,10 +128,14 @@ else() ${GFLAGS_LIB} ${PROTOBUF_LIB} lz4 + ${zvec_openmp_deps} ${CMAKE_THREAD_LIBS_INIT} rpcrt4 shlwapi ) + if(ZVEC_ENABLE_OMEGA) + list(APPEND zvec_c_api_deps ${ZVEC_OMEGA_LIB} ${ZVEC_LIGHTGBM_LIB}) + endif() endif() # Create INTERFACE target for zvec_c_api with platform-specific linking @@ -176,6 +215,13 @@ target_link_libraries(c_api_optimized_example PRIVATE zvec-c-api ) +if(ZVEC_ENABLE_OMEGA) + add_executable(c_api_omega_example omega_example.c) + target_link_libraries(c_api_omega_example PRIVATE + zvec-c-api + ) +endif() + # Strip symbols to reduce executable size if(CMAKE_BUILD_TYPE STREQUAL "Release" AND (ANDROID OR (CMAKE_SYSTEM_NAME STREQUAL "Linux"))) add_custom_command(TARGET c_api_basic_example POST_BUILD @@ -196,6 +242,11 @@ if(CMAKE_BUILD_TYPE STREQUAL "Release" AND (ANDROID OR (CMAKE_SYSTEM_NAME STREQU add_custom_command(TARGET c_api_optimized_example POST_BUILD COMMAND ${CMAKE_STRIP} "$" COMMENT "Stripping symbols from c_api_optimized_example") + if(ZVEC_ENABLE_OMEGA) + add_custom_command(TARGET c_api_omega_example POST_BUILD + COMMAND ${CMAKE_STRIP} "$" + COMMENT "Stripping symbols from c_api_omega_example") + endif() endif() # Optimize for size @@ -206,4 +257,10 @@ if(CMAKE_BUILD_TYPE STREQUAL "Release" AND ANDROID) set_property(TARGET c_api_basic_example c_api_collection_schema_example c_api_doc_example c_api_index_example c_api_field_schema_example c_api_optimized_example PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) -endif() \ No newline at end of file + if(ZVEC_ENABLE_OMEGA) + set_property(TARGET c_api_omega_example + PROPERTY COMPILE_FLAGS "-Os") + set_property(TARGET c_api_omega_example + PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE) + endif() +endif() diff --git a/examples/c/omega_example.c b/examples/c/omega_example.c new file mode 100644 index 000000000..caf1c5c5d --- /dev/null +++ b/examples/c/omega_example.c @@ -0,0 +1,171 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "zvec/c_api.h" + +static ZVecErrorCode check(ZVecErrorCode error, const char *context) { + if (error != ZVEC_OK) { + char *error_msg = NULL; + zvec_get_last_error(&error_msg); + fprintf(stderr, "%s failed: %s\n", context, + error_msg ? error_msg : "unknown error"); + zvec_free(error_msg); + } + return error; +} + +static ZVecErrorCode create_omega_collection(ZVecCollection **collection) { + ZVecCollectionSchema *schema = + zvec_collection_schema_create("omega_collection"); + ZVecCollectionOptions *options = NULL; + ZVecIndexParams *invert_params = NULL; + ZVecIndexParams *omega_params = NULL; + ZVecFieldSchema *id_field = NULL; + ZVecFieldSchema *embedding_field = NULL; + ZVecErrorCode error = ZVEC_OK; + + if (!schema) { + return ZVEC_ERROR_INTERNAL_ERROR; + } + + invert_params = zvec_index_params_create(ZVEC_INDEX_TYPE_INVERT); + omega_params = zvec_index_params_create(ZVEC_INDEX_TYPE_OMEGA); + options = zvec_collection_options_create(); + if (!invert_params || !omega_params || !options) { + error = ZVEC_ERROR_RESOURCE_EXHAUSTED; + goto cleanup; + } + + zvec_index_params_set_invert_params(invert_params, true, false); + zvec_index_params_set_metric_type(omega_params, ZVEC_METRIC_TYPE_IP); + zvec_index_params_set_hnsw_params(omega_params, 8, 64); + + id_field = zvec_field_schema_create("id", ZVEC_DATA_TYPE_STRING, false, 0); + embedding_field = zvec_field_schema_create( + "embedding", ZVEC_DATA_TYPE_VECTOR_FP32, false, 4); + if (!id_field || !embedding_field) { + error = ZVEC_ERROR_RESOURCE_EXHAUSTED; + goto cleanup; + } + + zvec_field_schema_set_index_params(id_field, invert_params); + zvec_field_schema_set_index_params(embedding_field, omega_params); + + error = zvec_collection_schema_add_field(schema, id_field); + if (error != ZVEC_OK) goto cleanup; + error = zvec_collection_schema_add_field(schema, embedding_field); + if (error != ZVEC_OK) goto cleanup; + + error = zvec_collection_create_and_open("./omega_collection", schema, options, + collection); + +cleanup: + zvec_index_params_destroy(invert_params); + zvec_index_params_destroy(omega_params); + zvec_collection_options_destroy(options); + zvec_collection_schema_destroy(schema); + return error; +} + +int main(void) { + ZVecCollection *collection = NULL; + ZVecVectorQuery *query = NULL; + ZVecOmegaQueryParams *omega_query = NULL; + ZVecDoc *docs[2] = {NULL, NULL}; + ZVecDoc **results = NULL; + size_t result_count = 0; + int exit_code = 1; + float vector1[] = {1.0f, 0.1f, 0.1f, 0.1f}; + float vector2[] = {2.0f, 0.2f, 0.2f, 0.2f}; + + remove("./omega_collection"); + + if (check(create_omega_collection(&collection), "create_omega_collection") != + ZVEC_OK) { + goto cleanup; + } + + for (int i = 0; i < 2; ++i) { + docs[i] = zvec_doc_create(); + if (!docs[i]) { + fprintf(stderr, "failed to create document %d\n", i); + goto cleanup; + } + } + + zvec_doc_set_pk(docs[0], "doc1"); + zvec_doc_add_field_by_value(docs[0], "id", ZVEC_DATA_TYPE_STRING, "doc1", 4); + zvec_doc_add_field_by_value(docs[0], "embedding", ZVEC_DATA_TYPE_VECTOR_FP32, + vector1, sizeof(vector1)); + + zvec_doc_set_pk(docs[1], "doc2"); + zvec_doc_add_field_by_value(docs[1], "id", ZVEC_DATA_TYPE_STRING, "doc2", 4); + zvec_doc_add_field_by_value(docs[1], "embedding", ZVEC_DATA_TYPE_VECTOR_FP32, + vector2, sizeof(vector2)); + + { + size_t success_count = 0; + size_t error_count = 0; + if (check(zvec_collection_insert(collection, (const ZVecDoc **)docs, 2, + &success_count, &error_count), + "zvec_collection_insert") != ZVEC_OK) { + goto cleanup; + } + } + + if (check(zvec_collection_flush(collection), "zvec_collection_flush") != + ZVEC_OK) { + goto cleanup; + } + + query = zvec_vector_query_create(); + omega_query = zvec_query_params_omega_create(32, 0.95f, 0.0f, false, false); + if (!query || !omega_query) { + fprintf(stderr, "failed to create omega query\n"); + goto cleanup; + } + + zvec_vector_query_set_field_name(query, "embedding"); + zvec_vector_query_set_query_vector(query, vector1, sizeof(vector1)); + zvec_vector_query_set_topk(query, 2); + zvec_vector_query_set_include_doc_id(query, true); + zvec_vector_query_set_omega_params(query, omega_query); + + if (check(zvec_collection_query(collection, query, &results, &result_count), + "zvec_collection_query") != ZVEC_OK) { + goto cleanup; + } + + printf("omega c example results: %zu\n", result_count); + if (result_count == 0) { + fprintf(stderr, "omega c example returned no results\n"); + goto cleanup; + } + printf("top result score: %.4f\n", zvec_doc_get_score(results[0])); + + exit_code = 0; + +cleanup: + zvec_docs_free(results, result_count); + zvec_vector_query_destroy(query); + if (collection) { + zvec_collection_destroy(collection); + } + zvec_doc_destroy(docs[0]); + zvec_doc_destroy(docs[1]); + return exit_code; +} diff --git a/python/tests/test_collection.py b/python/tests/test_collection.py index 7d021d6fd..6f37b5dfe 100644 --- a/python/tests/test_collection.py +++ b/python/tests/test_collection.py @@ -13,17 +13,25 @@ # limitations under the License. from __future__ import annotations +import os +import sys +from pathlib import Path import pytest import zvec from zvec import ( Collection, + CollectionSchema, CollectionOption, DataType, Doc, FieldSchema, HnswIndexParam, + HnswQueryParam, + OmegaIndexParam, + OmegaQueryParam, InvertIndexParam, + MetricType, LogLevel, LogType, VectorSchema, @@ -34,6 +42,24 @@ OptimizeOption, ) +IS_ANDROID = hasattr(sys, "getandroidapilevel") or "ANDROID_ROOT" in os.environ +OMEGA_ENABLED = os.environ.get("ZVEC_ENABLE_OMEGA", "1").lower() not in { + "0", + "off", + "false", + "no", +} +OMEGA_AVAILABLE = OMEGA_ENABLED and not IS_ANDROID +OMEGA_ANDROID_SKIP = pytest.mark.skipif( + not OMEGA_AVAILABLE, reason="OMEGA is disabled on this build/platform" +) + + +def _require_omega() -> None: + if not OMEGA_AVAILABLE: + pytest.skip("OMEGA is disabled on this build/platform") + + # ==================== Common ==================== @@ -129,6 +155,138 @@ def test_collection( print(f"Warning: failed to destroy collection: {e}") +@pytest.fixture(scope="session") +def omega_collection_schema(): + _require_omega() + return zvec.CollectionSchema( + name="omega_test_collection", + fields=[ + FieldSchema( + "id", + DataType.INT64, + nullable=False, + index_param=InvertIndexParam(enable_range_optimization=True), + ), + FieldSchema("name", DataType.STRING, nullable=False), + ], + vectors=[ + VectorSchema( + "dense", + DataType.VECTOR_FP32, + dimension=128, + index_param=OmegaIndexParam( + metric_type=MetricType.L2, + min_vector_threshold=100000, + num_training_queries=16, + ef_training=64, + window_size=32, + ), + ) + ], + ) + + +@pytest.fixture(scope="function") +def omega_test_collection( + tmp_path_factory, omega_collection_schema, collection_option +) -> Collection: + _require_omega() + temp_dir = tmp_path_factory.mktemp("zvec_omega") + collection_path = temp_dir / "omega_test_collection" + + coll = zvec.create_and_open( + path=str(collection_path), + schema=omega_collection_schema, + option=collection_option, + ) + + assert coll is not None + assert coll.path == str(collection_path) + assert coll.schema.vectors[0].index_param.type == IndexType.OMEGA + + try: + yield coll + finally: + if hasattr(coll, "destroy") and coll is not None: + try: + coll.destroy() + except Exception as e: + print(f"Warning: failed to destroy omega collection: {e}") + + +@pytest.fixture +def omega_multiple_docs(): + _require_omega() + return [ + Doc( + id=f"{id}", + fields={"id": id, "name": f"doc-{id}"}, + vectors={"dense": [id + 0.1] * 128}, + ) + for id in range(1, 65) + ] + + +@pytest.fixture +def omega_workflow_docs(): + _require_omega() + return [ + Doc( + id=f"{id}", + fields={"id": id, "name": f"workflow-doc-{id}"}, + vectors={"dense": [float(id) + 0.1] * 128}, + ) + for id in range(1, 129) + ] + + +def _create_vector_collection( + tmp_path_factory, + collection_option: CollectionOption, + collection_name: str, + vector_index_param, +) -> Collection: + temp_dir = tmp_path_factory.mktemp(collection_name) + collection_path = temp_dir / collection_name + schema = CollectionSchema( + name=collection_name, + fields=[ + FieldSchema( + "id", + DataType.INT64, + nullable=False, + index_param=InvertIndexParam(enable_range_optimization=True), + ), + FieldSchema("name", DataType.STRING, nullable=False), + ], + vectors=[ + VectorSchema( + "dense", + DataType.VECTOR_FP32, + dimension=128, + index_param=vector_index_param, + ) + ], + ) + coll = zvec.create_and_open( + path=str(collection_path), schema=schema, option=collection_option + ) + assert coll is not None + return coll + + +def _omega_model_files(collection: Collection) -> set[str]: + return { + path.name + for path in Path(collection.path).rglob("*") + if path.is_file() and path.parent.name == "omega_model" + } + + +def _result_ids(result) -> list[str]: + return [doc.id for doc in result] + + @pytest.fixture def collection_with_single_doc(test_collection: Collection, single_doc) -> Collection: # Setup: insert single doc @@ -969,6 +1127,141 @@ def test_collection_query_by_id( ) assert len(result) == 10 + @OMEGA_ANDROID_SKIP + def test_omega_collection_schema_uses_omega_index( + self, omega_test_collection: Collection + ): + vector_schema = omega_test_collection.schema.vector("dense") + assert vector_schema is not None + assert vector_schema.index_param.type == IndexType.OMEGA + + @OMEGA_ANDROID_SKIP + def test_omega_collection_query_by_id_with_omega_param( + self, omega_test_collection: Collection, omega_multiple_docs + ): + result = omega_test_collection.insert(omega_multiple_docs) + assert len(result) == len(omega_multiple_docs) + for item in result: + assert item.ok() + + query_result = omega_test_collection.query( + VectorQuery( + field_name="dense", + vector=omega_multiple_docs[0].vector("dense"), + param=OmegaQueryParam(ef=128, target_recall=0.91), + ), + topk=5, + ) + assert len(query_result) > 0 + assert query_result[0].id == omega_multiple_docs[0].id + + @OMEGA_ANDROID_SKIP + def test_omega_workflow_optimize_trains_model_and_query_runs( + self, tmp_path_factory, collection_option, omega_workflow_docs + ): + omega_collection = _create_vector_collection( + tmp_path_factory, + collection_option, + "omega_workflow_active", + OmegaIndexParam( + metric_type=MetricType.L2, + min_vector_threshold=32, + num_training_queries=16, + ef_training=64, + ef_groundtruth=128, + window_size=32, + ), + ) + try: + result = omega_collection.insert(omega_workflow_docs) + assert len(result) == len(omega_workflow_docs) + for item in result: + assert item.ok() + + omega_collection.optimize(option=OptimizeOption(concurrency=1)) + + model_files = _omega_model_files(omega_collection) + assert { + "model.txt", + "threshold_table.txt", + "interval_table.txt", + "gt_collected_table.txt", + "gt_cmps_all_table.txt", + }.issubset(model_files) + + query_result = omega_collection.query( + VectorQuery( + field_name="dense", + vector=omega_workflow_docs[31].vector("dense"), + param=OmegaQueryParam(ef=128, target_recall=0.91), + ), + topk=10, + ) + + assert len(query_result) == 10 + assert query_result[0].id == omega_workflow_docs[31].id + finally: + omega_collection.destroy() + + @OMEGA_ANDROID_SKIP + def test_omega_query_falls_back_to_hnsw_when_model_not_trained( + self, tmp_path_factory, collection_option, omega_workflow_docs + ): + hnsw_collection = _create_vector_collection( + tmp_path_factory, + collection_option, + "hnsw_workflow_baseline", + HnswIndexParam(metric_type=MetricType.L2), + ) + omega_collection = _create_vector_collection( + tmp_path_factory, + collection_option, + "omega_workflow_fallback", + OmegaIndexParam( + metric_type=MetricType.L2, + min_vector_threshold=100000, + num_training_queries=16, + ef_training=64, + ef_groundtruth=128, + window_size=32, + ), + ) + try: + hnsw_insert_result = hnsw_collection.insert(omega_workflow_docs) + omega_insert_result = omega_collection.insert(omega_workflow_docs) + assert len(hnsw_insert_result) == len(omega_workflow_docs) + assert len(omega_insert_result) == len(omega_workflow_docs) + + hnsw_collection.optimize(option=OptimizeOption(concurrency=1)) + omega_collection.optimize(option=OptimizeOption(concurrency=1)) + + assert "model.txt" not in _omega_model_files(omega_collection) + + # Use a non-document query vector to avoid equal-distance ties, + # so fallback-to-HNSW can be checked via exact result equality. + query_vector = [64.3] * 128 + hnsw_result = hnsw_collection.query( + VectorQuery( + field_name="dense", + vector=query_vector, + param=HnswQueryParam(ef=128), + ), + topk=10, + ) + omega_result = omega_collection.query( + VectorQuery( + field_name="dense", + vector=query_vector, + param=OmegaQueryParam(ef=128, target_recall=0.91), + ), + topk=10, + ) + + assert _result_ids(omega_result) == _result_ids(hnsw_result) + finally: + omega_collection.destroy() + hnsw_collection.destroy() + def test_collection_query_multi_vector_with_same_field( self, collection_with_multiple_docs: Collection, multiple_docs ): diff --git a/python/tests/test_params.py b/python/tests/test_params.py index 0a85a7a38..24be0b21d 100644 --- a/python/tests/test_params.py +++ b/python/tests/test_params.py @@ -13,6 +13,8 @@ # limitations under the License. from __future__ import annotations +import os +import pickle import sys import time @@ -25,11 +27,13 @@ CollectionOption, FlatIndexParam, HnswIndexParam, + OmegaIndexParam, IndexOption, InvertIndexParam, IVFIndexParam, OptimizeOption, HnswQueryParam, + OmegaQueryParam, IVFQueryParam, VectorQuery, IndexType, @@ -41,6 +45,18 @@ from _zvec.param import _VectorQuery +IS_ANDROID = hasattr(sys, "getandroidapilevel") or "ANDROID_ROOT" in os.environ +OMEGA_ENABLED = os.environ.get("ZVEC_ENABLE_OMEGA", "1").lower() not in { + "0", + "off", + "false", + "no", +} +OMEGA_AVAILABLE = OMEGA_ENABLED and not IS_ANDROID +OMEGA_ANDROID_SKIP = pytest.mark.skipif( + not OMEGA_AVAILABLE, reason="OMEGA is disabled on this build/platform" +) + # ---------------------------- # Invert Index Param Test Case # ---------------------------- @@ -177,6 +193,70 @@ def test_readonly_attributes(self, attr): setattr(param, attr, getattr(param, attr)) +# ---------------------------- +# OMEGA Index Param Test Case +# ---------------------------- +@OMEGA_ANDROID_SKIP +class TestOmegaIndexParam: + def test_default(self): + param = OmegaIndexParam() + assert param.type == IndexType.OMEGA + assert param.metric_type == MetricType.IP + assert param.m == 50 + assert param.ef_construction == 500 + assert param.quantize_type == QuantizeType.UNDEFINED + assert param.min_vector_threshold == 100000 + assert param.num_training_queries == 1000 + assert param.ef_training == 1000 + assert param.window_size == 100 + assert param.ef_groundtruth == 0 + assert param.k_train == 1 + + def test_custom(self): + param = OmegaIndexParam( + metric_type=MetricType.COSINE, + m=32, + ef_construction=700, + quantize_type=QuantizeType.INT8, + min_vector_threshold=1234, + num_training_queries=567, + ef_training=890, + window_size=42, + ef_groundtruth=321, + k_train=3, + ) + assert param.metric_type == MetricType.COSINE + assert param.m == 32 + assert param.ef_construction == 700 + assert param.quantize_type == QuantizeType.INT8 + assert param.min_vector_threshold == 1234 + assert param.num_training_queries == 567 + assert param.ef_training == 890 + assert param.window_size == 42 + assert param.ef_groundtruth == 321 + assert param.k_train == 3 + + def test_pickle_round_trip(self): + param = OmegaIndexParam( + metric_type=MetricType.COSINE, + m=24, + ef_construction=320, + quantize_type=QuantizeType.INT8, + min_vector_threshold=2048, + num_training_queries=256, + ef_training=640, + window_size=48, + ef_groundtruth=96, + k_train=2, + ) + restored = pickle.loads(pickle.dumps(param)) + assert restored.type == IndexType.OMEGA + assert restored.metric_type == MetricType.COSINE + assert restored.m == 24 + assert restored.ef_training == 640 + assert restored.k_train == 2 + + # ---------------------------- # CollectionOption Test Case # ---------------------------- @@ -345,6 +425,52 @@ def test_readonly_attributes(self): param.is_linear = True +# ---------------------------- +# OMEGA Query Param Test Case +# ---------------------------- +@OMEGA_ANDROID_SKIP +class TestOmegaQueryParam: + def test_default(self): + param = OmegaQueryParam() + assert param.type == IndexType.OMEGA + assert param.ef == 300 + assert param.target_recall == pytest.approx(0.95) + assert param.radius == pytest.approx(0.0) + assert param.is_linear is False + assert param.is_using_refiner is False + + def test_custom(self): + param = OmegaQueryParam( + ef=480, + target_recall=0.92, + radius=1.5, + is_linear=True, + is_using_refiner=True, + ) + assert param.type == IndexType.OMEGA + assert param.ef == 480 + assert param.target_recall == pytest.approx(0.92) + assert param.radius == pytest.approx(1.5) + assert param.is_linear is True + assert param.is_using_refiner is True + + def test_pickle_round_trip(self): + param = OmegaQueryParam( + ef=384, + target_recall=0.91, + radius=0.25, + is_linear=True, + is_using_refiner=True, + ) + restored = pickle.loads(pickle.dumps(param)) + assert restored.type == IndexType.OMEGA + assert restored.ef == 384 + assert restored.target_recall == pytest.approx(0.91) + assert restored.radius == pytest.approx(0.25) + assert restored.is_linear is True + assert restored.is_using_refiner is True + + # # ---------------------------- # # IVFQueryParam Test Case # # ---------------------------- @@ -389,6 +515,16 @@ def test_init_with_valid_vector(self): assert vq.vector == vec assert vq.param == param + @OMEGA_ANDROID_SKIP + def test_init_with_valid_omega_param(self): + vec = [0.1, 0.2, 0.3] + param = OmegaQueryParam(ef=256, target_recall=0.91) + vq = VectorQuery(field_name="embedding", vector=vec, param=param) + assert vq.field_name == "embedding" + assert vq.vector == vec + assert vq.param == param + assert vq.param.target_recall == pytest.approx(0.91) + def test_init_both_id_and_vector_raises_error(self): with pytest.raises(ValueError): VectorQuery(field_name="embedding", id="doc123", vector=[0.1])._validate() @@ -413,3 +549,22 @@ def test_validate_fails_on_both_id_and_vector(self): vq = VectorQuery(field_name="test", id="doc123", vector=[0.1]) with pytest.raises(ValueError): vq._validate() + + +@OMEGA_ANDROID_SKIP +class TestVectorSchemaWithOmega: + def test_accepts_omega_index_param(self): + schema = VectorSchema( + name="dense", + data_type=DataType.VECTOR_FP32, + dimension=8, + index_param=OmegaIndexParam( + metric_type=MetricType.COSINE, + m=16, + ef_construction=300, + window_size=64, + ), + ) + assert schema.index_param.type == IndexType.OMEGA + assert schema.index_param.metric_type == MetricType.COSINE + assert schema.index_param.window_size == 64 diff --git a/python/zvec/__init__.py b/python/zvec/__init__.py index 895897869..817988bf8 100644 --- a/python/zvec/__init__.py +++ b/python/zvec/__init__.py @@ -21,13 +21,7 @@ from importlib.metadata import PackageNotFoundError -# ============================== -# Public API — grouped by category -# ============================== - from . import model as model - -# —— Extensions —— from .extension import ( BM25EmbeddingFunction, DefaultLocalDenseEmbedding, @@ -46,16 +40,10 @@ SparseEmbeddingFunction, WeightedReRanker, ) - -# —— Typing —— from .model import param as param from .model import schema as schema - -# —— Core data structures —— from .model.collection import Collection from .model.doc import Doc - -# —— Query & index parameters —— from .model.param import ( AddColumnOption, AlterColumnOption, @@ -69,14 +57,12 @@ InvertIndexParam, IVFIndexParam, IVFQueryParam, + OmegaIndexParam, + OmegaQueryParam, OptimizeOption, ) from .model.param.vector_query import VectorQuery - -# —— Schema & field definitions —— from .model.schema import CollectionSchema, CollectionStats, FieldSchema, VectorSchema - -# —— tools —— from .tool import require_module from .typing import ( DataType, @@ -87,13 +73,8 @@ StatusCode, ) from .typing.enum import LogLevel, LogType - -# —— lifecycle —— from .zvec import create_and_open, init, open -# ============================== -# Public interface declaration -# ============================== __all__ = [ # Zvec functions "create_and_open", @@ -115,6 +96,7 @@ "HnswRabitqQueryParam", "FlatIndexParam", "IVFIndexParam", + "OmegaIndexParam", "CollectionOption", "IndexOption", "OptimizeOption", @@ -122,6 +104,7 @@ "AlterColumnOption", "HnswQueryParam", "IVFQueryParam", + "OmegaQueryParam", # Extensions "DenseEmbeddingFunction", "SparseEmbeddingFunction", diff --git a/python/zvec/__init__.pyi b/python/zvec/__init__.pyi index efb1b2dfb..c5b32b8ad 100644 --- a/python/zvec/__init__.pyi +++ b/python/zvec/__init__.pyi @@ -23,6 +23,8 @@ from .model.param import ( InvertIndexParam, IVFIndexParam, IVFQueryParam, + OmegaIndexParam, + OmegaQueryParam, OptimizeOption, ) from .model.param.vector_query import VectorQuery @@ -62,6 +64,8 @@ __all__: list = [ "LogLevel", "LogType", "MetricType", + "OmegaIndexParam", + "OmegaQueryParam", "OptimizeOption", "QuantizeType", "ReRanker", diff --git a/python/zvec/model/param/__init__.py b/python/zvec/model/param/__init__.py index c613edf52..ecafafde1 100644 --- a/python/zvec/model/param/__init__.py +++ b/python/zvec/model/param/__init__.py @@ -26,6 +26,8 @@ InvertIndexParam, IVFIndexParam, IVFQueryParam, + OmegaIndexParam, + OmegaQueryParam, OptimizeOption, ) @@ -42,5 +44,7 @@ "IVFQueryParam", "IndexOption", "InvertIndexParam", + "OmegaIndexParam", + "OmegaQueryParam", "OptimizeOption", ] diff --git a/python/zvec/model/param/__init__.pyi b/python/zvec/model/param/__init__.pyi index cd1491efa..03a696f2b 100644 --- a/python/zvec/model/param/__init__.pyi +++ b/python/zvec/model/param/__init__.pyi @@ -23,6 +23,8 @@ __all__: list[str] = [ "IndexOption", "IndexParam", "InvertIndexParam", + "OmegaIndexParam", + "OmegaQueryParam", "OptimizeOption", "QueryParam", "SegmentOption", @@ -654,6 +656,182 @@ class InvertIndexParam(IndexParam): bool: Whether range optimization is enabled for this inverted index. """ +class OmegaIndexParam(VectorIndexParam): + """ + + Parameters for configuring an OMEGA index. + + OMEGA is an advanced graph-based index that can fall back to HNSW when omega + functionality is disabled. This class encapsulates its construction hyperparameters. + + Attributes: + metric_type (MetricType): Distance metric used for similarity computation. + Default is ``MetricType.IP`` (inner product). + m (int): Number of bi-directional links created for every new element + during construction. Higher values improve accuracy but increase + memory usage and construction time. Default is 100. + ef_construction (int): Size of the dynamic candidate list for nearest + neighbors during index construction. Larger values yield better + graph quality at the cost of slower build time. Default is 500. + quantize_type (QuantizeType): Optional quantization type for vector + compression (e.g., FP16, INT8). Default is `QuantizeType.UNDEFINED` to + disable quantization. + + Examples: + >>> from zvec.typing import MetricType, QuantizeType + >>> params = OmegaIndexParam( + ... metric_type=MetricType.COSINE, + ... m=16, + ... ef_construction=200, + ... quantize_type=QuantizeType.INT8 + ... ) + >>> print(params) + {'metric_type': 'IP', 'm': 16, 'ef_construction': 200, 'quantize_type': 'INT8'} + """ + + def __getstate__(self) -> tuple: ... + def __init__( + self, + metric_type: _zvec.typing.MetricType = ..., + m: typing.SupportsInt = 100, + ef_construction: typing.SupportsInt = 500, + quantize_type: _zvec.typing.QuantizeType = ..., + min_vector_threshold: typing.SupportsInt = 100000, + num_training_queries: typing.SupportsInt = 1000, + ef_training: typing.SupportsInt = 1000, + window_size: typing.SupportsInt = 100, + ef_groundtruth: typing.SupportsInt = 0, + k_train: typing.SupportsInt = 1, + ) -> None: + """ + Constructs an OmegaIndexParam instance. + + Args: + metric_type (MetricType, optional): Distance metric. Defaults to MetricType.IP. + m (int, optional): Number of bi-directional links. Defaults to 100. + ef_construction (int, optional): Candidate list size during construction. + Defaults to 500. + quantize_type (QuantizeType, optional): Vector quantization type. + Defaults to QuantizeType.UNDEFINED. + min_vector_threshold (int, optional): Minimum doc count required to + enable OMEGA training. + num_training_queries (int, optional): Number of sampled queries for training. + ef_training (int, optional): Search ef used when collecting training records. + window_size (int, optional): Traversal window size. + ef_groundtruth (int, optional): ef used to compute training ground truth. + k_train (int, optional): Number of top GT results required for a + positive training label. + """ + + def __repr__(self) -> str: ... + def __setstate__(self, arg0: tuple) -> None: ... + def to_dict(self) -> dict: + """ + Convert to dictionary with all fields + """ + + @property + def ef_construction(self) -> int: + """ + int: Candidate list size during index construction. + """ + + @property + def m(self) -> int: + """ + int: Maximum number of neighbors per node in upper layers. + """ + + @property + def min_vector_threshold(self) -> int: + """ + int: Minimum vectors required to enable OMEGA optimization. + """ + + @property + def num_training_queries(self) -> int: + """ + int: Number of sampled queries used for OMEGA training. + """ + + @property + def ef_training(self) -> int: + """ + int: Search ef used when collecting training records. + """ + + @property + def window_size(self) -> int: + """ + int: Traversal window size used by OMEGA. + """ + + @property + def ef_groundtruth(self) -> int: + """ + int: ef used for ground truth computation (0 means brute force). + """ + + @property + def k_train(self) -> int: + """ + int: Number of top GT results required for a positive training label. + """ + +class OmegaQueryParam(HnswQueryParam): + """ + + Query parameters for OMEGA index with adaptive early stopping. + + OMEGA extends HNSW with machine learning-based early stopping that can + dynamically adjust search effort to meet a target recall. + + Attributes: + type (IndexType): Always ``IndexType.OMEGA``. + ef (int): Size of the dynamic candidate list during search. + Larger values improve recall but slow down search. + Default is 300. + target_recall (float): Target recall for OMEGA early stopping. + OMEGA will stop searching when predicted recall meets this target. + Valid range: 0.0 to 1.0. Default is 0.95. + radius (float): Search radius for range queries. Default is 0.0. + is_linear (bool): Force linear search. Default is False. + is_using_refiner (bool): Whether to use refiner for the query. Default is False. + + Examples: + >>> params = OmegaQueryParam(ef=300, target_recall=0.98) + >>> print(params.target_recall) + 0.98 + """ + def __getstate__(self) -> tuple: ... + def __init__( + self, + ef: typing.SupportsInt = 300, + target_recall: typing.SupportsFloat = 0.95, + radius: typing.SupportsFloat = 0.0, + is_linear: bool = False, + is_using_refiner: bool = False, + ) -> None: + """ + Constructs an OmegaQueryParam instance. + + Args: + ef (int, optional): Search-time candidate list size. + Higher values improve accuracy. Defaults to 300. + target_recall (float, optional): Target recall for early stopping. + Valid range: 0.0 to 1.0. Defaults to 0.95. + radius (float, optional): Search radius for range queries. Default is 0.0. + is_linear (bool, optional): Force linear search. Default is False. + is_using_refiner (bool, optional): Whether to use refiner. Default is False. + """ + def __repr__(self) -> str: ... + def __setstate__(self, arg0: tuple) -> None: ... + @property + def target_recall(self) -> float: + """ + float: Target recall for OMEGA early stopping (0.0 to 1.0). + """ + class OptimizeOption: """ @@ -663,21 +841,29 @@ class OptimizeOption: concurrency (int): Number of threads to use during optimization. If 0, the system will choose an optimal value automatically. Default is 0. + retrain_only (bool): Reuse existing indexes and only retrain OMEGA + models. Default is False. Examples: - >>> opt = OptimizeOption(concurrency=2) + >>> opt = OptimizeOption(concurrency=2, retrain_only=True) >>> print(opt.concurrency) 2 """ def __getstate__(self) -> tuple: ... - def __init__(self, concurrency: typing.SupportsInt = 0) -> None: + def __init__( + self, + concurrency: typing.SupportsInt = 0, + retrain_only: bool = False, + ) -> None: """ Constructs an OptimizeOption instance. Args: concurrency (int, optional): Number of concurrent threads. 0 means auto-detect. Defaults to 0. + retrain_only (bool, optional): Reuse existing indexes and only + retrain OMEGA models. Defaults to False. """ def __setstate__(self, arg0: tuple) -> None: ... @@ -686,6 +872,11 @@ class OptimizeOption: """ int: Number of threads used for optimization (0 = auto). """ + @property + def retrain_only(self) -> bool: + """ + bool: Whether to reuse existing indexes and only retrain OMEGA. + """ class QueryParam: """ diff --git a/python/zvec/model/param/vector_query.py b/python/zvec/model/param/vector_query.py index 97d105af6..0e27211a1 100644 --- a/python/zvec/model/param/vector_query.py +++ b/python/zvec/model/param/vector_query.py @@ -17,7 +17,7 @@ from typing import Optional, Union from ...common import VectorType -from . import HnswQueryParam, IVFQueryParam +from . import HnswQueryParam, IVFQueryParam, OmegaQueryParam __all__ = ["VectorQuery"] @@ -28,7 +28,8 @@ class VectorQuery: A `VectorQuery` can be constructed using either a document ID (to look up its vector) or an explicit vector. It may optionally include index-specific - query parameters to control search behavior (e.g., `ef` for HNSW, `nprobe` for IVF). + query parameters to control search behavior (e.g., `ef` for HNSW, `nprobe` for IVF, + `target_recall` for OMEGA). Exactly one of `id` or `vector` should be provided. If both are given, behavior is implementation-defined (typically `id` takes precedence). @@ -37,25 +38,31 @@ class VectorQuery: field_name (str): Name of the vector field to query. id (Optional[str], optional): Document ID to fetch vector from. Default is None. vector (VectorType, optional): Explicit query vector. Default is None. - param (Optional[Union[HnswQueryParam, IVFQueryParam]], optional): + param (Optional[Union[HnswQueryParam, IVFQueryParam, OmegaQueryParam]], optional): Index-specific query parameters. Default is None. Examples: >>> import zvec >>> # Query by ID >>> q1 = zvec.VectorQuery(field_name="embedding", id="doc123") - >>> # Query by vector + >>> # Query by vector with HNSW params >>> q2 = zvec.VectorQuery( ... field_name="embedding", ... vector=[0.1, 0.2, 0.3], ... param=HnswQueryParam(ef=300) ... ) + >>> # Query with OMEGA params and target recall + >>> q3 = zvec.VectorQuery( + ... field_name="embedding", + ... vector=[0.1, 0.2, 0.3], + ... param=OmegaQueryParam(ef=300, target_recall=0.98) + ... ) """ field_name: str id: Optional[str] = None vector: VectorType = None - param: Optional[Union[HnswQueryParam, IVFQueryParam]] = None + param: Optional[Union[HnswQueryParam, IVFQueryParam, OmegaQueryParam]] = None def has_id(self) -> bool: """Check if the query is based on a document ID. diff --git a/python/zvec/model/schema/field_schema.py b/python/zvec/model/schema/field_schema.py index da193dd5c..1233d1599 100644 --- a/python/zvec/model/schema/field_schema.py +++ b/python/zvec/model/schema/field_schema.py @@ -23,6 +23,7 @@ HnswIndexParam, InvertIndexParam, IVFIndexParam, + OmegaIndexParam, ) from zvec.typing import DataType @@ -188,19 +189,31 @@ class VectorSchema: data_type (DataType): Vector data type (e.g., VECTOR_FP32, VECTOR_INT8). dimension (int, optional): Dimensionality of the vector. Must be > 0 for dense vectors; may be `None` for sparse vectors. - index_param (Union[HnswIndexParam, IVFIndexParam, FlatIndexParam], optional): + index_param (Union[HnswIndexParam, IVFIndexParam, FlatIndexParam, OmegaIndexParam], optional): Index configuration for this vector field. Defaults to - ``HnswIndexParam()``. + ``FlatIndexParam()``. Examples: - >>> from zvec.typing import DataType - >>> from zvec.model.param import HnswIndexParam - >>> emb_field = VectorSchema( + >>> from zvec.typing import DataType, MetricType + >>> from zvec.model.param import HnswIndexParam, OmegaIndexParam + >>> # HNSW index + >>> hnsw_field = VectorSchema( ... name="embedding", ... data_type=DataType.VECTOR_FP32, ... dimension=128, ... index_param=HnswIndexParam(ef_construction=200, m=16) ... ) + >>> # OMEGA index (adaptive graph-based index with automatic training) + >>> omega_field = VectorSchema( + ... name="embedding", + ... data_type=DataType.VECTOR_FP32, + ... dimension=128, + ... index_param=OmegaIndexParam( + ... metric_type=MetricType.COSINE, + ... m=16, + ... ef_construction=200 + ... ) + ... ) """ def __init__( @@ -209,7 +222,7 @@ def __init__( data_type: DataType, dimension: Optional[int] = 0, index_param: Optional[ - Union[HnswIndexParam, FlatIndexParam, IVFIndexParam] + Union[HnswIndexParam, FlatIndexParam, IVFIndexParam, OmegaIndexParam] ] = None, ): if name is None or not isinstance(name, str): @@ -263,8 +276,10 @@ def dimension(self) -> int: return self._cpp_obj.dimension @property - def index_param(self) -> Union[HnswIndexParam, IVFIndexParam, FlatIndexParam]: - """Union[HnswIndexParam, IVFIndexParam, FlatIndexParam]: Index configuration for the vector.""" + def index_param( + self, + ) -> Union[HnswIndexParam, IVFIndexParam, FlatIndexParam, OmegaIndexParam]: + """Union[HnswIndexParam, IVFIndexParam, FlatIndexParam, OmegaIndexParam]: Index configuration for the vector.""" return self._cpp_obj.index_param def __dict__(self) -> dict[str, Any]: diff --git a/scripts/benchmark_hnsw_vs_omega.json b/scripts/benchmark_hnsw_vs_omega.json new file mode 100644 index 000000000..217c526d7 --- /dev/null +++ b/scripts/benchmark_hnsw_vs_omega.json @@ -0,0 +1,62 @@ +{ + "cohere_1m": { + "common": { + "case_type": "Performance768D1M", + "num_concurrency": "12,14,16,18,20", + "concurrency_duration": 30, + "k": 100, + "m": 64, + "ef_search": 300, + "quantize_type": "int8" + }, + "hnsw": { + "path": "cohere_1m_hnsw", + "db_label": "16c64g-v0.1-hnsw-m64-ef300", + "args": {} + }, + "omega": { + "path": "cohere_1m_omega", + "db_label": "16c64g-v0.1-omega-m64-ef300", + "target_recalls": [ + 0.92 + ], + "args": { + "min_vector_threshold": 100000, + "num_training_queries": 4000, + "ef_training": 500, + "window_size": 100, + "ef_groundtruth": 1000 + } + } + }, + "cohere_10m": { + "common": { + "case_type": "Performance768D10M", + "num_concurrency": "12,14,16,18,20", + "concurrency_duration": 30, + "k": 100, + "m": 64, + "ef_search": 600, + "quantize_type": "int8" + }, + "hnsw": { + "path": "cohere_10m_hnsw", + "db_label": "16c64g-v0.1-hnsw-m64-ef300", + "args": {} + }, + "omega": { + "path": "cohere_10m_omega", + "db_label": "16c64g-v0.1-omega-m64-ef300", + "target_recalls": [ + 0.92 + ], + "args": { + "min_vector_threshold": 100000, + "num_training_queries": 4000, + "ef_training": 1000, + "window_size": 100, + "ef_groundtruth": 2000 + } + } + } +} diff --git a/scripts/benchmark_hnsw_vs_omega.py b/scripts/benchmark_hnsw_vs_omega.py new file mode 100644 index 000000000..31b9aee04 --- /dev/null +++ b/scripts/benchmark_hnsw_vs_omega.py @@ -0,0 +1,452 @@ +from __future__ import annotations + +import argparse +import sys +from pathlib import Path +from typing import Any + +from benchmark_lib import ( + BenchmarkResult, + build_index, + compute_recall_with_zvec, + discover_index_files, + emit, + get_offline_load_duration, + load_dataset_config, + must_get, + prepare_dataset_artifacts, + print_header, + resolve_core_tools, + resolve_dataset_spec, + resolve_index_path, + resolve_paths, + run_concurrency_benchmark, + write_grouped_online_summaries, + write_offline_summary, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Benchmark Zvec HNSW vs OMEGA") + parser.add_argument("--config", required=True, help="Path to benchmark JSON config") + parser.add_argument( + "--dataset", + required=True, + help="Dataset key to run from the top-level JSON config map", + ) + parser.add_argument( + "--target-recalls", + type=str, + default=None, + help="Optional comma-separated override for omega.target_recalls in the JSON config", + ) + parser.add_argument( + "--dry-run", action="store_true", help="Print actions without executing" + ) + parser.add_argument("--skip-hnsw", action="store_true", help="Skip HNSW benchmark") + parser.add_argument( + "--skip-omega", action="store_true", help="Skip OMEGA benchmark" + ) + parser.add_argument( + "--build-only", action="store_true", help="Only build index, skip search" + ) + parser.add_argument( + "--search-only", action="store_true", help="Only run search on existing index" + ) + parser.add_argument( + "--retrain-only", + action="store_true", + help="Reuse existing OMEGA index and only retrain the model during the build phase", + ) + parser.add_argument( + "--zvec-root", + type=str, + default=None, + help="Path to zvec repo root (default: auto-detect from this script)", + ) + parser.add_argument( + "--benchmark-dir", + type=str, + default=None, + help="Directory used to store benchmark artifacts", + ) + parser.add_argument( + "--dataset-root", + type=str, + default=None, + help="Root directory containing the raw dataset files", + ) + return parser.parse_args() + + +def run_hnsw( + *, + args: argparse.Namespace, + dataset_spec: dict[str, object], + dataset_artifacts: dict[str, object], + bench_bin: Path, + hnsw_path: Path, + hnsw_db_label: str, + common: dict[str, object], + hnsw_config: dict[str, object], +) -> BenchmarkResult: + print_header("HNSW Benchmark") + + hnsw_specific_args = hnsw_config.get("args", {}) + if not args.search_only: + emit("\n[Phase 1] Building HNSW index...") + offline_metrics = build_index( + index_kind="HNSW", + index_path=hnsw_path, + dataset_spec=dataset_spec, + dataset_artifacts=dataset_artifacts, + common_args=common, + specific_args=hnsw_specific_args, + retrain_only=False, + dry_run=args.dry_run, + ) + if not args.dry_run: + write_offline_summary(hnsw_path, hnsw_db_label, offline_metrics) + + if args.build_only: + return BenchmarkResult( + type="HNSW", + path=str(hnsw_path), + success=True, + target_recall=None, + load_duration=get_offline_load_duration(hnsw_path), + ) + + index_files = discover_index_files(hnsw_path) + recall = compute_recall_with_zvec( + index_kind="HNSW", + index_path=hnsw_path, + dataset_artifacts=dataset_artifacts, + common_args=common, + target_recall=None, + dry_run=args.dry_run, + ) + + benchmark = run_concurrency_benchmark( + bench_bin=bench_bin, + index_files=index_files, + dataset_artifacts=dataset_artifacts, + dataset_spec=dataset_spec, + common_args=common, + target_recall=None, + dry_run=args.dry_run, + ) + online = benchmark["summary"] + success = online.get("retcode", 0) == 0 + + return BenchmarkResult( + type="HNSW", + path=str(hnsw_path), + success=success, + target_recall=None, + load_duration=get_offline_load_duration(hnsw_path), + qps=online.get("qps"), + avg_latency_ms=online.get("avg_latency_ms"), + p50_latency_ms=online.get("p50_latency_ms"), + p90_latency_ms=online.get("p90_latency_ms"), + p95_latency_ms=online.get("p95_latency_ms"), + p99_latency_ms=online.get("p99_latency_ms"), + recall=recall, + ) + + +def run_omega( + *, + args: argparse.Namespace, + dataset_spec: dict[str, object], + dataset_artifacts: dict[str, object], + bench_bin: Path, + omega_path: Path, + omega_db_label: str, + common: dict[str, object], + omega_config: dict[str, object], + target_recalls: list[float], +) -> list[BenchmarkResult]: + print_header("OMEGA Benchmark") + + omega_specific_args = omega_config.get("args", {}) + if not args.search_only: + phase_message = ( + "\n[Phase 1] Retraining OMEGA model only (reusing existing index)..." + if args.retrain_only + else "\n[Phase 1] Building OMEGA index + training model..." + ) + emit(phase_message) + offline_metrics = build_index( + index_kind="OMEGA", + index_path=omega_path, + dataset_spec=dataset_spec, + dataset_artifacts=dataset_artifacts, + common_args=common, + specific_args=omega_specific_args, + retrain_only=args.retrain_only, + dry_run=args.dry_run, + ) + if not args.dry_run: + write_offline_summary( + omega_path, + omega_db_label, + offline_metrics, + retrain_only=args.retrain_only, + ) + + if args.build_only: + return [ + BenchmarkResult( + type="OMEGA", + path=str(omega_path), + success=True, + target_recall=target_recall, + load_duration=get_offline_load_duration(omega_path), + ) + for target_recall in target_recalls + ] + + return _run_omega_searches( + args=args, + dataset_spec=dataset_spec, + dataset_artifacts=dataset_artifacts, + bench_bin=bench_bin, + omega_path=omega_path, + common=common, + target_recalls=target_recalls, + ) + + +def _run_omega_searches( + *, + args: argparse.Namespace, + dataset_spec: dict[str, object], + dataset_artifacts: dict[str, object], + bench_bin: Path, + omega_path: Path, + common: dict[str, object], + target_recalls: list[float], +) -> list[BenchmarkResult]: + results: list[BenchmarkResult] = [] + index_files = discover_index_files(omega_path) + omega_common = dict(common) + omega_common["k"] = int(common["k"]) + + for target_recall in target_recalls: + print_header(f"OMEGA Search Benchmark (target_recall={target_recall})") + recall = compute_recall_with_zvec( + index_kind="OMEGA", + index_path=omega_path, + dataset_artifacts=dataset_artifacts, + common_args=common, + target_recall=target_recall, + dry_run=args.dry_run, + ) + benchmark = run_concurrency_benchmark( + bench_bin=bench_bin, + index_files=index_files, + dataset_artifacts=dataset_artifacts, + dataset_spec=dataset_spec, + common_args=omega_common, + target_recall=target_recall, + dry_run=args.dry_run, + ) + online = benchmark["summary"] + results.append( + BenchmarkResult( + type="OMEGA", + path=str(omega_path), + success=online.get("retcode", 0) == 0, + target_recall=target_recall, + load_duration=get_offline_load_duration(omega_path), + qps=online.get("qps"), + avg_latency_ms=online.get("avg_latency_ms"), + p50_latency_ms=online.get("p50_latency_ms"), + p90_latency_ms=online.get("p90_latency_ms"), + p95_latency_ms=online.get("p95_latency_ms"), + p99_latency_ms=online.get("p99_latency_ms"), + recall=recall, + ) + ) + + return results + + +def _parse_target_recalls( + args: argparse.Namespace, omega_config: dict[str, object] +) -> list[float]: + target_recalls = omega_config.get("target_recalls", []) + if args.target_recalls: + target_recalls = [ + float(value) for value in args.target_recalls.split(",") if value + ] + if not target_recalls: + raise ValueError("omega.target_recalls must be a non-empty list") + return list(target_recalls) + + +def _emit_run_header( + *, + dataset_name: str, + config_path: Path, + zvec_root: Path, + benchmark_dir: Path, + dataset_spec: dict[str, object], + bench_bin: Path, + recall_bin: Path, + hnsw_path: Path, + omega_path: Path, + target_recalls: list[float], +) -> None: + emit("=" * 70) + emit(f"Zvec HNSW vs OMEGA ({dataset_name})") + emit(f"Config: {config_path}") + emit("=" * 70) + emit(f"zvec_root: {zvec_root}") + emit(f"benchmark_dir: {benchmark_dir}") + emit(f"dataset_dir: {dataset_spec['dataset_dir']}") + emit(f"bench_bin: {bench_bin}") + emit(f"recall_bin: {recall_bin}") + emit(f"hnsw_path: {hnsw_path}") + emit(f"omega_path: {omega_path}") + emit(f"target_recalls: {target_recalls}") + emit("=" * 70) + + +def _format_summary_row(result: BenchmarkResult) -> str: + tr = f"{result.target_recall:.2f}" if result.target_recall is not None else "N/A" + status = "OK" if result.success else "FAILED" + ld = f"{result.load_duration:.1f}" if result.load_duration is not None else "N/A" + qps = f"{result.qps:.1f}" if result.qps is not None else "N/A" + avg_latency = ( + f"{result.avg_latency_ms:.3f}" if result.avg_latency_ms is not None else "N/A" + ) + p95_latency = ( + f"{result.p95_latency_ms:.3f}" if result.p95_latency_ms is not None else "N/A" + ) + recall = f"{result.recall:.4f}" if result.recall is not None else "N/A" + return ( + f"{result.type:<10} {tr:<15} {ld:<12} {qps:<8} " + f"{avg_latency:<16} {p95_latency:<16} {recall:<10} {status:<10}" + ) + + +def _emit_result_summary( + results: list[BenchmarkResult], summary_paths: list[Path] +) -> None: + emit(f"\n\n{'=' * 70}") + emit("Benchmark Summary") + emit("=" * 70) + emit( + f"{'Type':<10} {'target_recall':<15} {'load_dur(s)':<12} " + f"{'qps':<8} {'avg_latency(ms)':<16} {'p95_latency(ms)':<16} " + f"{'recall':<10} {'Status':<10}" + ) + emit("-" * 100) + for result in results: + emit(_format_summary_row(result)) + + emit() + for path in summary_paths: + emit(f"Summary JSON: {path}") + + +def _resolve_run_context(args: argparse.Namespace) -> dict[str, Any]: + config_path = Path(args.config).expanduser().resolve() + config = load_dataset_config(config_path, args.dataset) + zvec_root, benchmark_dir = resolve_paths( + Path(__file__).resolve(), + config, + args.zvec_root, + args.benchmark_dir, + ) + dataset_spec = resolve_dataset_spec(args.dataset, config, args.dataset_root) + dataset_artifacts = prepare_dataset_artifacts( + args.dataset, dataset_spec, benchmark_dir, dry_run=args.dry_run + ) + bench_bin, recall_bin = resolve_core_tools(zvec_root) + benchmark_dir.mkdir(parents=True, exist_ok=True) + + hnsw_config = must_get(config, "hnsw") + omega_config = must_get(config, "omega") + + return { + "config_path": config_path, + "dataset_name": args.dataset, + "dataset_spec": dataset_spec, + "dataset_artifacts": dataset_artifacts, + "zvec_root": zvec_root, + "benchmark_dir": benchmark_dir, + "bench_bin": bench_bin, + "recall_bin": recall_bin, + "common": must_get(config, "common"), + "hnsw_config": hnsw_config, + "omega_config": omega_config, + "hnsw_path": resolve_index_path(benchmark_dir, must_get(hnsw_config, "path")), + "omega_path": resolve_index_path(benchmark_dir, must_get(omega_config, "path")), + "hnsw_db_label": must_get(hnsw_config, "db_label"), + "omega_db_label": must_get(omega_config, "db_label"), + "target_recalls": _parse_target_recalls(args, omega_config), + } + + +def main() -> int: + args = parse_args() + context = _resolve_run_context(args) + _emit_run_header( + dataset_name=context["dataset_name"], + config_path=context["config_path"], + zvec_root=context["zvec_root"], + benchmark_dir=context["benchmark_dir"], + dataset_spec=context["dataset_spec"], + bench_bin=context["bench_bin"], + recall_bin=context["recall_bin"], + hnsw_path=context["hnsw_path"], + omega_path=context["omega_path"], + target_recalls=context["target_recalls"], + ) + + results: list[BenchmarkResult] = [] + if not args.skip_hnsw: + results.append( + run_hnsw( + args=args, + dataset_spec=context["dataset_spec"], + dataset_artifacts=context["dataset_artifacts"], + bench_bin=context["bench_bin"], + hnsw_path=context["hnsw_path"], + hnsw_db_label=context["hnsw_db_label"], + common=context["common"], + hnsw_config=context["hnsw_config"], + ) + ) + + if not args.skip_omega: + results.extend( + run_omega( + args=args, + dataset_spec=context["dataset_spec"], + dataset_artifacts=context["dataset_artifacts"], + bench_bin=context["bench_bin"], + omega_path=context["omega_path"], + omega_db_label=context["omega_db_label"], + common=context["common"], + omega_config=context["omega_config"], + target_recalls=context["target_recalls"], + ) + ) + + if results: + summary_paths = ( + write_grouped_online_summaries(context["dataset_name"], results) + if not args.dry_run + else [] + ) + _emit_result_summary(results, summary_paths) + + return 0 if all(result.success for result in results) else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/benchmark_lib.py b/scripts/benchmark_lib.py new file mode 100644 index 000000000..2499b7002 --- /dev/null +++ b/scripts/benchmark_lib.py @@ -0,0 +1,1222 @@ +from __future__ import annotations + +import contextlib +import json +import os +import re +import shutil +import subprocess +import sys +import time +import urllib.request +from dataclasses import dataclass +from datetime import datetime +from functools import lru_cache +from pathlib import Path +from typing import Any + +try: + import polars as pl +except ImportError: + pl = None + +try: + import zvec +except ImportError: + zvec = None + + +@dataclass +class BenchmarkResult: + type: str + path: str + success: bool + target_recall: float | None + load_duration: float | None = None + qps: float | None = None + recall: float | None = None + avg_latency_ms: float | None = None + p50_latency_ms: float | None = None + p90_latency_ms: float | None = None + p95_latency_ms: float | None = None + p99_latency_ms: float | None = None + + +KV_PATTERN = re.compile(r"([A-Za-z_]+)=([^\s,]+)") +AVG_LINE_PATTERN = re.compile(r"Avg latency:\s*([0-9.]+)ms qps:\s*([0-9.]+)") +PERCENTILE_PATTERN = re.compile(r"(\d+)\s+Percentile:\s*([0-9.]+)\s+ms") +PROCESS_LINE_PATTERN = re.compile( + r"Process query:\s*(\d+), total process time:\s*(\d+)ms, duration:\s*(\d+)ms" +) +RECALL_PATTERN = re.compile(r"Recall@(\d+):\s*([0-9.]+)") + +DATASET_SPECS: dict[str, dict[str, Any]] = { + "cohere_1m": { + "dataset_dirname": "cohere/cohere_medium_1m", + "remote_dirname": "cohere_medium_1m", + "dimension": 768, + "metric_type": "COSINE", + "train_files": ["shuffle_train.parquet"], + }, + "cohere_10m": { + "dataset_dirname": "cohere/cohere_large_10m", + "remote_dirname": "cohere_large_10m", + "dimension": 768, + "metric_type": "COSINE", + "train_files": [f"shuffle_train-{idx:02d}-of-10.parquet" for idx in range(10)], + }, +} + +DATASET_DOWNLOAD_BASE_URLS = { + "S3": "https://assets.zilliz.com/benchmark", + "ALIYUNOSS": "https://assets.zilliz.com.cn/benchmark", +} + + +def load_json(path: Path) -> dict[str, Any]: + with path.open() as f: + return json.load(f) + + +def load_dataset_config(path: Path, dataset_name: str) -> dict[str, Any]: + root = load_json(path) + if dataset_name not in root: + available = ", ".join(sorted(root.keys())) + raise ValueError( + f"Dataset '{dataset_name}' not found in {path}. Available datasets: {available}" + ) + dataset_config = root[dataset_name] + if not isinstance(dataset_config, dict): + raise ValueError(f"Dataset config for '{dataset_name}' must be a JSON object") + return dataset_config + + +def must_get(config: dict[str, Any], key: str) -> Any: + if key not in config: + raise KeyError(f"Missing required config key: {key}") + return config[key] + + +def print_header(title: str) -> None: + emit(f"\n{'=' * 70}") + emit(title) + emit("=" * 70) + + +def emit(message: str = "") -> None: + sys.stdout.write(f"{message}\n") + + +def resolve_paths( + script_path: Path, + config: dict[str, Any], + zvec_root_arg: str | None, + benchmark_dir_arg: str | None, +) -> tuple[Path, Path]: + zvec_root = ( + Path(zvec_root_arg).resolve() if zvec_root_arg else script_path.parent.parent + ) + + config_benchmark_dir = config.get("benchmark_dir") + if benchmark_dir_arg: + benchmark_dir = Path(benchmark_dir_arg).resolve() + elif config_benchmark_dir: + benchmark_dir = Path(config_benchmark_dir).expanduser().resolve() + else: + benchmark_dir = (zvec_root / "benchmark_results").resolve() + + return zvec_root, benchmark_dir + + +def resolve_index_path(benchmark_dir: Path, configured_path: str) -> Path: + path = Path(configured_path).expanduser() + return path.resolve() if path.is_absolute() else (benchmark_dir / path).resolve() + + +def parse_scalar(value: str) -> Any: + lower = value.lower() + if lower in {"true", "false"}: + return lower == "true" + try: + if any(ch in value for ch in [".", "e", "E"]): + return float(value) + return int(value) + except ValueError: + return value + + +def parse_key_values(line: str) -> dict[str, Any]: + return {key: parse_scalar(value) for key, value in KV_PATTERN.findall(line)} + + +def avg_metric(records: list[dict[str, Any]], key: str) -> float | None: + values = [float(record[key]) for record in records if key in record] + if not values: + return None + return sum(values) / len(values) + + +def percentile_metric( + records: list[dict[str, Any]], key: str, percentile: float +) -> float | None: + values = sorted(float(record[key]) for record in records if key in record) + if not values: + return None + if len(values) == 1: + return values[0] + rank = (len(values) - 1) * percentile / 100.0 + lower = int(rank) + upper = min(lower + 1, len(values) - 1) + if lower == upper: + return values[lower] + weight = rank - lower + return values[lower] * (1.0 - weight) + values[upper] * weight + + +def parse_query_records(output: str, prefix: str) -> list[dict[str, Any]]: + records = [] + for line in output.splitlines(): + if prefix in line: + records.append(parse_key_values(line)) + return records + + +def parse_bench_output(output: str) -> dict[str, Any]: + metrics: dict[str, Any] = {} + for line in output.splitlines(): + if (match := PROCESS_LINE_PATTERN.search(line)) is not None: + metrics["process_query_count"] = int(match.group(1)) + metrics["total_process_time_ms"] = int(match.group(2)) + metrics["duration_ms"] = int(match.group(3)) + elif (match := AVG_LINE_PATTERN.search(line)) is not None: + metrics["avg_latency_ms"] = float(match.group(1)) + metrics["qps"] = float(match.group(2)) + elif (match := PERCENTILE_PATTERN.search(line)) is not None: + metrics[f"p{match.group(1)}_latency_ms"] = float(match.group(2)) + return metrics + + +def parse_recall_output(output: str, topk: int) -> float | None: + recall_by_k: dict[int, float] = {} + for line in output.splitlines(): + match = RECALL_PATTERN.search(line) + if match is not None: + recall_by_k[int(match.group(1))] = float(match.group(2)) + if topk in recall_by_k: + return recall_by_k[topk] + if recall_by_k: + return recall_by_k[max(recall_by_k)] + return None + + +def online_summary_path(index_path: Path) -> Path: + return index_path / "online_benchmark_summary.json" + + +def write_online_summary(index_path: Path, payload: dict[str, Any]) -> None: + with online_summary_path(index_path).open("w") as f: + json.dump(payload, f, indent=2, sort_keys=True) + + +def write_grouped_online_summaries( + dataset: str, results: list[BenchmarkResult] +) -> list[Path]: + written_paths: list[Path] = [] + grouped: dict[str, list[BenchmarkResult]] = {} + for result in results: + grouped.setdefault(result.path, []).append(result) + + for path_str, grouped_results in grouped.items(): + index_path = Path(path_str) + write_online_summary( + index_path, + { + "generated_at": datetime.now().isoformat(), + "dataset": dataset, + "results": [ + { + "type": result.type, + "target_recall": result.target_recall, + "path": result.path, + "load_duration_s": result.load_duration, + "qps": result.qps, + "avg_latency_ms": result.avg_latency_ms, + "p50_latency_ms": result.p50_latency_ms, + "p90_latency_ms": result.p90_latency_ms, + "p95_latency_ms": result.p95_latency_ms, + "p99_latency_ms": result.p99_latency_ms, + "recall": result.recall, + } + for result in grouped_results + ], + }, + ) + written_paths.append(online_summary_path(index_path)) + + return written_paths + + +def offline_summary_path(index_path: Path) -> Path: + return index_path / "offline_benchmark_summary.json" + + +def read_json_if_exists(path: Path) -> dict[str, Any]: + if not path.exists(): + return {} + try: + with path.open() as f: + return json.load(f) + except Exception: + return {} + + +def find_omega_model_dir(index_path: Path) -> Path | None: + candidates = sorted(index_path.glob("*/omega_model")) + return candidates[0] if candidates else None + + +def sum_timing_ms(data: dict[str, Any]) -> int: + return sum(v for v in data.values() if isinstance(v, (int, float))) + + +def build_offline_summary( + index_path: Path, + db_label: str, + metrics: dict[str, Any], + retrain_only: bool = False, +) -> dict[str, Any]: + previous_summary = ( + read_json_if_exists(offline_summary_path(index_path)) if retrain_only else {} + ) + previous_offline = previous_summary.get("offline", {}) + previous_omega_training = previous_summary.get("omega_training", {}) + + insert_duration = metrics.get("insert_duration") + optimize_duration = metrics.get("optimize_duration") + load_duration = metrics.get("load_duration") + + omega_model_dir = find_omega_model_dir(index_path) + omega_training = {} + if omega_model_dir is not None: + omega_training = { + "collection_timing_ms": read_json_if_exists( + omega_model_dir / "training_collection_timing.json" + ), + "lightgbm_timing_ms": read_json_if_exists( + omega_model_dir / "lightgbm_training_timing.json" + ), + "lightgbm_training_metrics": read_json_if_exists( + omega_model_dir / "lightgbm_training_metrics.json" + ), + } + + if retrain_only: + insert_duration = previous_offline.get("insert_duration_s") + old_optimize_duration = previous_offline.get("optimize_duration_s") + old_training_s = ( + sum_timing_ms(previous_omega_training.get("collection_timing_ms", {})) + + sum_timing_ms(previous_omega_training.get("lightgbm_timing_ms", {})) + ) / 1000.0 + new_training_s = ( + sum_timing_ms(omega_training.get("collection_timing_ms", {})) + + sum_timing_ms(omega_training.get("lightgbm_timing_ms", {})) + ) / 1000.0 + if old_optimize_duration is not None: + optimize_duration = round( + old_optimize_duration - old_training_s + new_training_s, 4 + ) + load_duration = ( + round(insert_duration + optimize_duration, 4) + if insert_duration is not None and optimize_duration is not None + else None + ) + + summary = { + "db_label": db_label, + "index_path": str(index_path), + "generated_at": datetime.now().isoformat(), + "offline": { + "insert_duration_s": insert_duration, + "optimize_duration_s": optimize_duration, + "load_duration_s": load_duration, + }, + } + if omega_training: + summary["omega_training"] = omega_training + return summary + + +def write_offline_summary( + index_path: Path, db_label: str, metrics: dict[str, Any], retrain_only: bool = False +) -> Path: + summary = build_offline_summary( + index_path, db_label, metrics, retrain_only=retrain_only + ) + path = offline_summary_path(index_path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w") as f: + json.dump(summary, f, indent=2, sort_keys=True) + return path + + +def get_offline_load_duration(index_path: Path) -> float | None: + return ( + read_json_if_exists(offline_summary_path(index_path)) + .get("offline", {}) + .get("load_duration_s") + ) + + +def resolve_dataset_spec( + dataset_name: str, config: dict[str, Any], dataset_root_arg: str | None +) -> dict[str, Any]: + default = DATASET_SPECS.get(dataset_name, {}) + dataset_root = None + if dataset_root_arg: + dataset_root = Path(dataset_root_arg).expanduser().resolve() + elif config.get("dataset_root"): + dataset_root = Path(config["dataset_root"]).expanduser().resolve() + elif os.environ.get("DATASET_LOCAL_DIR"): + dataset_root = Path(os.environ["DATASET_LOCAL_DIR"]).expanduser().resolve() + else: + dataset_root = Path("/tmp/zvec/dataset").resolve() + + dataset_dirname = config.get("dataset_dirname", default.get("dataset_dirname")) + if not dataset_dirname: + raise ValueError( + f"Dataset directory name is not configured for {dataset_name}." + ) + + dimension = int(config.get("dimension", default.get("dimension", 0))) + metric_type = str( + config.get("metric_type", default.get("metric_type", "COSINE")) + ).upper() + remote_dirname = str( + config.get("remote_dirname", default.get("remote_dirname", "")) + ) + train_files = list(config.get("train_files", default.get("train_files", []))) + dataset_source = str( + config.get("dataset_source", os.environ.get("ZVEC_DATASET_SOURCE", "S3")) + ) + download_base_url = str( + config.get( + "dataset_base_url", + os.environ.get( + "ZVEC_DATASET_BASE_URL", + DATASET_DOWNLOAD_BASE_URLS.get( + dataset_source.upper(), DATASET_DOWNLOAD_BASE_URLS["S3"] + ), + ), + ) + ) + if dimension <= 0: + raise ValueError(f"Missing dataset dimension for {dataset_name}") + + dataset_dir = (dataset_root / dataset_dirname).resolve() + return { + "dataset_root": dataset_root, + "dataset_dir": dataset_dir, + "dimension": dimension, + "metric_type": metric_type, + "remote_dirname": remote_dirname, + "train_files": train_files, + "download_base_url": download_base_url.rstrip("/"), + } + + +def _require_polars(): + if pl is None: + raise RuntimeError( + "This script requires polars in the active Python environment." + ) + return pl + + +def _require_zvec(): + if zvec is None: + raise RuntimeError( + "This script requires zvec in the active Python environment." + ) + return zvec + + +def _sorted_train_files(dataset_dir: Path) -> list[Path]: + candidates: list[Path] = [] + for pattern in [ + "shuffle_train-*.parquet", + "train-*.parquet", + "shuffle_train.parquet", + "train.parquet", + ]: + candidates.extend(sorted(dataset_dir.glob(pattern))) + unique: list[Path] = [] + seen: set[Path] = set() + for path in candidates: + if path not in seen: + seen.add(path) + unique.append(path) + return unique + + +def _dataset_required_files( + dataset_name: str, dataset_spec: dict[str, Any] +) -> list[str]: + required = list(dataset_spec.get("train_files", [])) + if not required: + raise ValueError( + f"Dataset {dataset_name} does not define train_files for auto-download" + ) + required.extend(["test.parquet", "neighbors.parquet"]) + return required + + +def _download_file(url: str, output_path: Path) -> None: + output_path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = output_path.with_suffix(output_path.suffix + ".tmp") + try: + with urllib.request.urlopen(url) as response, tmp_path.open("wb") as out: + shutil.copyfileobj(response, out) + tmp_path.replace(output_path) + finally: + tmp_path.unlink(missing_ok=True) + + +def ensure_dataset_available( + dataset_name: str, dataset_spec: dict[str, Any], dry_run: bool +) -> None: + dataset_dir = dataset_spec["dataset_dir"] + required_files = _dataset_required_files(dataset_name, dataset_spec) + missing_files = [ + name for name in required_files if not (dataset_dir / name).exists() + ] + if not missing_files: + return + + remote_dirname = dataset_spec.get("remote_dirname") + if not remote_dirname: + raise FileNotFoundError( + f"Dataset directory is incomplete and auto-download is not configured: {dataset_dir}" + ) + + base_url = dataset_spec["download_base_url"] + emit( + f"Dataset files missing under {dataset_dir}, downloading from {base_url}/{remote_dirname} ..." + ) + if dry_run: + for name in missing_files: + emit( + f"[Dry-run] download {base_url}/{remote_dirname}/{name} -> {dataset_dir / name}" + ) + return + + for name in missing_files: + url = f"{base_url}/{remote_dirname}/{name}" + output_path = dataset_dir / name + emit(f"Downloading {url}") + _download_file(url, output_path) + + +def prepare_dataset_artifacts( + dataset_name: str, + dataset_spec: dict[str, Any], + benchmark_dir: Path, + dry_run: bool = False, +) -> dict[str, Path]: + dataset_dir = dataset_spec["dataset_dir"] + query_parquet = dataset_dir / "test.parquet" + gt_parquet = dataset_dir / "neighbors.parquet" + ensure_dataset_available(dataset_name, dataset_spec, dry_run) + train_files = _sorted_train_files(dataset_dir) + if not dry_run: + if not dataset_dir.exists(): + raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}") + if not query_parquet.exists(): + raise FileNotFoundError(f"Missing query parquet: {query_parquet}") + if not gt_parquet.exists(): + raise FileNotFoundError(f"Missing ground-truth parquet: {gt_parquet}") + if not train_files: + raise FileNotFoundError( + f"No train parquet files found under: {dataset_dir}" + ) + + cache_dir = (benchmark_dir / "_dataset_cache" / dataset_name).resolve() + query_txt = cache_dir / "query.txt" + gt_txt = cache_dir / "groundtruth.txt" + cache_dir.mkdir(parents=True, exist_ok=True) + + if not dry_run: + refresh_query = ( + not query_txt.exists() + or query_txt.stat().st_mtime < query_parquet.stat().st_mtime + ) + refresh_gt = ( + not gt_txt.exists() + ) or gt_txt.stat().st_mtime < gt_parquet.stat().st_mtime + if refresh_query: + _write_query_text(query_parquet, query_txt) + if refresh_gt: + _write_groundtruth_text(gt_parquet, gt_txt) + + return { + "dataset_dir": dataset_dir, + "query_parquet": query_parquet, + "gt_parquet": gt_parquet, + "query_txt": query_txt, + "groundtruth_txt": gt_txt, + "train_files": train_files, + } + + +def _write_query_text(query_parquet: Path, output_path: Path) -> None: + polars = _require_polars() + frame = polars.read_parquet(query_parquet).sort("id") + with output_path.open("w") as f: + for row in frame.iter_rows(named=True): + vector = row["emb"] + vector_text = " ".join(str(round(float(v), 16)) for v in vector) + f.write(f"{int(row['id'])};{vector_text};\n") + + +def _write_groundtruth_text(gt_parquet: Path, output_path: Path) -> None: + polars = _require_polars() + frame = polars.read_parquet(gt_parquet).sort("id") + with output_path.open("w") as f: + for row in frame.iter_rows(named=True): + neighbors = " ".join(str(int(v)) for v in row["neighbors_id"]) + f.write(f"{int(row['id'])};{neighbors}\n") + + +@lru_cache(maxsize=1) +def _initialized_zvec(): + module = _require_zvec() + module.init(log_level=module.LogLevel.WARN) + return module + + +def _quantize_type_from_name(name: str): + module = _initialized_zvec() + normalized = str(name).upper() + mapping = { + "": module.QuantizeType.UNDEFINED, + "UNDEFINED": module.QuantizeType.UNDEFINED, + "FP16": module.QuantizeType.FP16, + "INT8": module.QuantizeType.INT8, + "INT4": module.QuantizeType.INT4, + } + if normalized not in mapping: + raise ValueError(f"Unsupported quantize type: {name}") + return mapping[normalized] + + +def _metric_type_from_name(name: str): + module = _initialized_zvec() + normalized = str(name).upper() + mapping = { + "COSINE": module.MetricType.COSINE, + "IP": module.MetricType.IP, + "L2": module.MetricType.L2, + } + if normalized not in mapping: + raise ValueError(f"Unsupported metric type: {name}") + return mapping[normalized] + + +def _maybe_destroy_collection(path: Path) -> None: + module = _initialized_zvec() + if not path.exists(): + return + try: + module.open(str(path)).destroy() + return + except Exception: + pass + shutil.rmtree(path, ignore_errors=True) + + +def _build_schema( + index_kind: str, + dimension: int, + metric_type: str, + common_args: dict[str, Any], + specific_args: dict[str, Any], +): + module = _initialized_zvec() + quantize_type = _quantize_type_from_name(common_args.get("quantize_type", "")) + metric = _metric_type_from_name(metric_type) + if index_kind == "OMEGA": + index_param = module.OmegaIndexParam( + metric_type=metric, + m=int(common_args["m"]), + ef_construction=int(specific_args.get("ef_construction", 500)), + quantize_type=quantize_type, + min_vector_threshold=int(specific_args["min_vector_threshold"]), + num_training_queries=int(specific_args["num_training_queries"]), + ef_training=int(specific_args["ef_training"]), + window_size=int(specific_args["window_size"]), + ef_groundtruth=int(specific_args["ef_groundtruth"]), + k_train=int(specific_args.get("k_train", 1)), + ) + else: + index_param = module.HnswIndexParam( + metric_type=metric, + m=int(common_args["m"]), + ef_construction=int(specific_args.get("ef_construction", 500)), + quantize_type=quantize_type, + ) + + return module.CollectionSchema( + name=f"{index_kind.lower()}_benchmark", + fields=[ + module.FieldSchema( + "id", + module.DataType.INT64, + nullable=False, + index_param=module.InvertIndexParam(enable_range_optimization=True), + ) + ], + vectors=[ + module.VectorSchema( + "dense", + module.DataType.VECTOR_FP32, + dimension=dimension, + index_param=index_param, + ) + ], + ) + + +def build_index( + *, + index_kind: str, + index_path: Path, + dataset_spec: dict[str, Any], + dataset_artifacts: dict[str, Any], + common_args: dict[str, Any], + specific_args: dict[str, Any], + retrain_only: bool, + dry_run: bool, +) -> dict[str, Any]: + if dry_run: + emit(f"[Dry-run] Build {index_kind} at {index_path}") + return { + "insert_duration": None, + "optimize_duration": None, + "load_duration": None, + } + + module = _initialized_zvec() + + if retrain_only: + collection = module.open( + str(index_path), module.CollectionOption(read_only=False, enable_mmap=True) + ) + insert_duration = None + else: + _maybe_destroy_collection(index_path) + schema = _build_schema( + index_kind, + dataset_spec["dimension"], + dataset_spec["metric_type"], + common_args, + specific_args, + ) + collection = module.create_and_open( + str(index_path), + schema, + module.CollectionOption(read_only=False, enable_mmap=True), + ) + insert_duration = _insert_training_data( + collection, dataset_artifacts["train_files"] + ) + + optimize_start = time.perf_counter() + collection.optimize(option=module.OptimizeOption(retrain_only=retrain_only)) + optimize_duration = time.perf_counter() - optimize_start + with contextlib.suppress(Exception): + collection.flush() + del collection + + load_duration = None + if insert_duration is not None: + load_duration = insert_duration + optimize_duration + elif optimize_duration is not None: + load_duration = optimize_duration + + return { + "insert_duration": round(insert_duration, 4) + if insert_duration is not None + else None, + "optimize_duration": round(optimize_duration, 4) + if optimize_duration is not None + else None, + "load_duration": round(load_duration, 4) if load_duration is not None else None, + } + + +def _insert_training_data( + collection, train_files: list[Path], batch_size: int = 1000 +) -> float: + module = _initialized_zvec() + polars = _require_polars() + start = time.perf_counter() + for train_file in train_files: + frame = polars.read_parquet(train_file) + for offset in range(0, frame.height, batch_size): + batch = frame.slice(offset, batch_size) + ids = batch["id"].to_list() + vectors = batch["emb"].to_list() + docs = [ + module.Doc( + id=str(int(doc_id)), + fields={"id": int(doc_id)}, + vectors={"dense": vector}, + ) + for doc_id, vector in zip(ids, vectors, strict=True) + ] + collection.insert(docs) + return time.perf_counter() - start + + +def compute_recall_with_zvec( + *, + index_kind: str, + index_path: Path, + dataset_artifacts: dict[str, Any], + common_args: dict[str, Any], + target_recall: float | None, + dry_run: bool, +) -> float | None: + if dry_run: + return None + + module = _initialized_zvec() + polars = _require_polars() + query_frame = polars.read_parquet(dataset_artifacts["query_parquet"]).sort("id") + gt_frame = polars.read_parquet(dataset_artifacts["gt_parquet"]).sort("id") + gt_map = { + int(row["id"]): [ + int(value) for value in row["neighbors_id"][: int(common_args["k"])] + ] + for row in gt_frame.iter_rows(named=True) + } + + option = module.CollectionOption(read_only=True, enable_mmap=True) + collection = module.open(str(index_path), option) + use_refiner = bool(common_args.get("is_using_refiner", False)) + if index_kind == "OMEGA": + query_param = module.OmegaQueryParam( + ef=int(common_args["ef_search"]), + target_recall=float(target_recall), + is_using_refiner=use_refiner, + ) + else: + query_param = module.HnswQueryParam( + ef=int(common_args["ef_search"]), + is_using_refiner=use_refiner, + ) + + recall_sum = 0.0 + query_count = 0 + topk = int(common_args["k"]) + for row in query_frame.iter_rows(named=True): + query_id = int(row["id"]) + gt = gt_map.get(query_id) + if not gt: + continue + results = collection.query( + vectors=module.VectorQuery( + field_name="dense", vector=row["emb"], param=query_param + ), + topk=topk, + output_fields=[], + ) + pred = [int(doc.id) for doc in results[:topk]] + recall_sum += len(set(pred) & set(gt)) / float(topk) + query_count += 1 + + del collection + if query_count == 0: + return None + return recall_sum / query_count + + +def resolve_core_tools(zvec_root: Path) -> tuple[Path, Path]: + bench_bin = (zvec_root / "build/bin/bench").resolve() + recall_bin = (zvec_root / "build/bin/recall").resolve() + if not bench_bin.exists(): + raise FileNotFoundError(f"bench binary not found: {bench_bin}") + if not recall_bin.exists(): + raise FileNotFoundError(f"recall binary not found: {recall_bin}") + return bench_bin, recall_bin + + +def _metric_type_name_for_core(metric_type: str) -> str: + mapping = { + "COSINE": "kCosine", + "IP": "kInnerProduct", + "L2": "kL2sq", + } + normalized = str(metric_type).upper() + if normalized not in mapping: + raise ValueError(f"Unsupported metric type: {metric_type}") + return mapping[normalized] + + +def _quantizer_json(quantize_type: str) -> dict[str, Any] | None: + normalized = str(quantize_type).upper() + if normalized in {"", "UNDEFINED"}: + return None + mapping = { + "FP16": "kFP16", + "INT8": "kInt8", + "INT4": "kInt4", + } + if normalized not in mapping: + raise ValueError(f"Unsupported quantize type: {quantize_type}") + return {"type": mapping[normalized]} + + +def build_core_index_config_json( + *, + index_type: str, + metric_type: str, + dimension: int, + m: int, + ef_construction: int, + quantize_type: str, +) -> str: + payload: dict[str, Any] = { + "index_type": index_type, + "metric_type": _metric_type_name_for_core(metric_type), + "dimension": int(dimension), + "version": 0, + "is_sparse": False, + "data_type": "DT_FP32", + "use_id_map": False, + "is_huge_page": False, + "m": int(m), + "ef_construction": int(ef_construction), + } + quantizer = _quantizer_json(quantize_type) + if quantizer is not None: + payload["quantizer_param"] = quantizer + return json.dumps(payload, separators=(",", ":")) + + +def build_core_query_param_json( + *, + index_type: str, + ef_search: int, + topk: int, + target_recall: float | None = None, +) -> str: + payload: dict[str, Any] = { + "index_type": index_type, + "topk": int(topk), + "fetch_vector": False, + "radius": 0.0, + "is_linear": False, + "ef_search": int(ef_search), + } + if target_recall is not None: + payload["target_recall"] = float(target_recall) + return json.dumps(payload, separators=(",", ":")) + + +def discover_index_files(index_path: Path) -> dict[str, Path | None]: + coarse_candidates = sorted(index_path.glob("*/dense.qindex.*.proxima")) + full_candidates = sorted(index_path.glob("*/dense.index.*.proxima")) + primary = ( + coarse_candidates[0] + if coarse_candidates + else (full_candidates[0] if full_candidates else None) + ) + reference = full_candidates[0] if full_candidates else None + if primary is None: + raise FileNotFoundError(f"No core index file found under {index_path}") + return {"primary": primary, "reference": reference} + + +def _yaml_quote(value: str) -> str: + return "'" + value.replace("'", "''") + "'" + + +def _write_core_config( + *, + path: Path, + index_path: Path, + index_config_json: str, + query_param_json: str, + query_file: Path, + topk: int, + use_refiner: bool, + reference_index_path: Path | None, + metric_type: str, + dimension: int, + m: int, + ef_construction: int, + bench_thread_count: int | None = None, + bench_secs: int | None = None, + recall_thread_count: int | None = None, + groundtruth_file: Path | None = None, +) -> None: + lines = [ + "IndexCommon:", + f" IndexPath: {_yaml_quote(str(index_path))}", + f" IndexConfig: {_yaml_quote(index_config_json)}", + f" TopK: {_yaml_quote(str(topk))}", + f" QueryFile: {_yaml_quote(str(query_file))}", + " QueryType: 'float'", + " QueryFirstSep: ';'", + " QuerySecondSep: ' '", + ] + if bench_thread_count is not None: + lines.append(f" BenchThreadCount: {bench_thread_count}") + if bench_secs is not None: + lines.append(f" BenchSecs: {bench_secs}") + lines.append(" BenchIterCount: 1000000000") + if recall_thread_count is not None: + lines.append(f" RecallThreadCount: {recall_thread_count}") + lines.append(f" RecallGTCount: {topk}") + lines.append(" CompareById: true") + if groundtruth_file is not None: + lines.append(f" GroundTruthFile: {_yaml_quote(str(groundtruth_file))}") + lines.append(" GroundTruthFirstSep: ';'") + lines.append(" GroundTruthSecondSep: ' '") + + lines.extend( + [ + "QueryConfig:", + f" QueryParam: {_yaml_quote(query_param_json)}", + ] + ) + + if use_refiner: + if reference_index_path is None: + raise ValueError("Refiner requested but reference index is missing") + reference_config = build_core_index_config_json( + index_type="kHNSW", + metric_type=metric_type, + dimension=dimension, + m=m, + ef_construction=ef_construction, + quantize_type="UNDEFINED", + ) + lines.extend( + [ + " RefinerConfig:", + " ScaleFactor: 2", + " ReferenceIndex:", + f" Config: {_yaml_quote(reference_config)}", + f" Path: {_yaml_quote(str(reference_index_path))}", + ] + ) + + path.write_text("\n".join(lines) + "\n") + + +def run_command_capture( + cmd: list[str], + *, + cwd: Path | None = None, + dry_run: bool = False, + extra_env: dict[str, str] | None = None, +) -> tuple[int, str]: + printable = " ".join(str(token) for token in cmd) + emit(printable) + if dry_run: + return 0, "" + + env = os.environ.copy() + if extra_env: + env.update(extra_env) + completed = subprocess.run( + cmd, + cwd=str(cwd) if cwd else None, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + check=False, + ) + if completed.returncode != 0 and completed.stdout: + sys.stdout.write(completed.stdout) + if not completed.stdout.endswith("\n"): + sys.stdout.write("\n") + return completed.returncode, completed.stdout + + +def _temporary_config_path(prefix: str, parent: Path) -> Path: + parent.mkdir(parents=True, exist_ok=True) + return parent / f"{prefix}_{int(time.time() * 1000)}.yaml" + + +def run_bench( + *, + bench_bin: Path, + index_file: Path, + query_file: Path, + metric_type: str, + dimension: int, + m: int, + ef_construction: int, + quantize_type: str, + ef_search: int, + topk: int, + bench_thread_count: int, + bench_secs: int, + use_refiner: bool, + reference_index_path: Path | None, + target_recall: float | None = None, + dry_run: bool = False, + extra_env: dict[str, str] | None = None, +) -> tuple[int, str, dict[str, Any]]: + config_path = _temporary_config_path("bench", index_file.parent) + try: + _write_core_config( + path=config_path, + index_path=index_file, + index_config_json=build_core_index_config_json( + index_type="kOMEGA" if target_recall is not None else "kHNSW", + metric_type=metric_type, + dimension=dimension, + m=m, + ef_construction=ef_construction, + quantize_type=quantize_type, + ), + query_param_json=build_core_query_param_json( + index_type="kOMEGA" if target_recall is not None else "kHNSW", + ef_search=ef_search, + topk=topk, + target_recall=target_recall, + ), + query_file=query_file, + topk=topk, + use_refiner=use_refiner, + reference_index_path=reference_index_path, + metric_type=metric_type, + dimension=dimension, + m=m, + ef_construction=ef_construction, + bench_thread_count=bench_thread_count, + bench_secs=bench_secs, + ) + ret, output = run_command_capture( + [str(bench_bin), str(config_path)], + dry_run=dry_run, + extra_env=extra_env, + ) + return ret, output, parse_bench_output(output) + finally: + if config_path.exists(): + config_path.unlink() + + +def run_recall( + *, + recall_bin: Path, + index_file: Path, + query_file: Path, + groundtruth_file: Path, + metric_type: str, + dimension: int, + m: int, + ef_construction: int, + quantize_type: str, + ef_search: int, + topk: int, + use_refiner: bool, + reference_index_path: Path | None, + target_recall: float | None = None, + dry_run: bool = False, +) -> tuple[int, str, float | None]: + config_path = _temporary_config_path("recall", index_file.parent) + try: + _write_core_config( + path=config_path, + index_path=index_file, + index_config_json=build_core_index_config_json( + index_type="kOMEGA" if target_recall is not None else "kHNSW", + metric_type=metric_type, + dimension=dimension, + m=m, + ef_construction=ef_construction, + quantize_type=quantize_type, + ), + query_param_json=build_core_query_param_json( + index_type="kOMEGA" if target_recall is not None else "kHNSW", + ef_search=ef_search, + topk=topk, + target_recall=target_recall, + ), + query_file=query_file, + topk=topk, + use_refiner=use_refiner, + reference_index_path=reference_index_path, + metric_type=metric_type, + dimension=dimension, + m=m, + ef_construction=ef_construction, + recall_thread_count=1, + groundtruth_file=groundtruth_file, + ) + ret, output = run_command_capture( + [str(recall_bin), str(config_path)], dry_run=dry_run + ) + return ret, output, parse_recall_output(output, topk) + finally: + if config_path.exists(): + config_path.unlink() + + +def run_concurrency_benchmark( + *, + bench_bin: Path, + index_files: dict[str, Path | None], + dataset_artifacts: dict[str, Any], + dataset_spec: dict[str, Any], + common_args: dict[str, Any], + target_recall: float | None, + dry_run: bool, +) -> dict[str, Any]: + ef_search = int(common_args["ef_search"]) + topk = int(common_args["k"]) + m = int(common_args["m"]) + ef_construction = int(common_args.get("ef_construction", 500)) + quantize_type = str(common_args.get("quantize_type", "UNDEFINED")) + use_refiner = bool(common_args.get("is_using_refiner", False)) + duration = int(common_args["concurrency_duration"]) + thread_counts = [ + int(value) for value in str(common_args["num_concurrency"]).split(",") if value + ] + + best_summary: dict[str, Any] | None = None + best_output = "" + for thread_count in thread_counts: + ret, output, summary = run_bench( + bench_bin=bench_bin, + index_file=index_files["primary"], + query_file=dataset_artifacts["query_txt"], + metric_type=dataset_spec["metric_type"], + dimension=dataset_spec["dimension"], + m=m, + ef_construction=ef_construction, + quantize_type=quantize_type, + ef_search=ef_search, + topk=topk, + bench_thread_count=thread_count, + bench_secs=duration, + use_refiner=use_refiner, + reference_index_path=index_files["reference"], + target_recall=target_recall, + dry_run=dry_run, + ) + summary["thread_count"] = thread_count + summary["retcode"] = ret + if best_summary is None or (summary.get("qps") or 0.0) > ( + best_summary.get("qps") or 0.0 + ): + best_summary = summary + best_output = output + + return {"summary": best_summary or {}, "output": best_output} diff --git a/src/binding/c/c_api.cc b/src/binding/c/c_api.cc index 255b03a1e..8764265bf 100644 --- a/src/binding/c/c_api.cc +++ b/src/binding/c/c_api.cc @@ -1332,6 +1332,14 @@ ZVecIndexParams *zvec_index_params_create(ZVecIndexType index_type) { zvec::core_interface::kDefaultHnswEfConstruction, // ef_construction zvec::QuantizeType::UNDEFINED); break; + case ZVEC_INDEX_TYPE_OMEGA: + cpp_params = + new zvec::OmegaIndexParams( + zvec::MetricType::L2, // metric_type + zvec::core_interface::kDefaultHnswNeighborCnt, // m + zvec::core_interface::kDefaultHnswEfConstruction, // ef_construction + zvec::QuantizeType::UNDEFINED); + break; case ZVEC_INDEX_TYPE_IVF: cpp_params = new zvec::IVFIndexParams(zvec::MetricType::L2, // metric_type @@ -1490,15 +1498,21 @@ ZVecErrorCode zvec_index_params_set_hnsw_params(ZVecIndexParams *params, int m, return ZVEC_ERROR_INVALID_ARGUMENT; } auto *cpp_params = reinterpret_cast(params); - auto *hnsw_params = dynamic_cast(cpp_params); - if (!hnsw_params) { + if (auto *hnsw_params = dynamic_cast(cpp_params)) { + hnsw_params->set_m(m); + hnsw_params->set_ef_construction(ef_construction); + return ZVEC_OK; + } + if (auto *omega_params = dynamic_cast(cpp_params)) { + omega_params->set_m(m); + omega_params->set_ef_construction(ef_construction); + return ZVEC_OK; + } + { SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, - "Invalid params or not HNSW index type"); + "Invalid params or not HNSW/OMEGA index type"); return ZVEC_ERROR_INVALID_ARGUMENT; } - hnsw_params->set_m(m); - hnsw_params->set_ef_construction(ef_construction); - return ZVEC_OK; } /** @@ -1513,13 +1527,19 @@ int zvec_index_params_get_hnsw_m(const ZVecIndexParams *params) { return 0; } auto *cpp_params = reinterpret_cast(params); - auto *hnsw_params = dynamic_cast(cpp_params); - if (!hnsw_params) { + if (auto *hnsw_params = + dynamic_cast(cpp_params)) { + return hnsw_params->m(); + } + if (auto *omega_params = + dynamic_cast(cpp_params)) { + return omega_params->m(); + } + { SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, - "Invalid params or not HNSW index type"); + "Invalid params or not HNSW/OMEGA index type"); return 0; } - return hnsw_params->m(); } /** @@ -1534,13 +1554,19 @@ int zvec_index_params_get_hnsw_ef_construction(const ZVecIndexParams *params) { return 0; } auto *cpp_params = reinterpret_cast(params); - auto *hnsw_params = dynamic_cast(cpp_params); - if (!hnsw_params) { + if (auto *hnsw_params = + dynamic_cast(cpp_params)) { + return hnsw_params->ef_construction(); + } + if (auto *omega_params = + dynamic_cast(cpp_params)) { + return omega_params->ef_construction(); + } + { SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, - "Invalid params or not HNSW index type"); + "Invalid params or not HNSW/OMEGA index type"); return 0; } - return hnsw_params->ef_construction(); } /** @@ -4609,6 +4635,113 @@ ZVecHnswQueryParams *zvec_query_params_hnsw_create(int ef, float radius, return nullptr; } +ZVecOmegaQueryParams *zvec_query_params_omega_create(int ef, + float target_recall, + float radius, + bool is_linear, + bool is_using_refiner) { + ZVEC_TRY_RETURN_NULL( + "Failed to create OmegaQueryParams", + auto *params = new zvec::OmegaQueryParams(ef, target_recall, radius, + is_linear, is_using_refiner); + return reinterpret_cast(params);) + + return nullptr; +} + +void zvec_query_params_omega_destroy(ZVecOmegaQueryParams *params) { + if (params) { + delete reinterpret_cast(params); + } +} + +ZVecErrorCode zvec_query_params_omega_set_ef(ZVecOmegaQueryParams *params, + int ef) { + if (!params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, "Params pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(params); + ptr->set_ef(ef); + return ZVEC_OK; +} + +int zvec_query_params_omega_get_ef(const ZVecOmegaQueryParams *params) { + if (!params) return 0; + auto *ptr = reinterpret_cast(params); + return ptr->ef(); +} + +ZVecErrorCode zvec_query_params_omega_set_target_recall( + ZVecOmegaQueryParams *params, float target_recall) { + if (!params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, "Params pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(params); + ptr->set_target_recall(target_recall); + return ZVEC_OK; +} + +float zvec_query_params_omega_get_target_recall( + const ZVecOmegaQueryParams *params) { + if (!params) return 0.0f; + auto *ptr = reinterpret_cast(params); + return ptr->target_recall(); +} + +ZVecErrorCode zvec_query_params_omega_set_radius(ZVecOmegaQueryParams *params, + float radius) { + if (!params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, "Params pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(params); + ptr->set_radius(radius); + return ZVEC_OK; +} + +float zvec_query_params_omega_get_radius(const ZVecOmegaQueryParams *params) { + if (!params) return 0.0f; + auto *ptr = reinterpret_cast(params); + return ptr->radius(); +} + +ZVecErrorCode zvec_query_params_omega_set_is_linear( + ZVecOmegaQueryParams *params, bool is_linear) { + if (!params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, "Params pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(params); + ptr->set_is_linear(is_linear); + return ZVEC_OK; +} + +bool zvec_query_params_omega_get_is_linear(const ZVecOmegaQueryParams *params) { + if (!params) return false; + auto *ptr = reinterpret_cast(params); + return ptr->is_linear(); +} + +ZVecErrorCode zvec_query_params_omega_set_is_using_refiner( + ZVecOmegaQueryParams *params, bool is_using_refiner) { + if (!params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, "Params pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + auto *ptr = reinterpret_cast(params); + ptr->set_is_using_refiner(is_using_refiner); + return ZVEC_OK; +} + +bool zvec_query_params_omega_get_is_using_refiner( + const ZVecOmegaQueryParams *params) { + if (!params) return false; + auto *ptr = reinterpret_cast(params); + return ptr->is_using_refiner(); +} + void zvec_query_params_hnsw_destroy(ZVecHnswQueryParams *params) { if (params) { delete reinterpret_cast(params); @@ -5102,6 +5235,20 @@ ZVecErrorCode zvec_vector_query_set_hnsw_params( return ZVEC_OK; } +ZVecErrorCode zvec_vector_query_set_omega_params( + ZVecVectorQuery *query, ZVecOmegaQueryParams *omega_params) { + if (!query || !omega_params) { + SET_LAST_ERROR(ZVEC_ERROR_INVALID_ARGUMENT, + "Query or OMEGA params pointer is null"); + return ZVEC_ERROR_INVALID_ARGUMENT; + } + + auto *query_ptr = reinterpret_cast(query); + auto *params_ptr = reinterpret_cast(omega_params); + query_ptr->query_params_.reset(params_ptr); + return ZVEC_OK; +} + ZVecErrorCode zvec_vector_query_set_ivf_params(ZVecVectorQuery *query, ZVecIVFQueryParams *ivf_params) { if (!query || !ivf_params) { diff --git a/src/binding/python/CMakeLists.txt b/src/binding/python/CMakeLists.txt index 7e5169176..3a176834d 100644 --- a/src/binding/python/CMakeLists.txt +++ b/src/binding/python/CMakeLists.txt @@ -19,8 +19,13 @@ set(SRC_LISTS pybind11_add_module(_zvec ${SRC_LISTS}) if (CMAKE_SYSTEM_NAME STREQUAL "Linux") + set(_zvec_force_link_libs) + if(ZVEC_ENABLE_OMEGA) + list(APPEND _zvec_force_link_libs $) + endif() target_link_libraries(_zvec PRIVATE -Wl,--whole-archive + ${_zvec_force_link_libs} $ $ $ @@ -39,7 +44,12 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Linux") "LINKER:--version-script=${CMAKE_CURRENT_SOURCE_DIR}/exports.map" ) elseif (APPLE) + set(_zvec_force_load_libs) + if(ZVEC_ENABLE_OMEGA) + list(APPEND _zvec_force_load_libs -Wl,-force_load,$) + endif() target_link_libraries(_zvec PRIVATE + ${_zvec_force_load_libs} -Wl,-force_load,$ -Wl,-force_load,$ -Wl,-force_load,$ @@ -69,6 +79,9 @@ elseif (MSVC) core_utility_static core_quantizer_static ) + if(ZVEC_ENABLE_OMEGA) + list(PREPEND _zvec_whole_archive_libs core_knn_omega_static) + endif() target_link_libraries(_zvec PRIVATE ${_zvec_whole_archive_libs} zvec_db @@ -80,4 +93,4 @@ elseif (MSVC) endforeach() endif () -target_include_directories(_zvec PRIVATE ${PYBIND11_INCLUDE_DIR} ${PROJECT_ROOT_DIR}/src ${PROJECT_ROOT_DIR}/src/binding/python/include) \ No newline at end of file +target_include_directories(_zvec PRIVATE ${PYBIND11_INCLUDE_DIR} ${PROJECT_ROOT_DIR}/src ${PROJECT_ROOT_DIR}/src/binding/python/include) diff --git a/src/binding/python/binding.cc b/src/binding/python/binding.cc index ed8d6918d..7a4d8acf6 100644 --- a/src/binding/python/binding.cc +++ b/src/binding/python/binding.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include "python_collection.h" #include "python_config.h" #include "python_doc.h" diff --git a/src/binding/python/model/param/python_param.cc b/src/binding/python/model/param/python_param.cc index 6ad7e1b58..b413814a4 100644 --- a/src/binding/python/model/param/python_param.cc +++ b/src/binding/python/model/param/python_param.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include "python_doc.h" namespace zvec { @@ -31,6 +32,8 @@ static std::string index_type_to_string(const IndexType type) { return "IVF"; case IndexType::HNSW: return "HNSW"; + case IndexType::OMEGA: + return "OMEGA"; case IndexType::HNSW_RABITQ: return "HNSW_RABITQ"; default: @@ -643,6 +646,163 @@ Constructs an IVFIndexParam instance. t[0].cast(), t[1].cast(), t[2].cast(), t[3].cast(), t[4].cast()); })); + + // OmegaIndexParams + py::class_> + omega_params(m, "OmegaIndexParam", R"pbdoc( +Parameters for configuring an OMEGA index. + +OMEGA is an advanced graph-based index that uses machine learning to optimize +search performance. It builds on HNSW and can automatically train a model to +predict when to stop searching. + +Attributes: + metric_type (MetricType): Distance metric used for similarity computation. + Default is ``MetricType.IP`` (inner product). + m (int): Number of bi-directional links created for every new element + during construction. Higher values improve accuracy but increase + memory usage and construction time. Default is 100. + ef_construction (int): Size of the dynamic candidate list for nearest + neighbors during index construction. Larger values yield better + graph quality at the cost of slower build time. Default is 500. + quantize_type (QuantizeType): Optional quantization type for vector + compression (e.g., FP16, INT8). Default is `QuantizeType.UNDEFINED` to + disable quantization. + min_vector_threshold (int): Minimum number of vectors required to enable + OMEGA optimization. Below this threshold, standard HNSW is used. + Default is 100000. + num_training_queries (int): Number of training queries to generate for + OMEGA model training. Default is 1000. + ef_training (int): Size of the candidate list (ef) used during training + searches. Larger values collect more training data but take longer. + Default is 1000. + window_size (int): Size of the sliding window for computing distance + statistics during search. Must be the same for training and inference. + Default is 100. + ef_groundtruth (int): Size of the candidate list (ef) used when computing + ground truth for training. If 0, brute force search is used (slower but exact). + If > 0, HNSW search with this ef is used (faster but approximate). + Default is 0 (brute force). + k_train (int): Number of top ground-truth results that must be present in + the current top-k before a training record is labeled positive. + Default is 1. + +Examples: + >>> from zvec.typing import MetricType, QuantizeType + >>> params = OmegaIndexParam( + ... metric_type=MetricType.L2, + ... m=16, + ... ef_construction=200, + ... min_vector_threshold=50000, + ... num_training_queries=500, + ... ef_training=800, + ... window_size=100, + ... ef_groundtruth=2000, # Use HNSW for faster ground truth computation + ... k_train=10 + ... ) + >>> print(params.k_train) + 10 +)pbdoc"); + omega_params + .def(py::init(), + py::arg("metric_type") = MetricType::IP, + py::arg("m") = core_interface::kDefaultHnswNeighborCnt, + py::arg("ef_construction") = + core_interface::kDefaultHnswEfConstruction, + py::arg("quantize_type") = QuantizeType::UNDEFINED, + py::arg("min_vector_threshold") = 100000, + py::arg("num_training_queries") = 1000, + py::arg("ef_training") = 1000, py::arg("window_size") = 100, + py::arg("ef_groundtruth") = 0, py::arg("k_train") = 1) + .def_property_readonly( + "m", &OmegaIndexParams::m, + "int: Maximum number of neighbors per node in upper layers.") + .def_property_readonly( + "ef_construction", &OmegaIndexParams::ef_construction, + "int: Candidate list size during index construction.") + .def_property_readonly( + "min_vector_threshold", &OmegaIndexParams::min_vector_threshold, + "int: Minimum vectors required to enable OMEGA optimization.") + .def_property_readonly( + "num_training_queries", &OmegaIndexParams::num_training_queries, + "int: Number of training queries for OMEGA model training.") + .def_property_readonly( + "ef_training", &OmegaIndexParams::ef_training, + "int: Candidate list size (ef) used during training searches.") + .def_property_readonly( + "window_size", &OmegaIndexParams::window_size, + "int: Sliding window size for distance statistics.") + .def_property_readonly( + "ef_groundtruth", &OmegaIndexParams::ef_groundtruth, + "int: ef for ground truth computation (0=brute force, >0=HNSW).") + .def_property_readonly("k_train", &OmegaIndexParams::k_train, + "int: Number of top GT results required for a " + "positive training label.") + .def( + "to_dict", + [](const OmegaIndexParams &self) -> py::dict { + py::dict dict; + dict["type"] = index_type_to_string(self.type()); + dict["metric_type"] = metric_type_to_string(self.metric_type()); + dict["m"] = self.m(); + dict["ef_construction"] = self.ef_construction(); + dict["min_vector_threshold"] = self.min_vector_threshold(); + dict["num_training_queries"] = self.num_training_queries(); + dict["ef_training"] = self.ef_training(); + dict["window_size"] = self.window_size(); + dict["ef_groundtruth"] = self.ef_groundtruth(); + dict["k_train"] = self.k_train(); + dict["quantize_type"] = + quantize_type_to_string(self.quantize_type()); + return dict; + }, + "Convert to dictionary with all fields") + .def("__repr__", + [](const OmegaIndexParams &self) -> std::string { + return "{" + "\"metric_type\":" + + metric_type_to_string(self.metric_type()) + + ", \"m\":" + std::to_string(self.m()) + + ", \"ef_construction\":" + + std::to_string(self.ef_construction()) + + ", \"min_vector_threshold\":" + + std::to_string(self.min_vector_threshold()) + + ", \"num_training_queries\":" + + std::to_string(self.num_training_queries()) + + ", \"ef_training\":" + std::to_string(self.ef_training()) + + ", \"window_size\":" + std::to_string(self.window_size()) + + ", \"ef_groundtruth\":" + + std::to_string(self.ef_groundtruth()) + + ", \"k_train\":" + std::to_string(self.k_train()) + + ", \"quantize_type\":" + + quantize_type_to_string(self.quantize_type()) + "}"; + }) + .def(py::pickle( + [](const OmegaIndexParams &self) { + return py::make_tuple( + self.metric_type(), self.m(), self.ef_construction(), + self.quantize_type(), self.min_vector_threshold(), + self.num_training_queries(), self.ef_training(), + self.window_size(), self.ef_groundtruth(), self.k_train()); + }, + [](py::tuple t) { + if (t.size() == 10) { + return std::make_shared( + t[0].cast(), t[1].cast(), t[2].cast(), + t[3].cast(), t[4].cast(), + t[5].cast(), t[6].cast(), t[7].cast(), + t[8].cast(), t[9].cast()); + } + if (t.size() != 9) + throw std::runtime_error("Invalid state for OmegaIndexParams"); + return std::make_shared( + t[0].cast(), t[1].cast(), t[2].cast(), + t[3].cast(), t[4].cast(), + t[5].cast(), t[6].cast(), t[7].cast(), + t[8].cast(), 1); + })); } void ZVecPyParams::bind_query_params(py::module_ &m) { @@ -760,6 +920,85 @@ Constructs an HnswQueryParam instance. return obj; })); + // binding omega query params + py::class_> + omega_query_params(m, "OmegaQueryParam", R"pbdoc( +Query parameters for OMEGA index with adaptive early stopping. + +OMEGA extends HNSW with machine learning-based early stopping that can +dynamically adjust search effort to meet a target recall. + +Attributes: + type (IndexType): Always ``IndexType.OMEGA``. + ef (int): Size of the dynamic candidate list during search. + Larger values improve recall but slow down search. + Default is 300. + target_recall (float): Target recall for OMEGA early stopping. + OMEGA will stop searching when predicted recall meets this target. + Valid range: 0.0 to 1.0. Default is 0.95. + radius (float): Search radius for range queries. Default is 0.0. + is_linear (bool): Force linear search. Default is False. + is_using_refiner (bool): Whether to use refiner for the query. Default is False. + +Examples: + >>> params = OmegaQueryParam(ef=300, target_recall=0.98) + >>> print(params.target_recall) + 0.98 +)pbdoc"); + omega_query_params + .def(py::init(), + py::arg("ef") = core_interface::kDefaultHnswEfSearch, + py::arg("target_recall") = 0.95f, py::arg("radius") = 0.0f, + py::arg("is_linear") = false, py::arg("is_using_refiner") = false, + R"pbdoc( +Constructs an OmegaQueryParam instance. + +Args: + ef (int, optional): Search-time candidate list size. + Higher values improve accuracy. Defaults to 300. + target_recall (float, optional): Target recall for early stopping. + Valid range: 0.0 to 1.0. Defaults to 0.95. + radius (float, optional): Search radius for range queries. Default is 0.0. + is_linear (bool, optional): Force linear search. Default is False. + is_using_refiner (bool, optional): Whether to use refiner. Default is False. +)pbdoc") + .def_property_readonly( + "target_recall", + [](const OmegaQueryParams &self) -> float { + return self.target_recall(); + }, + "float: Target recall for OMEGA early stopping (0.0 to 1.0).") + .def("__repr__", + [](const OmegaQueryParams &self) -> std::string { + return "{" + "\"type\":" + + index_type_to_string(self.type()) + + ", \"ef\":" + std::to_string(self.ef()) + + ", \"target_recall\":" + + std::to_string(self.target_recall()) + + ", \"radius\":" + std::to_string(self.radius()) + + ", \"is_linear\":" + std::to_string(self.is_linear()) + + ", \"is_using_refiner\":" + + std::to_string(self.is_using_refiner()) + "}"; + }) + .def(py::pickle( + [](const OmegaQueryParams &self) { + return py::make_tuple(self.ef(), self.target_recall(), + self.radius(), self.is_linear(), + self.is_using_refiner()); + }, + [](py::tuple t) { + if (t.size() != 5) + throw std::runtime_error("Invalid state for OmegaQueryParams"); + auto obj = std::make_shared(t[0].cast(), + t[1].cast()); + obj->set_radius(t[2].cast()); + obj->set_is_linear(t[3].cast()); + obj->set_is_using_refiner(t[4].cast()); + return obj; + })); + // binding ivf query params py::class_> ivf_params(m, "IVFQueryParam", R"pbdoc( @@ -1043,34 +1282,50 @@ Options for optimizing a collection (e.g., merging segments). concurrency (int): Number of threads to use during optimization. If 0, the system will choose an optimal value automatically. Default is 0. + retrain_only (bool): Reuse existing indexes and only retrain OMEGA models. + This skips index rebuild/merge work and is only meaningful for OMEGA. + Default is False. Examples: - >>> opt = OptimizeOption(concurrency=2) + >>> opt = OptimizeOption(concurrency=2, retrain_only=True) >>> print(opt.concurrency) 2 )pbdoc") - .def(py::init(), py::arg("concurrency") = 0, + .def(py::init([](int concurrency, bool retrain_only) { + OptimizeOptions obj{}; + obj.concurrency_ = concurrency; + obj.retrain_only_ = retrain_only; + return obj; + }), + py::arg("concurrency") = 0, py::arg("retrain_only") = false, R"pbdoc( Constructs an OptimizeOption instance. Args: concurrency (int, optional): Number of concurrent threads. 0 means auto-detect. Defaults to 0. + retrain_only (bool, optional): Reuse existing indexes and only retrain + OMEGA models. Defaults to False. )pbdoc") .def_property_readonly( "concurrency", [](const OptimizeOptions &self) { return self.concurrency_; }, "int: Number of threads used for optimization (0 = auto).") + .def_property_readonly( + "retrain_only", + [](const OptimizeOptions &self) { return self.retrain_only_; }, + "bool: Whether to reuse existing indexes and only retrain OMEGA.") .def(py::pickle( [](const OptimizeOptions &self) { - return py::make_tuple(self.concurrency_); + return py::make_tuple(self.concurrency_, self.retrain_only_); }, [](py::tuple t) { - if (t.size() != 1) + if (t.size() != 2) throw std::runtime_error( "Invalid pickle data for OptimizeOptions"); OptimizeOptions obj{}; obj.concurrency_ = t[0].cast(); + obj.retrain_only_ = t[1].cast(); return obj; })); @@ -1393,4 +1648,4 @@ void ZVecPyParams::bind_vector_query(py::module_ &m) { return obj; })); } -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/binding/python/model/python_collection.cc b/src/binding/python/model/python_collection.cc index 1ea32ac32..84c748cae 100644 --- a/src/binding/python/model/python_collection.cc +++ b/src/binding/python/model/python_collection.cc @@ -214,4 +214,4 @@ void ZVecPyCollection::bind_dql_methods( }); } -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/binding/python/typing/python_type.cc b/src/binding/python/typing/python_type.cc index bb5003463..e5760a622 100644 --- a/src/binding/python/typing/python_type.cc +++ b/src/binding/python/typing/python_type.cc @@ -99,7 +99,8 @@ Enumeration of supported index types in Zvec. .value("HNSW_RABITQ", IndexType::HNSW_RABITQ) .value("IVF", IndexType::IVF) .value("FLAT", IndexType::FLAT) - .value("INVERT", IndexType::INVERT); + .value("INVERT", IndexType::INVERT) + .value("OMEGA", IndexType::OMEGA); } void ZVecPyTyping::bind_metric_types(pybind11::module_ &m) { diff --git a/src/core/CMakeLists.txt b/src/core/CMakeLists.txt index 76bebd9f9..bdf1bbdf0 100644 --- a/src/core/CMakeLists.txt +++ b/src/core/CMakeLists.txt @@ -39,10 +39,21 @@ if(NOT RABITQ_SUPPORTED) list(FILTER ALL_CORE_SRCS EXCLUDE REGEX ".*/algorithm/hnsw_rabitq/.*") endif() +if(NOT ZVEC_ENABLE_OMEGA) + list(FILTER ALL_CORE_SRCS EXCLUDE REGEX ".*/algorithm/omega/.*") + list(FILTER ALL_CORE_SRCS EXCLUDE REGEX ".*/interface/indexes/omega_.*") +endif() + +set(CORE_LIBS zvec_ailego zvec_turbo sparsehash magic_enum rabitqlib) +if(ZVEC_ENABLE_OMEGA) + list(APPEND CORE_LIBS omega) +endif() + cc_library( NAME zvec_core STATIC STRICT PACKED SRCS ${ALL_CORE_SRCS} - LIBS zvec_ailego zvec_turbo sparsehash magic_enum rabitqlib + LIBS ${CORE_LIBS} + LIBS ${CORE_LIBS} INCS . ${PROJECT_ROOT_DIR}/src/core VERSION "${GIT_SRCS_VER}" ) diff --git a/src/core/algorithm/CMakeLists.txt b/src/core/algorithm/CMakeLists.txt index 20a459052..7f8e4e0c7 100644 --- a/src/core/algorithm/CMakeLists.txt +++ b/src/core/algorithm/CMakeLists.txt @@ -7,6 +7,23 @@ cc_directory(flat_sparse) cc_directory(ivf) cc_directory(hnsw) cc_directory(hnsw_sparse) +if(ZVEC_ENABLE_OMEGA) + cc_directory(omega) +else() + file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/omega_stub.cc + "// Stub implementation for disabled OMEGA support\n" + "namespace zvec { namespace core { /* empty namespace for compatibility */ } }\n" + ) + + cc_library( + NAME core_knn_omega + STATIC SHARED STRICT ALWAYS_LINK + SRCS ${CMAKE_CURRENT_BINARY_DIR}/omega_stub.cc + LIBS core_framework + INCS . ${PROJECT_ROOT_DIR}/src ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm + VERSION "${PROXIMA_ZVEC_VERSION}" + ) +endif() if(RABITQ_SUPPORTED) message(STATUS "BUILD RABITQ") cc_directory(hnsw_rabitq) diff --git a/src/core/algorithm/hnsw/CMakeLists.txt b/src/core/algorithm/hnsw/CMakeLists.txt index f4a105402..ed6b8d96d 100644 --- a/src/core/algorithm/hnsw/CMakeLists.txt +++ b/src/core/algorithm/hnsw/CMakeLists.txt @@ -5,7 +5,7 @@ cc_library( NAME core_knn_hnsw STATIC SHARED STRICT ALWAYS_LINK SRCS *.cc - LIBS core_framework sparsehash + LIBS core_framework core_utility sparsehash INCS . ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm VERSION "${PROXIMA_ZVEC_VERSION}" ) diff --git a/src/core/algorithm/hnsw/hnsw_algorithm.cc b/src/core/algorithm/hnsw/hnsw_algorithm.cc index fa553f554..ce7c6ace6 100644 --- a/src/core/algorithm/hnsw/hnsw_algorithm.cc +++ b/src/core/algorithm/hnsw/hnsw_algorithm.cc @@ -14,6 +14,7 @@ #include "hnsw_algorithm.h" #include #include +#include #include #include @@ -82,31 +83,65 @@ int HnswAlgorithm::add_node(node_id_t id, level_t level, HnswContext *ctx) { } int HnswAlgorithm::search(HnswContext *ctx) const { - spin_lock_.lock(); - auto maxLevel = entity_.cur_max_level(); - auto entry_point = entity_.entry_point(); - spin_lock_.unlock(); + return search_internal(ctx, true, nullptr, nullptr); +} + +int HnswAlgorithm::fast_search(HnswContext *ctx) const { + return search_internal(ctx, false, nullptr, nullptr); +} + +int HnswAlgorithm::search_with_hooks(HnswContext *ctx, const SearchHooks *hooks, + bool *stopped_early) const { + return search_internal(ctx, true, hooks, stopped_early); +} + +int HnswAlgorithm::fast_search_with_hooks(HnswContext *ctx, + const SearchHooks *hooks, + bool *stopped_early) const { + return search_internal(ctx, false, hooks, stopped_early); +} +int HnswAlgorithm::search_internal(HnswContext *ctx, bool use_lock, + const SearchHooks *hooks, + bool *stopped_early) const { + auto load_entry_point = [&]() { + if (use_lock) { + spin_lock_.lock(); + auto max_level = entity_.cur_max_level(); + auto entry_point = entity_.entry_point(); + spin_lock_.unlock(); + return std::make_pair(max_level, entry_point); + } + return std::make_pair(entity_.cur_max_level(), entity_.entry_point()); + }; + + auto [max_level, entry_point] = load_entry_point(); if (ailego_unlikely(entry_point == kInvalidNodeId)) { + if (stopped_early != nullptr) { + *stopped_early = false; + } return 0; } dist_t dist = ctx->dist_calculator().dist(entry_point); - for (level_t cur_level = maxLevel; cur_level >= 1; --cur_level) { + for (level_t cur_level = max_level; cur_level >= 1; --cur_level) { select_entry_point(cur_level, &entry_point, &dist, ctx); } auto &topk_heap = ctx->topk_heap(); topk_heap.clear(); - search_neighbors(0, &entry_point, &dist, topk_heap, ctx); + bool did_stop_early = + search_neighbors(0, &entry_point, &dist, topk_heap, ctx, hooks); + if (stopped_early != nullptr) { + *stopped_early = did_stop_early; + } - if (ctx->group_by_search()) { + if (!did_stop_early && ctx->group_by_search()) { expand_neighbors_by_group(topk_heap, ctx); } return 0; } - //! select_entry_point on hnsw level, ef = 1 void HnswAlgorithm::select_entry_point(level_t level, node_id_t *entry_point, dist_t *dist, HnswContext *ctx) const { @@ -178,9 +213,10 @@ void HnswAlgorithm::add_neighbors(node_id_t id, level_t level, return; } -void HnswAlgorithm::search_neighbors(level_t level, node_id_t *entry_point, +bool HnswAlgorithm::search_neighbors(level_t level, node_id_t *entry_point, dist_t *dist, TopkHeap &topk, - HnswContext *ctx) const { + HnswContext *ctx, + const SearchHooks *hooks) const { const auto &entity = ctx->get_entity(); HnswDistCalculator &dc = ctx->dist_calculator(); VisitFilter &visit = ctx->visit_filter(); @@ -190,15 +226,33 @@ void HnswAlgorithm::search_neighbors(level_t level, node_id_t *entry_point, filter = [&](node_id_t id) { return ctx->filter()(entity.get_key(id)); }; } + const uint32_t result_topk_limit = ctx->topk(); + const bool track_hook_result_topk = hooks != nullptr && + hooks->on_visit_candidate != nullptr && + result_topk_limit > 0; + TopkHeap hook_result_topk(result_topk_limit > 0 ? result_topk_limit : 1U); + candidates.clear(); visit.clear(); visit.set_visited(*entry_point); - if (!filter(*entry_point)) { + bool entry_inserted_to_topk = !filter(*entry_point); + if (entry_inserted_to_topk) { topk.emplace(*entry_point, *dist); + if (track_hook_result_topk) { + hook_result_topk.emplace(*entry_point, *dist); + } } candidates.emplace(*entry_point, *dist); + if (hooks != nullptr && hooks->on_level0_entry != nullptr) { + hooks->on_level0_entry(*entry_point, *dist, entry_inserted_to_topk, + hooks->user_data); + } while (!candidates.empty() && !ctx->reach_scan_limit()) { + if (hooks != nullptr && hooks->on_hop != nullptr) { + hooks->on_hop(hooks->user_data); + } + auto top = candidates.begin(); node_id_t main_node = top->first; dist_t main_dist = top->second; @@ -260,8 +314,11 @@ void HnswAlgorithm::search_neighbors(level_t level, node_id_t *entry_point, for (uint32_t i = 0; i < size; ++i) { node_id_t node = neighbor_ids[i]; dist_t cur_dist = dists[i]; + bool should_consider_candidate = + (!topk.full()) || cur_dist < topk[0].second; + bool inserted_to_topk = false; - if ((!topk.full()) || cur_dist < topk[0].second) { + if (should_consider_candidate) { candidates.emplace(node, cur_dist); // update entry_point for next level scan if (cur_dist < *dist) { @@ -270,12 +327,27 @@ void HnswAlgorithm::search_neighbors(level_t level, node_id_t *entry_point, } if (!filter(node)) { topk.emplace(node, cur_dist); + if (track_hook_result_topk) { + inserted_to_topk = !hook_result_topk.full() || + cur_dist < hook_result_topk[0].second; + if (inserted_to_topk) { + hook_result_topk.emplace(node, cur_dist); + } + } + } + } + + if (hooks != nullptr && hooks->on_visit_candidate != nullptr) { + bool should_stop = hooks->on_visit_candidate( + node, cur_dist, inserted_to_topk, hooks->user_data); + if (should_stop) { + return true; } - } // end if + } } // end for } // while - return; + return false; } void HnswAlgorithm::expand_neighbors_by_group(TopkHeap &topk, diff --git a/src/core/algorithm/hnsw/hnsw_algorithm.h b/src/core/algorithm/hnsw/hnsw_algorithm.h index 886d870c6..b7127bb48 100644 --- a/src/core/algorithm/hnsw/hnsw_algorithm.h +++ b/src/core/algorithm/hnsw/hnsw_algorithm.h @@ -27,6 +27,27 @@ class HnswAlgorithm { public: typedef std::unique_ptr UPointer; + // SearchHooks is the integration seam used by OMEGA. Callbacks are invoked + // from the level-0 search loop in this order: + // 1. on_level0_entry once after the initial level-0 entry point is accepted + // 2. on_hop once per popped candidate expansion + // 3. on_visit_candidate once per candidate comparison at level 0 + // + // inserted_to_topk is computed by the HNSW loop before callbacks fire and + // tells the callback whether the candidate improved the current result heap. + // + // Returning true from on_visit_candidate requests early termination of the + // level-0 search. This is currently used by OMEGA adaptive stopping. + // + struct SearchHooks { + void *user_data{nullptr}; + void (*on_level0_entry)(node_id_t id, dist_t dist, bool inserted_to_topk, + void *user_data){nullptr}; + void (*on_hop)(void *user_data){nullptr}; + bool (*on_visit_candidate)(node_id_t id, dist_t dist, bool inserted_to_topk, + void *user_data){nullptr}; + }; + public: //! Constructor explicit HnswAlgorithm(HnswEntity &entity); @@ -47,6 +68,17 @@ class HnswAlgorithm { //! return 0 on success, or errCode in failure. results saved in ctx int search(HnswContext *ctx) const; + //! do knn search in graph without lock + //! return 0 on success, or errCode in failure. results saved in ctx + int fast_search(HnswContext *ctx) const; + + //! do knn search in graph with optional callbacks inserted on the hot path + int search_with_hooks(HnswContext *ctx, const SearchHooks *hooks, + bool *stopped_early = nullptr) const; + + //! do knn search in graph without lock with optional callbacks + int fast_search_with_hooks(HnswContext *ctx, const SearchHooks *hooks, + bool *stopped_early = nullptr) const; //! Initiate HnswAlgorithm int init() { level_probas_.clear(); @@ -80,6 +112,9 @@ class HnswAlgorithm { } private: + int search_internal(HnswContext *ctx, bool use_lock, const SearchHooks *hooks, + bool *stopped_early) const; + //! Select in upper layer to get entry point for next layer search void select_entry_point(level_t level, node_id_t *entry_point, dist_t *dist, HnswContext *ctx) const; @@ -91,8 +126,9 @@ class HnswAlgorithm { //! Given a node id and level, search the nearest neighbors in graph //! Note: the nearest neighbors result keeps in topk, and entry_point and //! dist will be updated to current level nearest node id and distance - void search_neighbors(level_t level, node_id_t *entry_point, dist_t *dist, - TopkHeap &topk, HnswContext *ctx) const; + bool search_neighbors(level_t level, node_id_t *entry_point, dist_t *dist, + TopkHeap &topk, HnswContext *ctx, + const SearchHooks *hooks = nullptr) const; //! Update the node's neighbors void update_neighbors(HnswDistCalculator &dc, node_id_t id, level_t level, @@ -127,4 +163,4 @@ class HnswAlgorithm { }; } // namespace core -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/core/algorithm/hnsw/hnsw_context.h b/src/core/algorithm/hnsw/hnsw_context.h index e776b81a7..81d4852d7 100644 --- a/src/core/algorithm/hnsw/hnsw_context.h +++ b/src/core/algorithm/hnsw/hnsw_context.h @@ -96,11 +96,12 @@ class HnswContext : public IndexContext { //! Retrieve string of debug virtual std::string debug_string(void) const override { char buf[4096]; - size_t size = snprintf( - buf, sizeof(buf), - "scan_cnt=%zu,get_vector_cnt=%u,get_neighbors_cnt=%u,dup_node=%u", - get_scan_num(), stats_get_vector_cnt_, stats_get_neighbors_cnt_, - stats_visit_dup_cnt_); + size_t size = + snprintf(buf, sizeof(buf), + "scan_cnt=%zu,pairwise_dist_cnt=%zu,get_vector_cnt=%u,get_" + "neighbors_cnt=%u,dup_node=%u", + get_scan_num(), get_pairwise_dist_num(), stats_get_vector_cnt_, + stats_get_neighbors_cnt_, stats_visit_dup_cnt_); return std::string(buf, size); } @@ -404,6 +405,10 @@ class HnswContext : public IndexContext { return dc_.compare_cnt(); } + inline uint64_t get_pairwise_dist_num() const { + return dc_.pairwise_dist_cnt(); + } + inline uint64_t reach_scan_limit() const { return dc_.compare_cnt() >= max_scan_num_; } diff --git a/src/core/algorithm/hnsw/hnsw_dist_calculator.h b/src/core/algorithm/hnsw/hnsw_dist_calculator.h index caf6e6d15..6c410026c 100644 --- a/src/core/algorithm/hnsw/hnsw_dist_calculator.h +++ b/src/core/algorithm/hnsw/hnsw_dist_calculator.h @@ -40,7 +40,8 @@ class HnswDistCalculator { batch_distance_(metric->batch_distance()), query_(nullptr), dim_(dim), - compare_cnt_(0) {} + compare_cnt_(0), + pairwise_dist_cnt_(0) {} //! Constructor HnswDistCalculator(const HnswEntity *entity, @@ -51,7 +52,8 @@ class HnswDistCalculator { batch_distance_(metric->batch_distance()), query_(query), dim_(dim), - compare_cnt_(0) {} + compare_cnt_(0), + pairwise_dist_cnt_(0) {} //! Constructor HnswDistCalculator(const HnswEntity *entity, @@ -61,7 +63,8 @@ class HnswDistCalculator { batch_distance_(metric->batch_distance()), query_(nullptr), dim_(0), - compare_cnt_(0) {} + compare_cnt_(0), + pairwise_dist_cnt_(0) {} void update(const HnswEntity *entity, const IndexMetric::Pointer &metric) { entity_ = entity; @@ -108,6 +111,7 @@ class HnswDistCalculator { //! Returns distance between query and vec. inline dist_t dist(const void *vec) { compare_cnt_++; + pairwise_dist_cnt_++; return dist(vec, query_); } @@ -115,6 +119,7 @@ class HnswDistCalculator { //! Return distance between query and node id. inline dist_t dist(node_id_t id) { compare_cnt_++; + pairwise_dist_cnt_++; const void *feat = entity_->get_vector(id); if (ailego_unlikely(feat == nullptr)) { @@ -129,6 +134,7 @@ class HnswDistCalculator { //! Return dist node lhs between node rhs inline dist_t dist(node_id_t lhs, node_id_t rhs) { compare_cnt_++; + pairwise_dist_cnt_++; const void *feat = entity_->get_vector(lhs); const void *query = entity_->get_vector(rhs); @@ -155,12 +161,14 @@ class HnswDistCalculator { void batch_dist(const void **vecs, size_t num, dist_t *distances) { compare_cnt_++; + pairwise_dist_cnt_ += num; batch_distance_(vecs, query_, num, dim_, distances); } inline dist_t batch_dist(node_id_t id) { compare_cnt_++; + pairwise_dist_cnt_++; const void *feat = entity_->get_vector(id); if (ailego_unlikely(feat == nullptr)) { @@ -176,11 +184,13 @@ class HnswDistCalculator { inline void clear() { compare_cnt_ = 0; + pairwise_dist_cnt_ = 0; error_ = false; } inline void clear_compare_cnt() { compare_cnt_ = 0; + pairwise_dist_cnt_ = 0; } inline bool error() const { @@ -192,6 +202,10 @@ class HnswDistCalculator { return compare_cnt_; } + inline uint64_t pairwise_dist_cnt() const { + return pairwise_dist_cnt_; + } + inline uint32_t dimension() const { return dim_; } @@ -209,7 +223,8 @@ class HnswDistCalculator { const void *query_; uint32_t dim_; - uint32_t compare_cnt_; // record distance compute times + uint32_t compare_cnt_; // record distance compute times + uint64_t pairwise_dist_cnt_; // record actual pairwise distance work // uint32_t compare_cnt_batch_; // record batch distance compute time bool error_{false}; }; diff --git a/src/core/algorithm/hnsw/hnsw_searcher.cc b/src/core/algorithm/hnsw/hnsw_searcher.cc index c68a146d9..6334cb417 100644 --- a/src/core/algorithm/hnsw/hnsw_searcher.cc +++ b/src/core/algorithm/hnsw/hnsw_searcher.cc @@ -190,6 +190,16 @@ int HnswSearcher::update_context(HnswContext *ctx) const { entity, magic_); } +int HnswSearcher::fast_search(HnswContext *ctx) const { + return alg_->fast_search(ctx); +} + +int HnswSearcher::fast_search_with_hooks( + HnswContext *ctx, const HnswAlgorithm::SearchHooks *hooks, + bool *stopped_early) const { + return alg_->fast_search_with_hooks(ctx, hooks, stopped_early); +} + int HnswSearcher::search_impl(const void *query, const IndexQueryMeta &qmeta, uint32_t count, Context::Pointer &context) const { if (ailego_unlikely(!query || !context)) { diff --git a/src/core/algorithm/hnsw/hnsw_searcher.h b/src/core/algorithm/hnsw/hnsw_searcher.h index d79526df8..397ba0180 100644 --- a/src/core/algorithm/hnsw/hnsw_searcher.h +++ b/src/core/algorithm/hnsw/hnsw_searcher.h @@ -111,16 +111,15 @@ class HnswSearcher : public IndexSearcher { //! current streamer/searcher int update_context(HnswContext *ctx) const; - private: - enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_LOADED = 2 }; + protected: + int fast_search(HnswContext *ctx) const; - HnswSearcherEntity entity_{}; - HnswAlgorithm::UPointer alg_; // impl graph algorithm + int fast_search_with_hooks(HnswContext *ctx, + const HnswAlgorithm::SearchHooks *hooks, + bool *stopped_early) const; + + enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_LOADED = 2 }; - IndexMetric::Pointer metric_{}; - IndexMeta meta_{}; - ailego::Params params_{}; - Stats stats_; uint32_t ef_{HnswEntity::kDefaultEf}; uint32_t max_scan_num_{0U}; uint32_t bruteforce_threshold_{HnswEntity::kDefaultBruteForceThreshold}; @@ -131,9 +130,17 @@ class HnswSearcher : public IndexSearcher { bool force_padding_topk_enabled_{false}; float bf_negative_probability_{HnswEntity::kDefaultBFNegativeProbability}; uint32_t magic_{0U}; - State state_{STATE_INIT}; + HnswSearcherEntity entity_{}; + IndexMetric::Pointer metric_{}; + IndexMeta meta_{}; + + private: + HnswAlgorithm::UPointer alg_; // impl graph algorithm + + ailego::Params params_{}; + Stats stats_; }; } // namespace core -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/core/algorithm/hnsw/hnsw_streamer.cc b/src/core/algorithm/hnsw/hnsw_streamer.cc index 5804c7d04..35ac4a6ab 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer.cc +++ b/src/core/algorithm/hnsw/hnsw_streamer.cc @@ -12,11 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "hnsw_streamer.h" +#include +#include #include #include #include #include #include "utility/sparse_utility.h" +#include "utility/steady_clock_timer.h" #include "hnsw_algorithm.h" #include "hnsw_context.h" #include "hnsw_dist_calculator.h" @@ -25,6 +28,51 @@ namespace zvec { namespace core { +namespace { + +bool ShouldLogHnswQueryStats(uint64_t query_seq) { + static const bool enabled = []() { + const char *value = std::getenv("ZVEC_HNSW_LOG_QUERY_STATS"); + if (value == nullptr) { + return false; + } + return std::string(value) != "0"; + }(); + if (!enabled) { + return false; + } + + static const uint64_t limit = []() -> uint64_t { + const char *value = std::getenv("ZVEC_HNSW_LOG_QUERY_LIMIT"); + if (value == nullptr || *value == '\0') { + return std::numeric_limits::max(); + } + char *end = nullptr; + unsigned long long parsed = std::strtoull(value, &end, 10); + if (end == value) { + return std::numeric_limits::max(); + } + return static_cast(parsed); + }(); + + return query_seq < limit; +} + +bool UseEmptyHnswHooks() { + const char *value = std::getenv("ZVEC_HNSW_ENABLE_EMPTY_HOOKS"); + if (value == nullptr) { + return false; + } + return std::string(value) != "0"; +} + +std::atomic &HnswQueryStatsSequence() { + static std::atomic sequence{0}; + return sequence; +} + +} // namespace + HnswStreamer::HnswStreamer() : entity_(stats_) {} HnswStreamer::~HnswStreamer() { @@ -33,6 +81,16 @@ HnswStreamer::~HnswStreamer() { } } +int HnswStreamer::FastSearch(HnswContext *ctx) const { + return alg_->fast_search(ctx); +} + +int HnswStreamer::FastSearchWithHooks(HnswContext *ctx, + const HnswAlgorithm::SearchHooks *hooks, + bool *stopped_early) const { + return alg_->fast_search_with_hooks(ctx, hooks, stopped_early); +} + int HnswStreamer::init(const IndexMeta &imeta, const ailego::Params ¶ms) { meta_ = imeta; meta_.set_streamer("HnswStreamer", HnswEntity::kRevision, params); @@ -619,13 +677,38 @@ int HnswStreamer::search_impl(const void *query, const IndexQueryMeta &qmeta, ctx->update_dist_caculator_distance(search_distance_, search_batch_distance_); ctx->resize_results(count); ctx->check_need_adjuct_ctx(entity_.doc_cnt()); + const bool use_empty_hooks = UseEmptyHnswHooks(); + HnswAlgorithm::SearchHooks empty_hooks; for (size_t q = 0; q < count; ++q) { + auto query_start = SteadyClockTimer::Now(); ctx->reset_query(query); - ret = alg_->search(ctx); + auto query_search_start = SteadyClockTimer::Now(); + if (use_empty_hooks) { + bool stopped_early = false; + ret = alg_->search_with_hooks(ctx, &empty_hooks, &stopped_early); + } else { + ret = alg_->search(ctx); + } if (ailego_unlikely(ret != 0)) { LOG_ERROR("Hnsw searcher fast search failed"); return ret; } + auto query_search_end = SteadyClockTimer::Now(); + auto query_search_time_ns = + SteadyClockTimer::ElapsedNs(query_search_start, query_search_end); + auto query_latency_ns = + SteadyClockTimer::ElapsedNs(query_start, SteadyClockTimer::Now()); + uint64_t query_seq = HnswQueryStatsSequence().fetch_add(1); + if (ShouldLogHnswQueryStats(query_seq)) { + LOG_INFO( + "HNSW query stats: query_seq=%llu hook_mode=%s cmps=%zu " + "pairwise_dist_cnt=%zu pure_search_ms=%.3f latency_ms=%.3f", + static_cast(query_seq), + use_empty_hooks ? "empty" : "none", ctx->get_scan_num(), + ctx->get_pairwise_dist_num(), + static_cast(query_search_time_ns) / 1e6, + static_cast(query_latency_ns) / 1e6); + } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } @@ -728,6 +811,7 @@ int HnswStreamer::search_bf_impl( auto &topk = ctx->topk_heap(); for (size_t q = 0; q < count; ++q) { + auto query_start = SteadyClockTimer::Now(); ctx->reset_query(query); topk.clear(); for (node_id_t id = 0; id < entity_.doc_cnt(); ++id) { @@ -740,6 +824,17 @@ int HnswStreamer::search_bf_impl( topk.emplace(id, dist); } } + auto query_latency_ns = + SteadyClockTimer::ElapsedNs(query_start, SteadyClockTimer::Now()); + uint64_t query_seq = HnswQueryStatsSequence().fetch_add(1); + if (ShouldLogHnswQueryStats(query_seq)) { + LOG_INFO( + "HNSW query stats: query_seq=%llu cmps=%zu pairwise_dist_cnt=%zu " + "latency_ms=%.3f", + static_cast(query_seq), ctx->get_scan_num(), + ctx->get_pairwise_dist_num(), + static_cast(query_latency_ns) / 1e6); + } ctx->topk_to_result(q); query = static_cast(query) + qmeta.element_size(); } @@ -849,4 +944,4 @@ int HnswStreamer::search_bf_by_p_keys_impl( INDEX_FACTORY_REGISTER_STREAMER(HnswStreamer); } // namespace core -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/core/algorithm/hnsw/hnsw_streamer.h b/src/core/algorithm/hnsw/hnsw_streamer.h index b81106daf..7c046d0ee 100644 --- a/src/core/algorithm/hnsw/hnsw_streamer.h +++ b/src/core/algorithm/hnsw/hnsw_streamer.h @@ -31,6 +31,12 @@ class HnswStreamer : public IndexStreamer { HnswStreamer(const HnswStreamer &streamer) = delete; HnswStreamer &operator=(const HnswStreamer &streamer) = delete; + int FastSearch(HnswContext *ctx) const; + + int FastSearchWithHooks(HnswContext *ctx, + const HnswAlgorithm::SearchHooks *hooks, + bool *stopped_early) const; + protected: //! Initialize Streamer virtual int init(const IndexMeta &imeta, @@ -158,12 +164,11 @@ class HnswStreamer : public IndexStreamer { return 0; } - private: + protected: //! To share ctx across streamer/searcher, we need to update the context for - //! current streamer/searcher + //! current streamer/searcher (moved to protected for OmegaStreamer) int update_context(HnswContext *ctx) const; - - private: + // Changed from private to protected to allow OmegaStreamer inheritance enum State { STATE_INIT = 0, STATE_INITED = 1, STATE_OPENED = 2 }; class Stats : public IndexStreamer::Stats { public: diff --git a/src/core/algorithm/hnsw_rabitq/CMakeLists.txt b/src/core/algorithm/hnsw_rabitq/CMakeLists.txt index ed547dc76..e3da57dbd 100644 --- a/src/core/algorithm/hnsw_rabitq/CMakeLists.txt +++ b/src/core/algorithm/hnsw_rabitq/CMakeLists.txt @@ -15,7 +15,7 @@ cc_library( NAME core_knn_hnsw_rabitq STATIC SHARED STRICT ALWAYS_LINK SRCS *.cc - LIBS core_framework rabitqlib sparsehash + LIBS core_framework core_utility core_knn_cluster rabitqlib sparsehash INCS . ${PROJECT_ROOT_DIR}/src ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm VERSION "${PROXIMA_ZVEC_VERSION}" - ) \ No newline at end of file + ) diff --git a/src/core/algorithm/hnsw_sparse/CMakeLists.txt b/src/core/algorithm/hnsw_sparse/CMakeLists.txt index fe26d10e1..38ba102a6 100644 --- a/src/core/algorithm/hnsw_sparse/CMakeLists.txt +++ b/src/core/algorithm/hnsw_sparse/CMakeLists.txt @@ -5,7 +5,7 @@ cc_library( NAME core_knn_hnsw_sparse STATIC SHARED STRICT ALWAYS_LINK SRCS *.cc - LIBS core_framework sparsehash + LIBS core_framework core_utility sparsehash INCS . ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm VERSION "${PROXIMA_ZVEC_VERSION}" ) diff --git a/src/core/algorithm/omega/CMakeLists.txt b/src/core/algorithm/omega/CMakeLists.txt new file mode 100644 index 000000000..cf7548228 --- /dev/null +++ b/src/core/algorithm/omega/CMakeLists.txt @@ -0,0 +1,11 @@ +include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) +include(${PROJECT_ROOT_DIR}/cmake/option.cmake) + +cc_library( + NAME core_knn_omega + STATIC SHARED STRICT ALWAYS_LINK + SRCS *.cc + LIBS core_framework core_knn_hnsw omega + INCS . ${PROJECT_ROOT_DIR}/src/include ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm ${PROJECT_ROOT_DIR}/thirdparty/omega/OMEGALib/include + VERSION "${PROXIMA_ZVEC_VERSION}" + ) diff --git a/src/core/algorithm/omega/omega_context.h b/src/core/algorithm/omega/omega_context.h new file mode 100644 index 000000000..943096c1b --- /dev/null +++ b/src/core/algorithm/omega/omega_context.h @@ -0,0 +1,127 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "omega_params.h" +#include "../hnsw/hnsw_context.h" + +namespace zvec { +namespace core { + +/** + * OmegaContext extends HnswContext to support OMEGA-specific parameters + * like target_recall that can be set per-query. + * + * Training records are stored per-context (no shared state, no locks needed). + */ +class OmegaContext : public HnswContext { + public: + //! Constructor + OmegaContext(size_t dimension, const IndexMetric::Pointer &metric, + const HnswEntity::Pointer &entity) + : HnswContext(dimension, metric, entity), + target_recall_(0.95f), + training_query_id_(-1) {} + + //! Constructor + OmegaContext(const IndexMetric::Pointer &metric, + const HnswEntity::Pointer &entity) + : HnswContext(metric, entity), + target_recall_(0.95f), + training_query_id_(-1) {} + + //! Destructor + virtual ~OmegaContext() = default; + + //! Get target recall for this query + float target_recall() const { + return target_recall_; + } + + //! Get training query ID for this query (-1 means not set, use global) + int training_query_id() const { + return training_query_id_; + } + + //! Move training records out (override base class virtual method) + std::vector take_training_records() override { + return std::move(training_records_); + } + + //! Add a training record + void add_training_record(core_interface::TrainingRecord record) { + training_records_.push_back(std::move(record)); + } + + //! Clear training records (override base class virtual method) + //! Called before each search when context is reused from pool + void clear_training_records() override { + training_records_.clear(); + gt_cmps_per_rank_.clear(); + total_cmps_ = 0; + } + + //! Set gt_cmps data for this query + void set_gt_cmps(const std::vector >_cmps, int total_cmps) { + gt_cmps_per_rank_ = gt_cmps; + total_cmps_ = total_cmps; + } + + //! Take gt_cmps data (override base class virtual method) + std::vector take_gt_cmps() override { + return std::move(gt_cmps_per_rank_); + } + + //! Get total comparisons (override base class virtual method) + int get_total_cmps() const override { + return total_cmps_; + } + + //! Get training query ID (override base class virtual method) + int get_training_query_id() const override { + return training_query_id_; + } + + //! Update context parameters (overrides HnswContext::update) + int update(const ailego::Params ¶ms) override { + // First call parent to update HNSW parameters + int ret = HnswContext::update(params); + if (ret != 0) { + return ret; + } + + // Extract OMEGA-specific parameters + if (params.has(PARAM_OMEGA_SEARCHER_TARGET_RECALL)) { + params.get(PARAM_OMEGA_SEARCHER_TARGET_RECALL, &target_recall_); + } + if (params.has(PARAM_OMEGA_SEARCHER_TRAINING_QUERY_ID)) { + params.get(PARAM_OMEGA_SEARCHER_TRAINING_QUERY_ID, &training_query_id_); + } + + return 0; + } + + private: + float target_recall_; // Per-query target recall + int training_query_id_; // Per-query training query ID for parallel training + std::vector + training_records_; // Per-query training records + std::vector gt_cmps_per_rank_; // cmps value when each GT rank was found + int total_cmps_ = 0; // Total cmps for this search +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/omega/omega_hook_utils.h b/src/core/algorithm/omega/omega_hook_utils.h new file mode 100644 index 000000000..16fd11e12 --- /dev/null +++ b/src/core/algorithm/omega/omega_hook_utils.h @@ -0,0 +1,164 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include "../hnsw/hnsw_entity.h" + +namespace zvec::core { + +inline bool DisableOmegaModelPrediction() { + const char *value = std::getenv("ZVEC_OMEGA_DISABLE_MODEL_PREDICTION"); + if (value == nullptr) { + return false; + } + return std::string(value) != "0"; +} + +struct OmegaHookState { + struct PendingVisitBuffer { + std::vector storage; + int head{0}; + int count{0}; + + void Reset(int capacity) { + head = 0; + count = 0; + storage.resize(std::max(1, capacity)); + } + + bool Empty() const { + return count == 0; + } + + int Capacity() const { + return static_cast(storage.size()); + } + + void Push(const omega::SearchContext::VisitCandidate &candidate) { + storage[(head + count) % Capacity()] = candidate; + ++count; + } + + const omega::SearchContext::VisitCandidate *Data() const { + return storage.data() + head; + } + + void Clear() { + head = 0; + count = 0; + } + }; + + omega::SearchContext *search_ctx{nullptr}; + bool enable_early_stopping{false}; + bool per_cmp_reporting{false}; + PendingVisitBuffer pending_candidates; + int batch_min_interval{1}; +}; + +inline void ResetOmegaHookState(OmegaHookState *state) { + if (state->search_ctx != nullptr) { + state->batch_min_interval = + state->search_ctx->GetPredictionBatchMinInterval(); + } else { + state->batch_min_interval = 1; + } + state->pending_candidates.Reset(state->batch_min_interval); +} + +inline bool ShouldFlushOmegaPendingCandidates(const OmegaHookState &state) { + if (state.pending_candidates.Empty()) { + return false; + } + if (state.pending_candidates.count >= state.batch_min_interval) { + return true; + } + if (state.search_ctx == nullptr) { + return false; + } + return state.search_ctx->GetTotalCmps() + state.pending_candidates.count >= + state.search_ctx->GetNextPredictionCmps(); +} + +inline bool FlushOmegaPendingCandidates(OmegaHookState *state, + int flush_count) { + if (state->search_ctx == nullptr || flush_count <= 0 || + state->pending_candidates.Empty()) { + return false; + } + + flush_count = std::min(flush_count, state->pending_candidates.count); + bool should_predict = false; + should_predict = state->search_ctx->ReportVisitCandidates( + state->pending_candidates.Data(), static_cast(flush_count)); + state->pending_candidates.Clear(); + if (!state->enable_early_stopping || !should_predict) { + return false; + } + + bool should_stop = false; + should_stop = state->search_ctx->ShouldStopEarly(); + return should_stop; +} + +inline bool MaybeFlushOmegaPendingCandidates(OmegaHookState *state) { + if (!ShouldFlushOmegaPendingCandidates(*state)) { + return false; + } + return FlushOmegaPendingCandidates(state, state->pending_candidates.count); +} + +inline void OnOmegaLevel0Entry(node_id_t id, dist_t dist, + bool /*inserted_to_topk*/, void *user_data) { + auto &state = *static_cast(user_data); + if (state.per_cmp_reporting) { + state.search_ctx->SetDistStart(dist); + state.search_ctx->ReportVisitCandidate(id, dist, true); + return; + } + state.search_ctx->SetDistStart(dist); + state.pending_candidates.Push({static_cast(id), dist, true}); + MaybeFlushOmegaPendingCandidates(&state); +} + +inline void OnOmegaHop(void *user_data) { + auto &state = *static_cast(user_data); + state.search_ctx->ReportHop(); +} + +inline bool OnOmegaVisitCandidate(node_id_t id, dist_t dist, + bool inserted_to_topk, void *user_data) { + auto &state = *static_cast(user_data); + if (state.per_cmp_reporting) { + bool should_predict = false; + should_predict = + state.search_ctx->ReportVisitCandidate(id, dist, inserted_to_topk); + if (!state.enable_early_stopping || !should_predict) { + return false; + } + bool should_stop = false; + should_stop = state.search_ctx->ShouldStopEarly(); + return should_stop; + } + state.pending_candidates.Push({static_cast(id), dist, inserted_to_topk}); + return MaybeFlushOmegaPendingCandidates(&state); +} + +} // namespace zvec::core diff --git a/src/core/algorithm/omega/omega_params.h b/src/core/algorithm/omega/omega_params.h new file mode 100644 index 000000000..1ad91cc84 --- /dev/null +++ b/src/core/algorithm/omega/omega_params.h @@ -0,0 +1,33 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace zvec::core { + +// OMEGA searcher parameters (used at query time) +static const std::string PARAM_OMEGA_SEARCHER_TARGET_RECALL( + "proxima.omega.searcher.target_recall"); + +// Training query ID for parallel training searches +static const std::string PARAM_OMEGA_SEARCHER_TRAINING_QUERY_ID( + "proxima.omega.searcher.training_query_id"); + +// OMEGA streamer parameters (used at index time) +static const std::string PARAM_OMEGA_STREAMER_TARGET_RECALL( + "proxima.omega.streamer.target_recall"); + +} // namespace zvec::core diff --git a/src/core/algorithm/omega/omega_searcher.cc b/src/core/algorithm/omega/omega_searcher.cc new file mode 100644 index 000000000..a942c5aff --- /dev/null +++ b/src/core/algorithm/omega/omega_searcher.cc @@ -0,0 +1,293 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "omega_searcher.h" +#include +#include +#include +#include +#include +#include "omega_context.h" +#include "omega_hook_utils.h" +#include "omega_params.h" +#include "../hnsw/hnsw_context.h" + +namespace zvec { +namespace core { + +OmegaSearcher::OmegaSearcher(void) + : HnswSearcher(), + omega_model_(nullptr), + omega_enabled_(false), + use_omega_mode_(false), + min_vector_threshold_(100000), + current_vector_count_(0), + window_size_(100) {} + +OmegaSearcher::~OmegaSearcher(void) { + this->cleanup(); +} + +bool OmegaSearcher::should_use_omega() const { + if (DisableOmegaModelPrediction()) { + return true; + } + return omega_enabled_ && use_omega_mode_ && omega_model_ != nullptr && + omega_model_is_loaded(omega_model_); +} + +int OmegaSearcher::init(const ailego::Params ¶ms) { + // Get OMEGA-specific parameters + omega_enabled_ = + params.has("omega.enabled") ? params.get_as_bool("omega.enabled") : false; + min_vector_threshold_ = + params.has("omega.min_vector_threshold") + ? params.get_as_uint32("omega.min_vector_threshold") + : 100000; + window_size_ = params.has("omega.window_size") + ? params.get_as_int32("omega.window_size") + : 100; + + // Call parent class init + int ret = HnswSearcher::init(params); + if (ret != 0) { + LOG_ERROR("Failed to initialize HNSW searcher"); + return ret; + } + + LOG_INFO( + "OmegaSearcher initialized (omega_enabled=%d, min_threshold=%u, " + "window_size=%d)", + omega_enabled_, min_vector_threshold_, window_size_); + return 0; +} + +int OmegaSearcher::cleanup(void) { + // Cleanup OMEGA model + if (omega_model_ != nullptr) { + omega_model_destroy(omega_model_); + omega_model_ = nullptr; + } + + // Call parent class cleanup + return HnswSearcher::cleanup(); +} + +int OmegaSearcher::load(IndexStorage::Pointer container, + IndexMetric::Pointer metric) { + // Load HNSW index using parent class + int ret = HnswSearcher::load(container, metric); + if (ret != 0) { + LOG_ERROR("Failed to load HNSW index"); + return ret; + } + + // Get vector count from HNSW stats + current_vector_count_ = stats().loaded_count(); + + // Try to load OMEGA model if enabled and threshold met + use_omega_mode_ = false; + if (omega_enabled_ && current_vector_count_ >= min_vector_threshold_) { + // Load the model colocated with the persisted index file. + std::string effective_model_dir; + if (container) { + std::string index_path = container->file_path(); + if (!index_path.empty()) { + size_t last_slash = index_path.rfind('/'); + if (last_slash != std::string::npos) { + effective_model_dir = + index_path.substr(0, last_slash) + "/omega_model"; + } + } + } + + if (!effective_model_dir.empty()) { + omega_model_ = omega_model_create(); + if (omega_model_ != nullptr) { + ret = omega_model_load(omega_model_, effective_model_dir.c_str()); + if (ret == 0 && omega_model_is_loaded(omega_model_)) { + use_omega_mode_ = true; + LOG_INFO("OMEGA model loaded successfully from %s", + effective_model_dir.c_str()); + } else { + LOG_WARN("Failed to load OMEGA model from %s, falling back to HNSW", + effective_model_dir.c_str()); + omega_model_destroy(omega_model_); + omega_model_ = nullptr; + } + } + } else { + LOG_WARN( + "OMEGA enabled but cannot derive omega_model path from index " + "storage, falling back to HNSW"); + } + } else { + if (omega_enabled_) { + LOG_INFO("Vector count (%zu) below threshold (%u), using standard HNSW", + current_vector_count_, min_vector_threshold_); + } + } + + return 0; +} + +int OmegaSearcher::unload(void) { + // Unload OMEGA model + if (omega_model_ != nullptr) { + omega_model_destroy(omega_model_); + omega_model_ = nullptr; + } + use_omega_mode_ = false; + + // Call parent class unload + return HnswSearcher::unload(); +} + +IndexSearcher::Context::Pointer OmegaSearcher::create_context() const { + if (ailego_unlikely(state_ != STATE_LOADED)) { + LOG_ERROR("Load the index first before create context"); + return Context::Pointer(); + } + const HnswEntity::Pointer search_ctx_entity = entity_.clone(); + if (!search_ctx_entity) { + LOG_ERROR("Failed to create search context entity"); + return Context::Pointer(); + } + + // Create OmegaContext instead of HnswContext + OmegaContext *ctx = new (std::nothrow) + OmegaContext(meta_.dimension(), metric_, search_ctx_entity); + if (ailego_unlikely(ctx == nullptr)) { + LOG_ERROR("Failed to new OmegaContext"); + return Context::Pointer(); + } + + // Initialize context with HNSW parameters + ctx->set_ef(ef_); + ctx->set_max_scan_num(max_scan_num_); + uint32_t filter_mode = + bf_enabled_ ? VisitFilter::BloomFilter : VisitFilter::ByteMap; + ctx->set_filter_mode(filter_mode); + ctx->set_filter_negative_probability(bf_negative_probability_); + ctx->set_magic(magic_); + ctx->set_force_padding_topk(force_padding_topk_enabled_); + ctx->set_bruteforce_threshold(bruteforce_threshold_); + + if (ailego_unlikely(ctx->init(HnswContext::kSearcherContext)) != 0) { + LOG_ERROR("Init OmegaContext failed"); + delete ctx; + return Context::Pointer(); + } + + return Context::Pointer(ctx); +} + +int OmegaSearcher::search_impl(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, ContextPointer &context) const { + // If OMEGA mode is not active, delegate to parent HNSW + if (!should_use_omega()) { + return HnswSearcher::search_impl(query, qmeta, count, context); + } + + // Use OMEGA adaptive search + return adaptive_search(query, qmeta, count, context); +} + +int OmegaSearcher::adaptive_search(const void *query, + const IndexQueryMeta &qmeta, uint32_t count, + ContextPointer &context) const { + // Cast context to OmegaContext to access OMEGA-specific features + auto *omega_ctx = dynamic_cast(context.get()); + if (omega_ctx == nullptr) { + LOG_ERROR("Context is not OmegaContext"); + return IndexError_InvalidArgument; + } + + // Read target_recall from context (per-query parameter) + float target_recall = omega_ctx->target_recall(); + + // Create OMEGA search context with parameters (stateful interface). + // OMEGA's k is the requested top-k, not the batch/query count. + int omega_topk = static_cast(omega_ctx->topk()); + if (omega_topk <= 0) { + omega_topk = static_cast(count); + } + + const bool disable_model_prediction = DisableOmegaModelPrediction(); + OmegaModelHandle model_to_use = + disable_model_prediction ? nullptr : omega_model_; + + OmegaSearchHandle omega_search = omega_search_create_with_params( + model_to_use, target_recall, omega_topk, window_size_); + + if (omega_search == nullptr) { + LOG_WARN("Failed to create OMEGA search context, falling back to HNSW"); + return HnswSearcher::search_impl(query, qmeta, count, context); + } + omega::SearchContext *omega_search_ctx = + omega_search_get_cpp_context(omega_search); + if (omega_search_ctx == nullptr) { + omega_search_destroy(omega_search); + LOG_WARN("Failed to get OMEGA search context, falling back to HNSW"); + return HnswSearcher::search_impl(query, qmeta, count, context); + } + + omega_ctx->clear(); + omega_ctx->resize_results(count); + bool early_stop_hit = false; + + for (size_t q = 0; q < count; ++q) { + omega_ctx->reset_query(query); + OmegaHookState hook_state; + hook_state.search_ctx = omega_search_ctx; + hook_state.enable_early_stopping = !disable_model_prediction; + hook_state.per_cmp_reporting = false; + ResetOmegaHookState(&hook_state); + HnswAlgorithm::SearchHooks hooks; + hooks.user_data = &hook_state; + hooks.on_level0_entry = OnOmegaLevel0Entry; + hooks.on_hop = OnOmegaHop; + hooks.on_visit_candidate = OnOmegaVisitCandidate; + + int ret = fast_search_with_hooks(omega_ctx, &hooks, &early_stop_hit); + if (ret != 0) { + omega_search_destroy(omega_search); + LOG_WARN("OMEGA adaptive search failed, falling back to HNSW"); + return HnswSearcher::search_impl(query, qmeta, count, context); + } + MaybeFlushOmegaPendingCandidates(&hook_state); + + omega_ctx->topk_to_result(q); + if (early_stop_hit) { + break; + } + query = static_cast(query) + qmeta.element_size(); + } + + // Get final statistics + int hops, cmps, collected_gt; + omega_search_ctx->GetStats(&hops, &cmps, &collected_gt); + LOG_DEBUG("OMEGA search completed: cmps=%d, hops=%d, results=%zu", cmps, hops, + omega_ctx->topk_heap().size()); + + // Cleanup + omega_search_destroy(omega_search); + + return 0; +} + +INDEX_FACTORY_REGISTER_SEARCHER(OmegaSearcher); + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/omega/omega_searcher.h b/src/core/algorithm/omega/omega_searcher.h new file mode 100644 index 000000000..29cbb93cf --- /dev/null +++ b/src/core/algorithm/omega/omega_searcher.h @@ -0,0 +1,85 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include "../hnsw/hnsw_searcher.h" + +namespace zvec { +namespace core { + +// OmegaSearcher owns the loaded OMEGA runtime on the searcher side: +// model loading, mode selection, HNSW-hook wiring, and per-search context +// creation. It reuses the HNSW search loop instead of defining an independent +// graph-search implementation. +class OmegaSearcher : public HnswSearcher { + public: + using ContextPointer = IndexSearcher::Context::Pointer; + + public: + OmegaSearcher(void); + ~OmegaSearcher(void); + + OmegaSearcher(const OmegaSearcher &) = delete; + OmegaSearcher &operator=(const OmegaSearcher &) = delete; + + protected: + //! Initialize Searcher + virtual int init(const ailego::Params ¶ms) override; + + //! Cleanup Searcher + virtual int cleanup(void) override; + + //! Load Index from storage + virtual int load(IndexStorage::Pointer container, + IndexMetric::Pointer metric) override; + + //! Unload index from storage + virtual int unload(void) override; + + //! KNN Search + virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, + ContextPointer &context) const override { + return search_impl(query, qmeta, 1, context); + } + + //! KNN Search with OMEGA adaptive search + virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, + ContextPointer &context) const override; + + //! Create a searcher context (creates OmegaContext instead of HnswContext) + virtual ContextPointer create_context() const override; + + private: + //! Check if OMEGA mode should be used + bool should_use_omega() const; + + //! Adaptive search with OMEGA predictions + int adaptive_search(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, ContextPointer &context) const; + + private: + // OMEGA components + OmegaModelHandle omega_model_; + bool omega_enabled_; + bool use_omega_mode_; + uint32_t min_vector_threshold_; + size_t current_vector_count_; + int window_size_; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/omega/omega_streamer.cc b/src/core/algorithm/omega/omega_streamer.cc new file mode 100644 index 000000000..9d2ad6087 --- /dev/null +++ b/src/core/algorithm/omega/omega_streamer.cc @@ -0,0 +1,435 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "omega_streamer.h" +#include +#include +#include +#include +#include +#include +#include "omega_context.h" +#include "omega_hook_utils.h" +#include "omega_params.h" +#include "../hnsw/hnsw_context.h" +#include "../hnsw/hnsw_entity.h" + +namespace zvec { +namespace core { + +namespace { + +struct OmegaHookSetup { + OmegaHookState state; + HnswAlgorithm::SearchHooks hooks; +}; + +OmegaHookSetup CreateOmegaHookSetup(omega::SearchContext *omega_search_ctx, + bool enable_early_stopping, + bool per_cmp_reporting) { + OmegaHookSetup setup; + setup.state.search_ctx = omega_search_ctx; + setup.state.enable_early_stopping = enable_early_stopping; + setup.state.per_cmp_reporting = per_cmp_reporting; + ResetOmegaHookState(&setup.state); + + setup.hooks.user_data = &setup.state; + setup.hooks.on_level0_entry = OnOmegaLevel0Entry; + setup.hooks.on_hop = OnOmegaHop; + setup.hooks.on_visit_candidate = OnOmegaVisitCandidate; + return setup; +} + +void EnableOmegaTrainingIfNeeded( + OmegaSearchHandle omega_search, int query_id, bool training_mode_enabled, + const std::vector> &training_ground_truth, + int training_k_train) { + if (!training_mode_enabled) { + return; + } + + std::vector gt_for_query; + if (query_id >= 0 && + static_cast(query_id) < training_ground_truth.size()) { + const auto > = training_ground_truth[query_id]; + gt_for_query.reserve(gt.size()); + for (uint64_t node_id : gt) { + gt_for_query.push_back(static_cast(node_id)); + } + } + + omega_search_enable_training(omega_search, query_id, gt_for_query.data(), + gt_for_query.size(), training_k_train); + LOG_DEBUG("Training mode enabled for query_id=%d with %zu GT nodes", query_id, + gt_for_query.size()); +} + +void CollectOmegaTrainingOutputs(OmegaSearchHandle omega_search, + OmegaContext *omega_ctx, int query_id) { + if (omega_ctx == nullptr) { + return; + } + + size_t record_count = omega_search_get_training_records_count(omega_search); + if (record_count > 0) { + const void *records_ptr = omega_search_get_training_records(omega_search); + const auto *records_vec = + static_cast *>(records_ptr); + + for (size_t i = 0; i < record_count; ++i) { + const auto &omega_record = (*records_vec)[i]; + core_interface::TrainingRecord record; + record.query_id = omega_record.query_id; + record.hops_visited = omega_record.hops_visited; + record.cmps_visited = omega_record.cmps_visited; + record.dist_1st = omega_record.dist_1st; + record.dist_start = omega_record.dist_start; + + if (omega_record.traversal_window_stats.size() == 7) { + std::copy(omega_record.traversal_window_stats.begin(), + omega_record.traversal_window_stats.end(), + record.traversal_window_stats.begin()); + } + + record.label = omega_record.label; + omega_ctx->add_training_record(std::move(record)); + } + + LOG_DEBUG("Collected %zu training records for query_id=%d", record_count, + query_id); + } + + size_t gt_cmps_count = omega_search_get_gt_cmps_count(omega_search); + if (gt_cmps_count == 0) { + return; + } + + const int *gt_cmps_ptr = omega_search_get_gt_cmps(omega_search); + int total_cmps = omega_search_get_total_cmps(omega_search); + if (gt_cmps_ptr == nullptr) { + return; + } + + std::vector gt_cmps_vec(gt_cmps_ptr, gt_cmps_ptr + gt_cmps_count); + for (auto &v : gt_cmps_vec) { + if (v < 0) { + v = total_cmps; + } + } + omega_ctx->set_gt_cmps(gt_cmps_vec, total_cmps); +} + +} // namespace + +bool OmegaStreamer::LoadModel(const std::string &model_dir) { + std::lock_guard lock(model_mutex_); + + if (omega_model_ != nullptr) { + omega_model_destroy(omega_model_); + omega_model_ = nullptr; + } + + omega_model_ = omega_model_create(); + if (omega_model_ == nullptr) { + LOG_ERROR("Failed to create OMEGA model manager"); + return false; + } + + if (omega_model_load(omega_model_, model_dir.c_str()) != 0) { + LOG_ERROR("Failed to load OMEGA model from %s", model_dir.c_str()); + omega_model_destroy(omega_model_); + omega_model_ = nullptr; + return false; + } + + LOG_INFO("OMEGA model loaded successfully from %s", model_dir.c_str()); + return true; +} + +bool OmegaStreamer::IsModelLoaded() const { + std::lock_guard lock(model_mutex_); + return omega_model_ != nullptr && omega_model_is_loaded(omega_model_); +} + +int OmegaStreamer::open(IndexStorage::Pointer stg) { + std::string index_path = stg ? stg->file_path() : ""; + + int ret = HnswStreamer::open(std::move(stg)); + if (ret != 0) { + return ret; + } + + { + std::lock_guard lock(model_mutex_); + if (omega_model_ != nullptr) { + omega_model_destroy(omega_model_); + omega_model_ = nullptr; + } + } + + if (index_path.empty()) { + LOG_WARN( + "OmegaStreamer open: storage file path is empty, using HNSW fallback"); + return 0; + } + + size_t last_slash = index_path.rfind('/'); + if (last_slash == std::string::npos) { + LOG_WARN( + "OmegaStreamer open: cannot derive omega_model path from index path %s", + index_path.c_str()); + return 0; + } + + std::string model_dir = index_path.substr(0, last_slash) + "/omega_model"; + std::string model_path = model_dir + "/model.txt"; + if (!ailego::File::IsExist(model_path)) { + LOG_INFO( + "OmegaStreamer open: no OMEGA model found at %s, using HNSW fallback", + model_dir.c_str()); + return 0; + } + + if (!LoadModel(model_dir)) { + LOG_WARN( + "OmegaStreamer open: failed to load OMEGA model from %s, using HNSW " + "fallback", + model_dir.c_str()); + } + + return 0; +} + +int OmegaStreamer::search_impl(const void *query, const IndexQueryMeta &qmeta, + Context::Pointer &context) const { + return search_impl(query, qmeta, 1, context); +} + +int OmegaStreamer::search_impl(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, + Context::Pointer &context) const { + // Determine mode: training (no early stopping) vs inference (with early + // stopping) + bool enable_early_stopping = !training_mode_enabled_ && IsModelLoaded() && + !DisableOmegaModelPrediction(); + + if (training_mode_enabled_) { + LOG_DEBUG("OmegaStreamer: training mode, early stopping DISABLED"); + } else if (enable_early_stopping) { + LOG_DEBUG("OmegaStreamer: inference mode with OMEGA model"); + } else { + LOG_DEBUG("OmegaStreamer: OMEGA hooks mode without model prediction"); + } + + return omega_search_impl(query, qmeta, count, context, enable_early_stopping); +} + +int OmegaStreamer::omega_search_impl(const void *query, + const IndexQueryMeta &qmeta, + uint32_t count, Context::Pointer &context, + bool enable_early_stopping) const { + (void)qmeta; + + // Cast context to OmegaContext to access training_query_id + auto *omega_ctx = dynamic_cast(context.get()); + int query_id = current_query_id_; // Default to member variable + if (omega_ctx != nullptr && omega_ctx->training_query_id() >= 0) { + query_id = omega_ctx->training_query_id(); + } + + // Get target recall from context if available + float target_recall = target_recall_; + if (omega_ctx != nullptr) { + target_recall = omega_ctx->target_recall(); + } + + // Cast context to HnswContext to access HNSW-specific features + auto *hnsw_ctx = dynamic_cast(context.get()); + if (hnsw_ctx == nullptr) { + LOG_ERROR("Context is not HnswContext"); + return IndexError_InvalidArgument; + } + + // Create OMEGA search context. + // OMEGA's k is the search top-k, not the batch/query count. + int omega_topk = static_cast(hnsw_ctx->topk()); + if (omega_topk <= 0) { + omega_topk = static_cast(count); + } + + // In training mode: model=nullptr (collect features only) + // In inference mode: model=omega_model_ (use for early stopping) + OmegaModelHandle model_to_use = + enable_early_stopping ? omega_model_ : nullptr; + + OmegaSearchHandle omega_search = omega_search_create_with_params( + model_to_use, target_recall, omega_topk, window_size_); + + if (omega_search == nullptr) { + LOG_ERROR("Failed to create OMEGA search context"); + return IndexError_Runtime; + } + omega::SearchContext *omega_search_ctx = + omega_search_get_cpp_context(omega_search); + if (omega_search_ctx == nullptr) { + omega_search_destroy(omega_search); + LOG_ERROR("Failed to get OMEGA search context"); + return IndexError_Runtime; + } + + // Training state is attached before the shared HNSW loop starts so label + // collection sees the full query trajectory. + EnableOmegaTrainingIfNeeded(omega_search, query_id, training_mode_enabled_, + training_ground_truth_, training_k_train_); + + // Rebind the context if it originated from a different searcher/streamer + // instance so the HNSW state matches this streamer before search begins. + if (hnsw_ctx->magic() != magic_) { + int ret = update_context(hnsw_ctx); + if (ret != 0) { + omega_search_destroy(omega_search); + return ret; + } + } + + // Initialize context for search + hnsw_ctx->clear(); + hnsw_ctx->update_dist_caculator_distance(search_distance_, + search_batch_distance_); + hnsw_ctx->resize_results(count); + hnsw_ctx->check_need_adjuct_ctx(entity_.doc_cnt()); + hnsw_ctx->reset_query(query); + OmegaHookSetup hook_setup = CreateOmegaHookSetup( + omega_search_ctx, enable_early_stopping, training_mode_enabled_); + bool early_stop_hit = false; + int ret = + alg_->search_with_hooks(hnsw_ctx, &hook_setup.hooks, &early_stop_hit); + if (ret != 0) { + omega_search_destroy(omega_search); + LOG_ERROR("OMEGA search failed"); + return ret; + } + MaybeFlushOmegaPendingCandidates(&hook_setup.state); + int hops = 0; + int cmps = 0; + omega_search_ctx->GetStats(&hops, &cmps, nullptr); + LOG_DEBUG( + "OMEGA search completed: cmps=%d, hops=%d, results=%zu, early_stop=%d", + cmps, hops, hnsw_ctx->topk_heap().size(), + (early_stop_hit || omega_search_ctx->EarlyStopHit()) ? 1 : 0); + + // Match HNSW timing semantics: result materialization is outside the + // search-core timer and happens after logging. + hnsw_ctx->topk_to_result(); + + if (training_mode_enabled_) { + CollectOmegaTrainingOutputs(omega_search, omega_ctx, query_id); + } + + omega_search_destroy(omega_search); + return 0; +} + +IndexStreamer::Context::Pointer OmegaStreamer::create_context(void) const { + if (ailego_unlikely(state_ != STATE_OPENED)) { + LOG_ERROR("Create OmegaContext failed, open storage first!"); + return Context::Pointer(); + } + + HnswEntity::Pointer entity = entity_.clone(); + if (ailego_unlikely(!entity)) { + LOG_ERROR("OmegaContext clone entity failed"); + return Context::Pointer(); + } + + OmegaContext *ctx = + new (std::nothrow) OmegaContext(meta_.dimension(), metric_, entity); + if (ailego_unlikely(ctx == nullptr)) { + LOG_ERROR("Failed to new OmegaContext"); + return Context::Pointer(); + } + + ctx->set_ef(ef_); + ctx->set_max_scan_limit(max_scan_limit_); + ctx->set_min_scan_limit(min_scan_limit_); + ctx->set_max_scan_ratio(max_scan_ratio_); + ctx->set_filter_mode(bf_enabled_ ? VisitFilter::BloomFilter + : VisitFilter::ByteMap); + ctx->set_filter_negative_probability(bf_negative_prob_); + ctx->set_magic(magic_); + ctx->set_force_padding_topk(force_padding_topk_enabled_); + ctx->set_bruteforce_threshold(bruteforce_threshold_); + + if (ailego_unlikely(ctx->init(HnswContext::kStreamerContext)) != 0) { + LOG_ERROR("Init OmegaContext failed"); + delete ctx; + return Context::Pointer(); + } + return Context::Pointer(ctx); +} + +int OmegaStreamer::dump(const IndexDumper::Pointer &dumper) { + LOG_INFO("OmegaStreamer dump"); + + shared_mutex_.lock(); + AILEGO_DEFER([&]() { shared_mutex_.unlock(); }); + + // Persist the OMEGA searcher params alongside the dumped index metadata so a + // reopened index reconstructs the same searcher-side behavior. + ailego::Params searcher_params; + const auto &streamer_params = meta_.streamer_params(); + + // Copy the omega.* params needed by OmegaSearcher::init(). + if (streamer_params.has("omega.enabled")) { + searcher_params.insert("omega.enabled", + streamer_params.get_as_bool("omega.enabled")); + } + if (streamer_params.has("omega.min_vector_threshold")) { + searcher_params.insert( + "omega.min_vector_threshold", + streamer_params.get_as_uint32("omega.min_vector_threshold")); + } + if (streamer_params.has("omega.window_size")) { + searcher_params.insert("omega.window_size", + streamer_params.get_as_int32("omega.window_size")); + } + + LOG_INFO( + "OmegaStreamer::dump: passing omega params to searcher " + "(enabled=%d, min_threshold=%u, window_size=%d)", + searcher_params.has("omega.enabled") + ? searcher_params.get_as_bool("omega.enabled") + : false, + searcher_params.has("omega.min_vector_threshold") + ? searcher_params.get_as_uint32("omega.min_vector_threshold") + : 0, + searcher_params.has("omega.window_size") + ? searcher_params.get_as_int32("omega.window_size") + : 0); + + meta_.set_searcher("OmegaSearcher", HnswEntity::kRevision, searcher_params); + + int ret = IndexHelper::SerializeToDumper(meta_, dumper.get()); + if (ret != 0) { + LOG_ERROR("Failed to serialize meta into dumper."); + return ret; + } + + return entity_.dump(dumper); +} + +INDEX_FACTORY_REGISTER_STREAMER(OmegaStreamer); + +} // namespace core +} // namespace zvec diff --git a/src/core/algorithm/omega/omega_streamer.h b/src/core/algorithm/omega/omega_streamer.h new file mode 100644 index 000000000..4c4775b0c --- /dev/null +++ b/src/core/algorithm/omega/omega_streamer.h @@ -0,0 +1,116 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include +#include +#include "omega_context.h" +#include "../hnsw/hnsw_streamer.h" + +namespace zvec { +namespace core { + +/** + * @brief OMEGA-aware HNSW streamer. + * + * Ownership boundary: + * - OmegaStreamer owns persisted-streamer concerns such as open/dump and the + * streamer-side search entry point used by zvec's index framework. + * - It carries training-mode/search-mode configuration that must travel with + * the loaded streamer instance. + * - It does not define the adaptive-stop policy; that lives in OMEGALib + * through OmegaSearcher/OmegaContext. + */ +class OmegaStreamer : public HnswStreamer { + public: + OmegaStreamer(void) : HnswStreamer(), omega_model_(nullptr) {} + virtual ~OmegaStreamer(void) { + if (omega_model_ != nullptr) { + omega_model_destroy(omega_model_); + omega_model_ = nullptr; + } + } + + OmegaStreamer(const OmegaStreamer &streamer) = delete; + OmegaStreamer &operator=(const OmegaStreamer &streamer) = delete; + + // Training-mode configuration forwarded into per-search contexts. + void EnableTrainingMode(bool enable) { + training_mode_enabled_ = enable; + } + void SetCurrentQueryId(int query_id) { + current_query_id_ = query_id; + } + void SetTrainingGroundTruth( + const std::vector> &ground_truth, int k_train = 1) { + training_ground_truth_ = ground_truth; + training_k_train_ = k_train; + } + + protected: + /** + * @brief Override search to use OMEGA adaptive search + * + * In training mode: collects features without early stopping + * In inference mode: uses OMEGA model for adaptive early stopping + */ + virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, + Context::Pointer &context) const override; + + virtual int search_impl(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, + Context::Pointer &context) const override; + + /** + * @brief Override open to auto-load omega_model from the index directory. + */ + virtual int open(IndexStorage::Pointer stg) override; + + /** + * @brief Override create_context to return OmegaContext + */ + virtual Context::Pointer create_context() const override; + + /** + * @brief Override dump to set "OmegaSearcher" instead of "HnswSearcher" + */ + virtual int dump(const IndexDumper::Pointer &dumper) override; + + private: + // Search-mode configuration shared across searches for this streamer. + bool LoadModel(const std::string &model_dir); + bool IsModelLoaded() const; + + // Perform OMEGA adaptive search (shared between training and inference mode) + int omega_search_impl(const void *query, const IndexQueryMeta &qmeta, + uint32_t count, Context::Pointer &context, + bool enable_early_stopping) const; + + // Training mode state + mutable bool training_mode_enabled_{false}; + mutable int current_query_id_{0}; + std::vector> training_ground_truth_; + int training_k_train_{1}; + + // Inference mode state + mutable OmegaModelHandle omega_model_{nullptr}; + mutable std::mutex model_mutex_; + float target_recall_{0.95f}; + int window_size_{100}; +}; + +} // namespace core +} // namespace zvec diff --git a/src/core/interface/CMakeLists.txt b/src/core/interface/CMakeLists.txt index 9e1a2f6b6..dd3b0d350 100644 --- a/src/core/interface/CMakeLists.txt +++ b/src/core/interface/CMakeLists.txt @@ -1,10 +1,37 @@ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) +file(GLOB CORE_INTERFACE_SRCS *.cc indexes/*.cc) +if(NOT ZVEC_ENABLE_OMEGA) + list(FILTER CORE_INTERFACE_SRCS EXCLUDE REGEX ".*/indexes/omega_index\\.cc$") + list(FILTER CORE_INTERFACE_SRCS EXCLUDE REGEX ".*/indexes/omega_training_session\\.cc$") +endif() + +set(CORE_INTERFACE_INCS + . + ${PROJECT_ROOT_DIR}/src/include + ${PROJECT_ROOT_DIR}/src/ + ${PROJECT_ROOT_DIR}/src/core) +if(ZVEC_ENABLE_OMEGA) + list(APPEND CORE_INTERFACE_INCS ${PROJECT_ROOT_DIR}/thirdparty/omega/OMEGALib/include) +endif() + +set(CORE_INTERFACE_LIBS + zvec_ailego + core_framework + core_mix_reducer + core_knn_hnsw_rabitq + sparsehash + magic_enum + rabitqlib) +if(ZVEC_ENABLE_OMEGA) + list(APPEND CORE_INTERFACE_LIBS core_knn_omega) +endif() + cc_library( NAME core_interface STATIC STRICT ALWAYS_LINK - SRCS *.cc indexes/*.cc - INCS . ${PROJECT_ROOT_DIR}/src/ ${PROJECT_ROOT_DIR}/src/core - LIBS zvec_ailego core_framework sparsehash magic_enum rabitqlib + SRCS ${CORE_INTERFACE_SRCS} + INCS ${CORE_INTERFACE_INCS} + LIBS ${CORE_INTERFACE_LIBS} VERSION "${PROXIMA_ZVEC_VERSION}" ) diff --git a/src/core/interface/index.cc b/src/core/interface/index.cc index d482f1292..f21e202f7 100644 --- a/src/core/interface/index.cc +++ b/src/core/interface/index.cc @@ -16,7 +16,8 @@ #include #include #include -#include "mixed_reducer/mixed_reducer_params.h" +#include "../mixed_reducer/mixed_reducer_params.h" +#include "../mixed_reducer/mixed_streamer_reducer.h" namespace zvec::core_interface { @@ -244,6 +245,9 @@ int Index::Init(const BaseIndexParam ¶m) { int Index::Open(const std::string &file_path, StorageOptions storage_options) { + // Store the file path for later use (e.g., in Merge for dump/reload) + file_path_ = file_path; + ailego::Params storage_params; // storage_params.set("proxima.mmap_file.storage.memory_warmup", true); // storage_params.set("proxima.mmap_file.storage.segment_meta_capacity", @@ -289,11 +293,14 @@ int Index::Open(const std::string &file_path, StorageOptions storage_options) { core::IndexError::What(ret)); return core::IndexError_Runtime; } + + if (streamer_ == nullptr || streamer_->open(storage_) != 0) { LOG_ERROR("Failed to open streamer, path: %s", file_path.c_str()); return core::IndexError_Runtime; } + // converter/reformer/metric are created in IndexFactory::CreateIndex // TODO: init @@ -411,6 +418,9 @@ int Index::Search(const VectorData &vector_data, return core::IndexError_Runtime; } + // Clear training records before search (context may be reused from pool) + context->clear_training_records(); + if (_prepare_for_search(vector_data, search_param, context) != 0) { LOG_ERROR("Failed to prepare for search"); context->reset(); @@ -673,6 +683,14 @@ int Index::_dense_search(const VectorData &vector_data, } } + // Extract training records from context (for OMEGA training mode) + result->training_records_ = context->take_training_records(); + + // Extract gt_cmps data from context (for OMEGA training mode) + result->gt_cmps_per_rank_ = context->take_gt_cmps(); + result->total_cmps_ = context->get_total_cmps(); + result->training_query_id_ = context->get_training_query_id(); + return 0; } @@ -794,6 +812,15 @@ int Index::Merge(const std::vector &indexes, LOG_ERROR("Failed to init reducer"); return core::IndexError_Runtime; } + + + // Set storage and file path for dump/reload operations + auto *mixed_reducer = + dynamic_cast(reducer.get()); + if (mixed_reducer != nullptr) { + mixed_reducer->set_storage(storage_, file_path_); + } + if (reducer->set_target_streamer_wiht_info(builder_, streamer_, converter_, reformer_, input_vector_meta_) != 0) { @@ -813,6 +840,16 @@ int Index::Merge(const std::vector &indexes, return core::IndexError_Runtime; } is_trained_ = true; + + // Generic training support: Check if this index supports training capability + // The actual training orchestration happens at the db layer (Segment level) + auto *training_capable = this->GetTrainingCapability(); + if (training_capable != nullptr) { + LOG_INFO( + "Index merge completed for trainable index, training can now be " + "performed"); + } + return 0; } diff --git a/src/core/interface/index_factory.cc b/src/core/interface/index_factory.cc index 0d8157282..88cccf6f1 100644 --- a/src/core/interface/index_factory.cc +++ b/src/core/interface/index_factory.cc @@ -43,6 +43,13 @@ Index::Pointer IndexFactory::CreateAndInitIndex(const BaseIndexParam ¶m) { ptr = std::make_shared(); } else if (param.index_type == IndexType::kHNSW) { ptr = std::make_shared(); + } else if (param.index_type == IndexType::kOMEGA) { +#if ZVEC_ENABLE_OMEGA + ptr = std::make_shared(); +#else + LOG_ERROR("OMEGA is not supported on this platform"); + return nullptr; +#endif } else if (param.index_type == IndexType::kIVF) { ptr = std::make_shared(); } else if (param.index_type == IndexType::kHNSWRabitq) { @@ -98,6 +105,46 @@ BaseIndexParam::Pointer IndexFactory::DeserializeIndexParamFromJson( } return param; } + case IndexType::kOMEGA: { + HNSWIndexParam::Pointer param = std::make_shared(); + auto deserialize_quantizer = [&](const ailego::JsonObject &obj) -> bool { + ailego::JsonValue quantizer_json_value; + if (obj.has("quantizer_param")) { + if (obj.get("quantizer_param", &quantizer_json_value); + quantizer_json_value.is_object()) { + if (!param->quantizer_param.DeserializeFromJson( + quantizer_json_value.as_json_string().as_stl_string())) { + return false; + } + } + } + return true; + }; + + if (!extract_enum_from_json( + json_obj, "metric_type", param->metric_type, tmp_json_value) || + !extract_enum_from_json(json_obj, "data_type", + param->data_type, tmp_json_value) || + !extract_value_from_json(json_obj, "dimension", param->dimension, + tmp_json_value) || + !extract_value_from_json(json_obj, "version", param->version, + tmp_json_value) || + !extract_value_from_json(json_obj, "is_sparse", param->is_sparse, + tmp_json_value) || + !extract_value_from_json(json_obj, "use_id_map", param->use_id_map, + tmp_json_value) || + !extract_value_from_json(json_obj, "is_huge_page", + param->is_huge_page, tmp_json_value) || + !extract_value_from_json(json_obj, "m", param->m, tmp_json_value) || + !extract_value_from_json(json_obj, "ef_construction", + param->ef_construction, tmp_json_value) || + !deserialize_quantizer(json_obj)) { + LOG_ERROR("Failed to deserialize omega index param"); + return nullptr; + } + param->index_type = IndexType::kOMEGA; + return param; + } case IndexType::kIVF: { IVFIndexParam::Pointer param = std::make_shared(); if (!param->DeserializeFromJson(json_str)) { @@ -153,6 +200,14 @@ std::string IndexFactory::QueryParamSerializeToJson(const QueryParamType ¶m, json_obj.set("ef_search", ailego::JsonValue(param.ef_search)); } index_type = IndexType::kHNSW; + } else if constexpr (std::is_same_v) { + if (!omit_empty_value || param.ef_search != 0) { + json_obj.set("ef_search", ailego::JsonValue(param.ef_search)); + } + if (!omit_empty_value || param.target_recall != 0.0f) { + json_obj.set("target_recall", ailego::JsonValue(param.target_recall)); + } + index_type = IndexType::kOMEGA; } else if constexpr (std::is_same_v) { if (!omit_empty_value || param.nprobe != 0) { json_obj.set("nprobe", ailego::JsonValue(param.nprobe)); @@ -182,6 +237,8 @@ template std::string IndexFactory::QueryParamSerializeToJson( const FlatQueryParam ¶m, bool omit_empty_value); template std::string IndexFactory::QueryParamSerializeToJson( const HNSWQueryParam ¶m, bool omit_empty_value); +template std::string IndexFactory::QueryParamSerializeToJson( + const OmegaQueryParam ¶m, bool omit_empty_value); template std::string IndexFactory::QueryParamSerializeToJson( const IVFQueryParam ¶m, bool omit_empty_value); @@ -250,6 +307,22 @@ typename QueryParamType::Pointer IndexFactory::QueryParamDeserializeFromJson( return nullptr; } return param; + } else if (index_type == IndexType::kOMEGA) { + auto param = std::make_shared(); + if (!parse_common_fields(param)) { + return nullptr; + } + if (!extract_value_from_json(json_obj, "ef_search", param->ef_search, + tmp_json_value)) { + LOG_ERROR("Failed to deserialize ef_search"); + return nullptr; + } + if (!extract_value_from_json(json_obj, "target_recall", + param->target_recall, tmp_json_value)) { + LOG_ERROR("Failed to deserialize target_recall"); + return nullptr; + } + return param; } else if (index_type == IndexType::kIVF) { auto param = std::make_shared(); if (!parse_common_fields(param)) { @@ -289,6 +362,17 @@ typename QueryParamType::Pointer IndexFactory::QueryParamDeserializeFromJson( LOG_ERROR("Failed to deserialize ef_search"); return nullptr; } + } else if constexpr (std::is_same_v) { + if (!extract_value_from_json(json_obj, "ef_search", param->ef_search, + tmp_json_value)) { + LOG_ERROR("Failed to deserialize ef_search"); + return nullptr; + } + if (!extract_value_from_json(json_obj, "target_recall", + param->target_recall, tmp_json_value)) { + LOG_ERROR("Failed to deserialize target_recall"); + return nullptr; + } } else if constexpr (std::is_same_v) { if (!extract_value_from_json(json_obj, "nprobe", param->nprobe, tmp_json_value)) { @@ -317,6 +401,8 @@ template FlatQueryParam::Pointer IndexFactory::QueryParamDeserializeFromJson< FlatQueryParam>(const std::string &json_str); template HNSWQueryParam::Pointer IndexFactory::QueryParamDeserializeFromJson< HNSWQueryParam>(const std::string &json_str); +template OmegaQueryParam::Pointer IndexFactory::QueryParamDeserializeFromJson< + OmegaQueryParam>(const std::string &json_str); template IVFQueryParam::Pointer IndexFactory::QueryParamDeserializeFromJson< IVFQueryParam>(const std::string &json_str); diff --git a/src/core/interface/indexes/hnsw_index.cc b/src/core/interface/indexes/hnsw_index.cc index 921db8ffa..360b56c13 100644 --- a/src/core/interface/indexes/hnsw_index.cc +++ b/src/core/interface/indexes/hnsw_index.cc @@ -17,6 +17,9 @@ #include #include "algorithm/hnsw/hnsw_params.h" #include "algorithm/hnsw_sparse/hnsw_sparse_params.h" +#if ZVEC_ENABLE_OMEGA +#include "algorithm/omega/omega_params.h" +#endif namespace zvec::core_interface { @@ -104,6 +107,14 @@ int HNSWIndex::_prepare_for_search( const int real_search_ef = std::max(1u, std::min(2048u, hnsw_search_param->ef_search)); params.set(core::PARAM_HNSW_STREAMER_EF, real_search_ef); + + if (hnsw_search_param->training_query_id >= 0) { +#if ZVEC_ENABLE_OMEGA + params.set(core::PARAM_OMEGA_SEARCHER_TRAINING_QUERY_ID, + hnsw_search_param->training_query_id); +#endif + } + context->update(params); return 0; } @@ -118,4 +129,4 @@ int HNSWIndex::_get_coarse_search_topk( return ret; } -} // namespace zvec::core_interface \ No newline at end of file +} // namespace zvec::core_interface diff --git a/src/core/interface/indexes/omega_index.cc b/src/core/interface/indexes/omega_index.cc new file mode 100644 index 000000000..3ec172065 --- /dev/null +++ b/src/core/interface/indexes/omega_index.cc @@ -0,0 +1,101 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "algorithm/hnsw/hnsw_params.h" +#include "algorithm/omega/omega_params.h" +#include "algorithm/omega/omega_streamer.h" +#include "omega_training_session.h" + +namespace zvec::core_interface { + +// OmegaIndex owns the framework-facing index lifecycle and delegates OMEGA- +// specific runtime behavior to OmegaStreamer/OmegaSearcher. It is responsible +// for creating the correct streamer and injecting OMEGA query params into the +// search context. It does not own the adaptive-search algorithm itself. +int OmegaIndex::CreateAndInitStreamer(const BaseIndexParam ¶m) { + // Reuse HNSWIndex setup so the HNSW-compatible on-disk/index metadata is + // initialized consistently before swapping in the OMEGA-aware streamer. + int ret = HNSWIndex::CreateAndInitStreamer(param); + if (ret != core::IndexError_Success) { + return ret; + } + + // OMEGA build/merge still happens through the streamer path. Keeping the + // HNSW builder path untouched avoids changing merge semantics. + core::IndexMeta saved_meta = proxima_index_meta_; + ailego::Params saved_params = proxima_index_params_; + + streamer_ = core::IndexFactory::CreateStreamer("OmegaStreamer"); + if (ailego_unlikely(!streamer_)) { + LOG_ERROR("Failed to create OmegaStreamer"); + return core::IndexError_Runtime; + } + + if (ailego_unlikely(streamer_->init(saved_meta, saved_params) != 0)) { + LOG_ERROR("Failed to init OmegaStreamer"); + return core::IndexError_Runtime; + } + + // Persist the OMEGA-aware searcher type so reopened indices route searches + // through OmegaSearcher instead of the plain HNSW searcher. + proxima_index_meta_.set_searcher("OmegaSearcher", 0, ailego::Params()); + + return core::IndexError_Success; +} + + +ITrainingSession::Pointer OmegaIndex::CreateTrainingSession() { + if (auto *omega_streamer = + streamer_ ? dynamic_cast(streamer_.get()) + : nullptr) { + return std::make_shared(omega_streamer); + } + return nullptr; +} + +int OmegaIndex::_prepare_for_search( + const VectorData &vector_data, + const BaseIndexQueryParam::Pointer &search_param, + core::IndexContext::Pointer &context) { + // First call parent for base HNSW parameter handling (ef_search, etc.) + int ret = HNSWIndex::_prepare_for_search(vector_data, search_param, context); + if (ret != 0) { + return ret; + } + + ailego::Params params; + + // Extract OMEGA-specific parameter (target_recall) + const auto &omega_search_param = + std::dynamic_pointer_cast(search_param); + if (omega_search_param) { + params.set(core::PARAM_OMEGA_SEARCHER_TARGET_RECALL, + omega_search_param->target_recall); + // Pass training_query_id for parallel training searches + if (omega_search_param->training_query_id >= 0) { + params.set(core::PARAM_OMEGA_SEARCHER_TRAINING_QUERY_ID, + omega_search_param->training_query_id); + } + } + + if (!params.empty()) { + context->update(params); + } + + return 0; +} + +} // namespace zvec::core_interface diff --git a/src/core/interface/indexes/omega_training_session.cc b/src/core/interface/indexes/omega_training_session.cc new file mode 100644 index 000000000..f2ec7297f --- /dev/null +++ b/src/core/interface/indexes/omega_training_session.cc @@ -0,0 +1,114 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "omega_training_session.h" +#include "algorithm/omega/omega_streamer.h" + +namespace zvec::core_interface { + +zvec::Status OmegaTrainingSession::Start(const TrainingSessionConfig &config) { + std::lock_guard lock(mutex_); + if (streamer_ == nullptr) { + return zvec::Status::InvalidArgument("Omega streamer is not available"); + } + + ResetArtifactsLocked(); + topk_ = config.topk; + num_queries_ = config.ground_truth.size(); + streamer_->SetTrainingGroundTruth(config.ground_truth, config.k_train); + streamer_->EnableTrainingMode(true); + active_ = true; + return zvec::Status::OK(); +} + +void OmegaTrainingSession::BeginQuery(int query_id) { + if (streamer_ != nullptr) { + streamer_->SetCurrentQueryId(query_id); + } +} + +void OmegaTrainingSession::CollectQueryArtifacts( + QueryTrainingArtifacts &&artifacts) { + std::lock_guard lock(mutex_); + if (!artifacts.records.empty()) { + records_.insert(records_.end(), + std::make_move_iterator(artifacts.records.begin()), + std::make_move_iterator(artifacts.records.end())); + } + if (!artifacts.gt_cmps_per_rank.empty() && artifacts.training_query_id >= 0) { + gt_cmps_map_[artifacts.training_query_id] = { + std::move(artifacts.gt_cmps_per_rank), artifacts.total_cmps}; + } +} + +TrainingArtifacts OmegaTrainingSession::ConsumeArtifacts() { + std::lock_guard lock(mutex_); + + TrainingArtifacts artifacts; + artifacts.records = std::move(records_); + + if (!gt_cmps_map_.empty()) { + size_t num_queries = num_queries_; + if (num_queries == 0) { + num_queries = static_cast(gt_cmps_map_.rbegin()->first) + 1; + } + size_t topk = topk_; + if (topk == 0) { + for (const auto &entry : gt_cmps_map_) { + if (!entry.second.first.empty()) { + topk = entry.second.first.size(); + break; + } + } + } + + artifacts.gt_cmps_data.num_queries = num_queries; + artifacts.gt_cmps_data.topk = topk; + artifacts.gt_cmps_data.gt_cmps.resize(num_queries); + artifacts.gt_cmps_data.total_cmps.resize(num_queries, 0); + for (size_t q = 0; q < num_queries; ++q) { + artifacts.gt_cmps_data.gt_cmps[q].resize(topk, 0); + } + for (const auto &entry : gt_cmps_map_) { + size_t query_id = static_cast(entry.first); + if (query_id >= num_queries) { + continue; + } + const auto &[gt_cmps_per_rank, total_cmps] = entry.second; + artifacts.gt_cmps_data.total_cmps[query_id] = total_cmps; + for (size_t rank = 0; rank < gt_cmps_per_rank.size() && rank < topk; + ++rank) { + artifacts.gt_cmps_data.gt_cmps[query_id][rank] = gt_cmps_per_rank[rank]; + } + } + } + + gt_cmps_map_.clear(); + return artifacts; +} + +void OmegaTrainingSession::Finish() { + std::lock_guard lock(mutex_); + if (active_ && streamer_ != nullptr) { + streamer_->EnableTrainingMode(false); + } + active_ = false; +} + +void OmegaTrainingSession::ResetArtifactsLocked() { + records_.clear(); + gt_cmps_map_.clear(); +} + +} // namespace zvec::core_interface diff --git a/src/core/interface/indexes/omega_training_session.h b/src/core/interface/indexes/omega_training_session.h new file mode 100644 index 000000000..8915ad551 --- /dev/null +++ b/src/core/interface/indexes/omega_training_session.h @@ -0,0 +1,56 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace zvec { +namespace core { +class OmegaStreamer; +} // namespace core +namespace core_interface { + +class OmegaTrainingSession : public ITrainingSession { + public: + explicit OmegaTrainingSession(core::OmegaStreamer *streamer) + : streamer_(streamer) {} + + zvec::Status Start(const TrainingSessionConfig &config) override; + + void BeginQuery(int query_id) override; + + void CollectQueryArtifacts(QueryTrainingArtifacts &&artifacts) override; + + TrainingArtifacts ConsumeArtifacts() override; + + void Finish() override; + + private: + void ResetArtifactsLocked(); + + core::OmegaStreamer *streamer_{nullptr}; + std::mutex mutex_; + size_t topk_{0}; + size_t num_queries_{0}; + bool active_{false}; + std::vector records_; + std::map, int>> gt_cmps_map_; +}; + +} // namespace core_interface +} // namespace zvec diff --git a/src/core/mixed_reducer/mixed_streamer_reducer.cc b/src/core/mixed_reducer/mixed_streamer_reducer.cc index bb84e3d6d..19b9e3fa2 100644 --- a/src/core/mixed_reducer/mixed_streamer_reducer.cc +++ b/src/core/mixed_reducer/mixed_streamer_reducer.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "mixed_streamer_reducer.h" +#include #include #include #include @@ -141,7 +142,8 @@ int MixedStreamerReducer::reduce(const IndexFilter &filter) { ailego::ElapsedTime timer; - std::vector add_results(num_of_add_threads_, -1); + const size_t add_thread_count = enable_pk_rewrite_ ? 1 : num_of_add_threads_; + std::vector add_results(add_thread_count, -1); auto add_group = thread_pool_->make_group(); std::vector read_results(streamers_.size(), -1); @@ -149,7 +151,7 @@ int MixedStreamerReducer::reduce(const IndexFilter &filter) { uint32_t id_offset = 0, next_id = 0; if (is_sparse_) { - for (size_t i = 0; i < num_of_add_threads_; i++) { + for (size_t i = 0; i < add_thread_count; i++) { add_group->submit(ailego::Closure::New( this, &MixedStreamerReducer::add_sparse_vec, &add_results[i])); } @@ -162,7 +164,7 @@ int MixedStreamerReducer::reduce(const IndexFilter &filter) { sparse_mt_list_.done(); } else { - for (size_t i = 0; i < num_of_add_threads_; i++) { + for (size_t i = 0; i < add_thread_count; i++) { add_group->submit(ailego::Closure::New( this, &MixedStreamerReducer::add_vec, &add_results[i])); // add_vec(&add_results[i]); @@ -194,8 +196,67 @@ int MixedStreamerReducer::reduce(const IndexFilter &filter) { stats_.set_reduced_costtime(timer.seconds()); state_ = STATE_REDUCE; + + if (target_builder_ != nullptr) { IndexBuild(); + + // Best-effort persistence hook for builder-backed indexes. Some newer + // flows want the built graph persisted immediately after reduce(), but + // legacy paths such as IVF already perform their own dump/reopen later in + // the merge flow. Missing storage context must not break those existing + // paths. + + if (target_storage_ == nullptr) { + LOG_WARN("target_storage_ is null, skip dump/reload hook"); + LOG_INFO("End brute force reduce. cost time: [%zu]s", + (size_t)timer.seconds()); + return 0; + } + + if (target_file_path_.empty()) { + LOG_WARN("target_file_path_ is empty, skip dump/reload hook"); + LOG_INFO("End brute force reduce. cost time: [%zu]s", + (size_t)timer.seconds()); + return 0; + } + + + // Create a FileDumper that writes to the file + auto dumper = IndexFactory::CreateDumper("FileDumper"); + if (dumper == nullptr) { + LOG_ERROR("Failed to create dumper"); + return IndexError_Runtime; + } + + // Initialize the dumper with the file path + int ret = dumper->create(target_file_path_); + if (ret != 0) { + LOG_ERROR("Failed to create dumper at path=%s, ret=%d", + target_file_path_.c_str(), ret); + return ret; + } + + // Dump the builder's entity to the file + ret = target_builder_->dump(dumper); + if (ret != 0) { + LOG_ERROR("Failed to dump builder, ret=%d", ret); + return ret; + } + + // Close the dumper to flush data + ret = dumper->close(); + if (ret != 0) { + LOG_ERROR("Failed to close dumper, ret=%d", ret); + return ret; + } + + + // NOTE: We cannot safely reload the streamer here (close/open causes + // crashes). The streamer will properly load data when the collection is + // reopened. For now, auto-training will need to handle the case where + // streamer doc_count=0. + } else { } LOG_INFO("End brute force reduce. cost time: [%zu]s", @@ -245,6 +306,7 @@ int MixedStreamerReducer::read_vec(size_t source_streamer_index, IndexProvider::Pointer provider = streamer->create_provider(); IndexProvider::Iterator::Pointer iterator = provider->create_iterator(); + std::vector>> pending_items; while (iterator->is_valid()) { if (stop_flag_ != nullptr && stop_flag_->load(std::memory_order_relaxed)) { @@ -273,13 +335,18 @@ int MixedStreamerReducer::read_vec(size_t source_streamer_index, memcpy(bytes.data(), iterator->data(), bytes.size()); } - // TODO: use id instead of key - if (!mt_list_.produce(VectorItem((*next_id)++, std::move(bytes)))) { - LOG_ERROR("Produce vector to queue failed. key[%lu]", - (size_t)iterator->key()); + pending_items.emplace_back(iterator->key() + id_offset, std::move(bytes)); + iterator->next(); + } + + std::sort( + pending_items.begin(), pending_items.end(), + [](const auto &lhs, const auto &rhs) { return lhs.first < rhs.first; }); + for (auto &item : pending_items) { + if (!mt_list_.produce(VectorItem((*next_id)++, std::move(item.second)))) { + LOG_ERROR("Produce vector to queue failed. key[%u]", item.first); return IndexError_Runtime; } - iterator->next(); } return 0; } @@ -449,6 +516,7 @@ int MixedStreamerReducer::read_sparse_vec(size_t source_streamer_index, streamer->create_sparse_provider(); IndexStreamer::SparseProvider::Iterator::Pointer iterator = provider->create_iterator(); + std::vector pending_items; while (iterator->is_valid()) { if (stop_flag_ != nullptr && stop_flag_->load(std::memory_order_relaxed)) { @@ -488,15 +556,24 @@ int MixedStreamerReducer::read_sparse_vec(size_t source_streamer_index, memcpy(sparse_indices.data(), iterator->sparse_indices(), sparse_indices.size() * sizeof(uint32_t)); - // TODO: use id instead of key - if (!sparse_mt_list_.produce(SparseVectorItem((*next_id)++, - std::move(sparse_indices), - std::move(sparse_values)))) { + pending_items.emplace_back(iterator->key() + id_offset, + std::move(sparse_indices), + std::move(sparse_values)); + iterator->next(); + } + + std::sort(pending_items.begin(), pending_items.end(), + [](const SparseVectorItem &lhs, const SparseVectorItem &rhs) { + return lhs.pkey_ < rhs.pkey_; + }); + for (auto &item : pending_items) { + if (!sparse_mt_list_.produce( + SparseVectorItem((*next_id)++, std::move(item.sparse_indices_), + std::move(item.sparse_values_)))) { LOG_ERROR("Produce vector to queue failed. key[%lu]", - (size_t)iterator->key()); + static_cast(item.pkey_)); return IndexError_Runtime; } - iterator->next(); } return 0; } @@ -569,8 +646,12 @@ int MixedStreamerReducer::IndexBuild() { target_holder); target_holder = target_builder_converter_->result(); } + target_builder_->train(target_holder); + target_builder_->build(target_holder); + + return 0; } diff --git a/src/core/mixed_reducer/mixed_streamer_reducer.h b/src/core/mixed_reducer/mixed_streamer_reducer.h index ec4c62406..aba9e5c14 100644 --- a/src/core/mixed_reducer/mixed_streamer_reducer.h +++ b/src/core/mixed_reducer/mixed_streamer_reducer.h @@ -51,6 +51,13 @@ class MixedStreamerReducer : public IndexStreamerReducer { const IndexConverter::Pointer converter, const IndexReformer::Pointer reformer, const IndexQueryMeta &original_query_meta) override; + + // Set the storage and file path for dump/reload operations + void set_storage(const IndexStorage::Pointer &storage, + const std::string &file_path) { + target_storage_ = storage; + target_file_path_ = file_path; + } // feed_streamer int feed_streamer_with_reformer( IndexStreamer::Pointer streamer, @@ -105,6 +112,8 @@ class MixedStreamerReducer : public IndexStreamerReducer { IndexBuilder::Pointer target_builder_{nullptr}; IndexConverter::Pointer target_builder_converter_{nullptr}; + IndexStorage::Pointer target_storage_{nullptr}; + std::string target_file_path_; std::mutex mutex_{}; std::vector> doc_cache_; const uint64_t kInvalidKey = std::numeric_limits::max(); diff --git a/src/core/utility/buffer_storage.cc b/src/core/utility/buffer_storage.cc index a20a03160..ba82bc764 100644 --- a/src/core/utility/buffer_storage.cc +++ b/src/core/utility/buffer_storage.cc @@ -413,6 +413,15 @@ class BufferStorage : public IndexStorage { return header_.magic; } + //! Retrieve file path of storage + std::string file_path(void) const override { + return file_name_; + } + + uint32_t get_context_offset() { + return header_.content_offset; + } + protected: //! Initialize index version segment int init_version_segment(void) { diff --git a/src/core/utility/mmap_file_storage.cc b/src/core/utility/mmap_file_storage.cc index 9a1261f4f..cf18dc2a3 100644 --- a/src/core/utility/mmap_file_storage.cc +++ b/src/core/utility/mmap_file_storage.cc @@ -172,6 +172,7 @@ class MMapFileStorage : public IndexStorage { //! Open storage int open(const std::string &path, bool create_if_missing) override { + file_path_ = path; if (!ailego::File::IsExist(path) && create_if_missing) { size_t last_slash = path.rfind('/'); if (last_slash != std::string::npos) { @@ -231,6 +232,11 @@ class MMapFileStorage : public IndexStorage { return mapping_.magic(); } + //! Retrieve file path of storage + std::string file_path(void) const override { + return file_path_; + } + protected: //! Initialize index version segment int init_version_segment(void) { @@ -328,6 +334,7 @@ class MMapFileStorage : public IndexStorage { } private: + std::string file_path_{}; // Store the file path for retrieval uint32_t segment_meta_capacity_{1024 * 1024}; bool copy_on_write_{false}; bool force_flush_{false}; diff --git a/src/core/utility/steady_clock_timer.cc b/src/core/utility/steady_clock_timer.cc new file mode 100644 index 000000000..7232e6166 --- /dev/null +++ b/src/core/utility/steady_clock_timer.cc @@ -0,0 +1,31 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "utility/steady_clock_timer.h" + +namespace zvec { +namespace core { + +SteadyClockTimer::tick_t SteadyClockTimer::Now() { + const auto now = std::chrono::steady_clock::now().time_since_epoch(); + return static_cast( + std::chrono::duration_cast(now).count()); +} + +uint64_t SteadyClockTimer::ElapsedNs(tick_t start, tick_t end) { + return end > start ? (end - start) : 0; +} + +} // namespace core +} // namespace zvec diff --git a/src/core/utility/steady_clock_timer.h b/src/core/utility/steady_clock_timer.h new file mode 100644 index 000000000..96da46b0b --- /dev/null +++ b/src/core/utility/steady_clock_timer.h @@ -0,0 +1,35 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef ZVEC_CORE_UTILITY_STEADY_CLOCK_TIMER_H_ +#define ZVEC_CORE_UTILITY_STEADY_CLOCK_TIMER_H_ + +#include +#include + +namespace zvec { +namespace core { + +class SteadyClockTimer { + public: + using tick_t = uint64_t; + + static tick_t Now(); + static uint64_t ElapsedNs(tick_t start, tick_t end); +}; + +} // namespace core +} // namespace zvec + +#endif // ZVEC_CORE_UTILITY_STEADY_CLOCK_TIMER_H_ diff --git a/src/db/CMakeLists.txt b/src/db/CMakeLists.txt index b2689278a..c2d541435 100644 --- a/src/db/CMakeLists.txt +++ b/src/db/CMakeLists.txt @@ -12,11 +12,20 @@ cc_directory(index) cc_directory(sqlengine) file(GLOB_RECURSE ALL_DB_SRCS *.cc *.c *.h) +if(NOT ZVEC_ENABLE_OMEGA) + list(FILTER ALL_DB_SRCS EXCLUDE REGEX ".*/training/omega_model_trainer\\.cc$") + list(FILTER ALL_DB_SRCS EXCLUDE REGEX ".*/training/omega_training_coordinator\\.cc$") +endif() + +set(ZVEC_DB_INCS . ${CMAKE_CURRENT_BINARY_DIR}) +if(ZVEC_ENABLE_OMEGA) + list(APPEND ZVEC_DB_INCS ${PROJECT_ROOT_DIR}/thirdparty/omega/OMEGALib/include) +endif() cc_library( NAME zvec_db STATIC STRICT SRCS_NO_GLOB PACKED SRCS ${ALL_DB_SRCS} ${CMAKE_CURRENT_BINARY_DIR}/proto/zvec.pb.cc - INCS . ${CMAKE_CURRENT_BINARY_DIR} + INCS ${ZVEC_DB_INCS} PUBINCS ${PROJECT_ROOT_DIR}/src/include LIBS zvec_ailego @@ -32,4 +41,4 @@ cc_library( Arrow::arrow_acero DEPS zvec_proto VERSION "${PROXIMA_ZVEC_VERSION}" -) \ No newline at end of file +) diff --git a/src/db/collection.cc b/src/db/collection.cc index 13df9ca31..792c1cda6 100644 --- a/src/db/collection.cc +++ b/src/db/collection.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -43,6 +44,8 @@ #include "db/index/segment/segment_helper.h" #include "db/index/segment/segment_manager.h" #include "db/sqlengine/sqlengine.h" +#include "db/training/omega_model_trainer.h" +#include "db/training/training_data_collector.h" namespace zvec { @@ -820,7 +823,21 @@ Status CollectionImpl::Optimize(const OptimizeOptions &options) { return Status::OK(); } - // build segment compact task + if (options.retrain_only_) { + LOG_WARN( + "Optimize running in OMEGA retrain-only mode on %zu persisted segments", + persist_segments.size()); + for (auto &segment : persist_segments) { + auto s = segment->retrain_omega_model(); + CHECK_RETURN_STATUS(s); + } + return Status::OK(); + } + + // Build optimize tasks once so compacted segments are merged directly from + // their current per-segment sources. Pre-building filtered vector indexes for + // every persisted segment would shift source row ids before compaction and + // break alignment with scalar row-id filters. auto delete_store_clone = delete_store_->clone(); auto tasks = build_compact_task(schema_, persist_segments, options.concurrency_, diff --git a/src/db/index/CMakeLists.txt b/src/db/index/CMakeLists.txt index 4420050e6..d08b54dec 100644 --- a/src/db/index/CMakeLists.txt +++ b/src/db/index/CMakeLists.txt @@ -1,16 +1,37 @@ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) -cc_library( - NAME zvec_index STATIC STRICT - SRCS *.cc segment/*.cc column/vector_column/*.cc column/inverted_column/*.cc storage/*.cc storage/wal/*.cc common/*.cc - LIBS zvec_common +file(GLOB ZVEC_INDEX_SRCS + *.cc + segment/*.cc + column/vector_column/*.cc + column/inverted_column/*.cc + storage/*.cc + storage/wal/*.cc + common/*.cc + ../training/*.cc) + +if(NOT ZVEC_ENABLE_OMEGA) + list(FILTER ZVEC_INDEX_SRCS EXCLUDE REGEX ".*/training/omega_model_trainer\\.cc$") + list(FILTER ZVEC_INDEX_SRCS EXCLUDE REGEX ".*/training/omega_training_coordinator\\.cc$") +endif() + +set(ZVEC_INDEX_LIBS + zvec_common zvec_proto rocksdb core_interface Arrow::arrow_static Arrow::arrow_compute - Arrow::arrow_dataset + Arrow::arrow_dataset) +if(ZVEC_ENABLE_OMEGA) + list(APPEND ZVEC_INDEX_LIBS omega) +endif() + +cc_library( + NAME zvec_index STATIC STRICT + SRCS ${ZVEC_INDEX_SRCS} + LIBS ${ZVEC_INDEX_LIBS} INCS . ${PROJECT_ROOT_DIR}/src VERSION "${PROXIMA_ZVEC_VERSION}" ) diff --git a/src/db/index/column/vector_column/engine_helper.hpp b/src/db/index/column/vector_column/engine_helper.hpp index e7e323f0c..09063b32b 100644 --- a/src/db/index/column/vector_column/engine_helper.hpp +++ b/src/db/index/column/vector_column/engine_helper.hpp @@ -162,10 +162,36 @@ class ProximaEngineHelper { auto db_hnsw_query_params = dynamic_cast( query_params.query_params.get()); hnsw_query_param->ef_search = db_hnsw_query_params->ef(); + hnsw_query_param->training_query_id = + db_hnsw_query_params->training_query_id(); } return std::move(hnsw_query_param); } + case IndexType::OMEGA: { + auto omega_query_param_result = + _build_common_query_param( + query_params); + if (!omega_query_param_result.has_value()) { + return tl::make_unexpected(Status::InvalidArgument( + "failed to build query param: " + + omega_query_param_result.error().message())); + } + auto &omega_query_param = omega_query_param_result.value(); + if (query_params.query_params) { + if (auto *db_omega_query_params = + dynamic_cast( + query_params.query_params.get())) { + omega_query_param->ef_search = db_omega_query_params->ef(); + omega_query_param->target_recall = + db_omega_query_params->target_recall(); + omega_query_param->training_query_id = + db_omega_query_params->training_query_id(); + } + } + return std::move(omega_query_param); + } + case IndexType::HNSW_RABITQ: { auto hnsw_query_param_result = _build_common_query_param( @@ -350,6 +376,36 @@ class ProximaEngineHelper { return index_param_builder->Build(); } + case IndexType::OMEGA: { + // OMEGA uses its own index type at core_interface level + auto index_param_builder_result = + _build_common_index_param( + field_schema); + if (!index_param_builder_result.has_value()) { + return tl::make_unexpected(Status::InvalidArgument( + "failed to build index param: " + + index_param_builder_result.error().message())); + } + auto index_param_builder = index_param_builder_result.value(); + + auto db_index_params = dynamic_cast( + field_schema.index_params().get()); + index_param_builder->WithM(db_index_params->m()); + index_param_builder->WithEFConstruction( + db_index_params->ef_construction()); + + // Override index_type to kOMEGA and attach OMEGA-specific config. + auto hnsw_param = index_param_builder->Build(); + hnsw_param->index_type = core_interface::IndexType::kOMEGA; + hnsw_param->params.insert("omega.enabled", true); + hnsw_param->params.insert("omega.min_vector_threshold", + db_index_params->min_vector_threshold()); + hnsw_param->params.insert("omega.window_size", + db_index_params->window_size()); + return hnsw_param; + } + case IndexType::HNSW_RABITQ: { auto index_param_builder_result = _build_common_index_param< HnswRabitqIndexParams, core_interface::HNSWRabitqIndexParamBuilder>( @@ -400,4 +456,4 @@ class ProximaEngineHelper { } } }; -}; // namespace zvec \ No newline at end of file +}; // namespace zvec diff --git a/src/db/index/column/vector_column/vector_column_indexer.cc b/src/db/index/column/vector_column/vector_column_indexer.cc index 1859e2490..7d284a262 100644 --- a/src/db/index/column/vector_column/vector_column_indexer.cc +++ b/src/db/index/column/vector_column/vector_column_indexer.cc @@ -43,6 +43,7 @@ Status VectorColumnIndexer::CreateProximaIndex( } auto &index_param = index_param_result.value(); + // Use IndexFactory for all index types (including OMEGA) index = core_interface::IndexFactory::CreateAndInitIndex(*index_param); if (index == nullptr) { return Status::InternalError("Failed to create index"); @@ -130,6 +131,7 @@ Status VectorColumnIndexer::Merge( {merge_options.write_concurrency, merge_options.pool})) { return Status::InternalError("Failed to merge index"); } + return Status::OK(); } @@ -198,6 +200,30 @@ Result VectorColumnIndexer::Search( Status::InternalError("Failed to search vector")); } + core_interface::ITrainingSession::Pointer training_session; + { + std::lock_guard lock(training_mutex_); + training_session = training_session_; + } + + if (training_session != nullptr) { + LOG_INFO( + "VectorColumnIndexer training search: query_id=%d records=%zu " + "gt_cmps=%zu total_cmps=%d", + search_result.training_query_id_, + search_result.training_records_.size(), + search_result.gt_cmps_per_rank_.size(), search_result.total_cmps_); + } + + if (training_session != nullptr) { + core_interface::QueryTrainingArtifacts artifacts; + artifacts.records = std::move(search_result.training_records_); + artifacts.gt_cmps_per_rank = std::move(search_result.gt_cmps_per_rank_); + artifacts.total_cmps = search_result.total_cmps_; + artifacts.training_query_id = search_result.training_query_id_; + training_session->CollectQueryArtifacts(std::move(artifacts)); + } + auto result = std::make_shared( is_sparse_, std::move(search_result.doc_list_), std::move(search_result.reverted_vector_list_), @@ -205,4 +231,33 @@ Result VectorColumnIndexer::Search( return result; } +core_interface::ITrainingCapable *VectorColumnIndexer::GetTrainingCapability() + const { + if (index != nullptr) { + return index->GetTrainingCapability(); + } + return nullptr; +} + +core_interface::ITrainingSession::Pointer +VectorColumnIndexer::CreateTrainingSession() const { + if (index != nullptr) { + if (auto *training_capable = index->GetTrainingCapability()) { + return training_capable->CreateTrainingSession(); + } + } + return nullptr; +} + +void VectorColumnIndexer::SetTrainingSession( + const core_interface::ITrainingSession::Pointer &session) { + std::lock_guard lock(training_mutex_); + training_session_ = session; +} + +void VectorColumnIndexer::ClearTrainingSession() { + std::lock_guard lock(training_mutex_); + training_session_.reset(); +} + } // namespace zvec diff --git a/src/db/index/column/vector_column/vector_column_indexer.h b/src/db/index/column/vector_column/vector_column_indexer.h index 0006080ec..9ec2ebab9 100644 --- a/src/db/index/column/vector_column/vector_column_indexer.h +++ b/src/db/index/column/vector_column/vector_column_indexer.h @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once +#include +#include #include #include #include @@ -20,6 +22,8 @@ #include #include #include +#include +#include #include #include #include "db/common/constants.h" @@ -89,6 +93,23 @@ class VectorColumnIndexer { // Result BatchFetch(const std::vector &doc_ids) // const; + public: + // OMEGA Training Mode Support + /** + * @brief Check if the underlying index supports training capability. + * + * @return Pointer to ITrainingCapable interface if supported, nullptr + * otherwise + */ + core_interface::ITrainingCapable *GetTrainingCapability() const; + + core_interface::ITrainingSession::Pointer CreateTrainingSession() const; + + void SetTrainingSession( + const core_interface::ITrainingSession::Pointer &session); + + void ClearTrainingSession(); + core::IndexProvider::Pointer create_index_provider() const { return index->create_index_provider(); } @@ -98,6 +119,10 @@ class VectorColumnIndexer { return index_file_path_; } + core_interface::Index::Pointer core_index() const { + return index; + } + size_t doc_count() const { if (index == nullptr) { return -1; @@ -105,6 +130,18 @@ class VectorColumnIndexer { return index->GetDocCount(); } + MetricType metric_type() const { + auto index_params = field_schema_.index_params(); + if (index_params) { + auto vector_params = + std::dynamic_pointer_cast(index_params); + if (vector_params) { + return vector_params->metric_type(); + } + } + return MetricType::IP; // default + } + // for ut protected: VectorColumnIndexer() = default; @@ -128,6 +165,9 @@ class VectorColumnIndexer { std::string engine_name_ = "proxima"; bool is_sparse_{false}; // TODO: eliminate the dynamic flag and make it // static/template/seperate class + + mutable std::mutex training_mutex_; + core_interface::ITrainingSession::Pointer training_session_; }; diff --git a/src/db/index/column/vector_column/vector_index_results.h b/src/db/index/column/vector_column/vector_index_results.h index 154ff6e42..f824b651d 100644 --- a/src/db/index/column/vector_column/vector_index_results.h +++ b/src/db/index/column/vector_column/vector_index_results.h @@ -90,13 +90,6 @@ class VectorIndexResults : public IndexResults { friend class VectorIterator; public: - // VectorIndexResults(core::IndexDocumentList &&doc_list) - // : docs_(std::move(doc_list)) {} - // - // VectorIndexResults(core::IndexDocumentList &&doc_list, - // std::vector &&reverted_vector_list) - // : docs_(std::move(doc_list)), - // reverted_vector_list_(std::move(reverted_vector_list)) {} VectorIndexResults(bool is_sparse, core::IndexDocumentList &&doc_list, std::vector &&reverted_vector_list, std::vector &&reverted_sparse_values_list) diff --git a/src/db/index/common/proto_converter.cc b/src/db/index/common/proto_converter.cc index 46eb93f5a..b6ac9c100 100644 --- a/src/db/index/common/proto_converter.cc +++ b/src/db/index/common/proto_converter.cc @@ -101,6 +101,37 @@ proto::IVFIndexParams ProtoConverter::ToPb(const IVFIndexParams *params) { return params_pb; } +// OmegaIndexParams +OmegaIndexParams::OPtr ProtoConverter::FromPb( + const proto::OmegaIndexParams ¶ms_pb) { + auto params = std::make_shared( + MetricTypeCodeBook::Get(params_pb.base().metric_type()), params_pb.m(), + params_pb.ef_construction(), + QuantizeTypeCodeBook::Get(params_pb.base().quantize_type()), + params_pb.min_vector_threshold(), params_pb.num_training_queries(), + params_pb.ef_training(), params_pb.window_size(), + params_pb.ef_groundtruth(), params_pb.k_train()); + + return params; +} + +proto::OmegaIndexParams ProtoConverter::ToPb(const OmegaIndexParams *params) { + proto::OmegaIndexParams params_pb; + params_pb.mutable_base()->set_metric_type( + MetricTypeCodeBook::Get(params->metric_type())); + params_pb.mutable_base()->set_quantize_type( + QuantizeTypeCodeBook::Get(params->quantize_type())); + params_pb.set_ef_construction(params->ef_construction()); + params_pb.set_m(params->m()); + params_pb.set_min_vector_threshold(params->min_vector_threshold()); + params_pb.set_num_training_queries(params->num_training_queries()); + params_pb.set_ef_training(params->ef_training()); + params_pb.set_window_size(params->window_size()); + params_pb.set_ef_groundtruth(params->ef_groundtruth()); + params_pb.set_k_train(params->k_train()); + return params_pb; +} + // InvertIndexParams InvertIndexParams::OPtr ProtoConverter::FromPb( const proto::InvertIndexParams ¶ms_pb) { @@ -183,6 +214,8 @@ IndexParams::Ptr ProtoConverter::FromPb(const proto::IndexParams ¶ms_pb) { return ProtoConverter::FromPb(params_pb.ivf()); } else if (params_pb.has_flat()) { return ProtoConverter::FromPb(params_pb.flat()); + } else if (params_pb.has_omega()) { + return ProtoConverter::FromPb(params_pb.omega()); } else if (params_pb.has_hnsw_rabitq()) { return ProtoConverter::FromPb(params_pb.hnsw_rabitq()); } @@ -239,6 +272,13 @@ proto::IndexParams ProtoConverter::ToPb(const IndexParams *params) { } break; } + case IndexType::OMEGA: { + auto omega_params = dynamic_cast(params); + if (omega_params) { + params_pb.mutable_omega()->CopyFrom(ProtoConverter::ToPb(omega_params)); + } + break; + } case IndexType::HNSW_RABITQ: { auto hnsw_rabitq_params = dynamic_cast(params); @@ -246,6 +286,7 @@ proto::IndexParams ProtoConverter::ToPb(const IndexParams *params) { params_pb.mutable_hnsw_rabitq()->CopyFrom( ProtoConverter::ToPb(hnsw_rabitq_params)); } + break; } default: break; @@ -315,4 +356,4 @@ proto::SegmentMeta ProtoConverter::ToPb(const SegmentMeta &meta) { return meta_pb; } -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/db/index/common/proto_converter.h b/src/db/index/common/proto_converter.h index ad96007a4..f6859487d 100644 --- a/src/db/index/common/proto_converter.h +++ b/src/db/index/common/proto_converter.h @@ -38,6 +38,11 @@ struct ProtoConverter { static IVFIndexParams::OPtr FromPb(const proto::IVFIndexParams ¶ms_pb); static proto::IVFIndexParams ToPb(const IVFIndexParams *params); + // OmegaIndexParams + static OmegaIndexParams::OPtr FromPb( + const proto::OmegaIndexParams ¶ms_pb); + static proto::OmegaIndexParams ToPb(const OmegaIndexParams *params); + // InvertIndexParams static InvertIndexParams::OPtr FromPb( const proto::InvertIndexParams ¶ms_pb); diff --git a/src/db/index/common/schema.cc b/src/db/index/common/schema.cc index 971de61d6..2b1f01209 100644 --- a/src/db/index/common/schema.cc +++ b/src/db/index/common/schema.cc @@ -54,7 +54,8 @@ std::unordered_set support_sparse_vector_type = { }; std::unordered_set support_dense_vector_index = { - IndexType::FLAT, IndexType::HNSW, IndexType::HNSW_RABITQ, IndexType::IVF}; + IndexType::FLAT, IndexType::HNSW, IndexType::HNSW_RABITQ, IndexType::IVF, + IndexType::OMEGA}; std::unordered_set support_sparse_vector_index = {IndexType::FLAT, IndexType::HNSW}; @@ -169,6 +170,12 @@ Status FieldSchema::validate() const { } } +#if !ZVEC_ENABLE_OMEGA + if (index_params_->type() == IndexType::OMEGA) { + return Status::NotSupported("OMEGA is not supported on Android"); + } +#endif + if (vector_index_params->quantize_type() != QuantizeType::UNDEFINED) { auto iter = quantize_type_map.find(data_type_); @@ -601,4 +608,4 @@ bool CollectionSchema::has_index(const std::string &column) const { return false; } -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/db/index/segment/segment.cc b/src/db/index/segment/segment.cc index 821d236e3..00a2c56ed 100644 --- a/src/db/index/segment/segment.cc +++ b/src/db/index/segment/segment.cc @@ -14,11 +14,13 @@ #include "segment.h" #include +#include #include #include #include #include #include +#include #include #include #include @@ -57,6 +59,9 @@ #include "db/index/storage/mmap_forward_store.h" #include "db/index/storage/store_helper.h" #include "db/index/storage/wal/wal_file.h" +#include "db/training/omega_model_trainer.h" +#include "db/training/omega_training_coordinator.h" +#include "db/training/training_data_collector.h" #include "zvec/ailego/container/params.h" #include "zvec/core/framework/index_factory.h" #include "zvec/core/framework/index_meta.h" @@ -67,6 +72,8 @@ namespace zvec { +namespace {} // namespace + void global_init() { static std::once_flag once; // run once @@ -207,6 +214,8 @@ class SegmentImpl : public Segment, Status flush() override; + Status retrain_omega_model() override; + Status destroy() override; TablePtr fetch(const std::vector &columns, @@ -281,6 +290,11 @@ class SegmentImpl : public Segment, Status internal_upsert(Doc &doc); Status internal_delete(const Doc &doc); + // Auto-training for OMEGA index (called after Merge completes) + Status auto_train_omega_index_internal( + const std::string &field_name, + const std::vector &indexers); + Status recover(); Status open_wal_file(); Status append_wal(const Doc &doc); @@ -1447,11 +1461,13 @@ CombinedVectorColumnIndexer::Ptr SegmentImpl::get_combined_vector_indexer( auto m_iter = memory_vector_indexers_.find(field_name); if (m_iter != memory_vector_indexers_.end()) { indexers.push_back(m_iter->second); + } else { } auto field = collection_schema_->get_field(field_name); auto vector_index_params = std::dynamic_pointer_cast(field->index_params()); + MetricType metric_type = vector_index_params->metric_type(); auto blocks = get_persist_block_metas(BlockType::VECTOR_INDEX, field_name); @@ -1566,6 +1582,14 @@ Status SegmentImpl::create_all_vector_index( new_segment_meta->set_indexed_vector_fields(vector_field_names); *segment_meta = new_segment_meta; + for (const auto &field_name : vector_field_names) { + } + + // Note: OMEGA training is now performed in merge_vector_indexer() immediately + // after the index is built via Merge(). This is the recommended approach per + // Alibaba expert advice - train right after reduce() completes when data is + // still in memory. + return Status::OK(); } @@ -1579,8 +1603,10 @@ Result SegmentImpl::merge_vector_indexer( auto s = vector_indexer->Open(options); CHECK_RETURN_STATUS_EXPECTED(s); + std::vector to_merge_indexers = vector_indexers_[column]; + vector_column_params::MergeOptions merge_options; if (concurrency == 0) { merge_options.pool = GlobalResource::Instance().optimize_thread_pool(); @@ -1591,9 +1617,82 @@ Result SegmentImpl::merge_vector_indexer( } s = vector_indexer->Merge(to_merge_indexers, filter_, merge_options); CHECK_RETURN_STATUS_EXPECTED(s); + + // Check if this is a trainable index (OMEGA) + auto *training_capable = vector_indexer->GetTrainingCapability(); + bool needs_training = false; + std::string model_output_dir; + OmegaTrainingParams omega_training_params; + + if (training_capable != nullptr) { + omega_training_params = ResolveOmegaTrainingParams(field.index_params()); + + size_t doc_count = vector_indexer->doc_count(); + if (doc_count >= omega_training_params.min_vector_threshold) { + needs_training = true; + LOG_INFO( + "Trainable index detected after merge for field '%s' in segment %d " + "(doc_count=%zu >= min_vector_threshold=%u)", + column.c_str(), id(), doc_count, + omega_training_params.min_vector_threshold); + } else { + LOG_INFO( + "Skipping OMEGA training for field '%s': doc_count=%zu < " + "min_vector_threshold=%u", + column.c_str(), doc_count, + omega_training_params.min_vector_threshold); + } + } + + // OPTIMIZATION: Collect training data BEFORE Flush() while the in-memory + // graph still exists. This avoids the expensive disk reload (~2 minutes for + // 1M vectors) that was previously needed. The model training itself doesn't + // need the graph, only the collected records. + std::optional training_result_opt; + + if (needs_training) { + // Compute model output directory + std::string segment_dir = + index_file_path.substr(0, index_file_path.rfind('/')); + model_output_dir = segment_dir + "/omega_model"; + + LOG_INFO( + "Starting OMEGA training data collection for field '%s' (using " + "in-memory graph before flush)", + column.c_str()); + LOG_INFO( + "Using OMEGA index params: num_training_queries=%zu, ef_training=%d, " + "ef_groundtruth=%d, k_train=%d", + omega_training_params.num_training_queries, + omega_training_params.ef_training, omega_training_params.ef_groundtruth, + omega_training_params.k_train); + + auto training_result = CollectOmegaTrainingDataBeforeFlush( + shared_from_this(), column, vector_indexer, omega_training_params, + model_output_dir); + if (training_result.has_value()) { + training_result_opt = std::move(training_result.value()); + LOG_INFO("Collected %zu training records (before flush)", + training_result_opt->records.size()); + } else { + LOG_WARN("Failed to collect training data: %s", + training_result.error().message().c_str()); + } + } + + // Now flush to persist the data (this clears the in-memory graph) s = vector_indexer->Flush(); CHECK_RETURN_STATUS_EXPECTED(s); + // Train the model using the previously collected data (doesn't need the + // graph) + if (needs_training && training_result_opt.has_value()) { + auto s_train = TrainOmegaModelAfterBuild(training_result_opt.value(), + model_output_dir); + CHECK_RETURN_STATUS_EXPECTED(s_train); + } + + return vector_indexer; } @@ -2313,6 +2412,105 @@ Status SegmentImpl::cleanup() { return Status::OK(); } +Status SegmentImpl::auto_train_omega_index_internal( + const std::string &field_name, + const std::vector &indexers) { + LOG_WARN("Starting auto-training for OMEGA index on field '%s' in segment %d", + field_name.c_str(), id()); + + OmegaTrainingParams omega_training_params; + auto field = collection_schema_->get_field(field_name); + if (field && field->index_params()) { + omega_training_params = ResolveOmegaTrainingParams(field->index_params()); + LOG_INFO( + "Using OMEGA index params: num_training_queries=%zu, ef_training=%d, " + "ef_groundtruth=%d, min_vector_threshold=%u, k_train=%d", + omega_training_params.num_training_queries, + omega_training_params.ef_training, omega_training_params.ef_groundtruth, + omega_training_params.min_vector_threshold, + omega_training_params.k_train); + } + + // Check if we have enough vectors to justify training + size_t total_doc_count = 0; + for (const auto &indexer : indexers) { + total_doc_count += indexer->doc_count(); + } + + if (total_doc_count < omega_training_params.min_vector_threshold) { + LOG_INFO( + "Skipping OMEGA training for field '%s': doc_count=%zu < " + "min_vector_threshold=%u", + field_name.c_str(), total_doc_count, + omega_training_params.min_vector_threshold); + return Status::OK(); + } + + LOG_INFO( + "Proceeding with OMEGA training: doc_count=%zu >= " + "min_vector_threshold=%u", + total_doc_count, omega_training_params.min_vector_threshold); + + // Step 1: Collect training data using the provided indexers + LOG_WARN( + "OMEGA retrain step 1/2: start collecting training data for field '%s' " + "in segment %d", + field_name.c_str(), id()); + const std::string model_output_dir = + FileHelper::MakeSegmentPath(path_, id()) + "/omega_model"; + auto training_records_result = + CollectOmegaRetrainingData(shared_from_this(), field_name, indexers, + omega_training_params, model_output_dir); + + if (!training_records_result.has_value()) { + return Status::InternalError("Failed to collect training data: " + + training_records_result.error().message()); + } + + LOG_WARN( + "OMEGA retrain step 1/2: finished collecting training data for field " + "'%s' in segment %d", + field_name.c_str(), id()); + + auto &training_result = training_records_result.value(); + LOG_INFO("Collected %zu training records for segment %d", + training_result.records.size(), id()); + + return TrainOmegaModelAfterRetrainCollect(training_result, model_output_dir, + id(), field_name); +} + +Status SegmentImpl::retrain_omega_model() { + for (const auto &field : collection_schema_->vector_fields()) { + if (!field->index_params()) { + continue; + } + auto omega_params = + std::dynamic_pointer_cast(field->index_params()); + if (!omega_params) { + continue; + } + + auto indexers = get_vector_indexer(field->name()); + if (indexers.empty()) { + LOG_INFO( + "Skipping OMEGA retraining for field '%s' in segment %d: no vector " + "indexers loaded", + field->name().c_str(), id()); + continue; + } + + LOG_WARN( + "Retraining OMEGA model for field '%s' in segment %d using existing " + "index", + field->name().c_str(), id()); + auto s = auto_train_omega_index_internal(field->name(), indexers); + CHECK_RETURN_STATUS(s); + } + + return Status::OK(); +} + bool SegmentImpl::validate(const std::vector &columns) const { if (columns.empty()) { LOG_ERROR("Empty columns"); @@ -2431,7 +2629,7 @@ TablePtr SegmentImpl::fetch_normal( const auto &block_offsets = get_persist_block_offsets(BlockType::SCALAR); const auto &block_metas = get_persist_block_metas(BlockType::SCALAR); - // Phase 1: Map each (doc_id, column) to its block and local row + // Map each (doc_id, column) to its block and local row. for (int output_row = 0; output_row < static_cast(indices.size()); ++output_row) { int doc_id = indices[output_row]; @@ -2475,7 +2673,7 @@ TablePtr SegmentImpl::fetch_normal( } } - // Phase 2: Execute batched fetch per block + // Execute batched fetches per block. for (const auto &[block_index, col_to_rows] : block_request_map) { std::vector fetch_columns; std::vector fetch_local_rows; @@ -2535,7 +2733,7 @@ TablePtr SegmentImpl::fetch_normal( } } - // Phase 3: Construct result arrays + // Construct the output arrays. std::vector> result_arrays(columns.size()); bool need_local_doc_id = false; @@ -3984,22 +4182,28 @@ Status SegmentImpl::load_scalar_index_blocks(bool create) { } Status SegmentImpl::load_vector_index_blocks() { + int block_index = 0; for (const auto &block : segment_meta_->persisted_blocks()) { if (block.type() == BlockType::VECTOR_INDEX || block.type() == BlockType::VECTOR_INDEX_QUANTIZE) { // vector block only contained 1 column auto column = block.columns()[0]; + FieldSchema new_field_params = *collection_schema_->get_vector_field(column); + auto vector_index_params = std::dynamic_pointer_cast( new_field_params.index_params()); + + if (block.type_ == BlockType::VECTOR_INDEX) { if (vector_index_params->quantize_type() != QuantizeType::UNDEFINED || !segment_meta_->vector_indexed(column)) { new_field_params.set_index_params( MakeDefaultVectorIndexParams(vector_index_params->metric_type())); + } else { } } else { if (!segment_meta_->vector_indexed(column)) { @@ -4009,6 +4213,7 @@ Status SegmentImpl::load_vector_index_blocks() { } } + std::string index_path; if (block.type_ == BlockType::VECTOR_INDEX) { index_path = FileHelper::MakeVectorIndexPath( @@ -4462,4 +4667,4 @@ Result Segment::Open(const std::string &path, return segment; } -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/db/index/segment/segment.h b/src/db/index/segment/segment.h index 263463ea0..fd43c7110 100644 --- a/src/db/index/segment/segment.h +++ b/src/db/index/segment/segment.h @@ -174,9 +174,10 @@ class Segment { // for others virtual Status flush() = 0; virtual Status dump() = 0; + virtual Status retrain_omega_model() = 0; // only mark need_destroyed virtual Status destroy() = 0; }; -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/db/proto/zvec.proto b/src/db/proto/zvec.proto index 3c9d33319..07d4e7491 100644 --- a/src/db/proto/zvec.proto +++ b/src/db/proto/zvec.proto @@ -60,6 +60,8 @@ enum IndexType { IT_HNSW_RABITQ = 4; // Invert Index IT_INVERT = 10; + // OMEGA Index (HNSW with learned early stopping) + IT_OMEGA = 11; }; enum QuantizeType { @@ -112,13 +114,27 @@ message IVFIndexParams { bool use_soar = 4; } +message OmegaIndexParams { + BaseIndexParams base = 1; + int32 m = 2; + int32 ef_construction = 3; + uint32 min_vector_threshold = 4; + string model_dir = 5; + uint32 num_training_queries = 6; + int32 ef_training = 7; + int32 window_size = 8; + int32 ef_groundtruth = 9; // 0 = brute force, >0 = HNSW with this ef + int32 k_train = 10; +} + message IndexParams { oneof params { InvertIndexParams invert = 1; HnswIndexParams hnsw = 2; FlatIndexParams flat = 3; IVFIndexParams ivf = 4; - HnswRabitqIndexParams hnsw_rabitq = 5; + OmegaIndexParams omega = 5; + HnswRabitqIndexParams hnsw_rabitq = 6; }; }; @@ -186,4 +202,4 @@ message Manifest { uint32 delete_snapshot_path_suffix = 7; uint32 next_segment_id = 8; -}; \ No newline at end of file +}; diff --git a/src/db/training/omega_model_trainer.cc b/src/db/training/omega_model_trainer.cc new file mode 100644 index 000000000..176050deb --- /dev/null +++ b/src/db/training/omega_model_trainer.cc @@ -0,0 +1,127 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "omega_model_trainer.h" +#include +#include +#include +#include + +namespace zvec { + +namespace { + +// Convert zvec TrainingRecord to omega TrainingRecord +omega::TrainingRecord ConvertRecord(const core_interface::TrainingRecord &src) { + omega::TrainingRecord dst; + dst.query_id = src.query_id; + dst.hops_visited = src.hops_visited; + dst.cmps_visited = src.cmps_visited; + dst.dist_1st = src.dist_1st; + dst.dist_start = src.dist_start; + // Convert std::array to std::vector + dst.traversal_window_stats.assign(src.traversal_window_stats.begin(), + src.traversal_window_stats.end()); + dst.label = src.label; // Already computed in real-time during search + return dst; +} + +// Convert zvec GtCmpsData to omega GtCmpsData +omega::GtCmpsData ConvertGtCmpsData(const core_interface::GtCmpsData &src) { + omega::GtCmpsData dst; + dst.num_queries = src.num_queries; + dst.topk = src.topk; + dst.gt_cmps = src.gt_cmps; + dst.total_cmps = src.total_cmps; + return dst; +} + +} // namespace + +Status OmegaModelTrainer::TrainModelWithGtCmps( + const std::vector &training_records, + const core_interface::GtCmpsData >_cmps_data, + const OmegaModelTrainerOptions &options) { + if (training_records.empty()) { + return Status::InvalidArgument("Training records are empty"); + } + + if (options.output_dir.empty()) { + return Status::InvalidArgument("Output directory is empty"); + } + + auto total_start = std::chrono::high_resolution_clock::now(); + + LOG_INFO("Training OMEGA model using C++ LightGBM API (%zu records)", + training_records.size()); + + // Convert training records + std::vector omega_records; + omega_records.reserve(training_records.size()); + for (const auto &r : training_records) { + omega_records.push_back(ConvertRecord(r)); + } + std::sort( + omega_records.begin(), omega_records.end(), + [](const omega::TrainingRecord &lhs, const omega::TrainingRecord &rhs) { + if (lhs.query_id != rhs.query_id) { + return lhs.query_id < rhs.query_id; + } + if (lhs.cmps_visited != rhs.cmps_visited) { + return lhs.cmps_visited < rhs.cmps_visited; + } + if (lhs.hops_visited != rhs.hops_visited) { + return lhs.hops_visited < rhs.hops_visited; + } + return lhs.label < rhs.label; + }); + + // Convert gt_cmps data + omega::GtCmpsData omega_gt_cmps = ConvertGtCmpsData(gt_cmps_data); + + // Setup trainer options + omega::OmegaTrainerOptions trainer_options; + trainer_options.output_dir = options.output_dir; + trainer_options.num_iterations = options.num_iterations; + trainer_options.num_leaves = options.num_leaves; + trainer_options.learning_rate = options.learning_rate; + trainer_options.num_threads = options.num_threads; + trainer_options.seed = options.seed; + trainer_options.deterministic = options.deterministic; + trainer_options.verbose = options.verbose; + trainer_options.topk = gt_cmps_data.topk > 0 ? gt_cmps_data.topk : 100; + + // Train model + int ret = omega::OmegaTrainer::TrainModel(omega_records, omega_gt_cmps, + trainer_options); + + auto total_end = std::chrono::high_resolution_clock::now(); + auto total_ms = std::chrono::duration_cast( + total_end - total_start) + .count(); + + if (ret != 0) { + LOG_ERROR("OMEGA model training failed (return code: %d)", ret); + return Status::InternalError("OMEGA model training failed"); + } + + LOG_INFO("[TIMING] TrainModelWithGtCmps (C++ LightGBM) TOTAL: %ld ms", + total_ms); + LOG_INFO("Successfully trained OMEGA model, output: %s", + options.output_dir.c_str()); + + return Status::OK(); +} + +} // namespace zvec diff --git a/src/db/training/omega_model_trainer.h b/src/db/training/omega_model_trainer.h new file mode 100644 index 000000000..266e00b06 --- /dev/null +++ b/src/db/training/omega_model_trainer.h @@ -0,0 +1,78 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace zvec { + +inline int DefaultOmegaTrainerThreads() { + const unsigned int hc = std::thread::hardware_concurrency(); + if (hc == 0) { + return 8; + } + return static_cast(std::max(1u, hc / 2)); +} + +/** + * @brief Configuration options for OMEGA model training + */ +struct OmegaModelTrainerOptions { + // Output directory for trained model files + std::string output_dir; + + // LightGBM training parameters + int num_iterations = 100; + int num_leaves = 31; + double learning_rate = 0.1; + int num_threads = DefaultOmegaTrainerThreads(); + int seed = 42; + bool deterministic = true; + + // Enable verbose logging during training + bool verbose = false; +}; + +/** + * @brief OMEGA model trainer using LightGBM C API + * + * This class trains a LightGBM binary classifier directly in C++, + * eliminating the need for Python subprocess and CSV serialization. + */ +class OmegaModelTrainer { + public: + /** + * @brief Train OMEGA model with gt_cmps data for table generation + * + * This is the extended version that also generates gt_collected_table + * and gt_cmps_all_table from gt_cmps data. + * + * @param training_records Training data collected from searches + * @param gt_cmps_data Ground truth cmps data for table generation + * @param options Training configuration + * @return Status indicating success or failure + */ + static Status TrainModelWithGtCmps( + const std::vector &training_records, + const core_interface::GtCmpsData >_cmps_data, + const OmegaModelTrainerOptions &options); +}; + +} // namespace zvec diff --git a/src/db/training/omega_training_coordinator.cc b/src/db/training/omega_training_coordinator.cc new file mode 100644 index 000000000..2a0aad3dc --- /dev/null +++ b/src/db/training/omega_training_coordinator.cc @@ -0,0 +1,312 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "db/training/omega_training_coordinator.h" +#include +#include +#include +#include +#include +#include "db/common/file_helper.h" +#include "db/training/omega_model_trainer.h" + +namespace zvec { + +namespace { + +constexpr uint32_t kOmegaQueryCacheMagic = 0x4F514359; // OQCY +constexpr uint32_t kOmegaQueryCacheVersion = 1; + +} // namespace + +static void WriteOmegaTimingStatsJson( + const std::string &output_path, + const std::vector> &stats) { + std::ofstream ofs(output_path); + if (!ofs.is_open()) { + return; + } + ofs << "{\n"; + for (size_t i = 0; i < stats.size(); ++i) { + ofs << " \"" << stats[i].first << "\": " << stats[i].second; + if (i + 1 < stats.size()) { + ofs << ","; + } + ofs << "\n"; + } + ofs << "}\n"; +} + +static std::string OmegaQueryCachePath(const std::string &model_output_dir) { + return model_output_dir + "/training_queries.bin"; +} + +static bool SaveOmegaTrainingQueryCache( + const std::string &model_output_dir, + const std::vector> &queries, + const std::vector &query_doc_ids) { + if (queries.empty() || queries.size() != query_doc_ids.size()) { + return false; + } + const uint32_t dim = static_cast(queries[0].size()); + for (const auto &query : queries) { + if (query.size() != dim) { + return false; + } + } + + std::ofstream ofs(OmegaQueryCachePath(model_output_dir), std::ios::binary); + if (!ofs.is_open()) { + return false; + } + + const uint64_t num_queries = queries.size(); + ofs.write(reinterpret_cast(&kOmegaQueryCacheMagic), + sizeof(kOmegaQueryCacheMagic)); + ofs.write(reinterpret_cast(&kOmegaQueryCacheVersion), + sizeof(kOmegaQueryCacheVersion)); + ofs.write(reinterpret_cast(&num_queries), sizeof(num_queries)); + ofs.write(reinterpret_cast(&dim), sizeof(dim)); + for (size_t i = 0; i < queries.size(); ++i) { + ofs.write(reinterpret_cast(&query_doc_ids[i]), + sizeof(query_doc_ids[i])); + ofs.write(reinterpret_cast(queries[i].data()), + static_cast(dim * sizeof(float))); + } + return ofs.good(); +} + +static bool LoadOmegaTrainingQueryCache( + const std::string &model_output_dir, + std::vector> *queries, + std::vector *query_doc_ids) { + std::ifstream ifs(OmegaQueryCachePath(model_output_dir), std::ios::binary); + if (!ifs.is_open()) { + return false; + } + + uint32_t magic = 0; + uint32_t version = 0; + uint64_t num_queries = 0; + uint32_t dim = 0; + ifs.read(reinterpret_cast(&magic), sizeof(magic)); + ifs.read(reinterpret_cast(&version), sizeof(version)); + ifs.read(reinterpret_cast(&num_queries), sizeof(num_queries)); + ifs.read(reinterpret_cast(&dim), sizeof(dim)); + if (!ifs.good() || magic != kOmegaQueryCacheMagic || + version != kOmegaQueryCacheVersion || num_queries == 0 || dim == 0) { + return false; + } + + queries->assign(num_queries, std::vector(dim)); + query_doc_ids->assign(num_queries, 0); + for (size_t i = 0; i < num_queries; ++i) { + ifs.read(reinterpret_cast(&(*query_doc_ids)[i]), sizeof(uint64_t)); + ifs.read(reinterpret_cast((*queries)[i].data()), + static_cast(dim * sizeof(float))); + if (!ifs.good()) { + queries->clear(); + query_doc_ids->clear(); + return false; + } + } + return true; +} + +OmegaTrainingParams ResolveOmegaTrainingParams( + const IndexParams::Ptr &index_params) { + OmegaTrainingParams params; + auto omega_params = std::dynamic_pointer_cast(index_params); + if (!omega_params) { + return params; + } + + params.num_training_queries = omega_params->num_training_queries(); + params.ef_training = omega_params->ef_training(); + params.ef_groundtruth = omega_params->ef_groundtruth(); + params.min_vector_threshold = omega_params->min_vector_threshold(); + params.k_train = omega_params->k_train(); + return params; +} + +Result CollectOmegaTrainingDataBeforeFlush( + const Segment::Ptr &segment, const std::string &field_name, + const VectorColumnIndexer::Ptr &vector_indexer, + const OmegaTrainingParams ¶ms, const std::string &model_output_dir) { + TrainingDataCollectorOptions collector_opts; + const size_t doc_count = vector_indexer->doc_count(); + collector_opts.num_training_queries = + std::min(doc_count, params.num_training_queries); + collector_opts.ef_training = params.ef_training; + collector_opts.ef_groundtruth = params.ef_groundtruth; + collector_opts.topk = 100; + collector_opts.k_train = params.k_train; + + std::vector training_indexers = {vector_indexer}; + auto training_result = TrainingDataCollector::CollectTrainingDataWithGtCmps( + segment, field_name, collector_opts, training_indexers); + if (!training_result.has_value()) { + return training_result; + } + + if (!FileHelper::DirectoryExists(model_output_dir)) { + FileHelper::CreateDirectory(model_output_dir); + } + if (!SaveOmegaTrainingQueryCache(model_output_dir, + training_result->training_queries, + training_result->query_doc_ids)) { + LOG_WARN("Failed to persist OMEGA training query cache: %s", + OmegaQueryCachePath(model_output_dir).c_str()); + } + WriteOmegaTimingStatsJson( + model_output_dir + "/training_collection_timing.json", + TrainingDataCollector::ConsumeTimingStats()); + return training_result; +} + +Result CollectOmegaRetrainingData( + const Segment::Ptr &segment, const std::string &field_name, + const std::vector &indexers, + const OmegaTrainingParams ¶ms, const std::string &model_output_dir) { + TrainingDataCollectorOptions collector_options; + collector_options.num_training_queries = params.num_training_queries; + collector_options.ef_training = params.ef_training; + collector_options.ef_groundtruth = params.ef_groundtruth; + collector_options.topk = 100; + collector_options.k_train = params.k_train; + + std::vector> cached_queries; + std::vector cached_query_doc_ids; + if (LoadOmegaTrainingQueryCache(model_output_dir, &cached_queries, + &cached_query_doc_ids)) { + LOG_WARN("Loaded %zu cached held-out queries for OMEGA retraining from %s", + cached_queries.size(), + OmegaQueryCachePath(model_output_dir).c_str()); + return TrainingDataCollector::CollectTrainingDataWithGtCmpsFromQueries( + segment, field_name, cached_queries, cached_query_doc_ids, + collector_options, indexers); + } + + LOG_WARN( + "OMEGA retrain query cache not found, falling back to sampling held-out " + "queries from persisted segment"); + return TrainingDataCollector::CollectTrainingDataWithGtCmps( + segment, field_name, collector_options, indexers); +} + +Status TrainOmegaModelAfterBuild( + const TrainingDataCollectorResult &training_result, + const std::string &model_output_dir) { + if (training_result.records.size() < 100) { + LOG_INFO( + "Skipping model training: only %zu records collected (need >= 100)", + training_result.records.size()); + return Status::OK(); + } + + OmegaModelTrainerOptions trainer_opts; + trainer_opts.output_dir = model_output_dir; + trainer_opts.verbose = true; + + if (!FileHelper::DirectoryExists(model_output_dir) && + !FileHelper::CreateDirectory(model_output_dir)) { + LOG_WARN("Failed to create model output directory: %s", + model_output_dir.c_str()); + } + + auto train_status = OmegaModelTrainer::TrainModelWithGtCmps( + training_result.records, training_result.gt_cmps_data, trainer_opts); + if (train_status.ok()) { + LOG_INFO("OMEGA model training completed successfully: %s", + trainer_opts.output_dir.c_str()); + } else { + LOG_WARN("OMEGA model training failed: %s", train_status.message().c_str()); + } + + return Status::OK(); +} + +Status TrainOmegaModelAfterRetrainCollect( + const TrainingDataCollectorResult &training_result, + const std::string &model_output_dir, SegmentID segment_id, + const std::string &field_name) { + const auto &training_records = training_result.records; + if (training_records.empty()) { + LOG_WARN("No training records collected, skipping model training"); + return Status::OK(); + } + + size_t positive_count = 0; + size_t negative_count = 0; + for (const auto &record : training_records) { + if (record.label == 1) { + positive_count++; + } else { + negative_count++; + } + } + + if (positive_count == 0 || negative_count == 0) { + LOG_WARN( + "Insufficient training samples: %zu positive, %zu negative. Need both " + "> 0. Skipping training.", + positive_count, negative_count); + return Status::OK(); + } + + if (positive_count < 50 || negative_count < 50) { + LOG_WARN( + "Too few training samples: %zu positive, %zu negative. Need at least " + "50 of each. Skipping training.", + positive_count, negative_count); + return Status::OK(); + } + + LOG_INFO("Training data stats: %zu positive, %zu negative samples", + positive_count, negative_count); + + LOG_WARN( + "OMEGA retrain step 2/2: start model training for field '%s' in segment " + "%d", + field_name.c_str(), segment_id); + OmegaModelTrainerOptions trainer_options; + trainer_options.output_dir = model_output_dir; + trainer_options.verbose = true; + + if (!FileHelper::DirectoryExists(trainer_options.output_dir) && + !FileHelper::CreateDirectory(trainer_options.output_dir)) { + return Status::InternalError("Failed to create model output directory: " + + trainer_options.output_dir); + } + + WriteOmegaTimingStatsJson( + trainer_options.output_dir + "/training_collection_timing.json", + TrainingDataCollector::ConsumeTimingStats()); + + auto train_status = OmegaModelTrainer::TrainModelWithGtCmps( + training_records, training_result.gt_cmps_data, trainer_options); + if (!train_status.ok()) { + return Status::InternalError("Failed to train OMEGA model: " + + train_status.message()); + } + + LOG_WARN( + "OMEGA retrain step 2/2: finished model training for segment %d, output: " + "%s", + segment_id, trainer_options.output_dir.c_str()); + + return Status::OK(); +} + +} // namespace zvec diff --git a/src/db/training/omega_training_coordinator.h b/src/db/training/omega_training_coordinator.h new file mode 100644 index 000000000..d62ff4584 --- /dev/null +++ b/src/db/training/omega_training_coordinator.h @@ -0,0 +1,97 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "db/index/column/vector_column/vector_column_indexer.h" +#include "db/index/segment/segment.h" +#include "db/training/training_data_collector.h" + +namespace zvec { + +struct OmegaTrainingParams { + size_t num_training_queries = 1000; + int ef_training = 1000; + int ef_groundtruth = 0; + uint32_t min_vector_threshold = 100000; + int k_train = 1; +}; + +OmegaTrainingParams ResolveOmegaTrainingParams( + const IndexParams::Ptr &index_params); + +#if ZVEC_ENABLE_OMEGA +Result CollectOmegaTrainingDataBeforeFlush( + const Segment::Ptr &segment, const std::string &field_name, + const VectorColumnIndexer::Ptr &vector_indexer, + const OmegaTrainingParams ¶ms, const std::string &model_output_dir); + +Result CollectOmegaRetrainingData( + const Segment::Ptr &segment, const std::string &field_name, + const std::vector &indexers, + const OmegaTrainingParams ¶ms, const std::string &model_output_dir); + +Status TrainOmegaModelAfterBuild( + const TrainingDataCollectorResult &training_result, + const std::string &model_output_dir); + +Status TrainOmegaModelAfterRetrainCollect( + const TrainingDataCollectorResult &training_result, + const std::string &model_output_dir, SegmentID segment_id, + const std::string &field_name); +#else +inline OmegaTrainingParams ResolveOmegaTrainingParams( + const IndexParams::Ptr & /*index_params*/) { + return {}; +} + +inline Result CollectOmegaTrainingDataBeforeFlush( + const Segment::Ptr & /*segment*/, const std::string & /*field_name*/, + const VectorColumnIndexer::Ptr & /*vector_indexer*/, + const OmegaTrainingParams & /*params*/, + const std::string & /*model_output_dir*/) { + return tl::make_unexpected( + Status::NotSupported("OMEGA is disabled on Android")); +} + +inline Result CollectOmegaRetrainingData( + const Segment::Ptr & /*segment*/, const std::string & /*field_name*/, + const std::vector & /*indexers*/, + const OmegaTrainingParams & /*params*/, + const std::string & /*model_output_dir*/) { + return tl::make_unexpected( + Status::NotSupported("OMEGA is disabled on Android")); +} + +inline Status TrainOmegaModelAfterBuild( + const TrainingDataCollectorResult & /*training_result*/, + const std::string & /*model_output_dir*/) { + return Status::NotSupported("OMEGA is disabled on Android"); +} + +inline Status TrainOmegaModelAfterRetrainCollect( + const TrainingDataCollectorResult & /*training_result*/, + const std::string & /*model_output_dir*/, SegmentID /*segment_id*/, + const std::string & /*field_name*/) { + return Status::NotSupported("OMEGA is disabled on Android"); +} +#endif + +} // namespace zvec diff --git a/src/db/training/query_generator.cc b/src/db/training/query_generator.cc new file mode 100644 index 000000000..99c5f2263 --- /dev/null +++ b/src/db/training/query_generator.cc @@ -0,0 +1,99 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "query_generator.h" +#include +#include + +namespace zvec { + +SampledVectors TrainingQueryGenerator::SampleBaseVectorsWithIds( + const Segment::Ptr &segment, const std::string &field_name, + size_t num_samples, uint64_t seed) { + SampledVectors result; + + // Get total document count + uint64_t doc_count = segment->doc_count(); + if (doc_count == 0) { + LOG_WARN("Segment has no documents, cannot sample base vectors"); + return result; + } + + // Adjust num_samples if it exceeds doc_count + size_t actual_samples = std::min(num_samples, static_cast(doc_count)); + if (actual_samples < num_samples) { + LOG_WARN("Requested %zu samples but segment only has %zu documents", + num_samples, static_cast(doc_count)); + } + + // Random number generator for sampling indices + std::mt19937_64 rng(seed); + std::uniform_int_distribution dist(0, doc_count - 1); + + result.vectors.reserve(actual_samples); + result.doc_ids.reserve(actual_samples); + + // Sample vectors + for (size_t i = 0; i < actual_samples; ++i) { + // Get random document index + uint64_t doc_idx = dist(rng); + + // Fetch document + auto doc = segment->Fetch(doc_idx); + if (!doc) { + LOG_WARN("Failed to fetch document at index %zu, skipping", + static_cast(doc_idx)); + continue; + } + + // Extract vector field + auto vector_opt = doc->get>(field_name); + if (!vector_opt.has_value()) { + LOG_WARN("Document at index %zu does not have field '%s', skipping", + static_cast(doc_idx), field_name.c_str()); + continue; + } + + result.vectors.push_back(vector_opt.value()); + result.doc_ids.push_back(doc_idx); + } + + LOG_INFO("Successfully sampled %zu/%zu vectors with doc_ids from segment", + result.vectors.size(), actual_samples); + + return result; +} + +SampledVectors TrainingQueryGenerator::GenerateHeldOutQueries( + const Segment::Ptr &segment, const std::string &field_name, + size_t num_queries, uint64_t seed) { + // Sample vectors directly from the index - no noise added + // These vectors will be used as queries, with their doc_ids excluded from + // ground truth + auto result = + SampleBaseVectorsWithIds(segment, field_name, num_queries, seed); + + if (result.vectors.empty()) { + LOG_ERROR("Failed to sample vectors from segment for held-out queries"); + return result; + } + + LOG_INFO( + "Generated %zu held-out queries (vectors sampled directly from index)", + result.vectors.size()); + + return result; +} + +} // namespace zvec diff --git a/src/db/training/query_generator.h b/src/db/training/query_generator.h new file mode 100644 index 000000000..077020529 --- /dev/null +++ b/src/db/training/query_generator.h @@ -0,0 +1,76 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "db/index/segment/segment.h" + +namespace zvec { + +/** + * @brief Result of sampling base vectors, includes both vectors and doc_ids + */ +struct SampledVectors { + std::vector> vectors; + std::vector + doc_ids; // doc_id of each sampled vector (for exclusion in GT) +}; + +/** + * @brief Training query generator for OMEGA model training + * + * This class provides utilities to generate training queries by: + * 1. Sampling base vectors from a persisted segment (held-out approach) + * 2. Using sampled vectors directly as queries (no noise) + * 3. Ground truth computation excludes the query vector itself + */ +class TrainingQueryGenerator { + public: + /** + * @brief Sample base vectors from a segment with doc_ids + * + * @param segment The segment to sample from (must be persisted) + * @param field_name The vector field name to sample + * @param num_samples Number of vectors to sample + * @param seed Random seed for reproducibility + * @return SampledVectors with vectors and their doc_ids + */ + static SampledVectors SampleBaseVectorsWithIds(const Segment::Ptr &segment, + const std::string &field_name, + size_t num_samples, + uint64_t seed = 42); + + /** + * @brief Generate training queries using held-out approach + * + * Samples vectors directly from index and uses them as queries. + * Returns doc_ids so ground truth can exclude self-matches. + * + * @param segment The segment to sample from + * @param field_name The vector field name + * @param num_queries Number of training queries to generate + * @param seed Random seed for reproducibility + * @return SampledVectors with query vectors and their doc_ids + */ + static SampledVectors GenerateHeldOutQueries(const Segment::Ptr &segment, + const std::string &field_name, + size_t num_queries, + uint64_t seed = 42); +}; + +} // namespace zvec diff --git a/src/db/training/training_data_collector.cc b/src/db/training/training_data_collector.cc new file mode 100644 index 000000000..50e431310 --- /dev/null +++ b/src/db/training/training_data_collector.cc @@ -0,0 +1,898 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "training_data_collector.h" +#include +#include +#include +#include +#include +#include +#include +#include +#if ZVEC_ENABLE_OMEGA +#include +#endif +#include +#include +#include "db/index/column/vector_column/vector_column_params.h" +#include "query_generator.h" + +namespace zvec { + +namespace { + +#if !ZVEC_ENABLE_OMEGA +std::vector> ComputeGroundTruthFallbackBruteForce( + const Segment::Ptr &segment, const std::string &field_name, + const std::vector> &queries, size_t topk, + const std::vector &query_doc_ids, MetricType metric_type) { + std::vector> ground_truth(queries.size()); + const bool held_out_mode = + !query_doc_ids.empty() && query_doc_ids.size() == queries.size(); + const uint64_t doc_count = segment->doc_count(); + if (queries.empty() || doc_count == 0) { + return ground_truth; + } + + for (size_t q = 0; q < queries.size(); ++q) { + std::vector> scored; + scored.reserve(static_cast(doc_count)); + + for (uint64_t doc_id = 0; doc_id < doc_count; ++doc_id) { + if (held_out_mode && doc_id == query_doc_ids[q]) { + continue; + } + auto doc = segment->Fetch(doc_id); + if (!doc) { + continue; + } + auto vector_opt = doc->get>(field_name); + if (!vector_opt.has_value()) { + continue; + } + const auto &base = vector_opt.value(); + if (base.size() != queries[q].size()) { + continue; + } + + float score = 0.0f; + if (metric_type == MetricType::IP || metric_type == MetricType::COSINE) { + for (size_t d = 0; d < base.size(); ++d) { + score += queries[q][d] * base[d]; + } + } else { + for (size_t d = 0; d < base.size(); ++d) { + const float diff = queries[q][d] - base[d]; + score += diff * diff; + } + } + scored.emplace_back(score, doc_id); + } + + const auto limit = std::min(topk, scored.size()); + if (metric_type == MetricType::L2) { + std::partial_sort(scored.begin(), scored.begin() + limit, scored.end(), + [](const auto &lhs, const auto &rhs) { + return lhs.first < rhs.first; + }); + } else { + std::partial_sort(scored.begin(), scored.begin() + limit, scored.end(), + [](const auto &lhs, const auto &rhs) { + return lhs.first > rhs.first; + }); + } + + auto &result = ground_truth[q]; + result.reserve(limit); + for (size_t i = 0; i < limit; ++i) { + result.push_back(scored[i].second); + } + } + + return ground_truth; +} +#endif + +} // namespace + +// ============ DEBUG TIMING UTILITIES ============ +namespace { +struct TimingStatsState { + std::mutex mu; + std::vector> ordered_stats; + std::unordered_map index_by_name; +}; + +TimingStatsState &GetTimingStatsState() { + static TimingStatsState state; + return state; +} + +void RecordTimingStat(const std::string &name, int64_t duration_ms) { + auto &state = GetTimingStatsState(); + std::lock_guard lock(state.mu); + auto it = state.index_by_name.find(name); + if (it == state.index_by_name.end()) { + state.index_by_name[name] = state.ordered_stats.size(); + state.ordered_stats.emplace_back(name, duration_ms); + } else { + state.ordered_stats[it->second].second = duration_ms; + } +} + +class ScopedTimer { + public: + explicit ScopedTimer(const std::string &name) : name_(name) { + start_ = std::chrono::high_resolution_clock::now(); + } + ~ScopedTimer() { + auto end = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(end - start_) + .count(); + RecordTimingStat(name_, duration); + } + + private: + std::string name_; + std::chrono::high_resolution_clock::time_point start_; +}; + +std::vector ResolveTrainingIndexers( + const Segment::Ptr &segment, const std::string &field_name, + const std::vector &provided_indexers) { + if (!provided_indexers.empty()) { + return provided_indexers; + } + return segment->get_vector_indexer(field_name); +} + +std::vector StartTrainingSessions( + const std::vector &indexers, + const std::vector> &ground_truth, size_t topk, + int k_train) { + std::vector sessions; + sessions.reserve(indexers.size()); + + core_interface::TrainingSessionConfig config; + config.ground_truth = ground_truth; + config.topk = topk; + config.k_train = k_train; + + for (auto &indexer : indexers) { + auto session = indexer->CreateTrainingSession(); + if (session == nullptr) { + LOG_WARN("Indexer does not expose a training session"); + sessions.emplace_back(); + continue; + } + auto status = session->Start(config); + if (!status.ok()) { + LOG_WARN("Failed to start training session on indexer: %s", + status.message().c_str()); + sessions.emplace_back(); + continue; + } + indexer->SetTrainingSession(session); + sessions.push_back(std::move(session)); + } + + return sessions; +} + +core_interface::TrainingArtifacts ConsumeTrainingArtifacts( + const std::vector &sessions) { + core_interface::TrainingArtifacts merged; + for (const auto &session : sessions) { + if (session == nullptr) { + continue; + } + auto artifacts = session->ConsumeArtifacts(); + merged.records.insert(merged.records.end(), + std::make_move_iterator(artifacts.records.begin()), + std::make_move_iterator(artifacts.records.end())); + if (merged.gt_cmps_data.gt_cmps.empty() && + !artifacts.gt_cmps_data.gt_cmps.empty()) { + merged.gt_cmps_data = std::move(artifacts.gt_cmps_data); + } + } + return merged; +} + +void FinishTrainingSessions( + const std::vector &indexers, + const std::vector &sessions) { + for (size_t i = 0; i < indexers.size(); ++i) { + if (i < sessions.size() && sessions[i] != nullptr) { + sessions[i]->Finish(); + } + indexers[i]->ClearTrainingSession(); + } +} +} // namespace + +void TrainingDataCollector::ResetTimingStats() { + auto &state = GetTimingStatsState(); + std::lock_guard lock(state.mu); + state.ordered_stats.clear(); + state.index_by_name.clear(); +} + +TrainingDataCollector::TimingStats TrainingDataCollector::ConsumeTimingStats() { + auto &state = GetTimingStatsState(); + std::lock_guard lock(state.mu); + TimingStats timings = std::move(state.ordered_stats); + state.ordered_stats.clear(); + state.index_by_name.clear(); + return timings; +} + +Result +TrainingDataCollector::CollectTrainingDataFromQueriesImpl( + const Segment::Ptr &segment, const std::string &field_name, + const std::vector> &training_queries, + const std::vector> &provided_ground_truth, + const TrainingDataCollectorOptions &options, + const std::vector &query_doc_ids, + const std::vector &provided_indexers) { + std::vector indexers = + ResolveTrainingIndexers(segment, field_name, provided_indexers); + + if (indexers.empty()) { + return tl::make_unexpected(Status::InternalError( + "No vector indexers found for field: " + field_name)); + } + + if (training_queries.empty()) { + return tl::make_unexpected( + Status::InvalidArgument("Training queries are empty")); + } + + MetricType metric_type = indexers[0]->metric_type(); + + std::vector> ground_truth = provided_ground_truth; + if (ground_truth.empty()) { + LOG_INFO("Computing ground truth (topk=%zu, ef_groundtruth=%d)", + options.topk, options.ef_groundtruth); + ScopedTimer timer("Step2: ComputeGroundTruth"); + ground_truth = TrainingDataCollector::ComputeGroundTruth( + segment, field_name, training_queries, options.topk, + options.num_threads, query_doc_ids, options.ef_groundtruth, metric_type, + indexers); + } else if (ground_truth.size() != training_queries.size()) { + return tl::make_unexpected(Status::InvalidArgument( + "Ground truth size does not match query count")); + } + + if (ground_truth.empty()) { + return tl::make_unexpected( + Status::InternalError("Failed to obtain ground truth")); + } + + LOG_INFO("Starting training sessions for %zu queries on %zu indexers", + ground_truth.size(), indexers.size()); + std::vector training_sessions; + { + ScopedTimer timer("Step3: EnableTrainingMode"); + training_sessions = StartTrainingSessions(indexers, ground_truth, + options.topk, options.k_train); + } + + LOG_INFO("Performing training searches with ef=%d", options.ef_training); + std::vector> search_results; + search_results.reserve(training_queries.size()); + + { + ScopedTimer timer("Step4: TrainingSearches"); + + size_t actual_threads = options.num_threads; + if (actual_threads == 0) { + actual_threads = std::thread::hardware_concurrency(); + } + actual_threads = std::min(actual_threads, training_queries.size()); + + search_results.resize(training_queries.size()); + + auto search_start = std::chrono::high_resolution_clock::now(); + + auto worker = [&](size_t start_idx, size_t end_idx) { + for (size_t query_idx = start_idx; query_idx < end_idx; ++query_idx) { + const auto &query_vector = training_queries[query_idx]; + + vector_column_params::VectorData vector_data; + vector_data.vector = vector_column_params::DenseVector{ + const_cast(static_cast(query_vector.data()))}; + + vector_column_params::QueryParams query_params; + query_params.topk = options.topk; + query_params.fetch_vector = false; + query_params.filter = segment->get_filter().get(); + + auto omega_params = std::make_shared(); + omega_params->set_ef(options.ef_training); + omega_params->set_training_query_id(static_cast(query_idx)); + query_params.query_params = omega_params; + + if (indexers.size() != 1 && query_idx == start_idx) { + LOG_WARN("Expected 1 indexer but found %zu, using first one only", + indexers.size()); + } + + // Persisted OMEGA collections currently do not propagate per-query + // training_query_id through the search context reliably. In the + // single-threaded calibration path, fall back to the existing global + // query-id setter to preserve correct labels without races. + if (actual_threads == 1 && !training_sessions.empty() && + training_sessions[0] != nullptr) { + training_sessions[0]->BeginQuery(static_cast(query_idx)); + } + + auto search_result = indexers[0]->Search(vector_data, query_params); + if (!search_result.has_value()) { + LOG_WARN("Search failed for query %zu: %s", query_idx, + search_result.error().message().c_str()); + continue; + } + + auto &results = search_result.value(); + std::vector result_ids; + result_ids.reserve(results->count()); + auto iter = results->create_iterator(); + while (iter->valid()) { + result_ids.push_back(iter->doc_id()); + iter->next(); + } + + search_results[query_idx] = std::move(result_ids); + } + }; + + std::vector threads; + size_t queries_per_thread = + (training_queries.size() + actual_threads - 1) / actual_threads; + for (size_t t = 0; t < actual_threads; ++t) { + size_t start_idx = t * queries_per_thread; + size_t end_idx = + std::min(start_idx + queries_per_thread, training_queries.size()); + if (start_idx < end_idx) { + threads.emplace_back(worker, start_idx, end_idx); + } + } + + for (auto &thread : threads) { + thread.join(); + } + + auto search_end = std::chrono::high_resolution_clock::now(); + auto total_ms = std::chrono::duration_cast( + search_end - search_start) + .count(); + LOG_INFO("Training searches completed in %zu ms (%zu threads)", total_ms, + actual_threads); + } + + LOG_INFO("Collecting training records from indexers"); + core_interface::TrainingArtifacts training_artifacts; + { + ScopedTimer timer("Step5: CollectTrainingRecords"); + training_artifacts = ConsumeTrainingArtifacts(training_sessions); + LOG_INFO("Collected %zu records from training sessions", + training_artifacts.records.size()); + } + + auto &all_records = training_artifacts.records; + + if (all_records.empty()) { + LOG_WARN("No training records collected from any indexer"); + } + + size_t positive_count = 0; + size_t negative_count = 0; + for (const auto &record : all_records) { + if (record.label > 0) { + ++positive_count; + } else { + ++negative_count; + } + } + LOG_INFO( + "Collected %zu records: %zu positive, %zu negative (labels computed in " + "real-time)", + all_records.size(), positive_count, negative_count); + + LOG_INFO("Collecting gt_cmps data from indexers"); + core_interface::GtCmpsData gt_cmps_data = + std::move(training_artifacts.gt_cmps_data); + { + ScopedTimer timer("Step6: GetGtCmpsData"); + if (gt_cmps_data.gt_cmps.empty()) { + LOG_WARN( + "No actual gt_cmps data collected, falling back to approximation"); + gt_cmps_data = TrainingDataCollector::ComputeGtCmps( + all_records, ground_truth, options.topk); + } else { + LOG_INFO("Got actual gt_cmps data for %zu queries, topk=%zu", + gt_cmps_data.num_queries, gt_cmps_data.topk); + } + } + + { + ScopedTimer timer("Step7: DisableTrainingMode"); + FinishTrainingSessions(indexers, training_sessions); + } + + TrainingDataCollectorResult result; + result.records = std::move(all_records); + result.gt_cmps_data = std::move(gt_cmps_data); + result.training_queries = training_queries; + result.query_doc_ids = query_doc_ids; + return result; +} +// ============ END DEBUG TIMING UTILITIES ============ + +std::vector> TrainingDataCollector::ComputeGroundTruth( + const Segment::Ptr &segment, const std::string &field_name, + const std::vector> &queries, size_t topk, + size_t num_threads, const std::vector &query_doc_ids, + int ef_groundtruth, MetricType metric_type, + const std::vector &provided_indexers) { + std::vector> ground_truth(queries.size()); + + if (queries.empty()) { + return ground_truth; + } + + // Check if we have query doc_ids for self-exclusion (held-out mode) + bool held_out_mode = + !query_doc_ids.empty() && query_doc_ids.size() == queries.size(); + if (held_out_mode) { + LOG_INFO( + "Computing ground truth in held-out mode (excluding self-matches)"); + } + + // Get total document count + uint64_t doc_count = segment->doc_count(); + size_t dim = queries[0].size(); + + LOG_INFO( + "Computing ground truth: %zu queries, %zu base vectors, dim=%zu, " + "topk=%zu, metric=%s, ef_groundtruth=%d", + queries.size(), static_cast(doc_count), dim, topk, + MetricTypeCodeBook::AsString(metric_type).c_str(), ef_groundtruth); + + auto start_time = std::chrono::high_resolution_clock::now(); + + // ============================================================ + // Branch 1: HNSW search (ef_groundtruth > 0) + // Faster for large datasets, approximate results + // ============================================================ + if (ef_groundtruth > 0) { + // Use provided indexers if available, otherwise get from segment + // IMPORTANT: We must use provided_indexers when available because after + // Flush, segment->get_vector_indexer() returns stale indexers with cleared + // in-memory data + std::vector indexers; + if (!provided_indexers.empty()) { + indexers = provided_indexers; + } else { + indexers = segment->get_vector_indexer(field_name); + } + + if (indexers.empty()) { + LOG_ERROR( + "No vector indexers found for field '%s', falling back to brute " + "force", + field_name.c_str()); + ef_groundtruth = 0; // Fall back to brute force + } else { + // For held-out mode, we need topk+1 to exclude self + size_t actual_topk = held_out_mode ? topk + 1 : topk; + + // ======================================================== + // WARM-UP STEP: Pre-load index data to avoid cold-start penalty + // Without this, the first searches are very slow due to lazy loading. + // We use parallel warmup with all threads to load data faster. + // ======================================================== + { + auto warmup_start = std::chrono::high_resolution_clock::now(); + + // Warmup count: use a fraction of queries spread across threads + // This helps load different parts of the index in parallel + size_t actual_threads = + num_threads > 0 ? num_threads : std::thread::hardware_concurrency(); + size_t warmup_per_thread = 5; // Each thread does 5 warmup queries + size_t warmup_total = + std::min(actual_threads * warmup_per_thread, queries.size()); + + // Parallel warmup using std::thread + std::vector warmup_threads; + auto warmup_worker = [&](size_t start_idx, size_t count) { + for (size_t i = 0; i < count && (start_idx + i) < queries.size(); + ++i) { + size_t q_idx = start_idx + i; + vector_column_params::VectorData vector_data; + vector_data.vector = + vector_column_params::DenseVector{const_cast( + static_cast(queries[q_idx].data()))}; + + vector_column_params::QueryParams query_params; + query_params.topk = actual_topk; + query_params.fetch_vector = false; + query_params.filter = segment->get_filter().get(); + + auto omega_params = std::make_shared(); + omega_params->set_ef(ef_groundtruth); + omega_params->set_training_query_id( + -1); // Warmup, don't collect training data + query_params.query_params = omega_params; + + static_cast(indexers[0]->Search(vector_data, query_params)); + } + }; + + // Launch warmup threads + size_t queries_per_warmup_thread = warmup_total / actual_threads; + for (size_t t = 0; + t < actual_threads && t * queries_per_warmup_thread < warmup_total; + ++t) { + size_t start = t * queries_per_warmup_thread; + size_t count = + std::min(queries_per_warmup_thread, warmup_total - start); + if (count > 0) { + warmup_threads.emplace_back(warmup_worker, start, count); + } + } + + for (auto &t : warmup_threads) { + t.join(); + } + + auto warmup_end = std::chrono::high_resolution_clock::now(); + auto warmup_ms = std::chrono::duration_cast( + warmup_end - warmup_start) + .count(); + + // Note: If warmup takes very long (>60s), recommend using + // ef_groundtruth=0 (Eigen brute force) + if (warmup_ms > 60000) { + LOG_INFO( + "HNSW warmup took %zu ms. For cold indexes, consider using " + "ef_groundtruth=0 (Eigen brute force)", + warmup_ms); + } + } + + // Pre-allocate ground_truth for thread-safe access + ground_truth.resize(queries.size()); + + // Use std::thread instead of OpenMP (same as training searches) + size_t actual_threads = + num_threads > 0 ? num_threads : std::thread::hardware_concurrency(); + actual_threads = std::min(actual_threads, queries.size()); + + auto worker = [&](size_t start_idx, size_t end_idx) { + for (size_t q = start_idx; q < end_idx; ++q) { + // Prepare query parameters (exactly same as training searches) + vector_column_params::VectorData vector_data; + vector_data.vector = vector_column_params::DenseVector{ + const_cast(static_cast(queries[q].data()))}; + + vector_column_params::QueryParams query_params; + query_params.topk = actual_topk; + query_params.fetch_vector = false; + query_params.filter = segment->get_filter().get(); + + // Create OmegaQueryParams (same as training searches) + auto omega_params = std::make_shared(); + omega_params->set_ef(ef_groundtruth); + omega_params->set_training_query_id(static_cast(q)); + query_params.query_params = omega_params; + + // Search on first indexer (same as training searches) + auto search_result = indexers[0]->Search(vector_data, query_params); + if (search_result.has_value()) { + auto &results = search_result.value(); + std::vector result_ids; + result_ids.reserve(results->count()); + auto iter = results->create_iterator(); + while (iter->valid()) { + uint64_t doc_id = iter->doc_id(); + // Skip self in held-out mode + if (held_out_mode && doc_id == query_doc_ids[q]) { + iter->next(); + continue; + } + result_ids.push_back(doc_id); + if (result_ids.size() >= topk) break; + iter->next(); + } + ground_truth[q] = std::move(result_ids); + } + } + }; + + // Launch threads + std::vector threads; + size_t queries_per_thread = + (queries.size() + actual_threads - 1) / actual_threads; + for (size_t t = 0; t < actual_threads; ++t) { + size_t start_idx = t * queries_per_thread; + size_t end_idx = + std::min(start_idx + queries_per_thread, queries.size()); + if (start_idx < end_idx) { + threads.emplace_back(worker, start_idx, end_idx); + } + } + + // Wait for all threads + for (auto &thread : threads) { + thread.join(); + } + + auto end_time = std::chrono::high_resolution_clock::now(); + auto total_ms = std::chrono::duration_cast( + end_time - start_time) + .count(); + LOG_INFO("Computed ground truth (HNSW ef=%d) for %zu queries in %zu ms", + ef_groundtruth, queries.size(), total_ms); + return ground_truth; + } + } + + // ============================================================ + // Branch 2: Eigen brute force (ef_groundtruth == 0) + // Exact results, uses batch matrix multiplication + // ============================================================ + // Convert zvec MetricType to omega MetricType +#if ZVEC_ENABLE_OMEGA + omega::MetricType omega_metric; + switch (metric_type) { + case MetricType::L2: + omega_metric = omega::MetricType::L2; + break; + case MetricType::COSINE: + omega_metric = omega::MetricType::COSINE; + break; + case MetricType::IP: + default: + omega_metric = omega::MetricType::IP; + break; + } +#endif + + // Step 1: Load all base vectors into memory + auto load_start = std::chrono::high_resolution_clock::now(); + + std::vector base_vectors(doc_count * dim); + std::atomic load_error{false}; + + // Load vectors in parallel + size_t actual_threads = + num_threads > 0 ? num_threads : std::thread::hardware_concurrency(); + actual_threads = std::min(actual_threads, static_cast(doc_count)); + + auto load_worker = [&](size_t start_idx, size_t end_idx) { + for (size_t doc_idx = start_idx; doc_idx < end_idx && !load_error; + ++doc_idx) { + auto doc = segment->Fetch(doc_idx); + if (!doc) { + LOG_WARN("Failed to fetch document at index %zu", doc_idx); + load_error = true; + continue; + } + + auto vector_opt = doc->get>(field_name); + if (!vector_opt.has_value()) { + LOG_WARN("Document at index %zu does not have field '%s'", doc_idx, + field_name.c_str()); + load_error = true; + continue; + } + + const auto &vec = vector_opt.value(); + if (vec.size() != dim) { + LOG_WARN("Vector at index %zu has wrong dimension: %zu vs %zu", doc_idx, + vec.size(), dim); + load_error = true; + continue; + } + + std::memcpy(base_vectors.data() + doc_idx * dim, vec.data(), + dim * sizeof(float)); + } + }; + + std::vector load_threads; + size_t docs_per_thread = (doc_count + actual_threads - 1) / actual_threads; + + for (size_t t = 0; t < actual_threads; ++t) { + size_t start_idx = t * docs_per_thread; + size_t end_idx = + std::min(start_idx + docs_per_thread, static_cast(doc_count)); + if (start_idx < end_idx) { + load_threads.emplace_back(load_worker, start_idx, end_idx); + } + } + + for (auto &thread : load_threads) { + thread.join(); + } + + auto load_end = std::chrono::high_resolution_clock::now(); + auto load_ms = std::chrono::duration_cast( + load_end - load_start) + .count(); + + if (load_error) { + LOG_ERROR("Failed to load all base vectors, cannot compute ground truth"); + return ground_truth; + } + + // Step 2: Flatten query vectors + std::vector query_flat(queries.size() * dim); + for (size_t i = 0; i < queries.size(); ++i) { + std::memcpy(query_flat.data() + i * dim, queries[i].data(), + dim * sizeof(float)); + } + + // Step 3: Call OmegaLib's fast ground truth computation (Eigen) + auto compute_start = std::chrono::high_resolution_clock::now(); + +#if ZVEC_ENABLE_OMEGA + ground_truth = omega::ComputeGroundTruth( + base_vectors.data(), query_flat.data(), doc_count, queries.size(), dim, + topk, omega_metric, held_out_mode, + query_doc_ids); // Pass query-to-base mapping for correct self-exclusion +#else + ground_truth = ComputeGroundTruthFallbackBruteForce( + segment, field_name, queries, topk, query_doc_ids, metric_type); +#endif + + auto compute_end = std::chrono::high_resolution_clock::now(); + auto compute_ms = std::chrono::duration_cast( + compute_end - compute_start) + .count(); + + auto total_end = std::chrono::high_resolution_clock::now(); + auto total_ms = std::chrono::duration_cast( + total_end - start_time) + .count(); + + LOG_INFO( + "Computed ground truth (Eigen brute force) for %zu queries in %zu ms " + "(load: %zu ms, compute: %zu ms)", + queries.size(), total_ms, load_ms, compute_ms); + + return ground_truth; +} + +core_interface::GtCmpsData TrainingDataCollector::ComputeGtCmps( + const std::vector &records, + const std::vector> &ground_truth, size_t topk) { + // NOTE: This is a FALLBACK approximation method. + // The preferred method is to collect actual gt_cmps during search via + // VectorColumnIndexer::GetGtCmpsData(), which tracks the exact cmps value + // when each GT rank first enters the topk during HNSW traversal. + // + // This approximation uses record.cmps_visited as a simple heuristic. + + core_interface::GtCmpsData result; + result.topk = topk; + result.num_queries = ground_truth.size(); + + if (records.empty() || ground_truth.empty()) { + LOG_WARN("Empty records or ground_truth in ComputeGtCmps"); + return result; + } + + // Initialize gt_cmps and total_cmps + result.gt_cmps.resize(ground_truth.size()); + result.total_cmps.resize(ground_truth.size(), 0); + + for (size_t q = 0; q < ground_truth.size(); ++q) { + result.gt_cmps[q].resize(topk, 0); + } + + // Group records by query_id and track max cmps + // For gt_cmps, we use a simple heuristic: + // - If label=1 (GT found), use cmps_visited as gt_cmps for all GT ranks + // - If label=0 (GT not found), use the max cmps for that query + std::unordered_map query_max_cmps; + std::unordered_map query_first_found_cmps; + + for (const auto &record : records) { + int query_id = record.query_id; + if (query_id < 0 || query_id >= static_cast(ground_truth.size())) { + continue; + } + + // Track max cmps for each query + auto &max_cmps = query_max_cmps[query_id]; + max_cmps = std::max(max_cmps, record.cmps_visited); + + // Track first cmps where label became 1 + if (record.label > 0) { + auto it = query_first_found_cmps.find(query_id); + if (it == query_first_found_cmps.end()) { + query_first_found_cmps[query_id] = record.cmps_visited; + } else { + it->second = std::min(it->second, record.cmps_visited); + } + } + } + + // Fill in the result + for (size_t q = 0; q < ground_truth.size(); ++q) { + int query_id = static_cast(q); + result.total_cmps[q] = query_max_cmps[query_id]; + + // Use first_found_cmps if available, otherwise use total_cmps + auto it = query_first_found_cmps.find(query_id); + int gt_cmps_value = (it != query_first_found_cmps.end()) + ? it->second + : result.total_cmps[q]; + + for (size_t r = 0; r < result.gt_cmps[q].size(); ++r) { + result.gt_cmps[q][r] = gt_cmps_value; + } + } + + LOG_INFO("Computed gt_cmps (approximation) for %zu queries, topk=%zu", + result.num_queries, result.topk); + return result; +} + +Result +TrainingDataCollector::CollectTrainingDataWithGtCmps( + const Segment::Ptr &segment, const std::string &field_name, + const TrainingDataCollectorOptions &options, + const std::vector &provided_indexers) { + ResetTimingStats(); + ScopedTimer total_timer("CollectTrainingDataWithGtCmps [TOTAL]"); + LOG_INFO("Generating %zu held-out training queries for field '%s'", + options.num_training_queries, field_name.c_str()); + + std::vector> training_queries; + std::vector query_doc_ids; + { + ScopedTimer timer("Step1: GenerateHeldOutQueries"); + auto sampled = TrainingQueryGenerator::GenerateHeldOutQueries( + segment, field_name, options.num_training_queries, options.seed); + training_queries = std::move(sampled.vectors); + query_doc_ids = std::move(sampled.doc_ids); + } + + return CollectTrainingDataFromQueriesImpl(segment, field_name, + training_queries, {}, options, + query_doc_ids, provided_indexers); +} + +Result +TrainingDataCollector::CollectTrainingDataWithGtCmpsFromQueries( + const Segment::Ptr &segment, const std::string &field_name, + const std::vector> &training_queries, + const std::vector &query_doc_ids, + const TrainingDataCollectorOptions &options, + const std::vector &provided_indexers) { + ResetTimingStats(); + ScopedTimer total_timer("CollectTrainingDataWithGtCmps [TOTAL]"); + LOG_INFO("Reusing %zu cached held-out training queries for field '%s'", + training_queries.size(), field_name.c_str()); + return CollectTrainingDataFromQueriesImpl(segment, field_name, + training_queries, {}, options, + query_doc_ids, provided_indexers); +} + +} // namespace zvec diff --git a/src/db/training/training_data_collector.h b/src/db/training/training_data_collector.h new file mode 100644 index 000000000..3ef3c205c --- /dev/null +++ b/src/db/training/training_data_collector.h @@ -0,0 +1,161 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "db/index/segment/segment.h" + +namespace zvec { + +/** + * @brief Configuration options for training data collection + */ +struct TrainingDataCollectorOptions { + // Number of training queries to generate + size_t num_training_queries = 1000; + + // ef parameter for training searches (large value for recall ≈ 1) + int ef_training = 1000; + + // ef parameter for ground truth computation (0 = brute force, >0 = HNSW with + // this ef) Using HNSW with large ef is much faster than brute force while + // maintaining high accuracy + int ef_groundtruth = 0; + + // Top-K results to retrieve per query + size_t topk = 100; + + // K_train: number of ground truth results that must be collected for label=1 + // Label=1 iff the top K_train ground truth nodes are all in + // collected_node_ids Typically set to 1 (i.e., label=1 when the 1st ground + // truth is found) + size_t k_train = 1; + + // Random seed for reproducibility + uint64_t seed = 42; + + // Number of threads for parallel operations (0 = hardware_concurrency) + size_t num_threads = 0; +}; + +/** + * @brief Result of training data collection, includes both records and gt_cmps + */ +struct TrainingDataCollectorResult { + std::vector records; + core_interface::GtCmpsData gt_cmps_data; + std::vector> training_queries; + std::vector query_doc_ids; +}; + +/** + * @brief Collector for OMEGA training data + * + * This class collects training data by: + * 1. Generating training queries from base vectors + * 2. Computing ground truth with brute force search + * 3. Performing searches in training mode with large ef + * 4. Labeling collected records based on ground truth + */ +class TrainingDataCollector { + public: + using TimingStats = std::vector>; + + static void ResetTimingStats(); + + static TimingStats ConsumeTimingStats(); + + /** + * @brief Collect training data with gt_cmps information for table generation + * + * This is the extended version that also computes gt_cmps data needed for + * generating gt_collected_table and gt_cmps_all_table. + * + * @param segment The segment to collect data from (must be persisted) + * @param field_name Vector field name to train on + * @param options Collection options + * @param indexers Optional specific indexers to use + * @return TrainingDataCollectorResult with records and gt_cmps_data + */ + static Result CollectTrainingDataWithGtCmps( + const Segment::Ptr &segment, const std::string &field_name, + const TrainingDataCollectorOptions &options, + const std::vector &indexers = {}); + + static Result + CollectTrainingDataWithGtCmpsFromQueries( + const Segment::Ptr &segment, const std::string &field_name, + const std::vector> &training_queries, + const std::vector &query_doc_ids, + const TrainingDataCollectorOptions &options, + const std::vector &indexers = {}); + + private: + /** + * @brief Compute ground truth using brute force or HNSW search + * + * @param segment The segment to search + * @param field_name Vector field name + * @param queries Training query vectors + * @param topk Number of top results to retrieve + * @param num_threads Number of threads (0 = hardware_concurrency) + * @param query_doc_ids Optional doc_ids of query vectors (for self-exclusion + * in held-out mode) + * @param ef_groundtruth ef value for HNSW search (0 = brute force, >0 = HNSW) + * @param metric_type Distance metric type (L2, IP, COSINE) + * @param indexers Optional pre-opened indexers (for HNSW GT, avoids using + * stale indexers from segment) + * @return Ground truth doc IDs for each query + */ + static std::vector> ComputeGroundTruth( + const Segment::Ptr &segment, const std::string &field_name, + const std::vector> &queries, size_t topk, + size_t num_threads, const std::vector &query_doc_ids = {}, + int ef_groundtruth = 0, MetricType metric_type = MetricType::IP, + const std::vector &indexers = {}); + + /** + * @brief Compute gt_cmps data from training records and ground truth + * + * For each query and each GT rank, find the cmps value when that GT was first + * collected. This data is used to generate gt_collected_table and + * gt_cmps_all_table. + * + * @param records Training records (must be sorted by query_id, then by cmps) + * @param ground_truth Ground truth doc IDs per query + * @param topk Number of top results per query + * @return GtCmpsData structure with computed gt_cmps + */ + static core_interface::GtCmpsData ComputeGtCmps( + const std::vector &records, + const std::vector> &ground_truth, size_t topk); + + static Result CollectTrainingDataFromQueriesImpl( + const Segment::Ptr &segment, const std::string &field_name, + const std::vector> &training_queries, + const std::vector> &provided_ground_truth, + const TrainingDataCollectorOptions &options, + const std::vector &query_doc_ids, + const std::vector &provided_indexers); +}; + +} // namespace zvec diff --git a/src/include/zvec/c_api.h b/src/include/zvec/c_api.h index 75eba707b..62d5f137f 100644 --- a/src/include/zvec/c_api.h +++ b/src/include/zvec/c_api.h @@ -767,6 +767,7 @@ typedef uint32_t ZVecIndexType; #define ZVEC_INDEX_TYPE_HNSW 1 #define ZVEC_INDEX_TYPE_IVF 2 #define ZVEC_INDEX_TYPE_FLAT 3 +#define ZVEC_INDEX_TYPE_OMEGA 11 #define ZVEC_INDEX_TYPE_INVERT 10 /** @@ -983,6 +984,16 @@ ZVEC_EXPORT ZVecErrorCode ZVEC_CALL zvec_index_params_set_invert_params( */ typedef struct ZVecHnswQueryParams ZVecHnswQueryParams; +/** + * @brief OMEGA query parameters handle (opaque pointer) + * + * Internally maps to zvec::OmegaQueryParams* (raw pointer). + * Created by zvec_query_params_omega_create() and destroyed by + * zvec_query_params_omega_destroy(). Caller owns the pointer and must + * explicitly destroy it. + */ +typedef struct ZVecOmegaQueryParams ZVecOmegaQueryParams; + /** * @brief IVF query parameters handle (opaque pointer) * @@ -1120,6 +1131,117 @@ ZVEC_EXPORT ZVecErrorCode ZVEC_CALL zvec_query_params_hnsw_set_is_using_refiner( ZVEC_EXPORT bool ZVEC_CALL zvec_query_params_hnsw_get_is_using_refiner(const ZVecHnswQueryParams *params); +// ----------------------------------------------------------------------------- +// ZVecOmegaQueryParams (OMEGA Query Parameters) +// ----------------------------------------------------------------------------- + +/** + * @brief Create OMEGA query parameters + * @param ef Exploration factor during search + * @param target_recall Target recall used by OMEGA early stopping + * @param radius Search radius + * @param is_linear Whether linear search + * @param is_using_refiner Whether using refiner + * @return ZVecOmegaQueryParams* Pointer to the newly created OMEGA query + * parameters + */ +ZVEC_EXPORT ZVecOmegaQueryParams *ZVEC_CALL +zvec_query_params_omega_create(int ef, float target_recall, float radius, + bool is_linear, bool is_using_refiner); + +/** + * @brief Destroy OMEGA query parameters + * @param params OMEGA query parameters pointer + */ +ZVEC_EXPORT void ZVEC_CALL +zvec_query_params_omega_destroy(ZVecOmegaQueryParams *params); + +/** + * @brief Set exploration factor + * @param params OMEGA query parameters pointer + * @param ef Exploration factor + * @return ZVecErrorCode Error code + */ +ZVEC_EXPORT ZVecErrorCode ZVEC_CALL +zvec_query_params_omega_set_ef(ZVecOmegaQueryParams *params, int ef); + +/** + * @brief Get exploration factor + * @param params OMEGA query parameters pointer + * @return int Exploration factor + */ +ZVEC_EXPORT int ZVEC_CALL +zvec_query_params_omega_get_ef(const ZVecOmegaQueryParams *params); + +/** + * @brief Set target recall + * @param params OMEGA query parameters pointer + * @param target_recall Target recall + * @return ZVecErrorCode Error code + */ +ZVEC_EXPORT ZVecErrorCode ZVEC_CALL zvec_query_params_omega_set_target_recall( + ZVecOmegaQueryParams *params, float target_recall); + +/** + * @brief Get target recall + * @param params OMEGA query parameters pointer + * @return float Target recall + */ +ZVEC_EXPORT float ZVEC_CALL +zvec_query_params_omega_get_target_recall(const ZVecOmegaQueryParams *params); + +/** + * @brief Set search radius + * @param params OMEGA query parameters pointer + * @param radius Search radius + * @return ZVecErrorCode Error code + */ +ZVEC_EXPORT ZVecErrorCode ZVEC_CALL +zvec_query_params_omega_set_radius(ZVecOmegaQueryParams *params, float radius); + +/** + * @brief Get search radius + * @param params OMEGA query parameters pointer + * @return float Search radius + */ +ZVEC_EXPORT float ZVEC_CALL +zvec_query_params_omega_get_radius(const ZVecOmegaQueryParams *params); + +/** + * @brief Set linear search mode + * @param params OMEGA query parameters pointer + * @param is_linear Whether linear search + * @return ZVecErrorCode Error code + */ +ZVEC_EXPORT ZVecErrorCode ZVEC_CALL zvec_query_params_omega_set_is_linear( + ZVecOmegaQueryParams *params, bool is_linear); + +/** + * @brief Get linear search mode + * @param params OMEGA query parameters pointer + * @return bool Whether linear search + */ +ZVEC_EXPORT bool ZVEC_CALL +zvec_query_params_omega_get_is_linear(const ZVecOmegaQueryParams *params); + +/** + * @brief Set whether to use refiner + * @param params OMEGA query parameters pointer + * @param is_using_refiner Whether to use refiner + * @return ZVecErrorCode Error code + */ +ZVEC_EXPORT ZVecErrorCode ZVEC_CALL +zvec_query_params_omega_set_is_using_refiner(ZVecOmegaQueryParams *params, + bool is_using_refiner); + +/** + * @brief Get whether to use refiner + * @param params OMEGA query parameters pointer + * @return bool Whether to use refiner + */ +ZVEC_EXPORT bool ZVEC_CALL zvec_query_params_omega_get_is_using_refiner( + const ZVecOmegaQueryParams *params); + // ----------------------------------------------------------------------------- // ZVecIVFQueryParams (IVF Query Parameters) // ----------------------------------------------------------------------------- @@ -1469,6 +1591,15 @@ zvec_vector_query_set_query_params(ZVecVectorQuery *query, void *params); ZVEC_EXPORT ZVecErrorCode ZVEC_CALL zvec_vector_query_set_hnsw_params( ZVecVectorQuery *query, ZVecHnswQueryParams *hnsw_params); +/** + * @brief Set OMEGA query parameters (takes ownership) + * @param query Vector query pointer + * @param omega_params OMEGA query parameters pointer + * @return ZVecErrorCode Error code + */ +ZVEC_EXPORT ZVecErrorCode ZVEC_CALL zvec_vector_query_set_omega_params( + ZVecVectorQuery *query, ZVecOmegaQueryParams *omega_params); + /** * @brief Set IVF query parameters (takes ownership) * @param query Vector query pointer diff --git a/src/include/zvec/core/framework/index_context.h b/src/include/zvec/core/framework/index_context.h index c77fcf42e..7d185ce0f 100644 --- a/src/include/zvec/core/framework/index_context.h +++ b/src/include/zvec/core/framework/index_context.h @@ -22,6 +22,7 @@ #include #include #include +#include namespace zvec { namespace core { @@ -248,6 +249,32 @@ class IndexContext { return profiler_; } + //! Get training records collected during search (for OMEGA training mode) + //! Default implementation returns empty vector. Override in OmegaContext. + virtual std::vector take_training_records() { + return {}; + } + + //! Clear training records (call before each search if context is reused) + virtual void clear_training_records() {} + + //! Get gt_cmps data (cmps when each GT rank was found) for OMEGA training + //! Returns vector where gt_cmps[rank] = cmps when GT[rank] first entered topk + //! Default implementation returns empty vector. Override in OmegaContext. + virtual std::vector take_gt_cmps() { + return {}; + } + + //! Get total comparisons for this search (OMEGA training) + virtual int get_total_cmps() const { + return 0; + } + + //! Get training query ID for this search (-1 means not set) + virtual int get_training_query_id() const { + return -1; + } + private: //! Members IndexFilter filter_{}; diff --git a/src/include/zvec/core/framework/index_storage.h b/src/include/zvec/core/framework/index_storage.h index 8273004a3..cccc82425 100644 --- a/src/include/zvec/core/framework/index_storage.h +++ b/src/include/zvec/core/framework/index_storage.h @@ -264,6 +264,11 @@ class IndexStorage : public IndexModule { virtual bool isHugePage(void) const { return false; } + + //! Retrieve file path of storage (for OMEGA model loading) + virtual std::string file_path(void) const { + return ""; // Default: empty (not all storages have file paths) + } }; } // namespace core diff --git a/src/include/zvec/core/interface/index.h b/src/include/zvec/core/interface/index.h index 8634e3904..ac484872c 100644 --- a/src/include/zvec/core/interface/index.h +++ b/src/include/zvec/core/interface/index.h @@ -31,6 +31,10 @@ #include #include #include +#include +#include +#include +#include #include "zvec/core/framework/index_provider.h" namespace zvec::core_interface { @@ -97,6 +101,14 @@ struct SearchResult { // use string to manage memory std::vector reverted_vector_list_{}; std::vector reverted_sparse_values_list_{}; + // Training records collected during search (for OMEGA training mode) + std::vector training_records_{}; + // GT cmps data: cmps value when each GT rank was found (for OMEGA training) + // gt_cmps_per_rank_[rank] = cmps when GT[rank] first entered topk (-1 if not + // found) + std::vector gt_cmps_per_rank_{}; + int total_cmps_{0}; // Total comparisons in this search + int training_query_id_{-1}; // Query ID for this search (-1 if not training) }; class Index { @@ -132,6 +144,21 @@ class Index { const BaseIndexQueryParam::Pointer &search_param, SearchResult *result); + // Capability Pattern: Query optional capabilities + /** + * @brief Get training capability interface if supported. + * + * This method allows indexes to optionally provide training functionality + * without polluting the base Index class. Follows the Capability Pattern. + * + * @return Pointer to ITrainingCapable interface if supported, nullptr + * otherwise + * + */ + virtual class ITrainingCapable *GetTrainingCapability() { + return nullptr; // Default: capability not supported + } + virtual BaseIndexParam::Pointer GetParam() const { return std::make_shared(param_); } @@ -221,6 +248,7 @@ class Index { size_t context_index_; core::IndexStorage::Pointer storage_{}; + std::string file_path_; // Storage file path bool is_open_{false}; bool is_sparse_{false}; @@ -317,4 +345,32 @@ class HNSWRabitqIndex : public Index { }; +//! OMEGA Index - HNSW with learned early stopping +/** + * OmegaIndex is a specialized HNSW index that supports training mode for + * collecting features to train the OMEGA early stopping model. + * + * It implements the ITrainingCapable interface to provide training + * functionality without modifying the generic HNSWIndex class. + */ +class OmegaIndex : public HNSWIndex, public ITrainingCapable { + public: + OmegaIndex() = default; + + // Override GetTrainingCapability to return this + ITrainingCapable *GetTrainingCapability() override { + return this; + } + + ITrainingSession::Pointer CreateTrainingSession() override; + + protected: + virtual int CreateAndInitStreamer(const BaseIndexParam ¶m) override; + + virtual int _prepare_for_search( + const VectorData &query, const BaseIndexQueryParam::Pointer &search_param, + core::IndexContext::Pointer &context) override; +}; + + } // namespace zvec::core_interface diff --git a/src/include/zvec/core/interface/index_param.h b/src/include/zvec/core/interface/index_param.h index 0d7bf3017..4fc4246f5 100644 --- a/src/include/zvec/core/interface/index_param.h +++ b/src/include/zvec/core/interface/index_param.h @@ -62,6 +62,7 @@ enum class IndexType { kFlat, kIVF, // it's actual a two-layer index kHNSW, + kOMEGA, // HNSW with learned early stopping kHNSWRabitq, }; @@ -183,12 +184,24 @@ struct HNSWQueryParam : public BaseIndexQueryParam { using Pointer = std::shared_ptr; uint32_t ef_search = kDefaultHnswEfSearch; + int training_query_id = + -1; // For parallel training searches, -1 means use global BaseIndexQueryParam::Pointer Clone() const override { return std::make_shared(*this); } }; +struct OmegaQueryParam : public HNSWQueryParam { + using Pointer = std::shared_ptr; + + float target_recall = 0.95f; + + BaseIndexQueryParam::Pointer Clone() const override { + return std::make_shared(*this); + } +}; + struct HNSWRabitqQueryParam : public BaseIndexQueryParam { using Pointer = std::shared_ptr; @@ -362,4 +375,4 @@ struct HNSWRabitqIndexParam : public BaseIndexParam { bool omit_empty_value = false) const override; }; -} // namespace zvec::core_interface \ No newline at end of file +} // namespace zvec::core_interface diff --git a/src/include/zvec/core/interface/index_param_builders.h b/src/include/zvec/core/interface/index_param_builders.h index e22ecb392..49b2a8c36 100644 --- a/src/include/zvec/core/interface/index_param_builders.h +++ b/src/include/zvec/core/interface/index_param_builders.h @@ -299,6 +299,30 @@ class HNSWQueryParamBuilder } }; +class OmegaQueryParamBuilder + : public BaseIndexQueryParamBuilder { + public: + OmegaQueryParamBuilder &with_ef_search(int ef_search) { + m_param.ef_search = ef_search; + return *this; + } + + OmegaQueryParamBuilder &with_training_query_id(int training_query_id) { + m_param.training_query_id = training_query_id; + return *this; + } + + OmegaQueryParamBuilder &with_target_recall(float target_recall) { + m_param.target_recall = target_recall; + return *this; + } + + OmegaQueryParam::Pointer build() { + return std::make_shared(std::move(m_param)); + } +}; + // Example Usage: // HNSWQueryParam::Pointer hnsw_config = HNSWQueryParamBuilder() // .with_topk(5) @@ -407,4 +431,4 @@ class SCANNIndexParamBuilder { }; } // namespace predefined -} // namespace zvec::core_interface \ No newline at end of file +} // namespace zvec::core_interface diff --git a/src/include/zvec/core/interface/training.h b/src/include/zvec/core/interface/training.h new file mode 100644 index 000000000..0ac939142 --- /dev/null +++ b/src/include/zvec/core/interface/training.h @@ -0,0 +1,83 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace zvec::core_interface { + +/** + * @brief Training record structure for OMEGA adaptive search. + * + * This structure captures features collected during a single comparison + * in the HNSW/OMEGA search process. Multiple records are collected per query + * to train models for early stopping prediction. + * + * Features (11 dimensions total): + * - query_id: Unique identifier for the query + * - hops_visited: Number of hops in the search graph + * - cmps_visited: Number of node comparisons performed + * - dist_1st: Distance to the first (best) result found so far + * - dist_start: Distance to the starting node + * - traversal_window_stats: 7 statistical features of the traversal window + * (avg, var, min, max, median, percentile25, percentile75) + * - label: Binary label (1 if collected enough GT results, 0 otherwise) + * Computed in real-time during search (memory optimized) + */ +struct TrainingRecord { + int query_id; + int hops_visited; + int cmps_visited; + float dist_1st; + float dist_start; + std::array traversal_window_stats; + int label; // Computed in real-time during search + + TrainingRecord() + : query_id(0), + hops_visited(0), + cmps_visited(0), + dist_1st(0.0f), + dist_start(0.0f), + traversal_window_stats{}, + label(0) {} +}; + +/** + * @brief Ground truth cmps data for OMEGA table generation. + * + * For each query, stores the cmps value when each ground truth result was + * found. This data is used to generate gt_collected_table and + * gt_cmps_all_table. + * + * gt_cmps[query_id][rank] = cmps value when GT[rank] was collected + * = total_cmps if GT[rank] was never found + */ +struct GtCmpsData { + // gt_cmps[query_id][rank] = cmps when GT of rank was found + std::vector> gt_cmps; + + // total_cmps[query_id] = total comparisons for this query + std::vector total_cmps; + + // topk value used during training + size_t topk = 0; + + // Number of queries + size_t num_queries = 0; +}; + +} // namespace zvec::core_interface diff --git a/src/include/zvec/core/interface/training_capable.h b/src/include/zvec/core/interface/training_capable.h new file mode 100644 index 000000000..dc7452fe7 --- /dev/null +++ b/src/include/zvec/core/interface/training_capable.h @@ -0,0 +1,37 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace zvec { +namespace core_interface { + +/** + * @brief Training capability interface for indexes that support post-build + * training. + * + * This interface follows the Capability Pattern, allowing indexes to optionally + * provide training functionality without polluting the base Index class. + */ +class ITrainingCapable { + public: + virtual ~ITrainingCapable() = default; + + virtual ITrainingSession::Pointer CreateTrainingSession() = 0; +}; + +} // namespace core_interface +} // namespace zvec diff --git a/src/include/zvec/core/interface/training_session.h b/src/include/zvec/core/interface/training_session.h new file mode 100644 index 000000000..6fb31e9bd --- /dev/null +++ b/src/include/zvec/core/interface/training_session.h @@ -0,0 +1,60 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +namespace zvec::core_interface { + +struct TrainingSessionConfig { + std::vector> ground_truth; + size_t topk = 0; + int k_train = 1; +}; + +struct QueryTrainingArtifacts { + std::vector records; + std::vector gt_cmps_per_rank; + int total_cmps = 0; + int training_query_id = -1; +}; + +struct TrainingArtifacts { + std::vector records; + GtCmpsData gt_cmps_data; +}; + +class ITrainingSession { + public: + using Pointer = std::shared_ptr; + + virtual ~ITrainingSession() = default; + + virtual zvec::Status Start(const TrainingSessionConfig &config) = 0; + + virtual void BeginQuery(int query_id) = 0; + + virtual void CollectQueryArtifacts(QueryTrainingArtifacts &&artifacts) = 0; + + virtual TrainingArtifacts ConsumeArtifacts() = 0; + + virtual void Finish() = 0; +}; + +} // namespace zvec::core_interface diff --git a/src/include/zvec/db/collection.h b/src/include/zvec/db/collection.h index 57c231921..4cc5ff0a1 100644 --- a/src/include/zvec/db/collection.h +++ b/src/include/zvec/db/collection.h @@ -72,7 +72,7 @@ class Collection { virtual Status DropIndex(const std::string &column_name) = 0; virtual Status Optimize(const OptimizeOptions &options = OptimizeOptions{ - 0}) = 0; + 0, false}) = 0; virtual Status AddColumn(const FieldSchema::Ptr &column_schema, const std::string &expression, @@ -105,4 +105,4 @@ class Collection { const std::vector &pks) const = 0; }; -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/include/zvec/db/index_params.h b/src/include/zvec/db/index_params.h index fcccf080d..61a37238d 100644 --- a/src/include/zvec/db/index_params.h +++ b/src/include/zvec/db/index_params.h @@ -46,7 +46,8 @@ class IndexParams { bool is_vector_index_type() const { return type_ == IndexType::FLAT || type_ == IndexType::HNSW || - type_ == IndexType::HNSW_RABITQ || type_ == IndexType::IVF; + type_ == IndexType::HNSW_RABITQ || type_ == IndexType::IVF || + type_ == IndexType::OMEGA; } IndexType type() const { @@ -428,4 +429,133 @@ class IVFIndexParams : public VectorIndexParams { bool use_soar_; }; -} // namespace zvec \ No newline at end of file +/* + * Vector: Omega index params + */ +class OmegaIndexParams : public VectorIndexParams { + public: + OmegaIndexParams( + MetricType metric_type, int m = core_interface::kDefaultHnswNeighborCnt, + int ef_construction = core_interface::kDefaultHnswEfConstruction, + QuantizeType quantize_type = QuantizeType::UNDEFINED, + uint32_t min_vector_threshold = 100000, + size_t num_training_queries = 1000, int ef_training = 1000, + int window_size = 100, int ef_groundtruth = 0, + int k_train = + 1) // 0 means use brute force, >0 means use HNSW with this ef + : VectorIndexParams(IndexType::OMEGA, metric_type, quantize_type), + m_(m), + ef_construction_(ef_construction), + min_vector_threshold_(min_vector_threshold), + num_training_queries_(num_training_queries), + ef_training_(ef_training), + window_size_(window_size), + ef_groundtruth_(ef_groundtruth), + k_train_(k_train) {} + + using OPtr = std::shared_ptr; + + public: + Ptr clone() const override { + return std::make_shared( + metric_type_, m_, ef_construction_, quantize_type_, + min_vector_threshold_, num_training_queries_, ef_training_, + window_size_, ef_groundtruth_, k_train_); + } + + std::string to_string() const override { + auto base_str = vector_index_params_to_string("OmegaIndexParams", + metric_type_, quantize_type_); + std::ostringstream oss; + oss << base_str << ",m:" << m_ << ",ef_construction:" << ef_construction_ + << ",min_vector_threshold:" << min_vector_threshold_ + << ",num_training_queries:" << num_training_queries_ + << ",ef_training:" << ef_training_ << ",window_size:" << window_size_ + << ",ef_groundtruth:" << ef_groundtruth_ << ",k_train:" << k_train_ + << "}"; + return oss.str(); + } + + bool operator==(const IndexParams &other) const override { + return type() == other.type() && + metric_type() == + static_cast(other).metric_type() && + m_ == static_cast(other).m_ && + ef_construction_ == + static_cast(other).ef_construction_ && + min_vector_threshold_ == static_cast(other) + .min_vector_threshold_ && + num_training_queries_ == static_cast(other) + .num_training_queries_ && + ef_training_ == + static_cast(other).ef_training_ && + window_size_ == + static_cast(other).window_size_ && + ef_groundtruth_ == + static_cast(other).ef_groundtruth_ && + k_train_ == static_cast(other).k_train_ && + quantize_type() == + static_cast(other).quantize_type(); + } + + void set_m(int m) { + m_ = m; + } + int m() const { + return m_; + } + void set_ef_construction(int ef_construction) { + ef_construction_ = ef_construction; + } + int ef_construction() const { + return ef_construction_; + } + void set_min_vector_threshold(uint32_t min_vector_threshold) { + min_vector_threshold_ = min_vector_threshold; + } + uint32_t min_vector_threshold() const { + return min_vector_threshold_; + } + void set_num_training_queries(size_t num_training_queries) { + num_training_queries_ = num_training_queries; + } + size_t num_training_queries() const { + return num_training_queries_; + } + void set_ef_training(int ef_training) { + ef_training_ = ef_training; + } + int ef_training() const { + return ef_training_; + } + void set_window_size(int window_size) { + window_size_ = window_size; + } + int window_size() const { + return window_size_; + } + void set_ef_groundtruth(int ef_groundtruth) { + ef_groundtruth_ = ef_groundtruth; + } + int ef_groundtruth() const { + return ef_groundtruth_; + } + void set_k_train(int k_train) { + k_train_ = k_train; + } + int k_train() const { + return k_train_; + } + + private: + int m_; + int ef_construction_; + uint32_t min_vector_threshold_; + size_t num_training_queries_; + int ef_training_; + int window_size_; + int ef_groundtruth_; // 0 = brute force, >0 = use HNSW with this ef + int k_train_; +}; + +} // namespace zvec diff --git a/src/include/zvec/db/options.h b/src/include/zvec/db/options.h index 1f2a9cbf2..8fc00ae78 100644 --- a/src/include/zvec/db/options.h +++ b/src/include/zvec/db/options.h @@ -13,7 +13,9 @@ // limitations under the License. #pragma once +#include #include +#include namespace zvec { @@ -56,6 +58,7 @@ struct CreateIndexOptions { struct OptimizeOptions { int concurrency_{0}; + bool retrain_only_{false}; }; struct AddColumnOptions { @@ -66,4 +69,4 @@ struct AlterColumnOptions { int concurrency_{0}; }; -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/include/zvec/db/query_params.h b/src/include/zvec/db/query_params.h index ba62dab9c..03b38ed68 100644 --- a/src/include/zvec/db/query_params.h +++ b/src/include/zvec/db/query_params.h @@ -73,7 +73,7 @@ class HnswQueryParams : public QueryParams { HnswQueryParams(int ef = core_interface::kDefaultHnswEfSearch, float radius = 0.0f, bool is_linear = false, bool is_using_refiner = false) - : QueryParams(IndexType::HNSW), ef_(ef) { + : QueryParams(IndexType::HNSW), ef_(ef), training_query_id_(-1) { set_radius(radius); set_is_linear(is_linear); set_is_using_refiner(is_using_refiner); @@ -89,8 +89,43 @@ class HnswQueryParams : public QueryParams { ef_ = ef; } + // Training query ID for parallel training searches + // -1 means not set (use global current_query_id from indexer) + int training_query_id() const { + return training_query_id_; + } + + void set_training_query_id(int query_id) { + training_query_id_ = query_id; + } + private: int ef_; + int training_query_id_; +}; + +class OmegaQueryParams : public HnswQueryParams { + public: + OmegaQueryParams(int ef = core_interface::kDefaultHnswEfSearch, + float target_recall = 0.95f, float radius = 0.0f, + bool is_linear = false, bool is_using_refiner = false) + : HnswQueryParams(ef, radius, is_linear, is_using_refiner), + target_recall_(target_recall) { + set_type(IndexType::OMEGA); + } + + virtual ~OmegaQueryParams() = default; + + float target_recall() const { + return target_recall_; + } + + void set_target_recall(float target_recall) { + target_recall_ = target_recall; + } + + private: + float target_recall_; }; class IVFQueryParams : public QueryParams { @@ -172,4 +207,4 @@ class FlatQueryParams : public QueryParams { float scale_factor_{10}; }; -} // namespace zvec \ No newline at end of file +} // namespace zvec diff --git a/src/include/zvec/db/type.h b/src/include/zvec/db/type.h index 1578f81d8..ccf596f5d 100644 --- a/src/include/zvec/db/type.h +++ b/src/include/zvec/db/type.h @@ -27,6 +27,7 @@ enum class IndexType : uint32_t { FLAT = 3, HNSW_RABITQ = 4, INVERT = 10, + OMEGA = 11, }; /* diff --git a/tests/core/algorithm/CMakeLists.txt b/tests/core/algorithm/CMakeLists.txt index 9ef1ec2a0..c5f19e6c3 100644 --- a/tests/core/algorithm/CMakeLists.txt +++ b/tests/core/algorithm/CMakeLists.txt @@ -10,3 +10,6 @@ cc_directories(hnsw_sparse) if(RABITQ_SUPPORTED) cc_directories(hnsw_rabitq) endif() +if(ZVEC_ENABLE_OMEGA) +cc_directories(omega) +endif() diff --git a/tests/core/algorithm/omega/CMakeLists.txt b/tests/core/algorithm/omega/CMakeLists.txt new file mode 100644 index 000000000..53d5756e2 --- /dev/null +++ b/tests/core/algorithm/omega/CMakeLists.txt @@ -0,0 +1,17 @@ +include(${CMAKE_SOURCE_DIR}/cmake/bazel.cmake) + +file(GLOB_RECURSE ALL_TEST_SRCS *_test.cc) + +foreach(CC_SRCS ${ALL_TEST_SRCS}) + get_filename_component(CC_TARGET ${CC_SRCS} NAME_WE) + cc_gtest( + NAME ${CC_TARGET} + STRICT + LIBS zvec_ailego core_framework core_utility core_metric core_quantizer + core_knn_hnsw core_knn_hnsw_rabitq core_knn_flat core_knn_omega core_interface + core_mix_reducer omega + SRCS ${CC_SRCS} + INCS . ${CMAKE_SOURCE_DIR}/src/core ${CMAKE_SOURCE_DIR}/src/core/algorithm + ${CMAKE_SOURCE_DIR}/thirdparty/omega/OMEGALib/include + ) +endforeach() diff --git a/tests/core/algorithm/omega/omega_search_context_test.cc b/tests/core/algorithm/omega/omega_search_context_test.cc new file mode 100644 index 000000000..eb73c0735 --- /dev/null +++ b/tests/core/algorithm/omega/omega_search_context_test.cc @@ -0,0 +1,106 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include + +namespace omega { + +TEST(OmegaSearchContextTest, DefaultStateWithoutModelDoesNotEarlyStop) { + SearchContext ctx(nullptr, nullptr, 0.95f, 3, 5); + + int hops = -1; + int comparisons = -1; + int collected_gt = -1; + ctx.GetStats(&hops, &comparisons, &collected_gt); + + EXPECT_EQ(hops, 0); + EXPECT_EQ(comparisons, 0); + EXPECT_EQ(collected_gt, 0); + EXPECT_EQ(ctx.GetK(), 3); + EXPECT_EQ(ctx.GetNextPredictionCmps(), 50); + EXPECT_EQ(ctx.GetPredictionBatchMinInterval(), 10); + EXPECT_FALSE(ctx.ShouldTrackTraversalWindow()); + EXPECT_FALSE(ctx.ShouldStopEarly()); + EXPECT_FALSE(ctx.EarlyStopHit()); +} + +TEST(OmegaSearchContextTest, TrainingModeCollectsRecordsAndGtCmps) { + SearchContext ctx(nullptr, nullptr, 0.95f, 2, 4); + ctx.EnableTrainingMode(7, std::vector{11, 22}, 2); + ctx.SetDistStart(0.5f); + + EXPECT_TRUE(ctx.ShouldTrackTraversalWindow()); + + ctx.ReportHop(); + EXPECT_FALSE(ctx.ReportVisitCandidate(11, 0.1f, true)); + ctx.ReportHop(); + EXPECT_FALSE(ctx.ReportVisitCandidate(22, 0.2f, true)); + + const auto& records = ctx.GetTrainingRecords(); + ASSERT_EQ(records.size(), 2U); + EXPECT_EQ(records[0].query_id, 7); + EXPECT_EQ(records[0].cmps_visited, 1); + EXPECT_EQ(records[0].hops_visited, 1); + EXPECT_EQ(records[0].label, 0); + ASSERT_EQ(records[0].traversal_window_stats.size(), 7U); + + EXPECT_EQ(records[1].query_id, 7); + EXPECT_EQ(records[1].cmps_visited, 2); + EXPECT_EQ(records[1].hops_visited, 2); + EXPECT_EQ(records[1].label, 1); + ASSERT_EQ(records[1].traversal_window_stats.size(), 7U); + + const auto& gt_cmps = ctx.GetGtCmpsPerRank(); + ASSERT_EQ(gt_cmps.size(), 2U); + EXPECT_EQ(gt_cmps[0], 1); + EXPECT_EQ(gt_cmps[1], 2); + EXPECT_EQ(ctx.GetTotalCmps(), 2); +} + +TEST(OmegaSearchContextTest, ReportVisitCandidatesReturnsPredictionPointAtBoundary) { + SearchContext ctx(nullptr, nullptr, 0.95f, 2, 3); + + std::vector warmup; + warmup.reserve(49); + for (int i = 0; i < 49; ++i) { + warmup.push_back({100 + i, 10.0f + static_cast(i), i < 2}); + } + + EXPECT_FALSE(ctx.ReportVisitCandidates(warmup.data(), warmup.size())); + EXPECT_EQ(ctx.GetTotalCmps(), 49); + EXPECT_EQ(ctx.GetTopCandidateCountForHook(), 2); + + EXPECT_TRUE(ctx.ReportVisitCandidate(999, 0.01f, true)); + EXPECT_EQ(ctx.GetTotalCmps(), 50); + EXPECT_EQ(ctx.GetTopCandidateCountForHook(), 2); + + ctx.Reset(); + EXPECT_EQ(ctx.GetTotalCmps(), 0); + EXPECT_EQ(ctx.GetTopCandidateCountForHook(), 0); + EXPECT_EQ(ctx.GetNextPredictionCmps(), 50); +} + +TEST(OmegaModelManagerTest, MissingModelFileFailsClearly) { + ModelManager manager; + EXPECT_FALSE(manager.LoadModel("this/path/should/not/exist")); + EXPECT_FALSE(manager.IsLoaded()); + EXPECT_EQ(manager.GetModel(), nullptr); +} + +} // namespace omega diff --git a/tests/core/interface/CMakeLists.txt b/tests/core/interface/CMakeLists.txt index 62e7bc933..4323b5722 100644 --- a/tests/core/interface/CMakeLists.txt +++ b/tests/core/interface/CMakeLists.txt @@ -1,15 +1,47 @@ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) file(GLOB_RECURSE ALL_TEST_SRCS *_test.cc) +if(NOT ZVEC_ENABLE_OMEGA) + list(FILTER ALL_TEST_SRCS EXCLUDE REGEX ".*/omega_.*_test\\.cc$") +endif() + +set(ZVEC_TEST_CORE_INTERFACE_LIBS + zvec_ailego + core_framework + core_metric + core_interface + core_knn_flat + core_utility + core_quantizer + sparsehash + core_knn_hnsw + core_mix_reducer + core_knn_flat_sparse + core_knn_hnsw_sparse + core_knn_ivf + core_knn_hnsw_rabitq +) +if(ZVEC_ENABLE_OMEGA) + list(APPEND ZVEC_TEST_CORE_INTERFACE_LIBS core_knn_omega omega) +endif() + +set(ZVEC_TEST_CORE_INTERFACE_INCS + . + ${PROJECT_ROOT_DIR}/src + ${PROJECT_ROOT_DIR}/src/core + ${PROJECT_ROOT_DIR}/src/core/algorithm) +if(ZVEC_ENABLE_OMEGA) + list(APPEND ZVEC_TEST_CORE_INTERFACE_INCS + ${PROJECT_ROOT_DIR}/thirdparty/omega/OMEGALib/include) +endif() foreach(CC_SRCS ${ALL_TEST_SRCS}) get_filename_component(CC_TARGET ${CC_SRCS} NAME_WE) cc_gtest( NAME ${CC_TARGET} STRICT - LIBS zvec_ailego core_framework core_metric core_interface core_knn_flat core_utility core_quantizer sparsehash core_knn_hnsw core_mix_reducer - core_knn_flat_sparse core_knn_hnsw_sparse core_knn_ivf core_knn_hnsw_rabitq + LIBS ${ZVEC_TEST_CORE_INTERFACE_LIBS} SRCS ${CC_SRCS} - INCS . ${PROJECT_ROOT_DIR}/src/core ${PROJECT_ROOT_DIR}/src/core/algorithm + INCS ${ZVEC_TEST_CORE_INTERFACE_INCS} ) -endforeach() \ No newline at end of file +endforeach() diff --git a/tests/core/interface/omega_index_integration_test.cc b/tests/core/interface/omega_index_integration_test.cc new file mode 100644 index 000000000..fc27b2639 --- /dev/null +++ b/tests/core/interface/omega_index_integration_test.cc @@ -0,0 +1,158 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include "tests/test_util.h" +#include "zvec/core/interface/index.h" +#include "zvec/core/interface/index_factory.h" +#include "zvec/core/interface/index_param_builders.h" + +namespace zvec::core_interface { +namespace { + +constexpr uint32_t kDimension = 16; +const std::string kIndexPath = "OmegaIndexIntegrationTest/test.index"; + +BaseIndexParam::Pointer CreateOmegaIndexParam() { + auto param = HNSWIndexParamBuilder() + .WithMetricType(MetricType::kInnerProduct) + .WithDataType(DataType::DT_FP32) + .WithDimension(kDimension) + .WithIsSparse(false) + .WithM(8) + .WithEFConstruction(64) + .Build(); + param->index_type = IndexType::kOMEGA; + return param; +} + +void PopulateIndex(const Index::Pointer &index, uint32_t doc_count) { + for (uint32_t doc_id = 0; doc_id < doc_count; ++doc_id) { + std::vector values(kDimension, static_cast(doc_id) / 10.0f); + values[0] = 1.0f + static_cast(doc_id); + + VectorData vector_data; + vector_data.vector = DenseVector{values.data()}; + ASSERT_EQ(index->Add(vector_data, doc_id), 0); + } +} + +VectorData MakeQuery(float base) { + static std::vector values; + values.assign(kDimension, base); + values[0] = 1.0f; + return VectorData{DenseVector{values.data()}}; +} + +class OmegaIndexIntegrationTest : public ::testing::Test { + protected: + void SetUp() override { + zvec::test_util::RemoveTestPath("OmegaIndexIntegrationTest"); + } + + void TearDown() override { + zvec::test_util::RemoveTestPath("OmegaIndexIntegrationTest"); + } +}; + +TEST_F(OmegaIndexIntegrationTest, + SearchFallsBackWithoutModelAndReturnsResults) { + auto index = IndexFactory::CreateAndInitIndex(*CreateOmegaIndexParam()); + ASSERT_NE(index, nullptr); + ASSERT_EQ(index->Open(kIndexPath, {StorageOptions::StorageType::kMMAP, true}), + 0); + + PopulateIndex(index, 8); + ASSERT_EQ(index->Train(), 0); + + auto query_param = OmegaQueryParamBuilder() + .with_topk(3) + .with_fetch_vector(true) + .with_ef_search(32) + .with_target_recall(0.90f) + .build(); + + auto query = MakeQuery(0.0f); + SearchResult result; + ASSERT_EQ(index->Search(query, query_param, &result), 0); + ASSERT_EQ(result.doc_list_.size(), 3U); + EXPECT_EQ(result.doc_list_[0].key(), 7U); + EXPECT_TRUE(result.training_records_.empty()); + EXPECT_TRUE(result.gt_cmps_per_rank_.empty()); + + ASSERT_EQ(index->Close(), 0); +} + +TEST_F(OmegaIndexIntegrationTest, + TrainingSessionCollectsArtifactsThroughIndexSearch) { + auto index = IndexFactory::CreateAndInitIndex(*CreateOmegaIndexParam()); + ASSERT_NE(index, nullptr); + ASSERT_EQ(index->Open(kIndexPath, {StorageOptions::StorageType::kMMAP, true}), + 0); + + PopulateIndex(index, 8); + ASSERT_EQ(index->Train(), 0); + + auto *training_capable = index->GetTrainingCapability(); + ASSERT_NE(training_capable, nullptr); + + auto session = training_capable->CreateTrainingSession(); + ASSERT_NE(session, nullptr); + + TrainingSessionConfig config; + config.topk = 1; + config.k_train = 1; + config.ground_truth = {{0}}; + ASSERT_TRUE(session->Start(config).ok()); + session->BeginQuery(0); + + auto query_param = OmegaQueryParamBuilder() + .with_topk(3) + .with_ef_search(32) + .with_training_query_id(0) + .with_target_recall(0.95f) + .build(); + + auto query = MakeQuery(0.0f); + SearchResult result; + ASSERT_EQ(index->Search(query, query_param, &result), 0); + EXPECT_EQ(result.training_query_id_, 0); + EXPECT_FALSE(result.training_records_.empty()); + ASSERT_EQ(result.gt_cmps_per_rank_.size(), 1U); + EXPECT_GT(result.total_cmps_, 0); + + QueryTrainingArtifacts artifacts; + artifacts.records = result.training_records_; + artifacts.gt_cmps_per_rank = result.gt_cmps_per_rank_; + artifacts.total_cmps = result.total_cmps_; + artifacts.training_query_id = result.training_query_id_; + session->CollectQueryArtifacts(std::move(artifacts)); + + TrainingArtifacts consumed = session->ConsumeArtifacts(); + ASSERT_FALSE(consumed.records.empty()); + ASSERT_EQ(consumed.gt_cmps_data.num_queries, 1U); + ASSERT_EQ(consumed.gt_cmps_data.topk, 1U); + ASSERT_EQ(consumed.gt_cmps_data.gt_cmps.size(), 1U); + ASSERT_EQ(consumed.gt_cmps_data.total_cmps.size(), 1U); + EXPECT_EQ(consumed.gt_cmps_data.total_cmps[0], result.total_cmps_); + + session->Finish(); + ASSERT_EQ(index->Close(), 0); +} + +} // namespace +} // namespace zvec::core_interface diff --git a/tests/core/interface/omega_query_param_test.cc b/tests/core/interface/omega_query_param_test.cc new file mode 100644 index 000000000..58807eb85 --- /dev/null +++ b/tests/core/interface/omega_query_param_test.cc @@ -0,0 +1,79 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "zvec/core/interface/index_factory.h" +#include "zvec/core/interface/index_param.h" + +namespace zvec::core_interface { + +TEST(OmegaQueryParamTest, ClonePreservesOmegaFields) { + OmegaQueryParam param; + param.topk = 20; + param.fetch_vector = true; + param.radius = 1.25f; + param.is_linear = true; + param.ef_search = 432; + param.training_query_id = 7; + param.target_recall = 0.91f; + + auto cloned_base = param.Clone(); + auto cloned = std::dynamic_pointer_cast(cloned_base); + ASSERT_NE(cloned, nullptr); + EXPECT_EQ(cloned->topk, 20U); + EXPECT_TRUE(cloned->fetch_vector); + EXPECT_FLOAT_EQ(cloned->radius, 1.25f); + EXPECT_TRUE(cloned->is_linear); + EXPECT_EQ(cloned->ef_search, 432U); + EXPECT_EQ(cloned->training_query_id, 7); + EXPECT_FLOAT_EQ(cloned->target_recall, 0.91f); +} + +TEST(OmegaQueryParamTest, JsonRoundTripPreservesTargetRecall) { + OmegaQueryParam param; + param.topk = 15; + param.fetch_vector = true; + param.ef_search = 512; + param.target_recall = 0.92f; + + const std::string json = + IndexFactory::QueryParamSerializeToJson(param, false); + auto restored = + IndexFactory::QueryParamDeserializeFromJson(json); + + ASSERT_NE(restored, nullptr); + EXPECT_EQ(restored->topk, 15U); + EXPECT_TRUE(restored->fetch_vector); + EXPECT_EQ(restored->ef_search, 512U); + EXPECT_FLOAT_EQ(restored->target_recall, 0.92f); +} + +TEST(OmegaQueryParamTest, BaseDeserializerReturnsOmegaType) { + OmegaQueryParam param; + param.ef_search = 256; + param.target_recall = 0.9f; + + const std::string json = + IndexFactory::QueryParamSerializeToJson(param, false); + auto restored = + IndexFactory::QueryParamDeserializeFromJson(json); + + ASSERT_NE(restored, nullptr); + auto* omega = dynamic_cast(restored.get()); + ASSERT_NE(omega, nullptr); + EXPECT_EQ(omega->ef_search, 256U); + EXPECT_FLOAT_EQ(omega->target_recall, 0.9f); +} + +} // namespace zvec::core_interface diff --git a/tests/core/interface/omega_training_session_test.cc b/tests/core/interface/omega_training_session_test.cc new file mode 100644 index 000000000..fceaa2889 --- /dev/null +++ b/tests/core/interface/omega_training_session_test.cc @@ -0,0 +1,107 @@ +// Copyright 2025-present the zvec project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "core/interface/indexes/omega_training_session.h" +#include + +namespace zvec::core_interface { + +TEST(OmegaTrainingSessionTest, StartFailsWithoutStreamer) { + OmegaTrainingSession session(nullptr); + TrainingSessionConfig config; + config.topk = 3; + config.ground_truth = {{1, 2, 3}}; + + auto status = session.Start(config); + EXPECT_FALSE(status.ok()); +} + +TEST(OmegaTrainingSessionTest, ConsumeArtifactsAggregatesRecordsAndGtCmps) { + OmegaTrainingSession session(nullptr); + + QueryTrainingArtifacts first; + first.training_query_id = 0; + first.total_cmps = 13; + first.gt_cmps_per_rank = {3, 7, 11}; + TrainingRecord first_record; + first_record.query_id = 0; + first_record.hops_visited = 1; + first_record.cmps_visited = 3; + first_record.dist_1st = 0.1f; + first_record.dist_start = 0.2f; + first_record.traversal_window_stats = {1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, 1.0f}; + first_record.label = 1; + first.records.push_back(first_record); + + QueryTrainingArtifacts second; + second.training_query_id = 2; + second.total_cmps = 21; + second.gt_cmps_per_rank = {5, 9, 15}; + TrainingRecord second_record; + second_record.query_id = 2; + second_record.hops_visited = 4; + second_record.cmps_visited = 8; + second_record.dist_1st = 0.3f; + second_record.dist_start = 0.4f; + second_record.traversal_window_stats = {2.0f, 2.0f, 2.0f, 2.0f, + 2.0f, 2.0f, 2.0f}; + second_record.label = 0; + second.records.push_back(second_record); + + session.CollectQueryArtifacts(std::move(first)); + session.CollectQueryArtifacts(std::move(second)); + + TrainingArtifacts artifacts = session.ConsumeArtifacts(); + ASSERT_EQ(artifacts.records.size(), 2U); + EXPECT_EQ(artifacts.records[0].query_id, 0); + EXPECT_EQ(artifacts.records[1].query_id, 2); + + ASSERT_EQ(artifacts.gt_cmps_data.topk, 3U); + ASSERT_EQ(artifacts.gt_cmps_data.num_queries, 3U); + ASSERT_EQ(artifacts.gt_cmps_data.gt_cmps.size(), 3U); + ASSERT_EQ(artifacts.gt_cmps_data.total_cmps.size(), 3U); + + EXPECT_EQ(artifacts.gt_cmps_data.gt_cmps[0][0], 3); + EXPECT_EQ(artifacts.gt_cmps_data.gt_cmps[0][2], 11); + EXPECT_EQ(artifacts.gt_cmps_data.total_cmps[0], 13); + + EXPECT_EQ(artifacts.gt_cmps_data.gt_cmps[1][0], 0); + EXPECT_EQ(artifacts.gt_cmps_data.total_cmps[1], 0); + + EXPECT_EQ(artifacts.gt_cmps_data.gt_cmps[2][1], 9); + EXPECT_EQ(artifacts.gt_cmps_data.total_cmps[2], 21); + + TrainingArtifacts drained = session.ConsumeArtifacts(); + EXPECT_TRUE(drained.records.empty()); + EXPECT_TRUE(drained.gt_cmps_data.gt_cmps.empty()); +} + +TEST(OmegaTrainingSessionTest, + ConsumeArtifactsUsesConfiguredShapeWhenAvailable) { + OmegaTrainingSession session(nullptr); + + QueryTrainingArtifacts only; + only.training_query_id = 1; + only.total_cmps = 8; + only.gt_cmps_per_rank = {2, 4}; + session.CollectQueryArtifacts(std::move(only)); + + TrainingArtifacts inferred = session.ConsumeArtifacts(); + ASSERT_EQ(inferred.gt_cmps_data.num_queries, 2U); + ASSERT_EQ(inferred.gt_cmps_data.topk, 2U); + EXPECT_EQ(inferred.gt_cmps_data.total_cmps[1], 8); +} + +} // namespace zvec::core_interface diff --git a/tests/db/sqlengine/mock_segment.h b/tests/db/sqlengine/mock_segment.h index 6b46e2385..87e5368ab 100644 --- a/tests/db/sqlengine/mock_segment.h +++ b/tests/db/sqlengine/mock_segment.h @@ -504,6 +504,10 @@ class MockSegment : public Segment { return Status::OK(); } + Status retrain_omega_model() override { + return Status::OK(); + } + Status destroy() override { return Status::OK(); } diff --git a/thirdparty/CMakeLists.txt b/thirdparty/CMakeLists.txt index 1c552fa0b..83f93585f 100644 --- a/thirdparty/CMakeLists.txt +++ b/thirdparty/CMakeLists.txt @@ -26,4 +26,9 @@ add_subdirectory(CRoaring CRoaring EXCLUDE_FROM_ALL) add_subdirectory(arrow arrow EXCLUDE_FROM_ALL) add_subdirectory(magic_enum magic_enum EXCLUDE_FROM_ALL) add_subdirectory(RaBitQ-Library RaBitQ-Library EXCLUDE_FROM_ALL) - +if(ZVEC_ENABLE_OMEGA) + message(STATUS "ZVEC: Building omega library with LightGBM support") + add_subdirectory(omega omega EXCLUDE_FROM_ALL) +else() + message(STATUS "ZVEC: OMEGA support disabled") +endif() diff --git a/thirdparty/omega/CMakeLists.txt b/thirdparty/omega/CMakeLists.txt new file mode 100644 index 000000000..1ffc9fcc8 --- /dev/null +++ b/thirdparty/omega/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(OMEGALib) diff --git a/thirdparty/omega/OMEGALib b/thirdparty/omega/OMEGALib new file mode 160000 index 000000000..4c530aa16 --- /dev/null +++ b/thirdparty/omega/OMEGALib @@ -0,0 +1 @@ +Subproject commit 4c530aa16fc232ed9c52cb9c85ffeb67b30b6c87 diff --git a/tools/core/CMakeLists.txt b/tools/core/CMakeLists.txt index c36b26409..791f577df 100644 --- a/tools/core/CMakeLists.txt +++ b/tools/core/CMakeLists.txt @@ -1,6 +1,26 @@ include(${PROJECT_ROOT_DIR}/cmake/bazel.cmake) include(${PROJECT_ROOT_DIR}/cmake/option.cmake) +set(ZVEC_TOOL_CORE_INTERFACE_LIBS + core_framework + core_metric + core_quantizer + core_utility + core_knn_flat + core_knn_flat_sparse + core_knn_hnsw + core_knn_hnsw_sparse + core_knn_hnsw_rabitq + core_knn_cluster + core_knn_ivf + core_interface +) + +set(ZVEC_TOOL_CORE_INTERFACE_IMPL_LIBS core_mix_reducer) +if(ZVEC_ENABLE_OMEGA) + list(APPEND ZVEC_TOOL_CORE_INTERFACE_IMPL_LIBS core_knn_omega) +endif() + cc_binary( NAME txt2vecs STRICT PACKED @@ -14,7 +34,7 @@ cc_binary( STRICT PACKED SRCS local_builder.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf core_interface + LIBS gflags yaml-cpp magic_enum ${ZVEC_TOOL_CORE_INTERFACE_LIBS} ${ZVEC_TOOL_CORE_INTERFACE_IMPL_LIBS} ) cc_binary( @@ -22,7 +42,7 @@ cc_binary( STRICT PACKED SRCS recall.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface + LIBS gflags yaml-cpp magic_enum roaring ${ZVEC_TOOL_CORE_INTERFACE_LIBS} ${ZVEC_TOOL_CORE_INTERFACE_IMPL_LIBS} ) cc_binary( @@ -30,16 +50,15 @@ cc_binary( STRICT PACKED SRCS bench.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface + LIBS gflags yaml-cpp magic_enum roaring ${ZVEC_TOOL_CORE_INTERFACE_LIBS} ${ZVEC_TOOL_CORE_INTERFACE_IMPL_LIBS} ) - cc_binary( NAME recall_original STRICT PACKED SRCS recall_original.cc flow.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface + LIBS gflags yaml-cpp magic_enum roaring ${ZVEC_TOOL_CORE_INTERFACE_LIBS} ${ZVEC_TOOL_CORE_INTERFACE_IMPL_LIBS} ) cc_binary( @@ -47,7 +66,7 @@ cc_binary( STRICT PACKED SRCS bench_original.cc flow.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf roaring core_interface + LIBS gflags yaml-cpp magic_enum roaring ${ZVEC_TOOL_CORE_INTERFACE_LIBS} ${ZVEC_TOOL_CORE_INTERFACE_IMPL_LIBS} ) cc_binary( @@ -55,5 +74,6 @@ cc_binary( STRICT PACKED SRCS local_builder_original.cc INCS ${PROJECT_ROOT_DIR}/src/core/ - LIBS gflags yaml-cpp magic_enum core_framework core_metric core_quantizer core_utility core_knn_flat core_knn_flat_sparse core_knn_hnsw core_knn_hnsw_sparse core_knn_hnsw_rabitq core_knn_cluster core_knn_ivf core_interface + LIBS gflags yaml-cpp magic_enum ${ZVEC_TOOL_CORE_INTERFACE_LIBS} ${ZVEC_TOOL_CORE_INTERFACE_IMPL_LIBS} + LIBS gflags yaml-cpp magic_enum ${ZVEC_TOOL_CORE_INTERFACE_LIBS} ${ZVEC_TOOL_CORE_INTERFACE_IMPL_LIBS} )