Skip to content

feature/jax_cpu_jit#1170

Merged
Jammy2211 merged 2 commits intomain_buildfrom
feature/jax_cpu_jit
Jan 30, 2026
Merged

feature/jax_cpu_jit#1170
Jammy2211 merged 2 commits intomain_buildfrom
feature/jax_cpu_jit

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

This pull request primarily updates how JAX vectorization (vmap) is handled in the Nautilus search implementation, ensuring that vectorization is only enabled when appropriate for performance reasons. The changes disable vmap by default, especially when using the CPU backend, and ensure that the vectorization flag is consistently propagated throughout the codebase.

JAX vectorization handling:

  • In autofit/non_linear/fitness.py, the initialization logic now disables use_jax_vmap when the JAX backend is set to CPU, improving performance by avoiding unnecessary vectorization. The log message has also been updated to reflect this change.
  • In autofit/non_linear/search/nest/nautilus/search.py, the _fit method now explicitly sets use_jax_vmap=False when constructing the fitness object, ensuring vectorization is off by default.

Vectorization flag propagation:

  • In the fit_x1_cpu method of autofit/non_linear/search/nest/nautilus/search.py, the vectorized argument for the sampler is now set based on the fitness.use_jax_vmap flag, ensuring consistency in how vectorization is applied throughout the search process.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request updates the JAX vectorization (vmap) handling in the Nautilus nested sampling search implementation to improve performance when using the CPU backend. The changes disable vmap by default for CPU-based JAX operations and ensure consistent propagation of the vectorization flag throughout the search process.

Changes:

  • Modified the Fitness class initialization to automatically disable use_jax_vmap when JAX is running on CPU backend, replacing the previous approach of setting batch_size to 1
  • Updated the Nautilus _fit method to explicitly set use_jax_vmap=False when creating the fitness object
  • Modified fit_x1_cpu to propagate the use_jax_vmap flag to the sampler's vectorized parameter

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
autofit/non_linear/fitness.py Reorganized initialization logic to disable vmap when JAX uses CPU backend, updated log message to reflect the change from batch_size adjustment to vmap disabling
autofit/non_linear/search/nest/nautilus/search.py Set use_jax_vmap=False in _fit method and added logic in fit_x1_cpu to propagate the vmap flag to the sampler's vectorized parameter

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@Jammy2211 Jammy2211 merged commit 21fb4ab into main_build Jan 30, 2026
4 checks passed
@Jammy2211 Jammy2211 deleted the feature/jax_cpu_jit branch February 13, 2026 13:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants