Skip to content

Commit 724ac33

Browse files
ahuber21Copilot
andauthored
fix(simd): umasked AVX2 load (#239)
The previous code _always_ performed a full width load on the provided data. In ragged-epilogue scenarios, where we request a masked load, this resulted in SEGV errors in certain runs with address sanitizer. ```c++ if (i < count.size()) { auto mask = create_mask<simd_width>(count); s0 = op.accumulate(mask, s0, op.load_a(mask, a + i), op.load_b(mask, b + i)); } ``` **Why wasn't this caught sooner?** The OS only triggers a segmentation fault if a read accesses an unmapped memory page. Since memory protection (typically) operates at a 4KB page granularity, reading past the end of a buffer is "safe" from the OS's perspective unless the overflow happens to cross exactly into an unmapped page. **Why is ASan catching it sporadically?** Since our underlying object storage is `std::vector`, ASan detection requires two specific conditions to align: * **No Spare Capacity:** The vector's `size()` must equal its `capacity()`. If there is spare capacity, the unsafe load simply reads valid (though uninitialized) memory owned by the vector. * **Alignment & Redzones:** The underlying heap allocation must be sized and aligned such that the full-width SIMD read (e.g., 32 bytes) actually crosses the allocation boundary into the ASan redzone. If the allocator adds padding for alignment, the read might land in that valid padding instead. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent e39360a commit 724ac33

File tree

9 files changed

+140
-37
lines changed

9 files changed

+140
-37
lines changed

.github/workflows/build-linux.yml

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,38 @@ concurrency:
3030

3131
jobs:
3232
build:
33-
name: ${{ matrix.cxx }}, ${{ matrix.build_type }}, ivf=${{ matrix.ivf }}
33+
name: ${{ matrix.cxx }}, ${{ matrix.build_type }}, ivf=${{ matrix.ivf }}, asan=${{ matrix.asan }}
3434
runs-on: ubuntu-22.04
3535
strategy:
3636
matrix:
3737
build_type: [RelWithDebugInfo]
3838
ivf: [OFF, ON]
3939
cxx: [g++-11, g++-12, clang++-15]
40+
asan: [OFF]
41+
cmake_extra_args: ["-DSVS_BUILD_BINARIES=YES -DSVS_BUILD_EXAMPLES=YES"]
42+
ctest_args: [""]
4043
include:
4144
- cxx: g++-11
4245
cc: gcc-11
4346
- cxx: g++-12
4447
cc: gcc-12
4548
- cxx: clang++-15
4649
cc: clang-15
50+
- cxx: clang++-18
51+
cc: clang-18
52+
build_type: Debug
53+
ivf: OFF
54+
asan: ON
55+
# address sanitizer flags
56+
cmake_extra_args: >-
57+
-DCMAKE_CXX_FLAGS='-fsanitize=address -fno-omit-frame-pointer -g'
58+
-DCMAKE_C_FLAGS='-fsanitize=address -fno-omit-frame-pointer -g'
59+
-DCMAKE_EXE_LINKER_FLAGS='-fsanitize=address'
60+
-DCMAKE_SHARED_LINKER_FLAGS='-fsanitize=address'
61+
-DSVS_BUILD_BINARIES=NO
62+
-DSVS_BUILD_EXAMPLES=NO
63+
# skip longer-running tests
64+
ctest_args: "-LE long"
4765
exclude:
4866
- cxx: g++-12
4967
ivf: ON
@@ -60,6 +78,13 @@ jobs:
6078
source /opt/intel/oneapi/setvars.sh
6179
printenv >> $GITHUB_ENV
6280
81+
- name: Install Clang 18
82+
if: matrix.cxx == 'clang++-18'
83+
run: |
84+
wget https://apt.llvm.org/llvm.sh
85+
chmod +x llvm.sh
86+
sudo ./llvm.sh 18
87+
6388
- name: Configure build
6489
working-directory: ${{ runner.temp }}
6590
env:
@@ -69,25 +94,24 @@ jobs:
6994
run: |
7095
cmake -B${TEMP_WORKSPACE}/build -S${GITHUB_WORKSPACE} \
7196
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
72-
-DSVS_BUILD_BINARIES=YES \
7397
-DSVS_BUILD_TESTS=YES \
74-
-DSVS_BUILD_EXAMPLES=YES \
75-
-DSVS_EXPERIMENTAL_LEANVEC=YES \
7698
-DSVS_NO_AVX512=NO \
77-
-DSVS_EXPERIMENTAL_ENABLE_IVF=${{ matrix.ivf }}
99+
-DSVS_EXPERIMENTAL_ENABLE_IVF=${{ matrix.ivf }} \
100+
${{ matrix.cmake_extra_args }}
78101
79102
- name: Build Tests and Utilities
80103
working-directory: ${{ runner.temp }}/build
81104
run: make -j$(nproc)
82105

83106
- name: Run tests
84107
env:
85-
CTEST_OUTPUT_ON_FAILURE: 1
108+
CTEST_OUTPUT_ON_FAILURE: 1
86109
working-directory: ${{ runner.temp }}/build/tests
87-
run: ctest -C ${{ matrix.build_type }}
110+
run: ctest -C ${{ matrix.build_type }} ${{ matrix.ctest_args }}
88111

89112
- name: Run Cpp Examples
113+
if: matrix.asan != 'ON'
90114
env:
91-
CTEST_OUTPUT_ON_FAILURE: 1
115+
CTEST_OUTPUT_ON_FAILURE: 1
92116
working-directory: ${{ runner.temp }}/build/examples/cpp
93-
run: ctest -C RelWithDebugInfo
117+
run: ctest -C ${{ matrix.build_type }}

include/svs/core/distance/simd_utils.h

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#if defined(__i386__) || defined(__x86_64__)
2020

2121
#include <array>
22+
#include <cstring>
2223
#include <limits>
2324
#include <type_traits>
2425

@@ -332,11 +333,10 @@ template <> struct ConvertToFloat<8> {
332333
// from float
333334
static __m256 load(const float* ptr) { return _mm256_loadu_ps(ptr); }
334335
static __m256 load(mask_t m, const float* ptr) {
335-
// AVX2 doesn't have native masked load, so we load and then blend
336-
auto data = _mm256_loadu_ps(ptr);
337-
auto zero = _mm256_setzero_ps();
338-
auto mask_vec = create_blend_mask_avx2(m);
339-
return _mm256_blendv_ps(zero, data, mask_vec);
336+
// Full width load with blending may cause out-of-bounds read (SEGV)
337+
// Therefore we use _mm256_maskload_ps which safely handles masked loads
338+
auto mask_vec = _mm256_castps_si256(create_blend_mask_avx2(m));
339+
return _mm256_maskload_ps(ptr, mask_vec);
340340
}
341341

342342
// from float16
@@ -345,10 +345,10 @@ template <> struct ConvertToFloat<8> {
345345
}
346346

347347
static __m256 load(mask_t m, const Float16* ptr) {
348-
auto data = _mm256_cvtph_ps(_mm_loadu_si128(reinterpret_cast<const __m128i*>(ptr)));
349-
auto zero = _mm256_setzero_ps();
350-
auto mask_vec = create_blend_mask_avx2(m);
351-
return _mm256_blendv_ps(zero, data, mask_vec);
348+
// Safe masked load using a temporary buffer to avoid SEGV
349+
__m128i buffer = _mm_setzero_si128();
350+
std::memcpy(&buffer, ptr, __builtin_popcount(m) * sizeof(Float16));
351+
return _mm256_cvtph_ps(buffer);
352352
}
353353

354354
// from uint8
@@ -359,12 +359,10 @@ template <> struct ConvertToFloat<8> {
359359
}
360360

361361
static __m256 load(mask_t m, const uint8_t* ptr) {
362-
auto data = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
363-
_mm_cvtsi64_si128(*(reinterpret_cast<const int64_t*>(ptr)))
364-
));
365-
auto zero = _mm256_setzero_ps();
366-
auto mask_vec = create_blend_mask_avx2(m);
367-
return _mm256_blendv_ps(zero, data, mask_vec);
362+
// Safe masked load using a temporary buffer to avoid SEGV
363+
int64_t buffer = 0;
364+
std::memcpy(&buffer, ptr, __builtin_popcount(m) * sizeof(uint8_t));
365+
return _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_cvtsi64_si128(buffer)));
368366
}
369367

