Skip to content

Commit 3f3acf5

Browse files
committed
Feat: Load/Store masked API
1. Adds new masked API compile time masks (store_masked and load_masked) 2. General use case optimization 3. New tests 4. x86 kernels 5. Adds new APIs to batch_bool_constant for convenience resembling #include<bit> 6. Tests the new APIs
1 parent cbf693c commit 3f3acf5

14 files changed

+1787
-30
lines changed

docs/source/api/data_transfer.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Data transfer
1010
From memory:
1111

1212
+---------------------------------------+----------------------------------------------------+
13-
| :cpp:func:`load` | load values from memory |
13+
| :cpp:func:`load` | load values from memory (optionally masked) |
1414
+---------------------------------------+----------------------------------------------------+
1515
| :cpp:func:`load_aligned` | load values from aligned memory |
1616
+---------------------------------------+----------------------------------------------------+
@@ -30,7 +30,7 @@ From a scalar:
3030
To memory:
3131

3232
+---------------------------------------+----------------------------------------------------+
33-
| :cpp:func:`store` | store values to memory |
33+
| :cpp:func:`store` | store values to memory (optionally masked) |
3434
+---------------------------------------+----------------------------------------------------+
3535
| :cpp:func:`store_aligned` | store values to aligned memory |
3636
+---------------------------------------+----------------------------------------------------+

include/xsimd/arch/common/xsimd_common_arithmetic.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <limits>
1717
#include <type_traits>
1818

19+
#include "../../types/xsimd_batch_constant.hpp"
1920
#include "./xsimd_common_details.hpp"
2021

2122
namespace xsimd

include/xsimd/arch/common/xsimd_common_memory.hpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#define XSIMD_COMMON_MEMORY_HPP
1414

1515
#include <algorithm>
16+
#include <array>
1617
#include <complex>
1718
#include <stdexcept>
1819

@@ -348,6 +349,102 @@ namespace xsimd
348349
return detail::load_unaligned<A>(mem, cvt, common {}, detail::conversion_type<A, T_in, T_out> {});
349350
}
350351

