diff --git a/autofit/non_linear/search/abstract_search.py b/autofit/non_linear/search/abstract_search.py index 1051d80e0..6b58bd1d8 100644 --- a/autofit/non_linear/search/abstract_search.py +++ b/autofit/non_linear/search/abstract_search.py @@ -470,7 +470,20 @@ class represented by model M and gives a score for their fitness. """ self.check_model(model=model) - logger.info(f"Starting non-linear search with {self.number_of_cores} cores.") + if getattr(analysis, "_use_jax", False): + try: + import jax + devices = jax.devices() + device = devices[0] + backend = device.platform.upper() + device_name = getattr(device, "device_kind", backend) + logger.info( + f"Starting non-linear search with JAX ({backend}: {device_name})." + ) + except Exception: + logger.info("Starting non-linear search with JAX.") + else: + logger.info(f"Starting non-linear search with {self.number_of_cores} cores.") self._log_process_state() model = analysis.modify_model(model) diff --git a/autofit/non_linear/search/nest/nautilus/search.py b/autofit/non_linear/search/nest/nautilus/search.py index dcf8cdeb0..b70b3778d 100644 --- a/autofit/non_linear/search/nest/nautilus/search.py +++ b/autofit/non_linear/search/nest/nautilus/search.py @@ -216,11 +216,14 @@ def fit_x1_cpu(self, fitness, model, analysis): the log likelihood the search maximizes. """ - self.logger.info( - """ - Running search where parallelization is disabled. - """ - ) + if analysis._use_jax: + self.logger.info( + "Running search with JAX vectorization (parallelization handled by JAX)." + ) + else: + self.logger.info( + "Running search where parallelization is disabled." + ) config_dict = self.config_dict_search try: