Skip to content

Commit 2ae15bf

Browse files
committed
WIP: optimizing masked ops
1 parent 1bd8e19 commit 2ae15bf

File tree

10 files changed

+1841
-122
lines changed

10 files changed

+1841
-122
lines changed

include/xsimd/arch/common/xsimd_common_memory.hpp

Lines changed: 87 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -351,113 +351,119 @@ namespace xsimd
351351
// mask helpers
352352
namespace detail
353353
{
354-
// Helper: first_set_count recursion
355-
constexpr std::size_t first_set_count_impl(uint64_t bits, std::size_t n)
356-
{
357-
return (bits & 1u) != 0u
358-
? first_set_count_impl(bits >> 1, n + 1)
359-
: n;
360-
}
361-
362-
template <class Mask>
363-
constexpr std::size_t first_set_count(Mask const& mask) noexcept
364-
{
365-
return ((static_cast<uint64_t>(mask.mask()) >> first_set_count_impl(static_cast<uint64_t>(mask.mask()), 0)) == 0u)
366-
? first_set_count_impl(static_cast<uint64_t>(mask.mask()), 0)
367-
: (Mask::size + 1);
368-
}
369-
370-
// Helper: last_set_count recursion
371-
template <std::size_t Size>
372-
constexpr std::size_t last_set_count_impl(uint64_t bits, std::size_t n)
373-
{
374-
return (n < Size && (bits & (uint64_t(1) << (Size - 1 - n))) != 0u)
375-
? last_set_count_impl<Size>(bits, n + 1)
376-
: n;
377-
}
378-
379354
// safe mask for k bits (must be single return)
380355
constexpr uint64_t low_mask(std::size_t k)
381356
{
382357
return (k >= 64u) ? ~uint64_t(0) : ((uint64_t(1) << k) - 1u);
383358
}
384-
385-
template <class Mask>
386-
constexpr std::size_t last_set_count(Mask const& mask) noexcept
387-
{
388-
return ((static_cast<uint64_t>(mask.mask()) & low_mask(Mask::size - last_set_count_impl<Mask::size>(static_cast<uint64_t>(mask.mask()), 0))) == 0u)
389-
? last_set_count_impl<Mask::size>(static_cast<uint64_t>(mask.mask()), 0)
390-
: (Mask::size + 1);
391-
}
392359
}
393360