352+
template <class A, class T>
353+
XSIMD_INLINE batch<T, A> load(T const* mem, aligned_mode, requires_arch<A>) noexcept
354+
{
355+
return load_aligned<A>(mem, convert<T> {}, A {});
356+
}
357+
358+
template <class A, class T>
359+
XSIMD_INLINE batch<T, A> load(T const* mem, unaligned_mode, requires_arch<A>) noexcept
360+
{
361+
return load_unaligned<A>(mem, convert<T> {}, A {});
362+
}
363+
364+
template <class A, class T_in, class T_out, bool... Values, class alignment>
365+
XSIMD_INLINE batch<T_out, A>
366+
load_masked(T_in const* mem, batch_bool_constant<T_out, A, Values...>, convert<T_out>, alignment, requires_arch<common>) noexcept
367+
{
368+
constexpr std::size_t size = batch<T_out, A>::size;
369+
alignas(A::alignment()) std::array<T_out, size> buffer {};
370+
constexpr bool mask[size] = { Values... };
371+
372+
for (std::size_t i = 0; i < size; ++i)
373+
buffer[i] = mask[i] ? static_cast<T_out>(mem[i]) : T_out(0);
374+
375+
return batch<T_out, A>::load(buffer.data(), aligned_mode {});
376+
}
377+
378+
template <class A, class T_in, class T_out, bool... Values, class alignment>
379+
XSIMD_INLINE void
380+
store_masked(T_out* mem, batch<T_in, A> const& src, batch_bool_constant<T_in, A, Values...>, alignment, requires_arch<common>) noexcept
381+
{
382+
constexpr std::size_t size = batch<T_in, A>::size;
383+
constexpr bool mask[size] = { Values... };
384+
385+
for (std::size_t i = 0; i < size; ++i)
386+
if (mask[i])
387+
{
388+
mem[i] = static_cast<T_out>(src.get(i));
389+
}
390+
}
391+
392+
template <class A, bool... Values, class Mode>
393+
XSIMD_INLINE batch<int32_t, A> load_masked(int32_t const* mem, batch_bool_constant<int32_t, A, Values...>, convert<int32_t>, Mode, requires_arch<A>) noexcept
394+
{
395+
const auto f = load_masked<A>(reinterpret_cast<const float*>(mem), batch_bool_constant<float, A, Values...> {}, convert<float> {}, Mode {}, A {});
396+
return bitwise_cast<int32_t>(f);
397+
}
398+
399+
template <class A, bool... Values, class Mode>
400+
XSIMD_INLINE batch<uint32_t, A> load_masked(uint32_t const* mem, batch_bool_constant<uint32_t, A, Values...>, convert<uint32_t>, Mode, requires_arch<A>) noexcept
401+
{
402+
const auto f = load_masked<A>(reinterpret_cast<const float*>(mem), batch_bool_constant<float, A, Values...> {}, convert<float> {}, Mode {}, A {});
403+
return bitwise_cast<uint32_t>(f);
404+
}
405+
406+
template <class A, bool... Values, class Mode>
407+
XSIMD_INLINE typename std::enable_if<has_simd_register<double, A>::value, batch<int64_t, A>>::type
408+
load_masked(int64_t const* mem, batch_bool_constant<int64_t, A, Values...>, convert<int64_t>, Mode, requires_arch<A>) noexcept
409+
{
410+
const auto d = load_masked<A>(reinterpret_cast<const double*>(mem), batch_bool_constant<double, A, Values...> {}, convert<double> {}, Mode {}, A {});
411+
return bitwise_cast<int64_t>(d);
412+
}
413+
414+
template <class A, bool... Values, class Mode>
415+
XSIMD_INLINE typename std::enable_if<has_simd_register<double, A>::value, batch<uint64_t, A>>::type
416+
load_masked(uint64_t const* mem, batch_bool_constant<uint64_t, A, Values...>, convert<uint64_t>, Mode, requires_arch<A>) noexcept
417+
{
418+
const auto d = load_masked<A>(reinterpret_cast<const double*>(mem), batch_bool_constant<double, A, Values...> {}, convert<double> {}, Mode {}, A {});
419+
return bitwise_cast<uint64_t>(d);
420+
}
421+
422+
template <class A, bool... Values, class Mode>
423+
XSIMD_INLINE void store_masked(int32_t* mem, batch<int32_t, A> const& src, batch_bool_constant<int32_t, A, Values...>, Mode, requires_arch<A>) noexcept
424+
{
425+
store_masked<A>(reinterpret_cast<float*>(mem), bitwise_cast<float>(src), batch_bool_constant<float, A, Values...> {}, Mode {}, A {});
426+
}
427+
428+
template <class A, bool... Values, class Mode>
429+
XSIMD_INLINE void store_masked(uint32_t* mem, batch<uint32_t, A> const& src, batch_bool_constant<uint32_t, A, Values...>, Mode, requires_arch<A>) noexcept
430+
{
431+
store_masked<A>(reinterpret_cast<float*>(mem), bitwise_cast<float>(src), batch_bool_constant<float, A, Values...> {}, Mode {}, A {});
432+
}
433+
434+
template <class A, bool... Values, class Mode>
435+
XSIMD_INLINE typename std::enable_if<has_simd_register<double, A>::value, void>::type
436+
store_masked(int64_t* mem, batch<int64_t, A> const& src, batch_bool_constant<int64_t, A, Values...>, Mode, requires_arch<A>) noexcept
437+
{
438+
store_masked<A>(reinterpret_cast<double*>(mem), bitwise_cast<double>(src), batch_bool_constant<double, A, Values...> {}, Mode {}, A {});
439+
}
440+
441+
template <class A, bool... Values, class Mode>
442+
XSIMD_INLINE typename std::enable_if<has_simd_register<double, A>::value, void>::type
443+
store_masked(uint64_t* mem, batch<uint64_t, A> const& src, batch_bool_constant<uint64_t, A, Values...>, Mode, requires_arch<A>) noexcept
444+
{
445+
store_masked<A>(reinterpret_cast<double*>(mem), bitwise_cast<double>(src), batch_bool_constant<double, A, Values...> {}, Mode {}, A {});
446+
}
447+
351448
// rotate_right
352449
template <size_t N, class A, class T>
353450
XSIMD_INLINE batch<T, A> rotate_right(batch<T, A> const& self, requires_arch<common>) noexcept

include/xsimd/arch/xsimd_avx.hpp

Lines changed: 153 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <type_traits>
1919

2020
#include "../types/xsimd_avx_register.hpp"
21+
#include "../types/xsimd_batch_constant.hpp"
2122

