Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions autofit/non_linear/fitness.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
use_jax_vmap : bool = False,
batch_size : Optional[int] = None,
iterations_per_quick_update: Optional[int] = None,
xp=np,
):
Comment on lines 44 to 47
Copy link

Copilot AI Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The xp parameter was removed from __init__, but xp is still referenced later in this method (e.g., -xp.inf for resample_figure_of_merit / quick_update_max_lh). This will raise NameError at runtime during Fitness construction. Replace remaining xp references with np or self._xp (consistent with the existing _xp property) and ensure resample_figure_of_merit / quick_update_max_lh are initialized using that backend.

Copilot uses AI. Check for mistakes.
"""
Interfaces with any non-linear search to fit the model to the data and return a log likelihood via
Expand Down Expand Up @@ -109,7 +108,8 @@ def __init__(
self.model = model
self.paths = paths
self.fom_is_log_likelihood = fom_is_log_likelihood
self.resample_figure_of_merit = resample_figure_of_merit or -xp.inf

self.resample_figure_of_merit = resample_figure_of_merit or -self._xp.inf
self.convert_to_chi_squared = convert_to_chi_squared
self.store_history = store_history

Expand All @@ -123,10 +123,20 @@ def __init__(
if self.use_jax_vmap:
self._call = self._vmap

if analysis._use_jax:

import jax

if jax.default_backend() == "cpu":

logger.info("JAX using CPU backend, batch size set to 1 which will improve performance.")

Comment on lines +129 to +133
Copy link

Copilot AI Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting the local batch_size = 1 here only changes self.batch_size and does not actually constrain the size of batches passed into the vmapped likelihood (the effective vmap batch size is determined by the shape of the parameters array passed to call_wrap). If the goal is to avoid vmapping over batches >1 on CPU, this needs to be enforced at the call site (e.g., configuring vectorized samplers to use n_batch=1) or by adding logic in call_wrap to chunk/split incoming 2D parameter arrays to the desired batch size.

Copilot uses AI. Check for mistakes.
Comment on lines +131 to +133
Copy link

Copilot AI Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This introduces new CPU-specific behavior for JAX (forcing batch_size to 1) but there doesn’t appear to be coverage asserting this branch (or that the informational log is emitted). Consider adding a unit test that simulates a JAX CPU backend (e.g., via monkeypatching jax.default_backend or a stub module) and verifies the resulting Fitness.batch_size value.

Copilot uses AI. Check for mistakes.
batch_size = 1

self.batch_size = batch_size
self.iterations_per_quick_update = iterations_per_quick_update
self.quick_update_max_lh_parameters = None
self.quick_update_max_lh = -xp.inf
self.quick_update_max_lh = -self._xp.inf
self.quick_update_count = 0

if self.paths is not None:
Expand Down
Loading