From f1799933791c0f6e6ae960c70a66eee2ef6d2c7e Mon Sep 17 00:00:00 2001 From: Thibaud Kloczko Date: Mon, 10 Aug 2020 12:30:18 +0200 Subject: [PATCH] Polymorphic sparse scheme to solve following use case: * Let us consider a Finite-Element or Finite-Volume code for which we have to solve a linear system using different libraries (Hypre, Petsc, Trilinos, Mumps, PastiX, ...). We would like to avoid conversions between different formats of sparse matrix. As the linear solver drives the choice of the sparse matrix format, we would like that the matrix the FE code has to fill has directly the format the linear solver requires. The choice of the solver is done at runtime, so it means that the underlying sparse scheme of the matrix should be instanciated at runtime. It implies that the sparse array or sparse tensor is a wrapper around a polymorphic sparse scheme. This latter is a bridge over an abstract scheme that uses type erasure to handle the different format implementation. --- include/xtensor-sparse/xsparse_scheme.hpp | 864 ++++++++++++++++++++++ 1 file changed, 864 insertions(+) create mode 100644 include/xtensor-sparse/xsparse_scheme.hpp diff --git a/include/xtensor-sparse/xsparse_scheme.hpp b/include/xtensor-sparse/xsparse_scheme.hpp new file mode 100644 index 0000000..26062d2 --- /dev/null +++ b/include/xtensor-sparse/xsparse_scheme.hpp @@ -0,0 +1,864 @@ +#ifndef XSPARSE_SCHEME_HPP +#define XSPARSE_SCHEME_HPP + +#include +#include + +namespace xt +{ + + /*********************************************************************** + * xsparse_polymorphic_scheme_nz_iterator as a bridge for type erasure * + ***********************************************************************/ + + template + class xsparse_abstract_scheme_nz_iterator; + + template + class xsparse_polymorphic_scheme_nz_iterator + { + public: + + using self_type = xsparse_polymorphic_scheme_nz_iterator; + using abstract_iterator = xsparse_abstract_scheme_nz_iterator; + using index_type = xtl::any; + using value_type = T; + using reference = value_type&; + using pointer = value_type*; + using difference_type = std::ptrdiff_t; + + xsparse_polymorphic_scheme_nz_iterator(abstract_iterator *it); + ~xsparse_polymorphic_scheme_nz_iterator(); + + self_type& operator++(); + self_type& operator--(); + + self_type& operator+=(difference_type n); + self_type& operator-=(difference_type n); + + difference_type operator-(const self_type& rhs) const; + + reference operator*() const; + pointer operator->() const; + const index_type& index() const; + + bool equal(const self_type& rhs) const; + bool less_than(const self_type& rhs) const; + + private: + abstract_iterator *m_it = nullptr; + }; + + template + bool operator == (const xsparse_polymorphic_scheme_nz_iterator& lhs, + const xsparse_polymorphic_scheme_nz_iterator& rhs); + + template + bool operator < (const xsparse_polymorphic_scheme_nz_iterator& lhs, + const xsparse_polymorphic_scheme_nz_iterator& rhs); + + /*************************************************************************** + * xsparse_abstract_scheme_nz_iterator as top-level class for type erasure * + ***************************************************************************/ + + template + class xsparse_abstract_scheme_nz_iterator + { + public: + + using self_type = xsparse_abstract_scheme_nz_iterator; + using index_type = xtl::any; + using value_type = T; + using reference = value_type&; + using pointer = value_type*; + using difference_type = std::ptrdiff_t; + + virtual ~xsparse_abstract_scheme_nz_iterator() = default; + + virtual reference operator*() const = 0; + virtual pointer operator->() const = 0; + virtual const index_type& index() const = 0; + + virtual bool equal(const self_type& rhs) const = 0; + virtual bool less_than(const self_type& rhs) const = 0; + + virtual difference_type distance(const self_type& rhs) const = 0; + + virtual void advance(void) = 0; + virtual void rewind(void) = 0; + virtual void advance(difference_type n) = 0; + virtual void rewind(difference_type n) = 0; + }; + + /****************************************************************** + * xsparse_crtp_scheme_nz_iterator as base class for type erasure * + ******************************************************************/ + + template + class xsparse_crtp_scheme_nz_iterator : public xsparse_abstract_scheme_nz_iterator + { + public: + + using derived_type = D; + + using self_type = xsparse_crtp_scheme_nz_iterator; + using index_type = xtl::any; + using value_type = T; + using reference = value_type&; + using pointer = value_type*; + using difference_type = std::ptrdiff_t; + + const index_type& index() const final; + + bool equal(const self_type& rhs) const final; + bool less_than(const self_type& rhs) const final; + + difference_type distance(const self_type& rhs) const final; + + void advance(void) final; + void rewind(void) final; + void advance(difference_type n) final; + void rewind(difference_type n) final; + + private: + + derived_type& derived_cast() & noexcept; + const derived_type& derived_cast() const & noexcept; + derived_type derived_cast() && noexcept; + + index_type m_index; + }; + + /************************************************ + * xsparse_coo_scheme_nz_iterator as an example * + ************************************************/ + + namespace detail + { + template + struct xsparse_coo_scheme_storage_type + { + using storage_type = typename scheme::storage_type; + using value_iterator = typename storage_type::iterator; + }; + + template + struct xsparse_coo_scheme_storage_type + { + using storage_type = typename scheme::storage_type; + using value_iterator = typename storage_type::const_iterator; + }; + + template + struct xsparse_coo_scheme_nz_iterator_types : xsparse_coo_scheme_storage_type + { + using base_type = xsparse_coo_scheme_storage_type; + using index_type = typename scheme::index_type; + using coordinate_type = typename scheme::coordinate_type; + using coordinate_iterator = typename coordinate_type::const_iterator; + using value_iterator = typename base_type::value_iterator; + using value_type = typename value_iterator::value_type; + using reference = typename value_iterator::reference; + using pointer = typename value_iterator::pointer; + using difference_type = typename value_iterator::difference_type; + }; + } + + template + class xsparse_coo_scheme_nz_iterator : public xsparse_crtp_scheme_nz_iterator>, + xtl::xrandom_access_iterator_base3, + detail::xsparse_coo_scheme_nz_iterator_types> + { + public: + + using self_type = xsparse_coo_scheme_nz_iterator; + using scheme_type = scheme; + using iterator_types = detail::xsparse_coo_scheme_nz_iterator_types; + using index_type = typename iterator_types::index_type; + using coordinate_type = typename iterator_types::coordinate_type; + using coordinate_iterator = typename iterator_types::coordinate_iterator; + using value_iterator = typename iterator_types::value_iterator; + using value_type = typename iterator_types::value_type; + using reference = typename iterator_types::reference; + using pointer = typename iterator_types::pointer; + using difference_type = typename iterator_types::difference_type; + using iterator_category = std::random_access_iterator_tag; + + xsparse_coo_scheme_nz_iterator() = default; + xsparse_coo_scheme_nz_iterator(scheme& s, coordinate_iterator cit, value_iterator vit); + + self_type& operator++(); + self_type& operator--(); + + self_type& operator+=(difference_type n); + self_type& operator-=(difference_type n); + + difference_type operator-(const self_type& rhs) const; + + reference operator*() const; + pointer operator->() const; + const index_type& index() const; + + bool equal(const self_type& rhs) const; + bool less_than(const self_type& rhs) const; + + private: + + scheme_type* p_scheme = nullptr; + coordinate_iterator m_cit; + value_iterator m_vit; + }; + + template + bool operator==(const xsparse_coo_scheme_nz_iterator& lhs, + const xsparse_coo_scheme_nz_iterator& rhs); + + template + bool operator<(const xsparse_coo_scheme_nz_iterator& lhs, + const xsparse_coo_scheme_nz_iterator& rhs); + + /********************************************************* + * xsparse_polymorphic_scheme_nz_iterator implementation * + *********************************************************/ + + template + inline xsparse_polymorphic_scheme_nz_iterator::xsparse_polymorphic_scheme_nz_iterator(abstract_iterator *it) : m_it(it) + { + } + + template + inline xsparse_polymorphic_scheme_nz_iterator::~xsparse_polymorphic_scheme_nz_iterator() + { + if (m_it) + delete m_it; + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::operator++() -> self_type& + { + m_it->advance(); + return *this; + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::operator--() -> self_type& + { + m_it->rewind(); + return *this; + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::operator+=(difference_type n) -> self_type& + { + m_it->advance(n); + return *this; + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::operator-=(difference_type n) -> self_type& + { + m_it->rewind(n); + return *this; + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::operator-(const self_type& rhs) const -> difference_type + { + return m_it->distance(rhs); + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::operator*() const -> reference + { + return m_it->reference(); + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::operator->() const -> pointer + { + return m_it->pointer(); + } + + template + inline auto xsparse_polymorphic_scheme_nz_iterator::index() const -> const index_type& + { + return m_it->index(); + } + + template + inline bool xsparse_polymorphic_scheme_nz_iterator::equal(const self_type& rhs) const + { + return m_it->equal(*(rhs->m_it)); + } + + template + inline bool xsparse_polymorphic_scheme_nz_iterator::less_than(const self_type& rhs) const + { + return m_it->less_than(*(rhs->m_it)); + } + + template + inline bool operator == (const xsparse_polymorphic_scheme_nz_iterator& lhs, + const xsparse_polymorphic_scheme_nz_iterator& rhs) + { + return lhs->equal(rhs); + } + + template + inline bool operator < (const xsparse_polymorphic_scheme_nz_iterator& lhs, + const xsparse_polymorphic_scheme_nz_iterator& rhs) + { + return lhs->less_than(rhs); + } + + /************************************************** + * xsparse_crtp_scheme_nz_iterator implementation * + **************************************************/ + + template + inline auto xsparse_crtp_scheme_nz_iterator::index() const -> const index_type& + { + m_index = this->derived_cast().index(); + return m_index; + } + + template + inline bool xsparse_crtp_scheme_nz_iterator::equal(const self_type& rhs) const + { + return this->derived_cast() == static_cast(rhs); + } + + template + inline bool xsparse_crtp_scheme_nz_iterator::less_than(const self_type& rhs) const + { + return this->derived_cast() < static_cast(rhs); + } + + template + inline auto xsparse_crtp_scheme_nz_iterator::distance(const self_type& rhs) const -> difference_type + { + auto self = this->derived_cast(); + auto other = static_cast(rhs); + + auto diff = self - other; + return (difference_type)(diff); + } + + template + inline void xsparse_crtp_scheme_nz_iterator::advance(void) + { + ++(this->derived_cast()); + } + + template + inline void xsparse_crtp_scheme_nz_iterator::rewind(void) + { + --(this->derived_cast()); + } + + template + inline void xsparse_crtp_scheme_nz_iterator::advance(difference_type n) + { + (this->derived_cast()) += n; + } + + template + inline void xsparse_crtp_scheme_nz_iterator::rewind(difference_type n) + { + (this->derived_cast()) -= n; + } + + template + inline auto xsparse_crtp_scheme_nz_iterator::derived_cast() & noexcept -> derived_type& + { + return static_cast(*this); + } + + template + inline auto xsparse_crtp_scheme_nz_iterator::derived_cast() const & noexcept -> const derived_type& + { + return static_cast(*this); + } + + /************************************************* + * xsparse_coo_scheme_nz_iterator implementation * + *************************************************/ + + template + inline xsparse_coo_scheme_nz_iterator::xsparse_coo_scheme_nz_iterator(S& s, coordinate_iterator cit, value_iterator vit) + : p_scheme(&s) + , m_cit(cit) + , m_vit(vit) + { + } + + template + inline auto xsparse_coo_scheme_nz_iterator::operator++() -> self_type& + { + ++m_cit; + ++m_vit; + return *this; + } + + template + inline auto xsparse_coo_scheme_nz_iterator::operator--() -> self_type& + { + --m_cit; + --m_vit; + return *this; + } + + template + inline auto xsparse_coo_scheme_nz_iterator::operator+=(difference_type n) -> self_type& + { + m_cit += n; + m_vit += n; + return *this; + } + + template + inline auto xsparse_coo_scheme_nz_iterator::operator-=(difference_type n) -> self_type& + { + m_cit -= n; + m_vit -= n; + return *this; + } + + template + inline auto xsparse_coo_scheme_nz_iterator::operator-(const self_type& rhs) const -> difference_type + { + return m_cit - rhs.m_cit; + } + + template + inline auto xsparse_coo_scheme_nz_iterator::operator*() const -> reference + { + return *m_vit; + } + + template + inline auto xsparse_coo_scheme_nz_iterator::operator->() const -> pointer + { + return &(*m_vit); + } + + template + inline auto xsparse_coo_scheme_nz_iterator::index() const -> const index_type& + { + return *m_cit; + } + + template + inline bool xsparse_coo_scheme_nz_iterator::equal(const self_type& rhs) const + { + return p_scheme == rhs.p_scheme && m_cit == rhs.m_cit && m_vit == rhs.m_vit; + } + + template + inline bool xsparse_coo_scheme_nz_iterator::less_than(const self_type& rhs) const + { + return p_scheme == rhs.p_scheme && m_cit < rhs.m_cit && m_vit < rhs.m_vit; + } + + template + inline bool operator == (const xsparse_coo_scheme_nz_iterator& lhs, + const xsparse_coo_scheme_nz_iterator& rhs) + { + return lhs.equal(rhs); + } + + template + inline bool operator < (const xsparse_coo_scheme_nz_iterator& lhs, + const xsparse_coo_scheme_nz_iterator& rhs) + { + return lhs.less_than(rhs); + } + + /*********************************************************** + * xsparse_polymorphic_scheme as a bridge for type erasure * + ***********************************************************/ + + template + class xsparse_abstract_scheme; + + template + class xsparse_polymorphic_scheme + { + public: + + using self_type = xsparse_polymorphic_scheme; + using index_type = xtl::any; + + using value_type = T; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + + using size_type = std::size_t; + using shape_type = svector; + using strides_type = svector; + using inner_shape_type = shape_type; + + using nz_iterator = xsparse_polymorphic_scheme_nz_iterator; + using const_nz_iterator = xsparse_polymorphic_scheme_nz_iterator; + + xsparse_polymorphic_scheme(); + xsparse_polymorphic_scheme(xsparse_abstract_scheme *scheme); + ~xsparse_polymorphic_scheme(); + + pointer find_element(const index_type& index); + const_pointer find_element(const index_type& index) const; + void insert_element(const index_type& index, const_reference value); + void remove_element(const index_type& index); + + void update_entries(const strides_type& old_strides, + const strides_type& new_strides, + const shape_type& new_shape); + + nz_iterator nz_begin(); + nz_iterator nz_end(); + const_nz_iterator nz_begin() const; + const_nz_iterator nz_end() const; + const_nz_iterator nz_cbegin() const; + const_nz_iterator nz_cend() const; + + private: + class xsparse_abstract_scheme *m_scheme = nullptr; + }; + + /*********************************************************** + * xsparse_abstract_scheme as base class for type erasure * + ***********************************************************/ + + template + class xsparse_abstract_scheme + { + public: + + using self_type = xsparse_abstract_scheme; + using index_type = xtl::any; + + using value_type = T; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + + using size_type = std::size_t; + using shape_type = svector; + using strides_type = svector; + using inner_shape_type = shape_type; + + using nz_iterator = xsparse_polymorphic_scheme_nz_iterator; + using const_nz_iterator = xsparse_polymorphic_scheme_nz_iterator; + + + virtual ~xsparse_abstract_scheme = default; + + virtual pointer find_element(const index_type& index) = 0; + virtual const_pointer find_element(const index_type& index) const = 0; + virtual void insert_element(const index_type& index, const_reference value) = 0; + virtual void remove_element(const index_type& index) = 0; + + virtual void update_entries(const strides_type& old_strides, + const strides_type& new_strides, + const shape_type& new_shape) = 0; + + virtual nz_iterator nz_begin() = 0; + virtual nz_iterator nz_end() = 0; + virtual const_nz_iterator nz_begin() const = 0; + virtual const_nz_iterator nz_end() const = 0; + virtual const_nz_iterator nz_cbegin() const = 0; + virtual const_nz_iterator nz_cend() const = 0; + }; + + /********************** + * xsparse_coo_scheme * + **********************/ + + template > + class xsparse_coo_scheme + { + public: + + using self_type = xsparse_coo_scheme; + using position_type = P; + using coordinate_type = C; + using storage_type = ST; + using index_type = IT; + + using value_type = typename storage_type::value_type; + using reference = typename storage_type::reference; + using const_reference = typename storage_type::const_reference; + using pointer = typename storage_type::pointer; + using const_pointer = typename storage_type::const_pointer; + + using nz_iterator = xsparse_polymorphic_scheme_nz_iterator; + using const_nz_iterator = xsparse_polymorphic_scheme_nz_iterator; + + using coo_nz_iterator = xsparse_coo_scheme_nz_iterator; + using coo_const_nz_iterator = xsparse_coo_scheme_nz_iterator; + + xsparse_coo_scheme(); + + const position_type& position() const; + const coordinate_type& coordinate() const; + const storage_type& storage() const; + + + pointer find_element(const index_type& index); + const_pointer find_element(const index_type& index) const; + void insert_element(const index_type& index, const_reference value); + void remove_element(const index_type& index); + + template + void update_entries(const strides_type& old_strides, + const strides_type& new_strides, + const shape_type& new_shape); + + nz_iterator nz_begin(); + nz_iterator nz_end(); + const_nz_iterator nz_begin() const; + const_nz_iterator nz_end() const; + const_nz_iterator nz_cbegin() const; + const_nz_iterator nz_cend() const; + + private: + + const_pointer find_element_impl(const index_type& index) const; + + position_type m_pos; + coordinate_type m_coords; + storage_type m_storage; + + friend class xsparse_coo_scheme_nz_iterator; + friend class xsparse_coo_scheme_nz_iterator; + }; + + + /********************************************* + * xsparse_polymorphic_scheme implementation * + *********************************************/ + + template + inline xsparse_polymorphic_scheme::xsparse_polymorphic_scheme() + { + // m_scheme = xt::scheme_policy().scheme(); + } + + template + inline xsparse_polymorphic_scheme::xsparse_polymorphic_scheme(xsparse_abstract_scheme *scheme) : m_scheme(scheme) + { + + } + + template + inline xsparse_polymorphic_scheme::~xsparse_polymorphic_scheme() + { + if (m_scheme) + delete m_scheme; + } + + template + inline auto xsparse_polymorphic_scheme::find_element(const index_type& index) -> pointer + { + return m_scheme->find_element(index); + } + + template + inline auto xsparse_polymorphic_scheme::find_element(const index_type& index) const -> const_pointer + { + return m_scheme->find_element(index); + } + + template + inline void xsparse_polymorphic_scheme::insert_element(const index_type& index, const_reference value) + { + m_scheme->insert_element(index, value); + } + + template + inline void xsparse_polymorphic_scheme::remove_element(const index_type& index) + { + m_scheme->remove_element(index); + } + + template + inline void xsparse_polymorphic_scheme::update_entries(const strides_type& old_strides, + const strides_type& new_strides, + const shape_type& new_shape) + { + m_scheme->update_entries(old_strides, new_strides, new_shape); + } + + template + inline auto xsparse_polymorphic_scheme::nz_begin() -> nz_iterator + { + return m_scheme->nz_begin(); + } + + template + inline auto xsparse_polymorphic_scheme::nz_end() -> nz_iterator + { + return m_scheme->nz_end(); + } + + template + inline auto xsparse_polymorphic_scheme::nz_begin() const -> const_nz_iterator + { + return m_scheme->nz_begin(); + } + + template + inline auto xsparse_polymorphic_scheme::nz_end() const -> const_nz_iterator + { + return m_scheme->nz_end(); + } + + template + inline auto xsparse_polymorphic_scheme::nz_cbegin() const -> const_nz_iterator + { + return m_scheme->nz_cbegin(); + } + + template + inline auto xsparse_polymorphic_scheme::nz_cend() const -> const_nz_iterator + { + return m_scheme->nz_cend(); + } + + /****************************************** + * xsparse_abstract_scheme implementation * + ******************************************/ + + template + inline xsparse_coo_scheme::xsparse_coo_scheme() + : m_pos(P{{0u, 0u}}) + + + template + inline auto xsparse_coo_scheme::position() const -> const position_type& + { + return m_pos; + } + + template + inline auto xsparse_coo_scheme::coordinate() const -> const coordinate_type& + { + return m_coords; + } + + template + inline auto xsparse_coo_scheme::storage() const -> const storage_type& + { + return m_storage; + } + + template + inline auto xsparse_coo_scheme::find_element(const index_type& index) -> pointer + { + return const_cast(find_element_impl(index)); + } + + template + inline auto xsparse_coo_scheme::find_element(const index_type& index) const -> const_pointer + { + return find_element_impl(index); + } + + template + inline void xsparse_coo_scheme::insert_element(const index_type& index, const_reference value) + { + auto it = std::upper_bound(m_coords.cbegin(), m_coords.cend(), index); + if (it != m_coords.cend()) + { + auto diff = std::distance(m_coords.cbegin(), it); + m_coords.insert(it, index); + m_storage.insert(m_storage.cbegin() + diff, value); + } + else + { + m_coords.push_back(index); + m_storage.push_back(value); + } + ++m_pos.back(); + } + + template + inline void xsparse_coo_scheme::remove_element(const index_type& index) + { + auto it = std::find(m_coords.begin(), m_coords.end(), index); + if (it != m_coords.end()) + { + auto diff = it - m_coords.begin(); + m_coords.erase(it); + m_pos.back()--; + m_storage.erase(m_storage.begin() + diff); + } + } + + template + template + inline void xsparse_coo_scheme::update_entries(const strides_type& old_strides, + const strides_type& new_strides, + const shape_type&) + { + coordinate_type new_coords; + + for(auto& old_index: m_coords) + { + std::size_t offset = element_offset(old_strides, old_index.cbegin(), old_index.cend()); + index_type new_index = unravel_from_strides(offset, new_strides); + new_coords.push_back(new_index); + } + using std::swap; + swap(m_coords, new_coords); + } + + template + inline auto xsparse_coo_scheme::find_element_impl(const index_type& index) const -> const_pointer + { + auto it = std::find(m_coords.begin(), m_coords.end(), index); + return it == m_coords.end() ? nullptr : &*(m_storage.begin() + (it - m_coords.begin())); + } + + template + inline auto xsparse_coo_scheme::nz_begin() -> nz_iterator + { + return nz_iterator(new coo_nz_iterator(*this, m_coords.cbegin(), m_storage.begin())); + } + + template + inline auto xsparse_coo_scheme::nz_end() -> nz_iterator + { + return nz_iterator(new coo_nz_iterator(*this, m_coords.cend(), m_storage.end())); + } + + template + inline auto xsparse_coo_scheme::nz_begin() const -> const_nz_iterator + { + return nz_cbegin(); + } + + template + inline auto xsparse_coo_scheme::nz_end() const -> const_nz_iterator + { + return nz_cend(); + } + + template + inline auto xsparse_coo_scheme::nz_cbegin() const -> const_nz_iterator + { + return const_nz_iterator(new coo_const_nz_iterator(*this, m_coords.cbegin(), m_storage.cbegin())); + } + + template + inline auto xsparse_coo_scheme::nz_cend() const -> const_nz_iterator + { + return const_nz_iterator(new coo_const_nz_iterator(*this, m_coords.cend(), m_storage.cend())); + } +} + +#endif