Skip to content

Commit 53da643

Browse files
committed
initial implementation
1 parent 429da70 commit 53da643

File tree

10 files changed

+1037
-3
lines changed

10 files changed

+1037
-3
lines changed

docs/source/api/data_transfer.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Data transfer
1010
From memory:
1111

1212
+---------------------------------------+----------------------------------------------------+
13-
| :cpp:func:`load` | load values from memory |
13+
| :cpp:func:`load` | load values from memory (optionally masked) |
1414
+---------------------------------------+----------------------------------------------------+
1515
| :cpp:func:`load_aligned` | load values from aligned memory |
1616
+---------------------------------------+----------------------------------------------------+
@@ -30,7 +30,7 @@ From a scalar:
3030
To memory:
3131

3232
+---------------------------------------+----------------------------------------------------+
33-
| :cpp:func:`store` | store values to memory |
33+
| :cpp:func:`store` | store values to memory (optionally masked) |
3434
+---------------------------------------+----------------------------------------------------+
3535
| :cpp:func:`store_aligned` | store values to aligned memory |
3636
+---------------------------------------+----------------------------------------------------+

include/xsimd/arch/common/xsimd_common_memory.hpp

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,119 @@ namespace xsimd
348348
return detail::load_unaligned<A>(mem, cvt, common {}, detail::conversion_type<A, T_in, T_out> {});
349349
}
350350

351+
// mask helpers
352+
namespace detail
353+
{
354+
// Helper: first_set_count recursion
355+
constexpr std::size_t first_set_count_impl(uint64_t bits, std::size_t n)
356+
{
357+
return (bits & 1u) != 0u
358+
? first_set_count_impl(bits >> 1, n + 1)
359+
: n;
360+
}
361+
362+
template <class Mask>
363+
constexpr std::size_t first_set_count(Mask const& mask) noexcept
364+
{
365+
return ((static_cast<uint64_t>(mask.mask()) >> first_set_count_impl(static_cast<uint64_t>(mask.mask()), 0)) == 0u)
366+
? first_set_count_impl(static_cast<uint64_t>(mask.mask()), 0)
367+
: (Mask::size + 1);
368+
}
369+
370+
// Helper: last_set_count recursion
371+
template <std::size_t Size>
372+
constexpr std::size_t last_set_count_impl(uint64_t bits, std::size_t n)
373+
{
374+
return (n < Size && (bits & (uint64_t(1) << (Size - 1 - n))) != 0u)
375+
? last_set_count_impl<Size>(bits, n + 1)
376+
: n;
377+
}
378+
379+
// safe mask for k bits (must be single return)
380+
constexpr uint64_t low_mask(std::size_t k)
381+
{
382+
return (k >= 64u) ? ~uint64_t(0) : ((uint64_t(1) << k) - 1u);
383+
}
384+
385+
template <class Mask>
386+
constexpr std::size_t last_set_count(Mask const& mask) noexcept
387+
{
388+
return ((static_cast<uint64_t>(mask.mask()) & low_mask(Mask::size - last_set_count_impl<Mask::size>(static_cast<uint64_t>(mask.mask()), 0))) == 0u)
389+
? last_set_count_impl<Mask::size>(static_cast<uint64_t>(mask.mask()), 0)
390+
: (Mask::size + 1);
391+
}
392+
}
393+
394+
// masked_load
395+
template <class A, class T_in, class T_out>
396+
XSIMD_INLINE batch<T_out, A> masked_load(T_in const* mem,
397+
typename batch<T_out, A>::batch_bool_type const& mask,
398+
convert<T_out>,
399+
aligned_mode,
400+
requires_arch<common>) noexcept
401+
{
402+
alignas(A::alignment()) T_out buffer[batch<T_out, A>::size] = {};
403+
for (std::size_t i = 0; i < batch<T_out, A>::size; ++i)
404+
{
405+
if (mask.get(i))
406+
{
407+
buffer[i] = static_cast<T_out>(mem[i]);
408+
}
409+
}
410+
return batch<T_out, A>::load_aligned(buffer);
411+
}
412+
413+
template <class A, class T_in, class T_out>
414+
XSIMD_INLINE batch<T_out, A> masked_load(T_in const* mem,
415+
typename batch<T_out, A>::batch_bool_type const& mask,
416+
convert<T_out>,
417+
unaligned_mode,
418+
requires_arch<common>) noexcept
419+
{
420+
alignas(A::alignment()) T_out buffer[batch<T_out, A>::size] = {};
421+
for (std::size_t i = 0; i < batch<T_out, A>::size; ++i)
422+
{
423+
if (mask.get(i))
424+
{
425+
buffer[i] = static_cast<T_out>(mem[i]);
426+
}
427+
}
428+
return batch<T_out, A>::load_aligned(buffer);
429+
}
430+
431+
// masked_store
432+
template <class A, class T_in, class T_out>
433+
XSIMD_INLINE void masked_store(T_out* mem,
434+
batch<T_in, A> const& src,
435+
typename batch<T_in, A>::batch_bool_type const& mask,
436+
aligned_mode,
437+
requires_arch<common>) noexcept
438+
{
439+
for (std::size_t i = 0; i < batch<T_in, A>::size; ++i)
440+
{
441+
if (mask.get(i))
442+
{
443+
mem[i] = static_cast<T_out>(src.get(i));
444+
}
445+
}
446+
}
447+
448+
template <class A, class T_in, class T_out>
449+
XSIMD_INLINE void masked_store(T_out* mem,
450+
batch<T_in, A> const& src,
451+
typename batch<T_in, A>::batch_bool_type const& mask,
452+
unaligned_mode,
453+
requires_arch<common>) noexcept
454+
{
455+
for (std::size_t i = 0; i < batch<T_in, A>::size; ++i)
456+
{
457+
if (mask.get(i))
458+
{
459+
mem[i] = static_cast<T_out>(src.get(i));
460+
}
461+
}
462+
}
463+
351464
// rotate_right
352465
template <size_t N, class A, class T>
353466
XSIMD_INLINE batch<T, A> rotate_right(batch<T, A> const& self, requires_arch<common>) noexcept

0 commit comments

Comments
 (0)