From 064799e850ff8df6a806c750c46c44d2c5953f34 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 27 Jun 2025 20:08:04 +0200 Subject: [PATCH 1/3] feat: add untested walnuts implementation --- src/adapt_strategy.rs | 1 + src/euclidean_hamiltonian.rs | 3 +- src/hamiltonian.rs | 25 ++++++ src/nuts.rs | 144 +++++++++++++++++++++++++++++++-- src/sampler.rs | 21 +++-- src/stepsize/adapt.rs | 5 +- src/transformed_hamiltonian.rs | 3 +- 7 files changed, 185 insertions(+), 17 deletions(-) diff --git a/src/adapt_strategy.rs b/src/adapt_strategy.rs index 79939e4..eea83a9 100644 --- a/src/adapt_strategy.rs +++ b/src/adapt_strategy.rs @@ -482,6 +482,7 @@ mod test { store_unconstrained: true, check_turning: true, store_divergences: false, + walnuts_options: None, }; let rng = { diff --git a/src/euclidean_hamiltonian.rs b/src/euclidean_hamiltonian.rs index 5059171..370eb01 100644 --- a/src/euclidean_hamiltonian.rs +++ b/src/euclidean_hamiltonian.rs @@ -225,6 +225,7 @@ impl> Hamiltonian for EuclideanHamiltonian, dir: Direction, + step_size_factor: f64, collector: &mut C, ) -> LeapfrogResult { let mut out = self.pool().new_state(math); @@ -237,7 +238,7 @@ impl> Hamiltonian for EuclideanHamiltonian -1, }; - let epsilon = (sign as f64) * self.step_size; + let epsilon = (sign as f64) * self.step_size * step_size_factor; start .point() diff --git a/src/hamiltonian.rs b/src/hamiltonian.rs index e4abcf8..f55bede 100644 --- a/src/hamiltonian.rs +++ b/src/hamiltonian.rs @@ -28,12 +28,36 @@ pub struct DivergenceInfo { pub logp_function_error: Option>, } +impl DivergenceInfo { + pub fn new() -> Self { + DivergenceInfo { + start_momentum: None, + start_location: None, + start_gradient: None, + end_location: None, + energy_error: None, + end_idx_in_trajectory: None, + start_idx_in_trajectory: None, + logp_function_error: None, + } + } +} + #[derive(Debug, Copy, Clone)] pub enum Direction { Forward, Backward, } +impl Direction { + pub fn reverse(&self) -> Self { + match self { + Direction::Forward => Direction::Backward, + Direction::Backward => Direction::Forward, + } + } +} + impl Distribution for StandardUniform { fn sample(&self, rng: &mut R) -> Direction { if rng.random::() { @@ -82,6 +106,7 @@ pub trait Hamiltonian: SamplerStats + Sized { math: &mut M, start: &State, dir: Direction, + step_size_factor: f64, collector: &mut C, ) -> LeapfrogResult; diff --git a/src/nuts.rs b/src/nuts.rs index 6c5af53..70656a1 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -1,3 +1,4 @@ +use serde::Serialize; use thiserror::Error; use std::{fmt::Debug, marker::PhantomData}; @@ -120,7 +121,7 @@ impl, C: Collector> NutsTree { H: Hamiltonian, R: rand::Rng + ?Sized, { - let mut other = match self.single_step(math, hamiltonian, direction, collector) { + let mut other = match self.single_step(math, hamiltonian, direction, options, collector) { Ok(Ok(tree)) => tree, Ok(Err(info)) => return ExtendResult::Diverging(self, info), Err(err) => return ExtendResult::Err(err), @@ -213,19 +214,141 @@ impl, C: Collector> NutsTree { math: &mut M, hamiltonian: &mut H, direction: Direction, + options: &NutsOptions, collector: &mut C, ) -> Result, DivergenceInfo>> { let start = match direction { Direction::Forward => &self.right, Direction::Backward => &self.left, }; - let end = match hamiltonian.leapfrog(math, start, direction, collector) { - LeapfrogResult::Divergence(info) => return Ok(Err(info)), - LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), - LeapfrogResult::Ok(end) => end, + + let (log_size, end) = match options.walnuts_options { + Some(ref options) => { + // Walnuts implementation + // TODO: Shouldn't all be in this one big function... + let mut step_size_factor = 1.0; + let mut num_steps = 1; + let mut current = start.clone(); + + let mut success = false; + + 'step_size_search: for _ in 0..options.max_step_size_halvings { + current = start.clone(); + let mut min_energy = current.energy(); + let mut max_energy = min_energy; + + for _ in 0..num_steps { + current = match hamiltonian.leapfrog( + math, + ¤t, + direction, + step_size_factor, + collector, + ) { + LeapfrogResult::Ok(state) => state, + LeapfrogResult::Divergence(_) => { + num_steps *= 2; + step_size_factor *= 0.5; + continue 'step_size_search; + } + LeapfrogResult::Err(err) => { + return Err(NutsError::LogpFailure(err.into())); + } + }; + + // Update min/max energies + let current_energy = current.energy(); + min_energy = min_energy.min(current_energy); + max_energy = max_energy.max(current_energy); + } + + if max_energy - min_energy > options.max_energy_error { + num_steps *= 2; + step_size_factor *= 0.5; + continue 'step_size_search; + } + + success = true; + break 'step_size_search; + } + + if !success { + // TODO: More info + return Ok(Err(DivergenceInfo::new())); + } + + // TODO + let back = direction.reverse(); + let mut current_backward; + + let mut reversible = true; + + 'rev_step_size: while num_steps >= 2 { + num_steps /= 2; + step_size_factor *= 0.5; + + // TODO: Can we share code for the micro steps in the two directions? + current_backward = current.clone(); + + let mut min_energy = current_backward.energy(); + let mut max_energy = min_energy; + + for _ in 0..num_steps { + current_backward = match hamiltonian.leapfrog( + math, + ¤t_backward, + back, + step_size_factor, + collector, + ) { + LeapfrogResult::Ok(state) => state, + LeapfrogResult::Divergence(_) => { + // We also reject in the backward direction, all is good so far... + continue 'rev_step_size; + } + LeapfrogResult::Err(err) => { + return Err(NutsError::LogpFailure(err.into())); + } + }; + + // Update min/max energies + let current_energy = current_backward.energy(); + min_energy = min_energy.min(current_energy); + max_energy = max_energy.max(current_energy); + if max_energy - min_energy > options.max_energy_error { + // We reject also in the backward direction, all good so far... + continue 'rev_step_size; + } + } + + // We did not reject in the backward direction, so we are not reversible + reversible = false; + break; + } + + if reversible { + let log_size = -current.point().energy_error(); + (log_size, current) + } else { + // TODO: More info + return Ok(Err(DivergenceInfo::new())); + } + } + None => { + // Classical NUTS + // + let end = match hamiltonian.leapfrog(math, start, direction, 1.0, collector) { + LeapfrogResult::Divergence(info) => return Ok(Err(info)), + LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), + LeapfrogResult::Ok(end) => end, + }; + + let log_size = -end.point().energy_error(); + + (log_size, end) + } }; - let log_size = -end.point().energy_error(); Ok(Ok(NutsTree { right: end.clone(), left: end.clone(), @@ -248,6 +371,13 @@ impl, C: Collector> NutsTree { } } +#[derive(Debug, Clone, Copy, Serialize)] +pub struct WalnutsOptions { + pub max_energy_error: f64, + pub max_step_size_halvings: u64, +} + +#[derive(Debug, Clone, Copy)] pub struct NutsOptions { pub maxdepth: u64, pub mindepth: u64, @@ -255,6 +385,8 @@ pub struct NutsOptions { pub store_unconstrained: bool, pub check_turning: bool, pub store_divergences: bool, + + pub walnuts_options: Option, } pub(crate) fn draw( diff --git a/src/sampler.rs b/src/sampler.rs index 0ccb9ca..72b1fc0 100644 --- a/src/sampler.rs +++ b/src/sampler.rs @@ -20,17 +20,17 @@ use std::{ }; use crate::{ - DiagAdaptExpSettings, + DiagAdaptExpSettings, Model, SamplerStats, adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions}, chain::{AdaptStrategy, Chain, NutsChain, StatOptions}, euclidean_hamiltonian::EuclideanHamiltonian, - mass_matrix::DiagMassMatrix, - mass_matrix::Strategy as DiagMassMatrixStrategy, - mass_matrix::{LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings}, + mass_matrix::{ + DiagMassMatrix, LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings, + Strategy as DiagMassMatrixStrategy, + }, math_base::Math, - model::Model, - nuts::NutsOptions, - sampler_stats::{SamplerStats, StatsDims}, + nuts::{NutsOptions, WalnutsOptions}, + sampler_stats::StatsDims, storage::{ChainStorage, StorageConfig, TraceStorage}, transform_adapt_strategy::{TransformAdaptation, TransformedSettings}, transformed_hamiltonian::{TransformedHamiltonian, TransformedPointStatsOptions}, @@ -185,6 +185,7 @@ pub struct NutsSettings { pub num_chains: usize, pub seed: u64, + pub walnuts_options: Option, } pub type DiagGradNutsSettings = NutsSettings>; @@ -206,6 +207,7 @@ impl Default for DiagGradNutsSettings { check_turning: true, seed: 0, num_chains: 6, + walnuts_options: None, } } } @@ -225,6 +227,7 @@ impl Default for LowRankNutsSettings { check_turning: true, seed: 0, num_chains: 6, + walnuts_options: None, }; vals.adapt_options.mass_matrix_update_freq = 10; vals @@ -246,6 +249,7 @@ impl Default for TransformedNutsSettings { check_turning: true, seed: 0, num_chains: 1, + walnuts_options: None, } } } @@ -278,6 +282,7 @@ impl Settings for LowRankNutsSettings { store_divergences: self.store_divergences, store_unconstrained: self.store_unconstrained, check_turning: self.check_turning, + walnuts_options: self.walnuts_options, }; let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng"); @@ -346,6 +351,7 @@ impl Settings for DiagGradNutsSettings { store_divergences: self.store_divergences, store_unconstrained: self.store_unconstrained, check_turning: self.check_turning, + walnuts_options: self.walnuts_options, }; let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng"); @@ -411,6 +417,7 @@ impl Settings for TransformedNutsSettings { store_divergences: self.store_divergences, store_unconstrained: self.store_unconstrained, check_turning: self.check_turning, + walnuts_options: self.walnuts_options, }; let rng = rand::rngs::SmallRng::try_from_rng(&mut rng).expect("Could not seed rng"); diff --git a/src/stepsize/adapt.rs b/src/stepsize/adapt.rs index 7dc9134..b8afdc2 100644 --- a/src/stepsize/adapt.rs +++ b/src/stepsize/adapt.rs @@ -103,7 +103,8 @@ impl Strategy { *hamiltonian.step_size_mut() = self.options.initial_step; - let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, &mut collector); + let state_next = + hamiltonian.leapfrog(math, &state, Direction::Forward, 1.0, &mut collector); let LeapfrogResult::Ok(_) = state_next else { return Ok(()); @@ -119,7 +120,7 @@ impl Strategy { for _ in 0..100 { let mut collector = AcceptanceRateCollector::new(); collector.register_init(math, &state, options); - let state_next = hamiltonian.leapfrog(math, &state, dir, &mut collector); + let state_next = hamiltonian.leapfrog(math, &state, dir, 1.0, &mut collector); let LeapfrogResult::Ok(_) = state_next else { *hamiltonian.step_size_mut() = self.options.initial_step; return Ok(()); diff --git a/src/transformed_hamiltonian.rs b/src/transformed_hamiltonian.rs index 7b97482..7bf7fea 100644 --- a/src/transformed_hamiltonian.rs +++ b/src/transformed_hamiltonian.rs @@ -303,6 +303,7 @@ impl Hamiltonian for TransformedHamiltonian { math: &mut M, start: &State, dir: Direction, + step_size_factor: f64, collector: &mut C, ) -> LeapfrogResult { let mut out = self.pool().new_state(math); @@ -316,7 +317,7 @@ impl Hamiltonian for TransformedHamiltonian { Direction::Backward => -1, }; - let epsilon = (sign as f64) * self.step_size; + let epsilon = (sign as f64) * self.step_size * step_size_factor; start .point() From 47bd2ca53eaece3b67d12d92885e221a4d9a6929 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 27 Jun 2025 20:08:04 +0200 Subject: [PATCH 2/3] doc: add temporary comments for comparison with walnuts c++ --- src/nuts.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/nuts.rs b/src/nuts.rs index 70656a1..3946b46 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -60,22 +60,41 @@ pub struct SampleInfo { } /// A part of the trajectory tree during NUTS sampling. +/// +/// Corresponds to SpanW in walnuts C++ code struct NutsTree, C: Collector> { /// The left position of the tree. /// /// The left side always has the smaller index_in_trajectory. /// Leapfrogs in backward direction will replace the left. + /// + /// theta_bk_, rho_bk_, grad_theta_bk_, logp_bk_ in C++ code left: State, + + /// The right position of the tree. + /// + /// theta_fw_, rho_fw_, grad_theta_fw_, logp_fw_ in C++ code right: State, /// A draw from the trajectory between left and right using /// multinomial sampling. + /// + /// theta_select_ in C++ code draw: State, + + /// Constant for acceptance probability + /// + /// logp_ in C++ code log_size: f64, + + /// The depth of the tree depth: u64, /// A tree is the main tree if it contains the initial point /// of the trajectory. + /// + /// This is used to determine whether to use Metropolis + /// accptance or Barker is_main: bool, _phantom2: PhantomData, } @@ -172,6 +191,7 @@ impl, C: Collector> NutsTree { } } + // `combine` in C++ code fn merge_into( &mut self, _math: &mut M, @@ -209,6 +229,7 @@ impl, C: Collector> NutsTree { self.log_size = log_size; } + // Corresponds to `build_leaf` in C++ code fn single_step( &self, math: &mut M, From 3e0efa2970e003d5be2705dadd60ad6d9a632c37 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Sat, 28 Jun 2025 12:27:47 +0200 Subject: [PATCH 3/3] feat: clean up walnuts a little bit --- src/adapt_strategy.rs | 5 +- src/chain.rs | 4 +- src/euclidean_hamiltonian.rs | 44 ++++----- src/hamiltonian.rs | 99 ++++++++++++++++++- src/lib.rs | 2 +- src/nuts.rs | 169 ++++++++++++++------------------ src/stepsize/adapt.rs | 5 +- src/stepsize/dual_avg.rs | 1 + src/transform_adapt_strategy.rs | 1 + src/transformed_hamiltonian.rs | 43 ++++---- 10 files changed, 219 insertions(+), 154 deletions(-) diff --git a/src/adapt_strategy.rs b/src/adapt_strategy.rs index eea83a9..b1eca0f 100644 --- a/src/adapt_strategy.rs +++ b/src/adapt_strategy.rs @@ -291,11 +291,12 @@ where start: &State, end: &State, divergence_info: Option<&DivergenceInfo>, + num_substeps: u64, ) { self.collector1 - .register_leapfrog(math, start, end, divergence_info); + .register_leapfrog(math, start, end, divergence_info, num_substeps); self.collector2 - .register_leapfrog(math, start, end, divergence_info); + .register_leapfrog(math, start, end, divergence_info, num_substeps); } fn register_draw(&mut self, math: &mut M, state: &State, info: &crate::nuts::SampleInfo) { diff --git a/src/chain.rs b/src/chain.rs index 026312a..f8f5981 100644 --- a/src/chain.rs +++ b/src/chain.rs @@ -157,6 +157,7 @@ where &mut self.hamiltonian, &self.options, &mut self.collector, + self.draw_count < 70, )?; let mut position: Box<[f64]> = vec![0f64; math.dim()].into(); state.write_position(math, &mut position); @@ -235,6 +236,7 @@ pub struct NutsStats, A: Storable