370368
// from int8
@@ -375,12 +373,10 @@ template <> struct ConvertToFloat<8> {
375373
}
376374

377375
static __m256 load(mask_t m, const int8_t* ptr) {
378-
auto data = _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(
379-
_mm_cvtsi64_si128(*(reinterpret_cast<const int64_t*>(ptr)))
380-
));
381-
auto zero = _mm256_setzero_ps();
382-
auto mask_vec = create_blend_mask_avx2(m);
383-
return _mm256_blendv_ps(zero, data, mask_vec);
376+
// Safe masked load using a temporary buffer to avoid SEGV
377+
int64_t buffer = 0;
378+
std::memcpy(&buffer, ptr, __builtin_popcount(m) * sizeof(int8_t));
379+
return _mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(_mm_cvtsi64_si128(buffer)));
384380
}
385381

386382
// We do not need to treat the left or right-hand differently.

tests/CMakeLists.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ set(CMAKE_CXX_STANDARD ${SVS_CXX_STANDARD})
3737
FetchContent_Declare(
3838
Catch2
3939
GIT_REPOSITORY https://github.com/catchorg/Catch2.git
40-
GIT_TAG v3.4.0
40+
GIT_TAG v3.11.0
4141
)
4242

4343
FetchContent_MakeAvailable(Catch2)
@@ -230,5 +230,4 @@ target_include_directories(tests PRIVATE ${PROJECT_SOURCE_DIR})
230230
list(APPEND CMAKE_MODULE_PATH ${catch2_SOURCE_DIR}/extras)
231231
include(CTest)
232232
include(Catch)
233-
catch_discover_tests(tests)
234-
233+
catch_discover_tests(tests ADD_TAGS_AS_LABELS SKIP_IS_FAILURE)

