From 6cedcfcc09d42a21c3857007b76a5824c318bc2f Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 30 Jan 2026 20:46:12 +0000 Subject: [PATCH 1/2] disable vmap for CPU --- autofit/non_linear/fitness.py | 18 +++++++++--------- .../non_linear/search/nest/nautilus/search.py | 9 +++++++-- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index fca95406d..0da6185a7 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -116,22 +116,22 @@ def __init__( self.parameters_history_list = [] self.log_likelihood_history_list = [] - self.use_jax_vmap = use_jax_vmap - - self._call = self.call - - if self.use_jax_vmap: - self._call = self._vmap - if analysis._use_jax: import jax if jax.default_backend() == "cpu": - logger.info("JAX using CPU backend, batch size set to 1 which will improve performance.") + logger.info("JAX using CPU backend, vmap disabled for faster performance.") - batch_size = 1 + use_jax_vmap = False + + self.use_jax_vmap = use_jax_vmap + + self._call = self.call + + if self.use_jax_vmap: + self._call = self._vmap self.batch_size = batch_size self.iterations_per_quick_update = iterations_per_quick_update diff --git a/autofit/non_linear/search/nest/nautilus/search.py b/autofit/non_linear/search/nest/nautilus/search.py index fa5363cc3..8a584706f 100644 --- a/autofit/non_linear/search/nest/nautilus/search.py +++ b/autofit/non_linear/search/nest/nautilus/search.py @@ -137,7 +137,7 @@ def _fit(self, model: AbstractPriorModel, analysis): fom_is_log_likelihood=True, resample_figure_of_merit=-1.0e99, iterations_per_quick_update=self.iterations_per_quick_update, - use_jax_vmap=True, + use_jax_vmap=False, batch_size=self.config_dict_search["n_batch"], ) @@ -225,13 +225,18 @@ def fit_x1_cpu(self, fitness, model, analysis): except KeyError: pass + if fitness.use_jax_vmap: + vectorized = True + else: + vectorized = False + search_internal = self.sampler_cls( prior=PriorVectorized(model=model), likelihood=fitness.call_wrap, n_dim=model.prior_count, filepath=self.checkpoint_file, pool=None, - vectorized=True, + vectorized=vectorized, **config_dict, ) From a51ba48f846544fee72f56d24b90eab208255449 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 30 Jan 2026 20:52:12 +0000 Subject: [PATCH 2/2] Update autofit/non_linear/search/nest/nautilus/search.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- autofit/non_linear/search/nest/nautilus/search.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/autofit/non_linear/search/nest/nautilus/search.py b/autofit/non_linear/search/nest/nautilus/search.py index 8a584706f..e9c0abad6 100644 --- a/autofit/non_linear/search/nest/nautilus/search.py +++ b/autofit/non_linear/search/nest/nautilus/search.py @@ -225,10 +225,7 @@ def fit_x1_cpu(self, fitness, model, analysis): except KeyError: pass - if fitness.use_jax_vmap: - vectorized = True - else: - vectorized = False + vectorized = fitness.use_jax_vmap search_internal = self.sampler_cls( prior=PriorVectorized(model=model),