Conversation
There was a problem hiding this comment.
Pull request overview
Updates Fitness initialization to better handle JAX-on-CPU scenarios by forcing a batch size of 1 and removing an xp initializer argument.
Changes:
- Removed the
xpparameter fromFitness.__init__. - Added JAX CPU-backend detection that logs an info message and sets
batch_size = 1whenanalysis._use_jaxis enabled.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| if jax.default_backend() == "cpu": | ||
|
|
||
| logger.info("JAX using CPU backend, batch size set to 1 which will improve performance.") | ||
|
|
There was a problem hiding this comment.
Setting the local batch_size = 1 here only changes self.batch_size and does not actually constrain the size of batches passed into the vmapped likelihood (the effective vmap batch size is determined by the shape of the parameters array passed to call_wrap). If the goal is to avoid vmapping over batches >1 on CPU, this needs to be enforced at the call site (e.g., configuring vectorized samplers to use n_batch=1) or by adding logic in call_wrap to chunk/split incoming 2D parameter arrays to the desired batch size.
|
|
||
| logger.info("JAX using CPU backend, batch size set to 1 which will improve performance.") | ||
|
|
There was a problem hiding this comment.
This introduces new CPU-specific behavior for JAX (forcing batch_size to 1) but there doesn’t appear to be coverage asserting this branch (or that the informational log is emitted). Consider adding a unit test that simulates a JAX CPU backend (e.g., via monkeypatching jax.default_backend or a stub module) and verifies the resulting Fitness.batch_size value.
| use_jax_vmap : bool = False, | ||
| batch_size : Optional[int] = None, | ||
| iterations_per_quick_update: Optional[int] = None, | ||
| xp=np, | ||
| ): |
There was a problem hiding this comment.
The xp parameter was removed from __init__, but xp is still referenced later in this method (e.g., -xp.inf for resample_figure_of_merit / quick_update_max_lh). This will raise NameError at runtime during Fitness construction. Replace remaining xp references with np or self._xp (consistent with the existing _xp property) and ensure resample_figure_of_merit / quick_update_max_lh are initialized using that backend.
This pull request introduces a minor update to the initialization logic in
autofit/non_linear/fitness.py, specifically to improve performance when using the JAX backend on CPUs. Now, when JAX is detected as the analysis backend and it is running on a CPU, the batch size is automatically set to 1, and a log message is generated to inform the user.This is because CPU's do not get any speed up from using a vmap with a batch size above 1, but this can cause to issues with shared parallelism.
Performance optimization for JAX CPU backend:
batch_sizeto 1 and logs an informational message when JAX is used with the CPU backend to improve performance. (autofit/non_linear/fitness.py)Minor cleanup:
xpparameter from the__init__method signature. (autofit/non_linear/fitness.py)