diff --git a/.github/workflows/main-ci-tests.yml b/.github/workflows/main-ci-tests.yml index ad0ca6c..572a403 100644 --- a/.github/workflows/main-ci-tests.yml +++ b/.github/workflows/main-ci-tests.yml @@ -2,7 +2,7 @@ name: CMake - Build and Test on: pull_request: - branches: [ main, join_pre_experiment ] + branches: [ main, main-dev ] push: branches: [ main ] @@ -16,7 +16,7 @@ jobs: runs-on: ${{ matrix.os }} env: # 限制 CI 运行时日志级别,避免 DEBUG 级别日志过多 - CANDY_LOG_LEVEL: info + SAGEFLOW_LOG_LEVEL: info strategy: fail-fast: false matrix: @@ -48,8 +48,8 @@ jobs: cmake -S . -B build \ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ -DBUILD_TESTING=ON \ - -DCANDY_ENABLE_METRICS=ON \ - -DCANDY_BUILD_PYBIND=OFF \ + -DSAGEFLOW_ENABLE_METRICS=ON \ + -DSAGEFLOW_BUILD_PYBIND=OFF \ -DCMAKE_POSITION_INDEPENDENT_CODE=ON - name: Build diff --git a/.gitignore b/.gitignore index 4af6d3e..1cdacbb 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,11 @@ cmake-build-*/ dist/ *.egg-info/ +# Experiment outputs +*.png +*.tsv +*.otf + # Compiled Object files *.slo *.lo @@ -91,3 +96,4 @@ docs/_build/ # Uncomment if needed # data/generated/ # examples/output/ +install/* \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 09d9ac3..c826b35 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,14 +1,14 @@ cmake_minimum_required(VERSION 3.20) -project(CANDY CXX) +project(sageFlow CXX) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) # 是否启用运行时指标采集代码(通过编译宏控制)。默认开启以保持现有行为。 -option(CANDY_ENABLE_METRICS "Enable metrics instrumentation in join/operator and tests" ON) -message(STATUS "CANDY_ENABLE_METRICS: ${CANDY_ENABLE_METRICS}") +option(SAGEFLOW_ENABLE_METRICS "Enable metrics instrumentation in join/operator and tests" ON) +message(STATUS "SAGEFLOW_ENABLE_METRICS: ${SAGEFLOW_ENABLE_METRICS}") # 启用测试选项(CLion 识别 gtest 必需) option(BUILD_TESTING "Build tests" ON) @@ -16,7 +16,6 @@ if(BUILD_TESTING) enable_testing() endif() - set(_sage_flow_shared_deps FALSE) if(DEFINED SAGE_COMMON_DEPS_FILE AND EXISTS "${SAGE_COMMON_DEPS_FILE}") include("${SAGE_COMMON_DEPS_FILE}") @@ -79,33 +78,3 @@ add_subdirectory(src) add_subdirectory(test) add_subdirectory(examples) -# Python bindings -if(NOT _sage_flow_shared_deps) - include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/pybind11_dependency.cmake) -endif() - -pybind11_add_module(_sage_flow python/bindings.cpp) -target_link_libraries(_sage_flow PRIVATE - candy - externalRuntimeLibs -) -target_include_directories(_sage_flow PRIVATE include) -if(DEFINED SAGE_COMMON_COMPILE_DEFINITIONS) - target_compile_definitions(_sage_flow PRIVATE ${SAGE_COMMON_COMPILE_DEFINITIONS}) -else() - target_compile_definitions(_sage_flow PRIVATE PYBIND11_INTERNALS_ID="sage_pybind11_shared") -endif() - -# Reduce exported symbol surface to minimize potential cross-module clashes -if(DEFINED SAGE_COMMON_COMPILE_OPTIONS) - target_compile_options(_sage_flow PRIVATE ${SAGE_COMMON_COMPILE_OPTIONS}) -else() - target_compile_options(_sage_flow PRIVATE -fvisibility=hidden -fvisibility-inlines-hidden) -endif() - -if(ENABLE_GPERFTOOLS AND DEFINED SAGE_GPERFTOOLS_LIBS AND SAGE_GPERFTOOLS_LIBS) - target_link_libraries(_sage_flow PRIVATE ${SAGE_GPERFTOOLS_LIBS}) -endif() -set_target_properties(_sage_flow PROPERTIES CXX_VISIBILITY_PRESET hidden VISIBILITY_INLINES_HIDDEN YES) - - diff --git a/batch_translate.py b/batch_translate.py new file mode 100755 index 0000000..1313a61 --- /dev/null +++ b/batch_translate.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 + +# Helper function to write translated content +def write_translation(filename, content): + with open(filename, 'w', encoding='utf-8') as f: + f.write(content) + print(f"Translated {filename}") + +# Translate IMPLEMENTATION_SUMMARY.md - keeping existing Chinese, just ensure consistency +implementation_summary = """# 数据源框架实现总结 + +## 问题描述 + +原始需求:帮我把test文件夹里面数据生成的部分作为模块抽离出来。数据源除了目前的随机生成,新加上从数据集中直接获取数据(数据集目前在data文件夹下)。在生成给算子用的数据中,可以通过不同的类来区分我想要的数据源。 + +## 实现方案 + +### 架构设计 + +创建了一个模块化的数据源框架,包含三个主要组件: + +1. **DataSourceBase** - 抽象基类,定义统一接口 +2. **RandomDataSource** - 随机数据生成器(从原TestDataGenerator提取) +3. **DatasetDataSource** - 数据集加载器(读取fvecs格式文件) +4. **VectorListSource** - 内存向量包装器(可复用组件) + +### 文件结构 + +``` +test/test_utils/data_source/ +├── data_source_base.h # 基类接口 +├── random_data_source.h/cpp # 随机数据源实现 +├── dataset_data_source.h/cpp # 数据集数据源实现 +├── vector_list_source.h # 内存向量包装器 +└── README.md # 完整文档 + +test/UnitTest/ +├── test_data_source.cpp # 单元测试 +├── test_data_persistence.cpp # 持久化测试 +└── test_join_data_source.cpp # Join数据源测试 + +test/examples/ +├── test_data_source_example.cpp # 使用示例 +└── data_persistence_example.cpp # 持久化示例 +``` + +### 关键特性 + +1. **模块化设计** - 数据生成逻辑独立,易于扩展 +2. **统一接口** - 所有数据源实现相同的接口 +3. **向后兼容** - 现有测试代码无需修改即可运行 +4. **灵活配置** - 支持多种数据源和配置选项 +5. **易于扩展** - 添加新数据源只需继承基类 + +## 使用方法 + +### 1. 使用随机数据源 + +```cpp +// 配置随机数据源 +RandomDataSource::Config config; +config.vector_dim = 128; +config.seed = 42; +auto data_source = std::make_shared(config); + +// 与TestDataGenerator一起使用 +TestDataGenerator::Config gen_config; +gen_config.positive_pairs = 100; +TestDataGenerator generator(gen_config, data_source); +auto [records, matches] = generator.generateData(); +``` + +### 2. 使用数据集数据源 + +```cpp +// 配置数据集数据源 +DatasetDataSource::Config config; +config.file_path = PROJECT_DIR "/data/siftsmall/siftsmall_query.fvecs"; +config.expected_dim = 128; +config.loop = true; // 循环使用 +auto data_source = std::make_shared(config); + +// 与TestDataGenerator一起使用 +TestDataGenerator generator(gen_config, data_source); +auto [records, matches] = generator.generateData(); +``` + +### 3. 向后兼容用法 + +```cpp +// 原有代码无需修改,仍然正常工作 +TestDataGenerator::Config config; +config.vector_dim = 128; +TestDataGenerator generator(config); // 自动使用随机数据源 +auto [records, matches] = generator.generateData(); +``` + +## 测试验证 + +### 单元测试 +- `test_data_source.cpp` - 包含5个测试用例 + - RandomDataSourceBasic - 测试随机数据源 + - DatasetDataSourceBasic - 测试数据集数据源 + - TestDataGeneratorWithRandomDataSource - 测试生成器+随机源 + - TestDataGeneratorWithDatasetDataSource - 测试生成器+数据集源 + - BackwardCompatibility - 测试向后兼容性 + +- `test_data_persistence.cpp` - 包含5个测试用例 + - 测试保存为FVECS格式 + - 测试保存为JSON格式 + - 测试FVECS往返(保存后加载) + - 测试JSON往返(保存后加载) + - 测试从保存的数据生成 + +- `test_join_data_source.cpp` - 包含8个测试用例 + - 测试Duplicate模式 + - 测试Separate模式 + - 测试Generator集成 + - 测试向后兼容性 + +### 测试结果 +```bash +cd build +ctest -L UNIT +# 18/18 tests passed (100%) +``` + +所有现有测试仍然通过,证明完全向后兼容。 + +### 示例程序 +```bash +cd build +./bin/test_data_source_example +# 运行4个示例,展示不同使用场景 + +./bin/data_persistence_example +# 演示数据持久化功能 +``` + +## 文档 + +- **test/test_utils/data_source/README.md** - 完整的框架文档 + - 架构说明 + - 使用指南 + - 配置选项 + - 扩展方法 + +- **test/test_utils/data_writer/README.md** - 数据写入器文档 + - FvecsWriter使用说明 + - JsonWriter使用说明 + - 配置选项 + +- **test/test_utils/JOIN_DATA_SOURCE_GUIDE.md** - Join数据源指南 + - Join框架说明 + - 使用示例 + - 配置选项 + +## 兼容性 + +✅ **完全向后兼容** - 所有现有测试无需修改 +✅ **现有测试通过** - 18个单元测试全部通过 +✅ **性能测试正常** - test_join_perf_scaling等构建正常 + +## 扩展性 + +添加新数据源非常简单: + +```cpp +class MyCustomDataSource : public DataSourceBase { +public: + // 实现接口方法 + std::vector getNextVector() override; + int getDimension() const override; + bool hasMore() const override; + void reset() override; +}; +``` + +## 技术细节 + +1. **内存管理** - 使用智能指针,自动管理生命周期 +2. **异常处理** - 数据集加载失败时抛出异常,带详细错误信息 +3. **线程安全** - 基础类不保证线程安全,由使用方控制 +4. **性能** - 数据集一次性加载到内存,访问快速 + +## 未来改进 + +可能的扩展方向: +1. 添加更多数据格式支持(如HDF5) +2. 支持流式加载大数据集 +3. 添加数据预处理功能 +4. 支持数据增强 +""" +write_translation('IMPLEMENTATION_SUMMARY.md', implementation_summary) + +print("\\nAll key documentation files translated to Chinese!") +print("Files translated: CODE_REVIEW_IMPROVEMENTS.md, IMPLEMENTATION_SUMMARY.md") diff --git a/build.sh b/build.sh index eca5c10..92ceff9 100755 --- a/build.sh +++ b/build.sh @@ -5,6 +5,45 @@ BUILD_TYPE=${BUILD_TYPE:-Debug} echo "Building sageFlow with CMake (CMAKE_BUILD_TYPE=${BUILD_TYPE})..." +# Function to check and fix libstdc++ version issue in conda environment +check_libstdcxx() { + # Only check if we're in a conda environment + if [[ -z "${CONDA_PREFIX}" ]]; then + return 0 + fi + + # Check if conda libstdc++ needs update + local conda_libstdcxx="${CONDA_PREFIX}/lib/libstdc++.so.6" + if [[ ! -f "${conda_libstdcxx}" ]]; then + return 0 + fi + + # Check GCC version requirement + local gcc_version=$(gcc -dumpversion | cut -d. -f1) + if [[ ${gcc_version} -ge 11 ]]; then + # Check if conda libstdc++ has required GLIBCXX version + if ! strings "${conda_libstdcxx}" | grep -q "GLIBCXX_3.4.30"; then + echo "⚠️ 检测到conda环境中的libstdc++版本过低,正在更新..." + echo " 这是C++20/GCC 11+编译所必需的" + + # Try to update libstdc++ in conda environment + if command -v conda &> /dev/null; then + conda install -c conda-forge libstdcxx-ng -y || { + echo "⚠️ 无法自动更新libstdc++,将使用系统版本" + # Set LD_LIBRARY_PATH to prefer system libstdc++ + if [[ -f "/usr/lib/x86_64-linux-gnu/libstdc++.so.6" ]]; then + export LD_LIBRARY_PATH="/usr/lib/x86_64-linux-gnu:${LD_LIBRARY_PATH}" + echo " 已设置LD_LIBRARY_PATH优先使用系统libstdc++" + fi + } + fi + fi + fi +} + +# Check libstdc++ before building +check_libstdcxx + # Create build directory if not exists mkdir -p build diff --git a/config/perf_join_datasource_modes.toml b/config/perf_join_datasource_modes.toml new file mode 100644 index 0000000..00db08a --- /dev/null +++ b/config/perf_join_datasource_modes.toml @@ -0,0 +1,61 @@ +[[performance_test]] +# Test Mode 1: Generate-Save-Load (generate data, save to file, then load from file) +name = "perf_join_gen_save_load_random" +mode = "generate_save_load" +methods = ["bruteforce_eager", "ivf_eager"] +sizes = [2000] +records_count = 10 +vector_dim = 64 +parallelism = [1,2,4] +window_time_ms = [10000] +window_trigger_ms = 50 +time_interval = 10 +similarity_threshold = 0.8 +seed = 42 + +[performance_test.data_source] +type = "random" + +[performance_test.storage] +format = "fvecs" # Options: "fvecs", "json" +file_path = "test/data/generated_test_data.fvecs" + +[[performance_test]] +# Test Mode 2: Direct Load (load data directly from existing file) +name = "perf_join_direct_load_sift" +mode = "direct_load" +methods = ["bruteforce_eager", "ivf_eager"] +sizes = [1000] +records_count = 10 +vector_dim = 128 +parallelism = [1,2] +window_time_ms = [10000] +window_trigger_ms = 50 +time_interval = 10 +similarity_threshold = 0.8 + +[performance_test.data_source] +type = "dataset" +file_path = "data/siftsmall/siftsmall_query.fvecs" +expected_dim = 128 +loop = true + +[[performance_test]] +# Test Mode 3: Generate and Use Directly (no file I/O) +name = "perf_join_direct_use_random" +mode = "generate_direct_use" +methods = ["bruteforce_eager", "ivf_eager"] +sizes = [2000] +records_count = 10 +vector_dim = 64 +parallelism = [1,2,4] +window_time_ms = [10000] +window_trigger_ms = 50 +time_interval = 10 +similarity_threshold = 0.8 +seed = 42 + +[performance_test.data_source] +type = "random" + +log.level = "debug" diff --git a/config/perf_join_with_datasource.toml b/config/perf_join_with_datasource.toml new file mode 100644 index 0000000..12092b1 --- /dev/null +++ b/config/perf_join_with_datasource.toml @@ -0,0 +1,40 @@ +[[performance_test]] +# Test with random data source (default) +name = "perf_join_random" +methods = ["bruteforce_eager", "ivf_eager"] +sizes = [4000] +records_count = 10 +vector_dim = 64 +parallelism = [1,2,4,8] +window_time_ms = [10000] +window_trigger_ms = 50 +time_interval = 10 +similarity_threshold = 0.8 +seed = 42 + +# Data source configuration +[performance_test.data_source] +type = "random" # Options: "random", "dataset", "json" +# For random type, no additional config needed (uses seed from above) + +[[performance_test]] +# Test with SIFT dataset +name = "perf_join_sift" +methods = ["bruteforce_eager", "ivf_eager"] +sizes = [1000] # Number of records to generate from dataset +records_count = 10 +vector_dim = 128 # SIFT dimension +parallelism = [1,2,4] +window_time_ms = [10000] +window_trigger_ms = 50 +time_interval = 10 +similarity_threshold = 0.8 +seed = 42 + +[performance_test.data_source] +type = "dataset" +file_path = "data/siftsmall/siftsmall_query.fvecs" +expected_dim = 128 +loop = true # Allow reusing vectors if dataset is smaller than needed + +log.level = "debug" diff --git a/examples/Source/CMakeLists.txt b/examples/Source/CMakeLists.txt index 628e324..13f8654 100644 --- a/examples/Source/CMakeLists.txt +++ b/examples/Source/CMakeLists.txt @@ -7,9 +7,11 @@ add_executable( fvecs_to_vector_records.cpp ) target_link_libraries(generatorT - candy + PRIVATE + sageflow ) target_link_libraries(fvecs_to_vector_records - candy + PRIVATE + sageflow ) \ No newline at end of file diff --git a/examples/Source/fvecs_to_vector_records.cpp b/examples/Source/fvecs_to_vector_records.cpp index 712b799..be27872 100644 --- a/examples/Source/fvecs_to_vector_records.cpp +++ b/examples/Source/fvecs_to_vector_records.cpp @@ -10,7 +10,7 @@ // Assuming sageFlow headers are accessible via include path #include "common/data_types.h" // For DataType enum -using namespace candy; +using namespace sageFlow; // Function to read vectors from an fvecs file (remains the same) int read_fvecs(const std::string& filename, std::vector& data, int expected_dim = -1) { diff --git a/examples/Source/generate_vector_records.cpp b/examples/Source/generate_vector_records.cpp index dec2126..089c62e 100644 --- a/examples/Source/generate_vector_records.cpp +++ b/examples/Source/generate_vector_records.cpp @@ -8,8 +8,8 @@ static int cnt = 0; // Helper function to generate a random vector -auto GenerateRandomVectorData(std::mt19937& gen, int dim, candy::DataType data_type) -> candy::VectorData { - candy::VectorData vector_data(dim, data_type); +auto GenerateRandomVectorData(std::mt19937& gen, int dim, sageFlow::DataType data_type) -> sageFlow::VectorData { + sageFlow::VectorData vector_data(dim, data_type); float begin = 0; float end = 1.0F; // Create distributions based on data type @@ -18,35 +18,35 @@ auto GenerateRandomVectorData(std::mt19937& gen, int dim, candy::DataType data_t std::uniform_real_distribution double_dist(-100.0, 100.0); // Fill the vector with random data based on its type - int element_size = candy::DATA_TYPE_SIZE[data_type]; + int element_size = sageFlow::DATA_TYPE_SIZE[data_type]; for (int i = 0; i < dim; ++i) { switch (data_type) { - case candy::Int8: { + case sageFlow::Int8: { auto value = static_cast(int_dist(gen) % 128); std::memcpy(vector_data.data_.get() + i * element_size, &value, element_size); break; } - case candy::Int16: { + case sageFlow::Int16: { auto value = static_cast(int_dist(gen)); std::memcpy(vector_data.data_.get() + i * element_size, &value, element_size); break; } - case candy::Int32: { + case sageFlow::Int32: { int32_t value = int_dist(gen); std::memcpy(vector_data.data_.get() + i * element_size, &value, element_size); break; } - case candy::Int64: { + case sageFlow::Int64: { auto value = static_cast(int_dist(gen)); std::memcpy(vector_data.data_.get() + i * element_size, &value, element_size); break; } - case candy::Float32: { + case sageFlow::Float32: { float value = float_dist(gen); std::memcpy(vector_data.data_.get() + i * element_size, &value, element_size); break; } - case candy::Float64: { + case sageFlow::Float64: { double value = double_dist(gen); std::memcpy(vector_data.data_.get() + i * element_size, &value, element_size); break; @@ -94,19 +94,19 @@ auto main(int argc, char* argv[]) -> int { // Write number of records as header int32_t record_count = num_records; output_file.write(reinterpret_cast(&record_count), sizeof(int32_t)); - candy::DataType type = candy::Float32; + sageFlow::DataType type = sageFlow::Float32; for (int i = 0; i < num_records; ++i) { // Generate random values for the vector record uint64_t uid = i; int64_t timestamp = base_timestamp + i; // Sequential timestamps // Generate random vector data - candy::VectorData vector_data = GenerateRandomVectorData(gen, dim, type); + sageFlow::VectorData vector_data = GenerateRandomVectorData(gen, dim, type); // Create vector record // Serialize and write to file - if (candy::VectorRecord record(uid, timestamp, std::move(vector_data)); !record.Serialize(output_file)) { + if (sageFlow::VectorRecord record(uid, timestamp, std::move(vector_data)); !record.Serialize(output_file)) { std::cerr << "Failed to serialize record " << i << '\n'; output_file.close(); return 1; diff --git a/examples/Streaming/CMakeLists.txt b/examples/Streaming/CMakeLists.txt index c503b60..062668b 100644 --- a/examples/Streaming/CMakeLists.txt +++ b/examples/Streaming/CMakeLists.txt @@ -18,25 +18,33 @@ add_executable( target_link_libraries( Streaming_example - candy + PRIVATE + sageflow + externalRuntimeLibs ) target_link_libraries( topk - candy + PRIVATE + sageflow + externalRuntimeLibs ) target_link_libraries( itopk - candy + PRIVATE + sageflow + externalRuntimeLibs ) target_link_libraries( aggregate - candy + PRIVATE + sageflow + externalRuntimeLibs ) macro(add_streaming_example name) add_executable(${name} ${name}.cpp) - target_link_libraries(${name} candy) + target_link_libraries(${name} PRIVATE sageflow externalRuntimeLibs) endmacro() add_streaming_example(test1) add_streaming_example(test2) diff --git a/examples/Streaming/aggregate.cpp b/examples/Streaming/aggregate.cpp index e6e00a9..7a03294 100644 --- a/examples/Streaming/aggregate.cpp +++ b/examples/Streaming/aggregate.cpp @@ -18,12 +18,12 @@ #include "stream/data_stream_source/simple_stream_source.h" using namespace std; // NOLINT -using namespace candy; // NOLINT +using namespace sageFlow; // NOLINT -const std::string CANDY_PATH = PROJECT_DIR; +const std::string SAGEFLOW_PATH = PROJECT_DIR; #define CONFIG_DIR "/config/" -namespace candy { +namespace sageFlow { void ValidateConfiguration(const ConfigMap &conf) { if (!conf.exist("inputPath") || !conf.exist("outputPath")) { throw runtime_error("Missing required configuration keys: inputPath or outputPath."); @@ -39,7 +39,7 @@ void ValidateConfiguration(const ConfigMap &conf) { void SetupAndRunPipeline(const std::string &config_file_path) { StreamEnvironment env; - const auto conf = candy::StreamEnvironment::loadConfiguration(config_file_path); + const auto conf = sageFlow::StreamEnvironment::loadConfiguration(config_file_path); try { ValidateConfiguration(conf); @@ -70,16 +70,16 @@ void SetupAndRunPipeline(const std::string &config_file_path) { monitor.StopProfiling(); } -} // namespace candy +} // namespace sageFlow auto main(int argc, char *argv[]) -> int { - const std::string default_config_file = CANDY_PATH + CONFIG_DIR + "default_config.toml"; + const std::string default_config_file = SAGEFLOW_PATH + CONFIG_DIR + "default_config.toml"; string config_file_path; if (argc < 2) { config_file_path = default_config_file; } else { - config_file_path = CANDY_PATH + CONFIG_DIR + string(argv[1]); + config_file_path = SAGEFLOW_PATH + CONFIG_DIR + string(argv[1]); } try { diff --git a/examples/Streaming/buffer_test.cpp b/examples/Streaming/buffer_test.cpp index c3d8da6..84842ae 100644 --- a/examples/Streaming/buffer_test.cpp +++ b/examples/Streaming/buffer_test.cpp @@ -13,12 +13,12 @@ #include "stream/data_stream_source/simple_stream_source.h" using namespace std; // NOLINT -using namespace candy; // NOLINT +using namespace sageFlow; // NOLINT -const std::string CANDY_PATH = PROJECT_DIR; +const std::string SAGEFLOW_PATH = PROJECT_DIR; #define CONFIG_DIR "/config/" -namespace candy { +namespace sageFlow { void ValidateConfiguration(const ConfigMap &conf) { if (!conf.exist("inputPath") || !conf.exist("outputPath")) { throw runtime_error("Missing required configuration keys: inputPath or outputPath."); @@ -34,7 +34,7 @@ void ValidateConfiguration(const ConfigMap &conf) { void SetupAndRunPipeline(const std::string &config_file_path) { StreamEnvironment env; - const auto conf = candy::StreamEnvironment::loadConfiguration(config_file_path); + const auto conf = sageFlow::StreamEnvironment::loadConfiguration(config_file_path); try { ValidateConfiguration(conf); @@ -65,16 +65,16 @@ void SetupAndRunPipeline(const std::string &config_file_path) { monitor.StopProfiling(); } -} // namespace candy +} // namespace sageFlow auto main(int argc, char *argv[]) -> int { - const std::string default_config_file = CANDY_PATH + CONFIG_DIR + "default_config.toml"; + const std::string default_config_file = SAGEFLOW_PATH + CONFIG_DIR + "default_config.toml"; string config_file_path; if (argc < 2) { config_file_path = default_config_file; } else { - config_file_path = CANDY_PATH + CONFIG_DIR + string(argv[1]); + config_file_path = SAGEFLOW_PATH + CONFIG_DIR + string(argv[1]); } try { diff --git a/examples/Streaming/index_perf.cpp b/examples/Streaming/index_perf.cpp index 37cbe0d..609d5c9 100644 --- a/examples/Streaming/index_perf.cpp +++ b/examples/Streaming/index_perf.cpp @@ -15,12 +15,12 @@ #include "stream/data_stream_source/simple_stream_source.h" using namespace std; // NOLINT -using namespace candy; // NOLINT +using namespace sageFlow; // NOLINT -const std::string CANDY_PATH = PROJECT_DIR; +const std::string SAGEFLOW_PATH = PROJECT_DIR; #define CONFIG_DIR "/config/" -namespace candy { +namespace sageFlow { void ValidateConfiguration(const ConfigMap &conf) { if (!conf.exist("inputPath") || !conf.exist("outputPath")) { throw runtime_error("Missing required configuration keys: inputPath or outputPath."); @@ -36,7 +36,7 @@ void ValidateConfiguration(const ConfigMap &conf) { void SetupAndRunPipeline(const std::string &config_file_path) { StreamEnvironment env; - const auto conf = candy::StreamEnvironment::loadConfiguration(config_file_path); + const auto conf = sageFlow::StreamEnvironment::loadConfiguration(config_file_path); try { ValidateConfiguration(conf); @@ -69,16 +69,16 @@ void SetupAndRunPipeline(const std::string &config_file_path) { monitor.StopProfiling(); } -} // namespace candy +} // namespace sageFlow auto main(int argc, char *argv[]) -> int { - const std::string default_config_file = CANDY_PATH + CONFIG_DIR + "default_config.toml"; + const std::string default_config_file = SAGEFLOW_PATH + CONFIG_DIR + "default_config.toml"; string config_file_path; if (argc < 2) { config_file_path = default_config_file; } else { - config_file_path = CANDY_PATH + CONFIG_DIR + string(argv[1]); + config_file_path = SAGEFLOW_PATH + CONFIG_DIR + string(argv[1]); } try { diff --git a/examples/Streaming/itopk.cpp b/examples/Streaming/itopk.cpp index ae10d00..5553578 100644 --- a/examples/Streaming/itopk.cpp +++ b/examples/Streaming/itopk.cpp @@ -19,12 +19,12 @@ #include "stream/data_stream_source/simple_stream_source.h" using namespace std; // NOLINT -using namespace candy; // NOLINT +using namespace sageFlow; // NOLINT -const std::string CANDY_PATH = PROJECT_DIR; +const std::string SAGEFLOW_PATH = PROJECT_DIR; #define CONFIG_DIR "/config/" -namespace candy { +namespace sageFlow { void ValidateConfiguration(const ConfigMap &conf) { if (!conf.exist("inputPath") || !conf.exist("outputPath")) { throw runtime_error("Missing required configuration keys: inputPath or outputPath."); @@ -40,7 +40,7 @@ void ValidateConfiguration(const ConfigMap &conf) { void SetupAndRunPipeline(const std::string &config_file_path) { StreamEnvironment env; - const auto conf = candy::StreamEnvironment::loadConfiguration(config_file_path); + const auto conf = sageFlow::StreamEnvironment::loadConfiguration(config_file_path); try { ValidateConfiguration(conf); @@ -86,16 +86,16 @@ void SetupAndRunPipeline(const std::string &config_file_path) { monitor.StopProfiling(); } -} // namespace candy +} // namespace sageFlow auto main(int argc, char *argv[]) -> int { - const std::string default_config_file = CANDY_PATH + CONFIG_DIR + "default_config.toml"; + const std::string default_config_file = SAGEFLOW_PATH + CONFIG_DIR + "default_config.toml"; string config_file_path; if (argc < 2) { config_file_path = default_config_file; } else { - config_file_path = CANDY_PATH + CONFIG_DIR + string(argv[1]); + config_file_path = SAGEFLOW_PATH + CONFIG_DIR + string(argv[1]); } try { diff --git a/examples/Streaming/main.cpp b/examples/Streaming/main.cpp index 9551749..6d569b2 100644 --- a/examples/Streaming/main.cpp +++ b/examples/Streaming/main.cpp @@ -14,12 +14,12 @@ #include "stream/data_stream_source/file_stream_source.h" using namespace std; // NOLINT -using namespace candy; // NOLINT +using namespace sageFlow; // NOLINT -const std::string CANDY_PATH = PROJECT_DIR; +const std::string SAGEFLOW_PATH = PROJECT_DIR; #define CONFIG_DIR "/config/" -namespace candy { +namespace sageFlow { void ValidateConfiguration(const ConfigMap &conf) { if (!conf.exist("inputPath") || !conf.exist("outputPath")) { throw runtime_error("Missing required configuration keys: inputPath or outputPath."); @@ -35,7 +35,7 @@ void ValidateConfiguration(const ConfigMap &conf) { void SetupAndRunPipeline(const std::string &config_file_path) { StreamEnvironment env; - const auto conf = candy::StreamEnvironment::loadConfiguration(config_file_path); + const auto conf = sageFlow::StreamEnvironment::loadConfiguration(config_file_path); try { ValidateConfiguration(conf); @@ -82,16 +82,16 @@ void SetupAndRunPipeline(const std::string &config_file_path) { monitor.StopProfiling(); } -} // namespace candy +} // namespace sageFlow auto main(int argc, char *argv[]) -> int { - const std::string default_config_file = CANDY_PATH + CONFIG_DIR + "default_config.toml"; + const std::string default_config_file = SAGEFLOW_PATH + CONFIG_DIR + "default_config.toml"; string config_file_path; if (argc < 2) { config_file_path = default_config_file; } else { - config_file_path = CANDY_PATH + CONFIG_DIR + string(argv[1]); + config_file_path = SAGEFLOW_PATH + CONFIG_DIR + string(argv[1]); } try { diff --git a/examples/Streaming/test1.cpp b/examples/Streaming/test1.cpp index 9e187a6..5f9ebce 100644 --- a/examples/Streaming/test1.cpp +++ b/examples/Streaming/test1.cpp @@ -15,12 +15,12 @@ #include "stream/data_stream_source/simple_stream_source.h" using namespace std; // NOLINT -using namespace candy; // NOLINT +using namespace sageFlow; // NOLINT -const std::string CANDY_PATH = PROJECT_DIR; +const std::string SAGEFLOW_PATH = PROJECT_DIR; #define CONFIG_DIR "/config/" -namespace candy { +namespace sageFlow { void ValidateConfiguration(const ConfigMap &conf) { if (!conf.exist("inputPath") || !conf.exist("outputPath")) { throw runtime_error("Missing required configuration keys: inputPath or outputPath."); @@ -36,7 +36,7 @@ void ValidateConfiguration(const ConfigMap &conf) { void SetupAndRunPipeline(const std::string &config_file_path) { StreamEnvironment env; - const auto conf = candy::StreamEnvironment::loadConfiguration(config_file_path); + const auto conf = sageFlow::StreamEnvironment::loadConfiguration(config_file_path); try { ValidateConfiguration(conf); @@ -78,16 +78,16 @@ void SetupAndRunPipeline(const std::string &config_file_path) { monitor.StopProfiling(); } -} // namespace candy +} // namespace sageFlow auto main(int argc, char *argv[]) -> int { - const std::string default_config_file = CANDY_PATH + CONFIG_DIR + "default_config.toml"; + const std::string default_config_file = SAGEFLOW_PATH + CONFIG_DIR + "default_config.toml"; string config_file_path; if (argc < 2) { config_file_path = default_config_file; } else { - config_file_path = CANDY_PATH + CONFIG_DIR + string(argv[1]); + config_file_path = SAGEFLOW_PATH + CONFIG_DIR + string(argv[1]); } try { diff --git a/examples/Streaming/test2.cpp b/examples/Streaming/test2.cpp index a328c6b..0314077 100644 --- a/examples/Streaming/test2.cpp +++ b/examples/Streaming/test2.cpp @@ -18,12 +18,12 @@ #include "stream/data_stream_source/simple_stream_source.h" using namespace std; // NOLINT -using namespace candy; // NOLINT +using namespace sageFlow; // NOLINT -const std::string CANDY_PATH = PROJECT_DIR; +const std::string SAGEFLOW_PATH = PROJECT_DIR; #define CONFIG_DIR "/config/" -namespace candy { +namespace sageFlow { void ValidateConfiguration(const ConfigMap &conf) { if (!conf.exist("inputPath") || !conf.exist("outputPath")) { throw runtime_error("Missing required configuration keys: inputPath or outputPath."); @@ -39,7 +39,7 @@ void ValidateConfiguration(const ConfigMap &conf) { void SetupAndRunPipeline(const std::string &config_file_path) { StreamEnvironment env; - const auto conf = candy::StreamEnvironment::loadConfiguration(config_file_path); + const auto conf = sageFlow::StreamEnvironment::loadConfiguration(config_file_path); try { ValidateConfiguration(conf); @@ -78,16 +78,16 @@ void SetupAndRunPipeline(const std::string &config_file_path) { monitor.StopProfiling(); } -} // namespace candy +} // namespace sageFlow auto main(int argc, char *argv[]) -> int { - const std::string default_config_file = CANDY_PATH + CONFIG_DIR + "default_config.toml"; + const std::string default_config_file = SAGEFLOW_PATH + CONFIG_DIR + "default_config.toml"; string config_file_path; if (argc < 2) { config_file_path = default_config_file; } else { - config_file_path = CANDY_PATH + CONFIG_DIR + string(argv[1]); + config_file_path = SAGEFLOW_PATH + CONFIG_DIR + string(argv[1]); } try { diff --git a/examples/Streaming/topk.cpp b/examples/Streaming/topk.cpp index 31430d3..678ca3e 100644 --- a/examples/Streaming/topk.cpp +++ b/examples/Streaming/topk.cpp @@ -15,12 +15,12 @@ #include "stream/data_stream_source/simple_stream_source.h" using namespace std; // NOLINT -using namespace candy; // NOLINT +using namespace sageFlow; // NOLINT -const std::string CANDY_PATH = PROJECT_DIR; +const std::string SAGEFLOW_PATH = PROJECT_DIR; #define CONFIG_DIR "/config/" -namespace candy { +namespace sageFlow { void ValidateConfiguration(const ConfigMap &conf) { if (!conf.exist("inputPath") || !conf.exist("outputPath")) { throw runtime_error("Missing required configuration keys: inputPath or outputPath."); @@ -36,7 +36,7 @@ void ValidateConfiguration(const ConfigMap &conf) { void SetupAndRunPipeline(const std::string &config_file_path) { StreamEnvironment env; - const auto conf = candy::StreamEnvironment::loadConfiguration(config_file_path); + const auto conf = sageFlow::StreamEnvironment::loadConfiguration(config_file_path); try { ValidateConfiguration(conf); @@ -72,16 +72,16 @@ void SetupAndRunPipeline(const std::string &config_file_path) { monitor.StopProfiling(); } -} // namespace candy +} // namespace sageFlow auto main(int argc, char *argv[]) -> int { - const std::string default_config_file = CANDY_PATH + CONFIG_DIR + "default_config.toml"; + const std::string default_config_file = SAGEFLOW_PATH + CONFIG_DIR + "default_config.toml"; string config_file_path; if (argc < 2) { config_file_path = default_config_file; } else { - config_file_path = CANDY_PATH + CONFIG_DIR + string(argv[1]); + config_file_path = SAGEFLOW_PATH + CONFIG_DIR + string(argv[1]); } try { diff --git a/include/common/data_types.h b/include/common/data_types.h index 1041fd4..32405be 100644 --- a/include/common/data_types.h +++ b/include/common/data_types.h @@ -6,7 +6,7 @@ #include #include -namespace candy { +namespace sageFlow { enum DataType { // NOLINT None, Int8, @@ -123,4 +123,4 @@ struct UidAndDist { } }; -} // namespace candy +} // namespace sageFlow diff --git a/include/compute_engine/compute_engine.h b/include/compute_engine/compute_engine.h index 59d95e5..09bd01d 100644 --- a/include/compute_engine/compute_engine.h +++ b/include/compute_engine/compute_engine.h @@ -7,7 +7,7 @@ #include "common/data_types.h" -namespace candy { +namespace sageFlow { class ComputeEngine { public: @@ -30,4 +30,4 @@ class ComputeEngine { auto EuclideanDistanceImpl(const VectorData &vec1, const VectorData &vec2) -> double; }; -} // namespace candy +} // namespace sageFlow diff --git a/include/concurrency/blank_controller.h b/include/concurrency/blank_controller.h index 66946a3..1910426 100644 --- a/include/concurrency/blank_controller.h +++ b/include/concurrency/blank_controller.h @@ -3,7 +3,7 @@ #include "concurrency/concurrency_controller.h" #include "index/index.h" -namespace candy { +namespace sageFlow { class BlankController final : public ConcurrencyController { public: BlankController(); @@ -26,4 +26,4 @@ class BlankController final : public ConcurrencyController { private: std::shared_ptr index_; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/concurrency/concurrency_controller.h b/include/concurrency/concurrency_controller.h index ea72bfa..4fa5863 100644 --- a/include/concurrency/concurrency_controller.h +++ b/include/concurrency/concurrency_controller.h @@ -6,7 +6,7 @@ #include "common/data_types.h" #include "storage/storage_manager.h" -namespace candy { +namespace sageFlow { class ConcurrencyController { public: // Constructor @@ -28,4 +28,4 @@ class ConcurrencyController { std::shared_ptr storage_manager_ = nullptr; }; -}; // namespace candy +}; // namespace sageFlow diff --git a/include/concurrency/concurrency_manager.h b/include/concurrency/concurrency_manager.h index 0571113..b568f4e 100644 --- a/include/concurrency/concurrency_manager.h +++ b/include/concurrency/concurrency_manager.h @@ -7,7 +7,7 @@ #include "concurrency/concurrency_controller.h" #include "index/index.h" -namespace candy { +namespace sageFlow { struct IdWithType { int id_; IndexType index_type_; @@ -50,4 +50,4 @@ class ConcurrencyManager { std::atomic index_id_counter_ = 0; // atomic counter for index id }; -}; // namespace candy +}; // namespace sageFlow diff --git a/include/execution/blocking_queue.h b/include/execution/blocking_queue.h index d15433b..22a03b3 100644 --- a/include/execution/blocking_queue.h +++ b/include/execution/blocking_queue.h @@ -12,7 +12,7 @@ #include "execution/iqueue.h" #include "common/data_types.h" -namespace candy { +namespace sageFlow { /// 阻塞队列实现,适用于多生产者多消费者场景 class BlockingQueue final : public IQueue { public: @@ -53,4 +53,4 @@ class BlockingQueue final : public IQueue { std::atomic stopped_; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/execution/collector.h b/include/execution/collector.h index 92f4cae..cd4a2f0 100644 --- a/include/execution/collector.h +++ b/include/execution/collector.h @@ -8,7 +8,7 @@ #include #include -namespace candy { +namespace sageFlow { class Collector { public: // 构造函数接收一个可以发射数据的 lambda 函数 @@ -52,4 +52,4 @@ class Collector { std::vector slots_; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/execution/execution_graph.h b/include/execution/execution_graph.h index d29d070..3e9ace0 100644 --- a/include/execution/execution_graph.h +++ b/include/execution/execution_graph.h @@ -9,7 +9,7 @@ #include "execution/partitioner.h" #include "operator/operator.h" -namespace candy { +namespace sageFlow { struct OperatorInfo { std::shared_ptr op; @@ -64,4 +64,4 @@ class ExecutionGraph { bool is_join_operator = false); }; -} // namespace candy +} // namespace sageFlow diff --git a/include/execution/execution_vertex.h b/include/execution/execution_vertex.h index 132fa2a..1b66ba1 100644 --- a/include/execution/execution_vertex.h +++ b/include/execution/execution_vertex.h @@ -11,7 +11,7 @@ #include #include -namespace candy { +namespace sageFlow { class ExecutionVertex { private: std::shared_ptr operator_; @@ -50,4 +50,4 @@ class ExecutionVertex { private: void run()const; }; -} // namespace candy +} // namespace sageFlow diff --git a/include/execution/input_gate.h b/include/execution/input_gate.h index d47603f..cc67176 100644 --- a/include/execution/input_gate.h +++ b/include/execution/input_gate.h @@ -9,7 +9,7 @@ #include #include -namespace candy { +namespace sageFlow { class InputGate { private: std::vector input_queues_; diff --git a/include/execution/iqueue.h b/include/execution/iqueue.h index d3437ca..84a05a7 100644 --- a/include/execution/iqueue.h +++ b/include/execution/iqueue.h @@ -9,7 +9,7 @@ #include #include "common/data_types.h" -namespace candy { +namespace sageFlow { struct TaggedResponse { Response response; @@ -35,4 +35,4 @@ class IQueue { using QueuePtr = std::shared_ptr; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/execution/partitioner.h b/include/execution/partitioner.h index d409580..2e4eb4f 100644 --- a/include/execution/partitioner.h +++ b/include/execution/partitioner.h @@ -9,7 +9,7 @@ #include #include "common/data_types.h" -namespace candy { +namespace sageFlow { class IPartitioner { public: virtual ~IPartitioner() = default; diff --git a/include/execution/result_partition.h b/include/execution/result_partition.h index 676c33a..2e194dc 100644 --- a/include/execution/result_partition.h +++ b/include/execution/result_partition.h @@ -11,7 +11,7 @@ #include #include -namespace candy { +namespace sageFlow { class ResultPartition { private: std::unique_ptr partitioner_; diff --git a/include/execution/ring_buffer_queue.h b/include/execution/ring_buffer_queue.h index b41a39e..0bb2c04 100644 --- a/include/execution/ring_buffer_queue.h +++ b/include/execution/ring_buffer_queue.h @@ -10,7 +10,7 @@ #include "common/data_types.h" #include "execution/iqueue.h" -namespace candy { +namespace sageFlow { // 适用于单生产者单消费者场景,使用环形缓冲区实现 class RingBufferQueue final : public IQueue { public: diff --git a/include/function/aggregate_function.h b/include/function/aggregate_function.h index 91c885c..8e12e65 100644 --- a/include/function/aggregate_function.h +++ b/include/function/aggregate_function.h @@ -5,7 +5,7 @@ #include "common/data_types.h" #include "function/function.h" -namespace candy { +namespace sageFlow { enum class AggregateType { None, Avg }; @@ -19,4 +19,4 @@ class AggregateFunction final : public Function { private: AggregateType aggregate_type_ = AggregateType::None; }; -}; // namespace candy \ No newline at end of file +}; // namespace sageFlow \ No newline at end of file diff --git a/include/function/filter_function.h b/include/function/filter_function.h index c5ec4d1..6baa07f 100644 --- a/include/function/filter_function.h +++ b/include/function/filter_function.h @@ -3,7 +3,7 @@ #include "function/function.h" -namespace candy { +namespace sageFlow { using FilterFunc = std::function &)>; class FilterFunction final : public Function { @@ -19,4 +19,4 @@ class FilterFunction final : public Function { private: FilterFunc filter_func_; }; -}; // namespace candy \ No newline at end of file +}; // namespace sageFlow \ No newline at end of file diff --git a/include/function/function.h b/include/function/function.h index 035bab4..a942908 100644 --- a/include/function/function.h +++ b/include/function/function.h @@ -5,7 +5,7 @@ #include "common/data_types.h" -namespace candy { +namespace sageFlow { enum class FunctionType { // NOLINT None, Filter, @@ -41,4 +41,4 @@ class Function { std::string name_; FunctionType type_ = FunctionType::None; }; -}; // namespace candy \ No newline at end of file +}; // namespace sageFlow \ No newline at end of file diff --git a/include/function/itopk_function.h b/include/function/itopk_function.h index 6045bad..542c838 100644 --- a/include/function/itopk_function.h +++ b/include/function/itopk_function.h @@ -7,7 +7,7 @@ #include "common/data_types.h" #include "function/function.h" -namespace candy { +namespace sageFlow { class ITopkFunction final : public Function { public: @@ -24,4 +24,4 @@ class ITopkFunction final : public Function { int dim_ = 0; std::unique_ptr record_; }; -}; // namespace candy +}; // namespace sageFlow diff --git a/include/function/join_function.h b/include/function/join_function.h index 71d7506..cdd05f3 100644 --- a/include/function/join_function.h +++ b/include/function/join_function.h @@ -8,7 +8,7 @@ #include "function/function.h" #include "stream/stream.h" -namespace candy { +namespace sageFlow { using JoinFunc = std::function(std::unique_ptr &, std::unique_ptr &)>; // 线程安全的滑动窗口类 @@ -129,4 +129,4 @@ namespace candy { // TODO : 把Window逻辑扩展 // 现在的 window 是固定长度步长的滑动窗口 }; - }; // namespace candy \ No newline at end of file + }; // namespace sageFlow \ No newline at end of file diff --git a/include/function/map_function.h b/include/function/map_function.h index c505517..91bc895 100644 --- a/include/function/map_function.h +++ b/include/function/map_function.h @@ -3,7 +3,7 @@ #include "function/function.h" -namespace candy { +namespace sageFlow { using MapFunc = std::function &)>; class MapFunction final : public Function { @@ -19,4 +19,4 @@ class MapFunction final : public Function { private: MapFunc map_func_; }; -}; // namespace candy \ No newline at end of file +}; // namespace sageFlow \ No newline at end of file diff --git a/include/function/sink_function.h b/include/function/sink_function.h index 40ee873..c06f25a 100644 --- a/include/function/sink_function.h +++ b/include/function/sink_function.h @@ -3,7 +3,7 @@ #include "function/function.h" -namespace candy { +namespace sageFlow { using SinkFunc = std::function &)>; class SinkFunction final : public Function { @@ -19,4 +19,4 @@ class SinkFunction final : public Function { private: SinkFunc sink_func_; }; -}; // namespace candy \ No newline at end of file +}; // namespace sageFlow \ No newline at end of file diff --git a/include/function/topk_function.h b/include/function/topk_function.h index 65bbcd2..f693aef 100644 --- a/include/function/topk_function.h +++ b/include/function/topk_function.h @@ -5,7 +5,7 @@ #include "common/data_types.h" #include "function/function.h" -namespace candy { +namespace sageFlow { class TopkFunction final : public Function { public: @@ -20,4 +20,4 @@ class TopkFunction final : public Function { int k_ = 0; int index_id_ = 0; }; -}; // namespace candy \ No newline at end of file +}; // namespace sageFlow \ No newline at end of file diff --git a/include/function/window_function.h b/include/function/window_function.h index df55e6b..5e488ee 100644 --- a/include/function/window_function.h +++ b/include/function/window_function.h @@ -2,7 +2,7 @@ #include "function/function.h" -namespace candy { +namespace sageFlow { enum class WindowType { Sliding, Tumbling @@ -24,4 +24,4 @@ class WindowFunction final : public Function { int window_size_; int slide_size_; }; -}; // namespace candy \ No newline at end of file +}; // namespace sageFlow \ No newline at end of file diff --git a/include/index/hnsw.h b/include/index/hnsw.h index 1818838..7d9aa10 100644 --- a/include/index/hnsw.h +++ b/include/index/hnsw.h @@ -3,7 +3,7 @@ #include "index/index.h" -namespace candy { +namespace sageFlow { class HNSW final : public Index { public: // HNSW() : HNSW(20, 100, 40) {} @@ -55,4 +55,4 @@ class HNSW final : public Index { auto random_level() -> int; }; -} // namespace candy +} // namespace sageFlow diff --git a/include/index/index.h b/include/index/index.h index f97d9e6..1d77d72 100644 --- a/include/index/index.h +++ b/include/index/index.h @@ -9,7 +9,7 @@ #include "compute_engine/compute_engine.h" #include "storage/storage_manager.h" -namespace candy { +namespace sageFlow { enum class IndexType { // NOLINT None, HNSW, @@ -44,4 +44,4 @@ class GlobalIndex final : public Index { auto load(const std::string &path) -> bool; auto remove() -> bool; }; -} // namespace candy +} // namespace sageFlow diff --git a/include/index/ivf.h b/include/index/ivf.h index 74cf294..52311ef 100644 --- a/include/index/ivf.h +++ b/include/index/ivf.h @@ -7,7 +7,7 @@ #include #include -namespace candy { +namespace sageFlow { class Ivf final : public Index { public: // Constructor @@ -61,4 +61,4 @@ class Ivf final : public Index { void rebuildIfNeeded(); }; -} // namespace candy +} // namespace sageFlow diff --git a/include/index/knn.h b/include/index/knn.h index 93f09d4..249a207 100644 --- a/include/index/knn.h +++ b/include/index/knn.h @@ -1,6 +1,6 @@ #include "index/index.h" -namespace candy { +namespace sageFlow { class Knn final : public Index { public: ~Knn() override; @@ -10,4 +10,4 @@ class Knn final : public Index { auto query_for_join(const VectorRecord &record, double join_similarity_threshold) -> std::vector override; }; -} // namespace candy +} // namespace sageFlow diff --git a/include/index/vectraflow.h b/include/index/vectraflow.h index b3948b2..0b01d8a 100644 --- a/include/index/vectraflow.h +++ b/include/index/vectraflow.h @@ -1,6 +1,6 @@ #include "index/index.h" -namespace candy { +namespace sageFlow { class VectraFlow final : public Index { private: std::vector datas_; @@ -16,4 +16,4 @@ class VectraFlow final : public Index { return {}; } }; -} // namespace candy +} // namespace sageFlow diff --git a/include/operator/aggregate_operator.h b/include/operator/aggregate_operator.h index 29855ef..2f474c5 100644 --- a/include/operator/aggregate_operator.h +++ b/include/operator/aggregate_operator.h @@ -7,7 +7,7 @@ #include "function/function.h" #include "operator/operator.h" -namespace candy { +namespace sageFlow { class AggregateOperator final : public Operator { public: explicit AggregateOperator(std::unique_ptr &aggregate_func); @@ -19,4 +19,4 @@ class AggregateOperator final : public Operator { private: std::unique_ptr aggregate_func_; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/operator/filter_operator.h b/include/operator/filter_operator.h index b4ae580..232b52d 100644 --- a/include/operator/filter_operator.h +++ b/include/operator/filter_operator.h @@ -5,7 +5,7 @@ #include "common/data_types.h" #include "operator/operator.h" -namespace candy { +namespace sageFlow { class FilterOperator final : public Operator { public: explicit FilterOperator(std::unique_ptr &filter_func); @@ -17,4 +17,4 @@ class FilterOperator final : public Operator { private: std::unique_ptr filter_func_; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/operator/itopk_operator.h b/include/operator/itopk_operator.h index 86473d9..3e5cf0c 100644 --- a/include/operator/itopk_operator.h +++ b/include/operator/itopk_operator.h @@ -7,7 +7,7 @@ #include "concurrency/concurrency_manager.h" #include "operator/operator.h" -namespace candy { +namespace sageFlow { class ITopkOperator final : public Operator { public: explicit ITopkOperator(std::unique_ptr &func, @@ -30,4 +30,4 @@ class ITopkOperator final : public Operator { // 多线程改造:添加状态保护的互斥锁 mutable std::mutex state_mutex_; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/operator/join_metrics.h b/include/operator/join_metrics.h index 2105dd6..7a38e04 100644 --- a/include/operator/join_metrics.h +++ b/include/operator/join_metrics.h @@ -7,7 +7,7 @@ #include #include -namespace candy { +namespace sageFlow { struct JoinMetrics { std::atomic window_insert_ns{0}; std::atomic index_insert_ns{0}; @@ -75,4 +75,4 @@ class ScopedAccumulateAtomic { uint64_t start_ns_; }; -} // namespace candy +} // namespace sageFlow diff --git a/include/operator/join_operator.h b/include/operator/join_operator.h index abdf263..e49b800 100644 --- a/include/operator/join_operator.h +++ b/include/operator/join_operator.h @@ -12,7 +12,7 @@ #include "operator/join_operator_methods/base_method.h" #include "concurrency/concurrency_manager.h" -namespace candy { +namespace sageFlow { class JoinOperator final : public Operator { public: explicit JoinOperator(std::unique_ptr &join_func, @@ -97,4 +97,4 @@ namespace candy { int left_slot_id_ = 0; int right_slot_id_ = 1; }; - } // namespace candy + } // namespace sageFlow diff --git a/include/operator/join_operator_methods/base_method.h b/include/operator/join_operator_methods/base_method.h index 779ac54..12c9f6b 100644 --- a/include/operator/join_operator_methods/base_method.h +++ b/include/operator/join_operator_methods/base_method.h @@ -6,7 +6,7 @@ #include "function/join_function.h" -namespace candy { +namespace sageFlow { enum class JoinMethodType { BRUTEFORCE_EAGER, @@ -48,4 +48,4 @@ class BaseMethod { double join_similarity_threshold_; private: }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/operator/join_operator_methods/bruteforce.h b/include/operator/join_operator_methods/bruteforce.h index 898c691..06060cf 100644 --- a/include/operator/join_operator_methods/bruteforce.h +++ b/include/operator/join_operator_methods/bruteforce.h @@ -6,7 +6,7 @@ #include "function/join_function.h" #include "concurrency/concurrency_manager.h" -namespace candy { +namespace sageFlow { class BruteForceJoinMethod final : public BaseMethod { public: BruteForceJoinMethod(int left_index_id, @@ -31,4 +31,4 @@ class BruteForceJoinMethod final : public BaseMethod { int right_index_id_ = -1; std::shared_ptr concurrency_manager_; }; -} // namespace candy +} // namespace sageFlow diff --git a/include/operator/join_operator_methods/eager/bruteforce.h b/include/operator/join_operator_methods/eager/bruteforce.h index 1d1404e..a235f36 100644 --- a/include/operator/join_operator_methods/eager/bruteforce.h +++ b/include/operator/join_operator_methods/eager/bruteforce.h @@ -6,7 +6,7 @@ #include "function/join_function.h" #include "concurrency/concurrency_manager.h" -namespace candy { +namespace sageFlow { class BruteForceEager : public BaseMethod { public: // 更新构造函数,支持ConcurrencyManager @@ -50,4 +50,4 @@ class BruteForceEager : public BaseMethod { std::shared_ptr concurrency_manager_; bool using_knn_ = false; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/operator/join_operator_methods/eager/ivf.h b/include/operator/join_operator_methods/eager/ivf.h index f5afc72..a7a49b1 100644 --- a/include/operator/join_operator_methods/eager/ivf.h +++ b/include/operator/join_operator_methods/eager/ivf.h @@ -8,7 +8,7 @@ #include "common/data_types.h" // Required for VectorRecord #include // Required for std::shared_ptr, std::unique_ptr -namespace candy { +namespace sageFlow { class IvfEager final : public BaseMethod { public: @@ -48,4 +48,4 @@ class IvfEager final : public BaseMethod { std::shared_ptr concurrency_manager_; }; -} // namespace candy +} // namespace sageFlow diff --git a/include/operator/join_operator_methods/ivf.h b/include/operator/join_operator_methods/ivf.h index af7f711..8b2f16e 100644 --- a/include/operator/join_operator_methods/ivf.h +++ b/include/operator/join_operator_methods/ivf.h @@ -6,7 +6,7 @@ #include "function/join_function.h" #include "concurrency/concurrency_manager.h" -namespace candy { +namespace sageFlow { class IvfJoinMethod final : public BaseMethod { public: IvfJoinMethod(int left_index_id, @@ -31,4 +31,4 @@ class IvfJoinMethod final : public BaseMethod { int right_index_id_ = -1; std::shared_ptr concurrency_manager_; }; -} // namespace candy +} // namespace sageFlow diff --git a/include/operator/join_operator_methods/lazy/bruteforce.h b/include/operator/join_operator_methods/lazy/bruteforce.h index 69eb885..962de78 100644 --- a/include/operator/join_operator_methods/lazy/bruteforce.h +++ b/include/operator/join_operator_methods/lazy/bruteforce.h @@ -6,7 +6,7 @@ #include "function/join_function.h" #include "concurrency/concurrency_manager.h" -namespace candy { +namespace sageFlow { class BruteForceLazy final : public BaseMethod { public: explicit BruteForceLazy(double join_similarity_threshold) : BaseMethod(join_similarity_threshold) {} @@ -45,4 +45,4 @@ class BruteForceLazy final : public BaseMethod { std::shared_ptr concurrency_manager_; bool using_knn_ = false; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/operator/join_operator_methods/lazy/ivf.h b/include/operator/join_operator_methods/lazy/ivf.h index f1c3eec..1fe62b0 100644 --- a/include/operator/join_operator_methods/lazy/ivf.h +++ b/include/operator/join_operator_methods/lazy/ivf.h @@ -7,7 +7,7 @@ #include "index/ivf.h" #include "concurrency/concurrency_manager.h" -namespace candy { +namespace sageFlow { class IvfLazy final : public BaseMethod { public: @@ -45,4 +45,4 @@ class IvfLazy final : public BaseMethod { std::shared_ptr concurrency_manager_; }; -} // namespace candy +} // namespace sageFlow diff --git a/include/operator/map_operator.h b/include/operator/map_operator.h index 871e9f7..9dd1b86 100644 --- a/include/operator/map_operator.h +++ b/include/operator/map_operator.h @@ -6,7 +6,7 @@ #include "function/function.h" #include "operator/operator.h" -namespace candy { +namespace sageFlow { class MapOperator final : public Operator { public: explicit MapOperator(std::unique_ptr &map_func); @@ -18,4 +18,4 @@ class MapOperator final : public Operator { private: std::unique_ptr map_func_; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/operator/operator.h b/include/operator/operator.h index 46f7aa6..cee0546 100644 --- a/include/operator/operator.h +++ b/include/operator/operator.h @@ -11,7 +11,7 @@ #include "function/function_api.h" #include "execution/collector.h" -namespace candy { +namespace sageFlow { enum class OperatorType { NONE, OUTPUT, @@ -54,4 +54,4 @@ class Operator { std::string name = "Operator"; // 添加name字段用于标识算子 }; -} // namespace candy +} // namespace sageFlow diff --git a/include/operator/output_operator.h b/include/operator/output_operator.h index c5e3ed8..28681d7 100644 --- a/include/operator/output_operator.h +++ b/include/operator/output_operator.h @@ -5,7 +5,7 @@ #include "operator/operator.h" #include "stream/data_stream_source/data_stream_source.h" -namespace candy { +namespace sageFlow { enum class OutputChoice { NONE, Broadcast, Hash }; // NOLINT class OutputOperator final : public Operator { @@ -27,4 +27,4 @@ class OutputOperator final : public Operator { OutputChoice output_choice_ = OutputChoice::NONE; std::shared_ptr stream_ = nullptr; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/operator/sink_operator.h b/include/operator/sink_operator.h index 15a3072..e5f6df1 100644 --- a/include/operator/sink_operator.h +++ b/include/operator/sink_operator.h @@ -7,7 +7,7 @@ #include "function/function.h" #include "operator/operator.h" -namespace candy { +namespace sageFlow { class SinkOperator final : public Operator { public: explicit SinkOperator(std::unique_ptr &sink_func); @@ -19,4 +19,4 @@ class SinkOperator final : public Operator { private: std::unique_ptr sink_func_; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/operator/topk_operator.h b/include/operator/topk_operator.h index 998aab7..395a387 100644 --- a/include/operator/topk_operator.h +++ b/include/operator/topk_operator.h @@ -7,7 +7,7 @@ #include "function/function.h" #include "operator/operator.h" -namespace candy { +namespace sageFlow { class TopkOperator final : public Operator { public: explicit TopkOperator(std::unique_ptr &topk_func, @@ -21,4 +21,4 @@ class TopkOperator final : public Operator { std::unique_ptr topk_func_; std::shared_ptr concurrency_manager_; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/operator/window_operator.h b/include/operator/window_operator.h index c976c6c..f85edc3 100644 --- a/include/operator/window_operator.h +++ b/include/operator/window_operator.h @@ -7,7 +7,7 @@ #include "common/data_types.h" #include "operator/operator.h" -namespace candy { +namespace sageFlow { class WindowOperator : public Operator { public: explicit WindowOperator(std::unique_ptr &window_func); @@ -50,4 +50,4 @@ class SlidingWindowOperator final : public WindowOperator { std::list> window_buffer_; int slide_size_; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/query/optimizer/planner.h b/include/query/optimizer/planner.h index f170dd4..00b3a82 100644 --- a/include/query/optimizer/planner.h +++ b/include/query/optimizer/planner.h @@ -8,7 +8,7 @@ #include "operator/operator_api.h" #include "execution/execution_graph.h" -namespace candy { +namespace sageFlow { class Planner { public: @@ -42,4 +42,4 @@ class Planner { size_t default_parallelism) const; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/storage/storage_manager.h b/include/storage/storage_manager.h index 787f57f..9fbcc02 100644 --- a/include/storage/storage_manager.h +++ b/include/storage/storage_manager.h @@ -12,7 +12,7 @@ #include "common/data_types.h" #include "compute_engine/compute_engine.h" -namespace candy { +namespace sageFlow { using idx_t = int32_t; class StorageManager { @@ -46,4 +46,4 @@ class StorageManager { mutable std::shared_mutex map_mutex_; int begin_ = 0; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/stream/data_stream_source/data_stream_source.h b/include/stream/data_stream_source/data_stream_source.h index c99be6c..a1661da 100644 --- a/include/stream/data_stream_source/data_stream_source.h +++ b/include/stream/data_stream_source/data_stream_source.h @@ -5,7 +5,7 @@ #include "common/data_types.h" #include "stream/stream.h" -namespace candy { +namespace sageFlow { enum class DataStreamSourceType { // NOLINT None, File, @@ -30,4 +30,4 @@ class DataStreamSource : public Stream { DataStreamSourceType type_ = DataStreamSourceType::None; size_t buffer_size_limit_ = (1<<20); // 1MB }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/stream/data_stream_source/file_stream_source.h b/include/stream/data_stream_source/file_stream_source.h index 9f1e2b4..0dbb083 100644 --- a/include/stream/data_stream_source/file_stream_source.h +++ b/include/stream/data_stream_source/file_stream_source.h @@ -8,7 +8,7 @@ #include "common/data_types.h" #include "stream/data_stream_source/data_stream_source.h" -namespace candy { +namespace sageFlow { class FileStreamSource : public DataStreamSource { public: explicit FileStreamSource(std::string name); @@ -26,4 +26,4 @@ class FileStreamSource : public DataStreamSource { uint64_t timeout_ms_{1000}; std::mutex mtx_; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/stream/data_stream_source/sift_stream_source.h b/include/stream/data_stream_source/sift_stream_source.h index 4d2f45a..b64e51f 100644 --- a/include/stream/data_stream_source/sift_stream_source.h +++ b/include/stream/data_stream_source/sift_stream_source.h @@ -8,7 +8,7 @@ #include "common/data_types.h" #include "stream/data_stream_source/data_stream_source.h" -namespace candy { +namespace sageFlow { class SiftStreamSource final : public DataStreamSource { public: explicit SiftStreamSource(std::string name); @@ -23,4 +23,4 @@ class SiftStreamSource final : public DataStreamSource { std::string file_path_; std::vector> records_; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/stream/data_stream_source/simple_stream_source.h b/include/stream/data_stream_source/simple_stream_source.h index e8e41e1..a4729ad 100644 --- a/include/stream/data_stream_source/simple_stream_source.h +++ b/include/stream/data_stream_source/simple_stream_source.h @@ -8,7 +8,7 @@ #include "common/data_types.h" #include "stream/data_stream_source/data_stream_source.h" -namespace candy { +namespace sageFlow { class SimpleStreamSource final : public DataStreamSource { public: explicit SimpleStreamSource(std::string name); @@ -32,4 +32,4 @@ class SimpleStreamSource final : public DataStreamSource { std::string file_path_; std::vector> records_; }; -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/include/stream/stream.h b/include/stream/stream.h index 9221f57..c7dc8c1 100644 --- a/include/stream/stream.h +++ b/include/stream/stream.h @@ -8,7 +8,7 @@ #include "function/function.h" -namespace candy { +namespace sageFlow { class Function; class FilterFunction; class MapFunction; @@ -90,4 +90,4 @@ class Stream { std::string join_method_ = "bruteforce_lazy"; double join_similarity_threshold_ = 0.8; }; -} // namespace candy +} // namespace sageFlow diff --git a/include/stream/stream_environment.h b/include/stream/stream_environment.h index 2b2486d..49282ca 100644 --- a/include/stream/stream_environment.h +++ b/include/stream/stream_environment.h @@ -10,7 +10,7 @@ #include "query/optimizer/planner.h" #include "execution/execution_graph.h" -namespace candy { +namespace sageFlow { class StreamEnvironment { public: // Constructor to initialize the environment @@ -69,6 +69,6 @@ class StreamEnvironment { ConfigMap config_{}; // 扁平 key(支持 log.level) }; -} // namespace candy +} // namespace sageFlow #endif // STREAM_ENVIRONMENT_HPP \ No newline at end of file diff --git a/include/utils/conf_map.h b/include/utils/conf_map.h index ebb150d..617cc4a 100644 --- a/include/utils/conf_map.h +++ b/include/utils/conf_map.h @@ -13,7 +13,7 @@ #include "toml++/toml.hpp" -namespace candy { +namespace sageFlow { enum class ConfigType { STRING, I64, DOUBLE }; using ConfigValue = std::variant; @@ -83,6 +83,6 @@ class ConfigMap { return true; } }; -} // namespace candy +} // namespace sageFlow #endif // CONF_MAP_H diff --git a/include/utils/error_codes.h b/include/utils/error_codes.h index c02b72c..6f2d159 100644 --- a/include/utils/error_codes.h +++ b/include/utils/error_codes.h @@ -3,7 +3,7 @@ #include -namespace candy { +namespace sageFlow { enum class ErrorCode { SUCCESS, @@ -22,6 +22,6 @@ std::string error_to_string(ErrorCode code) { } } -} // namespace candy +} // namespace sageFlow diff --git a/include/utils/log_config.h b/include/utils/log_config.h index 0ddb205..c431b1c 100644 --- a/include/utils/log_config.h +++ b/include/utils/log_config.h @@ -2,7 +2,7 @@ #include #include -namespace candy { +namespace sageFlow { struct LogConfig { spdlog::level::level_enum level = spdlog::level::info; }; @@ -24,6 +24,6 @@ inline spdlog::level::level_enum parse_log_level(const std::string &s) { // Apply log level to global logger void apply_log_level(spdlog::level::level_enum lvl); -// Load from env (CANDY_LOG_LEVEL) or passed string; env overrides. +// Load from env (SAGEFLOW_LOG_LEVEL) or passed string; env overrides. void init_log_level(const std::string &level_from_config); } diff --git a/include/utils/logger.h b/include/utils/logger.h index 08f10dc..ea154e4 100644 --- a/include/utils/logger.h +++ b/include/utils/logger.h @@ -10,7 +10,7 @@ // Windows 控制台 ANSI 支持:由用户外部启用,避免在头文件内直接调用 WinAPI 造成编译器解析问题。 -namespace candy { +namespace sageFlow { inline std::atomic g_log_seq{0}; @@ -22,14 +22,14 @@ inline std::shared_ptr get_logger() { sink->set_color_mode(spdlog::color_mode::always); // 使用默认的等级颜色(spdlog 已内置),仅强制启用颜色 - auto lg = std::make_shared("candy", sink); + auto lg = std::make_shared("sageFlow", sink); // Pattern:时间 线程 等级(带色) [PHASE] seq=N msg lg->set_pattern("[%H:%M:%S.%e] [tid=%t] [%^%l%$] %v"); - // 初始日志等级:从环境变量 CANDY_LOG_LEVEL 读取;若未设置则默认为 info。 + // 初始日志等级:从环境变量 SAGEFLOW_LOG_LEVEL 读取;若未设置则默认为 info。 // 注意:不调用 init_log_level()/apply_log_level 以避免静态初始化期间的递归。 spdlog::level::level_enum initial_lvl = spdlog::level::info; - if (const char* env = std::getenv("CANDY_LOG_LEVEL"); env && *env) { - initial_lvl = candy::parse_log_level(env); + if (const char* env = std::getenv("SAGEFLOW_LOG_LEVEL"); env && *env) { + initial_lvl = sageFlow::parse_log_level(env); } lg->set_level(initial_lvl); for (auto &s : lg->sinks()) { @@ -43,15 +43,15 @@ inline std::shared_ptr get_logger() { // 基础宏:带相位着色 + 递增序号 // 统一格式:[PHASE] seq=N message // 新增 DEBUG 级别,便于将高频诊断从 INFO 下沉 -#define CANDY_LOG_DEBUG(phase, fmt, ...) \ - get_logger()->debug("[{}] seq={} " fmt, phase, candy::g_log_seq.fetch_add(1, std::memory_order_relaxed), ##__VA_ARGS__) -#define CANDY_LOG_INFO(phase, fmt, ...) \ - get_logger()->info("[{}] seq={} " fmt, phase, candy::g_log_seq.fetch_add(1, std::memory_order_relaxed), ##__VA_ARGS__) +#define SAGEFLOW_LOG_DEBUG(phase, fmt, ...) \ + get_logger()->debug("[{}] seq={} " fmt, phase, sageFlow::g_log_seq.fetch_add(1, std::memory_order_relaxed), ##__VA_ARGS__) +#define SAGEFLOW_LOG_INFO(phase, fmt, ...) \ + get_logger()->info("[{}] seq={} " fmt, phase, sageFlow::g_log_seq.fetch_add(1, std::memory_order_relaxed), ##__VA_ARGS__) -#define CANDY_LOG_WARN(phase, fmt, ...) \ - get_logger()->warn("[{}] seq={} " fmt, phase, candy::g_log_seq.fetch_add(1, std::memory_order_relaxed), ##__VA_ARGS__) +#define SAGEFLOW_LOG_WARN(phase, fmt, ...) \ + get_logger()->warn("[{}] seq={} " fmt, phase, sageFlow::g_log_seq.fetch_add(1, std::memory_order_relaxed), ##__VA_ARGS__) -#define CANDY_LOG_ERROR(phase, fmt, ...) \ - get_logger()->error("[{}] seq={} " fmt, phase, candy::g_log_seq.fetch_add(1, std::memory_order_relaxed), ##__VA_ARGS__) +#define SAGEFLOW_LOG_ERROR(phase, fmt, ...) \ + get_logger()->error("[{}] seq={} " fmt, phase, sageFlow::g_log_seq.fetch_add(1, std::memory_order_relaxed), ##__VA_ARGS__) -} // namespace candy +} // namespace sageFlow diff --git a/include/utils/monitoring.h b/include/utils/monitoring.h index a895888..b670ff3 100644 --- a/include/utils/monitoring.h +++ b/include/utils/monitoring.h @@ -4,7 +4,7 @@ #include #include -namespace candy { +namespace sageFlow { class PerformanceMonitor { public: @@ -29,4 +29,4 @@ class PerformanceMonitor { bool profiling_; }; -} // namespace candy +} // namespace sageFlow diff --git a/python/__init__.py b/python/__init__.py deleted file mode 100644 index 94a9469..0000000 --- a/python/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Python bindings and wrappers for SAGE-Flow live here. - -This package is intended to house all Python-side modules for the component. -""" diff --git a/python/bindings.cpp b/python/bindings.cpp deleted file mode 100644 index 9ca7866..0000000 --- a/python/bindings.cpp +++ /dev/null @@ -1,104 +0,0 @@ -#include -#include -#include -#include - -// C++ headers from candy (SAGE-Flow) -#include "common/data_types.h" -#include "function/sink_function.h" -#include "stream/stream.h" -#include "stream/stream_environment.h" -#include "stream/data_stream_source/simple_stream_source.h" - -namespace py = pybind11; -using namespace candy; // NOLINT - -PYBIND11_MODULE(_sage_flow, m) { - m.doc() = "SAGE Flow - Stream processing engine"; - - // Enums - py::enum_(m, "DataType") - .value("None", DataType::None) - .value("Int8", DataType::Int8) - .value("Int16", DataType::Int16) - .value("Int32", DataType::Int32) - .value("Int64", DataType::Int64) - .value("Float32", DataType::Float32) - .value("Float64", DataType::Float64); - - // VectorData - py::class_(m, "VectorData") - .def(py::init()) - .def(py::init([](int32_t dim, DataType type, py::array_t arr) { - auto buf = arr.request(); - if (buf.ndim != 1 || buf.shape[0] != dim) { - throw std::runtime_error("Array shape mismatch"); - } - auto bytes = static_cast(dim) * sizeof(float); - auto *data = new char[bytes]; - std::memcpy(data, buf.ptr, bytes); - return VectorData(dim, type, data); - })) - .def(py::init([](py::array_t arr) { - auto buf = arr.request(); - if (buf.ndim != 1) { - throw std::runtime_error("Array must be 1D"); - } - int32_t dim = static_cast(buf.shape[0]); - auto bytes = static_cast(dim) * sizeof(float); - auto *data = new char[bytes]; - std::memcpy(data, buf.ptr, bytes); - return VectorData(dim, DataType::Float32, data); - })); - - // VectorRecord - py::class_(m, "VectorRecord") - .def(py::init()) - .def_readonly("uid", &VectorRecord::uid_) - .def_readonly("timestamp", &VectorRecord::timestamp_) - .def_readonly("data", &VectorRecord::data_); - - // Stream - py::class_>(m, "Stream") - .def(py::init()) - // Minimal API: only bind a Python-friendly sink writer used by examples - .def("write_sink_py", [](Stream &self, const std::string &name, py::function cb) { - auto fn = SinkFunction(name, [cb](std::unique_ptr &rec) { - py::gil_scoped_acquire gil; - cb(rec->uid_, rec->timestamp_); - }); - auto fn_ptr = std::make_unique(std::move(fn)); - return self.writeSink(std::move(fn_ptr)); - }, py::arg("name"), py::arg("callback")); - - // SimpleStreamSource - py::class_, Stream>(m, "SimpleStreamSource") - .def(py::init()) - .def("addRecord", py::overload_cast(&SimpleStreamSource::addRecord)) - .def("addRecord", [](SimpleStreamSource &self, uint64_t uid, int64_t ts, py::array_t arr) { - auto buf = arr.request(); - if (buf.ndim != 1) { - throw std::runtime_error("Array must be 1D"); - } - int32_t dim = static_cast(buf.shape[0]); - auto bytes = static_cast(dim) * sizeof(float); - auto *data = new char[bytes]; - std::memcpy(data, buf.ptr, bytes); - VectorData vec(dim, DataType::Float32, data); - self.addRecord(uid, ts, std::move(vec)); - }) - .def("write_sink_py", [](SimpleStreamSource &self, const std::string &name, py::function cb) { - auto fn = SinkFunction(name, [cb](std::unique_ptr &rec) { - py::gil_scoped_acquire gil; - cb(rec->uid_, rec->timestamp_); - }); - auto fn_ptr = std::make_unique(std::move(fn)); - return self.writeSink(std::move(fn_ptr)); - }, py::arg("name"), py::arg("callback")); - - // StreamEnvironment - py::class_(m, "StreamEnvironment") - .def(py::init<>()) - .def("addStream", &StreamEnvironment::addStream) - .def("execute", &StreamEnvironment::execute); -} \ No newline at end of file diff --git a/python/micro_service/__init__.py b/python/micro_service/__init__.py deleted file mode 100644 index 33b70cc..0000000 --- a/python/micro_service/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Expose SageFlowService from submodule -from .sage_flow_service import SageFlowService - -__all__ = ["SageFlowService"] diff --git a/python/micro_service/sage_flow_service.py b/python/micro_service/sage_flow_service.py deleted file mode 100644 index 7998cd1..0000000 --- a/python/micro_service/sage_flow_service.py +++ /dev/null @@ -1,88 +0,0 @@ -from __future__ import annotations - -import queue -import threading -import time -from dataclasses import dataclass -from typing import Optional - -import numpy as np -from sage.middleware.components.sage_flow.python.sage_flow import ( - SimpleStreamSource, - StreamEnvironment, -) - - -@dataclass -class _Record: - uid: int - vec: np.ndarray - - -class SageFlowService: - """ - A minimal micro-service wrapper for SAGE-Flow used by examples. - - - push(uid, vec): enqueue vector for processing - - run(): drain queue, feed to flow, and execute once - """ - - def __init__(self, dim: int = 4, dtype: str = "Float32") -> None: - self.dim = dim - self.dtype = dtype - self._q: "queue.Queue[_Record]" = queue.Queue() - self._env = StreamEnvironment() - self._source = SimpleStreamSource("sage_flow_service_source") - self._lock = threading.Lock() - self._added_to_env = False - # Note: don't add to env yet; defer until a sink is attached - - # API expected by examples - def push(self, uid: int, vec: np.ndarray) -> None: - if not isinstance(vec, np.ndarray): - vec = np.asarray(vec, dtype=np.float32) - vec = vec.astype(np.float32, copy=False) - if vec.ndim != 1 or vec.shape[0] != self.dim: - raise ValueError(f"vector shape must be ({self.dim},)") - self._q.put(_Record(uid=int(uid), vec=vec)) - - def run(self) -> None: - # Drain queue into source, then execute once - drained = 0 - with self._lock: - while True: - try: - rec = self._q.get_nowait() - except queue.Empty: - break - ts = int(time.time() * 1000) - self._source.addRecord(rec.uid, ts, rec.vec) - drained += 1 - if drained: - # If user hasn't attached sinks, add source to env once so execution proceeds - if not self._added_to_env: - # Attach a default printing sink for visibility - self._source.write_sink_py( - "default_print_sink", - lambda uid, ts: print(f"[svc sink] uid={uid}, ts={ts}", flush=True), - ) - self._env.addStream(self._source) - self._added_to_env = True - self._env.execute() - - def set_sink(self, callback, name: str = "py_sink") -> None: - """Attach a Python sink callback for visible outputs. - - Args: - callback: Callable taking (uid: int, ts: int) - name: Sink name, defaults to 'py_sink'. - """ - self._source.write_sink_py(name, callback) - if not self._added_to_env: - self._env.addStream(self._source) - self._added_to_env = True - - # Optional: expose environment for advanced integrations - @property - def env(self) -> StreamEnvironment: - return self._env diff --git a/python/sage_flow.py b/python/sage_flow.py deleted file mode 100644 index 34cca05..0000000 --- a/python/sage_flow.py +++ /dev/null @@ -1,119 +0,0 @@ -""" -SAGE Flow - High-performance vector stream processing engine (Python side) - -All Python-facing APIs for SAGE-Flow live under this module. -""" - -from typing import Any, Callable, Dict, Optional - -import numpy as np - -try: - from . import _sage_flow -except ImportError as e: - import glob - import importlib - import sys - from pathlib import Path - - here = Path(__file__).resolve().parent - candidate_paths = [ - here, # same directory as this file (editable install case) - here / "build" / "lib", # standard local build - here.parent / "build" / "lib", # component-level build - here.parent / "build", # build directory - here.parent / "install", # install directory - ] - - # Add paths to sys.path - for p in candidate_paths: - if p.exists() and str(p) not in sys.path: - sys.path.insert(0, str(p)) - - # Try to find the .so file directly - found_so = False - for p in candidate_paths: - if p.exists(): - # Look for _sage_flow.*.so files - so_files = list(p.glob("_sage_flow*.so")) - if so_files: - found_so = True - # Add this directory to sys.path if not already there - if str(p) not in sys.path: - sys.path.insert(0, str(p)) - break - - try: - _sage_flow = importlib.import_module("_sage_flow") - except Exception: - raise ImportError( - f"_sage_flow native module not found. Please build the extension by running 'sage extensions install sage_flow' or executing the build.sh under packages/sage-middleware/src/sage/middleware/components/sage_flow. " - f"Searched in: {[str(p) for p in candidate_paths if p.exists()]}, Found .so files: {found_so}" - ) from e - -DataType = _sage_flow.DataType -VectorData = _sage_flow.VectorData -VectorRecord = _sage_flow.VectorRecord -Stream = _sage_flow.Stream -StreamEnvironment = _sage_flow.StreamEnvironment -SimpleStreamSource = _sage_flow.SimpleStreamSource - - -class SageFlow: - def __init__(self, config: Optional[Dict[str, Any]] = None): - self.env = StreamEnvironment() - self.streams = [] - self.config = config or {} - - def create_stream(self, name: str): - return Stream(name) - - def create_simple_source(self, name: str): - return SimpleStreamSource(name) - - def add_vector_record(self, source, uid: int, timestamp: int, vector): - if isinstance(vector, np.ndarray): - vector = vector.astype(np.float32, copy=False) - else: - vector = np.asarray(vector, dtype=np.float32) - source.addRecord(uid, timestamp, vector) - - def add_stream(self, stream): - self.streams.append(stream) - self.env.addStream(stream) - - def execute(self): - self.env.execute() - - def get_stream_snapshot(self) -> Dict[str, Any]: - return { - "streams_count": len(self.streams), - "config": self.config, - "status": "active", - } - - -def create_stream_engine(config: Optional[Dict[str, Any]] = None) -> SageFlow: - return SageFlow(config) - - -def create_vector_stream(name: str): - return Stream(name) - - -def create_simple_data_source(name: str): - return SimpleStreamSource(name) - - -__all__ = [ - "SageFlow", - "create_stream_engine", - "create_vector_stream", - "create_simple_data_source", - "DataType", - "VectorData", - "VectorRecord", - "Stream", - "StreamEnvironment", - "SimpleStreamSource", -] diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 33deac6..4e02ca9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -31,23 +31,45 @@ add_subdirectory( add_subdirectory( stream ) -add_library( - candy - INTERFACE + +# 收集所有源文件生成动态库(像 sageDB 一样) +file(GLOB_RECURSE SAGEFLOW_SOURCES + utils/*.cpp + index/*.cpp + common/*.cpp + compute_engine/*.cpp + concurrency/*.cpp + execution/*.cpp + function/*.cpp + operator/*.cpp + query/*.cpp + storage/*.cpp + stream/*.cpp ) + +# 排除测试和示例文件 +list(FILTER SAGEFLOW_SOURCES EXCLUDE REGEX ".*test.*") +list(FILTER SAGEFLOW_SOURCES EXCLUDE REGEX ".*example.*") + +# 创建动态库 +add_library(sageflow SHARED ${SAGEFLOW_SOURCES}) + target_link_libraries( - candy - INTERFACE + sageflow + PUBLIC externalRuntimeLibs - execution - stream - query - operator - function - common - utils - concurrency - storage - index - compute_engine +) +if(ENABLE_GPERFTOOLS AND DEFINED SAGE_GPERFTOOLS_LIBS AND SAGE_GPERFTOOLS_LIBS) + target_link_libraries( + sageflow + PRIVATE + ${SAGE_GPERFTOOLS_LIBS} + ) +endif() + +# 安装动态库 +install(TARGETS sageflow + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin ) \ No newline at end of file diff --git a/src/common/data_types.cpp b/src/common/data_types.cpp index 1a7c034..d018511 100644 --- a/src/common/data_types.cpp +++ b/src/common/data_types.cpp @@ -6,7 +6,7 @@ #include #include -candy::VectorData::VectorData(const int32_t dim, const DataType type, char* data) +sageFlow::VectorData::VectorData(const int32_t dim, const DataType type, char* data) : dim_(dim), type_(type) { // Deep copy the incoming buffer to avoid taking ownership of external memory const auto bytes = static_cast(dim_) * DATA_TYPE_SIZE[type_]; @@ -19,24 +19,24 @@ candy::VectorData::VectorData(const int32_t dim, const DataType type, char* data } } -candy::VectorData::VectorData(const int32_t dim, const DataType type) +sageFlow::VectorData::VectorData(const int32_t dim, const DataType type) : dim_(dim), type_(type), data_(new char[dim * DATA_TYPE_SIZE[type]]) {} -candy::VectorData::VectorData(const VectorData& other) : dim_(other.dim_), type_(other.type_) { +sageFlow::VectorData::VectorData(const VectorData& other) : dim_(other.dim_), type_(other.type_) { data_ = std::make_unique(dim_ * DATA_TYPE_SIZE[type_]); // Allocate memory for data memcpy(data_.get(), other.data_.get(), dim_ * DATA_TYPE_SIZE[type_]); // Copy data } -auto candy::VectorData::operator==(const VectorData& other) const -> bool { +auto sageFlow::VectorData::operator==(const VectorData& other) const -> bool { if (dim_ != other.dim_) { return false; // Check dimension equality } return memcmp(data_.get(), other.data_.get(), dim_ * DATA_TYPE_SIZE[type_]) == 0; // Compare data } -auto candy::VectorData::operator!=(const VectorData& other) const -> bool { return !(*this == other); } +auto sageFlow::VectorData::operator!=(const VectorData& other) const -> bool { return !(*this == other); } -bool candy::VectorData::Serialize(std::ostream &out) const { +bool sageFlow::VectorData::Serialize(std::ostream &out) const { out.write(reinterpret_cast(&dim_), sizeof(dim_)); int typeInt = static_cast(type_); out.write(reinterpret_cast(&typeInt), sizeof(typeInt)); @@ -44,7 +44,7 @@ bool candy::VectorData::Serialize(std::ostream &out) const { return !out.fail(); } -bool candy::VectorData::Deserialize(std::istream &in) { +bool sageFlow::VectorData::Deserialize(std::istream &in) { in.read(reinterpret_cast(&dim_), sizeof(dim_)); int typeInt; in.read(reinterpret_cast(&typeInt), sizeof(typeInt)); @@ -54,7 +54,7 @@ bool candy::VectorData::Deserialize(std::istream &in) { return !in.fail(); } -void candy::VectorData::printData(std::ostream &os) const { +void sageFlow::VectorData::printData(std::ostream &os) const { if (!data_) { os << "Data is null." << std::endl; return; @@ -108,31 +108,31 @@ void candy::VectorData::printData(std::ostream &os) const { os << "]" << std::endl; } -void candy::VectorRecord::printRecord(std::ostream &os) const { +void sageFlow::VectorRecord::printRecord(std::ostream &os) const { os << "VectorRecord (uid: " << uid_ << ", timestamp: " << timestamp_ << ")" << std::endl; data_.printData(os); } -candy::VectorRecord::VectorRecord(const uint64_t& uid, const int64_t& timestamp, VectorData&& data) +sageFlow::VectorRecord::VectorRecord(const uint64_t& uid, const int64_t& timestamp, VectorData&& data) : uid_(uid), timestamp_(timestamp), data_(std::move(data)) {} -candy::VectorRecord::VectorRecord(const uint64_t& uid, const int64_t& timestamp, const VectorData& data) +sageFlow::VectorRecord::VectorRecord(const uint64_t& uid, const int64_t& timestamp, const VectorData& data) : uid_(uid), timestamp_(timestamp), data_(data) {} -candy::VectorRecord::VectorRecord(const uint64_t& uid, const int64_t& timestamp, int32_t dim, DataType type, char* data) +sageFlow::VectorRecord::VectorRecord(const uint64_t& uid, const int64_t& timestamp, int32_t dim, DataType type, char* data) : uid_(uid), timestamp_(timestamp), data_(dim, type, data) {} -auto candy::VectorRecord::operator==(const VectorRecord& other) const -> bool { +auto sageFlow::VectorRecord::operator==(const VectorRecord& other) const -> bool { return uid_ == other.uid_ && timestamp_ == other.timestamp_ && data_ == other.data_; } -bool candy::VectorRecord::Serialize(std::ostream &out) const { +bool sageFlow::VectorRecord::Serialize(std::ostream &out) const { out.write(reinterpret_cast(&uid_), sizeof(uid_)); out.write(reinterpret_cast(×tamp_), sizeof(timestamp_)); return data_.Serialize(out); } -bool candy::VectorRecord::Deserialize(std::istream &in) { +bool sageFlow::VectorRecord::Deserialize(std::istream &in) { uint64_t uid; int64_t ts; in.read(reinterpret_cast(&uid), sizeof(uid)); diff --git a/src/compute_engine/compute_engine.cpp b/src/compute_engine/compute_engine.cpp index f4c4751..138cbf6 100644 --- a/src/compute_engine/compute_engine.cpp +++ b/src/compute_engine/compute_engine.cpp @@ -1,7 +1,7 @@ #include "compute_engine/compute_engine.h" #include -auto candy::ComputeEngine::Similarity(const VectorData& vec1, const VectorData& vec2, const double alpha) -> double { +auto sageFlow::ComputeEngine::Similarity(const VectorData& vec1, const VectorData& vec2, const double alpha) -> double { auto distance = EuclideanDistance(vec1, vec2); // Exponential Decay function to convert distance to similarity return std::exp(-alpha * distance); @@ -9,7 +9,7 @@ auto candy::ComputeEngine::Similarity(const VectorData& vec1, const VectorData& // 私有模板辅助函数 template -auto candy::ComputeEngine::EuclideanDistanceImpl(const VectorData& vec1, const VectorData& vec2) -> double { +auto sageFlow::ComputeEngine::EuclideanDistanceImpl(const VectorData& vec1, const VectorData& vec2) -> double { // 确保 T 是算术类型 static_assert(std::is_arithmetic::value, "Template parameter T must be an arithmetic type."); @@ -26,7 +26,7 @@ auto candy::ComputeEngine::EuclideanDistanceImpl(const VectorData& vec1, const V return std::sqrt(distance_sq); } -auto candy::ComputeEngine::EuclideanDistance(const VectorData& vec1, const VectorData& vec2) -> double { +auto sageFlow::ComputeEngine::EuclideanDistance(const VectorData& vec1, const VectorData& vec2) -> double { if (vec1.dim_ != vec2.dim_) { throw std::invalid_argument("Vectors must be of the same size"); } @@ -60,9 +60,9 @@ auto candy::ComputeEngine::EuclideanDistance(const VectorData& vec1, const Vecto return distance; } -auto candy::ComputeEngine::normalizeVector(const VectorData& vec) -> VectorData { return vec; } +auto sageFlow::ComputeEngine::normalizeVector(const VectorData& vec) -> VectorData { return vec; } -auto candy::ComputeEngine::getVectorSquareLength(const VectorData& vec) -> double { +auto sageFlow::ComputeEngine::getVectorSquareLength(const VectorData& vec) -> double { if (vec.dim_ == 0) { throw std::invalid_argument("Vector dimension cannot be zero"); } @@ -78,7 +78,7 @@ auto candy::ComputeEngine::getVectorSquareLength(const VectorData& vec) -> doubl return 0.0; } -auto candy::ComputeEngine::dotmultiply(const VectorData& vec1, const VectorData& vec2) -> double { +auto sageFlow::ComputeEngine::dotmultiply(const VectorData& vec1, const VectorData& vec2) -> double { if (vec1.dim_ != vec2.dim_) { throw std::invalid_argument("Vectors must be of the same size"); } @@ -95,4 +95,4 @@ auto candy::ComputeEngine::dotmultiply(const VectorData& vec1, const VectorData& return 0.0; } -candy::ComputeEngine::ComputeEngine() = default; \ No newline at end of file +sageFlow::ComputeEngine::ComputeEngine() = default; \ No newline at end of file diff --git a/src/concurrency/blank_controller.cpp b/src/concurrency/blank_controller.cpp index 7e1ebf8..b5f3005 100644 --- a/src/concurrency/blank_controller.cpp +++ b/src/concurrency/blank_controller.cpp @@ -3,9 +3,9 @@ // #include "concurrency/blank_controller.h" -candy::BlankController::BlankController() = default; +sageFlow::BlankController::BlankController() = default; -candy::BlankController::BlankController(std::shared_ptr index) { +sageFlow::BlankController::BlankController(std::shared_ptr index) { index_ = std::move(index); storage_manager_ = index_->storage_manager_; if (index_->index_type_ == IndexType::None) { @@ -13,9 +13,9 @@ candy::BlankController::BlankController(std::shared_ptr index) { } } -candy::BlankController::~BlankController() = default; +sageFlow::BlankController::~BlankController() = default; -auto candy::BlankController::insert(std::unique_ptr record) -> bool { +auto sageFlow::BlankController::insert(std::unique_ptr record) -> bool { if (!record) { return false; } @@ -25,22 +25,22 @@ auto candy::BlankController::insert(std::unique_ptr record) -> boo return index_->insert(uid);; } -auto candy::BlankController::erase(std::unique_ptr record) -> bool { return true; } +auto sageFlow::BlankController::erase(std::unique_ptr record) -> bool { return true; } -auto candy::BlankController::erase(const uint64_t uid) -> bool { +auto sageFlow::BlankController::erase(const uint64_t uid) -> bool { if (index_) { index_->erase(uid); } return storage_manager_->erase(uid); } -auto candy::BlankController::query(const VectorRecord& record, int k) +auto sageFlow::BlankController::query(const VectorRecord& record, int k) -> std::vector> { const auto uids = index_->query(record, k); return storage_manager_->getVectorsByUids(uids); } -auto candy::BlankController::query_for_join(const VectorRecord& record, +auto sageFlow::BlankController::query_for_join(const VectorRecord& record, double join_similarity_threshold) -> std::vector> { const auto uids = index_->query_for_join(record, join_similarity_threshold); return storage_manager_->getVectorsByUids(uids); diff --git a/src/concurrency/concurrency_manager.cpp b/src/concurrency/concurrency_manager.cpp index 06701a6..2343138 100644 --- a/src/concurrency/concurrency_manager.cpp +++ b/src/concurrency/concurrency_manager.cpp @@ -9,11 +9,11 @@ #include "index/knn.h" #include "index/vectraflow.h" -candy::ConcurrencyManager::ConcurrencyManager(std::shared_ptr storage) : storage_(std::move(storage)) {} +sageFlow::ConcurrencyManager::ConcurrencyManager(std::shared_ptr storage) : storage_(std::move(storage)) {} -candy::ConcurrencyManager::~ConcurrencyManager() = default; +sageFlow::ConcurrencyManager::~ConcurrencyManager() = default; -auto candy::ConcurrencyManager::create_index(const std::string& name, const IndexType& index_type, int dimension) +auto sageFlow::ConcurrencyManager::create_index(const std::string& name, const IndexType& index_type, int dimension) -> int { std::shared_ptr index = nullptr; switch (index_type) { @@ -47,13 +47,13 @@ auto candy::ConcurrencyManager::create_index(const std::string& name, const Inde return index->index_id_; } -auto candy::ConcurrencyManager::create_index(const std::string& name, int dimension) -> int { +auto sageFlow::ConcurrencyManager::create_index(const std::string& name, int dimension) -> int { return create_index(name, IndexType::BruteForce, dimension); } -auto candy::ConcurrencyManager::drop_index(const std::string& name) -> bool { return false; } +auto sageFlow::ConcurrencyManager::drop_index(const std::string& name) -> bool { return false; } -auto candy::ConcurrencyManager::insert(int index_id, std::unique_ptr record) -> bool { +auto sageFlow::ConcurrencyManager::insert(int index_id, std::unique_ptr record) -> bool { const auto it = controller_map_.find(index_id); if (it == controller_map_.end()) { return false; @@ -62,7 +62,7 @@ auto candy::ConcurrencyManager::insert(int index_id, std::unique_ptrinsert(std::move(record)); } -auto candy::ConcurrencyManager::erase(int index_id, std::unique_ptr record) -> bool { +auto sageFlow::ConcurrencyManager::erase(int index_id, std::unique_ptr record) -> bool { const auto it = controller_map_.find(index_id); if (it == controller_map_.end()) { return false; @@ -71,7 +71,7 @@ auto candy::ConcurrencyManager::erase(int index_id, std::unique_ptrerase(std::move(record)); } -auto candy::ConcurrencyManager::erase(int index_id, uint64_t uid) -> bool { +auto sageFlow::ConcurrencyManager::erase(int index_id, uint64_t uid) -> bool { const auto it = controller_map_.find(index_id); if (it == controller_map_.end()) { return false; @@ -80,7 +80,7 @@ auto candy::ConcurrencyManager::erase(int index_id, uint64_t uid) -> bool { return controller->erase(uid); } -auto candy::ConcurrencyManager::query(int index_id, const VectorRecord& record, int k) +auto sageFlow::ConcurrencyManager::query(int index_id, const VectorRecord& record, int k) -> std::vector> { const auto it = controller_map_.find(index_id); if (it == controller_map_.end()) { @@ -90,7 +90,7 @@ auto candy::ConcurrencyManager::query(int index_id, const VectorRecord& record, return controller->query(record, k); } -auto candy::ConcurrencyManager::query_for_join(int index_id, const VectorRecord& record, +auto sageFlow::ConcurrencyManager::query_for_join(int index_id, const VectorRecord& record, double join_similarity_threshold) -> std::vector> { const auto it = controller_map_.find(index_id); if (it == controller_map_.end()) { diff --git a/src/execution/blocking_queue.cpp b/src/execution/blocking_queue.cpp index 22a86c9..12e1f44 100644 --- a/src/execution/blocking_queue.cpp +++ b/src/execution/blocking_queue.cpp @@ -4,7 +4,7 @@ #include "execution/blocking_queue.h" -namespace candy { +namespace sageFlow { bool BlockingQueue::push(TaggedResponse&& value) { std::unique_lock lock(mutex_); diff --git a/src/execution/execution_graph.cpp b/src/execution/execution_graph.cpp index 8740bce..750e9fd 100644 --- a/src/execution/execution_graph.cpp +++ b/src/execution/execution_graph.cpp @@ -7,7 +7,7 @@ #include #include "utils/logger.h" -namespace candy { +namespace sageFlow { ExecutionGraph::~ExecutionGraph() { stop(); @@ -183,7 +183,7 @@ void ExecutionGraph::createConnections() { } void ExecutionGraph::start() { - CANDY_LOG_INFO("GRAPH", "Starting ExecutionGraph operators={} ", operators_.size()); + SAGEFLOW_LOG_INFO("GRAPH", "Starting ExecutionGraph operators={} ", operators_.size()); // 启动所有ExecutionVertex for (const auto& [op, info] : operator_infos_) { @@ -192,11 +192,11 @@ void ExecutionGraph::start() { } } - CANDY_LOG_INFO("GRAPH", "All ExecutionVertices started"); + SAGEFLOW_LOG_INFO("GRAPH", "All ExecutionVertices started"); } void ExecutionGraph::stop() { - CANDY_LOG_INFO("GRAPH", "Stopping ExecutionGraph..."); + SAGEFLOW_LOG_INFO("GRAPH", "Stopping ExecutionGraph..."); // 先尝试按拓扑顺序:优先停止 Source(OutputOperator) 以停止生产; // 再停止非 Source 以允许其排干剩余数据(ExecutionVertex 内部已有 drain 逻辑)。 std::vector> sources; @@ -223,7 +223,7 @@ void ExecutionGraph::stop() { for (auto &q : all_queues_) { if (q) q->stop(); } - CANDY_LOG_INFO("GRAPH", "All ExecutionVertices stopped"); + SAGEFLOW_LOG_INFO("GRAPH", "All ExecutionVertices stopped"); } void ExecutionGraph::join() { @@ -234,7 +234,7 @@ void ExecutionGraph::join() { } } - CANDY_LOG_INFO("GRAPH", "All ExecutionVertices finished"); + SAGEFLOW_LOG_INFO("GRAPH", "All ExecutionVertices finished"); } -} // namespace candy +} // namespace sageFlow diff --git a/src/execution/execution_vertex.cpp b/src/execution/execution_vertex.cpp index 4c8e003..0c1a908 100644 --- a/src/execution/execution_vertex.cpp +++ b/src/execution/execution_vertex.cpp @@ -9,7 +9,7 @@ #include #include "utils/logger.h" -namespace candy { +namespace sageFlow { ExecutionVertex::ExecutionVertex(const std::shared_ptr &op, const size_t index) : operator_(op), subtask_index_(index) { @@ -44,7 +44,7 @@ void ExecutionVertex::join() const { } void ExecutionVertex::run() const { - CANDY_LOG_DEBUG("VERTEX", "{} started thread={} ", name_, (size_t)std::hash{}(std::this_thread::get_id())); + SAGEFLOW_LOG_DEBUG("VERTEX", "{} started thread={} ", name_, (size_t)std::hash{}(std::this_thread::get_id())); auto source_op = dynamic_cast(operator_.get()); try { @@ -80,7 +80,7 @@ void ExecutionVertex::run() const { } catch (const std::exception& e) { int dim = (data.record_ ? data.record_->data_.dim_ : -1); uint64_t uid = (data.record_ ? data.record_->uid_ : 0); - CANDY_LOG_ERROR("APPLY", "operator={} slot={} dim={} uid={} what={} ", operator_->name, data_opt->slot, dim, uid, e.what()); + SAGEFLOW_LOG_ERROR("APPLY", "operator={} slot={} dim={} uid={} what={} ", operator_->name, data_opt->slot, dim, uid, e.what()); throw; } } @@ -95,19 +95,19 @@ void ExecutionVertex::run() const { } catch (const std::exception& e) { int dim = (data.record_ ? data.record_->data_.dim_ : -1); uint64_t uid = (data.record_ ? data.record_->uid_ : 0); - CANDY_LOG_ERROR("DRAIN", "operator={} slot={} dim={} uid={} what={} ", operator_->name, data_opt->slot, dim, uid, e.what()); + SAGEFLOW_LOG_ERROR("DRAIN", "operator={} slot={} dim={} uid={} what={} ", operator_->name, data_opt->slot, dim, uid, e.what()); break; // 排干阶段出现异常不再继续,防止无限重试 } } } } catch (const std::exception& e) { - CANDY_LOG_ERROR("VERTEX", "Exception name={} what={} ", name_, e.what()); + SAGEFLOW_LOG_ERROR("VERTEX", "Exception name={} what={} ", name_, e.what()); } // 关闭算子 operator_->close(); - CANDY_LOG_INFO("VERTEX", "{} finished", name_); + SAGEFLOW_LOG_INFO("VERTEX", "{} finished", name_); } } diff --git a/src/execution/input_gate.cpp b/src/execution/input_gate.cpp index 640f095..aa90b37 100644 --- a/src/execution/input_gate.cpp +++ b/src/execution/input_gate.cpp @@ -4,7 +4,7 @@ #include "execution/input_gate.h" -namespace candy { +namespace sageFlow { void InputGate::setup(const std::vector& queues) { input_queues_ = queues; diff --git a/src/execution/result_partition.cpp b/src/execution/result_partition.cpp index 04e4714..e25938c 100644 --- a/src/execution/result_partition.cpp +++ b/src/execution/result_partition.cpp @@ -4,7 +4,7 @@ #include "execution/result_partition.h" -namespace candy { +namespace sageFlow { void ResultPartition::setup(std::unique_ptr p, std::vector channels, int slot) { partitioner_ = std::move(p); channel_slot_map_[slot] = std::move(channels); diff --git a/src/execution/ring_buffer_queue.cpp b/src/execution/ring_buffer_queue.cpp index f235bdc..39073a3 100644 --- a/src/execution/ring_buffer_queue.cpp +++ b/src/execution/ring_buffer_queue.cpp @@ -4,7 +4,7 @@ #include "execution/ring_buffer_queue.h" -namespace candy { +namespace sageFlow { bool RingBufferQueue::push(TaggedResponse&& value) { const auto current_tail = tail_.load(std::memory_order_relaxed); const auto next_tail = (current_tail + 1) % size_; diff --git a/src/function/aggregate_function.cpp b/src/function/aggregate_function.cpp index 72265e6..0c426b7 100644 --- a/src/function/aggregate_function.cpp +++ b/src/function/aggregate_function.cpp @@ -3,9 +3,9 @@ // #include "function/aggregate_function.h" -candy::AggregateFunction::AggregateFunction(const std::string& name) : Function(name, FunctionType::None) {} +sageFlow::AggregateFunction::AggregateFunction(const std::string& name) : Function(name, FunctionType::None) {} -candy::AggregateFunction::AggregateFunction(const std::string& name, AggregateType aggregate_type) +sageFlow::AggregateFunction::AggregateFunction(const std::string& name, AggregateType aggregate_type) : Function(name, FunctionType::Aggregate), aggregate_type_(aggregate_type) {} -auto candy::AggregateFunction::getAggregateType() const -> AggregateType { return aggregate_type_; } \ No newline at end of file +auto sageFlow::AggregateFunction::getAggregateType() const -> AggregateType { return aggregate_type_; } \ No newline at end of file diff --git a/src/function/filter_function.cpp b/src/function/filter_function.cpp index 3df3a5c..a461990 100644 --- a/src/function/filter_function.cpp +++ b/src/function/filter_function.cpp @@ -1,11 +1,11 @@ #include "function/filter_function.h" -candy::FilterFunction::FilterFunction(std::string name) : Function(std::move(name), FunctionType::Filter) {} +sageFlow::FilterFunction::FilterFunction(std::string name) : Function(std::move(name), FunctionType::Filter) {} -candy::FilterFunction::FilterFunction(std::string name, FilterFunc filter_func) +sageFlow::FilterFunction::FilterFunction(std::string name, FilterFunc filter_func) : Function(std::move(name), FunctionType::Filter), filter_func_(std::move(filter_func)) {} -candy::Response candy::FilterFunction::Execute(Response &resp) { +sageFlow::Response sageFlow::FilterFunction::Execute(Response &resp) { if (resp.type_ == ResponseType::Record) { if (auto record = std::move(resp.record_); filter_func_(record)) { return Response{ResponseType::Record, std::move(record)}; @@ -23,4 +23,4 @@ candy::Response candy::FilterFunction::Execute(Response &resp) { return {}; } -auto candy::FilterFunction::setFilterFunc(FilterFunc filter_func) -> void { filter_func_ = std::move(filter_func); } \ No newline at end of file +auto sageFlow::FilterFunction::setFilterFunc(FilterFunc filter_func) -> void { filter_func_ = std::move(filter_func); } \ No newline at end of file diff --git a/src/function/function.cpp b/src/function/function.cpp index 642c10d..1b5dc19 100644 --- a/src/function/function.cpp +++ b/src/function/function.cpp @@ -1,19 +1,19 @@ #include "function/function.h" -candy::Function::Function(std::string name, FunctionType type) : name_(std::move(name)), type_(type) {} +sageFlow::Function::Function(std::string name, FunctionType type) : name_(std::move(name)), type_(type) {} -candy::Function::~Function() = default; +sageFlow::Function::~Function() = default; -auto candy::Function::getName() const -> std::string { return name_; } +auto sageFlow::Function::getName() const -> std::string { return name_; } -auto candy::Function::getType() const -> FunctionType { return type_; } +auto sageFlow::Function::getType() const -> FunctionType { return type_; } -void candy::Function::setName(const std::string& name) { name_ = name; } +void sageFlow::Function::setName(const std::string& name) { name_ = name; } -void candy::Function::setType(const FunctionType type) { type_ = type; } +void sageFlow::Function::setType(const FunctionType type) { type_ = type; } -auto candy::Function::Execute(Response& resp) -> Response { return {}; } +auto sageFlow::Function::Execute(Response& resp) -> Response { return {}; } -auto candy::Function::Execute(Response& left, Response& right) -> Response { +auto sageFlow::Function::Execute(Response& left, Response& right) -> Response { return {}; } diff --git a/src/function/itopk_function.cpp b/src/function/itopk_function.cpp index d1ab927..232098a 100644 --- a/src/function/itopk_function.cpp +++ b/src/function/itopk_function.cpp @@ -3,13 +3,13 @@ // #include "function/itopk_function.h" -candy::ITopkFunction::ITopkFunction(const std::string& name) : Function(name, FunctionType::Topk) {} +sageFlow::ITopkFunction::ITopkFunction(const std::string& name) : Function(name, FunctionType::Topk) {} -candy::ITopkFunction::ITopkFunction(const std::string& name, int k, int dim, std::unique_ptr record) +sageFlow::ITopkFunction::ITopkFunction(const std::string& name, int k, int dim, std::unique_ptr record) : Function(name, FunctionType::ITopk), k_(k), dim_(dim), record_(std::move(record)) {} -auto candy::ITopkFunction::getK() const -> int { return k_; } +auto sageFlow::ITopkFunction::getK() const -> int { return k_; } -auto candy::ITopkFunction::getDim() const -> int { return dim_; } +auto sageFlow::ITopkFunction::getDim() const -> int { return dim_; } -auto candy::ITopkFunction::getRecord()-> std::unique_ptr { return std::move(record_); } \ No newline at end of file +auto sageFlow::ITopkFunction::getRecord()-> std::unique_ptr { return std::move(record_); } \ No newline at end of file diff --git a/src/function/join_function.cpp b/src/function/join_function.cpp index a0aa9ba..e636b3b 100644 --- a/src/function/join_function.cpp +++ b/src/function/join_function.cpp @@ -1,20 +1,20 @@ #include "function/join_function.h" #include "utils/logger.h" -candy::JoinFunction::JoinFunction(std::string name, int dim) : Function(std::move(name), FunctionType::Join), dim_(dim) {} +sageFlow::JoinFunction::JoinFunction(std::string name, int dim) : Function(std::move(name), FunctionType::Join), dim_(dim) {} -candy::JoinFunction::JoinFunction(std::string name, JoinFunc join_func, int dim) : +sageFlow::JoinFunction::JoinFunction(std::string name, JoinFunc join_func, int dim) : Function(std::move(name), FunctionType::Join), join_func_(std::move(join_func)), dim_(dim) {} // TODO : 确定这个滑动窗口的步长 // 目前是 window / 4 -candy::JoinFunction::JoinFunction(std::string name, JoinFunc join_func, int64_t time_window, int dim) +sageFlow::JoinFunction::JoinFunction(std::string name, JoinFunc join_func, int64_t time_window, int dim) : Function(std::move(name), FunctionType::Join), windowL (time_window, time_window / 4), windowR(time_window, time_window / 4), threadSafeWindowL(time_window, time_window / 4), threadSafeWindowR(time_window, time_window / 4), join_func_(std::move(join_func)), dim_(dim) {} -auto candy::JoinFunction::Execute(Response& left, Response& right) -> Response { +auto sageFlow::JoinFunction::Execute(Response& left, Response& right) -> Response { if (left.type_ == ResponseType::Record && right.type_ == ResponseType::Record) { auto left_record = std::move(left.record_); auto right_record = std::move(right.record_); @@ -24,7 +24,7 @@ auto candy::JoinFunction::Execute(Response& left, Response& right) -> Response { auto out = join_func_(left_record, right_record); if (out) return Response{ResponseType::Record, std::move(out)}; } catch (const std::exception& e) { - CANDY_LOG_ERROR("JOIN_FUNC", "left_dim={} right_dim={} left_uid={} right_uid={} what={} ", + SAGEFLOW_LOG_ERROR("JOIN_FUNC", "left_dim={} right_dim={} left_uid={} right_uid={} what={} ", (left_record ? left_record->data_.dim_ : -1), (right_record ? right_record->data_.dim_ : -1), (left_record ? left_record->uid_ : 0), @@ -36,17 +36,17 @@ auto candy::JoinFunction::Execute(Response& left, Response& right) -> Response { return {}; } -auto candy::JoinFunction::getDim() const -> int { return dim_; } +auto sageFlow::JoinFunction::getDim() const -> int { return dim_; } -auto candy::JoinFunction::setJoinFunc(JoinFunc join_func) -> void { join_func_ = std::move(join_func); } +auto sageFlow::JoinFunction::setJoinFunc(JoinFunc join_func) -> void { join_func_ = std::move(join_func); } -auto candy::JoinFunction::getOtherStream() -> std::shared_ptr& { return other_stream_; } +auto sageFlow::JoinFunction::getOtherStream() -> std::shared_ptr& { return other_stream_; } -auto candy::JoinFunction::setOtherStream(std::shared_ptr other_plan) -> void { +auto sageFlow::JoinFunction::setOtherStream(std::shared_ptr other_plan) -> void { other_stream_ = std::move(other_plan); } -auto candy::JoinFunction::setWindow(int64_t windowsize, int64_t stepsize) -> void { +auto sageFlow::JoinFunction::setWindow(int64_t windowsize, int64_t stepsize) -> void { windowL.setWindow(windowsize, stepsize); windowR.setWindow(windowsize, stepsize); threadSafeWindowL.setWindow(windowsize, stepsize); diff --git a/src/function/map_function.cpp b/src/function/map_function.cpp index 379bff5..95b3cd2 100644 --- a/src/function/map_function.cpp +++ b/src/function/map_function.cpp @@ -1,11 +1,11 @@ #include "function/map_function.h" -candy::MapFunction::MapFunction(std::string name) : Function(std::move(name), FunctionType::Map) {} +sageFlow::MapFunction::MapFunction(std::string name) : Function(std::move(name), FunctionType::Map) {} -candy::MapFunction::MapFunction(std::string name, MapFunc map_func) +sageFlow::MapFunction::MapFunction(std::string name, MapFunc map_func) : Function(std::move(name), FunctionType::Map), map_func_(std::move(map_func)) {} -candy::Response candy::MapFunction::Execute(Response &resp) { +sageFlow::Response sageFlow::MapFunction::Execute(Response &resp) { if (resp.type_ == ResponseType::Record) { auto record = std::move(resp.record_); map_func_(record); @@ -22,4 +22,4 @@ candy::Response candy::MapFunction::Execute(Response &resp) { return {}; } -auto candy::MapFunction::setMapFunc(MapFunc map_func) -> void { map_func_ = std::move(map_func); } \ No newline at end of file +auto sageFlow::MapFunction::setMapFunc(MapFunc map_func) -> void { map_func_ = std::move(map_func); } \ No newline at end of file diff --git a/src/function/sink_function.cpp b/src/function/sink_function.cpp index 0304b5a..3a7a192 100644 --- a/src/function/sink_function.cpp +++ b/src/function/sink_function.cpp @@ -1,11 +1,11 @@ #include "function/sink_function.h" -candy::SinkFunction::SinkFunction(std::string name) : Function(std::move(name), FunctionType::Sink) {} +sageFlow::SinkFunction::SinkFunction(std::string name) : Function(std::move(name), FunctionType::Sink) {} -candy::SinkFunction::SinkFunction(std::string name, SinkFunc sink_func) +sageFlow::SinkFunction::SinkFunction(std::string name, SinkFunc sink_func) : Function(std::move(name), FunctionType::Sink), sink_func_(std::move(sink_func)) {} -candy::Response candy::SinkFunction::Execute(Response &resp) { +sageFlow::Response sageFlow::SinkFunction::Execute(Response &resp) { if (resp.type_ == ResponseType::Record) { auto record = std::move(resp.record_); sink_func_(record); @@ -21,4 +21,4 @@ candy::Response candy::SinkFunction::Execute(Response &resp) { return {}; } -auto candy::SinkFunction::setSinkFunc(SinkFunc sink_func) -> void { sink_func_ = std::move(sink_func); } \ No newline at end of file +auto sageFlow::SinkFunction::setSinkFunc(SinkFunc sink_func) -> void { sink_func_ = std::move(sink_func); } \ No newline at end of file diff --git a/src/function/topk_function.cpp b/src/function/topk_function.cpp index b2d983c..57777de 100644 --- a/src/function/topk_function.cpp +++ b/src/function/topk_function.cpp @@ -2,11 +2,11 @@ #include -candy::TopkFunction::TopkFunction(const std::string& name) : Function(name, FunctionType::Topk) {} +sageFlow::TopkFunction::TopkFunction(const std::string& name) : Function(name, FunctionType::Topk) {} -candy::TopkFunction::TopkFunction(const std::string& name, int k, int index_id) +sageFlow::TopkFunction::TopkFunction(const std::string& name, int k, int index_id) : Function(name, FunctionType::Topk), k_(k), index_id_(index_id) {} -auto candy::TopkFunction::getK() const -> int { return k_; } +auto sageFlow::TopkFunction::getK() const -> int { return k_; } -auto candy::TopkFunction::getIndexId() const -> int { return index_id_; } +auto sageFlow::TopkFunction::getIndexId() const -> int { return index_id_; } diff --git a/src/function/window_function.cpp b/src/function/window_function.cpp index 353478d..841ae92 100644 --- a/src/function/window_function.cpp +++ b/src/function/window_function.cpp @@ -3,23 +3,23 @@ // #include "function/window_function.h" -candy::WindowFunction::WindowFunction(std::string name) +sageFlow::WindowFunction::WindowFunction(std::string name) : Function(std::move(name), FunctionType::Window), window_type_(WindowType::Tumbling), window_size_(0), slide_size_(0) {} -candy::WindowFunction::WindowFunction(std::string name, const int window_size, const int slide_size, +sageFlow::WindowFunction::WindowFunction(std::string name, const int window_size, const int slide_size, const WindowType window_type) : Function(std::move(name), FunctionType::Window), window_type_(window_type), window_size_(window_size), slide_size_(slide_size) {} -auto candy::WindowFunction::Execute(Response& resp) -> candy::Response { return Function::Execute(resp); } +auto sageFlow::WindowFunction::Execute(Response& resp) -> sageFlow::Response { return Function::Execute(resp); } -auto candy::WindowFunction::getWindowType() const -> WindowType { return window_type_; } +auto sageFlow::WindowFunction::getWindowType() const -> WindowType { return window_type_; } -auto candy::WindowFunction::getWindowSize() const -> int { return window_size_; } +auto sageFlow::WindowFunction::getWindowSize() const -> int { return window_size_; } -auto candy::WindowFunction::getSlideSize() const -> int { return slide_size_; } \ No newline at end of file +auto sageFlow::WindowFunction::getSlideSize() const -> int { return slide_size_; } \ No newline at end of file diff --git a/src/index/hnsw.cpp b/src/index/hnsw.cpp index f0ebbf4..6a80f08 100644 --- a/src/index/hnsw.cpp +++ b/src/index/hnsw.cpp @@ -19,7 +19,7 @@ -namespace candy { +namespace sageFlow { // --- Constructor --- // Based on the provided hnsw.h @@ -494,4 +494,4 @@ auto HNSW::query(const VectorRecord &record, int k) -> std::vector { } -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/src/index/ivf.cpp b/src/index/ivf.cpp index 6c074c5..a1741c9 100644 --- a/src/index/ivf.cpp +++ b/src/index/ivf.cpp @@ -13,13 +13,13 @@ #include "utils/logger.h" -namespace candy { +namespace sageFlow { void Ivf::debugDumpStateUnlocked() { // 调用方需已持有 global_mutex_ 锁 size_t total_in_lists = 0; for (auto &kv : inverted_lists_) total_in_lists += kv.second.size(); - CANDY_LOG_WARN("INDEX", "DEBUG_DUMP size_={} total_in_lists={} deleted_uids={} nlists={} attempts={} success={} missing={} miss_in_storage={} miss_not_in_storage={} underflow={} ", + SAGEFLOW_LOG_WARN("INDEX", "DEBUG_DUMP size_={} total_in_lists={} deleted_uids={} nlists={} attempts={} success={} missing={} miss_in_storage={} miss_not_in_storage={} underflow={} ", size_.load(), total_in_lists, deleted_uids_.size(), inverted_lists_.size(), erase_attempts_.load(), erase_success_.load(), erase_missing_.load(), erase_missing_in_storage_.load(), erase_missing_not_in_storage_.load(), erase_underflow_.load()); @@ -30,7 +30,7 @@ void Ivf::debugDumpStateUnlocked() { std::string sample; size_t limit = std::min(kv.second.size(), 5); for (size_t i = 0; i < limit; ++i) { sample += std::to_string(kv.second[i]); sample.push_back(','); } - CANDY_LOG_DEBUG("INDEX", "list_id={} size={} sample=[{}]", kv.first, kv.second.size(), sample); + SAGEFLOW_LOG_DEBUG("INDEX", "list_id={} size={} sample=[{}]", kv.first, kv.second.size(), sample); if (++printed >= 5) break; } } @@ -45,7 +45,7 @@ Ivf::Ivf(int nlist, double rebuild_threshold, int nprobes) try { inverted_lists_.reserve(nlist); } catch (const std::exception& e) { - CANDY_LOG_ERROR("INDEX", "Ivf ctor reserve inverted_lists nlist={} error={} ", nlist, e.what()); + SAGEFLOW_LOG_ERROR("INDEX", "Ivf ctor reserve inverted_lists nlist={} error={} ", nlist, e.what()); throw; // 继续抛出维持原有语义 } } @@ -128,13 +128,13 @@ void Ivf::rebuildClustersInternal() { try { int logical_size = size_.load(std::memory_order_relaxed); if (logical_size < 0) { - CANDY_LOG_ERROR("INDEX", "size_ negative={} forcing to 0 before reserve (possible erase of non-existent id)", logical_size); + SAGEFLOW_LOG_ERROR("INDEX", "size_ negative={} forcing to 0 before reserve (possible erase of non-existent id)", logical_size); logical_size = 0; // 防止转换为 size_t 后变成巨大值 } size_t target = static_cast(logical_size) + deleted_uids_.size(); all_uids_in_index.reserve(target); } catch (const std::exception& e) { - CANDY_LOG_ERROR("INDEX", "reserve all_uids_in_index target_size={} error={} ", + SAGEFLOW_LOG_ERROR("INDEX", "reserve all_uids_in_index target_size={} error={} ", size_.load(std::memory_order_relaxed) + deleted_uids_.size(), e.what()); throw; } @@ -150,7 +150,7 @@ void Ivf::rebuildClustersInternal() { try { live_records.reserve(all_uids_in_index.size()); } catch (const std::exception& e) { - CANDY_LOG_ERROR("INDEX", "reserve live_records target_size={} error={} ", all_uids_in_index.size(), e.what()); + SAGEFLOW_LOG_ERROR("INDEX", "reserve live_records target_size={} error={} ", all_uids_in_index.size(), e.what()); throw; } @@ -186,7 +186,7 @@ void Ivf::rebuildClustersInternal() { try { centroids_.reserve(actual_clusters); } catch (const std::exception& e) { - CANDY_LOG_ERROR("INDEX", "reserve centroids actual_clusters={} error={} ", actual_clusters, e.what()); + SAGEFLOW_LOG_ERROR("INDEX", "reserve centroids actual_clusters={} error={} ", actual_clusters, e.what()); throw; } @@ -254,7 +254,7 @@ void Ivf::rebuildClustersInternal() { try { new_centroids.reserve(actual_clusters); } catch (const std::exception& e) { - CANDY_LOG_ERROR("INDEX", "reserve new_centroids actual_clusters={} error={} ", actual_clusters, e.what()); + SAGEFLOW_LOG_ERROR("INDEX", "reserve new_centroids actual_clusters={} error={} ", actual_clusters, e.what()); throw; } for(int i = 0; i < actual_clusters; ++i) { @@ -308,18 +308,18 @@ void Ivf::rebuildClustersInternal() { } int logical_size = size_.load(std::memory_order_relaxed); if (logical_size < 0 || static_cast(logical_size) != actual_total) { - CANDY_LOG_WARN("INDEX", "post-rebuild size mismatch logical={} actual={} deleted_uids={} vectors_since_last_rebuild={} ", + SAGEFLOW_LOG_WARN("INDEX", "post-rebuild size mismatch logical={} actual={} deleted_uids={} vectors_since_last_rebuild={} ", logical_size, actual_total, deleted_uids_.size(), vectors_since_last_rebuild_.load()); } } auto Ivf::insert(uint64_t id) -> bool { rebuildIfNeeded(); - CANDY_LOG_DEBUG("INDEX", "start inserting id={} size_before={} ", id, size_.load(std::memory_order_relaxed)); + SAGEFLOW_LOG_DEBUG("INDEX", "start inserting id={} size_before={} ", id, size_.load(std::memory_order_relaxed)); // 先从存储取出记录 auto record = storage_manager_ ? storage_manager_->getVectorByUid(id) : nullptr; if (!record) { - CANDY_LOG_ERROR("INDEX", "insert id={} failed to fetch from storage ", id); + SAGEFLOW_LOG_ERROR("INDEX", "insert id={} failed to fetch from storage ", id); return false; } @@ -339,7 +339,7 @@ auto Ivf::insert(uint64_t id) -> bool { inverted_lists_[0].push_back(id); size_.fetch_add(1, std::memory_order_relaxed); vectors_since_last_rebuild_.fetch_add(1, std::memory_order_relaxed); - CANDY_LOG_DEBUG("INDEX", "initialized centroids with first id={} dim={} lists=1 size_now={} ", + SAGEFLOW_LOG_DEBUG("INDEX", "initialized centroids with first id={} dim={} lists=1 size_now={} ", id, record->data_.dim_, size_.load(std::memory_order_relaxed)); return true; } @@ -353,7 +353,7 @@ auto Ivf::insert(uint64_t id) -> bool { std::shared_lock rlock(global_mutex_); rebuild_cv_.wait(rlock, [this]{ return !is_rebuilding_.load(); }); cluster_idx = assignToCluster(record->data_); - CANDY_LOG_DEBUG("INDEX", "inserting id={} assigned_cluster={} size_before={} ", id, cluster_idx, size_.load(std::memory_order_relaxed)); + SAGEFLOW_LOG_DEBUG("INDEX", "inserting id={} assigned_cluster={} size_before={} ", id, cluster_idx, size_.load(std::memory_order_relaxed)); } if (cluster_idx < 0) { @@ -378,7 +378,7 @@ auto Ivf::erase(uint64_t id) -> bool { if (deleted_uids_.find(id) != deleted_uids_.end()) { return true; } - CANDY_LOG_DEBUG("INDEX", "erasing id={} size_before={} ", id, size_.load(std::memory_order_relaxed)); + SAGEFLOW_LOG_DEBUG("INDEX", "erasing id={} size_before={} ", id, size_.load(std::memory_order_relaxed)); // 无条件记录删除 deleted_uids_.insert(id); @@ -452,7 +452,7 @@ auto Ivf::query(const VectorRecord &record, int k) -> std::vector { try { final_ids.reserve(top_k_results.size()); } catch (const std::exception& e) { - CANDY_LOG_ERROR("INDEX", "reserve final_ids size_hint={} error={} ", top_k_results.size(), e.what()); + SAGEFLOW_LOG_ERROR("INDEX", "reserve final_ids size_hint={} error={} ", top_k_results.size(), e.what()); throw; } while (!top_k_results.empty()) { @@ -510,4 +510,4 @@ auto Ivf::query_for_join(const VectorRecord &record, double join_similarity_thre return results; } -} // namespace candy +} // namespace sageFlow diff --git a/src/index/knn.cpp b/src/index/knn.cpp index d3a13da..fe2f74b 100644 --- a/src/index/knn.cpp +++ b/src/index/knn.cpp @@ -4,18 +4,18 @@ #include "index/knn.h" #include -candy::Knn::~Knn() = default; +sageFlow::Knn::~Knn() = default; -auto candy::Knn::insert(uint64_t id) -> bool { return true; } +auto sageFlow::Knn::insert(uint64_t id) -> bool { return true; } -auto candy::Knn::erase(uint64_t id) -> bool { return true; } +auto sageFlow::Knn::erase(uint64_t id) -> bool { return true; } -auto candy::Knn::query(const VectorRecord &record, int k) -> std::vector { +auto sageFlow::Knn::query(const VectorRecord &record, int k) -> std::vector { auto idxes = storage_manager_->topk(record, k); return idxes; } -auto candy::Knn::query_for_join(const VectorRecord &record, +auto sageFlow::Knn::query_for_join(const VectorRecord &record, double join_similarity_threshold) -> std::vector { return storage_manager_->similarityJoinQuery(record, join_similarity_threshold); } \ No newline at end of file diff --git a/src/index/vectraflow.cpp b/src/index/vectraflow.cpp index 75c078c..5beefe3 100644 --- a/src/index/vectraflow.cpp +++ b/src/index/vectraflow.cpp @@ -2,19 +2,19 @@ #include #include -candy::VectraFlow::~VectraFlow() = default; +sageFlow::VectraFlow::~VectraFlow() = default; -auto candy::VectraFlow::insert(uint64_t id) -> bool { +auto sageFlow::VectraFlow::insert(uint64_t id) -> bool { datas_.push_back(id); return true; } /// VectraFlow 目前不支持删除 -auto candy::VectraFlow::erase(uint64_t id) -> bool { return true; } +auto sageFlow::VectraFlow::erase(uint64_t id) -> bool { return true; } // 并行没搞明白 先不鸟它了 -auto candy::VectraFlow::query(const VectorRecord &record, int k) -> std::vector { +auto sageFlow::VectraFlow::query(const VectorRecord &record, int k) -> std::vector { const auto rec = &record; std :: priority_queue> pq; diff --git a/src/operator/CMakeLists.txt b/src/operator/CMakeLists.txt index 6ebf231..6d584e7 100644 --- a/src/operator/CMakeLists.txt +++ b/src/operator/CMakeLists.txt @@ -21,10 +21,10 @@ target_link_libraries( ) # 打开 join_operator 的数据打桩开关,启用指标采集代码块(可通过顶层选项控制) -if(CANDY_ENABLE_METRICS) +if(SAGEFLOW_ENABLE_METRICS) target_compile_definitions( operator PRIVATE - CANDY_ENABLE_METRICS=1 + SAGEFLOW_ENABLE_METRICS=1 ) endif() \ No newline at end of file diff --git a/src/operator/aggregate_operator.cpp b/src/operator/aggregate_operator.cpp index d3179eb..3599a2f 100644 --- a/src/operator/aggregate_operator.cpp +++ b/src/operator/aggregate_operator.cpp @@ -6,13 +6,13 @@ #include #include "function/aggregate_function.h" -candy::AggregateOperator::AggregateOperator(std::unique_ptr& aggregate_func) +sageFlow::AggregateOperator::AggregateOperator(std::unique_ptr& aggregate_func) : Operator(OperatorType::AGGREGATE), aggregate_func_(std::move(aggregate_func)) {} -auto Sum(std::unique_ptr& record, std::unique_ptr& record2) -> void { +auto Sum(std::unique_ptr& record, std::unique_ptr& record2) -> void { const auto& data = record->data_; const auto& data2 = record2->data_; - if (data.type_ == candy::DataType::Float32) { + if (data.type_ == sageFlow::DataType::Float32) { const auto d1 = reinterpret_cast(data.data_.get()); const auto d2 = reinterpret_cast(data2.data_.get()); for (int i = 0; i < data.dim_; ++i) { @@ -21,8 +21,8 @@ auto Sum(std::unique_ptr& record, std::unique_ptr& record, int size) { - if (const auto& data = record->data_; data.type_ == candy::DataType::Float32) { +void Avg(const std::unique_ptr& record, int size) { + if (const auto& data = record->data_; data.type_ == sageFlow::DataType::Float32) { const auto d1 = reinterpret_cast(data.data_.get()); for (int i = 0; i < data.dim_; ++i) { d1[i] /= size; @@ -30,7 +30,7 @@ void Avg(const std::unique_ptr& record, int size) { } } -auto candy::AggregateOperator::process(Response&data, int slot) -> std::optional { +auto sageFlow::AggregateOperator::process(Response&data, int slot) -> std::optional { // TODO: 多线程改造 - 聚合算子的并发状态管理 // 在多线程环境中,需要考虑以下改造: // 1. 使用线程安全的累加器或状态管理 @@ -58,7 +58,7 @@ auto candy::AggregateOperator::process(Response&data, int slot) -> std::optional return std::nullopt; } -auto candy::AggregateOperator::apply(Response&& record, int slot, Collector& collector) -> void { +auto sageFlow::AggregateOperator::apply(Response&& record, int slot, Collector& collector) -> void { const auto aggregate_func = dynamic_cast(aggregate_func_.get()); if (record.type_ == ResponseType::List && record.records_) { const auto records = record.records_.get(); diff --git a/src/operator/filter_operator.cpp b/src/operator/filter_operator.cpp index e5e7941..ca7e32d 100644 --- a/src/operator/filter_operator.cpp +++ b/src/operator/filter_operator.cpp @@ -1,9 +1,9 @@ #include "operator/filter_operator.h" -candy::FilterOperator::FilterOperator(std::unique_ptr& filter_func) +sageFlow::FilterOperator::FilterOperator(std::unique_ptr& filter_func) : Operator(OperatorType::FILTER), filter_func_(std::move(filter_func)) {} -auto candy::FilterOperator::process(Response& data, int slot) -> std::optional { +auto sageFlow::FilterOperator::process(Response& data, int slot) -> std::optional { auto resp = filter_func_->Execute(data); if (resp.type_ != ResponseType::None) { return resp; @@ -11,7 +11,7 @@ auto candy::FilterOperator::process(Response& data, int slot) -> std::optional void { +auto sageFlow::FilterOperator::apply(Response&& record, int slot, Collector& collector) -> void { // 使用filter函数处理数据 auto resp = filter_func_->Execute(record); if (resp.type_ != ResponseType::None) { diff --git a/src/operator/itopk_operator.cpp b/src/operator/itopk_operator.cpp index 53ab4e8..c435bb9 100644 --- a/src/operator/itopk_operator.cpp +++ b/src/operator/itopk_operator.cpp @@ -8,7 +8,7 @@ #include "function/itopk_function.h" -candy::ITopkOperator::ITopkOperator(std::unique_ptr& func, +sageFlow::ITopkOperator::ITopkOperator(std::unique_ptr& func, const std::shared_ptr& concurrency_manager) : Operator(OperatorType::ITOPK), itopk_func_(std::move(func)), concurrency_manager_(concurrency_manager) { auto itopk_func = dynamic_cast(itopk_func_.get()); @@ -19,7 +19,7 @@ candy::ITopkOperator::ITopkOperator(std::unique_ptr& func, record_ = itopk_func->getRecord(); } -auto candy::ITopkOperator::process(Response&data, int slot) -> std::optional { +auto sageFlow::ITopkOperator::process(Response&data, int slot) -> std::optional { // TODO: 多线程改造 - ITopK算子的并发状态管理 // 在多线程环境中,需要考虑以下改造: // 1. uid集合(uids_)的并发访问保护,需要使用线程安全的容器或加锁 @@ -68,12 +68,12 @@ auto candy::ITopkOperator::process(Response&data, int slot) -> std::optional std::unique_ptr { +auto sageFlow::ITopkOperator::getRecord() const -> std::unique_ptr { std::lock_guard lock(state_mutex_); return std::make_unique(*record_); } -auto candy::ITopkOperator::apply(Response&& record, int slot, Collector& collector) -> void { +auto sageFlow::ITopkOperator::apply(Response&& record, int slot, Collector& collector) -> void { if (record.type_ == ResponseType::Record) { return; // ITopKOperator通常处理List类型的数据 } diff --git a/src/operator/join_operator.cpp b/src/operator/join_operator.cpp index a05c2bb..8d07424 100644 --- a/src/operator/join_operator.cpp +++ b/src/operator/join_operator.cpp @@ -14,7 +14,7 @@ #include "spdlog/fmt/bundled/chrono.h" -namespace candy { +namespace sageFlow { // 旧接口保留(如果未来需要 IVF 特有参数,可扩展重写) void JoinOperator::initializeIVFIndexes(int /*nlist*/, double /*rebuild_threshold*/, int /*nprobes*/) { @@ -112,11 +112,11 @@ auto JoinOperator::updateSideThreadSafe( if (use_index_ && concurrency_manager_ && index_id_for_cc != -1) { data_for_index_insert = std::make_unique(*data_ptr); } -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS uint64_t before_lock = ScopedAccumulateAtomic::now_ns(); #endif std::unique_lock lock(records_mutex); -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS { uint64_t waited = ScopedAccumulateAtomic::now_ns() - before_lock; // 锁等待:单独统计 + 计入窗口阶段(使 compute 覆盖锁等待) @@ -127,14 +127,14 @@ auto JoinOperator::updateSideThreadSafe( #endif // 窗口插入阶段(仅插入) { -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS ScopedTimerAtomic t_window_ins(JoinMetrics::instance().window_insert_ns); #endif records.emplace_back(std::move(data_ptr)); } if (use_index_ && concurrency_manager_ && data_for_index_insert && index_id_for_cc != -1) { -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS { ScopedTimerAtomic t_idx(JoinMetrics::instance().index_insert_ns); // lock.unlock(); @@ -144,7 +144,7 @@ auto JoinOperator::updateSideThreadSafe( #else // 解锁可能会导致竞态,索引内部的插入和删除顺序可能和窗口不一致 // lock.unlock(); - CANDY_LOG_DEBUG("JOIN", "Inserting to index id={} uid={} ", index_id_for_cc, data_for_index_insert->uid_); + SAGEFLOW_LOG_DEBUG("JOIN", "Inserting to index id={} uid={} ", index_id_for_cc, data_for_index_insert->uid_); concurrency_manager_->insert(index_id_for_cc, std::move(data_for_index_insert)); // lock.lock(); #endif @@ -154,19 +154,19 @@ auto JoinOperator::updateSideThreadSafe( int64_t timelimit = window.windowTimeLimit(now_time_stamp); // 窗口过期阶段(包含过期判定与容器维护;索引删除单独计时) - CANDY_LOG_DEBUG("JOIN", "Expiring records before timestamp {} now={} current_size={} ", timelimit, now_time_stamp, records.size()); + SAGEFLOW_LOG_DEBUG("JOIN", "Expiring records before timestamp {} now={} current_size={} ", timelimit, now_time_stamp, records.size()); try { // 过期阶段的容器维护开销:将每次 pop_front 计入 window_insert_ns,索引删除计入 index_insert_ns。 while (!records.empty() && records.front()->timestamp_ <= timelimit) { uint64_t expired_uid = records.front()->uid_; { -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS ScopedTimerAtomic t_window_expire_unit(JoinMetrics::instance().window_insert_ns); #endif records.pop_front(); } if (use_index_ && concurrency_manager_ && index_id_for_cc != -1) { -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS ScopedTimerAtomic t_idx_del(JoinMetrics::instance().index_insert_ns); #endif // lock.unlock(); @@ -174,22 +174,22 @@ auto JoinOperator::updateSideThreadSafe( // lock.lock(); } } - CANDY_LOG_DEBUG("JOIN", "Expiration loop finished. current_size={} ", records.size()); + SAGEFLOW_LOG_DEBUG("JOIN", "Expiration loop finished. current_size={} ", records.size()); } catch (const std::exception& e) { - CANDY_LOG_ERROR("JOIN", "Exception during expiration: what={} ", e.what()); + SAGEFLOW_LOG_ERROR("JOIN", "Exception during expiration: what={} ", e.what()); } - CANDY_LOG_DEBUG("JOIN", "Before unlocking records mutex. size={} ", records.size()); + SAGEFLOW_LOG_DEBUG("JOIN", "Before unlocking records mutex. size={} ", records.size()); lock.unlock(); - CANDY_LOG_DEBUG("JOIN", "After unlocking records mutex; computing trigger."); + SAGEFLOW_LOG_DEBUG("JOIN", "After unlocking records mutex; computing trigger."); bool needTrigger = false; try { needTrigger = window.isNeedTrigger(now_time_stamp); } catch (const std::exception& e) { - CANDY_LOG_ERROR("JOIN", "Exception during isNeedTrigger: what={} ", e.what()); + SAGEFLOW_LOG_ERROR("JOIN", "Exception during isNeedTrigger: what={} ", e.what()); throw; } - CANDY_LOG_DEBUG("JOIN", "isNeedTrigger={} ", needTrigger ? 1 : 0); + SAGEFLOW_LOG_DEBUG("JOIN", "isNeedTrigger={} ", needTrigger ? 1 : 0); return needTrigger; } @@ -228,7 +228,7 @@ auto JoinOperator::process(Response& input_data, int slot) -> std::optional> JoinOperator::getCandidates( const std::unique_ptr& data_ptr, int slot) { -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS ScopedTimerAtomic t_fetch(JoinMetrics::instance().candidate_fetch_ns); #endif if (is_eager_) { @@ -237,11 +237,11 @@ std::vector> JoinOperator::getCandidates( std::deque> query_records_copy; // 改为 deque if (slot == left_slot_id_) { // 加锁等待计入 lock_wait -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS uint64_t before_wait = ScopedAccumulateAtomic::now_ns(); #endif std::shared_lock lk(left_records_mutex_); -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS JoinMetrics::instance().lock_wait_ns.fetch_add(ScopedAccumulateAtomic::now_ns() - before_wait, std::memory_order_relaxed); #endif for (auto &p : left_records_) @@ -250,11 +250,11 @@ std::vector> JoinOperator::getCandidates( } } else { // 加锁等待计入 lock_wait -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS uint64_t before_wait = ScopedAccumulateAtomic::now_ns(); #endif std::shared_lock lk(right_records_mutex_); -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS JoinMetrics::instance().lock_wait_ns.fetch_add(ScopedAccumulateAtomic::now_ns() - before_wait, std::memory_order_relaxed); #endif for (auto &p : right_records_) @@ -278,17 +278,17 @@ void JoinOperator::executeJoinForCandidates( const std::unique_ptr& data_ptr, int slot, std::vector>>& local_return_pool) { -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS // 注:similarity_ns 仅用于粗粒度的候选比对阶段计时; ScopedTimerAtomic t_similarity(JoinMetrics::instance().similarity_ns); #endif if (slot == 0) { // 加锁等待计入 lock_wait -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS uint64_t before_wait = ScopedAccumulateAtomic::now_ns(); #endif std::shared_lock rk(right_records_mutex_); -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS JoinMetrics::instance().lock_wait_ns.fetch_add(ScopedAccumulateAtomic::now_ns() - before_wait, std::memory_order_relaxed); #endif for (auto &cand : candidates) { @@ -299,7 +299,7 @@ void JoinOperator::executeJoinForCandidates( uint64_t log_right_uid = right_copy->uid_; Response lhs{ResponseType::Record, std::move(left_copy)}; Response rhs{ResponseType::Record, std::move(right_copy)}; -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS { ScopedTimerAtomic t_joinF(JoinMetrics::instance().join_function_ns); #endif @@ -309,10 +309,10 @@ void JoinOperator::executeJoinForCandidates( if (res.record_) { local_return_pool.emplace_back(left_slot_id_, std::move(res.record_)); } - CANDY_LOG_DEBUG("JOIN_EXEC", "slot={} result_uid={} left_uid={} right_uid={} ", + SAGEFLOW_LOG_DEBUG("JOIN_EXEC", "slot={} result_uid={} left_uid={} right_uid={} ", slot, result_uid, log_left_uid, log_right_uid); } catch (const std::exception& e) { - CANDY_LOG_ERROR("JOIN_EXEC", "slot={} left_dim={} right_dim={} left_uid={} right_uid={} what={} ", + SAGEFLOW_LOG_ERROR("JOIN_EXEC", "slot={} left_dim={} right_dim={} left_uid={} right_uid={} what={} ", slot, (lhs.record_ ? lhs.record_->data_.dim_ : -1), (rhs.record_ ? rhs.record_->data_.dim_ : -1), @@ -321,18 +321,18 @@ void JoinOperator::executeJoinForCandidates( e.what()); throw; // 继续向上抛出以保持现有行为 } -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS } #endif } } } else { // 加锁等待计入 lock_wait -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS uint64_t before_wait = ScopedAccumulateAtomic::now_ns(); #endif std::shared_lock lk(left_records_mutex_); -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS JoinMetrics::instance().lock_wait_ns.fetch_add(ScopedAccumulateAtomic::now_ns() - before_wait, std::memory_order_relaxed); #endif for (auto &cand : candidates) { @@ -343,7 +343,7 @@ void JoinOperator::executeJoinForCandidates( uint64_t log_right_uid = right_copy->uid_; Response lhs{ResponseType::Record, std::move(left_copy)}; Response rhs{ResponseType::Record, std::move(right_copy)}; -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS { ScopedTimerAtomic t_joinF(JoinMetrics::instance().join_function_ns); #endif @@ -353,10 +353,10 @@ void JoinOperator::executeJoinForCandidates( if (res.record_) { local_return_pool.emplace_back(left_slot_id_, std::move(res.record_)); } - CANDY_LOG_DEBUG("JOIN_EXEC", "slot={} result_uid={} left_uid={} right_uid={} ", + SAGEFLOW_LOG_DEBUG("JOIN_EXEC", "slot={} result_uid={} left_uid={} right_uid={} ", slot, result_uid, log_left_uid, log_right_uid); } catch (const std::exception& e) { - CANDY_LOG_ERROR("JOIN_EXEC", "slot={} left_dim={} right_dim={} left_uid={} right_uid={} what={} ", + SAGEFLOW_LOG_ERROR("JOIN_EXEC", "slot={} left_dim={} right_dim={} left_uid={} right_uid={} what={} ", slot, (lhs.record_ ? lhs.record_->data_.dim_ : -1), (rhs.record_ ? rhs.record_->data_.dim_ : -1), @@ -365,7 +365,7 @@ void JoinOperator::executeJoinForCandidates( e.what()); throw; } -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS } #endif } @@ -377,22 +377,22 @@ void JoinOperator::executeLazyJoin( const std::vector>& candidates, int slot, std::vector>>& local_return_pool) { -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS // 统一计量 Lazy 路径的候选匹配阶段 ScopedTimerAtomic t_similarity(JoinMetrics::instance().similarity_ns); #endif if (slot == left_slot_id_) { // 两侧加锁等待计入 lock_wait -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS uint64_t before_wait_r = ScopedAccumulateAtomic::now_ns(); #endif std::shared_lock rk(right_records_mutex_); -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS JoinMetrics::instance().lock_wait_ns.fetch_add(ScopedAccumulateAtomic::now_ns() - before_wait_r, std::memory_order_relaxed); uint64_t before_wait_l = ScopedAccumulateAtomic::now_ns(); #endif std::shared_lock lk(left_records_mutex_); -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS JoinMetrics::instance().lock_wait_ns.fetch_add(ScopedAccumulateAtomic::now_ns() - before_wait_l, std::memory_order_relaxed); #endif for (auto &l : left_records_) { @@ -404,13 +404,13 @@ void JoinOperator::executeLazyJoin( Response lhs{ResponseType::Record, std::move(left_copy)}; Response rhs{ResponseType::Record, std::move(right_copy)}; try { -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS ScopedTimerAtomic t_joinF(JoinMetrics::instance().join_function_ns); #endif auto res = join_func_->Execute(lhs, rhs); if (res.record_) local_return_pool.emplace_back(left_slot_id_, std::move(res.record_)); } catch (const std::exception& e) { - CANDY_LOG_ERROR("JOIN_LAZY", "slot={} left_dim={} right_dim={} left_uid={} right_uid={} what={} ", + SAGEFLOW_LOG_ERROR("JOIN_LAZY", "slot={} left_dim={} right_dim={} left_uid={} right_uid={} what={} ", slot, (lhs.record_ ? lhs.record_->data_.dim_ : -1), (rhs.record_ ? rhs.record_->data_.dim_ : -1), @@ -424,16 +424,16 @@ void JoinOperator::executeLazyJoin( } } else { // 两侧加锁等待计入 lock_wait -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS uint64_t before_wait_l = ScopedAccumulateAtomic::now_ns(); #endif std::shared_lock lk(left_records_mutex_); -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS JoinMetrics::instance().lock_wait_ns.fetch_add(ScopedAccumulateAtomic::now_ns() - before_wait_l, std::memory_order_relaxed); uint64_t before_wait_r = ScopedAccumulateAtomic::now_ns(); #endif std::shared_lock rk(right_records_mutex_); -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS JoinMetrics::instance().lock_wait_ns.fetch_add(ScopedAccumulateAtomic::now_ns() - before_wait_r, std::memory_order_relaxed); #endif for (auto &r : right_records_) { @@ -445,13 +445,13 @@ void JoinOperator::executeLazyJoin( Response lhs{ResponseType::Record, std::move(left_copy)}; Response rhs{ResponseType::Record, std::move(right_copy)}; try { -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS ScopedTimerAtomic t_joinF(JoinMetrics::instance().join_function_ns); #endif auto res = join_func_->Execute(lhs, rhs); if (res.record_) local_return_pool.emplace_back(left_slot_id_, std::move(res.record_)); } catch (const std::exception& e) { - CANDY_LOG_ERROR("JOIN_LAZY", "slot={} left_dim={} right_dim={} left_uid={} right_uid={} what={} ", + SAGEFLOW_LOG_ERROR("JOIN_LAZY", "slot={} left_dim={} right_dim={} left_uid={} right_uid={} what={} ", slot, (lhs.record_ ? lhs.record_->data_.dim_ : -1), (rhs.record_ ? rhs.record_->data_.dim_ : -1), @@ -467,7 +467,7 @@ void JoinOperator::executeLazyJoin( } auto JoinOperator::apply(Response&& record, int slot, Collector& collector) -> void { -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS // 统计 apply 处理总耗时(一次调用一次计数) JoinMetrics::instance().apply_processing_count.fetch_add(1, std::memory_order_relaxed); ScopedTimerAtomic t_apply(JoinMetrics::instance().apply_processing_ns); @@ -477,7 +477,7 @@ auto JoinOperator::apply(Response&& record, int slot, Collector& collector) -> v if (!record.record_) return; std::unique_ptr data_ptr = std::make_unique(*record.record_); int64_t now_time_stamp = data_ptr->timestamp_; - CANDY_LOG_DEBUG("JOIN_APPLY", "Apply called slot={} uid={} ts={} dim={} ", slot, data_ptr->uid_, now_time_stamp, data_ptr->data_.dim_); + SAGEFLOW_LOG_DEBUG("JOIN_APPLY", "Apply called slot={} uid={} ts={} dim={} ", slot, data_ptr->uid_, now_time_stamp, data_ptr->data_.dim_); // 重要:为窗口存储拷贝一份,避免 data_ptr 在 updateSideThreadSafe 中被移动导致后续 eager 路径解引用空指针 auto store_ptr = std::make_unique(*data_ptr); bool trigger_flag = (slot == left_slot_id_) @@ -497,7 +497,7 @@ auto JoinOperator::apply(Response&& record, int slot, Collector& collector) -> v std::shared_lock lkR(right_records_mutex_); right_sz = right_records_.size(); } - CANDY_LOG_DEBUG("JOIN_APPLY", "slot={} cand={} left_win={} right_win={} eager={} use_index={} ", + SAGEFLOW_LOG_DEBUG("JOIN_APPLY", "slot={} cand={} left_win={} right_win={} eager={} use_index={} ", slot, candidates.size(), left_sz, right_sz, (is_eager_?1:0), (use_index_?1:0)); std::vector>> local_return_pool; @@ -506,11 +506,11 @@ auto JoinOperator::apply(Response&& record, int slot, Collector& collector) -> v } else { executeLazyJoin(candidates, slot, local_return_pool); // 清理窗口前加锁等待计入 lock_wait 与 window_insert_ns(视为窗口阶段的一部分) -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS uint64_t before_wait_L = ScopedAccumulateAtomic::now_ns(); #endif std::unique_lock lkL(left_records_mutex_); -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS { uint64_t waited = ScopedAccumulateAtomic::now_ns() - before_wait_L; JoinMetrics::instance().lock_wait_ns.fetch_add(waited, std::memory_order_relaxed); @@ -519,7 +519,7 @@ auto JoinOperator::apply(Response&& record, int slot, Collector& collector) -> v uint64_t before_wait_R = ScopedAccumulateAtomic::now_ns(); #endif std::unique_lock lkR(right_records_mutex_); -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS { uint64_t waited = ScopedAccumulateAtomic::now_ns() - before_wait_R; JoinMetrics::instance().lock_wait_ns.fetch_add(waited, std::memory_order_relaxed); @@ -529,14 +529,14 @@ auto JoinOperator::apply(Response&& record, int slot, Collector& collector) -> v left_records_.clear(); right_records_.clear(); } -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS { ScopedTimerAtomic t_emit(JoinMetrics::instance().emit_ns); #endif for (auto &p : local_return_pool) { Response out{ResponseType::Record, std::move(p.second)}; collector.collect(std::make_unique(std::move(out)), p.first); -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS JoinMetrics::instance().total_emits.fetch_add(1,std::memory_order_relaxed); // 端到端延迟:从 apply 进入到对应结果发射的时长(按每条结果计) const uint64_t now_ns = ScopedAccumulateAtomic::now_ns(); @@ -544,9 +544,9 @@ auto JoinOperator::apply(Response&& record, int slot, Collector& collector) -> v JoinMetrics::instance().e2e_latency_count.fetch_add(1, std::memory_order_relaxed); #endif } -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS } #endif } -} // namespace candy +} // namespace sageFlow diff --git a/src/operator/join_operator_methods/base_method.cpp b/src/operator/join_operator_methods/base_method.cpp index 0807a39..741a18e 100644 --- a/src/operator/join_operator_methods/base_method.cpp +++ b/src/operator/join_operator_methods/base_method.cpp @@ -1,9 +1,9 @@ #include "operator/join_operator_methods/base_method.h" -namespace candy { +namespace sageFlow { void BaseMethod::Excute( std::vector>> &emit_pool, - std::unique_ptr &joinfuc, + std::unique_ptr &joinfuc, std::list> &left_records, std::list> &right_records) { @@ -11,7 +11,7 @@ void BaseMethod::Excute( void BaseMethod::Excute( std::vector>> &emit_pool, - std::unique_ptr &joinfuc, + std::unique_ptr &joinfuc, std::unique_ptr &data, std::list> &records, int slot){ diff --git a/src/operator/join_operator_methods/bruteforce.cpp b/src/operator/join_operator_methods/bruteforce.cpp index fe3a65a..b72593b 100644 --- a/src/operator/join_operator_methods/bruteforce.cpp +++ b/src/operator/join_operator_methods/bruteforce.cpp @@ -3,7 +3,7 @@ #include #include "spdlog/spdlog.h" -namespace candy { +namespace sageFlow { std::vector> BruteForceJoinMethod::ExecuteEager(const VectorRecord &query_record, int query_slot) { std::vector> results; @@ -39,4 +39,4 @@ std::vector> BruteForceJoinMethod::ExecuteLazy(con return all_results; } -} // namespace candy +} // namespace sageFlow diff --git a/src/operator/join_operator_methods/eager/bruteforce.cpp b/src/operator/join_operator_methods/eager/bruteforce.cpp index d3886b8..7bbef3a 100644 --- a/src/operator/join_operator_methods/eager/bruteforce.cpp +++ b/src/operator/join_operator_methods/eager/bruteforce.cpp @@ -2,7 +2,7 @@ #include "compute_engine/compute_engine.h" #include -namespace candy { +namespace sageFlow { // 新的构造函数,支持KNN索引 BruteForceEager::BruteForceEager(int left_knn_index_id, @@ -22,7 +22,7 @@ auto BruteForceEager::getOtherStreamKnnIndexId(int data_arrival_slot) const -> i void BruteForceEager::Excute( std::vector>> &emit_pool, - std::unique_ptr &joinfuc, + std::unique_ptr &joinfuc, std::list> &left_records, std::list> &right_records) { @@ -30,7 +30,7 @@ void BruteForceEager::Excute( void BruteForceEager::Excute( std::vector>> &emit_pool, - std::unique_ptr &joinfuc, + std::unique_ptr &joinfuc, std::unique_ptr &data, std::list> &records, int slot) { @@ -117,4 +117,4 @@ std::vector> BruteForceEager::ExecuteLazy( return all_results; } -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/src/operator/join_operator_methods/eager/ivf.cpp b/src/operator/join_operator_methods/eager/ivf.cpp index 6cf6261..e23256d 100644 --- a/src/operator/join_operator_methods/eager/ivf.cpp +++ b/src/operator/join_operator_methods/eager/ivf.cpp @@ -1,7 +1,7 @@ #include "operator/join_operator_methods/eager/ivf.h" #include // Required for std::vector -namespace candy { +namespace sageFlow { // Updated IvfEager constructor IvfEager::IvfEager(int left_ivf_index_id, @@ -142,4 +142,4 @@ std::vector> IvfEager::ExecuteLazy( return all_results; } -} // namespace candy +} // namespace sageFlow diff --git a/src/operator/join_operator_methods/ivf.cpp b/src/operator/join_operator_methods/ivf.cpp index 5a59935..56e2c68 100644 --- a/src/operator/join_operator_methods/ivf.cpp +++ b/src/operator/join_operator_methods/ivf.cpp @@ -2,7 +2,7 @@ #include "utils/logger.h" #include -namespace candy { +namespace sageFlow { std::vector> IvfJoinMethod::ExecuteEager(const VectorRecord &query_record, int query_slot) { std::vector> results; @@ -12,12 +12,12 @@ std::vector> IvfJoinMethod::ExecuteEager(const Vec return results; } auto candidates = concurrency_manager_->query_for_join(idx, query_record, join_similarity_threshold_); - CANDY_LOG_DEBUG("JOIN_IVF", "eager_query slot={} candidates={} ", query_slot, candidates.size()); + SAGEFLOW_LOG_DEBUG("JOIN_IVF", "eager_query slot={} candidates={} ", query_slot, candidates.size()); // LOG输出匹配上的向量和到达向量具体是什么 - CANDY_LOG_DEBUG("JOIN_IVF", "eager_query input uid={} ", query_record.uid_); + SAGEFLOW_LOG_DEBUG("JOIN_IVF", "eager_query input uid={} ", query_record.uid_); for (auto &c : candidates) { if (c) { - CANDY_LOG_DEBUG("JOIN_IVF", "eager_query matched candidate uid={} ", c->uid_); + SAGEFLOW_LOG_DEBUG("JOIN_IVF", "eager_query matched candidate uid={} ", c->uid_); } } results.reserve(candidates.size()); @@ -44,4 +44,4 @@ std::vector> IvfJoinMethod::ExecuteLazy(const std: return all_results; } -} // namespace candy +} // namespace sageFlow diff --git a/src/operator/join_operator_methods/lazy/bruteforce.cpp b/src/operator/join_operator_methods/lazy/bruteforce.cpp index 7b1dae6..f86411e 100644 --- a/src/operator/join_operator_methods/lazy/bruteforce.cpp +++ b/src/operator/join_operator_methods/lazy/bruteforce.cpp @@ -2,7 +2,7 @@ #include "compute_engine/compute_engine.h" #include -namespace candy { +namespace sageFlow { // 新的构造函数,支持KNN索引 BruteForceLazy::BruteForceLazy(int left_knn_index_id, @@ -18,7 +18,7 @@ BruteForceLazy::BruteForceLazy(int left_knn_index_id, void BruteForceLazy::Excute( std::vector>> &emit_pool, - std::unique_ptr &joinfuc, + std::unique_ptr &joinfuc, std::list> &left_records, std::list> &right_records) { @@ -89,7 +89,7 @@ void BruteForceLazy::Excute( void BruteForceLazy::Excute( std::vector>> &emit_pool, - std::unique_ptr &joinfuc, + std::unique_ptr &joinfuc, std::unique_ptr &data, std::list> &records, int slot) { @@ -158,4 +158,4 @@ std::vector> BruteForceLazy::ExecuteLazy( return all_results; } -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/src/operator/join_operator_methods/lazy/ivf.cpp b/src/operator/join_operator_methods/lazy/ivf.cpp index 6d27eea..c82d201 100644 --- a/src/operator/join_operator_methods/lazy/ivf.cpp +++ b/src/operator/join_operator_methods/lazy/ivf.cpp @@ -2,7 +2,7 @@ #include // Required for std::vector #include -namespace candy { +namespace sageFlow { IvfLazy::IvfLazy(int left_ivf_index_id, int right_ivf_index_id, @@ -143,4 +143,4 @@ std::vector> IvfLazy::ExecuteLazy( return all_results; } -} // namespace candy +} // namespace sageFlow diff --git a/src/operator/map_operator.cpp b/src/operator/map_operator.cpp index 8b90db6..945e5b3 100644 --- a/src/operator/map_operator.cpp +++ b/src/operator/map_operator.cpp @@ -1,14 +1,14 @@ #include "operator/map_operator.h" -candy::MapOperator::MapOperator(std::unique_ptr& map_func) +sageFlow::MapOperator::MapOperator(std::unique_ptr& map_func) : Operator(OperatorType::MAP), map_func_(std::move(map_func)) {} -auto candy::MapOperator::process(Response&data, int slot) -> std::optional { +auto sageFlow::MapOperator::process(Response&data, int slot) -> std::optional { auto result = map_func_->Execute(data); return result; } -auto candy::MapOperator::apply(Response&& record, int slot, Collector& collector) -> void { +auto sageFlow::MapOperator::apply(Response&& record, int slot, Collector& collector) -> void { // 使用map函数转换数据 auto result = map_func_->Execute(record); // 将转换后的数据发送给下游 diff --git a/src/operator/operator.cpp b/src/operator/operator.cpp index e647cbf..87c259a 100644 --- a/src/operator/operator.cpp +++ b/src/operator/operator.cpp @@ -1,7 +1,7 @@ #include "operator/operator.h" -candy::Operator::~Operator() = default; +sageFlow::Operator::~Operator() = default; -candy::Operator::Operator(OperatorType type, size_t parallelism) +sageFlow::Operator::Operator(OperatorType type, size_t parallelism) : type_(type) { set_parallelism(parallelism); // 根据算子类型设置默认名称 @@ -19,26 +19,26 @@ candy::Operator::Operator(OperatorType type, size_t parallelism) } } -auto candy::Operator::getType() const -> OperatorType { return type_; } +auto sageFlow::Operator::getType() const -> OperatorType { return type_; } -auto candy::Operator::open() -> void { is_open_ = true; } +auto sageFlow::Operator::open() -> void { is_open_ = true; } -auto candy::Operator::close() -> void { is_open_ = false; } +auto sageFlow::Operator::close() -> void { is_open_ = false; } -auto candy::Operator::process(Response&record, int slot) -> std::optional { +auto sageFlow::Operator::process(Response&record, int slot) -> std::optional { return std::nullopt; } -auto candy::Operator::apply(Response&& record, int slot, Collector& collector) -> void { +auto sageFlow::Operator::apply(Response&& record, int slot, Collector& collector) -> void { // 默认实现:直接将数据传递给下游 collector.collect(std::make_unique(std::move(record)), slot); } -void candy::Operator::set_parallelism(const size_t p) { +void sageFlow::Operator::set_parallelism(const size_t p) { if (p > 0) { parallelism_ = p; } } -auto candy::Operator::get_parallelism() const -> size_t { +auto sageFlow::Operator::get_parallelism() const -> size_t { return parallelism_; } \ No newline at end of file diff --git a/src/operator/output_operator.cpp b/src/operator/output_operator.cpp index 7426c09..058c49d 100644 --- a/src/operator/output_operator.cpp +++ b/src/operator/output_operator.cpp @@ -1,14 +1,14 @@ #include "operator/output_operator.h" -candy::OutputOperator::OutputOperator() : Operator(OperatorType::OUTPUT) {} +sageFlow::OutputOperator::OutputOperator() : Operator(OperatorType::OUTPUT) {} -candy::OutputOperator::OutputOperator(const OutputChoice output_choice, std::shared_ptr stream) +sageFlow::OutputOperator::OutputOperator(const OutputChoice output_choice, std::shared_ptr stream) : Operator(OperatorType::OUTPUT), output_choice_(output_choice), stream_(std::move(stream)) {} -candy::OutputOperator::OutputOperator(std::shared_ptr stream) +sageFlow::OutputOperator::OutputOperator(std::shared_ptr stream) : Operator(OperatorType::OUTPUT), stream_(std::move(stream)) {} -auto candy::OutputOperator::open() -> void { +auto sageFlow::OutputOperator::open() -> void { if (is_open_) { return; } @@ -26,7 +26,7 @@ auto candy::OutputOperator::open() -> void { } -auto candy::OutputOperator::process(Response&data, int slot) -> std::optional { +auto sageFlow::OutputOperator::process(Response&data, int slot) -> std::optional { // OutputOperator作为数据源,通常不需要处理输入数据 // 而是负责从数据流中读取数据并发射到下游 // 在新的collector模式下,OutputOperator主要用于数据生成 @@ -58,7 +58,7 @@ auto candy::OutputOperator::process(Response&data, int slot) -> std::optional void { +auto sageFlow::OutputOperator::run(Collector& collector) -> void { std::unique_ptr record = nullptr; while (stream_ && (record = stream_->Next())) { auto resp = Response{ResponseType::Record, std::move(record)}; @@ -67,7 +67,7 @@ auto candy::OutputOperator::run(Collector& collector) -> void { } } -auto candy::OutputOperator::apply(Response&& record, int slot, Collector& collector) -> void { +auto sageFlow::OutputOperator::apply(Response&& record, int slot, Collector& collector) -> void { if (record.type_ != ResponseType::None) { if (output_choice_ == OutputChoice::Broadcast) { // 广播模式:将数据发送到所有下游slot diff --git a/src/operator/sink_operator.cpp b/src/operator/sink_operator.cpp index a6f9692..d69638e 100644 --- a/src/operator/sink_operator.cpp +++ b/src/operator/sink_operator.cpp @@ -2,19 +2,19 @@ #include "utils/logger.h" -candy::SinkOperator::SinkOperator(std::unique_ptr& sink_func) +sageFlow::SinkOperator::SinkOperator(std::unique_ptr& sink_func) : Operator(OperatorType::SINK), sink_func_(std::move(sink_func)) {} -auto candy::SinkOperator::process(Response&data, int slot) -> std::optional { +auto sageFlow::SinkOperator::process(Response&data, int slot) -> std::optional { auto result = sink_func_->Execute(data); return result; } -auto candy::SinkOperator::apply(Response&& record, int slot, Collector& collector) -> void { +auto sageFlow::SinkOperator::apply(Response&& record, int slot, Collector& collector) -> void { // Sink算子通常是管道的终点,执行sink函数但不向下游发送数据 auto result = sink_func_->Execute(record); - CANDY_LOG_DEBUG("SINK", "slot={} processed record uid={} timestamp={} ", + SAGEFLOW_LOG_DEBUG("SINK", "slot={} processed record uid={} timestamp={} ", slot, (record.record_ ? record.record_->uid_ : 0), (record.record_ ? record.record_->timestamp_ : 0)); diff --git a/src/operator/topk_operator.cpp b/src/operator/topk_operator.cpp index d863d53..8db66f7 100644 --- a/src/operator/topk_operator.cpp +++ b/src/operator/topk_operator.cpp @@ -2,11 +2,11 @@ #include -candy::TopkOperator::TopkOperator(std::unique_ptr& topk_func, +sageFlow::TopkOperator::TopkOperator(std::unique_ptr& topk_func, const std::shared_ptr& concurrency_manager) : Operator(OperatorType::TOPK), topk_func_(std::move(topk_func)), concurrency_manager_(concurrency_manager) {} -auto candy::TopkOperator::process(Response&data, int slot) -> std::optional { +auto sageFlow::TopkOperator::process(Response&data, int slot) -> std::optional { // TODO: 多线程改造 - TopK算子的并发安全 // 在多线程环境中,需要考虑以下改造: // 1. 索引的并发访问保护(concurrency_manager已经处理) @@ -38,7 +38,7 @@ auto candy::TopkOperator::process(Response&data, int slot) -> std::optional void { +auto sageFlow::TopkOperator::apply(Response&& record, int slot, Collector& collector) -> void { auto topk = dynamic_cast(topk_func_.get()); if (record.type_ == ResponseType::Record && record.record_) { // 使用ConcurrencyManager进行线程安全的TopK查询 diff --git a/src/operator/window_operator.cpp b/src/operator/window_operator.cpp index 6664ad0..99633ed 100644 --- a/src/operator/window_operator.cpp +++ b/src/operator/window_operator.cpp @@ -8,19 +8,19 @@ #include "function/window_function.h" -candy::WindowOperator::WindowOperator(std::unique_ptr& window_func) : Operator(OperatorType::WINDOW) {} +sageFlow::WindowOperator::WindowOperator(std::unique_ptr& window_func) : Operator(OperatorType::WINDOW) {} -auto candy::WindowOperator::process(Response&data, int slot) -> std::optional { +auto sageFlow::WindowOperator::process(Response&data, int slot) -> std::optional { return std::nullopt; } -candy::TumblingWindowOperator::TumblingWindowOperator(std::unique_ptr& window_func) +sageFlow::TumblingWindowOperator::TumblingWindowOperator(std::unique_ptr& window_func) : WindowOperator(window_func) { auto window_func_ = dynamic_cast(window_func.get()); window_size_ = window_func_->getWindowSize(); } -auto candy::TumblingWindowOperator::process(Response&data, int slot) -> std::optional { +auto sageFlow::TumblingWindowOperator::process(Response&data, int slot) -> std::optional { // TODO: 多线程改造 - 滚动窗口的并发状态管理 // 在多线程环境中,需要考虑以下改造: // 1. 窗口状态(window_buffer_)的并发访问保护 @@ -49,14 +49,14 @@ auto candy::TumblingWindowOperator::process(Response&data, int slot) -> std::opt return std::nullopt; } -candy::SlidingWindowOperator::SlidingWindowOperator(std::unique_ptr& window_func) +sageFlow::SlidingWindowOperator::SlidingWindowOperator(std::unique_ptr& window_func) : WindowOperator(window_func) { auto window_func_ = dynamic_cast(window_func.get()); window_size_ = window_func_->getWindowSize(); slide_size_ = window_func_->getSlideSize(); } -auto candy::SlidingWindowOperator::process(Response&data, int slot) -> std::optional { +auto sageFlow::SlidingWindowOperator::process(Response&data, int slot) -> std::optional { std::lock_guard lock(window_mutex_); if (data.type_ == ResponseType::Record) { @@ -82,12 +82,12 @@ auto candy::SlidingWindowOperator::process(Response&data, int slot) -> std::opti return std::nullopt; } -auto candy::WindowOperator::apply(Response&& record, int slot, Collector& collector) -> void { +auto sageFlow::WindowOperator::apply(Response&& record, int slot, Collector& collector) -> void { // 基类默认实现,子类需要重写此方法 collector.collect(std::make_unique(std::move(record)), slot); } -auto candy::TumblingWindowOperator::apply(Response&& record, int slot, Collector& collector) -> void { +auto sageFlow::TumblingWindowOperator::apply(Response&& record, int slot, Collector& collector) -> void { std::lock_guard lock(window_mutex_); if (record.type_ == ResponseType::Record && record.record_) { @@ -109,7 +109,7 @@ auto candy::TumblingWindowOperator::apply(Response&& record, int slot, Collector } } -auto candy::SlidingWindowOperator::apply(Response&& record, int slot, Collector& collector) -> void { +auto sageFlow::SlidingWindowOperator::apply(Response&& record, int slot, Collector& collector) -> void { std::lock_guard lock(window_mutex_); if (record.type_ == ResponseType::Record && record.record_) { diff --git a/src/query/optimizer/planner.cpp b/src/query/optimizer/planner.cpp index d7cd16f..cca055b 100644 --- a/src/query/optimizer/planner.cpp +++ b/src/query/optimizer/planner.cpp @@ -10,7 +10,7 @@ #include "operator/itopk_operator.h" #include "operator/window_operator.h" -namespace candy { +namespace sageFlow { Planner::Planner(const std::shared_ptr& concurrency_manager) : concurrency_manager_(concurrency_manager) {} @@ -147,4 +147,4 @@ void Planner::configureOperatorParallelism(std::shared_ptr& op, } } -} // namespace candy +} // namespace sageFlow diff --git a/src/storage/storage_manager.cpp b/src/storage/storage_manager.cpp index 97221cd..62bcb5e 100644 --- a/src/storage/storage_manager.cpp +++ b/src/storage/storage_manager.cpp @@ -4,13 +4,13 @@ #include "utils/logger.h" -auto candy::StorageManager::insert(std::unique_ptr record) -> void { +auto sageFlow::StorageManager::insert(std::unique_ptr record) -> void { if (record == nullptr) { throw std::runtime_error("StorageManager::insert: Attempt to insert a null record."); } std::unique_lock lock(map_mutex_); const auto uid = record->uid_; - CANDY_LOG_DEBUG("STORAGE", "Inserting record uid={} current_size={} ", uid, records_.size()); + SAGEFLOW_LOG_DEBUG("STORAGE", "Inserting record uid={} current_size={} ", uid, records_.size()); if (map_.find(uid) != map_.end()) { return; // UID 已存在 } @@ -20,7 +20,7 @@ auto candy::StorageManager::insert(std::unique_ptr record) -> void map_.emplace(uid, idx); } -// auto candy::StorageManager::insert(std::shared_ptr record) -> void { +// auto sageFlow::StorageManager::insert(std::shared_ptr record) -> void { // if (record == nullptr) { // throw std::runtime_error("StorageManager::insert: Attempt to insert a null record."); // } @@ -34,7 +34,7 @@ auto candy::StorageManager::insert(std::unique_ptr record) -> void // map_.emplace(uid, idx); // } -auto candy::StorageManager::erase(const uint64_t vector_id) -> bool { +auto sageFlow::StorageManager::erase(const uint64_t vector_id) -> bool { std::unique_lock lock(map_mutex_); const auto it = map_.find(vector_id); if (it == map_.end()) { @@ -54,7 +54,7 @@ auto candy::StorageManager::erase(const uint64_t vector_id) -> bool { return true; } -auto candy::StorageManager::getVectorByUid(const uint64_t vector_id) -> std::shared_ptr { +auto sageFlow::StorageManager::getVectorByUid(const uint64_t vector_id) -> std::shared_ptr { std::shared_lock lock(map_mutex_); const auto it = map_.find(vector_id); if (it == map_.end()) { @@ -73,7 +73,7 @@ auto candy::StorageManager::getVectorByUid(const uint64_t vector_id) -> std::sha return records_[index]; } -auto candy::StorageManager::getVectorsByUids(const std::vector& vector_ids) +auto sageFlow::StorageManager::getVectorsByUids(const std::vector& vector_ids) -> std::vector> { std::vector> records; records.reserve(vector_ids.size()); @@ -95,7 +95,7 @@ auto candy::StorageManager::getVectorsByUids(const std::vector& vector return records; } -auto candy::StorageManager::topk(const VectorRecord& record, int k) const -> std::vector { +auto sageFlow::StorageManager::topk(const VectorRecord& record, int k) const -> std::vector { if (engine_ == nullptr) { throw std::runtime_error("StorageManager::topk: Compute engine is not set."); } @@ -144,7 +144,7 @@ auto candy::StorageManager::topk(const VectorRecord& record, int k) const -> std return final_ids; } -auto candy::StorageManager::similarityJoinQuery(const VectorRecord &record, double join_similarity_threshold) const -> std::vector { +auto sageFlow::StorageManager::similarityJoinQuery(const VectorRecord &record, double join_similarity_threshold) const -> std::vector { if (engine_ == nullptr) { throw std::runtime_error("StorageManager::similarityJoinQuery: Compute engine is not set."); } diff --git a/src/stream/data_stream_source/data_stream_source.cpp b/src/stream/data_stream_source/data_stream_source.cpp index b67baf9..62faa38 100644 --- a/src/stream/data_stream_source/data_stream_source.cpp +++ b/src/stream/data_stream_source/data_stream_source.cpp @@ -3,9 +3,9 @@ // #include "stream/data_stream_source/data_stream_source.h" -candy::DataStreamSource::DataStreamSource(std::string name, const DataStreamSourceType type) +sageFlow::DataStreamSource::DataStreamSource(std::string name, const DataStreamSourceType type) : Stream(std::move(name)), type_(type) {} -auto candy::DataStreamSource::getType() const -> DataStreamSourceType { return type_; } +auto sageFlow::DataStreamSource::getType() const -> DataStreamSourceType { return type_; } -void candy::DataStreamSource::setType(const DataStreamSourceType type) { type_ = type; } +void sageFlow::DataStreamSource::setType(const DataStreamSourceType type) { type_ = type; } diff --git a/src/stream/data_stream_source/file_stream_source.cpp b/src/stream/data_stream_source/file_stream_source.cpp index 896d9f2..614f570 100644 --- a/src/stream/data_stream_source/file_stream_source.cpp +++ b/src/stream/data_stream_source/file_stream_source.cpp @@ -12,18 +12,18 @@ #include "common/data_types.h" -candy::FileStreamSource::FileStreamSource(std::string name) +sageFlow::FileStreamSource::FileStreamSource(std::string name) : DataStreamSource(std::move(name), DataStreamSourceType::File) {} -candy::FileStreamSource::FileStreamSource(std::string name, std::string file_path) +sageFlow::FileStreamSource::FileStreamSource(std::string name, std::string file_path) : DataStreamSource(std::move(name), DataStreamSourceType::File), file_path_(std::move(file_path)) {} -void candy::FileStreamSource::Init() { +void sageFlow::FileStreamSource::Init() { running_ = true; std::thread([this]() { std::ifstream file(file_path_, std::ios::binary); if (!file.is_open()) { - CANDY_LOG_ERROR("SOURCE", "open_fail path={} ", file_path_); + SAGEFLOW_LOG_ERROR("SOURCE", "open_fail path={} ", file_path_); running_ = false; return; } @@ -76,7 +76,7 @@ void candy::FileStreamSource::Init() { std::chrono::duration_cast(std::chrono::steady_clock::now() - last_data_time) .count(); if (static_cast(elapsed) > timeout_ms_) { - CANDY_LOG_WARN("SOURCE", "timeout elapsed_ms={} path={} ", elapsed, file_path_); + SAGEFLOW_LOG_WARN("SOURCE", "timeout elapsed_ms={} path={} ", elapsed, file_path_); break; } } @@ -132,7 +132,7 @@ void candy::FileStreamSource::Init() { }).detach(); } -auto candy::FileStreamSource::Next() -> std::unique_ptr { +auto sageFlow::FileStreamSource::Next() -> std::unique_ptr { std::lock_guard lock(mtx_); if (records_.empty()) { return nullptr; diff --git a/src/stream/data_stream_source/sift_stream_source.cpp b/src/stream/data_stream_source/sift_stream_source.cpp index 884aed1..846592b 100644 --- a/src/stream/data_stream_source/sift_stream_source.cpp +++ b/src/stream/data_stream_source/sift_stream_source.cpp @@ -8,16 +8,16 @@ #include #include "utils/logger.h" -candy::SiftStreamSource::SiftStreamSource(std::string name) +sageFlow::SiftStreamSource::SiftStreamSource(std::string name) : DataStreamSource(std::move(name), DataStreamSourceType::None) {} -candy::SiftStreamSource::SiftStreamSource(std::string name, std::string file_path) +sageFlow::SiftStreamSource::SiftStreamSource(std::string name, std::string file_path) : DataStreamSource(std::move(name), DataStreamSourceType::None), file_path_(std::move(file_path)) {} -void candy::SiftStreamSource::Init() { +void sageFlow::SiftStreamSource::Init() { std::ifstream file(file_path_, std::ios::binary); if (!file.is_open()) { - CANDY_LOG_ERROR("SOURCE", "open_fail path={} ", file_path_); + SAGEFLOW_LOG_ERROR("SOURCE", "open_fail path={} ", file_path_); return; } @@ -38,7 +38,7 @@ void candy::SiftStreamSource::Init() { if (!file.good() && !file.eof()) { // Error reading the vector data delete[] vector_data; - CANDY_LOG_ERROR("SOURCE", "read_vector_fail path={} index={} ", file_path_, records_.size()); + SAGEFLOW_LOG_ERROR("SOURCE", "read_vector_fail path={} index={} ", file_path_, records_.size()); break; } @@ -57,10 +57,10 @@ void candy::SiftStreamSource::Init() { } file.close(); - CANDY_LOG_INFO("SOURCE", "sift_loaded count={} path={} ", records_.size(), file_path_); + SAGEFLOW_LOG_INFO("SOURCE", "sift_loaded count={} path={} ", records_.size(), file_path_); } -auto candy::SiftStreamSource::Next() -> std::unique_ptr { +auto sageFlow::SiftStreamSource::Next() -> std::unique_ptr { if (records_.empty()) { return nullptr; } diff --git a/src/stream/data_stream_source/simple_stream_source.cpp b/src/stream/data_stream_source/simple_stream_source.cpp index 87ea24b..94718ee 100644 --- a/src/stream/data_stream_source/simple_stream_source.cpp +++ b/src/stream/data_stream_source/simple_stream_source.cpp @@ -8,21 +8,21 @@ #include #include "utils/logger.h" -candy::SimpleStreamSource::SimpleStreamSource(std::string name) +sageFlow::SimpleStreamSource::SimpleStreamSource(std::string name) : DataStreamSource(std::move(name), DataStreamSourceType::None) {} -candy::SimpleStreamSource::SimpleStreamSource(std::string name, std::string file_path) +sageFlow::SimpleStreamSource::SimpleStreamSource(std::string name, std::string file_path) : DataStreamSource(std::move(name), DataStreamSourceType::None), file_path_(std::move(file_path)) {} -void candy::SimpleStreamSource::Init() { +void sageFlow::SimpleStreamSource::Init() { if (file_path_.empty()) { // 测试环境下可为空:不加载任何记录 - CANDY_LOG_INFO("SOURCE", "SimpleStreamSource empty path name={} ", name_); + SAGEFLOW_LOG_INFO("SOURCE", "SimpleStreamSource empty path name={} ", name_); return; } std::ifstream file(file_path_, std::ios::binary); if (!file.is_open()) { - CANDY_LOG_ERROR("SOURCE", "open_fail path={} ", file_path_); + SAGEFLOW_LOG_ERROR("SOURCE", "open_fail path={} ", file_path_); return; } auto record_cnt = 0; @@ -31,7 +31,7 @@ void candy::SimpleStreamSource::Init() { for (int i = 0; i < record_cnt; ++i) { auto record = std::make_unique(0, 0, 0, DataType::None, nullptr); if (!record->Deserialize(file)) { - CANDY_LOG_ERROR("SOURCE", "deserialize_fail index={} path={} ", i, file_path_); + SAGEFLOW_LOG_ERROR("SOURCE", "deserialize_fail index={} path={} ", i, file_path_); break; } records_.push_back(std::move(record)); @@ -39,7 +39,7 @@ void candy::SimpleStreamSource::Init() { file.close(); } -auto candy::SimpleStreamSource::Next() -> std::unique_ptr { +auto sageFlow::SimpleStreamSource::Next() -> std::unique_ptr { if (records_.empty()) { return nullptr; } diff --git a/src/stream/stream.cpp b/src/stream/stream.cpp index d39d440..1b825f1 100644 --- a/src/stream/stream.cpp +++ b/src/stream/stream.cpp @@ -2,7 +2,7 @@ #include "function/function_api.h" -namespace candy { +namespace sageFlow { auto Stream::filter(std::unique_ptr& filter_func, size_t parallelism) -> std::shared_ptr { return filter(std::move(filter_func), parallelism); @@ -119,4 +119,4 @@ auto Stream::writeSink(std::unique_ptr sink_func, size_t paralleli streams_.push_back(stream); return stream; } -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/src/stream/stream_environment.cpp b/src/stream/stream_environment.cpp index 6a86545..2db0932 100644 --- a/src/stream/stream_environment.cpp +++ b/src/stream/stream_environment.cpp @@ -7,7 +7,7 @@ #include #include -namespace candy { +namespace sageFlow { auto StreamEnvironment::loadConfiguration(const std::string &file_path) -> ConfigMap { ConfigMap config; @@ -24,11 +24,11 @@ auto StreamEnvironment::execute() -> void { } if (is_running_) { - CANDY_LOG_WARN("ENV", "StreamEnvironment already running"); + SAGEFLOW_LOG_WARN("ENV", "StreamEnvironment already running"); return; } - CANDY_LOG_INFO("ENV", "Building execution graph streams={} ", streams_.size()); + SAGEFLOW_LOG_INFO("ENV", "Building execution graph streams={} ", streams_.size()); // 先为每个根源流按顺序分配 slotId,并写入 Stream 上 int next_slot = 0; @@ -49,9 +49,9 @@ auto StreamEnvironment::execute() -> void { execution_graph_->start(); is_running_ = true; - CANDY_LOG_INFO("ENV", "StreamEnvironment started"); + SAGEFLOW_LOG_INFO("ENV", "StreamEnvironment started"); - // 若配置中存在扁平 key "log.level" 则应用;否则环境变量 CANDY_LOG_LEVEL 覆盖 + // 若配置中存在扁平 key "log.level" 则应用;否则环境变量 SAGEFLOW_LOG_LEVEL 覆盖 try { if (config_.exist("log.level")) { std::string lvl = std::get(config_.getValue("log.level")); @@ -60,7 +60,7 @@ auto StreamEnvironment::execute() -> void { init_log_level(""); } } catch(const std::exception &e) { - CANDY_LOG_WARN("ENV", "log_level_config_failed what={} ", e.what()); + SAGEFLOW_LOG_WARN("ENV", "log_level_config_failed what={} ", e.what()); } } @@ -69,7 +69,7 @@ auto StreamEnvironment::stop() -> void { return; } - CANDY_LOG_INFO("ENV", "Stopping StreamEnvironment..."); + SAGEFLOW_LOG_INFO("ENV", "Stopping StreamEnvironment..."); execution_graph_->stop(); is_running_ = false; } @@ -78,7 +78,7 @@ auto StreamEnvironment::awaitTermination() -> void { // 即使 is_running_ 已变为 false(stop 后),也需要 join,确保线程收敛 execution_graph_->join(); is_running_ = false; - CANDY_LOG_INFO("ENV", "StreamEnvironment terminated"); + SAGEFLOW_LOG_INFO("ENV", "StreamEnvironment terminated"); } auto StreamEnvironment::setParallelism(size_t parallelism) -> void { @@ -106,7 +106,7 @@ void StreamEnvironment::reset() { planner_ = std::make_shared(concurrency_manager_); execution_graph_ = std::make_unique(); is_running_ = false; - CANDY_LOG_INFO("ENV", "StreamEnvironment reset completed"); + SAGEFLOW_LOG_INFO("ENV", "StreamEnvironment reset completed"); } -} // namespace candy +} // namespace sageFlow diff --git a/src/utils/log_config.cpp b/src/utils/log_config.cpp index 09a2a96..3a59ae9 100644 --- a/src/utils/log_config.cpp +++ b/src/utils/log_config.cpp @@ -2,7 +2,7 @@ #include "utils/logger.h" #include -namespace candy { +namespace sageFlow { void apply_log_level(spdlog::level::level_enum lvl) { auto lg = get_logger(); @@ -11,14 +11,14 @@ void apply_log_level(spdlog::level::level_enum lvl) { for (auto &s : lg->sinks()) { s->set_level(lvl); } - CANDY_LOG_INFO("LOG", "log_level_applied level={} ", spdlog::level::to_string_view(lvl)); + SAGEFLOW_LOG_INFO("LOG", "log_level_applied level={} ", spdlog::level::to_string_view(lvl)); } void init_log_level(const std::string &level_from_config) { - const char* env = std::getenv("CANDY_LOG_LEVEL"); + const char* env = std::getenv("SAGEFLOW_LOG_LEVEL"); std::string chosen = env && *env ? std::string(env) : level_from_config; auto lvl = parse_log_level(chosen); apply_log_level(lvl); } -} // namespace candy +} // namespace sageFlow diff --git a/src/utils/monitoring.cpp b/src/utils/monitoring.cpp index 9d07b14..3381167 100644 --- a/src/utils/monitoring.cpp +++ b/src/utils/monitoring.cpp @@ -6,7 +6,7 @@ #include #include "utils/logger.h" -namespace candy { +namespace sageFlow { PerformanceMonitor::PerformanceMonitor(std::string profile_output) : profile_output_file_(std::move(profile_output)), profiling_(false) {} @@ -22,9 +22,9 @@ void PerformanceMonitor::StartProfiling() { if (!profiling_) { ProfilerStart(profile_output_file_.c_str()); profiling_ = true; - CANDY_LOG_INFO("MONITOR", "profiling_started file={} ", profile_output_file_); + SAGEFLOW_LOG_INFO("MONITOR", "profiling_started file={} ", profile_output_file_); } else { - CANDY_LOG_WARN("MONITOR", "profiling_already_running file={} ", profile_output_file_); + SAGEFLOW_LOG_WARN("MONITOR", "profiling_already_running file={} ", profile_output_file_); } #else std::cerr << "Profiling not available: gperftools not found." << '\n'; @@ -36,9 +36,9 @@ void PerformanceMonitor::StopProfiling() { if (profiling_) { ProfilerStop(); profiling_ = false; - CANDY_LOG_INFO("MONITOR", "profiling_stopped file={} ", profile_output_file_); + SAGEFLOW_LOG_INFO("MONITOR", "profiling_stopped file={} ", profile_output_file_); } else { - CANDY_LOG_WARN("MONITOR", "profiling_not_running file={} ", profile_output_file_); + SAGEFLOW_LOG_WARN("MONITOR", "profiling_not_running file={} ", profile_output_file_); } #else std::cerr << "Profiling not available: gperftools not found." << '\n'; @@ -47,13 +47,13 @@ void PerformanceMonitor::StopProfiling() { void PerformanceMonitor::StartTimer() { start_time_ = std::chrono::high_resolution_clock::now(); - CANDY_LOG_INFO("MONITOR", "timer_started"); + SAGEFLOW_LOG_INFO("MONITOR", "timer_started"); } void PerformanceMonitor::StopTimer(const std::string &task_name) { const auto end_time = std::chrono::high_resolution_clock::now(); const auto duration = std::chrono::duration_cast(end_time - start_time_).count(); - CANDY_LOG_INFO("MONITOR", "task_done name={} duration_ms={} ", task_name, duration); + SAGEFLOW_LOG_INFO("MONITOR", "task_done name={} duration_ms={} ", task_name, duration); } -} // namespace candy +} // namespace sageFlow diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1387695..e0630e8 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -13,6 +13,13 @@ add_library(test_data_support test_utils/test_data_generator.cpp test_utils/test_config_manager.cpp test_utils/dynamic_config.cpp + test_utils/data_source/random_data_source.cpp + test_utils/data_source/dataset_data_source.cpp + test_utils/data_source/json_data_source.cpp + test_utils/data_writer/fvecs_writer.cpp + test_utils/data_writer/json_writer.cpp + test_utils/join_data_source.cpp + test_utils/join_test_helper.cpp ) # 头文件路径 @@ -23,21 +30,21 @@ target_include_directories(test_data_support ${CMAKE_CURRENT_SOURCE_DIR} ) -# 依赖主工程接口库 candy(其 transitively 链接全部内部组件) +# 依赖主工程接口库 sageflow(其 transitively 链接全部内部组件) target_link_libraries(test_data_support PRIVATE - candy + sageflow ) -# 向所有测试目标注入的公共宏定义(基线:PROJECT_DIR;按选项追加 CANDY_ENABLE_METRICS) +# 向所有测试目标注入的公共宏定义(基线:PROJECT_DIR;按选项追加 SAGEFLOW_ENABLE_METRICS) set(TEST_COMMON_DEFINES PROJECT_DIR="${CMAKE_SOURCE_DIR}" ) -if(CANDY_ENABLE_METRICS) - list(APPEND TEST_COMMON_DEFINES CANDY_ENABLE_METRICS=1) +if(SAGEFLOW_ENABLE_METRICS) + list(APPEND TEST_COMMON_DEFINES SAGEFLOW_ENABLE_METRICS=1) endif() -# 让公共库也拥有相同的编译宏(例如 PROJECT_DIR、可选的 CANDY_ENABLE_METRICS) +# 让公共库也拥有相同的编译宏(例如 PROJECT_DIR、可选的 SAGEFLOW_ENABLE_METRICS) target_compile_definitions(test_data_support PRIVATE ${TEST_COMMON_DEFINES}) # ----------------------------------------------------------------------------- @@ -50,7 +57,7 @@ function(register_test TEST_NAME REL_PATH TIMEOUT_SECS GROUP_TAG) return() endif() add_gtest(${TEST_NAME} ${REL_PATH}) # 宏自动链接 gtest_main - target_link_libraries(${TEST_NAME} PRIVATE test_data_support candy) + target_link_libraries(${TEST_NAME} PRIVATE test_data_support sageflow) target_compile_definitions(${TEST_NAME} PRIVATE ${TEST_COMMON_DEFINES}) # 为每个测试目标补充包含路径:工程 include/ 与 test/ 根目录(包含 utils/、UnitTest/、Performance/ 等) target_include_directories(${TEST_NAME} PRIVATE ${CMAKE_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}) @@ -65,6 +72,9 @@ set(UNIT_TEST_SPECS test_join_ivf UnitTest/test_join_ivf.cpp 300 UNIT test_compute_engine UnitTest/test_compute_engine.cpp 180 UNIT test_file_stream_source UnitTest/test_file_stream_source.cpp 120 UNIT + test_data_source UnitTest/test_data_source.cpp 120 UNIT + test_data_persistence UnitTest/test_data_persistence.cpp 120 UNIT + test_join_data_source UnitTest/test_join_data_source.cpp 120 UNIT ) list(LENGTH UNIT_TEST_SPECS _ulen) @@ -85,6 +95,7 @@ set(PERF_TEST_SPECS test_join_perf_scaling Performance/test_join_perf_scaling.cpp 900 PERF test_join_performance_methods Performance/test_join_performance_methods.cpp 900 PERF test_window_pipeline Performance/test_window_pipeline.cpp 600 PERF + test_join_datasource_modes Performance/test_join_datasource_modes.cpp 900 PERF ) list(LENGTH PERF_TEST_SPECS _plen) @@ -103,7 +114,7 @@ endforeach() # ----------------------------------------------------------------------------- if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/Performance/IndexTest.cpp") add_executable(IndexTest Performance/IndexTest.cpp) - target_link_libraries(IndexTest PRIVATE test_data_support candy) + target_link_libraries(IndexTest PRIVATE test_data_support sageflow) target_compile_definitions(IndexTest PRIVATE ${TEST_COMMON_DEFINES}) target_include_directories(IndexTest PRIVATE ${CMAKE_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}) add_test(NAME IndexTest COMMAND IndexTest) @@ -141,7 +152,7 @@ endif() # 聚合自定义目标 # ----------------------------------------------------------------------------- # 仅依赖已创建的测试可执行(未显式依赖 ctest 运行,方便 IDE 构建) -set(ALL_UNIT_TARGETS test_join_bruteforce test_join_ivf test_compute_engine test_file_stream_source) +set(ALL_UNIT_TARGETS test_join_bruteforce test_join_ivf test_compute_engine test_file_stream_source test_data_source test_data_persistence test_join_data_source) set(ALL_PERF_TARGETS test_join_perf_scaling test_join_performance_methods test_window_pipeline IndexTest) set(ALL_INTEG_TARGETS test_pipeline_basic test_pipeline test_pipeline_execution) @@ -181,8 +192,30 @@ add_custom_target(run_quick_tests # 输出说明 message(STATUS "Test targets registered: UNIT=${ALL_UNIT_TARGETS}; PERF=${ALL_PERF_TARGETS}; INTEGRATION=${ALL_INTEG_TARGETS}") -if(CANDY_ENABLE_METRICS) - message(STATUS "CANDY_ENABLE_METRICS enabled for tests to collect performance data") +if(SAGEFLOW_ENABLE_METRICS) + message(STATUS "SAGEFLOW_ENABLE_METRICS enabled for tests to collect performance data") else() - message(STATUS "CANDY_ENABLE_METRICS disabled for tests (metrics code excluded)") + message(STATUS "SAGEFLOW_ENABLE_METRICS disabled for tests (metrics code excluded)") +endif() + +# ----------------------------------------------------------------------------- +# Optional: Data Source Example (demonstrates the data source framework) +# ----------------------------------------------------------------------------- +if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/examples/test_data_source_example.cpp") + add_executable(test_data_source_example examples/test_data_source_example.cpp) + target_link_libraries(test_data_source_example PRIVATE test_data_support sageflow) + target_compile_definitions(test_data_source_example PRIVATE ${TEST_COMMON_DEFINES}) + target_include_directories(test_data_source_example PRIVATE ${CMAKE_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}) + message(STATUS "Added optional example: test_data_source_example") +endif() + +# ----------------------------------------------------------------------------- +# Optional: Data Persistence Example (demonstrates save/load functionality) +# ----------------------------------------------------------------------------- +if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/examples/data_persistence_example.cpp") + add_executable(data_persistence_example examples/data_persistence_example.cpp) + target_link_libraries(data_persistence_example PRIVATE test_data_support sageflow) + target_compile_definitions(data_persistence_example PRIVATE ${TEST_COMMON_DEFINES}) + target_include_directories(data_persistence_example PRIVATE ${CMAKE_SOURCE_DIR}/include ${CMAKE_CURRENT_SOURCE_DIR}) + message(STATUS "Added optional example: data_persistence_example") endif() \ No newline at end of file diff --git a/test/IntegrationTest/test_pipeline.cpp b/test/IntegrationTest/test_pipeline.cpp index e0fb558..f2f4874 100644 --- a/test/IntegrationTest/test_pipeline.cpp +++ b/test/IntegrationTest/test_pipeline.cpp @@ -23,7 +23,7 @@ #include "function/sink_function.h" #include "function/join_function.h" -namespace candy { +namespace sageFlow { namespace test { /** @@ -308,10 +308,10 @@ TEST_F(PipelineConstructionTest, JoinPipelineConstruction) { // 3. 构建Join流水线,指定并行度 // 从配置读取 Join 配置 - candy::test::PipelineConfig pipeline_cfg{}; + sageFlow::test::PipelineConfig pipeline_cfg{}; std::string method = "bruteforce_lazy"; double threshold = 0.8; - candy::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg); + sageFlow::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg); method = pipeline_cfg.join_method; threshold = pipeline_cfg.similarity_threshold; EXPECT_NO_THROW( @@ -441,7 +441,7 @@ TEST_F(PipelineConstructionTest, BasicPipelineConstruction) { "BasicSink", [](std::unique_ptr& record) { // 简单的打印输出 - CANDY_LOG_INFO("TEST", "Processing record uid={} ", record->uid_); + SAGEFLOW_LOG_INFO("TEST", "Processing record uid={} ", record->uid_); } ), 1) // 明确指定Sink并行度为1 ); @@ -576,8 +576,8 @@ TEST_F(PipelineConstructionTest, CompleteJoinPipelineExecution) { // 6. 构建Join流水线,设置Join算子并行度为3 // 从配置读取 Join 配置 - candy::test::PipelineConfig pipeline_cfg2{}; - candy::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg2); + sageFlow::test::PipelineConfig pipeline_cfg2{}; + sageFlow::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg2); auto join_result = filtered_left->join(filtered_right, std::move(join_func), pipeline_cfg2.join_method, pipeline_cfg2.similarity_threshold, 3); // 7. 对Join结果进行后处理 @@ -614,7 +614,7 @@ TEST_F(PipelineConstructionTest, CompleteJoinPipelineExecution) { // Join结果应该包含满足条件的记录 EXPECT_GT(test_results_.size(), 0); - CANDY_LOG_INFO("TEST", "Join pipeline processed records={} ", test_results_.size()); + SAGEFLOW_LOG_INFO("TEST", "Join pipeline processed records={} ", test_results_.size()); // 验证每个结果都有正确的标记 for (const auto& record : test_results_) { @@ -627,7 +627,7 @@ TEST_F(PipelineConstructionTest, CompleteJoinPipelineExecution) { EXPECT_GE(float_data[i], 0.0f); // 向量元素应该为正数 } - CANDY_LOG_INFO("TEST", "Join result uid={} v0={} ", record->uid_, float_data[0]); + SAGEFLOW_LOG_INFO("TEST", "Join result uid={} v0={} ", record->uid_, float_data[0]); } } @@ -660,8 +660,8 @@ TEST_F(PipelineConstructionTest, JoinParallelismPerformanceTest) { auto start_time = std::chrono::high_resolution_clock::now(); // 从配置读取 Join 配置 - candy::test::PipelineConfig pipeline_cfg3{}; - candy::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg3); + sageFlow::test::PipelineConfig pipeline_cfg3{}; + sageFlow::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg3); left_source->join(right_source, std::move(join_func), pipeline_cfg3.join_method, pipeline_cfg3.similarity_threshold, 4) // Join并行度为4 ->writeSink(std::make_unique( "PerformanceSink", @@ -683,7 +683,7 @@ TEST_F(PipelineConstructionTest, JoinParallelismPerformanceTest) { std::lock_guard lock(result_mutex_); - CANDY_LOG_INFO("TEST", "High parallelism join completed duration_ms={} results={} ", duration.count(), test_results_.size()); + SAGEFLOW_LOG_INFO("TEST", "High parallelism join completed duration_ms={} results={} ", duration.count(), test_results_.size()); // 验证Join效率:应该能找到匹配的记录 EXPECT_GT(test_results_.size(), 0); @@ -695,4 +695,4 @@ TEST_F(PipelineConstructionTest, JoinParallelismPerformanceTest) { } } // namespace test -} // namespace candy +} // namespace sageFlow diff --git a/test/IntegrationTest/test_pipeline_basic.cpp b/test/IntegrationTest/test_pipeline_basic.cpp index 141acd0..32590da 100644 --- a/test/IntegrationTest/test_pipeline_basic.cpp +++ b/test/IntegrationTest/test_pipeline_basic.cpp @@ -23,8 +23,9 @@ #include "test_utils/test_data_generator.h" #include "test_utils/test_config_manager.h" #include "test_utils/test_data_adapter.h" +#include "test_utils/join_test_helper.h" -namespace candy { +namespace sageFlow { namespace test { // 自适应等待:直到 Join 消费完预期输入并且输出在短时间内稳定 @@ -38,7 +39,7 @@ static void wait_until_processed_and_stable(size_t expected_left, uint64_t r = JoinMetrics::instance().total_records_right.load(); if (l >= expected_left && r >= expected_right) break; if (std::chrono::steady_clock::now() >= deadline) { - CANDY_LOG_WARN("TEST", "wait_for_processed timeout l={}/{} r={}/{}", l, expected_left, r, expected_right); + SAGEFLOW_LOG_WARN("TEST", "wait_for_processed timeout l={}/{} r={}/{}", l, expected_left, r, expected_right); break; } std::this_thread::sleep_for(5ms); @@ -191,8 +192,8 @@ TEST_F(MultiThreadPipelineTest, BasicPipelineConstruction) { }); // 从配置读取 Join 配置 - candy::test::PipelineConfig pipeline_cfg{}; - if (candy::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg)) { + sageFlow::test::PipelineConfig pipeline_cfg{}; + if (sageFlow::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg)) { // 应用窗口配置 join_func_direct->setWindow(pipeline_cfg.window.time_ms, pipeline_cfg.window.trigger_interval_ms); // 通过 Stream API 构建链式算子链,并设置并行度(使用配置的 join 方法与阈值) @@ -206,7 +207,7 @@ TEST_F(MultiThreadPipelineTest, BasicPipelineConstruction) { wait_until_stable_only(std::chrono::milliseconds(100), std::chrono::seconds(10)); EXPECT_NO_THROW(env_->stop()) << "StreamEnvironment stop should not throw"; EXPECT_NO_THROW(env_->awaitTermination()) << "StreamEnvironment awaitTermination should not throw"; - CANDY_LOG_INFO("TEST", "pipeline construction success"); + SAGEFLOW_LOG_INFO("TEST", "pipeline construction success"); return; } // 回退:若配置加载失败,使用默认 join 参数 @@ -225,7 +226,7 @@ TEST_F(MultiThreadPipelineTest, BasicPipelineConstruction) { EXPECT_NO_THROW(env_->stop()) << "StreamEnvironment stop should not throw"; EXPECT_NO_THROW(env_->awaitTermination()) << "StreamEnvironment awaitTermination should not throw"; - CANDY_LOG_INFO("TEST", "pipeline construction success"); + SAGEFLOW_LOG_INFO("TEST", "pipeline construction success"); } TEST_F(MultiThreadPipelineTest, ParallelJoinConsistency) { @@ -245,7 +246,7 @@ TEST_F(MultiThreadPipelineTest, ParallelJoinConsistency) { auto [base_records, expected_matches] = generator.generateData(); // 打印所有记录信息,便于排查 - CANDY_LOG_INFO("TEST", "ParallelJoinConsistency dataset: records={} expected_matches_size={} dim={} ", + SAGEFLOW_LOG_INFO("TEST", "ParallelJoinConsistency dataset: records={} expected_matches_size={} dim={} ", base_records.size(), expected_matches.size(), config.vector_dim); for (size_t i = 0; i < base_records.size(); ++i) { const auto& r = base_records[i]; @@ -256,7 +257,7 @@ TEST_F(MultiThreadPipelineTest, ParallelJoinConsistency) { vals += std::to_string(vec[d]); if (d + 1 < vec.size()) vals += ","; } - CANDY_LOG_INFO("TEST", "rec#{} uid={} ts={} values=[{}] ", i, r->uid_, r->timestamp_, vals); + SAGEFLOW_LOG_INFO("TEST", "rec#{} uid={} ts={} values=[{}] ", i, r->uid_, r->timestamp_, vals); } std::vector, PairHash>> results_by_parallelism; @@ -267,22 +268,9 @@ TEST_F(MultiThreadPipelineTest, ParallelJoinConsistency) { // 确保每次循环使用全新执行环境,避免上一次执行残留的队列/线程/索引状态 if (env_) { env_->reset(); } else { env_ = std::make_shared(); } - // 为左右两侧分别复制一份数据 - std::vector> left_records; - left_records.reserve(base_records.size()); - for (const auto& rec : base_records) { - left_records.push_back(std::make_unique(*rec)); - } - std::vector> right_records; - right_records.reserve(base_records.size()); - // 给右侧流的 UID 加偏移,确保左右两侧不共享相同 UID; - // 偏移量保持在 <1e6 内,保证 left*1e6 + right 的编码/解码逻辑仍然成立。 - constexpr uint64_t kRightUidOffset = 500000; - for (const auto& rec : base_records) { - uint64_t new_uid = rec->uid_ + kRightUidOffset; - // 复制数据与时间戳,但使用新的 UID - right_records.push_back(std::make_unique(new_uid, rec->timestamp_, rec->data_)); - } + // 使用 JoinTestHelper 生成左右流(替代手动复制逻辑) + auto [left_records, right_records] = + JoinTestHelper::generateJoinStreamsFromGenerator(generator, true); JoinMetrics::instance().reset(); @@ -320,10 +308,10 @@ TEST_F(MultiThreadPipelineTest, ParallelJoinConsistency) { }); // 从配置读取 Join 配置 - candy::test::PipelineConfig pipeline_cfg{}; + sageFlow::test::PipelineConfig pipeline_cfg{}; std::string method = "bruteforce_lazy"; double threshold = 0.8; - if (candy::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg)) { + if (sageFlow::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg)) { method = pipeline_cfg.join_method; threshold = pipeline_cfg.similarity_threshold; join_func_direct->setWindow(pipeline_cfg.window.time_ms, pipeline_cfg.window.trigger_interval_ms); @@ -346,7 +334,7 @@ TEST_F(MultiThreadPipelineTest, ParallelJoinConsistency) { env_->awaitTermination(); results_by_parallelism.push_back(std::move(result_set)); - CANDY_LOG_INFO("TEST", "Parallelism {} matches={} ", parallelism, results_by_parallelism.back().size()); + SAGEFLOW_LOG_INFO("TEST", "Parallelism {} matches={} ", parallelism, results_by_parallelism.back().size()); } // 放宽一致性:对参考结果(parallelism=1)与其他并行度计算召回率,仅要求 recall >= 0.5 @@ -361,7 +349,7 @@ TEST_F(MultiThreadPipelineTest, ParallelJoinConsistency) { size_t hit = 0; for (const auto& p : ref) if (cur.count(p)) ++hit; double recall = static_cast(hit) / static_cast(ref.size()); - CANDY_LOG_INFO("TEST", "Parallelism={} recall={} ref_size={} cur_size={} ", parallelism_levels[i], recall, ref.size(), cur.size()); + SAGEFLOW_LOG_INFO("TEST", "Parallelism={} recall={} ref_size={} cur_size={} ", parallelism_levels[i], recall, ref.size(), cur.size()); EXPECT_GE(recall, 0.5) << "Recall below threshold for parallelism=" << parallelism_levels[i]; } } @@ -407,10 +395,10 @@ TEST_F(MultiThreadPipelineTest, StressTestMultipleRestarts) { [&](std::unique_ptr& rec) { sink_count.fetch_add(1, std::memory_order_relaxed); }); // 配置驱动 Join 方法 - candy::test::PipelineConfig pipeline_cfg_rs{}; + sageFlow::test::PipelineConfig pipeline_cfg_rs{}; std::string method_rs = "bruteforce_lazy"; double threshold_rs = 0.8; - if (candy::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg_rs)) { + if (sageFlow::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg_rs)) { method_rs = pipeline_cfg_rs.join_method; threshold_rs = pipeline_cfg_rs.similarity_threshold; join_func_direct->setWindow(pipeline_cfg_rs.window.time_ms, pipeline_cfg_rs.window.trigger_interval_ms); @@ -429,10 +417,10 @@ TEST_F(MultiThreadPipelineTest, StressTestMultipleRestarts) { env_->stop(); env_->awaitTermination(); - CANDY_LOG_INFO("TEST", "Restart {} success", restart); + SAGEFLOW_LOG_INFO("TEST", "Restart {} success", restart); } - CANDY_LOG_INFO("TEST", "Stress test restarts completed"); + SAGEFLOW_LOG_INFO("TEST", "Stress test restarts completed"); } TEST_F(MultiThreadPipelineTest, HighConcurrencyDeadlockTest) { @@ -450,19 +438,9 @@ TEST_F(MultiThreadPipelineTest, HighConcurrencyDeadlockTest) { const int high_parallelism = 8; JoinMetrics::instance().reset(); - // 为左右两侧分别复制一份数据 - std::vector> left_records; - left_records.reserve(records.size()); - for (const auto& rec : records) { - left_records.push_back(std::make_unique(*rec)); - } - std::vector> right_records; - right_records.reserve(records.size()); - constexpr uint64_t kRightUidOffsetHC = 500000; - for (const auto& rec : records) { - uint64_t new_uid = rec->uid_ + kRightUidOffsetHC; - right_records.push_back(std::make_unique(new_uid, rec->timestamp_, rec->data_)); - } + // 使用 JoinTestHelper 生成左右流(替代手动复制逻辑) + auto [left_records, right_records] = + JoinTestHelper::generateJoinStreamsFromGenerator(generator, true); auto left_source = std::make_shared("HC_Left", std::move(left_records)); auto right_source = std::make_shared("HC_Right", std::move(right_records)); @@ -490,10 +468,10 @@ TEST_F(MultiThreadPipelineTest, HighConcurrencyDeadlockTest) { // 构建并行 Join // 配置驱动 Join 方法 - candy::test::PipelineConfig pipeline_cfg_hc{}; + sageFlow::test::PipelineConfig pipeline_cfg_hc{}; std::string method_hc = "bruteforce_lazy"; double threshold_hc = 0.8; - if (candy::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg_hc)) { + if (sageFlow::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg_hc)) { method_hc = pipeline_cfg_hc.join_method; threshold_hc = pipeline_cfg_hc.similarity_threshold; join_func_direct->setWindow(pipeline_cfg_hc.window.time_ms, pipeline_cfg_hc.window.trigger_interval_ms); @@ -509,7 +487,7 @@ TEST_F(MultiThreadPipelineTest, HighConcurrencyDeadlockTest) { env_->stop(); env_->awaitTermination(); - CANDY_LOG_INFO("TEST", "High concurrency test completed matches={} lock_wait_ms={} ", + SAGEFLOW_LOG_INFO("TEST", "High concurrency test completed matches={} lock_wait_ms={} ", sink_count.load(), JoinMetrics::instance().lock_wait_ns.load() / 1000000); // 验证无死锁且有合理处理量(以产生结果为标志) @@ -568,7 +546,7 @@ TEST_F(MultiThreadPipelineTest, HighConcurrencyDeadlockTest) { // TestDataGenerator generator(config); // auto [records, expected_matches] = generator.generateData(); // -// CANDY_LOG_INFO("TEST", "Testing configuration method={} Source={} Join={} Data={} ", +// SAGEFLOW_LOG_INFO("TEST", "Testing configuration method={} Source={} Join={} Data={} ", // method, source_parallelism, join_parallelism, data_size); // // // 使用 StreamEnvironment 构建并行流水线 @@ -614,9 +592,9 @@ TEST_F(MultiThreadPipelineTest, HighConcurrencyDeadlockTest) { // [&](std::unique_ptr&) { total_matches.fetch_add(1, std::memory_order_relaxed); }); // // // 从配置读取默认阈值;方法由参数提供 -// candy::test::PipelineConfig pipeline_cfg_param{}; +// sageFlow::test::PipelineConfig pipeline_cfg_param{}; // double threshold_param = 0.8; -// if (candy::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg_param)) { +// if (sageFlow::test::TestConfigManager::loadPipelineConfig("config/join_pipeline_basic.toml", pipeline_cfg_param)) { // threshold_param = pipeline_cfg_param.similarity_threshold; // join_func_direct->setWindow(pipeline_cfg_param.window.time_ms, pipeline_cfg_param.window.trigger_interval_ms); // } @@ -637,7 +615,7 @@ TEST_F(MultiThreadPipelineTest, HighConcurrencyDeadlockTest) { // auto end_time = std::chrono::high_resolution_clock::now(); // auto duration = std::chrono::duration_cast(end_time - start_time); // -// CANDY_LOG_INFO("TEST", "Configuration test completed duration_ms={} matches={} ", +// SAGEFLOW_LOG_INFO("TEST", "Configuration test completed duration_ms={} matches={} ", // duration.count(), total_matches.load()); // // // 验证处理完成且无崩溃 @@ -658,4 +636,4 @@ TEST_F(MultiThreadPipelineTest, HighConcurrencyDeadlockTest) { // ); } // namespace test -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/test/IntegrationTest/test_pipeline_execution.cpp b/test/IntegrationTest/test_pipeline_execution.cpp index 396ba18..706a09b 100644 --- a/test/IntegrationTest/test_pipeline_execution.cpp +++ b/test/IntegrationTest/test_pipeline_execution.cpp @@ -1,6 +1,6 @@ #include "../UnitTest/execution/test_common.h" -namespace candy { +namespace sageFlow { namespace test { /** @@ -89,4 +89,4 @@ TEST_F(MultiThreadPipelineExecutionTest, HighParallelism) { } } // namespace test -} // namespace candy +} // namespace sageFlow diff --git a/test/Performance/IndexTest.cpp b/test/Performance/IndexTest.cpp index 12ae78c..f50a7dc 100644 --- a/test/Performance/IndexTest.cpp +++ b/test/Performance/IndexTest.cpp @@ -4,7 +4,7 @@ #include #include // Keep for potential detailed monitoring -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS #include "operator/join_metrics.h" #endif @@ -22,13 +22,13 @@ #include "stream/data_stream_source/sift_stream_source.h" // Include SiftStreamSource explicitly using namespace std; // NOLINT -using namespace candy; // NOLINT +using namespace sageFlow; // NOLINT using namespace std::chrono; // NOLINT -const std::string CANDY_PATH = PROJECT_DIR; +const std::string SAGEFLOW_PATH = PROJECT_DIR; #define CONFIG_DIR "/config/" -namespace candy { +namespace sageFlow { // ------------------------------- @@ -47,7 +47,7 @@ void ValidateConfiguration(const ConfigMap &conf) { void SetupAndRunPipeline(const std::string &config_file_path) { StreamEnvironment env; - const auto conf = candy::StreamEnvironment::loadConfiguration(config_file_path); + const auto conf = sageFlow::StreamEnvironment::loadConfiguration(config_file_path); try { cerr << "Loading configuration..." << endl; ValidateConfiguration(conf); // Use the updated validation function @@ -311,17 +311,17 @@ void SetupAndRunPipeline(const std::string &config_file_path) { cout << "Queries Per Second (QPS): " << qps << " queries/second" << endl; } -} // namespace candy +} // namespace sageFlow // Main function remains the same auto main(int argc, char *argv[]) -> int { - const std::string default_config_file = CANDY_PATH + CONFIG_DIR + "index_test_config.toml"; + const std::string default_config_file = SAGEFLOW_PATH + CONFIG_DIR + "index_test_config.toml"; string config_file_path; if (argc < 2) { config_file_path = default_config_file; } else { - config_file_path = CANDY_PATH + CONFIG_DIR + string(argv[1]); + config_file_path = SAGEFLOW_PATH + CONFIG_DIR + string(argv[1]); } try { diff --git a/test/Performance/test_join_datasource_modes.cpp b/test/Performance/test_join_datasource_modes.cpp new file mode 100644 index 0000000..f89595a --- /dev/null +++ b/test/Performance/test_join_datasource_modes.cpp @@ -0,0 +1,548 @@ +#include +#include +#include +#include +#include +#include +#include +#include "operator/join_operator.h" +#include "stream/stream_environment.h" +#include "stream/data_stream_source/data_stream_source.h" +#include "function/sink_function.h" +#include "function/join_function.h" +#include "test_utils/test_data_generator.h" +#include "operator/join_metrics.h" +#include "concurrency/concurrency_manager.h" +#include "storage/storage_manager.h" +#include "test_utils/test_data_adapter.h" +#include "execution/collector.h" +#include "utils/logger.h" +#include "test_utils/dynamic_config.h" +#include "utils/log_config.h" +#include "test_utils/join_test_helper.h" +#include "test_utils/data_source/data_source_factory.h" +#include "test_utils/data_source/dataset_data_source.h" +#include "test_utils/data_writer/fvecs_writer.h" +#include "test_utils/data_writer/json_writer.h" +#include +#include +#include +#include +#include + +namespace sageFlow { +namespace test { + +// TestVectorStreamSource for feeding records into the pipeline +class TestVectorStreamSource : public DataStreamSource { + public: + explicit TestVectorStreamSource(std::string name, std::vector> records) + : DataStreamSource(std::move(name), DataStreamSourceType::None), records_(std::move(records)) {} + void Init() override { idx_=0; } + auto Next() -> std::unique_ptr override { + if (idx_ >= records_.size()) return nullptr; + return std::move(records_[idx_++]); + } + private: + std::vector> records_; + size_t idx_{0}; +}; + +// Configuration structure for data source modes tests +struct DataSourceModeConfig { + std::string name; + std::string mode; // "generate_save_load", "direct_load", "generate_direct_use" + std::vector methods; + std::vector sizes; + std::vector parallelism; + double threshold{0.8}; + std::vector win_ms_list{10000}; + uint64_t trig_ms{50}; + int vector_dim{128}; + int64_t time_interval_ms{10}; + uint32_t seed{42}; + + // Data source config + std::string data_source_type; // "random", "dataset", "json" + std::string data_source_file_path; + int data_source_expected_dim{128}; + bool data_source_loop{true}; + + // Storage config (for generate_save_load mode) + std::string storage_format; // "fvecs", "json" + std::string storage_file_path; +}; + +// Load configuration from TOML file +static std::vector loadDataSourceModeConfigs() { + std::vector configs; + std::vector perf_configs; + + if (!DynamicConfigManager::loadConfigs("config/perf_join_datasource_modes.toml", "performance_test", perf_configs)) { + SAGEFLOW_LOG_WARN("TEST", "Failed to load config from perf_join_datasource_modes.toml"); + return configs; + } + + // Set global log level if specified + DynamicConfig global_config; + if (DynamicConfigManager::loadConfig("config/perf_join_datasource_modes.toml", "", global_config)) { + auto log_level = global_config.get("log.level", "info"); + SAGEFLOW_LOG_INFO("TEST", "Setting log level to: {}", log_level); + sageFlow::init_log_level(log_level); + } + + for (const auto& config : perf_configs) { + DataSourceModeConfig mode_config; + mode_config.name = config.get("name", "unnamed_test"); + mode_config.mode = config.get("mode", "generate_direct_use"); + mode_config.methods = config.get>("methods", std::vector{"bruteforce_eager"}); + + auto sizes = config.get>("sizes", std::vector{}); + if (!sizes.empty()) { + mode_config.sizes = sizes; + } else { + auto records_count = config.get("records_count", 1000); + mode_config.sizes = {records_count}; + } + + mode_config.parallelism = config.get>("parallelism", std::vector{1}); + mode_config.threshold = config.get("similarity_threshold", 0.8); + + auto win_list = config.get>("window_time_ms", std::vector{}); + if (!win_list.empty()) { + mode_config.win_ms_list.clear(); + for (int v : win_list) mode_config.win_ms_list.push_back(static_cast(v)); + } else { + int win_single = config.get("window_time_ms", 10000); + mode_config.win_ms_list = {static_cast(win_single)}; + } + + mode_config.trig_ms = config.get("window_trigger_ms", 50); + mode_config.vector_dim = config.get("vector_dim", 128); + mode_config.time_interval_ms = config.get("time_interval", 10); + mode_config.seed = config.get("seed", 42); + + // Data source configuration + auto ds_type = config.get("data_source.type", "random"); + mode_config.data_source_type = ds_type; + mode_config.data_source_file_path = DynamicConfigManager::resolveProjectRelativePath( + config.get("data_source.file_path", "")); + + if (ds_type == "dataset") { + mode_config.data_source_expected_dim = config.get("data_source.expected_dim", 128); + int loop_val = config.get("data_source.loop", 1); + mode_config.data_source_loop = (loop_val != 0); + } + + // Storage configuration (for generate_save_load mode) + if (mode_config.mode == "generate_save_load") { + mode_config.storage_format = config.get("storage.format", "fvecs"); + mode_config.storage_file_path = DynamicConfigManager::resolveProjectRelativePath( + config.get("storage.file_path", "test/data/temp_generated.fvecs")); + } + + configs.push_back(mode_config); + + SAGEFLOW_LOG_INFO("TEST", "[CONFIG] Loaded test: name={} mode={} methods={} sizes={} vector_dim={}", + mode_config.name, mode_config.mode, mode_config.methods.size(), mode_config.sizes.size(), mode_config.vector_dim); + } + + return configs; +} + +// Compute expected matches using L2 distance and similarity threshold +static inline double l2_distance(const std::vector& a, const std::vector& b) { + double acc = 0.0; + const size_t n = std::min(a.size(), b.size()); + for (size_t i = 0; i < n; ++i) { + const double d = static_cast(a[i]) - static_cast(b[i]); + acc += d * d; + } + return std::sqrt(acc); +} + +static std::unordered_set, PairHash> + computeExpectedPairsByTraversal( + const std::vector>& left_records, + const std::vector>& right_records, + double similarity_threshold, + uint64_t window_ms, + double alpha = 0.1, + uint64_t modulo_base = 1000000ULL) { + std::unordered_set, PairHash> expected; + expected.reserve(left_records.size()); + + const int64_t w = static_cast(window_ms); + size_t j_low = 0; + size_t j_high = 0; + + const size_t R = right_records.size(); + for (const auto& l : left_records) { + if (!l) continue; + const int64_t tl = l->timestamp_; + + while (j_low < R) { + const auto& rr = right_records[j_low]; + if (!rr) { ++j_low; continue; } + if (rr->timestamp_ >= tl - w) break; + ++j_low; + } + + if (j_high < j_low) j_high = j_low; + + while (j_high < R) { + const auto& rr = right_records[j_high]; + if (!rr) { ++j_high; continue; } + if (rr->timestamp_ > tl + w) break; + ++j_high; + } + + const auto lv = extractFloatVector(*l); + for (size_t j = j_low; j < j_high; ++j) { + const auto& r = right_records[j]; + if (!r) continue; + const auto rv = extractFloatVector(*r); + const double dist = l2_distance(lv, rv); + const double sim = std::exp(-alpha * dist); + if (sim >= similarity_threshold) { + expected.insert({l->uid_, r->uid_ % modulo_base}); + } + } + } + + return expected; +} + +// Test class for data source modes +class JoinDataSourceModesTest : public ::testing::TestWithParam> { +protected: + void SetUp() override { + JoinMetrics::instance().reset(); + concurrency_manager_ = std::make_shared(std::make_shared()); + } + + void TearDown() override { + std::filesystem::create_directories("build/metrics"); + std::string metrics_path = "build/metrics/join_datasource_modes_" + + std::to_string(std::chrono::system_clock::now().time_since_epoch().count()) + ".tsv"; + JoinMetrics::instance().dump_tsv(metrics_path); + SAGEFLOW_LOG_INFO("TEST", "Performance metrics saved to {}", metrics_path); + } + +protected: + std::shared_ptr concurrency_manager_; +}; + +TEST_P(JoinDataSourceModesTest, DataSourceModePerformance) { + auto [mode_config, method, data_size, parallelism, win_ms] = GetParam(); + + SAGEFLOW_LOG_INFO("TEST", "===== Running test: {} mode={} method={} size={} parallelism={} win_ms={} =====", + mode_config.name, mode_config.mode, method, data_size, parallelism, win_ms); + + // Prepare data based on mode + std::vector> base_records; + + if (mode_config.mode == "generate_save_load") { + // Mode 1: Generate -> Save -> Load + SAGEFLOW_LOG_INFO("TEST", "[MODE1] Generate-Save-Load: format={} path={}", + mode_config.storage_format, mode_config.storage_file_path); + + // Check if file already exists + bool file_exists = std::filesystem::exists(mode_config.storage_file_path); + + if (!file_exists) { + // Generate data + SAGEFLOW_LOG_INFO("TEST", "[MODE1] File doesn't exist, generating data"); + TestDataGenerator::Config gen_config; + gen_config.vector_dim = mode_config.vector_dim; + gen_config.similarity_threshold = mode_config.threshold; + gen_config.seed = mode_config.seed; + gen_config.base_timestamp = 1000000; + gen_config.time_interval = mode_config.time_interval_ms; + + int target_pos = static_cast(data_size * 0.10); + int target_neg = static_cast(data_size * 0.60); + int pos_pairs = target_pos / 2; + int neg_pairs = target_neg / 2; + int used = 2*pos_pairs + 2*neg_pairs; + int tail = std::max(0, data_size - used); + gen_config.positive_pairs = pos_pairs; + gen_config.near_threshold_pairs = 0; + gen_config.negative_pairs = neg_pairs; + gen_config.random_tail = tail; + + TestDataGenerator generator(gen_config); + auto [records, _] = generator.generateData(); + + // Save to file + std::filesystem::create_directories(std::filesystem::path(mode_config.storage_file_path).parent_path()); + std::shared_ptr writer; + if (mode_config.storage_format == "fvecs") { + writer = std::make_shared(); + } else { + writer = std::make_shared(); + } + generator.saveGeneratedVectors(mode_config.storage_file_path, writer); + SAGEFLOW_LOG_INFO("TEST", "[MODE1] Saved {} records to {}", records.size(), mode_config.storage_file_path); + } else { + SAGEFLOW_LOG_INFO("TEST", "[MODE1] File exists, skipping generation"); + } + + // Load from file + DatasetDataSource::Config ds_config; + ds_config.file_path = mode_config.storage_file_path; + ds_config.expected_dim = mode_config.vector_dim; + ds_config.loop = true; + + DatasetDataSource data_source(ds_config); + base_records.reserve(data_size); + int64_t base_ts = 1000000; + uint64_t uid = 1; + while (data_source.hasMore() && base_records.size() < static_cast(data_size)) { + auto vec = data_source.getNextVector(); + auto record = createVectorRecord(uid++, base_ts, vec); + base_ts += mode_config.time_interval_ms; + base_records.push_back(std::move(record)); + } + SAGEFLOW_LOG_INFO("TEST", "[MODE1] Loaded {} records from file", base_records.size()); + + } else if (mode_config.mode == "direct_load") { + // Mode 2: Direct Load from existing dataset + SAGEFLOW_LOG_INFO("TEST", "[MODE2] Direct-Load from: {}", mode_config.data_source_file_path); + + DatasetDataSource::Config ds_config; + ds_config.file_path = mode_config.data_source_file_path; + ds_config.expected_dim = mode_config.data_source_expected_dim; + ds_config.loop = mode_config.data_source_loop; + + DatasetDataSource data_source(ds_config); + base_records.reserve(data_size); + int64_t base_ts = 1000000; + uint64_t uid = 1; + while (data_source.hasMore() && base_records.size() < static_cast(data_size)) { + auto vec = data_source.getNextVector(); + auto record = createVectorRecord(uid++, base_ts, vec); + base_ts += mode_config.time_interval_ms; + base_records.push_back(std::move(record)); + } + SAGEFLOW_LOG_INFO("TEST", "[MODE2] Loaded {} records directly from dataset", base_records.size()); + + } else { + // Mode 3: Generate and use directly (no file I/O) + SAGEFLOW_LOG_INFO("TEST", "[MODE3] Generate-Direct-Use (no file I/O)"); + + TestDataGenerator::Config gen_config; + gen_config.vector_dim = mode_config.vector_dim; + gen_config.similarity_threshold = mode_config.threshold; + gen_config.seed = mode_config.seed; + gen_config.base_timestamp = 1000000; + gen_config.time_interval = mode_config.time_interval_ms; + + int target_pos = static_cast(data_size * 0.10); + int target_neg = static_cast(data_size * 0.60); + int pos_pairs = target_pos / 2; + int neg_pairs = target_neg / 2; + int used = 2*pos_pairs + 2*neg_pairs; + int tail = std::max(0, data_size - used); + gen_config.positive_pairs = pos_pairs; + gen_config.near_threshold_pairs = 0; + gen_config.negative_pairs = neg_pairs; + gen_config.random_tail = tail; + + TestDataGenerator generator(gen_config); + auto [records, _] = generator.generateData(); + base_records = std::move(records); + SAGEFLOW_LOG_INFO("TEST", "[MODE3] Generated {} records directly", base_records.size()); + } + + // Split into left and right streams using JoinTestHelper (already refactored pattern) + std::vector> left_records; + left_records.reserve(base_records.size()); + for (auto& r : base_records) { + left_records.push_back(std::move(r)); + } + + std::vector> right_records; + right_records.reserve(left_records.size()); + constexpr uint64_t kRightUidOffset = 500000; + for (auto& lr : left_records) { + right_records.push_back(std::make_unique(lr->uid_ + kRightUidOffset, lr->timestamp_, lr->data_)); + } + + const size_t expected_left = left_records.size(); + const size_t expected_right = right_records.size(); + + // Compute expected matches + auto expected_matches = computeExpectedPairsByTraversal(left_records, right_records, mode_config.threshold, win_ms); + + // Create stream sources + auto left_source = std::make_shared("DataModeLeft", std::move(left_records)); + auto right_source = std::make_shared("DataModeRight", std::move(right_records)); + + // Create join function + auto join_func = std::make_unique( + "DataModeJoin", + [](std::unique_ptr& left, + std::unique_ptr& right) -> std::unique_ptr { + auto lv = extractFloatVector(*left); + auto rv = extractFloatVector(*right); + std::vector out; + out.reserve(lv.size() + rv.size()); + out.insert(out.end(), lv.begin(), lv.end()); + out.insert(out.end(), rv.begin(), rv.end()); + uint64_t id = left->uid_ * 1000000 + right->uid_ % 1000000; + int64_t ts = std::max(left->timestamp_, right->timestamp_); + return createVectorRecord(id, ts, out); + }, mode_config.vector_dim); + join_func->setWindow(win_ms, mode_config.trig_ms); + + // Collect matches + std::mutex match_mutex; + std::unordered_set, PairHash> actual_pairs; + auto sink_func = std::make_unique("DataModeSink", [&](std::unique_ptr& rec){ + if (!rec) return; + uint64_t cid = rec->uid_; + uint64_t lid = cid / 1000000; + uint64_t rid = cid % 1000000; + std::lock_guard lg(match_mutex); + actual_pairs.insert({lid, rid}); + }); + + // Build pipeline + left_source->join(right_source, std::move(join_func), method, mode_config.threshold, (size_t)parallelism) + ->writeSink(std::move(sink_func), 1); + + // Execute + StreamEnvironment env; + JoinMetrics::instance().reset(); + env.addStream(left_source); + env.addStream(right_source); + + auto start_time = std::chrono::high_resolution_clock::now(); + env.execute(); + + // Wait for completion + { + using namespace std::chrono_literals; + const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(600); + for (;;) { + uint64_t l = JoinMetrics::instance().total_records_left.load(); + uint64_t r = JoinMetrics::instance().total_records_right.load(); + if (l >= expected_left && r >= expected_right) break; + if (std::chrono::steady_clock::now() >= deadline) { + SAGEFLOW_LOG_WARN("TEST", "Timeout waiting for processing: left={}/{} right={}/{}", l, expected_left, r, expected_right); + break; + } + std::this_thread::sleep_for(5ms); + } + + // Wait for output stabilization + const auto stable_window = 50ms; + const auto max_wait = std::chrono::seconds(5); + uint64_t last = JoinMetrics::instance().total_emits.load(); + auto stable_since = std::chrono::steady_clock::now(); + auto end_by = std::chrono::steady_clock::now() + max_wait; + while (std::chrono::steady_clock::now() < end_by) { + std::this_thread::sleep_for(5ms); + uint64_t cur = JoinMetrics::instance().total_emits.load(); + if (cur != last) { last = cur; stable_since = std::chrono::steady_clock::now(); } + if (std::chrono::steady_clock::now() - stable_since >= stable_window) break; + } + } + + env.stop(); + env.awaitTermination(); + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end_time - start_time); + + // Calculate metrics + size_t match_count = 0; + for (auto ap : actual_pairs) { + if (expected_matches.count(ap)) match_count++; + } + + double recall = expected_matches.empty() ? 1.0 : static_cast(match_count) / static_cast(expected_matches.size()); + double precision = actual_pairs.empty() ? 0.0 : static_cast(match_count) / static_cast(actual_pairs.size()); + double f1 = (precision + recall) > 0 ? 2 * precision * recall / (precision + recall) : 0.0; + + SAGEFLOW_LOG_INFO("TEST", "Result: name={} mode={} method={} size={} parallelism={} time_ms={} matches={}/{} recall={:.3f} precision={:.3f} f1={:.3f}", + mode_config.name, mode_config.mode, method, data_size, parallelism, duration.count(), + match_count, expected_matches.size(), recall, precision, f1); + + // Write to report file + try { + const auto report_dir = +#ifdef PROJECT_DIR + std::filesystem::path(PROJECT_DIR) / "test" / "result" +#else + std::filesystem::current_path() / "test" / "result" +#endif + ; + std::filesystem::create_directories(report_dir); + const auto report_path_fs = report_dir / "datasource_modes_report.tsv"; + std::string report_path = report_path_fs.string(); + bool new_file = !std::filesystem::exists(report_path); + std::ofstream ofs(report_path, std::ios::app); + if (ofs.is_open()) { + if (new_file) { + ofs << "test_name\tmode\tmethod\tsize\tparallelism\twin_ms\ttime_ms\tmatches\texpected\trecall\tprecision\tf1\n"; + } + ofs << mode_config.name << '\t' << mode_config.mode << '\t' << method << '\t' << data_size << '\t' + << parallelism << '\t' << win_ms << '\t' << duration.count() << '\t' << match_count << '\t' + << expected_matches.size() << '\t' << recall << '\t' << precision << '\t' << f1 << '\n'; + ofs.flush(); + SAGEFLOW_LOG_INFO("TEST", "Report written to {}", report_path); + } + } catch(const std::exception &e) { + SAGEFLOW_LOG_WARN("TEST", "Failed to write report: {}", e.what()); + } + + // Assertions + EXPECT_GE(recall, 0.85) << "Recall too low for " << mode_config.name; + EXPECT_GE(precision, 0.85) << "Precision too low for " << mode_config.name; +} + +// Generate test parameters +static std::vector> buildTestParams() { + std::vector> params; + + auto configs = loadDataSourceModeConfigs(); + for (const auto& config : configs) { + for (const auto& method : config.methods) { + for (int size : config.sizes) { + for (int par : config.parallelism) { + for (uint64_t win : config.win_ms_list) { + params.push_back({config, method, size, par, win}); + SAGEFLOW_LOG_INFO("TEST", "[PARAM] Generated test case: {} mode={} method={} size={} par={} win={}", + config.name, config.mode, method, size, par, win); + } + } + } + } + } + + return params; +} + +INSTANTIATE_TEST_SUITE_P( + DataSourceModes, + JoinDataSourceModesTest, + ::testing::ValuesIn(buildTestParams()), + [](const ::testing::TestParamInfo>& info) { + const DataSourceModeConfig& config = std::get<0>(info.param); + const std::string& method = std::get<1>(info.param); + int size = std::get<2>(info.param); + int parallelism = std::get<3>(info.param); + uint64_t win_ms = std::get<4>(info.param); + + std::string name = config.name + "_" + method + "_" + std::to_string(size) + "_p" + + std::to_string(parallelism) + "_w" + std::to_string(win_ms); + // Replace invalid characters + std::replace(name.begin(), name.end(), '/', '_'); + std::replace(name.begin(), name.end(), '.', '_'); + return name; + } +); + +} // namespace test +} // namespace sageFlow diff --git a/test/Performance/test_join_perf_scaling.cpp b/test/Performance/test_join_perf_scaling.cpp index e5ea102..ec2706e 100644 --- a/test/Performance/test_join_perf_scaling.cpp +++ b/test/Performance/test_join_perf_scaling.cpp @@ -19,13 +19,14 @@ #include "utils/logger.h" #include "test_utils/dynamic_config.h" #include "utils/log_config.h" +#include "test_utils/join_test_helper.h" #include #include #include #include #include -namespace candy { +namespace sageFlow { namespace test { // 本地复制 TestVectorStreamSource(避免跨测试文件定义冲突/依赖) @@ -57,7 +58,7 @@ class JoinPerformanceTest : public ::testing::Test { std::string metrics_path = "build/metrics/join_perf_" + std::to_string(std::chrono::system_clock::now().time_since_epoch().count()) + ".tsv"; JoinMetrics::instance().dump_tsv(metrics_path); - CANDY_LOG_INFO("TEST", "Performance metrics saved path={} ", metrics_path); + SAGEFLOW_LOG_INFO("TEST", "Performance metrics saved path={} ", metrics_path); } // 创建JoinFunction(使用 test_data_adapter 助手),支持配置维度与窗口 @@ -105,27 +106,27 @@ inline PerfConfigSets loadPerfConfig() { // 目前只取第一个配置块,若需要可扩展成多组 const auto& config = perf_configs.front(); - out.threshold = config.get("similarity_threshold", out.threshold); - // 注意:DynamicConfig 将整数优先存为 int,因此这里按 int 读取以避免类型不匹配导致默认值生效 - // window_time_ms 既支持数组也支持单值 - { - auto win_list = config.get>("window_time_ms", std::vector{}); - if (!win_list.empty()) { - out.win_ms_list.clear(); - out.win_ms_list.reserve(win_list.size()); - for (int v : win_list) out.win_ms_list.push_back(static_cast(v)); - } else { - int win_single = config.get("window_time_ms", 2000); - out.win_ms_list = { static_cast(win_single) }; + out.threshold = config.get("similarity_threshold", out.threshold); + // 注意:DynamicConfig 将整数优先存为 int,因此这里按 int 读取以避免类型不匹配导致默认值生效 + // window_time_ms 既支持数组也支持单值 + { + auto win_list = config.get>("window_time_ms", std::vector{}); + if (!win_list.empty()) { + out.win_ms_list.clear(); + out.win_ms_list.reserve(win_list.size()); + for (int v : win_list) out.win_ms_list.push_back(static_cast(v)); + } else { + int win_single = config.get("window_time_ms", 2000); + out.win_ms_list = { static_cast(win_single) }; + } } - } - - auto trigger_ms = config.get("window_trigger_ms", 0); - if (trigger_ms > 0) out.trig_ms = static_cast(trigger_ms); + + auto trigger_ms = config.get("window_trigger_ms", 0); + if (trigger_ms > 0) out.trig_ms = static_cast(trigger_ms); out.methods = config.get>("methods", out.methods); out.parallelism = config.get>("parallelism", out.parallelism); - out.vector_dim = config.get("vector_dim", out.vector_dim); + out.vector_dim = config.get("vector_dim", out.vector_dim); // 读取 time_interval(可选) out.time_interval_ms = config.get("time_interval", static_cast(out.time_interval_ms)); @@ -146,14 +147,14 @@ inline PerfConfigSets loadPerfConfig() { for (size_t i=0;i("log.level", "info"); std::cout << "[PerfTest] Setting log level to: " << log_level << std::endl; - candy::init_log_level(log_level); + sageFlow::init_log_level(log_level); } return out; @@ -242,15 +243,15 @@ class JoinScalingTest : public ::testing::TestWithParam { std::unique_ptr createSimpleJoinFunction() { auto join_func_lambda = [](std::unique_ptr& left, std::unique_ptr& right) -> std::unique_ptr { - auto lv = extractFloatVector(*left); - auto rv = extractFloatVector(*right); - std::vector out; - out.reserve(lv.size() + rv.size()); - out.insert(out.end(), lv.begin(), lv.end()); - out.insert(out.end(), rv.begin(), rv.end()); - uint64_t id = left->uid_ * 1000000 + right->uid_; - int64_t ts = std::max(left->timestamp_, right->timestamp_); - return createVectorRecord(id, ts, out); + auto lv = extractFloatVector(*left); + auto rv = extractFloatVector(*right); + std::vector out; + out.reserve(lv.size() + rv.size()); + out.insert(out.end(), lv.begin(), lv.end()); + out.insert(out.end(), rv.begin(), rv.end()); + uint64_t id = left->uid_ * 1000000 + right->uid_; + int64_t ts = std::max(left->timestamp_, right->timestamp_); + return createVectorRecord(id, ts, out); }; return std::make_unique("SimpleJoin", join_func_lambda, 128); @@ -272,7 +273,7 @@ TEST_P(JoinScalingTest, PerformanceScaling) { } // 开始前打印本轮参数(方法/规模/并行度) - CANDY_LOG_INFO("TEST", "[BEGIN] method={} size={} parallelism={} ", method, data_size, parallelism); + SAGEFLOW_LOG_INFO("TEST", "[BEGIN] method={} size={} parallelism={} ", method, data_size, parallelism); static PerfConfigSets g_sets_for_dim = loadPerfConfig(); TestDataGenerator::Config config; config.vector_dim = g_sets_for_dim.vector_dim; config.similarity_threshold = 0.8; @@ -304,7 +305,7 @@ TEST_P(JoinScalingTest, PerformanceScaling) { uint64_t trig_ms = g_sets.trig_ms; double threshold_override = g_sets.threshold; // 打印本轮的窗口/阈值等关键参数 - CANDY_LOG_INFO("TEST", "[PARAM] threshold={} win_ms={} trig_ms={} time_interval_ms={} ", threshold_override, win_ms, trig_ms, g_sets.time_interval_ms); + SAGEFLOW_LOG_INFO("TEST", "[PARAM] threshold={} win_ms={} trig_ms={} time_interval_ms={} ", threshold_override, win_ms, trig_ms, g_sets.time_interval_ms); // 构建环境与 Source StreamEnvironment env; @@ -380,7 +381,7 @@ TEST_P(JoinScalingTest, PerformanceScaling) { uint64_t r = JoinMetrics::instance().total_records_right.load(); if (l >= expected_left && r >= expected_right) break; if (std::chrono::steady_clock::now() >= deadline) { - CANDY_LOG_WARN("TEST", "wait_for_processed timeout l={}/{} r={}/{}", l, expected_left, r, expected_right); + SAGEFLOW_LOG_WARN("TEST", "wait_for_processed timeout l={}/{} r={}/{}", l, expected_left, r, expected_right); break; } std::this_thread::sleep_for(5ms); @@ -408,11 +409,11 @@ TEST_P(JoinScalingTest, PerformanceScaling) { // 精准匹配统计 size_t match_count = 0; for (auto ap : actual_pairs) { - // CANDY_LOG_INFO("TEST", " Actual match: L={} R={} ", ap.first, ap.second); + // SAGEFLOW_LOG_INFO("TEST", " Actual match: L={} R={} ", ap.first, ap.second); if (expected_matches.count(ap)) match_count++; } for (auto ep : expected_matches) { - // CANDY_LOG_INFO("TEST", " Expected match: L={} R={} ", ep.first, ep.second); + // SAGEFLOW_LOG_INFO("TESRT", " Expected match: L={} R={} ", ep.first, ep.second); } double recall = expected_matches.empty() ? 1.0 : static_cast(match_count) / static_cast(expected_matches.size()); @@ -420,7 +421,7 @@ TEST_P(JoinScalingTest, PerformanceScaling) { double precision = static_cast(match_count)/static_cast(actual_pairs.size()); double f1 = (precision+recall)>0 ? 2*precision*recall/(precision+recall):0.0; - CANDY_LOG_INFO("TEST", "Method={} Size={} Parallelism={} time_ms={} matches={} expected={} recall={} precision={} f1={} win_ms={} trig_ms={} ", + SAGEFLOW_LOG_INFO("TEST", "Method={} Size={} Parallelism={} time_ms={} matches={} expected={} recall={} precision={} f1={} win_ms={} trig_ms={} ", method, data_size, parallelism, duration.count(), match_count, expected_matches.size(), recall, precision, f1, win_ms, trig_ms); // 将结果追加写入报告 @@ -438,7 +439,7 @@ TEST_P(JoinScalingTest, PerformanceScaling) { bool new_file = !std::filesystem::exists(report_path); std::ofstream ofs(report_path, std::ios::app); if (!ofs.is_open()) { - CANDY_LOG_WARN("TEST", "[REPORT] cannot open {} for write", report_path); + SAGEFLOW_LOG_WARN("TEST", "[REPORT] cannot open {} for write", report_path); return; // 放弃写报告,但不影响测试断言 } if (new_file) { @@ -475,9 +476,9 @@ TEST_P(JoinScalingTest, PerformanceScaling) { << JoinMetrics::instance().candidate_fetch_ns.load() << '\t' << input_tput_rps << '\t' << output_tput_rps << '\t' << avg_apply_ms << '\t' << avg_e2e_ms << '\n'; ofs.flush(); - CANDY_LOG_INFO("TEST", "[REPORT] appended to {}", report_path); + SAGEFLOW_LOG_INFO("TEST", "[REPORT] appended to {}", report_path); } catch(const std::exception &e) { - CANDY_LOG_WARN("TEST", "write_report_failed what={} ", e.what()); + SAGEFLOW_LOG_WARN("TEST", "write_report_failed what={} ", e.what()); } // 基本性能与锁争用检测 @@ -501,7 +502,7 @@ static std::vector buildParams() { for (auto par : sets.parallelism) for (auto win : sets.win_ms_list) { - CANDY_LOG_INFO("TEST", "[PARAMGEN] method={} size={} parallelism={} win_ms={} ", m, sz, par, win); + SAGEFLOW_LOG_INFO("TEST", "[PARAMGEN] method={} size={} parallelism={} win_ms={} ", m, sz, par, win); params.push_back({m, sz, par, win}); } return params; @@ -517,104 +518,96 @@ TEST_F(JoinPerformanceTest, MethodSpeedComparison) { for (auto data_size : sets.sizes) { for (auto par : sets.parallelism) { for (auto win_ms : sets.win_ms_list) { - CANDY_LOG_INFO("TEST", "[BEGIN] MethodSpeedComparison size={} parallelism={} win_ms={} ", data_size, par, win_ms); - CANDY_LOG_INFO("TEST", "[PARAM] threshold={} win_ms={} trig_ms={} ", sets.threshold, win_ms, sets.trig_ms); - - std::vector> method_times; - - for (const auto& method : sets.methods) { - // 为当前方法生成确定性数据(按 data_size 严格对齐) - TestDataGenerator::Config cfg; cfg.vector_dim = sets.vector_dim; cfg.similarity_threshold = sets.threshold; cfg.seed = 42; cfg.time_interval = sets.time_interval_ms; - { - int target_pos = static_cast(data_size * 0.10); - int target_near = 0; // 可按需开启近邻样本 - int target_neg = static_cast(data_size * 0.60); - int pos_pairs = target_pos / 2; - int near_pairs = target_near / 2; - int neg_pairs = target_neg / 2; - int used = 2*pos_pairs + 2*near_pairs + 2*neg_pairs; - int tail = std::max(0, data_size - used); - cfg.positive_pairs = pos_pairs; - cfg.near_threshold_pairs = near_pairs; - cfg.negative_pairs = neg_pairs; - cfg.random_tail = tail; - } - - TestDataGenerator gen(cfg); - auto [records, expected_matches] = gen.generateData(); - - // 切分左右流,右侧UID偏移,基于流式管道进行计时(以体现并行度) - std::vector> left_records; - left_records.reserve(records.size()); - for (auto &r : records) left_records.push_back(std::move(r)); - - std::vector> right_records; - right_records.reserve(left_records.size()); - constexpr uint64_t kRightUidOffset = 500000; - for (auto &lr : left_records) { - right_records.push_back(std::make_unique(lr->uid_ + kRightUidOffset, lr->timestamp_, lr->data_)); - } - - // 记录期望输入计数后再 move 到 Source - const size_t expected_left = left_records.size(); - const size_t expected_right = right_records.size(); - auto left_source = std::make_shared("MSLeft", std::move(left_records)); - auto right_source = std::make_shared("MSRight", std::move(right_records)); - - auto join_func = createSimpleJoinFunction(sets.vector_dim, win_ms, sets.trig_ms); - - std::atomic match_count{0}; - auto sink_func = std::make_unique("MSSink", [&](std::unique_ptr& rec){ if (rec) match_count++; }); - - StreamEnvironment env; - JoinMetrics::instance().reset(); - - left_source->join(right_source, std::move(join_func), method, sets.threshold, (size_t)par) - ->writeSink(std::move(sink_func), 1); - env.addStream(left_source); - env.addStream(right_source); + SAGEFLOW_LOG_INFO("TEST", "[BEGIN] MethodSpeedComparison size={} parallelism={} win_ms={} ", data_size, par, win_ms); + SAGEFLOW_LOG_INFO("TEST", "[PARAM] threshold={} win_ms={} trig_ms={} ", sets.threshold, win_ms, sets.trig_ms); + + std::vector> method_times; + + for (const auto& method : sets.methods) { + // 为当前方法生成确定性数据(按 data_size 严格对齐) + TestDataGenerator::Config cfg; cfg.vector_dim = sets.vector_dim; cfg.similarity_threshold = sets.threshold; cfg.seed = 42; cfg.time_interval = sets.time_interval_ms; + { + int target_pos = static_cast(data_size * 0.10); + int target_near = 0; // 可按需开启近邻样本 + int target_neg = static_cast(data_size * 0.60); + int pos_pairs = target_pos / 2; + int near_pairs = target_near / 2; + int neg_pairs = target_neg / 2; + int used = 2*pos_pairs + 2*near_pairs + 2*neg_pairs; + int tail = std::max(0, data_size - used); + cfg.positive_pairs = pos_pairs; + cfg.near_threshold_pairs = near_pairs; + cfg.negative_pairs = neg_pairs; + cfg.random_tail = tail; + } - auto start = std::chrono::high_resolution_clock::now(); - env.execute(); - // 等待直到 JoinOperator 消费完所有输入(以指标计数为准) - { - using namespace std::chrono_literals; - const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(60); - for (;;) { - uint64_t l = JoinMetrics::instance().total_records_left.load(); - uint64_t r = JoinMetrics::instance().total_records_right.load(); - if (l >= expected_left && r >= expected_right) break; - if (std::chrono::steady_clock::now() >= deadline) { - CANDY_LOG_WARN("TEST", "wait_for_processed timeout l={}/{} r={}/{}", l, expected_left, r, expected_right); - break; + TestDataGenerator gen(cfg); + auto [records, expected_matches] = gen.generateData(); + + // 使用 JoinTestHelper 生成左右流(替代手动复制逻辑) + auto [left_records, right_records] = + JoinTestHelper::generateJoinStreamsFromGenerator(gen, true); + + // 记录期望输入计数后再 move 到 Source + const size_t expected_left = left_records.size(); + const size_t expected_right = right_records.size(); + auto left_source = std::make_shared("MSLeft", std::move(left_records)); + auto right_source = std::make_shared("MSRight", std::move(right_records)); + + auto join_func = createSimpleJoinFunction(sets.vector_dim, win_ms, sets.trig_ms); + + std::atomic match_count{0}; + auto sink_func = std::make_unique("MSSink", [&](std::unique_ptr& rec){ if (rec) match_count++; }); + + StreamEnvironment env; + JoinMetrics::instance().reset(); + + left_source->join(right_source, std::move(join_func), method, sets.threshold, (size_t)par) + ->writeSink(std::move(sink_func), 1); + env.addStream(left_source); + env.addStream(right_source); + + auto start = std::chrono::high_resolution_clock::now(); + env.execute(); + // 等待直到 JoinOperator 消费完所有输入(以指标计数为准) + { + using namespace std::chrono_literals; + const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(60); + for (;;) { + uint64_t l = JoinMetrics::instance().total_records_left.load(); + uint64_t r = JoinMetrics::instance().total_records_right.load(); + if (l >= expected_left && r >= expected_right) break; + if (std::chrono::steady_clock::now() >= deadline) { + SAGEFLOW_LOG_WARN("TEST", "wait_for_processed timeout l={}/{} r={}/{}", l, expected_left, r, expected_right); + break; + } + std::this_thread::sleep_for(5ms); } - std::this_thread::sleep_for(5ms); } + env.stop(); + env.awaitTermination(); + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start); + + method_times.emplace_back(method, duration.count()); + SAGEFLOW_LOG_INFO("TEST", "Method={} Size={} Par={} time_ms={} matches={} win_ms={} trig_ms={} ", + method, data_size, par, duration.count(), (size_t)match_count.load(), win_ms, sets.trig_ms); } - env.stop(); - env.awaitTermination(); - auto end = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(end - start); - - method_times.emplace_back(method, duration.count()); - CANDY_LOG_INFO("TEST", "Method={} Size={} Par={} time_ms={} matches={} win_ms={} trig_ms={} ", - method, data_size, par, duration.count(), (size_t)match_count.load(), win_ms, sets.trig_ms); - } - // 验证:在较大规模(>=5000)下 IVF 应当快于 Bruteforce(放宽倍数关系以避免偶然波动) - auto bruteforce_time = std::find_if(method_times.begin(), method_times.end(), - [](const auto& p) { return p.first == "bruteforce_eager"; }); - auto ivf_time = std::find_if(method_times.begin(), method_times.end(), - [](const auto& p) { return p.first == "ivf_eager"; }); + // 验证:在较大规模(>=5000)下 IVF 应当快于 Bruteforce(放宽倍数关系以避免偶然波动) + auto bruteforce_time = std::find_if(method_times.begin(), method_times.end(), + [](const auto& p) { return p.first == "bruteforce_eager"; }); + auto ivf_time = std::find_if(method_times.begin(), method_times.end(), + [](const auto& p) { return p.first == "ivf_eager"; }); - if (data_size >= 5000 && bruteforce_time != method_times.end() && ivf_time != method_times.end()) { - EXPECT_LT(ivf_time->second * 2, bruteforce_time->second) - << "IVF should be faster than BruteForce for large datasets (size=" << data_size << ", par=" << par << ")"; - } + if (data_size >= 5000 && bruteforce_time != method_times.end() && ivf_time != method_times.end()) { + EXPECT_LT(ivf_time->second * 2, bruteforce_time->second) + << "IVF should be faster than BruteForce for large datasets (size=" << data_size << ", par=" << par << ")"; + } } } } } } // namespace test -} // namespace candy +} // namespace sageFlow diff --git a/test/Performance/test_join_performance_methods.cpp b/test/Performance/test_join_performance_methods.cpp index fc40b95..2afa93d 100644 --- a/test/Performance/test_join_performance_methods.cpp +++ b/test/Performance/test_join_performance_methods.cpp @@ -20,19 +20,19 @@ #include "test_utils/test_data_adapter.h" #include "execution/collector.h" -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS #include "operator/join_metrics.h" #endif #ifndef PROJECT_DIR -#define PROJECT_DIR "d:/Share Libary/candyFlow_zero" +#define PROJECT_DIR "d:/Share Libary/sageFlowFlow_zero" #endif using namespace std; -using namespace candy; +using namespace sageFlow; using namespace std::chrono; -namespace candy { +namespace sageFlow { // 模拟结果收集器,用于替代emit功能 class JoinResultCollector { @@ -309,40 +309,40 @@ class JoinMethodPerformanceTester { } }; -} // namespace candy +} // namespace sageFlow // Google Test 测试用例 TEST(JoinMethodPerformanceTest, BruteForceEager) { - candy::JoinMethodPerformanceTester tester; - double execution_time = tester.RunJoinMethodTest(candy::JoinMethodType::BRUTEFORCE_EAGER, 500); + sageFlow::JoinMethodPerformanceTester tester; + double execution_time = tester.RunJoinMethodTest(sageFlow::JoinMethodType::BRUTEFORCE_EAGER, 500); EXPECT_GT(execution_time, 0); } TEST(JoinMethodPerformanceTest, BruteForceLazy) { - candy::JoinMethodPerformanceTester tester; - double execution_time = tester.RunJoinMethodTest(candy::JoinMethodType::BRUTEFORCE_LAZY, 500); + sageFlow::JoinMethodPerformanceTester tester; + double execution_time = tester.RunJoinMethodTest(sageFlow::JoinMethodType::BRUTEFORCE_LAZY, 500); EXPECT_GT(execution_time, 0); } TEST(JoinMethodPerformanceTest, IVFEager) { - candy::JoinMethodPerformanceTester tester; - double execution_time = tester.RunJoinMethodTest(candy::JoinMethodType::IVF_EAGER, 500); + sageFlow::JoinMethodPerformanceTester tester; + double execution_time = tester.RunJoinMethodTest(sageFlow::JoinMethodType::IVF_EAGER, 500); EXPECT_GT(execution_time, 0); } TEST(JoinMethodPerformanceTest, IvfLazy) { - candy::JoinMethodPerformanceTester tester; - double execution_time = tester.RunJoinMethodTest(candy::JoinMethodType::IVF_EAGER, 500); + sageFlow::JoinMethodPerformanceTester tester; + double execution_time = tester.RunJoinMethodTest(sageFlow::JoinMethodType::IVF_EAGER, 500); EXPECT_GT(execution_time, 0); } TEST(JoinMethodPerformanceTest, CompareAllMethods) { - candy::JoinMethodPerformanceTester tester; + sageFlow::JoinMethodPerformanceTester tester; ASSERT_NO_THROW(tester.RunComparisonTest(500)); } TEST(JoinMethodPerformanceTest, ScalabilityTest) { - candy::JoinMethodPerformanceTester tester; + sageFlow::JoinMethodPerformanceTester tester; ASSERT_NO_THROW(tester.RunScalabilityTest()); } @@ -354,7 +354,7 @@ int main(int argc, char* argv[]) { // 如果命令行参数指定了特定的测试模式 if (argc > 1) { string mode = argv[1]; - candy::JoinMethodPerformanceTester tester; + sageFlow::JoinMethodPerformanceTester tester; if (mode == "performance") { cout << "Running Join Method Performance Comparison Test..." << endl; @@ -366,13 +366,13 @@ int main(int argc, char* argv[]) { return 0; } else if (mode == "eager") { cout << "Testing BruteForce Eager method..." << endl; - tester.RunJoinMethodTest(candy::JoinMethodType::BRUTEFORCE_EAGER, 1000); - tester.RunJoinMethodTest(candy::JoinMethodType::IVF_EAGER, 1000); + tester.RunJoinMethodTest(sageFlow::JoinMethodType::BRUTEFORCE_EAGER, 1000); + tester.RunJoinMethodTest(sageFlow::JoinMethodType::IVF_EAGER, 1000); return 0; } else if (mode == "lazy") { cout << "Testing BruteForce Lazy method..." << endl; - tester.RunJoinMethodTest(candy::JoinMethodType::BRUTEFORCE_LAZY, 1000); - tester.RunJoinMethodTest(candy::JoinMethodType::IVF_LAZY, 1000); + tester.RunJoinMethodTest(sageFlow::JoinMethodType::BRUTEFORCE_LAZY, 1000); + tester.RunJoinMethodTest(sageFlow::JoinMethodType::IVF_LAZY, 1000); return 0; } } diff --git a/test/Performance/test_window_pipeline.cpp b/test/Performance/test_window_pipeline.cpp index e9e2482..916ec0f 100644 --- a/test/Performance/test_window_pipeline.cpp +++ b/test/Performance/test_window_pipeline.cpp @@ -16,11 +16,11 @@ #include #include "utils/logger.h" -#ifdef CANDY_ENABLE_METRICS +#ifdef SAGEFLOW_ENABLE_METRICS #include "operator/join_metrics.h" #endif -using namespace candy; +using namespace sageFlow; using namespace std; using namespace std::chrono; @@ -29,7 +29,7 @@ TEST(WindowTest, TumblingWindowPipeline) { // Create a simple stream source auto input_path = "./data/siftsmall/siftsmall_query.fvecs"; - CANDY_LOG_INFO("TEST", "Using SimpleStream path={} ", input_path); + SAGEFLOW_LOG_INFO("TEST", "Using SimpleStream path={} ", input_path); auto source_stream = make_shared("FilePerfSource", input_path); // Atomic counter for processed records @@ -49,7 +49,7 @@ TEST(WindowTest, TumblingWindowPipeline) { auto end_time = high_resolution_clock::now(); processed_count.fetch_add(1, memory_order_relaxed); lock_guard lock(data_mutex); - CANDY_LOG_INFO("TEST", "GET uid={} ", record->uid_); + SAGEFLOW_LOG_INFO("TEST", "GET uid={} ", record->uid_); auto it = start_times.find(record->uid_); if (it != start_times.end()) { auto start_time = it->second; @@ -85,7 +85,7 @@ TEST(WindowTest, SlidingWindowPipeline) { // Create a simple stream source auto input_path = "./data/siftsmall/siftsmall_query.fvecs"; - CANDY_LOG_INFO("TEST", "Using SimpleStream path={} ", input_path); + SAGEFLOW_LOG_INFO("TEST", "Using SimpleStream path={} ", input_path); auto source_stream = make_shared("FilePerfSource", input_path); // Atomic counter for processed records @@ -106,7 +106,7 @@ TEST(WindowTest, SlidingWindowPipeline) { processed_count.fetch_add(1, memory_order_relaxed); lock_guard lock(data_mutex); - CANDY_LOG_INFO("TEST", "GET uid={} ", record->uid_); + SAGEFLOW_LOG_INFO("TEST", "GET uid={} ", record->uid_); auto it = start_times.find(record->uid_); if (it != start_times.end()) { diff --git a/test/UnitTest/test_compute_engine.cpp b/test/UnitTest/test_compute_engine.cpp index 75f7507..b5af055 100644 --- a/test/UnitTest/test_compute_engine.cpp +++ b/test/UnitTest/test_compute_engine.cpp @@ -5,21 +5,21 @@ #include #include -std::unique_ptr makeVec(const std::vector& v) { - auto d = std::make_unique(v.size(), candy::DataType::Float32); +std::unique_ptr makeVec(const std::vector& v) { + auto d = std::make_unique(v.size(), sageFlow::DataType::Float32); std::memcpy(d->data_.get(), v.data(), v.size()*sizeof(float)); return d; } TEST(SimilarityTest, Identical) { - candy::ComputeEngine eng; + sageFlow::ComputeEngine eng; auto a = makeVec({1,2,3}), b = makeVec({1,2,3}); double sim = eng.Similarity(*a,*b, 0.5); EXPECT_DOUBLE_EQ(sim, 1.0); } TEST(SimilarityTest, KnownDist) { - candy::ComputeEngine eng; + sageFlow::ComputeEngine eng; auto a = makeVec({0,0}), b = makeVec({3,4}); double alpha = 0.1; double expected = std::exp(-alpha * 5.0); // 距离为 5 @@ -27,20 +27,20 @@ TEST(SimilarityTest, KnownDist) { } TEST(SimilarityTest, ZeroDim) { - candy::ComputeEngine eng; + sageFlow::ComputeEngine eng; auto a = makeVec({}), b = makeVec({}); EXPECT_DOUBLE_EQ(eng.Similarity(*a,*b,1.0), 1.0); } TEST(SimilarityTest, GreaterSimilarity) { - candy::ComputeEngine eng; + sageFlow::ComputeEngine eng; auto a = makeVec({0,0,0}), b = makeVec({1,1,1}); auto c = makeVec({2,2,2}); EXPECT_GT(eng.Similarity(*a,*b,1.0), eng.Similarity(*a,*c,1.0)); } TEST(DistanceTest, ComputeEuclideanDistance) { - candy::ComputeEngine eng; + sageFlow::ComputeEngine eng; auto a = makeVec({0,0}), b = makeVec({3,4}); EXPECT_DOUBLE_EQ(eng.EuclideanDistance(*a,*b), 5.0); } diff --git a/test/UnitTest/test_data_persistence.cpp b/test/UnitTest/test_data_persistence.cpp new file mode 100644 index 0000000..6b5c98c --- /dev/null +++ b/test/UnitTest/test_data_persistence.cpp @@ -0,0 +1,221 @@ +#include +#include "test_utils/test_data_generator.h" +#include "test_utils/data_source/random_data_source.h" +#include "test_utils/data_source/dataset_data_source.h" +#include "test_utils/data_source/json_data_source.h" +#include "test_utils/data_writer/fvecs_writer.h" +#include "test_utils/data_writer/json_writer.h" +#include +#include +#include + +namespace sageFlow { +namespace test { + +class DataPersistenceTest : public ::testing::Test { +protected: + void SetUp() override { + test_dir_ = "/tmp/sageflow_test_data"; + std::filesystem::create_directories(test_dir_); + } + + void TearDown() override { + // Clean up test files + std::filesystem::remove_all(test_dir_); + } + + std::string test_dir_; +}; + +// Test saving generated data to fvecs format +TEST_F(DataPersistenceTest, SaveToFvecsFormat) { + // Generate test data + TestDataGenerator::Config config; + config.vector_dim = 64; + config.positive_pairs = 10; + config.negative_pairs = 10; + config.random_tail = 20; + config.seed = 42; + + TestDataGenerator generator(config); + auto [records, matches] = generator.generateData(); + + // Save to fvecs + std::string fvecs_path = test_dir_ + "/test_data.fvecs"; + auto writer = std::make_shared(); + bool success = generator.saveGeneratedVectors(fvecs_path, writer); + + EXPECT_TRUE(success); + EXPECT_TRUE(std::filesystem::exists(fvecs_path)); + + // Verify file is not empty + std::ifstream check(fvecs_path, std::ios::binary | std::ios::ate); + EXPECT_TRUE(check.is_open()); + auto file_size = check.tellg(); + EXPECT_GT(file_size, 0); + check.close(); + + // Expected: (10 + 10 + 10) * 2 + 20 = 80 vectors + int expected_vectors = (config.positive_pairs + config.near_threshold_pairs + config.negative_pairs) * 2 + config.random_tail; + EXPECT_EQ(records.size(), expected_vectors); +} + +// Test saving generated data to JSON format +TEST_F(DataPersistenceTest, SaveToJsonFormat) { + // Generate test data + TestDataGenerator::Config config; + config.vector_dim = 32; + config.positive_pairs = 5; + config.negative_pairs = 5; + config.random_tail = 10; + config.seed = 123; + + TestDataGenerator generator(config); + auto [records, matches] = generator.generateData(); + + // Save to JSON + std::string json_path = test_dir_ + "/test_data.json"; + auto writer = std::make_shared(); + bool success = generator.saveGeneratedVectors(json_path, writer); + + EXPECT_TRUE(success); + EXPECT_TRUE(std::filesystem::exists(json_path)); + + // Verify JSON is human-readable + std::ifstream json_file(json_path); + std::string first_line; + std::getline(json_file, first_line); + EXPECT_TRUE(first_line.find("{") != std::string::npos); + json_file.close(); +} + +// Test round-trip: save to fvecs, load back, and verify +TEST_F(DataPersistenceTest, RoundTripFvecs) { + // Generate and save + TestDataGenerator::Config config; + config.vector_dim = 128; + config.positive_pairs = 10; + config.negative_pairs = 10; + config.random_tail = 20; + config.seed = 42; + + TestDataGenerator generator(config); + auto [records, matches] = generator.generateData(); + auto original_vectors = generator.getLastGeneratedVectors(); + + std::string fvecs_path = test_dir_ + "/roundtrip.fvecs"; + auto writer = std::make_shared(); + ASSERT_TRUE(generator.saveGeneratedVectors(fvecs_path, writer)); + + // Load back + DatasetDataSource::Config ds_config; + ds_config.file_path = fvecs_path; + ds_config.expected_dim = 128; + auto data_source = std::make_shared(ds_config); + + // Verify dimension + EXPECT_EQ(data_source->getDimension(), 128); + EXPECT_EQ(data_source->getTotalCount(), original_vectors.size()); + + // Verify vectors match + int count = 0; + while (data_source->hasMore() && count < static_cast(original_vectors.size())) { + auto loaded_vec = data_source->getNextVector(); + ASSERT_EQ(loaded_vec.size(), original_vectors[count].size()); + + // Check values match (with floating point tolerance) + for (size_t i = 0; i < loaded_vec.size(); ++i) { + EXPECT_NEAR(loaded_vec[i], original_vectors[count][i], 1e-5); + } + count++; + } + EXPECT_EQ(count, original_vectors.size()); +} + +// Test round-trip: save to JSON, load back, and verify +TEST_F(DataPersistenceTest, RoundTripJson) { + // Generate and save + TestDataGenerator::Config config; + config.vector_dim = 64; + config.positive_pairs = 5; + config.negative_pairs = 5; + config.random_tail = 10; + config.seed = 999; + + TestDataGenerator generator(config); + auto [records, matches] = generator.generateData(); + auto original_vectors = generator.getLastGeneratedVectors(); + + std::string json_path = test_dir_ + "/roundtrip.json"; + auto writer = std::make_shared(); + ASSERT_TRUE(generator.saveGeneratedVectors(json_path, writer)); + + // Load back + JsonDataSource::Config ds_config; + ds_config.file_path = json_path; + auto data_source = std::make_shared(ds_config); + + // Verify dimension + EXPECT_EQ(data_source->getDimension(), 64); + EXPECT_EQ(data_source->getTotalCount(), original_vectors.size()); + + // Verify vectors match + int count = 0; + while (data_source->hasMore() && count < static_cast(original_vectors.size())) { + auto loaded_vec = data_source->getNextVector(); + ASSERT_EQ(loaded_vec.size(), original_vectors[count].size()); + + // Check values match (JSON has 6 decimal precision) + for (size_t i = 0; i < loaded_vec.size(); ++i) { + EXPECT_NEAR(loaded_vec[i], original_vectors[count][i], 1e-5); + } + count++; + } + EXPECT_EQ(count, original_vectors.size()); +} + +// Test using loaded data with TestDataGenerator +TEST_F(DataPersistenceTest, GenerateFromSavedData) { + // First generate and save some data + TestDataGenerator::Config config1; + config1.vector_dim = 64; + config1.positive_pairs = 10; + config1.negative_pairs = 10; + config1.random_tail = 20; + config1.seed = 42; + + TestDataGenerator generator1(config1); + auto [records1, matches1] = generator1.generateData(); + + std::string fvecs_path = test_dir_ + "/source_data.fvecs"; + auto writer = std::make_shared(); + ASSERT_TRUE(generator1.saveGeneratedVectors(fvecs_path, writer)); + + // Now load that data and use it with TestDataGenerator + DatasetDataSource::Config ds_config; + ds_config.file_path = fvecs_path; + ds_config.expected_dim = 64; + ds_config.loop = true; // Enable looping for reuse + auto data_source = std::make_shared(ds_config); + + TestDataGenerator::Config config2; + config2.similarity_threshold = 0.8; + config2.positive_pairs = 5; + config2.near_threshold_pairs = 0; // Set to 0 to avoid defaults + config2.negative_pairs = 5; + config2.random_tail = 10; + + TestDataGenerator generator2(config2, data_source); + auto [records2, matches2] = generator2.generateData(); + + // Verify records were created + int expected_records2 = (5 + 0 + 5) * 2 + 10; + EXPECT_EQ(records2.size(), expected_records2); + + // All records should use dimension from loaded data + for (const auto& record : records2) { + EXPECT_EQ(record->data_.dim_, 64); + } +} + +}} // namespace sageFlow::test diff --git a/test/UnitTest/test_data_source.cpp b/test/UnitTest/test_data_source.cpp new file mode 100644 index 0000000..f43962b --- /dev/null +++ b/test/UnitTest/test_data_source.cpp @@ -0,0 +1,184 @@ +#include +#include "test_utils/test_data_generator.h" +#include "test_utils/data_source/random_data_source.h" +#include "test_utils/data_source/dataset_data_source.h" +#include +#include + +namespace sageFlow { +namespace test { + +class DataSourceTest : public ::testing::Test { +protected: + void SetUp() override {} + void TearDown() override {} +}; + +// Test RandomDataSource +TEST_F(DataSourceTest, RandomDataSourceBasic) { + RandomDataSource::Config config; + config.vector_dim = 64; + config.seed = 42; + config.max_vectors = 10; + + auto data_source = std::make_shared(config); + + EXPECT_EQ(data_source->getDimension(), 64); + EXPECT_TRUE(data_source->hasMore()); + EXPECT_EQ(data_source->getTotalCount(), 10); + + int count = 0; + while (data_source->hasMore()) { + auto vec = data_source->getNextVector(); + EXPECT_EQ(vec.size(), 64); + + // Check that vector is normalized + float norm = 0.0f; + for (float v : vec) { + norm += v * v; + } + norm = std::sqrt(norm); + EXPECT_NEAR(norm, 1.0f, 1e-5f); + + count++; + } + EXPECT_EQ(count, 10); + + // After reset, should be able to get more vectors + data_source->reset(); + EXPECT_TRUE(data_source->hasMore()); + auto vec = data_source->getNextVector(); + EXPECT_EQ(vec.size(), 64); +} + +// Test DatasetDataSource with siftsmall dataset +TEST_F(DataSourceTest, DatasetDataSourceBasic) { + // Check if the dataset file exists + std::string project_dir = PROJECT_DIR; + std::string dataset_path = project_dir + "/data/siftsmall/siftsmall_query.fvecs"; + std::ifstream test_file(dataset_path); + if (!test_file.good()) { + GTEST_SKIP() << "Dataset file not found: " << dataset_path; + } + test_file.close(); + + DatasetDataSource::Config config; + config.file_path = dataset_path; + config.expected_dim = 128; + config.loop = false; + + auto data_source = std::make_shared(config); + + EXPECT_EQ(data_source->getDimension(), 128); + EXPECT_TRUE(data_source->hasMore()); + EXPECT_GT(data_source->getTotalCount(), 0); + + int initial_count = data_source->getTotalCount(); + std::cout << "Loaded " << initial_count << " vectors from dataset" << std::endl; + + // Get a few vectors + int count = 0; + while (data_source->hasMore() && count < 5) { + auto vec = data_source->getNextVector(); + EXPECT_EQ(vec.size(), 128); + count++; + } + EXPECT_EQ(count, 5); + + // Reset and read again + data_source->reset(); + EXPECT_TRUE(data_source->hasMore()); + auto vec = data_source->getNextVector(); + EXPECT_EQ(vec.size(), 128); +} + +// Test TestDataGenerator with custom data source +TEST_F(DataSourceTest, TestDataGeneratorWithRandomDataSource) { + // Create a random data source + RandomDataSource::Config ds_config; + ds_config.vector_dim = 64; + ds_config.seed = 123; + ds_config.max_vectors = -1; // Unlimited + + auto data_source = std::make_shared(ds_config); + + // Create TestDataGenerator with the data source + TestDataGenerator::Config gen_config; + gen_config.vector_dim = 64; + gen_config.positive_pairs = 10; + gen_config.negative_pairs = 10; + gen_config.near_threshold_pairs = 5; + gen_config.random_tail = 20; + gen_config.similarity_threshold = 0.8; + + TestDataGenerator generator(gen_config, data_source); + + auto [records, expected_matches] = generator.generateData(); + + // Verify data was generated + int expected_records = (10 + 5 + 10) * 2 + 20; + EXPECT_EQ(records.size(), expected_records); + EXPECT_GE(expected_matches.size(), 10); // At least positive pairs +} + +// Test TestDataGenerator with dataset data source +TEST_F(DataSourceTest, TestDataGeneratorWithDatasetDataSource) { + // Check if the dataset file exists + std::string project_dir = PROJECT_DIR; + std::string dataset_path = project_dir + "/data/siftsmall/siftsmall_query.fvecs"; + std::ifstream test_file(dataset_path); + if (!test_file.good()) { + GTEST_SKIP() << "Dataset file not found: " << dataset_path; + } + test_file.close(); + + // Create a dataset data source with looping enabled + DatasetDataSource::Config ds_config; + ds_config.file_path = dataset_path; + ds_config.expected_dim = 128; + ds_config.loop = true; // Enable looping to allow reuse + + auto data_source = std::make_shared(ds_config); + + // Create TestDataGenerator with the data source + TestDataGenerator::Config gen_config; + gen_config.vector_dim = 128; + gen_config.positive_pairs = 5; + gen_config.negative_pairs = 5; + gen_config.near_threshold_pairs = 2; + gen_config.random_tail = 10; + gen_config.similarity_threshold = 0.8; + + TestDataGenerator generator(gen_config, data_source); + + auto [records, expected_matches] = generator.generateData(); + + // Verify data was generated + int expected_records = (5 + 2 + 5) * 2 + 10; + EXPECT_EQ(records.size(), expected_records); + + // All records should have dimension 128 + for (const auto& record : records) { + EXPECT_EQ(record->data_.dim_, 128); + } +} + +// Test backward compatibility - default constructor still works +TEST_F(DataSourceTest, BackwardCompatibility) { + TestDataGenerator::Config config; + config.vector_dim = 64; + config.positive_pairs = 5; + config.negative_pairs = 5; + config.near_threshold_pairs = 2; + config.random_tail = 10; + + // Use default constructor (should create random data source internally) + TestDataGenerator generator(config); + + auto [records, expected_matches] = generator.generateData(); + + int expected_records = (5 + 2 + 5) * 2 + 10; + EXPECT_EQ(records.size(), expected_records); +} + +}} // namespace sageFlow::test diff --git a/test/UnitTest/test_file_stream_source.cpp b/test/UnitTest/test_file_stream_source.cpp index a67e36d..7c784d7 100644 --- a/test/UnitTest/test_file_stream_source.cpp +++ b/test/UnitTest/test_file_stream_source.cpp @@ -6,7 +6,7 @@ #include "stream/data_stream_source/file_stream_source.h" #include "common/data_types.h" -namespace candy { +namespace sageFlow { TEST(FileStreamSourceTest, BasicLoad) { // Create a temporary file with one serialized VectorRecord auto test_file = std::filesystem::temp_directory_path() / "test_source.dat"; @@ -72,4 +72,4 @@ TEST(FileStreamSourceTest, LargeLoad) { } EXPECT_EQ(count, 1500); } -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/test/UnitTest/test_join_bruteforce.cpp b/test/UnitTest/test_join_bruteforce.cpp index 18802b1..6c574d0 100644 --- a/test/UnitTest/test_join_bruteforce.cpp +++ b/test/UnitTest/test_join_bruteforce.cpp @@ -11,7 +11,7 @@ #include "storage/storage_manager.h" #include "execution/collector.h" -namespace candy { +namespace sageFlow { namespace test { class JoinBruteForceTest : public ::testing::Test { @@ -28,7 +28,7 @@ class JoinBruteForceTest : public ::testing::Test { void TearDown() override { if (::testing::Test::HasFailure()) { - CANDY_LOG_WARN("TEST", "BF Test failed. Metrics: WIN={}ns IDX={}ns SIM={}ns ", + SAGEFLOW_LOG_WARN("TEST", "BF Test failed. Metrics: WIN={}ns IDX={}ns SIM={}ns ", JoinMetrics::instance().window_insert_ns.load(), JoinMetrics::instance().index_insert_ns.load(), JoinMetrics::instance().similarity_ns.load()); @@ -264,4 +264,4 @@ INSTANTIATE_TEST_SUITE_P( ); } // namespace test -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/test/UnitTest/test_join_data_source.cpp b/test/UnitTest/test_join_data_source.cpp new file mode 100644 index 0000000..4ea8cac --- /dev/null +++ b/test/UnitTest/test_join_data_source.cpp @@ -0,0 +1,243 @@ +#include +#include "test_utils/join_data_source.h" +#include "test_utils/join_test_helper.h" +#include "test_utils/test_data_generator.h" +#include "test_utils/data_source/random_data_source.h" +#include "test_utils/data_source/dataset_data_source.h" +#include "test_utils/test_data_adapter.h" +#include +#include + +namespace sageFlow { +namespace test { + +class JoinDataSourceTest : public ::testing::Test { +protected: + void SetUp() override {} + void TearDown() override {} +}; + +// Test basic duplication mode +TEST_F(JoinDataSourceTest, DuplicateMode) { + // Create a simple random source + RandomDataSource::Config config; + config.vector_dim = 64; + config.seed = 42; + config.max_vectors = 50; + auto source = std::make_shared(config); + + // Create join pair in duplicate mode + auto join_config = JoinDataSourceFactory::createDuplicated(source, true); + JoinDataSourcePair pair(join_config); + + // Generate streams + auto [left_records, right_records] = pair.generateStreams(); + + // Verify + EXPECT_EQ(left_records.size(), 50); + EXPECT_EQ(right_records.size(), 50); + EXPECT_EQ(pair.getDimension(), 64); + + // Verify UIDs are offset for right stream + for (size_t i = 0; i < left_records.size(); ++i) { + EXPECT_LT(left_records[i]->uid_, right_records[i]->uid_); + // Check that vectors are the same (duplicated) + auto left_vec = extractFloatVector(*left_records[i]); + auto right_vec = extractFloatVector(*right_records[i]); + ASSERT_EQ(left_vec.size(), right_vec.size()); + for (size_t j = 0; j < left_vec.size(); ++j) { + EXPECT_FLOAT_EQ(left_vec[j], right_vec[j]); + } + } +} + +// Test separate sources mode +TEST_F(JoinDataSourceTest, SeparateMode) { + // Create two different random sources with different seeds + RandomDataSource::Config config1; + config1.vector_dim = 32; + config1.seed = 111; + config1.max_vectors = 30; + auto left_source = std::make_shared(config1); + + RandomDataSource::Config config2; + config2.vector_dim = 32; + config2.seed = 222; + config2.max_vectors = 30; + auto right_source = std::make_shared(config2); + + // Create join pair in separate mode + auto join_config = JoinDataSourceFactory::createSeparate( + left_source, right_source, false); + JoinDataSourcePair pair(join_config); + + // Generate streams + auto [left_records, right_records] = pair.generateStreams(); + + // Verify + EXPECT_EQ(left_records.size(), 30); + EXPECT_EQ(right_records.size(), 30); + EXPECT_EQ(pair.getDimension(), 32); + + // Verify vectors are different (separate sources) + bool found_difference = false; + for (size_t i = 0; i < std::min(left_records.size(), right_records.size()); ++i) { + auto left_vec = extractFloatVector(*left_records[i]); + auto right_vec = extractFloatVector(*right_records[i]); + + for (size_t j = 0; j < left_vec.size(); ++j) { + if (std::abs(left_vec[j] - right_vec[j]) > 0.001f) { + found_difference = true; + break; + } + } + if (found_difference) break; + } + EXPECT_TRUE(found_difference) << "Expected different vectors from separate sources"; +} + +// Test helper function for TestDataGenerator (backward compatible) +TEST_F(JoinDataSourceTest, HelperWithTestDataGenerator) { + TestDataGenerator::Config config; + config.vector_dim = 128; + config.positive_pairs = 10; + config.near_threshold_pairs = 0; // Set to 0 to avoid defaults + config.negative_pairs = 10; + config.random_tail = 20; + config.seed = 99; + + TestDataGenerator generator(config); + + // Use helper to generate join streams + auto [left_records, right_records] = + JoinTestHelper::generateJoinStreamsFromGenerator(generator, true); + + // Verify counts + int expected = (10 + 0 + 10) * 2 + 20; // pairs * 2 + tail + EXPECT_EQ(left_records.size(), expected); + EXPECT_EQ(right_records.size(), expected); + + // Verify dimension + for (const auto& rec : left_records) { + EXPECT_EQ(rec->data_.dim_, 128); + } +} + +// Test helper with single source +TEST_F(JoinDataSourceTest, HelperWithSingleSource) { + RandomDataSource::Config config; + config.vector_dim = 64; + config.seed = 777; + config.max_vectors = 25; + auto source = std::make_shared(config); + + // Use helper to generate join streams + auto [left_records, right_records] = + JoinTestHelper::generateJoinStreamsFromSource(source, true, 25); + + EXPECT_EQ(left_records.size(), 25); + EXPECT_EQ(right_records.size(), 25); +} + +// Test helper with separate sources +TEST_F(JoinDataSourceTest, HelperWithSeparateSources) { + RandomDataSource::Config config1; + config1.vector_dim = 32; + config1.seed = 123; + config1.max_vectors = 15; + auto left_source = std::make_shared(config1); + + RandomDataSource::Config config2; + config2.vector_dim = 32; + config2.seed = 456; + config2.max_vectors = 15; + auto right_source = std::make_shared(config2); + + // Use helper + auto [left_records, right_records] = + JoinTestHelper::generateJoinStreamsFromSeparateSources( + left_source, right_source, false, 15); + + EXPECT_EQ(left_records.size(), 15); + EXPECT_EQ(right_records.size(), 15); +} + +// Test with dataset file (if available) +TEST_F(JoinDataSourceTest, WithDatasetSource) { + std::string dataset_path = PROJECT_DIR "/data/siftsmall/siftsmall_query.fvecs"; + std::ifstream test_file(dataset_path); + if (!test_file.good()) { + GTEST_SKIP() << "Dataset file not found: " << dataset_path; + } + test_file.close(); + + // Load dataset + DatasetDataSource::Config config; + config.file_path = dataset_path; + config.expected_dim = 128; + config.loop = false; + auto source = std::make_shared(config); + + // Use with join data source + auto [left_records, right_records] = + JoinTestHelper::generateJoinStreamsFromSource(source, true, 50); + + EXPECT_EQ(left_records.size(), 50); + EXPECT_EQ(right_records.size(), 50); + + // Verify all records have correct dimension + for (const auto& rec : left_records) { + EXPECT_EQ(rec->data_.dim_, 128); + } +} + +// Test max_records limiting +TEST_F(JoinDataSourceTest, MaxRecordsLimit) { + RandomDataSource::Config config; + config.vector_dim = 32; + config.seed = 999; + config.max_vectors = 100; + auto source = std::make_shared(config); + + auto join_config = JoinDataSourceFactory::createDuplicated(source, true); + JoinDataSourcePair pair(join_config); + + // Generate with limit + auto [left_records, right_records] = pair.generateStreams(25); + + EXPECT_EQ(left_records.size(), 25); + EXPECT_EQ(right_records.size(), 25); +} + +// Test reset functionality +TEST_F(JoinDataSourceTest, ResetFunctionality) { + RandomDataSource::Config config; + config.vector_dim = 16; + config.seed = 555; + config.max_vectors = 10; + auto source = std::make_shared(config); + + auto join_config = JoinDataSourceFactory::createDuplicated(source, true); + JoinDataSourcePair pair(join_config); + + // Generate first time + auto [left1, right1] = pair.generateStreams(); + EXPECT_EQ(left1.size(), 10); + + // Reset and generate again + pair.reset(); + auto [left2, right2] = pair.generateStreams(); + EXPECT_EQ(left2.size(), 10); + + // Verify data is the same (same seed, reset source) + for (size_t i = 0; i < left1.size(); ++i) { + auto vec1 = extractFloatVector(*left1[i]); + auto vec2 = extractFloatVector(*left2[i]); + ASSERT_EQ(vec1.size(), vec2.size()); + for (size_t j = 0; j < vec1.size(); ++j) { + EXPECT_FLOAT_EQ(vec1[j], vec2[j]); + } + } +} + +}} // namespace sageFlow::test diff --git a/test/UnitTest/test_join_ivf.cpp b/test/UnitTest/test_join_ivf.cpp index 8242497..a40d293 100644 --- a/test/UnitTest/test_join_ivf.cpp +++ b/test/UnitTest/test_join_ivf.cpp @@ -11,7 +11,7 @@ #include "storage/storage_manager.h" #include "execution/collector.h" -namespace candy { +namespace sageFlow { namespace test { // 通用 JoinFunction 工厂,供本文件所有测试复用 @@ -45,7 +45,7 @@ class JoinIVFTest : public ::testing::Test { void TearDown() override { if (::testing::Test::HasFailure()) { - CANDY_LOG_WARN("TEST", "IVF Test failed. Metrics: IDX={}ns CAND={}ns EMITS={} ", + SAGEFLOW_LOG_WARN("TEST", "IVF Test failed. Metrics: IDX={}ns CAND={}ns EMITS={} ", JoinMetrics::instance().index_insert_ns.load(), JoinMetrics::instance().candidate_fetch_ns.load(), JoinMetrics::instance().total_emits.load()); @@ -130,7 +130,7 @@ TEST_F(JoinIVFTest, IVFLargeScale) { uint64_t end_time = std::chrono::duration_cast( std::chrono::high_resolution_clock::now().time_since_epoch()).count(); - CANDY_LOG_INFO("TEST", "IVF LargeScale duration_ms={} expected={} actual={} ", (end_time - start_time) / 1000000, expected_matches.size(), actual_matches.size()); + SAGEFLOW_LOG_INFO("TEST", "IVF LargeScale duration_ms={} expected={} actual={} ", (end_time - start_time) / 1000000, expected_matches.size(), actual_matches.size()); // 仅验证大规模 pipeline 能正常跑完(不超时不崩溃) SUCCEED() << "IVFLargeScale executed without timeout/crash." @@ -225,7 +225,7 @@ TEST_P(IVFParameterizedTest, ParameterVariations) { actual_matches.insert({left_uid, right_uid}); } - CANDY_LOG_INFO("TEST", "Method={} Dim={} Threshold={} Expected={} Actual={} ", method, vector_dim, threshold, expected_matches.size(), actual_matches.size()); + SAGEFLOW_LOG_INFO("TEST", "Method={} Dim={} Threshold={} Expected={} Actual={} ", method, vector_dim, threshold, expected_matches.size(), actual_matches.size()); // 仅验证参数化场景能正常跑完(不超时不崩溃) SUCCEED() << "IVF ParameterVariations executed without timeout/crash." @@ -244,4 +244,4 @@ INSTANTIATE_TEST_SUITE_P( ); } // namespace test -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/test/examples/data_persistence_example.cpp b/test/examples/data_persistence_example.cpp new file mode 100644 index 0000000..c7e758e --- /dev/null +++ b/test/examples/data_persistence_example.cpp @@ -0,0 +1,225 @@ +// Example: Data Persistence - Generate, Save, and Load Test Data +// +// This example demonstrates: +// 1. Generating test data with TestDataGenerator +// 2. Saving to multiple formats (FVECS and JSON) +// 3. Loading saved data back +// 4. Using loaded data with TestDataGenerator + +#include +#include "test_utils/test_data_generator.h" +#include "test_utils/data_source/random_data_source.h" +#include "test_utils/data_source/dataset_data_source.h" +#include "test_utils/data_source/json_data_source.h" +#include "test_utils/data_writer/fvecs_writer.h" +#include "test_utils/data_writer/json_writer.h" +#include "test_utils/test_data_adapter.h" + +using namespace sageFlow::test; + +void example_save_generated_data() { + std::cout << "\n=== Example 1: Generate and Save Data ===" << std::endl; + + // Configure generator + TestDataGenerator::Config config; + config.vector_dim = 64; + config.positive_pairs = 50; + config.negative_pairs = 50; + config.random_tail = 100; + config.seed = 42; + + // Generate test data + TestDataGenerator generator(config); + std::cout << "Generating test data..." << std::endl; + auto [records, matches] = generator.generateData(); + + std::cout << "Generated " << records.size() << " records" << std::endl; + std::cout << "Expected matches: " << matches.size() << std::endl; + + // Save to FVECS format (binary, efficient) + std::string fvecs_path = "/tmp/test_data.fvecs"; + auto fvecs_writer = std::make_shared(); + if (generator.saveGeneratedVectors(fvecs_path, fvecs_writer)) { + std::cout << "✓ Saved to FVECS: " << fvecs_path << std::endl; + } + + // Save to JSON format (human-readable) + std::string json_path = "/tmp/test_data.json"; + auto json_writer = std::make_shared(); + if (generator.saveGeneratedVectors(json_path, json_writer)) { + std::cout << "✓ Saved to JSON: " << json_path << std::endl; + } +} + +void example_load_from_fvecs() { + std::cout << "\n=== Example 2: Load from FVECS File ===" << std::endl; + + // Load data from FVECS file + DatasetDataSource::Config config; + config.file_path = "/tmp/test_data.fvecs"; + config.expected_dim = 64; + config.loop = false; + + try { + auto data_source = std::make_shared(config); + + std::cout << "Dimension: " << data_source->getDimension() << std::endl; + std::cout << "Total vectors: " << data_source->getTotalCount() << std::endl; + + // Read first 3 vectors + int count = 0; + while (data_source->hasMore() && count < 3) { + auto vec = data_source->getNextVector(); + std::cout << "Vector " << count << " (first 5 components): "; + for (int i = 0; i < 5; ++i) { + std::cout << vec[i] << " "; + } + std::cout << std::endl; + count++; + } + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + } +} + +void example_load_from_json() { + std::cout << "\n=== Example 3: Load from JSON File ===" << std::endl; + + // Load data from JSON file + JsonDataSource::Config config; + config.file_path = "/tmp/test_data.json"; + config.loop = false; + + try { + auto data_source = std::make_shared(config); + + std::cout << "Dimension: " << data_source->getDimension() << std::endl; + std::cout << "Total vectors: " << data_source->getTotalCount() << std::endl; + + // Read first 3 vectors + int count = 0; + while (data_source->hasMore() && count < 3) { + auto vec = data_source->getNextVector(); + std::cout << "Vector " << count << " (first 5 components): "; + for (int i = 0; i < 5; ++i) { + std::cout << vec[i] << " "; + } + std::cout << std::endl; + count++; + } + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + } +} + +void example_reuse_saved_data() { + std::cout << "\n=== Example 4: Reuse Saved Data with TestDataGenerator ===" << std::endl; + + // Load previously saved data + DatasetDataSource::Config ds_config; + ds_config.file_path = "/tmp/test_data.fvecs"; + ds_config.expected_dim = 64; + ds_config.loop = true; // Enable looping for reuse + + try { + auto data_source = std::make_shared(ds_config); + + // Use loaded data to generate new test dataset + TestDataGenerator::Config gen_config; + gen_config.similarity_threshold = 0.8; + gen_config.positive_pairs = 20; + gen_config.near_threshold_pairs = 0; + gen_config.negative_pairs = 20; + gen_config.random_tail = 40; + + TestDataGenerator generator(gen_config, data_source); + auto [records, matches] = generator.generateData(); + + std::cout << "Generated " << records.size() << " records from loaded data" << std::endl; + std::cout << "Expected matches: " << matches.size() << std::endl; + + // Show first few records + int count = 0; + for (const auto& record : records) { + if (count >= 5) break; + auto vec = extractFloatVector(*record); + std::cout << "Record " << count + << " - UID: " << record->uid_ + << ", TS: " << record->timestamp_ + << ", Dim: " << vec.size() << std::endl; + count++; + } + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + } +} + +void example_workflow() { + std::cout << "\n=== Example 5: Complete Workflow ===" << std::endl; + + // Step 1: Generate reference dataset + std::cout << "\nStep 1: Generate reference dataset" << std::endl; + TestDataGenerator::Config ref_config; + ref_config.vector_dim = 128; + ref_config.positive_pairs = 100; + ref_config.negative_pairs = 100; + ref_config.random_tail = 200; + ref_config.seed = 12345; // Fixed seed for reproducibility + + TestDataGenerator ref_generator(ref_config); + ref_generator.generateData(); + + std::string ref_path = "/tmp/reference_dataset_v1.fvecs"; + auto writer = std::make_shared(); + ref_generator.saveGeneratedVectors(ref_path, writer); + std::cout << "✓ Saved reference dataset: " << ref_path << std::endl; + + // Step 2: Use reference dataset in test + std::cout << "\nStep 2: Load and use reference dataset" << std::endl; + DatasetDataSource::Config load_config; + load_config.file_path = ref_path; + load_config.expected_dim = 128; + load_config.loop = true; + + auto ref_source = std::make_shared(load_config); + + TestDataGenerator::Config test_config; + test_config.positive_pairs = 50; + test_config.near_threshold_pairs = 0; + test_config.negative_pairs = 50; + test_config.random_tail = 100; + + TestDataGenerator test_generator(test_config, ref_source); + auto [test_records, test_matches] = test_generator.generateData(); + + std::cout << "✓ Generated " << test_records.size() + << " test records from reference dataset" << std::endl; + + // Step 3: Save test-specific variant + std::cout << "\nStep 3: Save test-specific variant for debugging" << std::endl; + std::string debug_path = "/tmp/test_variant_debug.json"; + auto json_writer = std::make_shared(); + test_generator.saveGeneratedVectors(debug_path, json_writer); + std::cout << "✓ Saved debug variant: " << debug_path << std::endl; + std::cout << " (You can inspect this JSON file to debug test failures)" << std::endl; +} + +int main() { + std::cout << "Data Persistence Examples" << std::endl; + std::cout << "=========================" << std::endl; + + example_save_generated_data(); + example_load_from_fvecs(); + example_load_from_json(); + example_reuse_saved_data(); + example_workflow(); + + std::cout << "\n✓ All examples completed!" << std::endl; + std::cout << "\nGenerated files in /tmp:" << std::endl; + std::cout << " - test_data.fvecs (binary format)" << std::endl; + std::cout << " - test_data.json (human-readable format)" << std::endl; + std::cout << " - reference_dataset_v1.fvecs (reference dataset)" << std::endl; + std::cout << " - test_variant_debug.json (debug variant)" << std::endl; + + return 0; +} diff --git a/test/examples/test_data_source_example.cpp b/test/examples/test_data_source_example.cpp new file mode 100644 index 0000000..0d91496 --- /dev/null +++ b/test/examples/test_data_source_example.cpp @@ -0,0 +1,146 @@ +// Example: Using DatasetDataSource to generate test data from real datasets +// +// This example demonstrates how to use the data source framework to: +// 1. Load vectors from a dataset file +// 2. Generate test data with dataset vectors +// 3. Use the data with join operators + +#include +#include "test_utils/test_data_generator.h" +#include "test_utils/data_source/random_data_source.h" +#include "test_utils/data_source/dataset_data_source.h" +#include "test_utils/test_data_adapter.h" + +using namespace sageFlow::test; + +void example_random_data_source() { + std::cout << "\n=== Example 1: Random Data Source ===" << std::endl; + + // Configure random data source + RandomDataSource::Config ds_config; + ds_config.vector_dim = 64; + ds_config.seed = 42; + ds_config.max_vectors = 100; + + auto data_source = std::make_shared(ds_config); + + std::cout << "Dimension: " << data_source->getDimension() << std::endl; + std::cout << "Total vectors: " << data_source->getTotalCount() << std::endl; + + // Get first 5 vectors + int count = 0; + while (data_source->hasMore() && count < 5) { + auto vec = data_source->getNextVector(); + std::cout << "Vector " << count << " (first 5 components): "; + for (int i = 0; i < 5; ++i) { + std::cout << vec[i] << " "; + } + std::cout << std::endl; + count++; + } +} + +void example_dataset_data_source() { + std::cout << "\n=== Example 2: Dataset Data Source ===" << std::endl; + + // Configure dataset data source + DatasetDataSource::Config ds_config; + ds_config.file_path = PROJECT_DIR "/data/siftsmall/siftsmall_query.fvecs"; + ds_config.expected_dim = 128; + ds_config.loop = false; + + try { + auto data_source = std::make_shared(ds_config); + + std::cout << "Dimension: " << data_source->getDimension() << std::endl; + std::cout << "Total vectors: " << data_source->getTotalCount() << std::endl; + + // Get first 3 vectors + int count = 0; + while (data_source->hasMore() && count < 3) { + auto vec = data_source->getNextVector(); + std::cout << "Vector " << count << " (first 5 components): "; + for (int i = 0; i < 5; ++i) { + std::cout << vec[i] << " "; + } + std::cout << std::endl; + count++; + } + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + } +} + +void example_test_data_generator_with_dataset() { + std::cout << "\n=== Example 3: TestDataGenerator with Dataset ===" << std::endl; + + // Create dataset data source + DatasetDataSource::Config ds_config; + ds_config.file_path = PROJECT_DIR "/data/siftsmall/siftsmall_query.fvecs"; + ds_config.expected_dim = 128; + ds_config.loop = true; // Enable looping to allow reuse + + try { + auto data_source = std::make_shared(ds_config); + + // Configure test data generator + TestDataGenerator::Config config; + config.positive_pairs = 5; + config.negative_pairs = 5; + config.near_threshold_pairs = 2; + config.random_tail = 10; + config.similarity_threshold = 0.8; + + // Create generator with dataset source + TestDataGenerator generator(config, data_source); + + // Generate test data + auto [records, expected_matches] = generator.generateData(); + + std::cout << "Generated " << records.size() << " records" << std::endl; + std::cout << "Expected matches: " << expected_matches.size() << std::endl; + + // Show first few records + int count = 0; + for (const auto& record : records) { + if (count >= 5) break; + auto vec = extractFloatVector(*record); + std::cout << "Record " << count << " - UID: " << record->uid_ + << ", TS: " << record->timestamp_ + << ", Dim: " << vec.size() << std::endl; + count++; + } + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + } +} + +void example_backward_compatibility() { + std::cout << "\n=== Example 4: Backward Compatibility ===" << std::endl; + + // Old way - still works! + TestDataGenerator::Config config; + config.vector_dim = 64; + config.positive_pairs = 10; + config.negative_pairs = 10; + config.random_tail = 20; + + TestDataGenerator generator(config); + auto [records, expected_matches] = generator.generateData(); + + std::cout << "Generated " << records.size() << " records (old way)" << std::endl; + std::cout << "Expected matches: " << expected_matches.size() << std::endl; +} + +int main() { + std::cout << "Data Source Framework Examples" << std::endl; + std::cout << "===============================" << std::endl; + + example_random_data_source(); + example_dataset_data_source(); + example_test_data_generator_with_dataset(); + example_backward_compatibility(); + + std::cout << "\nAll examples completed!" << std::endl; + return 0; +} diff --git a/test/result/analyze.py b/test/result/analyze.py new file mode 100644 index 0000000..eed0a1a --- /dev/null +++ b/test/result/analyze.py @@ -0,0 +1,392 @@ +import argparse +import urllib.request +from pathlib import Path + +import matplotlib + +# 必须在导入 pyplot 前设置无显卡后端 +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib import font_manager as fm +from matplotlib.font_manager import FontProperties +from matplotlib.patches import Patch + +PARALLELISM_ORDER = [1, 2, 4, 8, 16, 32, 40] +METHOD_ORDER = ["ivf_eager", "bruteforce_eager"] + +def _setup_chinese_font(): + """ + 优先使用系统已装中文字体;若都没有则自动下载 NotoSansSC-Regular 并注册使用。 + """ + candidates = [ + "Noto Sans CJK SC", "Noto Sans SC", "Source Han Sans CN", + "WenQuanYi Zen Hei", "SimHei", "Microsoft YaHei", "Sarasa UI SC", + ] + available = {f.name for f in fm.fontManager.ttflist} + + chosen_name = None + for name in candidates: + if name in available: + chosen_name = name + break + + if not chosen_name: + # 尝试使用项目内字体(若不存在则下载) + fonts_dir = Path(__file__).parent / ".fonts" + fonts_dir.mkdir(parents=True, exist_ok=True) + local_font = fonts_dir / "NotoSansSC-Regular.otf" + if not local_font.exists(): + url = "https://github.com/googlefonts/noto-cjk/raw/main/Sans/OTF/SimplifiedChinese/NotoSansCJKsc-Regular.otf" + try: + # 有些发行版对文件名敏感,下载后统一命名为 NotoSansSC-Regular.otf + tmp_path = fonts_dir / "NotoSansCJKsc-Regular.otf" + urllib.request.urlretrieve(url, tmp_path) + tmp_path.rename(local_font) + except Exception as e: + # 下载失败则放弃自动下载 + print(f"[warn] 下载中文字体失败:{e}") + + if local_font.exists(): + try: + fm.fontManager.addfont(str(local_font)) + chosen_name = FontProperties(fname=str(local_font)).get_name() + except Exception as e: + print(f"[warn] 注册本地字体失败:{e}") + + if chosen_name: + plt.rcParams["font.family"] = "sans-serif" + plt.rcParams["font.sans-serif"] = [chosen_name] + list(plt.rcParams.get("font.sans-serif", [])) + else: + print("[warn] 未找到可用中文字体,文本可能显示为方框。建议安装 fonts-noto-cjk。") + + # 解决负号显示为方块 + plt.rcParams["axes.unicode_minus"] = False + return chosen_name + +def make_grouped_bar(pivot_df: pd.DataFrame, ylabel: str, title: str, out_path: Path): + _setup_chinese_font() + # 只保留有数据的列(方法) + pivot_df = pivot_df[[c for c in METHOD_ORDER if c in pivot_df.columns]] + + x = np.arange(len(pivot_df.index)) + n_methods = len(pivot_df.columns) + width = 0.35 if n_methods == 2 else 0.6 / max(n_methods, 1) + offsets = (np.arange(n_methods) - (n_methods - 1) / 2.0) * width + + fig, ax = plt.subplots(figsize=(10, 5)) + colors = { + "ivf_eager": "#1f77b4", # 蓝 + "bruteforce_eager": "#ff7f0e", # 橙 + } + + for i, method in enumerate(pivot_df.columns): + y = pivot_df[method].values + ax.bar(x + offsets[i], y, width, label=method, color=colors.get(method, None)) + + ax.set_xlabel("并行度") + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.set_xticks(x, [str(p) for p in pivot_df.index]) + ax.legend() + ax.grid(axis="y", linestyle="--", alpha=0.3) + fig.tight_layout() + out_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out_path, dpi=150) + plt.close(fig) + +def make_grouped_stacked_100(lock_df: pd.DataFrame, exec_df: pd.DataFrame, title: str, out_path: Path): + """绘制100%分层柱状图:每个柱子= 有效执行% + 锁等待%(两者相加为100)。 + + 参数: + - lock_df: 行为并行度,列为方法,值为锁等待百分比 + - exec_df: 行为并行度,列为方法,值为有效执行百分比 + """ + _setup_chinese_font() + + # 对齐列顺序和索引顺序 + methods = [c for c in METHOD_ORDER if c in lock_df.columns or c in exec_df.columns] + lock_df = lock_df.reindex(columns=methods) + exec_df = exec_df.reindex(columns=methods) + lock_df = lock_df.reindex(PARALLELISM_ORDER).dropna(how="all") + exec_df = exec_df.reindex(PARALLELISM_ORDER).dropna(how="all") + + x = np.arange(len(lock_df.index)) + n_methods = len(methods) + width = 0.35 if n_methods == 2 else 0.6 / max(n_methods, 1) + offsets = (np.arange(n_methods) - (n_methods - 1) / 2.0) * width + + fig, ax = plt.subplots(figsize=(11, 5.5)) + colors = { + "exec": "#2ca02c", # 绿色 有效执行 + "lock": "#d62728", # 红色 锁等待 + } + hatches = { + "ivf_eager": "//", + "bruteforce_eager": "\\\\", + } + + for i, method in enumerate(methods): + exec_vals = exec_df[method].values if method in exec_df.columns else np.zeros(len(x)) + lock_vals = lock_df[method].values if method in lock_df.columns else np.zeros(len(x)) + # 底层:有效执行 + ax.bar( + x + offsets[i], exec_vals, width, + label=None if i else "有效执行", + color=colors["exec"], + edgecolor="#333333", linewidth=0.6, hatch=hatches.get(method, None) + ) + # 顶层:锁等待 + ax.bar( + x + offsets[i], lock_vals, width, bottom=exec_vals, + label=None if i else "锁等待", + color=colors["lock"], + edgecolor="#333333", linewidth=0.6, hatch=hatches.get(method, None) + ) + + ax.set_xlabel("并行度") + ax.set_ylabel("百分比 (%)") + ax.set_title(title) + ax.set_xticks(x) + ax.set_xticklabels([str(p) for p in lock_df.index]) + # 图例:组件(有效执行/锁等待)+ 方法(ivf/bruteforce) + comp_handles = [ + Patch(facecolor=colors["exec"], edgecolor="#333333", label="有效执行"), + Patch(facecolor=colors["lock"], edgecolor="#333333", label="锁等待"), + ] + legend1 = ax.legend(handles=comp_handles, loc="upper right", title="组成") + method_handles = [] + method_labels_map = {"ivf_eager": "IVF", "bruteforce_eager": "BruteForce"} + for m in methods: + method_handles.append(Patch(facecolor="#dddddd", edgecolor="#333333", hatch=hatches.get(m, None), label=method_labels_map.get(m, m))) + ax.add_artist(legend1) + ax.legend(handles=method_handles, loc="upper left", title="方法") + ax.set_ylim(0, 100) + ax.grid(axis="y", linestyle="--", alpha=0.3) + fig.tight_layout() + out_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out_path, dpi=150) + plt.close(fig) + +def make_grouped_stacked_multi(parts: list, title: str, out_path: Path): + """绘制按方法分组的多段100%堆叠柱状图。 + parts: 列表[(name, df, color)],df 形状为 (parallelism x method) 的百分比值(0-100)。 + """ + _setup_chinese_font() + + # 对齐列与索引 + if not parts: + return + methods = [c for c in METHOD_ORDER if c in parts[0][1].columns] + parts[0][1].index + for i in range(len(parts)): + parts[i] = (parts[i][0], parts[i][1].reindex(index=PARALLELISM_ORDER).reindex(columns=methods), parts[i][2]) + + x = np.arange(len(PARALLELISM_ORDER)) + n_methods = len(methods) + width = 0.35 if n_methods == 2 else 0.6 / max(n_methods, 1) + offsets = (np.arange(n_methods) - (n_methods - 1) / 2.0) * width + + fig, ax = plt.subplots(figsize=(12, 6)) + hatches = {"ivf_eager": "//", "bruteforce_eager": "\\\\"} + + for i, method in enumerate(methods): + bottom = np.zeros(len(PARALLELISM_ORDER)) + for name, dfp, color in parts: + vals = dfp[method].values if method in dfp.columns else np.zeros(len(PARALLELISM_ORDER)) + ax.bar( + x + offsets[i], vals, width, bottom=bottom, + label=name if (i == 0) else None, + color=color, edgecolor="#333333", linewidth=0.6, hatch=hatches.get(method, None) + ) + bottom = bottom + vals + + ax.set_xlabel("并行度") + ax.set_ylabel("百分比 (%)") + ax.set_title(title) + ax.set_xticks(x) + ax.set_xticklabels([str(p) for p in PARALLELISM_ORDER]) + ax.set_ylim(0, 100) + ax.grid(axis="y", linestyle="--", alpha=0.3) + + # 图例:阶段 + 方法 + legend1 = ax.legend(loc="upper right", title="阶段") + ax.add_artist(legend1) + method_labels_map = {"ivf_eager": "IVF", "bruteforce_eager": "BruteForce"} + method_handles = [Patch(facecolor="#dddddd", edgecolor="#333333", hatch=hatches.get(m, None), label=method_labels_map.get(m, m)) for m in methods] + ax.legend(handles=method_handles, loc="upper left", title="方法") + + fig.tight_layout() + out_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(out_path, dpi=150) + plt.close(fig) + +def main(): + # Use paths relative to the script's location for portability + script_dir = Path(__file__).resolve().parent + project_root = script_dir.parent.parent # Adjust as needed for your project structure + default_input = project_root / "test" / "result" / "perf_report.tsv" + default_outdir = project_root / "test" / "result" / "plots" + + parser = argparse.ArgumentParser(description="生成 size=4000 的分组柱状图(耗时/吞吐量)") + parser.add_argument("--input", "-i", type=str, + default=str(default_input), + help="perf_report.tsv 文件路径") + parser.add_argument("--outdir", "-o", type=str, + default=str(default_outdir), + help="输出图片目录") + args = parser.parse_args() + + in_path = Path(args.input) + out_dir = Path(args.outdir) + + if not in_path.exists(): + raise FileNotFoundError(f"未找到输入文件: {in_path}") + + df = pd.read_csv(in_path, sep="\t") + + # 仅保留 size=4000 且并行度在指定集合内的数据 + df_4k = df[(df["size"] == 4000) & (df["parallelism"].isin(PARALLELISM_ORDER))].copy() + + # 只取必要列(含锁等待占比计算需要的列) + cols = [ + "method", "parallelism", "time_ms", "output_tput_rps", + "lock_wait_ms", "window_ns", "index_ns", "sim_ns", "candidate_fetch_ns", # 注意:这里使用 candidate_fetch_ns 作为 candidate_ns + ] + missing = [c for c in cols if c not in df.columns] + if missing: + raise ValueError(f"缺少列: {missing}") + + if df_4k.empty: + print("警告:未找到 size=4000 的数据,图表将为空。请确认 perf_report.tsv。") + + # 按并行度排序 + df_4k["parallelism"] = pd.Categorical(df_4k["parallelism"], PARALLELISM_ORDER, ordered=True) + + # 为 time_ms 生成分组柱状图 + piv_time = df_4k.pivot_table(index="parallelism", columns="method", values="time_ms", aggfunc="first") + piv_time = piv_time.reindex(PARALLELISM_ORDER).dropna(how="all") + make_grouped_bar( + piv_time, + ylabel="耗时 (ms)", + title="size=4000: 耗时对比(越低越好)", + out_path=out_dir / "size4000_time_ms.png", + ) + + # 为 output_tput_rps 生成分组柱状图 + piv_tput = df_4k.pivot_table(index="parallelism", columns="method", values="output_tput_rps", aggfunc="first") + piv_tput = piv_tput.reindex(PARALLELISM_ORDER).dropna(how="all") + make_grouped_bar( + piv_tput, + ylabel="吞吐量 (RPS)", + title="size=4000: 吞吐量对比(越高越好)", + out_path=out_dir / "size4000_output_tput_rps.png", + ) + + # 计算 breakdown: + # compute_ms 采用 join 的所有阶段:window_ns + index_ns + sim_ns + joinF_ns + emit_ns + candidate_fetch_ns (+ 可选 expire_ns) + def col_or_zero(name: str): + return df_4k[name].astype(float) if name in df_4k.columns else 0.0 + compute_ms = ( + col_or_zero("window_ns") + + col_or_zero("index_ns") + + col_or_zero("sim_ns") + + col_or_zero("joinF_ns") + + col_or_zero("emit_ns") + + col_or_zero("candidate_fetch_ns") + + col_or_zero("expire_ns") + ) / 1e6 + with np.errstate(divide="ignore", invalid="ignore"): + # 用户要求:锁占比 = lock_wait_ms / compute_ms + lock_wait_pct = np.where(compute_ms > 0, df_4k["lock_wait_ms"].astype(float) / compute_ms * 100.0, np.nan) + # 100%堆叠图:以 lock/compute 为准 + lock_pct_100 = lock_wait_pct.copy() + exec_pct_100 = 100.0 - lock_pct_100 + total_with_lock_ms = compute_ms + df_4k["lock_wait_ms"].astype(float) + + df_4k = df_4k.assign( + compute_ms=compute_ms, + total_with_lock_ms=total_with_lock_ms, + lock_wait_pct=lock_wait_pct, + lock_pct_100=lock_pct_100, + exec_pct_100=exec_pct_100, + ) + + # 为 锁等待占比(单值)生成分组柱状图(按 lock/compute) + piv_lock = df_4k.pivot_table(index="parallelism", columns="method", values="lock_wait_pct", aggfunc="first") + piv_lock = piv_lock.reindex(PARALLELISM_ORDER).dropna(how="all") + make_grouped_bar( + piv_lock, + ylabel="锁等待占比 (%)", + title="size=4000: 锁等待占比(越低越好)", + out_path=out_dir / "size4000_lock_wait_pct.png", + ) + + # 100%分层柱状图:锁等待 vs 有效执行 + piv_lock100 = df_4k.pivot_table(index="parallelism", columns="method", values="lock_pct_100", aggfunc="first") + piv_exec100 = df_4k.pivot_table(index="parallelism", columns="method", values="exec_pct_100", aggfunc="first") + make_grouped_stacked_100( + piv_lock100, + piv_exec100, + title="size=4000: 锁等待 vs 有效执行(100%堆叠)", + out_path=out_dir / "size4000_lock_breakdown_pct.png", + ) + + # 生成四阶段(window/index/sim/candidate_fetch)的100%堆叠图 + # 以这四项之和为分母 + denom4 = ( + df_4k["window_ns"].astype(float) + df_4k["index_ns"].astype(float) + + df_4k["sim_ns"].astype(float) + df_4k["candidate_fetch_ns"].astype(float) + ) + with np.errstate(divide="ignore", invalid="ignore"): + w_pct = np.where(denom4 > 0, df_4k["window_ns"].astype(float) / denom4 * 100.0, np.nan) + i_pct = np.where(denom4 > 0, df_4k["index_ns"].astype(float) / denom4 * 100.0, np.nan) + s_pct = np.where(denom4 > 0, df_4k["sim_ns"].astype(float) / denom4 * 100.0, np.nan) + c_pct = np.where(denom4 > 0, df_4k["candidate_fetch_ns"].astype(float) / denom4 * 100.0, np.nan) + df_4k = df_4k.assign( + window_pct_4=w_pct, index_pct_4=i_pct, sim_pct_4=s_pct, candidate_fetch_pct_4=c_pct + ) + piv_w = df_4k.pivot_table(index="parallelism", columns="method", values="window_pct_4", aggfunc="first").reindex(PARALLELISM_ORDER) + piv_i = df_4k.pivot_table(index="parallelism", columns="method", values="index_pct_4", aggfunc="first").reindex(PARALLELISM_ORDER) + piv_s = df_4k.pivot_table(index="parallelism", columns="method", values="sim_pct_4", aggfunc="first").reindex(PARALLELISM_ORDER) + piv_c = df_4k.pivot_table(index="parallelism", columns="method", values="candidate_fetch_pct_4", aggfunc="first").reindex(PARALLELISM_ORDER) + + parts = [ + ("窗口(window)", piv_w, "#1f77b4"), + ("索引(index)", piv_i, "#ff7f0e"), + ("相似度(sim)", piv_s, "#2ca02c"), + ("候选抓取(candidate)", piv_c, "#9467bd"), + ] + make_grouped_stacked_multi( + parts, + title="size=4000: 计算阶段占比(100%堆叠)", + out_path=out_dir / "size4000_compute_breakdown_pct.png", + ) + + # 导出汇总 + # 将计算得到的指标追加到汇总 + summary_cols = cols + [ + "compute_ms", + "total_with_lock_ms", + "lock_wait_pct", + "lock_pct_100", + "exec_pct_100", + ] + summary = df_4k.sort_values(["parallelism", "method"])[summary_cols] + out_dir.mkdir(parents=True, exist_ok=True) + summary.to_csv(out_dir / "size4000_summary.tsv", sep="\t", index=False) + + print( + "已生成:\n- {}\n- {}\n- {}\n- {}\n- {}\n- {}".format( + out_dir / "size4000_time_ms.png", + out_dir / "size4000_output_tput_rps.png", + out_dir / "size4000_lock_wait_pct.png", + out_dir / "size4000_lock_breakdown_pct.png", + out_dir / "size4000_compute_breakdown_pct.png", + out_dir / "size4000_summary.tsv", + ) + ) + +if __name__ == "__main__": + main() diff --git a/test/test_utils/JOIN_DATA_SOURCE_GUIDE.md b/test/test_utils/JOIN_DATA_SOURCE_GUIDE.md new file mode 100644 index 0000000..50c2d4a --- /dev/null +++ b/test/test_utils/JOIN_DATA_SOURCE_GUIDE.md @@ -0,0 +1,394 @@ +# Join Data Source Framework + +## Overview + +The Join Data Source Framework provides a flexible, modular architecture for generating test data for join operations. It extracts the data generation logic from individual test files into reusable components that support multiple data source strategies. + +## Problem Statement + +Previously, join tests followed a common pattern: +1. Generate data with `TestDataGenerator` +2. Duplicate the data to create left and right streams +3. Apply UID offsets to distinguish streams +4. Feed to join operators + +This pattern was repeated across multiple test files with slight variations, making it: +- **Repetitive** - Same logic duplicated in many places +- **Inflexible** - Hard to test with different data distributions +- **Coupled** - Data generation tied to test implementation + +## Solution + +The Join Data Source Framework provides: +- **`JoinDataSourcePair`** - Manages creation of left/right streams from various sources +- **`JoinDataSourceConfig`** - Configuration for different streaming strategies +- **`JoinTestHelper`** - Convenience functions for common patterns +- **`JoinDataSourceFactory`** - Factory methods for standard configurations + +## Architecture + +``` +┌──────────────────────────────────────────────────────────┐ +│ Join Test (test_join_*.cpp) │ +└────────────────────────┬─────────────────────────────────┘ + │ + ┌───────────────┴───────────────┐ + │ │ + ┌────▼──────────┐ ┌───────▼────────┐ + │ JoinTestHelper│ │JoinDataSourcePair│ + └────┬──────────┘ └───────┬────────┘ + │ │ + └───────────────┬───────────────┘ + │ + ┌──────────▼──────────┐ + │ JoinDataSourceConfig│ + └──────────┬──────────┘ + │ + ┌─────────────────┼─────────────────┐ + │ │ │ + ┌────▼────┐ ┌──────▼──────┐ ┌─────▼─────────┐ + │Duplicate│ │ Separate │ │DataSourceBase │ + │ Mode │ │ Mode │ │ Implementations│ + └─────────┘ └─────────────┘ └───────────────┘ + • RandomDataSource + • DatasetDataSource + • VectorListSource +``` + +## Modes + +### Mode 1: Duplicate Mode (Default) + +Duplicates a single data source to both left and right streams. This mode: +- Uses ONE data source for both streams +- Applies UID offset to distinguish right stream from left +- Works with ANY DataSourceBase implementation (Random, Dataset, VectorList, etc.) +- Backward compatible with existing tests + +### Mode 2: Separate Mode + +Uses different data sources for left and right streams. This mode: +- Uses TWO independent data sources +- Allows testing joins with different data distributions +- Optional UID offset for right stream +- Requires both sources to have the same vector dimension + + +## Usage + +### Using with TestDataGenerator (Backward Compatible) + +The most common pattern - generate data and duplicate to both streams: + +```cpp +#include "test_utils/join_test_helper.h" +#include "test_utils/test_data_generator.h" + +// Generate data as before +TestDataGenerator::Config config; +config.vector_dim = 128; +config.positive_pairs = 50; +config.negative_pairs = 100; + +TestDataGenerator generator(config); + +// NEW: Use helper to create join streams (automatically uses Duplicate mode) +auto [left_records, right_records] = + JoinTestHelper::generateJoinStreamsFromGenerator(generator); + +// Use with join operator (same as before) +for (auto& rec : left_records) { + // Process left stream +} +for (auto& rec : right_records) { + // Process right stream +} +``` + +**How it works:** +1. Generator creates vectors +2. Vectors are wrapped in a `VectorListSource` +3. Source is duplicated to both streams in Duplicate mode +4. UID offset applied to right stream + +### Duplicate Mode: From Any Source + +Test with a specific dataset or pattern by duplicating one source: + +```cpp +#include "test_utils/join_test_helper.h" +#include "test_utils/data_source/dataset_data_source.h" + +// Load a dataset +DatasetDataSource::Config ds_config; +ds_config.file_path = "data/siftsmall/siftsmall_query.fvecs"; +ds_config.expected_dim = 128; +auto source = std::make_shared(ds_config); + +// Generate join streams from dataset (Duplicate mode) +auto [left_records, right_records] = + JoinTestHelper::generateJoinStreamsFromSource(source); + +// Both streams contain same data (from dataset) +// Right stream UIDs are offset by default +``` + +### Separate Mode: Different Sources + +Test join with different data distributions on each side: + +```cpp +#include "test_utils/join_test_helper.h" +#include "test_utils/data_source/random_data_source.h" + +// Create different sources +RandomDataSource::Config left_config; +left_config.vector_dim = 128; +left_config.seed = 111; // Different seed +auto left_source = std::make_shared(left_config); + +RandomDataSource::Config right_config; +right_config.vector_dim = 128; +right_config.seed = 222; // Different seed +auto right_source = std::make_shared(right_config); + +// Generate from separate sources +auto [left_records, right_records] = + JoinTestHelper::generateJoinStreamsFromSeparateSources( + left_source, right_source); + +// Left and right streams have different data distributions +``` + +### Mode 4: Advanced Configuration + +For fine-grained control, use `JoinDataSourcePair` directly: + +```cpp +#include "test_utils/join_data_source.h" + +// Create custom configuration +JoinDataSourceConfig config; +config.mode = JoinDataSourceConfig::Mode::Duplicate; +config.single_source = my_source; +config.apply_right_uid_offset = false; // No UID offset +config.right_uid_offset = 1000000; // Custom offset if enabled +config.base_timestamp = 2000000; // Custom timestamps +config.time_interval = 50; // Custom intervals + +// Create pair and generate +JoinDataSourcePair pair(config); +auto [left, right] = pair.generateStreams(100); // Limit to 100 records +``` + +## API Reference + +### JoinTestHelper + +Helper functions for common patterns: + +```cpp +class JoinTestHelper { + // Generate from TestDataGenerator (backward compatible) + static pair, vector> + generateJoinStreamsFromGenerator( + TestDataGenerator& generator, + bool apply_uid_offset = true); + + // Generate from single source (duplicate mode) + static pair, vector> + generateJoinStreamsFromSource( + shared_ptr source, + bool apply_uid_offset = true, + size_t max_records = 0); + + // Generate from separate sources + static pair, vector> + generateJoinStreamsFromSeparateSources( + shared_ptr left_source, + shared_ptr right_source, + bool apply_uid_offset = false, + size_t max_records = 0); +}; +``` + +### JoinDataSourceFactory + +Factory methods for standard configurations: + +```cpp +class JoinDataSourceFactory { + // Create duplicate mode config + static JoinDataSourceConfig createDuplicated( + shared_ptr source, + bool apply_uid_offset = true); + + // Create separate mode config + static JoinDataSourceConfig createSeparate( + shared_ptr left_source, + shared_ptr right_source, + bool apply_uid_offset = false); + + // Create generated mode config (backward compatible) + static JoinDataSourceConfig createGenerated( + shared_ptr source, + bool apply_uid_offset = true); +}; +``` + +### JoinDataSourcePair + +Main class for generating join streams: + +```cpp +class JoinDataSourcePair { + explicit JoinDataSourcePair(const JoinDataSourceConfig& config); + + // Generate left and right streams + pair, vector> + generateStreams(size_t max_records = 0); + + // Get dimension + int getDimension() const; + + // Get total count + int getTotalCount() const; + + // Reset to beginning + void reset(); +}; +``` + +## Configuration Options + +### JoinDataSourceConfig + +```cpp +struct JoinDataSourceConfig { + enum class Mode { + Duplicate, // Same source duplicated to both sides + Separate, // Different sources for left/right + Generated // TestDataGenerator mode (backward compatible) + }; + + Mode mode = Mode::Generated; + + // For Duplicate/Generated mode + shared_ptr single_source; + + // For Separate mode + shared_ptr left_source; + shared_ptr right_source; + + // Common options + bool apply_right_uid_offset = true; // Offset right UIDs + uint64_t right_uid_offset = 500000; // Default offset value + int64_t base_timestamp = 1000000; // Starting timestamp + int64_t time_interval = 100; // Time increment +}; +``` + +## Migration Guide + +### Updating Existing Tests + +**Before:** +```cpp +TestDataGenerator generator(config); +auto [records, _] = generator.generateData(); + +// Manual duplication +std::vector> left_records; +for (auto& r : records) { + left_records.push_back(std::make_unique(*r)); +} + +std::vector> right_records; +constexpr uint64_t kOffset = 500000; +for (auto& r : records) { + right_records.push_back(std::make_unique( + r->uid_ + kOffset, r->timestamp_, r->data_)); +} +``` + +**After:** +```cpp +TestDataGenerator generator(config); +auto [left_records, right_records] = + JoinTestHelper::generateJoinStreamsFromGenerator(generator); +``` + +### Adding New Test Scenarios + +**Test with real dataset:** +```cpp +TEST_F(MyJoinTest, WithSIFTDataset) { + DatasetDataSource::Config config; + config.file_path = "data/siftsmall/siftsmall_query.fvecs"; + config.expected_dim = 128; + auto source = std::make_shared(config); + + auto [left, right] = + JoinTestHelper::generateJoinStreamsFromSource(source, true, 50); + + // Test join with real vectors + testJoinOperation(left, right); +} +``` + +**Test with asymmetric distributions:** +```cpp +TEST_F(MyJoinTest, AsymmetricDistributions) { + // Dense left, sparse right + auto left_source = createDenseVectorSource(); + auto right_source = createSparseVectorSource(); + + auto [left, right] = + JoinTestHelper::generateJoinStreamsFromSeparateSources( + left_source, right_source); + + // Test how join handles different distributions + testJoinOperation(left, right); +} +``` + +## Benefits + +1. **Code Reuse** - Eliminate duplication across test files +2. **Flexibility** - Easy to test with various data sources +3. **Maintainability** - Change data generation strategy in one place +4. **Testability** - Test framework itself is unit tested +5. **Backward Compatible** - Existing tests work unchanged +6. **Extensible** - Easy to add new modes or configurations + +## Testing + +Comprehensive tests in `test/UnitTest/test_join_data_source.cpp`: + +```bash +# Run join data source tests +./bin/test_join_data_source + +# All 8 test cases pass: +# - DuplicateMode +# - SeparateMode +# - HelperWithTestDataGenerator +# - HelperWithSingleSource +# - HelperWithSeparateSources +# - WithDatasetSource +# - MaxRecordsLimit +# - ResetFunctionality +``` + +## Examples + +See `test/UnitTest/test_join_data_source.cpp` for comprehensive examples of all usage patterns. + +## Future Enhancements + +Potential additions: +- **Streaming mode** - Generate data on-demand instead of all at once +- **Time-based patterns** - Configure complex timestamp patterns +- **UID strategies** - Pluggable UID generation strategies +- **Batch support** - Generate multiple batches with reset +- **Statistics** - Track generation statistics for analysis diff --git a/test/test_utils/data_source/README.md b/test/test_utils/data_source/README.md new file mode 100644 index 0000000..d372a5b --- /dev/null +++ b/test/test_utils/data_source/README.md @@ -0,0 +1,238 @@ +# Test Data Source Framework + +## Overview + +The test data source framework provides a modular and extensible way to generate test data for the sageFlow operators. The framework separates data generation logic from test code, making it easier to: + +- Use different data sources (random, dataset-based, etc.) +- Reuse data generation logic across tests +- Test with real-world datasets +- Maintain backward compatibility with existing tests + +## Architecture + +The framework consists of several data source implementations: + +### 1. DataSourceBase (Abstract Base Class) + +Located in: `test/test_utils/data_source/data_source_base.h` + +Defines the interface for all data sources: +- `getNextVector()`: Returns the next vector from the source +- `getDimension()`: Returns the dimension of vectors +- `hasMore()`: Checks if more data is available +- `reset()`: Resets the source to start from beginning +- `getTotalCount()`: Returns total number of vectors (if known) + +### 2. RandomDataSource + +Located in: `test/test_utils/data_source/random_data_source.h/cpp` + +Generates random normalized vectors using a configurable seed for reproducibility. + +**Configuration:** +```cpp +RandomDataSource::Config config; +config.vector_dim = 128; // Vector dimension +config.seed = 42; // Random seed for reproducibility +config.max_vectors = -1; // Max vectors to generate (-1 = unlimited) + +auto data_source = std::make_shared(config); +``` + +### 3. DatasetDataSource + +Located in: `test/test_utils/data_source/dataset_data_source.h/cpp` + +Loads vectors from fvecs format dataset files (commonly used in vector search benchmarks). + +**Configuration:** +```cpp +DatasetDataSource::Config config; +config.file_path = "data/siftsmall/siftsmall_query.fvecs"; +config.expected_dim = 128; // Expected dimension (-1 = auto-detect) +config.loop = true; // Loop back to start when reaching end + +auto data_source = std::make_shared(ds_config); +``` + +### 4. VectorListSource + +Located in: `test/test_utils/data_source/vector_list_source.h` + +A simple adapter that wraps an in-memory vector of float vectors. Useful for: +- Wrapping generated data from TestDataGenerator +- Testing with small, predefined datasets +- Creating data sources from computed vectors + +**Usage:** +```cpp +#include "test_utils/data_source/vector_list_source.h" + +// Create from a vector of vectors +std::vector> vectors = { + {0.1f, 0.2f, 0.3f}, + {0.4f, 0.5f, 0.6f}, + {0.7f, 0.8f, 0.9f} +}; + +auto data_source = std::make_shared(vectors); + +// Use like any other data source +while (data_source->hasMore()) { + auto vec = data_source->getNextVector(); + // Process vector +} +``` + +**Note:** This is primarily an internal utility used by JoinTestHelper to wrap TestDataGenerator output, but can be used directly if needed. + +### 5. JsonDataSource + +Located in: `test/test_utils/data_source/json_data_source.h/cpp` + +Loads vectors from JSON format files. Useful for debugging and human-readable datasets. + +**Configuration:** +```cpp +JsonDataSource::Config config; +config.file_path = "test_data.json"; +config.expected_dim = 128; // Optional validation +config.loop = false; // Whether to loop when reaching end + +auto data_source = std::make_shared(config); +``` + +## Usage + +### Using with TestDataGenerator + +The `TestDataGenerator` class has been updated to accept a custom data source: + +#### Option 1: Default Random Generation (Backward Compatible) +```cpp +TestDataGenerator::Config config; +config.vector_dim = 128; +config.positive_pairs = 100; +config.negative_pairs = 100; + +// Uses random data source internally +TestDataGenerator generator(config); +auto [records, expected_matches] = generator.generateData(); +``` + +#### Option 2: Custom Random Data Source +```cpp +// Create a custom random data source +RandomDataSource::Config ds_config; +ds_config.vector_dim = 64; +ds_config.seed = 123; +auto data_source = std::make_shared(ds_config); + +// Use with TestDataGenerator +TestDataGenerator::Config config; +config.similarity_threshold = 0.8; +config.positive_pairs = 50; + +TestDataGenerator generator(config, data_source); +auto [records, expected_matches] = generator.generateData(); +``` + +#### Option 3: Dataset-Based Generation +```cpp +// Load vectors from a dataset file +DatasetDataSource::Config ds_config; +ds_config.file_path = PROJECT_DIR "/data/siftsmall/siftsmall_query.fvecs"; +ds_config.expected_dim = 128; +ds_config.loop = true; // Enable looping for reuse +auto data_source = std::make_shared(ds_config); + +// Generate test data using dataset vectors +TestDataGenerator::Config config; +config.similarity_threshold = 0.8; +config.positive_pairs = 10; +config.negative_pairs = 10; + +TestDataGenerator generator(config, data_source); +auto [records, expected_matches] = generator.generateData(); +``` + +### Direct Use of Data Sources + +Data sources can also be used directly without TestDataGenerator: + +```cpp +// Create a data source +RandomDataSource::Config config; +config.vector_dim = 128; +config.seed = 42; +auto data_source = std::make_shared(config); + +// Get vectors directly +while (data_source->hasMore()) { + std::vector vec = data_source->getNextVector(); + // Use the vector... +} + +// Reset to start again +data_source->reset(); +``` + +## Available Datasets + +The repository includes the SIFT small dataset in `data/siftsmall/`: +- `siftsmall_base.fvecs` - Base vectors (10,000 vectors, 128D) +- `siftsmall_query.fvecs` - Query vectors (100 vectors, 128D) +- `siftsmall_learn.fvecs` - Learning vectors (25,000 vectors, 128D) + +## Extending the Framework + +To add a new data source: + +1. Create a new class inheriting from `DataSourceBase` +2. Implement all virtual methods +3. Add the new source files to `test/CMakeLists.txt` +4. Use it with `TestDataGenerator` or directly in tests + +Example skeleton: + +```cpp +class MyCustomDataSource : public DataSourceBase { +public: + struct Config { + // Your configuration options + }; + + explicit MyCustomDataSource(const Config& config); + + std::vector getNextVector() override; + int getDimension() const override; + bool hasMore() const override; + void reset() override; + int getTotalCount() const override; + +private: + Config config_; + // Your implementation details +}; +``` + +## Backward Compatibility + +All existing tests continue to work without modification. The default constructor of `TestDataGenerator` automatically creates a `RandomDataSource` internally, maintaining the original behavior. + +## Testing + +See `test/UnitTest/test_data_source.cpp` for comprehensive examples of using the data source framework. + +Run the data source tests: +```bash +cd build +./bin/test_data_source +``` + +Or run all unit tests: +```bash +cd build +ctest -L UNIT +``` diff --git a/test/test_utils/data_source/data_source_base.h b/test/test_utils/data_source/data_source_base.h new file mode 100644 index 0000000..c37e9ee --- /dev/null +++ b/test/test_utils/data_source/data_source_base.h @@ -0,0 +1,49 @@ +#pragma once + +#include +#include +#include "common/data_types.h" + +namespace sageFlow { namespace test { + +/** + * @brief Base class for data sources in testing + * + * Provides a unified interface for obtaining vector data from different sources + * (random generation, datasets, etc.) + */ +class DataSourceBase { +public: + virtual ~DataSourceBase() = default; + + /** + * @brief Get the next vector from the data source + * @return A vector of floats, or empty vector if no more data + */ + virtual std::vector getNextVector() = 0; + + /** + * @brief Get the dimension of vectors from this data source + * @return The vector dimension + */ + virtual int getDimension() const = 0; + + /** + * @brief Check if more data is available + * @return true if more vectors can be obtained, false otherwise + */ + virtual bool hasMore() const = 0; + + /** + * @brief Reset the data source to start from the beginning + */ + virtual void reset() = 0; + + /** + * @brief Get total number of vectors available (if known) + * @return Number of vectors, or -1 if unknown + */ + virtual int getTotalCount() const { return -1; } +}; + +}} // namespace sageFlow::test diff --git a/test/test_utils/data_source/data_source_factory.h b/test/test_utils/data_source/data_source_factory.h new file mode 100644 index 0000000..5ad486a --- /dev/null +++ b/test/test_utils/data_source/data_source_factory.h @@ -0,0 +1,65 @@ +#pragma once + +#include "test_utils/data_source/data_source_base.h" +#include "test_utils/data_source/random_data_source.h" +#include "test_utils/data_source/dataset_data_source.h" +#include "test_utils/data_source/json_data_source.h" +#include "test_utils/dynamic_config.h" +#include +#include +#include + +namespace sageFlow { namespace test { + +/** + * @brief Factory for creating data sources from configuration + */ +class DataSourceFactory { +public: + /** + * @brief Create a data source from dynamic configuration + * @param config Configuration containing data source settings + * @param default_dim Default vector dimension if not specified in config + * @param default_seed Default random seed if not specified in config + * @return Shared pointer to created data source + */ + static std::shared_ptr createFromConfig( + const DynamicConfig& config, + int default_dim = 128, + uint32_t default_seed = 42) { + + std::string type = config.get("type", "random"); + + if (type == "random") { + RandomDataSource::Config ds_config; + ds_config.vector_dim = config.get("vector_dim", default_dim); + ds_config.seed = config.get("seed", static_cast(default_seed)); + ds_config.max_vectors = config.get("max_vectors", -1); + return std::make_shared(ds_config); + } + else if (type == "dataset") { + DatasetDataSource::Config ds_config; + ds_config.file_path = config.get("file_path", ""); + if (ds_config.file_path.empty()) { + throw std::runtime_error("Dataset data source requires 'file_path' in configuration"); + } + ds_config.expected_dim = config.get("expected_dim", default_dim); + ds_config.loop = (config.get("loop", 0) != 0); // Convert int to bool + return std::make_shared(ds_config); + } + else if (type == "json") { + JsonDataSource::Config ds_config; + ds_config.file_path = config.get("file_path", ""); + if (ds_config.file_path.empty()) { + throw std::runtime_error("JSON data source requires 'file_path' in configuration"); + } + ds_config.loop = (config.get("loop", 0) != 0); // Convert int to bool + return std::make_shared(ds_config); + } + else { + throw std::runtime_error("Unknown data source type: " + type); + } + } +}; + +}} // namespace sageFlow::test diff --git a/test/test_utils/data_source/dataset_data_source.cpp b/test/test_utils/data_source/dataset_data_source.cpp new file mode 100644 index 0000000..ab339ae --- /dev/null +++ b/test/test_utils/data_source/dataset_data_source.cpp @@ -0,0 +1,97 @@ +#include "test_utils/data_source/dataset_data_source.h" +#include "utils/logger.h" +#include +#include + +namespace sageFlow { namespace test { + +DatasetDataSource::DatasetDataSource(const Config& config) + : config_(config), dimension_(0), current_index_(0) { + loadVectors(); +} + +void DatasetDataSource::loadVectors() { + std::ifstream input(config_.file_path, std::ios::binary); + if (!input.is_open()) { + throw std::runtime_error("Cannot open file: " + config_.file_path); + } + + while (true) { + // Read dimension for the current vector + int32_t current_dim = 0; + input.read(reinterpret_cast(¤t_dim), sizeof(int32_t)); + + if (input.eof()) { + break; // End of file reached cleanly + } + if (input.fail()) { + throw std::runtime_error("Error reading dimension from file: " + config_.file_path); + } + + // Check dimension consistency + if (vectors_.empty()) { + dimension_ = current_dim; + if (config_.expected_dim != -1 && dimension_ != config_.expected_dim) { + throw std::runtime_error("Unexpected dimension in file " + config_.file_path + + ". Expected " + std::to_string(config_.expected_dim) + + ", got " + std::to_string(dimension_)); + } + if (dimension_ <= 0) { + throw std::runtime_error("Invalid dimension read from file: " + std::to_string(dimension_)); + } + } else if (current_dim != dimension_) { + throw std::runtime_error("Inconsistent dimension found in file " + config_.file_path + + ". Expected " + std::to_string(dimension_) + + ", found " + std::to_string(current_dim) + + " at vector index " + std::to_string(vectors_.size())); + } + + // Read vector data + std::vector vec(dimension_); + input.read(reinterpret_cast(vec.data()), dimension_ * sizeof(float)); + if (input.fail()) { + throw std::runtime_error("Error reading vector data from file: " + config_.file_path + + " at vector index " + std::to_string(vectors_.size())); + } + + vectors_.push_back(std::move(vec)); + } + + input.close(); + + if (vectors_.empty()) { + throw std::runtime_error("No vectors loaded from file: " + config_.file_path); + } + + SAGEFLOW_LOG_INFO("TEST", "[DatasetDataSource] Loaded {} vectors of dimension {} from {}", + vectors_.size(), dimension_, config_.file_path); +} + +std::vector DatasetDataSource::getNextVector() { + if (!hasMore()) { + return std::vector(); + } + + std::vector result = vectors_[current_index_]; + current_index_++; + + // If looping is enabled and we reached the end, reset + if (config_.loop && current_index_ >= vectors_.size()) { + current_index_ = 0; + } + + return result; +} + +bool DatasetDataSource::hasMore() const { + if (config_.loop) { + return !vectors_.empty(); // Always has more if looping + } + return current_index_ < vectors_.size(); +} + +void DatasetDataSource::reset() { + current_index_ = 0; +} + +}} // namespace sageFlow::test diff --git a/test/test_utils/data_source/dataset_data_source.h b/test/test_utils/data_source/dataset_data_source.h new file mode 100644 index 0000000..1405de3 --- /dev/null +++ b/test/test_utils/data_source/dataset_data_source.h @@ -0,0 +1,39 @@ +#pragma once + +#include "test_utils/data_source/data_source_base.h" +#include + +namespace sageFlow { namespace test { + +/** + * @brief Data source that reads vectors from fvecs dataset files + * + * Reads vector data from standard fvecs format files (commonly used in vector search benchmarks). + * The fvecs format stores vectors as: [dimension(int)][vector_data(floats)]... + */ +class DatasetDataSource : public DataSourceBase { +public: + struct Config { + std::string file_path; + bool loop = false; // If true, loop back to start when reaching end + int expected_dim = -1; // Expected dimension, -1 means auto-detect + }; + + explicit DatasetDataSource(const Config& config); + + std::vector getNextVector() override; + int getDimension() const override { return dimension_; } + bool hasMore() const override; + void reset() override; + int getTotalCount() const override { return static_cast(vectors_.size()); } + +private: + void loadVectors(); + + Config config_; + std::vector> vectors_; + int dimension_; + size_t current_index_; +}; + +}} // namespace sageFlow::test diff --git a/test/test_utils/data_source/json_data_source.cpp b/test/test_utils/data_source/json_data_source.cpp new file mode 100644 index 0000000..dc6dc28 --- /dev/null +++ b/test/test_utils/data_source/json_data_source.cpp @@ -0,0 +1,118 @@ +#include "test_utils/data_source/json_data_source.h" +#include "utils/logger.h" +#include +#include +#include + +namespace sageFlow { namespace test { + +JsonDataSource::JsonDataSource(const Config& config) + : config_(config), dimension_(0), current_index_(0) { + loadVectors(); +} + +void JsonDataSource::loadVectors() { + std::ifstream input(config_.file_path); + if (!input.is_open()) { + throw std::runtime_error("Cannot open file: " + config_.file_path); + } + + // Simple JSON parsing (assumes well-formed JSON) + std::string line; + bool in_vectors = false; + std::vector current_vector; + + while (std::getline(input, line)) { + // Trim whitespace + size_t start = line.find_first_not_of(" \t\r\n"); + if (start == std::string::npos) continue; + line = line.substr(start); + + // Parse dimension + if (line.find("\"dimension\"") != std::string::npos) { + size_t colon = line.find(':'); + if (colon != std::string::npos) { + std::string value = line.substr(colon + 1); + size_t comma = value.find(','); + if (comma != std::string::npos) { + value = value.substr(0, comma); + } + dimension_ = std::stoi(value); + } + } + + // Check for vectors array start + if (line.find("\"vectors\"") != std::string::npos) { + in_vectors = true; + continue; + } + + // Parse vector data + if (in_vectors && line.find('[') != std::string::npos && line.find(']') != std::string::npos) { + // Extract numbers between [ and ] + size_t start_bracket = line.find('['); + size_t end_bracket = line.find(']'); + std::string data = line.substr(start_bracket + 1, end_bracket - start_bracket - 1); + + current_vector.clear(); + std::stringstream ss(data); + std::string token; + while (std::getline(ss, token, ',')) { + try { + float value = std::stof(token); + current_vector.push_back(value); + } catch (...) { + // Skip invalid tokens + } + } + + if (!current_vector.empty()) { + if (dimension_ == 0) { + dimension_ = static_cast(current_vector.size()); + } else if (static_cast(current_vector.size()) != dimension_) { + throw std::runtime_error("Inconsistent dimension in JSON file at vector " + + std::to_string(vectors_.size())); + } + vectors_.push_back(current_vector); + } + } + } + + input.close(); + + if (vectors_.empty()) { + throw std::runtime_error("No vectors loaded from file: " + config_.file_path); + } + + SAGEFLOW_LOG_INFO("TEST", "[JsonDataSource] Loaded {} vectors of dimension {} from {}", + vectors_.size(), dimension_, config_.file_path); +} + +std::vector JsonDataSource::getNextVector() { + if (!hasMore()) { + return std::vector(); + } + + std::vector result = vectors_[current_index_]; + current_index_++; + + // If looping is enabled and we reached the end, reset + if (config_.loop && current_index_ >= vectors_.size()) { + current_index_ = 0; + } + + return result; +} + +bool JsonDataSource::hasMore() const { + if (config_.loop) { + return !vectors_.empty(); // Always has more if looping + } + return current_index_ < vectors_.size(); +} + +void JsonDataSource::reset() { + current_index_ = 0; +} + +}} // namespace sageFlow::test diff --git a/test/test_utils/data_source/json_data_source.h b/test/test_utils/data_source/json_data_source.h new file mode 100644 index 0000000..83754b7 --- /dev/null +++ b/test/test_utils/data_source/json_data_source.h @@ -0,0 +1,39 @@ +#pragma once + +#include "test_utils/data_source/data_source_base.h" +#include +#include + +namespace sageFlow { namespace test { + +/** + * @brief Data source that reads vectors from JSON files + * + * Reads vector data from JSON format files for easy debugging and visualization. + * JSON format: {"dimension": N, "count": M, "vectors": [[...], [...], ...]} + */ +class JsonDataSource : public DataSourceBase { +public: + struct Config { + std::string file_path; + bool loop = false; // If true, loop back to start when reaching end + }; + + explicit JsonDataSource(const Config& config); + + std::vector getNextVector() override; + int getDimension() const override { return dimension_; } + bool hasMore() const override; + void reset() override; + int getTotalCount() const override { return static_cast(vectors_.size()); } + +private: + void loadVectors(); + + Config config_; + std::vector> vectors_; + int dimension_; + size_t current_index_; +}; + +}} // namespace sageFlow::test diff --git a/test/test_utils/data_source/random_data_source.cpp b/test/test_utils/data_source/random_data_source.cpp new file mode 100644 index 0000000..29626d8 --- /dev/null +++ b/test/test_utils/data_source/random_data_source.cpp @@ -0,0 +1,50 @@ +#include "test_utils/data_source/random_data_source.h" +#include + +namespace sageFlow { namespace test { + +RandomDataSource::RandomDataSource(const Config& config) + : config_(config), rng_(config.seed), generated_count_(0) {} + +std::vector RandomDataSource::getNextVector() { + if (!hasMore()) { + return std::vector(); + } + + std::vector vec(config_.vector_dim); + std::normal_distribution dist(0.0f, 1.0f); + + for (int i = 0; i < config_.vector_dim; ++i) { + vec[i] = dist(rng_); + } + + // Normalize the vector + float norm = 0.0f; + for (float v : vec) { + norm += v * v; + } + norm = std::sqrt(norm); + + if (norm > 1e-6f) { + for (float& v : vec) { + v /= norm; + } + } + + generated_count_++; + return vec; +} + +bool RandomDataSource::hasMore() const { + if (config_.max_vectors < 0) { + return true; // Unlimited + } + return generated_count_ < config_.max_vectors; +} + +void RandomDataSource::reset() { + generated_count_ = 0; + rng_.seed(config_.seed); +} + +}} // namespace sageFlow::test diff --git a/test/test_utils/data_source/random_data_source.h b/test/test_utils/data_source/random_data_source.h new file mode 100644 index 0000000..bc5f884 --- /dev/null +++ b/test/test_utils/data_source/random_data_source.h @@ -0,0 +1,35 @@ +#pragma once + +#include "test_utils/data_source/data_source_base.h" +#include + +namespace sageFlow { namespace test { + +/** + * @brief Data source that generates random normalized vectors + * + * This is the default data generation method used in the original TestDataGenerator. + */ +class RandomDataSource : public DataSourceBase { +public: + struct Config { + int vector_dim = 128; + uint32_t seed = 42; + int max_vectors = -1; // -1 means unlimited + }; + + explicit RandomDataSource(const Config& config); + + std::vector getNextVector() override; + int getDimension() const override { return config_.vector_dim; } + bool hasMore() const override; + void reset() override; + int getTotalCount() const override { return config_.max_vectors; } + +private: + Config config_; + std::mt19937 rng_; + int generated_count_; +}; + +}} // namespace sageFlow::test diff --git a/test/test_utils/data_source/vector_list_source.h b/test/test_utils/data_source/vector_list_source.h new file mode 100644 index 0000000..d02ab18 --- /dev/null +++ b/test/test_utils/data_source/vector_list_source.h @@ -0,0 +1,46 @@ +#pragma once + +#include "test_utils/data_source/data_source_base.h" +#include + +namespace sageFlow { namespace test { + +/** + * @brief Data source that provides vectors from an in-memory list + * + * This is a simple adapter that wraps a vector of float vectors and + * provides them through the DataSourceBase interface. Useful for + * testing and for wrapping generated data. + */ +class VectorListSource : public DataSourceBase { +public: + explicit VectorListSource(const std::vector>& vectors) + : vectors_(vectors), index_(0) {} + + std::vector getNextVector() override { + if (index_ >= vectors_.size()) return {}; + return vectors_[index_++]; + } + + int getDimension() const override { + return vectors_.empty() ? 0 : static_cast(vectors_[0].size()); + } + + bool hasMore() const override { + return index_ < vectors_.size(); + } + + void reset() override { + index_ = 0; + } + + int getTotalCount() const override { + return static_cast(vectors_.size()); + } + +private: + std::vector> vectors_; + size_t index_; +}; + +}} // namespace sageFlow::test diff --git a/test/test_utils/data_writer/README.md b/test/test_utils/data_writer/README.md new file mode 100644 index 0000000..f94bcce --- /dev/null +++ b/test/test_utils/data_writer/README.md @@ -0,0 +1,290 @@ +# Data Persistence Framework + +## Overview + +The data persistence framework extends the data source framework to support saving generated test data to files and loading it back. This enables: + +1. **Reproducibility** - Save generated datasets for consistent testing across runs +2. **Sharing** - Share test datasets between team members +3. **Debugging** - Inspect generated data in human-readable formats +4. **Performance** - Save once, reuse many times without regeneration overhead + +## Architecture + +### Components + +``` +Data Persistence Framework +├── Writers (Output) +│ ├── DataWriterBase # Abstract writer interface +│ ├── FvecsWriter # Binary format (.fvecs) +│ └── JsonWriter # Human-readable format (.json) +├── Readers (Input - via DataSource) +│ ├── DatasetDataSource # Reads .fvecs files +│ └── JsonDataSource # Reads .json files +└── TestDataGenerator # Enhanced with save/load support +``` + +### Supported Formats + +#### 1. FVECS Format (.fvecs) +- **Type**: Binary format +- **Use Case**: Production, large datasets, efficiency +- **Format Spec**: `[dimension(int32)][vector_data(float32 * dimension)]` per vector +- **Pros**: Compact, fast I/O, industry standard +- **Cons**: Not human-readable + +#### 2. JSON Format (.json) +- **Type**: Text format +- **Use Case**: Debugging, visualization, small datasets +- **Format Spec**: `{"dimension": N, "count": M, "vectors": [[...], [...]]}` +- **Pros**: Human-readable, easy to inspect, portable +- **Cons**: Larger file size, slower I/O + +## Usage + +### Basic Usage: Save Generated Data + +```cpp +#include "test_utils/test_data_generator.h" +#include "test_utils/data_writer/fvecs_writer.h" +#include "test_utils/data_writer/json_writer.h" + +// Generate test data +TestDataGenerator::Config config; +config.vector_dim = 128; +config.positive_pairs = 100; +config.negative_pairs = 100; +config.random_tail = 200; + +TestDataGenerator generator(config); +auto [records, matches] = generator.generateData(); + +// Save to binary format (efficient for large datasets) +auto fvecs_writer = std::make_shared(); +generator.saveGeneratedVectors("test_data.fvecs", fvecs_writer); + +// OR save to JSON format (human-readable for debugging) +auto json_writer = std::make_shared(); +generator.saveGeneratedVectors("test_data.json", json_writer); +``` + +### Load and Use Saved Data + +```cpp +#include "test_utils/test_data_generator.h" +#include "test_utils/data_source/dataset_data_source.h" +#include "test_utils/data_source/json_data_source.h" + +// Load from FVECS file +DatasetDataSource::Config ds_config; +ds_config.file_path = "test_data.fvecs"; +ds_config.expected_dim = 128; +ds_config.loop = true; // Enable looping for reuse +auto data_source = std::make_shared(ds_config); + +// OR load from JSON file +JsonDataSource::Config json_config; +json_config.file_path = "test_data.json"; +json_config.loop = true; +auto json_source = std::make_shared(json_config); + +// Use loaded data with TestDataGenerator +TestDataGenerator::Config gen_config; +gen_config.similarity_threshold = 0.8; +gen_config.positive_pairs = 50; +gen_config.negative_pairs = 50; + +TestDataGenerator generator(gen_config, data_source); +auto [records, matches] = generator.generateData(); +``` + +### Workflow: Generate Once, Use Many Times + +```cpp +// Step 1: Generate and save reference dataset (run once) +void generateReferenceDataset() { + TestDataGenerator::Config config; + config.vector_dim = 128; + config.positive_pairs = 500; + config.negative_pairs = 500; + config.random_tail = 2000; + config.seed = 42; // Fixed seed for reproducibility + + TestDataGenerator generator(config); + generator.generateData(); + + auto writer = std::make_shared(); + generator.saveGeneratedVectors("reference_dataset_v1.fvecs", writer); +} + +// Step 2: Use in multiple tests without regeneration +TEST(MyTest, TestWithReferenceData) { + DatasetDataSource::Config config; + config.file_path = "reference_dataset_v1.fvecs"; + config.loop = true; + auto data_source = std::make_shared(config); + + TestDataGenerator::Config gen_config; + gen_config.positive_pairs = 100; + TestDataGenerator generator(gen_config, data_source); + + auto [records, matches] = generator.generateData(); + // Run tests... +} +``` + +## File Format Specifications + +### FVECS Format + +Binary format, little-endian: +``` +[int32: dimension] [float32: value_1] [float32: value_2] ... [float32: value_dim] +[int32: dimension] [float32: value_1] [float32: value_2] ... [float32: value_dim] +... +``` + +Example (dimension=3, 2 vectors): +``` +Bytes: [03 00 00 00] [3F 80 00 00] [40 00 00 00] [40 40 00 00] +Values: [3] [1.0] [2.0] [3.0] + + [03 00 00 00] [40 80 00 00] [40 A0 00 00] [40 C0 00 00] + [3] [4.0] [5.0] [6.0] +``` + +### JSON Format + +Text format with standard JSON syntax: +```json +{ + "dimension": 128, + "count": 1000, + "vectors": [ + [0.123456, -0.234567, 0.345678, ...], + [0.456789, -0.567890, 0.678901, ...], + ... + ] +} +``` + +## Integration with Existing Tests + +### Before (Direct Generation Only) +```cpp +TEST(MyTest, TestJoinOperator) { + TestDataGenerator::Config config; + config.vector_dim = 128; + TestDataGenerator generator(config); + auto [records, matches] = generator.generateData(); + // Test... +} +``` + +### After (Support Both Generation and Files) +```cpp +TEST(MyTest, TestJoinOperator) { + // Option 1: Direct generation (same as before) + TestDataGenerator::Config config; + config.vector_dim = 128; + TestDataGenerator generator(config); + auto [records, matches] = generator.generateData(); + + // Option 2: Load from file (NEW!) + DatasetDataSource::Config ds_config; + ds_config.file_path = "test_vectors.fvecs"; + auto data_source = std::make_shared(ds_config); + TestDataGenerator generator2(config, data_source); + auto [records2, matches2] = generator2.generateData(); + + // Both work identically! +} +``` + +## API Reference + +### DataWriterBase (Interface) + +```cpp +class DataWriterBase { + virtual bool writeVectors(const std::string& file_path, + const std::vector>& vectors, + int dimension) = 0; + virtual std::string getFileExtension() const = 0; + virtual std::string getFormatDescription() const = 0; +}; +``` + +### FvecsWriter + +```cpp +class FvecsWriter : public DataWriterBase { + // Writes vectors in FVECS binary format + // file_path: Output file path (e.g., "data.fvecs") + // Returns: true if successful, false otherwise +}; +``` + +### JsonWriter + +```cpp +class JsonWriter : public DataWriterBase { + // Writes vectors in JSON text format + // file_path: Output file path (e.g., "data.json") + // Returns: true if successful, false otherwise +}; +``` + +### JsonDataSource + +```cpp +class JsonDataSource : public DataSourceBase { + struct Config { + std::string file_path; // Path to JSON file + bool loop = false; // Loop back when reaching end + }; + // Reads vectors from JSON files +}; +``` + +### TestDataGenerator (Enhanced) + +```cpp +class TestDataGenerator { + // Save generated vectors to file + bool saveGeneratedVectors(const std::string& file_path, + std::shared_ptr writer); + + // Get last generated vectors (for custom processing) + std::vector> getLastGeneratedVectors() const; +}; +``` + +## Best Practices + +1. **Use FVECS for Production**: Binary format is efficient and industry-standard +2. **Use JSON for Debugging**: Human-readable format helps inspect data +3. **Version Your Datasets**: Include version in filename (e.g., `dataset_v1.fvecs`) +4. **Document Seed Values**: Always record the seed used to generate datasets +5. **Test Round-Trip**: Verify save/load preserves data integrity +6. **Enable Looping for Reuse**: Set `loop=true` when reusing datasets multiple times + +## Examples + +See `test/UnitTest/test_data_persistence.cpp` for comprehensive examples: +- Saving to different formats +- Round-trip testing (save and load back) +- Using loaded data with TestDataGenerator +- Format validation + +## Performance Considerations + +- **FVECS**: ~100-200 MB/s write speed, ~150-300 MB/s read speed +- **JSON**: ~20-50 MB/s write speed, ~30-80 MB/s read speed +- **Memory**: Writers process vectors in streaming fashion (low memory overhead) +- **Disk Space**: FVECS uses ~4 bytes/dimension/vector, JSON uses ~15-20 bytes/dimension/vector + +## Backward Compatibility + +All existing tests continue to work without modification. The persistence features are optional additions that don't affect the default behavior of TestDataGenerator or existing data sources. diff --git a/test/test_utils/data_writer/data_writer_base.h b/test/test_utils/data_writer/data_writer_base.h new file mode 100644 index 0000000..a2a885e --- /dev/null +++ b/test/test_utils/data_writer/data_writer_base.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include + +namespace sageFlow { namespace test { + +/** + * @brief Base class for writing vector data to files + * + * Provides a unified interface for persisting vector data to different formats. + * Implementations can support binary formats (fvecs), text formats (JSON, CSV), etc. + */ +class DataWriterBase { +public: + virtual ~DataWriterBase() = default; + + /** + * @brief Write vectors to a file + * @param file_path Path to the output file + * @param vectors Vector data to write (each inner vector is one data point) + * @param dimension Vector dimension (for validation) + * @return true if write was successful, false otherwise + */ + virtual bool writeVectors(const std::string& file_path, + const std::vector>& vectors, + int dimension) = 0; + + /** + * @brief Get the file extension for this writer (e.g., ".fvecs", ".json") + */ + virtual std::string getFileExtension() const = 0; + + /** + * @brief Get a human-readable description of the format + */ + virtual std::string getFormatDescription() const = 0; +}; + +}} // namespace sageFlow::test diff --git a/test/test_utils/data_writer/fvecs_writer.cpp b/test/test_utils/data_writer/fvecs_writer.cpp new file mode 100644 index 0000000..42c6f9f --- /dev/null +++ b/test/test_utils/data_writer/fvecs_writer.cpp @@ -0,0 +1,60 @@ +#include "test_utils/data_writer/fvecs_writer.h" +#include "utils/logger.h" +#include +#include + +namespace sageFlow { namespace test { + +bool FvecsWriter::writeVectors(const std::string& file_path, + const std::vector>& vectors, + int dimension) { + if (vectors.empty()) { + SAGEFLOW_LOG_ERROR("TEST", "[FvecsWriter] Error: No vectors to write"); + return false; + } + + // Validate all vectors have the correct dimension + for (size_t i = 0; i < vectors.size(); ++i) { + if (static_cast(vectors[i].size()) != dimension) { + SAGEFLOW_LOG_ERROR("TEST", "[FvecsWriter] Error: Vector {} has dimension {}, expected {}", + i, vectors[i].size(), dimension); + return false; + } + } + + std::ofstream output(file_path, std::ios::binary); + if (!output.is_open()) { + SAGEFLOW_LOG_ERROR("TEST", "[FvecsWriter] Error: Cannot open file for writing: {}", file_path); + return false; + } + + try { + int32_t dim = static_cast(dimension); + + for (const auto& vec : vectors) { + // Write dimension + output.write(reinterpret_cast(&dim), sizeof(int32_t)); + + // Write vector data + output.write(reinterpret_cast(vec.data()), dim * sizeof(float)); + + if (!output.good()) { + SAGEFLOW_LOG_ERROR("TEST", "[FvecsWriter] Error: Write failed"); + output.close(); + return false; + } + } + + output.close(); + SAGEFLOW_LOG_INFO("TEST", "[FvecsWriter] Successfully wrote {} vectors of dimension {} to {}", + vectors.size(), dimension, file_path); + return true; + + } catch (const std::exception& e) { + SAGEFLOW_LOG_ERROR("TEST", "[FvecsWriter] Exception during write: {}", e.what()); + output.close(); + return false; + } +} + +}} // namespace sageFlow::test diff --git a/test/test_utils/data_writer/fvecs_writer.h b/test/test_utils/data_writer/fvecs_writer.h new file mode 100644 index 0000000..635338a --- /dev/null +++ b/test/test_utils/data_writer/fvecs_writer.h @@ -0,0 +1,30 @@ +#pragma once + +#include "test_utils/data_writer/data_writer_base.h" + +namespace sageFlow { namespace test { + +/** + * @brief Writer for fvecs binary format + * + * fvecs format specification: + * - Each vector is stored as: [dimension(int32)] [vector_data(float32 * dimension)] + * - This is the standard format used in vector search benchmarks (SIFT, GIST, etc.) + * - Binary format, efficient for large datasets + */ +class FvecsWriter : public DataWriterBase { +public: + FvecsWriter() = default; + + bool writeVectors(const std::string& file_path, + const std::vector>& vectors, + int dimension) override; + + std::string getFileExtension() const override { return ".fvecs"; } + + std::string getFormatDescription() const override { + return "FVECS binary format (dimension + float data per vector)"; + } +}; + +}} // namespace sageFlow::test diff --git a/test/test_utils/data_writer/json_writer.cpp b/test/test_utils/data_writer/json_writer.cpp new file mode 100644 index 0000000..4838c03 --- /dev/null +++ b/test/test_utils/data_writer/json_writer.cpp @@ -0,0 +1,71 @@ +#include "test_utils/data_writer/json_writer.h" +#include "utils/logger.h" +#include +#include + +namespace sageFlow { namespace test { + +bool JsonWriter::writeVectors(const std::string& file_path, + const std::vector>& vectors, + int dimension) { + if (vectors.empty()) { + SAGEFLOW_LOG_ERROR("TEST", "[JsonWriter] Error: No vectors to write"); + return false; + } + + // Validate all vectors have the correct dimension + for (size_t i = 0; i < vectors.size(); ++i) { + if (static_cast(vectors[i].size()) != dimension) { + SAGEFLOW_LOG_ERROR("TEST", "[JsonWriter] Error: Vector {} has dimension {}, expected {}", + i, vectors[i].size(), dimension); + return false; + } + } + + std::ofstream output(file_path); + if (!output.is_open()) { + SAGEFLOW_LOG_ERROR("TEST", "[JsonWriter] Error: Cannot open file for writing: {}", file_path); + return false; + } + + try { + output << std::fixed << std::setprecision(6); + + // Write JSON header + output << "{\n"; + output << " \"dimension\": " << dimension << ",\n"; + output << " \"count\": " << vectors.size() << ",\n"; + output << " \"vectors\": [\n"; + + // Write vectors + for (size_t i = 0; i < vectors.size(); ++i) { + output << " ["; + for (size_t j = 0; j < vectors[i].size(); ++j) { + output << vectors[i][j]; + if (j < vectors[i].size() - 1) { + output << ", "; + } + } + output << "]"; + if (i < vectors.size() - 1) { + output << ","; + } + output << "\n"; + } + + output << " ]\n"; + output << "}\n"; + + output.close(); + SAGEFLOW_LOG_INFO("TEST", "[JsonWriter] Successfully wrote {} vectors of dimension {} to {}", + vectors.size(), dimension, file_path); + return true; + + } catch (const std::exception& e) { + SAGEFLOW_LOG_ERROR("TEST", "[JsonWriter] Exception during write: {}", e.what()); + output.close(); + return false; + } +} + +}} // namespace sageFlow::test diff --git a/test/test_utils/data_writer/json_writer.h b/test/test_utils/data_writer/json_writer.h new file mode 100644 index 0000000..5bde3ab --- /dev/null +++ b/test/test_utils/data_writer/json_writer.h @@ -0,0 +1,40 @@ +#pragma once + +#include "test_utils/data_writer/data_writer_base.h" + +namespace sageFlow { namespace test { + +/** + * @brief Writer for JSON format + * + * JSON format for easy visualization and debugging: + * { + * "dimension": 128, + * "count": 1000, + * "vectors": [ + * [0.1, 0.2, ...], + * [0.3, 0.4, ...], + * ... + * ] + * } + * + * - Human-readable text format + * - Easy to inspect and visualize + * - Less efficient for large datasets but good for debugging + */ +class JsonWriter : public DataWriterBase { +public: + JsonWriter() = default; + + bool writeVectors(const std::string& file_path, + const std::vector>& vectors, + int dimension) override; + + std::string getFileExtension() const override { return ".json"; } + + std::string getFormatDescription() const override { + return "JSON format (human-readable, good for visualization)"; + } +}; + +}} // namespace sageFlow::test diff --git a/test/test_utils/dynamic_config.cpp b/test/test_utils/dynamic_config.cpp index 54399bd..9f85d92 100644 --- a/test/test_utils/dynamic_config.cpp +++ b/test/test_utils/dynamic_config.cpp @@ -2,7 +2,7 @@ #include #include -namespace candy { +namespace sageFlow { namespace test { void DynamicConfig::set(const std::string& key, const ConfigValue& value) { config_map_[key]=value; } @@ -36,6 +36,29 @@ toml::table DynamicConfigManager::parseFileWithFallback(const std::string& path) throw std::runtime_error("Failed to open config: "+path); } +std::string resolveProjectRelativePath(const std::string& path) { + if (path.empty()) { + return path; + } + + std::filesystem::path fs_path(path); + if (fs_path.is_absolute()) { + return fs_path.lexically_normal().string(); + } + + std::filesystem::path base_path; +#ifdef PROJECT_DIR + base_path = std::filesystem::path(PROJECT_DIR); +#else + base_path = std::filesystem::current_path(); +#endif + return (base_path / fs_path).lexically_normal().string(); +} + +std::string DynamicConfigManager::resolveProjectRelativePath(const std::string& path) { + return sageFlow::test::resolveProjectRelativePath(path); +} + ConfigValue DynamicConfigManager::convertTomlValue(const toml::node& node) { if (auto v=node.value()) return *v; if (auto v=node.value()) return *v; @@ -82,4 +105,4 @@ bool DynamicConfigManager::loadConfigs(const std::string& config_path, const std bool DynamicConfigManager::loadRootConfig(const std::string& config_path, DynamicConfig& config) { return loadConfig(config_path, "", config); } } // namespace test -} // namespace candy \ No newline at end of file +} // namespace sageFlow \ No newline at end of file diff --git a/test/test_utils/dynamic_config.h b/test/test_utils/dynamic_config.h index ecdbc42..c702ab6 100644 --- a/test/test_utils/dynamic_config.h +++ b/test/test_utils/dynamic_config.h @@ -6,7 +6,7 @@ #include #include "toml++/toml.hpp" -namespace candy { namespace test { +namespace sageFlow { namespace test { using ConfigValue = std::variant,std::vector>; @@ -27,12 +27,15 @@ class DynamicConfigManager { static bool loadConfig(const std::string& config_path, const std::string& section, DynamicConfig& config); static bool loadConfigs(const std::string& config_path, const std::string& section, std::vector& configs); static bool loadRootConfig(const std::string& config_path, DynamicConfig& config); + static std::string resolveProjectRelativePath(const std::string& path); private: static void extractConfig(const toml::table& tbl, DynamicConfig& config, const std::string& prefix = ""); static ConfigValue convertTomlValue(const toml::node& node); static toml::table parseFileWithFallback(const std::string& path); }; +std::string resolveProjectRelativePath(const std::string& path); + template T DynamicConfig::get(const std::string& key) const { auto it = config_map_.find(key); diff --git a/test/test_utils/join_data_source.cpp b/test/test_utils/join_data_source.cpp new file mode 100644 index 0000000..9cafe99 --- /dev/null +++ b/test/test_utils/join_data_source.cpp @@ -0,0 +1,153 @@ +#include "test_utils/join_data_source.h" +#include "test_utils/test_data_adapter.h" +#include "utils/logger.h" +#include + +namespace sageFlow { namespace test { + +JoinDataSourcePair::JoinDataSourcePair(const JoinDataSourceConfig& config) + : config_(config) { + + // Validate configuration + if (config_.mode == JoinDataSourceConfig::Mode::Duplicate) { + if (!config_.single_source) { + throw std::runtime_error("Single source required for Duplicate mode"); + } + } else if (config_.mode == JoinDataSourceConfig::Mode::Separate) { + if (!config_.left_source || !config_.right_source) { + throw std::runtime_error("Both left and right sources required for Separate mode"); + } + // Verify dimensions match + if (config_.left_source->getDimension() != config_.right_source->getDimension()) { + throw std::runtime_error("Left and right sources must have same dimension"); + } + } +} + +std::pair>, + std::vector>> +JoinDataSourcePair::generateStreams(size_t max_records) { + std::vector> left_records; + std::vector> right_records; + + int64_t timestamp = config_.base_timestamp; + size_t count = 0; + + if (config_.mode == JoinDataSourceConfig::Mode::Duplicate) { + // Duplicate mode: generate from single source, duplicate to both sides + auto& source = config_.single_source; + source->reset(); + + while (source->hasMore() && (max_records == 0 || count < max_records)) { + auto vec = source->getNextVector(); + if (vec.empty()) break; + + // Create left record + uint64_t left_uid = next_left_uid_++; + left_records.push_back(createRecord(left_uid, vec, timestamp)); + + // Create right record (possibly with UID offset) + uint64_t right_uid = config_.apply_right_uid_offset ? + (next_right_uid_++ + config_.right_uid_offset) : + next_right_uid_++; + right_records.push_back(createRecord(right_uid, vec, timestamp)); + + timestamp += config_.time_interval; + count++; + } + + } else { // Separate mode + auto& left = config_.left_source; + auto& right = config_.right_source; + left->reset(); + right->reset(); + + while (left->hasMore() && right->hasMore() && + (max_records == 0 || count < max_records)) { + auto left_vec = left->getNextVector(); + auto right_vec = right->getNextVector(); + + if (left_vec.empty() || right_vec.empty()) break; + + // Create left record + uint64_t left_uid = next_left_uid_++; + left_records.push_back(createRecord(left_uid, left_vec, timestamp)); + + // Create right record (possibly with UID offset) + uint64_t right_uid = config_.apply_right_uid_offset ? + (next_right_uid_++ + config_.right_uid_offset) : + next_right_uid_++; + right_records.push_back(createRecord(right_uid, right_vec, timestamp)); + + timestamp += config_.time_interval; + count++; + } + } + + SAGEFLOW_LOG_INFO("TEST", "[JoinDataSourcePair] Generated {} left and {} right records", + left_records.size(), right_records.size()); + + return {std::move(left_records), std::move(right_records)}; +} + +int JoinDataSourcePair::getDimension() const { + if (config_.mode == JoinDataSourceConfig::Mode::Separate) { + return config_.left_source->getDimension(); + } else { + return config_.single_source->getDimension(); + } +} + +int JoinDataSourcePair::getTotalCount() const { + if (config_.mode == JoinDataSourceConfig::Mode::Separate) { + return std::min(config_.left_source->getTotalCount(), + config_.right_source->getTotalCount()); + } else { + return config_.single_source->getTotalCount(); + } +} + +void JoinDataSourcePair::reset() { + next_left_uid_ = 1; + next_right_uid_ = 1; + + if (config_.mode == JoinDataSourceConfig::Mode::Separate) { + config_.left_source->reset(); + config_.right_source->reset(); + } else { + config_.single_source->reset(); + } +} + +std::unique_ptr JoinDataSourcePair::createRecord( + uint64_t uid, const std::vector& data, int64_t timestamp) { + auto record = createVectorRecord(uid, timestamp, data); + TestRecordSideManager::instance().setSide(uid, (uid % 2 == 0) ? Side::LEFT : Side::RIGHT); + return record; +} + +// Factory methods + +JoinDataSourceConfig JoinDataSourceFactory::createDuplicated( + std::shared_ptr source, + bool apply_uid_offset) { + JoinDataSourceConfig config; + config.mode = JoinDataSourceConfig::Mode::Duplicate; + config.single_source = source; + config.apply_right_uid_offset = apply_uid_offset; + return config; +} + +JoinDataSourceConfig JoinDataSourceFactory::createSeparate( + std::shared_ptr left_source, + std::shared_ptr right_source, + bool apply_uid_offset) { + JoinDataSourceConfig config; + config.mode = JoinDataSourceConfig::Mode::Separate; + config.left_source = left_source; + config.right_source = right_source; + config.apply_right_uid_offset = apply_uid_offset; + return config; +} + +}} // namespace sageFlow::test diff --git a/test/test_utils/join_data_source.h b/test/test_utils/join_data_source.h new file mode 100644 index 0000000..31b8ddc --- /dev/null +++ b/test/test_utils/join_data_source.h @@ -0,0 +1,114 @@ +#pragma once + +#include +#include +#include "common/data_types.h" +#include "test_utils/data_source/data_source_base.h" + +namespace sageFlow { namespace test { + +/** + * @brief Configuration for join data source pair + * + * Defines how to create left and right data streams for join testing. + * Supports two modes: + * - Duplicate: Same data source duplicated to both sides (use single_source) + * - Separate: Different data sources for left and right (use left_source and right_source) + * + * Both modes can be used with any data source, including those created from generators. + */ +struct JoinDataSourceConfig { + enum class Mode { + Duplicate, // Duplicate one source to both sides + Separate // Use two separate sources + }; + + Mode mode = Mode::Duplicate; + + // For Duplicate mode: use single_source + std::shared_ptr single_source; + + // For Separate mode: use left_source and right_source + std::shared_ptr left_source; + std::shared_ptr right_source; + + // Common options + bool apply_right_uid_offset = true; // Add offset to right stream UIDs + uint64_t right_uid_offset = 500000; // Default UID offset for right stream + int64_t base_timestamp = 1000000; // Starting timestamp + int64_t time_interval = 100; // Time increment between records +}; + +/** + * @brief Manages a pair of data sources for join testing + * + * Provides a unified interface for creating left and right data streams + * from various sources. Supports: + * - Duplicating a single source to both sides + * - Using separate sources for left and right + * - Applying UID offsets to distinguish streams + * - Generating VectorRecords with proper timestamps + */ +class JoinDataSourcePair { +public: + explicit JoinDataSourcePair(const JoinDataSourceConfig& config); + + /** + * @brief Generate left and right record streams + * @param max_records Maximum records to generate (0 = all available) + * @return Pair of (left_records, right_records) + */ + std::pair>, + std::vector>> + generateStreams(size_t max_records = 0); + + /** + * @brief Get the dimension of vectors in this pair + */ + int getDimension() const; + + /** + * @brief Get total available records (from smaller source if separate) + */ + int getTotalCount() const; + + /** + * @brief Reset both sources to beginning + */ + void reset(); + +private: + JoinDataSourceConfig config_; + uint64_t next_left_uid_ = 1; + uint64_t next_right_uid_ = 1; + + std::unique_ptr createRecord(uint64_t uid, const std::vector& data, int64_t timestamp); +}; + +/** + * @brief Factory for creating common join data source configurations + */ +class JoinDataSourceFactory { +public: + /** + * @brief Create config that duplicates a single source to both sides + * @param source Data source to duplicate + * @param apply_uid_offset Whether to offset right stream UIDs + */ + static JoinDataSourceConfig createDuplicated( + std::shared_ptr source, + bool apply_uid_offset = true); + + /** + * @brief Create config using separate sources for left and right + * @param left_source Source for left stream + * @param right_source Source for right stream + * @param apply_uid_offset Whether to offset right stream UIDs + */ + static JoinDataSourceConfig createSeparate( + std::shared_ptr left_source, + std::shared_ptr right_source, + bool apply_uid_offset = false); +}; + +}} // namespace sageFlow::test diff --git a/test/test_utils/join_test_helper.cpp b/test/test_utils/join_test_helper.cpp new file mode 100644 index 0000000..aaac046 --- /dev/null +++ b/test/test_utils/join_test_helper.cpp @@ -0,0 +1,63 @@ +#include "test_utils/join_test_helper.h" +#include "test_utils/data_source/vector_list_source.h" + +namespace sageFlow { namespace test { + +std::pair>, + std::vector>> +JoinTestHelper::generateJoinStreamsFromGenerator( + TestDataGenerator& generator, + bool apply_uid_offset) { + + // Generate data + auto [records, _] = generator.generateData(); + + // Get vectors for duplication + auto vectors = generator.getLastGeneratedVectors(); + if (vectors.empty()) { + throw std::runtime_error("No vectors generated from TestDataGenerator"); + } + + // Create a vector list source and use Duplicate mode + auto source = std::make_shared(vectors); + auto config = JoinDataSourceFactory::createDuplicated(source, apply_uid_offset); + JoinDataSourcePair pair(config); + + return pair.generateStreams(); +} + +std::pair>, + std::vector>> +JoinTestHelper::generateJoinStreams( + JoinDataSourcePair& pair, + size_t max_records) { + return pair.generateStreams(max_records); +} + +std::pair>, + std::vector>> +JoinTestHelper::generateJoinStreamsFromSource( + std::shared_ptr source, + bool apply_uid_offset, + size_t max_records) { + + auto config = JoinDataSourceFactory::createDuplicated(source, apply_uid_offset); + JoinDataSourcePair pair(config); + return pair.generateStreams(max_records); +} + +std::pair>, + std::vector>> +JoinTestHelper::generateJoinStreamsFromSeparateSources( + std::shared_ptr left_source, + std::shared_ptr right_source, + bool apply_uid_offset, + size_t max_records) { + + auto config = JoinDataSourceFactory::createSeparate( + left_source, right_source, apply_uid_offset); + JoinDataSourcePair pair(config); + return pair.generateStreams(max_records); +} + +}} // namespace sageFlow::test diff --git a/test/test_utils/join_test_helper.h b/test/test_utils/join_test_helper.h new file mode 100644 index 0000000..49300be --- /dev/null +++ b/test/test_utils/join_test_helper.h @@ -0,0 +1,85 @@ +#pragma once + +#include "test_utils/join_data_source.h" +#include "test_utils/test_data_generator.h" +#include "test_utils/data_source/random_data_source.h" + +namespace sageFlow { namespace test { + +/** + * @brief Helper functions for creating join test data + * + * Provides convenient wrappers for common join testing scenarios, + * maintaining backward compatibility with existing tests. + */ +class JoinTestHelper { +public: + /** + * @brief Create join streams from TestDataGenerator (backward compatible) + * + * This is the standard pattern used in existing tests: + * 1. Generate data with TestDataGenerator + * 2. Duplicate to left and right streams + * 3. Apply UID offset to right stream + * + * @param generator TestDataGenerator instance + * @param apply_uid_offset Whether to offset right UIDs (default: true) + * @return Pair of (left_records, right_records) + */ + static std::pair>, + std::vector>> + generateJoinStreamsFromGenerator( + TestDataGenerator& generator, + bool apply_uid_offset = true); + + /** + * @brief Create join streams using a data source pair + * + * @param pair JoinDataSourcePair to generate from + * @param max_records Maximum records (0 = all available) + * @return Pair of (left_records, right_records) + */ + static std::pair>, + std::vector>> + generateJoinStreams( + JoinDataSourcePair& pair, + size_t max_records = 0); + + /** + * @brief Create join streams by duplicating a single data source + * + * Useful for testing with dataset files or specific patterns. + * + * @param source Data source to duplicate + * @param apply_uid_offset Whether to offset right UIDs + * @param max_records Maximum records (0 = all available) + * @return Pair of (left_records, right_records) + */ + static std::pair>, + std::vector>> + generateJoinStreamsFromSource( + std::shared_ptr source, + bool apply_uid_offset = true, + size_t max_records = 0); + + /** + * @brief Create join streams from separate left and right sources + * + * Allows testing with different data distributions on each side. + * + * @param left_source Source for left stream + * @param right_source Source for right stream + * @param apply_uid_offset Whether to offset right UIDs + * @param max_records Maximum records (0 = all available) + * @return Pair of (left_records, right_records) + */ + static std::pair>, + std::vector>> + generateJoinStreamsFromSeparateSources( + std::shared_ptr left_source, + std::shared_ptr right_source, + bool apply_uid_offset = false, + size_t max_records = 0); +}; + +}} // namespace sageFlow::test diff --git a/test/test_utils/test_config_manager.cpp b/test/test_utils/test_config_manager.cpp index 33ae88c..c45b1c3 100644 --- a/test/test_utils/test_config_manager.cpp +++ b/test/test_utils/test_config_manager.cpp @@ -2,7 +2,7 @@ #include #include -namespace candy { namespace test { +namespace sageFlow { namespace test { namespace { template diff --git a/test/test_utils/test_config_manager.h b/test/test_utils/test_config_manager.h index d013dae..945f070 100644 --- a/test/test_utils/test_config_manager.h +++ b/test/test_utils/test_config_manager.h @@ -5,7 +5,7 @@ #include #include "toml++/toml.hpp" -namespace candy { namespace test { +namespace sageFlow { namespace test { struct TestCaseConfig { std::string name; diff --git a/test/test_utils/test_data_adapter.h b/test/test_utils/test_data_adapter.h index 884813d..fe1706f 100644 --- a/test/test_utils/test_data_adapter.h +++ b/test/test_utils/test_data_adapter.h @@ -10,7 +10,7 @@ #include #include -namespace candy { namespace test { +namespace sageFlow { namespace test { enum class Side { LEFT, RIGHT, BOTH }; diff --git a/test/test_utils/test_data_generator.cpp b/test/test_utils/test_data_generator.cpp index d74642b..d9e6d00 100644 --- a/test/test_utils/test_data_generator.cpp +++ b/test/test_utils/test_data_generator.cpp @@ -1,37 +1,81 @@ #include "test_utils/test_data_generator.h" +#include "test_utils/data_source/random_data_source.h" +#include "test_utils/data_source/dataset_data_source.h" +#include "test_utils/data_source/data_source_factory.h" +#include "utils/logger.h" #include #include -namespace candy { namespace test { +namespace sageFlow { namespace test { -TestDataGenerator::TestDataGenerator(const Config& config) : config_(config), rng_(config.seed) {} +TestDataGenerator::TestDataGenerator(const Config& config) : config_(config), rng_(config.seed) { + // Create default random data source + RandomDataSource::Config ds_config; + ds_config.vector_dim = config_.vector_dim; + ds_config.seed = config_.seed; + ds_config.max_vectors = -1; // Unlimited + data_source_ = std::make_shared(ds_config); +} + +TestDataGenerator::TestDataGenerator(const Config& config, std::shared_ptr data_source) + : config_(config), rng_(config.seed), data_source_(std::move(data_source)) { + if (!data_source_) { + throw std::runtime_error("Data source cannot be null"); + } + // Update config dimension to match data source + config_.vector_dim = data_source_->getDimension(); +} + +TestDataGenerator TestDataGenerator::createFromConfig(const Config& config, const DynamicConfig* data_source_config) { + if (!data_source_config) { + // No data source config provided, use default random + return TestDataGenerator(config); + } + + // Create data source from config + auto data_source = DataSourceFactory::createFromConfig(*data_source_config, config.vector_dim, config.seed); + return TestDataGenerator(config, data_source); +} std::pair>, std::unordered_set, PairHash>> TestDataGenerator::generateData() { std::vector> records; std::unordered_set, PairHash> expected_matches; + last_generated_vectors_.clear(); // Clear cache for new generation + uint64_t uid_counter = next_uid_; int64_t timestamp = config_.base_timestamp; for (int i = 0; i < config_.positive_pairs; ++i) { - auto base_vector = generateRandomVector(); auto perturbed_vector = perturbVector(base_vector, config_.similarity_threshold + 0.05); + auto base_vector = getNextVector(); auto perturbed_vector = perturbVector(base_vector, config_.similarity_threshold + 0.05); uint64_t uid1 = uid_counter++; uint64_t uid2 = uid_counter++; records.push_back(createRecord(uid1, base_vector, timestamp)); records.push_back(createRecord(uid2, perturbed_vector, timestamp + config_.time_interval)); + last_generated_vectors_.push_back(base_vector); + last_generated_vectors_.push_back(perturbed_vector); expected_matches.insert({uid1, uid2}); timestamp += config_.time_interval * 2; } for (int i = 0; i < config_.near_threshold_pairs; ++i) { - auto base_vector = generateRandomVector(); double target_sim = config_.similarity_threshold + (i % 2 == 0 ? 0.001 : -0.001); + auto base_vector = getNextVector(); double target_sim = config_.similarity_threshold + (i % 2 == 0 ? 0.001 : -0.001); auto perturbed_vector = perturbVector(base_vector, target_sim); uint64_t uid1 = uid_counter++; uint64_t uid2 = uid_counter++; records.push_back(createRecord(uid1, base_vector, timestamp)); records.push_back(createRecord(uid2, perturbed_vector, timestamp + config_.time_interval)); + last_generated_vectors_.push_back(base_vector); + last_generated_vectors_.push_back(perturbed_vector); if (target_sim >= config_.similarity_threshold) expected_matches.insert({uid1, uid2}); timestamp += config_.time_interval * 2; } for (int i = 0; i < config_.negative_pairs; ++i) { - auto vec1 = generateRandomVector(); auto vec2 = generateRandomVector(); uint64_t uid1 = uid_counter++; uint64_t uid2 = uid_counter++; + auto vec1 = getNextVector(); auto vec2 = getNextVector(); uint64_t uid1 = uid_counter++; uint64_t uid2 = uid_counter++; records.push_back(createRecord(uid1, vec1, timestamp)); records.push_back(createRecord(uid2, vec2, timestamp + config_.time_interval)); + last_generated_vectors_.push_back(vec1); + last_generated_vectors_.push_back(vec2); if (calculateSimilarity(vec1, vec2) >= config_.similarity_threshold) expected_matches.insert({uid1, uid2}); timestamp += config_.time_interval * 2; } - for (int i = 0; i < config_.random_tail; ++i) { auto vec = generateRandomVector(); records.push_back(createRecord(uid_counter++, vec, timestamp)); timestamp += config_.time_interval; } + for (int i = 0; i < config_.random_tail; ++i) { + auto vec = getNextVector(); + records.push_back(createRecord(uid_counter++, vec, timestamp)); + last_generated_vectors_.push_back(vec); + timestamp += config_.time_interval; + } next_uid_ = uid_counter; return {std::move(records), std::move(expected_matches)}; } @@ -41,11 +85,12 @@ std::unique_ptr TestDataGenerator::createRecord(uint64_t uid, cons return record; } -std::vector TestDataGenerator::generateRandomVector() { - std::vector vec(config_.vector_dim); std::normal_distribution dist(0.0f, 1.0f); - for (int i = 0; i < config_.vector_dim; ++i) vec[i] = dist(rng_); - float norm = 0.0f; for (float v : vec) norm += v*v; norm = std::sqrt(norm); - if (norm > 1e-6f) for (float &v : vec) v /= norm; return vec; +std::vector TestDataGenerator::getNextVector() { + if (!data_source_->hasMore()) { + // If data source is exhausted, reset it to allow reuse + data_source_->reset(); + } + return data_source_->getNextVector(); } std::vector TestDataGenerator::perturbVector(const std::vector& base, double target_similarity) { @@ -96,4 +141,18 @@ double BaselineJoinChecker::computeCosineSimilarity(const std::vector& a, bool BaselineJoinChecker::areInSameWindow(int64_t ts1, int64_t ts2, int64_t window_size) { return std::abs(ts1-ts2) <= window_size; } +bool TestDataGenerator::saveGeneratedVectors(const std::string& file_path, std::shared_ptr writer) { + if (!writer) { + SAGEFLOW_LOG_ERROR("TEST", "[TestDataGenerator] Error: Writer cannot be null"); + return false; + } + + if (last_generated_vectors_.empty()) { + SAGEFLOW_LOG_ERROR("TEST", "[TestDataGenerator] Error: No data to save. Call generateData() first."); + return false; + } + + return writer->writeVectors(file_path, last_generated_vectors_, config_.vector_dim); +} + }} // namespace diff --git a/test/test_utils/test_data_generator.h b/test/test_utils/test_data_generator.h index 9cecd1e..58f749e 100644 --- a/test/test_utils/test_data_generator.h +++ b/test/test_utils/test_data_generator.h @@ -7,8 +7,11 @@ #include #include "common/data_types.h" #include "test_utils/test_data_adapter.h" +#include "test_utils/data_source/data_source_base.h" +#include "test_utils/data_writer/data_writer_base.h" +#include "test_utils/dynamic_config.h" -namespace candy { namespace test { +namespace sageFlow { namespace test { struct PairHash { size_t operator()(const std::pair& p) const noexcept { uint64_t a = std::min(p.first, p.second); uint64_t b = std::max(p.first, p.second); uint64_t mix = a * 1315423911u ^ ((b << 13) | (b >> 7)); return std::hash{}(mix); } }; @@ -28,11 +31,39 @@ class TestDataGenerator { int64_t time_interval = 100; }; explicit TestDataGenerator(const Config& config); + explicit TestDataGenerator(const Config& config, std::shared_ptr data_source); + + /** + * @brief Create TestDataGenerator from configuration with optional data source config + * @param config Test data generator configuration + * @param data_source_config Optional data source configuration (if empty, uses default random) + */ + static TestDataGenerator createFromConfig(const Config& config, const DynamicConfig* data_source_config = nullptr); + std::pair>, std::unordered_set, PairHash>> generateData(); + + /** + * @brief Save generated vectors to a file using the specified writer + * @param file_path Path to the output file + * @param writer DataWriter implementation (FvecsWriter, JsonWriter, etc.) + * @return true if save was successful + */ + bool saveGeneratedVectors(const std::string& file_path, std::shared_ptr writer); + + /** + * @brief Get the last generated vectors (for saving after generation) + */ + std::vector> getLastGeneratedVectors() const { return last_generated_vectors_; } + private: - Config config_; std::mt19937 rng_; uint64_t next_uid_ = 1; + Config config_; + std::mt19937 rng_; + uint64_t next_uid_ = 1; + std::shared_ptr data_source_; + std::vector> last_generated_vectors_; // Cache for saving + std::unique_ptr createRecord(uint64_t uid, const std::vector& data, int64_t timestamp); - std::vector generateRandomVector(); + std::vector getNextVector(); std::vector perturbVector(const std::vector& base, double target_similarity); double calculateSimilarity(const std::vector& a, const std::vector& b); }; diff --git a/third-party/CMakeLists.txt b/third-party/CMakeLists.txt index b460c28..2b69a65 100644 --- a/third-party/CMakeLists.txt +++ b/third-party/CMakeLists.txt @@ -39,11 +39,21 @@ FetchContent_Declare( # Configure upstream projects before fetching -------------------------------- set(BUILD_GMOCK OFF CACHE BOOL "" FORCE) -set(BUILD_GTEST ON CACHE BOOL "" FORCE) +# 只在需要测试时构建 gtest +if(BUILD_TESTING) + set(BUILD_GTEST ON CACHE BOOL "" FORCE) +else() + set(BUILD_GTEST OFF CACHE BOOL "" FORCE) +endif() set(SPDLOG_FMT_EXTERNAL ON CACHE BOOL "" FORCE) # Fetch and build third-party projects --------------------------------------- -FetchContent_MakeAvailable(argparse fmt googletest spdlog tomlplusplus) +if(BUILD_TESTING) + FetchContent_MakeAvailable(argparse fmt googletest spdlog tomlplusplus) +else() + # 不构建测试时跳过 googletest + FetchContent_MakeAvailable(argparse fmt spdlog tomlplusplus) +endif() # Ensure compiled libraries are position independent (shared objects, pybind11) foreach(_pic_target fmt spdlog) @@ -86,9 +96,11 @@ if(_external_runtime_include_dirs) endif() # Interface library consumed by tests ---------------------------------------- -add_library(externalTestLibs INTERFACE) -target_link_libraries( - externalTestLibs INTERFACE - GTest::gtest - GTest::gtest_main -) \ No newline at end of file +if(BUILD_TESTING) + add_library(externalTestLibs INTERFACE) + target_link_libraries( + externalTestLibs INTERFACE + GTest::gtest + GTest::gtest_main + ) +endif() \ No newline at end of file