@@ -351,113 +351,119 @@ namespace xsimd
351351 // mask helpers
352352 namespace detail
353353 {
354- // Helper: first_set_count recursion
355- constexpr std::size_t first_set_count_impl (uint64_t bits, std::size_t n)
356- {
357- return (bits & 1u ) != 0u
358- ? first_set_count_impl (bits >> 1 , n + 1 )
359- : n;
360- }
361-
362- template <class Mask >
363- constexpr std::size_t first_set_count (Mask const & mask) noexcept
364- {
365- return ((static_cast <uint64_t >(mask.mask ()) >> first_set_count_impl (static_cast <uint64_t >(mask.mask ()), 0 )) == 0u )
366- ? first_set_count_impl (static_cast <uint64_t >(mask.mask ()), 0 )
367- : (Mask::size + 1 );
368- }
369-
370- // Helper: last_set_count recursion
371- template <std::size_t Size>
372- constexpr std::size_t last_set_count_impl (uint64_t bits, std::size_t n)
373- {
374- return (n < Size && (bits & (uint64_t (1 ) << (Size - 1 - n))) != 0u )
375- ? last_set_count_impl<Size>(bits, n + 1 )
376- : n;
377- }
378-
379354 // safe mask for k bits (must be single return)
380355 constexpr uint64_t low_mask (std::size_t k)
381356 {
382357 return (k >= 64u ) ? ~uint64_t (0 ) : ((uint64_t (1 ) << k) - 1u );
383358 }
384-
385- template <class Mask >
386- constexpr std::size_t last_set_count (Mask const & mask) noexcept
387- {
388- return ((static_cast <uint64_t >(mask.mask ()) & low_mask (Mask::size - last_set_count_impl<Mask::size>(static_cast <uint64_t >(mask.mask ()), 0 ))) == 0u )
389- ? last_set_count_impl<Mask::size>(static_cast <uint64_t >(mask.mask ()), 0 )
390- : (Mask::size + 1 );
391- }
392359 }
393360
394- // masked_load
395361 template <class A , class T_in , class T_out >
396- XSIMD_INLINE batch<T_out, A> masked_load (T_in const * mem,
397- typename batch<T_out, A>::batch_bool_type const & mask,
398- convert<T_out>,
399- aligned_mode,
400- requires_arch<common>) noexcept
401- {
402- alignas (A::alignment ()) T_out buffer[batch<T_out, A>::size] = {};
403- for (std::size_t i = 0 ; i < batch<T_out, A>::size; ++i)
404- {
362+ XSIMD_INLINE batch<T_out, A> masked_load (T_in const * mem, typename batch<T_out, A>::batch_bool_type const & mask, convert<T_out>, requires_arch<common>) noexcept
363+ {
364+ constexpr std::size_t size = batch<T_out, A>::size;
365+ const uint64_t m = mask.mask ();
366+ if (m == 0u )
367+ return batch<T_out, A>(0 );
368+ if (m == detail::low_mask (size))
369+ return batch<T_out, A>::load (mem, unaligned_mode {});
370+ alignas (A::alignment ()) std::array<T_out, size> buffer { 0 };
371+ for (std::size_t i = 0 ; i < size; ++i)
405372 if (mask.get (i))
406- {
407373 buffer[i] = static_cast <T_out>(mem[i]);
408- }
409- }
410- return batch<T_out, A>::load_aligned (buffer);
374+ return batch<T_out, A>::load (buffer.data (), unaligned_mode {});
411375 }
412376
413377 template <class A , class T_in , class T_out >
414- XSIMD_INLINE batch<T_out, A> masked_load (T_in const * mem,
415- typename batch<T_out, A>::batch_bool_type const & mask,
416- convert<T_out>,
417- unaligned_mode,
418- requires_arch<common>) noexcept
378+ XSIMD_INLINE void masked_store (T_out* mem, batch<T_in, A> const & src, typename batch<T_in, A>::batch_bool_type const & mask, requires_arch<common>) noexcept
419379 {
420- alignas (A::alignment ()) T_out buffer[batch<T_out, A>::size] = {};
421- for (std::size_t i = 0 ; i < batch<T_out, A>::size; ++i)
380+ constexpr std::size_t size = batch<T_in, A>::size;
381+ const uint64_t m = mask.mask ();
382+ if (m == 0u )
383+ return ;
384+ if (m == detail::low_mask (size))
422385 {
423- if (mask.get (i))
424- {
425- buffer[i] = static_cast <T_out>(mem[i]);
426- }
386+ src.store (mem, unaligned_mode {});
387+ return ;
427388 }
428- return batch<T_out, A>::load_aligned (buffer);
389+ for (std::size_t i = 0 ; i < size; ++i)
390+ if (mask.get (i))
391+ mem[i] = static_cast <T_out>(src.get (i));
429392 }
430393
431- // masked_store
432- template <class A , class T_in , class T_out >
433- XSIMD_INLINE void masked_store (T_out* mem,
434- batch<T_in, A> const & src,
435- typename batch<T_in, A>::batch_bool_type const & mask,
436- aligned_mode,
437- requires_arch<common>) noexcept
394+ // COMPILE-TIME (single version each, XSIMD_IF_CONSTEXPR)
395+ template <class A , class T_in , class T_out , bool ... Values>
396+ XSIMD_INLINE batch<T_out, A> masked_load (T_in const * mem, batch_bool_constant<T_out, A, Values...> mask, convert<T_out>, requires_arch<common>) noexcept
438397 {
439- for (std::size_t i = 0 ; i < batch<T_in, A>::size; ++i)
398+ constexpr std::size_t size = batch<T_out, A>::size;
399+ constexpr std::size_t n = mask.countr_one ();
400+ constexpr std::size_t l = mask.countl_one ();
401+
402+ // All zeros / all ones fast paths
403+ XSIMD_IF_CONSTEXPR (mask.is_all_zeros ())
440404 {
441- if (mask.get (i))
442- {
443- mem[i] = static_cast <T_out>(src.get (i));
444- }
405+ return batch<T_out, A>(0 );
406+ }
407+ else XSIMD_IF_CONSTEXPR (mask.is_all_ones ())
408+ {
409+ return batch<T_out, A>::load (mem, unaligned_mode {});
410+ }
411+ // Prefix-ones (n contiguous ones from LSB)
412+ else XSIMD_IF_CONSTEXPR (n > 0 )
413+ {
414+ alignas (A::alignment ()) std::array<T_out, size> buffer { 0 };
415+ for (std::size_t i = 0 ; i < n; ++i)
416+ buffer[i] = static_cast <T_out>(mem[i]);
417+ return batch<T_out, A>::load (buffer.data (), aligned_mode {});
418+ }
419+ // Suffix-ones (l contiguous ones from MSB)
420+ else XSIMD_IF_CONSTEXPR (l > 0 )
421+ {
422+ alignas (A::alignment ()) std::array<T_out, size> buffer { 0 };
423+ const std::size_t start = size - l;
424+ for (std::size_t i = 0 ; i < l; ++i)
425+ buffer[start + i] = static_cast <T_out>(mem[start + i]);
426+ return batch<T_out, A>::load (buffer.data (), aligned_mode {});
427+ }
428+ else
429+ {
430+ // Fallback to runtime path for non prefix/suffix masks
431+ return masked_load<A>(mem, mask.as_batch_bool (), convert<T_out> {}, common {});
445432 }
446433 }
447434
448- template <class A , class T_in , class T_out >
449- XSIMD_INLINE void masked_store (T_out* mem,
450- batch<T_in, A> const & src,
451- typename batch<T_in, A>::batch_bool_type const & mask,
452- unaligned_mode,
453- requires_arch<common>) noexcept
435+ template <class A , class T_in , class T_out , bool ... Values>
436+ XSIMD_INLINE void masked_store (T_out* mem, batch<T_in, A> const & src, batch_bool_constant<T_in, A, Values...> mask, requires_arch<common>) noexcept
454437 {
455- for (std::size_t i = 0 ; i < batch<T_in, A>::size; ++i)
438+ constexpr std::size_t size = batch<T_in, A>::size;
439+ constexpr std::size_t n = mask.countr_one ();
440+ constexpr std::size_t l = mask.countl_one ();
441+
442+ // All zeros / all ones fast paths
443+ XSIMD_IF_CONSTEXPR (mask.is_all_zeros ())
456444 {
457- if (mask.get (i))
458- {
445+ return ;
446+ }
447+ else XSIMD_IF_CONSTEXPR (mask.is_all_ones ())
448+ {
449+ src.store (mem, unaligned_mode {});
450+ }
451+ // Prefix-ones
452+ else XSIMD_IF_CONSTEXPR (n > 0 )
453+ {
454+ for (std::size_t i = 0 ; i < n; ++i)
459455 mem[i] = static_cast <T_out>(src.get (i));
460- }
456+ }
457+ // Suffix-ones
458+ else XSIMD_IF_CONSTEXPR (l > 0 )
459+ {
460+ const std::size_t start = size - l;
461+ for (std::size_t i = 0 ; i < l; ++i)
462+ mem[start + i] = static_cast <T_out>(src.get (start + i));
463+ }
464+ else
465+ {
466+ masked_store<A>(mem, src, mask.as_batch_bool (), common {});
461467 }
462468 }
463469
0 commit comments