From fda58cf69d2be9cbd402560df71bd0113db108cb Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 30 Jan 2026 18:18:28 +0000 Subject: [PATCH 1/2] CPU JAX uses batch size 1 --- autofit/non_linear/fitness.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index a41fc7614..1eb5c9f91 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -44,7 +44,6 @@ def __init__( use_jax_vmap : bool = False, batch_size : Optional[int] = None, iterations_per_quick_update: Optional[int] = None, - xp=np, ): """ Interfaces with any non-linear search to fit the model to the data and return a log likelihood via @@ -123,6 +122,16 @@ def __init__( if self.use_jax_vmap: self._call = self._vmap + if analysis._use_jax: + + import jax + + if jax.default_backend() == "cpu": + + logger.info("JAX using CPU backend, batch size set to 1 which will improve performance.") + + batch_size = 1 + self.batch_size = batch_size self.iterations_per_quick_update = iterations_per_quick_update self.quick_update_max_lh_parameters = None From 074b97b5b936f90ea17fd4a67f0cba7076e59a23 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 30 Jan 2026 18:24:29 +0000 Subject: [PATCH 2/2] update xp uses in fitness --- autofit/non_linear/fitness.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index 1eb5c9f91..fca95406d 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -108,7 +108,8 @@ def __init__( self.model = model self.paths = paths self.fom_is_log_likelihood = fom_is_log_likelihood - self.resample_figure_of_merit = resample_figure_of_merit or -xp.inf + + self.resample_figure_of_merit = resample_figure_of_merit or -self._xp.inf self.convert_to_chi_squared = convert_to_chi_squared self.store_history = store_history @@ -135,7 +136,7 @@ def __init__( self.batch_size = batch_size self.iterations_per_quick_update = iterations_per_quick_update self.quick_update_max_lh_parameters = None - self.quick_update_max_lh = -xp.inf + self.quick_update_max_lh = -self._xp.inf self.quick_update_count = 0 if self.paths is not None: