diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index 257fd1547..1987beb6c 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -1,4 +1,4 @@ -import numpy as np + import os from typing import Optional @@ -6,6 +6,8 @@ from autofit import exc +from autofit.jax_wrapper import numpy as np + from autofit.mapper.prior_model.abstract import AbstractPriorModel from autofit.non_linear.paths.abstract import AbstractPaths from autofit.non_linear.analysis import Analysis @@ -154,9 +156,7 @@ def __call__(self, parameters, *kwargs): try: instance = self.model.instance_from_vector(vector=parameters) log_likelihood = self.log_likelihood_function(instance=instance) - - if np.isnan(log_likelihood): - return self.resample_figure_of_merit + log_likelihood = np.where(np.isnan(log_likelihood), self.resample_figure_of_merit, log_likelihood) except exc.FitException: return self.resample_figure_of_merit diff --git a/autofit/non_linear/initializer.py b/autofit/non_linear/initializer.py index 5a762a54a..2cf0d3127 100644 --- a/autofit/non_linear/initializer.py +++ b/autofit/non_linear/initializer.py @@ -13,6 +13,8 @@ from autofit.mapper.prior_model.abstract import AbstractPriorModel from autofit.non_linear.parallel import SneakyPool +from autofit import jax_wrapper + logger = logging.getLogger(__name__) @@ -39,14 +41,14 @@ def figure_of_metric(args) -> Optional[float]: return None def samples_from_model( - self, - total_points: int, - model: AbstractPriorModel, - fitness, - paths: AbstractPaths, - use_prior_medians: bool = False, - test_mode_samples: bool = True, - n_cores: int = 1, + self, + total_points: int, + model: AbstractPriorModel, + fitness, + paths: AbstractPaths, + use_prior_medians: bool = False, + test_mode_samples: bool = True, + n_cores: int = 1, ): """ Generate the initial points of the non-linear search, by randomly drawing unit values from a uniform @@ -64,6 +66,14 @@ def samples_from_model( if os.environ.get("PYAUTOFIT_TEST_MODE") == "1" and test_mode_samples: return self.samples_in_test_mode(total_points=total_points, model=model) + if jax_wrapper.use_jax: + return self.samples_jax( + total_points=total_points, + model=model, + fitness=fitness, + use_prior_medians=use_prior_medians + ) + unit_parameter_lists = [] parameter_lists = [] figures_of_merit_list = [] @@ -92,13 +102,13 @@ def samples_from_model( unit_parameter_lists_.append(unit_parameter_list) for figure_of_merit, unit_parameter_list, parameter_list in zip( - sneaky_pool.map( - function=self.figure_of_metric, - args_list=[(fitness, parameter_list) for parameter_list in parameter_lists_], - log_info=False - ), - unit_parameter_lists_, - parameter_lists_, + sneaky_pool.map( + function=self.figure_of_metric, + args_list=[(fitness, parameter_list) for parameter_list in parameter_lists_], + log_info=False + ), + unit_parameter_lists_, + parameter_lists_, ): if figure_of_merit is not None: unit_parameter_lists.append(unit_parameter_list) @@ -106,16 +116,81 @@ def samples_from_model( figures_of_merit_list.append(figure_of_merit) if total_points > 1 and np.allclose( - a=figures_of_merit_list[0], b=figures_of_merit_list[1:] + a=figures_of_merit_list[0], b=figures_of_merit_list[1:] ): raise exc.InitializerException( """ The initial samples all have the same figure of merit (e.g. log likelihood values). - + The non-linear search will therefore not progress correctly. - + Possible causes for this behaviour are: - + + - The `log_likelihood_function` of the analysis class is defined incorrectly. + - The model parameterization creates numerically inaccurate log likelihoods. + - The`log_likelihood_function` is always returning `nan` values. + """ + ) + + logger.info(f"Initial samples generated, starting non-linear search") + + return unit_parameter_lists, parameter_lists, figures_of_merit_list + + def samples_jax( + self, + total_points: int, + model: AbstractPriorModel, + fitness, + use_prior_medians: bool = False, + ): + """ + Generate the initial points of the non-linear search, by randomly drawing unit values from a uniform + distribution between the ball_lower_limit and ball_upper_limit values. + + Parameters + ---------- + total_points + The number of points in non-linear paramemter space which initial points are created for. + model + An object that represents possible instances of some model with a given dimensionality which is the number + of free dimensions of the model. + """ + + unit_parameter_lists = [] + parameter_lists = [] + figures_of_merit_list = [] + + logger.info(f"Generating initial samples of model using JAX LH Function cores") + + while len(figures_of_merit_list) < total_points: + + if not use_prior_medians: + unit_parameter_list = self._generate_unit_parameter_list(model) + else: + unit_parameter_list = [0.5] * model.prior_count + + parameter_list = model.vector_from_unit_vector( + unit_vector=unit_parameter_list + ) + + figure_of_merit = self.figure_of_metric((fitness, parameter_list)) + + if figure_of_merit is not None: + unit_parameter_lists.append(unit_parameter_list) + parameter_lists.append(parameter_list) + figures_of_merit_list.append(figure_of_merit) + + if total_points > 1 and np.allclose( + a=figures_of_merit_list[0], b=figures_of_merit_list[1:] + ): + raise exc.InitializerException( + """ + The initial samples all have the same figure of merit (e.g. log likelihood values). + + The non-linear search will therefore not progress correctly. + + Possible causes for this behaviour are: + - The `log_likelihood_function` of the analysis class is defined incorrectly. - The model parameterization creates numerically inaccurate log likelihoods. - The`log_likelihood_function` is always returning `nan` values. @@ -321,6 +396,7 @@ def info_value_from(self, value : Tuple[float, float]) -> float: """ return (value[1] + value[0]) / 2.0 + class Initializer(AbstractInitializer): def __init__(self, lower_limit: float, upper_limit: float): """