Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion src/algorithms/gradient_free/nelder_mead.rs
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,19 @@ impl NelderMeadConfig {
self.beta = value;
self
}
/// Set the reflection coefficient $`\alpha`$ (default = `1`) and the expansion coefficient $`\beta`$ (default = `2`) simultaneously
///
/// # Panics
///
/// This method will panic if $`\alpha <= 0`$, $`\beta <= 1`$, or $`\beta <= \alpha`$.
pub fn with_alpha_beta(mut self, alpha: Float, beta: Float) -> Self {
assert!(alpha > 0.0);
assert!(beta > 1.0);
assert!(beta > alpha);
self.alpha = alpha;
self.beta = beta;
self
}
/// Set the contraction coefficient $`\gamma`$ (default = `0.5`).
///
/// # Panics
Expand Down Expand Up @@ -702,7 +715,7 @@ impl SupportsTransform for NelderMeadConfig {
/// input vector. The algorithm is as follows:
///
/// 0. Pick a starting simplex. The default implementation just takes one simplex point to be the
/// starting point and the others to be steps of equal size in each coordinate direction.
/// starting point and the others to be the starting point scaled by a small amount in each coordinate direction.
/// 1. Compute $`f(\vec{x}_i)`$ for each point in the simplex.
/// 2. Calculate the centroid of all but the worst point $`\vec{x}^\dagger`$ in the simplex,
/// $`\vec{x}_o`$.
Expand Down Expand Up @@ -1323,6 +1336,31 @@ mod tests {
let _ = NelderMeadConfig::new([1.0, 1.0]).with_beta(1.0);
}

#[test]
fn with_alpha_beta_sets_values() {
let nmc = NelderMeadConfig::new([1.0, 1.0]).with_alpha_beta(1.1, 2.2);
assert_eq!(nmc.alpha, 1.1);
assert_eq!(nmc.beta, 2.2);
}

#[test]
#[should_panic]
fn with_alpha_beta_panics_when_alpha_nonpositive() {
let _ = NelderMeadConfig::new([1.0, 1.0]).with_alpha_beta(0.0, 2.0);
}

#[test]
#[should_panic]
fn with_alpha_beta_panics_when_beta_not_gt_one() {
let _ = NelderMeadConfig::new([1.0, 1.0]).with_alpha_beta(0.5, 1.0);
}

#[test]
#[should_panic]
fn with_alpha_beta_panics_when_beta_not_gt_alpha() {
let _ = NelderMeadConfig::new([1.0, 1.0]).with_alpha_beta(1.6, 1.5);
}

#[test]
#[should_panic]
fn with_beta_panics_when_not_gt_alpha() {
Expand Down
68 changes: 68 additions & 0 deletions src/algorithms/line_search/backtracking_line_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,33 @@ impl Default for BacktrackingLineSearch {
Self { rho: 0.5, c: 1e-4 }
}
}
impl BacktrackingLineSearch {
/// Set the backtracking factor $`\rho`$ (default = `0.5`).
///
/// On each unsuccessful Armijo check, the step is scaled by $`\rho`$.
///
/// # Panics
///
/// Panics if $`0 \ge \rho`$ or $`\rho \ge 1`$.
pub fn with_rho(mut self, rho: Float) -> Self {
assert!(0.0 < rho && rho < 1.0);
self.rho = rho;
self
}

/// Set the Armijo parameter $`c`$ (default = `1e-4`).
///
/// The Armijo condition is $`\phi(\alpha) \le \phi(0) + c\,\alpha\,\phi'(0)`$.
///
/// # Panics
///
/// Panics if $`0 \ge c`$ or $`c \ge 1`$.
pub fn with_c(mut self, c: Float) -> Self {
assert!(0.0 < c && c < 1.0);
self.c = c;
self
}
}

impl<U, E> LineSearch<GradientStatus, U, E> for BacktrackingLineSearch {
fn search(
Expand Down Expand Up @@ -58,3 +85,44 @@ impl<U, E> LineSearch<GradientStatus, U, E> for BacktrackingLineSearch {
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn with_rho_sets_value() {
let ls = BacktrackingLineSearch::default().with_rho(0.7);
assert_eq!(ls.rho, 0.7);
}

#[test]
fn with_c_sets_value() {
let ls = BacktrackingLineSearch::default().with_c(1e-3);
assert_eq!(ls.c, 1e-3);
}

#[test]
#[should_panic]
fn with_rho_panics_when_out_of_range_low() {
let _ = BacktrackingLineSearch::default().with_rho(0.0);
}

#[test]
#[should_panic]
fn with_rho_panics_when_out_of_range_high() {
let _ = BacktrackingLineSearch::default().with_rho(1.0);
}

#[test]
#[should_panic]
fn with_c_panics_when_out_of_range_low() {
let _ = BacktrackingLineSearch::default().with_c(0.0);
}

#[test]
#[should_panic]
fn with_c_panics_when_out_of_range_high() {
let _ = BacktrackingLineSearch::default().with_c(1.0);
}
}
135 changes: 131 additions & 4 deletions src/algorithms/line_search/hager_zhang_line_search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,11 @@ impl HagerZhangLineSearch {
self.max_iters = max_iters;
self
}
/// Set the parameter $`\delta`$ used in the Armijo condition evaluation (defaults to
/// 0.1).
/// Set the parameter $`\delta`$ used in the Armijo condition evaluation (defaults to 0.1).
///
/// # Panics
///
/// This method will panic if the condition $`0 < \delta < \sigma < 1`$ is not met.
/// This method will panic if the condition $`0 < \delta < \sigma`$ is not met.
pub fn with_delta(mut self, delta: Float) -> Self {
assert!(0.0 < delta);
assert!(delta < self.sigma);
Expand All @@ -60,13 +59,26 @@ impl HagerZhangLineSearch {
///
/// # Panics
///
/// This method will panic if the condition $`0 < \delta < \sigma < 1`$ is not met.
/// This method will panic if the condition $`\delta < \sigma < 1`$ is not met.
pub fn with_sigma(mut self, sigma: Float) -> Self {
assert!(1.0 > sigma);
assert!(sigma > self.delta);
self.sigma = sigma;
self
}
/// Set the parameter $`\delta`$ used in the Armijo condition evaluation (defaults to 0.1) and the parameter $`\sigma`$ used in the second Wolfe condition (defaults to 0.9) simultaneously.
///
/// # Panics
///
/// This method will panic if the condition $`0 < \delta < \sigma < 1`$ is not met.
pub fn with_delta_sigma(mut self, delta: Float, sigma: Float) -> Self {
assert!(0.0 < delta);
assert!(1.0 > sigma);
assert!(delta < sigma);
self.delta = delta;
self.sigma = sigma;
self
}
/// Set the tolerance parameter $`\epsilon`$ used in the approximate Wolfe termination
/// conditions (defaults to `MACH_EPS^(1/3)`).
///
Expand Down Expand Up @@ -291,3 +303,118 @@ impl<U, E> LineSearch<GradientStatus, U, E> for HagerZhangLineSearch {
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn with_delta_sets_value() {
let ls = HagerZhangLineSearch::default().with_delta(0.2);
assert_eq!(ls.delta, 0.2);
}

#[test]
fn with_sigma_sets_value() {
let ls = HagerZhangLineSearch::default().with_sigma(0.7);
assert_eq!(ls.sigma, 0.7);
}

#[test]
fn with_delta_sigma_sets_both() {
let ls = HagerZhangLineSearch::default().with_delta_sigma(0.05, 0.8);
assert_eq!(ls.delta, 0.05);
assert_eq!(ls.sigma, 0.8);
}

#[test]
fn with_epsilon_sets_value() {
let ls = HagerZhangLineSearch::default().with_epsilon(1e-8);
assert_eq!(ls.epsilon, 1e-8);
assert!(ls.epsilon > 0.0);
}

#[test]
fn with_theta_sets_value() {
let ls = HagerZhangLineSearch::default().with_theta(0.6);
assert_eq!(ls.theta, 0.6);
assert!(0.0 < ls.theta && ls.theta < 1.0);
}

#[test]
fn with_gamma_sets_value() {
let ls = HagerZhangLineSearch::default().with_gamma(0.7);
assert_eq!(ls.gamma, 0.7);
assert!(0.0 < ls.gamma && ls.gamma < 1.0);
}

#[test]
fn with_max_bisects_sets_value() {
let ls = HagerZhangLineSearch::default().with_max_bisects(7);
assert_eq!(ls.max_bisects, 7);
}

#[test]
#[should_panic]
fn with_delta_panics_when_nonpositive() {
let _ = HagerZhangLineSearch::default().with_delta(0.0);
}

#[test]
#[should_panic]
fn with_delta_panics_when_not_less_than_sigma() {
let _ = HagerZhangLineSearch::default()
.with_sigma(0.4)
.with_delta(0.5);
}

#[test]
#[should_panic]
fn with_sigma_panics_when_not_less_than_one() {
let _ = HagerZhangLineSearch::default().with_sigma(1.0);
}

#[test]
#[should_panic]
fn with_sigma_panics_when_not_greater_than_delta() {
let _ = HagerZhangLineSearch::default()
.with_delta(0.2)
.with_sigma(0.1);
}

#[test]
#[should_panic]
fn with_delta_sigma_panics_when_bad_ordering() {
let _ = HagerZhangLineSearch::default().with_delta_sigma(0.5, 0.2);
}

#[test]
#[should_panic]
fn with_delta_sigma_panics_when_sigma_not_less_than_one() {
let _ = HagerZhangLineSearch::default().with_delta_sigma(0.2, 1.0);
}

#[test]
#[should_panic]
fn with_delta_sigma_panics_when_delta_not_positive() {
let _ = HagerZhangLineSearch::default().with_delta_sigma(0.0, 0.5);
}

#[test]
#[should_panic]
fn with_epsilon_panics_when_nonpositive() {
let _ = HagerZhangLineSearch::default().with_epsilon(0.0);
}

#[test]
#[should_panic]
fn with_theta_panics_when_out_of_range() {
let _ = HagerZhangLineSearch::default().with_theta(1.0);
}

#[test]
#[should_panic]
fn with_gamma_panics_when_out_of_range() {
let _ = HagerZhangLineSearch::default().with_gamma(0.0);
}
}
Loading