diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index fca95406d..0da6185a7 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -116,22 +116,22 @@ def __init__( self.parameters_history_list = [] self.log_likelihood_history_list = [] - self.use_jax_vmap = use_jax_vmap - - self._call = self.call - - if self.use_jax_vmap: - self._call = self._vmap - if analysis._use_jax: import jax if jax.default_backend() == "cpu": - logger.info("JAX using CPU backend, batch size set to 1 which will improve performance.") + logger.info("JAX using CPU backend, vmap disabled for faster performance.") - batch_size = 1 + use_jax_vmap = False + + self.use_jax_vmap = use_jax_vmap + + self._call = self.call + + if self.use_jax_vmap: + self._call = self._vmap self.batch_size = batch_size self.iterations_per_quick_update = iterations_per_quick_update diff --git a/autofit/non_linear/search/nest/nautilus/search.py b/autofit/non_linear/search/nest/nautilus/search.py index fa5363cc3..e9c0abad6 100644 --- a/autofit/non_linear/search/nest/nautilus/search.py +++ b/autofit/non_linear/search/nest/nautilus/search.py @@ -137,7 +137,7 @@ def _fit(self, model: AbstractPriorModel, analysis): fom_is_log_likelihood=True, resample_figure_of_merit=-1.0e99, iterations_per_quick_update=self.iterations_per_quick_update, - use_jax_vmap=True, + use_jax_vmap=False, batch_size=self.config_dict_search["n_batch"], ) @@ -225,13 +225,15 @@ def fit_x1_cpu(self, fitness, model, analysis): except KeyError: pass + vectorized = fitness.use_jax_vmap + search_internal = self.sampler_cls( prior=PriorVectorized(model=model), likelihood=fitness.call_wrap, n_dim=model.prior_count, filepath=self.checkpoint_file, pool=None, - vectorized=True, + vectorized=vectorized, **config_dict, )