tests/svs/core/distance.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,14 @@
1818
#include "svs/core/distance.h"
1919

2020
// catch 2
21+
#include "catch2/catch_template_test_macros.hpp"
2122
#include "catch2/catch_test_macros.hpp"
2223

24+
#include <numeric>
25+
#include <vector>
26+
27+
#include "svs/lib/avx_detection.h"
28+
2329
namespace {
2430

2531
std::string_view test_table = R"(
@@ -94,3 +100,70 @@ CATCH_TEST_CASE("Distance Utils", "[core][distance][distance_type]") {
94100
}
95101
}
96102
}
103+
104+
CATCH_TEMPLATE_TEST_CASE(
105+
"Distance ASan",
106+
"[distance][simd][asan]",
107+
svs::DistanceL2,
108+
svs::DistanceIP,
109+
svs::DistanceCosineSimilarity
110+
) {
111+
using Distance = TestType;
112+
113+
auto run_test = []() {
114+
// some full-width AVX2/AVX512 registers plus (crucially) ragged epilogue
115+
constexpr size_t size = 64 + 2;
116+
std::vector<float> a(size);
117+
std::vector<float> b(size);
118+
119+
std::iota(a.begin(), a.end(), 1.0f);
120+
std::iota(b.begin(), b.end(), 2.0f);
121+
122+
// Ensure no spare capacity
123+
a.shrink_to_fit();
124+
b.shrink_to_fit();
125+
126+
auto dist = svs::distance::compute(Distance(), std::span(a), std::span(b));
127+
CATCH_REQUIRE(dist >= 0);
128+
};
129+
130+
CATCH_SECTION("Default") { run_test(); }
131+
132+
#ifdef __x86_64__
133+
if (svs::detail::avx_runtime_flags.is_avx512vnni_supported()) {
134+
CATCH_SECTION("No AVX512VNNI") {
135+
auto& mutable_flags =
136+
const_cast<svs::detail::AVXRuntimeFlags&>(svs::detail::avx_runtime_flags);
137+
auto original = mutable_flags;
138+
mutable_flags.avx512vnni = false;
139+
run_test();
140+
mutable_flags = original;
141+
}
142+
}
143+
144+
if (svs::detail::avx_runtime_flags.is_avx512f_supported()) {
145+
CATCH_SECTION("No AVX512F") {
146+
auto& mutable_flags =
147+
const_cast<svs::detail::AVXRuntimeFlags&>(svs::detail::avx_runtime_flags);
148+
auto original = mutable_flags;
149+
mutable_flags.avx512vnni = false;
150+
mutable_flags.avx512f = false;
151+
run_test();
152+
mutable_flags = original;
153+
}
154+
}
155+
156+
if (svs::detail::avx_runtime_flags.is_avx2_supported()) {
157+
CATCH_SECTION("No AVX2") {
158+
auto& mutable_flags =
159+
const_cast<svs::detail::AVXRuntimeFlags&>(svs::detail::avx_runtime_flags);
160+
auto original = mutable_flags;
161+
mutable_flags.avx512vnni = false;
162+
mutable_flags.avx512f = false;
163+
mutable_flags.avx2 = false;
164+
run_test();
165+
mutable_flags = original;
166+
}
167+
}
168+
#endif // __x86_64__
169+
}

