Skip to content

Commit 911d491

Browse files
committed
Merge remote-tracking branch 'origin/main' into dev/add-yml-linter
2 parents fb18cd8 + 724ac33 commit 911d491

File tree

10 files changed

+139
-36
lines changed

10 files changed

+139
-36
lines changed

.github/workflows/build-linux.yml

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,38 @@ concurrency:
3030

3131
jobs:
3232
build:
33-
name: ${{ matrix.cxx }}, ${{ matrix.build_type }}, ivf=${{ matrix.ivf }}
33+
name: ${{ matrix.cxx }}, ${{ matrix.build_type }}, ivf=${{ matrix.ivf }}, asan=${{ matrix.asan }}
3434
runs-on: ubuntu-22.04
3535
strategy:
3636
matrix:
3737
build_type: [RelWithDebugInfo]
3838
ivf: [OFF, ON]
3939
cxx: [g++-11, g++-12, clang++-15]
40+
asan: [OFF]
41+
cmake_extra_args: [-DSVS_BUILD_BINARIES=YES -DSVS_BUILD_EXAMPLES=YES]
42+
ctest_args: ['']
4043
include:
4144
- cxx: g++-11
4245
cc: gcc-11
4346
- cxx: g++-12
4447
cc: gcc-12
4548
- cxx: clang++-15
4649
cc: clang-15
50+
- cxx: clang++-18
51+
cc: clang-18
52+
build_type: Debug
53+
ivf: OFF
54+
asan: ON
55+
# address sanitizer flags
56+
cmake_extra_args: >-
57+
-DCMAKE_CXX_FLAGS='-fsanitize=address -fno-omit-frame-pointer -g'
58+
-DCMAKE_C_FLAGS='-fsanitize=address -fno-omit-frame-pointer -g'
59+
-DCMAKE_EXE_LINKER_FLAGS='-fsanitize=address'
60+
-DCMAKE_SHARED_LINKER_FLAGS='-fsanitize=address'
61+
-DSVS_BUILD_BINARIES=NO
62+
-DSVS_BUILD_EXAMPLES=NO
63+
# skip longer-running tests
64+
ctest_args: -LE long
4765
exclude:
4866
- cxx: g++-12
4967
ivf: ON
@@ -60,6 +78,13 @@ jobs:
6078
source /opt/intel/oneapi/setvars.sh
6179
printenv >> $GITHUB_ENV
6280
81+
- name: Install Clang 18
82+
if: matrix.cxx == 'clang++-18'
83+
run: |
84+
wget https://apt.llvm.org/llvm.sh
85+
chmod +x llvm.sh
86+
sudo ./llvm.sh 18
87+
6388
- name: Configure build
6489
working-directory: ${{ runner.temp }}
6590
env:
@@ -69,12 +94,10 @@ jobs:
6994
run: |
7095
cmake -B${TEMP_WORKSPACE}/build -S${GITHUB_WORKSPACE} \
7196
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
72-
-DSVS_BUILD_BINARIES=YES \
7397
-DSVS_BUILD_TESTS=YES \
74-
-DSVS_BUILD_EXAMPLES=YES \
75-
-DSVS_EXPERIMENTAL_LEANVEC=YES \
7698
-DSVS_NO_AVX512=NO \
77-
-DSVS_EXPERIMENTAL_ENABLE_IVF=${{ matrix.ivf }}
99+
-DSVS_EXPERIMENTAL_ENABLE_IVF=${{ matrix.ivf }} \
100+
${{ matrix.cmake_extra_args }}
78101
79102
- name: Build Tests and Utilities
80103
working-directory: ${{ runner.temp }}/build
@@ -84,10 +107,11 @@ jobs:
84107
env:
85108
CTEST_OUTPUT_ON_FAILURE: 1
86109
working-directory: ${{ runner.temp }}/build/tests
87-
run: ctest -C ${{ matrix.build_type }}
110+
run: ctest -C ${{ matrix.build_type }} ${{ matrix.ctest_args }}
88111

89112
- name: Run Cpp Examples
113+
if: matrix.asan != 'ON'
90114
env:
91115
CTEST_OUTPUT_ON_FAILURE: 1
92116
working-directory: ${{ runner.temp }}/build/examples/cpp
93-
run: ctest -C RelWithDebugInfo
117+
run: ctest -C ${{ matrix.build_type }}

include/svs/core/distance/simd_utils.h

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#if defined(__i386__) || defined(__x86_64__)
2020

2121
#include <array>
22+
#include <cstring>
2223
#include <limits>
2324
#include <type_traits>
2425

