From e85682ca5c77585503b3d5b2d2db3edb4cfaf25d Mon Sep 17 00:00:00 2001 From: Richard Date: Fri, 14 Feb 2025 15:08:47 +0000 Subject: [PATCH 1/9] move jax serialise/deserialise test util to conftest --- test_autofit/conftest.py | 12 ++++++++++++ test_autofit/jax/test_pytrees.py | 21 +++++++-------------- 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/test_autofit/conftest.py b/test_autofit/conftest.py index bac54a0a6..731535395 100644 --- a/test_autofit/conftest.py +++ b/test_autofit/conftest.py @@ -21,6 +21,18 @@ directory = Path(__file__).parent +@pytest.fixture(name="recreate") +def recreate(): + jax = pytest.importorskip("jax") + + def _recreate(o): + flatten_func, unflatten_func = jax._src.tree_util._registry[type(o)] + children, aux_data = flatten_func(o) + return unflatten_func(aux_data, children) + + return _recreate + + @pytest.fixture(autouse=True) def turn_off_gc(monkeypatch): monkeypatch.setattr(abstract_search, "gc", MagicMock()) diff --git a/test_autofit/jax/test_pytrees.py b/test_autofit/jax/test_pytrees.py index d5667a325..56d080adf 100644 --- a/test_autofit/jax/test_pytrees.py +++ b/test_autofit/jax/test_pytrees.py @@ -4,16 +4,9 @@ import autofit as af - jax = pytest.importorskip("jax") -def recreate(o): - flatten_func, unflatten_func = jax._src.tree_util._registry[type(o)] - children, aux_data = flatten_func(o) - return unflatten_func(aux_data, children) - - @pytest.fixture(name="gaussian") def make_gaussian(): return af.Gaussian(centre=1.0, sigma=1.0, normalization=1.0) @@ -33,7 +26,7 @@ def vmapped(gaussian, size=1000): return list(f(np.arange(size))) -def test_gaussian_prior(): +def test_gaussian_prior(recreate): prior = af.GaussianPrior(mean=1.0, sigma=1.0) new = recreate(prior) @@ -56,7 +49,7 @@ def _model(): ) -def test_model(model): +def test_model(model, recreate): new = recreate(model) assert new.cls == af.Gaussian @@ -66,7 +59,7 @@ def test_model(model): assert centre.id == model.centre.id -def test_instance(model): +def test_instance(model, recreate): instance = model.instance_from_prior_medians() new = recreate(instance) @@ -77,7 +70,7 @@ def test_instance(model): assert new.sigma == instance.sigma -def test_uniform_prior(): +def test_uniform_prior(recreate): prior = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) new = recreate(prior) @@ -87,7 +80,7 @@ def test_uniform_prior(): assert new.id == prior.id -def test_model_instance(model): +def test_model_instance(model, recreate): collection = af.Collection(gaussian=model) instance = collection.instance_from_prior_medians() new = recreate(instance) @@ -96,7 +89,7 @@ def test_model_instance(model): assert isinstance(new.gaussian, af.Gaussian) -def test_collection(model): +def test_collection(model, recreate): collection = af.Collection(gaussian=model) new = recreate(collection) @@ -120,7 +113,7 @@ def __init__(self, **kwargs): self.__dict__.update(kwargs) -def test_kwargs(): +def test_kwargs(recreate): model = af.Model(KwargClass, a=1, b=2) instance = model.instance_from_prior_medians() From 4fe526057b9b0b8c7efea241cdb2808e15ccce3f Mon Sep 17 00:00:00 2001 From: Richard Date: Fri, 14 Feb 2025 15:23:09 +0000 Subject: [PATCH 2/9] pytree methods for FactorGraphModel --- autofit/graphical/declarative/collection.py | 73 +++++++++---------- .../graphical/gaussian/test_declarative.py | 10 ++- 2 files changed, 41 insertions(+), 42 deletions(-) diff --git a/autofit/graphical/declarative/collection.py b/autofit/graphical/declarative/collection.py index 9bad8fb29..ad82799a5 100644 --- a/autofit/graphical/declarative/collection.py +++ b/autofit/graphical/declarative/collection.py @@ -5,16 +5,16 @@ from autofit.tools.namer import namer from .abstract import AbstractDeclarativeFactor +from autofit.jax_wrapper import register_pytree_node_class + +@register_pytree_node_class class FactorGraphModel(AbstractDeclarativeFactor): def __init__( - self, - *model_factors: Union[ - AbstractDeclarativeFactor, - HierarchicalFactor - ], - name=None, - include_prior_factors=True, + self, + *model_factors: Union[AbstractDeclarativeFactor, HierarchicalFactor], + name=None, + include_prior_factors=True, ): """ A collection of factors that describe models, which can be @@ -33,6 +33,20 @@ def __init__( self._model_factors = list(model_factors) self._name = name or namer(self.__class__.__name__) + def tree_flatten(self): + return ( + (self._model_factors,), + (self._name, self.include_prior_factors), + ) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls( + *children[0], + name=aux_data[0], + include_prior_factors=aux_data[1], + ) + @property def prior_model(self): """ @@ -40,11 +54,10 @@ def prior_model(self): in each model factor """ from autofit.mapper.prior_model.collection import Collection - return Collection({ - factor.name: factor.prior_model - for factor - in self.model_factors - }) + + return Collection( + {factor.name: factor.prior_model for factor in self.model_factors} + ) @property def optimiser(self): @@ -61,21 +74,13 @@ def info(self) -> str: def name(self): return self._name - def add( - self, - model_factor: AbstractDeclarativeFactor - ): + def add(self, model_factor: AbstractDeclarativeFactor): """ Add another factor to this collection. """ - self._model_factors.append( - model_factor - ) + self._model_factors.append(model_factor) - def log_likelihood_function( - self, - instance: ModelInstance - ) -> float: + def log_likelihood_function(self, instance: ModelInstance) -> float: """ Compute the combined likelihood of each factor from a collection of instances with the same ordering as the factors. @@ -90,13 +95,8 @@ def log_likelihood_function( The combined likelihood of all factors """ log_likelihood = 0 - for model_factor, instance_ in zip( - self.model_factors, - instance - ): - log_likelihood += model_factor.log_likelihood_function( - instance_ - ) + for model_factor, instance_ in zip(self.model_factors, instance): + log_likelihood += model_factor.log_likelihood_function(instance_) return log_likelihood @@ -104,15 +104,8 @@ def log_likelihood_function( def model_factors(self): model_factors = list() for model_factor in self._model_factors: - if isinstance( - model_factor, - HierarchicalFactor - ): - model_factors.extend( - model_factor.factors - ) + if isinstance(model_factor, HierarchicalFactor): + model_factors.extend(model_factor.factors) else: - model_factors.append( - model_factor - ) + model_factors.append(model_factor) return model_factors diff --git a/test_autofit/graphical/gaussian/test_declarative.py b/test_autofit/graphical/gaussian/test_declarative.py index 00a80c92b..0dbd105a6 100644 --- a/test_autofit/graphical/gaussian/test_declarative.py +++ b/test_autofit/graphical/gaussian/test_declarative.py @@ -9,7 +9,7 @@ @pytest.fixture(name="make_model_factor") def make_make_model_factor(normalization, normalization_prior, x): def make_factor_model( - centre: float, sigma: float, optimiser=None + centre: float, sigma: float, optimiser=None ) -> ep.AnalysisFactor: """ We'll make a LikelihoodModel for each Gaussian we're fitting. @@ -108,7 +108,9 @@ def _test_optimise_factor_model(factor_model): """ laplace = ep.LaplaceOptimiser() - collection = factor_model.optimise(laplace, ) + collection = factor_model.optimise( + laplace, + ) """ And what we get back is actually a PriorModelCollection @@ -171,3 +173,7 @@ def test_prior_model_node(likelihood_model): ) assert isinstance(result, ep.FactorValue) + + +def test_pytrees(recreate, factor_model): + recreate(factor_model) From f3e49fe5b188d24200f25b9d7469da0c596b1a5b Mon Sep 17 00:00:00 2001 From: Richard Date: Fri, 14 Feb 2025 15:44:20 +0000 Subject: [PATCH 3/9] pytree methods for AnalysisFactor --- .../graphical/declarative/factor/analysis.py | 22 +++++++++++++++++++ .../graphical/gaussian/test_declarative.py | 9 +++++++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/autofit/graphical/declarative/factor/analysis.py b/autofit/graphical/declarative/factor/analysis.py index 485872229..5496c4bbd 100644 --- a/autofit/graphical/declarative/factor/analysis.py +++ b/autofit/graphical/declarative/factor/analysis.py @@ -9,6 +9,8 @@ from autofit.non_linear.paths.abstract import AbstractPaths from .abstract import AbstractModelFactor +from autofit.jax_wrapper import register_pytree_node_class + class FactorCallable: def __init__( @@ -43,6 +45,7 @@ def __call__(self, **kwargs: np.ndarray) -> float: return self.analysis.log_likelihood_function(instance) +@register_pytree_node_class class AnalysisFactor(AbstractModelFactor): @property def prior_model(self): @@ -86,6 +89,25 @@ def __init__( name=name, ) + def tree_flatten(self): + return ( + (self.prior_model,), + ( + self.analysis, + self.optimiser, + self.name, + ), + ) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls( + children[0], + analysis=aux_data[0], + optimiser=aux_data[1], + name=aux_data[2], + ) + def __getstate__(self): return self.__dict__ diff --git a/test_autofit/graphical/gaussian/test_declarative.py b/test_autofit/graphical/gaussian/test_declarative.py index 0dbd105a6..f59406c96 100644 --- a/test_autofit/graphical/gaussian/test_declarative.py +++ b/test_autofit/graphical/gaussian/test_declarative.py @@ -175,5 +175,12 @@ def test_prior_model_node(likelihood_model): assert isinstance(result, ep.FactorValue) -def test_pytrees(recreate, factor_model): +def test_pytrees( + recreate, + factor_model, + make_model_factor, +): recreate(factor_model) + + model_factor = make_model_factor(centre=60, sigma=15) + recreate(model_factor) From dd410240091e007958f614f2f79efd8083c2f3f8 Mon Sep 17 00:00:00 2001 From: Richard Date: Fri, 14 Feb 2025 15:52:09 +0000 Subject: [PATCH 4/9] tree flatten for LogUniformPrior --- autofit/mapper/prior/log_uniform.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/autofit/mapper/prior/log_uniform.py b/autofit/mapper/prior/log_uniform.py index baeb14a14..1068fb34e 100644 --- a/autofit/mapper/prior/log_uniform.py +++ b/autofit/mapper/prior/log_uniform.py @@ -2,6 +2,7 @@ import numpy as np +from autofit.jax_wrapper import register_pytree_node_class from autofit import exc from autofit.messages.normal import UniformNormalMessage from autofit.messages.transform import log_10_transform, LinearShiftTransform @@ -9,6 +10,7 @@ from ...messages.composed_transform import TransformedMessage +@register_pytree_node_class class LogUniformPrior(Prior): def __init__( self, @@ -67,6 +69,9 @@ def __init__( id_=id_, ) + def tree_flatten(self): + return (self.lower_limit, self.upper_limit), (self.id,) + @classmethod def with_limits(cls, lower_limit: float, upper_limit: float) -> "LogUniformPrior": """ From 39b8e8ad86fc972b526df9c64983c3c3e1d8c5d4 Mon Sep 17 00:00:00 2001 From: Richard Date: Fri, 14 Feb 2025 16:00:18 +0000 Subject: [PATCH 5/9] prior ids as children when creating pytrees --- autofit/mapper/prior/abstract.py | 2 +- autofit/mapper/prior/gaussian.py | 2 +- autofit/mapper/prior/log_gaussian.py | 11 +++++++++++ autofit/mapper/prior/log_uniform.py | 6 +++++- autofit/mapper/prior/uniform.py | 2 +- 5 files changed, 19 insertions(+), 4 deletions(-) diff --git a/autofit/mapper/prior/abstract.py b/autofit/mapper/prior/abstract.py index 91605ed2c..8d4bc66ef 100644 --- a/autofit/mapper/prior/abstract.py +++ b/autofit/mapper/prior/abstract.py @@ -61,7 +61,7 @@ def tree_unflatten(cls, aux_data, children): ------- An instance of this class """ - return cls(*children, id_=aux_data[0]) + return cls(*children) @property def lower_unit_limit(self) -> float: diff --git a/autofit/mapper/prior/gaussian.py b/autofit/mapper/prior/gaussian.py index 84f9a112c..1d3a68907 100644 --- a/autofit/mapper/prior/gaussian.py +++ b/autofit/mapper/prior/gaussian.py @@ -65,7 +65,7 @@ def __init__( ) def tree_flatten(self): - return (self.mean, self.sigma, self.lower_limit, self.upper_limit), (self.id,) + return (self.mean, self.sigma, self.lower_limit, self.upper_limit, self.id), () @classmethod def with_limits(cls, lower_limit: float, upper_limit: float) -> "GaussianPrior": diff --git a/autofit/mapper/prior/log_gaussian.py b/autofit/mapper/prior/log_gaussian.py index a02d77e7b..1cf461393 100644 --- a/autofit/mapper/prior/log_gaussian.py +++ b/autofit/mapper/prior/log_gaussian.py @@ -2,12 +2,14 @@ import numpy as np +from autofit.jax_wrapper import register_pytree_node_class from autofit.messages.normal import NormalMessage from .abstract import Prior from ...messages.composed_transform import TransformedMessage from ...messages.transform import log_transform +@register_pytree_node_class class LogGaussianPrior(Prior): __identifier_fields__ = ("lower_limit", "upper_limit", "mean", "sigma") __database_args__ = ("mean", "sigma", "lower_limit", "upper_limit", "id_") @@ -71,6 +73,15 @@ def __init__( id_=id_, ) + def tree_flatten(self): + return ( + self.mean, + self.sigma, + self.lower_limit, + self.upper_limit, + self.id, + ), () + @classmethod def with_limits(cls, lower_limit: float, upper_limit: float) -> "LogGaussianPrior": """ diff --git a/autofit/mapper/prior/log_uniform.py b/autofit/mapper/prior/log_uniform.py index 1068fb34e..5b071da85 100644 --- a/autofit/mapper/prior/log_uniform.py +++ b/autofit/mapper/prior/log_uniform.py @@ -70,7 +70,11 @@ def __init__( ) def tree_flatten(self): - return (self.lower_limit, self.upper_limit), (self.id,) + return ( + self.lower_limit, + self.upper_limit, + self.id, + ), () @classmethod def with_limits(cls, lower_limit: float, upper_limit: float) -> "LogUniformPrior": diff --git a/autofit/mapper/prior/uniform.py b/autofit/mapper/prior/uniform.py index 3cea04a90..c2d83c157 100644 --- a/autofit/mapper/prior/uniform.py +++ b/autofit/mapper/prior/uniform.py @@ -62,7 +62,7 @@ def __init__( ) def tree_flatten(self): - return (self.lower_limit, self.upper_limit), (self.id,) + return (self.lower_limit, self.upper_limit, self.id), () def with_limits( self, From 528ef9cc98a2f7ec3fb3f32f2260076b86c2cbbe Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Mon, 17 Mar 2025 13:17:15 +0000 Subject: [PATCH 6/9] main merge --- autofit/graphical/declarative/factor/analysis.py | 1 - 1 file changed, 1 deletion(-) diff --git a/autofit/graphical/declarative/factor/analysis.py b/autofit/graphical/declarative/factor/analysis.py index 67e47e5f2..c9eb22fec 100644 --- a/autofit/graphical/declarative/factor/analysis.py +++ b/autofit/graphical/declarative/factor/analysis.py @@ -88,7 +88,6 @@ def __init__( prior_variable_dict=prior_variable_dict, name=name, ) - print(name) def tree_flatten(self): return ( From a68537fa457df8d4c858f6ae538df18ed9d774de Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 15:27:48 +0100 Subject: [PATCH 7/9] specific function which generates initial sa,ples without pool --- autofit/non_linear/initializer.py | 114 +++++++++++++++++++++++++----- 1 file changed, 95 insertions(+), 19 deletions(-) diff --git a/autofit/non_linear/initializer.py b/autofit/non_linear/initializer.py index 5a762a54a..2cf0d3127 100644 --- a/autofit/non_linear/initializer.py +++ b/autofit/non_linear/initializer.py @@ -13,6 +13,8 @@ from autofit.mapper.prior_model.abstract import AbstractPriorModel from autofit.non_linear.parallel import SneakyPool +from autofit import jax_wrapper + logger = logging.getLogger(__name__) @@ -39,14 +41,14 @@ def figure_of_metric(args) -> Optional[float]: return None def samples_from_model( - self, - total_points: int, - model: AbstractPriorModel, - fitness, - paths: AbstractPaths, - use_prior_medians: bool = False, - test_mode_samples: bool = True, - n_cores: int = 1, + self, + total_points: int, + model: AbstractPriorModel, + fitness, + paths: AbstractPaths, + use_prior_medians: bool = False, + test_mode_samples: bool = True, + n_cores: int = 1, ): """ Generate the initial points of the non-linear search, by randomly drawing unit values from a uniform @@ -64,6 +66,14 @@ def samples_from_model( if os.environ.get("PYAUTOFIT_TEST_MODE") == "1" and test_mode_samples: return self.samples_in_test_mode(total_points=total_points, model=model) + if jax_wrapper.use_jax: + return self.samples_jax( + total_points=total_points, + model=model, + fitness=fitness, + use_prior_medians=use_prior_medians + ) + unit_parameter_lists = [] parameter_lists = [] figures_of_merit_list = [] @@ -92,13 +102,13 @@ def samples_from_model( unit_parameter_lists_.append(unit_parameter_list) for figure_of_merit, unit_parameter_list, parameter_list in zip( - sneaky_pool.map( - function=self.figure_of_metric, - args_list=[(fitness, parameter_list) for parameter_list in parameter_lists_], - log_info=False - ), - unit_parameter_lists_, - parameter_lists_, + sneaky_pool.map( + function=self.figure_of_metric, + args_list=[(fitness, parameter_list) for parameter_list in parameter_lists_], + log_info=False + ), + unit_parameter_lists_, + parameter_lists_, ): if figure_of_merit is not None: unit_parameter_lists.append(unit_parameter_list) @@ -106,16 +116,81 @@ def samples_from_model( figures_of_merit_list.append(figure_of_merit) if total_points > 1 and np.allclose( - a=figures_of_merit_list[0], b=figures_of_merit_list[1:] + a=figures_of_merit_list[0], b=figures_of_merit_list[1:] ): raise exc.InitializerException( """ The initial samples all have the same figure of merit (e.g. log likelihood values). - + The non-linear search will therefore not progress correctly. - + Possible causes for this behaviour are: - + + - The `log_likelihood_function` of the analysis class is defined incorrectly. + - The model parameterization creates numerically inaccurate log likelihoods. + - The`log_likelihood_function` is always returning `nan` values. + """ + ) + + logger.info(f"Initial samples generated, starting non-linear search") + + return unit_parameter_lists, parameter_lists, figures_of_merit_list + + def samples_jax( + self, + total_points: int, + model: AbstractPriorModel, + fitness, + use_prior_medians: bool = False, + ): + """ + Generate the initial points of the non-linear search, by randomly drawing unit values from a uniform + distribution between the ball_lower_limit and ball_upper_limit values. + + Parameters + ---------- + total_points + The number of points in non-linear paramemter space which initial points are created for. + model + An object that represents possible instances of some model with a given dimensionality which is the number + of free dimensions of the model. + """ + + unit_parameter_lists = [] + parameter_lists = [] + figures_of_merit_list = [] + + logger.info(f"Generating initial samples of model using JAX LH Function cores") + + while len(figures_of_merit_list) < total_points: + + if not use_prior_medians: + unit_parameter_list = self._generate_unit_parameter_list(model) + else: + unit_parameter_list = [0.5] * model.prior_count + + parameter_list = model.vector_from_unit_vector( + unit_vector=unit_parameter_list + ) + + figure_of_merit = self.figure_of_metric((fitness, parameter_list)) + + if figure_of_merit is not None: + unit_parameter_lists.append(unit_parameter_list) + parameter_lists.append(parameter_list) + figures_of_merit_list.append(figure_of_merit) + + if total_points > 1 and np.allclose( + a=figures_of_merit_list[0], b=figures_of_merit_list[1:] + ): + raise exc.InitializerException( + """ + The initial samples all have the same figure of merit (e.g. log likelihood values). + + The non-linear search will therefore not progress correctly. + + Possible causes for this behaviour are: + - The `log_likelihood_function` of the analysis class is defined incorrectly. - The model parameterization creates numerically inaccurate log likelihoods. - The`log_likelihood_function` is always returning `nan` values. @@ -321,6 +396,7 @@ def info_value_from(self, value : Tuple[float, float]) -> float: """ return (value[1] + value[0]) / 2.0 + class Initializer(AbstractInitializer): def __init__(self, lower_limit: float, upper_limit: float): """ From fba57fba892f636a35dbf5214841cf46582b77c4 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 15:37:49 +0100 Subject: [PATCH 8/9] just ignore nans currentlY, asked for help --- autofit/non_linear/fitness.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index 257fd1547..99adacb8e 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -1,4 +1,4 @@ -import numpy as np + import os from typing import Optional @@ -6,6 +6,8 @@ from autofit import exc +from autofit.jax_wrapper import numpy as np + from autofit.mapper.prior_model.abstract import AbstractPriorModel from autofit.non_linear.paths.abstract import AbstractPaths from autofit.non_linear.analysis import Analysis @@ -155,7 +157,7 @@ def __call__(self, parameters, *kwargs): instance = self.model.instance_from_vector(vector=parameters) log_likelihood = self.log_likelihood_function(instance=instance) - if np.isnan(log_likelihood): + if not jax_wrapper.use_jax and np.isnan(log_likelihood): return self.resample_figure_of_merit except exc.FitException: From 3f0fe22cab05751c630452c1501542d84e6b0363 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Tue, 1 Apr 2025 16:05:46 +0100 Subject: [PATCH 9/9] nans now handled by np.nan --- autofit/non_linear/fitness.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index 99adacb8e..1987beb6c 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -156,9 +156,7 @@ def __call__(self, parameters, *kwargs): try: instance = self.model.instance_from_vector(vector=parameters) log_likelihood = self.log_likelihood_function(instance=instance) - - if not jax_wrapper.use_jax and np.isnan(log_likelihood): - return self.resample_figure_of_merit + log_likelihood = np.where(np.isnan(log_likelihood), self.resample_figure_of_merit, log_likelihood) except exc.FitException: return self.resample_figure_of_merit