diff --git a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp index 59c93db1..d503298e 100644 --- a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp +++ b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp @@ -28,6 +28,34 @@ namespace dwave::optimization { /// A contiguous block of numbers. class NumberNode : public ArrayOutputMixin, public DecisionNode { public: + /// Allowable axis-wise bound operators. + enum BoundAxisOperator { Equal, LessEqual, GreaterEqual }; + + /// Struct for stateless axis-wise bound information. Given an `axis`, define + /// constraints on the sum of the values in each slice along `axis`. + /// Constraints can be defined for ALL slices along `axis` or PER slice along + /// `axis`. Allowable operators are defined by `BoundAxisOperator`. + struct BoundAxisInfo { + /// To reduce the # of `IntegerNode` and `BinaryNode` constructors, we + /// allow only one constructor. + BoundAxisInfo(ssize_t axis, std::vector axis_operators, + std::vector axis_bounds); + /// The bound axis + const ssize_t axis; + /// Operator for ALL axis slices (vector has length one) or operator*s* PER + /// slice (length of vector is equal to the number of slices). + const std::vector operators; + /// Bound for ALL axis slices (vector has length one) or bound*s* PER slice + /// (length of vector is equal to the number of slices). + const std::vector bounds; + + /// Obtain the bound associated with a given slice along `axis`. + double get_bound(const ssize_t slice) const; + + /// Obtain the operator associated with a given slice along `axis`. + BoundAxisOperator get_operator(const ssize_t slice) const; + }; + NumberNode() = delete; // Overloads needed by the Array ABC ************************************** @@ -68,6 +96,12 @@ class NumberNode : public ArrayOutputMixin, public DecisionNode { // Initialize the state of the node randomly template void initialize_state(State& state, Generator& rng) const { + // Currently, we do not support random node Initialization with + // axis wise bounds. + if (bound_axes_info_.size() > 0) { + throw std::invalid_argument("Cannot randomly initialize_state with bound axes"); + } + std::vector values; const ssize_t size = this->size(); values.reserve(size); @@ -106,21 +140,38 @@ class NumberNode : public ArrayOutputMixin, public DecisionNode { // in a given index. void clip_and_set_value(State& state, ssize_t index, double value) const; + /// Return vector of axis-wise bounds. + const std::vector& axis_wise_bounds() const; + + /// Return vector containing the bound axis sums in a given state. + const std::vector>& bound_axis_sums(State& state) const; + protected: explicit NumberNode(std::span shape, std::vector lower_bound, - std::vector upper_bound); + std::vector upper_bound, + std::optional> bound_axes = std::nullopt); - // Return truth statement: 'value is valid in a given index'. + /// Return truth statement: 'value is valid in a given index'. virtual bool is_valid(ssize_t index, double value) const = 0; - // Default value in a given index. + /// Default value in a given index. virtual double default_value(ssize_t index) const = 0; + /// Update the running bound axis sums where the value stored at `index` is + /// changed by `value_change` in a given state. + void update_bound_axis_slice_sums(State& state, const ssize_t index, + const double value_change) const; + + /// Statelss global minimum and maximum of the values stored in NumberNode. double min_; double max_; + /// Stateless index-wise upper and lower bounds. std::vector lower_bounds_; std::vector upper_bounds_; + + /// Stateless information on each bound axis. + const std::vector bound_axes_info_; }; /// A contiguous block of integer numbers. @@ -134,33 +185,45 @@ class IntegerNode : public NumberNode { // Default to a single scalar integer with default bounds IntegerNode() : IntegerNode({}) {} - // Create an integer array with the user-defined bounds. - // Defaulting to the specified default bounds. + // Create an integer array with the user-defined index- and axis-wise bounds. + // Index-wise bounds default to the specified default bounds. IntegerNode(std::span shape, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); IntegerNode(std::initializer_list shape, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); IntegerNode(ssize_t size, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); IntegerNode(std::span shape, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); IntegerNode(std::initializer_list shape, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); IntegerNode(ssize_t size, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); IntegerNode(std::span shape, std::optional> lower_bound, - double upper_bound); + double upper_bound, + std::optional> bound_axes = std::nullopt); IntegerNode(std::initializer_list shape, - std::optional> lower_bound, double upper_bound); - IntegerNode(ssize_t size, std::optional> lower_bound, double upper_bound); - - IntegerNode(std::span shape, double lower_bound, double upper_bound); - IntegerNode(std::initializer_list shape, double lower_bound, double upper_bound); - IntegerNode(ssize_t size, double lower_bound, double upper_bound); + std::optional> lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); + IntegerNode(ssize_t size, std::optional> lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); + + IntegerNode(std::span shape, double lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); + IntegerNode(std::initializer_list shape, double lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); + IntegerNode(ssize_t size, double lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); // Overloads needed by the Node ABC *************************************** @@ -190,33 +253,45 @@ class BinaryNode : public IntegerNode { /// A binary scalar variable with lower_bound = 0.0 and upper_bound = 1.0 BinaryNode() : BinaryNode({}) {} - // Create a binary array with the user-defined bounds. - // Defaulting to lower_bound = 0.0 and upper_bound = 1.0 + // Create a binary array with the user-defined index- and axis-wise bounds. + // Index-wise bounds default to lower_bound = 0.0 and upper_bound = 1.0. BinaryNode(std::span shape, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); BinaryNode(std::initializer_list shape, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); BinaryNode(ssize_t size, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); BinaryNode(std::span shape, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); BinaryNode(std::initializer_list shape, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); BinaryNode(ssize_t size, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); BinaryNode(std::span shape, std::optional> lower_bound, - double upper_bound); + double upper_bound, + std::optional> bound_axes = std::nullopt); BinaryNode(std::initializer_list shape, std::optional> lower_bound, - double upper_bound); - BinaryNode(ssize_t size, std::optional> lower_bound, double upper_bound); - - BinaryNode(std::span shape, double lower_bound, double upper_bound); - BinaryNode(std::initializer_list shape, double lower_bound, double upper_bound); - BinaryNode(ssize_t size, double lower_bound, double upper_bound); + double upper_bound, + std::optional> bound_axes = std::nullopt); + BinaryNode(ssize_t size, std::optional> lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); + + BinaryNode(std::span shape, double lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); + BinaryNode(std::initializer_list shape, double lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); + BinaryNode(ssize_t size, double lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); // Flip the value (0 -> 1 or 1 -> 0) at index i in the given state. void flip(State& state, ssize_t i) const; diff --git a/dwave/optimization/src/nodes/numbers.cpp b/dwave/optimization/src/nodes/numbers.cpp index 5ad26c99..d7525b3f 100644 --- a/dwave/optimization/src/nodes/numbers.cpp +++ b/dwave/optimization/src/nodes/numbers.cpp @@ -15,79 +15,360 @@ #include "dwave-optimization/nodes/numbers.hpp" #include +#include +#include #include +#include #include #include +#include #include "_state.hpp" +#include "dwave-optimization/array.hpp" +#include "dwave-optimization/common.hpp" namespace dwave::optimization { -// Base class to be used as interfaces. +NumberNode::BoundAxisInfo::BoundAxisInfo(ssize_t bound_axis, + std::vector axis_operators, + std::vector axis_bounds) + : axis(bound_axis), operators(std::move(axis_operators)), bounds(std::move(axis_bounds)) { + const size_t num_operators = operators.size(); + const size_t num_bounds = bounds.size(); + + if ((num_operators == 0) || (num_bounds == 0)) { + throw std::invalid_argument("Axis-wise `operators` and `bounds` must have non-zero size."); + } + + // If `operators` and `bounds` are both defined PER hyperslice along + // `axis`, they must have the same size. + if ((num_operators > 1) && (num_bounds > 1) && (num_bounds != num_operators)) { + throw std::invalid_argument( + "Axis-wise `operators` and `bounds` should have same size if neither has size 1."); + } +} + +double NumberNode::BoundAxisInfo::get_bound(const ssize_t slice) const { + assert(0 <= slice); + if (bounds.size() == 0) return bounds[0]; + assert(slice < static_cast(bounds.size())); + return bounds[slice]; +} + +NumberNode::BoundAxisOperator NumberNode::BoundAxisInfo::get_operator(const ssize_t slice) const { + assert(0 <= slice); + if (operators.size() == 0) return operators[0]; + assert(slice < static_cast(operators.size())); + return operators[slice]; +} + +/// State dependant data attached to NumberNode +struct NumberNodeStateData : public ArrayNodeStateData { + NumberNodeStateData(std::vector input) : ArrayNodeStateData(std::move(input)) {} + NumberNodeStateData(std::vector input, std::vector> bound_axes_sums) + : ArrayNodeStateData(std::move(input)), + bound_axes_sums(std::move(bound_axes_sums)), + prior_bound_axes_sums(this->bound_axes_sums) {} + /// For each bound axis and for each hyperslice along said axis, we + /// track the sum of the values within the hyperslice. + /// bound_axes_sums[i][j] = "sum of the values within the jth + /// hyperslice along the ith bound axis" + std::vector> bound_axes_sums; + // Store a copy for NumberNode::revert() and commit() + std::vector> prior_bound_axes_sums; +}; double const* NumberNode::buff(const State& state) const noexcept { - return data_ptr(state)->buff(); + return data_ptr(state)->buff(); } std::span NumberNode::diff(const State& state) const noexcept { - return data_ptr(state)->diff(); + return data_ptr(state)->diff(); } double NumberNode::min() const { return min_; } double NumberNode::max() const { return max_; } +/// Given a NumberNode and an assingnment of it's variables (number_data), +/// compute and return a vector containing the sum of the values within each +/// hyperslice along each bound axis. +std::vector> get_bound_axes_sums(const NumberNode* node, + const std::vector& number_data) { + std::span node_shape = node->shape(); + const auto& bound_axes_info = node->axis_wise_bounds(); + const ssize_t num_bound_axes = static_cast(bound_axes_info.size()); + assert(num_bound_axes <= static_cast(node_shape.size())); + assert(std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()) == + static_cast(number_data.size())); + + // For each bound axis, initialize the sum of the values contained in each + // of it's hyperslice to 0. Define bound_axes_sums[i][j] = "sum of the + // values within the jth hyperslice along the ith bound axis" + std::vector> bound_axes_sums; + bound_axes_sums.reserve(num_bound_axes); + for (const NumberNode::BoundAxisInfo& axis_info : bound_axes_info) { + assert(0 <= axis_info.axis && axis_info.axis < static_cast(node_shape.size())); + bound_axes_sums.emplace_back(node_shape[axis_info.axis], 0.0); + } + + // Define a BufferIterator for `number_data` given the shape and strides of + // NumberNode and iterate over it. + for (BufferIterator it(number_data.data(), node_shape, node->strides()); + it != std::default_sentinel; ++it) { + // Increment the appropriate hyperslice along each bound axis. + for (ssize_t bound_axis = 0; bound_axis < num_bound_axes; ++bound_axis) { + const ssize_t axis = bound_axes_info[bound_axis].axis; + assert(0 <= axis && axis < static_cast(it.location().size())); + const ssize_t slice = it.location()[axis]; + assert(0 <= slice && slice < static_cast(bound_axes_sums[bound_axis].size())); + bound_axes_sums[bound_axis][slice] += *it; + } + } + + return bound_axes_sums; +} + +/// Determine whether the sum of the values within each hyperslice along +/// each bound axis satisfies the axis-wise bounds. +bool satisfies_axis_wise_bounds(const std::vector& bound_axes_info, + const std::vector>& bound_axes_sums) { + assert(bound_axes_info.size() == bound_axes_sums.size()); + // Check that each hyperslice satisfies the axis-wise bounds. + for (ssize_t i = 0, stop_i = static_cast(bound_axes_info.size()); i < stop_i; ++i) { + const auto& bound_axis_info = bound_axes_info[i]; + const auto& bound_axis_sums = bound_axes_sums[i]; + + for (ssize_t slice = 0, stop_slice = static_cast(bound_axis_sums.size()); + slice < stop_slice; ++slice) { + switch (bound_axis_info.get_operator(slice)) { + case NumberNode::Equal: + if (bound_axis_sums[slice] != bound_axis_info.get_bound(slice)) return false; + break; + case NumberNode::LessEqual: + if (bound_axis_sums[slice] > bound_axis_info.get_bound(slice)) return false; + break; + case NumberNode::GreaterEqual: + if (bound_axis_sums[slice] < bound_axis_info.get_bound(slice)) return false; + break; + default: + unreachable(); + } + } + } + return true; +} + void NumberNode::initialize_state(State& state, std::vector&& number_data) const { if (number_data.size() != static_cast(this->size())) { throw std::invalid_argument("Size of data provided does not match node size"); } + for (ssize_t index = 0, stop = this->size(); index < stop; ++index) { if (!is_valid(index, number_data[index])) { throw std::invalid_argument("Invalid data provided for node"); } } - emplace_data_ptr(state, std::move(number_data)); + if (bound_axes_info_.size() == 0) { // No bound axes to consider. + emplace_data_ptr(state, std::move(number_data)); + return; + } + + // Given the assingnment to NumberNode, `number_data`, get the sum of the + // values within each hyperslice along each bound axis. + std::vector> bound_axes_sums = get_bound_axes_sums(this, number_data); + + if (!satisfies_axis_wise_bounds(bound_axes_info_, bound_axes_sums)) { + throw std::invalid_argument("Initialized values do not satisfy axis-wise bounds."); + } + + emplace_data_ptr(state, std::move(number_data), + std::move(bound_axes_sums)); +} + +/// Given a `span` (typically containing strides or shape), we reorder the +/// values of the span such that the given `axis` is moved to the 0th index. +std::vector reorder_to_move_along_axis(const std::span span, + const ssize_t axis) { + const ssize_t ndim = span.size(); + std::vector output; + output.reserve(ndim); + output.emplace_back(span[axis]); + + for (ssize_t i = 0; i < ndim; ++i) { + if (i != axis) output.emplace_back(span[i]); + } + return output; +} + +/// Given a `slice` along a bound axis in a NumberNode where the sum of it's +/// values are given by `sum`, determine the non-negative amount `delta` +/// needed to be added to `sum` to satisfy the expression: (sum+delta) op bound +/// e.g. Given (sum, op, bound) := (10, ==, 12), delta = 2 +/// e.g. Given (sum, op, bound) := (10, <=, 12), delta = 0 +/// e.g. Given (sum, op, bound) := (10, >=, 12), delta = 2 +/// Throws an error if `delta` is negative (corresponding with an infeasible axis-wise bound); +double compute_bound_axis_slice_delta(const ssize_t slice, const double sum, + const NumberNode::BoundAxisOperator op, const double bound) { + switch (op) { + case NumberNode::Equal: + if (sum > bound) throw std::invalid_argument("Infeasible axis-wise bounds."); + // If error was not thrown, return amount needed to satisfy bound. + return bound - sum; + case NumberNode::LessEqual: + if (sum > bound) throw std::invalid_argument("Infeasible axis-wise bounds."); + // If error was not thrown, sum satisfies bound. + return 0.0; + case NumberNode::GreaterEqual: + // If sum is less than bound, return the amount needed to equal it. + if (sum < bound) return bound - sum; + // Otherwise, sum satisfies bound. + return 0.0; + default: + unreachable(); + } +} + +/// Given a NumberNod and exactly one axis-wise bound defined for NumberNode, +/// assign values to `values` (in-place) to satisfy the axis-wise bound. This method +/// 1) Initially sets `values[i] = lower_bound(i)` for all i. +/// 2) Incremements the values within each hyperslice until they satisfy +/// the axis-wise bound (should this be possible). +void construct_state_given_exactly_one_bound_axis(const NumberNode* node, + std::vector& values) { + const std::span node_shape = node->shape(); + const ssize_t ndim = node_shape.size(); + + // 1) Initialize all elements to their lower bounds. + for (ssize_t i = 0, stop = node->size(); i < stop; ++i) { + values.push_back(node->lower_bound(i)); + } + // 2) Determine the hyperslice sums for the bound axis. This could be + // done during the previous loop if we want to improve performance. + assert(node->axis_wise_bounds().size() == 1); + const std::vector bound_axis_sums = get_bound_axes_sums(node, values)[0]; + // Obtain the axis-wise bound + const NumberNode::BoundAxisInfo& bound_axis_info = node->axis_wise_bounds()[0]; + const ssize_t bound_axis = bound_axis_info.axis; + assert(0 <= bound_axis && bound_axis < ndim); + + // We need a way to iterate over each hyperslice along the bound axis and + // adjust it`s values until they satisfy the axis-wise bounds. We do this + // by defining an iterator of `values` that can be used to iterate over the + // values within each hyperslice along the bound axis one after another. We + // can do this by modifying the NumberNode shape and strides such that the + // data for the bound_axis is moved to position 0 (remaining indices are + // shifted back). + std::vector new_shape = reorder_to_move_along_axis(node_shape, bound_axis); + std::vector new_strides = reorder_to_move_along_axis(node->strides(), bound_axis); + // Define an iterator for `values` corresponding to the beginning of slice + // 0 along the bound axis. This iterater will be used to define the start + // of a slice iterater. + BufferIterator slice_0_it(values.data(), ndim, new_shape.data(), + new_strides.data()); + // Determine the size of each slice along the bound axis. + const ssize_t slice_size = std::accumulate(new_shape.begin() + 1, new_shape.end(), 1.0, + std::multiplies()); + + // 3) Iterate over each hyperslice and adjust it's values until they + // satisfy the axis-wise bounds. + for (ssize_t slice = 0, stop = node_shape[bound_axis]; slice < stop; ++slice) { + // Determine the amount we need to adjust the initialized values within + // the slice. + double delta = compute_bound_axis_slice_delta(slice, bound_axis_sums[slice], + bound_axis_info.get_operator(slice), + bound_axis_info.get_bound(slice)); + if (delta == 0) continue; // Axis-wise bounds are satisfied for slice. + assert(delta >= 0); // Should only increment. + + // Determine how much we need to offset slice_0_it to get to the first + // value in the given `slice` + const ssize_t offset = slice * slice_size; + + for (auto slice_it = slice_0_it + offset, slice_end_it = slice_0_it + offset + slice_size; + slice_it != slice_end_it; ++slice_it) { + assert(slice_it.location()[0] == slice); // We should be in the right slice. + + // Determine the index of `it` from `slice_0_it` + const ssize_t index = static_cast(slice_it - slice_0_it); + assert(0 <= index && index < static_cast(values.size())); + // Determine the amount we can increment the value in the given index. + ssize_t inc = std::min(delta, node->upper_bound(index) - *slice_it); + + if (inc > 0) { // Apply the increment to both `it` and `delta`. + *slice_it += inc; + delta -= inc; + if (delta == 0) break; // Axis-wise bounds are now satisfied for slice. + } + } + + if (delta != 0) throw std::invalid_argument("Infeasible axis-wise bounds."); + } } void NumberNode::initialize_state(State& state) const { std::vector values; values.reserve(this->size()); - for (ssize_t i = 0, stop = this->size(); i < stop; ++i) { - values.push_back(default_value(i)); + + if (bound_axes_info_.size() == 0) { // No bound axes to consider + for (ssize_t i = 0, stop = this->size(); i < stop; ++i) { + values.push_back(default_value(i)); + } + initialize_state(state, std::move(values)); + return; + } else if (bound_axes_info_.size() == 1) { + construct_state_given_exactly_one_bound_axis(this, values); + initialize_state(state, std::move(values)); + return; } - initialize_state(state, std::move(values)); + + throw std::invalid_argument("Cannot initialize state with multiple bound axes."); } void NumberNode::commit(State& state) const noexcept { - data_ptr(state)->commit(); + auto node_data = data_ptr(state); + // Manually store a copy of bound_axes_sums. + node_data->prior_bound_axes_sums = node_data->bound_axes_sums; + node_data->commit(); } void NumberNode::revert(State& state) const noexcept { - data_ptr(state)->revert(); + auto node_data = data_ptr(state); + // Manually reset bound_axes_sums. + node_data->bound_axes_sums = node_data->prior_bound_axes_sums; + node_data->revert(); } void NumberNode::exchange(State& state, ssize_t i, ssize_t j) const { - auto ptr = data_ptr(state); + auto ptr = data_ptr(state); // We expect the exchange to obey the index-wise bounds. assert(lower_bound(i) <= ptr->get(j)); assert(upper_bound(i) >= ptr->get(j)); assert(lower_bound(j) <= ptr->get(i)); assert(upper_bound(j) >= ptr->get(i)); - // Assert that i and j are valid indices occurs in ptr->exchange(). - // Exchange occurs IFF (i != j) and (buffer[i] != buffer[j]). - ptr->exchange(i, j); + // assert() that i and j are valid indices occurs in ptr->exchange(). + // State change occurs IFF (i != j) and (buffer[i] != buffer[j]). + if (ptr->exchange(i, j)) { + // If the values at indices i and j were exchanged, update the bound + // axis sums. + const double difference = ptr->get(i) - ptr->get(j); + // Index i changed from (what is now) ptr->get(j) to ptr->get(i) + update_bound_axis_slice_sums(state, i, difference); + // Index j changed from (what is now) ptr->get(i) to ptr->get(j) + update_bound_axis_slice_sums(state, j, -difference); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); + } } double NumberNode::get_value(State& state, ssize_t i) const { - return data_ptr(state)->get(i); + return data_ptr(state)->get(i); } double NumberNode::lower_bound(ssize_t index) const { if (lower_bounds_.size() == 1) { return lower_bounds_[0]; } - assert(lower_bounds_.size() > 1); assert(0 <= index && index < static_cast(lower_bounds_.size())); return lower_bounds_[index]; } @@ -104,7 +385,6 @@ double NumberNode::upper_bound(ssize_t index) const { if (upper_bounds_.size() == 1) { return upper_bounds_[0]; } - assert(upper_bounds_.size() > 1); assert(0 <= index && index < static_cast(upper_bounds_.size())); return upper_bounds_[index]; } @@ -118,10 +398,22 @@ double NumberNode::upper_bound() const { } void NumberNode::clip_and_set_value(State& state, ssize_t index, double value) const { + auto ptr = data_ptr(state); value = std::clamp(value, lower_bound(index), upper_bound(index)); - // Assert that i is a valid index occurs in data_ptr->set(). - // Set occurs IFF `value` != buffer[i] . - data_ptr(state)->set(index, value); + // assert() that i is a valid index occurs in ptr->set(). + // State change occurs IFF `value` != buffer[index] . + if (ptr->set(index, value)) { + update_bound_axis_slice_sums(state, index, value - diff(state).back().old); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); + } +} + +const std::vector& NumberNode::axis_wise_bounds() const { + return bound_axes_info_; +} + +const std::vector>& NumberNode::bound_axis_sums(State& state) const { + return data_ptr(state)->bound_axes_sums; } template @@ -164,13 +456,62 @@ void check_index_wise_bounds(const NumberNode& node, const std::vector& } } +/// Check the user defined axis-wise bounds for NumberNode +void check_axis_wise_bounds(const std::vector& bound_axes_info, + const std::span shape) { + if (bound_axes_info.size() == 0) return; // No bound axes to check. + + // Used to asses if an axis have been bound multiple times. + std::vector axis_bound(shape.size(), false); + + // For each set of bound axis data + for (const NumberNode::BoundAxisInfo& bound_axis_info : bound_axes_info) { + const ssize_t axis = bound_axis_info.axis; + + if (axis < 0 || axis >= static_cast(shape.size())) { + throw std::invalid_argument("Invalid bound axis given number array shape."); + } + + // The number of operators defined for the given bound axis + const ssize_t num_operators = static_cast(bound_axis_info.operators.size()); + if ((num_operators > 1) && (num_operators != shape[axis])) { + throw std::invalid_argument( + "Invalid number of axis-wise operators given number array shape."); + } + + // The number of operators defined for the given bound axis + const ssize_t num_bounds = static_cast(bound_axis_info.bounds.size()); + if ((num_bounds > 1) && (num_bounds != shape[axis])) { + throw std::invalid_argument( + "Invalid number of axis-wise bounds given number array shape."); + } + + // Checked in BoundAxisInfo constructor + assert(num_operators == num_bounds || num_operators == 1 || num_bounds == 1); + + if (axis_bound[axis]) { + throw std::invalid_argument( + "Cannot define multiple axis-wise bounds for a single axis."); + } + axis_bound[axis] = true; + } + + // *Currently*, we only support axis-wise bounds for up to one axis. + if (bound_axes_info.size() > 1) { + throw std::invalid_argument("Axis-wise bounds are supported for at most one axis."); + } +} + +// Base class to be used as interfaces. NumberNode::NumberNode(std::span shape, std::vector lower_bound, - std::vector upper_bound) + std::vector upper_bound, + std::optional> bound_axes) : ArrayOutputMixin(shape), min_(get_extreme_index_wise_bound(lower_bound)), max_(get_extreme_index_wise_bound(upper_bound)), lower_bounds_(std::move(lower_bound)), - upper_bounds_(std::move(upper_bound)) { + upper_bounds_(std::move(upper_bound)), + bound_axes_info_(bound_axes ? std::move(*bound_axes) : std::vector{}) { if ((shape.size() > 0) && (shape[0] < 0)) { throw std::invalid_argument("Number array cannot have dynamic size."); } @@ -180,59 +521,123 @@ NumberNode::NumberNode(std::span shape, std::vector lower } check_index_wise_bounds(*this, lower_bounds_, upper_bounds_); + check_axis_wise_bounds(bound_axes_info_, this->shape()); +} + +void NumberNode::update_bound_axis_slice_sums(State& state, const ssize_t index, + const double value_change) const { + const auto& bound_axes_info = bound_axes_info_; + if (bound_axes_info.size() == 0) return; // No axis-wise bounds to satisfy + + // Get multidimensional indices for `index` so we can identify the slices + // `index` lies on per bound axis. + const std::vector multi_index = unravel_index(index, this->shape()); + assert(bound_axes_info.size() <= multi_index.size()); + // Get the hyperslice sums of all bound axes. + auto& bound_axes_sums = data_ptr(state)->bound_axes_sums; + assert(bound_axes_info.size() == bound_axes_sums.size()); + + for (ssize_t bound_axis = 0, stop = static_cast(bound_axes_info.size()); + bound_axis < stop; ++bound_axis) { + assert(0 <= bound_axes_info[bound_axis].axis); + assert(bound_axes_info[bound_axis].axis < static_cast(multi_index.size())); + // Get the slice along the bound axis the `value_change` occurs in + const ssize_t slice = multi_index[bound_axes_info[bound_axis].axis]; + assert(0 <= slice && slice < static_cast(bound_axes_sums[bound_axis].size())); + // Offset running sum in slice + bound_axes_sums[bound_axis][slice] += value_change; + } } // Integer Node *************************************************************** +/// Check the user defined axis-wise bounds for IntegerNode +void check_integrality_of_axis_wise_bounds( + const std::vector& bound_axes_info) { + if (bound_axes_info.size() == 0) return; // No bound axes to check. + + for (const NumberNode::BoundAxisInfo& bound_axis_info : bound_axes_info) { + for (const double& bound : bound_axis_info.bounds) { + if (bound != std::round(bound)) { + throw std::invalid_argument( + "Axis wise bounds for integral number arrays must be intregral."); + } + } + } +} + IntegerNode::IntegerNode(std::span shape, std::optional> lower_bound, - std::optional> upper_bound) + std::optional> upper_bound, + std::optional> bound_axes) : NumberNode(shape, lower_bound.has_value() ? std::move(*lower_bound) : std::vector{default_lower_bound}, upper_bound.has_value() ? std::move(*upper_bound) - : std::vector{default_upper_bound}) { + : std::vector{default_upper_bound}, + std::move(bound_axes)) { if (min_ < minimum_lower_bound || max_ > maximum_upper_bound) { throw std::invalid_argument("range provided for integers exceeds supported range"); } + + check_integrality_of_axis_wise_bounds(bound_axes_info_); } IntegerNode::IntegerNode(std::initializer_list shape, std::optional> lower_bound, - std::optional> upper_bound) - : IntegerNode(std::span(shape), std::move(lower_bound), std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : IntegerNode(std::span(shape), std::move(lower_bound), std::move(upper_bound), + std::move(bound_axes)) {} IntegerNode::IntegerNode(ssize_t size, std::optional> lower_bound, - std::optional> upper_bound) - : IntegerNode({size}, std::move(lower_bound), std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : IntegerNode({size}, std::move(lower_bound), std::move(upper_bound), + std::move(bound_axes)) {} IntegerNode::IntegerNode(std::span shape, double lower_bound, - std::optional> upper_bound) - : IntegerNode(shape, std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : IntegerNode(shape, std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} IntegerNode::IntegerNode(std::initializer_list shape, double lower_bound, - std::optional> upper_bound) - : IntegerNode(std::span(shape), std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : IntegerNode(std::span(shape), std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} IntegerNode::IntegerNode(ssize_t size, double lower_bound, - std::optional> upper_bound) - : IntegerNode({size}, std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : IntegerNode({size}, std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} IntegerNode::IntegerNode(std::span shape, - std::optional> lower_bound, double upper_bound) - : IntegerNode(shape, std::move(lower_bound), std::vector{upper_bound}) {} + std::optional> lower_bound, double upper_bound, + std::optional> bound_axes) + : IntegerNode(shape, std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} IntegerNode::IntegerNode(std::initializer_list shape, - std::optional> lower_bound, double upper_bound) - : IntegerNode(std::span(shape), std::move(lower_bound), std::vector{upper_bound}) {} + std::optional> lower_bound, double upper_bound, + std::optional> bound_axes) + : IntegerNode(std::span(shape), std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} IntegerNode::IntegerNode(ssize_t size, std::optional> lower_bound, - double upper_bound) - : IntegerNode({size}, std::move(lower_bound), std::vector{upper_bound}) {} - -IntegerNode::IntegerNode(std::span shape, double lower_bound, double upper_bound) - : IntegerNode(shape, std::vector{lower_bound}, std::vector{upper_bound}) {} + double upper_bound, std::optional> bound_axes) + : IntegerNode({size}, std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} + +IntegerNode::IntegerNode(std::span shape, double lower_bound, double upper_bound, + std::optional> bound_axes) + : IntegerNode(shape, std::vector{lower_bound}, std::vector{upper_bound}, + std::move(bound_axes)) {} IntegerNode::IntegerNode(std::initializer_list shape, double lower_bound, - double upper_bound) + double upper_bound, std::optional> bound_axes) : IntegerNode(std::span(shape), std::vector{lower_bound}, - std::vector{upper_bound}) {} -IntegerNode::IntegerNode(ssize_t size, double lower_bound, double upper_bound) - : IntegerNode({size}, std::vector{lower_bound}, std::vector{upper_bound}) {} + std::vector{upper_bound}, std::move(bound_axes)) {} +IntegerNode::IntegerNode(ssize_t size, double lower_bound, double upper_bound, + std::optional> bound_axes) + : IntegerNode({size}, std::vector{lower_bound}, std::vector{upper_bound}, + std::move(bound_axes)) {} bool IntegerNode::integral() const { return true; } @@ -242,13 +647,17 @@ bool IntegerNode::is_valid(ssize_t index, double value) const { } void IntegerNode::set_value(State& state, ssize_t index, double value) const { + auto ptr = data_ptr(state); // We expect `value` to obey the index-wise bounds and to be an integer. assert(lower_bound(index) <= value); assert(upper_bound(index) >= value); assert(value == std::round(value)); - // Assert that i is a valid index occurs in data_ptr->set(). - // Set occurs IFF `value` != buffer[i] . - data_ptr(state)->set(index, value); + // assert() that i is a valid index occurs in ptr->set(). + // State change occurs IFF `value` != buffer[index]. + if (ptr->set(index, value)) { + update_bound_axis_slice_sums(state, index, value - diff(state).back().old); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); + } } double IntegerNode::default_value(ssize_t index) const { @@ -287,69 +696,105 @@ std::vector limit_bound_to_bool_domain(std::optional BinaryNode::BinaryNode(std::span shape, std::optional> lower_bound, - std::optional> upper_bound) + std::optional> upper_bound, + std::optional> bound_axes) : IntegerNode(shape, limit_bound_to_bool_domain(lower_bound), - limit_bound_to_bool_domain(upper_bound)) {} + limit_bound_to_bool_domain(upper_bound), bound_axes) {} BinaryNode::BinaryNode(std::initializer_list shape, std::optional> lower_bound, - std::optional> upper_bound) - : BinaryNode(std::span(shape), std::move(lower_bound), std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : BinaryNode(std::span(shape), std::move(lower_bound), std::move(upper_bound), + std::move(bound_axes)) {} BinaryNode::BinaryNode(ssize_t size, std::optional> lower_bound, - std::optional> upper_bound) - : BinaryNode({size}, std::move(lower_bound), std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : BinaryNode({size}, std::move(lower_bound), std::move(upper_bound), + std::move(bound_axes)) {} BinaryNode::BinaryNode(std::span shape, double lower_bound, - std::optional> upper_bound) - : BinaryNode(shape, std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : BinaryNode(shape, std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} BinaryNode::BinaryNode(std::initializer_list shape, double lower_bound, - std::optional> upper_bound) - : BinaryNode(std::span(shape), std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : BinaryNode(std::span(shape), std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} BinaryNode::BinaryNode(ssize_t size, double lower_bound, - std::optional> upper_bound) - : BinaryNode({size}, std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : BinaryNode({size}, std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} BinaryNode::BinaryNode(std::span shape, - std::optional> lower_bound, double upper_bound) - : BinaryNode(shape, std::move(lower_bound), std::vector{upper_bound}) {} + std::optional> lower_bound, double upper_bound, + std::optional> bound_axes) + : BinaryNode(shape, std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} BinaryNode::BinaryNode(std::initializer_list shape, - std::optional> lower_bound, double upper_bound) - : BinaryNode(std::span(shape), std::move(lower_bound), std::vector{upper_bound}) {} + std::optional> lower_bound, double upper_bound, + std::optional> bound_axes) + : BinaryNode(std::span(shape), std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} BinaryNode::BinaryNode(ssize_t size, std::optional> lower_bound, - double upper_bound) - : BinaryNode({size}, std::move(lower_bound), std::vector{upper_bound}) {} - -BinaryNode::BinaryNode(std::span shape, double lower_bound, double upper_bound) - : BinaryNode(shape, std::vector{lower_bound}, std::vector{upper_bound}) {} -BinaryNode::BinaryNode(std::initializer_list shape, double lower_bound, double upper_bound) + double upper_bound, std::optional> bound_axes) + : BinaryNode({size}, std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} + +BinaryNode::BinaryNode(std::span shape, double lower_bound, double upper_bound, + std::optional> bound_axes) + : BinaryNode(shape, std::vector{lower_bound}, std::vector{upper_bound}, + std::move(bound_axes)) {} +BinaryNode::BinaryNode(std::initializer_list shape, double lower_bound, double upper_bound, + std::optional> bound_axes) : BinaryNode(std::span(shape), std::vector{lower_bound}, - std::vector{upper_bound}) {} -BinaryNode::BinaryNode(ssize_t size, double lower_bound, double upper_bound) - : BinaryNode({size}, std::vector{lower_bound}, std::vector{upper_bound}) {} + std::vector{upper_bound}, std::move(bound_axes)) {} +BinaryNode::BinaryNode(ssize_t size, double lower_bound, double upper_bound, + std::optional> bound_axes) + : BinaryNode({size}, std::vector{lower_bound}, std::vector{upper_bound}, + std::move(bound_axes)) {} void BinaryNode::flip(State& state, ssize_t i) const { - auto ptr = data_ptr(state); + auto ptr = data_ptr(state); // Variable should not be fixed. assert(lower_bound(i) != upper_bound(i)); - // Assert that i is a valid index occurs in ptr->set(). - // Set occurs IFF `value` != buffer[i] . - ptr->set(i, !ptr->get(i)); + // assert() that i is a valid index occurs in ptr->set(). + // State change occurs IFF `value` != buffer[i]. + if (ptr->set(i, !ptr->get(i))) { + // If value changed from 0 -> 1, update the bound axis sums by 1. + // If value changed from 1 -> 0, update the bound axis sums by -1. + update_bound_axis_slice_sums(state, i, (ptr->get(i) == 1) ? 1 : -1); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); + } } void BinaryNode::set(State& state, ssize_t i) const { + auto ptr = data_ptr(state); // We expect the set to obey the index-wise bounds. assert(upper_bound(i) == 1.0); - // Assert that i is a valid index occurs in data_ptr->set(). - // Set occurs IFF `value` != buffer[i] . - data_ptr(state)->set(i, 1.0); + // assert() that i is a valid index occurs in ptr->set(). + // State change occurs IFF `value` != buffer[i]. + if (ptr->set(i, 1.0)) { + // If value changed from 0 -> 1, update the bound axis sums by 1. + update_bound_axis_slice_sums(state, i, 1.0); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); + } } void BinaryNode::unset(State& state, ssize_t i) const { + auto ptr = data_ptr(state); // We expect the set to obey the index-wise bounds. assert(lower_bound(i) == 0.0); - // Assert that i is a valid index occurs in data_ptr->set(). - // Set occurs IFF `value` != buffer[i] . - data_ptr(state)->set(i, 0.0); + // assert() that i is a valid index occurs in ptr->set(). + // State change occurs IFF `value` != buffer[i]. + if (ptr->set(i, 0.0)) { + // If value changed from 1 -> 0, update the bound axis sums by -1. + update_bound_axis_slice_sums(state, i, -1.0); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); + } } } // namespace dwave::optimization diff --git a/tests/cpp/nodes/test_numbers.cpp b/tests/cpp/nodes/test_numbers.cpp index b62e6bdd..778d8cdf 100644 --- a/tests/cpp/nodes/test_numbers.cpp +++ b/tests/cpp/nodes/test_numbers.cpp @@ -18,6 +18,7 @@ #include "catch2/catch_test_macros.hpp" #include "catch2/matchers/catch_matchers.hpp" #include "catch2/matchers/catch_matchers_all.hpp" +#include "catch2/matchers/catch_matchers_range_equals.hpp" #include "dwave-optimization/graph.hpp" #include "dwave-optimization/nodes/numbers.hpp" @@ -25,6 +26,56 @@ using Catch::Matchers::RangeEquals; namespace dwave::optimization { +TEST_CASE("BoundAxisInfo") { + GIVEN("BoundAxisInfo(axis = 0, operators = {}, bounds = {1.0})") { + std::vector operators; + std::vector bounds{1.0}; + REQUIRE_THROWS_WITH(NumberNode::BoundAxisInfo(0, operators, bounds), + "Axis-wise `operators` and `bounds` must have non-zero size."); + } + + GIVEN("BoundAxisInfo(axis = 0, operators = {<=}, bounds = {})") { + std::vector operators{NumberNode::LessEqual}; + std::vector bounds; + REQUIRE_THROWS_WITH(NumberNode::BoundAxisInfo(0, operators, bounds), + "Axis-wise `operators` and `bounds` must have non-zero size."); + } + + GIVEN("BoundAxisInfo(axis = 1, operators = {<=, ==, ==}, bounds = {2.0, 1.0})") { + std::vector operators{NumberNode::LessEqual, + NumberNode::Equal, NumberNode::Equal}; + std::vector bounds{2.0, 1.0}; + REQUIRE_THROWS_WITH( + NumberNode::BoundAxisInfo(1, operators, bounds), + "Axis-wise `operators` and `bounds` should have same size if neither has size 1."); + } + + GIVEN("BoundAxisInfo(axis = 2, operators = {==}, bounds = {1.0})") { + std::vector operators{NumberNode::Equal}; + std::vector bounds{1.0}; + NumberNode::BoundAxisInfo bound_axis(2, operators, bounds); + + THEN("The bound axis info is correct") { + CHECK(bound_axis.axis == 2); + CHECK_THAT(bound_axis.operators, RangeEquals(operators)); + CHECK_THAT(bound_axis.bounds, RangeEquals(bounds)); + } + } + + GIVEN("BoundAxisInfo(axis = 2, operators = {==, <=, >=}, bounds = {1.0, 2.0, 3.0})") { + std::vector operators{ + NumberNode::Equal, NumberNode::LessEqual, NumberNode::GreaterEqual}; + std::vector bounds{1.0, 2.0, 3.0}; + NumberNode::BoundAxisInfo bound_axis(2, operators, bounds); + + THEN("The bound axis info is correct") { + CHECK(bound_axis.axis == 2); + CHECK_THAT(bound_axis.operators, RangeEquals(operators)); + CHECK_THAT(bound_axis.bounds, RangeEquals(bounds)); + } + } +} + TEST_CASE("BinaryNode") { auto graph = Graph(); @@ -296,8 +347,8 @@ TEST_CASE("BinaryNode") { } GIVEN("Binary node with index-wise bounds") { - auto bnode_ptr = graph.emplace_node( - 3, std::vector{-1, 0, 1}, std::vector{2, 1, 1}); + auto bnode_ptr = graph.emplace_node(3, std::vector{-1, 0, 1}, + std::vector{2, 1, 1}); THEN("The shape, max, min, and bounds are correct") { CHECK(bnode_ptr->size() == 3); @@ -394,8 +445,7 @@ TEST_CASE("BinaryNode") { } GIVEN("Binary node with index-wise upper bound and general lower bound") { - auto bnode_ptr = graph.emplace_node( - 2, -2.0, std::vector{0.0, 1.1}); + auto bnode_ptr = graph.emplace_node(2, -2.0, std::vector{0.0, 1.1}); THEN("The max, min, and bounds are correct") { CHECK(bnode_ptr->max() == 1.0); @@ -411,8 +461,7 @@ TEST_CASE("BinaryNode") { } GIVEN("Binary node with index-wise lower bound and general upper bound") { - auto bnode_ptr = graph.emplace_node( - 2, std::vector{-1.0, 1.0}, 100.0); + auto bnode_ptr = graph.emplace_node(2, std::vector{-1.0, 1.0}, 100.0); THEN("The max, min, and bounds are correct") { CHECK(bnode_ptr->max() == 1.0); @@ -428,19 +477,556 @@ TEST_CASE("BinaryNode") { } GIVEN("Binary node with invalid index-wise lower bounds at index 0") { - REQUIRE_THROWS(graph.emplace_node( - 2, std::vector{2, 0}, std::vector{1, 1})); + REQUIRE_THROWS(graph.emplace_node(2, std::vector{2, 0}, + std::vector{1, 1})); } GIVEN("Binary node with invalid index-wise upper bounds at index 1") { - REQUIRE_THROWS(graph.emplace_node( - 2, std::vector{0, 0}, std::vector{1, -1})); + REQUIRE_THROWS(graph.emplace_node(2, std::vector{0, 0}, + std::vector{1, -1})); } GIVEN("Invalid dynamically sized BinaryNode") { REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{-1, 2}), "Number array cannot have dynamic size."); } + + GIVEN("(2x3)-BinaryNode with axis-wise bounds on the invalid axis -1") { + std::vector operators{NumberNode::Equal}; + std::vector bounds{1.0}; + std::vector bound_axes{{-1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid bound axis given number array shape."); + } + + GIVEN("(2x3)-BinaryNode with axis-wise bounds on the invalid axis 2") { + std::vector operators{NumberNode::Equal}; + std::vector bounds{1.0}; + std::vector bound_axes{{2, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid bound axis given number array shape."); + } + + GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too many operators.") { + std::vector operators{ + NumberNode::LessEqual, NumberNode::Equal, NumberNode::Equal, NumberNode::Equal}; + std::vector bounds{1.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise operators given number array shape."); + } + + GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too few operators.") { + std::vector operators{NumberNode::LessEqual, + NumberNode::Equal}; + std::vector bounds{1.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise operators given number array shape."); + } + + GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too many bounds.") { + std::vector operators{NumberNode::Equal}; + std::vector bounds{1.0, 2.0, 3.0, 4.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise bounds given number array shape."); + } + + GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too few bounds.") { + std::vector operators{NumberNode::LessEqual}; + std::vector bounds{1.0, 2.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise bounds given number array shape."); + } + + GIVEN("(2x3)-BinaryNode with duplicate axis-wise bounds on axis: 1") { + std::vector operators{NumberNode::Equal}; + std::vector bounds{1.0}; + NumberNode::BoundAxisInfo bound_axis{1, operators, bounds}; + + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, std::nullopt, + std::vector{bound_axis, bound_axis}), + "Cannot define multiple axis-wise bounds for a single axis."); + } + + GIVEN("(2x3)-BinaryNode with axis-wise bounds on axes: 0 and 1") { + std::vector operators{NumberNode::LessEqual}; + std::vector bounds{1.0}; + NumberNode::BoundAxisInfo bound_axis_0{0, operators, bounds}; + NumberNode::BoundAxisInfo bound_axis_1{1, operators, bounds}; + + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, std::nullopt, + std::vector{bound_axis_0, bound_axis_1}), + "Axis-wise bounds are supported for at most one axis."); + } + + GIVEN("(2x3x4)-BinaryNode with non-integral axis-wise bounds") { + std::vector operators{NumberNode::Equal}; + std::vector bounds{0.1}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Axis wise bounds for integral number arrays must be intregral."); + } + + GIVEN("(3x2x2)-BinaryNode with infeasible axis-wise bound on axis: 0") { + auto graph = Graph(); + std::vector operators{ + NumberNode::Equal, NumberNode::LessEqual, NumberNode::GreaterEqual}; + std::vector bounds{5.0, 2.0, 3.0}; + std::vector bound_axes{{0, operators, bounds}}; + // Each hyperslice along axis 0 has size 4. There is no feasible + // assignment to the values in slice 0 (along axis 0) that results in a + // sum equal to 5. + graph.emplace_node(std::initializer_list{3, 2, 2}, std::nullopt, + std::nullopt, bound_axes); + + WHEN("We create a state by initialize_state()") { + REQUIRE_THROWS_WITH(graph.initialize_state(), "Infeasible axis-wise bounds."); + } + } + + GIVEN("(3x2x2)-BinaryNode with infeasible axis-wise bound on axis: 1") { + auto graph = Graph(); + std::vector operators{NumberNode::Equal, + NumberNode::GreaterEqual}; + std::vector bounds{5.0, 7.0}; + std::vector bound_axes{{1, operators, bounds}}; + graph.emplace_node(std::initializer_list{3, 2, 2}, std::nullopt, + std::nullopt, bound_axes); + + WHEN("We create a state by initialize_state()") { + // Each hyperslice along axis 1 has size 6. There is no feasible + // assignment to the values in slice 1 (along axis 1) that results in a + // sum greater than or equal to 7. + REQUIRE_THROWS_WITH(graph.initialize_state(), "Infeasible axis-wise bounds."); + } + } + + GIVEN("(3x2x2)-BinaryNode with infeasible axis-wise bound on axis: 2") { + auto graph = Graph(); + std::vector operators{NumberNode::Equal, + NumberNode::LessEqual}; + std::vector bounds{5.0, -1.0}; + std::vector bound_axes{{2, operators, bounds}}; + graph.emplace_node(std::initializer_list{3, 2, 2}, std::nullopt, + std::nullopt, bound_axes); + + WHEN("We create a state by initialize_state()") { + // Each hyperslice along axis 2 has size 6. There is no feasible + // assignment to the values in slice 1 (along axis 2) that results in a + // sum less than or equal to -1. + REQUIRE_THROWS_WITH(graph.initialize_state(), "Infeasible axis-wise bounds."); + } + } + + GIVEN("(3x2x2)-BinaryNode with feasible axis-wise bound on axis: 0") { + auto graph = Graph(); + std::vector operators{ + NumberNode::Equal, NumberNode::LessEqual, NumberNode::GreaterEqual}; + std::vector bounds{1.0, 2.0, 3.0}; + std::vector bound_axes{{0, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{3, 2, 2}, + std::nullopt, std::nullopt, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(3*2*2)]).reshape(3, 2, 2) + // print(a[0, :, :].flatten()) + // ... [0 1 2 3] + // print(a[1, :, :].flatten()) + // ... [4 5 6 7] + // print(a[2, :, :].flatten()) + // ... [ 8 9 10 11] + std::vector expected_init{1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0}; + // Cannonically least state that satisfies bounds + // slice 0 slice 1 slice 2 + // 1, 0 0, 0 1, 1 + // 0, 0 0, 0 1, 0 + + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 3); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 0, 3})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(3x2x2)-BinaryNode with feasible axis-wise bound on axis: 1") { + auto graph = Graph(); + std::vector operators{NumberNode::LessEqual, + NumberNode::GreaterEqual}; + std::vector bounds{1.0, 5.0}; + std::vector bound_axes{{1, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{3, 2, 2}, + std::nullopt, std::nullopt, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(3*2*2)]).reshape(3, 2, 2) + // print(a[:, 0, :].flatten()) + // ... [0 1 4 5 8 9] + // print(a[:, 1, :].flatten()) + // ... [ 2 3 6 7 10 11] + std::vector expected_init{0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0}; + // Cannonically least state that satisfies bounds + // slice 0 slice 1 + // 0, 0 1, 1 + // 0, 0 1, 1 + // 0, 0 1, 0 + + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 2); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({0, 5})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(3x2x2)-BinaryNode with feasible axis-wise bound on axis: 2") { + auto graph = Graph(); + std::vector operators{NumberNode::Equal, + NumberNode::GreaterEqual}; + std::vector bounds{3.0, 6.0}; + std::vector bound_axes{{2, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{3, 2, 2}, + std::nullopt, std::nullopt, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(3*2*2)]).reshape(3, 2, 2) + // print(a[:, :, 0].flatten()) + // ... [ 0 2 4 6 8 10] + // print(a[:, :, 1].flatten()) + // ... [ 1 3 5 7 9 11] + std::vector expected_init{1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1}; + // Cannonically least state that satisfies bounds + // slice 0 slice 1 + // 1, 1 1, 1 + // 1, 0 1, 1 + // 0, 0 1, 1 + + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 2); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({3, 6})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(3x2x2)-BinaryNode with an axis-wise bound on axis: 0") { + auto graph = Graph(); + std::vector operators{ + NumberNode::Equal, NumberNode::LessEqual, NumberNode::GreaterEqual}; + std::vector bounds{1.0, 2.0, 3.0}; + std::vector bound_axes{{0, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{3, 2, 2}, + std::nullopt, std::nullopt, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We initialize three invalid states") { + auto state = graph.empty_state(); + // This state violates the 0th hyperslice along axis 0 + std::vector init_values{1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1}; + // import numpy as np + // a = np.asarray([1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) + // a = a.reshape(3, 2, 2) + // a.sum(axis=(1, 2)) + // >>> array([2, 2, 4]) + CHECK_THROWS_WITH(bnode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + + state = graph.empty_state(); + // This state violates the 1st hyperslice along axis 0 + init_values = {0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1}; + // import numpy as np + // a = np.asarray([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]) + // a = a.reshape(3, 2, 2) + // a.sum(axis=(1, 2)) + // >>> array([1, 3, 4]) + CHECK_THROWS_WITH(bnode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + + state = graph.empty_state(); + // This state violates the 2nd hyperslice along axis 0 + init_values = {0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0}; + // import numpy as np + // a = np.asarray([0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0]) + // a = a.reshape(3, 2, 2) + // a.sum(axis=(1, 2)) + // >>> array([1, 2, 2]) + CHECK_THROWS_WITH(bnode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + } + + WHEN("We initialize a valid state") { + auto state = graph.empty_state(); + std::vector init_values{0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1}; + bnode_ptr->initialize_state(state, init_values); + graph.initialize_state(state); + + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + // **Python Code 1** + // import numpy as np + // a = np.asarray([0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) + // a = a.reshape(3, 2, 2) + // a.sum(axis=(1, 2)) + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 3); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + THEN("We exchange() some values") { + bnode_ptr->exchange(state, 0, 3); // Does nothing. + bnode_ptr->exchange(state, 1, 6); // Does nothing. + bnode_ptr->exchange(state, 1, 3); + std::swap(init_values[0], init_values[3]); + std::swap(init_values[1], init_values[6]); + std::swap(init_values[1], init_values[3]); + // state is now: [0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 1** + // a[np.unravel_index(1, a.shape)] = 0 + // a[np.unravel_index(3, a.shape)] = 1 + // a.sum(axis=(1, 2)) + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK(bnode_ptr->diff(state).size() == 2); // 2 updates per exchange + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK(bnode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We clip_and_set_value() some values") { + bnode_ptr->clip_and_set_value(state, 5, -1); // Does nothing. + bnode_ptr->clip_and_set_value(state, 7, -1); + bnode_ptr->clip_and_set_value(state, 9, 1); // Does nothing. + bnode_ptr->clip_and_set_value(state, 11, 0); + bnode_ptr->clip_and_set_value(state, 11, 1); + bnode_ptr->clip_and_set_value(state, 10, 0); + init_values[5] = 0; + init_values[7] = 0; + init_values[9] = 1; + init_values[11] = 1; + init_values[10] = 0; + // state is now: [0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 1** + // a[np.unravel_index(5, a.shape)] = 0 + // a[np.unravel_index(7, a.shape)] = 0 + // a[np.unravel_index(9, a.shape)] = 1 + // a[np.unravel_index(11, a.shape)] = 1 + // a[np.unravel_index(10, a.shape)] = 0 + // a.sum(axis=(1, 2)) + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 1, 3})); + CHECK(bnode_ptr->diff(state).size() == 4); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK(bnode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We set_value() some values") { + bnode_ptr->set_value(state, 0, 0); // Does nothing. + bnode_ptr->set_value(state, 6, 0); + bnode_ptr->set_value(state, 7, 0); + bnode_ptr->set_value(state, 4, 1); + bnode_ptr->set_value(state, 10, 1); // Does nothing. + bnode_ptr->set_value(state, 11, 0); + init_values[0] = 0; + init_values[6] = 0; + init_values[7] = 0; + init_values[4] = 1; + init_values[10] = 1; + init_values[11] = 0; + // state is now: [0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 1** + // a[np.unravel_index(0, a.shape)] = 0 + // a[np.unravel_index(6, a.shape)] = 0 + // a[np.unravel_index(7, a.shape)] = 0 + // a[np.unravel_index(4, a.shape)] = 1 + // a[np.unravel_index(10, a.shape)] = 1 + // a[np.unravel_index(11, a.shape)] = 0 + // a.sum(axis=(1, 2)) + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 1, 3})); + CHECK(bnode_ptr->diff(state).size() == 4); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK(bnode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We flip() some values") { + bnode_ptr->flip(state, 6); // 1 -> 0 + bnode_ptr->flip(state, 4); // 0 -> 1 + bnode_ptr->flip(state, 11); // 1 -> 0 + init_values[6] = !init_values[6]; + init_values[4] = !init_values[4]; + init_values[11] = !init_values[11]; + // state is now: [0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 1** + // a[np.unravel_index(6, a.shape)] = 0 + // a[np.unravel_index(4, a.shape)] = 1 + // a[np.unravel_index(11, a.shape)] = 0 + // a.sum(axis=(1, 2)) + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 3})); + CHECK(bnode_ptr->diff(state).size() == 3); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK(bnode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We unset() some values") { + bnode_ptr->unset(state, 0); // Does nothing. + bnode_ptr->unset(state, 6); + bnode_ptr->unset(state, 11); + init_values[0] = 0; + init_values[6] = 0; + init_values[11] = 0; + // state is now: [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 1** + // a[np.unravel_index(0, a.shape)] = 0 + // a[np.unravel_index(6, a.shape)] = 0 + // a[np.unravel_index(11, a.shape)] = 0 + // a.sum(axis=(1, 2)) + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 1, 3})); + CHECK(bnode_ptr->diff(state).size() == 2); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We commit and set() some values") { + graph.commit(state); + + bnode_ptr->set(state, 10); // Does nothing. + bnode_ptr->set(state, 11); + init_values[10] = 1; + init_values[11] = 1; + // state is now: [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + + THEN("The bound axis sums updated correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 1, 4})); + CHECK(bnode_ptr->diff(state).size() == 1); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], + RangeEquals({1, 1, 3})); + CHECK(bnode_ptr->diff(state).size() == 0); + } + } + } + } + } + } } TEST_CASE("IntegerNode") { @@ -461,7 +1047,8 @@ TEST_CASE("IntegerNode") { } } - GIVEN("Double precision numbers, which may fall outside integer range or are not integral") { + GIVEN("Double precision numbers, which may fall outside integer range or are not " + "integral") { IntegerNode inode({1}); THEN("The state is not deterministic") { CHECK(!inode.deterministic_state()); } @@ -534,8 +1121,8 @@ TEST_CASE("IntegerNode") { } GIVEN("Integer node with index-wise bounds") { - auto inode_ptr = graph.emplace_node( - 3, std::vector{-1, 3, 5}, std::vector{1, 7, 7}); + auto inode_ptr = graph.emplace_node(3, std::vector{-1, 3, 5}, + std::vector{1, 7, 7}); THEN("The shape, max, min, and bounds are correct") { CHECK(inode_ptr->size() == 3); @@ -596,8 +1183,7 @@ TEST_CASE("IntegerNode") { } GIVEN("Integer node with index-wise upper bound and general integer lower bound") { - auto inode_ptr = graph.emplace_node( - 2, 10, std::vector{20, 10}); + auto inode_ptr = graph.emplace_node(2, 10, std::vector{20, 10}); THEN("The max, min, and bounds are correct") { CHECK(inode_ptr->max() == 20.0); @@ -613,8 +1199,8 @@ TEST_CASE("IntegerNode") { } GIVEN("Integer node with invalid index-wise bounds at index 0") { - REQUIRE_THROWS(graph.emplace_node( - 2, std::vector{19, 12}, std::vector{20, 11})); + REQUIRE_THROWS(graph.emplace_node(2, std::vector{19, 12}, + std::vector{20, 11})); } GIVEN("An Integer Node representing an 1d array of 10 elements with lower bound -10") { @@ -739,6 +1325,462 @@ TEST_CASE("IntegerNode") { REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{-1, 3}), "Number array cannot have dynamic size."); } + + GIVEN("(2x3)-IntegerNode with axis-wise bounds on the invalid axis -2") { + std::vector operators{NumberNode::Equal}; + std::vector bounds{20.0}; + std::vector bound_axes{{-2, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid bound axis given number array shape."); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on the invalid axis 3") { + std::vector operators{NumberNode::Equal}; + std::vector bounds{10.0}; + std::vector bound_axes{{3, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Invalid bound axis given number array shape."); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too many operators.") { + std::vector operators{ + NumberNode::LessEqual, NumberNode::Equal, NumberNode::Equal, NumberNode::Equal}; + std::vector bounds{-10.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise operators given number array shape."); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too few operators.") { + std::vector operators{NumberNode::LessEqual, + NumberNode::Equal}; + std::vector bounds{-11.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise operators given number array shape."); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too many bounds.") { + std::vector operators{NumberNode::LessEqual}; + std::vector bounds{-10.0, 20.0, 30.0, 40.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise bounds given number array shape."); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too few bounds.") { + std::vector operators{NumberNode::LessEqual}; + std::vector bounds{111.0, -223.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise bounds given number array shape."); + } + + GIVEN("(2x3x4)-IntegerNode with duplicate axis-wise bounds on axis: 1") { + std::vector operators{NumberNode::Equal}; + std::vector bounds{100.0}; + NumberNode::BoundAxisInfo bound_axis{1, operators, bounds}; + + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, + std::vector{bound_axis, bound_axis}), + "Cannot define multiple axis-wise bounds for a single axis."); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axes: 0 and 1") { + std::vector operators{NumberNode::Equal}; + std::vector bounds{100.0}; + NumberNode::BoundAxisInfo bound_axis_0{0, operators, bounds}; + NumberNode::BoundAxisInfo bound_axis_1{1, operators, bounds}; + + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, + std::vector{bound_axis_0, bound_axis_1}), + "Axis-wise bounds are supported for at most one axis."); + } + + GIVEN("(2x3x4)-IntegerNode with non-integral axis-wise bounds") { + std::vector operators{NumberNode::LessEqual}; + std::vector bounds{11.0, 12.0001, 0.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Axis wise bounds for integral number arrays must be intregral."); + } + + GIVEN("(2x3x2)-IntegerNode with infeasible axis-wise bound on axis: 0") { + auto graph = Graph(); + std::vector operators{NumberNode::Equal, + NumberNode::LessEqual}; + std::vector bounds{5.0, -31.0}; + std::vector bound_axes{{0, operators, bounds}}; + graph.emplace_node(std::initializer_list{2, 3, 2}, -5, 8, bound_axes); + + WHEN("We create a state by initialize_state()") { + // Each hyperslice along axis 0 has size 6. There is no feasible + // assignment to the values in slice 1 (along axis 0) that results in a + // sum less than or equal to -5*6-1 = -31. + REQUIRE_THROWS_WITH(graph.initialize_state(), "Infeasible axis-wise bounds."); + } + } + + GIVEN("(2x3x2)-IntegerNode with infeasible axis-wise bound on axis: 1") { + auto graph = Graph(); + std::vector operators{NumberNode::GreaterEqual, + NumberNode::Equal, NumberNode::Equal}; + std::vector bounds{33.0, 0.0, 0.0}; + std::vector bound_axes{{1, operators, bounds}}; + graph.emplace_node(std::initializer_list{2, 3, 2}, -5, 8, bound_axes); + + WHEN("We create a state by initialize_state()") { + // Each hyperslice along axis 1 has size 4. There is no feasible + // assignment to the values in slice 0 (along axis 1) that results in a + // sum greater than or equal to 4*8+1 = 33. + REQUIRE_THROWS_WITH(graph.initialize_state(), "Infeasible axis-wise bounds."); + } + } + + GIVEN("(2x3x2)-IntegerNode with infeasible axis-wise bound on axis: 2") { + auto graph = Graph(); + std::vector operators{NumberNode::GreaterEqual, + NumberNode::Equal}; + std::vector bounds{-1.0, 49.0}; + std::vector bound_axes{{2, operators, bounds}}; + graph.emplace_node(std::initializer_list{2, 3, 2}, -5, 8, bound_axes); + + WHEN("We create a state by initialize_state()") { + // Each hyperslice along axis 2 has size 6. There is no feasible + // assignment to the values in slice 1 (along axis 2) that results in a + // sum or equal to 6*8+1 = 49 + REQUIRE_THROWS_WITH(graph.initialize_state(), "Infeasible axis-wise bounds."); + } + } + + GIVEN("(2x3x2)-IntegerNode with feasible axis-wise bound on axis: 0") { + auto graph = Graph(); + std::vector operators{NumberNode::Equal, + NumberNode::GreaterEqual}; + std::vector bounds{-21.0, 9.0}; + std::vector bound_axes{{0, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{2, 3, 2}, + -5, 8, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(2*3*2)]).reshape(2, 3, 2) + // print(a[0, :, :].flatten()) + // ... [0 1 2 3 4 5] + // print(a[1, :, :].flatten()) + // ... [ 6 7 8 9 10 11] + // + // initialize_state() will start with + // [-5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5] + // repair slice 0 + // [4, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5] + // repair slice 1 + // [4, -5, -5, -5, -5, -5, 8, 8, 8, -5, -5, -5] + std::vector expected_init{4, -5, -5, -5, -5, -5, 8, 8, 8, -5, -5, -5}; + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 2); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({-21.0, 9.0})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(2x3x2)-IntegerNode with feasible axis-wise bound on axis: 1") { + auto graph = Graph(); + std::vector operators{ + NumberNode::Equal, NumberNode::GreaterEqual, NumberNode::LessEqual}; + std::vector bounds{0.0, -2.0, 0.0}; + std::vector bound_axes{{1, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{2, 3, 2}, + -5, 8, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(2*3*2)]).reshape(2, 3, 2) + // print(a[:, 0, :].flatten()) + // ... [0 1 6 7] + // print(a[:, 1, :].flatten()) + // ... [2 3 8 9] + // print(a[:, 2, :].flatten()) + // ... [ 4 5 10 11] + // + // initialize_state() will start with + // [-5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5] + // repair slice 0 w/ [8, 2, -5, -5] + // [8, 2, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5] + // repair slice 1 w/ [8, 0, -5, -5] + // [8, 2, 8, 0, -5, -5, -5, -5, -5, -5, -5, -5] + // no need to repair slice 2 + std::vector expected_init{8, 2, 8, 0, -5, -5, -5, -5, -5, -5, -5, -5}; + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 3); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({0.0, -2.0, -20.0})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(2x3x2)-IntegerNode with feasible axis-wise bound on axis: 2") { + auto graph = Graph(); + std::vector operators{NumberNode::Equal, + NumberNode::GreaterEqual}; + std::vector bounds{23.0, 14.0}; + std::vector bound_axes{{2, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{2, 3, 2}, + -5, 8, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(2*3*2)]).reshape(2, 3, 2) + // print(a[:, :, 0].flatten()) + // ... [ 0 2 4 6 8 10] + // print(a[:, :, 0].flatten()) + // ... [ 1 3 5 7 9 11] + // + // initialize_state() will start with + // [-5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5] + // repair slice 0 w/ [8, 8, 8, 8, -4, -5] + // [8, -5, 8, -5, 8, -5, 8, -5, -4, -5, -5, -5] + // repair slice 0 w/ [8, 8, 8, 0, -5, -5] + // [8, 8, 8, 8, 8, 8, 8, 0, -4, -5, -5, -5] + std::vector expected_init{8, 8, 8, 8, 8, 8, 8, 0, -4, -5, -5, -5}; + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 2); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({23.0, 14.0})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(2x3x2)-IntegerNode with index-wise bounds and an axis-wise bound on axis: 1") { + auto graph = Graph(); + std::vector operators{ + NumberNode::Equal, NumberNode::LessEqual, NumberNode::GreaterEqual}; + std::vector bounds{11.0, 2.0, 5.0}; + std::vector bound_axes{{1, operators, bounds}}; + auto inode_ptr = graph.emplace_node(std::initializer_list{2, 3, 2}, + -5, 8, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(inode_ptr->axis_wise_bounds().size() == 1); + const NumberNode::BoundAxisInfo inode_bound_axis_ptr = + inode_ptr->axis_wise_bounds().data()[0]; + CHECK(bound_axes[0].axis == inode_bound_axis_ptr.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(inode_bound_axis_ptr.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(inode_bound_axis_ptr.bounds)); + } + + WHEN("We initialize three invalid states") { + auto state = graph.empty_state(); + // This state violates the 0th hyperslice along axis 1 + std::vector init_values{5, 6, 0, 0, 3, 1, 4, 0, 2, 0, 0, 3}; + // import numpy as np + // a = np.asarray([5, 6, 0, 0, 3, 1, 4, 0, 2, 0, 0, 3]) + // a = a.reshape(2, 3, 2) + // a.sum(axis=(0, 2)) + // >>> array([15, 2, 7]) + CHECK_THROWS_WITH(inode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + + state = graph.empty_state(); + // This state violates the 1st hyperslice along axis 1 + init_values = {5, 2, 0, 0, 3, 1, 4, 0, 2, 1, 0, 3}; + // import numpy as np + // a = np.asarray([5, 2, 0, 0, 3, 1, 4, 0, 2, 1, 0, 3]) + // a = a.reshape(2, 3, 2) + // a.sum(axis=(0, 2)) + // >>> array([11, 3, 7]) + CHECK_THROWS_WITH(inode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + + state = graph.empty_state(); + // This state violates the 2nd hyperslice along axis 1 + init_values = {5, 2, 0, 0, 3, 1, 4, 0, 1, 0, 0, 0}; + // import numpy as np + // a = np.asarray([5, 2, 0, 0, 3, 1, 4, 0, 1, 0, 0, 0]) + // a = a.reshape(2, 3, 2) + // a.sum(axis=(0, 2)) + // >>> array([11, 1, 4]) + CHECK_THROWS_WITH(inode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + } + + WHEN("We initialize a valid state") { + auto state = graph.empty_state(); + std::vector init_values{5, 2, 0, 0, 3, 1, 4, 0, 2, 0, 0, 3}; + inode_ptr->initialize_state(state, init_values); + graph.initialize_state(state); + + auto bound_axis_sums = inode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + // **Python Code 2** + // import numpy as np + // a = np.asarray([5, 2, 0, 0, 3, 1, 4, 0, 2, 0, 0, 3]) + // a = a.reshape(2, 3, 2) + // a.sum(axis=(0, 2)) + // >>> array([11, 2, 7]) + CHECK(inode_ptr->bound_axis_sums(state).size() == 1); + CHECK(inode_ptr->bound_axis_sums(state).data()[0].size() == 3); + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, 2, 7})); + CHECK_THAT(inode_ptr->view(state), RangeEquals(init_values)); + } + + THEN("We exchange() some values") { + inode_ptr->exchange(state, 2, 3); // Does nothing. + inode_ptr->exchange(state, 1, 8); // Does nothing. + inode_ptr->exchange(state, 8, 10); + inode_ptr->exchange(state, 0, 1); + std::swap(init_values[2], init_values[3]); + std::swap(init_values[1], init_values[8]); + std::swap(init_values[8], init_values[10]); + std::swap(init_values[0], init_values[1]); + // state is now: [2, 5, 0, 0, 3, 1, 4, 0, 0, 0, 2, 3] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 2** + // a[np.unravel_index(8, a.shape)] = 0 + // a[np.unravel_index(10, a.shape)] = 2 + // a[np.unravel_index(0, a.shape)] = 2 + // a[np.unravel_index(1, a.shape)] = 5 + // a.sum(axis=(0, 2)) + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, 0, 9})); + CHECK(inode_ptr->diff(state).size() == 4); // 2 updates per exchange + CHECK_THAT(inode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, 2, 7})); + CHECK(inode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We clip_and_set_value() some values") { + inode_ptr->clip_and_set_value(state, 0, 5); // Does nothing. + inode_ptr->clip_and_set_value(state, 8, -300); + inode_ptr->clip_and_set_value(state, 10, 100); + init_values[8] = -5; + init_values[10] = 8; + // state is now: [5, 2, 0, 0, 3, 1, 4, 0, -5, 0, 8, 3] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 2** + // a[np.unravel_index(8, a.shape)] = -5 + // a[np.unravel_index(10, a.shape)] = 8 + // a.sum(axis=(0, 2)) + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, -5, 15})); + CHECK(inode_ptr->diff(state).size() == 2); + CHECK_THAT(inode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, 2, 7})); + CHECK(inode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We set_value() some values") { + inode_ptr->set_value(state, 0, 5); // Does nothing. + inode_ptr->set_value(state, 8, 0); + inode_ptr->set_value(state, 9, 1); + inode_ptr->set_value(state, 10, 5); + inode_ptr->set_value(state, 11, 0); + init_values[0] = 5; + init_values[8] = 0; + init_values[9] = 1; + init_values[10] = 5; + init_values[11] = 0; + // state is now: [5, 2, 0, 0, 3, 1, 4, 0, 0, 1, 5, 0] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 2** + // a[np.unravel_index(0, a.shape)] = 5 + // a[np.unravel_index(8, a.shape)] = 0 + // a[np.unravel_index(9, a.shape)] = 1 + // a[np.unravel_index(10, a.shape)] = 5 + // a[np.unravel_index(11, a.shape)] = 0 + // a.sum(axis=(0, 2)) + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, 1, 9})); + CHECK(inode_ptr->diff(state).size() == 4); + CHECK_THAT(inode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bound_axis_sums[0], RangeEquals({11, 2, 7})); + CHECK(inode_ptr->diff(state).size() == 0); + } + } + } + } + } } } // namespace dwave::optimization