From 244eb136efb57736e0ca315d477ca1aa507dbc25 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 19 Jun 2025 16:23:27 +0200 Subject: [PATCH] feat: add arguments for walnuts --- src/wrapper.rs | 130 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 129 insertions(+), 1 deletion(-) diff --git a/src/wrapper.rs b/src/wrapper.rs index f29620c..8f21240 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -17,7 +17,7 @@ use arrow::array::Array; use numpy::{PyArray1, PyReadonlyArray1}; use nuts_rs::{ ChainProgress, DiagGradNutsSettings, LowRankNutsSettings, ProgressCallback, Sampler, - SamplerWaitResult, Trace, TransformedNutsSettings, + SamplerWaitResult, Trace, TransformedNutsSettings, WalnutsOptions, }; use pyo3::{ exceptions::PyTimeoutError, @@ -654,6 +654,134 @@ impl PyNutsSettings { } Ok(()) } + + #[getter] + fn max_step_size_halvings(&self) -> Result> { + let walnuts = match &self.inner { + Settings::LowRank(inner) => inner.walnuts_options, + Settings::Diag(inner) => inner.walnuts_options, + Settings::Transforming(inner) => inner.walnuts_options, + }; + if let Some(walnuts) = walnuts { + Ok(Some(walnuts.max_step_size_halvings)) + } else { + Ok(None) + } + } + + #[setter(max_step_size_halvings)] + fn set_max_step_size_halvings(&mut self, val: Option) -> Result<()> { + let options = match &mut self.inner { + Settings::LowRank(inner) => &mut inner.walnuts_options, + Settings::Diag(inner) => &mut inner.walnuts_options, + Settings::Transforming(inner) => &mut inner.walnuts_options, + }; + + if let Some(max_halvings) = val { + if let Some(ref mut options) = options { + options.max_step_size_halvings = max_halvings; + } else { + let mut new_options = WalnutsOptions::default(); + new_options.max_step_size_halvings = max_halvings; + *options = Some(new_options); + } + } else { + *options = None; + } + + Ok(()) + } + + #[getter] + fn max_walnuts_energy_error(&self) -> Result> { + let walnuts = match &self.inner { + Settings::LowRank(inner) => inner.walnuts_options, + Settings::Diag(inner) => inner.walnuts_options, + Settings::Transforming(inner) => inner.walnuts_options, + }; + if let Some(walnuts) = walnuts { + Ok(Some(walnuts.max_energy_error)) + } else { + Ok(None) + } + } + + #[setter(max_walnuts_energy_error)] + fn set_max_walnuts_energy_error(&mut self, val: Option) -> Result<()> { + let options = match &mut self.inner { + Settings::LowRank(inner) => &mut inner.walnuts_options, + Settings::Diag(inner) => &mut inner.walnuts_options, + Settings::Transforming(inner) => &mut inner.walnuts_options, + }; + + if let Some(max_error) = val { + if let Some(ref mut options) = options { + options.max_energy_error = max_error; + } else { + let mut new_options = WalnutsOptions::default(); + new_options.max_energy_error = max_error; + *options = Some(new_options); + } + } else { + *options = None; + } + + Ok(()) + } + + #[getter] + fn fixed_step_size(&self) -> Result> { + match &self.inner { + Settings::LowRank(inner) => { + Ok(inner.adapt_options.dual_average_options.fixed_step_size) + } + Settings::Diag(inner) => Ok(inner.adapt_options.dual_average_options.fixed_step_size), + Settings::Transforming(inner) => { + Ok(inner.adapt_options.dual_average_options.fixed_step_size) + } + } + } + + #[setter(fixed_step_size)] + fn set_fixed_step_size(&mut self, val: Option) -> Result<()> { + match &mut self.inner { + Settings::LowRank(inner) => { + inner.adapt_options.dual_average_options.fixed_step_size = val; + } + Settings::Diag(inner) => { + inner.adapt_options.dual_average_options.fixed_step_size = val; + } + Settings::Transforming(inner) => { + inner.adapt_options.dual_average_options.fixed_step_size = val; + } + } + Ok(()) + } + + #[getter] + fn step_size_jitter(&self) -> Result> { + match &self.inner { + Settings::LowRank(inner) => Ok(inner.adapt_options.dual_average_options.jitter), + Settings::Diag(inner) => Ok(inner.adapt_options.dual_average_options.jitter), + Settings::Transforming(inner) => Ok(inner.adapt_options.dual_average_options.jitter), + } + } + + #[setter(step_size_jitter)] + fn set_step_size_jitter(&mut self, val: Option) -> Result<()> { + match &mut self.inner { + Settings::LowRank(inner) => { + inner.adapt_options.dual_average_options.jitter = val; + } + Settings::Diag(inner) => { + inner.adapt_options.dual_average_options.jitter = val; + } + Settings::Transforming(inner) => { + inner.adapt_options.dual_average_options.jitter = val; + } + } + Ok(()) + } } pub(crate) enum SamplerState {