From 380f73b397f1e898e568abe36b84ac4168ae479c Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 3 Apr 2026 19:07:32 +0100 Subject: [PATCH] Make search logging JAX-aware: report GPU/CPU instead of misleading core count When JAX is active, the "Starting non-linear search with 1 cores" message was inaccurate (JAX handles its own parallelization). Now logs the JAX backend and device instead. Similarly, Nautilus's "parallelization is disabled" message is replaced with a JAX-specific message when appropriate. Closes rhayes777/PyAutoFit#1175 Co-Authored-By: Claude Opus 4.6 (1M context) --- autofit/non_linear/search/abstract_search.py | 15 ++++++++++++++- autofit/non_linear/search/nest/nautilus/search.py | 13 ++++++++----- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/autofit/non_linear/search/abstract_search.py b/autofit/non_linear/search/abstract_search.py index 1051d80e0..6b58bd1d8 100644 --- a/autofit/non_linear/search/abstract_search.py +++ b/autofit/non_linear/search/abstract_search.py @@ -470,7 +470,20 @@ class represented by model M and gives a score for their fitness. """ self.check_model(model=model) - logger.info(f"Starting non-linear search with {self.number_of_cores} cores.") + if getattr(analysis, "_use_jax", False): + try: + import jax + devices = jax.devices() + device = devices[0] + backend = device.platform.upper() + device_name = getattr(device, "device_kind", backend) + logger.info( + f"Starting non-linear search with JAX ({backend}: {device_name})." + ) + except Exception: + logger.info("Starting non-linear search with JAX.") + else: + logger.info(f"Starting non-linear search with {self.number_of_cores} cores.") self._log_process_state() model = analysis.modify_model(model) diff --git a/autofit/non_linear/search/nest/nautilus/search.py b/autofit/non_linear/search/nest/nautilus/search.py index dcf8cdeb0..b70b3778d 100644 --- a/autofit/non_linear/search/nest/nautilus/search.py +++ b/autofit/non_linear/search/nest/nautilus/search.py @@ -216,11 +216,14 @@ def fit_x1_cpu(self, fitness, model, analysis): the log likelihood the search maximizes. """ - self.logger.info( - """ - Running search where parallelization is disabled. - """ - ) + if analysis._use_jax: + self.logger.info( + "Running search with JAX vectorization (parallelization handled by JAX)." + ) + else: + self.logger.info( + "Running search where parallelization is disabled." + ) config_dict = self.config_dict_search try: