From 82e10a80ea9af9f1f26514e1d941672fc2a51c2c Mon Sep 17 00:00:00 2001 From: fastbodin Date: Tue, 6 Jan 2026 09:31:39 -0800 Subject: [PATCH 1/7] Add stateless axis-wise bound info to NumberNode Data is stored at C++ level with the class `AxisBoundInfo` as private attribute to `NumberNode`. Added relevant C++ tests. --- .../dwave-optimization/nodes/numbers.hpp | 127 ++++-- dwave/optimization/src/nodes/numbers.cpp | 366 +++++++++++++----- tests/cpp/nodes/test_numbers.cpp | 144 +++++++ 3 files changed, 501 insertions(+), 136 deletions(-) diff --git a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp index 59c93db1..9760b27f 100644 --- a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp +++ b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp @@ -25,6 +25,36 @@ namespace dwave::optimization { +/// Allowable axis-wise bound operators. +enum BoundAxisOperator { Equal, LessEqual, GreaterEqual }; + +/// Class for stateless axis-wise bound information. Given an `axis`, define +/// constraints on the sum of the values in each slice along `axis`. +/// Constraints can be defined for ALL slices along `axis` or PER slice along +/// `axis`. Allowable operators are defined by `BoundAxisOperator`. +class BoundAxisInfo { + public: + /// To reduce the # of `IntegerNode` and `BinaryNode` constructors, we + /// allow only one constructor. + BoundAxisInfo(ssize_t axis, std::vector axis_operators, + std::vector axis_bounds); + /// The bound axis + const ssize_t axis; + /// Operator for ALL axis slices (vector has length one) or operator*s* PER + /// slice (length of vector is equal to the number of slices). + const std::vector operators; + /// Bound for ALL axis slices (vector has length one) or bound*s* PER slice + /// (length of vector is equal to the number of slices). + const std::vector bounds; + + private: + /// Obtain the bound associated with a given slice along bound axis. + double get_bound(const ssize_t slice) const; + + /// Obtain the operator associated with a given slice along bound axis. + BoundAxisOperator get_operator(const ssize_t slice) const; +}; + /// A contiguous block of numbers. class NumberNode : public ArrayOutputMixin, public DecisionNode { public: @@ -106,9 +136,16 @@ class NumberNode : public ArrayOutputMixin, public DecisionNode { // in a given index. void clip_and_set_value(State& state, ssize_t index, double value) const; + /// The number of axes with axis-wise bounds. + ssize_t num_bound_axes() const; + + /// Return the bound information for the ith bound axis + const BoundAxisInfo* get_ith_bound_axis_info(const ssize_t i) const; + protected: explicit NumberNode(std::span shape, std::vector lower_bound, - std::vector upper_bound); + std::vector upper_bound, + std::optional> bound_axes = std::nullopt); // Return truth statement: 'value is valid in a given index'. virtual bool is_valid(ssize_t index, double value) const = 0; @@ -119,8 +156,12 @@ class NumberNode : public ArrayOutputMixin, public DecisionNode { double min_; double max_; + // Stateless index-wise upper and lower bounds std::vector lower_bounds_; std::vector upper_bounds_; + + /// Stateless information on each bound axis. + const std::vector bound_axes_info_; }; /// A contiguous block of integer numbers. @@ -134,33 +175,45 @@ class IntegerNode : public NumberNode { // Default to a single scalar integer with default bounds IntegerNode() : IntegerNode({}) {} - // Create an integer array with the user-defined bounds. - // Defaulting to the specified default bounds. + // Create an integer array with the user-defined index- and axis-wise bounds. + // Index-wise bounds default to the specified default bounds. IntegerNode(std::span shape, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); IntegerNode(std::initializer_list shape, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); IntegerNode(ssize_t size, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); IntegerNode(std::span shape, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); IntegerNode(std::initializer_list shape, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); IntegerNode(ssize_t size, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); IntegerNode(std::span shape, std::optional> lower_bound, - double upper_bound); + double upper_bound, + std::optional> bound_axes = std::nullopt); IntegerNode(std::initializer_list shape, - std::optional> lower_bound, double upper_bound); - IntegerNode(ssize_t size, std::optional> lower_bound, double upper_bound); - - IntegerNode(std::span shape, double lower_bound, double upper_bound); - IntegerNode(std::initializer_list shape, double lower_bound, double upper_bound); - IntegerNode(ssize_t size, double lower_bound, double upper_bound); + std::optional> lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); + IntegerNode(ssize_t size, std::optional> lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); + + IntegerNode(std::span shape, double lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); + IntegerNode(std::initializer_list shape, double lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); + IntegerNode(ssize_t size, double lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); // Overloads needed by the Node ABC *************************************** @@ -190,33 +243,45 @@ class BinaryNode : public IntegerNode { /// A binary scalar variable with lower_bound = 0.0 and upper_bound = 1.0 BinaryNode() : BinaryNode({}) {} - // Create a binary array with the user-defined bounds. - // Defaulting to lower_bound = 0.0 and upper_bound = 1.0 + // Create a binary array with the user-defined index- and axis-wise bounds. + // Index-wise bounds default to lower_bound = 0.0 and upper_bound = 1.0. BinaryNode(std::span shape, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); BinaryNode(std::initializer_list shape, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); BinaryNode(ssize_t size, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); BinaryNode(std::span shape, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); BinaryNode(std::initializer_list shape, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); BinaryNode(ssize_t size, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::optional> bound_axes = std::nullopt); BinaryNode(std::span shape, std::optional> lower_bound, - double upper_bound); + double upper_bound, + std::optional> bound_axes = std::nullopt); BinaryNode(std::initializer_list shape, std::optional> lower_bound, - double upper_bound); - BinaryNode(ssize_t size, std::optional> lower_bound, double upper_bound); - - BinaryNode(std::span shape, double lower_bound, double upper_bound); - BinaryNode(std::initializer_list shape, double lower_bound, double upper_bound); - BinaryNode(ssize_t size, double lower_bound, double upper_bound); + double upper_bound, + std::optional> bound_axes = std::nullopt); + BinaryNode(ssize_t size, std::optional> lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); + + BinaryNode(std::span shape, double lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); + BinaryNode(std::initializer_list shape, double lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); + BinaryNode(ssize_t size, double lower_bound, double upper_bound, + std::optional> bound_axes = std::nullopt); // Flip the value (0 -> 1 or 1 -> 0) at index i in the given state. void flip(State& state, ssize_t i) const; diff --git a/dwave/optimization/src/nodes/numbers.cpp b/dwave/optimization/src/nodes/numbers.cpp index 5ad26c99..2e6be87c 100644 --- a/dwave/optimization/src/nodes/numbers.cpp +++ b/dwave/optimization/src/nodes/numbers.cpp @@ -23,7 +23,168 @@ namespace dwave::optimization { +BoundAxisInfo::BoundAxisInfo(ssize_t bound_axis, std::vector axis_operators, + std::vector axis_bounds) + : axis(bound_axis), operators(std::move(axis_operators)), bounds(std::move(axis_bounds)) { + const ssize_t num_operators = operators.size(); + const ssize_t num_bounds = bounds.size(); + + // Null `operators` and `bounds` are not accepted. + if ((num_operators == 0) || (num_bounds == 0)) { + throw std::invalid_argument("Bad axis-wise bounds for axis: " + std::to_string(axis) + + ", `operators` and `bounds` must each have non-zero size."); + } + + // If `operators` and `bounds` are defined PER hyperslice along `axis`, + // they must have the same size. + if ((num_operators > 1) && (num_bounds > 1) && (num_bounds != num_operators)) { + throw std::invalid_argument( + "Bad axis-wise bounds for axis: " + std::to_string(axis) + + ", `operators` and `bounds` should have same size if neither has size 1."); + } +} + +double BoundAxisInfo::get_bound(const ssize_t slice) const { + const ssize_t max_slice = bounds.size(); + // Negative indexing is not supported. + if ((slice < 0) || (slice >= max_slice)) { + throw std::invalid_argument("Out of range slice: " + std::to_string(slice) + + " along axis: " + std::to_string(axis)); + } + + if (max_slice == 1) { + return bounds[0]; + } + return bounds[slice]; +} + +BoundAxisOperator BoundAxisInfo::get_operator(const ssize_t slice) const { + const ssize_t max_slice = operators.size(); + // Negative indexing is not supported. + if ((slice < 0) || (slice >= max_slice)) { + throw std::invalid_argument("Out of range slice: " + std::to_string(slice) + + " along axis: " + std::to_string(axis)); + } + + if (max_slice == 1) { + return operators[0]; + } + return operators[slice]; +} + +template +double get_extreme_index_wise_bound(const std::vector& bound) { + assert(bound.size() > 0); + std::vector::const_iterator it; + if (maximum) { + it = std::max_element(bound.begin(), bound.end()); + } else { + it = std::min_element(bound.begin(), bound.end()); + } + return *it; +} + +void check_index_wise_bounds(const NumberNode& node, const std::vector& lower_bounds_, + const std::vector& upper_bounds_) { + bool index_wise_bound = false; + // If lower bound is index-wise, it must be correct size. + if (lower_bounds_.size() > 1) { + index_wise_bound = true; + if (static_cast(lower_bounds_.size()) != node.size()) { + throw std::invalid_argument("lower_bound must match size of node"); + } + } + // If upper bound is index-wise, it must be correct size. + if (upper_bounds_.size() > 1) { + index_wise_bound = true; + if (static_cast(upper_bounds_.size()) != node.size()) { + throw std::invalid_argument("upper_bound must match size of node"); + } + } + // If at least one of the bounds is index-wise, check that there are no + // violations at any of the indices. + if (index_wise_bound) { + for (ssize_t i = 0, stop = node.size(); i < stop; ++i) { + if (node.lower_bound(i) > node.upper_bound(i)) { + throw std::invalid_argument("Bounds of index " + std::to_string(i) + " clash"); + } + } + } +} + +/// Check the user defined axis-wise bounds for NumberNode +void check_axis_wise_bounds(const std::vector& bound_axes_info, + const std::span shape) { + if (bound_axes_info.size() == 0) { // No bound axes to check. + return; + } + + // Used to asses if an axis have been bound multiple times. + std::vector axis_bound(shape.size(), false); + + // For each set of bound axis data + for (const BoundAxisInfo& bound_axis_info : bound_axes_info) { + const ssize_t axis = bound_axis_info.axis; + + if (axis < 0 || axis >= shape.size()) { + throw std::invalid_argument( + "Invalid bound axis: " + std::to_string(axis) + + ". Note, negative indexing is not supported for axis-wise bounds."); + } + + // The number of operators defined for the given bound axis + const ssize_t num_operators = bound_axis_info.operators.size(); + if ((num_operators > 1) && (num_operators != shape[axis])) { + throw std::invalid_argument( + "Invalid number of axis-wise operators along axis: " + std::to_string(axis) + + " given axis shape: " + std::to_string(shape[axis])); + } + + // The number of operators defined for the given bound axis + const ssize_t num_bounds = bound_axis_info.bounds.size(); + if ((num_bounds > 1) && (num_bounds != shape[axis])) { + throw std::invalid_argument( + "Invalid number of axis-wise bounds along axis: " + std::to_string(axis) + + " given axis shape: " + std::to_string(shape[axis])); + } + + // Checked in BoundAxisInfo constructor + assert(num_operators == num_bounds || num_operators == 1 || num_bounds == 1); + + if (axis_bound[axis]) { + throw std::invalid_argument( + "Cannot define multiple axis-wise bounds for a single axis."); + } + axis_bound[axis] = true; + } + + // *Currently*, we only support axis-wise bounds for up to one axis. + if (bound_axes_info.size() > 1) { + throw std::invalid_argument("Axis-wise bounds are supported for at most one axis."); + } +} + // Base class to be used as interfaces. +NumberNode::NumberNode(std::span shape, std::vector lower_bound, + std::vector upper_bound, + std::optional> bound_axes) + : ArrayOutputMixin(shape), + min_(get_extreme_index_wise_bound(lower_bound)), + max_(get_extreme_index_wise_bound(upper_bound)), + lower_bounds_(std::move(lower_bound)), + upper_bounds_(std::move(upper_bound)), + bound_axes_info_(bound_axes ? std::move(*bound_axes) : std::vector{}) { + if ((shape.size() > 0) && (shape[0] < 0)) { + throw std::invalid_argument("Number array cannot have dynamic size."); + } + + if (max_ < min_) { + throw std::invalid_argument("Invalid range for number array provided."); + } + + check_index_wise_bounds(*this, lower_bounds_, upper_bounds_); + check_axis_wise_bounds(bound_axes_info_, this->shape()); +} double const* NumberNode::buff(const State& state) const noexcept { return data_ptr(state)->buff(); @@ -124,74 +285,29 @@ void NumberNode::clip_and_set_value(State& state, ssize_t index, double value) c data_ptr(state)->set(index, value); } -template -double get_extreme_index_wise_bound(const std::vector& bound) { - assert(bound.size() > 0); - std::vector::const_iterator it; - if (maximum) { - it = std::max_element(bound.begin(), bound.end()); - } else { - it = std::min_element(bound.begin(), bound.end()); - } - return *it; -} +ssize_t NumberNode::num_bound_axes() const { + return static_cast(bound_axes_info_.size()); +}; -void check_index_wise_bounds(const NumberNode& node, const std::vector& lower_bounds_, - const std::vector& upper_bounds_) { - bool index_wise_bound = false; - // If lower bound is index-wise, it must be correct size. - if (lower_bounds_.size() > 1) { - index_wise_bound = true; - if (static_cast(lower_bounds_.size()) != node.size()) { - throw std::invalid_argument("lower_bound must match size of node"); - } +const BoundAxisInfo* NumberNode::get_ith_bound_axis_info(const ssize_t i) const { + if (i < 0 || i >= bound_axes_info_.size()) { + throw std::invalid_argument("Invalid ith bound axis requested: " + std::to_string(i)); } - // If upper bound is index-wise, it must be correct size. - if (upper_bounds_.size() > 1) { - index_wise_bound = true; - if (static_cast(upper_bounds_.size()) != node.size()) { - throw std::invalid_argument("upper_bound must match size of node"); - } - } - // If at least one of the bounds is index-wise, check that there are no - // violations at any of the indices. - if (index_wise_bound) { - for (ssize_t i = 0, stop = node.size(); i < stop; ++i) { - if (node.lower_bound(i) > node.upper_bound(i)) { - throw std::invalid_argument("Bounds of index " + std::to_string(i) + " clash"); - } - } - } -} - -NumberNode::NumberNode(std::span shape, std::vector lower_bound, - std::vector upper_bound) - : ArrayOutputMixin(shape), - min_(get_extreme_index_wise_bound(lower_bound)), - max_(get_extreme_index_wise_bound(upper_bound)), - lower_bounds_(std::move(lower_bound)), - upper_bounds_(std::move(upper_bound)) { - if ((shape.size() > 0) && (shape[0] < 0)) { - throw std::invalid_argument("Number array cannot have dynamic size."); - } - - if (max_ < min_) { - throw std::invalid_argument("Invalid range for number array provided."); - } - - check_index_wise_bounds(*this, lower_bounds_, upper_bounds_); -} + return &bound_axes_info_[i]; +}; // Integer Node *************************************************************** IntegerNode::IntegerNode(std::span shape, std::optional> lower_bound, - std::optional> upper_bound) + std::optional> upper_bound, + std::optional> bound_axes) : NumberNode(shape, lower_bound.has_value() ? std::move(*lower_bound) : std::vector{default_lower_bound}, upper_bound.has_value() ? std::move(*upper_bound) - : std::vector{default_upper_bound}) { + : std::vector{default_upper_bound}, + std::move(bound_axes)) { if (min_ < minimum_lower_bound || max_ > maximum_upper_bound) { throw std::invalid_argument("range provided for integers exceeds supported range"); } @@ -199,40 +315,59 @@ IntegerNode::IntegerNode(std::span shape, IntegerNode::IntegerNode(std::initializer_list shape, std::optional> lower_bound, - std::optional> upper_bound) - : IntegerNode(std::span(shape), std::move(lower_bound), std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : IntegerNode(std::span(shape), std::move(lower_bound), std::move(upper_bound), + std::move(bound_axes)) {} IntegerNode::IntegerNode(ssize_t size, std::optional> lower_bound, - std::optional> upper_bound) - : IntegerNode({size}, std::move(lower_bound), std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : IntegerNode({size}, std::move(lower_bound), std::move(upper_bound), + std::move(bound_axes)) {} IntegerNode::IntegerNode(std::span shape, double lower_bound, - std::optional> upper_bound) - : IntegerNode(shape, std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : IntegerNode(shape, std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} IntegerNode::IntegerNode(std::initializer_list shape, double lower_bound, - std::optional> upper_bound) - : IntegerNode(std::span(shape), std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : IntegerNode(std::span(shape), std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} IntegerNode::IntegerNode(ssize_t size, double lower_bound, - std::optional> upper_bound) - : IntegerNode({size}, std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : IntegerNode({size}, std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} IntegerNode::IntegerNode(std::span shape, - std::optional> lower_bound, double upper_bound) - : IntegerNode(shape, std::move(lower_bound), std::vector{upper_bound}) {} + std::optional> lower_bound, double upper_bound, + std::optional> bound_axes) + : IntegerNode(shape, std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} IntegerNode::IntegerNode(std::initializer_list shape, - std::optional> lower_bound, double upper_bound) - : IntegerNode(std::span(shape), std::move(lower_bound), std::vector{upper_bound}) {} + std::optional> lower_bound, double upper_bound, + std::optional> bound_axes) + : IntegerNode(std::span(shape), std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} IntegerNode::IntegerNode(ssize_t size, std::optional> lower_bound, - double upper_bound) - : IntegerNode({size}, std::move(lower_bound), std::vector{upper_bound}) {} - -IntegerNode::IntegerNode(std::span shape, double lower_bound, double upper_bound) - : IntegerNode(shape, std::vector{lower_bound}, std::vector{upper_bound}) {} + double upper_bound, std::optional> bound_axes) + : IntegerNode({size}, std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} + +IntegerNode::IntegerNode(std::span shape, double lower_bound, double upper_bound, + std::optional> bound_axes) + : IntegerNode(shape, std::vector{lower_bound}, std::vector{upper_bound}, + std::move(bound_axes)) {} IntegerNode::IntegerNode(std::initializer_list shape, double lower_bound, - double upper_bound) + double upper_bound, std::optional> bound_axes) : IntegerNode(std::span(shape), std::vector{lower_bound}, - std::vector{upper_bound}) {} -IntegerNode::IntegerNode(ssize_t size, double lower_bound, double upper_bound) - : IntegerNode({size}, std::vector{lower_bound}, std::vector{upper_bound}) {} + std::vector{upper_bound}, std::move(bound_axes)) {} +IntegerNode::IntegerNode(ssize_t size, double lower_bound, double upper_bound, + std::optional> bound_axes) + : IntegerNode({size}, std::vector{lower_bound}, std::vector{upper_bound}, + std::move(bound_axes)) {} bool IntegerNode::integral() const { return true; } @@ -287,45 +422,66 @@ std::vector limit_bound_to_bool_domain(std::optional BinaryNode::BinaryNode(std::span shape, std::optional> lower_bound, - std::optional> upper_bound) + std::optional> upper_bound, + std::optional> bound_axes) : IntegerNode(shape, limit_bound_to_bool_domain(lower_bound), - limit_bound_to_bool_domain(upper_bound)) {} + limit_bound_to_bool_domain(upper_bound), bound_axes) {} BinaryNode::BinaryNode(std::initializer_list shape, std::optional> lower_bound, - std::optional> upper_bound) - : BinaryNode(std::span(shape), std::move(lower_bound), std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : BinaryNode(std::span(shape), std::move(lower_bound), std::move(upper_bound), + std::move(bound_axes)) {} BinaryNode::BinaryNode(ssize_t size, std::optional> lower_bound, - std::optional> upper_bound) - : BinaryNode({size}, std::move(lower_bound), std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : BinaryNode({size}, std::move(lower_bound), std::move(upper_bound), + std::move(bound_axes)) {} BinaryNode::BinaryNode(std::span shape, double lower_bound, - std::optional> upper_bound) - : BinaryNode(shape, std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : BinaryNode(shape, std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} BinaryNode::BinaryNode(std::initializer_list shape, double lower_bound, - std::optional> upper_bound) - : BinaryNode(std::span(shape), std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : BinaryNode(std::span(shape), std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} BinaryNode::BinaryNode(ssize_t size, double lower_bound, - std::optional> upper_bound) - : BinaryNode({size}, std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::optional> bound_axes) + : BinaryNode({size}, std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} BinaryNode::BinaryNode(std::span shape, - std::optional> lower_bound, double upper_bound) - : BinaryNode(shape, std::move(lower_bound), std::vector{upper_bound}) {} + std::optional> lower_bound, double upper_bound, + std::optional> bound_axes) + : BinaryNode(shape, std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} BinaryNode::BinaryNode(std::initializer_list shape, - std::optional> lower_bound, double upper_bound) - : BinaryNode(std::span(shape), std::move(lower_bound), std::vector{upper_bound}) {} + std::optional> lower_bound, double upper_bound, + std::optional> bound_axes) + : BinaryNode(std::span(shape), std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} BinaryNode::BinaryNode(ssize_t size, std::optional> lower_bound, - double upper_bound) - : BinaryNode({size}, std::move(lower_bound), std::vector{upper_bound}) {} - -BinaryNode::BinaryNode(std::span shape, double lower_bound, double upper_bound) - : BinaryNode(shape, std::vector{lower_bound}, std::vector{upper_bound}) {} -BinaryNode::BinaryNode(std::initializer_list shape, double lower_bound, double upper_bound) + double upper_bound, std::optional> bound_axes) + : BinaryNode({size}, std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} + +BinaryNode::BinaryNode(std::span shape, double lower_bound, double upper_bound, + std::optional> bound_axes) + : BinaryNode(shape, std::vector{lower_bound}, std::vector{upper_bound}, + std::move(bound_axes)) {} +BinaryNode::BinaryNode(std::initializer_list shape, double lower_bound, double upper_bound, + std::optional> bound_axes) : BinaryNode(std::span(shape), std::vector{lower_bound}, - std::vector{upper_bound}) {} -BinaryNode::BinaryNode(ssize_t size, double lower_bound, double upper_bound) - : BinaryNode({size}, std::vector{lower_bound}, std::vector{upper_bound}) {} + std::vector{upper_bound}, std::move(bound_axes)) {} +BinaryNode::BinaryNode(ssize_t size, double lower_bound, double upper_bound, + std::optional> bound_axes) + : BinaryNode({size}, std::vector{lower_bound}, std::vector{upper_bound}, + std::move(bound_axes)) {} void BinaryNode::flip(State& state, ssize_t i) const { auto ptr = data_ptr(state); diff --git a/tests/cpp/nodes/test_numbers.cpp b/tests/cpp/nodes/test_numbers.cpp index b62e6bdd..df033a3f 100644 --- a/tests/cpp/nodes/test_numbers.cpp +++ b/tests/cpp/nodes/test_numbers.cpp @@ -25,6 +25,50 @@ using Catch::Matchers::RangeEquals; namespace dwave::optimization { +TEST_CASE("BoundAxisInfo") { + GIVEN("BoundAxisInfo(axis = 0, operators = {}, bounds = {1.0})") { + REQUIRE_THROWS_WITH( + BoundAxisInfo(0, std::vector{}, std::vector{1.0}), + "Bad axis-wise bounds for axis: 0, `operators` and `bounds` must each have " + "non-zero size."); + } + + GIVEN("BoundAxisInfo(axis = 0, operators = {<=}, bounds = {})") { + REQUIRE_THROWS_WITH( + BoundAxisInfo(0, std::vector{LessEqual}, std::vector{}), + "Bad axis-wise bounds for axis: 0, `operators` and `bounds` must each have " + "non-zero size."); + } + + GIVEN("BoundAxisInfo(axis = 1, operators = {<=, ==, ==}, bounds = {2.0, 1.0})") { + REQUIRE_THROWS_WITH( + BoundAxisInfo(1, std::vector{LessEqual, Equal, Equal}, + std::vector{2.0, 1.0}), + "Bad axis-wise bounds for axis: 1, `operators` and `bounds` should have same size " + "if neither has size 1."); + } + + GIVEN("BoundAxisInfo(axis = 2, operators = {==}, bounds = {1.0})") { + BoundAxisInfo bound_axis(2, std::vector{Equal}, + std::vector{1.0}); + THEN("The bound axis info is correct") { + CHECK(bound_axis.axis == 2); + CHECK_THAT(bound_axis.operators, RangeEquals({Equal})); + CHECK_THAT(bound_axis.bounds, RangeEquals({1.0})); + } + } + + GIVEN("BoundAxisInfo(axis = 2, operators = {==, <=, >=}, bounds = {1.0, 2.0, 3.0})") { + BoundAxisInfo bound_axis(2, std::vector{Equal, LessEqual, GreaterEqual}, + std::vector{1.0, 2.0, 3.0}); + THEN("The bound axis info is correct") { + CHECK(bound_axis.axis == 2); + CHECK_THAT(bound_axis.operators, RangeEquals({Equal, LessEqual, GreaterEqual})); + CHECK_THAT(bound_axis.bounds, RangeEquals({1.0, 2.0, 3.0})); + } + } +} + TEST_CASE("BinaryNode") { auto graph = Graph(); @@ -441,6 +485,106 @@ TEST_CASE("BinaryNode") { REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{-1, 2}), "Number array cannot have dynamic size."); } + + GIVEN("(2x3)-Binary node with axis-wise bounds on the invalid axis -1") { + BoundAxisInfo bound_axis{-1, std::vector{Equal}, + std::vector{1.0}}; + REQUIRE_THROWS_WITH(graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, + std::nullopt, std::vector{bound_axis}), + "Invalid bound axis: -1. Note, negative indexing is not supported for " + "axis-wise bounds."); + } + + GIVEN("(2x3)-Binary node with axis-wise bounds on the invalid axis 2") { + BoundAxisInfo bound_axis{2, std::vector{Equal}, + std::vector{1.0}}; + REQUIRE_THROWS_WITH(graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, + std::nullopt, std::vector{bound_axis}), + "Invalid bound axis: 2. Note, negative indexing is not supported for " + "axis-wise bounds."); + } + + GIVEN("(2x3)-Binary node with axis-wise bounds on axis: 1 with too many operators.") { + BoundAxisInfo bound_axis{1, std::vector{LessEqual, Equal, Equal, Equal}, + std::vector{1.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, std::nullopt, + std::vector{bound_axis}), + "Invalid number of axis-wise operators along axis: 1 given axis shape: 3"); + } + + GIVEN("(2x3)-Binary node with axis-wise bounds on axis: 1 with too few operators.") { + BoundAxisInfo bound_axis{1, std::vector{LessEqual, Equal}, + std::vector{1.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, std::nullopt, + std::vector{bound_axis}), + "Invalid number of axis-wise operators along axis: 1 given axis shape: 3"); + } + + GIVEN("(2x3)-Binary node with axis-wise bounds on axis: 1 with too many bounds.") { + BoundAxisInfo bound_axis{1, std::vector{LessEqual}, + std::vector{1.0, 2.0, 3.0, 4.0}}; + REQUIRE_THROWS_WITH(graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, + std::nullopt, std::vector{bound_axis}), + "Invalid number of axis-wise bounds along axis: 1 given axis shape: 3"); + } + + GIVEN("(2x3)-Binary node with axis-wise bounds on axis: 1 with too few bounds.") { + BoundAxisInfo bound_axis{1, std::vector{LessEqual}, + std::vector{1.0, 2.0}}; + REQUIRE_THROWS_WITH(graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, + std::nullopt, std::vector{bound_axis}), + "Invalid number of axis-wise bounds along axis: 1 given axis shape: 3"); + } + + GIVEN("(2x3)-Binary node with duplicate axis-wise bounds on axis: 1") { + BoundAxisInfo bound_axis{1, std::vector{Equal}, + std::vector{1.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, std::nullopt, + std::vector{bound_axis, bound_axis}), + "Cannot define multiple axis-wise bounds for a single axis."); + } + + GIVEN("(2x3)-Binary node with axis-wise bounds on axes: 0 and 1") { + BoundAxisInfo bound_axis_0{0, std::vector{LessEqual}, + std::vector{1.0}}; + BoundAxisInfo bound_axis_1{1, std::vector{LessEqual}, + std::vector{1.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, std::nullopt, + std::vector{bound_axis_0, bound_axis_1}), + "Axis-wise bounds are supported for at most one axis."); + } + + GIVEN("(2x3x4)-Binary node with an axis-wise bound on axis: 1") { + BoundAxisInfo bound_axis{1, std::vector{LessEqual}, + std::vector{1.0, 1.0, 0.0}}; + + auto bnode_ptr = graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, + std::vector{bound_axis}); + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->num_bound_axes() == 1.0); + const BoundAxisInfo* bnode_bound_axis_ptr = bnode_ptr->get_ith_bound_axis_info(0); + CHECK(bound_axis.axis == bnode_bound_axis_ptr->axis); + CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis_ptr->operators)); + CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis_ptr->bounds)); + CHECK_THROWS_WITH(bnode_ptr->get_ith_bound_axis_info(1), + "Invalid ith bound axis requested: 1"); + CHECK_THROWS_WITH(bnode_ptr->get_ith_bound_axis_info(-1), + "Invalid ith bound axis requested: -1"); + } + } } TEST_CASE("IntegerNode") { From 17fcea552dc97d5d88bc9f82048c96a05ee4b1e5 Mon Sep 17 00:00:00 2001 From: fastbodin Date: Tue, 6 Jan 2026 13:06:03 -0800 Subject: [PATCH 2/7] Add axis-wise bound state dependant data to NumberNode For each bound axis and each hyperslice along said axis, we store the running sum of the values within the hyperslice. This state dependant data is stored via `NumberNodeStateData`. If `NumberNode` is initialized with values, we check that all axis-wise bounds are satisfied. --- .../dwave-optimization/nodes/numbers.hpp | 12 +- dwave/optimization/src/nodes/numbers.cpp | 247 +++++++++++++++--- tests/cpp/nodes/test_numbers.cpp | 45 +++- 3 files changed, 266 insertions(+), 38 deletions(-) diff --git a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp index 9760b27f..9e7a95f8 100644 --- a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp +++ b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp @@ -47,7 +47,6 @@ class BoundAxisInfo { /// (length of vector is equal to the number of slices). const std::vector bounds; - private: /// Obtain the bound associated with a given slice along bound axis. double get_bound(const ssize_t slice) const; @@ -140,7 +139,16 @@ class NumberNode : public ArrayOutputMixin, public DecisionNode { ssize_t num_bound_axes() const; /// Return the bound information for the ith bound axis - const BoundAxisInfo* get_ith_bound_axis_info(const ssize_t i) const; + const BoundAxisInfo* bound_axis_info(const ssize_t axis) const; + + /// The number of hyperslice along the ith bound axis + ssize_t num_hyperslice_along_bound_axis(State& state, const ssize_t axis) const; + + /// Get the sum of the values in the given slice along the ith bound axis + double bound_axis_hyperslice_sum(State& state, const ssize_t axis, const ssize_t slice) const; + + /// Check whether the axis-wise bounds are satisfied + bool satisfies_axis_wise_bounds(State& state) const; protected: explicit NumberNode(std::span shape, std::vector lower_bound, diff --git a/dwave/optimization/src/nodes/numbers.cpp b/dwave/optimization/src/nodes/numbers.cpp index 2e6be87c..e3ebce2f 100644 --- a/dwave/optimization/src/nodes/numbers.cpp +++ b/dwave/optimization/src/nodes/numbers.cpp @@ -16,10 +16,13 @@ #include #include +#include #include #include +#include #include "_state.hpp" +#include "dwave-optimization/array.hpp" namespace dwave::optimization { @@ -45,30 +48,16 @@ BoundAxisInfo::BoundAxisInfo(ssize_t bound_axis, std::vector } double BoundAxisInfo::get_bound(const ssize_t slice) const { - const ssize_t max_slice = bounds.size(); - // Negative indexing is not supported. - if ((slice < 0) || (slice >= max_slice)) { - throw std::invalid_argument("Out of range slice: " + std::to_string(slice) + - " along axis: " + std::to_string(axis)); - } - - if (max_slice == 1) { - return bounds[0]; - } + assert(0 <= slice); + if (bounds.size() == 1) return bounds[0]; + assert(slice < bounds.size()); return bounds[slice]; } BoundAxisOperator BoundAxisInfo::get_operator(const ssize_t slice) const { - const ssize_t max_slice = operators.size(); - // Negative indexing is not supported. - if ((slice < 0) || (slice >= max_slice)) { - throw std::invalid_argument("Out of range slice: " + std::to_string(slice) + - " along axis: " + std::to_string(axis)); - } - - if (max_slice == 1) { - return operators[0]; - } + assert(0 <= slice); + if (operators.size() == 1) return operators[0]; + assert(slice < operators.size()); return operators[slice]; } @@ -112,6 +101,118 @@ void check_index_wise_bounds(const NumberNode& node, const std::vector& } } +struct NumberNodeDataHelper_ { + NumberNodeDataHelper_(std::vector input, const std::span& shape, + const std::span& strides, + const std::vector& bound_axes_info) + : values(std::move(input)) { + if (bound_axes_info.empty()) return; // No axis sums to compute. + compute_bound_axis_hyperslice_sums(shape, strides, bound_axes_info); + } + + /// Variable assignment to NumberNode + std::vector values; + /// For each bound axis and for each hyperslice along said axis, we track + /// the sum of the values within the hyperslice. + /// bound_axes_sums[i][j] = "sum of the values within the jth hyperslice along + /// the ith bound axis" + std::vector> bound_axes_sums; + + /// Determine the sum of the values of each hyperslice along each bound + /// axis given the variable assignment of NumberNode. + void compute_bound_axis_hyperslice_sums(const std::span& shape, + const std::span& strides, + const std::vector& bound_axes_info) { + const ssize_t num_bound_axes = bound_axes_info.size(); + bound_axes_sums.reserve(num_bound_axes); + + // For each variable assignment of NumberNode (stored in values), we + // need to add the variables value to the running sum for each + // hyperslice it is contained in (and that we are tracking). For each + // such variable i and each bound axis j, we can identify which + // hyperslice i lies in along j via `unravel_index(i, shape)[j]`. + // However, this is inefficient. Instead we track the running + // multidimensional index (for each bound axis we care about) and + // adjust it based on the strides of the NumberNode array as + // we iterate over the variable assignments of the NumberNode. + // + // To do this easily, we first compute the element strides from the + // byte strides of the NumberNode array. Formally + // element_strides[i] = "# of elements need get to the next hyperslice + // along the ith bound axis" + const ssize_t bytes_per_element = static_cast(sizeof(double)); + std::vector element_strides; + element_strides.reserve(num_bound_axes); + // A running stride counter for each bound axis. + // When remaining_axis_strides[i] = 0, we have moved to the next + // hyperslice along the ith bound axis. + std::vector remaining_axis_strides; + remaining_axis_strides.reserve(num_bound_axes); + + // For each bound axis + for (ssize_t i = 0; i < num_bound_axes; ++i) { + const ssize_t bound_axis = bound_axes_info[i].axis; + assert(0 <= bound_axis && bound_axis < shape.size()); + + const ssize_t num_axis_slices = shape[bound_axis]; + // Initialize the sums for each hyperslice along the bound axis. + bound_axes_sums.emplace_back(std::vector(num_axis_slices, 0.0)); + + // Update element stride data + assert(strides[bound_axis] % bytes_per_element == 0); + element_strides.emplace_back(strides[bound_axis] / bytes_per_element); + // Initialize by the total # of element_strides along the bound axis + remaining_axis_strides.push_back(element_strides[i]); + } + + // Running hyperslice index per bound axis + std::vector hyperslice_index(num_bound_axes, 0); + + // Iterate over variable assignments of NumberNode. + for (ssize_t i = 0, stop = static_cast(values.size()); i < stop; ++i) { + // Iterate over the bound axes. + for (ssize_t j = 0; j < num_bound_axes; ++j) { + const ssize_t bound_axis = bound_axes_info[j].axis; + // Check the computation of the hyperslice + assert(unravel_index(i, shape)[bound_axis] == hyperslice_index[j]); + // Accumulate sum in hyperslice along jth bound axis + bound_axes_sums[j][hyperslice_index[j]] += values[i]; + + // Update running multidimensional index + if (--remaining_axis_strides[j] == 0) { + // Moved to next hyperslice, reset `remaining_axis_strides` + remaining_axis_strides[j] = element_strides[j]; + + // Increment the multi_index along bound axis modulo the # + // of hyperslice along said axis + if (++hyperslice_index[j] == shape[bound_axis]) { + hyperslice_index[j] = 0; + } + } + } + } + } +}; + +// State dependant data attached to NumberNode + +struct NumberNodeStateData : public ArrayNodeStateData { + NumberNodeStateData(std::vector input, const std::span& shape, + const std::span& strides, + const std::vector& bound_axis_info) + : NumberNodeStateData( + NumberNodeDataHelper_(std::move(input), shape, strides, bound_axis_info)) {} + + NumberNodeStateData(NumberNodeDataHelper_&& helper) + : ArrayNodeStateData(std::move(helper.values)), + bound_axes_sums(helper.bound_axes_sums), + prior_bound_axes_sums(std::move(helper.bound_axes_sums)) {} + + std::vector> bound_axes_sums; + // Store a copy for NumberNode::revert() + std::vector> prior_bound_axes_sums; +}; + /// Check the user defined axis-wise bounds for NumberNode void check_axis_wise_bounds(const std::vector& bound_axes_info, const std::span shape) { @@ -187,11 +288,11 @@ NumberNode::NumberNode(std::span shape, std::vector lower } double const* NumberNode::buff(const State& state) const noexcept { - return data_ptr(state)->buff(); + return data_ptr(state)->buff(); } std::span NumberNode::diff(const State& state) const noexcept { - return data_ptr(state)->diff(); + return data_ptr(state)->diff(); } double NumberNode::min() const { return min_; } @@ -208,7 +309,12 @@ void NumberNode::initialize_state(State& state, std::vector&& number_dat } } - emplace_data_ptr(state, std::move(number_data)); + emplace_data_ptr(state, std::move(number_data), this->shape(), + this->strides(), this->bound_axes_info_); + + if (!this->satisfies_axis_wise_bounds(state)) { + throw std::invalid_argument("Initialized values do not satisfy axis-wise bounds."); + } } void NumberNode::initialize_state(State& state) const { @@ -221,11 +327,17 @@ void NumberNode::initialize_state(State& state) const { } void NumberNode::commit(State& state) const noexcept { - data_ptr(state)->commit(); + auto node_data = data_ptr(state); + node_data->commit(); + // Manually store a copy of axis_sums + node_data->prior_bound_axes_sums = node_data->bound_axes_sums; } void NumberNode::revert(State& state) const noexcept { - data_ptr(state)->revert(); + auto node_data = data_ptr(state); + node_data->revert(); + // Manually reset axis_sums + node_data->bound_axes_sums = node_data->prior_bound_axes_sums; } void NumberNode::exchange(State& state, ssize_t i, ssize_t j) const { @@ -289,13 +401,88 @@ ssize_t NumberNode::num_bound_axes() const { return static_cast(bound_axes_info_.size()); }; -const BoundAxisInfo* NumberNode::get_ith_bound_axis_info(const ssize_t i) const { - if (i < 0 || i >= bound_axes_info_.size()) { - throw std::invalid_argument("Invalid ith bound axis requested: " + std::to_string(i)); - } - return &bound_axes_info_[i]; +const BoundAxisInfo* NumberNode::bound_axis_info(const ssize_t axis) const { + assert(axis >= 0 && axis < bound_axes_info_.size()); + return &bound_axes_info_[axis]; }; +ssize_t NumberNode::num_hyperslice_along_bound_axis(State& state, const ssize_t axis) const { + assert(axis >= 0 && axis < data_ptr(state)->bound_axes_sums.size()); + return data_ptr(state)->bound_axes_sums[axis].size(); +} + +double NumberNode::bound_axis_hyperslice_sum(State& state, const ssize_t axis, + const ssize_t slice) const { + assert(axis >= 0 && slice >= 0); + assert(axis < data_ptr(state)->bound_axes_sums.size()); + assert(slice < data_ptr(state)->bound_axes_sums[axis].size()); + return data_ptr(state)->bound_axes_sums[axis][slice]; +} + +// /// Check whether the axis-wise bound is satisfied for the given hyperslice +// void check_hyperslice(const BoundAxisInfo& bound_axis_info, const ssize_t slice, +// const double slice_sum) { +// const double rhs_bound = bound_axis_info.get_bound(slice); +// std::cout << slice_sum; +// +// switch (bound_axis_info.get_operator(slice)) { +// case Equal: +// std::cout << " == " << rhs_bound << std::endl; +// if (slice_sum == rhs_bound) return; +// case LessEqual: +// std::cout << " <= " << rhs_bound << std::endl; +// if (slice_sum <= rhs_bound) return; +// case GreaterEqual: +// std::cout << " >= " << rhs_bound << std::endl; +// if (slice_sum >= rhs_bound) return; +// default: +// throw std::invalid_argument("Invalid axis-wise bound operator"); +// } +// +// throw std::invalid_argument("Initialized state does not satisfy axis-wise bounds."); +// } + +bool NumberNode::satisfies_axis_wise_bounds(State& state) const { + const ssize_t num_bound_axes = this->num_bound_axes(); + if (num_bound_axes == 0) return true; // No bounds to satisfy + + // Grab the hyperslice sums of all bound axes + const std::vector>& bound_axes_sums = + data_ptr(state)->bound_axes_sums; + assert(num_bound_axes == bound_axes_sums.size()); + + for (ssize_t bound_axis = 0; bound_axis < num_bound_axes; ++bound_axis) { + // Grab the stateless axis-wise bound data for the bound axis + const BoundAxisInfo& bound_axis_info = this->bound_axes_info_[bound_axis]; + + // Grab the sums of all hyperslices along the bound axis + const std::vector& bound_axis_sums = bound_axes_sums[bound_axis]; + + // For each hyperslice along said axis + for (ssize_t slice = 0, stop = bound_axis_sums.size(); slice < stop; ++slice) { + const double rhs_bound = bound_axis_info.get_bound(slice); + const double slice_sum = bound_axis_sums[slice]; + + // Check whether the axis-wise bound is satisfied for the given hyperslice + switch (bound_axis_info.get_operator(slice)) { + case Equal: + if (slice_sum != rhs_bound) return false; + continue; + case LessEqual: + if (slice_sum > rhs_bound) return false; + continue; + case GreaterEqual: + if (slice_sum < rhs_bound) return false; + continue; + default: + throw std::invalid_argument("Invalid axis-wise bound operator"); + } + } + } + + return true; +} + // Integer Node *************************************************************** IntegerNode::IntegerNode(std::span shape, diff --git a/tests/cpp/nodes/test_numbers.cpp b/tests/cpp/nodes/test_numbers.cpp index df033a3f..177b3c4f 100644 --- a/tests/cpp/nodes/test_numbers.cpp +++ b/tests/cpp/nodes/test_numbers.cpp @@ -567,22 +567,55 @@ TEST_CASE("BinaryNode") { } GIVEN("(2x3x4)-Binary node with an axis-wise bound on axis: 1") { + auto graph = Graph(); + BoundAxisInfo bound_axis{1, std::vector{LessEqual}, - std::vector{1.0, 1.0, 0.0}}; + std::vector{4.0, 4.0, 6.0}}; auto bnode_ptr = graph.emplace_node( std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, std::vector{bound_axis}); + THEN("Axis wise bound is correct") { CHECK(bnode_ptr->num_bound_axes() == 1.0); - const BoundAxisInfo* bnode_bound_axis_ptr = bnode_ptr->get_ith_bound_axis_info(0); + const BoundAxisInfo* bnode_bound_axis_ptr = bnode_ptr->bound_axis_info(0); CHECK(bound_axis.axis == bnode_bound_axis_ptr->axis); CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis_ptr->operators)); CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis_ptr->bounds)); - CHECK_THROWS_WITH(bnode_ptr->get_ith_bound_axis_info(1), - "Invalid ith bound axis requested: 1"); - CHECK_THROWS_WITH(bnode_ptr->get_ith_bound_axis_info(-1), - "Invalid ith bound axis requested: -1"); + } + + WHEN("We initialize an invalid state") { + auto state = graph.empty_state(); + std::vector init_values(2 * 3 * 4, 1); + // import numpy as np + // a = np.ones((2,3,4)) + // a.sum(axis=(0, 2)) + // array([8, 8, 8]) + CHECK_THROWS_WITH(bnode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + } + + WHEN("We initialize a state") { + auto state = graph.empty_state(); + std::vector init_values{ + 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, + }; + std::cout << "here" << std::endl; + bnode_ptr->initialize_state(state, init_values); + graph.initialize_state(state); + + // import numpy as np + // a = np.array([1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, + // ... 1, 0, 1, 1, 1, 1, 0]) + // a = a.reshape(2,3,4) + // a.sum(axis=(0, 2)) + // array([3, 4, 5]) + THEN("The sums of each hyperslice along axis 1 are correct") { + CHECK(bnode_ptr->num_hyperslice_along_bound_axis(state, 0) == 3); + CHECK(bnode_ptr->bound_axis_hyperslice_sum(state, 0, 0) == 3); + CHECK(bnode_ptr->bound_axis_hyperslice_sum(state, 0, 1) == 4); + CHECK(bnode_ptr->bound_axis_hyperslice_sum(state, 0, 2) == 5); + } } } } From 6730fb03bb402b58b6f500c0bf149366972061f3 Mon Sep 17 00:00:00 2001 From: fastbodin Date: Mon, 12 Jan 2026 12:58:56 -0800 Subject: [PATCH 3/7] Add NumberNode axis-wise bound methods Added satisfies_axis_wise_bounds(), update_bound_axis_slice_sums(), axis_wise_bounds(), and bound_axis_sums() to NumberNode. Updated various NumberNode, IntegerNode, and BinaryNode methods to reference NumberNodeStateData as opposed to ArrayNodeStateData. Updated all NumberNode mutate methods to reflect changes to the axis-wise bound running sums. Added C++ tests to check said mutate methods on BinaryNode and IntegerNode. --- .../dwave-optimization/nodes/numbers.hpp | 31 +- dwave/optimization/src/nodes/numbers.cpp | 490 ++++++++-------- tests/cpp/nodes/test_numbers.cpp | 551 ++++++++++++++++-- 3 files changed, 784 insertions(+), 288 deletions(-) diff --git a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp index 9e7a95f8..29993d90 100644 --- a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp +++ b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp @@ -135,36 +135,35 @@ class NumberNode : public ArrayOutputMixin, public DecisionNode { // in a given index. void clip_and_set_value(State& state, ssize_t index, double value) const; - /// The number of axes with axis-wise bounds. - ssize_t num_bound_axes() const; + /// Return pointer to the vector of axis-wise bounds + const std::vector& axis_wise_bounds() const; - /// Return the bound information for the ith bound axis - const BoundAxisInfo* bound_axis_info(const ssize_t axis) const; - - /// The number of hyperslice along the ith bound axis - ssize_t num_hyperslice_along_bound_axis(State& state, const ssize_t axis) const; - - /// Get the sum of the values in the given slice along the ith bound axis - double bound_axis_hyperslice_sum(State& state, const ssize_t axis, const ssize_t slice) const; - - /// Check whether the axis-wise bounds are satisfied - bool satisfies_axis_wise_bounds(State& state) const; + // Return a pointer to the vector containing the bound axis sums + const std::vector>& bound_axis_sums(State& state) const; protected: explicit NumberNode(std::span shape, std::vector lower_bound, std::vector upper_bound, std::optional> bound_axes = std::nullopt); - // Return truth statement: 'value is valid in a given index'. + /// Return truth statement: 'value is valid in a given index'. virtual bool is_valid(ssize_t index, double value) const = 0; - // Default value in a given index. + /// Default value in a given index. virtual double default_value(ssize_t index) const = 0; + /// Check whether all axis-wise bounds are satisfied + bool satisfies_axis_wise_bounds(State& state) const; + + /// Update the running bound axis sums where `index` is changed by + /// `value_change` in a given state. + void update_bound_axis_slice_sums(State& state, const ssize_t index, + const double value_change) const; + double min_; double max_; - // Stateless index-wise upper and lower bounds + /// Stateless index-wise upper and lower bounds std::vector lower_bounds_; std::vector upper_bounds_; diff --git a/dwave/optimization/src/nodes/numbers.cpp b/dwave/optimization/src/nodes/numbers.cpp index e3ebce2f..9fad6545 100644 --- a/dwave/optimization/src/nodes/numbers.cpp +++ b/dwave/optimization/src/nodes/numbers.cpp @@ -29,8 +29,8 @@ namespace dwave::optimization { BoundAxisInfo::BoundAxisInfo(ssize_t bound_axis, std::vector axis_operators, std::vector axis_bounds) : axis(bound_axis), operators(std::move(axis_operators)), bounds(std::move(axis_bounds)) { - const ssize_t num_operators = operators.size(); - const ssize_t num_bounds = bounds.size(); + const ssize_t num_operators = static_cast(operators.size()); + const ssize_t num_bounds = static_cast(bounds.size()); // Null `operators` and `bounds` are not accepted. if ((num_operators == 0) || (num_bounds == 0)) { @@ -49,65 +49,26 @@ BoundAxisInfo::BoundAxisInfo(ssize_t bound_axis, std::vector double BoundAxisInfo::get_bound(const ssize_t slice) const { assert(0 <= slice); - if (bounds.size() == 1) return bounds[0]; - assert(slice < bounds.size()); + if (bounds.size() == 0) return bounds[0]; + assert(slice < static_cast(bounds.size())); return bounds[slice]; } BoundAxisOperator BoundAxisInfo::get_operator(const ssize_t slice) const { assert(0 <= slice); - if (operators.size() == 1) return operators[0]; - assert(slice < operators.size()); + if (operators.size() == 0) return operators[0]; + assert(slice < static_cast(operators.size())); return operators[slice]; } -template -double get_extreme_index_wise_bound(const std::vector& bound) { - assert(bound.size() > 0); - std::vector::const_iterator it; - if (maximum) { - it = std::max_element(bound.begin(), bound.end()); - } else { - it = std::min_element(bound.begin(), bound.end()); - } - return *it; -} - -void check_index_wise_bounds(const NumberNode& node, const std::vector& lower_bounds_, - const std::vector& upper_bounds_) { - bool index_wise_bound = false; - // If lower bound is index-wise, it must be correct size. - if (lower_bounds_.size() > 1) { - index_wise_bound = true; - if (static_cast(lower_bounds_.size()) != node.size()) { - throw std::invalid_argument("lower_bound must match size of node"); - } - } - // If upper bound is index-wise, it must be correct size. - if (upper_bounds_.size() > 1) { - index_wise_bound = true; - if (static_cast(upper_bounds_.size()) != node.size()) { - throw std::invalid_argument("upper_bound must match size of node"); - } - } - // If at least one of the bounds is index-wise, check that there are no - // violations at any of the indices. - if (index_wise_bound) { - for (ssize_t i = 0, stop = node.size(); i < stop; ++i) { - if (node.lower_bound(i) > node.upper_bound(i)) { - throw std::invalid_argument("Bounds of index " + std::to_string(i) + " clash"); - } - } - } -} - struct NumberNodeDataHelper_ { - NumberNodeDataHelper_(std::vector input, const std::span& shape, - const std::span& strides, - const std::vector& bound_axes_info) + NumberNodeDataHelper_(std::vector input, + const std::vector& bound_axes_info, + const std::span& shape, + const std::span& strides) : values(std::move(input)) { if (bound_axes_info.empty()) return; // No axis sums to compute. - compute_bound_axis_hyperslice_sums(shape, strides, bound_axes_info); + compute_bound_axis_hyperslice_sums(bound_axes_info, shape, strides); } /// Variable assignment to NumberNode @@ -120,24 +81,23 @@ struct NumberNodeDataHelper_ { /// Determine the sum of the values of each hyperslice along each bound /// axis given the variable assignment of NumberNode. - void compute_bound_axis_hyperslice_sums(const std::span& shape, - const std::span& strides, - const std::vector& bound_axes_info) { - const ssize_t num_bound_axes = bound_axes_info.size(); + void compute_bound_axis_hyperslice_sums(const std::vector& bound_axes_info, + const std::span& shape, + const std::span& strides) { + const ssize_t num_bound_axes = static_cast(bound_axes_info.size()); bound_axes_sums.reserve(num_bound_axes); - // For each variable assignment of NumberNode (stored in values), we - // need to add the variables value to the running sum for each - // hyperslice it is contained in (and that we are tracking). For each - // such variable i and each bound axis j, we can identify which - // hyperslice i lies in along j via `unravel_index(i, shape)[j]`. - // However, this is inefficient. Instead we track the running - // multidimensional index (for each bound axis we care about) and - // adjust it based on the strides of the NumberNode array as - // we iterate over the variable assignments of the NumberNode. + // For each variable assignment to NumberNode (stored in `values`), we need + // to add the variables value to the running sum for each hyperslice it is + // contained in (and that we are tracking). For each such variable i and + // each bound axis j, we can identify which hyperslice i lies in along j + // via `unravel_index(i, shape)[j]`. However, this is inefficient. Instead + // we track the running multidimensional index (for each bound axis we care + // about) and adjust it based on the strides of the NumberNode array as we + // iterate over the variable assignments of the NumberNode. // - // To do this easily, we first compute the element strides from the - // byte strides of the NumberNode array. Formally + // To do this easily, we first compute the element strides from the byte + // strides of the NumberNode array. Formally // element_strides[i] = "# of elements need get to the next hyperslice // along the ith bound axis" const ssize_t bytes_per_element = static_cast(sizeof(double)); @@ -148,11 +108,14 @@ struct NumberNodeDataHelper_ { // hyperslice along the ith bound axis. std::vector remaining_axis_strides; remaining_axis_strides.reserve(num_bound_axes); + // Running hyperslice index per bound axis + std::vector hyperslice_index; + hyperslice_index.reserve(num_bound_axes); // For each bound axis for (ssize_t i = 0; i < num_bound_axes; ++i) { const ssize_t bound_axis = bound_axes_info[i].axis; - assert(0 <= bound_axis && bound_axis < shape.size()); + assert(0 <= bound_axis && bound_axis < static_cast(shape.size())); const ssize_t num_axis_slices = shape[bound_axis]; // Initialize the sums for each hyperslice along the bound axis. @@ -163,10 +126,10 @@ struct NumberNodeDataHelper_ { element_strides.emplace_back(strides[bound_axis] / bytes_per_element); // Initialize by the total # of element_strides along the bound axis remaining_axis_strides.push_back(element_strides[i]); - } - // Running hyperslice index per bound axis - std::vector hyperslice_index(num_bound_axes, 0); + // Initialize hyperslice index to 0 + hyperslice_index.emplace_back(0); + } // Iterate over variable assignments of NumberNode. for (ssize_t i = 0, stop = static_cast(values.size()); i < stop; ++i) { @@ -194,14 +157,14 @@ struct NumberNodeDataHelper_ { } }; -// State dependant data attached to NumberNode - +/// State dependant data attached to NumberNode struct NumberNodeStateData : public ArrayNodeStateData { - NumberNodeStateData(std::vector input, const std::span& shape, - const std::span& strides, - const std::vector& bound_axis_info) + NumberNodeStateData(std::vector input, + const std::vector& bound_axes_info, + const std::span& shape, + const std::span& strides) : NumberNodeStateData( - NumberNodeDataHelper_(std::move(input), shape, strides, bound_axis_info)) {} + NumberNodeDataHelper_(std::move(input), bound_axes_info, shape, strides)) {} NumberNodeStateData(NumberNodeDataHelper_&& helper) : ArrayNodeStateData(std::move(helper.values)), @@ -209,84 +172,10 @@ struct NumberNodeStateData : public ArrayNodeStateData { prior_bound_axes_sums(std::move(helper.bound_axes_sums)) {} std::vector> bound_axes_sums; - // Store a copy for NumberNode::revert() + // Store a copy for NumberNode::revert() and commit() std::vector> prior_bound_axes_sums; }; -/// Check the user defined axis-wise bounds for NumberNode -void check_axis_wise_bounds(const std::vector& bound_axes_info, - const std::span shape) { - if (bound_axes_info.size() == 0) { // No bound axes to check. - return; - } - - // Used to asses if an axis have been bound multiple times. - std::vector axis_bound(shape.size(), false); - - // For each set of bound axis data - for (const BoundAxisInfo& bound_axis_info : bound_axes_info) { - const ssize_t axis = bound_axis_info.axis; - - if (axis < 0 || axis >= shape.size()) { - throw std::invalid_argument( - "Invalid bound axis: " + std::to_string(axis) + - ". Note, negative indexing is not supported for axis-wise bounds."); - } - - // The number of operators defined for the given bound axis - const ssize_t num_operators = bound_axis_info.operators.size(); - if ((num_operators > 1) && (num_operators != shape[axis])) { - throw std::invalid_argument( - "Invalid number of axis-wise operators along axis: " + std::to_string(axis) + - " given axis shape: " + std::to_string(shape[axis])); - } - - // The number of operators defined for the given bound axis - const ssize_t num_bounds = bound_axis_info.bounds.size(); - if ((num_bounds > 1) && (num_bounds != shape[axis])) { - throw std::invalid_argument( - "Invalid number of axis-wise bounds along axis: " + std::to_string(axis) + - " given axis shape: " + std::to_string(shape[axis])); - } - - // Checked in BoundAxisInfo constructor - assert(num_operators == num_bounds || num_operators == 1 || num_bounds == 1); - - if (axis_bound[axis]) { - throw std::invalid_argument( - "Cannot define multiple axis-wise bounds for a single axis."); - } - axis_bound[axis] = true; - } - - // *Currently*, we only support axis-wise bounds for up to one axis. - if (bound_axes_info.size() > 1) { - throw std::invalid_argument("Axis-wise bounds are supported for at most one axis."); - } -} - -// Base class to be used as interfaces. -NumberNode::NumberNode(std::span shape, std::vector lower_bound, - std::vector upper_bound, - std::optional> bound_axes) - : ArrayOutputMixin(shape), - min_(get_extreme_index_wise_bound(lower_bound)), - max_(get_extreme_index_wise_bound(upper_bound)), - lower_bounds_(std::move(lower_bound)), - upper_bounds_(std::move(upper_bound)), - bound_axes_info_(bound_axes ? std::move(*bound_axes) : std::vector{}) { - if ((shape.size() > 0) && (shape[0] < 0)) { - throw std::invalid_argument("Number array cannot have dynamic size."); - } - - if (max_ < min_) { - throw std::invalid_argument("Invalid range for number array provided."); - } - - check_index_wise_bounds(*this, lower_bounds_, upper_bounds_); - check_axis_wise_bounds(bound_axes_info_, this->shape()); -} - double const* NumberNode::buff(const State& state) const noexcept { return data_ptr(state)->buff(); } @@ -309,8 +198,8 @@ void NumberNode::initialize_state(State& state, std::vector&& number_dat } } - emplace_data_ptr(state, std::move(number_data), this->shape(), - this->strides(), this->bound_axes_info_); + emplace_data_ptr(state, std::move(number_data), bound_axes_info_, + this->shape(), this->strides()); if (!this->satisfies_axis_wise_bounds(state)) { throw std::invalid_argument("Initialized values do not satisfy axis-wise bounds."); @@ -328,16 +217,16 @@ void NumberNode::initialize_state(State& state) const { void NumberNode::commit(State& state) const noexcept { auto node_data = data_ptr(state); - node_data->commit(); - // Manually store a copy of axis_sums + // Manually store a copy of bound_axes_sums. node_data->prior_bound_axes_sums = node_data->bound_axes_sums; + node_data->commit(); } void NumberNode::revert(State& state) const noexcept { auto node_data = data_ptr(state); - node_data->revert(); - // Manually reset axis_sums + // Manually reset bound_axes_sums. node_data->bound_axes_sums = node_data->prior_bound_axes_sums; + node_data->revert(); } void NumberNode::exchange(State& state, ssize_t i, ssize_t j) const { @@ -349,18 +238,26 @@ void NumberNode::exchange(State& state, ssize_t i, ssize_t j) const { assert(upper_bound(j) >= ptr->get(i)); // Assert that i and j are valid indices occurs in ptr->exchange(). // Exchange occurs IFF (i != j) and (buffer[i] != buffer[j]). - ptr->exchange(i, j); + if (ptr->exchange(i, j)) { + // If the values at indices i and j were exchanged, update the bound + // axis sums. + const double difference = ptr->get(i) - ptr->get(j); + // Index i changed from (what is now) ptr->get(j) to ptr->get(i) + update_bound_axis_slice_sums(state, i, difference); + // Index j changed from (what is now) ptr->get(i) to ptr->get(j) + update_bound_axis_slice_sums(state, j, -difference); + assert(satisfies_axis_wise_bounds(state)); + } } double NumberNode::get_value(State& state, ssize_t i) const { - return data_ptr(state)->get(i); + return data_ptr(state)->get(i); } double NumberNode::lower_bound(ssize_t index) const { if (lower_bounds_.size() == 1) { return lower_bounds_[0]; } - assert(lower_bounds_.size() > 1); assert(0 <= index && index < static_cast(lower_bounds_.size())); return lower_bounds_[index]; } @@ -377,7 +274,6 @@ double NumberNode::upper_bound(ssize_t index) const { if (upper_bounds_.size() == 1) { return upper_bounds_[0]; } - assert(upper_bounds_.size() > 1); assert(0 <= index && index < static_cast(upper_bounds_.size())); return upper_bounds_[index]; } @@ -391,100 +287,214 @@ double NumberNode::upper_bound() const { } void NumberNode::clip_and_set_value(State& state, ssize_t index, double value) const { + auto ptr = data_ptr(state); value = std::clamp(value, lower_bound(index), upper_bound(index)); // Assert that i is a valid index occurs in data_ptr->set(). // Set occurs IFF `value` != buffer[i] . - data_ptr(state)->set(index, value); + if (ptr->set(index, value)) { + // Update the bound axis sums. + update_bound_axis_slice_sums(state, index, value - diff(state).back().old); + assert(satisfies_axis_wise_bounds(state)); + } } -ssize_t NumberNode::num_bound_axes() const { - return static_cast(bound_axes_info_.size()); -}; +const std::vector& NumberNode::axis_wise_bounds() const { return bound_axes_info_; } -const BoundAxisInfo* NumberNode::bound_axis_info(const ssize_t axis) const { - assert(axis >= 0 && axis < bound_axes_info_.size()); - return &bound_axes_info_[axis]; -}; +const std::vector>& NumberNode::bound_axis_sums(State& state) const { + return data_ptr(state)->bound_axes_sums; +} -ssize_t NumberNode::num_hyperslice_along_bound_axis(State& state, const ssize_t axis) const { - assert(axis >= 0 && axis < data_ptr(state)->bound_axes_sums.size()); - return data_ptr(state)->bound_axes_sums[axis].size(); +template +double get_extreme_index_wise_bound(const std::vector& bound) { + assert(bound.size() > 0); + std::vector::const_iterator it; + if (maximum) { + it = std::max_element(bound.begin(), bound.end()); + } else { + it = std::min_element(bound.begin(), bound.end()); + } + return *it; } -double NumberNode::bound_axis_hyperslice_sum(State& state, const ssize_t axis, - const ssize_t slice) const { - assert(axis >= 0 && slice >= 0); - assert(axis < data_ptr(state)->bound_axes_sums.size()); - assert(slice < data_ptr(state)->bound_axes_sums[axis].size()); - return data_ptr(state)->bound_axes_sums[axis][slice]; +void check_index_wise_bounds(const NumberNode& node, const std::vector& lower_bounds_, + const std::vector& upper_bounds_) { + bool index_wise_bound = false; + // If lower bound is index-wise, it must be correct size. + if (lower_bounds_.size() > 1) { + index_wise_bound = true; + if (static_cast(lower_bounds_.size()) != node.size()) { + throw std::invalid_argument("lower_bound must match size of node"); + } + } + // If upper bound is index-wise, it must be correct size. + if (upper_bounds_.size() > 1) { + index_wise_bound = true; + if (static_cast(upper_bounds_.size()) != node.size()) { + throw std::invalid_argument("upper_bound must match size of node"); + } + } + // If at least one of the bounds is index-wise, check that there are no + // violations at any of the indices. + if (index_wise_bound) { + for (ssize_t i = 0, stop = node.size(); i < stop; ++i) { + if (node.lower_bound(i) > node.upper_bound(i)) { + throw std::invalid_argument("Bounds of index " + std::to_string(i) + " clash"); + } + } + } } -// /// Check whether the axis-wise bound is satisfied for the given hyperslice -// void check_hyperslice(const BoundAxisInfo& bound_axis_info, const ssize_t slice, -// const double slice_sum) { -// const double rhs_bound = bound_axis_info.get_bound(slice); -// std::cout << slice_sum; -// -// switch (bound_axis_info.get_operator(slice)) { -// case Equal: -// std::cout << " == " << rhs_bound << std::endl; -// if (slice_sum == rhs_bound) return; -// case LessEqual: -// std::cout << " <= " << rhs_bound << std::endl; -// if (slice_sum <= rhs_bound) return; -// case GreaterEqual: -// std::cout << " >= " << rhs_bound << std::endl; -// if (slice_sum >= rhs_bound) return; -// default: -// throw std::invalid_argument("Invalid axis-wise bound operator"); -// } -// -// throw std::invalid_argument("Initialized state does not satisfy axis-wise bounds."); -// } +/// Check the user defined axis-wise bounds for NumberNode +void check_axis_wise_bounds(const std::vector& bound_axes_info, + const std::span shape) { + if (bound_axes_info.size() == 0) return; // No bound axes to check. -bool NumberNode::satisfies_axis_wise_bounds(State& state) const { - const ssize_t num_bound_axes = this->num_bound_axes(); - if (num_bound_axes == 0) return true; // No bounds to satisfy + // Used to asses if an axis have been bound multiple times. + std::vector axis_bound(shape.size(), false); - // Grab the hyperslice sums of all bound axes - const std::vector>& bound_axes_sums = - data_ptr(state)->bound_axes_sums; - assert(num_bound_axes == bound_axes_sums.size()); + // For each set of bound axis data + for (const BoundAxisInfo& bound_axis_info : bound_axes_info) { + const ssize_t axis = bound_axis_info.axis; - for (ssize_t bound_axis = 0; bound_axis < num_bound_axes; ++bound_axis) { - // Grab the stateless axis-wise bound data for the bound axis - const BoundAxisInfo& bound_axis_info = this->bound_axes_info_[bound_axis]; + if (axis < 0 || axis >= static_cast(shape.size())) { + throw std::invalid_argument( + "Invalid bound axis: " + std::to_string(axis) + + ". Note, negative indexing is not supported for axis-wise bounds."); + } - // Grab the sums of all hyperslices along the bound axis - const std::vector& bound_axis_sums = bound_axes_sums[bound_axis]; + // The number of operators defined for the given bound axis + const ssize_t num_operators = static_cast(bound_axis_info.operators.size()); + if ((num_operators > 1) && (num_operators != shape[axis])) { + throw std::invalid_argument( + "Invalid number of axis-wise operators along axis: " + std::to_string(axis) + + " given axis size: " + std::to_string(shape[axis])); + } - // For each hyperslice along said axis - for (ssize_t slice = 0, stop = bound_axis_sums.size(); slice < stop; ++slice) { - const double rhs_bound = bound_axis_info.get_bound(slice); - const double slice_sum = bound_axis_sums[slice]; + // The number of operators defined for the given bound axis + const ssize_t num_bounds = static_cast(bound_axis_info.bounds.size()); + if ((num_bounds > 1) && (num_bounds != shape[axis])) { + throw std::invalid_argument( + "Invalid number of axis-wise bounds along axis: " + std::to_string(axis) + + " given axis size: " + std::to_string(shape[axis])); + } + // Checked in BoundAxisInfo constructor + assert(num_operators == num_bounds || num_operators == 1 || num_bounds == 1); + + if (axis_bound[axis]) { + throw std::invalid_argument( + "Cannot define multiple axis-wise bounds for a single axis."); + } + axis_bound[axis] = true; + } + + // *Currently*, we only support axis-wise bounds for up to one axis. + if (bound_axes_info.size() > 1) { + throw std::invalid_argument("Axis-wise bounds are supported for at most one axis."); + } +} + +// Base class to be used as interfaces. +NumberNode::NumberNode(std::span shape, std::vector lower_bound, + std::vector upper_bound, + std::optional> bound_axes) + : ArrayOutputMixin(shape), + min_(get_extreme_index_wise_bound(lower_bound)), + max_(get_extreme_index_wise_bound(upper_bound)), + lower_bounds_(std::move(lower_bound)), + upper_bounds_(std::move(upper_bound)), + bound_axes_info_(bound_axes ? std::move(*bound_axes) : std::vector{}) { + if ((shape.size() > 0) && (shape[0] < 0)) { + throw std::invalid_argument("Number array cannot have dynamic size."); + } + + if (max_ < min_) { + throw std::invalid_argument("Invalid range for number array provided."); + } + + check_index_wise_bounds(*this, lower_bounds_, upper_bounds_); + check_axis_wise_bounds(bound_axes_info_, this->shape()); +} + +bool NumberNode::satisfies_axis_wise_bounds(State& state) const { + const auto& bound_axes_info = bound_axes_info_; + if (bound_axes_info.size() == 0) return true; // No axis-wise bounds to satisfy + + // Get the hyperslice sums of all bound axes. + const auto& bound_axes_sums = data_ptr(state)->bound_axes_sums; + assert(bound_axes_info.size() == bound_axes_sums.size()); + + for (ssize_t bound_axis = 0, stop = static_cast(bound_axes_info.size()); + bound_axis < stop; ++bound_axis) { + // Get the stateless axis-wise bound for the bound axis + const BoundAxisInfo& bound_axis_info = bound_axes_info[bound_axis]; + // Get the sums of all hyperslices along the bound axis + const std::vector& bound_axis_sums = bound_axes_sums[bound_axis]; + + // Possible To Do: We could "optimize" here if axis has uniform bounds + // and or operators for all slices. + for (ssize_t slice = 0, stop = static_cast(bound_axis_sums.size()); slice < stop; + ++slice) { // Check whether the axis-wise bound is satisfied for the given hyperslice switch (bound_axis_info.get_operator(slice)) { case Equal: - if (slice_sum != rhs_bound) return false; - continue; + if (bound_axis_sums[slice] == bound_axis_info.get_bound(slice)) continue; + return false; case LessEqual: - if (slice_sum > rhs_bound) return false; - continue; + if (bound_axis_sums[slice] <= bound_axis_info.get_bound(slice)) continue; + return false; case GreaterEqual: - if (slice_sum < rhs_bound) return false; - continue; + if (bound_axis_sums[slice] >= bound_axis_info.get_bound(slice)) continue; + return false; default: throw std::invalid_argument("Invalid axis-wise bound operator"); } } } - return true; } +void NumberNode::update_bound_axis_slice_sums(State& state, const ssize_t index, + const double value_change) const { + const auto& bound_axes_info = bound_axes_info_; + if (bound_axes_info.size() == 0) return; // No axis-wise bounds to satisfy + + // Obtain the multidimensional indices for `index` so we can identify the + // slices `index` lies on per bound axis. + // Possible To Do: We could optimize this get the bound axes indices only. + const std::vector multi_index = unravel_index(index, this->shape()); + assert(bound_axes_info.size() <= multi_index.size()); + // Get the hyperslice sums of all bound axes. + auto& bound_axes_sums = data_ptr(state)->bound_axes_sums; + assert(bound_axes_info.size() == bound_axes_sums.size()); + + for (ssize_t bound_axis = 0, stop = static_cast(bound_axes_info.size()); + bound_axis < stop; ++bound_axis) { + assert(bound_axes_info[bound_axis].axis < static_cast(multi_index.size())); + // Get the slice along the bound axis the `value_change` occurs in + const ssize_t slice = multi_index[bound_axes_info[bound_axis].axis]; + assert(slice < static_cast(bound_axes_sums[bound_axis].size())); + // Offset running sum in slice + bound_axes_sums[bound_axis][slice] += value_change; + } +} + // Integer Node *************************************************************** +/// Check the user defined axis-wise bounds for IntegerNode +void check_integrality_of_axis_wise_bounds(const std::vector& bound_axes_info) { + if (bound_axes_info.size() == 0) return; // No bound axes to check. + + for (const BoundAxisInfo& bound_axis_info : bound_axes_info) { + for (const double& bound : bound_axis_info.bounds) { + if (bound != std::round(bound)) { + throw std::invalid_argument( + "Axis wise bounds for integral number arrays must be intregral."); + } + } + } +} + IntegerNode::IntegerNode(std::span shape, std::optional> lower_bound, std::optional> upper_bound, @@ -498,6 +508,8 @@ IntegerNode::IntegerNode(std::span shape, if (min_ < minimum_lower_bound || max_ > maximum_upper_bound) { throw std::invalid_argument("range provided for integers exceeds supported range"); } + + check_integrality_of_axis_wise_bounds(bound_axes_info_); } IntegerNode::IntegerNode(std::initializer_list shape, @@ -564,13 +576,18 @@ bool IntegerNode::is_valid(ssize_t index, double value) const { } void IntegerNode::set_value(State& state, ssize_t index, double value) const { + auto ptr = data_ptr(state); // We expect `value` to obey the index-wise bounds and to be an integer. assert(lower_bound(index) <= value); assert(upper_bound(index) >= value); assert(value == std::round(value)); // Assert that i is a valid index occurs in data_ptr->set(). - // Set occurs IFF `value` != buffer[i] . - data_ptr(state)->set(index, value); + // set() occurs IFF `value` != buffer[i]. + if (ptr->set(index, value)) { + // Update the bound axis. + update_bound_axis_slice_sums(state, index, value - diff(state).back().old); + assert(satisfies_axis_wise_bounds(state)); + } } double IntegerNode::default_value(ssize_t index) const { @@ -675,24 +692,37 @@ void BinaryNode::flip(State& state, ssize_t i) const { // Variable should not be fixed. assert(lower_bound(i) != upper_bound(i)); // Assert that i is a valid index occurs in ptr->set(). - // Set occurs IFF `value` != buffer[i] . - ptr->set(i, !ptr->get(i)); + // set() occurs IFF `value` != buffer[i]. + if (ptr->set(i, !ptr->get(i))) { + // If value changed from 0 -> 1, update the bound axis sums by 1. + // If value changed from 1 -> 0, update the bound axis sums by -1. + update_bound_axis_slice_sums(state, i, (ptr->get(i) == 1) ? 1 : -1); + assert(satisfies_axis_wise_bounds(state)); + } } void BinaryNode::set(State& state, ssize_t i) const { // We expect the set to obey the index-wise bounds. assert(upper_bound(i) == 1.0); // Assert that i is a valid index occurs in data_ptr->set(). - // Set occurs IFF `value` != buffer[i] . - data_ptr(state)->set(i, 1.0); + // set() occurs IFF `value` != buffer[i]. + if (data_ptr(state)->set(i, 1.0)) { + // If value changed from 0 -> 1, update the bound axis sums by 1. + update_bound_axis_slice_sums(state, i, 1.0); + assert(satisfies_axis_wise_bounds(state)); + } } void BinaryNode::unset(State& state, ssize_t i) const { // We expect the set to obey the index-wise bounds. assert(lower_bound(i) == 0.0); // Assert that i is a valid index occurs in data_ptr->set(). - // Set occurs IFF `value` != buffer[i] . - data_ptr(state)->set(i, 0.0); + // set occurs IFF `value` != buffer[i]. + if (data_ptr(state)->set(i, 0.0)) { + // If value changed from 1 -> 0, update the bound axis sums by -1. + update_bound_axis_slice_sums(state, i, -1.0); + assert(satisfies_axis_wise_bounds(state)); + } } } // namespace dwave::optimization diff --git a/tests/cpp/nodes/test_numbers.cpp b/tests/cpp/nodes/test_numbers.cpp index 177b3c4f..20e58bf3 100644 --- a/tests/cpp/nodes/test_numbers.cpp +++ b/tests/cpp/nodes/test_numbers.cpp @@ -18,6 +18,7 @@ #include "catch2/catch_test_macros.hpp" #include "catch2/matchers/catch_matchers.hpp" #include "catch2/matchers/catch_matchers_all.hpp" +#include "catch2/matchers/catch_matchers_range_equals.hpp" #include "dwave-optimization/graph.hpp" #include "dwave-optimization/nodes/numbers.hpp" @@ -486,7 +487,7 @@ TEST_CASE("BinaryNode") { "Number array cannot have dynamic size."); } - GIVEN("(2x3)-Binary node with axis-wise bounds on the invalid axis -1") { + GIVEN("(2x3)-BinaryNode with axis-wise bounds on the invalid axis -1") { BoundAxisInfo bound_axis{-1, std::vector{Equal}, std::vector{1.0}}; REQUIRE_THROWS_WITH(graph.emplace_node( @@ -496,7 +497,7 @@ TEST_CASE("BinaryNode") { "axis-wise bounds."); } - GIVEN("(2x3)-Binary node with axis-wise bounds on the invalid axis 2") { + GIVEN("(2x3)-BinaryNode with axis-wise bounds on the invalid axis 2") { BoundAxisInfo bound_axis{2, std::vector{Equal}, std::vector{1.0}}; REQUIRE_THROWS_WITH(graph.emplace_node( @@ -506,45 +507,45 @@ TEST_CASE("BinaryNode") { "axis-wise bounds."); } - GIVEN("(2x3)-Binary node with axis-wise bounds on axis: 1 with too many operators.") { + GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too many operators.") { BoundAxisInfo bound_axis{1, std::vector{LessEqual, Equal, Equal, Equal}, std::vector{1.0}}; REQUIRE_THROWS_WITH( graph.emplace_node( std::initializer_list{2, 3}, std::nullopt, std::nullopt, std::vector{bound_axis}), - "Invalid number of axis-wise operators along axis: 1 given axis shape: 3"); + "Invalid number of axis-wise operators along axis: 1 given axis size: 3"); } - GIVEN("(2x3)-Binary node with axis-wise bounds on axis: 1 with too few operators.") { + GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too few operators.") { BoundAxisInfo bound_axis{1, std::vector{LessEqual, Equal}, std::vector{1.0}}; REQUIRE_THROWS_WITH( graph.emplace_node( std::initializer_list{2, 3}, std::nullopt, std::nullopt, std::vector{bound_axis}), - "Invalid number of axis-wise operators along axis: 1 given axis shape: 3"); + "Invalid number of axis-wise operators along axis: 1 given axis size: 3"); } - GIVEN("(2x3)-Binary node with axis-wise bounds on axis: 1 with too many bounds.") { + GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too many bounds.") { BoundAxisInfo bound_axis{1, std::vector{LessEqual}, std::vector{1.0, 2.0, 3.0, 4.0}}; REQUIRE_THROWS_WITH(graph.emplace_node( std::initializer_list{2, 3}, std::nullopt, std::nullopt, std::vector{bound_axis}), - "Invalid number of axis-wise bounds along axis: 1 given axis shape: 3"); + "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); } - GIVEN("(2x3)-Binary node with axis-wise bounds on axis: 1 with too few bounds.") { + GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too few bounds.") { BoundAxisInfo bound_axis{1, std::vector{LessEqual}, std::vector{1.0, 2.0}}; REQUIRE_THROWS_WITH(graph.emplace_node( std::initializer_list{2, 3}, std::nullopt, std::nullopt, std::vector{bound_axis}), - "Invalid number of axis-wise bounds along axis: 1 given axis shape: 3"); + "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); } - GIVEN("(2x3)-Binary node with duplicate axis-wise bounds on axis: 1") { + GIVEN("(2x3)-BinaryNode with duplicate axis-wise bounds on axis: 1") { BoundAxisInfo bound_axis{1, std::vector{Equal}, std::vector{1.0}}; REQUIRE_THROWS_WITH( @@ -554,7 +555,7 @@ TEST_CASE("BinaryNode") { "Cannot define multiple axis-wise bounds for a single axis."); } - GIVEN("(2x3)-Binary node with axis-wise bounds on axes: 0 and 1") { + GIVEN("(2x3)-BinaryNode with axis-wise bounds on axes: 0 and 1") { BoundAxisInfo bound_axis_0{0, std::vector{LessEqual}, std::vector{1.0}}; BoundAxisInfo bound_axis_1{1, std::vector{LessEqual}, @@ -566,55 +567,259 @@ TEST_CASE("BinaryNode") { "Axis-wise bounds are supported for at most one axis."); } - GIVEN("(2x3x4)-Binary node with an axis-wise bound on axis: 1") { + GIVEN("(2x3x4)-BinaryNode with an axis-wise bound on axis: 0") { auto graph = Graph(); - BoundAxisInfo bound_axis{1, std::vector{LessEqual}, - std::vector{4.0, 4.0, 6.0}}; + BoundAxisInfo bound_axis{0, std::vector{Equal, LessEqual, GreaterEqual}, + std::vector{1.0, 2.0, 3.0}}; auto bnode_ptr = graph.emplace_node( - std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, + std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, std::vector{bound_axis}); THEN("Axis wise bound is correct") { - CHECK(bnode_ptr->num_bound_axes() == 1.0); - const BoundAxisInfo* bnode_bound_axis_ptr = bnode_ptr->bound_axis_info(0); - CHECK(bound_axis.axis == bnode_bound_axis_ptr->axis); - CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis_ptr->operators)); - CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis_ptr->bounds)); + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axis.axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); } - WHEN("We initialize an invalid state") { + WHEN("We initialize three invalid states") { auto state = graph.empty_state(); - std::vector init_values(2 * 3 * 4, 1); + // This state violates the 0th hyperslice along axis 0 + std::vector init_values{1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1}; // import numpy as np - // a = np.ones((2,3,4)) - // a.sum(axis=(0, 2)) - // array([8, 8, 8]) + // a = np.asarray([1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) + // a = a.reshape(3, 2, 2) + // a.sum(axis=(1, 2)) + // >>> array([2, 2, 4]) + CHECK_THROWS_WITH(bnode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + + state = graph.empty_state(); + // This state violates the 1st hyperslice along axis 0 + init_values = {0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1}; + // import numpy as np + // a = np.asarray([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]) + // a = a.reshape(3, 2, 2) + // a.sum(axis=(1, 2)) + // >>> array([1, 3, 4]) + CHECK_THROWS_WITH(bnode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + + state = graph.empty_state(); + // This state violates the 2nd hyperslice along axis 0 + init_values = {0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0}; + // import numpy as np + // a = np.asarray([0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0]) + // a = a.reshape(3, 2, 2) + // a.sum(axis=(1, 2)) + // >>> array([1, 2, 2]) CHECK_THROWS_WITH(bnode_ptr->initialize_state(state, init_values), "Initialized values do not satisfy axis-wise bounds."); } - WHEN("We initialize a state") { + WHEN("We initialize a valid state") { auto state = graph.empty_state(); - std::vector init_values{ - 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, - }; - std::cout << "here" << std::endl; + std::vector init_values{0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1}; bnode_ptr->initialize_state(state, init_values); graph.initialize_state(state); - // import numpy as np - // a = np.array([1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, - // ... 1, 0, 1, 1, 1, 1, 0]) - // a = a.reshape(2,3,4) - // a.sum(axis=(0, 2)) - // array([3, 4, 5]) - THEN("The sums of each hyperslice along axis 1 are correct") { - CHECK(bnode_ptr->num_hyperslice_along_bound_axis(state, 0) == 3); - CHECK(bnode_ptr->bound_axis_hyperslice_sum(state, 0, 0) == 3); - CHECK(bnode_ptr->bound_axis_hyperslice_sum(state, 0, 1) == 4); - CHECK(bnode_ptr->bound_axis_hyperslice_sum(state, 0, 2) == 5); + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + // **Python Code 1** + // import numpy as np + // a = np.asarray([0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) + // a = a.reshape(3, 2, 2) + // a.sum(axis=(1, 2)) + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 3); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + THEN("We exchange() some values") { + bnode_ptr->exchange(state, 0, 3); // Does nothing. + bnode_ptr->exchange(state, 1, 6); // Does nothing. + bnode_ptr->exchange(state, 1, 3); + std::swap(init_values[0], init_values[3]); + std::swap(init_values[1], init_values[6]); + std::swap(init_values[1], init_values[3]); + // state is now: [0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 1** + // a[np.unravel_index(1, a.shape)] = 0 + // a[np.unravel_index(3, a.shape)] = 1 + // a.sum(axis=(1, 2)) + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK(bnode_ptr->diff(state).size() == 2); // 2 updates per exchange + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK(bnode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We clip_and_set_value() some values") { + bnode_ptr->clip_and_set_value(state, 5, -1); // Does nothing. + bnode_ptr->clip_and_set_value(state, 7, -1); + bnode_ptr->clip_and_set_value(state, 9, 1); // Does nothing. + bnode_ptr->clip_and_set_value(state, 11, 0); + bnode_ptr->clip_and_set_value(state, 11, 1); + bnode_ptr->clip_and_set_value(state, 10, 0); + init_values[5] = 0; + init_values[7] = 0; + init_values[9] = 1; + init_values[11] = 1; + init_values[10] = 0; + // state is now: [0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 1** + // a[np.unravel_index(5, a.shape)] = 0 + // a[np.unravel_index(7, a.shape)] = 0 + // a[np.unravel_index(9, a.shape)] = 1 + // a[np.unravel_index(11, a.shape)] = 1 + // a[np.unravel_index(10, a.shape)] = 0 + // a.sum(axis=(1, 2)) + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 1, 3})); + CHECK(bnode_ptr->diff(state).size() == 4); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK(bnode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We set_value() some values") { + bnode_ptr->set_value(state, 0, 0); // Does nothing. + bnode_ptr->set_value(state, 6, 0); + bnode_ptr->set_value(state, 7, 0); + bnode_ptr->set_value(state, 4, 1); + bnode_ptr->set_value(state, 10, 1); // Does nothing. + bnode_ptr->set_value(state, 11, 0); + init_values[0] = 0; + init_values[6] = 0; + init_values[7] = 0; + init_values[4] = 1; + init_values[10] = 1; + init_values[11] = 0; + // state is now: [0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 1** + // a[np.unravel_index(0, a.shape)] = 0 + // a[np.unravel_index(6, a.shape)] = 0 + // a[np.unravel_index(7, a.shape)] = 0 + // a[np.unravel_index(4, a.shape)] = 1 + // a[np.unravel_index(10, a.shape)] = 1 + // a[np.unravel_index(11, a.shape)] = 0 + // a.sum(axis=(1, 2)) + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 1, 3})); + CHECK(bnode_ptr->diff(state).size() == 4); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK(bnode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We flip() some values") { + bnode_ptr->flip(state, 6); // 1 -> 0 + bnode_ptr->flip(state, 4); // 0 -> 1 + bnode_ptr->flip(state, 11); // 1 -> 0 + init_values[6] = !init_values[6]; + init_values[4] = !init_values[4]; + init_values[11] = !init_values[11]; + // state is now: [0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 1** + // a[np.unravel_index(6, a.shape)] = 0 + // a[np.unravel_index(4, a.shape)] = 1 + // a[np.unravel_index(11, a.shape)] = 0 + // a.sum(axis=(1, 2)) + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 3})); + CHECK(bnode_ptr->diff(state).size() == 3); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK(bnode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We unset() some values") { + bnode_ptr->unset(state, 0); // Does nothing. + bnode_ptr->unset(state, 6); + bnode_ptr->unset(state, 11); + init_values[0] = 0; + init_values[6] = 0; + init_values[11] = 0; + // state is now: [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 1** + // a[np.unravel_index(0, a.shape)] = 0 + // a[np.unravel_index(6, a.shape)] = 0 + // a[np.unravel_index(11, a.shape)] = 0 + // a.sum(axis=(1, 2)) + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 1, 3})); + CHECK(bnode_ptr->diff(state).size() == 2); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We commit and set() some values") { + graph.commit(state); + + bnode_ptr->set(state, 10); // Does nothing. + bnode_ptr->set(state, 11); + init_values[10] = 1; + init_values[11] = 1; + // state is now: [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + + THEN("The bound axis sums updated correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 1, 4})); + CHECK(bnode_ptr->diff(state).size() == 1); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], + RangeEquals({1, 1, 3})); + CHECK(bnode_ptr->diff(state).size() == 0); + } + } + } } } } @@ -916,6 +1121,268 @@ TEST_CASE("IntegerNode") { REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{-1, 3}), "Number array cannot have dynamic size."); } + + GIVEN("(2x3)-IntegerNode with axis-wise bounds on the invalid axis -2") { + BoundAxisInfo bound_axis{-2, std::vector{Equal}, + std::vector{20.0}}; + REQUIRE_THROWS_WITH(graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, + std::nullopt, std::vector{bound_axis}), + "Invalid bound axis: -2. Note, negative indexing is not supported for " + "axis-wise bounds."); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on the invalid axis 3") { + BoundAxisInfo bound_axis{3, std::vector{Equal}, + std::vector{10.0}}; + REQUIRE_THROWS_WITH(graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, + std::nullopt, std::vector{bound_axis}), + "Invalid bound axis: 3. Note, negative indexing is not supported for " + "axis-wise bounds."); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too many operators.") { + BoundAxisInfo bound_axis{1, std::vector{LessEqual, Equal, Equal, Equal}, + std::vector{-10.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, + std::vector{bound_axis}), + "Invalid number of axis-wise operators along axis: 1 given axis size: 3"); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too few operators.") { + BoundAxisInfo bound_axis{1, std::vector{LessEqual, Equal}, + std::vector{-11.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, + std::vector{bound_axis}), + "Invalid number of axis-wise operators along axis: 1 given axis size: 3"); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too many bounds.") { + BoundAxisInfo bound_axis{1, std::vector{LessEqual}, + std::vector{-10.0, 20.0, 30.0, 40.0}}; + REQUIRE_THROWS_WITH(graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, + std::nullopt, std::vector{bound_axis}), + "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too few bounds.") { + BoundAxisInfo bound_axis{1, std::vector{LessEqual}, + std::vector{111.0, -223.0}}; + REQUIRE_THROWS_WITH(graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, + std::nullopt, std::vector{bound_axis}), + "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); + } + + GIVEN("(2x3x4)-IntegerNode with duplicate axis-wise bounds on axis: 1") { + BoundAxisInfo bound_axis{1, std::vector{Equal}, + std::vector{100.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, + std::vector{bound_axis, bound_axis}), + "Cannot define multiple axis-wise bounds for a single axis."); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axes: 0 and 1") { + BoundAxisInfo bound_axis_0{0, std::vector{LessEqual}, + std::vector{11.0}}; + BoundAxisInfo bound_axis_1{1, std::vector{LessEqual}, + std::vector{12.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, + std::vector{bound_axis_0, bound_axis_1}), + "Axis-wise bounds are supported for at most one axis."); + } + + GIVEN("(2x3x4)-IntegerNode with non-integral axis-wise bounds") { + BoundAxisInfo bound_axis{2, std::vector{LessEqual}, + std::vector{11.0, 12.0001, 0.0, 0.0}}; + REQUIRE_THROWS_WITH(graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, + std::nullopt, std::vector{bound_axis}), + "Axis wise bounds for integral number arrays must be intregral."); + } + + GIVEN("(2x3x2)-IntegerNode with index-wise bounds and an axis-wise bound on axis: 1") { + auto graph = Graph(); + + BoundAxisInfo bound_axis{1, std::vector{Equal, LessEqual, GreaterEqual}, + std::vector{11.0, 2.0, 5.0}}; + + auto inode_ptr = graph.emplace_node( + std::initializer_list{2, 3, 2}, -5, 8, + std::vector{bound_axis}); + + THEN("Axis wise bound is correct") { + CHECK(inode_ptr->axis_wise_bounds().size() == 1); + const BoundAxisInfo inode_bound_axis_ptr = inode_ptr->axis_wise_bounds().data()[0]; + CHECK(bound_axis.axis == inode_bound_axis_ptr.axis); + CHECK_THAT(bound_axis.operators, RangeEquals(inode_bound_axis_ptr.operators)); + CHECK_THAT(bound_axis.bounds, RangeEquals(inode_bound_axis_ptr.bounds)); + } + + WHEN("We initialize three invalid states") { + auto state = graph.empty_state(); + // This state violates the 0th hyperslice along axis 1 + std::vector init_values{5, 6, 0, 0, 3, 1, 4, 0, 2, 0, 0, 3}; + // import numpy as np + // a = np.asarray([5, 6, 0, 0, 3, 1, 4, 0, 2, 0, 0, 3]) + // a = a.reshape(2, 3, 2) + // a.sum(axis=(0, 2)) + // >>> array([15, 2, 7]) + CHECK_THROWS_WITH(inode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + + state = graph.empty_state(); + // This state violates the 1st hyperslice along axis 1 + init_values = {5, 2, 0, 0, 3, 1, 4, 0, 2, 1, 0, 3}; + // import numpy as np + // a = np.asarray([5, 2, 0, 0, 3, 1, 4, 0, 2, 1, 0, 3]) + // a = a.reshape(2, 3, 2) + // a.sum(axis=(0, 2)) + // >>> array([11, 3, 7]) + CHECK_THROWS_WITH(inode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + + state = graph.empty_state(); + // This state violates the 2nd hyperslice along axis 1 + init_values = {5, 2, 0, 0, 3, 1, 4, 0, 1, 0, 0, 0}; + // import numpy as np + // a = np.asarray([5, 2, 0, 0, 3, 1, 4, 0, 1, 0, 0, 0]) + // a = a.reshape(2, 3, 2) + // a.sum(axis=(0, 2)) + // >>> array([11, 1, 4]) + CHECK_THROWS_WITH(inode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + } + + WHEN("We initialize a valid state") { + auto state = graph.empty_state(); + std::vector init_values{5, 2, 0, 0, 3, 1, 4, 0, 2, 0, 0, 3}; + inode_ptr->initialize_state(state, init_values); + graph.initialize_state(state); + + auto bound_axis_sums = inode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + // **Python Code 2** + // import numpy as np + // a = np.asarray([5, 2, 0, 0, 3, 1, 4, 0, 2, 0, 0, 3]) + // a = a.reshape(2, 3, 2) + // a.sum(axis=(0, 2)) + // >>> array([11, 2, 7]) + CHECK(inode_ptr->bound_axis_sums(state).size() == 1); + CHECK(inode_ptr->bound_axis_sums(state).data()[0].size() == 3); + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, 2, 7})); + CHECK_THAT(inode_ptr->view(state), RangeEquals(init_values)); + } + + THEN("We exchange() some values") { + inode_ptr->exchange(state, 2, 3); // Does nothing. + inode_ptr->exchange(state, 1, 8); // Does nothing. + inode_ptr->exchange(state, 8, 10); + inode_ptr->exchange(state, 0, 1); + std::swap(init_values[2], init_values[3]); + std::swap(init_values[1], init_values[8]); + std::swap(init_values[8], init_values[10]); + std::swap(init_values[0], init_values[1]); + // state is now: [2, 5, 0, 0, 3, 1, 4, 0, 0, 0, 2, 3] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 2** + // a[np.unravel_index(8, a.shape)] = 0 + // a[np.unravel_index(10, a.shape)] = 2 + // a[np.unravel_index(0, a.shape)] = 2 + // a[np.unravel_index(1, a.shape)] = 5 + // a.sum(axis=(0, 2)) + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, 0, 9})); + CHECK(inode_ptr->diff(state).size() == 4); // 2 updates per exchange + CHECK_THAT(inode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, 2, 7})); + CHECK(inode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We clip_and_set_value() some values") { + inode_ptr->clip_and_set_value(state, 0, 5); // Does nothing. + inode_ptr->clip_and_set_value(state, 8, -300); + inode_ptr->clip_and_set_value(state, 10, 100); + init_values[8] = -5; + init_values[10] = 8; + // state is now: [5, 2, 0, 0, 3, 1, 4, 0, -5, 0, 8, 3] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 2** + // a[np.unravel_index(8, a.shape)] = -5 + // a[np.unravel_index(10, a.shape)] = 8 + // a.sum(axis=(0, 2)) + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, -5, 15})); + CHECK(inode_ptr->diff(state).size() == 2); + CHECK_THAT(inode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, 2, 7})); + CHECK(inode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We set_value() some values") { + inode_ptr->set_value(state, 0, 5); // Does nothing. + inode_ptr->set_value(state, 8, 0); + inode_ptr->set_value(state, 9, 1); + inode_ptr->set_value(state, 10, 5); + inode_ptr->set_value(state, 11, 0); + init_values[0] = 5; + init_values[8] = 0; + init_values[9] = 1; + init_values[10] = 5; + init_values[11] = 0; + // state is now: [5, 2, 0, 0, 3, 1, 4, 0, 0, 1, 5, 0] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 2** + // a[np.unravel_index(0, a.shape)] = 5 + // a[np.unravel_index(8, a.shape)] = 0 + // a[np.unravel_index(9, a.shape)] = 1 + // a[np.unravel_index(10, a.shape)] = 5 + // a[np.unravel_index(11, a.shape)] = 0 + // a.sum(axis=(0, 2)) + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, 1, 9})); + CHECK(inode_ptr->diff(state).size() == 4); + CHECK_THAT(inode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bound_axis_sums[0], RangeEquals({11, 2, 7})); + CHECK(inode_ptr->diff(state).size() == 0); + } + } + } + } + } } } // namespace dwave::optimization From 556132c92c5d194b2dd49a424f77e2961f08b2ad Mon Sep 17 00:00:00 2001 From: fastbodin Date: Wed, 28 Jan 2026 13:51:26 -0800 Subject: [PATCH 4/7] Simplify NumberNodeStateData Make use of BufferIterators to compute the sum of the values within each hyperslice along each bound axis as opposed making a custom method to do this. --- .../dwave-optimization/nodes/numbers.hpp | 6 +- dwave/optimization/src/nodes/numbers.cpp | 273 +++++++----------- 2 files changed, 109 insertions(+), 170 deletions(-) diff --git a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp index 29993d90..2735c173 100644 --- a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp +++ b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp @@ -152,18 +152,16 @@ class NumberNode : public ArrayOutputMixin, public DecisionNode { /// Default value in a given index. virtual double default_value(ssize_t index) const = 0; - /// Check whether all axis-wise bounds are satisfied - bool satisfies_axis_wise_bounds(State& state) const; - /// Update the running bound axis sums where `index` is changed by /// `value_change` in a given state. void update_bound_axis_slice_sums(State& state, const ssize_t index, const double value_change) const; + /// Statelss global minimum and maximum of the values stored in NumberNode. double min_; double max_; - /// Stateless index-wise upper and lower bounds + /// Stateless index-wise upper and lower bounds. std::vector lower_bounds_; std::vector upper_bounds_; diff --git a/dwave/optimization/src/nodes/numbers.cpp b/dwave/optimization/src/nodes/numbers.cpp index 9fad6545..574b8f59 100644 --- a/dwave/optimization/src/nodes/numbers.cpp +++ b/dwave/optimization/src/nodes/numbers.cpp @@ -15,6 +15,8 @@ #include "dwave-optimization/nodes/numbers.hpp" #include +#include +#include #include #include #include @@ -61,116 +63,17 @@ BoundAxisOperator BoundAxisInfo::get_operator(const ssize_t slice) const { return operators[slice]; } -struct NumberNodeDataHelper_ { - NumberNodeDataHelper_(std::vector input, - const std::vector& bound_axes_info, - const std::span& shape, - const std::span& strides) - : values(std::move(input)) { - if (bound_axes_info.empty()) return; // No axis sums to compute. - compute_bound_axis_hyperslice_sums(bound_axes_info, shape, strides); - } - - /// Variable assignment to NumberNode - std::vector values; - /// For each bound axis and for each hyperslice along said axis, we track - /// the sum of the values within the hyperslice. - /// bound_axes_sums[i][j] = "sum of the values within the jth hyperslice along - /// the ith bound axis" - std::vector> bound_axes_sums; - - /// Determine the sum of the values of each hyperslice along each bound - /// axis given the variable assignment of NumberNode. - void compute_bound_axis_hyperslice_sums(const std::vector& bound_axes_info, - const std::span& shape, - const std::span& strides) { - const ssize_t num_bound_axes = static_cast(bound_axes_info.size()); - bound_axes_sums.reserve(num_bound_axes); - - // For each variable assignment to NumberNode (stored in `values`), we need - // to add the variables value to the running sum for each hyperslice it is - // contained in (and that we are tracking). For each such variable i and - // each bound axis j, we can identify which hyperslice i lies in along j - // via `unravel_index(i, shape)[j]`. However, this is inefficient. Instead - // we track the running multidimensional index (for each bound axis we care - // about) and adjust it based on the strides of the NumberNode array as we - // iterate over the variable assignments of the NumberNode. - // - // To do this easily, we first compute the element strides from the byte - // strides of the NumberNode array. Formally - // element_strides[i] = "# of elements need get to the next hyperslice - // along the ith bound axis" - const ssize_t bytes_per_element = static_cast(sizeof(double)); - std::vector element_strides; - element_strides.reserve(num_bound_axes); - // A running stride counter for each bound axis. - // When remaining_axis_strides[i] = 0, we have moved to the next - // hyperslice along the ith bound axis. - std::vector remaining_axis_strides; - remaining_axis_strides.reserve(num_bound_axes); - // Running hyperslice index per bound axis - std::vector hyperslice_index; - hyperslice_index.reserve(num_bound_axes); - - // For each bound axis - for (ssize_t i = 0; i < num_bound_axes; ++i) { - const ssize_t bound_axis = bound_axes_info[i].axis; - assert(0 <= bound_axis && bound_axis < static_cast(shape.size())); - - const ssize_t num_axis_slices = shape[bound_axis]; - // Initialize the sums for each hyperslice along the bound axis. - bound_axes_sums.emplace_back(std::vector(num_axis_slices, 0.0)); - - // Update element stride data - assert(strides[bound_axis] % bytes_per_element == 0); - element_strides.emplace_back(strides[bound_axis] / bytes_per_element); - // Initialize by the total # of element_strides along the bound axis - remaining_axis_strides.push_back(element_strides[i]); - - // Initialize hyperslice index to 0 - hyperslice_index.emplace_back(0); - } - - // Iterate over variable assignments of NumberNode. - for (ssize_t i = 0, stop = static_cast(values.size()); i < stop; ++i) { - // Iterate over the bound axes. - for (ssize_t j = 0; j < num_bound_axes; ++j) { - const ssize_t bound_axis = bound_axes_info[j].axis; - // Check the computation of the hyperslice - assert(unravel_index(i, shape)[bound_axis] == hyperslice_index[j]); - // Accumulate sum in hyperslice along jth bound axis - bound_axes_sums[j][hyperslice_index[j]] += values[i]; - - // Update running multidimensional index - if (--remaining_axis_strides[j] == 0) { - // Moved to next hyperslice, reset `remaining_axis_strides` - remaining_axis_strides[j] = element_strides[j]; - - // Increment the multi_index along bound axis modulo the # - // of hyperslice along said axis - if (++hyperslice_index[j] == shape[bound_axis]) { - hyperslice_index[j] = 0; - } - } - } - } - } -}; - /// State dependant data attached to NumberNode struct NumberNodeStateData : public ArrayNodeStateData { - NumberNodeStateData(std::vector input, - const std::vector& bound_axes_info, - const std::span& shape, - const std::span& strides) - : NumberNodeStateData( - NumberNodeDataHelper_(std::move(input), bound_axes_info, shape, strides)) {} - - NumberNodeStateData(NumberNodeDataHelper_&& helper) - : ArrayNodeStateData(std::move(helper.values)), - bound_axes_sums(helper.bound_axes_sums), - prior_bound_axes_sums(std::move(helper.bound_axes_sums)) {} - + NumberNodeStateData(std::vector input) : ArrayNodeStateData(std::move(input)) {} + NumberNodeStateData(std::vector input, std::vector> bound_axes_sums) + : ArrayNodeStateData(std::move(input)), + bound_axes_sums(std::move(bound_axes_sums)), + prior_bound_axes_sums(this->bound_axes_sums) {} + /// For each bound axis and for each hyperslice along said axis, we + /// track the sum of the values within the hyperslice. + /// bound_axes_sums[i][j] = "sum of the values within the jth + /// hyperslice along the ith bound axis" std::vector> bound_axes_sums; // Store a copy for NumberNode::revert() and commit() std::vector> prior_bound_axes_sums; @@ -188,22 +91,96 @@ double NumberNode::min() const { return min_; } double NumberNode::max() const { return max_; } +std::vector> get_bound_axes_sums( + const std::vector& number_data, const std::vector bound_axes_info, + std::span node_shape, std::span node_strides) { + assert(node_shape.size() == node_strides.size()); + assert(bound_axes_info.size() <= node_shape.size()); + assert(std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()) == + static_cast(number_data.size())); + + const ssize_t num_bound_axes = static_cast(bound_axes_info.size()); + // For each bound axis, initialize the sum of the values contained in each + // of it's hyperslice to 0. + std::vector> bound_axes_sums; + bound_axes_sums.reserve(num_bound_axes); + for (const BoundAxisInfo& axis_info : bound_axes_info) { + assert(0 <= axis_info.axis && axis_info.axis < static_cast(node_shape.size())); + bound_axes_sums.emplace_back(node_shape[axis_info.axis], 0.0); + } + + // Define a BufferIterator for number_data (contiguous block of doubles) + // given the shape and strides of the NumberNode. + BufferIterator it(number_data.data(), node_shape, node_strides); + + // Iterate over number_data. + for (; it != std::default_sentinel; ++it) { + // Increment the appropriate slice in each bound axis. + for (ssize_t i = 0; i < num_bound_axes; ++i) { + const ssize_t axis = bound_axes_info[i].axis; + assert(0 <= axis && axis < it.location().size()); + const ssize_t slice = it.location()[axis]; + assert(0 <= slice && slice < bound_axes_sums[i].size()); + bound_axes_sums[i][slice] += *it; + } + } + + return bound_axes_sums; +} + +bool satisfies_axis_wise_bounds(const std::vector& bound_axes_info, + const std::vector>& bound_axes_sums) { + assert(bound_axes_info.size() == bound_axes_sums.size()); + // Check that each hyperslice satisfies the axis-wise bounds. + for (ssize_t i = 0, stop_i = static_cast(bound_axes_info.size()); i < stop_i; ++i) { + const std::vector& bound_axis_sums = bound_axes_sums[i]; + const BoundAxisInfo& bound_axis_info = bound_axes_info[i]; + + for (ssize_t slice = 0, stop_slice = static_cast(bound_axis_sums.size()); + slice < stop_slice; ++slice) { + switch (bound_axis_info.get_operator(slice)) { + case Equal: + if (bound_axis_sums[slice] != bound_axis_info.get_bound(slice)) return false; + break; + case LessEqual: + if (bound_axis_sums[slice] > bound_axis_info.get_bound(slice)) return false; + break; + case GreaterEqual: + if (bound_axis_sums[slice] < bound_axis_info.get_bound(slice)) return false; + break; + default: + throw std::invalid_argument("Invalid axis-wise bound operator"); + } + } + } + return true; +} + void NumberNode::initialize_state(State& state, std::vector&& number_data) const { if (number_data.size() != static_cast(this->size())) { throw std::invalid_argument("Size of data provided does not match node size"); } + for (ssize_t index = 0, stop = this->size(); index < stop; ++index) { if (!is_valid(index, number_data[index])) { throw std::invalid_argument("Invalid data provided for node"); } } - emplace_data_ptr(state, std::move(number_data), bound_axes_info_, - this->shape(), this->strides()); + if (bound_axes_info_.size() == 0) { // No bound axes to consider. + emplace_data_ptr(state, std::move(number_data)); + return; + } - if (!this->satisfies_axis_wise_bounds(state)) { + std::vector> bound_axes_sums = + get_bound_axes_sums(number_data, bound_axes_info_, this->shape(), this->strides()); + + if (!satisfies_axis_wise_bounds(bound_axes_info_, bound_axes_sums)) { throw std::invalid_argument("Initialized values do not satisfy axis-wise bounds."); } + + emplace_data_ptr(state, std::move(number_data), + std::move(bound_axes_sums)); } void NumberNode::initialize_state(State& state) const { @@ -212,6 +189,7 @@ void NumberNode::initialize_state(State& state) const { for (ssize_t i = 0, stop = this->size(); i < stop; ++i) { values.push_back(default_value(i)); } + /// Set all to mins initialize_state(state, std::move(values)); } @@ -230,7 +208,7 @@ void NumberNode::revert(State& state) const noexcept { } void NumberNode::exchange(State& state, ssize_t i, ssize_t j) const { - auto ptr = data_ptr(state); + auto ptr = data_ptr(state); // We expect the exchange to obey the index-wise bounds. assert(lower_bound(i) <= ptr->get(j)); assert(upper_bound(i) >= ptr->get(j)); @@ -246,7 +224,7 @@ void NumberNode::exchange(State& state, ssize_t i, ssize_t j) const { update_bound_axis_slice_sums(state, i, difference); // Index j changed from (what is now) ptr->get(i) to ptr->get(j) update_bound_axis_slice_sums(state, j, -difference); - assert(satisfies_axis_wise_bounds(state)); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); } } @@ -287,14 +265,14 @@ double NumberNode::upper_bound() const { } void NumberNode::clip_and_set_value(State& state, ssize_t index, double value) const { - auto ptr = data_ptr(state); + auto ptr = data_ptr(state); value = std::clamp(value, lower_bound(index), upper_bound(index)); // Assert that i is a valid index occurs in data_ptr->set(). // Set occurs IFF `value` != buffer[i] . if (ptr->set(index, value)) { // Update the bound axis sums. update_bound_axis_slice_sums(state, index, value - diff(state).back().old); - assert(satisfies_axis_wise_bounds(state)); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); } } @@ -416,52 +394,13 @@ NumberNode::NumberNode(std::span shape, std::vector lower check_axis_wise_bounds(bound_axes_info_, this->shape()); } -bool NumberNode::satisfies_axis_wise_bounds(State& state) const { - const auto& bound_axes_info = bound_axes_info_; - if (bound_axes_info.size() == 0) return true; // No axis-wise bounds to satisfy - - // Get the hyperslice sums of all bound axes. - const auto& bound_axes_sums = data_ptr(state)->bound_axes_sums; - assert(bound_axes_info.size() == bound_axes_sums.size()); - - for (ssize_t bound_axis = 0, stop = static_cast(bound_axes_info.size()); - bound_axis < stop; ++bound_axis) { - // Get the stateless axis-wise bound for the bound axis - const BoundAxisInfo& bound_axis_info = bound_axes_info[bound_axis]; - // Get the sums of all hyperslices along the bound axis - const std::vector& bound_axis_sums = bound_axes_sums[bound_axis]; - - // Possible To Do: We could "optimize" here if axis has uniform bounds - // and or operators for all slices. - for (ssize_t slice = 0, stop = static_cast(bound_axis_sums.size()); slice < stop; - ++slice) { - // Check whether the axis-wise bound is satisfied for the given hyperslice - switch (bound_axis_info.get_operator(slice)) { - case Equal: - if (bound_axis_sums[slice] == bound_axis_info.get_bound(slice)) continue; - return false; - case LessEqual: - if (bound_axis_sums[slice] <= bound_axis_info.get_bound(slice)) continue; - return false; - case GreaterEqual: - if (bound_axis_sums[slice] >= bound_axis_info.get_bound(slice)) continue; - return false; - default: - throw std::invalid_argument("Invalid axis-wise bound operator"); - } - } - } - return true; -} - void NumberNode::update_bound_axis_slice_sums(State& state, const ssize_t index, const double value_change) const { const auto& bound_axes_info = bound_axes_info_; if (bound_axes_info.size() == 0) return; // No axis-wise bounds to satisfy - // Obtain the multidimensional indices for `index` so we can identify the - // slices `index` lies on per bound axis. - // Possible To Do: We could optimize this get the bound axes indices only. + // Get multidimensional indices for `index` so we can identify the slices + // `index` lies on per bound axis. const std::vector multi_index = unravel_index(index, this->shape()); assert(bound_axes_info.size() <= multi_index.size()); // Get the hyperslice sums of all bound axes. @@ -576,7 +515,7 @@ bool IntegerNode::is_valid(ssize_t index, double value) const { } void IntegerNode::set_value(State& state, ssize_t index, double value) const { - auto ptr = data_ptr(state); + auto ptr = data_ptr(state); // We expect `value` to obey the index-wise bounds and to be an integer. assert(lower_bound(index) <= value); assert(upper_bound(index) >= value); @@ -586,7 +525,7 @@ void IntegerNode::set_value(State& state, ssize_t index, double value) const { if (ptr->set(index, value)) { // Update the bound axis. update_bound_axis_slice_sums(state, index, value - diff(state).back().old); - assert(satisfies_axis_wise_bounds(state)); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); } } @@ -688,7 +627,7 @@ BinaryNode::BinaryNode(ssize_t size, double lower_bound, double upper_bound, std::move(bound_axes)) {} void BinaryNode::flip(State& state, ssize_t i) const { - auto ptr = data_ptr(state); + auto ptr = data_ptr(state); // Variable should not be fixed. assert(lower_bound(i) != upper_bound(i)); // Assert that i is a valid index occurs in ptr->set(). @@ -697,31 +636,33 @@ void BinaryNode::flip(State& state, ssize_t i) const { // If value changed from 0 -> 1, update the bound axis sums by 1. // If value changed from 1 -> 0, update the bound axis sums by -1. update_bound_axis_slice_sums(state, i, (ptr->get(i) == 1) ? 1 : -1); - assert(satisfies_axis_wise_bounds(state)); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); } } void BinaryNode::set(State& state, ssize_t i) const { + auto ptr = data_ptr(state); // We expect the set to obey the index-wise bounds. assert(upper_bound(i) == 1.0); - // Assert that i is a valid index occurs in data_ptr->set(). + // Assert that i is a valid index occurs in ptr->set(). // set() occurs IFF `value` != buffer[i]. - if (data_ptr(state)->set(i, 1.0)) { + if (ptr->set(i, 1.0)) { // If value changed from 0 -> 1, update the bound axis sums by 1. update_bound_axis_slice_sums(state, i, 1.0); - assert(satisfies_axis_wise_bounds(state)); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); } } void BinaryNode::unset(State& state, ssize_t i) const { + auto ptr = data_ptr(state); // We expect the set to obey the index-wise bounds. assert(lower_bound(i) == 0.0); - // Assert that i is a valid index occurs in data_ptr->set(). + // Assert that i is a valid index occurs in ptr->set(). // set occurs IFF `value` != buffer[i]. - if (data_ptr(state)->set(i, 0.0)) { + if (ptr->set(i, 0.0)) { // If value changed from 1 -> 0, update the bound axis sums by -1. update_bound_axis_slice_sums(state, i, -1.0); - assert(satisfies_axis_wise_bounds(state)); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); } } From 304859e428a3dc468c062ff59c56a6c29e0c8ca3 Mon Sep 17 00:00:00 2001 From: fastbodin Date: Wed, 28 Jan 2026 16:10:14 -0800 Subject: [PATCH 5/7] NumberNode: Construct state given exactly one axis-wise bound. Defined method to initialize_state() given exactly one axis-wise bound. Fill state with lower bounds and increment until state satisfies axis-wise bounds or determines infeasible. Added appropriate C++ IntegerNode and BinaryNode tests. --- .../dwave-optimization/nodes/numbers.hpp | 4 + dwave/optimization/src/nodes/numbers.cpp | 141 ++++++- tests/cpp/nodes/test_numbers.cpp | 396 +++++++++++++++++- 3 files changed, 521 insertions(+), 20 deletions(-) diff --git a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp index 2735c173..c42afdbf 100644 --- a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp +++ b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp @@ -97,6 +97,10 @@ class NumberNode : public ArrayOutputMixin, public DecisionNode { // Initialize the state of the node randomly template void initialize_state(State& state, Generator& rng) const { + if (bound_axes_info_.size() > 0) { + throw std::invalid_argument("Cannot randomly initialize_state with bound axes"); + } + std::vector values; const ssize_t size = this->size(); values.reserve(size); diff --git a/dwave/optimization/src/nodes/numbers.cpp b/dwave/optimization/src/nodes/numbers.cpp index 574b8f59..6f30045f 100644 --- a/dwave/optimization/src/nodes/numbers.cpp +++ b/dwave/optimization/src/nodes/numbers.cpp @@ -25,6 +25,7 @@ #include "_state.hpp" #include "dwave-optimization/array.hpp" +#include "dwave-optimization/common.hpp" namespace dwave::optimization { @@ -91,15 +92,16 @@ double NumberNode::min() const { return min_; } double NumberNode::max() const { return max_; } -std::vector> get_bound_axes_sums( - const std::vector& number_data, const std::vector bound_axes_info, - std::span node_shape, std::span node_strides) { - assert(node_shape.size() == node_strides.size()); - assert(bound_axes_info.size() <= node_shape.size()); +std::vector> get_bound_axes_sums(const NumberNode* node, + const std::vector& number_data) { + std::span node_shape = node->shape(); + const std::vector& bound_axes_info = node->axis_wise_bounds(); + const ssize_t num_bound_axes = static_cast(bound_axes_info.size()); + + assert(num_bound_axes <= node_shape.size()); assert(std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()) == static_cast(number_data.size())); - const ssize_t num_bound_axes = static_cast(bound_axes_info.size()); // For each bound axis, initialize the sum of the values contained in each // of it's hyperslice to 0. std::vector> bound_axes_sums; @@ -110,18 +112,18 @@ std::vector> get_bound_axes_sums( } // Define a BufferIterator for number_data (contiguous block of doubles) - // given the shape and strides of the NumberNode. - BufferIterator it(number_data.data(), node_shape, node_strides); + // given the shape and strides of NumberNode. + BufferIterator it(number_data.data(), node_shape, node->strides()); // Iterate over number_data. for (; it != std::default_sentinel; ++it) { - // Increment the appropriate slice in each bound axis. - for (ssize_t i = 0; i < num_bound_axes; ++i) { - const ssize_t axis = bound_axes_info[i].axis; + // Increment the appropriate hyperslice along each bound axis. + for (ssize_t bound_axis = 0; bound_axis < num_bound_axes; ++bound_axis) { + const ssize_t axis = bound_axes_info[bound_axis].axis; assert(0 <= axis && axis < it.location().size()); const ssize_t slice = it.location()[axis]; - assert(0 <= slice && slice < bound_axes_sums[i].size()); - bound_axes_sums[i][slice] += *it; + assert(0 <= slice && slice < bound_axes_sums[bound_axis].size()); + bound_axes_sums[bound_axis][slice] += *it; } } @@ -149,7 +151,7 @@ bool satisfies_axis_wise_bounds(const std::vector& bound_axes_inf if (bound_axis_sums[slice] < bound_axis_info.get_bound(slice)) return false; break; default: - throw std::invalid_argument("Invalid axis-wise bound operator"); + unreachable(); } } } @@ -172,8 +174,7 @@ void NumberNode::initialize_state(State& state, std::vector&& number_dat return; } - std::vector> bound_axes_sums = - get_bound_axes_sums(number_data, bound_axes_info_, this->shape(), this->strides()); + std::vector> bound_axes_sums = get_bound_axes_sums(this, number_data); if (!satisfies_axis_wise_bounds(bound_axes_info_, bound_axes_sums)) { throw std::invalid_argument("Initialized values do not satisfy axis-wise bounds."); @@ -183,13 +184,115 @@ void NumberNode::initialize_state(State& state, std::vector&& number_dat std::move(bound_axes_sums)); } +void construct_state_given_exactly_one_bound_axis(const NumberNode* node, + std::vector& values) { + const std::span node_shape = node->shape(); + const std::span node_strides = node->strides(); + assert(node_shape.size() == node_strides.size()); + const ssize_t ndim = node_shape.size(); + + // We need to construct a state that satisfies the axis wise bounds. + // First, initialize all elements to their lower bounds. + for (ssize_t i = 0, stop = node->size(); i < stop; ++i) { + values.push_back(node->lower_bound(i)); + } + // Second, determine the hyperslice sums for the bound axis. This could be + // done during the previous loop if we want to improve performance. + assert(node->axis_wise_bounds().size() == 1); + std::vector bound_axis_sums = get_bound_axes_sums(node, values)[0]; + const BoundAxisInfo& bound_axis_info = node->axis_wise_bounds()[0]; + const ssize_t bound_axis = bound_axis_info.axis; + assert(0 <= bound_axis && bound_axis < ndim); + // Iterator to the beginning of `values`. + BufferIterator values_begin(values.data(), ndim, node_shape.data(), + node_strides.data()); + // Offset used to perterb `values_begin` to the first element of the + // hyperslice along the given bound axis. + std::vector offset(ndim, 0); + + // Third, we iterate over each hyperslice and adjust its values until + // it satisfies the axis-wise bounds. + for (ssize_t slice = 0, stop = node_shape[bound_axis]; slice < stop; ++slice) { + // Determine the amount we need to adjust the initialized values by + // to satisfy the axis-wise bounds for the given hyperslice. + double delta = 0; + + switch (bound_axis_info.get_operator(slice)) { + case Equal: + if (bound_axis_sums[slice] > bound_axis_info.get_bound(slice)) { + throw std::invalid_argument("Axis-wise bounds are infeasible."); + } + delta = bound_axis_info.get_bound(slice) - bound_axis_sums[slice]; + assert(delta >= 0); + // If error was not thrown, either (delta > 0) and (sum < + // bound) or (delta == 0) and (sum == bound). + break; + case LessEqual: + if (bound_axis_sums[slice] > bound_axis_info.get_bound(slice)) { + throw std::invalid_argument("Axis-wise bounds are infeasible."); + } + // If error was not thrown, then (delta == 0) and (sum <= bound) + break; + case GreaterEqual: + if (bound_axis_sums[slice] < bound_axis_info.get_bound(slice)) { + delta = bound_axis_info.get_bound(slice) - bound_axis_sums[slice]; + } + assert(delta >= 0); + // Either (delta == 0) and (sum >= bound) or (delta > 0) and + // (sum < bound). + break; + default: + unreachable(); + } + + if (delta == 0) continue; // axis-wise bounds are satisfied for slice. + + // Define iterator to the cannonically least index in the given slice + // along the bound axis. + offset[bound_axis] = slice; + BufferIterator it = values_begin + offset; + + // Iterate over all remaining elements in values. + for (; it != std::default_sentinel_t(); ++it) { + // Only consider values that fall in the slice. + if (it.location()[bound_axis] != slice) continue; + + // Determine the index of `it` from `values_begin` + const ssize_t index = static_cast(it - values_begin); + assert(0 <= index && index < values.size()); + // Determine the amount we can increment the value in the given index. + ssize_t inc = std::min(delta, node->upper_bound(index) - *it); + + if (inc > 0) { // Apply the increment to both `it` and `delta`. + *it += inc; + delta -= inc; + if (delta == 0) break; // Axis-wise bounds are now satisfied for slice. + } + } + + if (delta != 0) { + throw std::invalid_argument("Axis-wise bounds are infeasible."); + } + } +} + void NumberNode::initialize_state(State& state) const { std::vector values; values.reserve(this->size()); - for (ssize_t i = 0, stop = this->size(); i < stop; ++i) { - values.push_back(default_value(i)); + + if (bound_axes_info_.size() == 0) { // No bound axes to consider + for (ssize_t i = 0, stop = this->size(); i < stop; ++i) { + values.push_back(default_value(i)); + } + initialize_state(state, std::move(values)); + return; } - /// Set all to mins + + if (bound_axes_info_.size() != 1) { + throw std::invalid_argument("Cannot initialize state with multiple bound axes."); + } + + construct_state_given_exactly_one_bound_axis(this, values); initialize_state(state, std::move(values)); } diff --git a/tests/cpp/nodes/test_numbers.cpp b/tests/cpp/nodes/test_numbers.cpp index 20e58bf3..53df1eae 100644 --- a/tests/cpp/nodes/test_numbers.cpp +++ b/tests/cpp/nodes/test_numbers.cpp @@ -567,7 +567,206 @@ TEST_CASE("BinaryNode") { "Axis-wise bounds are supported for at most one axis."); } - GIVEN("(2x3x4)-BinaryNode with an axis-wise bound on axis: 0") { + GIVEN("(2x3x4)-IntegerNode with non-integral axis-wise bounds") { + BoundAxisInfo bound_axis{1, std::vector{Equal}, + std::vector{0.1}}; + REQUIRE_THROWS_WITH(graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, + std::nullopt, std::vector{bound_axis}), + "Axis wise bounds for integral number arrays must be intregral."); + } + + GIVEN("(3x2x2)-BinaryNode with infeasible axis-wise bound on axis: 0") { + auto graph = Graph(); + + BoundAxisInfo bound_axis{0, std::vector{Equal, LessEqual, GreaterEqual}, + std::vector{5.0, 2.0, 3.0}}; + + // Each hyperslice along axis 0 has size 4. There is no feasible + // assignment to the values in slice 0 (along axis 0) that results in a + // sum equal to 5. + graph.emplace_node(std::initializer_list{3, 2, 2}, + std::nullopt, std::nullopt, + std::vector{bound_axis}); + + WHEN("We create a state by initialize_state()") { + REQUIRE_THROWS_WITH(graph.initialize_state(), "Axis-wise bounds are infeasible."); + } + } + + GIVEN("(3x2x2)-BinaryNode with infeasible axis-wise bound on axis: 1") { + auto graph = Graph(); + + BoundAxisInfo bound_axis{1, std::vector{Equal, GreaterEqual}, + std::vector{5.0, 7.0}}; + + graph.emplace_node(std::initializer_list{3, 2, 2}, + std::nullopt, std::nullopt, + std::vector{bound_axis}); + + WHEN("We create a state by initialize_state()") { + // Each hyperslice along axis 1 has size 6. There is no feasible + // assignment to the values in slice 1 (along axis 1) that results in a + // sum greater than or equal to 7. + REQUIRE_THROWS_WITH(graph.initialize_state(), "Axis-wise bounds are infeasible."); + } + } + + GIVEN("(3x2x2)-BinaryNode with infeasible axis-wise bound on axis: 2") { + auto graph = Graph(); + + BoundAxisInfo bound_axis{2, std::vector{Equal, LessEqual}, + std::vector{5.0, -1.0}}; + + graph.emplace_node(std::initializer_list{3, 2, 2}, + std::nullopt, std::nullopt, + std::vector{bound_axis}); + + WHEN("We create a state by initialize_state()") { + // Each hyperslice along axis 2 has size 6. There is no feasible + // assignment to the values in slice 1 (along axis 2) that results in a + // sum less than or equal to -1. + REQUIRE_THROWS_WITH(graph.initialize_state(), "Axis-wise bounds are infeasible."); + } + } + + GIVEN("(3x2x2)-BinaryNode with feasible axis-wise bound on axis: 0") { + auto graph = Graph(); + + BoundAxisInfo bound_axis{0, std::vector{Equal, LessEqual, GreaterEqual}, + std::vector{1.0, 2.0, 3.0}}; + + auto bnode_ptr = graph.emplace_node( + std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, + std::vector{bound_axis}); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axis.axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(3*2*2)]).reshape(3, 2, 2) + // print(a[0, :, :].flatten()) + // ... [0 1 2 3] + // print(a[1, :, :].flatten()) + // ... [4 5 6 7] + // print(a[2, :, :].flatten()) + // ... [ 8 9 10 11] + std::vector expected_init{1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0}; + // Cannonically least state that satisfies bounds + // slice 0 slice 1 slice 2 + // 1, 0 0, 0 1, 1 + // 0, 0 0, 0 1, 0 + + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 3); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 0, 3})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(3x2x2)-BinaryNode with feasible axis-wise bound on axis: 1") { + auto graph = Graph(); + + BoundAxisInfo bound_axis{1, std::vector{LessEqual, GreaterEqual}, + std::vector{1.0, 5.0}}; + + auto bnode_ptr = graph.emplace_node( + std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, + std::vector{bound_axis}); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axis.axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(3*2*2)]).reshape(3, 2, 2) + // print(a[:, 0, :].flatten()) + // ... [0 1 4 5 8 9] + // print(a[:, 1, :].flatten()) + // ... [ 2 3 6 7 10 11] + std::vector expected_init{0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0}; + // Cannonically least state that satisfies bounds + // slice 0 slice 1 + // 0, 0 1, 1 + // 0, 0 1, 1 + // 0, 0 1, 0 + + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 2); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({0, 5})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(3x2x2)-BinaryNode with feasible axis-wise bound on axis: 2") { + auto graph = Graph(); + + BoundAxisInfo bound_axis{2, std::vector{Equal, GreaterEqual}, + std::vector{3.0, 6.0}}; + + auto bnode_ptr = graph.emplace_node( + std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, + std::vector{bound_axis}); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axis.axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(3*2*2)]).reshape(3, 2, 2) + // print(a[:, :, 0].flatten()) + // ... [ 0 2 4 6 8 10] + // print(a[:, :, 1].flatten()) + // ... [ 1 3 5 7 9 11] + std::vector expected_init{1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1}; + // Cannonically least state that satisfies bounds + // slice 0 slice 1 + // 1, 1 1, 1 + // 1, 0 1, 1 + // 0, 0 1, 1 + + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 2); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({3, 6})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(3x2x2)-BinaryNode with an axis-wise bound on axis: 0") { auto graph = Graph(); BoundAxisInfo bound_axis{0, std::vector{Equal, LessEqual, GreaterEqual}, @@ -1211,6 +1410,201 @@ TEST_CASE("IntegerNode") { "Axis wise bounds for integral number arrays must be intregral."); } + GIVEN("(2x3x2)-IntegerNode with infeasible axis-wise bound on axis: 0") { + auto graph = Graph(); + + BoundAxisInfo bound_axis{0, std::vector{Equal, LessEqual}, + std::vector{5.0, -31.0}}; + + graph.emplace_node( + std::initializer_list{2, 3, 2}, -5, 8, + std::vector{bound_axis}); + + WHEN("We create a state by initialize_state()") { + // Each hyperslice along axis 0 has size 6. There is no feasible + // assignment to the values in slice 1 (along axis 0) that results in a + // sum less than or equal to -5*6-1 = -31. + REQUIRE_THROWS_WITH(graph.initialize_state(), "Axis-wise bounds are infeasible."); + } + } + + GIVEN("(2x3x2)-IntegerNode with infeasible axis-wise bound on axis: 1") { + auto graph = Graph(); + + BoundAxisInfo bound_axis{1, std::vector{GreaterEqual, Equal, Equal}, + std::vector{33.0, 0.0, 0.0}}; + + graph.emplace_node( + std::initializer_list{2, 3, 2}, -5, 8, + std::vector{bound_axis}); + + WHEN("We create a state by initialize_state()") { + // Each hyperslice along axis 1 has size 4. There is no feasible + // assignment to the values in slice 0 (along axis 1) that results in a + // sum greater than or equal to 4*8+1 = 33. + REQUIRE_THROWS_WITH(graph.initialize_state(), "Axis-wise bounds are infeasible."); + } + } + + GIVEN("(2x3x2)-IntegerNode with infeasible axis-wise bound on axis: 2") { + auto graph = Graph(); + + BoundAxisInfo bound_axis{2, std::vector{GreaterEqual, Equal}, + std::vector{-1.0, 49.0}}; + + graph.emplace_node( + std::initializer_list{2, 3, 2}, -5, 8, + std::vector{bound_axis}); + + WHEN("We create a state by initialize_state()") { + // Each hyperslice along axis 2 has size 6. There is no feasible + // assignment to the values in slice 1 (along axis 2) that results in a + // sum or equal to 6*8+1 = 49 + REQUIRE_THROWS_WITH(graph.initialize_state(), "Axis-wise bounds are infeasible."); + } + } + + GIVEN("(2x3x2)-IntegerNode with feasible axis-wise bound on axis: 0") { + auto graph = Graph(); + + BoundAxisInfo bound_axis{0, std::vector{Equal, GreaterEqual}, + std::vector{-21.0, 9.0}}; + + auto bnode_ptr = graph.emplace_node( + std::initializer_list{2, 3, 2}, -5, 8, + std::vector{bound_axis}); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axis.axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(2*3*2)]).reshape(2, 3, 2) + // print(a[0, :, :].flatten()) + // ... [0 1 2 3 4 5] + // print(a[1, :, :].flatten()) + // ... [ 6 7 8 9 10 11] + // + // initialize_state() will start with + // [-5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5] + // repair slice 0 + // [4, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5] + // repair slice 1 + // [4, -5, -5, -5, -5, -5, 8, 8, 8, -5, -5, -5] + std::vector expected_init{4, -5, -5, -5, -5, -5, 8, 8, 8, -5, -5, -5}; + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 2); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({-21.0, 9.0})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(2x3x2)-IntegerNode with feasible axis-wise bound on axis: 1") { + auto graph = Graph(); + + BoundAxisInfo bound_axis{1, std::vector{Equal, GreaterEqual, LessEqual}, + std::vector{0.0, -2.0, 0.0}}; + + auto bnode_ptr = graph.emplace_node( + std::initializer_list{2, 3, 2}, -5, 8, + std::vector{bound_axis}); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axis.axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(2*3*2)]).reshape(2, 3, 2) + // print(a[:, 0, :].flatten()) + // ... [0 1 6 7] + // print(a[:, 1, :].flatten()) + // ... [2 3 8 9] + // print(a[:, 2, :].flatten()) + // ... [ 4 5 10 11] + // + // initialize_state() will start with + // [-5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5] + // repair slice 0 w/ [8, 2, -5, -5] + // [8, 2, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5] + // repair slice 1 w/ [8, 0, -5, -5] + // [8, 2, 8, 0, -5, -5, -5, -5, -5, -5, -5, -5] + // no need to repair slice 2 + std::vector expected_init{8, 2, 8, 0, -5, -5, -5, -5, -5, -5, -5, -5}; + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 3); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({0.0, -2.0, -20.0})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(2x3x2)-IntegerNode with feasible axis-wise bound on axis: 2") { + auto graph = Graph(); + + BoundAxisInfo bound_axis{2, std::vector{Equal, GreaterEqual}, + std::vector{23.0, 14.0}}; + + auto bnode_ptr = graph.emplace_node( + std::initializer_list{2, 3, 2}, -5, 8, + std::vector{bound_axis}); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axis.axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(2*3*2)]).reshape(2, 3, 2) + // print(a[:, :, 0].flatten()) + // ... [ 0 2 4 6 8 10] + // print(a[:, :, 0].flatten()) + // ... [ 1 3 5 7 9 11] + // + // initialize_state() will start with + // [-5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5] + // repair slice 0 w/ [8, 8, 8, 8, -4, -5] + // [8, -5, 8, -5, 8, -5, 8, -5, -4, -5, -5, -5] + // repair slice 0 w/ [8, 8, 8, 0, -5, -5] + // [8, 8, 8, 8, 8, 8, 8, 0, -4, -5, -5, -5] + std::vector expected_init{8, 8, 8, 8, 8, 8, 8, 0, -4, -5, -5, -5}; + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 2); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({23.0, 14.0})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + GIVEN("(2x3x2)-IntegerNode with index-wise bounds and an axis-wise bound on axis: 1") { auto graph = Graph(); From 1b1df7ce2db0723996132f7777e9ed3954256a53 Mon Sep 17 00:00:00 2001 From: fastbodin Date: Thu, 29 Jan 2026 15:35:47 -0800 Subject: [PATCH 6/7] Improve NumberNode bound axes Made BoundAxisInfo and BoundAxisOperators members of NumberNode. Updated all C++ tests. Optimized BufferIterator use in initialize_state(). --- .../dwave-optimization/nodes/numbers.hpp | 57 ++- dwave/optimization/src/nodes/numbers.cpp | 133 +++--- tests/cpp/nodes/test_numbers.cpp | 412 +++++++++++------- 3 files changed, 349 insertions(+), 253 deletions(-) diff --git a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp index c42afdbf..bf96fdad 100644 --- a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp +++ b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp @@ -25,38 +25,37 @@ namespace dwave::optimization { -/// Allowable axis-wise bound operators. -enum BoundAxisOperator { Equal, LessEqual, GreaterEqual }; - -/// Class for stateless axis-wise bound information. Given an `axis`, define -/// constraints on the sum of the values in each slice along `axis`. -/// Constraints can be defined for ALL slices along `axis` or PER slice along -/// `axis`. Allowable operators are defined by `BoundAxisOperator`. -class BoundAxisInfo { - public: - /// To reduce the # of `IntegerNode` and `BinaryNode` constructors, we - /// allow only one constructor. - BoundAxisInfo(ssize_t axis, std::vector axis_operators, - std::vector axis_bounds); - /// The bound axis - const ssize_t axis; - /// Operator for ALL axis slices (vector has length one) or operator*s* PER - /// slice (length of vector is equal to the number of slices). - const std::vector operators; - /// Bound for ALL axis slices (vector has length one) or bound*s* PER slice - /// (length of vector is equal to the number of slices). - const std::vector bounds; - - /// Obtain the bound associated with a given slice along bound axis. - double get_bound(const ssize_t slice) const; - - /// Obtain the operator associated with a given slice along bound axis. - BoundAxisOperator get_operator(const ssize_t slice) const; -}; - /// A contiguous block of numbers. class NumberNode : public ArrayOutputMixin, public DecisionNode { public: + /// Allowable axis-wise bound operators. + enum BoundAxisOperator { Equal, LessEqual, GreaterEqual }; + + /// Struct for stateless axis-wise bound information. Given an `axis`, define + /// constraints on the sum of the values in each slice along `axis`. + /// Constraints can be defined for ALL slices along `axis` or PER slice along + /// `axis`. Allowable operators are defined by `BoundAxisOperator`. + struct BoundAxisInfo { + /// To reduce the # of `IntegerNode` and `BinaryNode` constructors, we + /// allow only one constructor. + BoundAxisInfo(ssize_t axis, std::vector axis_operators, + std::vector axis_bounds); + /// The bound axis + const ssize_t axis; + /// Operator for ALL axis slices (vector has length one) or operator*s* PER + /// slice (length of vector is equal to the number of slices). + const std::vector operators; + /// Bound for ALL axis slices (vector has length one) or bound*s* PER slice + /// (length of vector is equal to the number of slices). + const std::vector bounds; + + /// Obtain the bound associated with a given slice along bound axis. + double get_bound(const ssize_t slice) const; + + /// Obtain the operator associated with a given slice along bound axis. + BoundAxisOperator get_operator(const ssize_t slice) const; + }; + NumberNode() = delete; // Overloads needed by the Array ABC ************************************** diff --git a/dwave/optimization/src/nodes/numbers.cpp b/dwave/optimization/src/nodes/numbers.cpp index 6f30045f..cb6108d7 100644 --- a/dwave/optimization/src/nodes/numbers.cpp +++ b/dwave/optimization/src/nodes/numbers.cpp @@ -29,8 +29,9 @@ namespace dwave::optimization { -BoundAxisInfo::BoundAxisInfo(ssize_t bound_axis, std::vector axis_operators, - std::vector axis_bounds) +NumberNode::BoundAxisInfo::BoundAxisInfo(ssize_t bound_axis, + std::vector axis_operators, + std::vector axis_bounds) : axis(bound_axis), operators(std::move(axis_operators)), bounds(std::move(axis_bounds)) { const ssize_t num_operators = static_cast(operators.size()); const ssize_t num_bounds = static_cast(bounds.size()); @@ -50,14 +51,14 @@ BoundAxisInfo::BoundAxisInfo(ssize_t bound_axis, std::vector } } -double BoundAxisInfo::get_bound(const ssize_t slice) const { +double NumberNode::BoundAxisInfo::get_bound(const ssize_t slice) const { assert(0 <= slice); if (bounds.size() == 0) return bounds[0]; assert(slice < static_cast(bounds.size())); return bounds[slice]; } -BoundAxisOperator BoundAxisInfo::get_operator(const ssize_t slice) const { +NumberNode::BoundAxisOperator NumberNode::BoundAxisInfo::get_operator(const ssize_t slice) const { assert(0 <= slice); if (operators.size() == 0) return operators[0]; assert(slice < static_cast(operators.size())); @@ -95,7 +96,7 @@ double NumberNode::max() const { return max_; } std::vector> get_bound_axes_sums(const NumberNode* node, const std::vector& number_data) { std::span node_shape = node->shape(); - const std::vector& bound_axes_info = node->axis_wise_bounds(); + const std::vector& bound_axes_info = node->axis_wise_bounds(); const ssize_t num_bound_axes = static_cast(bound_axes_info.size()); assert(num_bound_axes <= node_shape.size()); @@ -106,7 +107,7 @@ std::vector> get_bound_axes_sums(const NumberNode* node, // of it's hyperslice to 0. std::vector> bound_axes_sums; bound_axes_sums.reserve(num_bound_axes); - for (const BoundAxisInfo& axis_info : bound_axes_info) { + for (const NumberNode::BoundAxisInfo& axis_info : bound_axes_info) { assert(0 <= axis_info.axis && axis_info.axis < static_cast(node_shape.size())); bound_axes_sums.emplace_back(node_shape[axis_info.axis], 0.0); } @@ -130,24 +131,24 @@ std::vector> get_bound_axes_sums(const NumberNode* node, return bound_axes_sums; } -bool satisfies_axis_wise_bounds(const std::vector& bound_axes_info, +bool satisfies_axis_wise_bounds(const std::vector& bound_axes_info, const std::vector>& bound_axes_sums) { assert(bound_axes_info.size() == bound_axes_sums.size()); // Check that each hyperslice satisfies the axis-wise bounds. for (ssize_t i = 0, stop_i = static_cast(bound_axes_info.size()); i < stop_i; ++i) { const std::vector& bound_axis_sums = bound_axes_sums[i]; - const BoundAxisInfo& bound_axis_info = bound_axes_info[i]; + const NumberNode::BoundAxisInfo& bound_axis_info = bound_axes_info[i]; for (ssize_t slice = 0, stop_slice = static_cast(bound_axis_sums.size()); slice < stop_slice; ++slice) { switch (bound_axis_info.get_operator(slice)) { - case Equal: + case NumberNode::Equal: if (bound_axis_sums[slice] != bound_axis_info.get_bound(slice)) return false; break; - case LessEqual: + case NumberNode::LessEqual: if (bound_axis_sums[slice] > bound_axis_info.get_bound(slice)) return false; break; - case GreaterEqual: + case NumberNode::GreaterEqual: if (bound_axis_sums[slice] < bound_axis_info.get_bound(slice)) return false; break; default: @@ -184,11 +185,42 @@ void NumberNode::initialize_state(State& state, std::vector&& number_dat std::move(bound_axes_sums)); } +std::vector reorder_span(const std::span span, const ssize_t axis) { + std::vector output; + const ssize_t ndim = span.size(); + output.reserve(ndim); + output.emplace_back(span[axis]); + for (ssize_t i = 0; i < ndim; ++i) { + if (i == axis) continue; + output.emplace_back(span[i]); + } + return output; +} + +double compute_bound_axis_slice_delta(const ssize_t slice, const double sum, + const NumberNode::BoundAxisOperator op, const double bound) { + switch (op) { + case NumberNode::Equal: + if (sum > bound) throw std::invalid_argument("Infeasible axis-wise bounds."); + // If error was not thrown, return amount needed to satisfy bound. + return bound - sum; + case NumberNode::LessEqual: + if (sum > bound) throw std::invalid_argument("Infeasible axis-wise bounds."); + // If error was not thrown, sum satisfies bound. + return 0.0; + case NumberNode::GreaterEqual: + // If sum is less than bound, return the amount needed to equal it. + if (sum < bound) return bound - sum; + // Otherwise, sum satisfies bound. + return 0.0; + default: + unreachable(); + } +} + void construct_state_given_exactly_one_bound_axis(const NumberNode* node, std::vector& values) { const std::span node_shape = node->shape(); - const std::span node_strides = node->strides(); - assert(node_shape.size() == node_strides.size()); const ssize_t ndim = node_shape.size(); // We need to construct a state that satisfies the axis wise bounds. @@ -200,12 +232,20 @@ void construct_state_given_exactly_one_bound_axis(const NumberNode* node, // done during the previous loop if we want to improve performance. assert(node->axis_wise_bounds().size() == 1); std::vector bound_axis_sums = get_bound_axes_sums(node, values)[0]; - const BoundAxisInfo& bound_axis_info = node->axis_wise_bounds()[0]; + + const NumberNode::BoundAxisInfo& bound_axis_info = node->axis_wise_bounds()[0]; const ssize_t bound_axis = bound_axis_info.axis; assert(0 <= bound_axis && bound_axis < ndim); + // Iterator to the beginning of `values`. - BufferIterator values_begin(values.data(), ndim, node_shape.data(), - node_strides.data()); + std::vector slice_shape = reorder_span(node_shape, bound_axis); + std::vector slice_strides = reorder_span(node->strides(), bound_axis); + BufferIterator values_begin(values.data(), ndim, slice_shape.data(), + slice_strides.data()); + std::vector one_more(ndim, 0); + one_more[0] = 1; + auto values_next = values_begin + one_more; + // Offset used to perterb `values_begin` to the first element of the // hyperslice along the given bound axis. std::vector offset(ndim, 0); @@ -215,47 +255,19 @@ void construct_state_given_exactly_one_bound_axis(const NumberNode* node, for (ssize_t slice = 0, stop = node_shape[bound_axis]; slice < stop; ++slice) { // Determine the amount we need to adjust the initialized values by // to satisfy the axis-wise bounds for the given hyperslice. - double delta = 0; - - switch (bound_axis_info.get_operator(slice)) { - case Equal: - if (bound_axis_sums[slice] > bound_axis_info.get_bound(slice)) { - throw std::invalid_argument("Axis-wise bounds are infeasible."); - } - delta = bound_axis_info.get_bound(slice) - bound_axis_sums[slice]; - assert(delta >= 0); - // If error was not thrown, either (delta > 0) and (sum < - // bound) or (delta == 0) and (sum == bound). - break; - case LessEqual: - if (bound_axis_sums[slice] > bound_axis_info.get_bound(slice)) { - throw std::invalid_argument("Axis-wise bounds are infeasible."); - } - // If error was not thrown, then (delta == 0) and (sum <= bound) - break; - case GreaterEqual: - if (bound_axis_sums[slice] < bound_axis_info.get_bound(slice)) { - delta = bound_axis_info.get_bound(slice) - bound_axis_sums[slice]; - } - assert(delta >= 0); - // Either (delta == 0) and (sum >= bound) or (delta > 0) and - // (sum < bound). - break; - default: - unreachable(); - } + double delta = compute_bound_axis_slice_delta(slice, bound_axis_sums[slice], + bound_axis_info.get_operator(slice), + bound_axis_info.get_bound(slice)); + assert(delta >= 0); if (delta == 0) continue; // axis-wise bounds are satisfied for slice. + offset[0] = slice; // Define iterator to the cannonically least index in the given slice // along the bound axis. - offset[bound_axis] = slice; - BufferIterator it = values_begin + offset; - - // Iterate over all remaining elements in values. - for (; it != std::default_sentinel_t(); ++it) { + for (auto it = values_begin + offset, it_end = values_next + offset; it != it_end; ++it) { // Only consider values that fall in the slice. - if (it.location()[bound_axis] != slice) continue; + assert(it.location()[0] == slice); // Determine the index of `it` from `values_begin` const ssize_t index = static_cast(it - values_begin); @@ -266,13 +278,11 @@ void construct_state_given_exactly_one_bound_axis(const NumberNode* node, if (inc > 0) { // Apply the increment to both `it` and `delta`. *it += inc; delta -= inc; - if (delta == 0) break; // Axis-wise bounds are now satisfied for slice. + if (delta == 0) break; // Axis-wise bounds are now satisfied for slice. } } - if (delta != 0) { - throw std::invalid_argument("Axis-wise bounds are infeasible."); - } + if (delta != 0) throw std::invalid_argument("Infeasible axis-wise bounds."); } } @@ -379,7 +389,9 @@ void NumberNode::clip_and_set_value(State& state, ssize_t index, double value) c } } -const std::vector& NumberNode::axis_wise_bounds() const { return bound_axes_info_; } +const std::vector& NumberNode::axis_wise_bounds() const { + return bound_axes_info_; +} const std::vector>& NumberNode::bound_axis_sums(State& state) const { return data_ptr(state)->bound_axes_sums; @@ -426,7 +438,7 @@ void check_index_wise_bounds(const NumberNode& node, const std::vector& } /// Check the user defined axis-wise bounds for NumberNode -void check_axis_wise_bounds(const std::vector& bound_axes_info, +void check_axis_wise_bounds(const std::vector& bound_axes_info, const std::span shape) { if (bound_axes_info.size() == 0) return; // No bound axes to check. @@ -434,7 +446,7 @@ void check_axis_wise_bounds(const std::vector& bound_axes_info, std::vector axis_bound(shape.size(), false); // For each set of bound axis data - for (const BoundAxisInfo& bound_axis_info : bound_axes_info) { + for (const NumberNode::BoundAxisInfo& bound_axis_info : bound_axes_info) { const ssize_t axis = bound_axis_info.axis; if (axis < 0 || axis >= static_cast(shape.size())) { @@ -524,10 +536,11 @@ void NumberNode::update_bound_axis_slice_sums(State& state, const ssize_t index, // Integer Node *************************************************************** /// Check the user defined axis-wise bounds for IntegerNode -void check_integrality_of_axis_wise_bounds(const std::vector& bound_axes_info) { +void check_integrality_of_axis_wise_bounds( + const std::vector& bound_axes_info) { if (bound_axes_info.size() == 0) return; // No bound axes to check. - for (const BoundAxisInfo& bound_axis_info : bound_axes_info) { + for (const NumberNode::BoundAxisInfo& bound_axis_info : bound_axes_info) { for (const double& bound : bound_axis_info.bounds) { if (bound != std::round(bound)) { throw std::invalid_argument( diff --git a/tests/cpp/nodes/test_numbers.cpp b/tests/cpp/nodes/test_numbers.cpp index 53df1eae..3116c094 100644 --- a/tests/cpp/nodes/test_numbers.cpp +++ b/tests/cpp/nodes/test_numbers.cpp @@ -29,42 +29,54 @@ namespace dwave::optimization { TEST_CASE("BoundAxisInfo") { GIVEN("BoundAxisInfo(axis = 0, operators = {}, bounds = {1.0})") { REQUIRE_THROWS_WITH( - BoundAxisInfo(0, std::vector{}, std::vector{1.0}), + NumberNode::BoundAxisInfo(0, std::vector{}, + std::vector{1.0}), "Bad axis-wise bounds for axis: 0, `operators` and `bounds` must each have " "non-zero size."); } GIVEN("BoundAxisInfo(axis = 0, operators = {<=}, bounds = {})") { REQUIRE_THROWS_WITH( - BoundAxisInfo(0, std::vector{LessEqual}, std::vector{}), + NumberNode::BoundAxisInfo(0, + std::vector{ + NumberNode::NumberNode::LessEqual}, + std::vector{}), "Bad axis-wise bounds for axis: 0, `operators` and `bounds` must each have " "non-zero size."); } GIVEN("BoundAxisInfo(axis = 1, operators = {<=, ==, ==}, bounds = {2.0, 1.0})") { REQUIRE_THROWS_WITH( - BoundAxisInfo(1, std::vector{LessEqual, Equal, Equal}, - std::vector{2.0, 1.0}), + NumberNode::BoundAxisInfo( + 1, + std::vector{ + NumberNode::LessEqual, NumberNode::Equal, NumberNode::Equal}, + std::vector{2.0, 1.0}), "Bad axis-wise bounds for axis: 1, `operators` and `bounds` should have same size " "if neither has size 1."); } GIVEN("BoundAxisInfo(axis = 2, operators = {==}, bounds = {1.0})") { - BoundAxisInfo bound_axis(2, std::vector{Equal}, - std::vector{1.0}); + NumberNode::BoundAxisInfo bound_axis( + 2, std::vector{NumberNode::Equal}, + std::vector{1.0}); THEN("The bound axis info is correct") { CHECK(bound_axis.axis == 2); - CHECK_THAT(bound_axis.operators, RangeEquals({Equal})); + CHECK_THAT(bound_axis.operators, RangeEquals({NumberNode::Equal})); CHECK_THAT(bound_axis.bounds, RangeEquals({1.0})); } } GIVEN("BoundAxisInfo(axis = 2, operators = {==, <=, >=}, bounds = {1.0, 2.0, 3.0})") { - BoundAxisInfo bound_axis(2, std::vector{Equal, LessEqual, GreaterEqual}, - std::vector{1.0, 2.0, 3.0}); + NumberNode::BoundAxisInfo bound_axis( + 2, + std::vector{NumberNode::Equal, NumberNode::LessEqual, + NumberNode::GreaterEqual}, + std::vector{1.0, 2.0, 3.0}); THEN("The bound axis info is correct") { CHECK(bound_axis.axis == 2); - CHECK_THAT(bound_axis.operators, RangeEquals({Equal, LessEqual, GreaterEqual})); + CHECK_THAT(bound_axis.operators, RangeEquals({NumberNode::Equal, NumberNode::LessEqual, + NumberNode::GreaterEqual})); CHECK_THAT(bound_axis.bounds, RangeEquals({1.0, 2.0, 3.0})); } } @@ -488,161 +500,189 @@ TEST_CASE("BinaryNode") { } GIVEN("(2x3)-BinaryNode with axis-wise bounds on the invalid axis -1") { - BoundAxisInfo bound_axis{-1, std::vector{Equal}, - std::vector{1.0}}; - REQUIRE_THROWS_WITH(graph.emplace_node( - std::initializer_list{2, 3}, std::nullopt, - std::nullopt, std::vector{bound_axis}), - "Invalid bound axis: -1. Note, negative indexing is not supported for " - "axis-wise bounds."); + NumberNode::BoundAxisInfo bound_axis{ + -1, std::vector{NumberNode::Equal}, + std::vector{1.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, std::nullopt, + std::vector{bound_axis}), + "Invalid bound axis: -1. Note, negative indexing is not supported for " + "axis-wise bounds."); } GIVEN("(2x3)-BinaryNode with axis-wise bounds on the invalid axis 2") { - BoundAxisInfo bound_axis{2, std::vector{Equal}, - std::vector{1.0}}; - REQUIRE_THROWS_WITH(graph.emplace_node( - std::initializer_list{2, 3}, std::nullopt, - std::nullopt, std::vector{bound_axis}), - "Invalid bound axis: 2. Note, negative indexing is not supported for " - "axis-wise bounds."); + NumberNode::BoundAxisInfo bound_axis{ + 2, std::vector{NumberNode::Equal}, + std::vector{1.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, std::nullopt, + std::vector{bound_axis}), + "Invalid bound axis: 2. Note, negative indexing is not supported for " + "axis-wise bounds."); } GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too many operators.") { - BoundAxisInfo bound_axis{1, std::vector{LessEqual, Equal, Equal, Equal}, - std::vector{1.0}}; + NumberNode::BoundAxisInfo bound_axis{ + 1, + std::vector{NumberNode::LessEqual, NumberNode::Equal, + NumberNode::Equal, NumberNode::Equal}, + std::vector{1.0}}; REQUIRE_THROWS_WITH( graph.emplace_node( std::initializer_list{2, 3}, std::nullopt, std::nullopt, - std::vector{bound_axis}), + std::vector{bound_axis}), "Invalid number of axis-wise operators along axis: 1 given axis size: 3"); } GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too few operators.") { - BoundAxisInfo bound_axis{1, std::vector{LessEqual, Equal}, - std::vector{1.0}}; + NumberNode::BoundAxisInfo bound_axis{1, + std::vector{ + NumberNode::LessEqual, NumberNode::Equal}, + std::vector{1.0}}; REQUIRE_THROWS_WITH( graph.emplace_node( std::initializer_list{2, 3}, std::nullopt, std::nullopt, - std::vector{bound_axis}), + std::vector{bound_axis}), "Invalid number of axis-wise operators along axis: 1 given axis size: 3"); } GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too many bounds.") { - BoundAxisInfo bound_axis{1, std::vector{LessEqual}, - std::vector{1.0, 2.0, 3.0, 4.0}}; - REQUIRE_THROWS_WITH(graph.emplace_node( - std::initializer_list{2, 3}, std::nullopt, - std::nullopt, std::vector{bound_axis}), - "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); + NumberNode::BoundAxisInfo bound_axis{ + 1, std::vector{NumberNode::LessEqual}, + std::vector{1.0, 2.0, 3.0, 4.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, std::nullopt, + std::vector{bound_axis}), + "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); } GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too few bounds.") { - BoundAxisInfo bound_axis{1, std::vector{LessEqual}, - std::vector{1.0, 2.0}}; - REQUIRE_THROWS_WITH(graph.emplace_node( - std::initializer_list{2, 3}, std::nullopt, - std::nullopt, std::vector{bound_axis}), - "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); + NumberNode::BoundAxisInfo bound_axis{ + 1, std::vector{NumberNode::LessEqual}, + std::vector{1.0, 2.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, std::nullopt, + std::vector{bound_axis}), + "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); } GIVEN("(2x3)-BinaryNode with duplicate axis-wise bounds on axis: 1") { - BoundAxisInfo bound_axis{1, std::vector{Equal}, - std::vector{1.0}}; + NumberNode::BoundAxisInfo bound_axis{ + 1, std::vector{NumberNode::Equal}, + std::vector{1.0}}; REQUIRE_THROWS_WITH( graph.emplace_node( std::initializer_list{2, 3}, std::nullopt, std::nullopt, - std::vector{bound_axis, bound_axis}), + std::vector{bound_axis, bound_axis}), "Cannot define multiple axis-wise bounds for a single axis."); } GIVEN("(2x3)-BinaryNode with axis-wise bounds on axes: 0 and 1") { - BoundAxisInfo bound_axis_0{0, std::vector{LessEqual}, - std::vector{1.0}}; - BoundAxisInfo bound_axis_1{1, std::vector{LessEqual}, - std::vector{1.0}}; + NumberNode::BoundAxisInfo bound_axis_0{ + 0, std::vector{NumberNode::LessEqual}, + std::vector{1.0}}; + NumberNode::BoundAxisInfo bound_axis_1{ + 1, std::vector{NumberNode::LessEqual}, + std::vector{1.0}}; REQUIRE_THROWS_WITH( graph.emplace_node( std::initializer_list{2, 3}, std::nullopt, std::nullopt, - std::vector{bound_axis_0, bound_axis_1}), + std::vector{bound_axis_0, bound_axis_1}), "Axis-wise bounds are supported for at most one axis."); } GIVEN("(2x3x4)-IntegerNode with non-integral axis-wise bounds") { - BoundAxisInfo bound_axis{1, std::vector{Equal}, - std::vector{0.1}}; - REQUIRE_THROWS_WITH(graph.emplace_node( - std::initializer_list{2, 3}, std::nullopt, - std::nullopt, std::vector{bound_axis}), - "Axis wise bounds for integral number arrays must be intregral."); + NumberNode::BoundAxisInfo bound_axis{ + 1, std::vector{NumberNode::Equal}, + std::vector{0.1}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, std::nullopt, + std::vector{bound_axis}), + "Axis wise bounds for integral number arrays must be intregral."); } GIVEN("(3x2x2)-BinaryNode with infeasible axis-wise bound on axis: 0") { auto graph = Graph(); - BoundAxisInfo bound_axis{0, std::vector{Equal, LessEqual, GreaterEqual}, - std::vector{5.0, 2.0, 3.0}}; + NumberNode::BoundAxisInfo bound_axis{ + 0, + std::vector{NumberNode::Equal, NumberNode::LessEqual, + NumberNode::GreaterEqual}, + std::vector{5.0, 2.0, 3.0}}; // Each hyperslice along axis 0 has size 4. There is no feasible // assignment to the values in slice 0 (along axis 0) that results in a // sum equal to 5. - graph.emplace_node(std::initializer_list{3, 2, 2}, - std::nullopt, std::nullopt, - std::vector{bound_axis}); + graph.emplace_node( + std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, + std::vector{bound_axis}); WHEN("We create a state by initialize_state()") { - REQUIRE_THROWS_WITH(graph.initialize_state(), "Axis-wise bounds are infeasible."); + REQUIRE_THROWS_WITH(graph.initialize_state(), "Infeasible axis-wise bounds."); } } GIVEN("(3x2x2)-BinaryNode with infeasible axis-wise bound on axis: 1") { auto graph = Graph(); - BoundAxisInfo bound_axis{1, std::vector{Equal, GreaterEqual}, - std::vector{5.0, 7.0}}; + NumberNode::BoundAxisInfo bound_axis{1, + std::vector{ + NumberNode::Equal, NumberNode::GreaterEqual}, + std::vector{5.0, 7.0}}; - graph.emplace_node(std::initializer_list{3, 2, 2}, - std::nullopt, std::nullopt, - std::vector{bound_axis}); + graph.emplace_node( + std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, + std::vector{bound_axis}); WHEN("We create a state by initialize_state()") { // Each hyperslice along axis 1 has size 6. There is no feasible // assignment to the values in slice 1 (along axis 1) that results in a // sum greater than or equal to 7. - REQUIRE_THROWS_WITH(graph.initialize_state(), "Axis-wise bounds are infeasible."); + REQUIRE_THROWS_WITH(graph.initialize_state(), "Infeasible axis-wise bounds."); } } GIVEN("(3x2x2)-BinaryNode with infeasible axis-wise bound on axis: 2") { auto graph = Graph(); - BoundAxisInfo bound_axis{2, std::vector{Equal, LessEqual}, - std::vector{5.0, -1.0}}; + NumberNode::BoundAxisInfo bound_axis{2, + std::vector{ + NumberNode::Equal, NumberNode::LessEqual}, + std::vector{5.0, -1.0}}; - graph.emplace_node(std::initializer_list{3, 2, 2}, - std::nullopt, std::nullopt, - std::vector{bound_axis}); + graph.emplace_node( + std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, + std::vector{bound_axis}); WHEN("We create a state by initialize_state()") { // Each hyperslice along axis 2 has size 6. There is no feasible // assignment to the values in slice 1 (along axis 2) that results in a // sum less than or equal to -1. - REQUIRE_THROWS_WITH(graph.initialize_state(), "Axis-wise bounds are infeasible."); + REQUIRE_THROWS_WITH(graph.initialize_state(), "Infeasible axis-wise bounds."); } } GIVEN("(3x2x2)-BinaryNode with feasible axis-wise bound on axis: 0") { auto graph = Graph(); - BoundAxisInfo bound_axis{0, std::vector{Equal, LessEqual, GreaterEqual}, - std::vector{1.0, 2.0, 3.0}}; + NumberNode::BoundAxisInfo bound_axis{ + 0, + std::vector{NumberNode::Equal, NumberNode::LessEqual, + NumberNode::GreaterEqual}, + std::vector{1.0, 2.0, 3.0}}; auto bnode_ptr = graph.emplace_node( std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, - std::vector{bound_axis}); + std::vector{bound_axis}); THEN("Axis wise bound is correct") { CHECK(bnode_ptr->axis_wise_bounds().size() == 1); - BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; CHECK(bound_axis.axis == bnode_bound_axis.axis); CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); @@ -679,16 +719,19 @@ TEST_CASE("BinaryNode") { GIVEN("(3x2x2)-BinaryNode with feasible axis-wise bound on axis: 1") { auto graph = Graph(); - BoundAxisInfo bound_axis{1, std::vector{LessEqual, GreaterEqual}, - std::vector{1.0, 5.0}}; + NumberNode::BoundAxisInfo bound_axis{ + 1, + std::vector{NumberNode::LessEqual, + NumberNode::GreaterEqual}, + std::vector{1.0, 5.0}}; auto bnode_ptr = graph.emplace_node( std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, - std::vector{bound_axis}); + std::vector{bound_axis}); THEN("Axis wise bound is correct") { CHECK(bnode_ptr->axis_wise_bounds().size() == 1); - BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; CHECK(bound_axis.axis == bnode_bound_axis.axis); CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); @@ -724,16 +767,18 @@ TEST_CASE("BinaryNode") { GIVEN("(3x2x2)-BinaryNode with feasible axis-wise bound on axis: 2") { auto graph = Graph(); - BoundAxisInfo bound_axis{2, std::vector{Equal, GreaterEqual}, - std::vector{3.0, 6.0}}; + NumberNode::BoundAxisInfo bound_axis{2, + std::vector{ + NumberNode::Equal, NumberNode::GreaterEqual}, + std::vector{3.0, 6.0}}; auto bnode_ptr = graph.emplace_node( std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, - std::vector{bound_axis}); + std::vector{bound_axis}); THEN("Axis wise bound is correct") { CHECK(bnode_ptr->axis_wise_bounds().size() == 1); - BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; CHECK(bound_axis.axis == bnode_bound_axis.axis); CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); @@ -769,16 +814,19 @@ TEST_CASE("BinaryNode") { GIVEN("(3x2x2)-BinaryNode with an axis-wise bound on axis: 0") { auto graph = Graph(); - BoundAxisInfo bound_axis{0, std::vector{Equal, LessEqual, GreaterEqual}, - std::vector{1.0, 2.0, 3.0}}; + NumberNode::BoundAxisInfo bound_axis{ + 0, + std::vector{NumberNode::Equal, NumberNode::LessEqual, + NumberNode::GreaterEqual}, + std::vector{1.0, 2.0, 3.0}}; auto bnode_ptr = graph.emplace_node( std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, - std::vector{bound_axis}); + std::vector{bound_axis}); THEN("Axis wise bound is correct") { CHECK(bnode_ptr->axis_wise_bounds().size() == 1); - BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; CHECK(bound_axis.axis == bnode_bound_axis.axis); CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); @@ -1322,161 +1370,188 @@ TEST_CASE("IntegerNode") { } GIVEN("(2x3)-IntegerNode with axis-wise bounds on the invalid axis -2") { - BoundAxisInfo bound_axis{-2, std::vector{Equal}, - std::vector{20.0}}; - REQUIRE_THROWS_WITH(graph.emplace_node( - std::initializer_list{2, 3}, std::nullopt, - std::nullopt, std::vector{bound_axis}), - "Invalid bound axis: -2. Note, negative indexing is not supported for " - "axis-wise bounds."); + NumberNode::BoundAxisInfo bound_axis{ + -2, std::vector{NumberNode::Equal}, + std::vector{20.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, std::nullopt, + std::vector{bound_axis}), + "Invalid bound axis: -2. Note, negative indexing is not supported for " + "axis-wise bounds."); } GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on the invalid axis 3") { - BoundAxisInfo bound_axis{3, std::vector{Equal}, - std::vector{10.0}}; - REQUIRE_THROWS_WITH(graph.emplace_node( - std::initializer_list{2, 3, 4}, std::nullopt, - std::nullopt, std::vector{bound_axis}), - "Invalid bound axis: 3. Note, negative indexing is not supported for " - "axis-wise bounds."); + NumberNode::BoundAxisInfo bound_axis{ + 3, std::vector{NumberNode::Equal}, + std::vector{10.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, + std::vector{bound_axis}), + "Invalid bound axis: 3. Note, negative indexing is not supported for " + "axis-wise bounds."); } GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too many operators.") { - BoundAxisInfo bound_axis{1, std::vector{LessEqual, Equal, Equal, Equal}, - std::vector{-10.0}}; + NumberNode::BoundAxisInfo bound_axis{ + 1, + std::vector{NumberNode::LessEqual, NumberNode::Equal, + NumberNode::Equal, NumberNode::Equal}, + std::vector{-10.0}}; REQUIRE_THROWS_WITH( graph.emplace_node( std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, - std::vector{bound_axis}), + std::vector{bound_axis}), "Invalid number of axis-wise operators along axis: 1 given axis size: 3"); } GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too few operators.") { - BoundAxisInfo bound_axis{1, std::vector{LessEqual, Equal}, - std::vector{-11.0}}; + NumberNode::BoundAxisInfo bound_axis{1, + std::vector{ + NumberNode::LessEqual, NumberNode::Equal}, + std::vector{-11.0}}; REQUIRE_THROWS_WITH( graph.emplace_node( std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, - std::vector{bound_axis}), + std::vector{bound_axis}), "Invalid number of axis-wise operators along axis: 1 given axis size: 3"); } GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too many bounds.") { - BoundAxisInfo bound_axis{1, std::vector{LessEqual}, - std::vector{-10.0, 20.0, 30.0, 40.0}}; - REQUIRE_THROWS_WITH(graph.emplace_node( - std::initializer_list{2, 3, 4}, std::nullopt, - std::nullopt, std::vector{bound_axis}), - "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); + NumberNode::BoundAxisInfo bound_axis{ + 1, std::vector{NumberNode::LessEqual}, + std::vector{-10.0, 20.0, 30.0, 40.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, + std::vector{bound_axis}), + "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); } GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too few bounds.") { - BoundAxisInfo bound_axis{1, std::vector{LessEqual}, - std::vector{111.0, -223.0}}; - REQUIRE_THROWS_WITH(graph.emplace_node( - std::initializer_list{2, 3, 4}, std::nullopt, - std::nullopt, std::vector{bound_axis}), - "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); + NumberNode::BoundAxisInfo bound_axis{ + 1, std::vector{NumberNode::LessEqual}, + std::vector{111.0, -223.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, + std::vector{bound_axis}), + "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); } GIVEN("(2x3x4)-IntegerNode with duplicate axis-wise bounds on axis: 1") { - BoundAxisInfo bound_axis{1, std::vector{Equal}, - std::vector{100.0}}; + NumberNode::BoundAxisInfo bound_axis{ + 1, std::vector{NumberNode::Equal}, + std::vector{100.0}}; REQUIRE_THROWS_WITH( graph.emplace_node( std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, - std::vector{bound_axis, bound_axis}), + std::vector{bound_axis, bound_axis}), "Cannot define multiple axis-wise bounds for a single axis."); } GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axes: 0 and 1") { - BoundAxisInfo bound_axis_0{0, std::vector{LessEqual}, - std::vector{11.0}}; - BoundAxisInfo bound_axis_1{1, std::vector{LessEqual}, - std::vector{12.0}}; + NumberNode::BoundAxisInfo bound_axis_0{ + 0, std::vector{NumberNode::LessEqual}, + std::vector{11.0}}; + NumberNode::BoundAxisInfo bound_axis_1{ + 1, std::vector{NumberNode::LessEqual}, + std::vector{12.0}}; REQUIRE_THROWS_WITH( graph.emplace_node( std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, - std::vector{bound_axis_0, bound_axis_1}), + std::vector{bound_axis_0, bound_axis_1}), "Axis-wise bounds are supported for at most one axis."); } GIVEN("(2x3x4)-IntegerNode with non-integral axis-wise bounds") { - BoundAxisInfo bound_axis{2, std::vector{LessEqual}, - std::vector{11.0, 12.0001, 0.0, 0.0}}; - REQUIRE_THROWS_WITH(graph.emplace_node( - std::initializer_list{2, 3, 4}, std::nullopt, - std::nullopt, std::vector{bound_axis}), - "Axis wise bounds for integral number arrays must be intregral."); + NumberNode::BoundAxisInfo bound_axis{ + 2, std::vector{NumberNode::LessEqual}, + std::vector{11.0, 12.0001, 0.0, 0.0}}; + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, + std::vector{bound_axis}), + "Axis wise bounds for integral number arrays must be intregral."); } GIVEN("(2x3x2)-IntegerNode with infeasible axis-wise bound on axis: 0") { auto graph = Graph(); - BoundAxisInfo bound_axis{0, std::vector{Equal, LessEqual}, - std::vector{5.0, -31.0}}; + NumberNode::BoundAxisInfo bound_axis{0, + std::vector{ + NumberNode::Equal, NumberNode::LessEqual}, + std::vector{5.0, -31.0}}; graph.emplace_node( std::initializer_list{2, 3, 2}, -5, 8, - std::vector{bound_axis}); + std::vector{bound_axis}); WHEN("We create a state by initialize_state()") { // Each hyperslice along axis 0 has size 6. There is no feasible // assignment to the values in slice 1 (along axis 0) that results in a // sum less than or equal to -5*6-1 = -31. - REQUIRE_THROWS_WITH(graph.initialize_state(), "Axis-wise bounds are infeasible."); + REQUIRE_THROWS_WITH(graph.initialize_state(), "Infeasible axis-wise bounds."); } } GIVEN("(2x3x2)-IntegerNode with infeasible axis-wise bound on axis: 1") { auto graph = Graph(); - BoundAxisInfo bound_axis{1, std::vector{GreaterEqual, Equal, Equal}, - std::vector{33.0, 0.0, 0.0}}; + NumberNode::BoundAxisInfo bound_axis{ + 1, + std::vector{NumberNode::GreaterEqual, + NumberNode::Equal, NumberNode::Equal}, + std::vector{33.0, 0.0, 0.0}}; graph.emplace_node( std::initializer_list{2, 3, 2}, -5, 8, - std::vector{bound_axis}); + std::vector{bound_axis}); WHEN("We create a state by initialize_state()") { // Each hyperslice along axis 1 has size 4. There is no feasible // assignment to the values in slice 0 (along axis 1) that results in a // sum greater than or equal to 4*8+1 = 33. - REQUIRE_THROWS_WITH(graph.initialize_state(), "Axis-wise bounds are infeasible."); + REQUIRE_THROWS_WITH(graph.initialize_state(), "Infeasible axis-wise bounds."); } } GIVEN("(2x3x2)-IntegerNode with infeasible axis-wise bound on axis: 2") { auto graph = Graph(); - BoundAxisInfo bound_axis{2, std::vector{GreaterEqual, Equal}, - std::vector{-1.0, 49.0}}; + NumberNode::BoundAxisInfo bound_axis{2, + std::vector{ + NumberNode::GreaterEqual, NumberNode::Equal}, + std::vector{-1.0, 49.0}}; graph.emplace_node( std::initializer_list{2, 3, 2}, -5, 8, - std::vector{bound_axis}); + std::vector{bound_axis}); WHEN("We create a state by initialize_state()") { // Each hyperslice along axis 2 has size 6. There is no feasible // assignment to the values in slice 1 (along axis 2) that results in a // sum or equal to 6*8+1 = 49 - REQUIRE_THROWS_WITH(graph.initialize_state(), "Axis-wise bounds are infeasible."); + REQUIRE_THROWS_WITH(graph.initialize_state(), "Infeasible axis-wise bounds."); } } GIVEN("(2x3x2)-IntegerNode with feasible axis-wise bound on axis: 0") { auto graph = Graph(); - BoundAxisInfo bound_axis{0, std::vector{Equal, GreaterEqual}, - std::vector{-21.0, 9.0}}; + NumberNode::BoundAxisInfo bound_axis{0, + std::vector{ + NumberNode::Equal, NumberNode::GreaterEqual}, + std::vector{-21.0, 9.0}}; auto bnode_ptr = graph.emplace_node( std::initializer_list{2, 3, 2}, -5, 8, - std::vector{bound_axis}); + std::vector{bound_axis}); THEN("Axis wise bound is correct") { CHECK(bnode_ptr->axis_wise_bounds().size() == 1); - BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; CHECK(bound_axis.axis == bnode_bound_axis.axis); CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); @@ -1513,16 +1588,19 @@ TEST_CASE("IntegerNode") { GIVEN("(2x3x2)-IntegerNode with feasible axis-wise bound on axis: 1") { auto graph = Graph(); - BoundAxisInfo bound_axis{1, std::vector{Equal, GreaterEqual, LessEqual}, - std::vector{0.0, -2.0, 0.0}}; + NumberNode::BoundAxisInfo bound_axis{ + 1, + std::vector{ + NumberNode::Equal, NumberNode::GreaterEqual, NumberNode::LessEqual}, + std::vector{0.0, -2.0, 0.0}}; auto bnode_ptr = graph.emplace_node( std::initializer_list{2, 3, 2}, -5, 8, - std::vector{bound_axis}); + std::vector{bound_axis}); THEN("Axis wise bound is correct") { CHECK(bnode_ptr->axis_wise_bounds().size() == 1); - BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; CHECK(bound_axis.axis == bnode_bound_axis.axis); CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); @@ -1562,16 +1640,18 @@ TEST_CASE("IntegerNode") { GIVEN("(2x3x2)-IntegerNode with feasible axis-wise bound on axis: 2") { auto graph = Graph(); - BoundAxisInfo bound_axis{2, std::vector{Equal, GreaterEqual}, - std::vector{23.0, 14.0}}; + NumberNode::BoundAxisInfo bound_axis{2, + std::vector{ + NumberNode::Equal, NumberNode::GreaterEqual}, + std::vector{23.0, 14.0}}; auto bnode_ptr = graph.emplace_node( std::initializer_list{2, 3, 2}, -5, 8, - std::vector{bound_axis}); + std::vector{bound_axis}); THEN("Axis wise bound is correct") { CHECK(bnode_ptr->axis_wise_bounds().size() == 1); - BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; CHECK(bound_axis.axis == bnode_bound_axis.axis); CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); @@ -1608,16 +1688,20 @@ TEST_CASE("IntegerNode") { GIVEN("(2x3x2)-IntegerNode with index-wise bounds and an axis-wise bound on axis: 1") { auto graph = Graph(); - BoundAxisInfo bound_axis{1, std::vector{Equal, LessEqual, GreaterEqual}, - std::vector{11.0, 2.0, 5.0}}; + NumberNode::BoundAxisInfo bound_axis{ + 1, + std::vector{NumberNode::Equal, NumberNode::LessEqual, + NumberNode::GreaterEqual}, + std::vector{11.0, 2.0, 5.0}}; auto inode_ptr = graph.emplace_node( std::initializer_list{2, 3, 2}, -5, 8, - std::vector{bound_axis}); + std::vector{bound_axis}); THEN("Axis wise bound is correct") { CHECK(inode_ptr->axis_wise_bounds().size() == 1); - const BoundAxisInfo inode_bound_axis_ptr = inode_ptr->axis_wise_bounds().data()[0]; + const NumberNode::BoundAxisInfo inode_bound_axis_ptr = + inode_ptr->axis_wise_bounds().data()[0]; CHECK(bound_axis.axis == inode_bound_axis_ptr.axis); CHECK_THAT(bound_axis.operators, RangeEquals(inode_bound_axis_ptr.operators)); CHECK_THAT(bound_axis.bounds, RangeEquals(inode_bound_axis_ptr.bounds)); From bebd1459b048357fd372d7f64b099a483c74a656 Mon Sep 17 00:00:00 2001 From: fastbodin Date: Fri, 30 Jan 2026 14:40:01 -0800 Subject: [PATCH 7/7] Clean up axis-wise bound NumberNode C++ code Improved comments. Improved methods. Cleaned up C++ tests. Added static_casts where necessary for CircleCI. --- .../dwave-optimization/nodes/numbers.hpp | 14 +- dwave/optimization/src/nodes/numbers.cpp | 195 +++--- tests/cpp/nodes/test_numbers.cpp | 622 ++++++++---------- 3 files changed, 384 insertions(+), 447 deletions(-) diff --git a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp index bf96fdad..d503298e 100644 --- a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp +++ b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp @@ -49,10 +49,10 @@ class NumberNode : public ArrayOutputMixin, public DecisionNode { /// (length of vector is equal to the number of slices). const std::vector bounds; - /// Obtain the bound associated with a given slice along bound axis. + /// Obtain the bound associated with a given slice along `axis`. double get_bound(const ssize_t slice) const; - /// Obtain the operator associated with a given slice along bound axis. + /// Obtain the operator associated with a given slice along `axis`. BoundAxisOperator get_operator(const ssize_t slice) const; }; @@ -96,6 +96,8 @@ class NumberNode : public ArrayOutputMixin, public DecisionNode { // Initialize the state of the node randomly template void initialize_state(State& state, Generator& rng) const { + // Currently, we do not support random node Initialization with + // axis wise bounds. if (bound_axes_info_.size() > 0) { throw std::invalid_argument("Cannot randomly initialize_state with bound axes"); } @@ -138,10 +140,10 @@ class NumberNode : public ArrayOutputMixin, public DecisionNode { // in a given index. void clip_and_set_value(State& state, ssize_t index, double value) const; - /// Return pointer to the vector of axis-wise bounds + /// Return vector of axis-wise bounds. const std::vector& axis_wise_bounds() const; - // Return a pointer to the vector containing the bound axis sums + /// Return vector containing the bound axis sums in a given state. const std::vector>& bound_axis_sums(State& state) const; protected: @@ -155,8 +157,8 @@ class NumberNode : public ArrayOutputMixin, public DecisionNode { /// Default value in a given index. virtual double default_value(ssize_t index) const = 0; - /// Update the running bound axis sums where `index` is changed by - /// `value_change` in a given state. + /// Update the running bound axis sums where the value stored at `index` is + /// changed by `value_change` in a given state. void update_bound_axis_slice_sums(State& state, const ssize_t index, const double value_change) const; diff --git a/dwave/optimization/src/nodes/numbers.cpp b/dwave/optimization/src/nodes/numbers.cpp index cb6108d7..d7525b3f 100644 --- a/dwave/optimization/src/nodes/numbers.cpp +++ b/dwave/optimization/src/nodes/numbers.cpp @@ -33,21 +33,18 @@ NumberNode::BoundAxisInfo::BoundAxisInfo(ssize_t bound_axis, std::vector axis_operators, std::vector axis_bounds) : axis(bound_axis), operators(std::move(axis_operators)), bounds(std::move(axis_bounds)) { - const ssize_t num_operators = static_cast(operators.size()); - const ssize_t num_bounds = static_cast(bounds.size()); + const size_t num_operators = operators.size(); + const size_t num_bounds = bounds.size(); - // Null `operators` and `bounds` are not accepted. if ((num_operators == 0) || (num_bounds == 0)) { - throw std::invalid_argument("Bad axis-wise bounds for axis: " + std::to_string(axis) + - ", `operators` and `bounds` must each have non-zero size."); + throw std::invalid_argument("Axis-wise `operators` and `bounds` must have non-zero size."); } - // If `operators` and `bounds` are defined PER hyperslice along `axis`, - // they must have the same size. + // If `operators` and `bounds` are both defined PER hyperslice along + // `axis`, they must have the same size. if ((num_operators > 1) && (num_bounds > 1) && (num_bounds != num_operators)) { throw std::invalid_argument( - "Bad axis-wise bounds for axis: " + std::to_string(axis) + - ", `operators` and `bounds` should have same size if neither has size 1."); + "Axis-wise `operators` and `bounds` should have same size if neither has size 1."); } } @@ -93,18 +90,21 @@ double NumberNode::min() const { return min_; } double NumberNode::max() const { return max_; } +/// Given a NumberNode and an assingnment of it's variables (number_data), +/// compute and return a vector containing the sum of the values within each +/// hyperslice along each bound axis. std::vector> get_bound_axes_sums(const NumberNode* node, const std::vector& number_data) { std::span node_shape = node->shape(); - const std::vector& bound_axes_info = node->axis_wise_bounds(); + const auto& bound_axes_info = node->axis_wise_bounds(); const ssize_t num_bound_axes = static_cast(bound_axes_info.size()); - - assert(num_bound_axes <= node_shape.size()); + assert(num_bound_axes <= static_cast(node_shape.size())); assert(std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()) == static_cast(number_data.size())); // For each bound axis, initialize the sum of the values contained in each - // of it's hyperslice to 0. + // of it's hyperslice to 0. Define bound_axes_sums[i][j] = "sum of the + // values within the jth hyperslice along the ith bound axis" std::vector> bound_axes_sums; bound_axes_sums.reserve(num_bound_axes); for (const NumberNode::BoundAxisInfo& axis_info : bound_axes_info) { @@ -112,18 +112,16 @@ std::vector> get_bound_axes_sums(const NumberNode* node, bound_axes_sums.emplace_back(node_shape[axis_info.axis], 0.0); } - // Define a BufferIterator for number_data (contiguous block of doubles) - // given the shape and strides of NumberNode. - BufferIterator it(number_data.data(), node_shape, node->strides()); - - // Iterate over number_data. - for (; it != std::default_sentinel; ++it) { + // Define a BufferIterator for `number_data` given the shape and strides of + // NumberNode and iterate over it. + for (BufferIterator it(number_data.data(), node_shape, node->strides()); + it != std::default_sentinel; ++it) { // Increment the appropriate hyperslice along each bound axis. for (ssize_t bound_axis = 0; bound_axis < num_bound_axes; ++bound_axis) { const ssize_t axis = bound_axes_info[bound_axis].axis; - assert(0 <= axis && axis < it.location().size()); + assert(0 <= axis && axis < static_cast(it.location().size())); const ssize_t slice = it.location()[axis]; - assert(0 <= slice && slice < bound_axes_sums[bound_axis].size()); + assert(0 <= slice && slice < static_cast(bound_axes_sums[bound_axis].size())); bound_axes_sums[bound_axis][slice] += *it; } } @@ -131,13 +129,15 @@ std::vector> get_bound_axes_sums(const NumberNode* node, return bound_axes_sums; } +/// Determine whether the sum of the values within each hyperslice along +/// each bound axis satisfies the axis-wise bounds. bool satisfies_axis_wise_bounds(const std::vector& bound_axes_info, const std::vector>& bound_axes_sums) { assert(bound_axes_info.size() == bound_axes_sums.size()); // Check that each hyperslice satisfies the axis-wise bounds. for (ssize_t i = 0, stop_i = static_cast(bound_axes_info.size()); i < stop_i; ++i) { - const std::vector& bound_axis_sums = bound_axes_sums[i]; - const NumberNode::BoundAxisInfo& bound_axis_info = bound_axes_info[i]; + const auto& bound_axis_info = bound_axes_info[i]; + const auto& bound_axis_sums = bound_axes_sums[i]; for (ssize_t slice = 0, stop_slice = static_cast(bound_axis_sums.size()); slice < stop_slice; ++slice) { @@ -175,6 +175,8 @@ void NumberNode::initialize_state(State& state, std::vector&& number_dat return; } + // Given the assingnment to NumberNode, `number_data`, get the sum of the + // values within each hyperslice along each bound axis. std::vector> bound_axes_sums = get_bound_axes_sums(this, number_data); if (!satisfies_axis_wise_bounds(bound_axes_info_, bound_axes_sums)) { @@ -185,18 +187,28 @@ void NumberNode::initialize_state(State& state, std::vector&& number_dat std::move(bound_axes_sums)); } -std::vector reorder_span(const std::span span, const ssize_t axis) { - std::vector output; +/// Given a `span` (typically containing strides or shape), we reorder the +/// values of the span such that the given `axis` is moved to the 0th index. +std::vector reorder_to_move_along_axis(const std::span span, + const ssize_t axis) { const ssize_t ndim = span.size(); + std::vector output; output.reserve(ndim); output.emplace_back(span[axis]); + for (ssize_t i = 0; i < ndim; ++i) { - if (i == axis) continue; - output.emplace_back(span[i]); + if (i != axis) output.emplace_back(span[i]); } return output; } +/// Given a `slice` along a bound axis in a NumberNode where the sum of it's +/// values are given by `sum`, determine the non-negative amount `delta` +/// needed to be added to `sum` to satisfy the expression: (sum+delta) op bound +/// e.g. Given (sum, op, bound) := (10, ==, 12), delta = 2 +/// e.g. Given (sum, op, bound) := (10, <=, 12), delta = 0 +/// e.g. Given (sum, op, bound) := (10, >=, 12), delta = 2 +/// Throws an error if `delta` is negative (corresponding with an infeasible axis-wise bound); double compute_bound_axis_slice_delta(const ssize_t slice, const double sum, const NumberNode::BoundAxisOperator op, const double bound) { switch (op) { @@ -218,65 +230,74 @@ double compute_bound_axis_slice_delta(const ssize_t slice, const double sum, } } +/// Given a NumberNod and exactly one axis-wise bound defined for NumberNode, +/// assign values to `values` (in-place) to satisfy the axis-wise bound. This method +/// 1) Initially sets `values[i] = lower_bound(i)` for all i. +/// 2) Incremements the values within each hyperslice until they satisfy +/// the axis-wise bound (should this be possible). void construct_state_given_exactly_one_bound_axis(const NumberNode* node, std::vector& values) { const std::span node_shape = node->shape(); const ssize_t ndim = node_shape.size(); - // We need to construct a state that satisfies the axis wise bounds. - // First, initialize all elements to their lower bounds. + // 1) Initialize all elements to their lower bounds. for (ssize_t i = 0, stop = node->size(); i < stop; ++i) { values.push_back(node->lower_bound(i)); } - // Second, determine the hyperslice sums for the bound axis. This could be + // 2) Determine the hyperslice sums for the bound axis. This could be // done during the previous loop if we want to improve performance. assert(node->axis_wise_bounds().size() == 1); - std::vector bound_axis_sums = get_bound_axes_sums(node, values)[0]; - + const std::vector bound_axis_sums = get_bound_axes_sums(node, values)[0]; + // Obtain the axis-wise bound const NumberNode::BoundAxisInfo& bound_axis_info = node->axis_wise_bounds()[0]; const ssize_t bound_axis = bound_axis_info.axis; assert(0 <= bound_axis && bound_axis < ndim); - // Iterator to the beginning of `values`. - std::vector slice_shape = reorder_span(node_shape, bound_axis); - std::vector slice_strides = reorder_span(node->strides(), bound_axis); - BufferIterator values_begin(values.data(), ndim, slice_shape.data(), - slice_strides.data()); - std::vector one_more(ndim, 0); - one_more[0] = 1; - auto values_next = values_begin + one_more; - - // Offset used to perterb `values_begin` to the first element of the - // hyperslice along the given bound axis. - std::vector offset(ndim, 0); - - // Third, we iterate over each hyperslice and adjust its values until - // it satisfies the axis-wise bounds. + // We need a way to iterate over each hyperslice along the bound axis and + // adjust it`s values until they satisfy the axis-wise bounds. We do this + // by defining an iterator of `values` that can be used to iterate over the + // values within each hyperslice along the bound axis one after another. We + // can do this by modifying the NumberNode shape and strides such that the + // data for the bound_axis is moved to position 0 (remaining indices are + // shifted back). + std::vector new_shape = reorder_to_move_along_axis(node_shape, bound_axis); + std::vector new_strides = reorder_to_move_along_axis(node->strides(), bound_axis); + // Define an iterator for `values` corresponding to the beginning of slice + // 0 along the bound axis. This iterater will be used to define the start + // of a slice iterater. + BufferIterator slice_0_it(values.data(), ndim, new_shape.data(), + new_strides.data()); + // Determine the size of each slice along the bound axis. + const ssize_t slice_size = std::accumulate(new_shape.begin() + 1, new_shape.end(), 1.0, + std::multiplies()); + + // 3) Iterate over each hyperslice and adjust it's values until they + // satisfy the axis-wise bounds. for (ssize_t slice = 0, stop = node_shape[bound_axis]; slice < stop; ++slice) { - // Determine the amount we need to adjust the initialized values by - // to satisfy the axis-wise bounds for the given hyperslice. + // Determine the amount we need to adjust the initialized values within + // the slice. double delta = compute_bound_axis_slice_delta(slice, bound_axis_sums[slice], bound_axis_info.get_operator(slice), bound_axis_info.get_bound(slice)); - assert(delta >= 0); + if (delta == 0) continue; // Axis-wise bounds are satisfied for slice. + assert(delta >= 0); // Should only increment. - if (delta == 0) continue; // axis-wise bounds are satisfied for slice. + // Determine how much we need to offset slice_0_it to get to the first + // value in the given `slice` + const ssize_t offset = slice * slice_size; - offset[0] = slice; - // Define iterator to the cannonically least index in the given slice - // along the bound axis. - for (auto it = values_begin + offset, it_end = values_next + offset; it != it_end; ++it) { - // Only consider values that fall in the slice. - assert(it.location()[0] == slice); + for (auto slice_it = slice_0_it + offset, slice_end_it = slice_0_it + offset + slice_size; + slice_it != slice_end_it; ++slice_it) { + assert(slice_it.location()[0] == slice); // We should be in the right slice. - // Determine the index of `it` from `values_begin` - const ssize_t index = static_cast(it - values_begin); - assert(0 <= index && index < values.size()); + // Determine the index of `it` from `slice_0_it` + const ssize_t index = static_cast(slice_it - slice_0_it); + assert(0 <= index && index < static_cast(values.size())); // Determine the amount we can increment the value in the given index. - ssize_t inc = std::min(delta, node->upper_bound(index) - *it); + ssize_t inc = std::min(delta, node->upper_bound(index) - *slice_it); if (inc > 0) { // Apply the increment to both `it` and `delta`. - *it += inc; + *slice_it += inc; delta -= inc; if (delta == 0) break; // Axis-wise bounds are now satisfied for slice. } @@ -296,14 +317,13 @@ void NumberNode::initialize_state(State& state) const { } initialize_state(state, std::move(values)); return; + } else if (bound_axes_info_.size() == 1) { + construct_state_given_exactly_one_bound_axis(this, values); + initialize_state(state, std::move(values)); + return; } - if (bound_axes_info_.size() != 1) { - throw std::invalid_argument("Cannot initialize state with multiple bound axes."); - } - - construct_state_given_exactly_one_bound_axis(this, values); - initialize_state(state, std::move(values)); + throw std::invalid_argument("Cannot initialize state with multiple bound axes."); } void NumberNode::commit(State& state) const noexcept { @@ -327,8 +347,8 @@ void NumberNode::exchange(State& state, ssize_t i, ssize_t j) const { assert(upper_bound(i) >= ptr->get(j)); assert(lower_bound(j) <= ptr->get(i)); assert(upper_bound(j) >= ptr->get(i)); - // Assert that i and j are valid indices occurs in ptr->exchange(). - // Exchange occurs IFF (i != j) and (buffer[i] != buffer[j]). + // assert() that i and j are valid indices occurs in ptr->exchange(). + // State change occurs IFF (i != j) and (buffer[i] != buffer[j]). if (ptr->exchange(i, j)) { // If the values at indices i and j were exchanged, update the bound // axis sums. @@ -380,10 +400,9 @@ double NumberNode::upper_bound() const { void NumberNode::clip_and_set_value(State& state, ssize_t index, double value) const { auto ptr = data_ptr(state); value = std::clamp(value, lower_bound(index), upper_bound(index)); - // Assert that i is a valid index occurs in data_ptr->set(). - // Set occurs IFF `value` != buffer[i] . + // assert() that i is a valid index occurs in ptr->set(). + // State change occurs IFF `value` != buffer[index] . if (ptr->set(index, value)) { - // Update the bound axis sums. update_bound_axis_slice_sums(state, index, value - diff(state).back().old); assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); } @@ -450,25 +469,21 @@ void check_axis_wise_bounds(const std::vector& bound_ const ssize_t axis = bound_axis_info.axis; if (axis < 0 || axis >= static_cast(shape.size())) { - throw std::invalid_argument( - "Invalid bound axis: " + std::to_string(axis) + - ". Note, negative indexing is not supported for axis-wise bounds."); + throw std::invalid_argument("Invalid bound axis given number array shape."); } // The number of operators defined for the given bound axis const ssize_t num_operators = static_cast(bound_axis_info.operators.size()); if ((num_operators > 1) && (num_operators != shape[axis])) { throw std::invalid_argument( - "Invalid number of axis-wise operators along axis: " + std::to_string(axis) + - " given axis size: " + std::to_string(shape[axis])); + "Invalid number of axis-wise operators given number array shape."); } // The number of operators defined for the given bound axis const ssize_t num_bounds = static_cast(bound_axis_info.bounds.size()); if ((num_bounds > 1) && (num_bounds != shape[axis])) { throw std::invalid_argument( - "Invalid number of axis-wise bounds along axis: " + std::to_string(axis) + - " given axis size: " + std::to_string(shape[axis])); + "Invalid number of axis-wise bounds given number array shape."); } // Checked in BoundAxisInfo constructor @@ -524,10 +539,11 @@ void NumberNode::update_bound_axis_slice_sums(State& state, const ssize_t index, for (ssize_t bound_axis = 0, stop = static_cast(bound_axes_info.size()); bound_axis < stop; ++bound_axis) { + assert(0 <= bound_axes_info[bound_axis].axis); assert(bound_axes_info[bound_axis].axis < static_cast(multi_index.size())); // Get the slice along the bound axis the `value_change` occurs in const ssize_t slice = multi_index[bound_axes_info[bound_axis].axis]; - assert(slice < static_cast(bound_axes_sums[bound_axis].size())); + assert(0 <= slice && slice < static_cast(bound_axes_sums[bound_axis].size())); // Offset running sum in slice bound_axes_sums[bound_axis][slice] += value_change; } @@ -636,10 +652,9 @@ void IntegerNode::set_value(State& state, ssize_t index, double value) const { assert(lower_bound(index) <= value); assert(upper_bound(index) >= value); assert(value == std::round(value)); - // Assert that i is a valid index occurs in data_ptr->set(). - // set() occurs IFF `value` != buffer[i]. + // assert() that i is a valid index occurs in ptr->set(). + // State change occurs IFF `value` != buffer[index]. if (ptr->set(index, value)) { - // Update the bound axis. update_bound_axis_slice_sums(state, index, value - diff(state).back().old); assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); } @@ -746,8 +761,8 @@ void BinaryNode::flip(State& state, ssize_t i) const { auto ptr = data_ptr(state); // Variable should not be fixed. assert(lower_bound(i) != upper_bound(i)); - // Assert that i is a valid index occurs in ptr->set(). - // set() occurs IFF `value` != buffer[i]. + // assert() that i is a valid index occurs in ptr->set(). + // State change occurs IFF `value` != buffer[i]. if (ptr->set(i, !ptr->get(i))) { // If value changed from 0 -> 1, update the bound axis sums by 1. // If value changed from 1 -> 0, update the bound axis sums by -1. @@ -760,8 +775,8 @@ void BinaryNode::set(State& state, ssize_t i) const { auto ptr = data_ptr(state); // We expect the set to obey the index-wise bounds. assert(upper_bound(i) == 1.0); - // Assert that i is a valid index occurs in ptr->set(). - // set() occurs IFF `value` != buffer[i]. + // assert() that i is a valid index occurs in ptr->set(). + // State change occurs IFF `value` != buffer[i]. if (ptr->set(i, 1.0)) { // If value changed from 0 -> 1, update the bound axis sums by 1. update_bound_axis_slice_sums(state, i, 1.0); @@ -773,8 +788,8 @@ void BinaryNode::unset(State& state, ssize_t i) const { auto ptr = data_ptr(state); // We expect the set to obey the index-wise bounds. assert(lower_bound(i) == 0.0); - // Assert that i is a valid index occurs in ptr->set(). - // set occurs IFF `value` != buffer[i]. + // assert() that i is a valid index occurs in ptr->set(). + // State change occurs IFF `value` != buffer[i]. if (ptr->set(i, 0.0)) { // If value changed from 1 -> 0, update the bound axis sums by -1. update_bound_axis_slice_sums(state, i, -1.0); diff --git a/tests/cpp/nodes/test_numbers.cpp b/tests/cpp/nodes/test_numbers.cpp index 3116c094..778d8cdf 100644 --- a/tests/cpp/nodes/test_numbers.cpp +++ b/tests/cpp/nodes/test_numbers.cpp @@ -28,56 +28,50 @@ namespace dwave::optimization { TEST_CASE("BoundAxisInfo") { GIVEN("BoundAxisInfo(axis = 0, operators = {}, bounds = {1.0})") { - REQUIRE_THROWS_WITH( - NumberNode::BoundAxisInfo(0, std::vector{}, - std::vector{1.0}), - "Bad axis-wise bounds for axis: 0, `operators` and `bounds` must each have " - "non-zero size."); + std::vector operators; + std::vector bounds{1.0}; + REQUIRE_THROWS_WITH(NumberNode::BoundAxisInfo(0, operators, bounds), + "Axis-wise `operators` and `bounds` must have non-zero size."); } GIVEN("BoundAxisInfo(axis = 0, operators = {<=}, bounds = {})") { - REQUIRE_THROWS_WITH( - NumberNode::BoundAxisInfo(0, - std::vector{ - NumberNode::NumberNode::LessEqual}, - std::vector{}), - "Bad axis-wise bounds for axis: 0, `operators` and `bounds` must each have " - "non-zero size."); + std::vector operators{NumberNode::LessEqual}; + std::vector bounds; + REQUIRE_THROWS_WITH(NumberNode::BoundAxisInfo(0, operators, bounds), + "Axis-wise `operators` and `bounds` must have non-zero size."); } GIVEN("BoundAxisInfo(axis = 1, operators = {<=, ==, ==}, bounds = {2.0, 1.0})") { + std::vector operators{NumberNode::LessEqual, + NumberNode::Equal, NumberNode::Equal}; + std::vector bounds{2.0, 1.0}; REQUIRE_THROWS_WITH( - NumberNode::BoundAxisInfo( - 1, - std::vector{ - NumberNode::LessEqual, NumberNode::Equal, NumberNode::Equal}, - std::vector{2.0, 1.0}), - "Bad axis-wise bounds for axis: 1, `operators` and `bounds` should have same size " - "if neither has size 1."); + NumberNode::BoundAxisInfo(1, operators, bounds), + "Axis-wise `operators` and `bounds` should have same size if neither has size 1."); } GIVEN("BoundAxisInfo(axis = 2, operators = {==}, bounds = {1.0})") { - NumberNode::BoundAxisInfo bound_axis( - 2, std::vector{NumberNode::Equal}, - std::vector{1.0}); + std::vector operators{NumberNode::Equal}; + std::vector bounds{1.0}; + NumberNode::BoundAxisInfo bound_axis(2, operators, bounds); + THEN("The bound axis info is correct") { CHECK(bound_axis.axis == 2); - CHECK_THAT(bound_axis.operators, RangeEquals({NumberNode::Equal})); - CHECK_THAT(bound_axis.bounds, RangeEquals({1.0})); + CHECK_THAT(bound_axis.operators, RangeEquals(operators)); + CHECK_THAT(bound_axis.bounds, RangeEquals(bounds)); } } GIVEN("BoundAxisInfo(axis = 2, operators = {==, <=, >=}, bounds = {1.0, 2.0, 3.0})") { - NumberNode::BoundAxisInfo bound_axis( - 2, - std::vector{NumberNode::Equal, NumberNode::LessEqual, - NumberNode::GreaterEqual}, - std::vector{1.0, 2.0, 3.0}); + std::vector operators{ + NumberNode::Equal, NumberNode::LessEqual, NumberNode::GreaterEqual}; + std::vector bounds{1.0, 2.0, 3.0}; + NumberNode::BoundAxisInfo bound_axis(2, operators, bounds); + THEN("The bound axis info is correct") { CHECK(bound_axis.axis == 2); - CHECK_THAT(bound_axis.operators, RangeEquals({NumberNode::Equal, NumberNode::LessEqual, - NumberNode::GreaterEqual})); - CHECK_THAT(bound_axis.bounds, RangeEquals({1.0, 2.0, 3.0})); + CHECK_THAT(bound_axis.operators, RangeEquals(operators)); + CHECK_THAT(bound_axis.bounds, RangeEquals(bounds)); } } } @@ -353,8 +347,8 @@ TEST_CASE("BinaryNode") { } GIVEN("Binary node with index-wise bounds") { - auto bnode_ptr = graph.emplace_node( - 3, std::vector{-1, 0, 1}, std::vector{2, 1, 1}); + auto bnode_ptr = graph.emplace_node(3, std::vector{-1, 0, 1}, + std::vector{2, 1, 1}); THEN("The shape, max, min, and bounds are correct") { CHECK(bnode_ptr->size() == 3); @@ -451,8 +445,7 @@ TEST_CASE("BinaryNode") { } GIVEN("Binary node with index-wise upper bound and general lower bound") { - auto bnode_ptr = graph.emplace_node( - 2, -2.0, std::vector{0.0, 1.1}); + auto bnode_ptr = graph.emplace_node(2, -2.0, std::vector{0.0, 1.1}); THEN("The max, min, and bounds are correct") { CHECK(bnode_ptr->max() == 1.0); @@ -468,8 +461,7 @@ TEST_CASE("BinaryNode") { } GIVEN("Binary node with index-wise lower bound and general upper bound") { - auto bnode_ptr = graph.emplace_node( - 2, std::vector{-1.0, 1.0}, 100.0); + auto bnode_ptr = graph.emplace_node(2, std::vector{-1.0, 1.0}, 100.0); THEN("The max, min, and bounds are correct") { CHECK(bnode_ptr->max() == 1.0); @@ -485,13 +477,13 @@ TEST_CASE("BinaryNode") { } GIVEN("Binary node with invalid index-wise lower bounds at index 0") { - REQUIRE_THROWS(graph.emplace_node( - 2, std::vector{2, 0}, std::vector{1, 1})); + REQUIRE_THROWS(graph.emplace_node(2, std::vector{2, 0}, + std::vector{1, 1})); } GIVEN("Binary node with invalid index-wise upper bounds at index 1") { - REQUIRE_THROWS(graph.emplace_node( - 2, std::vector{0, 0}, std::vector{1, -1})); + REQUIRE_THROWS(graph.emplace_node(2, std::vector{0, 0}, + std::vector{1, -1})); } GIVEN("Invalid dynamically sized BinaryNode") { @@ -500,127 +492,113 @@ TEST_CASE("BinaryNode") { } GIVEN("(2x3)-BinaryNode with axis-wise bounds on the invalid axis -1") { - NumberNode::BoundAxisInfo bound_axis{ - -1, std::vector{NumberNode::Equal}, - std::vector{1.0}}; - REQUIRE_THROWS_WITH( - graph.emplace_node( - std::initializer_list{2, 3}, std::nullopt, std::nullopt, - std::vector{bound_axis}), - "Invalid bound axis: -1. Note, negative indexing is not supported for " - "axis-wise bounds."); + std::vector operators{NumberNode::Equal}; + std::vector bounds{1.0}; + std::vector bound_axes{{-1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid bound axis given number array shape."); } GIVEN("(2x3)-BinaryNode with axis-wise bounds on the invalid axis 2") { - NumberNode::BoundAxisInfo bound_axis{ - 2, std::vector{NumberNode::Equal}, - std::vector{1.0}}; - REQUIRE_THROWS_WITH( - graph.emplace_node( - std::initializer_list{2, 3}, std::nullopt, std::nullopt, - std::vector{bound_axis}), - "Invalid bound axis: 2. Note, negative indexing is not supported for " - "axis-wise bounds."); + std::vector operators{NumberNode::Equal}; + std::vector bounds{1.0}; + std::vector bound_axes{{2, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid bound axis given number array shape."); } GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too many operators.") { - NumberNode::BoundAxisInfo bound_axis{ - 1, - std::vector{NumberNode::LessEqual, NumberNode::Equal, - NumberNode::Equal, NumberNode::Equal}, - std::vector{1.0}}; - REQUIRE_THROWS_WITH( - graph.emplace_node( - std::initializer_list{2, 3}, std::nullopt, std::nullopt, - std::vector{bound_axis}), - "Invalid number of axis-wise operators along axis: 1 given axis size: 3"); + std::vector operators{ + NumberNode::LessEqual, NumberNode::Equal, NumberNode::Equal, NumberNode::Equal}; + std::vector bounds{1.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise operators given number array shape."); } GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too few operators.") { - NumberNode::BoundAxisInfo bound_axis{1, - std::vector{ - NumberNode::LessEqual, NumberNode::Equal}, - std::vector{1.0}}; - REQUIRE_THROWS_WITH( - graph.emplace_node( - std::initializer_list{2, 3}, std::nullopt, std::nullopt, - std::vector{bound_axis}), - "Invalid number of axis-wise operators along axis: 1 given axis size: 3"); + std::vector operators{NumberNode::LessEqual, + NumberNode::Equal}; + std::vector bounds{1.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise operators given number array shape."); } GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too many bounds.") { - NumberNode::BoundAxisInfo bound_axis{ - 1, std::vector{NumberNode::LessEqual}, - std::vector{1.0, 2.0, 3.0, 4.0}}; - REQUIRE_THROWS_WITH( - graph.emplace_node( - std::initializer_list{2, 3}, std::nullopt, std::nullopt, - std::vector{bound_axis}), - "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); + std::vector operators{NumberNode::Equal}; + std::vector bounds{1.0, 2.0, 3.0, 4.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise bounds given number array shape."); } GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too few bounds.") { - NumberNode::BoundAxisInfo bound_axis{ - 1, std::vector{NumberNode::LessEqual}, - std::vector{1.0, 2.0}}; - REQUIRE_THROWS_WITH( - graph.emplace_node( - std::initializer_list{2, 3}, std::nullopt, std::nullopt, - std::vector{bound_axis}), - "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); + std::vector operators{NumberNode::LessEqual}; + std::vector bounds{1.0, 2.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise bounds given number array shape."); } GIVEN("(2x3)-BinaryNode with duplicate axis-wise bounds on axis: 1") { - NumberNode::BoundAxisInfo bound_axis{ - 1, std::vector{NumberNode::Equal}, - std::vector{1.0}}; + std::vector operators{NumberNode::Equal}; + std::vector bounds{1.0}; + NumberNode::BoundAxisInfo bound_axis{1, operators, bounds}; + REQUIRE_THROWS_WITH( - graph.emplace_node( + graph.emplace_node( std::initializer_list{2, 3}, std::nullopt, std::nullopt, std::vector{bound_axis, bound_axis}), "Cannot define multiple axis-wise bounds for a single axis."); } GIVEN("(2x3)-BinaryNode with axis-wise bounds on axes: 0 and 1") { - NumberNode::BoundAxisInfo bound_axis_0{ - 0, std::vector{NumberNode::LessEqual}, - std::vector{1.0}}; - NumberNode::BoundAxisInfo bound_axis_1{ - 1, std::vector{NumberNode::LessEqual}, - std::vector{1.0}}; + std::vector operators{NumberNode::LessEqual}; + std::vector bounds{1.0}; + NumberNode::BoundAxisInfo bound_axis_0{0, operators, bounds}; + NumberNode::BoundAxisInfo bound_axis_1{1, operators, bounds}; + REQUIRE_THROWS_WITH( - graph.emplace_node( + graph.emplace_node( std::initializer_list{2, 3}, std::nullopt, std::nullopt, std::vector{bound_axis_0, bound_axis_1}), "Axis-wise bounds are supported for at most one axis."); } - GIVEN("(2x3x4)-IntegerNode with non-integral axis-wise bounds") { - NumberNode::BoundAxisInfo bound_axis{ - 1, std::vector{NumberNode::Equal}, - std::vector{0.1}}; - REQUIRE_THROWS_WITH( - graph.emplace_node( - std::initializer_list{2, 3}, std::nullopt, std::nullopt, - std::vector{bound_axis}), - "Axis wise bounds for integral number arrays must be intregral."); + GIVEN("(2x3x4)-BinaryNode with non-integral axis-wise bounds") { + std::vector operators{NumberNode::Equal}; + std::vector bounds{0.1}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Axis wise bounds for integral number arrays must be intregral."); } GIVEN("(3x2x2)-BinaryNode with infeasible axis-wise bound on axis: 0") { auto graph = Graph(); - - NumberNode::BoundAxisInfo bound_axis{ - 0, - std::vector{NumberNode::Equal, NumberNode::LessEqual, - NumberNode::GreaterEqual}, - std::vector{5.0, 2.0, 3.0}}; - + std::vector operators{ + NumberNode::Equal, NumberNode::LessEqual, NumberNode::GreaterEqual}; + std::vector bounds{5.0, 2.0, 3.0}; + std::vector bound_axes{{0, operators, bounds}}; // Each hyperslice along axis 0 has size 4. There is no feasible // assignment to the values in slice 0 (along axis 0) that results in a // sum equal to 5. - graph.emplace_node( - std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, - std::vector{bound_axis}); + graph.emplace_node(std::initializer_list{3, 2, 2}, std::nullopt, + std::nullopt, bound_axes); WHEN("We create a state by initialize_state()") { REQUIRE_THROWS_WITH(graph.initialize_state(), "Infeasible axis-wise bounds."); @@ -629,15 +607,12 @@ TEST_CASE("BinaryNode") { GIVEN("(3x2x2)-BinaryNode with infeasible axis-wise bound on axis: 1") { auto graph = Graph(); - - NumberNode::BoundAxisInfo bound_axis{1, - std::vector{ - NumberNode::Equal, NumberNode::GreaterEqual}, - std::vector{5.0, 7.0}}; - - graph.emplace_node( - std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, - std::vector{bound_axis}); + std::vector operators{NumberNode::Equal, + NumberNode::GreaterEqual}; + std::vector bounds{5.0, 7.0}; + std::vector bound_axes{{1, operators, bounds}}; + graph.emplace_node(std::initializer_list{3, 2, 2}, std::nullopt, + std::nullopt, bound_axes); WHEN("We create a state by initialize_state()") { // Each hyperslice along axis 1 has size 6. There is no feasible @@ -649,15 +624,12 @@ TEST_CASE("BinaryNode") { GIVEN("(3x2x2)-BinaryNode with infeasible axis-wise bound on axis: 2") { auto graph = Graph(); - - NumberNode::BoundAxisInfo bound_axis{2, - std::vector{ - NumberNode::Equal, NumberNode::LessEqual}, - std::vector{5.0, -1.0}}; - - graph.emplace_node( - std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, - std::vector{bound_axis}); + std::vector operators{NumberNode::Equal, + NumberNode::LessEqual}; + std::vector bounds{5.0, -1.0}; + std::vector bound_axes{{2, operators, bounds}}; + graph.emplace_node(std::initializer_list{3, 2, 2}, std::nullopt, + std::nullopt, bound_axes); WHEN("We create a state by initialize_state()") { // Each hyperslice along axis 2 has size 6. There is no feasible @@ -669,23 +641,19 @@ TEST_CASE("BinaryNode") { GIVEN("(3x2x2)-BinaryNode with feasible axis-wise bound on axis: 0") { auto graph = Graph(); - - NumberNode::BoundAxisInfo bound_axis{ - 0, - std::vector{NumberNode::Equal, NumberNode::LessEqual, - NumberNode::GreaterEqual}, - std::vector{1.0, 2.0, 3.0}}; - - auto bnode_ptr = graph.emplace_node( - std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, - std::vector{bound_axis}); + std::vector operators{ + NumberNode::Equal, NumberNode::LessEqual, NumberNode::GreaterEqual}; + std::vector bounds{1.0, 2.0, 3.0}; + std::vector bound_axes{{0, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{3, 2, 2}, + std::nullopt, std::nullopt, bound_axes); THEN("Axis wise bound is correct") { CHECK(bnode_ptr->axis_wise_bounds().size() == 1); NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; - CHECK(bound_axis.axis == bnode_bound_axis.axis); - CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); - CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); } WHEN("We create a state by initialize_state()") { @@ -718,23 +686,19 @@ TEST_CASE("BinaryNode") { GIVEN("(3x2x2)-BinaryNode with feasible axis-wise bound on axis: 1") { auto graph = Graph(); - - NumberNode::BoundAxisInfo bound_axis{ - 1, - std::vector{NumberNode::LessEqual, - NumberNode::GreaterEqual}, - std::vector{1.0, 5.0}}; - - auto bnode_ptr = graph.emplace_node( - std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, - std::vector{bound_axis}); + std::vector operators{NumberNode::LessEqual, + NumberNode::GreaterEqual}; + std::vector bounds{1.0, 5.0}; + std::vector bound_axes{{1, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{3, 2, 2}, + std::nullopt, std::nullopt, bound_axes); THEN("Axis wise bound is correct") { CHECK(bnode_ptr->axis_wise_bounds().size() == 1); NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; - CHECK(bound_axis.axis == bnode_bound_axis.axis); - CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); - CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); } WHEN("We create a state by initialize_state()") { @@ -766,22 +730,19 @@ TEST_CASE("BinaryNode") { GIVEN("(3x2x2)-BinaryNode with feasible axis-wise bound on axis: 2") { auto graph = Graph(); - - NumberNode::BoundAxisInfo bound_axis{2, - std::vector{ - NumberNode::Equal, NumberNode::GreaterEqual}, - std::vector{3.0, 6.0}}; - - auto bnode_ptr = graph.emplace_node( - std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, - std::vector{bound_axis}); + std::vector operators{NumberNode::Equal, + NumberNode::GreaterEqual}; + std::vector bounds{3.0, 6.0}; + std::vector bound_axes{{2, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{3, 2, 2}, + std::nullopt, std::nullopt, bound_axes); THEN("Axis wise bound is correct") { CHECK(bnode_ptr->axis_wise_bounds().size() == 1); NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; - CHECK(bound_axis.axis == bnode_bound_axis.axis); - CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); - CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); } WHEN("We create a state by initialize_state()") { @@ -813,23 +774,19 @@ TEST_CASE("BinaryNode") { GIVEN("(3x2x2)-BinaryNode with an axis-wise bound on axis: 0") { auto graph = Graph(); - - NumberNode::BoundAxisInfo bound_axis{ - 0, - std::vector{NumberNode::Equal, NumberNode::LessEqual, - NumberNode::GreaterEqual}, - std::vector{1.0, 2.0, 3.0}}; - - auto bnode_ptr = graph.emplace_node( - std::initializer_list{3, 2, 2}, std::nullopt, std::nullopt, - std::vector{bound_axis}); + std::vector operators{ + NumberNode::Equal, NumberNode::LessEqual, NumberNode::GreaterEqual}; + std::vector bounds{1.0, 2.0, 3.0}; + std::vector bound_axes{{0, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{3, 2, 2}, + std::nullopt, std::nullopt, bound_axes); THEN("Axis wise bound is correct") { CHECK(bnode_ptr->axis_wise_bounds().size() == 1); NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; - CHECK(bound_axis.axis == bnode_bound_axis.axis); - CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); - CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); } WHEN("We initialize three invalid states") { @@ -1090,7 +1047,8 @@ TEST_CASE("IntegerNode") { } } - GIVEN("Double precision numbers, which may fall outside integer range or are not integral") { + GIVEN("Double precision numbers, which may fall outside integer range or are not " + "integral") { IntegerNode inode({1}); THEN("The state is not deterministic") { CHECK(!inode.deterministic_state()); } @@ -1163,8 +1121,8 @@ TEST_CASE("IntegerNode") { } GIVEN("Integer node with index-wise bounds") { - auto inode_ptr = graph.emplace_node( - 3, std::vector{-1, 3, 5}, std::vector{1, 7, 7}); + auto inode_ptr = graph.emplace_node(3, std::vector{-1, 3, 5}, + std::vector{1, 7, 7}); THEN("The shape, max, min, and bounds are correct") { CHECK(inode_ptr->size() == 3); @@ -1225,8 +1183,7 @@ TEST_CASE("IntegerNode") { } GIVEN("Integer node with index-wise upper bound and general integer lower bound") { - auto inode_ptr = graph.emplace_node( - 2, 10, std::vector{20, 10}); + auto inode_ptr = graph.emplace_node(2, 10, std::vector{20, 10}); THEN("The max, min, and bounds are correct") { CHECK(inode_ptr->max() == 20.0); @@ -1242,8 +1199,8 @@ TEST_CASE("IntegerNode") { } GIVEN("Integer node with invalid index-wise bounds at index 0") { - REQUIRE_THROWS(graph.emplace_node( - 2, std::vector{19, 12}, std::vector{20, 11})); + REQUIRE_THROWS(graph.emplace_node(2, std::vector{19, 12}, + std::vector{20, 11})); } GIVEN("An Integer Node representing an 1d array of 10 elements with lower bound -10") { @@ -1370,123 +1327,109 @@ TEST_CASE("IntegerNode") { } GIVEN("(2x3)-IntegerNode with axis-wise bounds on the invalid axis -2") { - NumberNode::BoundAxisInfo bound_axis{ - -2, std::vector{NumberNode::Equal}, - std::vector{20.0}}; - REQUIRE_THROWS_WITH( - graph.emplace_node( - std::initializer_list{2, 3}, std::nullopt, std::nullopt, - std::vector{bound_axis}), - "Invalid bound axis: -2. Note, negative indexing is not supported for " - "axis-wise bounds."); + std::vector operators{NumberNode::Equal}; + std::vector bounds{20.0}; + std::vector bound_axes{{-2, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid bound axis given number array shape."); } GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on the invalid axis 3") { - NumberNode::BoundAxisInfo bound_axis{ - 3, std::vector{NumberNode::Equal}, - std::vector{10.0}}; - REQUIRE_THROWS_WITH( - graph.emplace_node( - std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, - std::vector{bound_axis}), - "Invalid bound axis: 3. Note, negative indexing is not supported for " - "axis-wise bounds."); + std::vector operators{NumberNode::Equal}; + std::vector bounds{10.0}; + std::vector bound_axes{{3, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Invalid bound axis given number array shape."); } GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too many operators.") { - NumberNode::BoundAxisInfo bound_axis{ - 1, - std::vector{NumberNode::LessEqual, NumberNode::Equal, - NumberNode::Equal, NumberNode::Equal}, - std::vector{-10.0}}; - REQUIRE_THROWS_WITH( - graph.emplace_node( - std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, - std::vector{bound_axis}), - "Invalid number of axis-wise operators along axis: 1 given axis size: 3"); + std::vector operators{ + NumberNode::LessEqual, NumberNode::Equal, NumberNode::Equal, NumberNode::Equal}; + std::vector bounds{-10.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise operators given number array shape."); } GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too few operators.") { - NumberNode::BoundAxisInfo bound_axis{1, - std::vector{ - NumberNode::LessEqual, NumberNode::Equal}, - std::vector{-11.0}}; - REQUIRE_THROWS_WITH( - graph.emplace_node( - std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, - std::vector{bound_axis}), - "Invalid number of axis-wise operators along axis: 1 given axis size: 3"); + std::vector operators{NumberNode::LessEqual, + NumberNode::Equal}; + std::vector bounds{-11.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise operators given number array shape."); } GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too many bounds.") { - NumberNode::BoundAxisInfo bound_axis{ - 1, std::vector{NumberNode::LessEqual}, - std::vector{-10.0, 20.0, 30.0, 40.0}}; - REQUIRE_THROWS_WITH( - graph.emplace_node( - std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, - std::vector{bound_axis}), - "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); + std::vector operators{NumberNode::LessEqual}; + std::vector bounds{-10.0, 20.0, 30.0, 40.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise bounds given number array shape."); } GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too few bounds.") { - NumberNode::BoundAxisInfo bound_axis{ - 1, std::vector{NumberNode::LessEqual}, - std::vector{111.0, -223.0}}; - REQUIRE_THROWS_WITH( - graph.emplace_node( - std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, - std::vector{bound_axis}), - "Invalid number of axis-wise bounds along axis: 1 given axis size: 3"); + std::vector operators{NumberNode::LessEqual}; + std::vector bounds{111.0, -223.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise bounds given number array shape."); } GIVEN("(2x3x4)-IntegerNode with duplicate axis-wise bounds on axis: 1") { - NumberNode::BoundAxisInfo bound_axis{ - 1, std::vector{NumberNode::Equal}, - std::vector{100.0}}; + std::vector operators{NumberNode::Equal}; + std::vector bounds{100.0}; + NumberNode::BoundAxisInfo bound_axis{1, operators, bounds}; + REQUIRE_THROWS_WITH( - graph.emplace_node( + graph.emplace_node( std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, std::vector{bound_axis, bound_axis}), "Cannot define multiple axis-wise bounds for a single axis."); } GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axes: 0 and 1") { - NumberNode::BoundAxisInfo bound_axis_0{ - 0, std::vector{NumberNode::LessEqual}, - std::vector{11.0}}; - NumberNode::BoundAxisInfo bound_axis_1{ - 1, std::vector{NumberNode::LessEqual}, - std::vector{12.0}}; + std::vector operators{NumberNode::Equal}; + std::vector bounds{100.0}; + NumberNode::BoundAxisInfo bound_axis_0{0, operators, bounds}; + NumberNode::BoundAxisInfo bound_axis_1{1, operators, bounds}; + REQUIRE_THROWS_WITH( - graph.emplace_node( + graph.emplace_node( std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, std::vector{bound_axis_0, bound_axis_1}), "Axis-wise bounds are supported for at most one axis."); } GIVEN("(2x3x4)-IntegerNode with non-integral axis-wise bounds") { - NumberNode::BoundAxisInfo bound_axis{ - 2, std::vector{NumberNode::LessEqual}, - std::vector{11.0, 12.0001, 0.0, 0.0}}; - REQUIRE_THROWS_WITH( - graph.emplace_node( - std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, - std::vector{bound_axis}), - "Axis wise bounds for integral number arrays must be intregral."); + std::vector operators{NumberNode::LessEqual}; + std::vector bounds{11.0, 12.0001, 0.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Axis wise bounds for integral number arrays must be intregral."); } GIVEN("(2x3x2)-IntegerNode with infeasible axis-wise bound on axis: 0") { auto graph = Graph(); - - NumberNode::BoundAxisInfo bound_axis{0, - std::vector{ - NumberNode::Equal, NumberNode::LessEqual}, - std::vector{5.0, -31.0}}; - - graph.emplace_node( - std::initializer_list{2, 3, 2}, -5, 8, - std::vector{bound_axis}); + std::vector operators{NumberNode::Equal, + NumberNode::LessEqual}; + std::vector bounds{5.0, -31.0}; + std::vector bound_axes{{0, operators, bounds}}; + graph.emplace_node(std::initializer_list{2, 3, 2}, -5, 8, bound_axes); WHEN("We create a state by initialize_state()") { // Each hyperslice along axis 0 has size 6. There is no feasible @@ -1498,16 +1441,11 @@ TEST_CASE("IntegerNode") { GIVEN("(2x3x2)-IntegerNode with infeasible axis-wise bound on axis: 1") { auto graph = Graph(); - - NumberNode::BoundAxisInfo bound_axis{ - 1, - std::vector{NumberNode::GreaterEqual, - NumberNode::Equal, NumberNode::Equal}, - std::vector{33.0, 0.0, 0.0}}; - - graph.emplace_node( - std::initializer_list{2, 3, 2}, -5, 8, - std::vector{bound_axis}); + std::vector operators{NumberNode::GreaterEqual, + NumberNode::Equal, NumberNode::Equal}; + std::vector bounds{33.0, 0.0, 0.0}; + std::vector bound_axes{{1, operators, bounds}}; + graph.emplace_node(std::initializer_list{2, 3, 2}, -5, 8, bound_axes); WHEN("We create a state by initialize_state()") { // Each hyperslice along axis 1 has size 4. There is no feasible @@ -1519,15 +1457,11 @@ TEST_CASE("IntegerNode") { GIVEN("(2x3x2)-IntegerNode with infeasible axis-wise bound on axis: 2") { auto graph = Graph(); - - NumberNode::BoundAxisInfo bound_axis{2, - std::vector{ - NumberNode::GreaterEqual, NumberNode::Equal}, - std::vector{-1.0, 49.0}}; - - graph.emplace_node( - std::initializer_list{2, 3, 2}, -5, 8, - std::vector{bound_axis}); + std::vector operators{NumberNode::GreaterEqual, + NumberNode::Equal}; + std::vector bounds{-1.0, 49.0}; + std::vector bound_axes{{2, operators, bounds}}; + graph.emplace_node(std::initializer_list{2, 3, 2}, -5, 8, bound_axes); WHEN("We create a state by initialize_state()") { // Each hyperslice along axis 2 has size 6. There is no feasible @@ -1539,22 +1473,19 @@ TEST_CASE("IntegerNode") { GIVEN("(2x3x2)-IntegerNode with feasible axis-wise bound on axis: 0") { auto graph = Graph(); - - NumberNode::BoundAxisInfo bound_axis{0, - std::vector{ - NumberNode::Equal, NumberNode::GreaterEqual}, - std::vector{-21.0, 9.0}}; - - auto bnode_ptr = graph.emplace_node( - std::initializer_list{2, 3, 2}, -5, 8, - std::vector{bound_axis}); + std::vector operators{NumberNode::Equal, + NumberNode::GreaterEqual}; + std::vector bounds{-21.0, 9.0}; + std::vector bound_axes{{0, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{2, 3, 2}, + -5, 8, bound_axes); THEN("Axis wise bound is correct") { CHECK(bnode_ptr->axis_wise_bounds().size() == 1); NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; - CHECK(bound_axis.axis == bnode_bound_axis.axis); - CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); - CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); } WHEN("We create a state by initialize_state()") { @@ -1587,23 +1518,19 @@ TEST_CASE("IntegerNode") { GIVEN("(2x3x2)-IntegerNode with feasible axis-wise bound on axis: 1") { auto graph = Graph(); - - NumberNode::BoundAxisInfo bound_axis{ - 1, - std::vector{ - NumberNode::Equal, NumberNode::GreaterEqual, NumberNode::LessEqual}, - std::vector{0.0, -2.0, 0.0}}; - - auto bnode_ptr = graph.emplace_node( - std::initializer_list{2, 3, 2}, -5, 8, - std::vector{bound_axis}); + std::vector operators{ + NumberNode::Equal, NumberNode::GreaterEqual, NumberNode::LessEqual}; + std::vector bounds{0.0, -2.0, 0.0}; + std::vector bound_axes{{1, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{2, 3, 2}, + -5, 8, bound_axes); THEN("Axis wise bound is correct") { CHECK(bnode_ptr->axis_wise_bounds().size() == 1); NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; - CHECK(bound_axis.axis == bnode_bound_axis.axis); - CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); - CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); } WHEN("We create a state by initialize_state()") { @@ -1639,22 +1566,19 @@ TEST_CASE("IntegerNode") { GIVEN("(2x3x2)-IntegerNode with feasible axis-wise bound on axis: 2") { auto graph = Graph(); - - NumberNode::BoundAxisInfo bound_axis{2, - std::vector{ - NumberNode::Equal, NumberNode::GreaterEqual}, - std::vector{23.0, 14.0}}; - - auto bnode_ptr = graph.emplace_node( - std::initializer_list{2, 3, 2}, -5, 8, - std::vector{bound_axis}); + std::vector operators{NumberNode::Equal, + NumberNode::GreaterEqual}; + std::vector bounds{23.0, 14.0}; + std::vector bound_axes{{2, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{2, 3, 2}, + -5, 8, bound_axes); THEN("Axis wise bound is correct") { CHECK(bnode_ptr->axis_wise_bounds().size() == 1); NumberNode::BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; - CHECK(bound_axis.axis == bnode_bound_axis.axis); - CHECK_THAT(bound_axis.operators, RangeEquals(bnode_bound_axis.operators)); - CHECK_THAT(bound_axis.bounds, RangeEquals(bnode_bound_axis.bounds)); + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); } WHEN("We create a state by initialize_state()") { @@ -1687,24 +1611,20 @@ TEST_CASE("IntegerNode") { GIVEN("(2x3x2)-IntegerNode with index-wise bounds and an axis-wise bound on axis: 1") { auto graph = Graph(); - - NumberNode::BoundAxisInfo bound_axis{ - 1, - std::vector{NumberNode::Equal, NumberNode::LessEqual, - NumberNode::GreaterEqual}, - std::vector{11.0, 2.0, 5.0}}; - - auto inode_ptr = graph.emplace_node( - std::initializer_list{2, 3, 2}, -5, 8, - std::vector{bound_axis}); + std::vector operators{ + NumberNode::Equal, NumberNode::LessEqual, NumberNode::GreaterEqual}; + std::vector bounds{11.0, 2.0, 5.0}; + std::vector bound_axes{{1, operators, bounds}}; + auto inode_ptr = graph.emplace_node(std::initializer_list{2, 3, 2}, + -5, 8, bound_axes); THEN("Axis wise bound is correct") { CHECK(inode_ptr->axis_wise_bounds().size() == 1); const NumberNode::BoundAxisInfo inode_bound_axis_ptr = inode_ptr->axis_wise_bounds().data()[0]; - CHECK(bound_axis.axis == inode_bound_axis_ptr.axis); - CHECK_THAT(bound_axis.operators, RangeEquals(inode_bound_axis_ptr.operators)); - CHECK_THAT(bound_axis.bounds, RangeEquals(inode_bound_axis_ptr.bounds)); + CHECK(bound_axes[0].axis == inode_bound_axis_ptr.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(inode_bound_axis_ptr.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(inode_bound_axis_ptr.bounds)); } WHEN("We initialize three invalid states") {