diff --git a/autofit/__init__.py b/autofit/__init__.py index 949baad80..4b5783e7e 100644 --- a/autofit/__init__.py +++ b/autofit/__init__.py @@ -141,4 +141,4 @@ def save_abc(pickler, obj): pickle._Pickler.save_type(pickler, obj) -__version__ = "2025.11.5.1" \ No newline at end of file +__version__ = "2026.1.21.3" \ No newline at end of file diff --git a/autofit/aggregator/search_output.py b/autofit/aggregator/search_output.py index 0a8facd85..d4484dbdf 100644 --- a/autofit/aggregator/search_output.py +++ b/autofit/aggregator/search_output.py @@ -250,7 +250,7 @@ def instance(self): try: return self.samples.max_log_likelihood() except (AttributeError, NotImplementedError): - return None + return self.samples_summary.instance @property def id(self) -> str: diff --git a/autofit/aggregator/summary/aggregate_csv/column.py b/autofit/aggregator/summary/aggregate_csv/column.py index 9635c1bde..77eedafde 100644 --- a/autofit/aggregator/summary/aggregate_csv/column.py +++ b/autofit/aggregator/summary/aggregate_csv/column.py @@ -64,20 +64,34 @@ def value(self, row: "Row"): result = {} if ValueType.Median in self.value_types: - result[""] = row.median_pdf_sample_kwargs[self.path] + try: + result[""] = row.median_pdf_sample_kwargs[self.path] + except KeyError: + result[""] = None if ValueType.MaxLogLikelihood in self.value_types: - result["max_lh"] = row.max_likelihood_kwargs[self.path] + try: + result["max_lh"] = row.max_likelihood_kwargs[self.path] + except KeyError: + result["max_lh"] = None if ValueType.ValuesAt1Sigma in self.value_types: - lower, upper = row.values_at_sigma_1_kwargs[self.path] - result["lower_1_sigma"] = lower - result["upper_1_sigma"] = upper + try: + lower, upper = row.values_at_sigma_1_kwargs[self.path] + result["lower_1_sigma"] = lower + result["upper_1_sigma"] = upper + except KeyError: + result["lower_1_sigma"] = None + result["upper_1_sigma"] = None if ValueType.ValuesAt3Sigma in self.value_types: - lower, upper = row.values_at_sigma_3_kwargs[self.path] - result["lower_3_sigma"] = lower - result["upper_3_sigma"] = upper + try: + lower, upper = row.values_at_sigma_3_kwargs[self.path] + result["lower_3_sigma"] = lower + result["upper_3_sigma"] = upper + except KeyError: + result["lower_3_sigma"] = None + result["upper_3_sigma"] = None return result diff --git a/autofit/aggregator/summary/aggregate_images.py b/autofit/aggregator/summary/aggregate_images.py index da9b179f9..69a29d047 100644 --- a/autofit/aggregator/summary/aggregate_images.py +++ b/autofit/aggregator/summary/aggregate_images.py @@ -24,32 +24,11 @@ def subplot_filename(subplot: Enum) -> str: ) -class SubplotFit(Enum): - """ - The subplots that can be extracted from the subplot_fit image. - - The values correspond to the position of the subplot in the 4x3 grid. - """ - - Data = (0, 0) - DataSourceScaled = (1, 0) - SignalToNoiseMap = (2, 0) - ModelData = (3, 0) - LensLightModelData = (0, 1) - LensLightSubtractedImage = (1, 1) - SourceModelData = (2, 1) - SourcePlaneZoomed = (3, 1) - NormalizedResidualMap = (0, 2) - NormalizedResidualMapOneSigma = (1, 2) - ChiSquaredMap = (2, 2) - SourcePlaneNoZoom = (3, 2) - - class SubplotFitImage: def __init__( self, image: Image.Image, - suplot_type: Type[SubplotFit] = SubplotFit, + suplot_type, ): """ The subplot_fit image associated with one fit. @@ -173,7 +152,7 @@ def output_to_folder( self, folder: Path, name: Union[str, List[str]], - subplots: List[Union[SubplotFit, List[Image.Image], Callable]], + subplots: List[Union[List[Image.Image], Callable]], subplot_width: Optional[int] = sys.maxsize, ): """ @@ -226,7 +205,7 @@ def output_to_folder( def _matrix_for_result( i: int, result: SearchOutput, - subplots: List[Union[SubplotFit, List[Image.Image], Callable]], + subplots: List[Union[List[Image.Image], Callable]], subplot_width: int = sys.maxsize, ) -> List[List[Image.Image]]: """ @@ -272,7 +251,8 @@ class name but using snake_case. _images[subplot_type] = SubplotFitImage( result.image( subplot_filename(subplot_), - ) + ), + subplot_type ) return _images[subplot_type] diff --git a/autofit/graphical/declarative/abstract.py b/autofit/graphical/declarative/abstract.py index e633c71df..35e1c177f 100644 --- a/autofit/graphical/declarative/abstract.py +++ b/autofit/graphical/declarative/abstract.py @@ -20,7 +20,7 @@ class AbstractDeclarativeFactor(Analysis, ABC): _plates: Tuple[Plate, ...] = () def __init__(self, include_prior_factors=False, use_jax : bool = False): - self.include_prior_factors = include_prior_factors + self._include_prior_factors = include_prior_factors super().__init__(use_jax=use_jax) @@ -63,7 +63,7 @@ def prior_counts(self) -> List[Tuple[Prior, int]]: for prior in factor.prior_model.priors: counter[prior] += 1 return [ - (prior, count + 1 if self.include_prior_factors else count) + (prior, count + 1 if self._include_prior_factors else count) for prior, count in counter.items() ] @@ -96,7 +96,7 @@ def graph(self) -> DeclarativeFactorGraph: The complete graph made by combining all factors and priors """ factors = [model for model in self.model_factors] - if self.include_prior_factors: + if self._include_prior_factors: factors += self.prior_factors # noinspection PyTypeChecker return DeclarativeFactorGraph(factors) diff --git a/autofit/graphical/declarative/collection.py b/autofit/graphical/declarative/collection.py index 97fb9824f..6dff83693 100644 --- a/autofit/graphical/declarative/collection.py +++ b/autofit/graphical/declarative/collection.py @@ -42,7 +42,7 @@ def __init__( def tree_flatten(self): return ( (self._model_factors,), - (self._name, self.include_prior_factors), + (self._name, self._include_prior_factors), ) @classmethod diff --git a/autofit/graphical/declarative/factor/abstract.py b/autofit/graphical/declarative/factor/abstract.py index 3df6a9f0c..67bfec786 100644 --- a/autofit/graphical/declarative/factor/abstract.py +++ b/autofit/graphical/declarative/factor/abstract.py @@ -42,7 +42,7 @@ def __init__( factor, **prior_variable_dict, name=name or namer(self.__class__.__name__), ) - self.include_prior_factors = include_prior_factors + self._include_prior_factors = include_prior_factors @property def info(self) -> str: diff --git a/autofit/graphical/declarative/factor/hierarchical.py b/autofit/graphical/declarative/factor/hierarchical.py index f4dfec1d0..69bd3c272 100644 --- a/autofit/graphical/declarative/factor/hierarchical.py +++ b/autofit/graphical/declarative/factor/hierarchical.py @@ -8,6 +8,7 @@ from autofit.messages import NormalMessage from autofit.non_linear.paths.abstract import AbstractPaths from autofit.tools.namer import namer + from .abstract import AbstractModelFactor @@ -19,6 +20,7 @@ def __init__( distribution: Type[Prior], optimiser=None, name: Optional[str] = None, + use_jax : bool = False, **kwargs, ): """ @@ -70,6 +72,7 @@ def __init__( self._name = name or namer(self.__class__.__name__) self._factors = list() self.optimiser = optimiser + self._use_jax = use_jax @property def name(self): @@ -144,7 +147,7 @@ def __call__(self, **kwargs): class _HierarchicalFactor(AbstractModelFactor): def __init__( - self, distribution_model: HierarchicalFactor, drawn_prior: Prior, use_jax : bool = False + self, distribution_model: HierarchicalFactor, drawn_prior: Prior, ): """ A factor that links a variable to a parameterised distribution. @@ -159,7 +162,7 @@ def __init__( """ self.distribution_model = distribution_model self.drawn_prior = drawn_prior - self.use_jax = use_jax + self._use_jax = distribution_model._use_jax prior_variable_dict = {prior.name: prior for prior in distribution_model.priors} @@ -188,7 +191,7 @@ def variable(self): return self.drawn_prior def log_likelihood_function(self, instance): - return instance.distribution_model.message(instance.drawn_prior) + return instance.distribution_model.message(instance.drawn_prior, xp=self._xp) @property def priors(self) -> Set[Prior]: diff --git a/autofit/mapper/prior/arithmetic/assertion.py b/autofit/mapper/prior/arithmetic/assertion.py index 2e7315d2b..b23a8c7a4 100644 --- a/autofit/mapper/prior/arithmetic/assertion.py +++ b/autofit/mapper/prior/arithmetic/assertion.py @@ -1,4 +1,5 @@ from abc import ABC +import numpy as np from typing import Optional, Dict from autofit.mapper.prior.arithmetic.compound import CompoundPrior, Compound @@ -24,7 +25,7 @@ def __le__(self, other): class GreaterThanLessThanAssertion(ComparisonAssertion): - def _instance_for_arguments(self, arguments, ignore_assertions=False): + def _instance_for_arguments(self, arguments, ignore_assertions=False, xp=np): """ Assert that the value in the dictionary associated with the lower prior is lower than the value associated with the greater prior. @@ -55,6 +56,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): """ Assert that the value in the dictionary associated with the lower @@ -90,6 +92,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.assertion_1.instance_for_arguments( arguments, diff --git a/autofit/mapper/prior/arithmetic/compound.py b/autofit/mapper/prior/arithmetic/compound.py index 4951340a1..ce9d1b910 100644 --- a/autofit/mapper/prior/arithmetic/compound.py +++ b/autofit/mapper/prior/arithmetic/compound.py @@ -212,6 +212,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -237,6 +238,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -256,6 +258,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -275,6 +278,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -294,6 +298,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -313,6 +318,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -396,6 +402,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return -self.prior.instance_for_arguments( arguments, @@ -412,6 +419,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return abs( self.prior.instance_for_arguments( @@ -430,6 +438,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return np.log( self.prior.instance_for_arguments( @@ -448,6 +457,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return np.log10( self.prior.instance_for_arguments( diff --git a/autofit/mapper/prior/gaussian.py b/autofit/mapper/prior/gaussian.py index a4bcc5bb6..fe74b6016 100644 --- a/autofit/mapper/prior/gaussian.py +++ b/autofit/mapper/prior/gaussian.py @@ -1,3 +1,4 @@ +import numpy as np from typing import Optional from autofit.messages.normal import NormalMessage @@ -46,6 +47,7 @@ def __init__( >>> prior = GaussianPrior(mean=1.0, sigma=2.0) >>> physical_value = prior.value_for(unit=0.5) # Returns ~1.0 (mean) """ + super().__init__( message=NormalMessage( mean=mean, diff --git a/autofit/mapper/prior/log_gaussian.py b/autofit/mapper/prior/log_gaussian.py index 6fa458950..d8f2a85b7 100644 --- a/autofit/mapper/prior/log_gaussian.py +++ b/autofit/mapper/prior/log_gaussian.py @@ -133,7 +133,7 @@ def value_for(self, unit: float) -> float: def parameter_string(self) -> str: return f"mean = {self.mean}, sigma = {self.sigma}" - def log_prior_from_value(self, value): + def log_prior_from_value(self, value, xp=np): if value <= 0: return float("-inf") diff --git a/autofit/mapper/prior/log_uniform.py b/autofit/mapper/prior/log_uniform.py index 63a9063b2..434af5b08 100644 --- a/autofit/mapper/prior/log_uniform.py +++ b/autofit/mapper/prior/log_uniform.py @@ -110,7 +110,7 @@ def with_limits(cls, lower_limit: float, upper_limit: float) -> "LogUniformPrior __identifier_fields__ = ("lower_limit", "upper_limit") - def log_prior_from_value(self, value) -> float: + def log_prior_from_value(self, value, xp=np) -> float: """ Returns the log prior of a physical value, so the log likelihood of a model evaluation can be converted to a posterior as log_prior + log_likelihood. diff --git a/autofit/mapper/prior/uniform.py b/autofit/mapper/prior/uniform.py index 08baed5bb..af04f23d5 100644 --- a/autofit/mapper/prior/uniform.py +++ b/autofit/mapper/prior/uniform.py @@ -1,3 +1,4 @@ +import numpy as np from typing import Optional, Tuple from autofit.messages.normal import UniformNormalMessage @@ -128,7 +129,7 @@ def value_for(self, unit: float) -> float: round(super().value_for(unit), 14) ) - def log_prior_from_value(self, value): + def log_prior_from_value(self, value, xp=np): """ Returns the log prior of a physical value, so the log likelihood of a model evaluation can be converted to a posterior as log_prior + log_likelihood. diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index b41ea3c58..f6b7f3939 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -745,7 +745,7 @@ def physical_values_from_prior_medians(self): """ return self.vector_from_unit_vector([0.5] * len(self.unique_prior_tuples)) - def instance_from_vector(self, vector, ignore_assertions: bool = False): + def instance_from_vector(self, vector, ignore_assertions: bool = False, xp=np): """ Returns a ModelInstance, which has an attribute and class instance corresponding to every `Model` attributed to this instance. @@ -778,6 +778,7 @@ def instance_from_vector(self, vector, ignore_assertions: bool = False): return self.instance_for_arguments( arguments, ignore_assertions=ignore_assertions, + xp=xp ) def has(self, cls: Union[Type, Tuple[Type, ...]]) -> bool: @@ -1078,6 +1079,7 @@ def instance_from_prior_medians(self, ignore_assertions : bool = False): def log_prior_list_from_vector( self, vector: [float], + xp=np, ): """ Compute the log priors of every parameter in a vector, using the Prior of every parameter. @@ -1094,7 +1096,7 @@ def log_prior_list_from_vector( return list( map( lambda prior_tuple, value: prior_tuple.prior.log_prior_from_value( - value=value + value=value, xp=xp ), self.prior_tuples_ordered_by_id, vector, @@ -1308,6 +1310,7 @@ def _instance_for_arguments( self, arguments: Dict[Prior, float], ignore_assertions: bool = False, + xp=np, ): raise NotImplementedError() @@ -1315,6 +1318,7 @@ def instance_for_arguments( self, arguments: Dict[Prior, float], ignore_assertions: bool = False, + xp=np, ): """ Returns an instance of the model for a set of arguments @@ -1338,6 +1342,7 @@ def instance_for_arguments( return self._instance_for_arguments( arguments, ignore_assertions=ignore_assertions, + xp=xp ) def path_for_name(self, name: str) -> Tuple[str, ...]: diff --git a/autofit/mapper/prior_model/array.py b/autofit/mapper/prior_model/array.py index 7952b2743..317680818 100644 --- a/autofit/mapper/prior_model/array.py +++ b/autofit/mapper/prior_model/array.py @@ -57,6 +57,7 @@ def _instance_for_arguments( self, arguments: Dict[Prior, float], ignore_assertions: bool = False, + xp=np, ) -> np.ndarray: """ Create an array where the prior at each index is replaced with the diff --git a/autofit/mapper/prior_model/collection.py b/autofit/mapper/prior_model/collection.py index 5d39dcdcd..13073c10e 100644 --- a/autofit/mapper/prior_model/collection.py +++ b/autofit/mapper/prior_model/collection.py @@ -1,3 +1,5 @@ +import numpy as np + from collections.abc import Iterable from autofit.mapper.model import ModelInstance, assert_not_frozen @@ -208,6 +210,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): """ Parameters @@ -228,6 +231,7 @@ def _instance_for_arguments( value = value.instance_for_arguments( arguments, ignore_assertions=ignore_assertions, + xp=xp ) elif isinstance(value, Prior): value = arguments[value] diff --git a/autofit/mapper/prior_model/prior_model.py b/autofit/mapper/prior_model/prior_model.py index cbf1cb285..756899185 100644 --- a/autofit/mapper/prior_model/prior_model.py +++ b/autofit/mapper/prior_model/prior_model.py @@ -2,6 +2,7 @@ import copy import inspect import logging +import numpy as np import typing from typing import * @@ -420,36 +421,6 @@ def __getattr__(self, item): self.__getattribute__(item) - # def __getattr__(self, item): - # - # try: - # if ( - # "_" in item - # and item not in ("_is_frozen", "tuple_prior_tuples") - # and not item.startswith("_") - # ): - # return getattr( - # [v for k, v in self.tuple_prior_tuples if item.split("_")[0] == k][ - # 0 - # ], - # item, - # ) - # - # except IndexError: - # pass - # - # try: - # return getattr( - # self.instance_for_arguments( - # {prior: prior for prior in self.priors}, - # ), - # item, - # ) - # except (AttributeError, TypeError): - # pass - # - # self.__getattribute__(item) - @property def is_deferred_arguments(self): return len(self.direct_deferred_tuples) > 0 @@ -459,6 +430,7 @@ def _instance_for_arguments( self, arguments: {ModelObject: object}, ignore_assertions=False, + xp=np, ): """ Returns an instance of the associated class for a set of arguments @@ -490,6 +462,7 @@ def _instance_for_arguments( ] = prior_model.instance_for_arguments( arguments, ignore_assertions=ignore_assertions, + xp=xp ) prior_arguments = dict() @@ -532,6 +505,7 @@ def _instance_for_arguments( value = value.instance_for_arguments( arguments, ignore_assertions=ignore_assertions, + xp=xp ) elif isinstance(value, Constant): value = value.value diff --git a/autofit/messages/abstract.py b/autofit/messages/abstract.py index d3b921108..e6a9fd860 100644 --- a/autofit/messages/abstract.py +++ b/autofit/messages/abstract.py @@ -45,17 +45,25 @@ def __init__( lower_limit=-math.inf, upper_limit=math.inf, id_=None, + _xp=np ): + xp=_xp + self.lower_limit = float(lower_limit) self.upper_limit = float(upper_limit) self.id = next(self.ids) if id_ is None else id_ self.log_norm = log_norm - self._broadcast = np.broadcast(*parameters) + + + if xp is np: + self._broadcast = np.broadcast(*parameters) + else: + self._broadcast = _xp.broadcast_arrays(*parameters) if self.shape: - self.parameters = tuple(np.asanyarray(p) for p in parameters) + self.parameters = tuple(xp.asarray(p) for p in parameters) else: self.parameters = tuple(parameters) @@ -215,7 +223,7 @@ def __truediv__(self, other: Union["AbstractMessage", Real]) -> "AbstractMessage ) def __pow__(self, other: Real) -> "AbstractMessage": - natural = self.natural_parameters + natural = self.natural_parameters() new_params = other * natural log_norm = other * self.log_norm new = self.from_natural_parameters( @@ -306,12 +314,12 @@ def log_normalisation(self, *elems: Union["AbstractMessage", float]) -> np.ndarr ] # Calculate log product of message normalisation - log_norm = self.log_base_measure - self.log_partition - log_norm += sum(dist.log_base_measure - dist.log_partition for dist in dists) + log_norm = self.log_base_measure - self.log_partition() + log_norm += sum(dist.log_base_measure - dist.log_partition() for dist in dists) # Calculate log normalisation of product of messages prod_dist = self.sum_natural_parameters(*dists) - log_norm -= prod_dist.log_base_measure - prod_dist.log_partition + log_norm -= prod_dist.log_base_measure - prod_dist.log_partition() return log_norm @@ -359,8 +367,8 @@ def _get_mean_variance( ) return mean, variance - def __call__(self, x): - return np.sum(self.logpdf(x)) + def __call__(self, x, xp=np): + return xp.sum(self.logpdf(x, xp=xp)) def factor_jacobian( self, x: np.ndarray, _variables: Optional[Tuple[str]] = ("x",) diff --git a/autofit/messages/beta.py b/autofit/messages/beta.py index 2f4848e05..9b9e6b919 100644 --- a/autofit/messages/beta.py +++ b/autofit/messages/beta.py @@ -171,8 +171,7 @@ def value_for(self, unit: float) -> float: """ raise NotImplemented() - @cached_property - def log_partition(self) -> np.ndarray: + def log_partition(self, xp=np) -> np.ndarray: """ Compute the log partition function (log normalization constant) of the Beta distribution. @@ -184,8 +183,7 @@ def log_partition(self) -> np.ndarray: return betaln(*self.parameters) - @cached_property - def natural_parameters(self) -> np.ndarray: + def natural_parameters(self, xp=np) -> np.ndarray: """ Compute the natural parameters of the Beta distribution. @@ -196,13 +194,15 @@ def natural_parameters(self) -> np.ndarray: """ return self.calc_natural_parameters( self.alpha, - self.beta + self.beta, + xp=xp ) @staticmethod def calc_natural_parameters( alpha: Union[float, np.ndarray], - beta: Union[float, np.ndarray] + beta: Union[float, np.ndarray], + xp=np ) -> np.ndarray: """ Calculate the natural parameters of a Beta distribution from alpha and beta. @@ -218,7 +218,7 @@ def calc_natural_parameters( ------- Natural parameters [alpha - 1, beta - 1]. """ - return np.array([alpha - 1, beta - 1]) + return xp.array([alpha - 1, beta - 1]) @staticmethod def invert_natural_parameters( @@ -258,7 +258,7 @@ def invert_sufficient_statistics( return cls.calc_natural_parameters(a, b) @classmethod - def to_canonical_form(cls, x: np.ndarray) -> np.ndarray: + def to_canonical_form(cls, x: np.ndarray, xp=np) -> np.ndarray: """ Convert a value x in (0,1) to the canonical sufficient statistics for Beta. @@ -271,7 +271,7 @@ def to_canonical_form(cls, x: np.ndarray) -> np.ndarray: ------- Canonical sufficient statistics [log(x), log(1 - x)]. """ - return np.array([np.log(x), np.log1p(-x)]) + return xp.array([xp.log(x), xp.log1p(-x)]) @cached_property def mean(self) -> Union[np.ndarray, float]: diff --git a/autofit/messages/composed_transform.py b/autofit/messages/composed_transform.py index 959647009..84792b062 100644 --- a/autofit/messages/composed_transform.py +++ b/autofit/messages/composed_transform.py @@ -34,9 +34,9 @@ def transform(func): """ @functools.wraps(func) - def wrapper(self, x): + def wrapper(self, x, xp=np): x = self._transform(x) - return func(self, x) + return func(self, x, xp) return wrapper @@ -157,9 +157,8 @@ def project( def kl(self, dist): return self.base_message.kl(dist.base_message) - @property - def natural_parameters(self): - return self.base_message.natural_parameters + def natural_parameters(self, xp=np) -> np.ndarray: + return self.base_message.natural_parameters(xp=xp) @inverse_transform def sample(self, n_samples: Optional[int] = None): @@ -226,12 +225,11 @@ def invert_natural_parameters( return self.base_message.invert_natural_parameters(natural_parameters) @transform - def cdf(self, x): - return self.base_message.cdf(x) + def cdf(self, x, xp=np): + return self.base_message.cdf(x, xp=xp) - @property - def log_partition(self) -> np.ndarray: - return self.base_message.log_partition + def log_partition(self, xp=np) -> np.ndarray: + return self.base_message.log_partition(xp=xp) def invert_sufficient_statistics(self, sufficient_statistics): return self.base_message.invert_sufficient_statistics(sufficient_statistics) @@ -241,12 +239,12 @@ def value_for(self, unit): return self.base_message.value_for(unit) @transform - def calc_log_base_measure(self, x) -> np.ndarray: - return self.base_message.calc_log_base_measure(x) + def calc_log_base_measure(self, x, xp=np) -> np.ndarray: + return self.base_message.calc_log_base_measure(x, xp=xp) @transform - def to_canonical_form(self, x) -> np.ndarray: - return self.base_message.to_canonical_form(x) + def to_canonical_form(self, x, xp=np) -> np.ndarray: + return self.base_message.to_canonical_form(x, xp=xp) @property @inverse_transform diff --git a/autofit/messages/fixed.py b/autofit/messages/fixed.py index 496a62804..1280d5e95 100644 --- a/autofit/messages/fixed.py +++ b/autofit/messages/fixed.py @@ -25,8 +25,7 @@ def __init__( def value_for(self, unit: float) -> float: raise NotImplemented() - @cached_property - def natural_parameters(self) -> Tuple[np.ndarray, ...]: + def natural_parameters(self, xp=np) -> Tuple[np.ndarray, ...]: return self.parameters @staticmethod @@ -35,11 +34,10 @@ def invert_natural_parameters(natural_parameters: np.ndarray return natural_parameters, @staticmethod - def to_canonical_form(x: np.ndarray) -> np.ndarray: + def to_canonical_form(x: np.ndarray, xp=np) -> np.ndarray: return x - @cached_property - def log_partition(self) -> np.ndarray: + def log_partition(self, xp=np) -> np.ndarray: return 0. @classmethod diff --git a/autofit/messages/gamma.py b/autofit/messages/gamma.py index ae86a615e..3677186f0 100644 --- a/autofit/messages/gamma.py +++ b/autofit/messages/gamma.py @@ -6,11 +6,11 @@ class GammaMessage(AbstractMessage): - @property - def log_partition(self): + + def log_partition(self, xp=np): from scipy import special - alpha, beta = GammaMessage.invert_natural_parameters(self.natural_parameters) + alpha, beta = GammaMessage.invert_natural_parameters(self.natural_parameters(xp=xp)) return special.gammaln(alpha) - alpha * np.log(beta) log_base_measure = 0.0 @@ -36,13 +36,12 @@ def __init__( def value_for(self, unit: float) -> float: raise NotImplemented() - @cached_property - def natural_parameters(self): - return self.calc_natural_parameters(self.alpha, self.beta) + def natural_parameters(self, xp=np) -> np.ndarray: + return self.calc_natural_parameters(self.alpha, self.beta, xp=xp) @staticmethod - def calc_natural_parameters(alpha, beta): - return np.array([alpha - 1, -beta]) + def calc_natural_parameters(alpha, beta, xp=np): + return xp.array([alpha - 1, -beta]) @staticmethod def invert_natural_parameters(natural_parameters): @@ -50,8 +49,8 @@ def invert_natural_parameters(natural_parameters): return eta1 + 1, -eta2 @staticmethod - def to_canonical_form(x): - return np.array([np.log(x), x]) + def to_canonical_form(x, xp=np): + return xp.array([np.log(x), x]) @classmethod def invert_sufficient_statistics(cls, suff_stats): @@ -98,13 +97,13 @@ def kl(self, dist): def logpdf_gradient(self, x): logl = self.logpdf(x) - eta1 = self.natural_parameters[0] + eta1 = self.natural_parameters()[0] gradl = eta1 / x - self.beta return logl, gradl def logpdf_gradient_hessian(self, x): logl = self.logpdf(x) - eta1 = self.natural_parameters[0] + eta1 = self.natural_parameters()[0] gradl = eta1 / x hessl = -gradl / x gradl -= self.beta diff --git a/autofit/messages/interface.py b/autofit/messages/interface.py index 28d46d2ab..c0d589553 100644 --- a/autofit/messages/interface.py +++ b/autofit/messages/interface.py @@ -23,6 +23,11 @@ def broadcast(self): @property def shape(self) -> Tuple[int, ...]: + + # JAX behaviour + if isinstance(self.broadcast, list): + return () + return self.broadcast.shape @property @@ -39,47 +44,47 @@ def __eq__(self, other): def pdf(self, x: np.ndarray) -> np.ndarray: return np.exp(self.logpdf(x)) - def logpdf(self, x: Union[np.ndarray, float]) -> np.ndarray: - eta = self._broadcast_natural_parameters(x) - t = self.to_canonical_form(x) - log_base = self.calc_log_base_measure(x) - return self.natural_logpdf(eta, t, log_base, self.log_partition) + def logpdf(self, x: Union[np.ndarray, float], xp=np) -> np.ndarray: - def _broadcast_natural_parameters(self, x): - shape = np.shape(x) + eta = self._broadcast_natural_parameters(x, xp=xp) + t = self.to_canonical_form(x, xp=xp) + log_base = self.calc_log_base_measure(x, xp=xp) + + return self.natural_logpdf(eta, t, log_base, self.log_partition(xp=xp), xp=xp) + + def _broadcast_natural_parameters(self, x, xp=np): + shape = xp.shape(x) if shape == self.shape: - return self.natural_parameters + return self.natural_parameters(xp=xp) elif shape[1:] == self.shape: - return self.natural_parameters[:, None, ...] + return self.natural_parameters(xp=xp)[:, None, ...] else: raise ValueError( f"shape of passed value {shape} does not " f"match message shape {self.shape}" ) - @cached_property @abstractmethod - def natural_parameters(self): + def natural_parameters(self, xp=np) -> np.ndarray: pass @staticmethod @abstractmethod - def to_canonical_form(x: Union[np.ndarray, float]) -> np.ndarray: + def to_canonical_form(x: Union[np.ndarray, float], xp=np) -> np.ndarray: pass @classmethod - def calc_log_base_measure(cls, x): + def calc_log_base_measure(cls, x, xp=np): return cls.log_base_measure - @cached_property @abstractmethod - def log_partition(self) -> np.ndarray: + def log_partition(self, xp=np) -> np.ndarray: pass @classmethod - def natural_logpdf(cls, eta, t, log_base, log_partition): - eta_t = np.multiply(eta, t).sum(0) - return np.nan_to_num(log_base + eta_t - log_partition, nan=-np.inf) + def natural_logpdf(cls, eta, t, log_base, log_partition, xp=np): + eta_t = xp.multiply(eta, t).sum(0) + return xp.nan_to_num(log_base + eta_t - log_partition, nan=-xp.inf) def numerical_logpdf_gradient( self, x: np.ndarray, eps: float = 1e-6 @@ -181,7 +186,7 @@ def check_support(self) -> np.ndarray: pass def check_finite(self) -> np.ndarray: - return np.isfinite(self.natural_parameters).all(0) + return np.isfinite(self.natural_parameters()).all(0) def check_valid(self) -> np.ndarray: return self.check_finite() & self.check_support() @@ -201,11 +206,11 @@ def sum_natural_parameters(self, *dists: "MessageInterface") -> "MessageInterfac """ new_params = sum( ( - dist.natural_parameters + dist.natural_parameters() for dist in self._iter_dists(dists) if isinstance(dist, MessageInterface) ), - self.natural_parameters, + self.natural_parameters(), ) return self.from_natural_parameters( new_params, @@ -217,7 +222,7 @@ def sub_natural_parameters(self, other: "MessageInterface") -> "MessageInterface of this distribution with another distribution of the same type""" log_norm = self.log_norm - other.log_norm - new_params = self.natural_parameters - other.natural_parameters + new_params = self.natural_parameters() - other.natural_parameters() return self.from_natural_parameters( new_params, log_norm=log_norm, diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index 9ff12ef3a..f7128aa97 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -23,10 +23,30 @@ def is_nan(value): is_nan_ = is_nan_.all() return is_nan_ +def assert_sigma_non_negative(sigma, xp=np): + + is_negative = sigma < 0 + + if xp.__name__.startswith("jax"): + import jax + # JAX path: cannot convert to Python bool + # Raise using JAX control flow: + return jax.lax.cond( + is_negative, + lambda _: (_ for _ in ()).throw( + ValueError("Sigma cannot be negative") + ), + lambda _: None, + operand=None, + ) + else: + # NumPy path: normal boolean works + if bool(is_negative): + raise ValueError("Sigma cannot be negative") class NormalMessage(AbstractMessage): - @cached_property - def log_partition(self): + + def log_partition(self, xp=np): """ Compute the log-partition function (also called log-normalizer or cumulant function) for the normal distribution in its natural (canonical) parameterization. @@ -39,8 +59,8 @@ def log_partition(self): A(η) = η₁² / (-4η₂) - 0.5 * log(-2η₂) This ensures normalization of the exponential-family distribution. """ - eta1, eta2 = self.natural_parameters - return -(eta1**2) / 4 / eta2 - np.log(-2 * eta2) / 2 + eta1, eta2 = self.natural_parameters(xp=xp) + return -(eta1**2) / 4 / eta2 - xp.log(-2 * eta2) / 2 log_base_measure = -0.5 * np.log(2 * np.pi) _support = ((-np.inf, np.inf),) @@ -73,18 +93,25 @@ def __init__( id_ An optional unique identifier used to track the message in larger probabilistic graphs or models. """ - if (np.array(sigma) < 0).any(): - raise exc.MessageException("Sigma cannot be negative") + + if isinstance(mean, (np.ndarray, float, int, list)): + xp = np + else: + import jax.numpy as jnp + xp = jnp + + # assert_sigma_non_negative(sigma, xp=xp) super().__init__( mean, sigma, log_norm=log_norm, id_=id_, + _xp=xp ) self.mean, self.sigma = self.parameters - def cdf(self, x : Union[float, np.ndarray]) -> Union[float, np.ndarray]: + def cdf(self, x : Union[float, np.ndarray], xp=np) -> Union[float, np.ndarray]: """ Compute the cumulative distribution function (CDF) of the Gaussian distribution at a given value or array of values `x`. @@ -122,8 +149,7 @@ def ppf(self, x : Union[float, np.ndarray]) -> Union[float, np.ndarray]: return norm.ppf(x, loc=self.mean, scale=self.sigma) - @cached_property - def natural_parameters(self) -> np.ndarray: + def natural_parameters(self, xp=np) -> np.ndarray: """ The natural (canonical) parameters of the Gaussian distribution in exponential-family form. @@ -136,10 +162,10 @@ def natural_parameters(self) -> np.ndarray: ------- A NumPy array containing the two natural parameters [η₁, η₂]. """ - return self.calc_natural_parameters(self.mean, self.sigma) + return self.calc_natural_parameters(self.mean, self.sigma, xp=xp) @staticmethod - def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray]) -> np.ndarray: + def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray], xp=np) -> np.ndarray: """ Convert standard parameters of a Gaussian distribution (mean and standard deviation) into natural parameters used in its exponential family representation. @@ -158,7 +184,7 @@ def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, η₂ = -1 / (2σ²) """ precision = 1 / sigma**2 - return np.array([mu * precision, -precision / 2]) + return xp.array([mu * precision, -precision / 2]) @staticmethod def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, float]: @@ -181,7 +207,7 @@ def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, f return mu, sigma @staticmethod - def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray: + def to_canonical_form(x : Union[float, np.ndarray], xp=np) -> np.ndarray: """ Convert a scalar input `x` to its sufficient statistics for the Gaussian exponential family. @@ -197,7 +223,7 @@ def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray: ------- The sufficient statistics [x, x²]. """ - return np.array([x, x**2]) + return xp.array([x, x**2]) @classmethod def invert_sufficient_statistics(cls, suff_stats: Tuple[float, float]) -> np.ndarray: @@ -401,7 +427,7 @@ def value_for(self, unit: float) -> float: return self.mean + (self.sigma * np.sqrt(2) * inv) - def log_prior_from_value(self, value: float) -> float: + def log_prior_from_value(self, value: float, xp=np) -> float: """ Compute the log prior probability of a given physical value under this Gaussian prior. @@ -445,7 +471,7 @@ def natural(self)-> "NaturalNormal": A natural form Gaussian with zeroed parameters but same configuration. """ return NaturalNormal.from_natural_parameters( - self.natural_parameters * 0.0, **self._init_kwargs + self.natural_parameters() * 0.0, **self._init_kwargs ) def zeros_like(self) -> "AbstractMessage": @@ -531,7 +557,7 @@ def mean(self) -> float: return np.nan_to_num(-self.parameters[0] / self.parameters[1] / 2) @staticmethod - def calc_natural_parameters(eta1: float, eta2: float) -> np.ndarray: + def calc_natural_parameters(eta1: float, eta2: float, xp=np) -> np.ndarray: """ Return the natural parameters in array form (identity function for this class). @@ -542,14 +568,13 @@ def calc_natural_parameters(eta1: float, eta2: float) -> np.ndarray: eta2 The second natural parameter. """ - return np.array([eta1, eta2]) + return xp.array([eta1, eta2]) - @cached_property - def natural_parameters(self) -> np.ndarray: + def natural_parameters(self, xp=np) -> np.ndarray: """ Return the natural parameters of this distribution. """ - return self.calc_natural_parameters(*self.parameters) + return self.calc_natural_parameters(*self.parameters, xp=xp) @classmethod def invert_sufficient_statistics(cls, suff_stats: Tuple[float, float]) -> np.ndarray: diff --git a/autofit/messages/truncated_normal.py b/autofit/messages/truncated_normal.py index 3d1614584..081e4d99d 100644 --- a/autofit/messages/truncated_normal.py +++ b/autofit/messages/truncated_normal.py @@ -25,8 +25,8 @@ def is_nan(value): class TruncatedNormalMessage(AbstractMessage): - @cached_property - def log_partition(self) -> float: + + def log_partition(self, xp=np) -> float: """ Compute the log-partition function (normalizer) of the truncated Gaussian. @@ -46,7 +46,7 @@ def log_partition(self) -> float: a = (self.lower_limit - self.mean) / self.sigma b = (self.upper_limit - self.mean) / self.sigma Z = norm.cdf(b) - norm.cdf(a) - return np.log(Z) if Z > 0 else -np.inf + return xp.log(Z) if Z > 0 else -xp.inf log_base_measure = -0.5 * np.log(2 * np.pi) @@ -96,7 +96,7 @@ def __init__( self.mean, self.sigma, self.lower_limit, self.upper_limit = self.parameters - def cdf(self, x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: + def cdf(self, x: Union[float, np.ndarray], xp=np) -> Union[float, np.ndarray]: """ Compute the cumulative distribution function (CDF) of the truncated Gaussian distribution at a given value or array of values `x`. @@ -141,8 +141,7 @@ def ppf(self, x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: b = (self.upper_limit - self.mean) / self.sigma return truncnorm.ppf(x, a=a, b=b, loc=self.mean, scale=self.sigma) - @cached_property - def natural_parameters(self) -> np.ndarray: + def natural_parameters(self, xp=np) -> np.ndarray: """ The pseudo-natural (canonical) parameters of a truncated Gaussian distribution. @@ -159,10 +158,10 @@ def natural_parameters(self) -> np.ndarray: ------- A NumPy array containing the pseudo-natural parameters [η₁, η₂]. """ - return self.calc_natural_parameters(self.mean, self.sigma) + return self.calc_natural_parameters(self.mean, self.sigma, xp=xp) @staticmethod - def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray]) -> np.ndarray: + def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray], xp=np) -> np.ndarray: """ Convert standard parameters of a Gaussian distribution (mean and standard deviation) into natural parameters used in its exponential family representation. @@ -190,7 +189,7 @@ def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, η₂ = -1 / (2σ²) """ precision = 1 / sigma**2 - return np.array([mu * precision, -precision / 2]) + return xp.array([mu * precision, -precision / 2]) @staticmethod def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, float]: @@ -217,7 +216,7 @@ def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, f return mu, sigma @staticmethod - def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray: + def to_canonical_form(x : Union[float, np.ndarray], xp=np) -> np.ndarray: """ Convert a scalar input `x` to its sufficient statistics for the Gaussian exponential family. @@ -234,7 +233,7 @@ def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray: ------- The sufficient statistics [x, x²]. """ - return np.array([x, x**2]) + return xp.array([x, x**2]) @classmethod def invert_sufficient_statistics(cls, suff_stats: Tuple[float, float]) -> np.ndarray: @@ -422,7 +421,7 @@ def value_for(self, unit: float) -> float: x_standard = norm.ppf(truncated_cdf) return self.mean + self.sigma * x_standard - def log_prior_from_value(self, value: float) -> float: + def log_prior_from_value(self, value: float, xp=np) -> float: """ Compute the log prior probability of a given physical value under this truncated Gaussian prior. @@ -439,18 +438,31 @@ def log_prior_from_value(self, value: float) -> float: ------- The log prior probability of the given value, or -inf if outside truncation bounds. """ - from scipy.stats import norm + if xp.__name__.startswith("jax"): + import jax.scipy.stats as jstats + norm = jstats.norm + else: + from scipy.stats import norm + + # Normalization term (truncation) a = (self.lower_limit - self.mean) / self.sigma b = (self.upper_limit - self.mean) / self.sigma Z = norm.cdf(b) - norm.cdf(a) - z = (value -self.mean) / self.sigma - log_pdf = -0.5 * z ** 2 - np.log(self.sigma) - 0.5 * np.log(2 * np.pi) - log_trunc_pdf = log_pdf - np.log(Z) + # Log pdf + z = (value - self.mean) / self.sigma + log_pdf = ( + -0.5 * z ** 2 + - xp.log(self.sigma) + - 0.5 * xp.log(2.0 * xp.pi) + ) + log_trunc_pdf = log_pdf - xp.log(Z) + # Truncation mask (must be xp.where for JAX) in_bounds = (self.lower_limit <= value) & (value <= self.upper_limit) - return np.where(in_bounds, log_trunc_pdf, -np.inf) + + return xp.where(in_bounds, log_trunc_pdf, -xp.inf) def __str__(self): """ @@ -481,7 +493,7 @@ def natural(self)-> "NaturalNormal": A natural form Gaussian with zeroed parameters but same configuration. """ return TruncatedNaturalNormal.from_natural_parameters( - self.natural_parameters * 0.0, **self._init_kwargs + self.natural_parameters() * 0.0, **self._init_kwargs ) def zeros_like(self) -> "AbstractMessage": @@ -605,10 +617,11 @@ def mean(self) -> float: @staticmethod def calc_natural_parameters( - eta1: float, - eta2: float, - lower_limit: float = -np.inf, - upper_limit: float = np.inf + eta1: float, + eta2: float, + lower_limit: float = -np.inf, + upper_limit: float = np.inf, + xp=np ) -> np.ndarray: """ Return the natural parameters in array form (identity function for this class). @@ -623,14 +636,13 @@ def calc_natural_parameters( eta2 The second natural parameter. """ - return np.array([eta1, eta2]) + return xp.array([eta1, eta2]) - @cached_property - def natural_parameters(self) -> np.ndarray: + def natural_parameters(self, xp=np) -> np.ndarray: """ Return the natural parameters of this distribution. """ - return self.calc_natural_parameters(*self.parameters, self.lower_limit, self.upper_limit) + return self.calc_natural_parameters(*self.parameters, self.lower_limit, self.upper_limit, xp=xp) @classmethod def invert_sufficient_statistics( diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 996bb995f..794d32361 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -36,7 +36,7 @@ def __init__( self, use_jax : bool = False, **kwargs ): - self.use_jax = use_jax + self._use_jax = use_jax self.kwargs = kwargs def __getattr__(self, item: str): @@ -63,7 +63,7 @@ def method(*args, **kwargs): @property def _xp(self): - if self.use_jax: + if self._use_jax: import jax.numpy as jnp return jnp return np @@ -109,7 +109,7 @@ def compute_latent_samples(self, samples: Samples, batch_size : Optional[int] = compute_latent_for_model = functools.partial(self.compute_latent_variables, model=samples.model) - if self.use_jax: + if self._use_jax: import jax start = time.time() logger.info("JAX: Applying vmap and jit to likelihood function for latent variables -- may take a few seconds.") @@ -130,7 +130,7 @@ def batched_compute_latent(x): # batched JAX call on this chunk latent_values_batch = batched_compute_latent(batch) - if self.use_jax: + if self._use_jax: import jax.numpy as jnp latent_values_batch = jnp.stack(latent_values_batch, axis=-1) # (batch, n_latents) mask = jnp.all(jnp.isfinite(latent_values_batch), axis=0) @@ -314,4 +314,57 @@ def profile_log_likelihood_function(self, paths: AbstractPaths, instance): pass def perform_quick_update(self, paths, instance): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError + + def print_vram_use(self, model, batch_size : int) -> str: + """ + Print JAX VRAM use for a given batch size. + + Parameters + ---------- + batch_size + The batch size to profile, which is the number of model evaluations JAX will perform simultaneously. + """ + if not self._use_jax: + print("use_jax=False for this analysis, therefore does not use GPU and VRAM use cannot be profiled.") + return + + import jax + import jax.numpy as jnp + + from autofit.non_linear.fitness import Fitness + + fitness = Fitness( + model=model, + analysis=self, + fom_is_log_likelihood=True, + use_jax_vmap=True, + batch_size=batch_size, + ) + + parameters = np.zeros((batch_size, model.total_free_parameters)) + + for i in range(batch_size): + parameters[i, :] = model.physical_values_from_prior_medians + + parameters = jnp.array(parameters) + + batched_call = jax.jit(jax.vmap(fitness.call)) + lowered = batched_call.lower(parameters) + compiled = lowered.compile() + memory_analysis = compiled.memory_analysis() + + vram_bytes = ( + memory_analysis.output_size_in_bytes + + memory_analysis.temp_size_in_bytes + ) + + if vram_bytes == 0: + print( + "VRAM USE = 0.000 GB " + "(this likely means JAX is running in CPU-only mode)" + ) + else: + print( + f"VRAM USE = {vram_bytes / 1024 ** 3:.3f} GB" + ) \ No newline at end of file diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index be0e63403..85b2be91b 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -44,7 +44,6 @@ def __init__( use_jax_vmap : bool = False, batch_size : Optional[int] = None, iterations_per_quick_update: Optional[int] = None, - xp=np, ): """ Interfaces with any non-linear search to fit the model to the data and return a log likelihood via @@ -109,7 +108,8 @@ def __init__( self.model = model self.paths = paths self.fom_is_log_likelihood = fom_is_log_likelihood - self.resample_figure_of_merit = resample_figure_of_merit or -xp.inf + + self.resample_figure_of_merit = resample_figure_of_merit or -self._xp.inf self.convert_to_chi_squared = convert_to_chi_squared self.store_history = store_history @@ -126,7 +126,7 @@ def __init__( self.batch_size = batch_size self.iterations_per_quick_update = iterations_per_quick_update self.quick_update_max_lh_parameters = None - self.quick_update_max_lh = -xp.inf + self.quick_update_max_lh = -self._xp.inf self.quick_update_count = 0 if self.paths is not None: @@ -134,10 +134,7 @@ def __init__( @property def _xp(self): - if self.analysis.use_jax: - import jax.numpy as jnp - return jnp - return np + return self.analysis._xp def call(self, parameters): """ @@ -157,7 +154,7 @@ def call(self, parameters): The figure of merit returned to the non-linear search, which is either the log likelihood or log posterior. """ # Get instance from model - instance = self.model.instance_from_vector(vector=parameters) + instance = self.model.instance_from_vector(vector=parameters, xp=self._xp) if self._xp.__name__.startswith("jax"): @@ -173,13 +170,14 @@ def call(self, parameters): # Penalize NaNs in the log-likelihood log_likelihood = self._xp.where(self._xp.isnan(log_likelihood), self.resample_figure_of_merit, log_likelihood) + log_likelihood = self._xp.where(self._xp.isinf(log_likelihood), self.resample_figure_of_merit, log_likelihood) # Determine final figure of merit if self.fom_is_log_likelihood: figure_of_merit = log_likelihood else: # Ensure prior list is compatible with JAX (must return a JAX array, not list) - log_prior_array = self._xp.array(self.model.log_prior_list_from_vector(vector=parameters)) + log_prior_array = self._xp.array(self.model.log_prior_list_from_vector(vector=parameters, xp=self._xp)) figure_of_merit = log_likelihood + self._xp.sum(log_prior_array) # Convert to chi-squared scale if requested @@ -222,23 +220,20 @@ def call_wrap(self, parameters): figure_of_merit = self._call(parameters) if self.convert_to_chi_squared: - figure_of_merit *= -0.5 - - if self.fom_is_log_likelihood: - log_likelihood = figure_of_merit + log_likelihood = -0.5 * figure_of_merit else: - log_prior_list = self._xp.array(self.model.log_prior_list_from_vector(vector=parameters)) - log_likelihood = figure_of_merit - self._xp.sum(log_prior_list) + log_likelihood = figure_of_merit - self.manage_quick_update(parameters=parameters, log_likelihood=log_likelihood) + if not self.fom_is_log_likelihood: + log_prior_list = np.array(self.model.log_prior_list_from_vector(vector=parameters, xp=np)) + log_likelihood -= np.sum(log_prior_list) - if self.convert_to_chi_squared: - log_likelihood *= -2.0 + self.manage_quick_update(parameters=parameters, log_likelihood=log_likelihood) if self.store_history: - self.parameters_history_list.append(parameters) - self.log_likelihood_history_list.append(log_likelihood) + self.parameters_history_list.append(np.array(parameters)) + self.log_likelihood_history_list.append(np.array(log_likelihood)) return figure_of_merit @@ -317,7 +312,7 @@ def manage_quick_update(self, parameters, log_likelihood): logger.info("Performing quick update of maximum log likelihood fit image and model.results") - instance = self.model.instance_from_vector(vector=self.quick_update_max_lh_parameters) + instance = self.model.instance_from_vector(vector=self.quick_update_max_lh_parameters, xp=self._xp) try: self.analysis.perform_quick_update(self.paths, instance) diff --git a/autofit/non_linear/grid/grid_search/__init__.py b/autofit/non_linear/grid/grid_search/__init__.py index 8768c28f5..5c42d201f 100644 --- a/autofit/non_linear/grid/grid_search/__init__.py +++ b/autofit/non_linear/grid/grid_search/__init__.py @@ -219,8 +219,6 @@ def _fit( result: GridSearchResult The result of the grid search """ - self.logger.info("...in parallel") - grid_priors = model.sort_priors_alphabetically(set(grid_priors)) lists = self.make_lists(grid_priors) diff --git a/autofit/non_linear/grid/grid_search/job.py b/autofit/non_linear/grid/grid_search/job.py index 892d8ad46..bef62184a 100644 --- a/autofit/non_linear/grid/grid_search/job.py +++ b/autofit/non_linear/grid/grid_search/job.py @@ -51,6 +51,7 @@ def __init__( self.info = info def perform(self): + result = self.search_instance.fit( model=self.model, analysis=self.analysis, diff --git a/autofit/non_linear/search/mle/bfgs/search.py b/autofit/non_linear/search/mle/bfgs/search.py index 0ef3a5b34..2b105e786 100644 --- a/autofit/non_linear/search/mle/bfgs/search.py +++ b/autofit/non_linear/search/mle/bfgs/search.py @@ -136,21 +136,32 @@ def _fit( maxiter = self.config_dict_options.get("maxiter", 1e8) while total_iterations < maxiter: - iterations_remaining = maxiter - total_iterations + iterations_remaining = maxiter - total_iterations iterations = min(self.iterations_per_full_update, iterations_remaining) if iterations > 0: config_dict_options = self.config_dict_options config_dict_options["maxiter"] = iterations - search_internal = optimize.minimize( - fun=fitness.call_wrap, - x0=x0, - method=self.method, - options=config_dict_options, - **self.config_dict_search - ) + if analysis._use_jax: + + search_internal = optimize.minimize( + fun=fitness._jit, + x0=x0, + method=self.method, + options=config_dict_options, + **self.config_dict_search + ) + else: + + search_internal = optimize.minimize( + fun=fitness.__call__, + x0=x0, + method=self.method, + options=config_dict_options, + **self.config_dict_search + ) total_iterations += search_internal.nit @@ -208,7 +219,6 @@ def samples_via_internal_from( x0 = search_internal.x total_iterations = search_internal.nit - if self.should_plot_start_point: parameter_lists = search_internal.parameters_history_list diff --git a/autofit/non_linear/search/nest/dynesty/search/abstract.py b/autofit/non_linear/search/nest/dynesty/search/abstract.py index b566a1c50..913bd2b62 100644 --- a/autofit/non_linear/search/nest/dynesty/search/abstract.py +++ b/autofit/non_linear/search/nest/dynesty/search/abstract.py @@ -146,7 +146,7 @@ def _fit( "parallel" ].get("force_x1_cpu") or self.kwargs.get("force_x1_cpu") - or analysis.use_jax + or analysis._use_jax ): raise RuntimeError diff --git a/autofit/non_linear/search/nest/nautilus/search.py b/autofit/non_linear/search/nest/nautilus/search.py index 30531edbe..dcf8cdeb0 100644 --- a/autofit/non_linear/search/nest/nautilus/search.py +++ b/autofit/non_linear/search/nest/nautilus/search.py @@ -41,6 +41,7 @@ def __init__( iterations_per_full_update: int = None, number_of_cores: int = None, session: Optional[sa.orm.Session] = None, + use_jax_vmap : bool = True, **kwargs ): """ @@ -90,6 +91,8 @@ def __init__( self.logger.debug("Creating Nautilus Search") + self.use_jax_vmap = use_jax_vmap + def _fit(self, model: AbstractPriorModel, analysis): """ Fit a model using the search and the Analysis class which contains the data and returns the log likelihood from @@ -127,7 +130,7 @@ def _fit(self, model: AbstractPriorModel, analysis): if ( self.config_dict.get("force_x1_cpu") or self.kwargs.get("force_x1_cpu") - or analysis.use_jax + or analysis._use_jax ): fitness = Fitness( @@ -137,7 +140,7 @@ def _fit(self, model: AbstractPriorModel, analysis): fom_is_log_likelihood=True, resample_figure_of_merit=-1.0e99, iterations_per_quick_update=self.iterations_per_quick_update, - use_jax_vmap=True, + use_jax_vmap=self.use_jax_vmap, batch_size=self.config_dict_search["n_batch"], ) @@ -225,13 +228,15 @@ def fit_x1_cpu(self, fitness, model, analysis): except KeyError: pass + vectorized = fitness.use_jax_vmap + search_internal = self.sampler_cls( prior=PriorVectorized(model=model), likelihood=fitness.call_wrap, n_dim=model.prior_count, filepath=self.checkpoint_file, pool=None, - vectorized=True, + vectorized=vectorized, **config_dict, ) @@ -485,6 +490,10 @@ def samples_via_internal_from( samples_info=self.samples_info_from(search_internal=search_internal), ) + @property + def batch_size(self): + return self.config_dict_search.get("n_batch") + @property def config_dict(self): return conf.instance["non_linear"]["nest"][self.__class__.__name__] diff --git a/autofit/non_linear/settings.py b/autofit/non_linear/settings.py index a31ea873a..d0dc298aa 100644 --- a/autofit/non_linear/settings.py +++ b/autofit/non_linear/settings.py @@ -11,6 +11,7 @@ def __init__( number_of_cores: Optional[int] = 1, session: Optional[sa.orm.Session] = None, info: Optional[dict] = None, + use_jax_vmap: bool = True, ): """ Stores all the input settings that are used in search's and their `fit functions. @@ -40,6 +41,7 @@ def __init__( self.unique_tag = unique_tag self.number_of_cores = number_of_cores self.session = session + self.use_jax_vmap = use_jax_vmap self.info = info @@ -50,6 +52,7 @@ def search_dict(self): "unique_tag": self.unique_tag, "number_of_cores": self.number_of_cores, "session": self.session, + "use_jax_vmap": self.use_jax_vmap, } @property diff --git a/test_autofit/aggregator/summary_files/test_aggregate_images.py b/test_autofit/aggregator/summary_files/test_aggregate_images.py index d4334408b..1906aa522 100644 --- a/test_autofit/aggregator/summary_files/test_aggregate_images.py +++ b/test_autofit/aggregator/summary_files/test_aggregate_images.py @@ -6,8 +6,27 @@ from PIL import Image from autofit.aggregator import Aggregator -from autofit.aggregator.summary.aggregate_images import AggregateImages, SubplotFit - +from autofit.aggregator.summary.aggregate_images import AggregateImages + +class SubplotFit(Enum): + """ + The subplots that can be extracted from the subplot_fit image. + + The values correspond to the position of the subplot in the 4x3 grid. + """ + + Data = (0, 0) + DataSourceScaled = (1, 0) + SignalToNoiseMap = (2, 0) + ModelData = (3, 0) + LensLightModelData = (0, 1) + LensLightSubtractedImage = (1, 1) + SourceModelData = (2, 1) + SourcePlaneZoomed = (3, 1) + NormalizedResidualMap = (0, 2) + NormalizedResidualMapOneSigma = (1, 2) + ChiSquaredMap = (2, 2) + SourcePlaneNoZoom = (3, 2) @pytest.fixture def aggregate(aggregator): @@ -167,7 +186,7 @@ class SubplotFit(Enum): SubplotFit.Data, ] ) - assert result.size == (61, 120) + assert result.size == (244, 362) def test_bad_aggregator(): diff --git a/test_autofit/graphical/global/test_hierarchical.py b/test_autofit/graphical/global/test_hierarchical.py index 5e90d3704..6f22fa874 100644 --- a/test_autofit/graphical/global/test_hierarchical.py +++ b/test_autofit/graphical/global/test_hierarchical.py @@ -53,10 +53,7 @@ def test_model_info(model): 1 drawn_prior UniformPrior [1], lower_limit = 0.0, upper_limit = 1.0 2 - 3 - one UniformPrior [0], lower_limit = 0.0, upper_limit = 1.0 -factor - include_prior_factors True - use_jax False""" + one UniformPrior [0], lower_limit = 0.0, upper_limit = 1.0""" ) diff --git a/test_autofit/graphical/hierarchical/test_optimise.py b/test_autofit/graphical/hierarchical/test_optimise.py index 12caf64f2..f500029ae 100644 --- a/test_autofit/graphical/hierarchical/test_optimise.py +++ b/test_autofit/graphical/hierarchical/test_optimise.py @@ -11,13 +11,6 @@ def make_factor(hierarchical_factor): def test_optimise(factor): search = af.DynestyStatic(maxcall=100, dynamic_delta=False, delta=0.1,) - print(type(factor.analysis)) - print(type(factor.analysis)) - print(type(factor.analysis)) - print(type(factor.analysis)) - print(type(factor.analysis)) - - _, status = search.optimise( factor.mean_field_approximation().factor_approximation(factor) ) diff --git a/test_autofit/mapper/test_model_mapping_expanded.py b/test_autofit/mapper/test_model_mapping_expanded.py new file mode 100644 index 000000000..8be6fd534 --- /dev/null +++ b/test_autofit/mapper/test_model_mapping_expanded.py @@ -0,0 +1,631 @@ +""" +Expanded tests for the model mapping API, covering gaps identified in: +- Collection composition and instance creation +- Shared (linked) priors across model types +- Direct use of instance_for_arguments with argument dicts +- Model tree navigation (object_for_path, path_for_prior, name_for_prior) +- Edge cases (empty models, deeply nested models, single-parameter models) +- Model subsetting (with_paths, without_paths) +- Freezing behavior +- Assertion checking +- from_instance round-trips +- mapper_from_prior_arguments and mapper_from_partial_prior_arguments +""" +import copy + +import numpy as np +import pytest + +import autofit as af +from autofit import exc +from autofit.mapper.prior.abstract import Prior + + +# --------------------------------------------------------------------------- +# Collection: composition, nesting, instance creation, iteration +# --------------------------------------------------------------------------- +class TestCollectionComposition: + def test_collection_from_dict(self): + model = af.Collection( + one=af.Model(af.m.MockClassx2), + two=af.Model(af.m.MockClassx2), + ) + assert model.prior_count == 4 + + def test_collection_from_list(self): + model = af.Collection([af.m.MockClassx2, af.m.MockClassx2]) + assert model.prior_count == 4 + + def test_collection_from_generator(self): + model = af.Collection(af.Model(af.m.MockClassx2) for _ in range(3)) + assert model.prior_count == 6 + + def test_nested_collection(self): + inner = af.Collection(a=af.m.MockClassx2) + outer = af.Collection(inner=inner, extra=af.m.MockClassx2) + assert outer.prior_count == 4 + + def test_deeply_nested_collection(self): + model = af.Collection( + level1=af.Collection( + level2=af.Collection( + leaf=af.m.MockClassx2, + ) + ) + ) + assert model.prior_count == 2 + + def test_collection_instance_attribute_access(self): + model = af.Collection(gaussian=af.m.MockClassx2, exp=af.m.MockClassx2) + instance = model.instance_from_vector([1.0, 2.0, 3.0, 4.0]) + assert instance.gaussian.one == 1.0 + assert instance.gaussian.two == 2.0 + assert instance.exp.one == 3.0 + assert instance.exp.two == 4.0 + + def test_collection_instance_index_access(self): + model = af.Collection([af.m.MockClassx2, af.m.MockClassx2]) + instance = model.instance_from_vector([1.0, 2.0, 3.0, 4.0]) + assert instance[0].one == 1.0 + assert instance[1].one == 3.0 + + def test_collection_len(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + assert len(model) == 2 + + def test_collection_contains(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + assert "a" in model + assert "c" not in model + + def test_collection_items(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + keys = [k for k, v in model.items()] + assert "a" in keys + assert "b" in keys + + def test_collection_getitem_string(self): + model = af.Collection(a=af.m.MockClassx2) + assert isinstance(model["a"], af.Model) + + def test_collection_append(self): + model = af.Collection() + model.append(af.m.MockClassx2) + model.append(af.m.MockClassx2) + assert model.prior_count == 4 + + def test_collection_mixed_model_and_fixed(self): + """Collection with one free model and one fixed instance.""" + model = af.Collection( + free=af.Model(af.m.MockClassx2), + ) + assert model.prior_count == 2 + + def test_empty_collection(self): + model = af.Collection() + assert model.prior_count == 0 + + +# --------------------------------------------------------------------------- +# Shared (linked) priors +# --------------------------------------------------------------------------- +class TestSharedPriors: + def test_link_within_model(self): + model = af.Model(af.m.MockClassx2) + model.one = model.two + assert model.prior_count == 1 + instance = model.instance_from_vector([5.0]) + assert instance.one == instance.two == 5.0 + + def test_link_across_collection_children(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + model.a.one = model.b.one # Link a.one to b.one + assert model.prior_count == 3 # 4 - 1 shared + + def test_linked_priors_same_value(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + model.a.one = model.b.one + instance = model.instance_from_vector([10.0, 20.0, 30.0]) + assert instance.a.one == instance.b.one + + def test_link_reduces_unique_prior_count(self): + model = af.Model(af.m.MockClassx2) + original_count = len(model.unique_prior_tuples) + model.one = model.two + assert len(model.unique_prior_tuples) == original_count - 1 + + def test_linked_prior_identity(self): + model = af.Model(af.m.MockClassx2) + model.one = model.two + assert model.one is model.two + + +# --------------------------------------------------------------------------- +# instance_for_arguments (direct argument dict usage) +# --------------------------------------------------------------------------- +class TestInstanceForArguments: + def test_model_instance_for_arguments(self): + model = af.Model(af.m.MockClassx2) + args = {model.one: 10.0, model.two: 20.0} + instance = model.instance_for_arguments(args) + assert instance.one == 10.0 + assert instance.two == 20.0 + + def test_collection_instance_for_arguments(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + args = {} + for name, prior in model.prior_tuples_ordered_by_id: + args[prior] = 1.0 + instance = model.instance_for_arguments(args) + assert instance.a.one == 1.0 + assert instance.b.two == 1.0 + + def test_shared_prior_in_arguments(self): + """When priors are linked, only one entry is needed in the arguments dict.""" + model = af.Model(af.m.MockClassx2) + model.one = model.two + shared_prior = model.one + args = {shared_prior: 42.0} + instance = model.instance_for_arguments(args) + assert instance.one == 42.0 + assert instance.two == 42.0 + + def test_missing_prior_raises(self): + model = af.Model(af.m.MockClassx2) + args = {model.one: 10.0} # missing model.two + with pytest.raises(KeyError): + model.instance_for_arguments(args) + + +# --------------------------------------------------------------------------- +# Vector and unit vector mapping +# --------------------------------------------------------------------------- +class TestVectorMapping: + def test_instance_from_vector_basic(self): + model = af.Model(af.m.MockClassx2) + instance = model.instance_from_vector([3.0, 4.0]) + assert instance.one == 3.0 + assert instance.two == 4.0 + + def test_vector_length_mismatch_raises(self): + model = af.Model(af.m.MockClassx2) + with pytest.raises(AssertionError): + model.instance_from_vector([1.0]) + + def test_unit_vector_length_mismatch_raises(self): + model = af.Model(af.m.MockClassx2) + with pytest.raises(AssertionError): + model.instance_from_unit_vector([0.5]) + + def test_vector_from_unit_vector(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + physical = model.vector_from_unit_vector([0.0, 1.0]) + assert physical[0] == pytest.approx(0.0, abs=1e-6) + assert physical[1] == pytest.approx(10.0, abs=1e-6) + + def test_instance_from_prior_medians(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=100.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=100.0) + instance = model.instance_from_prior_medians() + assert instance.one == pytest.approx(50.0) + assert instance.two == pytest.approx(50.0) + + def test_random_instance(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) + instance = model.random_instance() + assert 0.0 <= instance.one <= 1.0 + assert 0.0 <= instance.two <= 1.0 + + +# --------------------------------------------------------------------------- +# Model tree navigation +# --------------------------------------------------------------------------- +class TestModelTreeNavigation: + def test_object_for_path_child_model(self): + model = af.Collection(g=af.Model(af.m.MockClassx2)) + child = model.object_for_path(("g",)) + assert isinstance(child, af.Model) + + def test_object_for_path_prior(self): + model = af.Collection(g=af.Model(af.m.MockClassx2)) + prior = model.object_for_path(("g", "one")) + assert isinstance(prior, Prior) + + def test_paths_matches_prior_count(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + assert len(model.paths) == model.prior_count + + def test_path_for_prior(self): + model = af.Collection(g=af.Model(af.m.MockClassx2)) + prior = model.g.one + path = model.path_for_prior(prior) + assert path == ("g", "one") + + def test_name_for_prior(self): + model = af.Collection(g=af.Model(af.m.MockClassx2)) + prior = model.g.one + name = model.name_for_prior(prior) + assert name == "g_one" + + def test_path_instance_tuples_for_class(self): + model = af.Collection(g=af.Model(af.m.MockClassx2)) + tuples = model.path_instance_tuples_for_class(Prior) + paths = [t[0] for t in tuples] + assert ("g", "one") in paths + assert ("g", "two") in paths + + def test_deeply_nested_path(self): + inner_model = af.Model(af.m.MockClassx2) + inner_collection = af.Collection(leaf=inner_model) + outer = af.Collection(branch=inner_collection) + + prior = outer.branch.leaf.one + path = outer.path_for_prior(prior) + assert path == ("branch", "leaf", "one") + + def test_direct_vs_recursive_prior_tuples(self): + model = af.Collection(a=af.m.MockClassx2) + assert len(model.direct_prior_tuples) == 0 # Collection has no direct priors + assert len(model.prior_tuples) == 2 # But has 2 recursive priors + + def test_direct_prior_model_tuples(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + assert len(model.direct_prior_model_tuples) == 2 + + +# --------------------------------------------------------------------------- +# instance_from_path_arguments and instance_from_prior_name_arguments +# --------------------------------------------------------------------------- +class TestPathAndNameArguments: + def test_instance_from_path_arguments(self): + model = af.Collection(g=af.m.MockClassx2) + instance = model.instance_from_path_arguments( + {("g", "one"): 10.0, ("g", "two"): 20.0} + ) + assert instance.g.one == 10.0 + assert instance.g.two == 20.0 + + def test_instance_from_prior_name_arguments(self): + model = af.Collection(g=af.m.MockClassx2) + instance = model.instance_from_prior_name_arguments( + {"g_one": 10.0, "g_two": 20.0} + ) + assert instance.g.one == 10.0 + assert instance.g.two == 20.0 + + +# --------------------------------------------------------------------------- +# Assertions +# --------------------------------------------------------------------------- +class TestAssertions: + def test_assertion_passes(self): + model = af.Model(af.m.MockClassx2) + model.add_assertion(model.one > model.two) + # one=10 > two=5 should pass + instance = model.instance_from_vector([10.0, 5.0]) + assert instance.one == 10.0 + + def test_assertion_fails(self): + model = af.Model(af.m.MockClassx2) + model.add_assertion(model.one > model.two) + with pytest.raises(exc.FitException): + model.instance_from_vector([1.0, 10.0]) + + def test_ignore_assertions(self): + model = af.Model(af.m.MockClassx2) + model.add_assertion(model.one > model.two) + # Should not raise even though assertion fails + instance = model.instance_from_vector([1.0, 10.0], ignore_assertions=True) + assert instance.one == 1.0 + + def test_multiple_assertions(self): + model = af.Model(af.m.MockClassx4) + model.add_assertion(model.one > model.two) + model.add_assertion(model.three > model.four) + # Both pass + instance = model.instance_from_vector([10.0, 5.0, 10.0, 5.0]) + assert instance.one == 10.0 + # First fails + with pytest.raises(exc.FitException): + model.instance_from_vector([1.0, 10.0, 10.0, 5.0]) + + def test_true_assertion_ignored(self): + """Adding True as an assertion should be a no-op.""" + model = af.Model(af.m.MockClassx2) + model.add_assertion(True) + assert len(model.assertions) == 0 + + +# --------------------------------------------------------------------------- +# Model subsetting (with_paths, without_paths) +# --------------------------------------------------------------------------- +class TestModelSubsetting: + def test_with_paths_single_child(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + subset = model.with_paths([("a",)]) + assert subset.prior_count == 2 + + def test_without_paths_single_child(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + subset = model.without_paths([("a",)]) + assert subset.prior_count == 2 + + def test_with_paths_specific_prior(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + subset = model.with_paths([("a", "one")]) + assert subset.prior_count == 1 + + def test_with_prefix(self): + model = af.Collection(ab_one=af.m.MockClassx2, cd_two=af.m.MockClassx2) + subset = model.with_prefix("ab") + assert subset.prior_count == 2 + + +# --------------------------------------------------------------------------- +# Freezing behavior +# --------------------------------------------------------------------------- +class TestFreezing: + def test_freeze_prevents_modification(self): + model = af.Model(af.m.MockClassx2) + model.freeze() + with pytest.raises(AssertionError): + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) + + def test_unfreeze_allows_modification(self): + model = af.Model(af.m.MockClassx2) + model.freeze() + model.unfreeze() + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) + assert isinstance(model.one, af.UniformPrior) + + def test_frozen_model_still_creates_instances(self): + model = af.Model(af.m.MockClassx2) + model.freeze() + instance = model.instance_from_vector([1.0, 2.0]) + assert instance.one == 1.0 + + def test_freeze_propagates_to_children(self): + model = af.Collection(a=af.m.MockClassx2) + model.freeze() + with pytest.raises(AssertionError): + model.a.one = 1.0 + + def test_cached_results_consistent(self): + model = af.Model(af.m.MockClassx2) + model.freeze() + result1 = model.prior_tuples_ordered_by_id + result2 = model.prior_tuples_ordered_by_id + assert result1 == result2 + + +# --------------------------------------------------------------------------- +# mapper_from_prior_arguments and related +# --------------------------------------------------------------------------- +class TestMapperFromPriorArguments: + def test_replace_all_priors(self): + model = af.Model(af.m.MockClassx2) + new_one = af.GaussianPrior(mean=0.0, sigma=1.0) + new_two = af.GaussianPrior(mean=5.0, sigma=2.0) + new_model = model.mapper_from_prior_arguments( + {model.one: new_one, model.two: new_two} + ) + assert new_model.prior_count == 2 + assert isinstance(new_model.one, af.GaussianPrior) + + def test_partial_replacement(self): + model = af.Model(af.m.MockClassx2) + new_one = af.GaussianPrior(mean=0.0, sigma=1.0) + new_model = model.mapper_from_partial_prior_arguments( + {model.one: new_one} + ) + assert new_model.prior_count == 2 + assert isinstance(new_model.one, af.GaussianPrior) + # two should retain its original prior type + assert new_model.two is not None + + def test_fix_via_mapper_from_prior_arguments(self): + """Replacing a prior with a float effectively fixes that parameter.""" + model = af.Model(af.m.MockClassx2) + new_model = model.mapper_from_prior_arguments( + {model.one: 5.0, model.two: model.two} + ) + assert new_model.prior_count == 1 + + def test_with_limits(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=100.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=100.0) + new_model = model.with_limits([(10.0, 20.0), (30.0, 40.0)]) + assert new_model.prior_count == 2 + + +# --------------------------------------------------------------------------- +# from_instance round trips +# --------------------------------------------------------------------------- +class TestFromInstance: + def test_from_simple_instance(self): + instance = af.m.MockClassx2(1.0, 2.0) + model = af.AbstractPriorModel.from_instance(instance) + assert model.prior_count == 0 + + def test_from_instance_as_model(self): + instance = af.m.MockClassx2(1.0, 2.0) + model = af.AbstractPriorModel.from_instance(instance) + free_model = model.as_model() + assert free_model.prior_count == 2 + + def test_from_instance_with_model_classes(self): + instance = af.m.MockClassx2(1.0, 2.0) + model = af.AbstractPriorModel.from_instance( + instance, model_classes=(af.m.MockClassx2,) + ) + assert model.prior_count == 2 + + def test_from_list_instance(self): + instance_list = [af.m.MockClassx2(1.0, 2.0), af.m.MockClassx2(3.0, 4.0)] + model = af.AbstractPriorModel.from_instance(instance_list) + assert model.prior_count == 0 + + def test_from_dict_instance(self): + instance_dict = { + "one": af.m.MockClassx2(1.0, 2.0), + "two": af.m.MockClassx2(3.0, 4.0), + } + model = af.AbstractPriorModel.from_instance(instance_dict) + assert model.prior_count == 0 + + +# --------------------------------------------------------------------------- +# Fixing parameters and Constant values +# --------------------------------------------------------------------------- +class TestFixedParameters: + def test_fix_reduces_prior_count(self): + model = af.Model(af.m.MockClassx2) + model.one = 5.0 + assert model.prior_count == 1 + + def test_fixed_value_in_instance(self): + model = af.Model(af.m.MockClassx2) + model.one = 5.0 + instance = model.instance_from_vector([10.0]) + assert instance.one == 5.0 + assert instance.two == 10.0 + + def test_fix_all_parameters(self): + model = af.Model(af.m.MockClassx2) + model.one = 5.0 + model.two = 10.0 + assert model.prior_count == 0 + instance = model.instance_from_vector([]) + assert instance.one == 5.0 + assert instance.two == 10.0 + + +# --------------------------------------------------------------------------- +# take_attributes (prior passing) +# --------------------------------------------------------------------------- +class TestTakeAttributes: + def test_take_from_instance(self): + model = af.Model(af.m.MockClassx2) + source = af.m.MockClassx2(10.0, 20.0) + model.take_attributes(source) + assert model.prior_count == 0 + + def test_take_from_model(self): + """Taking attributes from another model copies priors.""" + source_model = af.Model(af.m.MockClassx2) + source_model.one = af.GaussianPrior(mean=5.0, sigma=1.0) + source_model.two = af.GaussianPrior(mean=10.0, sigma=2.0) + + target_model = af.Model(af.m.MockClassx2) + target_model.take_attributes(source_model) + assert isinstance(target_model.one, af.GaussianPrior) + + +# --------------------------------------------------------------------------- +# Serialization (dict / from_dict) +# --------------------------------------------------------------------------- +class TestSerialization: + def test_model_dict_roundtrip(self): + model = af.Model(af.m.MockClassx2) + d = model.dict() + loaded = af.AbstractPriorModel.from_dict(d) + assert loaded.prior_count == model.prior_count + + def test_collection_dict_roundtrip(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + d = model.dict() + loaded = af.AbstractPriorModel.from_dict(d) + assert loaded.prior_count == model.prior_count + + def test_fixed_parameter_survives_roundtrip(self): + model = af.Model(af.m.MockClassx2) + model.one = 5.0 + d = model.dict() + loaded = af.AbstractPriorModel.from_dict(d) + assert loaded.prior_count == 1 + + def test_linked_prior_survives_roundtrip(self): + model = af.Model(af.m.MockClassx2) + model.one = model.two + assert model.prior_count == 1 + d = model.dict() + loaded = af.AbstractPriorModel.from_dict(d) + assert loaded.prior_count == 1 + + +# --------------------------------------------------------------------------- +# Log prior computation +# --------------------------------------------------------------------------- +class TestLogPrior: + def test_log_prior_within_bounds(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + log_priors = model.log_prior_list_from_vector([5.0, 5.0]) + assert all(np.isfinite(lp) for lp in log_priors) + + def test_log_prior_outside_bounds(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + log_priors = model.log_prior_list_from_vector([15.0, 5.0]) + # Out-of-bounds value should have a lower (or zero) log prior than in-bounds + assert log_priors[0] <= log_priors[1] + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- +class TestEdgeCases: + def test_single_parameter_model(self): + """A model with a single free parameter using explicit prior.""" + model = af.Model(af.m.MockClassx2) + model.two = 5.0 # Fix one parameter + assert model.prior_count == 1 + instance = model.instance_from_vector([42.0]) + assert instance.one == 42.0 + assert instance.two == 5.0 + + def test_model_copy_preserves_priors(self): + model = af.Model(af.m.MockClassx2) + copied = model.copy() + assert copied.prior_count == model.prior_count + # Priors are independent copies (different objects) + assert copied.one is not model.one + + def test_model_copy_linked_priors_independent(self): + """Copying a model with linked priors preserves the link in the copy.""" + model = af.Model(af.m.MockClassx2) + model.one = model.two + assert model.prior_count == 1 + copied = model.copy() + assert copied.prior_count == 1 + # The copy's internal link is preserved + assert copied.one is copied.two + + def test_prior_ordering_is_deterministic(self): + """prior_tuples_ordered_by_id should be stable across calls.""" + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + order1 = [(n, p.id) for n, p in model.prior_tuples_ordered_by_id] + order2 = [(n, p.id) for n, p in model.prior_tuples_ordered_by_id] + assert order1 == order2 + + def test_prior_count_equals_total_free_parameters(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx4) + assert model.prior_count == model.total_free_parameters + + def test_has_model(self): + model = af.Collection(a=af.Model(af.m.MockClassx2)) + assert model.has_model(af.m.MockClassx2) + assert not model.has_model(af.m.MockClassx4) + + def test_has_instance(self): + model = af.Model(af.m.MockClassx2) + assert model.has_instance(Prior) + assert not model.has_instance(af.m.MockClassx4)