394-
// masked_load
395361
template <class A, class T_in, class T_out>
396-
XSIMD_INLINE batch<T_out, A> masked_load(T_in const* mem,
397-
typename batch<T_out, A>::batch_bool_type const& mask,
398-
convert<T_out>,
399-
aligned_mode,
400-
requires_arch<common>) noexcept
401-
{
402-
alignas(A::alignment()) T_out buffer[batch<T_out, A>::size] = {};
403-
for (std::size_t i = 0; i < batch<T_out, A>::size; ++i)
404-
{
362+
XSIMD_INLINE batch<T_out, A> masked_load(T_in const* mem, typename batch<T_out, A>::batch_bool_type const& mask, convert<T_out>, requires_arch<common>) noexcept
363+
{
364+
constexpr std::size_t size = batch<T_out, A>::size;
365+
const uint64_t m = mask.mask();
366+
if (m == 0u)
367+
return batch<T_out, A>(0);
368+
if (m == detail::low_mask(size))
369+
return batch<T_out, A>::load(mem, unaligned_mode {});
370+
alignas(A::alignment()) std::array<T_out, size> buffer { 0 };
371+
for (std::size_t i = 0; i < size; ++i)
405372
if (mask.get(i))
406-
{
407373
buffer[i] = static_cast<T_out>(mem[i]);
408-
}
409-
}
410-
return batch<T_out, A>::load_aligned(buffer);
374+
return batch<T_out, A>::load(buffer.data(), unaligned_mode {});
411375
}
412376

413377
template <class A, class T_in, class T_out>
414-
XSIMD_INLINE batch<T_out, A> masked_load(T_in const* mem,
415-
typename batch<T_out, A>::batch_bool_type const& mask,
416-
convert<T_out>,
417-
unaligned_mode,
418-
requires_arch<common>) noexcept
378+
XSIMD_INLINE void masked_store(T_out* mem, batch<T_in, A> const& src, typename batch<T_in, A>::batch_bool_type const& mask, requires_arch<common>) noexcept
419379
{
420-
alignas(A::alignment()) T_out buffer[batch<T_out, A>::size] = {};
421-
for (std::size_t i = 0; i < batch<T_out, A>::size; ++i)
380+
constexpr std::size_t size = batch<T_in, A>::size;
381+
const uint64_t m = mask.mask();
382+
if (m == 0u)
383+
return;
384+
if (m == detail::low_mask(size))
422385
{
423-
if (mask.get(i))
424-
{
425-
buffer[i] = static_cast<T_out>(mem[i]);
426-
}
386+
src.store(mem, unaligned_mode {});
387+
return;
427388
}
428-
return batch<T_out, A>::load_aligned(buffer);
389+
for (std::size_t i = 0; i < size; ++i)
390+
if (mask.get(i))
391+
mem[i] = static_cast<T_out>(src.get(i));
429392
}
430393

431-
// masked_store
432-
template <class A, class T_in, class T_out>
433-
XSIMD_INLINE void masked_store(T_out* mem,
434-
batch<T_in, A> const& src,
435-
typename batch<T_in, A>::batch_bool_type const& mask,
436-
aligned_mode,
437-
requires_arch<common>) noexcept
394+
// COMPILE-TIME (single version each, XSIMD_IF_CONSTEXPR)
395+
template <class A, class T_in, class T_out, bool... Values>
396+
XSIMD_INLINE batch<T_out, A> masked_load(T_in const* mem, batch_bool_constant<T_out, A, Values...> mask, convert<T_out>, requires_arch<common>) noexcept
438397
{
439-
for (std::size_t i = 0; i < batch<T_in, A>::size; ++i)
398+
constexpr std::size_t size = batch<T_out, A>::size;
399+
constexpr std::size_t n = mask.countr_one();
400+
constexpr std::size_t l = mask.countl_one();
401+
402+
// All zeros / all ones fast paths
403+
XSIMD_IF_CONSTEXPR(mask.is_all_zeros())
440404
{
441-
if (mask.get(i))
442-
{
443-
mem[i] = static_cast<T_out>(src.get(i));
444-
}
405+
return batch<T_out, A>(0);
406+
}
407+
else XSIMD_IF_CONSTEXPR(mask.is_all_ones())
408+
{
409+
return batch<T_out, A>::load(mem, unaligned_mode {});
410+
}
411+
// Prefix-ones (n contiguous ones from LSB)
412+
else XSIMD_IF_CONSTEXPR(n > 0)
413+
{
414+
alignas(A::alignment()) std::array<T_out, size> buffer { 0 };
415+
for (std::size_t i = 0; i < n; ++i)
416+
buffer[i] = static_cast<T_out>(mem[i]);
417+
return batch<T_out, A>::load(buffer.data(), aligned_mode {});
418+
}
419+
// Suffix-ones (l contiguous ones from MSB)
420+
else XSIMD_IF_CONSTEXPR(l > 0)
421+
{
422+
alignas(A::alignment()) std::array<T_out, size> buffer { 0 };
423+
const std::size_t start = size - l;
424+
for (std::size_t i = 0; i < l; ++i)
425+
buffer[start + i] = static_cast<T_out>(mem[start + i]);
426+
return batch<T_out, A>::load(buffer.data(), aligned_mode {});
427+
}
428+
else
429+
{
430+
// Fallback to runtime path for non prefix/suffix masks
431+
return masked_load<A>(mem, mask.as_batch_bool(), convert<T_out> {}, common {});
445432
}
446433
}
447434

448-
template <class A, class T_in, class T_out>
449-
XSIMD_INLINE void masked_store(T_out* mem,
450-
batch<T_in, A> const& src,
451-
typename batch<T_in, A>::batch_bool_type const& mask,
452-
unaligned_mode,
453-
requires_arch<common>) noexcept
435+
template <class A, class T_in, class T_out, bool... Values>
436+
XSIMD_INLINE void masked_store(T_out* mem, batch<T_in, A> const& src, batch_bool_constant<T_in, A, Values...> mask, requires_arch<common>) noexcept
454437
{
455-
for (std::size_t i = 0; i < batch<T_in, A>::size; ++i)
438+
constexpr std::size_t size = batch<T_in, A>::size;
439+
constexpr std::size_t n = mask.countr_one();
440+
constexpr std::size_t l = mask.countl_one();
441+
442+
// All zeros / all ones fast paths
443+
XSIMD_IF_CONSTEXPR(mask.is_all_zeros())
456444
{
457-
if (mask.get(i))
458-
{
445+
return;
446+
}
447+
else XSIMD_IF_CONSTEXPR(mask.is_all_ones())
448+
{
449+
src.store(mem, unaligned_mode {});
450+
}
451+
// Prefix-ones
452+
else XSIMD_IF_CONSTEXPR(n > 0)
453+
{
454+
for (std::size_t i = 0; i < n; ++i)
459455
mem[i] = static_cast<T_out>(src.get(i));
460-
}
456+
}
457+
// Suffix-ones
458+
else XSIMD_IF_CONSTEXPR(l > 0)
459+
{
460+
const std::size_t start = size - l;
461+
for (std::size_t i = 0; i < l; ++i)
462+
mem[start + i] = static_cast<T_out>(src.get(start + i));
463+
}
464+
else
465+
{
466+
masked_store<A>(mem, src, mask.as_batch_bool(), common {});
461467
}
462468
}
463469

0 commit comments

Comments
 (0)