@@ -332,11 +333,10 @@ template <> struct ConvertToFloat<8> {
332333
// from float
333334
static __m256 load(const float* ptr) { return _mm256_loadu_ps(ptr); }
334335
static __m256 load(mask_t m, const float* ptr) {
335-
// AVX2 doesn't have native masked load, so we load and then blend
336-
auto data = _mm256_loadu_ps(ptr);
337-
auto zero = _mm256_setzero_ps();
338-
auto mask_vec = create_blend_mask_avx2(m);
339-
return _mm256_blendv_ps(zero, data, mask_vec);
336+
// Full width load with blending may cause out-of-bounds read (SEGV)
337+
// Therefore we use _mm256_maskload_ps which safely handles masked loads
338+
auto mask_vec = _mm256_castps_si256(create_blend_mask_avx2(m));
339+
return _mm256_maskload_ps(ptr, mask_vec);
340340
}
341341

342342
// from float16
@@ -345,10 +345,10 @@ template <> struct ConvertToFloat<8> {
345345
}
346346

347347
static __m256 load(mask_t m, const Float16* ptr) {
348-
auto data = _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(ptr)));
349-
auto zero = _mm256_setzero_ps();
350-
auto mask_vec = create_blend_mask_avx2(m);
351-
return _mm256_blendv_ps(zero, data, mask_vec);
348+
// Safe masked load using a temporary buffer to avoid SEGV
349+
__m128i buffer = _mm_setzero_si128();
350+
std::memcpy(&buffer, ptr, __builtin_popcount(m) * sizeof(Float16));
351+
return _mm256_cvtph_ps(buffer);
352352
}
353353

354354
// from uint8
@@ -359,12 +359,10 @@ template <> struct ConvertToFloat<8> {
359359
}
360360

361361
static __m256 load(mask_t m, const uint8_t* ptr) {
362-
auto data = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
363-
_mm_cvtsi64_si128(*(reinterpret_cast<const int64_t*>(ptr)))
364-
));
365-
auto zero = _mm256_setzero_ps();
366-
auto mask_vec = create_blend_mask_avx2(m);
367-
return _mm256_blendv_ps(zero, data, mask_vec);
362+
// Safe masked load using a temporary buffer to avoid SEGV
363+
int64_t buffer = 0;
364+
std::memcpy(&buffer, ptr, __builtin_popcount(m) * sizeof(uint8_t));
365+
return _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_cvtsi64_si128(buffer)));
368366
}
369367

370368
// from int8
@@ -375,12 +373,10 @@ template <> struct ConvertToFloat<8> {
375373
}
376374

377375
static __m256 load(mask_t m, const int8_t* ptr) {
378-
auto data = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
379-
_mm_cvtsi64_si128(*(reinterpret_cast<const int64_t*>(ptr)))
380-
));
381-
auto zero = _mm256_setzero_ps();
382-
auto mask_vec = create_blend_mask_avx2(m);
383-
return _mm256_blendv_ps(zero, data, mask_vec);
376+
// Safe masked load using a temporary buffer to avoid SEGV
377+
int64_t buffer = 0;
378+
std::memcpy(&buffer, ptr, __builtin_popcount(m) * sizeof(int8_t));
379+
return _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_cvtsi64_si128(buffer)));
384380
}
385381

386382
// We do not need to treat the left or right-hand differently.

include/svs/index/vamana/search_buffer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ template <typename Idx, typename Cmp = std::less<>> class SearchBuffer {
340340
/// returns ``true``.
341341
///
342342
bool can_skip(float distance) const {
343-
return compare_(back().distance(), distance) && full();
343+
return full() && (capacity() == 0 || compare_(back().distance(), distance));
344344
}
345345

346346
///

tests/CMakeLists.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ set(CMAKE_CXX_STANDARD ${SVS_CXX_STANDARD})
3737
FetchContent_Declare(
3838
Catch2
3939
GIT_REPOSITORY https://github.com/catchorg/Catch2.git
40-
GIT_TAG v3.4.0
40+
GIT_TAG v3.11.0
4141
)
4242

4343
FetchContent_MakeAvailable(Catch2)
@@ -230,5 +230,4 @@ target_include_directories(tests PRIVATE ${PROJECT_SOURCE_DIR})
230230
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
231231
include(CTest)
232232
include(Catch)
233-
catch_discover_tests(tests)
234-
233+
catch_discover_tests(tests ADD_TAGS_AS_LABELS SKIP_IS_FAILURE)

tests/svs/core/distance.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,14 @@
1818
#include "svs/core/distance.h"
1919

2020
// catch 2
21+
#include "catch2/catch_template_test_macros.hpp"
2122
#include "catch2/catch_test_macros.hpp"
2223

24+
#include <numeric>
25+
#include <vector>
26+
27+
#include "svs/lib/avx_detection.h"
28+
2329
namespace {
2430

2531
std::string_view test_table = R"(
@@ -94,3 +100,70 @@ CATCH_TEST_CASE("Distance Utils", "[core][distance][distance_type]") {
94100
}
95101
}
96102
}
103+
104+
CATCH_TEMPLATE_TEST_CASE(
105+
"Distance ASan",
106+
"[distance][simd][asan]",
107+
svs::DistanceL2,
108+
svs::DistanceIP,
109+
svs::DistanceCosineSimilarity
110+
) {
111+
using Distance = TestType;
112+
113+
auto run_test = []() {
114+
// some full-width AVX2/AVX512 registers plus (crucially) ragged epilogue
115+
constexpr size_t size = 64 + 2;
116+
std::vector<float> a(size);
117+
std::vector<float> b(size);
118+
119+
std::iota(a.begin(), a.end(), 1.0f);
120+
std::iota(b.begin(), b.end(), 2.0f);
121+
122+
// Ensure no spare capacity
123+
a.shrink_to_fit();
124+
b.shrink_to_fit();
125+
126+
auto dist = svs::distance::compute(Distance(), std::span(a), std::span(b));
127+
CATCH_REQUIRE(dist >= 0);
128+
};
129+
130+
CATCH_SECTION("Default") { run_test(); }
131+
132+
#ifdef __x86_64__
133+
if (svs::detail::avx_runtime_flags.is_avx512vnni_supported()) {
134+
CATCH_SECTION("No AVX512VNNI") {
135+
auto& mutable_flags =
136+
const_cast<svs::detail::AVXRuntimeFlags&>(svs::detail::avx_runtime_flags);
137+
auto original = mutable_flags;
138+
mutable_flags.avx512vnni = false;
139+
run_test();
140+
mutable_flags = original;
141+
}
142+
}
143+
144+
if (svs::detail::avx_runtime_flags.is_avx512f_supported()) {
145+
CATCH_SECTION("No AVX512F") {
146+
auto& mutable_flags =
147+
const_cast<svs::detail::AVXRuntimeFlags&>(svs::detail::avx_runtime_flags);
148+
auto original = mutable_flags;
149+
mutable_flags.avx512vnni = false;
150+
mutable_flags.avx512f = false;
151+
run_test();
152+
mutable_flags = original;
153+
}
154+
}
155+
156+
if (svs::detail::avx_runtime_flags.is_avx2_supported()) {
157+
CATCH_SECTION("No AVX2") {
158+
auto& mutable_flags =
159+
const_cast<svs::detail::AVXRuntimeFlags&>(svs::detail::avx_runtime_flags);
160+
auto original = mutable_flags;
161+
mutable_flags.avx512vnni = false;
162+
mutable_flags.avx512f = false;
163+
mutable_flags.avx2 = false;
164+
run_test();
165+
mutable_flags = original;
166+
}
167+
}
168+
#endif // __x86_64__
169+
}

