diff --git a/autofit/graphical/declarative/collection.py b/autofit/graphical/declarative/collection.py index 341e51c24..06caa0a8a 100644 --- a/autofit/graphical/declarative/collection.py +++ b/autofit/graphical/declarative/collection.py @@ -9,7 +9,10 @@ from autofit.non_linear.samples.summary import SamplesSummary from autofit.non_linear.analysis.combined import CombinedResult +from autofit.jax_wrapper import register_pytree_node_class + +@register_pytree_node_class class FactorGraphModel(AbstractDeclarativeFactor): def __init__( self, @@ -34,6 +37,20 @@ def __init__( self._model_factors = list(model_factors) self._name = name or namer(self.__class__.__name__) + def tree_flatten(self): + return ( + (self._model_factors,), + (self._name, self.include_prior_factors), + ) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls( + *children[0], + name=aux_data[0], + include_prior_factors=aux_data[1], + ) + @property def prior_model(self): """ diff --git a/autofit/graphical/declarative/factor/analysis.py b/autofit/graphical/declarative/factor/analysis.py index c28bd5ae0..c9eb22fec 100644 --- a/autofit/graphical/declarative/factor/analysis.py +++ b/autofit/graphical/declarative/factor/analysis.py @@ -9,6 +9,8 @@ from autofit.non_linear.paths.abstract import AbstractPaths from .abstract import AbstractModelFactor +from autofit.jax_wrapper import register_pytree_node_class + class FactorCallable: def __init__( @@ -43,6 +45,7 @@ def __call__(self, **kwargs: np.ndarray) -> float: return self.analysis.log_likelihood_function(instance) +@register_pytree_node_class class AnalysisFactor(AbstractModelFactor): @property def prior_model(self): @@ -85,7 +88,25 @@ def __init__( prior_variable_dict=prior_variable_dict, name=name, ) - print(name) + + def tree_flatten(self): + return ( + (self.prior_model,), + ( + self.analysis, + self.optimiser, + self.name, + ), + ) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls( + children[0], + analysis=aux_data[0], + optimiser=aux_data[1], + name=aux_data[2], + ) def __getstate__(self): return self.__dict__ diff --git a/autofit/mapper/prior/abstract.py b/autofit/mapper/prior/abstract.py index 63053123d..f0edd8a46 100644 --- a/autofit/mapper/prior/abstract.py +++ b/autofit/mapper/prior/abstract.py @@ -61,7 +61,7 @@ def tree_unflatten(cls, aux_data, children): ------- An instance of this class """ - return cls(*children, id_=aux_data[0]) + return cls(*children) @property def lower_unit_limit(self) -> float: diff --git a/autofit/mapper/prior/gaussian.py b/autofit/mapper/prior/gaussian.py index 84f9a112c..1d3a68907 100644 --- a/autofit/mapper/prior/gaussian.py +++ b/autofit/mapper/prior/gaussian.py @@ -65,7 +65,7 @@ def __init__( ) def tree_flatten(self): - return (self.mean, self.sigma, self.lower_limit, self.upper_limit), (self.id,) + return (self.mean, self.sigma, self.lower_limit, self.upper_limit, self.id), () @classmethod def with_limits(cls, lower_limit: float, upper_limit: float) -> "GaussianPrior": diff --git a/autofit/mapper/prior/log_gaussian.py b/autofit/mapper/prior/log_gaussian.py index a02d77e7b..1cf461393 100644 --- a/autofit/mapper/prior/log_gaussian.py +++ b/autofit/mapper/prior/log_gaussian.py @@ -2,12 +2,14 @@ import numpy as np +from autofit.jax_wrapper import register_pytree_node_class from autofit.messages.normal import NormalMessage from .abstract import Prior from ...messages.composed_transform import TransformedMessage from ...messages.transform import log_transform +@register_pytree_node_class class LogGaussianPrior(Prior): __identifier_fields__ = ("lower_limit", "upper_limit", "mean", "sigma") __database_args__ = ("mean", "sigma", "lower_limit", "upper_limit", "id_") @@ -71,6 +73,15 @@ def __init__( id_=id_, ) + def tree_flatten(self): + return ( + self.mean, + self.sigma, + self.lower_limit, + self.upper_limit, + self.id, + ), () + @classmethod def with_limits(cls, lower_limit: float, upper_limit: float) -> "LogGaussianPrior": """ diff --git a/autofit/mapper/prior/log_uniform.py b/autofit/mapper/prior/log_uniform.py index baeb14a14..5b071da85 100644 --- a/autofit/mapper/prior/log_uniform.py +++ b/autofit/mapper/prior/log_uniform.py @@ -2,6 +2,7 @@ import numpy as np +from autofit.jax_wrapper import register_pytree_node_class from autofit import exc from autofit.messages.normal import UniformNormalMessage from autofit.messages.transform import log_10_transform, LinearShiftTransform @@ -9,6 +10,7 @@ from ...messages.composed_transform import TransformedMessage +@register_pytree_node_class class LogUniformPrior(Prior): def __init__( self, @@ -67,6 +69,13 @@ def __init__( id_=id_, ) + def tree_flatten(self): + return ( + self.lower_limit, + self.upper_limit, + self.id, + ), () + @classmethod def with_limits(cls, lower_limit: float, upper_limit: float) -> "LogUniformPrior": """ diff --git a/autofit/mapper/prior/uniform.py b/autofit/mapper/prior/uniform.py index 3cea04a90..c2d83c157 100644 --- a/autofit/mapper/prior/uniform.py +++ b/autofit/mapper/prior/uniform.py @@ -62,7 +62,7 @@ def __init__( ) def tree_flatten(self): - return (self.lower_limit, self.upper_limit), (self.id,) + return (self.lower_limit, self.upper_limit, self.id), () def with_limits( self, diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index 257fd1547..1987beb6c 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -1,4 +1,4 @@ -import numpy as np + import os from typing import Optional @@ -6,6 +6,8 @@ from autofit import exc +from autofit.jax_wrapper import numpy as np + from autofit.mapper.prior_model.abstract import AbstractPriorModel from autofit.non_linear.paths.abstract import AbstractPaths from autofit.non_linear.analysis import Analysis @@ -154,9 +156,7 @@ def __call__(self, parameters, *kwargs): try: instance = self.model.instance_from_vector(vector=parameters) log_likelihood = self.log_likelihood_function(instance=instance) - - if np.isnan(log_likelihood): - return self.resample_figure_of_merit + log_likelihood = np.where(np.isnan(log_likelihood), self.resample_figure_of_merit, log_likelihood) except exc.FitException: return self.resample_figure_of_merit diff --git a/autofit/non_linear/initializer.py b/autofit/non_linear/initializer.py index 5a762a54a..2cf0d3127 100644 --- a/autofit/non_linear/initializer.py +++ b/autofit/non_linear/initializer.py @@ -13,6 +13,8 @@ from autofit.mapper.prior_model.abstract import AbstractPriorModel from autofit.non_linear.parallel import SneakyPool +from autofit import jax_wrapper + logger = logging.getLogger(__name__) @@ -39,14 +41,14 @@ def figure_of_metric(args) -> Optional[float]: return None def samples_from_model( - self, - total_points: int, - model: AbstractPriorModel, - fitness, - paths: AbstractPaths, - use_prior_medians: bool = False, - test_mode_samples: bool = True, - n_cores: int = 1, + self, + total_points: int, + model: AbstractPriorModel, + fitness, + paths: AbstractPaths, + use_prior_medians: bool = False, + test_mode_samples: bool = True, + n_cores: int = 1, ): """ Generate the initial points of the non-linear search, by randomly drawing unit values from a uniform @@ -64,6 +66,14 @@ def samples_from_model( if os.environ.get("PYAUTOFIT_TEST_MODE") == "1" and test_mode_samples: return self.samples_in_test_mode(total_points=total_points, model=model) + if jax_wrapper.use_jax: + return self.samples_jax( + total_points=total_points, + model=model, + fitness=fitness, + use_prior_medians=use_prior_medians + ) + unit_parameter_lists = [] parameter_lists = [] figures_of_merit_list = [] @@ -92,13 +102,13 @@ def samples_from_model( unit_parameter_lists_.append(unit_parameter_list) for figure_of_merit, unit_parameter_list, parameter_list in zip( - sneaky_pool.map( - function=self.figure_of_metric, - args_list=[(fitness, parameter_list) for parameter_list in parameter_lists_], - log_info=False - ), - unit_parameter_lists_, - parameter_lists_, + sneaky_pool.map( + function=self.figure_of_metric, + args_list=[(fitness, parameter_list) for parameter_list in parameter_lists_], + log_info=False + ), + unit_parameter_lists_, + parameter_lists_, ): if figure_of_merit is not None: unit_parameter_lists.append(unit_parameter_list) @@ -106,16 +116,81 @@ def samples_from_model( figures_of_merit_list.append(figure_of_merit) if total_points > 1 and np.allclose( - a=figures_of_merit_list[0], b=figures_of_merit_list[1:] + a=figures_of_merit_list[0], b=figures_of_merit_list[1:] ): raise exc.InitializerException( """ The initial samples all have the same figure of merit (e.g. log likelihood values). - + The non-linear search will therefore not progress correctly. - + Possible causes for this behaviour are: - + + - The `log_likelihood_function` of the analysis class is defined incorrectly. + - The model parameterization creates numerically inaccurate log likelihoods. + - The`log_likelihood_function` is always returning `nan` values. + """ + ) + + logger.info(f"Initial samples generated, starting non-linear search") + + return unit_parameter_lists, parameter_lists, figures_of_merit_list + + def samples_jax( + self, + total_points: int, + model: AbstractPriorModel, + fitness, + use_prior_medians: bool = False, + ): + """ + Generate the initial points of the non-linear search, by randomly drawing unit values from a uniform + distribution between the ball_lower_limit and ball_upper_limit values. + + Parameters + ---------- + total_points + The number of points in non-linear paramemter space which initial points are created for. + model + An object that represents possible instances of some model with a given dimensionality which is the number + of free dimensions of the model. + """ + + unit_parameter_lists = [] + parameter_lists = [] + figures_of_merit_list = [] + + logger.info(f"Generating initial samples of model using JAX LH Function cores") + + while len(figures_of_merit_list) < total_points: + + if not use_prior_medians: + unit_parameter_list = self._generate_unit_parameter_list(model) + else: + unit_parameter_list = [0.5] * model.prior_count + + parameter_list = model.vector_from_unit_vector( + unit_vector=unit_parameter_list + ) + + figure_of_merit = self.figure_of_metric((fitness, parameter_list)) + + if figure_of_merit is not None: + unit_parameter_lists.append(unit_parameter_list) + parameter_lists.append(parameter_list) + figures_of_merit_list.append(figure_of_merit) + + if total_points > 1 and np.allclose( + a=figures_of_merit_list[0], b=figures_of_merit_list[1:] + ): + raise exc.InitializerException( + """ + The initial samples all have the same figure of merit (e.g. log likelihood values). + + The non-linear search will therefore not progress correctly. + + Possible causes for this behaviour are: + - The `log_likelihood_function` of the analysis class is defined incorrectly. - The model parameterization creates numerically inaccurate log likelihoods. - The`log_likelihood_function` is always returning `nan` values. @@ -321,6 +396,7 @@ def info_value_from(self, value : Tuple[float, float]) -> float: """ return (value[1] + value[0]) / 2.0 + class Initializer(AbstractInitializer): def __init__(self, lower_limit: float, upper_limit: float): """ diff --git a/test_autofit/conftest.py b/test_autofit/conftest.py index 2eeb884ae..0380fc9e9 100644 --- a/test_autofit/conftest.py +++ b/test_autofit/conftest.py @@ -21,6 +21,18 @@ directory = Path(__file__).parent +@pytest.fixture(name="recreate") +def recreate(): + jax = pytest.importorskip("jax") + + def _recreate(o): + flatten_func, unflatten_func = jax._src.tree_util._registry[type(o)] + children, aux_data = flatten_func(o) + return unflatten_func(aux_data, children) + + return _recreate + + @pytest.fixture(autouse=True) def turn_off_gc(monkeypatch): monkeypatch.setattr(abstract_search, "gc", MagicMock()) diff --git a/test_autofit/graphical/gaussian/test_declarative.py b/test_autofit/graphical/gaussian/test_declarative.py index 00a80c92b..f59406c96 100644 --- a/test_autofit/graphical/gaussian/test_declarative.py +++ b/test_autofit/graphical/gaussian/test_declarative.py @@ -9,7 +9,7 @@ @pytest.fixture(name="make_model_factor") def make_make_model_factor(normalization, normalization_prior, x): def make_factor_model( - centre: float, sigma: float, optimiser=None + centre: float, sigma: float, optimiser=None ) -> ep.AnalysisFactor: """ We'll make a LikelihoodModel for each Gaussian we're fitting. @@ -108,7 +108,9 @@ def _test_optimise_factor_model(factor_model): """ laplace = ep.LaplaceOptimiser() - collection = factor_model.optimise(laplace, ) + collection = factor_model.optimise( + laplace, + ) """ And what we get back is actually a PriorModelCollection @@ -171,3 +173,14 @@ def test_prior_model_node(likelihood_model): ) assert isinstance(result, ep.FactorValue) + + +def test_pytrees( + recreate, + factor_model, + make_model_factor, +): + recreate(factor_model) + + model_factor = make_model_factor(centre=60, sigma=15) + recreate(model_factor) diff --git a/test_autofit/jax/test_pytrees.py b/test_autofit/jax/test_pytrees.py index d5667a325..56d080adf 100644 --- a/test_autofit/jax/test_pytrees.py +++ b/test_autofit/jax/test_pytrees.py @@ -4,16 +4,9 @@ import autofit as af - jax = pytest.importorskip("jax") -def recreate(o): - flatten_func, unflatten_func = jax._src.tree_util._registry[type(o)] - children, aux_data = flatten_func(o) - return unflatten_func(aux_data, children) - - @pytest.fixture(name="gaussian") def make_gaussian(): return af.Gaussian(centre=1.0, sigma=1.0, normalization=1.0) @@ -33,7 +26,7 @@ def vmapped(gaussian, size=1000): return list(f(np.arange(size))) -def test_gaussian_prior(): +def test_gaussian_prior(recreate): prior = af.GaussianPrior(mean=1.0, sigma=1.0) new = recreate(prior) @@ -56,7 +49,7 @@ def _model(): ) -def test_model(model): +def test_model(model, recreate): new = recreate(model) assert new.cls == af.Gaussian @@ -66,7 +59,7 @@ def test_model(model): assert centre.id == model.centre.id -def test_instance(model): +def test_instance(model, recreate): instance = model.instance_from_prior_medians() new = recreate(instance) @@ -77,7 +70,7 @@ def test_instance(model): assert new.sigma == instance.sigma -def test_uniform_prior(): +def test_uniform_prior(recreate): prior = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) new = recreate(prior) @@ -87,7 +80,7 @@ def test_uniform_prior(): assert new.id == prior.id -def test_model_instance(model): +def test_model_instance(model, recreate): collection = af.Collection(gaussian=model) instance = collection.instance_from_prior_medians() new = recreate(instance) @@ -96,7 +89,7 @@ def test_model_instance(model): assert isinstance(new.gaussian, af.Gaussian) -def test_collection(model): +def test_collection(model, recreate): collection = af.Collection(gaussian=model) new = recreate(collection) @@ -120,7 +113,7 @@ def __init__(self, **kwargs): self.__dict__.update(kwargs) -def test_kwargs(): +def test_kwargs(recreate): model = af.Model(KwargClass, a=1, b=2) instance = model.instance_from_prior_medians()