tests/svs/index/inverted/clustering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ void test_end_to_end_clustering(
385385

386386
} // namespace
387387

388-
CATCH_TEST_CASE("Random Clustering - End to End", "[inverted][random_clustering]") {
388+
CATCH_TEST_CASE("Random Clustering - End to End", "[long][inverted][random_clustering]") {
389389
CATCH_SECTION("Uncompressed Data") {
390390
auto data = svs::data::SimpleData<float>::load(test_dataset::data_svs_file());
391391
test_end_to_end_clustering(data, svs::DistanceL2(), 1.2f);

tests/svs/index/inverted/memory_based.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include "tests/utils/test_dataset.h"
2424
#include <filesystem>
2525

26-
CATCH_TEST_CASE("InvertedIndex Logging Test", "[logging]") {
26+
CATCH_TEST_CASE("InvertedIndex Logging Test", "[long][logging]") {
2727
// Vector to store captured log messages
2828
std::vector<std::string> captured_logs;
2929
std::vector<std::string> global_captured_logs;

tests/svs/index/vamana/index.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ CATCH_TEST_CASE("Static VamanaIndex Per-Index Logging", "[logging]") {
181181
CATCH_REQUIRE(captured_logs[2].find("Batch Size:") != std::string::npos);
182182
}
183183

184-
CATCH_TEST_CASE("Vamana Index Default Parameters", "[parameter][vamana]") {
184+
CATCH_TEST_CASE("Vamana Index Default Parameters", "[long][parameter][vamana]") {
185185
using Catch::Approx;
186186
std::filesystem::path data_path = test_dataset::data_svs_file();
187187

tests/svs/index/vamana/multi.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ template <typename Distance> float pick_alpha(Distance SVS_UNUSED(dist)) {
4848

4949
CATCH_TEMPLATE_TEST_CASE(
5050
"Multi-vector dynamic vamana index",
51-
"[index][vamana][multi]",
51+
"[long][index][vamana][multi]",
5252
svs::DistanceL2,
5353
svs::DistanceIP,
5454
svs::DistanceCosineSimilarity

tests/svs/lib/avx_detection.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,15 @@ CATCH_TEST_CASE("AVX detection", "[lib][lib-avx-detection]") {
2929
<< svs::detail::avx_runtime_flags.is_avx512f_supported() << "\n";
3030
std::cout << "AVX512VNNI: " << std::boolalpha
3131
<< svs::detail::avx_runtime_flags.is_avx512vnni_supported() << "\n";
32+
33+
#ifdef __x86_64__
34+
CATCH_SECTION("Patching") {
35+
auto& mutable_flags =
36+
const_cast<svs::detail::AVXRuntimeFlags&>(svs::detail::avx_runtime_flags);
37+
auto original = mutable_flags.avx512f;
38+
mutable_flags.avx512f = false;
39+
CATCH_REQUIRE(svs::detail::avx_runtime_flags.is_avx512f_supported() == false);
40+
mutable_flags.avx512f = original;
41+
}
42+
#endif
3243
}

0 commit comments

Comments
 (0)