Merged
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
This pull request updates the JAX vectorization (vmap) handling in the Nautilus nested sampling search implementation to improve performance when using the CPU backend. The changes disable vmap by default for CPU-based JAX operations and ensure consistent propagation of the vectorization flag throughout the search process.
Changes:
- Modified the
Fitnessclass initialization to automatically disableuse_jax_vmapwhen JAX is running on CPU backend, replacing the previous approach of setting batch_size to 1 - Updated the Nautilus
_fitmethod to explicitly setuse_jax_vmap=Falsewhen creating the fitness object - Modified
fit_x1_cputo propagate theuse_jax_vmapflag to the sampler'svectorizedparameter
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| autofit/non_linear/fitness.py | Reorganized initialization logic to disable vmap when JAX uses CPU backend, updated log message to reflect the change from batch_size adjustment to vmap disabling |
| autofit/non_linear/search/nest/nautilus/search.py | Set use_jax_vmap=False in _fit method and added logic in fit_x1_cpu to propagate the vmap flag to the sampler's vectorized parameter |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This pull request primarily updates how JAX vectorization (
vmap) is handled in the Nautilus search implementation, ensuring that vectorization is only enabled when appropriate for performance reasons. The changes disablevmapby default, especially when using the CPU backend, and ensure that the vectorization flag is consistently propagated throughout the codebase.JAX vectorization handling:
autofit/non_linear/fitness.py, the initialization logic now disablesuse_jax_vmapwhen the JAX backend is set to CPU, improving performance by avoiding unnecessary vectorization. The log message has also been updated to reflect this change.autofit/non_linear/search/nest/nautilus/search.py, the_fitmethod now explicitly setsuse_jax_vmap=Falsewhen constructing the fitness object, ensuring vectorization is off by default.Vectorization flag propagation:
fit_x1_cpumethod ofautofit/non_linear/search/nest/nautilus/search.py, thevectorizedargument for the sampler is now set based on thefitness.use_jax_vmapflag, ensuring consistency in how vectorization is applied throughout the search process.