Overview
The non-linear search logging currently displays misleading messages when JAX is being used. "Starting non-linear search with 1 cores" is inaccurate because JAX automatically uses all available cores/GPU. Similarly, "Running search where parallelization is disabled" is misleading for JAX since JAX handles its own parallelization. These messages should be JAX-aware and also report whether JAX is using CPU or GPU.
Plan
- Detect JAX usage via
analysis._use_jax before logging the search start message
- Replace the "N cores" message with a JAX-specific message that reports CPU vs GPU when JAX is active
- Remove or replace the "parallelization is disabled" message in Nautilus's
fit_x1_cpu when JAX is being used
- Keep existing logging unchanged for non-JAX (NumPy) code paths
Detailed implementation plan
Affected Repositories
- rhayes777/PyAutoFit (primary)
Branch Survey
| Repository |
Current Branch |
Dirty? |
| ./PyAutoFit |
main_build |
3 untracked (new agent configs) |
Suggested branch: feature/jax-search-logging
Implementation Steps
-
autofit/non_linear/search/abstract_search.py ~line 473 — In the fit() method, check analysis._use_jax before logging. If JAX is active, detect CPU/GPU via jax.devices() and log something like "Starting non-linear search with JAX (GPU: NVIDIA A100)" or "Starting non-linear search with JAX (CPU)". If not JAX, keep the existing "Starting non-linear search with {self.number_of_cores} cores." message.
-
autofit/non_linear/search/nest/nautilus/search.py ~line 219-223 — In fit_x1_cpu(), check if JAX is being used (the fitness.use_jax_vmap flag is available here). If JAX, skip or replace the "parallelization is disabled" message with something accurate like "Running search with JAX vectorization (parallelization handled by JAX).".
-
JAX device detection helper — Add a small utility (in abstract_search.py or inline) that safely imports jax and calls jax.devices() to report the backend. Guard with try/except in case JAX is not installed.
-
Testing — Run python -m pytest test_autofit/ to verify no regressions. The logging changes don't affect behaviour, so existing tests should pass unchanged.
Key Files
autofit/non_linear/search/abstract_search.py — Main search start log message (~line 473)
autofit/non_linear/search/nest/nautilus/search.py — "Parallelization disabled" message (~line 219)
autofit/non_linear/analysis/analysis.py — _use_jax flag definition (reference only)
Original Prompt
Click to expand starting prompt
Read @PyAutoFit/autofit/non_linear/search for context where this is.
The logging of a fit Can be improved as follows:
2026-04-01 18:46:35,500 - autofit.non_linear.search.abstract_search - INFO - Starting non-linear search with 1 cores.
This is confusing for JAX users, as JAX automatically uses all cores, so the above is not true.
Instead, if JAX is being used it should be something like:
2026-04-01 18:46:35,500 - autofit.non_linear.search.abstract_search - INFO - Starting non-linear search with JAX.
Can you make it so JAX displays if CPU or GPU is being used?
Running search where parallelization is disabled.
Again, this is not true in JAX, just remove this.
Overview
The non-linear search logging currently displays misleading messages when JAX is being used. "Starting non-linear search with 1 cores" is inaccurate because JAX automatically uses all available cores/GPU. Similarly, "Running search where parallelization is disabled" is misleading for JAX since JAX handles its own parallelization. These messages should be JAX-aware and also report whether JAX is using CPU or GPU.
Plan
analysis._use_jaxbefore logging the search start messagefit_x1_cpuwhen JAX is being usedDetailed implementation plan
Affected Repositories
Branch Survey
Suggested branch:
feature/jax-search-loggingImplementation Steps
autofit/non_linear/search/abstract_search.py~line 473 — In thefit()method, checkanalysis._use_jaxbefore logging. If JAX is active, detect CPU/GPU viajax.devices()and log something like"Starting non-linear search with JAX (GPU: NVIDIA A100)"or"Starting non-linear search with JAX (CPU)". If not JAX, keep the existing"Starting non-linear search with {self.number_of_cores} cores."message.autofit/non_linear/search/nest/nautilus/search.py~line 219-223 — Infit_x1_cpu(), check if JAX is being used (thefitness.use_jax_vmapflag is available here). If JAX, skip or replace the "parallelization is disabled" message with something accurate like"Running search with JAX vectorization (parallelization handled by JAX).".JAX device detection helper — Add a small utility (in
abstract_search.pyor inline) that safely importsjaxand callsjax.devices()to report the backend. Guard with try/except in case JAX is not installed.Testing — Run
python -m pytest test_autofit/to verify no regressions. The logging changes don't affect behaviour, so existing tests should pass unchanged.Key Files
autofit/non_linear/search/abstract_search.py— Main search start log message (~line 473)autofit/non_linear/search/nest/nautilus/search.py— "Parallelization disabled" message (~line 219)autofit/non_linear/analysis/analysis.py—_use_jaxflag definition (reference only)Original Prompt
Click to expand starting prompt
Read @PyAutoFit/autofit/non_linear/search for context where this is.
The logging of a fit Can be improved as follows:
This is confusing for JAX users, as JAX automatically uses all cores, so the above is not true.
Instead, if JAX is being used it should be something like:
Can you make it so JAX displays if CPU or GPU is being used?