Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
373 changes: 373 additions & 0 deletions LICENSE

Large diffs are not rendered by default.

163 changes: 125 additions & 38 deletions src/ae_adaptive_predicate_eval.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@ template <typename E_, typename eval_type, typename allocator_type_>
class adaptive_eval_impl {
public:
using E = std::remove_cvref_t<E_>;
using allocator_type = std::remove_cvref_t<allocator_type_>;

explicit adaptive_eval_impl(allocator_type_ mem_pool)
: exact_storage{num_partials_for_exact<E>(),
std::forward<allocator_type_>(mem_pool)} {}
adaptive_eval_impl() = default;

explicit adaptive_eval_impl(allocator_type mem_pool_)
: cache(), mem_pool(std::forward<allocator_type_>(mem_pool_)),
exact_storage{num_partials_for_exact<E>(), mem_pool} {}

explicit adaptive_eval_impl(const E_ &) : adaptive_eval_impl() {}

Expand All @@ -78,17 +81,17 @@ class adaptive_eval_impl {
exact_storage.data() + begin_idx, num_partials_for_exact<subexpr_t>()};
}

// exact_eval_root computes the requested result of the (sub)expression to 1/2
// epsilon precision
// exact_eval_round computes the requested result of the (sub)expression to
// 1/2 epsilon precision
template <std::size_t branch_id>
std::pair<eval_type, eval_type> exact_eval_root(evaluatable auto &&expr) {
std::pair<eval_type, eval_type> exact_eval_round(evaluatable auto &&expr) {
using sub_expr = decltype(expr);
auto memory = get_memory<branch_id, sub_expr>();
exact_eval<branch_id>(std::forward<sub_expr>(expr), memory);
_impl::merge_sum(memory);

const eval_type exact_result = std::reduce(memory.begin(), memory.end());
cache[branch_id] = exact_result;
auto result_end =
exact_eval_merge<branch_id>(std::forward<sub_expr>(expr), memory);
const eval_type exact_result = std::reduce(memory.begin(), result_end);
cache[branch_id] =
std::pair{exact_result, std::distance(memory.begin(), result_end)};
return {exact_result, abs(exact_result) *
std::numeric_limits<eval_type>::epsilon() / 2.0};
}
Expand All @@ -108,12 +111,12 @@ class adaptive_eval_impl {
using Op = typename sub_expr::Op;

constexpr std::size_t subexpr_choice_latency =
std::max(exact_fp_rounding_latency<LHS>(),
exact_fp_rounding_latency<RHS>()) +
std::max(exact_fp_rounding_latency<eval_type, LHS>(),
exact_fp_rounding_latency<eval_type, RHS>()) +
2 * overshoot_latency() + 2 * cmp_latency() +
error_contrib_latency<Op>();

if constexpr (exact_fp_rounding_latency<sub_expr>() >
if constexpr (exact_fp_rounding_latency<eval_type, sub_expr>() >
subexpr_choice_latency &&
is_expr_v<LHS> && is_expr_v<RHS>) {
// We need to reduce error efficiently, so don't just exactly evaluate
Expand All @@ -132,7 +135,7 @@ class adaptive_eval_impl {
// guarantees we deal with the largest part of the error, making
// the fall-through case very unlikely
const auto [new_left, new_left_err] =
exact_eval_root<_impl::left_branch_id(branch_id)>(expr.lhs());
exact_eval_round<_impl::left_branch_id(branch_id)>(expr.lhs());

const auto [new_result, new_abs_err] =
_impl::eval_with_max_abs_err<Op>(new_left, new_left_err,
Expand All @@ -144,7 +147,7 @@ class adaptive_eval_impl {
}
} else {
const auto [new_right, new_right_err] =
exact_eval_root<_impl::right_branch_id<sub_expr>(branch_id)>(
exact_eval_round<_impl::right_branch_id<sub_expr>(branch_id)>(
expr.rhs());
const auto [new_result, new_abs_err] =
_impl::eval_with_max_abs_err<Op>(left_result, left_abs_err,
Expand All @@ -157,7 +160,7 @@ class adaptive_eval_impl {
}
}
}
return exact_eval_root<branch_id>(std::forward<sub_expr_>(expr));
return exact_eval_round<branch_id>(std::forward<sub_expr_>(expr));
}

// Returns the result and maximum absolute error from computing the expression
Expand All @@ -167,8 +170,8 @@ class adaptive_eval_impl {
if constexpr (is_expr_v<sub_expr>) {
const auto exact_eval_info = cache[branch_id];
if (exact_eval_info) {
return {*exact_eval_info,
abs(*exact_eval_info) *
return {exact_eval_info->first,
abs(exact_eval_info->first) *
std::numeric_limits<eval_type>::epsilon() / 2.0};
}
using Op = typename sub_expr::Op;
Expand Down Expand Up @@ -218,52 +221,136 @@ class adaptive_eval_impl {
}
}

template <std::size_t branch_id, typename sub_expr>
requires expr_type<sub_expr> || arith_number<sub_expr>
constexpr auto
exact_eval_merge(sub_expr &&e,
std::span<eval_type, num_partials_for_exact<sub_expr>()>
partial_results) noexcept
-> decltype(partial_results.end()) {
auto partial_last =
exact_eval<branch_id>(std::forward<sub_expr>(e), partial_results);
if constexpr (!_impl::linear_merge_lower_latency<eval_type, sub_expr>()) {
// In cases where the linear merge algorithm doesn't make sense, we need
// to ensure we can compute the correctly rounded result with just a
// reduction.
// In cases where the linear merge algorithm does make sense, this is
// already handled by exact_eval
std::span nonzero_results{partial_results.begin(), partial_last};
auto nonzero_last = _impl::merge_sum(nonzero_results).second;
partial_last = partial_results.begin() +
std::distance(nonzero_results.begin(), nonzero_last);
}
return partial_last;
}

template <std::size_t branch_id, typename sub_expr_>
requires expr_type<sub_expr_> || arith_number<sub_expr_>
constexpr void
constexpr auto
exact_eval(sub_expr_ &&e,
std::span<eval_type, num_partials_for_exact<sub_expr_>()>
partial_results) noexcept {
partial_results) noexcept -> decltype(partial_results.end()) {
using sub_expr = std::remove_cvref_t<sub_expr_>;
if constexpr (is_expr_v<sub_expr>) {
if constexpr (num_partials_for_exact<sub_expr>() == 0) {
return partial_results.begin();
} else if constexpr (is_expr_v<sub_expr>) {
if (cache[branch_id]) {
return;
return partial_results.begin() + cache[branch_id]->second;
}
constexpr std::size_t left_id = _impl::left_branch_id(branch_id);
constexpr std::size_t reserve_left =
num_partials_for_exact<typename sub_expr::LHS>();
const auto storage_left = partial_results.template first<reserve_left>();
exact_eval<_impl::left_branch_id(branch_id)>(e.lhs(), storage_left);

constexpr std::size_t reserve_right =
num_partials_for_exact<typename sub_expr::RHS>();
constexpr std::size_t start_left =
num_partials_for_exact<sub_expr>() - reserve_right - reserve_left;
const auto storage_left =
partial_results.template subspan<start_left, reserve_left>();

auto left_end = exact_eval<left_id>(e.lhs(), storage_left);

constexpr std::size_t start_right =
num_partials_for_exact<sub_expr>() - reserve_right;
const auto storage_right =
partial_results.template subspan<reserve_left, reserve_right>();
exact_eval<_impl::right_branch_id<sub_expr>(branch_id)>(e.rhs(),
storage_right);
partial_results.template subspan<start_right, reserve_right>();
auto right_end = exact_eval<_impl::right_branch_id<sub_expr>(branch_id)>(
e.rhs(), storage_right);

if constexpr (_impl::linear_merge_lower_latency<eval_type, sub_expr_>()) {
// Since we're at a point where we want to use linear merge sum,
// we need to ensure the lower levels of the tree have been merged
// together. If they're also at a point where linear merge sum is being
// used, then they're already merged and we don't need to do anything
if constexpr (!_impl::linear_merge_lower_latency<eval_type,
decltype(e.lhs())>()) {
const std::span nonzero_left{storage_left.begin(), left_end};
left_end = storage_left.begin() +
std::distance(nonzero_left.begin(),
_impl::merge_sum(nonzero_left).second);
}
if constexpr (!_impl::linear_merge_lower_latency<eval_type,
decltype(e.rhs())>()) {
const std::span nonzero_right{storage_right.begin(), right_end};
right_end = storage_right.begin() +
std::distance(nonzero_right.begin(),
_impl::merge_sum(nonzero_right).second);
}
}

using Op = typename sub_expr::Op;
if constexpr (std::is_same_v<std::plus<>, Op> ||
std::is_same_v<std::minus<>, Op>) {
if constexpr (std::is_same_v<std::minus<>, Op>) {
for (eval_type &v : storage_right) {
for (eval_type &v : std::span{storage_right.begin(), right_end}) {
v = -v;
}
}
if constexpr (_impl::linear_merge_lower_latency<eval_type,
sub_expr_>()) {
std::span left_span{storage_left.begin(), left_end};
std::vector<eval_type, allocator_type> left_copy(mem_pool);
left_copy.reserve(left_span.size());
for (const auto v : left_span) {
left_copy.push_back(v);
}
return _impl::merge_sum_linear(
partial_results, std::span{left_copy},
std::span{storage_right.begin(), right_end})
.second;
} else {
// We must compact so that the returned end iterator doesn't include
// garbage data
return std::copy(storage_right.begin(), right_end,
partial_results.begin() +
std::distance(storage_left.begin(), left_end));
}
} else if constexpr (std::is_same_v<std::multiplies<>, Op>) {
const auto storage_mult =
partial_results.template last<partial_results.size() -
reserve_left - reserve_right>();
_impl::sparse_mult(storage_left, storage_right, storage_mult);
if constexpr (_impl::linear_merge_lower_latency<eval_type,
sub_expr_>()) {
return _impl::sparse_mult_merge(
std::span{storage_left.begin(), left_end},
std::span{storage_right.begin(), right_end}, partial_results,
mem_pool);
} else {
return _impl::sparse_mult(storage_left, storage_right,
partial_results);
}
}
} else if constexpr (!std::is_same_v<additive_id, sub_expr>) {
// additive_id is zero, so we don't actually have memory allocated for it
} else {
*partial_results.begin() = eval_type(e);
return partial_results.begin() + 1;
}
}

std::array<std::optional<eval_type>, num_internal_nodes<E>()> cache;
// Cache the rounded value and the index of the end of the non-zero element
// subspan
std::array<std::optional<std::pair<eval_type, std::size_t>>,
num_internal_nodes<E>()>
cache;

allocator_type mem_pool;

std::vector<eval_type, std::remove_cvref_t<allocator_type_>> exact_storage;
std::vector<eval_type, allocator_type> exact_storage;
};

} // namespace _impl
Expand Down
Loading