diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index a41fc7614..fca95406d 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -44,7 +44,6 @@ def __init__( use_jax_vmap : bool = False, batch_size : Optional[int] = None, iterations_per_quick_update: Optional[int] = None, - xp=np, ): """ Interfaces with any non-linear search to fit the model to the data and return a log likelihood via @@ -109,7 +108,8 @@ def __init__( self.model = model self.paths = paths self.fom_is_log_likelihood = fom_is_log_likelihood - self.resample_figure_of_merit = resample_figure_of_merit or -xp.inf + + self.resample_figure_of_merit = resample_figure_of_merit or -self._xp.inf self.convert_to_chi_squared = convert_to_chi_squared self.store_history = store_history @@ -123,10 +123,20 @@ def __init__( if self.use_jax_vmap: self._call = self._vmap + if analysis._use_jax: + + import jax + + if jax.default_backend() == "cpu": + + logger.info("JAX using CPU backend, batch size set to 1 which will improve performance.") + + batch_size = 1 + self.batch_size = batch_size self.iterations_per_quick_update = iterations_per_quick_update self.quick_update_max_lh_parameters = None - self.quick_update_max_lh = -xp.inf + self.quick_update_max_lh = -self._xp.inf self.quick_update_count = 0 if self.paths is not None: