diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..a612ad9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,373 @@ +Mozilla Public License Version 2.0 +================================== + +1. Definitions +-------------- + +1.1. "Contributor" + means each individual or legal entity that creates, contributes to + the creation of, or owns Covered Software. + +1.2. "Contributor Version" + means the combination of the Contributions of others (if any) used + by a Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + means Source Code Form to which the initial Contributor has attached + the notice in Exhibit A, the Executable Form of such Source Code + Form, and Modifications of such Source Code Form, in each case + including portions thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + (a) that the initial Contributor has attached the notice described + in Exhibit B to the Covered Software; or + + (b) that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the + terms of a Secondary License. + +1.6. "Executable Form" + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + means a work that combines Covered Software with other material, in + a separate file or files, that is not Covered Software. + +1.8. "License" + means this document. + +1.9. "Licensable" + means having the right to grant, to the maximum extent possible, + whether at the time of the initial grant or subsequently, any and + all of the rights conveyed by this License. + +1.10. "Modifications" + means any of the following: + + (a) any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered + Software; or + + (b) any new file in Source Code Form that contains any Covered + Software. + +1.11. "Patent Claims" of a Contributor + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the + License, by the making, using, selling, offering for sale, having + made, import, or transfer of either its Contributions or its + Contributor Version. + +1.12. "Secondary License" + means either the GNU General Public License, Version 2.0, the GNU + Lesser General Public License, Version 2.1, the GNU Affero General + Public License, Version 3.0, or any later versions of those + licenses. + +1.13. "Source Code Form" + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants and Conditions +-------------------------------- + +2.1. Grants + +Each Contributor hereby grants You a world-wide, royalty-free, +non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + +(b) under Patent Claims of such Contributor to make, use, sell, offer + for sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + +The licenses granted in Section 2.1 with respect to any Contribution +become effective for each Contribution on the date the Contributor first +distributes such Contribution. + +2.3. Limitations on Grant Scope + +The licenses granted in this Section 2 are the only rights granted under +this License. No additional rights or licenses will be implied from the +distribution or licensing of Covered Software under this License. +Notwithstanding Section 2.1(b) above, no patent license is granted by a +Contributor: + +(a) for any code that a Contributor has removed from Covered Software; + or + +(b) for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + +(c) under Patent Claims infringed by Covered Software in the absence of + its Contributions. + +This License does not grant any rights in the trademarks, service marks, +or logos of any Contributor (except as may be necessary to comply with +the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + +No Contributor makes additional grants as a result of Your choice to +distribute the Covered Software under a subsequent version of this +License (see Section 10.2) or under the terms of a Secondary License (if +permitted under the terms of Section 3.3). + +2.5. Representation + +Each Contributor represents that the Contributor believes its +Contributions are its original creation(s) or it has sufficient rights +to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + +This License is not intended to limit any rights You have under +applicable copyright doctrines of fair use, fair dealing, or other +equivalents. + +2.7. Conditions + +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted +in Section 2.1. + +3. Responsibilities +------------------- + +3.1. Distribution of Source Form + +All distribution of Covered Software in Source Code Form, including any +Modifications that You create or to which You contribute, must be under +the terms of this License. You must inform recipients that the Source +Code Form of the Covered Software is governed by the terms of this +License, and how they can obtain a copy of this License. You may not +attempt to alter or restrict the recipients' rights in the Source Code +Form. + +3.2. Distribution of Executable Form + +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code + Form, as described in Section 3.1, and You must inform recipients of + the Executable Form how they can obtain a copy of such Source Code + Form by reasonable means in a timely manner, at a charge no more + than the cost of distribution to the recipient; and + +(b) You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter + the recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + +You may create and distribute a Larger Work under terms of Your choice, +provided that You also comply with the requirements of this License for +the Covered Software. If the Larger Work is a combination of Covered +Software with a work governed by one or more Secondary Licenses, and the +Covered Software is not Incompatible With Secondary Licenses, this +License permits You to additionally distribute such Covered Software +under the terms of such Secondary License(s), so that the recipient of +the Larger Work may, at their option, further distribute the Covered +Software under the terms of either this License or such Secondary +License(s). + +3.4. Notices + +You may not remove or alter the substance of any license notices +(including copyright notices, patent notices, disclaimers of warranty, +or limitations of liability) contained within the Source Code Form of +the Covered Software, except that You may alter any license notices to +the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + +You may choose to offer, and to charge a fee for, warranty, support, +indemnity or liability obligations to one or more recipients of Covered +Software. However, You may do so only on Your own behalf, and not on +behalf of any Contributor. You must make it absolutely clear that any +such warranty, support, indemnity, or liability obligation is offered by +You alone, and You hereby agree to indemnify every Contributor for any +liability incurred by such Contributor as a result of warranty, support, +indemnity or liability terms You offer. You may include additional +disclaimers of warranty and limitations of liability specific to any +jurisdiction. + +4. Inability to Comply Due to Statute or Regulation +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this +License with respect to some or all of the Covered Software due to +statute, judicial order, or regulation then You must: (a) comply with +the terms of this License to the maximum extent possible; and (b) +describe the limitations and the code they affect. Such description must +be placed in a text file included with all distributions of the Covered +Software under this License. Except to the extent prohibited by statute +or regulation, such description must be sufficiently detailed for a +recipient of ordinary skill to be able to understand it. + +5. Termination +-------------- + +5.1. The rights granted under this License will terminate automatically +if You fail to comply with any of its terms. However, if You become +compliant, then the rights granted under this License from a particular +Contributor are reinstated (a) provisionally, unless and until such +Contributor explicitly and finally terminates Your grants, and (b) on an +ongoing basis, if such Contributor fails to notify You of the +non-compliance by some reasonable means prior to 60 days after You have +come back into compliance. Moreover, Your grants from a particular +Contributor are reinstated on an ongoing basis if such Contributor +notifies You of the non-compliance by some reasonable means, this is the +first time You have received notice of non-compliance with this License +from such Contributor, and You become compliant prior to 30 days after +Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent +infringement claim (excluding declaratory judgment actions, +counter-claims, and cross-claims) alleging that a Contributor Version +directly or indirectly infringes any patent, then the rights granted to +You by any and all Contributors for the Covered Software under Section +2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all +end user license agreements (excluding distributors and resellers) which +have been validly granted by You or Your distributors under this License +prior to termination shall survive termination. + +************************************************************************ +* * +* 6. Disclaimer of Warranty * +* ------------------------- * +* * +* Covered Software is provided under this License on an "as is" * +* basis, without warranty of any kind, either expressed, implied, or * +* statutory, including, without limitation, warranties that the * +* Covered Software is free of defects, merchantable, fit for a * +* particular purpose or non-infringing. The entire risk as to the * +* quality and performance of the Covered Software is with You. * +* Should any Covered Software prove defective in any respect, You * +* (not any Contributor) assume the cost of any necessary servicing, * +* repair, or correction. This disclaimer of warranty constitutes an * +* essential part of this License. No use of any Covered Software is * +* authorized under this License except under this disclaimer. * +* * +************************************************************************ + +************************************************************************ +* * +* 7. Limitation of Liability * +* -------------------------- * +* * +* Under no circumstances and under no legal theory, whether tort * +* (including negligence), contract, or otherwise, shall any * +* Contributor, or anyone who distributes Covered Software as * +* permitted above, be liable to You for any direct, indirect, * +* special, incidental, or consequential damages of any character * +* including, without limitation, damages for lost profits, loss of * +* goodwill, work stoppage, computer failure or malfunction, or any * +* and all other commercial damages or losses, even if such party * +* shall have been informed of the possibility of such damages. This * +* limitation of liability shall not apply to liability for death or * +* personal injury resulting from such party's negligence to the * +* extent applicable law prohibits such limitation. Some * +* jurisdictions do not allow the exclusion or limitation of * +* incidental or consequential damages, so this exclusion and * +* limitation may not apply to You. * +* * +************************************************************************ + +8. Litigation +------------- + +Any litigation relating to this License may be brought only in the +courts of a jurisdiction where the defendant maintains its principal +place of business and such litigation shall be governed by laws of that +jurisdiction, without reference to its conflict-of-law provisions. +Nothing in this Section shall prevent a party's ability to bring +cross-claims or counter-claims. + +9. Miscellaneous +---------------- + +This License represents the complete agreement concerning the subject +matter hereof. If any provision of this License is held to be +unenforceable, such provision shall be reformed only to the extent +necessary to make it enforceable. Any law or regulation which provides +that the language of a contract shall be construed against the drafter +shall not be used to construe this License against a Contributor. + +10. Versions of the License +--------------------------- + +10.1. New Versions + +Mozilla Foundation is the license steward. Except as provided in Section +10.3, no one other than the license steward has the right to modify or +publish new versions of this License. Each version will be given a +distinguishing version number. + +10.2. Effect of New Versions + +You may distribute the Covered Software under the terms of the version +of the License under which You originally received the Covered Software, +or under the terms of any subsequent version published by the license +steward. + +10.3. Modified Versions + +If you create software not governed by this License, and you want to +create a new license for such software, you may create and use a +modified version of this License if you rename the license and remove +any references to the name of the license steward (except to note that +such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary +Licenses + +If You choose to distribute Source Code Form that is Incompatible With +Secondary Licenses under the terms of this version of the License, the +notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice +------------------------------------------- + + This Source Code Form is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular +file, then You may include the notice in a location (such as a LICENSE +file in a relevant directory) where a recipient would be likely to look +for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice +--------------------------------------------------------- + + This Source Code Form is "Incompatible With Secondary Licenses", as + defined by the Mozilla Public License, v. 2.0. diff --git a/src/ae_adaptive_predicate_eval.hpp b/src/ae_adaptive_predicate_eval.hpp index 55f7e1a..0b86685 100644 --- a/src/ae_adaptive_predicate_eval.hpp +++ b/src/ae_adaptive_predicate_eval.hpp @@ -53,10 +53,13 @@ template class adaptive_eval_impl { public: using E = std::remove_cvref_t; + using allocator_type = std::remove_cvref_t; - explicit adaptive_eval_impl(allocator_type_ mem_pool) - : exact_storage{num_partials_for_exact(), - std::forward(mem_pool)} {} + adaptive_eval_impl() = default; + + explicit adaptive_eval_impl(allocator_type mem_pool_) + : cache(), mem_pool(std::forward(mem_pool_)), + exact_storage{num_partials_for_exact(), mem_pool} {} explicit adaptive_eval_impl(const E_ &) : adaptive_eval_impl() {} @@ -78,17 +81,17 @@ class adaptive_eval_impl { exact_storage.data() + begin_idx, num_partials_for_exact()}; } - // exact_eval_root computes the requested result of the (sub)expression to 1/2 - // epsilon precision + // exact_eval_round computes the requested result of the (sub)expression to + // 1/2 epsilon precision template - std::pair exact_eval_root(evaluatable auto &&expr) { + std::pair exact_eval_round(evaluatable auto &&expr) { using sub_expr = decltype(expr); auto memory = get_memory(); - exact_eval(std::forward(expr), memory); - _impl::merge_sum(memory); - - const eval_type exact_result = std::reduce(memory.begin(), memory.end()); - cache[branch_id] = exact_result; + auto result_end = + exact_eval_merge(std::forward(expr), memory); + const eval_type exact_result = std::reduce(memory.begin(), result_end); + cache[branch_id] = + std::pair{exact_result, std::distance(memory.begin(), result_end)}; return {exact_result, abs(exact_result) * std::numeric_limits::epsilon() / 2.0}; } @@ -108,12 +111,12 @@ class adaptive_eval_impl { using Op = typename sub_expr::Op; constexpr std::size_t subexpr_choice_latency = - std::max(exact_fp_rounding_latency(), - exact_fp_rounding_latency()) + + std::max(exact_fp_rounding_latency(), + exact_fp_rounding_latency()) + 2 * overshoot_latency() + 2 * cmp_latency() + error_contrib_latency(); - if constexpr (exact_fp_rounding_latency() > + if constexpr (exact_fp_rounding_latency() > subexpr_choice_latency && is_expr_v && is_expr_v) { // We need to reduce error efficiently, so don't just exactly evaluate @@ -132,7 +135,7 @@ class adaptive_eval_impl { // guarantees we deal with the largest part of the error, making // the fall-through case very unlikely const auto [new_left, new_left_err] = - exact_eval_root<_impl::left_branch_id(branch_id)>(expr.lhs()); + exact_eval_round<_impl::left_branch_id(branch_id)>(expr.lhs()); const auto [new_result, new_abs_err] = _impl::eval_with_max_abs_err(new_left, new_left_err, @@ -144,7 +147,7 @@ class adaptive_eval_impl { } } else { const auto [new_right, new_right_err] = - exact_eval_root<_impl::right_branch_id(branch_id)>( + exact_eval_round<_impl::right_branch_id(branch_id)>( expr.rhs()); const auto [new_result, new_abs_err] = _impl::eval_with_max_abs_err(left_result, left_abs_err, @@ -157,7 +160,7 @@ class adaptive_eval_impl { } } } - return exact_eval_root(std::forward(expr)); + return exact_eval_round(std::forward(expr)); } // Returns the result and maximum absolute error from computing the expression @@ -167,8 +170,8 @@ class adaptive_eval_impl { if constexpr (is_expr_v) { const auto exact_eval_info = cache[branch_id]; if (exact_eval_info) { - return {*exact_eval_info, - abs(*exact_eval_info) * + return {exact_eval_info->first, + abs(exact_eval_info->first) * std::numeric_limits::epsilon() / 2.0}; } using Op = typename sub_expr::Op; @@ -218,52 +221,136 @@ class adaptive_eval_impl { } } + template + requires expr_type || arith_number + constexpr auto + exact_eval_merge(sub_expr &&e, + std::span()> + partial_results) noexcept + -> decltype(partial_results.end()) { + auto partial_last = + exact_eval(std::forward(e), partial_results); + if constexpr (!_impl::linear_merge_lower_latency()) { + // In cases where the linear merge algorithm doesn't make sense, we need + // to ensure we can compute the correctly rounded result with just a + // reduction. + // In cases where the linear merge algorithm does make sense, this is + // already handled by exact_eval + std::span nonzero_results{partial_results.begin(), partial_last}; + auto nonzero_last = _impl::merge_sum(nonzero_results).second; + partial_last = partial_results.begin() + + std::distance(nonzero_results.begin(), nonzero_last); + } + return partial_last; + } + template requires expr_type || arith_number - constexpr void + constexpr auto exact_eval(sub_expr_ &&e, std::span()> - partial_results) noexcept { + partial_results) noexcept -> decltype(partial_results.end()) { using sub_expr = std::remove_cvref_t; - if constexpr (is_expr_v) { + if constexpr (num_partials_for_exact() == 0) { + return partial_results.begin(); + } else if constexpr (is_expr_v) { if (cache[branch_id]) { - return; + return partial_results.begin() + cache[branch_id]->second; } + constexpr std::size_t left_id = _impl::left_branch_id(branch_id); constexpr std::size_t reserve_left = num_partials_for_exact(); - const auto storage_left = partial_results.template first(); - exact_eval<_impl::left_branch_id(branch_id)>(e.lhs(), storage_left); - constexpr std::size_t reserve_right = num_partials_for_exact(); + constexpr std::size_t start_left = + num_partials_for_exact() - reserve_right - reserve_left; + const auto storage_left = + partial_results.template subspan(); + + auto left_end = exact_eval(e.lhs(), storage_left); + + constexpr std::size_t start_right = + num_partials_for_exact() - reserve_right; const auto storage_right = - partial_results.template subspan(); - exact_eval<_impl::right_branch_id(branch_id)>(e.rhs(), - storage_right); + partial_results.template subspan(); + auto right_end = exact_eval<_impl::right_branch_id(branch_id)>( + e.rhs(), storage_right); + + if constexpr (_impl::linear_merge_lower_latency()) { + // Since we're at a point where we want to use linear merge sum, + // we need to ensure the lower levels of the tree have been merged + // together. If they're also at a point where linear merge sum is being + // used, then they're already merged and we don't need to do anything + if constexpr (!_impl::linear_merge_lower_latency()) { + const std::span nonzero_left{storage_left.begin(), left_end}; + left_end = storage_left.begin() + + std::distance(nonzero_left.begin(), + _impl::merge_sum(nonzero_left).second); + } + if constexpr (!_impl::linear_merge_lower_latency()) { + const std::span nonzero_right{storage_right.begin(), right_end}; + right_end = storage_right.begin() + + std::distance(nonzero_right.begin(), + _impl::merge_sum(nonzero_right).second); + } + } using Op = typename sub_expr::Op; if constexpr (std::is_same_v, Op> || std::is_same_v, Op>) { if constexpr (std::is_same_v, Op>) { - for (eval_type &v : storage_right) { + for (eval_type &v : std::span{storage_right.begin(), right_end}) { v = -v; } } + if constexpr (_impl::linear_merge_lower_latency()) { + std::span left_span{storage_left.begin(), left_end}; + std::vector left_copy(mem_pool); + left_copy.reserve(left_span.size()); + for (const auto v : left_span) { + left_copy.push_back(v); + } + return _impl::merge_sum_linear( + partial_results, std::span{left_copy}, + std::span{storage_right.begin(), right_end}) + .second; + } else { + // We must compact so that the returned end iterator doesn't include + // garbage data + return std::copy(storage_right.begin(), right_end, + partial_results.begin() + + std::distance(storage_left.begin(), left_end)); + } } else if constexpr (std::is_same_v, Op>) { - const auto storage_mult = - partial_results.template last(); - _impl::sparse_mult(storage_left, storage_right, storage_mult); + if constexpr (_impl::linear_merge_lower_latency()) { + return _impl::sparse_mult_merge( + std::span{storage_left.begin(), left_end}, + std::span{storage_right.begin(), right_end}, partial_results, + mem_pool); + } else { + return _impl::sparse_mult(storage_left, storage_right, + partial_results); + } } - } else if constexpr (!std::is_same_v) { - // additive_id is zero, so we don't actually have memory allocated for it + } else { *partial_results.begin() = eval_type(e); + return partial_results.begin() + 1; } } - std::array, num_internal_nodes()> cache; + // Cache the rounded value and the index of the end of the non-zero element + // subspan + std::array>, + num_internal_nodes()> + cache; + + allocator_type mem_pool; - std::vector> exact_storage; + std::vector exact_storage; }; } // namespace _impl diff --git a/src/ae_expr_utils.hpp b/src/ae_expr_utils.hpp index b632481..5640da3 100644 --- a/src/ae_expr_utils.hpp +++ b/src/ae_expr_utils.hpp @@ -4,6 +4,7 @@ #include "ae_expr.hpp" +#include #include namespace adaptive_expr { @@ -368,11 +369,19 @@ constexpr std::size_t get_memory_begin_idx() { return 0; } else { if constexpr (branch_id < _impl::right_branch_id(root_branch_id)) { - return get_memory_begin_idx() - + num_partials_for_exact() - + num_partials_for_exact(); + return left_start + + get_memory_begin_idx(); } else { - return num_partials_for_exact() + + constexpr std::size_t right_start = + num_partials_for_exact() - + num_partials_for_exact(); + return right_start + get_memory_begin_idx( root_branch_id), typename root_t::RHS>(); @@ -380,6 +389,41 @@ constexpr std::size_t get_memory_begin_idx() { } } +// Useful functors for filtering and merging +template +static constexpr auto is_nonzero(const eval_type v) { + return v != eval_type{0}; +} + +template +auto copy_nonzero(range_type &range, allocator_type_ &&mem_pool) { + using eval_type = std::remove_cvref_t; + using allocator_type = std::remove_cvref_t; + auto nonzero_range = range | std::views::filter(is_nonzero); + const std::size_t size = + std::distance(nonzero_range.begin(), nonzero_range.end()); + std::vector terms{size, mem_pool}; + std::ranges::copy(nonzero_range, terms.begin()); + return terms; +} + +template +static constexpr auto zero_prune_store(const eval_type v, iterator i) + -> iterator { + if constexpr (scalar_type) { + if (v) { + *i = v; + ++i; + } + } else { + *i = v; + ++i; + } + return i; +} + +// A simple (overly pessimistic) attempt to model latencies so decisions +// regarding algorithm choices can be made by the compiler template consteval std::size_t op_latency() { if constexpr (std::is_same_v, Op> || std::is_same_v, Op>) { @@ -401,10 +445,134 @@ template consteval std::size_t error_contrib_latency() { consteval std::size_t cmp_latency() { return 1; } consteval std::size_t abs_latency() { return 1; } consteval std::size_t fma_latency() { return 4; } +consteval std::size_t swap_latency() { return 2; } +consteval std::size_t mem_alloc_latency() { + // This probably results in a cache miss, so guess a high latency + return 100; +} + +template +consteval std::size_t unchecked_two_sum_latency() { + if (vector_type) { + // unchecked_dekker_sum has 2 subtractions, 1 addition + return op_latency>() + 2 * op_latency>(); + } else { + // knuth_sum has 2 additions, 4 subtractions. + return 2 * op_latency>() + 4 * op_latency>(); + } +} + +template consteval std::size_t two_sum_latency() { + if (vector_type) { + // Needs to compare the absolute values of the entries, and then use + // unchecked sum Assume branch prediction is optimized away + return unchecked_two_sum_latency() + 2 * abs_latency() + + cmp_latency(); + } else { + // knuth_sum has 2 additions, 4 subtractions. + return 2 * op_latency>() + 4 * op_latency>(); + } +} + consteval std::size_t overshoot_latency() { return abs_latency() + op_latency>(); } +// Computes the total (non-pipelined) latency of merging two sorted lists +// into one sorted list. +// Merging two sorted lists in linear time requires a memory allocation which +// can be pretty expensive. +consteval std::size_t merge_latency(std::size_t left_terms, + std::size_t right_terms) { + return (left_terms + right_terms - 1) * cmp_latency() + mem_alloc_latency(); +} + +// Computes the total (non-pipelined) latency of the quadratic merge sum +// algorithm +// Note that this ignores the zero elimination, assumes neither subtree has used +// the merge sum algorithm (so we can't just perform a partial merge) and is +// overly pessimistic +template +consteval std::size_t merge_sum_latency() { + using E = std::remove_cvref_t; + if constexpr (is_expr_v) { + const std::size_t terms = num_partials_for_exact(); + return (terms - 1) * terms * two_sum_latency() / 2; + } else { + return 0; + } +} + +// Computes the total (non-pipelined) latency of the linear merge sum +// algorithm. +// Note that this ignores zero elimination, and doesn't compute the latency of +// the required merges of the left and right subtree +template +consteval std::size_t merge_sum_linear_latency() { + using E = std::remove_cvref_t; + if constexpr (is_expr_v) { + using LHS = typename E::LHS; + using RHS = typename E::RHS; + const std::size_t left_terms = num_partials_for_exact(); + const std::size_t right_terms = num_partials_for_exact(); + return (left_terms + right_terms - 2) * two_sum_latency() + + unchecked_two_sum_latency() + + merge_latency(left_terms, right_terms); + } else { + return 0; + } +} + +template +consteval std::size_t total_merge_sum_latency(); + +// Determines whether it's more efficient to use the linear merge over the +// quadratic merge with the overly pessimistic latency model above +template +consteval bool linear_merge_lower_latency() { + using E = std::remove_cvref_t; + if constexpr (is_expr_v) { + using LHS = typename E::LHS; + using RHS = typename E::RHS; + constexpr std::size_t linear_latency = + merge_sum_linear_latency(); + constexpr std::size_t quadratic_latency = merge_sum_latency(); + // linear merge has a high constant latency, reduce template instantiations + // by checking that it's faster than the quadratic merge first + if constexpr (linear_latency < quadratic_latency) { + constexpr std::size_t total_linear_latency = + total_merge_sum_latency() + + total_merge_sum_latency() + linear_latency; + return total_linear_latency < quadratic_latency; + } else { + return false; + } + } else { + return false; + } +} + +template +consteval std::size_t total_merge_sum_latency() { + using E = std::remove_cvref_t; + if constexpr (is_expr_v) { + if constexpr (linear_merge_lower_latency()) { + using LHS = typename E::LHS; + using RHS = typename E::RHS; + return total_merge_sum_latency() + + total_merge_sum_latency() + + merge_sum_linear_latency(); + } else { + return merge_sum_latency(); + } + } else { + return 0; + } +} + +// The latency cost of all of the (addition doesn't need to do anything) +// negations and multiplications that need to happen to get a series which sums +// to the exact result template consteval std::size_t exact_fp_latency() { using E = std::remove_cvref_t; if constexpr (is_expr_v) { @@ -436,20 +604,16 @@ template consteval std::size_t exact_fp_latency() { } } -template consteval std::size_t exact_fp_rounding_latency() { +// The latency cost converting the sum into a value representable as `eval_type` +// with at most 1/2 epsilon rounding error +// Note that this ignores the zero elimination and is overly pessimistic +template +consteval std::size_t exact_fp_rounding_latency() { using E = std::remove_cvref_t; if constexpr (is_expr_v) { - const std::size_t fp_vals = num_partials_for_exact(); - // dekker_sum has 2 additions, 2 abs(), 1 comparison. - // Assume branch prediction is optimized away - const std::size_t two_sum_cost = - 2 * op_latency>() + 2 * abs_latency() + cmp_latency(); - - const std::size_t exact_latency = exact_fp_latency(); - // Note that this ignores the zero elimination and is overly pessimistic - const std::size_t merge_latency = - two_sum_cost * fp_vals * (fp_vals - 1) / 2; - const std::size_t accumulate_latency = fp_vals - 1; + const std::size_t exact_latency = exact_fp_latency(); + const std::size_t merge_latency = total_merge_sum_latency(); + const std::size_t accumulate_latency = num_partials_for_exact() - 1; return exact_latency + merge_latency + accumulate_latency; } else { return 0; diff --git a/src/ae_fp_eval.hpp b/src/ae_fp_eval.hpp index 2aaf407..36b5603 100644 --- a/src/ae_fp_eval.hpp +++ b/src/ae_fp_eval.hpp @@ -44,9 +44,8 @@ exactfp_eval(E &&e, allocator_type_ &&mem_pool = mem_pool, storage_needed}; std::span partial_span{partial_results_ptr.get(), storage_needed}; - _impl::exactfp_eval_impl(std::forward(e), partial_span); - const eval_type result = _impl::merge_sum(partial_span); - return result; + auto last = _impl::exactfp_eval_impl(std::forward(e), partial_span); + return _impl::merge_sum(std::span{partial_span.begin(), last}).first; } else { return static_cast(e); } diff --git a/src/ae_fp_eval_impl.hpp b/src/ae_fp_eval_impl.hpp index d722ebe..7cd8fed 100644 --- a/src/ae_fp_eval_impl.hpp +++ b/src/ae_fp_eval_impl.hpp @@ -15,29 +15,42 @@ namespace adaptive_expr { namespace _impl { -/* merge_sum_linear runs in-place in linear time, but requires the two sequences - * in storage to be strongly non-overlapping. +/* merge_sum_linear runs in linear time, but requires the two sequences + * left and right to be non-overlapping. + * + * merge_sum_linear_fast performs two fewer additions per value, but requires + * the two sequences left and right to be strongly non-overlapping. + * * That is, each sequence must be non-overlapping and elements which aren't * powers of two must be non-adjacent. Elements which are powers of two can be * adjacent to at most one other element in its sequence. * Elements a, b with abs(a) < abs(b) are adjacent if (a, b) is overlapping or * if (2 * a, b) is overlapping + * + * One of left or right can alias with the tail of the result so long as there + * is space for the two sequences to be merged, starting from the beginning of + * result */ -auto merge_sum_linear( - std::ranges::range auto &&storage, - const typename std::remove_cvref_t::iterator midpoint) -> - typename std::remove_cvref_t::value_type; +auto merge_sum_linear(std::ranges::range auto &&result, + std::ranges::range auto &&left, + std::ranges::range auto &&right) + -> std::pair::value_type, + typename std::remove_cvref_t>; auto merge_sum_linear_fast( std::ranges::range auto &&storage, const typename std::remove_cvref_t::iterator midpoint) -> typename std::remove_cvref_t::value_type; -constexpr auto merge_sum_quadratic(std::ranges::range auto &&storage) -> - typename std::remove_cvref_t::value_type; + +constexpr auto merge_sum_quadratic(std::ranges::range auto &&storage) + -> std::pair::value_type, + std::remove_cvref_t>; constexpr auto merge_sum_quadratic_keep_zeros(std::ranges::range auto &&storage) - -> typename std::remove_cvref_t::value_type; + -> std::pair::value_type, + std::remove_cvref_t>; -constexpr auto merge_sum(std::ranges::range auto storage) -> - typename decltype(storage)::value_type { +constexpr auto merge_sum(std::ranges::range auto storage) + -> std::pair::value_type, + std::remove_cvref_t> { if constexpr (vector_type) { return merge_sum_quadratic_keep_zeros(storage); } else { @@ -47,8 +60,15 @@ constexpr auto merge_sum(std::ranges::range auto storage) -> template -constexpr void sparse_mult(span_l storage_left, span_r storage_right, - span_m storage_mult); +constexpr auto sparse_mult(span_l storage_left, span_r storage_right, + span_m storage_mult) + -> std::remove_cvref_t; + +template +constexpr auto sparse_mult_merge(span_l left_terms, span_r right_terms, + span_result result, allocator_type_ &&mem_pool) + -> decltype(result.end()); template constexpr std::pair knuth_sum(const eval_type &lhs, @@ -183,40 +203,46 @@ error_overlaps(const eval_type left_result, const eval_type left_abs_err, template requires expr_type || arith_number -constexpr void exactfp_eval_impl(E_ &&e, span_t partial_results) noexcept { +constexpr auto exactfp_eval_impl(E_ &&e, span_t partial_results) noexcept + -> decltype(partial_results.end()) { using E = std::remove_cvref_t; - if constexpr (is_expr_v) { + if constexpr (num_partials_for_exact() == 0) { + return partial_results.begin(); + } else if constexpr (is_expr_v) { constexpr std::size_t reserve_left = num_partials_for_exact(); - const auto storage_left = partial_results.template first(); - exactfp_eval_impl(e.lhs(), storage_left); constexpr std::size_t reserve_right = num_partials_for_exact(); - const auto storage_right = - partial_results.template subspan(); - exactfp_eval_impl(e.rhs(), storage_right); + constexpr std::size_t left_start = + num_partials_for_exact() - reserve_left - reserve_right; + const auto storage_left = + partial_results.template subspan(); + const auto left_end = exactfp_eval_impl(e.lhs(), storage_left); + + const std::size_t right_start = + left_start + std::distance(storage_left.begin(), left_end); + const std::span storage_right{ + partial_results.begin() + right_start, reserve_right}; + const auto right_end = exactfp_eval_impl(e.rhs(), storage_right); + using Op = typename E::Op; - if constexpr (std::is_same_v, Op>) { - for (eval_type &v : storage_right) { - v = -v; - } - } else if constexpr (std::is_same_v, Op>) { - const auto storage_mult = [partial_results]() { - if constexpr (span_t::extent == std::dynamic_extent) { - return partial_results.last(partial_results.size() - reserve_left - - reserve_right); - } else { - return partial_results.template last(); + if constexpr (std::is_same_v, Op>) { + return sparse_mult(storage_left, storage_right, partial_results); + } else { + if constexpr (std::is_same_v, Op>) { + for (auto &v : std::span{storage_right.begin(), right_end}) { + v = -v; } - }(); - sparse_mult(storage_left, storage_right, storage_mult); + } + return partial_results.begin() + right_start + + std::distance(storage_right.begin(), right_end); } } else if constexpr (!std::is_same_v) { - partial_results[0] = eval_type(e); + return zero_prune_store(eval_type(e), partial_results.begin()); } } +// Linear merge sum which requires the inputs be strongly non-overlapping auto merge_sum_linear_fast( std::ranges::range auto &&storage, const typename std::remove_cvref_t::iterator midpoint) -> @@ -225,14 +251,25 @@ auto merge_sum_linear_fast( if (storage.size() > 1) { std::ranges::inplace_merge( storage, midpoint, [](const eval_type &left, const eval_type &right) { - return abs(left) < abs(right); + // Zero pruning ensures all of the zeros are at the ends of left and + // right, so we need to ensure that zero is considered greater than + // any non-zero number + // This algorithm technically works regardless of where the zeros are, + // but ensuring they remain at the end allows us to reduce the number + // of computations we have to perform + if (left == eval_type{0}) { + return false; + } else if (right == eval_type{0}) { + return true; + } else { + return abs(left) < abs(right); + } }); auto nonzero_itr = storage.begin(); for (; nonzero_itr != storage.end() && *nonzero_itr == eval_type{0}; ++nonzero_itr) { } std::ranges::rotate(storage, nonzero_itr); - auto [Q, q] = dekker_sum_unchecked(storage[1], storage[0]); auto out = storage.begin(); *out = q; @@ -254,40 +291,33 @@ auto merge_sum_linear_fast( } } -auto merge_sum_linear( - std::ranges::range auto &&storage, - const typename std::remove_cvref_t::iterator midpoint) -> - typename std::remove_cvref_t::value_type { - using eval_type = typename std::remove_cvref_t::value_type; - if (storage.size() > 1) { - std::ranges::inplace_merge( - storage, midpoint, [](const eval_type &left, const eval_type &right) { - return abs(left) < abs(right); - }); - auto nonzero_itr = storage.begin(); - for (; nonzero_itr != storage.end() && *nonzero_itr == eval_type{0}; - ++nonzero_itr) { - } - std::ranges::rotate(storage, nonzero_itr); - auto [Q, q] = dekker_sum_unchecked(storage[1], storage[0]); - auto out = storage.begin(); - for (auto h : std::span{storage.begin() + 2, storage.end()}) { +// Linear merge sum without the strongly non-overlapping requirement +auto merge_sum_linear(std::ranges::range auto &&result, + std::ranges::range auto &&left, + std::ranges::range auto &&right) + -> std::pair::value_type, + typename std::remove_cvref_t> { + using eval_type = typename std::remove_cvref_t::value_type; + const auto [left_last, right_last, result_last] = std::ranges::merge( + left, right, result.begin(), + [](eval_type l, eval_type r) { return abs(l) < abs(r); }); + if (std::distance(result.begin(), result_last) > 1) { + auto [Q, q] = dekker_sum_unchecked(result[1], result[0]); + auto out = result.begin(); + for (auto &h : std::span{result.begin() + 2, result_last}) { auto [R, g] = dekker_sum_unchecked(h, q); - *out = g; - ++out; + out = zero_prune_store(g, out); std::tie(Q, q) = two_sum(Q, R); } - *out = q; - ++out; - *out = Q; - ++out; + out = zero_prune_store(q, out); + out = zero_prune_store(Q, out); - return Q; - } else if (storage.size() == 1) { - return storage[0]; + return std::pair{Q, out}; + } else if (std::distance(result.begin(), result_last) == 1) { + return std::pair{result[0], result.begin() + 1}; } else { - return 0.0; + return std::pair{eval_type{0}, result.begin()}; } } @@ -296,42 +326,35 @@ constexpr auto merge_sum_append(auto begin, auto end, auto v) { auto out = begin; for (auto &e : std::span{begin, end}) { const auto [result, error] = two_sum(v, e); - e = eval_type{0.0}; + e = eval_type{0}; v = result; - if (error) { - *out = error; - ++out; - } + out = zero_prune_store(error, out); } return std::pair{out, v}; } -constexpr auto merge_sum_quadratic(std::ranges::range auto &&storage) -> - typename std::remove_cvref_t::value_type { +constexpr auto merge_sum_quadratic(std::ranges::range auto &&storage) + -> std::pair::value_type, + std::remove_cvref_t> { using eval_type = typename std::remove_cvref_t::value_type; if (storage.size() > 1) { auto out = storage.begin(); - for (eval_type &inp : storage | std::views::filter([](const eval_type v) { - return v != eval_type{0}; - })) { + for (eval_type &inp : storage) { eval_type v = inp; - inp = eval_type{0.0}; + inp = eval_type{0}; auto [new_out, result] = merge_sum_append(storage.begin(), out, v); out = new_out; - if (result) { - *out = result; - ++out; - } + out = zero_prune_store(result, out); } if (out == storage.begin()) { - return eval_type{0}; + return {eval_type{0}, out}; } else { - return *(out - 1); + return {*(out - 1), out}; } } else if (storage.size() == 1) { - return storage[0]; + return {storage[0], storage.end()}; } else { - return eval_type{0.0}; + return {eval_type{0}, storage.end()}; } } @@ -346,54 +369,153 @@ constexpr auto merge_sum_append_keep_zeros(auto begin, auto end) { } constexpr auto merge_sum_quadratic_keep_zeros(std::ranges::range auto &&storage) - -> typename std::remove_cvref_t::value_type { + -> std::pair::value_type, + std::remove_cvref_t> { using eval_type = typename std::remove_cvref_t::value_type; if (storage.size() > 1) { for (auto inp = storage.begin(); inp != storage.end(); ++inp) { *inp = merge_sum_append_keep_zeros(storage.begin(), inp); } - return std::reduce(storage.begin(), storage.end()); + return {std::reduce(storage.begin(), storage.end()), storage.end()}; } else if (storage.size() == 1) { - return storage[0]; + return {storage[0], storage.end()}; } else { - return eval_type{0.0}; + return {eval_type{0}, storage.end()}; } } template -constexpr void sparse_mult(span_l storage_left, span_r storage_right, - span_m storage_mult) { +constexpr auto sparse_mult(span_l storage_left, span_r storage_right, + span_m storage_mult) + -> std::remove_cvref_t { #ifndef __FMA__ static_assert(!vector_type, "Vectorization doesn't have a functional mul_sub method, " "cannot efficiently evaluate multiplications exactly"); #endif // __FMA__ - // This performs multiplication in-place for a contiguous piece of memory - // starting at storage_left.begin() and ending at storage_mult.end() + // This performs multiplication in-place for a contiguous piece of memory, + // where storage_left and storage_right can alias parts of storage_mult // // storage_mult is initially empty and written to first - // storage_right is overwritten second, each value in storage_left is finished + // storage_left is overwritten second, each value in storage_left is finished // and over-writable when its iteration of the outer loop finishes - // storage_left can be shown to only be is overwritten during the final + // storage_right can be shown to only be is overwritten during the final // iteration of the outer loop, the values in it are only overwritten after - // they've been multiplied If storage_left and storage_right are sorted by - // increasing magnitude before multiplying, the first element in the output is - // the least significant and the last element is the most significant - auto out_i = storage_mult.end() - 1; - for (auto r_itr = storage_right.rbegin(); r_itr != storage_right.rend(); - ++r_itr) { - const auto r = *r_itr; - for (auto l_itr = storage_left.rbegin(); l_itr != storage_left.rend(); - ++l_itr) { - const auto l = *l_itr; + // they've been multiplied. + // + // If storage_left and storage_right are sorted by increasing magnitude before + // multiplying, the first element in the output is the least significant and + // the last element is the most significant + auto out_i = storage_mult.begin(); + for (const auto l : storage_left) { + for (const auto r : storage_right) { auto [upper, lower] = exact_mult(r, l); - *out_i = upper; - --out_i; - *out_i = lower; - --out_i; + out_i = zero_prune_store(upper, out_i); + out_i = zero_prune_store(lower, out_i); } } + + return out_i; +} + +template +constexpr auto sparse_mult_merge_term(const span_l storage_left, + const eval_type v, iterator_t out) + -> iterator_t { + // h is the output list + // We only need two values of Q at a time + // T_i, t_i are transient + // + // (Q_2, h_1) <= exact_mult(l[0], v) + // for i : 1 ... m + // (T_i, t_i) <= exact_mult(l[i], v) + // (Q_{2i - 1}, h_{2i - 2}) <= two_sum(Q_{2i - 2}, t_i) + // (Q_{2i}, h_{2i - 1}) <= fast_two_sum(T_i, Q_{2i - 1}) + // h_{2m} <= Q_{2m} + auto [high, low] = exact_mult(storage_left[0], v); + eval_type accumulated = high; + out = zero_prune_store(low, out); + for (std::size_t i = 1; i < storage_left.size(); ++i) { + const auto [mult_high, mult_low] = exact_mult(storage_left[i], v); + std::tie(high, low) = two_sum(accumulated, mult_low); + out = zero_prune_store(low, out); + std::tie(accumulated, low) = two_sum(mult_high, high); + out = zero_prune_store(low, out); + } + out = zero_prune_store(accumulated, out); + return out; +} + +// Recursively merges the subspans with the linear merge algorithm +template +constexpr auto merge_spans(const iter_span_type &iter_spans, + allocator_type &&mem_pool) + -> std::remove_cvref_t { + if (iter_spans.size() == 1) { + // We have zero ranges, nothing to do + return iter_spans[0]; + } else if (iter_spans.size() == 2) { + // We have only one range, already merged + return iter_spans[1]; + } else { + const std::size_t midpoint_itr = (iter_spans.size() - 1) / 2; + + // The iterator marking beginning of the right subspan must be included in + // this subspan, so add 1 to midpoint + const std::span left_span = iter_spans | std::views::take(midpoint_itr + 1); + const std::span right_span = iter_spans | std::views::drop(midpoint_itr); + + const auto left_end = + merge_spans(left_span, std::forward(mem_pool)); + using eval_type = std::remove_cvref_t; + + const auto right_end = + merge_spans(right_span, std::forward(mem_pool)); + + // Merge the spans together. merge_sum_linear requires left not alias + // the output, so copy left out + std::vector> left_copy( + std::forward(mem_pool)); + left_copy.reserve(std::distance(left_span[0], left_end)); + for (auto &v : std::span{iter_spans[0], left_end}) { + left_copy.push_back(v); + } + + std::span results{iter_spans[0], *(iter_spans.end() - 1)}; + auto results_end = merge_sum_linear(results, std::span{left_copy}, + std::span{right_span[0], right_end}) + .second; + return iter_spans[0] + std::distance(results.begin(), results_end); + } +} + +template +constexpr auto sparse_mult_merge(span_l left_terms, span_r right_terms, + span_result result, allocator_type_ &&mem_pool) + -> decltype(result.end()) { + if (left_terms.size() < right_terms.size()) { + // We want to minimize the number of lists to merge at the end since merging + // has a high constant cost + return sparse_mult_merge(right_terms, left_terms, result, + std::forward(mem_pool)); + } else if (right_terms.size() > 0) { + // Multiply all left_terms by all right_terms, keep track of where the spans + // end so we can merge them all at the end + auto out_iter = result.begin(); + std::vector mult_spans; + mult_spans.reserve(right_terms.size() + 1); + mult_spans.push_back(out_iter); + for (const auto v : right_terms) { + out_iter = sparse_mult_merge_term(left_terms, v, out_iter); + mult_spans.push_back(out_iter); + } + + return merge_spans(mult_spans, std::forward(mem_pool)); + } else { + return result.begin(); + } } template diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index f9f77d6..25b4add 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -18,7 +18,7 @@ if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") endif() endif() -add_executable(tests test_adaptive_expr.cpp test_geom_exprs.cpp test_determinant_4x4.cpp test_determinant_5x5.cpp) +add_executable(tests test_adaptive_expr.cpp test_exact_eval.cpp test_geom_exprs.cpp test_determinant_4x4.cpp test_determinant_5x5.cpp) target_compile_features(tests PRIVATE cxx_std_23) target_compile_options(tests PRIVATE -fprofile-arcs -ftest-coverage) target_link_libraries(tests PRIVATE adaptive_predicates shewchuk Catch2::Catch2WithMain fmt) diff --git a/tests/test_adaptive_expr.cpp b/tests/test_adaptive_expr.cpp index c216232..d9057c8 100644 --- a/tests/test_adaptive_expr.cpp +++ b/tests/test_adaptive_expr.cpp @@ -342,7 +342,7 @@ TEST_CASE("expr_template_eval_simple", "[expr_template_eval]") { REQUIRE(exactfp_eval(e.lhs()) == -15.0); REQUIRE(exactfp_eval(e) == -14.5); std::vector fp_vals{5.0, 10.0, 11.0, 11.0, 44.0}; - REQUIRE(merge_sum(std::span{fp_vals}) == 81.0); + REQUIRE(merge_sum(std::span{fp_vals}).first == 81.0); REQUIRE(correct_eval(e)); REQUIRE(!correct_eval(e + 14.5)); REQUIRE(*correct_eval(e.lhs().lhs().lhs().lhs()) == 0.0); @@ -389,23 +389,26 @@ TEST_CASE("nonoverlapping", "[eval_utils]") { CHECK( !is_nonoverlapping(std::vector{-0.375, 0.5, 1.5, 0, 0, 0, -14.0})); - std::vector merge_test1{0, - -1.5436178396078065e-49, - -2.184158631330676e-33, - -1.1470824290427116e-16, - 0, - 1.0353799381025734e-34, - -1.7308376953906192e-17, - -1.2053999999999998}; + std::vector merge_test1{ + -1.5436178396078065e-49, -2.184158631330676e-33, + -1.1470824290427116e-16, 0, + 1.0353799381025734e-34, -1.7308376953906192e-17, + -1.2053999999999998, 0}; const auto midpoint1 = merge_test1.begin() + 4; - CHECK(*midpoint1 == real{0}); CHECK(is_nonoverlapping(std::span{merge_test1.begin(), midpoint1})); CHECK(is_nonoverlapping(std::span{midpoint1, merge_test1.end()})); REQUIRE(midpoint1 > merge_test1.begin()); REQUIRE(midpoint1 < merge_test1.end()); - merge_sum_linear(merge_test1, midpoint1); - CHECK(is_nonoverlapping(merge_test1)); + std::vector left; + for (const auto v : std::span{merge_test1.begin(), merge_test1.begin() + 3}) { + left.push_back(v); + } + const auto merge_end = + merge_sum_linear(merge_test1, left, + std::span{midpoint1, merge_test1.end() - 1}) + .second; + CHECK(is_nonoverlapping(std::span{merge_test1.begin(), merge_end})); // Same strongly non-overlapping sequence but for merge_sum_linear_fast std::vector merge_test2{0, diff --git a/tests/test_exact_eval.cpp b/tests/test_exact_eval.cpp new file mode 100644 index 0000000..d169f52 --- /dev/null +++ b/tests/test_exact_eval.cpp @@ -0,0 +1,199 @@ + +#include + +#include +#include + +#include "ae_expr.hpp" +#include "ae_fp_eval.hpp" + +#include "testing_utils.hpp" + +using namespace adaptive_expr; +using namespace _impl; + +using real = double; + +auto mult_test_case() { + const std::vector left_terms{ + 9.2730153767185534643293381538690987834189119468944775061e-69, + -1.4836824602749685542926941046190558053470259115031164010e-67, + -8.0712325838958289353522559291276635810878209585769532213e-65, + -3.7146931176476335494342239740309342440743131677144559041e-50, + 1.8816263559049057348829376728746432426858986035939542834e-48, + -2.2999754552016361580557916276085068746722863910872779169e-34, + -1.0014835710813626435891085301441622056398128570720018615e-32, + -5.7128445182814558007931378008513521845906262593833835339e-17, + 8.8817841970012523233890533447265625000000000000000000000e-16, + 2.6181623585046946089960329118184745311737060546875000000e+01}; + const std::vector right_terms{ + 3.4159996254284212719523204452351641690112422793523625658e-53, + -1.2829270608442539101474575042327113018312945024504233750e-49, + 1.5826728910280977138562999662532483368295423737368913138e-33, + -5.3571703870932964415335228169487282410437732177140371381e-17}; + std::vector results(2 * left_terms.size() * right_terms.size()); + return std::tuple{std::move(results), std::move(left_terms), + std::move(right_terms)}; +} + +TEST_CASE("sparse_mult eval", "[sparse_mult]") { + auto [results, left, right] = mult_test_case(); + std::vector high_terms; + high_terms.reserve(left.size() * right.size()); + std::vector low_terms; + low_terms.reserve(left.size() * right.size()); + for (const auto l : left) { + for (const auto r : right) { + const auto high = l * r; + high_terms.push_back(high); + const auto low = std::fma(l, r, -high); + low_terms.push_back(low); + } + } + std::span result_span{results}; + const auto result_last = + sparse_mult(std::span{left}, std::span{right}, result_span); + for (const auto v : std::span{result_span.begin(), result_last}) { + // zero pruning check + CHECK(v != real{0}); + } + // check that all of the expected values are in the result + for (const auto v : high_terms) { + if (v != real{0}) { + REQUIRE(std::find(result_span.begin(), result_last, v) != result_last); + } + } + for (const auto v : low_terms) { + if (v != real{0}) { + REQUIRE(std::find(result_span.begin(), result_last, v) != result_last); + } + } + for (const auto v : std::span{result_span.begin(), result_last}) { + // Ensure that v is in either in_high or in_low, but not both + const bool in_high = std::ranges::find(high_terms, v) != high_terms.end(); + const bool in_low = std::ranges::find(low_terms, v) != low_terms.end(); + REQUIRE(in_high != in_low); + } +} + +TEST_CASE("sparse_mult_merge eval", "[sparse_mult_merge]") { + auto [results, left, right] = mult_test_case(); + REQUIRE(is_nonoverlapping(left)); + REQUIRE(is_nonoverlapping(right)); + std::vector expected_results_vec = results; + std::span result_span{results}; + auto result_last = + sparse_mult_merge(left, right, result_span, std::allocator()); + + result_span = std::span{result_span.begin(), result_last}; + REQUIRE(is_nonoverlapping(result_span)); + std::vector nonzero_results; + for (const auto v : result_span) { + // zero-pruning check + REQUIRE(v != real{0}); + nonzero_results.push_back(v); + } + // Check that the result is correct by subtracting values that sum to the same + // thing that a correct implementation produces + std::span expected_results{expected_results_vec}; + auto expected_last = sparse_mult(left, right, expected_results); + expected_results = std::span{expected_results.begin(), expected_last}; + for (const auto v : expected_results) { + nonzero_results.push_back(-v); + } + result_span = std::span{nonzero_results}; + const auto merge_result = merge_sum(result_span); + + REQUIRE(merge_result.first == real{0}); + REQUIRE(merge_result.second == result_span.begin()); + for (const auto v : nonzero_results) { + REQUIRE(v == real{0}); + } +} + +auto mult_inplace_test_case() { + auto [results, left, right] = mult_test_case(); + auto out = results.end(); + for (const auto v : right) { + --out; + *out = v; + } + const auto right_begin = out; + for (const auto v : left) { + --out; + *out = v; + } + const auto left_begin = out; + return std::tuple{std::move(results), std::span{left_begin, right_begin}, + std::span{right_begin, results.end()}}; +} + +TEST_CASE("sparse_mult inplace eval", "[sparse_mult_inplace]") { + auto [results, left, right] = mult_inplace_test_case(); + std::vector high_terms; + high_terms.reserve(2 * left.size() * right.size()); + std::vector low_terms; + low_terms.reserve(left.size() * right.size()); + for (const auto l : left) { + for (const auto r : right) { + const auto high = l * r; + high_terms.push_back(high); + const auto low = std::fma(l, r, -high); + low_terms.push_back(low); + } + } + std::span result_span{results}; + const auto result_last = + sparse_mult(std::span{left}, std::span{right}, result_span); + for (const auto v : std::span{std::span{results}.begin(), result_last}) { + // zero pruning check + CHECK(v != real{0}); + } + // check that all of the expected values are in the result + for (const auto v : high_terms) { + if (v != real{0}) { + REQUIRE(std::find(result_span.begin(), result_last, v) != result_last); + } + } + for (const auto v : low_terms) { + if (v != real{0}) { + REQUIRE(std::find(result_span.begin(), result_last, v) != result_last); + } + } + for (const auto v : std::span{result_span.begin(), result_last}) { + // Ensure that v is in either in_high or in_low + const bool in_high = std::ranges::find(high_terms, v) != high_terms.end(); + const bool in_low = std::ranges::find(low_terms, v) != low_terms.end(); + REQUIRE(in_high != in_low); + } +} + +TEST_CASE("sparse_mult_merge inplace eval", "[sparse_mult_merge_inplace]") { + auto [results, left, right] = mult_test_case(); + REQUIRE(is_nonoverlapping(left)); + REQUIRE(is_nonoverlapping(right)); + std::vector expected_results = results; + sparse_mult(left, right, std::span{expected_results}); + std::span result_span{results}; + auto result_last = + sparse_mult_merge(left, right, result_span, std::allocator()); + REQUIRE(is_nonoverlapping(std::span{result_span.begin(), result_last})); + std::vector nonzero_results; + for (const auto v : std::span{result_span.begin(), result_last}) { + // zero pruning check + REQUIRE(v != real{0}); + nonzero_results.push_back(v); + } + // Check that the result is correct by subtracting values that sum to the same + // thing that a correct implementation produces + for (const auto v : expected_results) { + nonzero_results.push_back(-v); + } + result_span = std::span{nonzero_results}; + const auto merge_result = merge_sum(result_span); + REQUIRE(merge_result.first == real{0}); + REQUIRE(merge_result.second == result_span.begin()); + for (const auto v : nonzero_results) { + REQUIRE(v == real{0}); + } +}