From 3ce01f40a5842f3f6d8d732777b00bc3d187529f Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 9 Dec 2025 19:32:45 +0000 Subject: [PATCH] fixes to figure of merit --- autofit/non_linear/fitness.py | 14 +++++----- autofit/non_linear/search/mle/bfgs/search.py | 27 ++++++++++++++------ 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index a3792b701..5e5e6d981 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -173,6 +173,7 @@ def call(self, parameters): # Penalize NaNs in the log-likelihood log_likelihood = self._xp.where(self._xp.isnan(log_likelihood), self.resample_figure_of_merit, log_likelihood) + log_likelihood = self._xp.where(self._xp.isinf(log_likelihood), self.resample_figure_of_merit, log_likelihood) # Determine final figure of merit if self.fom_is_log_likelihood: @@ -222,19 +223,16 @@ def call_wrap(self, parameters): figure_of_merit = self._call(parameters) if self.convert_to_chi_squared: - figure_of_merit *= -0.5 - - if self.fom_is_log_likelihood: - log_likelihood = figure_of_merit + log_likelihood = -0.5 * figure_of_merit else: + log_likelihood = figure_of_merit + + if not self.fom_is_log_likelihood: log_prior_list = self._xp.array(self.model.log_prior_list_from_vector(vector=parameters, xp=self._xp)) - log_likelihood = figure_of_merit - self._xp.sum(log_prior_list) + log_likelihood -= self._xp.sum(log_prior_list) self.manage_quick_update(parameters=parameters, log_likelihood=log_likelihood) - if self.convert_to_chi_squared: - log_likelihood *= -2.0 - if self.store_history: self.parameters_history_list.append(np.array(parameters)) diff --git a/autofit/non_linear/search/mle/bfgs/search.py b/autofit/non_linear/search/mle/bfgs/search.py index 8ca8f064b..2b105e786 100644 --- a/autofit/non_linear/search/mle/bfgs/search.py +++ b/autofit/non_linear/search/mle/bfgs/search.py @@ -136,21 +136,32 @@ def _fit( maxiter = self.config_dict_options.get("maxiter", 1e8) while total_iterations < maxiter: - iterations_remaining = maxiter - total_iterations + iterations_remaining = maxiter - total_iterations iterations = min(self.iterations_per_full_update, iterations_remaining) if iterations > 0: config_dict_options = self.config_dict_options config_dict_options["maxiter"] = iterations - search_internal = optimize.minimize( - fun=fitness._jit, - x0=x0, - method=self.method, - options=config_dict_options, - **self.config_dict_search - ) + if analysis._use_jax: + + search_internal = optimize.minimize( + fun=fitness._jit, + x0=x0, + method=self.method, + options=config_dict_options, + **self.config_dict_search + ) + else: + + search_internal = optimize.minimize( + fun=fitness.__call__, + x0=x0, + method=self.method, + options=config_dict_options, + **self.config_dict_search + ) total_iterations += search_internal.nit