, D: Storable

> pub divergence_end: Option>, #[storable(dims("unconstrained_parameter"))] pub divergence_momentum: Option>, + non_reversible: Option, //pub divergence_message: Option, #[storable(ignore)] _phantom: PhantomData P>, @@ -303,7 +305,7 @@ impl> SamplerStats for NutsChain> Hamiltonian for EuclideanHamiltonian, dir: Direction, - step_size_factor: f64, + step_size_splits: u64, collector: &mut C, ) -> LeapfrogResult { let mut out = self.pool().new_state(math); @@ -238,7 +238,7 @@ impl> Hamiltonian for EuclideanHamiltonian -1, }; - let epsilon = (sign as f64) * self.step_size * step_size_factor; + let epsilon = (sign as f64) * self.step_size / (step_size_splits as f64); start .point() @@ -250,17 +250,9 @@ impl> Hamiltonian for EuclideanHamiltonian> Hamiltonian for EuclideanHamiltonian self.max_energy_error) | !energy_error.is_finite() { - let divergence_info = DivergenceInfo { - logp_function_error: None, - start_location: Some(math.box_array(start.point().position())), - start_gradient: Some(math.box_array(start.point().gradient())), - end_location: Some(math.box_array(&out_point.position)), - start_momentum: Some(math.box_array(&out_point.momentum)), - start_idx_in_trajectory: Some(start.index_in_trajectory()), - end_idx_in_trajectory: Some(out.index_in_trajectory()), - energy_error: Some(energy_error), - }; - collector.register_leapfrog(math, start, &out, Some(&divergence_info)); + let divergence_info = DivergenceInfo::new_energy_error_too_large(math, start, &out); + collector.register_leapfrog( + math, + start, + &out, + Some(&divergence_info), + step_size_splits, + ); return LeapfrogResult::Divergence(divergence_info); } - collector.register_leapfrog(math, start, &out, None); + collector.register_leapfrog(math, start, &out, None, step_size_splits); LeapfrogResult::Ok(out) } @@ -363,4 +353,8 @@ impl> Hamiltonian for EuclideanHamiltonian &mut f64 { &mut self.step_size } + + fn max_energy_error(&self) -> f64 { + self.max_energy_error + } } diff --git a/src/hamiltonian.rs b/src/hamiltonian.rs index f55bede..b8632ae 100644 --- a/src/hamiltonian.rs +++ b/src/hamiltonian.rs @@ -16,6 +16,7 @@ use crate::{ /// a cutoff value or nan. /// - The logp function caused a recoverable error (eg if an ODE solver /// failed) +#[non_exhaustive] #[derive(Debug, Clone)] pub struct DivergenceInfo { pub start_momentum: Option>, @@ -26,6 +27,7 @@ pub struct DivergenceInfo { pub end_idx_in_trajectory: Option, pub start_idx_in_trajectory: Option, pub logp_function_error: Option>, + pub non_reversible: bool, } impl DivergenceInfo { @@ -39,8 +41,67 @@ impl DivergenceInfo { end_idx_in_trajectory: None, start_idx_in_trajectory: None, logp_function_error: None, + non_reversible: false, } } + + pub fn new_energy_error_too_large( + math: &mut M, + start: &State>, + stop: &State>, + ) -> Self { + DivergenceInfo { + logp_function_error: None, + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(start.point().gradient())), + // TODO + start_momentum: None, + start_idx_in_trajectory: Some(start.index_in_trajectory()), + end_location: Some(math.box_array(&stop.point().position())), + end_idx_in_trajectory: Some(stop.index_in_trajectory()), + // TODO + energy_error: None, + non_reversible: false, + } + } + + pub fn new_logp_function_error( + math: &mut M, + start: &State>, + logp_function_error: Arc, + ) -> Self { + DivergenceInfo { + logp_function_error: Some(logp_function_error), + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(start.point().gradient())), + // TODO + start_momentum: None, + start_idx_in_trajectory: Some(start.index_in_trajectory()), + end_location: None, + end_idx_in_trajectory: None, + energy_error: None, + non_reversible: false, + } + } + + pub fn new_not_reversible(math: &mut M, start: &State>) -> Self { + // TODO add info about what went wrong + DivergenceInfo { + logp_function_error: None, + start_location: Some(math.box_array(start.point().position())), + start_gradient: Some(math.box_array(start.point().gradient())), + // TODO + start_momentum: None, + start_idx_in_trajectory: Some(start.index_in_trajectory()), + end_location: None, + end_idx_in_trajectory: None, + energy_error: None, + non_reversible: true, + } + } + pub fn new_max_step_size_halvings(math: &mut M, num_steps: u64, info: Self) -> Self { + info // TODO + } } #[derive(Debug, Copy, Clone)] @@ -106,10 +167,44 @@ pub trait Hamiltonian: SamplerStats + Sized { math: &mut M, start: &State, dir: Direction, - step_size_factor: f64, + step_size_splits: u64, collector: &mut C, ) -> LeapfrogResult; + fn split_leapfrog>( + &mut self, + math: &mut M, + start: &State, + dir: Direction, + num_steps: u64, + collector: &mut C, + max_error: f64, + ) -> LeapfrogResult { + let mut state = start.clone(); + + let mut min_energy = start.energy(); + let mut max_energy = min_energy; + + for _ in 0..num_steps { + state = match self.leapfrog(math, &state, dir, num_steps, collector) { + LeapfrogResult::Ok(state) => state, + LeapfrogResult::Divergence(info) => return LeapfrogResult::Divergence(info), + LeapfrogResult::Err(err) => return LeapfrogResult::Err(err), + }; + let energy = state.energy(); + min_energy = min_energy.min(energy); + max_energy = max_energy.max(energy); + + // TODO: walnuts papers says to use abs, but c++ code doesn't? + if max_energy - min_energy > max_error { + let info = DivergenceInfo::new_energy_error_too_large(math, start, &state); + return LeapfrogResult::Divergence(info); + } + } + + LeapfrogResult::Ok(state) + } + fn is_turning( &self, math: &mut M, @@ -141,4 +236,6 @@ pub trait Hamiltonian: SamplerStats + Sized { fn step_size(&self) -> f64; fn step_size_mut(&mut self) -> &mut f64; + + fn max_energy_error(&self) -> f64; } diff --git a/src/lib.rs b/src/lib.rs index b722803..c0e8c8a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -125,7 +125,7 @@ pub use cpu_math::{CpuLogpFunc, CpuMath, CpuMathError}; pub use hamiltonian::DivergenceInfo; pub use math_base::{LogpError, Math}; pub use model::Model; -pub use nuts::NutsError; +pub use nuts::{NutsError, WalnutsOptions}; pub use sampler::{ ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, NutsSettings, Progress, ProgressCallback, Sampler, SamplerWaitResult, Settings, TransformedNutsSettings, diff --git a/src/nuts.rs b/src/nuts.rs index 3946b46..8cfd4e7 100644 --- a/src/nuts.rs +++ b/src/nuts.rs @@ -35,6 +35,7 @@ pub trait Collector> { _start: &State, _end: &State, _divergence_info: Option<&DivergenceInfo>, + _num_substeps: u64, ) { } fn register_draw(&mut self, _math: &mut M, _state: &State, _info: &SampleInfo) {} @@ -135,20 +136,23 @@ impl, C: Collector> NutsTree { direction: Direction, collector: &mut C, options: &NutsOptions, + early: bool, ) -> ExtendResult where H: Hamiltonian, R: rand::Rng + ?Sized, { - let mut other = match self.single_step(math, hamiltonian, direction, options, collector) { - Ok(Ok(tree)) => tree, - Ok(Err(info)) => return ExtendResult::Diverging(self, info), - Err(err) => return ExtendResult::Err(err), - }; + let mut other = + match self.single_step(math, hamiltonian, direction, options, collector, early) { + Ok(Ok(tree)) => tree, + Ok(Err(info)) => return ExtendResult::Diverging(self, info), + Err(err) => return ExtendResult::Err(err), + }; while other.depth < self.depth { use ExtendResult::*; - other = match other.extend(math, rng, hamiltonian, direction, collector, options) { + other = match other.extend(math, rng, hamiltonian, direction, collector, options, early) + { Ok(tree) => tree, Turning(_) => { return Turning(self); @@ -237,6 +241,7 @@ impl, C: Collector> NutsTree { direction: Direction, options: &NutsOptions, collector: &mut C, + early: bool, ) -> Result, DivergenceInfo>> { let start = match direction { Direction::Forward => &self.right, @@ -247,118 +252,79 @@ impl, C: Collector> NutsTree { Some(ref options) => { // Walnuts implementation // TODO: Shouldn't all be in this one big function... - let mut step_size_factor = 1.0; let mut num_steps = 1; let mut current = start.clone(); - let mut success = false; - - 'step_size_search: for _ in 0..options.max_step_size_halvings { - current = start.clone(); - let mut min_energy = current.energy(); - let mut max_energy = min_energy; - - for _ in 0..num_steps { - current = match hamiltonian.leapfrog( - math, - ¤t, - direction, - step_size_factor, - collector, - ) { - LeapfrogResult::Ok(state) => state, - LeapfrogResult::Divergence(_) => { - num_steps *= 2; - step_size_factor *= 0.5; - continue 'step_size_search; - } - LeapfrogResult::Err(err) => { - return Err(NutsError::LogpFailure(err.into())); - } - }; - - // Update min/max energies - let current_energy = current.energy(); - min_energy = min_energy.min(current_energy); - max_energy = max_energy.max(current_energy); - } - - if max_energy - min_energy > options.max_energy_error { - num_steps *= 2; - step_size_factor *= 0.5; - continue 'step_size_search; - } - - success = true; - break 'step_size_search; + let mut last_divergence = None; + + for _ in 0..options.max_step_size_halvings { + current = match hamiltonian.split_leapfrog( + math, + start, + direction, + num_steps, + collector, + options.max_energy_error, + ) { + LeapfrogResult::Ok(state) => { + last_divergence = None; + state + } + LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), + LeapfrogResult::Divergence(info) => { + num_steps *= 2; + last_divergence = Some(info); + continue; + } + }; + break; } - if !success { - // TODO: More info - return Ok(Err(DivergenceInfo::new())); + if let Some(info) = last_divergence { + let info = DivergenceInfo::new_max_step_size_halvings(math, num_steps, info); + return Ok(Err(info)); } - // TODO let back = direction.reverse(); - let mut current_backward; - let mut reversible = true; - 'rev_step_size: while num_steps >= 2 { + while num_steps >= 2 { num_steps /= 2; - step_size_factor *= 0.5; - - // TODO: Can we share code for the micro steps in the two directions? - current_backward = current.clone(); - - let mut min_energy = current_backward.energy(); - let mut max_energy = min_energy; - - for _ in 0..num_steps { - current_backward = match hamiltonian.leapfrog( - math, - ¤t_backward, - back, - step_size_factor, - collector, - ) { - LeapfrogResult::Ok(state) => state, - LeapfrogResult::Divergence(_) => { - // We also reject in the backward direction, all is good so far... - continue 'rev_step_size; - } - LeapfrogResult::Err(err) => { - return Err(NutsError::LogpFailure(err.into())); - } - }; - - // Update min/max energies - let current_energy = current_backward.energy(); - min_energy = min_energy.min(current_energy); - max_energy = max_energy.max(current_energy); - if max_energy - min_energy > options.max_energy_error { - // We reject also in the backward direction, all good so far... - continue 'rev_step_size; + + match hamiltonian.split_leapfrog( + math, + ¤t, + back, + num_steps, + collector, + options.max_energy_error, + ) { + LeapfrogResult::Ok(_) => (), + LeapfrogResult::Divergence(_) => { + // We also reject in the backward direction, all is good so far... + continue; + } + LeapfrogResult::Err(err) => { + return Err(NutsError::LogpFailure(err.into())); } - } + }; // We did not reject in the backward direction, so we are not reversible reversible = false; break; } - if reversible { + if reversible || early { let log_size = -current.point().energy_error(); (log_size, current) } else { - // TODO: More info - return Ok(Err(DivergenceInfo::new())); + return Ok(Err(DivergenceInfo::new_not_reversible(math, start))); } } None => { - // Classical NUTS - // - let end = match hamiltonian.leapfrog(math, start, direction, 1.0, collector) { + // Classical NUTS. + // TODO Is equivalent to walnuts with max_step_size_halvings = 0? + let end = match hamiltonian.leapfrog(math, start, direction, 1, collector) { LeapfrogResult::Divergence(info) => return Ok(Err(info)), LeapfrogResult::Err(err) => return Err(NutsError::LogpFailure(err.into())), LeapfrogResult::Ok(end) => end, @@ -392,10 +358,20 @@ impl, C: Collector> NutsTree { } } +#[non_exhaustive] #[derive(Debug, Clone, Copy, Serialize)] pub struct WalnutsOptions { - pub max_energy_error: f64, pub max_step_size_halvings: u64, + pub max_energy_error: f64, +} + +impl Default for WalnutsOptions { + fn default() -> Self { + WalnutsOptions { + max_step_size_halvings: 10, + max_energy_error: 5.0, + } + } } #[derive(Debug, Clone, Copy)] @@ -417,6 +393,7 @@ pub(crate) fn draw( hamiltonian: &mut H, options: &NutsOptions, collector: &mut C, + early: bool, ) -> Result<(State, SampleInfo)> where M: Math, @@ -437,7 +414,7 @@ where while tree.depth < options.maxdepth { let direction: Direction = rng.random(); - tree = match tree.extend(math, rng, hamiltonian, direction, collector, options) { + tree = match tree.extend(math, rng, hamiltonian, direction, collector, options, early) { ExtendResult::Ok(tree) => tree, ExtendResult::Turning(tree) => { if tree.depth < options.mindepth { diff --git a/src/stepsize/adapt.rs b/src/stepsize/adapt.rs index b8afdc2..49cdd77 100644 --- a/src/stepsize/adapt.rs +++ b/src/stepsize/adapt.rs @@ -103,8 +103,7 @@ impl Strategy { *hamiltonian.step_size_mut() = self.options.initial_step; - let state_next = - hamiltonian.leapfrog(math, &state, Direction::Forward, 1.0, &mut collector); + let state_next = hamiltonian.leapfrog(math, &state, Direction::Forward, 0, &mut collector); let LeapfrogResult::Ok(_) = state_next else { return Ok(()); @@ -120,7 +119,7 @@ impl Strategy { for _ in 0..100 { let mut collector = AcceptanceRateCollector::new(); collector.register_init(math, &state, options); - let state_next = hamiltonian.leapfrog(math, &state, dir, 1.0, &mut collector); + let state_next = hamiltonian.leapfrog(math, &state, dir, 0, &mut collector); let LeapfrogResult::Ok(_) = state_next else { *hamiltonian.step_size_mut() = self.options.initial_step; return Ok(()); diff --git a/src/stepsize/dual_avg.rs b/src/stepsize/dual_avg.rs index 3f6d613..7c989a0 100644 --- a/src/stepsize/dual_avg.rs +++ b/src/stepsize/dual_avg.rs @@ -126,6 +126,7 @@ impl> Collector for AcceptanceRateCollector { _start: &State, end: &State, divergence_info: Option<&DivergenceInfo>, + _num_substeps: u64, ) { match divergence_info { Some(_) => { diff --git a/src/transform_adapt_strategy.rs b/src/transform_adapt_strategy.rs index 6360ab1..307c0fa 100644 --- a/src/transform_adapt_strategy.rs +++ b/src/transform_adapt_strategy.rs @@ -81,6 +81,7 @@ impl> Collector for DrawCollector { _start: &State, end: &State, divergence_info: Option<&crate::DivergenceInfo>, + num_substeps: u64, ) { if divergence_info.is_some() { return; diff --git a/src/transformed_hamiltonian.rs b/src/transformed_hamiltonian.rs index 7bf7fea..9f0cf27 100644 --- a/src/transformed_hamiltonian.rs +++ b/src/transformed_hamiltonian.rs @@ -303,7 +303,7 @@ impl Hamiltonian for TransformedHamiltonian { math: &mut M, start: &State, dir: Direction, - step_size_factor: f64, + step_size_splits: u64, collector: &mut C, ) -> LeapfrogResult { let mut out = self.pool().new_state(math); @@ -317,7 +317,7 @@ impl Hamiltonian for TransformedHamiltonian { Direction::Backward => -1, }; - let epsilon = (sign as f64) * self.step_size * step_size_factor; + let epsilon = (sign as f64) * self.step_size / (step_size_splits as f64); start .point() @@ -328,17 +328,9 @@ impl Hamiltonian for TransformedHamiltonian { if !logp_error.is_recoverable() { return LeapfrogResult::Err(logp_error); } - let div_info = DivergenceInfo { - logp_function_error: Some(Arc::new(Box::new(logp_error))), - start_location: Some(math.box_array(start.point().position())), - start_gradient: Some(math.box_array(start.point().gradient())), - start_momentum: None, - end_location: None, - start_idx_in_trajectory: Some(start.point().index_in_trajectory()), - end_idx_in_trajectory: None, - energy_error: None, - }; - collector.register_leapfrog(math, start, &out, Some(&div_info)); + let logp_error = Arc::new(Box::new(logp_error)); + let div_info = DivergenceInfo::new_logp_function_error(math, start, logp_error); + collector.register_leapfrog(math, start, &out, Some(&div_info), step_size_splits); return LeapfrogResult::Divergence(div_info); } @@ -349,21 +341,18 @@ impl Hamiltonian for TransformedHamiltonian { let energy_error = out_point.energy_error(); if (energy_error > self.max_energy_error) | !energy_error.is_finite() { - let divergence_info = DivergenceInfo { - logp_function_error: None, - start_location: Some(math.box_array(start.point().position())), - start_gradient: Some(math.box_array(start.point().gradient())), - end_location: Some(math.box_array(out_point.position())), - start_momentum: None, - start_idx_in_trajectory: Some(start.index_in_trajectory()), - end_idx_in_trajectory: Some(out.index_in_trajectory()), - energy_error: Some(energy_error), - }; - collector.register_leapfrog(math, start, &out, Some(&divergence_info)); + let divergence_info = DivergenceInfo::new_energy_error_too_large(math, start, &out); + collector.register_leapfrog( + math, + start, + &out, + Some(&divergence_info), + step_size_splits, + ); return LeapfrogResult::Divergence(divergence_info); } - collector.register_leapfrog(math, start, &out, None); + collector.register_leapfrog(math, start, &out, None, step_size_splits); LeapfrogResult::Ok(out) } @@ -465,4 +454,8 @@ impl Hamiltonian for TransformedHamiltonian { fn step_size_mut(&mut self) -> &mut f64 { &mut self.step_size } + + fn max_energy_error(&self) -> f64 { + self.max_energy_error + } }