tests/svs/index/inverted/clustering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ void test_end_to_end_clustering(
385385

386386
} // namespace
387387

388-
CATCH_TEST_CASE("Random Clustering - End to End", "[inverted][random_clustering]") {
388+
CATCH_TEST_CASE("Random Clustering - End to End", "[long][inverted][random_clustering]") {
389389
CATCH_SECTION("Uncompressed Data") {
390390
auto data = svs::data::SimpleData<float>::load(test_dataset::data_svs_file());
391391
test_end_to_end_clustering(data, svs::DistanceL2(), 1.2f);

tests/svs/index/inverted/memory_based.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include "tests/utils/test_dataset.h"
2424
#include <filesystem>
2525

26-
CATCH_TEST_CASE("InvertedIndex Logging Test", "[logging]") {
26+
CATCH_TEST_CASE("InvertedIndex Logging Test", "[long][logging]") {
2727
// Vector to store captured log messages
2828
std::vector<std::string> captured_logs;
2929
std::vector<std::string> global_captured_logs;

tests/svs/index/vamana/index.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ CATCH_TEST_CASE("Static VamanaIndex Per-Index Logging", "[logging]") {
181181
CATCH_REQUIRE(captured_logs[2].find("Batch Size:") != std::string::npos);
182182
}
183183

184-
CATCH_TEST_CASE("Vamana Index Default Parameters", "[parameter][vamana]") {
184+
CATCH_TEST_CASE("Vamana Index Default Parameters", "[long][parameter][vamana]") {
185185
using Catch::Approx;
186186
std::filesystem::path data_path = test_dataset::data_svs_file();
187187

tests/svs/index/vamana/multi.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ template <typename Distance> float pick_alpha(Distance SVS_UNUSED(dist)) {
4848

4949
CATCH_TEMPLATE_TEST_CASE(
5050
"Multi-vector dynamic vamana index",
51-
"[index][vamana][multi]",
51+
"[long][index][vamana][multi]",
5252
svs::DistanceL2,
5353
svs::DistanceIP,
5454
svs::DistanceCosineSimilarity

tests/svs/lib/avx_detection.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,15 @@ CATCH_TEST_CASE("AVX detection", "[lib][lib-avx-detection]") {
2929
<< svs::detail::avx_runtime_flags.is_avx512f_supported() << "\n";
3030
std::cout << "AVX512VNNI: " << std::boolalpha
3131
<< svs::detail::avx_runtime_flags.is_avx512vnni_supported() << "\n";
32+
33+
#ifdef __x86_64__
34+
CATCH_SECTION("Patching") {
35+
auto& mutable_flags =
36+
const_cast<svs::detail::AVXRuntimeFlags&>(svs::detail::avx_runtime_flags);
37+
auto original = mutable_flags.avx512f;
38+
mutable_flags.avx512f = false;
39+
CATCH_REQUIRE(svs::detail::avx_runtime_flags.is_avx512f_supported() == false);
40+
mutable_flags.avx512f = original;
41+
}
42+
#endif
3243
}

0 commit comments

Comments
 (0)