diff --git a/src/algorithms/gradient_free/nelder_mead.rs b/src/algorithms/gradient_free/nelder_mead.rs index cd419ae..226de96 100644 --- a/src/algorithms/gradient_free/nelder_mead.rs +++ b/src/algorithms/gradient_free/nelder_mead.rs @@ -638,6 +638,19 @@ impl NelderMeadConfig { self.beta = value; self } + /// Set the reflection coefficient $`\alpha`$ (default = `1`) and the expansion coefficient $`\beta`$ (default = `2`) simultaneously + /// + /// # Panics + /// + /// This method will panic if $`\alpha <= 0`$, $`\beta <= 1`$, or $`\beta <= \alpha`$. + pub fn with_alpha_beta(mut self, alpha: Float, beta: Float) -> Self { + assert!(alpha > 0.0); + assert!(beta > 1.0); + assert!(beta > alpha); + self.alpha = alpha; + self.beta = beta; + self + } /// Set the contraction coefficient $`\gamma`$ (default = `0.5`). /// /// # Panics @@ -702,7 +715,7 @@ impl SupportsTransform for NelderMeadConfig { /// input vector. The algorithm is as follows: /// /// 0. Pick a starting simplex. The default implementation just takes one simplex point to be the -/// starting point and the others to be steps of equal size in each coordinate direction. +/// starting point and the others to be the starting point scaled by a small amount in each coordinate direction. /// 1. Compute $`f(\vec{x}_i)`$ for each point in the simplex. /// 2. Calculate the centroid of all but the worst point $`\vec{x}^\dagger`$ in the simplex, /// $`\vec{x}_o`$. @@ -1323,6 +1336,31 @@ mod tests { let _ = NelderMeadConfig::new([1.0, 1.0]).with_beta(1.0); } + #[test] + fn with_alpha_beta_sets_values() { + let nmc = NelderMeadConfig::new([1.0, 1.0]).with_alpha_beta(1.1, 2.2); + assert_eq!(nmc.alpha, 1.1); + assert_eq!(nmc.beta, 2.2); + } + + #[test] + #[should_panic] + fn with_alpha_beta_panics_when_alpha_nonpositive() { + let _ = NelderMeadConfig::new([1.0, 1.0]).with_alpha_beta(0.0, 2.0); + } + + #[test] + #[should_panic] + fn with_alpha_beta_panics_when_beta_not_gt_one() { + let _ = NelderMeadConfig::new([1.0, 1.0]).with_alpha_beta(0.5, 1.0); + } + + #[test] + #[should_panic] + fn with_alpha_beta_panics_when_beta_not_gt_alpha() { + let _ = NelderMeadConfig::new([1.0, 1.0]).with_alpha_beta(1.6, 1.5); + } + #[test] #[should_panic] fn with_beta_panics_when_not_gt_alpha() { diff --git a/src/algorithms/line_search/backtracking_line_search.rs b/src/algorithms/line_search/backtracking_line_search.rs index 7c03004..862c3a1 100644 --- a/src/algorithms/line_search/backtracking_line_search.rs +++ b/src/algorithms/line_search/backtracking_line_search.rs @@ -19,6 +19,33 @@ impl Default for BacktrackingLineSearch { Self { rho: 0.5, c: 1e-4 } } } +impl BacktrackingLineSearch { + /// Set the backtracking factor $`\rho`$ (default = `0.5`). + /// + /// On each unsuccessful Armijo check, the step is scaled by $`\rho`$. + /// + /// # Panics + /// + /// Panics if $`0 \ge \rho`$ or $`\rho \ge 1`$. + pub fn with_rho(mut self, rho: Float) -> Self { + assert!(0.0 < rho && rho < 1.0); + self.rho = rho; + self + } + + /// Set the Armijo parameter $`c`$ (default = `1e-4`). + /// + /// The Armijo condition is $`\phi(\alpha) \le \phi(0) + c\,\alpha\,\phi'(0)`$. + /// + /// # Panics + /// + /// Panics if $`0 \ge c`$ or $`c \ge 1`$. + pub fn with_c(mut self, c: Float) -> Self { + assert!(0.0 < c && c < 1.0); + self.c = c; + self + } +} impl LineSearch for BacktrackingLineSearch { fn search( @@ -58,3 +85,44 @@ impl LineSearch for BacktrackingLineSearch { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn with_rho_sets_value() { + let ls = BacktrackingLineSearch::default().with_rho(0.7); + assert_eq!(ls.rho, 0.7); + } + + #[test] + fn with_c_sets_value() { + let ls = BacktrackingLineSearch::default().with_c(1e-3); + assert_eq!(ls.c, 1e-3); + } + + #[test] + #[should_panic] + fn with_rho_panics_when_out_of_range_low() { + let _ = BacktrackingLineSearch::default().with_rho(0.0); + } + + #[test] + #[should_panic] + fn with_rho_panics_when_out_of_range_high() { + let _ = BacktrackingLineSearch::default().with_rho(1.0); + } + + #[test] + #[should_panic] + fn with_c_panics_when_out_of_range_low() { + let _ = BacktrackingLineSearch::default().with_c(0.0); + } + + #[test] + #[should_panic] + fn with_c_panics_when_out_of_range_high() { + let _ = BacktrackingLineSearch::default().with_c(1.0); + } +} diff --git a/src/algorithms/line_search/hager_zhang_line_search.rs b/src/algorithms/line_search/hager_zhang_line_search.rs index db9b550..fc96532 100644 --- a/src/algorithms/line_search/hager_zhang_line_search.rs +++ b/src/algorithms/line_search/hager_zhang_line_search.rs @@ -44,12 +44,11 @@ impl HagerZhangLineSearch { self.max_iters = max_iters; self } - /// Set the parameter $`\delta`$ used in the Armijo condition evaluation (defaults to - /// 0.1). + /// Set the parameter $`\delta`$ used in the Armijo condition evaluation (defaults to 0.1). /// /// # Panics /// - /// This method will panic if the condition $`0 < \delta < \sigma < 1`$ is not met. + /// This method will panic if the condition $`0 < \delta < \sigma`$ is not met. pub fn with_delta(mut self, delta: Float) -> Self { assert!(0.0 < delta); assert!(delta < self.sigma); @@ -60,13 +59,26 @@ impl HagerZhangLineSearch { /// /// # Panics /// - /// This method will panic if the condition $`0 < \delta < \sigma < 1`$ is not met. + /// This method will panic if the condition $`\delta < \sigma < 1`$ is not met. pub fn with_sigma(mut self, sigma: Float) -> Self { assert!(1.0 > sigma); assert!(sigma > self.delta); self.sigma = sigma; self } + /// Set the parameter $`\delta`$ used in the Armijo condition evaluation (defaults to 0.1) and the parameter $`\sigma`$ used in the second Wolfe condition (defaults to 0.9) simultaneously. + /// + /// # Panics + /// + /// This method will panic if the condition $`0 < \delta < \sigma < 1`$ is not met. + pub fn with_delta_sigma(mut self, delta: Float, sigma: Float) -> Self { + assert!(0.0 < delta); + assert!(1.0 > sigma); + assert!(delta < sigma); + self.delta = delta; + self.sigma = sigma; + self + } /// Set the tolerance parameter $`\epsilon`$ used in the approximate Wolfe termination /// conditions (defaults to `MACH_EPS^(1/3)`). /// @@ -291,3 +303,118 @@ impl LineSearch for HagerZhangLineSearch { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn with_delta_sets_value() { + let ls = HagerZhangLineSearch::default().with_delta(0.2); + assert_eq!(ls.delta, 0.2); + } + + #[test] + fn with_sigma_sets_value() { + let ls = HagerZhangLineSearch::default().with_sigma(0.7); + assert_eq!(ls.sigma, 0.7); + } + + #[test] + fn with_delta_sigma_sets_both() { + let ls = HagerZhangLineSearch::default().with_delta_sigma(0.05, 0.8); + assert_eq!(ls.delta, 0.05); + assert_eq!(ls.sigma, 0.8); + } + + #[test] + fn with_epsilon_sets_value() { + let ls = HagerZhangLineSearch::default().with_epsilon(1e-8); + assert_eq!(ls.epsilon, 1e-8); + assert!(ls.epsilon > 0.0); + } + + #[test] + fn with_theta_sets_value() { + let ls = HagerZhangLineSearch::default().with_theta(0.6); + assert_eq!(ls.theta, 0.6); + assert!(0.0 < ls.theta && ls.theta < 1.0); + } + + #[test] + fn with_gamma_sets_value() { + let ls = HagerZhangLineSearch::default().with_gamma(0.7); + assert_eq!(ls.gamma, 0.7); + assert!(0.0 < ls.gamma && ls.gamma < 1.0); + } + + #[test] + fn with_max_bisects_sets_value() { + let ls = HagerZhangLineSearch::default().with_max_bisects(7); + assert_eq!(ls.max_bisects, 7); + } + + #[test] + #[should_panic] + fn with_delta_panics_when_nonpositive() { + let _ = HagerZhangLineSearch::default().with_delta(0.0); + } + + #[test] + #[should_panic] + fn with_delta_panics_when_not_less_than_sigma() { + let _ = HagerZhangLineSearch::default() + .with_sigma(0.4) + .with_delta(0.5); + } + + #[test] + #[should_panic] + fn with_sigma_panics_when_not_less_than_one() { + let _ = HagerZhangLineSearch::default().with_sigma(1.0); + } + + #[test] + #[should_panic] + fn with_sigma_panics_when_not_greater_than_delta() { + let _ = HagerZhangLineSearch::default() + .with_delta(0.2) + .with_sigma(0.1); + } + + #[test] + #[should_panic] + fn with_delta_sigma_panics_when_bad_ordering() { + let _ = HagerZhangLineSearch::default().with_delta_sigma(0.5, 0.2); + } + + #[test] + #[should_panic] + fn with_delta_sigma_panics_when_sigma_not_less_than_one() { + let _ = HagerZhangLineSearch::default().with_delta_sigma(0.2, 1.0); + } + + #[test] + #[should_panic] + fn with_delta_sigma_panics_when_delta_not_positive() { + let _ = HagerZhangLineSearch::default().with_delta_sigma(0.0, 0.5); + } + + #[test] + #[should_panic] + fn with_epsilon_panics_when_nonpositive() { + let _ = HagerZhangLineSearch::default().with_epsilon(0.0); + } + + #[test] + #[should_panic] + fn with_theta_panics_when_out_of_range() { + let _ = HagerZhangLineSearch::default().with_theta(1.0); + } + + #[test] + #[should_panic] + fn with_gamma_panics_when_out_of_range() { + let _ = HagerZhangLineSearch::default().with_gamma(0.0); + } +} diff --git a/src/algorithms/line_search/more_thuente_line_search.rs b/src/algorithms/line_search/more_thuente_line_search.rs index 58a4452..d0f47a8 100644 --- a/src/algorithms/line_search/more_thuente_line_search.rs +++ b/src/algorithms/line_search/more_thuente_line_search.rs @@ -43,12 +43,11 @@ impl MoreThuenteLineSearch { self.max_zoom = max_zoom; self } - /// Set the first control parameter, used in the Armijo condition evaluation (defaults to - /// 1e-4). + /// Set the first control parameter, used in the Armijo condition evaluation (defaults to 1e-4). /// /// # Panics /// - /// This method will panic if the condition $`0 < c_1 < c_2 < 1`$ is not met. + /// This method will panic if the condition $`0 < c_1 < c_2`$ is not met. pub fn with_c1(mut self, c1: Float) -> Self { assert!(0.0 < c1); assert!(c1 < self.c2); @@ -59,13 +58,26 @@ impl MoreThuenteLineSearch { /// /// # Panics /// - /// This method will panic if the condition $`0 < c_1 < c_2 < 1`$ is not met. + /// This method will panic if the condition $`c_1 < c_2 < 1`$ is not met. pub fn with_c2(mut self, c2: Float) -> Self { assert!(1.0 > c2); assert!(c2 > self.c1); self.c2 = c2; self } + /// Set the first control parameter, used in the Armijo condition evaluation (defaults to 1e-4) and the second control parameter, used in the second Wolfe condition (defaults to 0.9) simultaneously. + /// + /// # Panics + /// + /// This method will panic if the condition $`0 < c_1 < c_2 < 1`$ is not met. + pub fn with_c1_c2(mut self, c1: Float, c2: Float) -> Self { + assert!(0.0 < c1); + assert!(1.0 > c2); + assert!(c1 < c2); + self.c1 = c1; + self.c2 = c2; + self + } } impl MoreThuenteLineSearch { @@ -203,3 +215,72 @@ impl LineSearch for MoreThuenteLineSearch { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn with_c1_sets_value() { + let ls = MoreThuenteLineSearch::default().with_c1(1e-3); + assert_eq!(ls.c1, 1e-3); + assert!(ls.c1 > 0.0 && ls.c1 < ls.c2); + } + + #[test] + fn with_c2_sets_value() { + let ls = MoreThuenteLineSearch::default().with_c2(0.8); + assert_eq!(ls.c2, 0.8); + assert!(ls.c2 < 1.0 && ls.c2 > ls.c1); + } + + #[test] + fn with_c1_c2_sets_both() { + let ls = MoreThuenteLineSearch::default().with_c1_c2(1e-5, 0.7); + assert_eq!(ls.c1, 1e-5); + assert_eq!(ls.c2, 0.7); + assert!(ls.c1 > 0.0 && ls.c2 < 1.0 && ls.c1 < ls.c2); + } + + #[test] + #[should_panic] + fn with_c1_panics_when_nonpositive() { + let _ = MoreThuenteLineSearch::default().with_c1(0.0); + } + + #[test] + #[should_panic] + fn with_c1_panics_when_not_less_than_c2() { + let _ = MoreThuenteLineSearch::default().with_c2(0.2).with_c1(0.3); + } + + #[test] + #[should_panic] + fn with_c2_panics_when_not_less_than_one() { + let _ = MoreThuenteLineSearch::default().with_c2(1.0); + } + + #[test] + #[should_panic] + fn with_c2_panics_when_not_greater_than_c1() { + let _ = MoreThuenteLineSearch::default().with_c1(1e-4).with_c2(1e-5); + } + + #[test] + #[should_panic] + fn with_c1_c2_panics_when_bad_ordering() { + let _ = MoreThuenteLineSearch::default().with_c1_c2(0.9, 0.1); + } + + #[test] + #[should_panic] + fn with_c1_c2_panics_when_c2_not_less_than_one() { + let _ = MoreThuenteLineSearch::default().with_c1_c2(1e-4, 1.0); + } + + #[test] + #[should_panic] + fn with_c1_c2_panics_when_c1_not_positive() { + let _ = MoreThuenteLineSearch::default().with_c1_c2(0.0, 0.5); + } +} diff --git a/src/algorithms/particles/swarm.rs b/src/algorithms/particles/swarm.rs index 06306e7..245bfd4 100644 --- a/src/algorithms/particles/swarm.rs +++ b/src/algorithms/particles/swarm.rs @@ -82,28 +82,25 @@ impl Swarm { Ok(()) } /// Sets the topology used by the swarm (default = [`SwarmTopology::Global`]). - pub const fn with_topology(&mut self, value: SwarmTopology) -> &mut Self { + pub const fn with_topology(mut self, value: SwarmTopology) -> Self { self.topology = value; self } /// Sets the update method used by the swarm (default = [`SwarmUpdateMethod::Synchronous`]). - pub const fn with_update_method(&mut self, value: SwarmUpdateMethod) -> &mut Self { + pub const fn with_update_method(mut self, value: SwarmUpdateMethod) -> Self { self.update_method = value; self } /// Set the [`PSO`](super::PSO)'s [`SwarmVelocityInitializer`]. pub fn with_velocity_initializer( - &mut self, + mut self, velocity_initializer: SwarmVelocityInitializer, - ) -> &mut Self { + ) -> Self { self.velocity_initializer = velocity_initializer; self } /// Set the [`SwarmBoundaryMethod`] for the [`PSO`](super::PSO). - pub const fn with_boundary_method( - &mut self, - boundary_method: SwarmBoundaryMethod, - ) -> &mut Self { + pub const fn with_boundary_method(mut self, boundary_method: SwarmBoundaryMethod) -> Self { self.boundary_method = boundary_method; self } diff --git a/src/algorithms/particles/swarm_status.rs b/src/algorithms/particles/swarm_status.rs index 1b86069..df6c2a1 100644 --- a/src/algorithms/particles/swarm_status.rs +++ b/src/algorithms/particles/swarm_status.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; /// A status for particle swarm optimization and similar methods. #[derive(Clone, Serialize, Deserialize)] pub struct SwarmStatus { - /// The global best position found by all particles (in unbounded space) + /// The global best position found by all particles pub gbest: Point>, /// An indicator of whether the swarm has converged pub converged: bool, diff --git a/src/traits/abort_signal.rs b/src/traits/abort_signal.rs index eda0d9b..b66fff6 100644 --- a/src/traits/abort_signal.rs +++ b/src/traits/abort_signal.rs @@ -5,7 +5,7 @@ use std::ops::ControlFlow; /// A trait for abort signals. /// This trait is used in minimizers to check if the user has requested to abort the calculation. -pub trait AbortSignal: DynClone { +pub trait AbortSignal: DynClone + Send + Sync { /// Return `true` if the user has requested to abort the calculation. fn is_aborted(&self) -> bool; /// Abort the calculation. Make `is_aborted()` return `true`. diff --git a/src/traits/algorithm.rs b/src/traits/algorithm.rs index 928df77..767e7ac 100644 --- a/src/traits/algorithm.rs +++ b/src/traits/algorithm.rs @@ -8,7 +8,7 @@ use std::convert::Infallible; /// /// This trait is implemented for the algorithms found in the [`algorithms`](`crate::algorithms`) module and contains /// all the methods needed to [`process`](`Algorithm::process`) a problem. -pub trait Algorithm { +pub trait Algorithm: Send + Sync { /// A type which holds a summary of the algorithm's ending state. type Summary; /// The configuration struct for the algorithm. diff --git a/src/traits/boundlike.rs b/src/traits/boundlike.rs index d2b009b..f2ad920 100644 --- a/src/traits/boundlike.rs +++ b/src/traits/boundlike.rs @@ -126,20 +126,18 @@ impl From<(Float, Float)> for Bound { impl From<(Option, Option)> for Bound { fn from(value: (Option, Option)) -> Self { match value { - (Some(sa), Some(sb)) => { - let (l, u) = if sa < sb { (sa, sb) } else { (sb, sa) }; - Self::LowerAndUpperBound(l, u) - } - (Some(l), None) => Self::LowerBound(l), - (None, Some(u)) => Self::UpperBound(u), - (None, None) => Self::NoBound, + (Some(a), Some(b)) => (a, b), + (Some(l), None) => (l, Float::INFINITY), + (None, Some(u)) => (Float::NEG_INFINITY, u), + (None, None) => (Float::NEG_INFINITY, Float::INFINITY), } + .into() } } /// A trait representing a transform specifically involving a parameter bound. #[typetag::serde] -pub trait BoundLike: DynClone + Debug { +pub trait BoundLike: DynClone + Debug + Send + Sync { /// The mapping to internal (unbounded) coordinates. fn to_internal_impl(&self, bound: Bound, x: Float) -> Float; /// The first derivative of the mapping to internal (unbounded) coordinates. diff --git a/src/traits/linesearch.rs b/src/traits/linesearch.rs index ac767f2..ff91cb2 100644 --- a/src/traits/linesearch.rs +++ b/src/traits/linesearch.rs @@ -19,7 +19,7 @@ pub struct LineSearchOutput { /// /// Line searches are one-dimensional minimizers typically used to determine optimal step sizes for /// [`Algorithm`](`crate::traits::Algorithm`)s which only provide a direction for the next optimal step. -pub trait LineSearch: DynClone { +pub trait LineSearch: DynClone + Send + Sync { /// The search method takes the current position of the minimizer, `x`, the search direction /// `p`, the objective function `func`, optional bounds `bounds`, and any arguments to the /// objective function `args`, and returns a [`Result`] containing another [`Result`]. The diff --git a/src/traits/transform.rs b/src/traits/transform.rs index 7261ca1..4926294 100644 --- a/src/traits/transform.rs +++ b/src/traits/transform.rs @@ -12,7 +12,7 @@ use crate::{ /// /// This can be used to restrict an algorithm to a space of valid coordinates, such as a bounded /// space or a space satisfying some constraints between parameters. -pub trait Transform: DynClone { +pub trait Transform: DynClone + Send + Sync { /// Map from internal to external coordinates. fn to_external<'a>(&'a self, z: &'a DVector) -> Cow<'a, DVector>; /// Map from external to internal coordinates.