Skip to content

Commit fd0e777

Browse files
committed
1. Adds new masked API runtime/compile time masks (store_masked and load_masked)
2. General use case optimization 3. New tests 4. x86 kernels
1 parent ac8a93f commit fd0e777

File tree

13 files changed

+3821
-10
lines changed

13 files changed

+3821
-10
lines changed

docs/source/api/data_transfer.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Data transfer
1010
From memory:
1111

1212
+---------------------------------------+----------------------------------------------------+
13-
| :cpp:func:`load` | load values from memory |
13+
| :cpp:func:`load` | load values from memory (optionally masked) |
1414
+---------------------------------------+----------------------------------------------------+
1515
| :cpp:func:`load_aligned` | load values from aligned memory |
1616
+---------------------------------------+----------------------------------------------------+
@@ -30,7 +30,7 @@ From a scalar:
3030
To memory:
3131

3232
+---------------------------------------+----------------------------------------------------+
33-
| :cpp:func:`store` | store values to memory |
33+
| :cpp:func:`store` | store values to memory (optionally masked) |
3434
+---------------------------------------+----------------------------------------------------+
3535
| :cpp:func:`store_aligned` | store values to aligned memory |
3636
+---------------------------------------+----------------------------------------------------+

include/xsimd/arch/common/xsimd_common_memory.hpp

Lines changed: 150 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
#ifndef XSIMD_COMMON_MEMORY_HPP
1313
#define XSIMD_COMMON_MEMORY_HPP
1414

15+
#include "../../types/xsimd_batch_constant.hpp"
16+
#include "./xsimd_common_details.hpp"
1517
#include <algorithm>
18+
#include <array>
1619
#include <complex>
1720
#include <stdexcept>
1821

19-
#include "../../types/xsimd_batch_constant.hpp"
20-
#include "./xsimd_common_details.hpp"
21-
2222
namespace xsimd
2323
{
2424
template <typename T, class A, T... Values>
@@ -348,6 +348,153 @@ namespace xsimd
348348
return detail::load_unaligned<A>(mem, cvt, common {}, detail::conversion_type<A, T_in, T_out> {});
349349
}
350350

