diff --git a/autofit/__init__.py b/autofit/__init__.py index 25e87125f..6bd2be25a 100644 --- a/autofit/__init__.py +++ b/autofit/__init__.py @@ -1,13 +1,15 @@ +from autoconf.dictable import register_parser +from . import conf + +conf.instance.register(__file__) + import abc import pickle - from dill import register -from autoconf.dictable import register_parser -from .non_linear.grid.grid_search import GridSearch as SearchGridSearch -from . import conf from . import exc from . import mock as m +from .non_linear.grid.grid_search import GridSearch as SearchGridSearch from .aggregator.base import AggBase from .database.aggregator.aggregator import GridSearchAggregator from .graphical.expectation_propagation.history import EPHistory @@ -54,6 +56,7 @@ from .mapper.prior import GaussianPrior from .mapper.prior import LogGaussianPrior from .mapper.prior import LogUniformPrior +from .mapper.prior import TruncatedGaussianPrior from .mapper.prior.abstract import Prior from .mapper.prior.tuple_prior import TuplePrior from .mapper.prior import UniformPrior @@ -136,6 +139,6 @@ def save_abc(pickler, obj): pickle._Pickler.save_type(pickler, obj) -conf.instance.register(__file__) + __version__ = "2025.5.10.1" diff --git a/autofit/config/general.yaml b/autofit/config/general.yaml index 9ecbb7b8b..7cd467b8e 100644 --- a/autofit/config/general.yaml +++ b/autofit/config/general.yaml @@ -1,3 +1,5 @@ +jax: + use_jax: false # If True, PyAutoFit uses JAX internally, whereas False uses normal Numpy. analysis: n_cores: 1 # The number of cores a parallelized sum of Analysis classes uses by default. hpc: @@ -5,9 +7,7 @@ hpc: iterations_per_update: 5000 # The number of iterations between every update (visualization, results output, etc) in HPC mode. inversion: check_reconstruction: true # If True, the inversion's reconstruction is checked to ensure the solution of a meshs's mapper is not an invalid solution where the values are all the same. - reconstruction_vmax_factor: 0.5 # Plots of an Inversion's reconstruction use the reconstructed data's bright value multiplied by this factor. -model: - ignore_prior_limits: false # If ``True`` the limits applied to priors will be ignored, where limits set upper / lower limits. This stops PriorLimitException's from being raised. + reconstruction_vmax_factor: 0.5 # Plots of an Inversion's reconstruction use the reconstructed data's bright value multiplied by this factor. output: force_pickle_overwrite: false # If True pickle files output by a search (e.g. samples.pickle) are recreated when a new model-fit is performed. force_visualize_overwrite: false # If True, visualization images output by a search (e.g. subplots of the fit) are recreated when a new model-fit is performed. diff --git a/autofit/config/priors/Exponential.yaml b/autofit/config/priors/Exponential.yaml index c4c387009..47f645503 100644 --- a/autofit/config/priors/Exponential.yaml +++ b/autofit/config/priors/Exponential.yaml @@ -1,5 +1,5 @@ centre: - gaussian_limits: + limits: lower: -inf upper: inf lower_limit: 0.0 @@ -9,7 +9,7 @@ centre: type: Absolute value: 20.0 normalization: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 1.0e-06 @@ -19,7 +19,7 @@ normalization: type: Relative value: 0.5 rate: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 0.0 diff --git a/autofit/config/priors/Gaussian.yaml b/autofit/config/priors/Gaussian.yaml index 29061ca82..65d7fbf01 100644 --- a/autofit/config/priors/Gaussian.yaml +++ b/autofit/config/priors/Gaussian.yaml @@ -1,5 +1,5 @@ centre: - gaussian_limits: + limits: lower: -inf upper: inf lower_limit: 0.0 @@ -9,7 +9,7 @@ centre: type: Absolute value: 20.0 normalization: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 1.0e-06 @@ -19,7 +19,7 @@ normalization: type: Relative value: 0.5 sigma: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 0.0 diff --git a/autofit/config/priors/Gaussian2D.yaml b/autofit/config/priors/Gaussian2D.yaml index 30cf44490..07fff3b27 100644 --- a/autofit/config/priors/Gaussian2D.yaml +++ b/autofit/config/priors/Gaussian2D.yaml @@ -1,5 +1,5 @@ centre_0: - gaussian_limits: + limits: lower: -inf upper: inf lower_limit: 0.0 @@ -9,7 +9,7 @@ centre_0: type: Absolute value: 20.0 centre_1: - gaussian_limits: + limits: lower: -inf upper: inf lower_limit: 0.0 @@ -19,7 +19,7 @@ centre_1: type: Absolute value: 20.0 normalization: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 1.0e-06 @@ -29,7 +29,7 @@ normalization: type: Relative value: 0.5 sigma: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 0.0 diff --git a/autofit/config/priors/GaussianKurtosis.yaml b/autofit/config/priors/GaussianKurtosis.yaml index 739de5d5c..313d53614 100644 --- a/autofit/config/priors/GaussianKurtosis.yaml +++ b/autofit/config/priors/GaussianKurtosis.yaml @@ -1,12 +1,5 @@ -GaussianPrior: - lower_limit: - type: Constant - value: -inf - upper_limit: - type: Constant - value: inf centre: - gaussian_limits: + limits: lower: -inf upper: inf lower_limit: 0.0 @@ -16,7 +9,7 @@ centre: type: Absolute value: 20.0 kurtosis: - gaussian_limits: + limits: lower: -inf upper: inf lower_limit: 0.0 @@ -26,7 +19,7 @@ kurtosis: type: Absolute value: 20.0 normalization: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 1.0e-06 @@ -36,7 +29,7 @@ normalization: type: Relative value: 0.5 sigma: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 0.0 diff --git a/autofit/config/priors/MultiLevelGaussians.yaml b/autofit/config/priors/MultiLevelGaussians.yaml index 1356bbe2a..c4bc8eb89 100644 --- a/autofit/config/priors/MultiLevelGaussians.yaml +++ b/autofit/config/priors/MultiLevelGaussians.yaml @@ -1,5 +1,5 @@ higher_level_centre: - gaussian_limits: + limits: lower: -inf upper: inf lower_limit: 0.0 diff --git a/autofit/config/priors/README.rst b/autofit/config/priors/README.rst index fab4dcd04..91c6147f3 100644 --- a/autofit/config/priors/README.rst +++ b/autofit/config/priors/README.rst @@ -13,7 +13,7 @@ They appear as follows: width_modifier: type: Absolute value: 20.0 - gaussian_limits: + limits: lower: -inf upper: inf @@ -28,9 +28,9 @@ The sections of this example config set the following: When the results of a search are passed to a subsequent search to set up the priors of its non-linear search, this entry describes how the Prior is passed. For a full description of prior passing, checkout the examples in 'autolens_workspace/examples/complex/linking'. - gaussian_limits + limits When the results of a search are passed to a subsequent search, they are passed using a GaussianPrior. The - gaussian_limits set the physical lower and upper limits of this GaussianPrior, such that parameter samples + limits set the physical lower and upper limits of this GaussianPrior, such that parameter samples can not go beyond these limits. The files ``template_module.yaml`` and ``TemplateObject.yaml`` give templates one can use to set up prior default diff --git a/autofit/config/priors/model.yaml b/autofit/config/priors/model.yaml index 7809624cf..5a89a4e98 100644 --- a/autofit/config/priors/model.yaml +++ b/autofit/config/priors/model.yaml @@ -1,6 +1,6 @@ Exponential: centre: - gaussian_limits: + limits: lower: -inf upper: inf lower_limit: 0.0 @@ -10,7 +10,7 @@ Exponential: type: Absolute value: 20.0 normalization: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 1.0e-06 @@ -20,7 +20,7 @@ Exponential: type: Relative value: 0.5 rate: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 0.0 @@ -31,7 +31,7 @@ Exponential: value: 0.5 Gaussian: centre: - gaussian_limits: + limits: lower: -inf upper: inf lower_limit: 0.0 @@ -41,7 +41,7 @@ Gaussian: type: Absolute value: 20.0 normalization: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 1.0e-06 @@ -51,7 +51,7 @@ Gaussian: type: Relative value: 0.5 sigma: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 0.0 diff --git a/autofit/config/priors/profiles.yaml b/autofit/config/priors/profiles.yaml index 7809624cf..5a89a4e98 100644 --- a/autofit/config/priors/profiles.yaml +++ b/autofit/config/priors/profiles.yaml @@ -1,6 +1,6 @@ Exponential: centre: - gaussian_limits: + limits: lower: -inf upper: inf lower_limit: 0.0 @@ -10,7 +10,7 @@ Exponential: type: Absolute value: 20.0 normalization: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 1.0e-06 @@ -20,7 +20,7 @@ Exponential: type: Relative value: 0.5 rate: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 0.0 @@ -31,7 +31,7 @@ Exponential: value: 0.5 Gaussian: centre: - gaussian_limits: + limits: lower: -inf upper: inf lower_limit: 0.0 @@ -41,7 +41,7 @@ Gaussian: type: Absolute value: 20.0 normalization: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 1.0e-06 @@ -51,7 +51,7 @@ Gaussian: type: Relative value: 0.5 sigma: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 0.0 diff --git a/autofit/config/priors/template.yaml b/autofit/config/priors/template.yaml index 82f5513c9..1a34ec6a4 100644 --- a/autofit/config/priors/template.yaml +++ b/autofit/config/priors/template.yaml @@ -1,6 +1,6 @@ ModelComponent0: parameter0: - gaussian_limits: + limits: lower: -inf upper: inf lower_limit: 0.0 @@ -10,7 +10,7 @@ ModelComponent0: type: Absolute value: 20.0 parameter1: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 1.0e-06 @@ -20,7 +20,7 @@ ModelComponent0: type: Relative value: 0.5 parameter2: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 0.0 @@ -31,7 +31,7 @@ ModelComponent0: value: 0.5 ModelComponent1: parameter0: - gaussian_limits: + limits: lower: -inf upper: inf lower_limit: 0.0 @@ -41,7 +41,7 @@ ModelComponent1: type: Absolute value: 20.0 parameter1: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 1.0e-06 @@ -51,7 +51,7 @@ ModelComponent1: type: Relative value: 0.5 parameter2: - gaussian_limits: + limits: lower: 0.0 upper: inf lower_limit: 0.0 diff --git a/autofit/database/model/prior.py b/autofit/database/model/prior.py index 5e68565b3..495c424c1 100644 --- a/autofit/database/model/prior.py +++ b/autofit/database/model/prior.py @@ -100,6 +100,7 @@ class Prior(Object): def _from_object(cls, model: abstract.Prior): instance = cls() instance.cls = type(model) + print(model.__database_args__) instance._add_children( [(key, getattr(model, key)) for key in model.__database_args__] ) diff --git a/autofit/exc.py b/autofit/exc.py index 744918f73..86a427419 100644 --- a/autofit/exc.py +++ b/autofit/exc.py @@ -19,10 +19,6 @@ class FitException(Exception): pass -class PriorLimitException(FitException, PriorException): - pass - - class PipelineException(Exception): pass diff --git a/autofit/graphical/factor_graphs/factor.py b/autofit/graphical/factor_graphs/factor.py index b36ff0fc4..c5a5752f6 100644 --- a/autofit/graphical/factor_graphs/factor.py +++ b/autofit/graphical/factor_graphs/factor.py @@ -1,16 +1,10 @@ from copy import deepcopy from inspect import getfullargspec +import jax from typing import Tuple, Dict, Any, Callable, Union, List, Optional, TYPE_CHECKING import numpy as np -try: - import jax - - _HAS_JAX = True -except ImportError: - _HAS_JAX = False - from autofit.graphical.utils import ( nested_filter, to_variabledata, @@ -294,13 +288,7 @@ def _set_jacobians( self._vjp = vjp self._jacfwd = jacfwd if vjp or factor_vjp: - if factor_vjp: - self._factor_vjp = factor_vjp - elif not _HAS_JAX: - raise ModuleNotFoundError( - "jax needed if `factor_vjp` not passed with vjp=True" - ) - + self._factor_vjp = factor_vjp self.func_jacobian = self._vjp_func_jacobian else: # This is set by default @@ -312,11 +300,10 @@ def _set_jacobians( self._jacobian = jacobian elif numerical_jacobian: self._factor_jacobian = self._numerical_factor_jacobian - elif _HAS_JAX: - if jacfwd: - self._jacobian = jax.jacfwd(self._factor, range(self.n_args)) - else: - self._jacobian = jax.jacobian(self._factor, range(self.n_args)) + elif jacfwd: + self._jacobian = jax.jacfwd(self._factor, range(self.n_args)) + else: + self._jacobian = jax.jacobian(self._factor, range(self.n_args)) def _factor_value(self, raw_fval) -> FactorValue: """Converts the raw output of the factor into a `FactorValue` diff --git a/autofit/graphical/factor_graphs/jacobians.py b/autofit/graphical/factor_graphs/jacobians.py index 3b83a90d7..5e07b74d6 100644 --- a/autofit/graphical/factor_graphs/jacobians.py +++ b/autofit/graphical/factor_graphs/jacobians.py @@ -1,11 +1,3 @@ -try: - import jax - - _HAS_JAX = True -except ImportError: - _HAS_JAX = False - - import numpy as np from autoconf import cached_property diff --git a/autofit/jax_wrapper.py b/autofit/jax_wrapper.py index b64e27224..92e54f1ac 100644 --- a/autofit/jax_wrapper.py +++ b/autofit/jax_wrapper.py @@ -1,29 +1,45 @@ """ Allows the user to switch between using NumPy and JAX for linear algebra operations. -If USE_JAX=1 then JAX's NumPy is used, otherwise vanilla NumPy is used. +If USE_JAX=true in general.yaml then JAX's NumPy is used, otherwise vanilla NumPy is used. """ -from os import environ +import jax -use_jax = environ.get("USE_JAX", "0") == "1" +from autoconf import conf + +use_jax = conf.instance["general"]["jax"]["use_jax"] if use_jax: - try: - import jax - from jax import numpy - def jit(function, *args, **kwargs): - return jax.jit(function, *args, **kwargs) + from jax import numpy + + print( + + """ +***JAX ENABLED*** + +Using JAX for grad/jit and GPU/TPU acceleration. +To disable JAX, set: config -> general -> jax -> use_jax = false + """) - def grad(function, *args, **kwargs): - return jax.grad(function, *args, **kwargs) + def jit(function, *args, **kwargs): + return jax.jit(function, *args, **kwargs) + + def grad(function, *args, **kwargs): + return jax.grad(function, *args, **kwargs) + + from jax._src.scipy.special import erfinv - print("JAX mode enabled") - except ImportError: - raise ImportError( - "JAX is not installed. Please install it with `pip install jax`." - ) else: + + print( + """ +***JAX DISABLED*** + +Falling back to standard NumPy (no grad/jit or GPU support). +To enable JAX (if supported), set: config -> general -> jax -> use_jax = true + """) + import numpy # noqa from scipy.special.cython_special import erfinv # noqa @@ -33,20 +49,8 @@ def jit(function, *_, **__): def grad(function, *_, **__): return function -try: - from jax._src.tree_util import ( - register_pytree_node_class as register_pytree_node_class, - register_pytree_node as register_pytree_node, - ) - from jax._src.scipy.special import erfinv - -except ImportError: - - def register_pytree_node_class(cls): - return cls - - def register_pytree_node(*_, **__): - def decorator(cls): - return cls +from jax._src.tree_util import ( + register_pytree_node_class as register_pytree_node_class, + register_pytree_node as register_pytree_node, +) - return decorator diff --git a/autofit/mapper/prior/__init__.py b/autofit/mapper/prior/__init__.py index a54086622..8c5af8d33 100644 --- a/autofit/mapper/prior/__init__.py +++ b/autofit/mapper/prior/__init__.py @@ -1,4 +1,5 @@ from .gaussian import GaussianPrior +from .truncated_gaussian import TruncatedGaussianPrior from .log_gaussian import LogGaussianPrior from .log_uniform import LogUniformPrior from .uniform import UniformPrior diff --git a/autofit/mapper/prior/abstract.py b/autofit/mapper/prior/abstract.py index f0edd8a46..de5e852f1 100644 --- a/autofit/mapper/prior/abstract.py +++ b/autofit/mapper/prior/abstract.py @@ -2,6 +2,7 @@ import random from abc import ABC, abstractmethod from copy import copy +import jax from typing import Union, Tuple, Optional, Dict from autoconf import conf @@ -15,33 +16,24 @@ class Prior(Variable, ABC, ArithmeticMixin): - __database_args__ = ("lower_limit", "upper_limit", "id_") + __database_args__ = ("id_") _ids = itertools.count() - def __init__(self, message, lower_limit=0.0, upper_limit=1.0, id_=None): + def __init__(self, message, id_=None): """ An object used to mappers a unit value to an attribute value for a specific class attribute. Parameters ---------- - lower_limit: Float - The lowest value this prior can return - upper_limit: Float - The highest value this prior can return + message + """ super().__init__(id_=id_) self.message = message message.id_ = self.id - self.lower_limit = float(lower_limit) - self.upper_limit = float(upper_limit) - if self.lower_limit >= self.upper_limit: - raise exc.PriorException( - "The upper limit of a prior must be greater than its lower limit" - ) - self.width_modifier = None @classmethod @@ -63,20 +55,6 @@ def tree_unflatten(cls, aux_data, children): """ return cls(*children) - @property - def lower_unit_limit(self) -> float: - """ - The lower limit for this prior in unit vector space - """ - return self.unit_value_for(self.lower_limit) - - @property - def upper_unit_limit(self) -> float: - """ - The upper limit for this prior in unit vector space - """ - return self.unit_value_for(self.upper_limit) - def unit_value_for(self, physical_value: float) -> float: """ Compute the unit value between 0 and 1 for the physical value. @@ -109,31 +87,6 @@ def factor(self): """ return self.message.factor - def assert_within_limits(self, value): - - def exception_message(): - raise exc.PriorLimitException( - "The physical value {} for a prior " - "was not within its limits {}, {}".format( - value, self.lower_limit, self.upper_limit - ) - ) - - if jax_wrapper.use_jax: - import jax - jax.lax.cond( - jax.numpy.logical_or( - value < self.lower_limit, - value > self.upper_limit - ), - lambda _: jax.debug.callback(exception_message), - lambda _: None, - None - ) - - elif not (self.lower_limit <= value <= self.upper_limit): - exception_message() - @staticmethod def for_class_and_attribute_name(cls, attribute_name): prior_dict = conf.instance.prior_config.for_class_and_suffix_path( @@ -141,26 +94,22 @@ def for_class_and_attribute_name(cls, attribute_name): ) return Prior.from_dict(prior_dict) - @property - def width(self): - return self.upper_limit - self.lower_limit - def random( self, - lower_limit=0.0, - upper_limit=1.0, + lower_limit : float = 0.0, + upper_limit : float = 1.0 ) -> float: """ A random value sampled from this prior """ return self.value_for( random.uniform( - max(lower_limit, self.lower_unit_limit), - min(upper_limit, self.upper_unit_limit), + lower_limit, + upper_limit, ) ) - def value_for(self, unit: float, ignore_prior_limits=False) -> float: + def value_for(self, unit: float) -> float: """ Return a physical value for a value between 0 and 1 with the transformation described by this prior. @@ -174,10 +123,7 @@ def value_for(self, unit: float, ignore_prior_limits=False) -> float: ------- A physical value, mapped from the unit value accoridng to the prior. """ - result = self.message.value_for(unit) - if not ignore_prior_limits: - self.assert_within_limits(result) - return result + return self.message.value_for(unit) def instance_for_arguments( self, @@ -193,8 +139,6 @@ def project(self, samples, weights): samples=samples, log_weight_list=weights, id_=self.id, - lower_limit=self.lower_limit, - upper_limit=self.upper_limit, ) return result @@ -216,8 +160,8 @@ def __hash__(self): return hash(self.id) def __repr__(self): - return "<{} id={} lower_limit={} upper_limit={}>".format( - self.__class__.__name__, self.id, self.lower_limit, self.upper_limit + return "<{} id={}>".format( + self.__class__.__name__, self.id, ) def __str__(self): @@ -268,12 +212,14 @@ def from_dict( from .log_uniform import LogUniformPrior from .gaussian import GaussianPrior from .log_gaussian import LogGaussianPrior + from .truncated_gaussian import TruncatedGaussianPrior prior_type_dict = { "Uniform": UniformPrior, "LogUniform": LogUniformPrior, "Gaussian": GaussianPrior, "LogGaussian": LogGaussianPrior, + "TruncatedGaussian" : TruncatedGaussianPrior, "Constant": Constant, } @@ -282,7 +228,7 @@ def from_dict( **{ key: value for key, value in prior_dict.items() - if key not in ("type", "width_modifier", "gaussian_limits", "id") + if key not in ("type", "width_modifier", "limits", "id") }, ) if id_ is not None: @@ -294,8 +240,6 @@ def dict(self) -> dict: A dictionary representation of this prior """ prior_dict = { - "lower_limit": self.lower_limit, - "upper_limit": self.upper_limit, "type": self.name_of_class(), "id": self.id, } @@ -310,7 +254,7 @@ def name_of_class(cls) -> str: @property def limits(self) -> Tuple[float, float]: - return self.lower_limit, self.upper_limit + return (float("-inf"), float("inf")) def gaussian_prior_model_for_arguments(self, arguments): return arguments[self] diff --git a/autofit/mapper/prior/gaussian.py b/autofit/mapper/prior/gaussian.py index 1d3a68907..d8bdee3f9 100644 --- a/autofit/mapper/prior/gaussian.py +++ b/autofit/mapper/prior/gaussian.py @@ -8,64 +8,57 @@ @register_pytree_node_class class GaussianPrior(Prior): - __identifier_fields__ = ("lower_limit", "upper_limit", "mean", "sigma") - __database_args__ = ("mean", "sigma", "lower_limit", "upper_limit", "id_") + __identifier_fields__ = ("mean", "sigma") + __database_args__ = ("mean", "sigma", "id_") def __init__( self, mean: float, sigma: float, - lower_limit: float = float("-inf"), - upper_limit: float = float("inf"), id_: Optional[int] = None, ): """ - A prior with a uniform distribution, defined between a lower limit and upper limit. + A Gaussian prior defined by a normal distribution. - The conversion of an input unit value, ``u``, to a physical value, ``p``, via the prior is as follows: + The prior transforms a unit interval input `u` in [0, 1] into a physical parameter `p` via + the inverse error function (erfcinv) based on the Gaussian CDF: .. math:: + p = \mu + \sigma \sqrt{2} \, \mathrm{erfcinv}(2 \times (1 - u)) - p = \mu + (\sigma * sqrt(2) * erfcinv(2.0 * (1.0 - u)) + where :math:`\mu` is the mean and :math:`\sigma` the standard deviation. - For example for ``prior = GaussianPrior(mean=1.0, sigma=2.0)``, an - input ``prior.value_for(unit=0.5)`` is equal to 1.0. + For example, with `mean=1.0` and `sigma=2.0`, the value at `u=0.5` corresponds to the mean, 1.0. - The mapping is performed using the message functionality, where a message represents the distirubtion - of this prior. + This mapping is implemented using a NormalMessage instance, encapsulating + the Gaussian distribution and any specified truncation limits. Parameters ---------- mean - The mean of the Gaussian distribution defining the prior. + The mean (center) of the Gaussian prior distribution. sigma - The sigma value of the Gaussian distribution defining the prior. - lower_limit - A lower limit of the Gaussian distribution; physical values below this value are rejected. - upper_limit - A upper limit of the Gaussian distribution; physical values below this value are rejected. + The standard deviation (spread) of the Gaussian prior. + id_ : Optional[int], optional + Optional identifier for the prior instance. Examples -------- + Create a GaussianPrior with mean 1.0, sigma 2.0, truncated between 0.0 and 2.0: - prior = af.GaussianPrior(mean=1.0, sigma=2.0, lower_limit=0.0, upper_limit=2.0) - - physical_value = prior.value_for(unit=0.5) + >>> prior = GaussianPrior(mean=1.0, sigma=2.0) + >>> physical_value = prior.value_for(unit=0.5) # Returns ~1.0 (mean) """ super().__init__( - lower_limit=lower_limit, - upper_limit=upper_limit, message=NormalMessage( mean=mean, sigma=sigma, - lower_limit=lower_limit, - upper_limit=upper_limit, ), id_=id_, ) def tree_flatten(self): - return (self.mean, self.sigma, self.lower_limit, self.upper_limit, self.id), () + return (self.mean, self.sigma, self.id), () @classmethod def with_limits(cls, lower_limit: float, upper_limit: float) -> "GaussianPrior": @@ -98,11 +91,19 @@ def with_limits(cls, lower_limit: float, upper_limit: float) -> "GaussianPrior": def dict(self) -> dict: """ - A dictionary representation of this prior + Return a dictionary representation of this GaussianPrior instance, + including mean and sigma. + + Returns + ------- + Dictionary containing prior parameters. """ prior_dict = super().dict() return {**prior_dict, "mean": self.mean, "sigma": self.sigma} @property def parameter_string(self) -> str: + """ + Return a human-readable string summarizing the GaussianPrior parameters. + """ return f"mean = {self.mean}, sigma = {self.sigma}" diff --git a/autofit/mapper/prior/log_gaussian.py b/autofit/mapper/prior/log_gaussian.py index 1cf461393..aaab73c5e 100644 --- a/autofit/mapper/prior/log_gaussian.py +++ b/autofit/mapper/prior/log_gaussian.py @@ -11,15 +11,13 @@ @register_pytree_node_class class LogGaussianPrior(Prior): - __identifier_fields__ = ("lower_limit", "upper_limit", "mean", "sigma") - __database_args__ = ("mean", "sigma", "lower_limit", "upper_limit", "id_") + __identifier_fields__ = ("mean", "sigma") + __database_args__ = ("mean", "sigma", "id_") def __init__( self, mean: float, sigma: float, - lower_limit: float = 0.0, - upper_limit: float = float("inf"), id_: Optional[int] = None, ): """ @@ -43,20 +41,14 @@ def __init__( sigma The spread of this distribution in *natural log* space, e.g. sigma=1.0 means P(ln x) has a standard deviation of 1. - lower_limit - A lower limit in *real space* (not log); physical values below this are rejected. - upper_limit - A upper limit in *real space* (not log); physical values above this are rejected. Examples -------- - prior = af.LogGaussianPrior(mean=1.0, sigma=2.0, lower_limit=0.0, upper_limit=2.0) + prior = af.LogGaussianPrior(mean=1.0, sigma=2.0) physical_value = prior.value_for(unit=0.5) """ - lower_limit = float(lower_limit) - upper_limit = float(upper_limit) self.mean = mean self.sigma = sigma @@ -68,8 +60,6 @@ def __init__( super().__init__( message=message, - lower_limit=lower_limit, - upper_limit=upper_limit, id_=id_, ) @@ -77,8 +67,6 @@ def tree_flatten(self): return ( self.mean, self.sigma, - self.lower_limit, - self.upper_limit, self.id, ), () @@ -127,7 +115,7 @@ def _new_for_base_message(self, message): id_=self.instance().id, ) - def value_for(self, unit: float, ignore_prior_limits: bool = False) -> float: + def value_for(self, unit: float) -> float: """ Return a physical value for a value between 0 and 1 with the transformation described by this prior. @@ -141,7 +129,7 @@ def value_for(self, unit: float, ignore_prior_limits: bool = False) -> float: ------- A physical value, mapped from the unit value accoridng to the prior. """ - return super().value_for(unit, ignore_prior_limits=ignore_prior_limits) + return super().value_for(unit) @property def parameter_string(self) -> str: diff --git a/autofit/mapper/prior/log_uniform.py b/autofit/mapper/prior/log_uniform.py index 5b071da85..ffcb33912 100644 --- a/autofit/mapper/prior/log_uniform.py +++ b/autofit/mapper/prior/log_uniform.py @@ -1,17 +1,20 @@ -from typing import Optional +from typing import Optional, Tuple import numpy as np from autofit.jax_wrapper import register_pytree_node_class -from autofit import exc from autofit.messages.normal import UniformNormalMessage from autofit.messages.transform import log_10_transform, LinearShiftTransform from .abstract import Prior from ...messages.composed_transform import TransformedMessage +from autofit import exc @register_pytree_node_class class LogUniformPrior(Prior): + __identifier_fields__ = ("lower_limit", "upper_limit") + __database_args__ = ("lower_limit", "upper_limit", "id_") + def __init__( self, lower_limit: float = 1e-6, @@ -45,27 +48,29 @@ def __init__( physical_value = prior.value_for(unit=0.2) """ - if lower_limit <= 0.0: + self.lower_limit = float(lower_limit) + self.upper_limit = float(upper_limit) + + if self.lower_limit <= 0.0: raise exc.PriorException( "The lower limit of a LogUniformPrior cannot be zero or negative." ) - - lower_limit = float(lower_limit) - upper_limit = float(upper_limit) + if self.lower_limit >= self.upper_limit: + raise exc.PriorException( + "The upper limit of a prior must be greater than its lower limit" + ) message = TransformedMessage( UniformNormalMessage, LinearShiftTransform( - shift=np.log10(lower_limit), - scale=np.log10(upper_limit / lower_limit), + shift=np.log10(self.lower_limit), + scale=np.log10(self.upper_limit / self.lower_limit), ), log_10_transform, ) super().__init__( message=message, - lower_limit=lower_limit, - upper_limit=upper_limit, id_=id_, ) @@ -121,7 +126,7 @@ def log_prior_from_value(self, value) -> float: """ return 1.0 / value - def value_for(self, unit: float, ignore_prior_limits: bool = False) -> float: + def value_for(self, unit: float) -> float: """ Returns a physical value from an input unit value according to the limits of the log10 uniform prior. @@ -142,7 +147,23 @@ def value_for(self, unit: float, ignore_prior_limits: bool = False) -> float: physical_value = prior.value_for(unit=0.2) """ - return super().value_for(unit, ignore_prior_limits=ignore_prior_limits) + return super().value_for(unit) + + def dict(self) -> dict: + """ + Return a dictionary representation of this GaussianPrior instance, + including mean and sigma. + + Returns + ------- + Dictionary containing prior parameters. + """ + prior_dict = super().dict() + return {**prior_dict, "lower_limit": self.lower_limit, "upper_limit": self.upper_limit} + + @property + def limits(self) -> Tuple[float, float]: + return self.lower_limit, self.upper_limit @property def parameter_string(self) -> str: diff --git a/autofit/mapper/prior/truncated_gaussian.py b/autofit/mapper/prior/truncated_gaussian.py new file mode 100644 index 000000000..5cd3b9813 --- /dev/null +++ b/autofit/mapper/prior/truncated_gaussian.py @@ -0,0 +1,134 @@ +from typing import Optional, Tuple + +from autofit.jax_wrapper import register_pytree_node_class + +from autofit.messages.truncated_normal import TruncatedNormalMessage +from .abstract import Prior + + +@register_pytree_node_class +class TruncatedGaussianPrior(Prior): + __identifier_fields__ = ("lower_limit", "upper_limit", "mean", "sigma") + __database_args__ = ("mean", "sigma", "lower_limit", "upper_limit", "id_") + + def __init__( + self, + mean: float, + sigma: float, + lower_limit: float = float("-inf"), + upper_limit: float = float("inf"), + id_: Optional[int] = None, + ): + """ + A Gaussian prior defined by a normal distribution with optional truncation limits. + + This prior represents a Gaussian (normal) distribution with mean `mean` + and standard deviation `sigma`, optionally truncated between `lower_limit` + and `upper_limit`. The transformation from a unit interval input `u ∈ [0, 1]` + to a physical parameter value `p` uses the inverse error function (erfcinv) applied + to the Gaussian CDF, adjusted for truncation: + + .. math:: + p = \mu + \sigma \sqrt{2} \, \mathrm{erfcinv}(2 \times (1 - u)) + + where :math:`\mu` is the mean and :math:`\sigma` the standard deviation. + + If truncation limits are specified, values outside the interval + [`lower_limit`, `upper_limit`] are disallowed and the distribution is + normalized over this interval. + + Parameters + ---------- + mean + The mean (center) of the Gaussian prior distribution. + sigma + The standard deviation (spread) of the Gaussian prior. + lower_limit : float, optional + The lower truncation limit (default: -∞). + upper_limit : float, optional + The upper truncation limit (default: +∞). + id_ : Optional[int], optional + Optional identifier for the prior instance. + + Examples + -------- + Create a TruncatedGaussianPrior with mean 1.0, sigma 2.0, truncated between 0.0 and 2.0: + + >>> prior = TruncatedGaussianPrior(mean=1.0, sigma=2.0, lower_limit=0.0, upper_limit=2.0) + >>> physical_value = prior.value_for(unit=0.5) # Returns a value near 1.0 (mean) + """ + super().__init__( + message=TruncatedNormalMessage( + mean=mean, + sigma=sigma, + lower_limit=lower_limit, + upper_limit=upper_limit, + ), + id_=id_, + ) + + def tree_flatten(self): + return (self.mean, self.sigma, self.lower_limit, self.upper_limit, self.id), () + + @classmethod + def with_limits(cls, lower_limit: float, upper_limit: float) -> "TruncatedGaussianPrior": + """ + Create a new truncated gaussian prior centred between two limits + with sigma distance between this limits. + + Note that these limits are not strict so exceptions will not + be raised for values outside of the limits. + + This function is typically used in prior passing, where the + result of a model-fit are used to create new Gaussian priors + centred on the previously estimated median PDF model. + + Parameters + ---------- + lower_limit + The lower limit of the new Gaussian prior. + upper_limit + The upper limit of the new Gaussian Prior. + + Returns + ------- + A new prior instance centered between the limits. + """ + return cls( + mean=(lower_limit + upper_limit) / 2, + sigma=(upper_limit - lower_limit), + lower_limit=lower_limit, + upper_limit=upper_limit, + ) + + def dict(self) -> dict: + """ + Return a dictionary representation of this GaussianPrior instance, + including mean and sigma. + + Returns + ------- + Dictionary containing prior parameters. + """ + prior_dict = super().dict() + return { + **prior_dict, "mean": self.mean, + "sigma": self.sigma, + "lower_limit": self.lower_limit, + "upper_limit": self.upper_limit + } + + @property + def limits(self) -> Tuple[float, float]: + return self.lower_limit, self.upper_limit + + @property + def parameter_string(self) -> str: + """ + Return a human-readable string summarizing the GaussianPrior parameters. + """ + return (f"mean = {self.mean}, " + f"sigma = {self.sigma}, " + f"lower_limit = {self.lower_limit}, " + f"upper_limit = {self.upper_limit}" + ) diff --git a/autofit/mapper/prior/uniform.py b/autofit/mapper/prior/uniform.py index c2d83c157..0e240eb04 100644 --- a/autofit/mapper/prior/uniform.py +++ b/autofit/mapper/prior/uniform.py @@ -1,5 +1,5 @@ from autofit.jax_wrapper import register_pytree_node_class -from typing import Optional +from typing import Optional, Tuple from autofit.messages.normal import UniformNormalMessage from .abstract import Prior @@ -7,10 +7,12 @@ from ...messages.composed_transform import TransformedMessage from ...messages.transform import LinearShiftTransform +from autofit import exc @register_pytree_node_class class UniformPrior(Prior): __identifier_fields__ = ("lower_limit", "upper_limit") + __database_args__ = ("lower_limit", "upper_limit", "id_") def __init__( self, @@ -45,25 +47,30 @@ def __init__( physical_value = prior.value_for(unit=0.2) """ - lower_limit = float(lower_limit) - upper_limit = float(upper_limit) + self.lower_limit = float(lower_limit) + self.upper_limit = float(upper_limit) + + if self.lower_limit >= self.upper_limit: + raise exc.PriorException( + "The upper limit of a prior must be greater than its lower limit" + ) message = TransformedMessage( UniformNormalMessage, - LinearShiftTransform(shift=lower_limit, scale=upper_limit - lower_limit), - lower_limit=lower_limit, - upper_limit=upper_limit, + LinearShiftTransform(shift=self.lower_limit, scale=self.upper_limit - self.lower_limit), ) super().__init__( message, - lower_limit=lower_limit, - upper_limit=upper_limit, id_=id_, ) def tree_flatten(self): return (self.lower_limit, self.upper_limit, self.id), () + @property + def width(self): + return self.upper_limit - self.lower_limit + def with_limits( self, lower_limit: float, @@ -82,11 +89,23 @@ def logpdf(self, x): x -= epsilon return self.message.logpdf(x) + def dict(self) -> dict: + """ + Return a dictionary representation of this GaussianPrior instance, + including mean and sigma. + + Returns + ------- + Dictionary containing prior parameters. + """ + prior_dict = super().dict() + return {**prior_dict, "lower_limit": self.lower_limit, "upper_limit": self.upper_limit} + @property def parameter_string(self) -> str: return f"lower_limit = {self.lower_limit}, upper_limit = {self.upper_limit}" - def value_for(self, unit: float, ignore_prior_limits: bool = False) -> float: + def value_for(self, unit: float) -> float: """ Returns a physical value from an input unit value according to the limits of the uniform prior. @@ -108,7 +127,7 @@ def value_for(self, unit: float, ignore_prior_limits: bool = False) -> float: physical_value = prior.value_for(unit=0.2) """ return float( - round(super().value_for(unit, ignore_prior_limits=ignore_prior_limits), 14) + round(super().value_for(unit), 14) ) def log_prior_from_value(self, value): @@ -121,3 +140,7 @@ def log_prior_from_value(self, value): For a UniformPrior this is always zero, provided the value is between the lower and upper limit. """ return 0.0 + + @property + def limits(self) -> Tuple[float, float]: + return self.lower_limit, self.upper_limit \ No newline at end of file diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index a2659aa04..3af441ad8 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -1,5 +1,7 @@ import copy import inspect +import jax.numpy as jnp +import jax import json import logging import random @@ -12,9 +14,11 @@ from autoconf import conf from autoconf.exc import ConfigException from autofit import exc +from autofit import jax_wrapper from autofit.mapper import model from autofit.mapper.model import AbstractModel, frozen_cache from autofit.mapper.prior import GaussianPrior +from autofit.mapper.prior import TruncatedGaussianPrior from autofit.mapper.prior import UniformPrior from autofit.mapper.prior.abstract import Prior from autofit.mapper.prior.constant import Constant @@ -42,7 +46,7 @@ class Limits: @staticmethod def for_class_and_attributes_name(cls, attribute_name): limit_dict = conf.instance.prior_config.for_class_and_suffix_path( - cls, [attribute_name, "gaussian_limits"] + cls, [attribute_name, "limits"] ) return limit_dict["lower"], limit_dict["upper"] @@ -519,7 +523,7 @@ def assert_no_assertions(obj): except AttributeError: pass - def instance_from_unit_vector(self, unit_vector, ignore_prior_limits=False): + def instance_from_unit_vector(self, unit_vector, ignore_assertions : bool = False): """ Returns a ModelInstance, which has an attribute and class instance corresponding to every `Model` attributed to this instance. @@ -527,14 +531,15 @@ def instance_from_unit_vector(self, unit_vector, ignore_prior_limits=False): physical values via their priors. Parameters ---------- - ignore_prior_limits - If true then no exception is thrown if priors fall outside defined limits unit_vector: [float] A unit hypercube vector that is mapped to an instance of physical values via the priors. Returns ------- model_instance : autofit.mapper.model.ModelInstance An object containing reconstructed model_mapper instances + ignore_assertions + If True, the assertions attached to this model (e.g. that one parameter > another parameter) are ignored. + Raises ------ exc.FitException @@ -560,7 +565,6 @@ def instance_from_unit_vector(self, unit_vector, ignore_prior_limits=False): prior_tuple.prior, prior_tuple.prior.value_for( unit, - ignore_prior_limits=ignore_prior_limits, ), ), self.prior_tuples_ordered_by_id, @@ -570,7 +574,7 @@ def instance_from_unit_vector(self, unit_vector, ignore_prior_limits=False): return self.instance_for_arguments( arguments, - ignore_assertions=ignore_prior_limits, + ignore_assertions=ignore_assertions ) @property @@ -608,16 +612,12 @@ def prior_tuples_ordered_by_id(self): def priors_ordered_by_id(self): return [prior for _, prior in self.prior_tuples_ordered_by_id] - def vector_from_unit_vector(self, unit_vector, ignore_prior_limits=False): + def vector_from_unit_vector(self, unit_vector): """ Parameters ---------- unit_vector: [float] A unit hypercube vector - ignore_prior_limits - Set to True to prevent an exception being raised if - the physical value of a prior is outside the allowable - limits Returns ------- @@ -627,7 +627,7 @@ def vector_from_unit_vector(self, unit_vector, ignore_prior_limits=False): return list( map( lambda prior_tuple, unit: prior_tuple.prior.value_for( - unit, ignore_prior_limits=ignore_prior_limits + unit, ), self.prior_tuples_ordered_by_id, unit_vector, @@ -647,8 +647,8 @@ def random_unit_vector_within_limits( """ return [ random.uniform( - max(lower_limit, prior.lower_unit_limit), - min(upper_limit, prior.upper_unit_limit), + lower_limit, + upper_limit, ) for prior in self.priors_ordered_by_id ] @@ -749,7 +749,7 @@ def physical_values_from_prior_medians(self): """ return self.vector_from_unit_vector([0.5] * len(self.unique_prior_tuples)) - def instance_from_vector(self, vector, ignore_prior_limits=False): + def instance_from_vector(self, vector, ignore_assertions: bool = False): """ Returns a ModelInstance, which has an attribute and class instance corresponding to every `Model` attributed to this instance. @@ -759,8 +759,8 @@ def instance_from_vector(self, vector, ignore_prior_limits=False): ---------- vector: [float] A vector of physical parameter values that is mapped to an instance. - ignore_prior_limits - If True don't check that physical values are within expected limits. + ignore_assertions + If True, any assertions attached to this object are ignored and not checked. Returns ------- @@ -779,13 +779,9 @@ def instance_from_vector(self, vector, ignore_prior_limits=False): ) ) - if not ignore_prior_limits: - for prior, value in arguments.items(): - prior.assert_within_limits(value) - return self.instance_for_arguments( arguments, - ignore_assertions=ignore_prior_limits, + ignore_assertions=ignore_assertions, ) def has(self, cls: Union[Type, Tuple[Type, ...]]) -> bool: @@ -921,8 +917,6 @@ def mapper_from_prior_means(self, means, a=None, r=None, no_limits=False): ---------- means The median PDF value of every Gaussian, which centres each `GaussianPrior`. - no_limits - If `True` generated priors have infinite limits r The relative width to be assigned to gaussian priors a @@ -976,7 +970,7 @@ def mapper_from_prior_means(self, means, a=None, r=None, no_limits=False): sigma = width - new_prior = GaussianPrior(mean, sigma, *limits) + new_prior = TruncatedGaussianPrior(mean, sigma, *limits) new_prior.id = prior.id new_prior.width_modifier = prior.width_modifier arguments[prior] = new_prior @@ -1036,17 +1030,18 @@ def mapper_from_uniform_floats(self, floats, b): return self.mapper_from_prior_arguments(arguments) - def instance_from_prior_medians(self, ignore_prior_limits=False): + def instance_from_prior_medians(self, ignore_assertions : bool = False): """ Returns a list of physical values from the median values of the priors. Returns ------- - physical_values : [float] - A list of physical values + ignore_assertions + If True, the assertions attached to this model (e.g. that one parameter > another parameter) are ignored + and not checked. """ return self.instance_from_unit_vector( unit_vector=[0.5] * self.prior_count, - ignore_prior_limits=ignore_prior_limits, + ignore_assertions=ignore_assertions ) def log_prior_list_from_vector( @@ -1075,18 +1070,22 @@ def log_prior_list_from_vector( ) ) - def random_instance(self, ignore_prior_limits=False): + def random_instance(self, ignore_assertions : bool = False): """ Returns a random instance of the model. + + Parameters + ---------- + ignore_assertions + If True, the assertions attached to this model (e.g. that one parameter > another parameter) are ignored. """ - logger.debug(f"Creating a random instance") - if ignore_prior_limits: + if ignore_assertions: return self.instance_from_unit_vector( unit_vector=[random.random() for _ in range(self.prior_count)], - ignore_prior_limits=ignore_prior_limits, + ignore_assertions=ignore_assertions, ) - return self.instance_for_arguments( - {prior: prior.random() for prior in self.priors} + return self.instance_from_unit_vector( + unit_vector=[random.random() for _ in range(self.prior_count)], ) @staticmethod @@ -1294,7 +1293,7 @@ def instance_for_arguments( arguments Dictionary mapping priors to attribute analysis_path and value pairs ignore_assertions - If True, assertions will not be checked + If True, the assertions attached to this model (e.g. that one parameter > another parameter) are ignored. Returns ------- @@ -1371,7 +1370,7 @@ def instance_from_prior_name_arguments( name of the prior have been joined by underscores, mapped to corresponding values. ignore_assertions - If True, assertions will not be checked + If True, the assertions attached to this model (e.g. that one parameter > another parameter) are ignored. Returns ------- @@ -1402,7 +1401,7 @@ def instance_from_path_arguments( specified once. If multiple paths for the same prior are specified then the value for the last path will be used. ignore_assertions - If True, assertions will not be checked + If True, the assertions attached to this model (e.g. that one parameter > another parameter) are ignored. Returns ------- diff --git a/autofit/messages/abstract.py b/autofit/messages/abstract.py index a3f3a3808..bdc7fbdf4 100644 --- a/autofit/messages/abstract.py +++ b/autofit/messages/abstract.py @@ -46,8 +46,10 @@ def __init__( upper_limit=math.inf, id_=None, ): + self.lower_limit = lower_limit self.upper_limit = upper_limit + self.id = next(self.ids) if id_ is None else id_ self.log_norm = log_norm self._broadcast = np.broadcast(*parameters) @@ -66,8 +68,6 @@ def _init_kwargs(self): return dict( log_norm=self.log_norm, id_=self.id, - lower_limit=self.lower_limit, - upper_limit=self.upper_limit, ) def check_support(self) -> np.ndarray: @@ -92,8 +92,6 @@ def copy(self): result = cls( *(copy(params) for params in self.parameters), log_norm=self.log_norm, - lower_limit=self.lower_limit, - upper_limit=self.upper_limit, ) result.id = self.id return result @@ -199,8 +197,6 @@ def __mul__(self, other: Union["AbstractMessage", Real]) -> "AbstractMessage": *self.parameters, log_norm=log_norm, id_=self.id, - lower_limit=self.lower_limit, - upper_limit=self.upper_limit, ) def __rmul__(self, other: "AbstractMessage") -> "AbstractMessage": @@ -216,8 +212,6 @@ def __truediv__(self, other: Union["AbstractMessage", Real]) -> "AbstractMessage *self.parameters, log_norm=log_norm, id_=self.id, - lower_limit=self.lower_limit, - upper_limit=self.upper_limit, ) def __pow__(self, other: Real) -> "AbstractMessage": @@ -228,8 +222,6 @@ def __pow__(self, other: Real) -> "AbstractMessage": new_params, log_norm=log_norm, id_=self.id, - lower_limit=self.lower_limit, - upper_limit=self.upper_limit, ) return new @@ -260,7 +252,6 @@ def __str__(self) -> str: __repr__ = __str__ def factor(self, x): - # self.assert_within_limits(x) return self.logpdf(x) @classmethod @@ -341,8 +332,6 @@ def update_invalid(self, other: "AbstractMessage") -> "AbstractMessage": *valid_parameters, log_norm=self.log_norm, id_=self.id, - lower_limit=self.lower_limit, - upper_limit=self.upper_limit, ) return new @@ -410,16 +399,12 @@ def _reconstruct( parameters: Tuple[np.ndarray, ...], log_norm: float, id_, - lower_limit, - upper_limit, *args, ): return cls( *parameters, log_norm=log_norm, id_=id_, - lower_limit=lower_limit, - upper_limit=upper_limit, ) def __reduce__(self): @@ -430,8 +415,6 @@ def __reduce__(self): self.parameters, self.log_norm, self.id, - self.lower_limit, - self.upper_limit, ), ) diff --git a/autofit/messages/beta.py b/autofit/messages/beta.py index 09abae38b..bd5b28a7e 100644 --- a/autofit/messages/beta.py +++ b/autofit/messages/beta.py @@ -1,6 +1,6 @@ import math import warnings -from typing import Union +from typing import Tuple, Union import numpy as np from scipy.special import betaln @@ -10,23 +10,71 @@ from ..messages.abstract import AbstractMessage -def grad_betaln(ab): +def grad_betaln(ab: np.ndarray) -> np.ndarray: + """ + Compute the gradient of the log Beta function with respect to parameters a and b. + + Parameters + ---------- + ab + Array of shape (N, 2) where each row contains parameters [a, b] of Beta distributions. + + Returns + ------- + Gradient array of shape (N, 2) with derivatives of log Beta function w.r.t a and b. + """ psiab = psi(ab.sum(axis=1, keepdims=True)) return psi(ab) - psiab -def jac_grad_betaln(ab): +def jac_grad_betaln(ab: np.ndarray) -> np.ndarray: + """ + Compute the Jacobian matrix of the gradient of the log Beta function. + + Parameters + ---------- + ab + Array of shape (N, 2) with Beta parameters [a, b]. + + Returns + ------- + Array of shape (N, 2, 2), the Jacobian matrices for each parameter pair. + """ psi1ab = polygamma(1, ab.sum(axis=1, keepdims=True)) fii = polygamma(1, ab) - psi1ab fij = -psi1ab[:, 0] return np.array([[fii[:, 0], fij], [fij, fii[:, 1]]]).T -def inv_beta_suffstats(lnX, ln1X): - """Solve for a, b for, +def inv_beta_suffstats( + lnX: Union[np.ndarray, float], + ln1X: Union[np.ndarray, float], +) -> Tuple[Union[np.ndarray, float], Union[np.ndarray, float]]: + """ + Solve for Beta distribution parameters (a, b) given log sufficient statistics. + + The system solves: + + psi(a) - psi(a + b) = lnX + psi(b) - psi(a + b) = ln1X - psi(a) + psi(a + b) = lnX - psi(b) + psi(a + b) = ln1X + Parameters + ---------- + lnX + Logarithm of the expected value of X. + ln1X + Logarithm of the expected value of 1 - X. + + Returns + ------- + a + Estimated alpha parameter(s) of the Beta distribution. + b + Estimated beta parameter(s) of the Beta distribution. + + Warnings + -------- + Emits a RuntimeWarning if negative parameters are found, and clamps them to 0.5. """ _lnX, _ln1X = np.ravel(lnX), np.ravel(ln1X) lnXs = np.c_[_lnX, _ln1X] @@ -61,9 +109,7 @@ def inv_beta_suffstats(lnX, ln1X): class BetaMessage(AbstractMessage): - """ - Models a Beta distribution - """ + log_base_measure = 0 _support = ((0, 1),) _min = 0 @@ -72,34 +118,77 @@ class BetaMessage(AbstractMessage): _parameter_support = ((0, np.inf), (0, np.inf)) def __init__( - self, - alpha=0.5, - beta=0.5, - lower_limit=-math.inf, - upper_limit=math.inf, - log_norm=0, - id_=None + self, + alpha: float = 0.5, + beta: float = 0.5, + log_norm: float = 0, + id_: Union[str, None] = None, ): + """ + Represents a Beta distribution message in natural parameter form. + + Parameters + ---------- + alpha + Alpha (shape) parameter of the Beta distribution. Default is 0.5. + beta + Beta (shape) parameter of the Beta distribution. Default is 0.5. + log_norm + Logarithm of normalization constant for message passing. Default is 0. + id_ + Identifier for the message. Default is None. + """ self.alpha = alpha self.beta = beta super().__init__( alpha, beta, - lower_limit=lower_limit, - upper_limit=upper_limit, log_norm=log_norm, id_=id_ ) def value_for(self, unit: float) -> float: + """ + Map a unit interval value (0 to 1) to a value consistent with the Beta distribution. + + Parameters + ---------- + unit + Input value in the unit interval [0, 1]. + + Returns + ------- + float + Corresponding Beta-distributed value. + + Raises + ------ + NotImplementedError + This method should be implemented by subclasses. + """ raise NotImplemented() @cached_property def log_partition(self) -> np.ndarray: + """ + Compute the log partition function (log normalization constant) of the Beta distribution. + + Returns + ------- + The value of the log Beta function, i.e. betaln(alpha, beta). + """ return betaln(*self.parameters) @cached_property def natural_parameters(self) -> np.ndarray: + """ + Compute the natural parameters of the Beta distribution. + + Returns + ------- + np.ndarray + Natural parameters array [alpha - 1, beta - 1]. + """ return self.calc_natural_parameters( self.alpha, self.beta @@ -110,43 +199,137 @@ def calc_natural_parameters( alpha: Union[float, np.ndarray], beta: Union[float, np.ndarray] ) -> np.ndarray: + """ + Calculate the natural parameters of a Beta distribution from alpha and beta. + + Parameters + ---------- + alpha + Alpha (shape) parameter(s) of the Beta distribution. + beta + Beta (shape) parameter(s) of the Beta distribution. + + Returns + ------- + Natural parameters [alpha - 1, beta - 1]. + """ return np.array([alpha - 1, beta - 1]) @staticmethod def invert_natural_parameters( natural_parameters: np.ndarray ) -> np.ndarray: + """ + Convert natural parameters back to standard Beta distribution parameters. + + Parameters + ---------- + natural_parameters + Array of natural parameters [alpha - 1, beta - 1]. + + Returns + ------- + Standard Beta parameters [alpha, beta]. + """ return natural_parameters + 1 @classmethod def invert_sufficient_statistics( cls, sufficient_statistics: np.ndarray ) -> np.ndarray: + """ + Estimate natural parameters from sufficient statistics using inverse operations. + + Parameters + ---------- + sufficient_statistics + Sufficient statistics (e.g. expectations of log X and log(1 - X)). + + Returns + ------- + Natural parameters computed from sufficient statistics. + """ a, b = inv_beta_suffstats(*sufficient_statistics) return cls.calc_natural_parameters(a, b) @classmethod def to_canonical_form(cls, x: np.ndarray) -> np.ndarray: + """ + Convert a value x in (0,1) to the canonical sufficient statistics for Beta. + + Parameters + ---------- + x + Values in the support of the Beta distribution (0 < x < 1). + + Returns + ------- + Canonical sufficient statistics [log(x), log(1 - x)]. + """ return np.array([np.log(x), np.log1p(-x)]) @cached_property def mean(self) -> Union[np.ndarray, float]: + """ + Compute the mean of the Beta distribution. + + Returns + ------- + Mean value alpha / (alpha + beta). + """ return self.alpha / (self.alpha + self.beta) @cached_property def variance(self) -> Union[np.ndarray, float]: + """ + Compute the variance of the Beta distribution. + + Returns + ------- + Variance value of the Beta distribution. + """ return ( self.alpha * self.beta / (self.alpha + self.beta) ** 2 / (self.alpha + self.beta + 1) ) - def sample(self, n_samples=None): + def sample(self, n_samples: int = None) -> np.ndarray: + """ + Draw samples from the Beta distribution. + + Parameters + ---------- + n_samples + Number of samples to draw. If None, returns a single sample. + + Returns + ------- + Samples drawn from Beta(alpha, beta). + """ a, b = self.parameters shape = (n_samples,) + self.shape if n_samples else self.shape return np.random.beta(a, b, size=shape) - def kl(self, dist): + def kl(self, dist: "BetaMessage") -> float: + """ + Calculate the Kullback-Leibler divergence KL(self || dist). + + Parameters + ---------- + dist + The Beta distribution to compare against. + + Returns + ------- + float + The KL divergence value. + + Raises + ------ + TypeError + If the support of the two distributions does not match. + """ # TODO check this is correct # https://arxiv.org/pdf/0911.4863.pdf if self._support != dist._support: @@ -161,13 +344,51 @@ def kl(self, dist): + (aQ - aP + bQ - bP) * psi(aP + bP) ) - def logpdf_gradient(self, x): + def logpdf_gradient( + self, + x: Union[float, np.ndarray] + ) -> Tuple[Union[float, np.ndarray], Union[float, np.ndarray]]: + """ + Compute the log probability density function and its gradient at x. + + Parameters + ---------- + x + Point(s) in (0, 1) where to evaluate the logpdf and gradient. + + Returns + ------- + logl + Log of the PDF evaluated at x. + gradl + Gradient of the log PDF at x. + """ logl = self.logpdf(x) a, b = self.parameters gradl = (a - 1) / x + (b - 1) / (x - 1) return logl, gradl - def logpdf_gradient_hessian(self, x): + def logpdf_gradient_hessian( + self, + x: Union[float, np.ndarray] + ) -> Tuple[Union[float, np.ndarray], Union[float, np.ndarray], Union[float, np.ndarray]]: + """ + Compute the logpdf, its gradient, and Hessian at x. + + Parameters + ---------- + x + Point(s) in (0, 1) where to evaluate the logpdf, gradient, and Hessian. + + Returns + ------- + logl + Log of the PDF evaluated at x. + gradl + Gradient of the log PDF at x. + hessl + Hessian (second derivative) of the log PDF at x. + """ logl = self.logpdf(x) a, b = self.parameters ax, bx = (a - 1) / x, (b - 1) / (x - 1) diff --git a/autofit/messages/composed_transform.py b/autofit/messages/composed_transform.py index c040ea0bb..959647009 100644 --- a/autofit/messages/composed_transform.py +++ b/autofit/messages/composed_transform.py @@ -80,8 +80,6 @@ def __init__( A list of transforms applied left to right. For example, a shifted uniform normal message is first converted to uniform normal then shifted id_ - lower_limit - upper_limit """ while isinstance(base_message, TransformedMessage): transforms = base_message.transforms + transforms diff --git a/autofit/messages/fixed.py b/autofit/messages/fixed.py index 7645fe8cd..db4db6482 100644 --- a/autofit/messages/fixed.py +++ b/autofit/messages/fixed.py @@ -13,16 +13,12 @@ class FixedMessage(AbstractMessage): def __init__( self, value: np.ndarray, - lower_limit=-math.inf, - upper_limit=math.inf, log_norm: np.ndarray = 0., id_=None ): self.value = value super().__init__( value, - lower_limit=lower_limit, - upper_limit=upper_limit, log_norm=log_norm, id_=id_ ) diff --git a/autofit/messages/gamma.py b/autofit/messages/gamma.py index 7a2de6cb8..e21989197 100644 --- a/autofit/messages/gamma.py +++ b/autofit/messages/gamma.py @@ -22,8 +22,6 @@ def __init__( self, alpha=1.0, beta=1.0, - lower_limit=-math.inf, - upper_limit=math.inf, log_norm=0.0, id_=None ): @@ -32,8 +30,6 @@ def __init__( super().__init__( alpha, beta, - lower_limit=lower_limit, - upper_limit=upper_limit, log_norm=log_norm, id_=id_ ) diff --git a/autofit/messages/interface.py b/autofit/messages/interface.py index f0eaef14f..28d46d2ab 100644 --- a/autofit/messages/interface.py +++ b/autofit/messages/interface.py @@ -15,8 +15,6 @@ class MessageInterface(ABC): log_base_measure: float log_norm: float id: int - lower_limit: float - upper_limit: float @property @abstractmethod @@ -212,8 +210,6 @@ def sum_natural_parameters(self, *dists: "MessageInterface") -> "MessageInterfac return self.from_natural_parameters( new_params, id_=self.id, - lower_limit=self.lower_limit, - upper_limit=self.upper_limit, ) def sub_natural_parameters(self, other: "MessageInterface") -> "MessageInterface": @@ -226,8 +222,6 @@ def sub_natural_parameters(self, other: "MessageInterface") -> "MessageInterface new_params, log_norm=log_norm, id_=self.id, - lower_limit=self.lower_limit, - upper_limit=self.upper_limit, ) @abstractmethod diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index d789d5aa8..2afe449fe 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -1,5 +1,6 @@ +from collections.abc import Hashable import math -from typing import Tuple, Union +from typing import Optional, Tuple, Union import numpy as np from autofit.jax_wrapper import erfinv @@ -28,6 +29,18 @@ def is_nan(value): class NormalMessage(AbstractMessage): @cached_property def log_partition(self): + """ + Compute the log-partition function (also called log-normalizer or cumulant function) + for the normal distribution in its natural (canonical) parameterization. + # + Let the natural parameters be: + η₁ = μ / σ² + η₂ = -1 / (2σ²) + + Then the log-partition function A(η) for the Gaussian is: + A(η) = η₁² / (-4η₂) - 0.5 * log(-2η₂) + This ensures normalization of the exponential-family distribution. + """ eta1, eta2 = self.natural_parameters return -(eta1**2) / 4 / eta2 - np.log(-2 * eta2) / 2 @@ -37,13 +50,31 @@ def log_partition(self): def __init__( self, - mean, - sigma, - lower_limit=-math.inf, - upper_limit=math.inf, - log_norm=0.0, - id_=None, + mean : Union[float, np.ndarray], + sigma : Union[float, np.ndarray], + log_norm : Optional[float] = 0.0, + id_ : Optional[Hashable] = None, ): + """ + A Gaussian (Normal) message representing a probability distribution over a continuous variable. + + This message defines a Normal distribution parameterized by its mean and standard deviation (sigma). + + Parameters + ---------- + mean + The mean (μ) of the normal distribution. + + sigma + The standard deviation (σ) of the distribution. Must be non-negative. + + log_norm + An additive constant to the log probability of the message. Used internally for message-passing normalization. + Default is 0.0. + + id_ + An optional unique identifier used to track the message in larger probabilistic graphs or models. + """ if (np.array(sigma) < 0).any(): raise exc.MessageException("Sigma cannot be negative") @@ -51,49 +82,159 @@ def __init__( mean, sigma, log_norm=log_norm, - lower_limit=lower_limit, - upper_limit=upper_limit, id_=id_, ) self.mean, self.sigma = self.parameters - def cdf(self, x): + def cdf(self, x : Union[float, np.ndarray]) -> Union[float, np.ndarray]: + """ + Compute the cumulative distribution function (CDF) of the Gaussian distribution + at a given value or array of values `x`. + + Parameters + ---------- + x + The value(s) at which to evaluate the CDF. + + Returns + ------- + The cumulative probability P(X ≤ x). + """ return norm.cdf(x, loc=self.mean, scale=self.sigma) - def ppf(self, x): + def ppf(self, x : Union[float, np.ndarray]) -> Union[float, np.ndarray]: + """ + Compute the percent-point function (inverse CDF) of the Gaussian distribution. + + This function maps a probability value `x` in [0, 1] to the corresponding value + of the distribution with that cumulative probability. + + Parameters + ---------- + x + The cumulative probability or array of probabilities. + + Returns + ------- + The value(s) corresponding to the input quantiles. + """ return norm.ppf(x, loc=self.mean, scale=self.sigma) @cached_property - def natural_parameters(self): + def natural_parameters(self) -> np.ndarray: + """ + The natural (canonical) parameters of the Gaussian distribution in exponential-family form. + + For a Normal distribution with mean μ and standard deviation σ, the natural parameters η are: + + η₁ = μ / σ² + η₂ = -1 / (2σ²) + + Returns + ------- + A NumPy array containing the two natural parameters [η₁, η₂]. + """ return self.calc_natural_parameters(self.mean, self.sigma) @staticmethod - def calc_natural_parameters(mu, sigma): + def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray]) -> np.ndarray: + """ + Convert standard parameters of a Gaussian distribution (mean and standard deviation) + into natural parameters used in its exponential family representation. + + Parameters + ---------- + mu + Mean of the Gaussian distribution. + sigma + Standard deviation of the Gaussian distribution. + + Returns + ------- + Natural parameters [η₁, η₂], where: + η₁ = μ / σ² + η₂ = -1 / (2σ²) + """ precision = 1 / sigma**2 return np.array([mu * precision, -precision / 2]) @staticmethod - def invert_natural_parameters(natural_parameters): + def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, float]: + """ + Convert natural parameters [η₁, η₂] back into standard parameters (mean and sigma) + of a Gaussian distribution. + + Parameters + ---------- + natural_parameters + The natural parameters [η₁, η₂] from the exponential family form. + + Returns + ------- + The corresponding (mean, sigma) of the Gaussian distribution. + """ eta1, eta2 = natural_parameters mu = -0.5 * eta1 / eta2 sigma = np.sqrt(-0.5 / eta2) return mu, sigma @staticmethod - def to_canonical_form(x): + def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray: + """ + Convert a scalar input `x` to its sufficient statistics for the Gaussian exponential family. + + The sufficient statistics for a normal distribution are [x, x²], which correspond to the + inner product with the natural parameters in the exponential-family log-likelihood. + + Parameters + ---------- + x + Input data point or array of points. + + Returns + ------- + The sufficient statistics [x, x²]. + """ return np.array([x, x**2]) @classmethod - def invert_sufficient_statistics(cls, suff_stats): + def invert_sufficient_statistics(cls, suff_stats: Tuple[float, float]) -> np.ndarray: + """ + Convert sufficient statistics [E[x], E[x²]] into natural parameters [η₁, η₂]. + + Parameters + ---------- + suff_stats + First and second moments of the distribution. + + Returns + ------- + Natural parameters of the Gaussian. + """ m1, m2 = suff_stats sigma = np.sqrt(m2 - m1**2) return cls.calc_natural_parameters(m1, sigma) @cached_property - def variance(self): + def variance(self) -> np.ndarray: + """ + Return the variance σ² of the Gaussian distribution. + """ return self.sigma**2 - def sample(self, n_samples=None): + def sample(self, n_samples: Optional[int] = None) -> np.ndarray: + """ + Draw samples from the Gaussian distribution. + + Parameters + ---------- + n_samples + Number of samples to draw. If None, returns a single sample. + + Returns + ------- + Sample(s) from the distribution. + """ if n_samples: x = np.random.randn(n_samples, *self.shape) if self.shape: @@ -103,7 +244,20 @@ def sample(self, n_samples=None): return x * self.sigma + self.mean - def kl(self, dist): + def kl(self, dist : "NormalMessage") -> float: + """ + Compute the Kullback-Leibler (KL) divergence to another Gaussian distribution. + + Parameters + ---------- + dist : Gaussian + The target distribution for the KL divergence. + + Returns + ------- + float + The KL divergence KL(self || dist). + """ return ( np.log(dist.sigma / self.sigma) + (self.sigma**2 + (self.mean - dist.mean) ** 2) / 2 / dist.sigma**2 @@ -112,17 +266,53 @@ def kl(self, dist): @classmethod def from_mode( - cls, mode: np.ndarray, covariance: Union[float, LinearOperator] = 1.0, **kwargs - ): + cls, + mode: np.ndarray, + covariance: Union[float, LinearOperator] = 1.0, + **kwargs + ) -> "NormalMessage": + """ + Construct a Gaussian from its mode and covariance. + + Parameters + ---------- + mode + The mode (same as mean for Gaussian). + covariance + The covariance or a linear operator with `.diagonal()` method. + + Returns + ------- + An instance of the NormalMessage class. + """ if isinstance(covariance, LinearOperator): variance = covariance.diagonal() else: mode, variance = cls._get_mean_variance(mode, covariance) + + if kwargs.get("upper_limit") is not None: + kwargs.pop("upper_limit") + + if kwargs.get("lower_limit") is not None: + kwargs.pop("lower_limit") + return cls(mode, np.abs(variance) ** 0.5, **kwargs) def _normal_gradient_hessian( self, x: np.ndarray ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Compute the log-pdf, gradient, and Hessian of a Gaussian with respect to x. + + Parameters + ---------- + x + Points at which to evaluate. + + Returns + ------- + Log-pdf values, gradients, and Hessians. + """ # raise Exception shape = np.shape(x) if shape: @@ -148,9 +338,33 @@ def _normal_gradient_hessian( return logl, grad_logl, hess_logl def logpdf_gradient(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Return the gradient of the log-pdf of the Gaussian evaluated at `x`. + + Parameters + ---------- + x + Evaluation points. + + Returns + ------- + Log-pdf values and gradients. + """ return self._normal_gradient_hessian(x)[:2] def logpdf_gradient_hessian(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Return the gradient and Hessian of the log-pdf of the Gaussian at `x`. + + Parameters + ---------- + x + Evaluation points. + + Returns + ------- + Gradient and Hessian of the log-pdf. + """ return self._normal_gradient_hessian(x) __name__ = "gaussian_prior" @@ -159,24 +373,21 @@ def logpdf_gradient_hessian(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray def value_for(self, unit: float) -> float: """ - Returns a physical value from an input unit value according to the Gaussian distribution of the prior. + Map a unit value in [0, 1] to a physical value drawn from this Gaussian prior. Parameters ---------- unit - A unit value between 0 and 1. + A unit value between 0 and 1 representing a uniform draw. Returns ------- - value - The unit value mapped to a physical value according to the prior. + A physical value sampled from the Gaussian prior corresponding to the given unit. Examples -------- - - prior = af.GaussianPrior(mean=1.0, sigma=2.0, lower_limit=0.0, upper_limit=2.0) - - physical_value = prior.value_for(unit=0.5) + >>> prior = af.GaussianPrior(mean=1.0, sigma=2.0) + >>> physical_value = prior.value_for(unit=0.5) """ try: inv = erfinv(1 - 2.0 * (1.0 - unit)) @@ -185,39 +396,63 @@ def value_for(self, unit: float) -> float: inv = scipy_erfinv(1 - 2.0 * (1.0 - unit)) return self.mean + (self.sigma * np.sqrt(2) * inv) - def log_prior_from_value(self, value): + def log_prior_from_value(self, value: float) -> float: """ - Returns the log prior of a physical value, so the log likelihood of a model evaluation can be converted to a - posterior as log_prior + log_likelihood. - This is used by certain non-linear searches (e.g. Emcee) in the log likelihood function evaluation. + Compute the log prior probability of a given physical value under this Gaussian prior. + + Used to convert a likelihood to a posterior in non-linear searches (e.g., Emcee). + Parameters ---------- value - The physical value of this prior's corresponding parameter in a `NonLinearSearch` sample. + A physical parameter value for which the log prior is evaluated. + + Returns + ------- + The log prior probability of the given value. """ return (value - self.mean) ** 2.0 / (2 * self.sigma**2.0) def __str__(self): """ - The line of text describing this prior for the model_mapper.info file + Generate a short string summary describing the prior for use in model summaries. """ return f"NormalMessage, mean = {self.mean}, sigma = {self.sigma}" def __repr__(self): + """ + Return the official string representation of this Gaussian prior including + the ID, mean, sigma, and optional bounds. + """ return ( - "".format( - self.id, self.mean, self.sigma, self.lower_limit, self.upper_limit + "".format( + self.id, self.mean, self.sigma, ) ) @property - def natural(self): + def natural(self)-> "NaturalNormal": + """ + Return a 'zeroed' natural parameterization of this Gaussian prior. + + Returns + ------- + A natural form Gaussian with zeroed parameters but same configuration. + """ return NaturalNormal.from_natural_parameters( self.natural_parameters * 0.0, **self._init_kwargs ) def zeros_like(self) -> "AbstractMessage": + """ + Return a new instance of this prior with the same structure but zeroed natural parameters. + + Useful for initializing messages in variational inference frameworks. + + Returns + ------- + A new prior object with zeroed natural parameters. + """ return self.natural.zeros_like() @@ -231,54 +466,144 @@ class NaturalNormal(NormalMessage): def __init__( self, - eta1, - eta2, - lower_limit=-math.inf, - upper_limit=math.inf, - log_norm=0.0, - id_=None, + eta1 : float, + eta2 : float, + log_norm : Optional[float] = 0.0, + id_ : Optional[Hashable] = None, ): + """ + A natural parameterization of a Gaussian distribution. + + This class behaves like `NormalMessage`, but allows non-normalized or degenerate distributions, + including those with negative or infinite variance. This flexibility is useful in advanced + inference settings like message passing or variational approximations, where intermediate + natural parameter values may fall outside standard constraints. + + In natural form, the parameters `eta1` and `eta2` correspond to: + - eta1 = mu / sigma^2 + - eta2 = -1 / (2 * sigma^2) + + Parameters + ---------- + eta1 + First natural parameter, related to the mean. + eta2 + Second natural parameter, related to the variance (must be < 0). + log_norm + Optional additive normalization term for use in message passing. + id_ + Optional identifier for the distribution instance. + """ AbstractMessage.__init__( self, eta1, eta2, log_norm=log_norm, - lower_limit=lower_limit, - upper_limit=upper_limit, id_=id_, ) @cached_property - def sigma(self): + def sigma(self)-> float: + """ + Return the standard deviation corresponding to the natural parameters. + + Returns + ------- + The standard deviation σ, derived from eta2 via σ² = -1/(2η₂). + """ precision = -2 * self.parameters[1] return precision**-0.5 @cached_property - def mean(self): + def mean(self) -> float: + """ + Return the mean corresponding to the natural parameters. + + Returns + ------- + The mean μ = -η₁ / (2η₂), with NaNs replaced by 0 for numerical stability. + """ return np.nan_to_num(-self.parameters[0] / self.parameters[1] / 2) @staticmethod - def calc_natural_parameters(eta1, eta2): + def calc_natural_parameters(eta1: float, eta2: float) -> np.ndarray: + """ + Return the natural parameters in array form (identity function for this class). + + Parameters + ---------- + eta1 + The first natural parameter. + eta2 + The second natural parameter. + """ return np.array([eta1, eta2]) @cached_property - def natural_parameters(self): + def natural_parameters(self) -> np.ndarray: + """ + Return the natural parameters of this distribution. + """ return self.calc_natural_parameters(*self.parameters) @classmethod - def invert_sufficient_statistics(cls, suff_stats): + def invert_sufficient_statistics(cls, suff_stats: Tuple[float, float]) -> np.ndarray: + """ + Convert sufficient statistics back to natural parameters. + + Parameters + ---------- + suff_stats + Tuple of first and second moments: (mean, second_moment). + + Returns + ------- + Natural parameters [eta1, eta2] recovered from the sufficient statistics. + """ m1, m2 = suff_stats precision = 1 / (m2 - m1**2) return cls.calc_natural_parameters(m1 * precision, -precision / 2) @staticmethod - def invert_natural_parameters(natural_parameters): + def invert_natural_parameters(natural_parameters: np.ndarray) -> np.ndarray: + """ + Identity function for natural parameters (no inversion needed). + + Parameters + ---------- + natural_parameters : np.ndarray + Natural parameters [eta1, eta2]. + + Returns + ------- + np.ndarray + The same input array. + """ return natural_parameters @classmethod def from_mode( - cls, mode: np.ndarray, covariance: Union[float, LinearOperator] = 1.0, **kwargs - ): + cls, + mode: np.ndarray, + covariance: Union[float, LinearOperator] = 1.0, + **kwargs + ) -> "NaturalNormal": + """ + Construct a `NaturalNormal` distribution from mode and covariance. + + Parameters + ---------- + mode + The mode (mean) of the distribution. + covariance + Covariance of the distribution. If a `LinearOperator`, its inverse is used for precision. + kwargs + Additional keyword arguments passed to the constructor. + + Returns + ------- + An instance of `NaturalNormal` with the corresponding natural parameters. + """ if isinstance(covariance, LinearOperator): precision = covariance.inv().diagonal() else: @@ -290,7 +615,10 @@ def from_mode( zeros_like = AbstractMessage.zeros_like @property - def natural(self): + def natural(self) -> "NaturalNormal": + """ + Return self — already in natural form -- for clean API. + """ return self diff --git a/autofit/messages/truncated_normal.py b/autofit/messages/truncated_normal.py new file mode 100644 index 000000000..1be286725 --- /dev/null +++ b/autofit/messages/truncated_normal.py @@ -0,0 +1,725 @@ +from collections.abc import Hashable +import math +from scipy.stats import truncnorm +from typing import Optional, Tuple, Union + +import numpy as np +from autofit.jax_wrapper import erfinv +from scipy.stats import norm + +from autoconf import cached_property +from autofit.mapper.operator import LinearOperator +from autofit.messages.abstract import AbstractMessage +from .composed_transform import TransformedMessage +from .transform import ( + phi_transform, + log_transform, + multinomial_logit_transform, + log_10_transform, +) +from .. import exc + + +def is_nan(value): + is_nan_ = np.isnan(value) + if isinstance(is_nan_, np.ndarray): + is_nan_ = is_nan_.all() + return is_nan_ + + +class TruncatedNormalMessage(AbstractMessage): + @cached_property + def log_partition(self) -> float: + """ + Compute the log-partition function (normalizer) of the truncated Gaussian. + + This is the log of the normalization constant Z of the truncated normal: + + Z = Φ((b - μ)/σ) - Φ((a - μ)/σ) + + where Φ is the standard normal CDF and [a, b] are the truncation bounds. + + Returns + ------- + float + The log-partition (log of the normalizing constant). + """ + a = (self.lower_limit - self.mean) / self.sigma + b = (self.upper_limit - self.mean) / self.sigma + Z = norm.cdf(b) - norm.cdf(a) + return np.log(Z) if Z > 0 else -np.inf + + log_base_measure = -0.5 * np.log(2 * np.pi) + + @property + def _support(self): + return ((self.lower_limit, self.upper_limit),) + + _parameter_support = ((-np.inf, np.inf), (0, np.inf)) + + def __init__( + self, + mean : Union[float, np.ndarray], + sigma : Union[float, np.ndarray], + lower_limit=-math.inf, + upper_limit=math.inf, + log_norm : Optional[float] = 0.0, + id_ : Optional[Hashable] = None, + ): + """ + A Gaussian (Normal) message representing a probability distribution over a continuous variable. + + This message defines a Normal distribution parameterized by its mean and standard deviation (sigma). + + Parameters + ---------- + mean + The mean (μ) of the normal distribution. + + sigma + The standard deviation (σ) of the distribution. Must be non-negative. + + log_norm + An additive constant to the log probability of the message. Used internally for message-passing normalization. + Default is 0.0. + + id_ + An optional unique identifier used to track the message in larger probabilistic graphs or models. + """ + if (np.array(sigma) < 0).any(): + raise exc.MessageException("Sigma cannot be negative") + + super().__init__( + mean, + sigma, + float(lower_limit), + float(upper_limit), + log_norm=log_norm, + id_=id_, + ) + self.mean, self.sigma, self.lower_limit, self.upper_limit = self.parameters + + def cdf(self, x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: + """ + Compute the cumulative distribution function (CDF) of the truncated Gaussian distribution + at a given value or array of values `x`. + + The CDF is computed using `scipy.stats.truncnorm`, which handles the normalization + over the truncated interval [lower_limit, upper_limit]. + + Parameters + ---------- + x + The value(s) at which to evaluate the CDF. + + Returns + ------- + The cumulative probability P(X ≤ x) under the truncated Gaussian. + """ + a = (self.lower_limit - self.mean) / self.sigma + b = (self.upper_limit - self.mean) / self.sigma + return truncnorm.cdf(x, a=a, b=b, loc=self.mean, scale=self.sigma) + + def ppf(self, x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: + """ + Compute the percent-point function (inverse CDF) of the truncated Gaussian distribution. + + This function maps a probability value `x` in [0, 1] to the corresponding value + under the truncated Gaussian distribution. + + Parameters + ---------- + x + The cumulative probability or array of probabilities. + + Returns + ------- + The value(s) corresponding to the input quantiles. + """ + a = (self.lower_limit - self.mean) / self.sigma + b = (self.upper_limit - self.mean) / self.sigma + return truncnorm.ppf(x, a=a, b=b, loc=self.mean, scale=self.sigma) + + @cached_property + def natural_parameters(self) -> np.ndarray: + """ + The pseudo-natural (canonical) parameters of a truncated Gaussian distribution. + + For a Gaussian with mean μ and standard deviation σ, the untruncated natural parameters η are: + + η₁ = μ / σ² + η₂ = -1 / (2σ²) + + These are returned here even for the truncated case, but note that due to truncation, + the distribution is no longer in the exponential family and the log-partition function + depends on the lower and upper truncation limits. + + Returns + ------- + A NumPy array containing the pseudo-natural parameters [η₁, η₂]. + """ + return self.calc_natural_parameters(self.mean, self.sigma) + + @staticmethod + def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray]) -> np.ndarray: + """ + Convert standard parameters of a Gaussian distribution (mean and standard deviation) + into natural parameters used in its exponential family representation. + + This function does **not** directly account for truncation. In the case of a truncated Gaussian, + these parameters are treated as pseudo-natural parameters, meaning they are defined analogously + to the untruncated case but do not fully characterize the distribution. This is because truncation + modifies the normalization constant (log-partition function), making the distribution fall outside + the strict exponential family. + + For truncated Gaussians, any computations involving expectations, gradients, or log-partition + functions must incorporate the effects of truncation separately. + + Parameters + ---------- + mu + Mean of the Gaussian distribution. + sigma + Standard deviation of the Gaussian distribution. + + Returns + ------- + Natural parameters [η₁, η₂], where: + η₁ = μ / σ² + η₂ = -1 / (2σ²) + """ + precision = 1 / sigma**2 + return np.array([mu * precision, -precision / 2]) + + @staticmethod + def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, float]: + """ + Convert natural parameters [η₁, η₂] back into standard parameters (mean and sigma) + of a Gaussian distribution. + + For a truncated Gaussian, this inversion treats the natural parameters as if they + came from an untruncated distribution. That is, the computed (mean, sigma) are + the parameters of the *underlying* Gaussian prior to truncation. + + Parameters + ---------- + natural_parameters + The natural parameters [η₁, η₂] from the exponential family form. + + Returns + ------- + The corresponding (mean, sigma) of the Gaussian distribution. + """ + eta1, eta2 = natural_parameters + mu = -0.5 * eta1 / eta2 + sigma = np.sqrt(-0.5 / eta2) + return mu, sigma + + @staticmethod + def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray: + """ + Convert a scalar input `x` to its sufficient statistics for the Gaussian exponential family. + + This form is unchanged by truncation, as sufficient statistics remain [x, x²] regardless + of whether the distribution is truncated. However, note that for a truncated Gaussian, + expectations (e.g. E[x], E[x²]) must be computed over the truncated support. + + Parameters + ---------- + x + Input data point or array of points. + + Returns + ------- + The sufficient statistics [x, x²]. + """ + return np.array([x, x**2]) + + @classmethod + def invert_sufficient_statistics(cls, suff_stats: Tuple[float, float]) -> np.ndarray: + """ + Convert sufficient statistics [E[x], E[x²]] into natural parameters [η₁, η₂]. + + These moments are assumed to be expectations *under the truncated Gaussian* distribution, + meaning that the inferred natural parameters correspond to the truncated form indirectly. + + Parameters + ---------- + suff_stats + First and second moments of the distribution. + + Returns + ------- + Natural parameters of the Gaussian. + """ + m1, m2 = suff_stats + sigma = np.sqrt(m2 - m1**2) + return cls.calc_natural_parameters(m1, sigma) + + @cached_property + def variance(self) -> np.ndarray: + """ + Return the variance σ² of the Gaussian distribution. + """ + return self.sigma**2 + + def sample(self, n_samples: Optional[int] = None) -> np.ndarray: + """ + Draw samples from a truncated Gaussian distribution using inverse transform sampling. + + Samples are drawn from a standard Normal distribution, transformed using the mean and sigma, + and then rejected if they fall outside the [lower_limit, upper_limit] bounds. + + Parameters + ---------- + n_samples + Number of samples to draw. If None, returns a single sample. + + Returns + ------- + Sample(s) from the truncated Gaussian distribution. + """ + a, b = (self.lower_limit - self.mean) / self.sigma, (self.upper_limit - self.mean) / self.sigma + shape = (n_samples,) + self.shape if n_samples else self.shape + samples = truncnorm.rvs(a, b, loc=self.mean, scale=self.sigma, size=shape) + + return samples + + def kl(self, dist : "TruncatedNormalMessage") -> float: + """ + Compute the Kullback-Leibler (KL) divergence between two truncated Gaussian distributions. + + This is an approximate KL divergence that assumes both distributions are truncated Gaussians + with the same support (i.e. the same lower and upper limits). If the supports differ, this + expression is invalid and should raise an error or be corrected for normalization. + + Parameters + ---------- + dist + The target distribution for the KL divergence. + + Returns + ------- + float + The KL divergence KL(self || dist). + """ + if (self.lower_limit != dist.lower_limit) or (self.upper_limit != dist.upper_limit): + raise ValueError("KL divergence between truncated Gaussians with different support is not implemented.") + + return ( + np.log(dist.sigma / self.sigma) + + (self.sigma**2 + (self.mean - dist.mean) ** 2) / 2 / dist.sigma**2 + - 1 / 2 + ) + + @classmethod + def from_mode( + cls, + mode: np.ndarray, + covariance: Union[float, LinearOperator] = 1.0, + **kwargs + ) -> "TruncatedNormalMessage": + """ + Construct a truncated Gaussian from its mode and covariance. + + For a Gaussian, the mode equals the mean. This method uses that identity to construct + the message from point estimates. + + Parameters + ---------- + mode + The mode (same as mean for Gaussian). + covariance + The covariance or a linear operator with `.diagonal()` method. + **kwargs + Additional keyword arguments passed to the constructor, such as truncation bounds. + + Returns + ------- + An instance of the TruncatedNormalMessage class. + """ + if isinstance(covariance, LinearOperator): + variance = covariance.diagonal() + else: + mode, variance = cls._get_mean_variance(mode, covariance) + return cls(mode, np.abs(variance) ** 0.5, **kwargs) + + def _normal_gradient_hessian( + self, x: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + raise NotImplementedError + + def logpdf_gradient(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Return the gradient of the log-pdf of the Gaussian evaluated at `x`. + + Parameters + ---------- + x + Evaluation points. + + Returns + ------- + Log-pdf values and gradients. + """ + return self._normal_gradient_hessian(x)[:2] + + def logpdf_gradient_hessian(self, x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Return the gradient and Hessian of the log-pdf of the Gaussian at `x`. + + Parameters + ---------- + x + Evaluation points. + + Returns + ------- + Gradient and Hessian of the log-pdf. + """ + return self._normal_gradient_hessian(x) + + __name__ = "truncated_gaussian_prior" + + __default_fields__ = ("log_norm", "id_") + + def value_for(self, unit: float) -> float: + """ + Map a unit value in [0, 1] to a physical value drawn from this truncated Gaussian prior. + + For a truncated Gaussian, this is done using the percent-point function (inverse CDF) + that accounts for the truncation bounds. + + Parameters + ---------- + unit + A unit value between 0 and 1 representing a uniform draw. + + Returns + ------- + A physical value sampled from the truncated Gaussian prior corresponding to the given unit. + + Examples + -------- + >>> prior = af.TruncatedNormalMessage(mean=1.0, sigma=2.0, lower_limit=0.0, upper_limit=2.0) + >>> physical_value = prior.value_for(unit=0.5) + """ + # Standardized truncation bounds + a = (self.lower_limit - self.mean) / self.sigma + b = (self.upper_limit - self.mean) / self.sigma + + # Interpolate unit into [Phi(a), Phi(b)] + lower_cdf = norm.cdf(a) + upper_cdf = norm.cdf(b) + truncated_cdf = lower_cdf + unit * (upper_cdf - lower_cdf) + + # Map back to x using inverse CDF, then rescale + x_standard = norm.ppf(truncated_cdf) + return self.mean + self.sigma * x_standard + + def log_prior_from_value(self, value: float) -> float: + """ + Compute the log prior probability of a given physical value under this truncated Gaussian prior. + + This accounts for truncation by normalizing the Gaussian density over the + interval [lower_limit, upper_limit], returning -inf if the value lies outside + these limits. + + Parameters + ---------- + value + A physical parameter value for which the log prior is evaluated. + + Returns + ------- + The log prior probability of the given value, or -inf if outside truncation bounds. + """ + # Check truncation bounds + if not (self.lower_limit <= value <= self.upper_limit): + return -np.inf + + # Standardized truncation limits + a, b = (self.lower_limit - self.mean) / self.sigma, (self.upper_limit - self.mean) / self.sigma + + # Normalization constant for truncated Gaussian + Z = norm.cdf(b) - norm.cdf(a) + + # Standardized value + z = (value - self.mean) / self.sigma + + # Log probability density of normal (up to normalization) + log_pdf = -0.5 * z ** 2 - np.log(self.sigma) - 0.5 * np.log(2 * np.pi) + + # Adjust for truncation normalization + return log_pdf - np.log(Z) + + def __str__(self): + """ + Generate a short string summary describing the prior for use in model summaries. + """ + return (f"TruncatedNormalMessage, mean = {self.mean}, sigma = {self.sigma}, " + f"lower_limit = {self.lower_limit}, upper_limit = {self.upper_limit}") + + def __repr__(self): + """ + Return the official string representation of this Gaussian prior including + the ID, mean, sigma, and optional bounds. + """ + return ( + "".format( + self.id, self.mean, self.sigma, self.lower_limit, self.upper_limit + ) + ) + + @property + def natural(self)-> "NaturalNormal": + """ + Return a 'zeroed' natural parameterization of this Gaussian prior. + + Returns + ------- + A natural form Gaussian with zeroed parameters but same configuration. + """ + return TruncatedNaturalNormal.from_natural_parameters( + self.natural_parameters * 0.0, **self._init_kwargs + ) + + def zeros_like(self) -> "AbstractMessage": + """ + Return a new instance of this prior with the same structure but zeroed natural parameters. + + Useful for initializing messages in variational inference frameworks. + + Returns + ------- + A new prior object with zeroed natural parameters. + """ + return self.natural.zeros_like() + + +class TruncatedNaturalNormal(TruncatedNormalMessage): + """ + Identical to the TruncatedNormalMessage but allows non-normalised values, + e.g negative or infinite variances + """ + + _parameter_support = ((-np.inf, np.inf), (-np.inf, 0)) + + def __init__( + self, + eta1 : float, + eta2 : float, + lower_limit=-math.inf, + upper_limit=math.inf, + log_norm : Optional[float] = 0.0, + id_ : Optional[Hashable] = None, + ): + """ + A natural parameterization of a Gaussian distribution. + + This class behaves like `TruncatedNormalMessage`, but allows non-normalized or degenerate distributions, + including those with negative or infinite variance. This flexibility is useful in advanced + inference settings like message passing or variational approximations, where intermediate + natural parameter values may fall outside standard constraints. + + In natural form, the parameters `eta1` and `eta2` correspond to: + - eta1 = mu / sigma^2 + - eta2 = -1 / (2 * sigma^2) + + Parameters + ---------- + eta1 + First natural parameter, related to the mean. + eta2 + Second natural parameter, related to the variance (must be < 0). + log_norm + Optional additive normalization term for use in message passing. + id_ + Optional identifier for the distribution instance. + """ + AbstractMessage.__init__( + self, + eta1, + eta2, + log_norm=log_norm, + lower_limit=lower_limit, + upper_limit=upper_limit, + id_=id_, + ) + + @cached_property + def sigma(self) -> float: + """ + Return the standard deviation σ of the truncated Gaussian corresponding to + the natural parameters and truncation limits. + + Uses scipy.stats.truncnorm to compute std dev on the truncated interval. + + Returns + ------- + The truncated Gaussian standard deviation σ. + """ + precision = -2 * self.parameters[1] + if precision <= 0 or np.isinf(precision) or np.isnan(precision): + # Degenerate or invalid precision: fallback to NaN or zero + return np.nan + + mean = -self.parameters[0] / (2 * self.parameters[1]) + std = precision ** -0.5 + + a, b = (self.lower_limit - mean) / std, (self.upper_limit - mean) / std + + # Compute truncated std dev + truncated_std = truncnorm.std(a, b, loc=mean, scale=std) + return truncated_std + + @cached_property + def mean(self) -> float: + """ + Return the mean μ of the truncated Gaussian corresponding to the natural parameters + and truncation limits. + + Uses scipy.stats.truncnorm to compute mean on the truncated interval. + + Returns + ------- + The truncated Gaussian mean μ. + """ + precision = -2 * self.parameters[1] + if precision <= 0 or np.isinf(precision) or np.isnan(precision): + # Degenerate or invalid precision: fallback to NaN or zero + return np.nan + + mean = -self.parameters[0] / (2 * self.parameters[1]) + std = precision**-0.5 + + a, b = (self.lower_limit - mean) / std, (self.upper_limit - mean) / std + + # Compute truncated mean + truncated_mean = truncnorm.mean(a, b, loc=mean, scale=std) + return truncated_mean + + @staticmethod + def calc_natural_parameters( + eta1: float, + eta2: float, + lower_limit: float = -np.inf, + upper_limit: float = np.inf + ) -> np.ndarray: + """ + Return the natural parameters in array form (identity function for this class). + + Currently returns eta1 and eta2 ignoring truncation, + but can be extended to adjust natural parameters based on truncation. + + Parameters + ---------- + eta1 + The first natural parameter. + eta2 + The second natural parameter. + """ + return np.array([eta1, eta2]) + + @cached_property + def natural_parameters(self) -> np.ndarray: + """ + Return the natural parameters of this distribution. + """ + return self.calc_natural_parameters(*self.parameters, self.lower_limit, self.upper_limit) + + @classmethod + def invert_sufficient_statistics( + cls, + suff_stats: Tuple[float, float], + lower_limit: float = -np.inf, + upper_limit: float = np.inf + ) -> np.ndarray: + """ + Convert sufficient statistics back to natural parameters. + + Parameters + ---------- + suff_stats + Tuple of first and second moments: (mean, second_moment). + + Returns + ------- + Natural parameters [eta1, eta2] recovered from the sufficient statistics. + """ + m1, m2 = suff_stats + precision = 1 / (m2 - m1**2) + return cls.calc_natural_parameters(m1 * precision, -precision / 2, lower_limit, upper_limit) + + @staticmethod + def invert_natural_parameters(natural_parameters: np.ndarray) -> np.ndarray: + """ + Identity function for natural parameters (no inversion needed). + + Parameters + ---------- + natural_parameters : np.ndarray + Natural parameters [eta1, eta2]. + + Returns + ------- + np.ndarray + The same input array. + """ + return natural_parameters + + @classmethod + def from_mode( + cls, + mode: np.ndarray, + covariance: Union[float, LinearOperator] = 1.0, + lower_limit: float = -np.inf, + upper_limit: float = np.inf, + **kwargs + ) -> "NaturalNormal": + """ + Construct a `NaturalNormal` distribution from mode and covariance. + + Parameters + ---------- + mode + The mode (mean) of the distribution. + covariance + Covariance of the distribution. If a `LinearOperator`, its inverse is used for precision. + kwargs + Additional keyword arguments passed to the constructor. + + Returns + ------- + An instance of `NaturalNormal` with the corresponding natural parameters. + """ + if isinstance(covariance, LinearOperator): + precision = covariance.inv().diagonal() + else: + mode, variance = cls._get_mean_variance(mode, covariance) + precision = 1 / variance + + return cls(mode * precision, -precision / 2, lower_limit=lower_limit, upper_limit=upper_limit, **kwargs) + + zeros_like = AbstractMessage.zeros_like + + @property + def natural(self) -> "NaturalNormal": + """ + Return self — already in natural form -- for clean API. + """ + return self + + +UniformNormalMessage = TransformedMessage(TruncatedNormalMessage(0, 1), phi_transform) + +Log10UniformNormalMessage = TransformedMessage(UniformNormalMessage, log_10_transform) + +LogNormalMessage = TransformedMessage(TruncatedNormalMessage(0, 1), log_transform) +Log10NormalMessage = TransformedMessage(TruncatedNormalMessage(0, 1), log_10_transform) + +# Support is the simplex +MultiLogitNormalMessage = TransformedMessage( + TruncatedNormalMessage(0, 1), multinomial_logit_transform +) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index 84c05c2d3..7b1774631 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -4,9 +4,10 @@ from autoconf import conf from autoconf import cached_property +from autofit import jax_wrapper +from autofit.jax_wrapper import numpy as np from autofit import exc -from autofit.jax_wrapper import numpy as np from autofit.mapper.prior_model.abstract import AbstractPriorModel from autofit.non_linear.paths.abstract import AbstractPaths @@ -14,9 +15,6 @@ from timeout_decorator import timeout -from autofit import jax_wrapper - - def get_timeout_seconds(): try: @@ -24,10 +22,8 @@ def get_timeout_seconds(): except KeyError: pass - timeout_seconds = get_timeout_seconds() - class Fitness: def __init__( self, diff --git a/autofit/non_linear/grid/grid_search/__init__.py b/autofit/non_linear/grid/grid_search/__init__.py index 0a33d0af0..8768c28f5 100644 --- a/autofit/non_linear/grid/grid_search/__init__.py +++ b/autofit/non_linear/grid/grid_search/__init__.py @@ -123,6 +123,13 @@ def make_lists(self, grid_priors): def make_arguments(self, values, grid_priors): arguments = {} for value, grid_prior in zip(values, grid_priors): + try: + grid_prior.lower_limit + grid_prior.upper_limit + except AttributeError: + raise exc.PriorException( + "Priors passed to the grid search must have upper and lower limit (e.g. be UniformPrior)" + ) if ( float("-inf") == grid_prior.lower_limit or float("inf") == grid_prior.upper_limit diff --git a/autofit/non_linear/initializer.py b/autofit/non_linear/initializer.py index 2cf0d3127..3c0ffb20a 100644 --- a/autofit/non_linear/initializer.py +++ b/autofit/non_linear/initializer.py @@ -66,7 +66,7 @@ def samples_from_model( if os.environ.get("PYAUTOFIT_TEST_MODE") == "1" and test_mode_samples: return self.samples_in_test_mode(total_points=total_points, model=model) - if jax_wrapper.use_jax: + if jax_wrapper.use_jax or n_cores == 1: return self.samples_jax( total_points=total_points, model=model, diff --git a/autofit/non_linear/parallel/process.py b/autofit/non_linear/parallel/process.py index bd3c6e3d2..b034611a5 100644 --- a/autofit/non_linear/parallel/process.py +++ b/autofit/non_linear/parallel/process.py @@ -62,6 +62,7 @@ def __init__( job_queue: multiprocessing.Queue The queue through which jobs are submitted """ + super().__init__(name=name) self.logger = logging.getLogger( f"process {name}" diff --git a/autofit/non_linear/parallel/sneaky.py b/autofit/non_linear/parallel/sneaky.py index f47b99e2f..321c5778c 100644 --- a/autofit/non_linear/parallel/sneaky.py +++ b/autofit/non_linear/parallel/sneaky.py @@ -64,6 +64,7 @@ def __init__(self, function, *args): args The arguments to that function """ + super().__init__() if _is_likelihood_function(function): self.function = None diff --git a/autofit/non_linear/paths/database.py b/autofit/non_linear/paths/database.py index 5845a1442..75f288db3 100644 --- a/autofit/non_linear/paths/database.py +++ b/autofit/non_linear/paths/database.py @@ -86,9 +86,7 @@ def create_child( """ self.fit.is_grid_search = True if self.fit.instance is None: - self.fit.instance = self.model.instance_from_prior_medians( - ignore_prior_limits=True - ) + self.fit.instance = self.model.instance_from_prior_medians() child = type(self)( session=self.session, name=name or self.name, diff --git a/autofit/non_linear/samples/interface.py b/autofit/non_linear/samples/interface.py index 636416814..7e34a0fac 100644 --- a/autofit/non_linear/samples/interface.py +++ b/autofit/non_linear/samples/interface.py @@ -180,7 +180,7 @@ def model_bounded(self, b: float) -> AbstractPriorModel: ) def _instance_from_vector(self, vector: List[float]) -> ModelInstance: - return self.model.instance_from_vector(vector=vector, ignore_prior_limits=True) + return self.model.instance_from_vector(vector=vector) @property def prior_means(self) -> [List]: diff --git a/autofit/non_linear/samples/sample.py b/autofit/non_linear/samples/sample.py index 0346fbabe..930a1bdfe 100644 --- a/autofit/non_linear/samples/sample.py +++ b/autofit/non_linear/samples/sample.py @@ -207,10 +207,9 @@ def instance_for_model( ) except KeyError: - # TODO: Does this get used? If so, why? return model.instance_from_vector( self.parameter_lists_for_model(model), - ignore_prior_limits=ignore_assertions, + ignore_assertions=ignore_assertions, ) @split_paths diff --git a/autofit/non_linear/samples/samples.py b/autofit/non_linear/samples/samples.py index 742270b38..2bcf5db78 100644 --- a/autofit/non_linear/samples/samples.py +++ b/autofit/non_linear/samples/samples.py @@ -75,7 +75,6 @@ def instances(self): sample.parameter_lists_for_paths( self.paths if sample.is_path_kwargs else self.names ), - ignore_prior_limits=True, ) for sample in self.sample_list ] diff --git a/autofit/non_linear/search/abstract_search.py b/autofit/non_linear/search/abstract_search.py index 93eff3ba2..f17f9cbe8 100644 --- a/autofit/non_linear/search/abstract_search.py +++ b/autofit/non_linear/search/abstract_search.py @@ -4,8 +4,6 @@ import logging import multiprocessing as mp import os -import signal -import sys import time import warnings from abc import ABC, abstractmethod @@ -256,7 +254,6 @@ def __init__( if jax_wrapper.use_jax: self.number_of_cores = 1 - logger.warning(f"JAX is enabled. Setting number of cores to 1.") self.number_of_cores = number_of_cores @@ -1198,6 +1195,7 @@ def make_sneaky_pool(self, fitness: Fitness) -> Optional[SneakyPool]: ------- An implementation of a multiprocessing pool """ + self.logger.warning( "...using SneakyPool. This copies the likelihood function " "to each process on instantiation to avoid copying multiple " diff --git a/autofit/non_linear/search/mle/pyswarms/search/abstract.py b/autofit/non_linear/search/mle/pyswarms/search/abstract.py index 17718471b..6076eb0a1 100644 --- a/autofit/non_linear/search/mle/pyswarms/search/abstract.py +++ b/autofit/non_linear/search/mle/pyswarms/search/abstract.py @@ -188,10 +188,10 @@ def _fit(self, model: AbstractPriorModel, analysis): ## TODO : Use actual limits vector_lower = model.vector_from_unit_vector( - unit_vector=[1e-6] * model.prior_count, ignore_prior_limits=True + unit_vector=[1e-6] * model.prior_count, ) vector_upper = model.vector_from_unit_vector( - unit_vector=[0.9999999] * model.prior_count, ignore_prior_limits=True + unit_vector=[0.9999999] * model.prior_count, ) lower_bounds = [lower for lower in vector_lower] diff --git a/autofit/non_linear/search/nest/dynesty/search/abstract.py b/autofit/non_linear/search/nest/dynesty/search/abstract.py index 69eb81560..19dc7d624 100644 --- a/autofit/non_linear/search/nest/dynesty/search/abstract.py +++ b/autofit/non_linear/search/nest/dynesty/search/abstract.py @@ -20,7 +20,7 @@ def prior_transform(cube, model): phys_cube = model.vector_from_unit_vector( - unit_vector=cube, ignore_prior_limits=True + unit_vector=cube, ) for i in range(len(phys_cube)): @@ -115,8 +115,6 @@ def _fit( set of accepted samples of the fit. """ - from dynesty.pool import Pool - fitness = Fitness( model=model, analysis=analysis, @@ -152,6 +150,8 @@ def _fit( ): raise RuntimeError + from dynesty.pool import Pool + with Pool( njobs=self.number_of_cores, loglike=fitness, diff --git a/autofit/non_linear/search/nest/nautilus/search.py b/autofit/non_linear/search/nest/nautilus/search.py index 3864210f8..a6bae604d 100644 --- a/autofit/non_linear/search/nest/nautilus/search.py +++ b/autofit/non_linear/search/nest/nautilus/search.py @@ -19,7 +19,7 @@ def prior_transform(cube, model): - return model.vector_from_unit_vector(unit_vector=cube, ignore_prior_limits=True) + return model.vector_from_unit_vector(unit_vector=cube) class Nautilus(abstract_nest.AbstractNest): diff --git a/autofit/non_linear/search/nest/ultranest/search.py b/autofit/non_linear/search/nest/ultranest/search.py index ebf56123b..e81cdadb8 100644 --- a/autofit/non_linear/search/nest/ultranest/search.py +++ b/autofit/non_linear/search/nest/ultranest/search.py @@ -125,7 +125,6 @@ def _fit(self, model: AbstractPriorModel, analysis): def prior_transform(cube): return model.vector_from_unit_vector( unit_vector=cube, - ignore_prior_limits=True ) log_dir = self.paths.search_internal_path diff --git a/autofit/visualise.py b/autofit/visualise.py index 33614b144..68dbc1977 100644 --- a/autofit/visualise.py +++ b/autofit/visualise.py @@ -11,6 +11,7 @@ from autofit.mapper.prior.gaussian import GaussianPrior from autofit.mapper.prior.log_gaussian import LogGaussianPrior from autofit.mapper.prior.log_uniform import LogUniformPrior +from autofit.mapper.prior.truncated_gaussian import TruncatedGaussianPrior from autofit.mapper.prior_model.prior_model import ModelObject from autofit.mapper.prior_model.prior_model import Model from autofit.mapper.prior_model.collection import Collection @@ -42,6 +43,8 @@ def str_for_object(obj: ModelObject) -> str: return f"{obj.id}:LogGaussianPrior({obj.mean}, {obj.sigma})" if isinstance(obj, LogUniformPrior): return f"{obj.id}:LogUniformPrior({obj.lower_limit}, {obj.upper_limit})" + if isinstance(obj, TruncatedGaussianPrior): + return f"{obj.id}:TruncatedGaussianPrior({obj.mean}, {obj.sigma}, ({obj.lower_limit}, {obj.upper_limit})" return repr(obj) @@ -115,6 +118,7 @@ def colours(self) -> Dict[type, str]: GaussianPrior, LogGaussianPrior, LogUniformPrior, + TruncatedGaussianPrior, } | {model.cls for _, model in self.model.attribute_tuples_with_type(Model)} if isinstance(self.model, Model): types.add(self.model.cls) diff --git a/docs/cookbooks/configs.rst b/docs/cookbooks/configs.rst index 35a41446d..469caf94e 100644 --- a/docs/cookbooks/configs.rst +++ b/docs/cookbooks/configs.rst @@ -106,7 +106,7 @@ You should see the following text: lower_limit: 0.0 upper_limit: 1.0 parameter1: - type: Gaussian + type: TruncatedGaussian mean: 0.0 sigma: 0.1 lower_limit: 0.0 @@ -152,7 +152,7 @@ The ``.yaml`` file should read as follows: lower_limit: 0.0 upper_limit: 100.0 normalization: - type: Gaussian + type: TruncatedGaussian mean: 0.0 sigma: 0.1 lower_limit: 0.0 diff --git a/pyproject.toml b/pyproject.toml index 3d42be5cd..fae6c14d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "typing-inspect>=0.4.0", "emcee>=3.1.6", "gprof2dot==2021.2.21", + "jax==0.5.3", "matplotlib", "numpydoc>=1.0.0", "pyprojroot==0.2.0", diff --git a/test_autofit/aggregator/search_output/files/model.json b/test_autofit/aggregator/search_output/files/model.json index d4c7ebfd6..1eea7c604 100644 --- a/test_autofit/aggregator/search_output/files/model.json +++ b/test_autofit/aggregator/search_output/files/model.json @@ -3,24 +3,18 @@ "type": "model", "arguments": { "centre": { - "lower_limit": "-inf", - "upper_limit": "inf", "type": "Gaussian", "id": 0, "mean": 1.0, "sigma": 1.0 }, "normalization": { - "lower_limit": "-inf", - "upper_limit": "inf", "type": "Gaussian", "id": 1, "mean": 1.0, "sigma": 1.0 }, "sigma": { - "lower_limit": "-inf", - "upper_limit": "inf", "type": "Gaussian", "id": 2, "mean": 1.0, diff --git a/test_autofit/aggregator/test_reference.py b/test_autofit/aggregator/test_reference.py index cb56bfce1..611f95340 100644 --- a/test_autofit/aggregator/test_reference.py +++ b/test_autofit/aggregator/test_reference.py @@ -23,6 +23,7 @@ def test_without(directory): def test_with(): + aggregator = Aggregator.from_directory( Path(__file__).parent, reference={"": get_class_path(af.Exponential)}, @@ -84,11 +85,12 @@ def test_database_info( database_aggregator, output_directory, ): + print((output_directory / "database.info").read_text()) assert ( (output_directory / "database.info").read_text() == """ unique_id,name,unique_tag,total_free_parameters,is_complete - c4bf344d706947aa66b129ed2e05e1bd, , , 4, True -c4bf344d706947aa66b129ed2e05e1bd_0, , , 0, -c4bf344d706947aa66b129ed2e05e1bd_1, , , 0, + d05be1e6380082adea5c918af392d2b9, , , 4, True +d05be1e6380082adea5c918af392d2b9_0, , , 0, +d05be1e6380082adea5c918af392d2b9_1, , , 0, """ ) diff --git a/test_autofit/aggregator/test_scrape.py b/test_autofit/aggregator/test_scrape.py index 5033e7739..70cf35b56 100644 --- a/test_autofit/aggregator/test_scrape.py +++ b/test_autofit/aggregator/test_scrape.py @@ -42,24 +42,18 @@ def test_add_files(fit): "type": "model", "arguments": { "centre": { - "lower_limit": "-inf", - "upper_limit": "inf", "type": "Gaussian", "id": 0, "mean": 1.0, "sigma": 1.0, }, "normalization": { - "lower_limit": "-inf", - "upper_limit": "inf", "type": "Gaussian", "id": 1, "mean": 1.0, "sigma": 1.0, }, "sigma": { - "lower_limit": "-inf", - "upper_limit": "inf", "type": "Gaussian", "id": 2, "mean": 1.0, diff --git a/test_autofit/config/general.yaml b/test_autofit/config/general.yaml index 62b6374aa..db6860056 100644 --- a/test_autofit/config/general.yaml +++ b/test_autofit/config/general.yaml @@ -1,3 +1,5 @@ +jax: + use_jax: false # If True, PyAutoFit uses JAX internally, whereas False uses normal Numpy. analysis: n_cores: 1 # The number of cores a parallelized sum of Analysis classes uses by default. hpc: @@ -6,8 +8,6 @@ hpc: inversion: check_reconstruction: true # If True, the inversion's reconstruction is checked to ensure the solution of a meshs's mapper is not an invalid solution where the values are all the same. reconstruction_vmax_factor: 0.5 # Plots of an Inversion's reconstruction use the reconstructed data's bright value multiplied by this factor. -model: - ignore_prior_limits: false # If ``True`` the limits applied to priors will be ignored, where limits set upper / lower limits. This stops PriorLimitException's from being raised. output: force_pickle_overwrite: false # If True pickle files output by a search (e.g. samples.pickle) are recreated when a new model-fit is performed. info_whitespace_length: 80 # Length of whitespace between the parameter names and values in the model.info / result.info diff --git a/test_autofit/config/non_linear/mcmc.yaml b/test_autofit/config/non_linear/mcmc.yaml index 0e99e781f..004334db6 100644 --- a/test_autofit/config/non_linear/mcmc.yaml +++ b/test_autofit/config/non_linear/mcmc.yaml @@ -9,6 +9,7 @@ Emcee: ball_upper_limit: 0.51 method: prior parallel: + force_x1_cpu: true number_of_cores: 1 printing: silence: false diff --git a/test_autofit/config/non_linear/nest.yaml b/test_autofit/config/non_linear/nest.yaml index 03d83adf3..e8ce09b75 100644 --- a/test_autofit/config/non_linear/nest.yaml +++ b/test_autofit/config/non_linear/nest.yaml @@ -35,7 +35,7 @@ DynestyStatic: initialize: method: prior parallel: - force_x1_cpu: false + force_x1_cpu: true number_of_cores: 1 printing: silence: true diff --git a/test_autofit/config/priors/mock_model.yaml b/test_autofit/config/priors/mock_model.yaml index 7e8761902..a36e81818 100644 --- a/test_autofit/config/priors/mock_model.yaml +++ b/test_autofit/config/priors/mock_model.yaml @@ -1,6 +1,6 @@ Parameter: value: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -11,7 +11,7 @@ Parameter: value: 0.2 MockChildTuple: tup_0: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -21,7 +21,7 @@ MockChildTuple: type: Absolute value: 0.2 tup_1: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -32,7 +32,7 @@ MockChildTuple: value: 0.2 MockChildTuplex2: one: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -42,7 +42,7 @@ MockChildTuplex2: type: Absolute value: 0.2 tup_0: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -52,7 +52,7 @@ MockChildTuplex2: type: Absolute value: 0.2 tup_1: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -62,7 +62,7 @@ MockChildTuplex2: type: Absolute value: 0.2 two: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -73,7 +73,7 @@ MockChildTuplex2: value: 0.2 MockChildTuplex3: one: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -83,7 +83,7 @@ MockChildTuplex3: type: Absolute value: 0.2 three: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -93,7 +93,7 @@ MockChildTuplex3: type: Absolute value: 0.2 tup_0: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -103,7 +103,7 @@ MockChildTuplex3: type: Absolute value: 0.2 tup_1: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -113,7 +113,7 @@ MockChildTuplex3: type: Absolute value: 0.2 two: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -124,7 +124,7 @@ MockChildTuplex3: value: 0.2 MockClassInf: one: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: -inf @@ -134,7 +134,7 @@ MockClassInf: type: Absolute value: 0.2 two: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -145,7 +145,7 @@ MockClassInf: value: 0.2 MockClassRelativeWidth: one: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -155,7 +155,7 @@ MockClassRelativeWidth: type: Absolute value: 0.1 three: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -165,7 +165,7 @@ MockClassRelativeWidth: type: Absolute value: 1.0 two: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -176,7 +176,7 @@ MockClassRelativeWidth: value: 0.5 MockClassx2: one: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -186,7 +186,7 @@ MockClassx2: type: Absolute value: 1.0 two: - gaussian_limits: + limits: lower: 0.0 upper: 2.0 lower_limit: 0.0 @@ -197,7 +197,7 @@ MockClassx2: value: 2.0 MockClassx2FormatExp: one: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -207,7 +207,7 @@ MockClassx2FormatExp: type: Absolute value: 1.0 two_exp: - gaussian_limits: + limits: lower: 0.0 upper: 2.0 lower_limit: 0.0 @@ -218,7 +218,7 @@ MockClassx2FormatExp: value: 2.0 MockClassx2NoSuperScript: one: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -228,7 +228,7 @@ MockClassx2NoSuperScript: type: Absolute value: 1.0 two: - gaussian_limits: + limits: lower: 0.0 upper: 2.0 lower_limit: 0.0 @@ -239,7 +239,7 @@ MockClassx2NoSuperScript: value: 2.0 MockClassx2Tuple: one_tuple_0: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -249,7 +249,7 @@ MockClassx2Tuple: type: Absolute value: 0.2 one_tuple_1: - gaussian_limits: + limits: lower: 0.0 upper: 2.0 lower_limit: 0.0 @@ -260,7 +260,7 @@ MockClassx2Tuple: value: 0.2 MockClassx3TupleFloat: one_tuple_0: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -270,7 +270,7 @@ MockClassx3TupleFloat: type: Absolute value: 0.2 one_tuple_1: - gaussian_limits: + limits: lower: 0.0 upper: 2.0 lower_limit: 0.0 @@ -280,7 +280,7 @@ MockClassx3TupleFloat: type: Absolute value: 0.2 two: - gaussian_limits: + limits: lower: 0.0 upper: 2.0 lower_limit: 0.0 @@ -291,7 +291,7 @@ MockClassx3TupleFloat: value: 0.2 MockClassx4: four: - gaussian_limits: + limits: lower: -120.0 upper: 120.0 lower_limit: -120.0 @@ -301,7 +301,7 @@ MockClassx4: type: Absolute value: 2.0 one: - gaussian_limits: + limits: lower: -120.0 upper: 120.0 lower_limit: -120.0 @@ -311,7 +311,7 @@ MockClassx4: type: Absolute value: 1.0 three: - gaussian_limits: + limits: lower: -120.0 upper: 120.0 lower_limit: -120.0 @@ -321,7 +321,7 @@ MockClassx4: type: Absolute value: 2.0 two: - gaussian_limits: + limits: lower: -120.0 upper: 120.0 lower_limit: -120.0 @@ -332,7 +332,7 @@ MockClassx4: value: 2.0 MockComponents: parameter: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -350,7 +350,7 @@ MockDeferredClass: type: Deferred MockDistanceClass: one: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -360,7 +360,7 @@ MockDistanceClass: type: Absolute value: 0.2 two: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -389,7 +389,7 @@ MockPositionClass: upper_limit: 1.0 MockWithFloat: value: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -400,7 +400,7 @@ MockWithFloat: value: 0.2 MockWithTuple: tup_0: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -410,7 +410,7 @@ MockWithTuple: type: Absolute value: 0.2 tup_1: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 diff --git a/test_autofit/config/priors/model.yaml b/test_autofit/config/priors/model.yaml index 16e704f99..31da0ebb3 100644 --- a/test_autofit/config/priors/model.yaml +++ b/test_autofit/config/priors/model.yaml @@ -1,6 +1,6 @@ Gaussian: centre: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -10,7 +10,7 @@ Gaussian: type: Absolute value: 1.0 normalization: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -20,7 +20,7 @@ Gaussian: type: Absolute value: 1.0 sigma: - gaussian_limits: + limits: lower: 0.0 upper: 1.0 lower_limit: 0.0 @@ -36,36 +36,28 @@ PhysicalNFW: type: Gaussian mean: 0.0 sigma: 1.0 - lower_limit: -inf - upper_limit: inf centre_1: type: Gaussian mean: 0.0 sigma: 1.0 - lower_limit: -inf - upper_limit: inf ell_comps_0: type: Gaussian mean: 0.0 sigma: 0.3 - lower_limit: -1.0 - upper_limit: 1.0 width_modifier: type: Absolute value: 0.2 - gaussian_limits: + limits: lower: -1.0 upper: 1.0 ell_comps_1: type: Gaussian mean: 0.0 sigma: 0.3 - lower_limit: -1.0 - upper_limit: 1.0 width_modifier: type: Absolute value: 0.2 - gaussian_limits: + limits: lower: -1.0 upper: 1.0 log10m: diff --git a/test_autofit/config/priors/prior.yaml b/test_autofit/config/priors/prior.yaml deleted file mode 100644 index 4cc72c4ae..000000000 --- a/test_autofit/config/priors/prior.yaml +++ /dev/null @@ -1,7 +0,0 @@ -gaussian.GaussianPrior: - lower_limit: - type: Constant - value: -inf - upper_limit: - type: Constant - value: inf diff --git a/test_autofit/conftest.py b/test_autofit/conftest.py index 0380fc9e9..928cde646 100644 --- a/test_autofit/conftest.py +++ b/test_autofit/conftest.py @@ -1,3 +1,4 @@ +import jax import multiprocessing import os import shutil @@ -23,7 +24,6 @@ @pytest.fixture(name="recreate") def recreate(): - jax = pytest.importorskip("jax") def _recreate(o): flatten_func, unflatten_func = jax._src.tree_util._registry[type(o)] diff --git a/test_autofit/database/identifier/test_identifiers.py b/test_autofit/database/identifier/test_identifiers.py index 47300c6d1..fc29f4c19 100644 --- a/test_autofit/database/identifier/test_identifiers.py +++ b/test_autofit/database/identifier/test_identifiers.py @@ -204,7 +204,7 @@ def test__identifier_description(): centre=af.UniformPrior(lower_limit=0.0, upper_limit=1.0), normalization=af.LogUniformPrior(lower_limit=0.001, upper_limit=0.01), sigma=af.GaussianPrior( - mean=0.5, sigma=2.0, lower_limit=-1.0, upper_limit=1.0 + mean=0.5, sigma=2.0, ), ) ) @@ -213,8 +213,8 @@ def test__identifier_description(): description = identifier.description.splitlines() - i = 0 + i = 0 assert description[i] == "Collection" i += 1 assert description[i] == "item_number" @@ -257,14 +257,6 @@ def test__identifier_description(): i += 1 assert description[i] == "GaussianPrior" i += 1 - assert description[i] == "lower_limit" - i += 1 - assert description[i] == "-1.0" - i += 1 - assert description[i] == "upper_limit" - i += 1 - assert description[i] == "1.0" - i += 1 assert description[i] == "mean" i += 1 assert description[i] == "0.5" @@ -282,7 +274,7 @@ def test__identifier_description__after_model_and_instance(): centre=af.UniformPrior(lower_limit=0.0, upper_limit=1.0), normalization=af.LogUniformPrior(lower_limit=0.001, upper_limit=0.01), sigma=af.GaussianPrior( - mean=0.5, sigma=2.0, lower_limit=-1.0, upper_limit=1.0 + mean=0.5, sigma=2.0, ), ) ) @@ -305,6 +297,7 @@ def test__identifier_description__after_model_and_instance(): identifier = Identifier([model]) description = identifier.description + assert ( description == """Collection @@ -315,7 +308,7 @@ def test__identifier_description__after_model_and_instance(): cls autofit.example.model.Gaussian centre -GaussianPrior +TruncatedGaussianPrior lower_limit 0.0 upper_limit @@ -328,10 +321,6 @@ def test__identifier_description__after_model_and_instance(): 0.00316228 sigma GaussianPrior -lower_limit --1.0 -upper_limit -1.0 mean 0.5 sigma @@ -346,7 +335,7 @@ def test__identifier_description__after_take_attributes(): centre=af.UniformPrior(lower_limit=0.0, upper_limit=1.0), normalization=af.LogUniformPrior(lower_limit=0.001, upper_limit=0.01), sigma=af.GaussianPrior( - mean=0.5, sigma=2.0, lower_limit=-1.0, upper_limit=1.0 + mean=0.5, sigma=2.0, ), ) ) @@ -403,14 +392,6 @@ def test__identifier_description__after_take_attributes(): i += 1 assert description[i] == "GaussianPrior" i += 1 - assert description[i] == "lower_limit" - i += 1 - assert description[i] == "-1.0" - i += 1 - assert description[i] == "upper_limit" - i += 1 - assert description[i] == "1.0" - i += 1 assert description[i] == "mean" i += 1 assert description[i] == "0.5" diff --git a/test_autofit/graphical/functionality/test_factor_graph.py b/test_autofit/graphical/functionality/test_factor_graph.py index 9af8500f8..b475070f4 100644 --- a/test_autofit/graphical/functionality/test_factor_graph.py +++ b/test_autofit/graphical/functionality/test_factor_graph.py @@ -100,29 +100,29 @@ def func(a, b): assert grad[c] == pytest.approx(3) -def test_nested_factor_jax(): - def func(a, b): - a0 = a[0] - c = a[1]["c"] - return a0 * c * b - - a, b, c = graph.variables("a, b, c") - - f = func((1, {"c": 2}), 3) - values = {a: 1.0, b: 3.0, c: 2.0} - - pytest.importorskip("jax") - - factor = graph.Factor(func, (a, {"c": c}), b, vjp=True) - - assert factor(values) == pytest.approx(f) - - fval, grad = factor.func_gradient(values) - - assert fval == pytest.approx(f) - assert grad[a] == pytest.approx(6) - assert grad[b] == pytest.approx(2) - assert grad[c] == pytest.approx(3) +# def test_nested_factor_jax(): +# def func(a, b): +# a0 = a[0] +# c = a[1]["c"] +# return a0 * c * b +# +# a, b, c = graph.variables("a, b, c") +# +# f = func((1, {"c": 2}), 3) +# values = {a: 1.0, b: 3.0, c: 2.0} +# +# pytest.importorskip("jax") +# +# factor = graph.Factor(func, (a, {"c": c}), b, vjp=True) +# +# assert factor(values) == pytest.approx(f) +# +# fval, grad = factor.func_gradient(values) +# +# assert fval == pytest.approx(f) +# assert grad[a] == pytest.approx(6) +# assert grad[b] == pytest.approx(2) +# assert grad[c] == pytest.approx(3) class TestFactorGraph: diff --git a/test_autofit/graphical/functionality/test_jacobians.py b/test_autofit/graphical/functionality/test_jacobians.py index 5a20c2c34..690c93e35 100644 --- a/test_autofit/graphical/functionality/test_jacobians.py +++ b/test_autofit/graphical/functionality/test_jacobians.py @@ -1,15 +1,8 @@ from itertools import combinations - +import jax import numpy as np import pytest -# try: -# import jax -# -# _HAS_JAX = True -# except ImportError: -_HAS_JAX = False - from autofit.mapper.variable import variables from autofit.graphical.factor_graphs import ( Factor, @@ -17,136 +10,132 @@ ) -def test_jacobian_equiv(): - if not _HAS_JAX: - return - - def linear(x, a, b, c): - z = x.dot(a) + b - return (z**2).sum(), z - - x_, a_, b_, c_, z_ = variables("x, a, b, c, z") - x = np.arange(10.0).reshape(5, 2) - a = np.arange(2.0).reshape(2, 1) - b = np.ones(1) - c = -1.0 - - factors = [ - Factor( - linear, - x_, - a_, - b_, - c_, - factor_out=(FactorValue, z_), - numerical_jacobian=False, - ), - Factor( - linear, - x_, - a_, - b_, - c_, - factor_out=(FactorValue, z_), - numerical_jacobian=False, - jacfwd=False, - ), - Factor( - linear, - x_, - a_, - b_, - c_, - factor_out=(FactorValue, z_), - numerical_jacobian=False, - vjp=True, - ), - Factor( - linear, - x_, - a_, - b_, - c_, - factor_out=(FactorValue, z_), - numerical_jacobian=True, - ), - ] - - values = {x_: x, a_: a, b_: b, c_: c} - outputs = [factor.func_jacobian(values) for factor in factors] - - tol = pytest.approx(0, abs=1e-4) - pairs = combinations(outputs, 2) - g0 = FactorValue(1.0, {z_: np.ones((5, 1))}) - for (val1, jac1), (val2, jac2) in pairs: - assert val1 == val2 - - # test with different ways of calculating gradients - grad1, grad2 = jac1.grad(g0), jac2.grad(g0) - assert (grad1 - grad2).norm() == tol - grad1 = g0.to_dict() * jac1 - assert (grad1 - grad2).norm() == tol - grad2 = g0.to_dict() * jac2 - assert (grad1 - grad2).norm() == tol - - grad1, grad2 = jac1.grad(val1), jac2.grad(val2) - assert (grad1 - grad2).norm() == tol - - # test getting gradient with no args - assert (jac1.grad() - jac2.grad()).norm() == tol - - -def test_jac_model(): - if not _HAS_JAX: - return - - def linear(x, a, b): - z = x.dot(a) + b - return (z**2).sum(), z - - def likelihood(y, z): - return ((y - z) ** 2).sum() - - def combined(x, y, a, b): - like, z = linear(x, a, b) - return like + likelihood(y, z) - - x_, a_, b_, y_, z_ = variables("x, a, b, y, z") - x = np.arange(10.0).reshape(5, 2) - a = np.arange(2.0).reshape(2, 1) - b = np.ones(1) - y = np.arange(0.0, 10.0, 2).reshape(5, 1) - values = {x_: x, y_: y, a_: a, b_: b} - linear_factor = Factor(linear, x_, a_, b_, factor_out=(FactorValue, z_), vjp=True) - like_factor = Factor(likelihood, y_, z_, vjp=True) - full_factor = Factor(combined, x_, y_, a_, b_, vjp=True) - model_factor = like_factor * linear_factor - - x = np.arange(10.0).reshape(5, 2) - a = np.arange(2.0).reshape(2, 1) - b = np.ones(1) - y = np.arange(0.0, 10.0, 2).reshape(5, 1) - values = {x_: x, y_: y, a_: a, b_: b} - - # Fully working problem - fval, jac = full_factor.func_jacobian(values) - grad = jac.grad() - - model_val, model_jac = model_factor.func_jacobian(values) - model_grad = model_jac.grad() - - linear_val, linear_jac = linear_factor.func_jacobian(values) - like_val, like_jac = like_factor.func_jacobian( - {**values, **linear_val.deterministic_values} - ) - combined_val = like_val + linear_val - - # Manually back propagate - combined_grads = linear_jac.grad(like_jac.grad()) - - vals = (fval, model_val, combined_val) - grads = (grad, model_grad, combined_grads) - pairs = combinations(zip(vals, grads), 2) - for (val1, grad1), (val2, grad2) in pairs: - assert val1 == val2 - assert (grad1 - grad2).norm() == pytest.approx(0, 1e-6) +# def test_jacobian_equiv(): +# +# def linear(x, a, b, c): +# z = x.dot(a) + b +# return (z**2).sum(), z +# +# x_, a_, b_, c_, z_ = variables("x, a, b, c, z") +# x = np.arange(10.0).reshape(5, 2) +# a = np.arange(2.0).reshape(2, 1) +# b = np.ones(1) +# c = -1.0 +# +# factors = [ +# Factor( +# linear, +# x_, +# a_, +# b_, +# c_, +# factor_out=(FactorValue, z_), +# numerical_jacobian=False, +# ), +# Factor( +# linear, +# x_, +# a_, +# b_, +# c_, +# factor_out=(FactorValue, z_), +# numerical_jacobian=False, +# jacfwd=False, +# ), +# Factor( +# linear, +# x_, +# a_, +# b_, +# c_, +# factor_out=(FactorValue, z_), +# numerical_jacobian=False, +# vjp=True, +# ), +# Factor( +# linear, +# x_, +# a_, +# b_, +# c_, +# factor_out=(FactorValue, z_), +# numerical_jacobian=True, +# ), +# ] +# +# values = {x_: x, a_: a, b_: b, c_: c} +# outputs = [factor.func_jacobian(values) for factor in factors] +# +# tol = pytest.approx(0, abs=1e-4) +# pairs = combinations(outputs, 2) +# g0 = FactorValue(1.0, {z_: np.ones((5, 1))}) +# for (val1, jac1), (val2, jac2) in pairs: +# assert val1 == val2 +# +# # test with different ways of calculating gradients +# grad1, grad2 = jac1.grad(g0), jac2.grad(g0) +# assert (grad1 - grad2).norm() == tol +# grad1 = g0.to_dict() * jac1 +# assert (grad1 - grad2).norm() == tol +# grad2 = g0.to_dict() * jac2 +# assert (grad1 - grad2).norm() == tol +# +# grad1, grad2 = jac1.grad(val1), jac2.grad(val2) +# assert (grad1 - grad2).norm() == tol +# +# # test getting gradient with no args +# assert (jac1.grad() - jac2.grad()).norm() == tol +# +# +# def test_jac_model(): +# +# def linear(x, a, b): +# z = x.dot(a) + b +# return (z**2).sum(), z +# +# def likelihood(y, z): +# return ((y - z) ** 2).sum() +# +# def combined(x, y, a, b): +# like, z = linear(x, a, b) +# return like + likelihood(y, z) +# +# x_, a_, b_, y_, z_ = variables("x, a, b, y, z") +# x = np.arange(10.0).reshape(5, 2) +# a = np.arange(2.0).reshape(2, 1) +# b = np.ones(1) +# y = np.arange(0.0, 10.0, 2).reshape(5, 1) +# values = {x_: x, y_: y, a_: a, b_: b} +# linear_factor = Factor(linear, x_, a_, b_, factor_out=(FactorValue, z_), vjp=True) +# like_factor = Factor(likelihood, y_, z_, vjp=True) +# full_factor = Factor(combined, x_, y_, a_, b_, vjp=True) +# model_factor = like_factor * linear_factor +# +# x = np.arange(10.0).reshape(5, 2) +# a = np.arange(2.0).reshape(2, 1) +# b = np.ones(1) +# y = np.arange(0.0, 10.0, 2).reshape(5, 1) +# values = {x_: x, y_: y, a_: a, b_: b} +# +# # Fully working problem +# fval, jac = full_factor.func_jacobian(values) +# grad = jac.grad() +# +# model_val, model_jac = model_factor.func_jacobian(values) +# model_grad = model_jac.grad() +# +# linear_val, linear_jac = linear_factor.func_jacobian(values) +# like_val, like_jac = like_factor.func_jacobian( +# {**values, **linear_val.deterministic_values} +# ) +# combined_val = like_val + linear_val +# +# # Manually back propagate +# combined_grads = linear_jac.grad(like_jac.grad()) +# +# vals = (fval, model_val, combined_val) +# grads = (grad, model_grad, combined_grads) +# pairs = combinations(zip(vals, grads), 2) +# for (val1, grad1), (val2, grad2) in pairs: +# assert val1 == val2 +# assert (grad1 - grad2).norm() == pytest.approx(0, 1e-6) diff --git a/test_autofit/graphical/functionality/test_nested.py b/test_autofit/graphical/functionality/test_nested.py index e5d436f9e..6ecfad756 100644 --- a/test_autofit/graphical/functionality/test_nested.py +++ b/test_autofit/graphical/functionality/test_nested.py @@ -232,7 +232,6 @@ def test_nested_items(): ], ] - # Need jax version > 0.4 if hasattr(tree_util, "tree_flatten_with_path"): jax_flat = tree_util.tree_flatten_with_path(obj1)[0] af_flat = utils.nested_items(obj2) diff --git a/test_autofit/graphical/global/test_global.py b/test_autofit/graphical/global/test_global.py index 24e7b8760..c554c254a 100644 --- a/test_autofit/graphical/global/test_global.py +++ b/test_autofit/graphical/global/test_global.py @@ -48,7 +48,6 @@ def test_single_factor(self, model_factor, unit_value, likelihood): model_factor.log_likelihood_function( model_factor.global_prior_model.instance_from_unit_vector( [unit_value], - ignore_prior_limits=True, )[0] ) == likelihood @@ -60,7 +59,7 @@ def test_collection(self, model_factor, unit_value, likelihood): assert ( collection.log_likelihood_function( collection.global_prior_model.instance_from_unit_vector( - [unit_value], ignore_prior_limits=True + [unit_value] ) ) == likelihood @@ -75,7 +74,7 @@ def test_two_factor(self, model_factor, model_factor_2, unit_vector, likelihood) assert ( collection.log_likelihood_function( collection.global_prior_model.instance_from_unit_vector( - unit_vector, ignore_prior_limits=True + unit_vector ) ) == likelihood diff --git a/test_autofit/graphical/global/test_hierarchical.py b/test_autofit/graphical/global/test_hierarchical.py index 0426ab529..26b3e4da3 100644 --- a/test_autofit/graphical/global/test_hierarchical.py +++ b/test_autofit/graphical/global/test_hierarchical.py @@ -34,6 +34,7 @@ def reset_ids(): def test_model_info(model): + print(model.info) assert ( model.info == """Total Free Parameters = 4 @@ -47,8 +48,6 @@ def test_model_info(model): distribution_model mean GaussianPrior [2], mean = 0.5, sigma = 0.1 sigma GaussianPrior [3], mean = 1.0, sigma = 0.01 - lower_limit -inf - upper_limit inf 0 drawn_prior UniformPrior [0], lower_limit = 0.0, upper_limit = 1.0 1 diff --git a/test_autofit/graphical/hierarchical/test_embedded.py b/test_autofit/graphical/hierarchical/test_embedded.py index d46034acc..53836bb95 100644 --- a/test_autofit/graphical/hierarchical/test_embedded.py +++ b/test_autofit/graphical/hierarchical/test_embedded.py @@ -19,7 +19,7 @@ def make_centre_model(): return g.HierarchicalFactor( af.GaussianPrior, mean=af.GaussianPrior(mean=100, sigma=10), - sigma=af.GaussianPrior(mean=10, sigma=5, lower_limit=0), + sigma=af.GaussianPrior(mean=10, sigma=5), ) diff --git a/test_autofit/graphical/regression/test_identifier.py b/test_autofit/graphical/regression/test_identifier.py index 5a2a0132b..944c32e29 100644 --- a/test_autofit/graphical/regression/test_identifier.py +++ b/test_autofit/graphical/regression/test_identifier.py @@ -5,22 +5,24 @@ @pytest.fixture( - name="gaussian_prior" + name="truncated_gaussian_prior" ) -def make_gaussian_prior(): +def make_truncated_gaussian_prior(): return Identifier( - af.GaussianPrior( + af.TruncatedGaussianPrior( mean=1.0, - sigma=2.0 + sigma=2.0, + lower_limit="-inf", + upper_limit="inf" ) ) -def test_gaussian_prior_fields( - gaussian_prior +def test_truncated_gaussian_prior_fields( + truncated_gaussian_prior ): - assert gaussian_prior.hash_list == [ - 'GaussianPrior', + assert truncated_gaussian_prior.hash_list == [ + 'TruncatedGaussianPrior', 'lower_limit', '-inf', 'upper_limit', @@ -32,12 +34,12 @@ def test_gaussian_prior_fields( ] -def test_gaussian_prior( - gaussian_prior +def test_truncated_gaussian_prior( + truncated_gaussian_prior ): assert str( - gaussian_prior - ) == "218e05b43472cb7661b4712da640a81c" + truncated_gaussian_prior + ) == "b9d2c8e380214bf0888a5f65f651eb5c" @pytest.fixture( diff --git a/test_autofit/graphical/test_unification.py b/test_autofit/graphical/test_unification.py index a403a49d9..8f654f250 100644 --- a/test_autofit/graphical/test_unification.py +++ b/test_autofit/graphical/test_unification.py @@ -91,7 +91,7 @@ def test_uniform_prior(lower_limit, upper_limit, unit_value, physical_value): lower_limit=lower_limit, upper_limit=upper_limit, ).value_for( - unit_value, ignore_prior_limits=True + unit_value ) == pytest.approx(physical_value) diff --git a/test_autofit/jax/test_pytrees.py b/test_autofit/jax/test_pytrees.py index 56d080adf..eb1d6af78 100644 --- a/test_autofit/jax/test_pytrees.py +++ b/test_autofit/jax/test_pytrees.py @@ -27,7 +27,7 @@ def vmapped(gaussian, size=1000): def test_gaussian_prior(recreate): - prior = af.GaussianPrior(mean=1.0, sigma=1.0) + prior = af.TruncatedGaussianPrior(mean=1.0, sigma=1.0) new = recreate(prior) @@ -44,8 +44,8 @@ def _model(): return af.Model( af.Gaussian, centre=af.GaussianPrior(mean=1.0, sigma=1.0), - normalization=af.GaussianPrior(mean=1.0, sigma=1.0, lower_limit=0.0), - sigma=af.GaussianPrior(mean=1.0, sigma=1.0, lower_limit=0.0), + normalization=af.GaussianPrior(mean=1.0, sigma=1.0), + sigma=af.GaussianPrior(mean=1.0, sigma=1.0), ) diff --git a/test_autofit/mapper/functionality/test_take_attributes.py b/test_autofit/mapper/functionality/test_take_attributes.py index 156df8c4b..002e43a72 100644 --- a/test_autofit/mapper/functionality/test_take_attributes.py +++ b/test_autofit/mapper/functionality/test_take_attributes.py @@ -264,7 +264,7 @@ def test_limits( source_gaussian, target_gaussian ): - source_gaussian.centre = af.GaussianPrior( + source_gaussian.centre = af.TruncatedGaussianPrior( mean=0, sigma=1, lower_limit=-1, diff --git a/test_autofit/mapper/model/test_model_instance.py b/test_autofit/mapper/model/test_model_instance.py index 05e29ee37..e131f8e50 100644 --- a/test_autofit/mapper/model/test_model_instance.py +++ b/test_autofit/mapper/model/test_model_instance.py @@ -71,7 +71,7 @@ def test_simple_model(self): mapper.mock_class = af.m.MockClassx2 model_map = mapper.instance_from_unit_vector( - [1.0, 1.0], ignore_prior_limits=True + [1.0, 1.0] ) assert isinstance(model_map.mock_class, af.m.MockClassx2) @@ -85,7 +85,7 @@ def test_two_object_model(self): mapper.mock_class_2 = af.m.MockClassx2 model_map = mapper.instance_from_unit_vector( - [1.0, 0.0, 0.0, 1.0], ignore_prior_limits=True + [1.0, 0.0, 0.0, 1.0] ) assert isinstance(model_map.mock_class_1, af.m.MockClassx2) diff --git a/test_autofit/mapper/model/test_model_mapper.py b/test_autofit/mapper/model/test_model_mapper.py index 41e42d7f8..198c30f65 100644 --- a/test_autofit/mapper/model/test_model_mapper.py +++ b/test_autofit/mapper/model/test_model_mapper.py @@ -532,7 +532,7 @@ def test_prior_replacement(self): mapper = af.ModelMapper(mock_class=af.m.MockClassx2) result = mapper.mapper_from_prior_means([10, 5]) - assert isinstance(result.mock_class.one, af.GaussianPrior) + assert isinstance(result.mock_class.one, af.TruncatedGaussianPrior) assert {prior.id for prior in mapper.priors} == { prior.id for prior in result.priors } @@ -541,19 +541,19 @@ def test_replace_priors_with_gaussians_from_tuples(self): mapper = af.ModelMapper(mock_class=af.m.MockClassx2) result = mapper.mapper_from_prior_means([10, 5]) - assert isinstance(result.mock_class.one, af.GaussianPrior) + assert isinstance(result.mock_class.one, af.TruncatedGaussianPrior) def test_replacing_priors_for_profile(self): mapper = af.ModelMapper(mock_class=af.m.MockClassx3TupleFloat) result = mapper.mapper_from_prior_means([10, 5, 5]) assert isinstance( - result.mock_class.one_tuple.unique_prior_tuples[0][1], af.GaussianPrior + result.mock_class.one_tuple.unique_prior_tuples[0][1], af.TruncatedGaussianPrior ) assert isinstance( - result.mock_class.one_tuple.unique_prior_tuples[1][1], af.GaussianPrior + result.mock_class.one_tuple.unique_prior_tuples[1][1], af.TruncatedGaussianPrior ) - assert isinstance(result.mock_class.two, af.GaussianPrior) + assert isinstance(result.mock_class.two, af.TruncatedGaussianPrior) def test_replace_priors_for_two_classes(self): mapper = af.ModelMapper(one=af.m.MockClassx2, two=af.m.MockClassx2) diff --git a/test_autofit/mapper/model/test_regression.py b/test_autofit/mapper/model/test_regression.py index fa381fb6a..de6a3c35e 100644 --- a/test_autofit/mapper/model/test_regression.py +++ b/test_autofit/mapper/model/test_regression.py @@ -96,8 +96,8 @@ def test_passing_priors(): model = af.Model(af.m.MockWithTuple) new_model = model.mapper_from_prior_means([1, 1]) - assert isinstance(new_model.tup_0, af.GaussianPrior) - assert isinstance(new_model.tup_1, af.GaussianPrior) + assert isinstance(new_model.tup_0, af.TruncatedGaussianPrior) + assert isinstance(new_model.tup_1, af.TruncatedGaussianPrior) def test_passing_fixed(): @@ -150,12 +150,12 @@ def make_model_with_assertion(): def test_instance_from_vector(model_with_assertion): model_with_assertion.instance_from_vector( [0.5, 0.5, 0.5], - ignore_prior_limits=True, + ignore_assertions=True ) def test_random_instance(model_with_assertion): - model_with_assertion.random_instance(ignore_prior_limits=True) + model_with_assertion.random_instance(ignore_assertions=True) class TestModel: diff --git a/test_autofit/mapper/prior/test_arithmetic.py b/test_autofit/mapper/prior/test_arithmetic.py index 38831f64c..ea77bd534 100644 --- a/test_autofit/mapper/prior/test_arithmetic.py +++ b/test_autofit/mapper/prior/test_arithmetic.py @@ -51,7 +51,7 @@ class TestDivision: def test_prior_over_prior(self, prior): division_prior = prior / prior assert ( - division_prior.instance_from_unit_vector([0.5], ignore_prior_limits=True) + division_prior.instance_from_unit_vector([0.5]) == 1 ) @@ -73,14 +73,14 @@ class TestFloorDiv: def test_prior_over_int(self, ten_prior): division_prior = ten_prior // 2 assert ( - division_prior.instance_from_unit_vector([0.5], ignore_prior_limits=True) + division_prior.instance_from_unit_vector([0.5]) == 2.0 ) def test_int_over_prior(self, ten_prior): division_prior = 3 // ten_prior assert ( - division_prior.instance_from_unit_vector([0.2], ignore_prior_limits=True) + division_prior.instance_from_unit_vector([0.2]) == 1.0 ) @@ -89,13 +89,13 @@ class TestMod: def test_prior_mod_int(self, ten_prior): mod_prior = ten_prior % 3 assert ( - mod_prior.instance_from_unit_vector([0.5], ignore_prior_limits=True) == 2.0 + mod_prior.instance_from_unit_vector([0.5]) == 2.0 ) def test_int_mod_prior(self, ten_prior): mod_prior = 5.0 % ten_prior assert mod_prior.instance_from_unit_vector( - [0.3], ignore_prior_limits=True + [0.3] ) == pytest.approx(2.0) @@ -110,19 +110,19 @@ class TestPowers: def test_prior_to_prior(self, ten_prior): power_prior = ten_prior ** ten_prior assert power_prior.instance_from_unit_vector( - [0.2], ignore_prior_limits=True + [0.2] ) == pytest.approx(4.0) def test_prior_to_float(self, ten_prior): power_prior = ten_prior ** 3 assert power_prior.instance_from_unit_vector( - [0.2], ignore_prior_limits=True + [0.2] ) == pytest.approx(8.0) def test_float_to_prior(self, ten_prior): power_prior = 3.0 ** ten_prior assert power_prior.instance_from_unit_vector( - [0.2], ignore_prior_limits=True + [0.2] ) == pytest.approx(9.0) @@ -130,12 +130,12 @@ class TestInequality: def test_prior_lt_prior(self, prior): inequality_prior = (prior * prior) < prior result = inequality_prior.instance_from_unit_vector( - [0.5], ignore_prior_limits=True + [0.5] ) assert result inequality_prior = (prior * prior) > prior assert not ( - inequality_prior.instance_from_unit_vector([0.5], ignore_prior_limits=True) + inequality_prior.instance_from_unit_vector([0.5]) ) diff --git a/test_autofit/mapper/prior/test_limits.py b/test_autofit/mapper/prior/test_limits.py deleted file mode 100644 index 5f94a5197..000000000 --- a/test_autofit/mapper/prior/test_limits.py +++ /dev/null @@ -1,85 +0,0 @@ -import numpy as np -import pytest - -import autofit as af -from autofit.exc import PriorLimitException - - -@pytest.fixture(name="prior") -def make_prior(): - return af.GaussianPrior(mean=3.0, sigma=5.0, lower_limit=0.0) - - -def test_intrinsic_lower_limit(prior): - with pytest.raises(PriorLimitException): - prior.value_for(0.0) - - -def test_optional(prior): - prior.value_for(0.0, ignore_prior_limits=True) - - -@pytest.fixture(name="model") -def make_model(prior): - return af.Model(af.Gaussian, centre=prior) - - -def test_vector_from_unit_vector(model): - with pytest.raises(PriorLimitException): - model.vector_from_unit_vector([0, 0, 0]) - - -def test_vector_ignore_limits(model): - model.vector_from_unit_vector([0, 0, 0], ignore_prior_limits=True) - - -@pytest.mark.parametrize( - "prior", - [ - af.LogUniformPrior(), - af.UniformPrior(), - af.GaussianPrior(mean=0, sigma=1, lower_limit=0.0, upper_limit=1.0,), - ], -) -@pytest.mark.parametrize("value", [-1.0, 2.0]) -def test_all_priors(prior, value): - with pytest.raises(PriorLimitException): - prior.value_for(value) - - prior.value_for(value, ignore_prior_limits=True) - - -@pytest.fixture(name="limitless_prior") -def make_limitless_prior(): - return af.GaussianPrior(mean=1.0, sigma=2.0,) - - -@pytest.mark.parametrize("value", np.arange(0, 1, 0.1)) -def test_invert_limits(value, limitless_prior): - value = float(value) - assert limitless_prior.message.cdf( - limitless_prior.value_for(value) - ) == pytest.approx(value) - - -def test_unit_limits(): - prior = af.GaussianPrior(mean=1.0, sigma=2.0, lower_limit=-10, upper_limit=5,) - EPSILON = 0.00001 - assert prior.value_for(prior.lower_unit_limit) - assert prior.value_for(prior.upper_unit_limit - EPSILON) - - with pytest.raises(PriorLimitException): - prior.value_for(prior.lower_unit_limit - EPSILON) - with pytest.raises(PriorLimitException): - prior.value_for(prior.upper_unit_limit + EPSILON) - - -def test_infinite_limits(limitless_prior): - assert limitless_prior.lower_unit_limit == 0.0 - assert limitless_prior.upper_unit_limit == 1.0 - - -def test_uniform_prior(): - uniform_prior = af.UniformPrior(lower_limit=1.0, upper_limit=2.0,) - assert uniform_prior.lower_unit_limit == pytest.approx(0.0) - assert uniform_prior.upper_unit_limit == pytest.approx(1.0) diff --git a/test_autofit/mapper/prior/test_log_gaussian.py b/test_autofit/mapper/prior/test_log_gaussian.py index 8abfd2d4e..6bef0d064 100644 --- a/test_autofit/mapper/prior/test_log_gaussian.py +++ b/test_autofit/mapper/prior/test_log_gaussian.py @@ -28,11 +28,5 @@ def test_pickle(log_gaussian): loaded = pickle.loads(pickle.dumps(log_gaussian)) assert loaded == log_gaussian - -def test_attributes(log_gaussian): - assert log_gaussian.lower_limit == 0 - assert log_gaussian.upper_limit == float("inf") - - def test_identifier(log_gaussian): Identifier(log_gaussian) diff --git a/test_autofit/mapper/prior/test_prior.py b/test_autofit/mapper/prior/test_prior.py index ad2207071..afeb64545 100644 --- a/test_autofit/mapper/prior/test_prior.py +++ b/test_autofit/mapper/prior/test_prior.py @@ -10,41 +10,6 @@ class TestPriorLimits: def test_out_of_order_prior_limits(self): with pytest.raises(af.exc.PriorException): af.UniformPrior(1.0, 0) - with pytest.raises(af.exc.PriorException): - af.GaussianPrior(0, 1, 1, 0) - - def test_in_or_out(self): - prior = af.GaussianPrior(0, 1, 0, 1) - with pytest.raises(af.exc.PriorLimitException): - prior.assert_within_limits(-1) - - with pytest.raises(af.exc.PriorLimitException): - prior.assert_within_limits(1.1) - - prior.assert_within_limits(0.0) - prior.assert_within_limits(0.5) - prior.assert_within_limits(1.0) - - def test_no_limits(self): - prior = af.GaussianPrior(0, 1) - - prior.assert_within_limits(100) - prior.assert_within_limits(-100) - prior.assert_within_limits(0) - prior.assert_within_limits(0.5) - - def test_uniform_prior(self): - prior = af.UniformPrior(0, 1) - - with pytest.raises(af.exc.PriorLimitException): - prior.assert_within_limits(-1) - - with pytest.raises(af.exc.PriorLimitException): - prior.assert_within_limits(1.1) - - prior.assert_within_limits(0.0) - prior.assert_within_limits(0.5) - prior.assert_within_limits(1.0) def test_prior_creation(self): mapper = af.ModelMapper() @@ -58,18 +23,6 @@ def test_prior_creation(self): assert prior_tuples[1].prior.lower_limit == 0 assert prior_tuples[1].prior.upper_limit == 2 - def test_out_of_limits(self): - mm = af.ModelMapper() - mm.mock_class_gaussian = af.m.MockClassx2 - - assert mm.instance_from_vector([1, 2]) is not None - - with pytest.raises(af.exc.PriorLimitException): - mm.instance_from_vector(([1, 3])) - - with pytest.raises(af.exc.PriorLimitException): - mm.instance_from_vector(([-1, 2])) - def test_inf(self): mm = af.ModelMapper() mm.mock_class_inf = af.m.MockClassInf @@ -84,12 +37,6 @@ def test_inf(self): assert mm.instance_from_vector([-10000, 10000]) is not None - with pytest.raises(af.exc.PriorLimitException): - mm.instance_from_vector(([1, 0])) - - with pytest.raises(af.exc.PriorLimitException): - mm.instance_from_vector(([0, -1])) - def test_preserve_limits_tuples(self): mm = af.ModelMapper() mm.mock_class_gaussian = af.m.MockClassx2 @@ -314,7 +261,7 @@ def test__log_prior_from_value(self, mean, sigma, value, expected): def test_log_gaussian_prior_log_prior_from_value(): log_gaussian_prior = af.LogGaussianPrior( - mean=0.0, sigma=1.0, lower_limit=0.0, upper_limit=1.0 + mean=0.0, sigma=1.0, ) assert log_gaussian_prior.log_prior_from_value(value=0.0) == float("-inf") diff --git a/test_autofit/mapper/prior/test_prior_parsing.py b/test_autofit/mapper/prior/test_prior_parsing.py index a7ed4aae3..0ebb4803d 100644 --- a/test_autofit/mapper/prior/test_prior_parsing.py +++ b/test_autofit/mapper/prior/test_prior_parsing.py @@ -35,6 +35,15 @@ def make_log_uniform_prior(log_uniform_dict): def make_gaussian_dict(): return { "type": "Gaussian", + "mean": 3, + "sigma": 4, + "id": 0, + } + +@pytest.fixture(name="truncated_gaussian_dict") +def make_truncated_gaussian_dict(): + return { + "type": "TruncatedGaussian", "lower_limit": -10.0, "upper_limit": 10.0, "mean": 3, @@ -42,11 +51,13 @@ def make_gaussian_dict(): "id": 0, } - @pytest.fixture(name="gaussian_prior") def make_gaussian_prior(gaussian_dict): return af.Prior.from_dict(gaussian_dict) +@pytest.fixture(name="truncated_gaussian_prior") +def make_truncated_gaussian_prior(truncated_gaussian_dict): + return af.Prior.from_dict(truncated_gaussian_dict) @pytest.fixture(name="relative_width_dict") def make_relative_width_dict(): @@ -87,6 +98,10 @@ def test_default(self): class TestDict: def test_uniform(self, uniform_prior, uniform_dict, remove_ids): + + print(uniform_dict) + print(remove_ids(uniform_prior.dict())) + assert remove_ids(uniform_prior.dict()) == uniform_dict def test_log_uniform(self, log_uniform_prior, log_uniform_dict, remove_ids): @@ -109,11 +124,16 @@ def test_log_uniform(self, log_uniform_prior, absolute_width_modifier): def test_gaussian(self, gaussian_prior): assert isinstance(gaussian_prior, af.GaussianPrior) - assert gaussian_prior.lower_limit == -10 - assert gaussian_prior.upper_limit == 10 assert gaussian_prior.mean == 3 assert gaussian_prior.sigma == 4 + def test_truncated_gaussian(self, truncated_gaussian_prior): + assert isinstance(truncated_gaussian_prior, af.TruncatedGaussianPrior) + assert truncated_gaussian_prior.lower_limit == -10 + assert truncated_gaussian_prior.upper_limit == 10 + assert truncated_gaussian_prior.mean == 3 + assert truncated_gaussian_prior.sigma == 4 + def test_constant(self): result = af.Prior.from_dict({"type": "Constant", "value": 1.5}) assert result == 1.5 diff --git a/test_autofit/mapper/prior/test_regression.py b/test_autofit/mapper/prior/test_regression.py index 1d91d5c0e..b1b25c3a4 100644 --- a/test_autofit/mapper/prior/test_regression.py +++ b/test_autofit/mapper/prior/test_regression.py @@ -5,9 +5,11 @@ import autofit as af +# TODO : Use TruncatedGaussianPrior + @pytest.fixture(name="prior") def make_prior(): - return af.GaussianPrior(mean=1, sigma=2, lower_limit=3, upper_limit=4) + return af.GaussianPrior(mean=1, sigma=2) @pytest.fixture(name="message") @@ -15,21 +17,20 @@ def make_message(prior): return prior.message -def test_copy_limits(message): - copied = message.copy() - assert message.lower_limit == copied.lower_limit - assert message.upper_limit == copied.upper_limit - - -def test_multiply_limits(message): - multiplied = message * message - assert message.lower_limit == multiplied.lower_limit - assert message.upper_limit == multiplied.upper_limit - - multiplied = 1 * message - assert message.lower_limit == multiplied.lower_limit - assert message.upper_limit == multiplied.upper_limit - +# def test_copy_limits(message): +# copied = message.copy() +# assert message.lower_limit == copied.lower_limit +# assert message.upper_limit == copied.upper_limit +# +# +# def test_multiply_limits(message): +# multiplied = message * message +# assert message.lower_limit == multiplied.lower_limit +# assert message.upper_limit == multiplied.upper_limit +# +# multiplied = 1 * message +# assert message.lower_limit == multiplied.lower_limit +# assert message.upper_limit == multiplied.upper_limit @pytest.fixture(name="uniform_prior") def make_uniform_prior(): diff --git a/test_autofit/mapper/prior/test_truncated_gaussian.py b/test_autofit/mapper/prior/test_truncated_gaussian.py new file mode 100644 index 000000000..b0afeba7b --- /dev/null +++ b/test_autofit/mapper/prior/test_truncated_gaussian.py @@ -0,0 +1,26 @@ +import pickle + +import pytest + +import autofit as af +from autofit.mapper.identifier import Identifier + + +@pytest.fixture(name="truncated_gaussian") +def make_truncated_gaussian(): + return af.TruncatedGaussianPrior(mean=1.0, sigma=2.0, lower_limit=0.95, upper_limit=1.05) + + +@pytest.mark.parametrize( + "unit, value", + [ + # (0.0, 0.0), + (0.001, 0.95), + (0.5, 1.0), + (0.999, 1.05), + ], +) +def test_values(truncated_gaussian, unit, value): + print(unit, value) + assert truncated_gaussian.value_for(unit) == pytest.approx(value, rel=0.1) + diff --git a/test_autofit/mapper/test_array.py b/test_autofit/mapper/test_array.py index 836eb91f9..afc3732b2 100644 --- a/test_autofit/mapper/test_array.py +++ b/test_autofit/mapper/test_array.py @@ -81,32 +81,24 @@ def array_dict(): ], }, "prior_0_0": { - "lower_limit": float("-inf"), "mean": 0.0, "sigma": 1.0, "type": "Gaussian", - "upper_limit": float("inf"), }, "prior_0_1": { - "lower_limit": float("-inf"), "mean": 0.0, "sigma": 1.0, "type": "Gaussian", - "upper_limit": float("inf"), }, "prior_1_0": { - "lower_limit": float("-inf"), "mean": 0.0, "sigma": 1.0, "type": "Gaussian", - "upper_limit": float("inf"), }, "prior_1_1": { - "lower_limit": float("-inf"), "mean": 0.0, "sigma": 1.0, "type": "Gaussian", - "upper_limit": float("inf"), }, "shape": {"type": "tuple", "values": [2, 2]}, }, diff --git a/test_autofit/non_linear/grid/test_optimizer_grid_search.py b/test_autofit/non_linear/grid/test_optimizer_grid_search.py index 129177d2a..ae779eb24 100644 --- a/test_autofit/non_linear/grid/test_optimizer_grid_search.py +++ b/test_autofit/non_linear/grid/test_optimizer_grid_search.py @@ -130,7 +130,16 @@ def test_different_prior_width(self, grid_search, mapper): def test_raises_exception_for_bad_limits(self, grid_search, mapper): mapper.component.one_tuple.one_tuple_0 = af.GaussianPrior( - 0.0, 2.0, lower_limit=float("-inf"), upper_limit=float("inf") + 0.0, 2.0, + ) + with pytest.raises(exc.PriorException): + list( + grid_search.make_arguments( + [[0, 1]], grid_priors=[mapper.component.one_tuple.one_tuple_0] + ) + ) + mapper.component.one_tuple.one_tuple_0 = af.UniformPrior( + lower_limit=float("-inf"), upper_limit=float("inf") ) with pytest.raises(exc.PriorException): list( diff --git a/test_autofit/non_linear/grid/test_sensitivity/test_run.py b/test_autofit/non_linear/grid/test_sensitivity/test_run.py index 86564bf03..c7bd0caa6 100644 --- a/test_autofit/non_linear/grid/test_sensitivity/test_run.py +++ b/test_autofit/non_linear/grid/test_sensitivity/test_run.py @@ -1,7 +1,7 @@ # from autoconf.conf import with_config # # -# @with_config("general", "model", "ignore_prior_limits", value=True) +# @with_config("general", "model", value=True) # def test_sensitivity(sensitivity): # results = sensitivity.run() # assert len(results) == 8 diff --git a/test_autofit/non_linear/samples/test_pdf.py b/test_autofit/non_linear/samples/test_pdf.py index 5f54dabc9..9ba92d046 100644 --- a/test_autofit/non_linear/samples/test_pdf.py +++ b/test_autofit/non_linear/samples/test_pdf.py @@ -180,7 +180,7 @@ def test__median_pdf__unconverged(): assert median_pdf[1] == pytest.approx(1.9, 1.0e-4) -@with_config("general", "model", "ignore_prior_limits", value=True) +@with_config("general", "model", value=True) def test__converged__vector_and_instance_at_upper_and_lower_sigma(): parameters = [ [0.1, 0.4], diff --git a/test_autofit/non_linear/test_regression.py b/test_autofit/non_linear/test_regression.py index 64876ce63..469f1f52c 100644 --- a/test_autofit/non_linear/test_regression.py +++ b/test_autofit/non_linear/test_regression.py @@ -50,7 +50,7 @@ def test_skip_assertions(model): with pytest.raises(exc.FitException): model.instance_from_prior_medians() - model.instance_from_prior_medians(ignore_prior_limits=True) + model.instance_from_prior_medians(ignore_assertions=True) def test_recursive_skip_assertions(model): @@ -58,4 +58,4 @@ def test_recursive_skip_assertions(model): with pytest.raises(exc.FitException): model.instance_from_prior_medians() - model.instance_from_prior_medians(ignore_prior_limits=True) + model.instance_from_prior_medians(ignore_assertions=True) diff --git a/test_autofit/serialise/test_sum_prior.py b/test_autofit/serialise/test_sum_prior.py index 00b501ae8..6ca13ecd0 100644 --- a/test_autofit/serialise/test_sum_prior.py +++ b/test_autofit/serialise/test_sum_prior.py @@ -10,16 +10,12 @@ def make_prior_dict(): "type": "compound", "compound_type": "SumPrior", "left": { - "lower_limit": float("-inf"), - "upper_limit": float("inf"), "type": "Gaussian", "id": 0, "mean": 1.0, "sigma": 2.0, }, "right": { - "lower_limit": float("-inf"), - "upper_limit": float("inf"), "type": "Gaussian", "id": 0, "mean": 1.0,