2223
namespace xsimd
2324
{
@@ -36,20 +37,35 @@ namespace xsimd
3637

3738
namespace detail
3839
{
39-
XSIMD_INLINE void split_avx(__m256i val, __m128i& low, __m128i& high) noexcept
40+
XSIMD_INLINE __m128i lower_half(__m256i self) noexcept
4041
{
41-
low = _mm256_castsi256_si128(val);
42-
high = _mm256_extractf128_si256(val, 1);
42+
return _mm256_castsi256_si128(self);
4343
}
44-
XSIMD_INLINE void split_avx(__m256 val, __m128& low, __m128& high) noexcept
44+
XSIMD_INLINE __m128 lower_half(__m256 self) noexcept
4545
{
46-
low = _mm256_castps256_ps128(val);
47-
high = _mm256_extractf128_ps(val, 1);
46+
return _mm256_castps256_ps128(self);
4847
}
49-
XSIMD_INLINE void split_avx(__m256d val, __m128d& low, __m128d& high) noexcept
48+
XSIMD_INLINE __m128d lower_half(__m256d self) noexcept
5049
{
51-
low = _mm256_castpd256_pd128(val);
52-
high = _mm256_extractf128_pd(val, 1);
50+
return _mm256_castpd256_pd128(self);
51+
}
52+
XSIMD_INLINE __m128i upper_half(__m256i self) noexcept
53+
{
54+
return _mm256_extractf128_si256(self, 1);
55+
}
56+
XSIMD_INLINE __m128 upper_half(__m256 self) noexcept
57+
{
58+
return _mm256_extractf128_ps(self, 1);
59+
}
60+
XSIMD_INLINE __m128d upper_half(__m256d self) noexcept
61+
{
62+
return _mm256_extractf128_pd(self, 1);
63+
}
64+
template <class Full, class Half>
65+
XSIMD_INLINE void split_avx(Full val, Half& low, Half& high) noexcept
66+
{
67+
low = lower_half(val);
68+
high = upper_half(val);
5369
}
5470
XSIMD_INLINE __m256i merge_sse(__m128i low, __m128i high) noexcept
5571
{
@@ -865,6 +881,134 @@ namespace xsimd
865881
return _mm256_loadu_pd(mem);
866882
}
867883

884+
// load_masked
885+
template <class A, bool... Values, class Mode>
886+
XSIMD_INLINE batch<float, A> load_masked(float const* mem, batch_bool_constant<float, A, Values...> mask, convert<float>, Mode, requires_arch<avx>) noexcept
887+
{
888+
XSIMD_IF_CONSTEXPR(mask.none())
889+
{
890+
return _mm256_setzero_ps();
891+
}
892+
else XSIMD_IF_CONSTEXPR(mask.all())
893+
{
894+
return load<A>(mem, Mode {});
895+
}
896+
// confined to lower 128-bit half (4 lanes) → forward to SSE2
897+
else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 4)
898+
{
899+
constexpr auto mlo = ::xsimd::detail::lower_half<sse4_2>(mask);
900+
const auto lo = load_masked(mem, mlo, convert<float> {}, Mode {}, sse4_2 {});
901+
return batch<float, A>(detail::merge_sse(lo, batch<float, sse4_2>(0.f)));
902+
}
903+
// confined to upper 128-bit half (4 lanes) → forward to SSE2
904+
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 4)
905+
{
906+
constexpr auto mhi = ::xsimd::detail::upper_half<sse4_2>(mask);
907+
const auto hi = load_masked(mem + 4, mhi, convert<float> {}, Mode {}, sse4_2 {});
908+
return batch<float, A>(detail::merge_sse(batch<float, sse4_2>(0.f), hi));
909+
}
910+
else
911+
{
912+
// crossing 128-bit boundary → use 256-bit masked load
913+
return _mm256_maskload_ps(mem, mask.as_batch());
914+
}
915+
}
916+
917+
template <class A, bool... Values, class Mode>
918+
XSIMD_INLINE batch<double, A> load_masked(double const* mem, batch_bool_constant<double, A, Values...> mask, convert<double>, Mode, requires_arch<avx>) noexcept
919+
{
920+
XSIMD_IF_CONSTEXPR(mask.none())
921+
{
922+
return _mm256_setzero_pd();
923+
}
924+
else XSIMD_IF_CONSTEXPR(mask.all())
925+
{
926+
return load<A>(mem, Mode {});
927+
}
928+
// confined to lower 128-bit half (2 lanes) → forward to SSE2
929+
else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 2)
930+
{
931+
constexpr auto mlo = ::xsimd::detail::lower_half<sse4_2>(mask);
932+
const auto lo = load_masked(mem, mlo, convert<double> {}, Mode {}, sse4_2 {});
933+
return batch<double, A>(detail::merge_sse(lo, batch<double, sse4_2>(0.0)));
934+
}
935+
// confined to upper 128-bit half (2 lanes) → forward to SSE2
936+
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 2)
937+
{
938+
constexpr auto mhi = ::xsimd::detail::upper_half<sse4_2>(mask);
939+
const auto hi = load_masked(mem + 2, mhi, convert<double> {}, Mode {}, sse4_2 {});
940+
return batch<double, A>(detail::merge_sse(batch<double, sse4_2>(0.0), hi));
941+
}
942+
else
943+
{
944+
// crossing 128-bit boundary → use 256-bit masked load
945+
return _mm256_maskload_pd(mem, mask.as_batch());
946+
}
947+
}
948+
949+
// store_masked
950+
template <class A, bool... Values, class Mode>
951+
XSIMD_INLINE void store_masked(float* mem, batch<float, A> const& src, batch_bool_constant<float, A, Values...> mask, Mode, requires_arch<avx>) noexcept
952+
{
953+
XSIMD_IF_CONSTEXPR(mask.none())
954+
{
955+
return;
956+
}
957+
else XSIMD_IF_CONSTEXPR(mask.all())
958+
{
959+
src.store(mem, Mode {});
960+
}
961+
// confined to lower 128-bit half (4 lanes) → forward to SSE2
962+
else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 4)
963+
{
964+
constexpr auto mlo = ::xsimd::detail::lower_half<sse4_2>(mask);
965+
const batch<float, sse4_2> lo(_mm256_castps256_ps128(src));
966+
store_masked<sse4_2>(mem, lo, mlo, Mode {}, sse4_2 {});
967+
}
968+
// confined to upper 128-bit half (4 lanes) → forward to SSE2
969+
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 4)
970+
{
971+
constexpr auto mhi = ::xsimd::detail::upper_half<sse4_2>(mask);
972+
const batch<float, sse4_2> hi(_mm256_extractf128_ps(src, 1));
973+
store_masked<sse4_2>(mem + 4, hi, mhi, Mode {}, sse4_2 {});
974+
}
975+
else
976+
{
977+
_mm256_maskstore_ps(mem, mask.as_batch(), src);
978+
}
979+
}
980+
981+
template <class A, bool... Values, class Mode>
982+
XSIMD_INLINE void store_masked(double* mem, batch<double, A> const& src, batch_bool_constant<double, A, Values...> mask, Mode, requires_arch<avx>) noexcept
983+
{
984+
XSIMD_IF_CONSTEXPR(mask.none())
985+
{
986+
return;
987+
}
988+
else XSIMD_IF_CONSTEXPR(mask.all())
989+
{
990+
src.store(mem, Mode {});
991+
}
992+
// confined to lower 128-bit half (2 lanes) → forward to SSE2
993+
else XSIMD_IF_CONSTEXPR(mask.countl_zero() >= 2)
994+
{
995+
constexpr auto mlo = ::xsimd::detail::lower_half<sse2>(mask);
996+
const batch<double, sse2> lo(_mm256_castpd256_pd128(src));
997+
store_masked<sse2>(mem, lo, mlo, Mode {}, sse4_2 {});
998+
}
999+
// confined to upper 128-bit half (2 lanes) → forward to SSE2
1000+
else XSIMD_IF_CONSTEXPR(mask.countr_zero() >= 2)
1001+
{
1002+
constexpr auto mhi = ::xsimd::detail::upper_half<sse2>(mask);
1003+
const batch<double, sse2> hi(_mm256_extractf128_pd(src, 1));
1004+
store_masked<sse2>(mem + 2, hi, mhi, Mode {}, sse4_2 {});
1005+
}
1006+
else
1007+
{
1008+
_mm256_maskstore_pd(mem, mask.as_batch(), src);
1009+
}
1010+
}
1011+
8681012
// lt
8691013
template <class A>
8701014
XSIMD_INLINE batch_bool<float, A> lt(batch<float, A> const& self, batch<float, A> const& other, requires_arch<avx>) noexcept

0 commit comments

Comments
 (0)