351+
template <class A, class T_in, class T_out>
352+
XSIMD_INLINE batch<T_out, A> load_masked(T_in const* mem, typename batch<T_out, A>::batch_bool_type const& mask, convert<T_out>, requires_arch<common>) noexcept
353+
{
354+
constexpr std::size_t size = batch<T_out, A>::size;
355+
if (mask.none())
356+
return batch<T_out, A>(0);
357+
if (mask.all())
358+
return batch<T_out, A>::load(mem, unaligned_mode {});
359+
alignas(A::alignment()) std::array<T_out, size> buffer { 0 };
360+
for (std::size_t idx = 0; idx < size; ++idx)
361+
{
362+
if (mask.get(idx))
363+
{
364+
buffer[idx] = static_cast<T_out>(mem[idx]);
365+
}
366+
}
367+
return batch<T_out, A>::load(buffer.data(), unaligned_mode {});
368+
}
369+
370+
template <class A, class T_in, class T_out>
371+
XSIMD_INLINE void store_masked(T_out* mem, batch<T_in, A> const& src, typename batch<T_in, A>::batch_bool_type const& mask, requires_arch<common>) noexcept
372+
{
373+
constexpr std::size_t size = batch<T_in, A>::size;
374+
if (mask.none())
375+
return;
376+
if (mask.all())
377+
{
378+
src.store(mem, unaligned_mode {});
379+
return;
380+
}
381+
for (std::size_t idx = 0; idx < size; ++idx)
382+
{
383+
if (mask.get(idx))
384+
{
385+
mem[idx] = static_cast<T_out>(src.get(idx));
386+
}
387+
}
388+
}
389+
390+
// COMPILE-TIME (single version each, XSIMD_IF_CONSTEXPR)
391+
template <class A, class T_in, class T_out, bool... Values>
392+
XSIMD_INLINE batch<T_out, A> load_masked(T_in const* mem, batch_bool_constant<T_out, A, Values...> mask, convert<T_out>, requires_arch<common>) noexcept
393+
{
394+
constexpr std::size_t size = batch<T_out, A>::size;
395+
constexpr std::size_t n = mask.countr_one();
396+
constexpr std::size_t l = mask.countl_one();
397+
398+
// All zeros / all ones fast paths
399+
XSIMD_IF_CONSTEXPR(mask.none())
400+
{
401+
return batch<T_out, A>(0);
402+
}
403+
else XSIMD_IF_CONSTEXPR(mask.all())
404+
{
405+
return batch<T_out, A>::load(mem, unaligned_mode {});
406+
}
407+
// Prefix-ones (n contiguous ones from LSB)
408+
else XSIMD_IF_CONSTEXPR(n > 0)
409+
{
410+
alignas(A::alignment()) std::array<T_out, size> buffer { 0 };
411+
for (std::size_t i = 0; i < n; ++i)
412+
buffer[i] = static_cast<T_out>(mem[i]);
413+
return batch<T_out, A>::load(buffer.data(), aligned_mode {});
414+
}
415+
// Suffix-ones (l contiguous ones from MSB)
416+
else XSIMD_IF_CONSTEXPR(l > 0)
417+
{
418+
alignas(A::alignment()) std::array<T_out, size> buffer { 0 };
419+
const std::size_t start = size - l;
420+
for (std::size_t i = 0; i < l; ++i)
421+
buffer[start + i] = static_cast<T_out>(mem[start + i]);
422+
return batch<T_out, A>::load(buffer.data(), aligned_mode {});
423+
}
424+
else XSIMD_IF_CONSTEXPR(mask.popcount() > 0)
425+
{
426+
constexpr std::size_t first = mask.first_one_index();
427+
constexpr std::size_t last = mask.last_one_index();
428+
constexpr std::size_t span = last >= first ? (last - first + 1) : 0;
429+
XSIMD_IF_CONSTEXPR(span > 0 && mask.popcount() == span)
430+
{
431+
alignas(A::alignment()) std::array<T_out, size> buffer { 0 };
432+
for (std::size_t i = 0; i < span; ++i)
433+
buffer[first + i] = static_cast<T_out>(mem[first + i]);
434+
return batch<T_out, A>::load(buffer.data(), aligned_mode {});
435+
}
436+
else
437+
{
438+
return load_masked<A>(mem, mask.as_batch_bool(), convert<T_out> {}, common {});
439+
}
440+
}
441+
else
442+
{
443+
// Fallback to runtime path for non prefix/suffix masks
444+
return load_masked<A>(mem, mask.as_batch_bool(), convert<T_out> {}, common {});
445+
}
446+
}
447+
448+
template <class A, class T_in, class T_out, bool... Values>
449+
XSIMD_INLINE void store_masked(T_out* mem, batch<T_in, A> const& src, batch_bool_constant<T_in, A, Values...> mask, requires_arch<common>) noexcept
450+
{
451+
constexpr std::size_t size = batch<T_in, A>::size;
452+
constexpr std::size_t n = mask.countr_one();
453+
constexpr std::size_t l = mask.countl_one();
454+
455+
// All zeros / all ones fast paths
456+
XSIMD_IF_CONSTEXPR(mask.none())
457+
{
458+
return;
459+
}
460+
else XSIMD_IF_CONSTEXPR(mask.all())
461+
{
462+
src.store(mem, unaligned_mode {});
463+
}
464+
// Prefix-ones
465+
else XSIMD_IF_CONSTEXPR(n > 0)
466+
{
467+
for (std::size_t i = 0; i < n; ++i)
468+
mem[i] = static_cast<T_out>(src.get(i));
469+
}
470+
// Suffix-ones
471+
else XSIMD_IF_CONSTEXPR(l > 0)
472+
{
473+
const std::size_t start = size - l;
474+
for (std::size_t i = 0; i < l; ++i)
475+
mem[start + i] = static_cast<T_out>(src.get(start + i));
476+
}
477+
else XSIMD_IF_CONSTEXPR(mask.popcount() > 0)
478+
{
479+
constexpr std::size_t first = mask.first_one_index();
480+
constexpr std::size_t last = mask.last_one_index();
481+
constexpr std::size_t span = last >= first ? (last - first + 1) : 0;
482+
XSIMD_IF_CONSTEXPR(span > 0 && mask.popcount() == span)
483+
{
484+
for (std::size_t i = 0; i < span; ++i)
485+
mem[first + i] = static_cast<T_out>(src.get(first + i));
486+
}
487+
else
488+
{
489+
store_masked<A>(mem, src, mask.as_batch_bool(), common {});
490+
}
491+
}
492+
else
493+
{
494+
store_masked<A>(mem, src, mask.as_batch_bool(), common {});
495+
}
496+
}
497+
351498
// rotate_right
352499
template <size_t N, class A, class T>
353500
XSIMD_INLINE batch<T, A> rotate_right(batch<T, A> const& self, requires_arch<common>) noexcept

0 commit comments

Comments
 (0)