Skip to content

fix: Improve non-linear search logging for JAX users #1175

@Jammy2211

Description

@Jammy2211

Overview

The non-linear search logging currently displays misleading messages when JAX is being used. "Starting non-linear search with 1 cores" is inaccurate because JAX automatically uses all available cores/GPU. Similarly, "Running search where parallelization is disabled" is misleading for JAX since JAX handles its own parallelization. These messages should be JAX-aware and also report whether JAX is using CPU or GPU.

Plan

  • Detect JAX usage via analysis._use_jax before logging the search start message
  • Replace the "N cores" message with a JAX-specific message that reports CPU vs GPU when JAX is active
  • Remove or replace the "parallelization is disabled" message in Nautilus's fit_x1_cpu when JAX is being used
  • Keep existing logging unchanged for non-JAX (NumPy) code paths
Detailed implementation plan

Affected Repositories

  • rhayes777/PyAutoFit (primary)

Branch Survey

Repository Current Branch Dirty?
./PyAutoFit main_build 3 untracked (new agent configs)

Suggested branch: feature/jax-search-logging

Implementation Steps

  1. autofit/non_linear/search/abstract_search.py ~line 473 — In the fit() method, check analysis._use_jax before logging. If JAX is active, detect CPU/GPU via jax.devices() and log something like "Starting non-linear search with JAX (GPU: NVIDIA A100)" or "Starting non-linear search with JAX (CPU)". If not JAX, keep the existing "Starting non-linear search with {self.number_of_cores} cores." message.

  2. autofit/non_linear/search/nest/nautilus/search.py ~line 219-223 — In fit_x1_cpu(), check if JAX is being used (the fitness.use_jax_vmap flag is available here). If JAX, skip or replace the "parallelization is disabled" message with something accurate like "Running search with JAX vectorization (parallelization handled by JAX).".

  3. JAX device detection helper — Add a small utility (in abstract_search.py or inline) that safely imports jax and calls jax.devices() to report the backend. Guard with try/except in case JAX is not installed.

  4. Testing — Run python -m pytest test_autofit/ to verify no regressions. The logging changes don't affect behaviour, so existing tests should pass unchanged.

Key Files

  • autofit/non_linear/search/abstract_search.py — Main search start log message (~line 473)
  • autofit/non_linear/search/nest/nautilus/search.py — "Parallelization disabled" message (~line 219)
  • autofit/non_linear/analysis/analysis.py_use_jax flag definition (reference only)

Original Prompt

Click to expand starting prompt

Read @PyAutoFit/autofit/non_linear/search for context where this is.

The logging of a fit Can be improved as follows:

2026-04-01 18:46:35,500 - autofit.non_linear.search.abstract_search - INFO - Starting non-linear search with 1 cores.

This is confusing for JAX users, as JAX automatically uses all cores, so the above is not true.

Instead, if JAX is being used it should be something like:

2026-04-01 18:46:35,500 - autofit.non_linear.search.abstract_search - INFO - Starting non-linear search with JAX.

Can you make it so JAX displays if CPU or GPU is being used?

Running search where parallelization is disabled.

Again, this is not true in JAX, just remove this.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions