Skip to content

feature/jax_cpu_batch_size_1#1169

Merged
Jammy2211 merged 2 commits intomain_buildfrom
feature/jax_cpu_batch_size_1
Jan 30, 2026
Merged

feature/jax_cpu_batch_size_1#1169
Jammy2211 merged 2 commits intomain_buildfrom
feature/jax_cpu_batch_size_1

Conversation

@Jammy2211
Copy link
Collaborator

This pull request introduces a minor update to the initialization logic in autofit/non_linear/fitness.py, specifically to improve performance when using the JAX backend on CPUs. Now, when JAX is detected as the analysis backend and it is running on a CPU, the batch size is automatically set to 1, and a log message is generated to inform the user.

This is because CPU's do not get any speed up from using a vmap with a batch size above 1, but this can cause to issues with shared parallelism.

Performance optimization for JAX CPU backend:

  • Automatically sets batch_size to 1 and logs an informational message when JAX is used with the CPU backend to improve performance. (autofit/non_linear/fitness.py)

Minor cleanup:

  • Removes the unused xp parameter from the __init__ method signature. (autofit/non_linear/fitness.py)

@Jammy2211 Jammy2211 requested a review from Copilot January 30, 2026 18:19
@Jammy2211 Jammy2211 changed the title CPU JAX uses batch size 1 feature/jax_cpu_batch_size_1 Jan 30, 2026
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates Fitness initialization to better handle JAX-on-CPU scenarios by forcing a batch size of 1 and removing an xp initializer argument.

Changes:

  • Removed the xp parameter from Fitness.__init__.
  • Added JAX CPU-backend detection that logs an info message and sets batch_size = 1 when analysis._use_jax is enabled.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +129 to +133

if jax.default_backend() == "cpu":

logger.info("JAX using CPU backend, batch size set to 1 which will improve performance.")

Copy link

Copilot AI Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting the local batch_size = 1 here only changes self.batch_size and does not actually constrain the size of batches passed into the vmapped likelihood (the effective vmap batch size is determined by the shape of the parameters array passed to call_wrap). If the goal is to avoid vmapping over batches >1 on CPU, this needs to be enforced at the call site (e.g., configuring vectorized samplers to use n_batch=1) or by adding logic in call_wrap to chunk/split incoming 2D parameter arrays to the desired batch size.

Copilot uses AI. Check for mistakes.
Comment on lines +131 to +133

logger.info("JAX using CPU backend, batch size set to 1 which will improve performance.")

Copy link

Copilot AI Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This introduces new CPU-specific behavior for JAX (forcing batch_size to 1) but there doesn’t appear to be coverage asserting this branch (or that the informational log is emitted). Consider adding a unit test that simulates a JAX CPU backend (e.g., via monkeypatching jax.default_backend or a stub module) and verifies the resulting Fitness.batch_size value.

Copilot uses AI. Check for mistakes.
Comment on lines 44 to 47
use_jax_vmap : bool = False,
batch_size : Optional[int] = None,
iterations_per_quick_update: Optional[int] = None,
xp=np,
):
Copy link

Copilot AI Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The xp parameter was removed from __init__, but xp is still referenced later in this method (e.g., -xp.inf for resample_figure_of_merit / quick_update_max_lh). This will raise NameError at runtime during Fitness construction. Replace remaining xp references with np or self._xp (consistent with the existing _xp property) and ensure resample_figure_of_merit / quick_update_max_lh are initialized using that backend.

Copilot uses AI. Check for mistakes.
@Jammy2211 Jammy2211 merged commit 49ad425 into main_build Jan 30, 2026
4 checks passed
@Jammy2211 Jammy2211 deleted the feature/jax_cpu_batch_size_1 branch February 13, 2026 13:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants