Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions autofit/non_linear/fitness.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,22 +116,22 @@ def __init__(
self.parameters_history_list = []
self.log_likelihood_history_list = []

self.use_jax_vmap = use_jax_vmap

self._call = self.call

if self.use_jax_vmap:
self._call = self._vmap

if analysis._use_jax:

import jax

if jax.default_backend() == "cpu":

logger.info("JAX using CPU backend, batch size set to 1 which will improve performance.")
logger.info("JAX using CPU backend, vmap disabled for faster performance.")

batch_size = 1
use_jax_vmap = False

self.use_jax_vmap = use_jax_vmap

self._call = self.call

if self.use_jax_vmap:
self._call = self._vmap

self.batch_size = batch_size
self.iterations_per_quick_update = iterations_per_quick_update
Expand Down
6 changes: 4 additions & 2 deletions autofit/non_linear/search/nest/nautilus/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _fit(self, model: AbstractPriorModel, analysis):
fom_is_log_likelihood=True,
resample_figure_of_merit=-1.0e99,
iterations_per_quick_update=self.iterations_per_quick_update,
use_jax_vmap=True,
use_jax_vmap=False,
batch_size=self.config_dict_search["n_batch"],
)

Expand Down Expand Up @@ -225,13 +225,15 @@ def fit_x1_cpu(self, fitness, model, analysis):
except KeyError:
pass

vectorized = fitness.use_jax_vmap

search_internal = self.sampler_cls(
prior=PriorVectorized(model=model),
likelihood=fitness.call_wrap,
n_dim=model.prior_count,
filepath=self.checkpoint_file,
pool=None,
vectorized=True,
vectorized=vectorized,
**config_dict,
)

Expand Down
Loading