From bff1343cd05adb62b2c8b1b1e82be338fa65d206 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 12:27:53 +0000 Subject: [PATCH 01/35] log_prior_from_value in TruncatedNormalMessae --- .../summary/aggregate_csv/column.py | 30 ++++++++++++++----- autofit/mapper/prior_model/abstract.py | 3 +- autofit/messages/truncated_normal.py | 8 ++--- autofit/non_linear/fitness.py | 2 +- autofit/non_linear/search/mle/bfgs/search.py | 2 +- 5 files changed, 30 insertions(+), 15 deletions(-) diff --git a/autofit/aggregator/summary/aggregate_csv/column.py b/autofit/aggregator/summary/aggregate_csv/column.py index 9635c1bde..77eedafde 100644 --- a/autofit/aggregator/summary/aggregate_csv/column.py +++ b/autofit/aggregator/summary/aggregate_csv/column.py @@ -64,20 +64,34 @@ def value(self, row: "Row"): result = {} if ValueType.Median in self.value_types: - result[""] = row.median_pdf_sample_kwargs[self.path] + try: + result[""] = row.median_pdf_sample_kwargs[self.path] + except KeyError: + result[""] = None if ValueType.MaxLogLikelihood in self.value_types: - result["max_lh"] = row.max_likelihood_kwargs[self.path] + try: + result["max_lh"] = row.max_likelihood_kwargs[self.path] + except KeyError: + result["max_lh"] = None if ValueType.ValuesAt1Sigma in self.value_types: - lower, upper = row.values_at_sigma_1_kwargs[self.path] - result["lower_1_sigma"] = lower - result["upper_1_sigma"] = upper + try: + lower, upper = row.values_at_sigma_1_kwargs[self.path] + result["lower_1_sigma"] = lower + result["upper_1_sigma"] = upper + except KeyError: + result["lower_1_sigma"] = None + result["upper_1_sigma"] = None if ValueType.ValuesAt3Sigma in self.value_types: - lower, upper = row.values_at_sigma_3_kwargs[self.path] - result["lower_3_sigma"] = lower - result["upper_3_sigma"] = upper + try: + lower, upper = row.values_at_sigma_3_kwargs[self.path] + result["lower_3_sigma"] = lower + result["upper_3_sigma"] = upper + except KeyError: + result["lower_3_sigma"] = None + result["upper_3_sigma"] = None return result diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index b41ea3c58..6e7778ad1 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -1078,6 +1078,7 @@ def instance_from_prior_medians(self, ignore_assertions : bool = False): def log_prior_list_from_vector( self, vector: [float], + xp=np, ): """ Compute the log priors of every parameter in a vector, using the Prior of every parameter. @@ -1094,7 +1095,7 @@ def log_prior_list_from_vector( return list( map( lambda prior_tuple, value: prior_tuple.prior.log_prior_from_value( - value=value + value=value, xp=np ), self.prior_tuples_ordered_by_id, vector, diff --git a/autofit/messages/truncated_normal.py b/autofit/messages/truncated_normal.py index 3d1614584..7329147a4 100644 --- a/autofit/messages/truncated_normal.py +++ b/autofit/messages/truncated_normal.py @@ -422,7 +422,7 @@ def value_for(self, unit: float) -> float: x_standard = norm.ppf(truncated_cdf) return self.mean + self.sigma * x_standard - def log_prior_from_value(self, value: float) -> float: + def log_prior_from_value(self, value: float, xp=np) -> float: """ Compute the log prior probability of a given physical value under this truncated Gaussian prior. @@ -446,11 +446,11 @@ def log_prior_from_value(self, value: float) -> float: Z = norm.cdf(b) - norm.cdf(a) z = (value -self.mean) / self.sigma - log_pdf = -0.5 * z ** 2 - np.log(self.sigma) - 0.5 * np.log(2 * np.pi) - log_trunc_pdf = log_pdf - np.log(Z) + log_pdf = -0.5 * z ** 2 - xp.log(self.sigma) - 0.5 * xp.log(2 * xp.pi) + log_trunc_pdf = log_pdf - xp.log(Z) in_bounds = (self.lower_limit <= value) & (value <= self.upper_limit) - return np.where(in_bounds, log_trunc_pdf, -np.inf) + return xp.where(in_bounds, log_trunc_pdf, -xp.inf) def __str__(self): """ diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index be0e63403..be9d916c6 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -179,7 +179,7 @@ def call(self, parameters): figure_of_merit = log_likelihood else: # Ensure prior list is compatible with JAX (must return a JAX array, not list) - log_prior_array = self._xp.array(self.model.log_prior_list_from_vector(vector=parameters)) + log_prior_array = self._xp.array(self.model.log_prior_list_from_vector(vector=parameters, xp=self._xp)) figure_of_merit = log_likelihood + self._xp.sum(log_prior_array) # Convert to chi-squared scale if requested diff --git a/autofit/non_linear/search/mle/bfgs/search.py b/autofit/non_linear/search/mle/bfgs/search.py index 0ef3a5b34..7b5c389e5 100644 --- a/autofit/non_linear/search/mle/bfgs/search.py +++ b/autofit/non_linear/search/mle/bfgs/search.py @@ -145,7 +145,7 @@ def _fit( config_dict_options["maxiter"] = iterations search_internal = optimize.minimize( - fun=fitness.call_wrap, + fun=fitness._jit, x0=x0, method=self.method, options=config_dict_options, From be2dbd0cc49b143db3c53c2fdd6366f50722e153 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 12:29:04 +0000 Subject: [PATCH 02/35] log_prior_from_value in UniformPrior --- autofit/mapper/prior/uniform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/autofit/mapper/prior/uniform.py b/autofit/mapper/prior/uniform.py index 08baed5bb..af04f23d5 100644 --- a/autofit/mapper/prior/uniform.py +++ b/autofit/mapper/prior/uniform.py @@ -1,3 +1,4 @@ +import numpy as np from typing import Optional, Tuple from autofit.messages.normal import UniformNormalMessage @@ -128,7 +129,7 @@ def value_for(self, unit: float) -> float: round(super().value_for(unit), 14) ) - def log_prior_from_value(self, value): + def log_prior_from_value(self, value, xp=np): """ Returns the log prior of a physical value, so the log likelihood of a model evaluation can be converted to a posterior as log_prior + log_likelihood. From a00bcff1cf270beee4d429179db16a4f0976bbe3 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 12:30:50 +0000 Subject: [PATCH 03/35] remaining def log_prior_from_value --- autofit/mapper/prior/log_gaussian.py | 2 +- autofit/mapper/prior/log_uniform.py | 2 +- autofit/messages/normal.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/autofit/mapper/prior/log_gaussian.py b/autofit/mapper/prior/log_gaussian.py index 6fa458950..d8f2a85b7 100644 --- a/autofit/mapper/prior/log_gaussian.py +++ b/autofit/mapper/prior/log_gaussian.py @@ -133,7 +133,7 @@ def value_for(self, unit: float) -> float: def parameter_string(self) -> str: return f"mean = {self.mean}, sigma = {self.sigma}" - def log_prior_from_value(self, value): + def log_prior_from_value(self, value, xp=np): if value <= 0: return float("-inf") diff --git a/autofit/mapper/prior/log_uniform.py b/autofit/mapper/prior/log_uniform.py index 63a9063b2..434af5b08 100644 --- a/autofit/mapper/prior/log_uniform.py +++ b/autofit/mapper/prior/log_uniform.py @@ -110,7 +110,7 @@ def with_limits(cls, lower_limit: float, upper_limit: float) -> "LogUniformPrior __identifier_fields__ = ("lower_limit", "upper_limit") - def log_prior_from_value(self, value) -> float: + def log_prior_from_value(self, value, xp=np) -> float: """ Returns the log prior of a physical value, so the log likelihood of a model evaluation can be converted to a posterior as log_prior + log_likelihood. diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index 9ff12ef3a..dff47e773 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -401,7 +401,7 @@ def value_for(self, unit: float) -> float: return self.mean + (self.sigma * np.sqrt(2) * inv) - def log_prior_from_value(self, value: float) -> float: + def log_prior_from_value(self, value: float, xp=np) -> float: """ Compute the log prior probability of a given physical value under this Gaussian prior. From 4874c6575137ccb37080e53fbb733adb0a22598c Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 12:46:05 +0000 Subject: [PATCH 04/35] fix xp pass --- autofit/mapper/prior_model/abstract.py | 2 +- autofit/messages/truncated_normal.py | 19 ++++++++++++++++--- autofit/non_linear/search/mle/bfgs/search.py | 3 ++- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index 6e7778ad1..5310315a5 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -1095,7 +1095,7 @@ def log_prior_list_from_vector( return list( map( lambda prior_tuple, value: prior_tuple.prior.log_prior_from_value( - value=value, xp=np + value=value, xp=xp ), self.prior_tuples_ordered_by_id, vector, diff --git a/autofit/messages/truncated_normal.py b/autofit/messages/truncated_normal.py index 7329147a4..07a34464c 100644 --- a/autofit/messages/truncated_normal.py +++ b/autofit/messages/truncated_normal.py @@ -439,17 +439,30 @@ def log_prior_from_value(self, value: float, xp=np) -> float: ------- The log prior probability of the given value, or -inf if outside truncation bounds. """ - from scipy.stats import norm + if xp.__name__.startswith("jax"): + import jax.scipy.stats as jstats + norm = jstats.norm + else: + from scipy.stats import norm + + # Normalization term (truncation) a = (self.lower_limit - self.mean) / self.sigma b = (self.upper_limit - self.mean) / self.sigma Z = norm.cdf(b) - norm.cdf(a) - z = (value -self.mean) / self.sigma - log_pdf = -0.5 * z ** 2 - xp.log(self.sigma) - 0.5 * xp.log(2 * xp.pi) + # Log pdf + z = (value - self.mean) / self.sigma + log_pdf = ( + -0.5 * z ** 2 + - xp.log(self.sigma) + - 0.5 * xp.log(2.0 * xp.pi) + ) log_trunc_pdf = log_pdf - xp.log(Z) + # Truncation mask (must be xp.where for JAX) in_bounds = (self.lower_limit <= value) & (value <= self.upper_limit) + return xp.where(in_bounds, log_trunc_pdf, -xp.inf) def __str__(self): diff --git a/autofit/non_linear/search/mle/bfgs/search.py b/autofit/non_linear/search/mle/bfgs/search.py index 7b5c389e5..9804f83b9 100644 --- a/autofit/non_linear/search/mle/bfgs/search.py +++ b/autofit/non_linear/search/mle/bfgs/search.py @@ -92,6 +92,7 @@ def _fit( fom_is_log_likelihood=False, resample_figure_of_merit=-np.inf, convert_to_chi_squared=True, + use_jax_vmap=True, store_history=self.should_plot_start_point ) @@ -145,7 +146,7 @@ def _fit( config_dict_options["maxiter"] = iterations search_internal = optimize.minimize( - fun=fitness._jit, + fun=fitness.call_wrap, x0=x0, method=self.method, options=config_dict_options, From 0629f55376fd08b78ac31b350c8397df6b9436b4 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 12:58:36 +0000 Subject: [PATCH 05/35] Put xp code throughout messages --- autofit/mapper/prior/arithmetic/assertion.py | 4 +++- autofit/mapper/prior/arithmetic/compound.py | 10 ++++++++++ autofit/mapper/prior/gaussian.py | 3 +++ autofit/mapper/prior_model/abstract.py | 6 +++++- autofit/mapper/prior_model/array.py | 1 + autofit/mapper/prior_model/collection.py | 1 + autofit/mapper/prior_model/prior_model.py | 1 + autofit/messages/normal.py | 17 +++++++++++++++-- autofit/non_linear/fitness.py | 2 +- autofit/non_linear/search/mle/bfgs/search.py | 3 +++ 10 files changed, 43 insertions(+), 5 deletions(-) diff --git a/autofit/mapper/prior/arithmetic/assertion.py b/autofit/mapper/prior/arithmetic/assertion.py index 2e7315d2b..5d9315e43 100644 --- a/autofit/mapper/prior/arithmetic/assertion.py +++ b/autofit/mapper/prior/arithmetic/assertion.py @@ -24,7 +24,7 @@ def __le__(self, other): class GreaterThanLessThanAssertion(ComparisonAssertion): - def _instance_for_arguments(self, arguments, ignore_assertions=False): + def _instance_for_arguments(self, arguments, ignore_assertions=False, xp=np): """ Assert that the value in the dictionary associated with the lower prior is lower than the value associated with the greater prior. @@ -55,6 +55,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): """ Assert that the value in the dictionary associated with the lower @@ -90,6 +91,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.assertion_1.instance_for_arguments( arguments, diff --git a/autofit/mapper/prior/arithmetic/compound.py b/autofit/mapper/prior/arithmetic/compound.py index 4951340a1..ce9d1b910 100644 --- a/autofit/mapper/prior/arithmetic/compound.py +++ b/autofit/mapper/prior/arithmetic/compound.py @@ -212,6 +212,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -237,6 +238,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -256,6 +258,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -275,6 +278,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -294,6 +298,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -313,6 +318,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -396,6 +402,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return -self.prior.instance_for_arguments( arguments, @@ -412,6 +419,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return abs( self.prior.instance_for_arguments( @@ -430,6 +438,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return np.log( self.prior.instance_for_arguments( @@ -448,6 +457,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return np.log10( self.prior.instance_for_arguments( diff --git a/autofit/mapper/prior/gaussian.py b/autofit/mapper/prior/gaussian.py index a4bcc5bb6..0ee8e568d 100644 --- a/autofit/mapper/prior/gaussian.py +++ b/autofit/mapper/prior/gaussian.py @@ -1,3 +1,4 @@ +import numpy as np from typing import Optional from autofit.messages.normal import NormalMessage @@ -13,6 +14,7 @@ def __init__( mean: float, sigma: float, id_: Optional[int] = None, + _xp=np ): """ A Gaussian prior defined by a normal distribution. @@ -50,6 +52,7 @@ def __init__( message=NormalMessage( mean=mean, sigma=sigma, + xp=_xp ), id_=id_, ) diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index 5310315a5..8c134687e 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -745,7 +745,7 @@ def physical_values_from_prior_medians(self): """ return self.vector_from_unit_vector([0.5] * len(self.unique_prior_tuples)) - def instance_from_vector(self, vector, ignore_assertions: bool = False): + def instance_from_vector(self, vector, ignore_assertions: bool = False, xp=np): """ Returns a ModelInstance, which has an attribute and class instance corresponding to every `Model` attributed to this instance. @@ -778,6 +778,7 @@ def instance_from_vector(self, vector, ignore_assertions: bool = False): return self.instance_for_arguments( arguments, ignore_assertions=ignore_assertions, + xp=xp ) def has(self, cls: Union[Type, Tuple[Type, ...]]) -> bool: @@ -1309,6 +1310,7 @@ def _instance_for_arguments( self, arguments: Dict[Prior, float], ignore_assertions: bool = False, + xp=np, ): raise NotImplementedError() @@ -1316,6 +1318,7 @@ def instance_for_arguments( self, arguments: Dict[Prior, float], ignore_assertions: bool = False, + xp=np, ): """ Returns an instance of the model for a set of arguments @@ -1339,6 +1342,7 @@ def instance_for_arguments( return self._instance_for_arguments( arguments, ignore_assertions=ignore_assertions, + xpx=p ) def path_for_name(self, name: str) -> Tuple[str, ...]: diff --git a/autofit/mapper/prior_model/array.py b/autofit/mapper/prior_model/array.py index 7952b2743..317680818 100644 --- a/autofit/mapper/prior_model/array.py +++ b/autofit/mapper/prior_model/array.py @@ -57,6 +57,7 @@ def _instance_for_arguments( self, arguments: Dict[Prior, float], ignore_assertions: bool = False, + xp=np, ) -> np.ndarray: """ Create an array where the prior at each index is replaced with the diff --git a/autofit/mapper/prior_model/collection.py b/autofit/mapper/prior_model/collection.py index 5d39dcdcd..36f7a14d8 100644 --- a/autofit/mapper/prior_model/collection.py +++ b/autofit/mapper/prior_model/collection.py @@ -208,6 +208,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): """ Parameters diff --git a/autofit/mapper/prior_model/prior_model.py b/autofit/mapper/prior_model/prior_model.py index cbf1cb285..778a15fdc 100644 --- a/autofit/mapper/prior_model/prior_model.py +++ b/autofit/mapper/prior_model/prior_model.py @@ -459,6 +459,7 @@ def _instance_for_arguments( self, arguments: {ModelObject: object}, ignore_assertions=False, + xp=np, ): """ Returns an instance of the associated class for a set of arguments diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index dff47e773..248e66c6f 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -23,6 +23,19 @@ def is_nan(value): is_nan_ = is_nan_.all() return is_nan_ +def assert_sigma_non_negative(sigma, xp=np): + sigma_arr = xp.asarray(sigma) + is_negative = xp.any(sigma_arr < 0) + + # Convert to Python bool safely: + try: + flag = bool(is_negative) + except Exception: + # JAX tracers need explicit .item() + flag = bool(is_negative.item()) + + if flag: + raise exc.MessageException("Sigma cannot be negative") class NormalMessage(AbstractMessage): @cached_property @@ -52,6 +65,7 @@ def __init__( sigma : Union[float, np.ndarray], log_norm : Optional[float] = 0.0, id_ : Optional[Hashable] = None, + xp=np ): """ A Gaussian (Normal) message representing a probability distribution over a continuous variable. @@ -73,8 +87,7 @@ def __init__( id_ An optional unique identifier used to track the message in larger probabilistic graphs or models. """ - if (np.array(sigma) < 0).any(): - raise exc.MessageException("Sigma cannot be negative") + assert_sigma_non_negative(sigma, xp=xp) super().__init__( mean, diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index be9d916c6..f0de611de 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -317,7 +317,7 @@ def manage_quick_update(self, parameters, log_likelihood): logger.info("Performing quick update of maximum log likelihood fit image and model.results") - instance = self.model.instance_from_vector(vector=self.quick_update_max_lh_parameters) + instance = self.model.instance_from_vector(vector=self.quick_update_max_lh_parameters, xp=self._xp) try: self.analysis.perform_quick_update(self.paths, instance) diff --git a/autofit/non_linear/search/mle/bfgs/search.py b/autofit/non_linear/search/mle/bfgs/search.py index 9804f83b9..32acb089a 100644 --- a/autofit/non_linear/search/mle/bfgs/search.py +++ b/autofit/non_linear/search/mle/bfgs/search.py @@ -227,6 +227,9 @@ def samples_via_internal_from( weight_list = len(log_likelihood_list) * [1.0] + print(log_likelihood_list) + print(parameter_lists) + sample_list = Sample.from_lists( model=model, parameter_lists=parameter_lists, From 15d08785799ca491c257343098a8e8b672e158f8 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 13:56:45 +0000 Subject: [PATCH 06/35] lots of hjorrible hacks --- autofit/mapper/prior/arithmetic/assertion.py | 1 + autofit/mapper/prior/gaussian.py | 3 +- autofit/mapper/prior_model/abstract.py | 2 +- autofit/mapper/prior_model/collection.py | 3 ++ autofit/mapper/prior_model/prior_model.py | 33 ++------------------ autofit/messages/normal.py | 16 +++++++++- autofit/non_linear/fitness.py | 8 ++--- autofit/non_linear/search/mle/bfgs/search.py | 7 ++--- 8 files changed, 31 insertions(+), 42 deletions(-) diff --git a/autofit/mapper/prior/arithmetic/assertion.py b/autofit/mapper/prior/arithmetic/assertion.py index 5d9315e43..b23a8c7a4 100644 --- a/autofit/mapper/prior/arithmetic/assertion.py +++ b/autofit/mapper/prior/arithmetic/assertion.py @@ -1,4 +1,5 @@ from abc import ABC +import numpy as np from typing import Optional, Dict from autofit.mapper.prior.arithmetic.compound import CompoundPrior, Compound diff --git a/autofit/mapper/prior/gaussian.py b/autofit/mapper/prior/gaussian.py index 0ee8e568d..fe74b6016 100644 --- a/autofit/mapper/prior/gaussian.py +++ b/autofit/mapper/prior/gaussian.py @@ -14,7 +14,6 @@ def __init__( mean: float, sigma: float, id_: Optional[int] = None, - _xp=np ): """ A Gaussian prior defined by a normal distribution. @@ -48,11 +47,11 @@ def __init__( >>> prior = GaussianPrior(mean=1.0, sigma=2.0) >>> physical_value = prior.value_for(unit=0.5) # Returns ~1.0 (mean) """ + super().__init__( message=NormalMessage( mean=mean, sigma=sigma, - xp=_xp ), id_=id_, ) diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index 8c134687e..f6b7f3939 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -1342,7 +1342,7 @@ def instance_for_arguments( return self._instance_for_arguments( arguments, ignore_assertions=ignore_assertions, - xpx=p + xp=xp ) def path_for_name(self, name: str) -> Tuple[str, ...]: diff --git a/autofit/mapper/prior_model/collection.py b/autofit/mapper/prior_model/collection.py index 36f7a14d8..13073c10e 100644 --- a/autofit/mapper/prior_model/collection.py +++ b/autofit/mapper/prior_model/collection.py @@ -1,3 +1,5 @@ +import numpy as np + from collections.abc import Iterable from autofit.mapper.model import ModelInstance, assert_not_frozen @@ -229,6 +231,7 @@ def _instance_for_arguments( value = value.instance_for_arguments( arguments, ignore_assertions=ignore_assertions, + xp=xp ) elif isinstance(value, Prior): value = arguments[value] diff --git a/autofit/mapper/prior_model/prior_model.py b/autofit/mapper/prior_model/prior_model.py index 778a15fdc..756899185 100644 --- a/autofit/mapper/prior_model/prior_model.py +++ b/autofit/mapper/prior_model/prior_model.py @@ -2,6 +2,7 @@ import copy import inspect import logging +import numpy as np import typing from typing import * @@ -420,36 +421,6 @@ def __getattr__(self, item): self.__getattribute__(item) - # def __getattr__(self, item): - # - # try: - # if ( - # "_" in item - # and item not in ("_is_frozen", "tuple_prior_tuples") - # and not item.startswith("_") - # ): - # return getattr( - # [v for k, v in self.tuple_prior_tuples if item.split("_")[0] == k][ - # 0 - # ], - # item, - # ) - # - # except IndexError: - # pass - # - # try: - # return getattr( - # self.instance_for_arguments( - # {prior: prior for prior in self.priors}, - # ), - # item, - # ) - # except (AttributeError, TypeError): - # pass - # - # self.__getattribute__(item) - @property def is_deferred_arguments(self): return len(self.direct_deferred_tuples) > 0 @@ -491,6 +462,7 @@ def _instance_for_arguments( ] = prior_model.instance_for_arguments( arguments, ignore_assertions=ignore_assertions, + xp=xp ) prior_arguments = dict() @@ -533,6 +505,7 @@ def _instance_for_arguments( value = value.instance_for_arguments( arguments, ignore_assertions=ignore_assertions, + xp=xp ) elif isinstance(value, Constant): value = value.value diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index 248e66c6f..c5c36f1d6 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -24,6 +24,7 @@ def is_nan(value): return is_nan_ def assert_sigma_non_negative(sigma, xp=np): + sigma_arr = xp.asarray(sigma) is_negative = xp.any(sigma_arr < 0) @@ -65,7 +66,6 @@ def __init__( sigma : Union[float, np.ndarray], log_norm : Optional[float] = 0.0, id_ : Optional[Hashable] = None, - xp=np ): """ A Gaussian (Normal) message representing a probability distribution over a continuous variable. @@ -87,6 +87,20 @@ def __init__( id_ An optional unique identifier used to track the message in larger probabilistic graphs or models. """ + if isinstance(mean, (np.ndarray, float, int)): + xp = np + else: + import jax.numpy as jnp + xp = jnp + + + print(type(mean)) + print(type(mean)) + print(xp) + print(xp) + print(xp) + print(xp) + assert_sigma_non_negative(sigma, xp=xp) super().__init__( diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index f0de611de..e3a7ab0b0 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -157,7 +157,7 @@ def call(self, parameters): The figure of merit returned to the non-linear search, which is either the log likelihood or log posterior. """ # Get instance from model - instance = self.model.instance_from_vector(vector=parameters) + instance = self.model.instance_from_vector(vector=parameters, xp=self._xp) if self._xp.__name__.startswith("jax"): @@ -227,7 +227,7 @@ def call_wrap(self, parameters): if self.fom_is_log_likelihood: log_likelihood = figure_of_merit else: - log_prior_list = self._xp.array(self.model.log_prior_list_from_vector(vector=parameters)) + log_prior_list = self._xp.array(self.model.log_prior_list_from_vector(vector=parameters, xp=self._xp)) log_likelihood = figure_of_merit - self._xp.sum(log_prior_list) self.manage_quick_update(parameters=parameters, log_likelihood=log_likelihood) @@ -237,8 +237,8 @@ def call_wrap(self, parameters): if self.store_history: - self.parameters_history_list.append(parameters) - self.log_likelihood_history_list.append(log_likelihood) + self.parameters_history_list.append(np.array(parameters)) + self.log_likelihood_history_list.append(np.array(log_likelihood)) return figure_of_merit diff --git a/autofit/non_linear/search/mle/bfgs/search.py b/autofit/non_linear/search/mle/bfgs/search.py index 32acb089a..42a849235 100644 --- a/autofit/non_linear/search/mle/bfgs/search.py +++ b/autofit/non_linear/search/mle/bfgs/search.py @@ -92,7 +92,6 @@ def _fit( fom_is_log_likelihood=False, resample_figure_of_merit=-np.inf, convert_to_chi_squared=True, - use_jax_vmap=True, store_history=self.should_plot_start_point ) @@ -146,7 +145,7 @@ def _fit( config_dict_options["maxiter"] = iterations search_internal = optimize.minimize( - fun=fitness.call_wrap, + fun=fitness._jit, x0=x0, method=self.method, options=config_dict_options, @@ -209,7 +208,6 @@ def samples_via_internal_from( x0 = search_internal.x total_iterations = search_internal.nit - if self.should_plot_start_point: parameter_lists = search_internal.parameters_history_list @@ -227,8 +225,9 @@ def samples_via_internal_from( weight_list = len(log_likelihood_list) * [1.0] - print(log_likelihood_list) print(parameter_lists) + print(log_likelihood_list) + print(log_prior_list) sample_list = Sample.from_lists( model=model, From 601c5f1fb74ff510b99782236901b52911315562 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 14:09:40 +0000 Subject: [PATCH 07/35] GaussianPrior works --- autofit/messages/abstract.py | 11 ++++++++-- autofit/messages/interface.py | 5 +++++ autofit/messages/normal.py | 41 +++++++++++++++++------------------ 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/autofit/messages/abstract.py b/autofit/messages/abstract.py index d3b921108..2148c66e5 100644 --- a/autofit/messages/abstract.py +++ b/autofit/messages/abstract.py @@ -45,17 +45,24 @@ def __init__( lower_limit=-math.inf, upper_limit=math.inf, id_=None, + _xp=np ): + xp=_xp + self.lower_limit = float(lower_limit) self.upper_limit = float(upper_limit) self.id = next(self.ids) if id_ is None else id_ self.log_norm = log_norm - self._broadcast = np.broadcast(*parameters) + + if xp is np: + self._broadcast = np.broadcast(*parameters) + else: + self._broadcast = _xp.broadcast_arrays(*parameters) if self.shape: - self.parameters = tuple(np.asanyarray(p) for p in parameters) + self.parameters = tuple(xp.aarray(p) for p in parameters) else: self.parameters = tuple(parameters) diff --git a/autofit/messages/interface.py b/autofit/messages/interface.py index 28d46d2ab..daff57479 100644 --- a/autofit/messages/interface.py +++ b/autofit/messages/interface.py @@ -23,6 +23,11 @@ def broadcast(self): @property def shape(self) -> Tuple[int, ...]: + + # JAX behaviour + if isinstance(self.broadcast, list): + return () + return self.broadcast.shape @property diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index c5c36f1d6..84bd3f2ff 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -25,18 +25,24 @@ def is_nan(value): def assert_sigma_non_negative(sigma, xp=np): - sigma_arr = xp.asarray(sigma) - is_negative = xp.any(sigma_arr < 0) - - # Convert to Python bool safely: - try: - flag = bool(is_negative) - except Exception: - # JAX tracers need explicit .item() - flag = bool(is_negative.item()) - - if flag: - raise exc.MessageException("Sigma cannot be negative") + is_negative = sigma < 0 + + if xp.__name__.startswith("jax"): + import jax + # JAX path: cannot convert to Python bool + # Raise using JAX control flow: + return jax.lax.cond( + is_negative, + lambda _: (_ for _ in ()).throw( + ValueError("Sigma cannot be negative") + ), + lambda _: None, + operand=None, + ) + else: + # NumPy path: normal boolean works + if bool(is_negative): + raise ValueError("Sigma cannot be negative") class NormalMessage(AbstractMessage): @cached_property @@ -93,21 +99,14 @@ def __init__( import jax.numpy as jnp xp = jnp - - print(type(mean)) - print(type(mean)) - print(xp) - print(xp) - print(xp) - print(xp) - - assert_sigma_non_negative(sigma, xp=xp) + # assert_sigma_non_negative(sigma, xp=xp) super().__init__( mean, sigma, log_norm=log_norm, id_=id_, + _xp=xp ) self.mean, self.sigma = self.parameters From e5dda2279d2d72e94ef9269e25b30f2a008580c9 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 14:24:43 +0000 Subject: [PATCH 08/35] various message funcitons now use xp --- .../declarative/factor/hierarchical.py | 7 +++++-- autofit/messages/abstract.py | 4 ++-- autofit/messages/beta.py | 10 +++++----- autofit/messages/composed_transform.py | 7 +++---- autofit/messages/fixed.py | 5 ++--- autofit/messages/gamma.py | 9 ++++----- autofit/messages/interface.py | 19 +++++++++---------- autofit/messages/normal.py | 18 ++++++++---------- autofit/messages/truncated_normal.py | 14 ++++++-------- 9 files changed, 44 insertions(+), 49 deletions(-) diff --git a/autofit/graphical/declarative/factor/hierarchical.py b/autofit/graphical/declarative/factor/hierarchical.py index f4dfec1d0..0ec5fde9b 100644 --- a/autofit/graphical/declarative/factor/hierarchical.py +++ b/autofit/graphical/declarative/factor/hierarchical.py @@ -8,6 +8,7 @@ from autofit.messages import NormalMessage from autofit.non_linear.paths.abstract import AbstractPaths from autofit.tools.namer import namer + from .abstract import AbstractModelFactor @@ -19,6 +20,7 @@ def __init__( distribution: Type[Prior], optimiser=None, name: Optional[str] = None, + use_jax : bool = False, **kwargs, ): """ @@ -70,6 +72,7 @@ def __init__( self._name = name or namer(self.__class__.__name__) self._factors = list() self.optimiser = optimiser + self.use_jax = use_jax @property def name(self): @@ -159,7 +162,7 @@ def __init__( """ self.distribution_model = distribution_model self.drawn_prior = drawn_prior - self.use_jax = use_jax + self.use_jax = distribution_model.use_jax prior_variable_dict = {prior.name: prior for prior in distribution_model.priors} @@ -188,7 +191,7 @@ def variable(self): return self.drawn_prior def log_likelihood_function(self, instance): - return instance.distribution_model.message(instance.drawn_prior) + return instance.distribution_model.message(instance.drawn_prior, xp=self._xp) @property def priors(self) -> Set[Prior]: diff --git a/autofit/messages/abstract.py b/autofit/messages/abstract.py index 2148c66e5..84d0739d4 100644 --- a/autofit/messages/abstract.py +++ b/autofit/messages/abstract.py @@ -366,8 +366,8 @@ def _get_mean_variance( ) return mean, variance - def __call__(self, x): - return np.sum(self.logpdf(x)) + def __call__(self, x, xp=np): + return xp.sum(self.logpdf(x, xp=xp)) def factor_jacobian( self, x: np.ndarray, _variables: Optional[Tuple[str]] = ("x",) diff --git a/autofit/messages/beta.py b/autofit/messages/beta.py index 2f4848e05..460b8a084 100644 --- a/autofit/messages/beta.py +++ b/autofit/messages/beta.py @@ -184,8 +184,7 @@ def log_partition(self) -> np.ndarray: return betaln(*self.parameters) - @cached_property - def natural_parameters(self) -> np.ndarray: + def natural_parameters(self, xp=np) -> np.ndarray: """ Compute the natural parameters of the Beta distribution. @@ -196,7 +195,8 @@ def natural_parameters(self) -> np.ndarray: """ return self.calc_natural_parameters( self.alpha, - self.beta + self.beta, + xp=xp ) @staticmethod @@ -258,7 +258,7 @@ def invert_sufficient_statistics( return cls.calc_natural_parameters(a, b) @classmethod - def to_canonical_form(cls, x: np.ndarray) -> np.ndarray: + def to_canonical_form(cls, x: np.ndarray, xp=np) -> np.ndarray: """ Convert a value x in (0,1) to the canonical sufficient statistics for Beta. @@ -271,7 +271,7 @@ def to_canonical_form(cls, x: np.ndarray) -> np.ndarray: ------- Canonical sufficient statistics [log(x), log(1 - x)]. """ - return np.array([np.log(x), np.log1p(-x)]) + return xp.array([xp.log(x), xp.log1p(-x)]) @cached_property def mean(self) -> Union[np.ndarray, float]: diff --git a/autofit/messages/composed_transform.py b/autofit/messages/composed_transform.py index 959647009..c06983dee 100644 --- a/autofit/messages/composed_transform.py +++ b/autofit/messages/composed_transform.py @@ -157,8 +157,7 @@ def project( def kl(self, dist): return self.base_message.kl(dist.base_message) - @property - def natural_parameters(self): + def natural_parameters(self, xp=np) -> np.ndarray: return self.base_message.natural_parameters @inverse_transform @@ -245,8 +244,8 @@ def calc_log_base_measure(self, x) -> np.ndarray: return self.base_message.calc_log_base_measure(x) @transform - def to_canonical_form(self, x) -> np.ndarray: - return self.base_message.to_canonical_form(x) + def to_canonical_form(self, x, xp=np) -> np.ndarray: + return self.base_message.to_canonical_form(x, xp=xp) @property @inverse_transform diff --git a/autofit/messages/fixed.py b/autofit/messages/fixed.py index 496a62804..8bf990a48 100644 --- a/autofit/messages/fixed.py +++ b/autofit/messages/fixed.py @@ -25,8 +25,7 @@ def __init__( def value_for(self, unit: float) -> float: raise NotImplemented() - @cached_property - def natural_parameters(self) -> Tuple[np.ndarray, ...]: + def natural_parameters(self, xp=np) -> Tuple[np.ndarray, ...]: return self.parameters @staticmethod @@ -35,7 +34,7 @@ def invert_natural_parameters(natural_parameters: np.ndarray return natural_parameters, @staticmethod - def to_canonical_form(x: np.ndarray) -> np.ndarray: + def to_canonical_form(x: np.ndarray, xp=np) -> np.ndarray: return x @cached_property diff --git a/autofit/messages/gamma.py b/autofit/messages/gamma.py index ae86a615e..3b4ed79c5 100644 --- a/autofit/messages/gamma.py +++ b/autofit/messages/gamma.py @@ -36,9 +36,8 @@ def __init__( def value_for(self, unit: float) -> float: raise NotImplemented() - @cached_property - def natural_parameters(self): - return self.calc_natural_parameters(self.alpha, self.beta) + def natural_parameters(self, xp=np) -> np.ndarray: + return self.calc_natural_parameters(self.alpha, self.beta, xp=xp) @staticmethod def calc_natural_parameters(alpha, beta): @@ -50,8 +49,8 @@ def invert_natural_parameters(natural_parameters): return eta1 + 1, -eta2 @staticmethod - def to_canonical_form(x): - return np.array([np.log(x), x]) + def to_canonical_form(x, xp=np): + return xp.array([np.log(x), x]) @classmethod def invert_sufficient_statistics(cls, suff_stats): diff --git a/autofit/messages/interface.py b/autofit/messages/interface.py index daff57479..3826cb821 100644 --- a/autofit/messages/interface.py +++ b/autofit/messages/interface.py @@ -44,32 +44,31 @@ def __eq__(self, other): def pdf(self, x: np.ndarray) -> np.ndarray: return np.exp(self.logpdf(x)) - def logpdf(self, x: Union[np.ndarray, float]) -> np.ndarray: - eta = self._broadcast_natural_parameters(x) - t = self.to_canonical_form(x) + def logpdf(self, x: Union[np.ndarray, float], xp=np) -> np.ndarray: + eta = self._broadcast_natural_parameters(x, xp=xp) + t = self.to_canonical_form(x, xp=xp) log_base = self.calc_log_base_measure(x) return self.natural_logpdf(eta, t, log_base, self.log_partition) - def _broadcast_natural_parameters(self, x): - shape = np.shape(x) + def _broadcast_natural_parameters(self, x, xp=np): + shape = xp.shape(x) if shape == self.shape: - return self.natural_parameters + return self.natural_parameters(xp=xp) elif shape[1:] == self.shape: - return self.natural_parameters[:, None, ...] + return self.natural_parameters(xp=xp)[:, None, ...] else: raise ValueError( f"shape of passed value {shape} does not " f"match message shape {self.shape}" ) - @cached_property @abstractmethod - def natural_parameters(self): + def natural_parameters(self, xp=np) -> np.ndarray: pass @staticmethod @abstractmethod - def to_canonical_form(x: Union[np.ndarray, float]) -> np.ndarray: + def to_canonical_form(x: Union[np.ndarray, float], xp=np) -> np.ndarray: pass @classmethod diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index 84bd3f2ff..407b3c5a6 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -148,8 +148,7 @@ def ppf(self, x : Union[float, np.ndarray]) -> Union[float, np.ndarray]: return norm.ppf(x, loc=self.mean, scale=self.sigma) - @cached_property - def natural_parameters(self) -> np.ndarray: + def natural_parameters(self, xp=np) -> np.ndarray: """ The natural (canonical) parameters of the Gaussian distribution in exponential-family form. @@ -162,10 +161,10 @@ def natural_parameters(self) -> np.ndarray: ------- A NumPy array containing the two natural parameters [η₁, η₂]. """ - return self.calc_natural_parameters(self.mean, self.sigma) + return self.calc_natural_parameters(self.mean, self.sigma, xp=xp) @staticmethod - def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray]) -> np.ndarray: + def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray], xp=np) -> np.ndarray: """ Convert standard parameters of a Gaussian distribution (mean and standard deviation) into natural parameters used in its exponential family representation. @@ -184,7 +183,7 @@ def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, η₂ = -1 / (2σ²) """ precision = 1 / sigma**2 - return np.array([mu * precision, -precision / 2]) + return xp.array([mu * precision, -precision / 2]) @staticmethod def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, float]: @@ -207,7 +206,7 @@ def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, f return mu, sigma @staticmethod - def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray: + def to_canonical_form(x : Union[float, np.ndarray], xp=np) -> np.ndarray: """ Convert a scalar input `x` to its sufficient statistics for the Gaussian exponential family. @@ -223,7 +222,7 @@ def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray: ------- The sufficient statistics [x, x²]. """ - return np.array([x, x**2]) + return xp.array([x, x**2]) @classmethod def invert_sufficient_statistics(cls, suff_stats: Tuple[float, float]) -> np.ndarray: @@ -570,12 +569,11 @@ def calc_natural_parameters(eta1: float, eta2: float) -> np.ndarray: """ return np.array([eta1, eta2]) - @cached_property - def natural_parameters(self) -> np.ndarray: + def natural_parameters(self, xp=np) -> np.ndarray: """ Return the natural parameters of this distribution. """ - return self.calc_natural_parameters(*self.parameters) + return self.calc_natural_parameters(*self.parameters, xp=xp) @classmethod def invert_sufficient_statistics(cls, suff_stats: Tuple[float, float]) -> np.ndarray: diff --git a/autofit/messages/truncated_normal.py b/autofit/messages/truncated_normal.py index 07a34464c..a26292292 100644 --- a/autofit/messages/truncated_normal.py +++ b/autofit/messages/truncated_normal.py @@ -141,8 +141,7 @@ def ppf(self, x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: b = (self.upper_limit - self.mean) / self.sigma return truncnorm.ppf(x, a=a, b=b, loc=self.mean, scale=self.sigma) - @cached_property - def natural_parameters(self) -> np.ndarray: + def natural_parameters(self, xp=np) -> np.ndarray: """ The pseudo-natural (canonical) parameters of a truncated Gaussian distribution. @@ -159,7 +158,7 @@ def natural_parameters(self) -> np.ndarray: ------- A NumPy array containing the pseudo-natural parameters [η₁, η₂]. """ - return self.calc_natural_parameters(self.mean, self.sigma) + return self.calc_natural_parameters(self.mean, self.sigma, xp=xp) @staticmethod def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray]) -> np.ndarray: @@ -217,7 +216,7 @@ def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, f return mu, sigma @staticmethod - def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray: + def to_canonical_form(x : Union[float, np.ndarray], xp=np) -> np.ndarray: """ Convert a scalar input `x` to its sufficient statistics for the Gaussian exponential family. @@ -234,7 +233,7 @@ def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray: ------- The sufficient statistics [x, x²]. """ - return np.array([x, x**2]) + return xp.array([x, x**2]) @classmethod def invert_sufficient_statistics(cls, suff_stats: Tuple[float, float]) -> np.ndarray: @@ -638,12 +637,11 @@ def calc_natural_parameters( """ return np.array([eta1, eta2]) - @cached_property - def natural_parameters(self) -> np.ndarray: + def natural_parameters(self, xp=np) -> np.ndarray: """ Return the natural parameters of this distribution. """ - return self.calc_natural_parameters(*self.parameters, self.lower_limit, self.upper_limit) + return self.calc_natural_parameters(*self.parameters, self.lower_limit, self.upper_limit, xp=xp) @classmethod def invert_sufficient_statistics( From 6f28bba56fef7c5ade08f4575bfca4e72a8bb2ea Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 14:32:44 +0000 Subject: [PATCH 09/35] integration tests work, cleaning up unit tests --- autofit/messages/abstract.py | 2 +- autofit/messages/beta.py | 3 +-- autofit/messages/composed_transform.py | 3 +-- autofit/messages/fixed.py | 3 +-- autofit/messages/gamma.py | 6 +++--- autofit/messages/interface.py | 11 +++++------ autofit/messages/normal.py | 8 ++++---- autofit/messages/truncated_normal.py | 6 +++--- 8 files changed, 19 insertions(+), 23 deletions(-) diff --git a/autofit/messages/abstract.py b/autofit/messages/abstract.py index 84d0739d4..cb76b5f80 100644 --- a/autofit/messages/abstract.py +++ b/autofit/messages/abstract.py @@ -62,7 +62,7 @@ def __init__( self._broadcast = _xp.broadcast_arrays(*parameters) if self.shape: - self.parameters = tuple(xp.aarray(p) for p in parameters) + self.parameters = tuple(xp.asarray(p) for p in parameters) else: self.parameters = tuple(parameters) diff --git a/autofit/messages/beta.py b/autofit/messages/beta.py index 460b8a084..8a4696d92 100644 --- a/autofit/messages/beta.py +++ b/autofit/messages/beta.py @@ -171,8 +171,7 @@ def value_for(self, unit: float) -> float: """ raise NotImplemented() - @cached_property - def log_partition(self) -> np.ndarray: + def log_partition(self, xp=np) -> np.ndarray: """ Compute the log partition function (log normalization constant) of the Beta distribution. diff --git a/autofit/messages/composed_transform.py b/autofit/messages/composed_transform.py index c06983dee..0b7f4f34f 100644 --- a/autofit/messages/composed_transform.py +++ b/autofit/messages/composed_transform.py @@ -228,8 +228,7 @@ def invert_natural_parameters( def cdf(self, x): return self.base_message.cdf(x) - @property - def log_partition(self) -> np.ndarray: + def log_partition(self, xp=np) -> np.ndarray: return self.base_message.log_partition def invert_sufficient_statistics(self, sufficient_statistics): diff --git a/autofit/messages/fixed.py b/autofit/messages/fixed.py index 8bf990a48..1280d5e95 100644 --- a/autofit/messages/fixed.py +++ b/autofit/messages/fixed.py @@ -37,8 +37,7 @@ def invert_natural_parameters(natural_parameters: np.ndarray def to_canonical_form(x: np.ndarray, xp=np) -> np.ndarray: return x - @cached_property - def log_partition(self) -> np.ndarray: + def log_partition(self, xp=np) -> np.ndarray: return 0. @classmethod diff --git a/autofit/messages/gamma.py b/autofit/messages/gamma.py index 3b4ed79c5..05c587552 100644 --- a/autofit/messages/gamma.py +++ b/autofit/messages/gamma.py @@ -6,11 +6,11 @@ class GammaMessage(AbstractMessage): - @property - def log_partition(self): + + def log_partition(self, xp=np): from scipy import special - alpha, beta = GammaMessage.invert_natural_parameters(self.natural_parameters) + alpha, beta = GammaMessage.invert_natural_parameters(self.natural_parameters(xp=xp)) return special.gammaln(alpha) - alpha * np.log(beta) log_base_measure = 0.0 diff --git a/autofit/messages/interface.py b/autofit/messages/interface.py index 3826cb821..2ed1932e9 100644 --- a/autofit/messages/interface.py +++ b/autofit/messages/interface.py @@ -48,7 +48,7 @@ def logpdf(self, x: Union[np.ndarray, float], xp=np) -> np.ndarray: eta = self._broadcast_natural_parameters(x, xp=xp) t = self.to_canonical_form(x, xp=xp) log_base = self.calc_log_base_measure(x) - return self.natural_logpdf(eta, t, log_base, self.log_partition) + return self.natural_logpdf(eta, t, log_base, self.log_partition(xp=xp), xp=xp) def _broadcast_natural_parameters(self, x, xp=np): shape = xp.shape(x) @@ -75,15 +75,14 @@ def to_canonical_form(x: Union[np.ndarray, float], xp=np) -> np.ndarray: def calc_log_base_measure(cls, x): return cls.log_base_measure - @cached_property @abstractmethod - def log_partition(self) -> np.ndarray: + def log_partition(self, xp=np) -> np.ndarray: pass @classmethod - def natural_logpdf(cls, eta, t, log_base, log_partition): - eta_t = np.multiply(eta, t).sum(0) - return np.nan_to_num(log_base + eta_t - log_partition, nan=-np.inf) + def natural_logpdf(cls, eta, t, log_base, log_partition, xp=np): + eta_t = xp.multiply(eta, t).sum(0) + return xp.nan_to_num(log_base + eta_t - log_partition, nan=-xp.inf) def numerical_logpdf_gradient( self, x: np.ndarray, eps: float = 1e-6 diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index 407b3c5a6..d8d6773cc 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -45,8 +45,8 @@ def assert_sigma_non_negative(sigma, xp=np): raise ValueError("Sigma cannot be negative") class NormalMessage(AbstractMessage): - @cached_property - def log_partition(self): + + def log_partition(self, xp=np): """ Compute the log-partition function (also called log-normalizer or cumulant function) for the normal distribution in its natural (canonical) parameterization. @@ -59,8 +59,8 @@ def log_partition(self): A(η) = η₁² / (-4η₂) - 0.5 * log(-2η₂) This ensures normalization of the exponential-family distribution. """ - eta1, eta2 = self.natural_parameters - return -(eta1**2) / 4 / eta2 - np.log(-2 * eta2) / 2 + eta1, eta2 = self.natural_parameters(xp=xp) + return -(eta1**2) / 4 / eta2 - xp.log(-2 * eta2) / 2 log_base_measure = -0.5 * np.log(2 * np.pi) _support = ((-np.inf, np.inf),) diff --git a/autofit/messages/truncated_normal.py b/autofit/messages/truncated_normal.py index a26292292..e8cf91f83 100644 --- a/autofit/messages/truncated_normal.py +++ b/autofit/messages/truncated_normal.py @@ -25,8 +25,8 @@ def is_nan(value): class TruncatedNormalMessage(AbstractMessage): - @cached_property - def log_partition(self) -> float: + + def log_partition(self, xp=np) -> float: """ Compute the log-partition function (normalizer) of the truncated Gaussian. @@ -46,7 +46,7 @@ def log_partition(self) -> float: a = (self.lower_limit - self.mean) / self.sigma b = (self.upper_limit - self.mean) / self.sigma Z = norm.cdf(b) - norm.cdf(a) - return np.log(Z) if Z > 0 else -np.inf + return xp.log(Z) if Z > 0 else -xp.inf log_base_measure = -0.5 * np.log(2 * np.pi) From c62083063b3913c694b0611f87d9a0d61b6444fb Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 15:03:54 +0000 Subject: [PATCH 10/35] hacks which push us most of the way through --- autofit/messages/abstract.py | 9 +++++---- autofit/messages/beta.py | 5 +++-- autofit/messages/composed_transform.py | 8 ++++---- autofit/messages/gamma.py | 8 ++++---- autofit/messages/interface.py | 27 +++++++++++++++++++------- autofit/messages/normal.py | 9 +++++---- autofit/messages/truncated_normal.py | 17 ++++++++-------- 7 files changed, 50 insertions(+), 33 deletions(-) diff --git a/autofit/messages/abstract.py b/autofit/messages/abstract.py index cb76b5f80..e6a9fd860 100644 --- a/autofit/messages/abstract.py +++ b/autofit/messages/abstract.py @@ -56,6 +56,7 @@ def __init__( self.id = next(self.ids) if id_ is None else id_ self.log_norm = log_norm + if xp is np: self._broadcast = np.broadcast(*parameters) else: @@ -222,7 +223,7 @@ def __truediv__(self, other: Union["AbstractMessage", Real]) -> "AbstractMessage ) def __pow__(self, other: Real) -> "AbstractMessage": - natural = self.natural_parameters + natural = self.natural_parameters() new_params = other * natural log_norm = other * self.log_norm new = self.from_natural_parameters( @@ -313,12 +314,12 @@ def log_normalisation(self, *elems: Union["AbstractMessage", float]) -> np.ndarr ] # Calculate log product of message normalisation - log_norm = self.log_base_measure - self.log_partition - log_norm += sum(dist.log_base_measure - dist.log_partition for dist in dists) + log_norm = self.log_base_measure - self.log_partition() + log_norm += sum(dist.log_base_measure - dist.log_partition() for dist in dists) # Calculate log normalisation of product of messages prod_dist = self.sum_natural_parameters(*dists) - log_norm -= prod_dist.log_base_measure - prod_dist.log_partition + log_norm -= prod_dist.log_base_measure - prod_dist.log_partition() return log_norm diff --git a/autofit/messages/beta.py b/autofit/messages/beta.py index 8a4696d92..9b9e6b919 100644 --- a/autofit/messages/beta.py +++ b/autofit/messages/beta.py @@ -201,7 +201,8 @@ def natural_parameters(self, xp=np) -> np.ndarray: @staticmethod def calc_natural_parameters( alpha: Union[float, np.ndarray], - beta: Union[float, np.ndarray] + beta: Union[float, np.ndarray], + xp=np ) -> np.ndarray: """ Calculate the natural parameters of a Beta distribution from alpha and beta. @@ -217,7 +218,7 @@ def calc_natural_parameters( ------- Natural parameters [alpha - 1, beta - 1]. """ - return np.array([alpha - 1, beta - 1]) + return xp.array([alpha - 1, beta - 1]) @staticmethod def invert_natural_parameters( diff --git a/autofit/messages/composed_transform.py b/autofit/messages/composed_transform.py index 0b7f4f34f..b4d1fa1d0 100644 --- a/autofit/messages/composed_transform.py +++ b/autofit/messages/composed_transform.py @@ -34,9 +34,9 @@ def transform(func): """ @functools.wraps(func) - def wrapper(self, x): + def wrapper(self, x, xp): x = self._transform(x) - return func(self, x) + return func(self, x, xp) return wrapper @@ -239,8 +239,8 @@ def value_for(self, unit): return self.base_message.value_for(unit) @transform - def calc_log_base_measure(self, x) -> np.ndarray: - return self.base_message.calc_log_base_measure(x) + def calc_log_base_measure(self, x, xp=np) -> np.ndarray: + return self.base_message.calc_log_base_measure(x, xp=xp) @transform def to_canonical_form(self, x, xp=np) -> np.ndarray: diff --git a/autofit/messages/gamma.py b/autofit/messages/gamma.py index 05c587552..3677186f0 100644 --- a/autofit/messages/gamma.py +++ b/autofit/messages/gamma.py @@ -40,8 +40,8 @@ def natural_parameters(self, xp=np) -> np.ndarray: return self.calc_natural_parameters(self.alpha, self.beta, xp=xp) @staticmethod - def calc_natural_parameters(alpha, beta): - return np.array([alpha - 1, -beta]) + def calc_natural_parameters(alpha, beta, xp=np): + return xp.array([alpha - 1, -beta]) @staticmethod def invert_natural_parameters(natural_parameters): @@ -97,13 +97,13 @@ def kl(self, dist): def logpdf_gradient(self, x): logl = self.logpdf(x) - eta1 = self.natural_parameters[0] + eta1 = self.natural_parameters()[0] gradl = eta1 / x - self.beta return logl, gradl def logpdf_gradient_hessian(self, x): logl = self.logpdf(x) - eta1 = self.natural_parameters[0] + eta1 = self.natural_parameters()[0] gradl = eta1 / x hessl = -gradl / x gradl -= self.beta diff --git a/autofit/messages/interface.py b/autofit/messages/interface.py index 2ed1932e9..e7f3b3136 100644 --- a/autofit/messages/interface.py +++ b/autofit/messages/interface.py @@ -45,9 +45,11 @@ def pdf(self, x: np.ndarray) -> np.ndarray: return np.exp(self.logpdf(x)) def logpdf(self, x: Union[np.ndarray, float], xp=np) -> np.ndarray: + eta = self._broadcast_natural_parameters(x, xp=xp) t = self.to_canonical_form(x, xp=xp) - log_base = self.calc_log_base_measure(x) + log_base = self.calc_log_base_measure(x, xp=xp) + return self.natural_logpdf(eta, t, log_base, self.log_partition(xp=xp), xp=xp) def _broadcast_natural_parameters(self, x, xp=np): @@ -72,7 +74,7 @@ def to_canonical_form(x: Union[np.ndarray, float], xp=np) -> np.ndarray: pass @classmethod - def calc_log_base_measure(cls, x): + def calc_log_base_measure(cls, x, xp=np): return cls.log_base_measure @abstractmethod @@ -81,8 +83,19 @@ def log_partition(self, xp=np) -> np.ndarray: @classmethod def natural_logpdf(cls, eta, t, log_base, log_partition, xp=np): + + try: + eta = eta() + except TypeError: + pass + + try: + log_partition_in = log_partition(xp=xp) + except TypeError: + log_partition_in = log_partition + eta_t = xp.multiply(eta, t).sum(0) - return xp.nan_to_num(log_base + eta_t - log_partition, nan=-xp.inf) + return xp.nan_to_num(log_base + eta_t - log_partition_in, nan=-xp.inf) def numerical_logpdf_gradient( self, x: np.ndarray, eps: float = 1e-6 @@ -184,7 +197,7 @@ def check_support(self) -> np.ndarray: pass def check_finite(self) -> np.ndarray: - return np.isfinite(self.natural_parameters).all(0) + return np.isfinite(self.natural_parameters()).all(0) def check_valid(self) -> np.ndarray: return self.check_finite() & self.check_support() @@ -204,11 +217,11 @@ def sum_natural_parameters(self, *dists: "MessageInterface") -> "MessageInterfac """ new_params = sum( ( - dist.natural_parameters + dist.natural_parameters() for dist in self._iter_dists(dists) if isinstance(dist, MessageInterface) ), - self.natural_parameters, + self.natural_parameters(), ) return self.from_natural_parameters( new_params, @@ -220,7 +233,7 @@ def sub_natural_parameters(self, other: "MessageInterface") -> "MessageInterface of this distribution with another distribution of the same type""" log_norm = self.log_norm - other.log_norm - new_params = self.natural_parameters - other.natural_parameters + new_params = self.natural_parameters() - other.natural_parameters() return self.from_natural_parameters( new_params, log_norm=log_norm, diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index d8d6773cc..56af526aa 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -93,7 +93,8 @@ def __init__( id_ An optional unique identifier used to track the message in larger probabilistic graphs or models. """ - if isinstance(mean, (np.ndarray, float, int)): + + if isinstance(mean, (np.ndarray, float, int, list)): xp = np else: import jax.numpy as jnp @@ -470,7 +471,7 @@ def natural(self)-> "NaturalNormal": A natural form Gaussian with zeroed parameters but same configuration. """ return NaturalNormal.from_natural_parameters( - self.natural_parameters * 0.0, **self._init_kwargs + self.natural_parameters() * 0.0, **self._init_kwargs ) def zeros_like(self) -> "AbstractMessage": @@ -556,7 +557,7 @@ def mean(self) -> float: return np.nan_to_num(-self.parameters[0] / self.parameters[1] / 2) @staticmethod - def calc_natural_parameters(eta1: float, eta2: float) -> np.ndarray: + def calc_natural_parameters(eta1: float, eta2: float, xp=np) -> np.ndarray: """ Return the natural parameters in array form (identity function for this class). @@ -567,7 +568,7 @@ def calc_natural_parameters(eta1: float, eta2: float) -> np.ndarray: eta2 The second natural parameter. """ - return np.array([eta1, eta2]) + return xp.array([eta1, eta2]) def natural_parameters(self, xp=np) -> np.ndarray: """ diff --git a/autofit/messages/truncated_normal.py b/autofit/messages/truncated_normal.py index e8cf91f83..d7ac5c89b 100644 --- a/autofit/messages/truncated_normal.py +++ b/autofit/messages/truncated_normal.py @@ -161,7 +161,7 @@ def natural_parameters(self, xp=np) -> np.ndarray: return self.calc_natural_parameters(self.mean, self.sigma, xp=xp) @staticmethod - def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray]) -> np.ndarray: + def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray], xp=np) -> np.ndarray: """ Convert standard parameters of a Gaussian distribution (mean and standard deviation) into natural parameters used in its exponential family representation. @@ -189,7 +189,7 @@ def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, η₂ = -1 / (2σ²) """ precision = 1 / sigma**2 - return np.array([mu * precision, -precision / 2]) + return xp.array([mu * precision, -precision / 2]) @staticmethod def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, float]: @@ -493,7 +493,7 @@ def natural(self)-> "NaturalNormal": A natural form Gaussian with zeroed parameters but same configuration. """ return TruncatedNaturalNormal.from_natural_parameters( - self.natural_parameters * 0.0, **self._init_kwargs + self.natural_parameters() * 0.0, **self._init_kwargs ) def zeros_like(self) -> "AbstractMessage": @@ -617,10 +617,11 @@ def mean(self) -> float: @staticmethod def calc_natural_parameters( - eta1: float, - eta2: float, - lower_limit: float = -np.inf, - upper_limit: float = np.inf + eta1: float, + eta2: float, + lower_limit: float = -np.inf, + upper_limit: float = np.inf, + xp=np ) -> np.ndarray: """ Return the natural parameters in array form (identity function for this class). @@ -635,7 +636,7 @@ def calc_natural_parameters( eta2 The second natural parameter. """ - return np.array([eta1, eta2]) + return xp.array([eta1, eta2]) def natural_parameters(self, xp=np) -> np.ndarray: """ From 6c78390b5e036d95850a8202ee41afab9f7b8f4f Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 15:06:46 +0000 Subject: [PATCH 11/35] fix removing dodgy hacks --- autofit/messages/composed_transform.py | 4 ++-- autofit/messages/interface.py | 13 +------------ 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/autofit/messages/composed_transform.py b/autofit/messages/composed_transform.py index b4d1fa1d0..28fed6d9a 100644 --- a/autofit/messages/composed_transform.py +++ b/autofit/messages/composed_transform.py @@ -158,7 +158,7 @@ def kl(self, dist): return self.base_message.kl(dist.base_message) def natural_parameters(self, xp=np) -> np.ndarray: - return self.base_message.natural_parameters + return self.base_message.natural_parameters(xp=xp) @inverse_transform def sample(self, n_samples: Optional[int] = None): @@ -229,7 +229,7 @@ def cdf(self, x): return self.base_message.cdf(x) def log_partition(self, xp=np) -> np.ndarray: - return self.base_message.log_partition + return self.base_message.log_partition(xp=xp) def invert_sufficient_statistics(self, sufficient_statistics): return self.base_message.invert_sufficient_statistics(sufficient_statistics) diff --git a/autofit/messages/interface.py b/autofit/messages/interface.py index e7f3b3136..c0d589553 100644 --- a/autofit/messages/interface.py +++ b/autofit/messages/interface.py @@ -83,19 +83,8 @@ def log_partition(self, xp=np) -> np.ndarray: @classmethod def natural_logpdf(cls, eta, t, log_base, log_partition, xp=np): - - try: - eta = eta() - except TypeError: - pass - - try: - log_partition_in = log_partition(xp=xp) - except TypeError: - log_partition_in = log_partition - eta_t = xp.multiply(eta, t).sum(0) - return xp.nan_to_num(log_base + eta_t - log_partition_in, nan=-xp.inf) + return xp.nan_to_num(log_base + eta_t - log_partition, nan=-xp.inf) def numerical_logpdf_gradient( self, x: np.ndarray, eps: float = 1e-6 From 0bbb3dfbb4865360b5f733e1cd252715fb1298cd Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 15:13:38 +0000 Subject: [PATCH 12/35] use_jax made private --- autofit/graphical/declarative/factor/hierarchical.py | 6 +++--- autofit/non_linear/analysis/analysis.py | 4 ++-- test_autofit/graphical/global/test_hierarchical.py | 3 +-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/autofit/graphical/declarative/factor/hierarchical.py b/autofit/graphical/declarative/factor/hierarchical.py index 0ec5fde9b..69bd3c272 100644 --- a/autofit/graphical/declarative/factor/hierarchical.py +++ b/autofit/graphical/declarative/factor/hierarchical.py @@ -72,7 +72,7 @@ def __init__( self._name = name or namer(self.__class__.__name__) self._factors = list() self.optimiser = optimiser - self.use_jax = use_jax + self._use_jax = use_jax @property def name(self): @@ -147,7 +147,7 @@ def __call__(self, **kwargs): class _HierarchicalFactor(AbstractModelFactor): def __init__( - self, distribution_model: HierarchicalFactor, drawn_prior: Prior, use_jax : bool = False + self, distribution_model: HierarchicalFactor, drawn_prior: Prior, ): """ A factor that links a variable to a parameterised distribution. @@ -162,7 +162,7 @@ def __init__( """ self.distribution_model = distribution_model self.drawn_prior = drawn_prior - self.use_jax = distribution_model.use_jax + self._use_jax = distribution_model._use_jax prior_variable_dict = {prior.name: prior for prior in distribution_model.priors} diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 996bb995f..441bc70b4 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -36,7 +36,7 @@ def __init__( self, use_jax : bool = False, **kwargs ): - self.use_jax = use_jax + self._use_jax = use_jax self.kwargs = kwargs def __getattr__(self, item: str): @@ -63,7 +63,7 @@ def method(*args, **kwargs): @property def _xp(self): - if self.use_jax: + if self._use_jax: import jax.numpy as jnp return jnp return np diff --git a/test_autofit/graphical/global/test_hierarchical.py b/test_autofit/graphical/global/test_hierarchical.py index 5e90d3704..26b3e4da3 100644 --- a/test_autofit/graphical/global/test_hierarchical.py +++ b/test_autofit/graphical/global/test_hierarchical.py @@ -55,8 +55,7 @@ def test_model_info(model): 2 - 3 one UniformPrior [0], lower_limit = 0.0, upper_limit = 1.0 factor - include_prior_factors True - use_jax False""" + include_prior_factors True""" ) From 2cb4a044c0503155b77a943f64b2766fba77a36b Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 15:14:33 +0000 Subject: [PATCH 13/35] include prior factors made private --- autofit/graphical/declarative/abstract.py | 6 +++--- autofit/graphical/declarative/collection.py | 2 +- autofit/graphical/declarative/factor/abstract.py | 2 +- test_autofit/graphical/global/test_hierarchical.py | 4 +--- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/autofit/graphical/declarative/abstract.py b/autofit/graphical/declarative/abstract.py index e633c71df..35e1c177f 100644 --- a/autofit/graphical/declarative/abstract.py +++ b/autofit/graphical/declarative/abstract.py @@ -20,7 +20,7 @@ class AbstractDeclarativeFactor(Analysis, ABC): _plates: Tuple[Plate, ...] = () def __init__(self, include_prior_factors=False, use_jax : bool = False): - self.include_prior_factors = include_prior_factors + self._include_prior_factors = include_prior_factors super().__init__(use_jax=use_jax) @@ -63,7 +63,7 @@ def prior_counts(self) -> List[Tuple[Prior, int]]: for prior in factor.prior_model.priors: counter[prior] += 1 return [ - (prior, count + 1 if self.include_prior_factors else count) + (prior, count + 1 if self._include_prior_factors else count) for prior, count in counter.items() ] @@ -96,7 +96,7 @@ def graph(self) -> DeclarativeFactorGraph: The complete graph made by combining all factors and priors """ factors = [model for model in self.model_factors] - if self.include_prior_factors: + if self._include_prior_factors: factors += self.prior_factors # noinspection PyTypeChecker return DeclarativeFactorGraph(factors) diff --git a/autofit/graphical/declarative/collection.py b/autofit/graphical/declarative/collection.py index 97fb9824f..6dff83693 100644 --- a/autofit/graphical/declarative/collection.py +++ b/autofit/graphical/declarative/collection.py @@ -42,7 +42,7 @@ def __init__( def tree_flatten(self): return ( (self._model_factors,), - (self._name, self.include_prior_factors), + (self._name, self._include_prior_factors), ) @classmethod diff --git a/autofit/graphical/declarative/factor/abstract.py b/autofit/graphical/declarative/factor/abstract.py index 3df6a9f0c..67bfec786 100644 --- a/autofit/graphical/declarative/factor/abstract.py +++ b/autofit/graphical/declarative/factor/abstract.py @@ -42,7 +42,7 @@ def __init__( factor, **prior_variable_dict, name=name or namer(self.__class__.__name__), ) - self.include_prior_factors = include_prior_factors + self._include_prior_factors = include_prior_factors @property def info(self) -> str: diff --git a/test_autofit/graphical/global/test_hierarchical.py b/test_autofit/graphical/global/test_hierarchical.py index 26b3e4da3..6f22fa874 100644 --- a/test_autofit/graphical/global/test_hierarchical.py +++ b/test_autofit/graphical/global/test_hierarchical.py @@ -53,9 +53,7 @@ def test_model_info(model): 1 drawn_prior UniformPrior [1], lower_limit = 0.0, upper_limit = 1.0 2 - 3 - one UniformPrior [0], lower_limit = 0.0, upper_limit = 1.0 -factor - include_prior_factors True""" + one UniformPrior [0], lower_limit = 0.0, upper_limit = 1.0""" ) From fb9d8b97071c0141f3f93a28bfcebbca3158e683 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 15:17:54 +0000 Subject: [PATCH 14/35] fix issues with _use_jax now being private --- autofit/non_linear/analysis/analysis.py | 4 ++-- autofit/non_linear/fitness.py | 2 +- autofit/non_linear/search/nest/dynesty/search/abstract.py | 2 +- autofit/non_linear/search/nest/nautilus/search.py | 2 +- test_autofit/graphical/hierarchical/test_optimise.py | 7 ------- 5 files changed, 5 insertions(+), 12 deletions(-) diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 441bc70b4..3bdeab220 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -109,7 +109,7 @@ def compute_latent_samples(self, samples: Samples, batch_size : Optional[int] = compute_latent_for_model = functools.partial(self.compute_latent_variables, model=samples.model) - if self.use_jax: + if self._use_jax: import jax start = time.time() logger.info("JAX: Applying vmap and jit to likelihood function for latent variables -- may take a few seconds.") @@ -130,7 +130,7 @@ def batched_compute_latent(x): # batched JAX call on this chunk latent_values_batch = batched_compute_latent(batch) - if self.use_jax: + if self._use_jax: import jax.numpy as jnp latent_values_batch = jnp.stack(latent_values_batch, axis=-1) # (batch, n_latents) mask = jnp.all(jnp.isfinite(latent_values_batch), axis=0) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index e3a7ab0b0..a3792b701 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -134,7 +134,7 @@ def __init__( @property def _xp(self): - if self.analysis.use_jax: + if self.analysis._use_jax: import jax.numpy as jnp return jnp return np diff --git a/autofit/non_linear/search/nest/dynesty/search/abstract.py b/autofit/non_linear/search/nest/dynesty/search/abstract.py index b566a1c50..913bd2b62 100644 --- a/autofit/non_linear/search/nest/dynesty/search/abstract.py +++ b/autofit/non_linear/search/nest/dynesty/search/abstract.py @@ -146,7 +146,7 @@ def _fit( "parallel" ].get("force_x1_cpu") or self.kwargs.get("force_x1_cpu") - or analysis.use_jax + or analysis._use_jax ): raise RuntimeError diff --git a/autofit/non_linear/search/nest/nautilus/search.py b/autofit/non_linear/search/nest/nautilus/search.py index 30531edbe..3977843b2 100644 --- a/autofit/non_linear/search/nest/nautilus/search.py +++ b/autofit/non_linear/search/nest/nautilus/search.py @@ -127,7 +127,7 @@ def _fit(self, model: AbstractPriorModel, analysis): if ( self.config_dict.get("force_x1_cpu") or self.kwargs.get("force_x1_cpu") - or analysis.use_jax + or analysis._use_jax ): fitness = Fitness( diff --git a/test_autofit/graphical/hierarchical/test_optimise.py b/test_autofit/graphical/hierarchical/test_optimise.py index 12caf64f2..f500029ae 100644 --- a/test_autofit/graphical/hierarchical/test_optimise.py +++ b/test_autofit/graphical/hierarchical/test_optimise.py @@ -11,13 +11,6 @@ def make_factor(hierarchical_factor): def test_optimise(factor): search = af.DynestyStatic(maxcall=100, dynamic_delta=False, delta=0.1,) - print(type(factor.analysis)) - print(type(factor.analysis)) - print(type(factor.analysis)) - print(type(factor.analysis)) - print(type(factor.analysis)) - - _, status = search.optimise( factor.mean_field_approximation().factor_approximation(factor) ) From ee14b5bfeca7d31907c6dc9adaca50091d9c9f6f Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 16:18:28 +0000 Subject: [PATCH 15/35] remove print statements --- autofit/messages/composed_transform.py | 6 +++--- autofit/messages/normal.py | 2 +- autofit/messages/truncated_normal.py | 2 +- autofit/non_linear/search/mle/bfgs/search.py | 4 ---- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/autofit/messages/composed_transform.py b/autofit/messages/composed_transform.py index 28fed6d9a..84792b062 100644 --- a/autofit/messages/composed_transform.py +++ b/autofit/messages/composed_transform.py @@ -34,7 +34,7 @@ def transform(func): """ @functools.wraps(func) - def wrapper(self, x, xp): + def wrapper(self, x, xp=np): x = self._transform(x) return func(self, x, xp) @@ -225,8 +225,8 @@ def invert_natural_parameters( return self.base_message.invert_natural_parameters(natural_parameters) @transform - def cdf(self, x): - return self.base_message.cdf(x) + def cdf(self, x, xp=np): + return self.base_message.cdf(x, xp=xp) def log_partition(self, xp=np) -> np.ndarray: return self.base_message.log_partition(xp=xp) diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index 56af526aa..f7128aa97 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -111,7 +111,7 @@ def __init__( ) self.mean, self.sigma = self.parameters - def cdf(self, x : Union[float, np.ndarray]) -> Union[float, np.ndarray]: + def cdf(self, x : Union[float, np.ndarray], xp=np) -> Union[float, np.ndarray]: """ Compute the cumulative distribution function (CDF) of the Gaussian distribution at a given value or array of values `x`. diff --git a/autofit/messages/truncated_normal.py b/autofit/messages/truncated_normal.py index d7ac5c89b..081e4d99d 100644 --- a/autofit/messages/truncated_normal.py +++ b/autofit/messages/truncated_normal.py @@ -96,7 +96,7 @@ def __init__( self.mean, self.sigma, self.lower_limit, self.upper_limit = self.parameters - def cdf(self, x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: + def cdf(self, x: Union[float, np.ndarray], xp=np) -> Union[float, np.ndarray]: """ Compute the cumulative distribution function (CDF) of the truncated Gaussian distribution at a given value or array of values `x`. diff --git a/autofit/non_linear/search/mle/bfgs/search.py b/autofit/non_linear/search/mle/bfgs/search.py index 42a849235..8ca8f064b 100644 --- a/autofit/non_linear/search/mle/bfgs/search.py +++ b/autofit/non_linear/search/mle/bfgs/search.py @@ -225,10 +225,6 @@ def samples_via_internal_from( weight_list = len(log_likelihood_list) * [1.0] - print(parameter_lists) - print(log_likelihood_list) - print(log_prior_list) - sample_list = Sample.from_lists( model=model, parameter_lists=parameter_lists, From 316c884fcd03fc8175c3ae24920164ca88a5b05b Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 2 Dec 2025 16:05:56 +0000 Subject: [PATCH 16/35] fixs png maker --- .../aggregator/summary/aggregate_images.py | 30 ++++--------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/autofit/aggregator/summary/aggregate_images.py b/autofit/aggregator/summary/aggregate_images.py index da9b179f9..69a29d047 100644 --- a/autofit/aggregator/summary/aggregate_images.py +++ b/autofit/aggregator/summary/aggregate_images.py @@ -24,32 +24,11 @@ def subplot_filename(subplot: Enum) -> str: ) -class SubplotFit(Enum): - """ - The subplots that can be extracted from the subplot_fit image. - - The values correspond to the position of the subplot in the 4x3 grid. - """ - - Data = (0, 0) - DataSourceScaled = (1, 0) - SignalToNoiseMap = (2, 0) - ModelData = (3, 0) - LensLightModelData = (0, 1) - LensLightSubtractedImage = (1, 1) - SourceModelData = (2, 1) - SourcePlaneZoomed = (3, 1) - NormalizedResidualMap = (0, 2) - NormalizedResidualMapOneSigma = (1, 2) - ChiSquaredMap = (2, 2) - SourcePlaneNoZoom = (3, 2) - - class SubplotFitImage: def __init__( self, image: Image.Image, - suplot_type: Type[SubplotFit] = SubplotFit, + suplot_type, ): """ The subplot_fit image associated with one fit. @@ -173,7 +152,7 @@ def output_to_folder( self, folder: Path, name: Union[str, List[str]], - subplots: List[Union[SubplotFit, List[Image.Image], Callable]], + subplots: List[Union[List[Image.Image], Callable]], subplot_width: Optional[int] = sys.maxsize, ): """ @@ -226,7 +205,7 @@ def output_to_folder( def _matrix_for_result( i: int, result: SearchOutput, - subplots: List[Union[SubplotFit, List[Image.Image], Callable]], + subplots: List[Union[List[Image.Image], Callable]], subplot_width: int = sys.maxsize, ) -> List[List[Image.Image]]: """ @@ -272,7 +251,8 @@ class name but using snake_case. _images[subplot_type] = SubplotFitImage( result.image( subplot_filename(subplot_), - ) + ), + subplot_type ) return _images[subplot_type] From 58a7473c9f179d9ff08f6e9f99fbfabc29f224f3 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 2 Dec 2025 16:43:11 +0000 Subject: [PATCH 17/35] fix unit test --- .../summary_files/test_aggregate_images.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/test_autofit/aggregator/summary_files/test_aggregate_images.py b/test_autofit/aggregator/summary_files/test_aggregate_images.py index d4334408b..1906aa522 100644 --- a/test_autofit/aggregator/summary_files/test_aggregate_images.py +++ b/test_autofit/aggregator/summary_files/test_aggregate_images.py @@ -6,8 +6,27 @@ from PIL import Image from autofit.aggregator import Aggregator -from autofit.aggregator.summary.aggregate_images import AggregateImages, SubplotFit - +from autofit.aggregator.summary.aggregate_images import AggregateImages + +class SubplotFit(Enum): + """ + The subplots that can be extracted from the subplot_fit image. + + The values correspond to the position of the subplot in the 4x3 grid. + """ + + Data = (0, 0) + DataSourceScaled = (1, 0) + SignalToNoiseMap = (2, 0) + ModelData = (3, 0) + LensLightModelData = (0, 1) + LensLightSubtractedImage = (1, 1) + SourceModelData = (2, 1) + SourcePlaneZoomed = (3, 1) + NormalizedResidualMap = (0, 2) + NormalizedResidualMapOneSigma = (1, 2) + ChiSquaredMap = (2, 2) + SourcePlaneNoZoom = (3, 2) @pytest.fixture def aggregate(aggregator): @@ -167,7 +186,7 @@ class SubplotFit(Enum): SubplotFit.Data, ] ) - assert result.size == (61, 120) + assert result.size == (244, 362) def test_bad_aggregator(): From 3ce01f40a5842f3f6d8d732777b00bc3d187529f Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 9 Dec 2025 19:32:45 +0000 Subject: [PATCH 18/35] fixes to figure of merit --- autofit/non_linear/fitness.py | 14 +++++----- autofit/non_linear/search/mle/bfgs/search.py | 27 ++++++++++++++------ 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index a3792b701..5e5e6d981 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -173,6 +173,7 @@ def call(self, parameters): # Penalize NaNs in the log-likelihood log_likelihood = self._xp.where(self._xp.isnan(log_likelihood), self.resample_figure_of_merit, log_likelihood) + log_likelihood = self._xp.where(self._xp.isinf(log_likelihood), self.resample_figure_of_merit, log_likelihood) # Determine final figure of merit if self.fom_is_log_likelihood: @@ -222,19 +223,16 @@ def call_wrap(self, parameters): figure_of_merit = self._call(parameters) if self.convert_to_chi_squared: - figure_of_merit *= -0.5 - - if self.fom_is_log_likelihood: - log_likelihood = figure_of_merit + log_likelihood = -0.5 * figure_of_merit else: + log_likelihood = figure_of_merit + + if not self.fom_is_log_likelihood: log_prior_list = self._xp.array(self.model.log_prior_list_from_vector(vector=parameters, xp=self._xp)) - log_likelihood = figure_of_merit - self._xp.sum(log_prior_list) + log_likelihood -= self._xp.sum(log_prior_list) self.manage_quick_update(parameters=parameters, log_likelihood=log_likelihood) - if self.convert_to_chi_squared: - log_likelihood *= -2.0 - if self.store_history: self.parameters_history_list.append(np.array(parameters)) diff --git a/autofit/non_linear/search/mle/bfgs/search.py b/autofit/non_linear/search/mle/bfgs/search.py index 8ca8f064b..2b105e786 100644 --- a/autofit/non_linear/search/mle/bfgs/search.py +++ b/autofit/non_linear/search/mle/bfgs/search.py @@ -136,21 +136,32 @@ def _fit( maxiter = self.config_dict_options.get("maxiter", 1e8) while total_iterations < maxiter: - iterations_remaining = maxiter - total_iterations + iterations_remaining = maxiter - total_iterations iterations = min(self.iterations_per_full_update, iterations_remaining) if iterations > 0: config_dict_options = self.config_dict_options config_dict_options["maxiter"] = iterations - search_internal = optimize.minimize( - fun=fitness._jit, - x0=x0, - method=self.method, - options=config_dict_options, - **self.config_dict_search - ) + if analysis._use_jax: + + search_internal = optimize.minimize( + fun=fitness._jit, + x0=x0, + method=self.method, + options=config_dict_options, + **self.config_dict_search + ) + else: + + search_internal = optimize.minimize( + fun=fitness.__call__, + x0=x0, + method=self.method, + options=config_dict_options, + **self.config_dict_search + ) total_iterations += search_internal.nit From 7313572cb62f3b4a7d4e3c0912c08048a8c9e640 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 12 Dec 2025 19:33:38 +0000 Subject: [PATCH 19/35] minor fixes --- autofit/non_linear/fitness.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index 5e5e6d981..a41fc7614 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -228,8 +228,8 @@ def call_wrap(self, parameters): log_likelihood = figure_of_merit if not self.fom_is_log_likelihood: - log_prior_list = self._xp.array(self.model.log_prior_list_from_vector(vector=parameters, xp=self._xp)) - log_likelihood -= self._xp.sum(log_prior_list) + log_prior_list = np.array(self.model.log_prior_list_from_vector(vector=parameters, xp=np)) + log_likelihood -= np.sum(log_prior_list) self.manage_quick_update(parameters=parameters, log_likelihood=log_likelihood) From cecb125276ea7bc7d8ed61c14381c4fcf3aa8b4c Mon Sep 17 00:00:00 2001 From: GitHub Actions bot Date: Mon, 15 Dec 2025 09:29:13 +0000 Subject: [PATCH 20/35] 'Updated version in __init__ to 2025.12.15.1 --- autofit/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autofit/__init__.py b/autofit/__init__.py index 949baad80..639972c42 100644 --- a/autofit/__init__.py +++ b/autofit/__init__.py @@ -141,4 +141,4 @@ def save_abc(pickler, obj): pickle._Pickler.save_type(pickler, obj) -__version__ = "2025.11.5.1" \ No newline at end of file +__version__ = "2025.12.15.1" \ No newline at end of file From dabaa2aeb9768f9c30fa43a6512796f6883b7295 Mon Sep 17 00:00:00 2001 From: GitHub Actions bot Date: Mon, 15 Dec 2025 09:57:28 +0000 Subject: [PATCH 21/35] 'Updated version in __init__ to 2025.12.15.2 --- autofit/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autofit/__init__.py b/autofit/__init__.py index 639972c42..225e7e0df 100644 --- a/autofit/__init__.py +++ b/autofit/__init__.py @@ -141,4 +141,4 @@ def save_abc(pickler, obj): pickle._Pickler.save_type(pickler, obj) -__version__ = "2025.12.15.1" \ No newline at end of file +__version__ = "2025.12.15.2" \ No newline at end of file From 2e417c2f5ec45161ccddfd7521d16e4bbceaf529 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 16 Dec 2025 14:39:16 +0000 Subject: [PATCH 22/35] print staement works --- autofit/non_linear/analysis/analysis.py | 40 ++++++++++++++++++- .../non_linear/grid/grid_search/__init__.py | 2 - autofit/non_linear/grid/grid_search/job.py | 1 + 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 3bdeab220..84272e689 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -314,4 +314,42 @@ def profile_log_likelihood_function(self, paths: AbstractPaths, instance): pass def perform_quick_update(self, paths, instance): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError + + def print_vram_use(self, model, batch_size : int) -> str: + """ + Print JAX VRAM use for a given batch size. + + Parameters + ---------- + batch_size + The batch size to profile, which is the number of model evaluations JAX will perform simultaneously. + """ + import jax + import jax.numpy as jnp + + from autofit.non_linear.fitness import Fitness + + fitness = Fitness( + model=model, + analysis=self, + fom_is_log_likelihood=True, + use_jax_vmap=True, + batch_size=batch_size, + ) + + parameters = np.zeros((batch_size, model.total_free_parameters)) + + for i in range(batch_size): + parameters[i, :] = model.physical_values_from_prior_medians + + parameters = jnp.array(parameters) + + batched_call = jax.jit(jax.vmap(fitness.call)) + lowered = batched_call.lower(parameters) + compiled = lowered.compile() + memory_analysis = compiled.memory_analysis() + + print( + f"VRAM USE = {(memory_analysis.output_size_in_bytes + memory_analysis.temp_size_in_bytes) / 1024 ** 3:.3f} GB" + ) diff --git a/autofit/non_linear/grid/grid_search/__init__.py b/autofit/non_linear/grid/grid_search/__init__.py index 8768c28f5..5c42d201f 100644 --- a/autofit/non_linear/grid/grid_search/__init__.py +++ b/autofit/non_linear/grid/grid_search/__init__.py @@ -219,8 +219,6 @@ def _fit( result: GridSearchResult The result of the grid search """ - self.logger.info("...in parallel") - grid_priors = model.sort_priors_alphabetically(set(grid_priors)) lists = self.make_lists(grid_priors) diff --git a/autofit/non_linear/grid/grid_search/job.py b/autofit/non_linear/grid/grid_search/job.py index 892d8ad46..bef62184a 100644 --- a/autofit/non_linear/grid/grid_search/job.py +++ b/autofit/non_linear/grid/grid_search/job.py @@ -51,6 +51,7 @@ def __init__( self.info = info def perform(self): + result = self.search_instance.fit( model=self.model, analysis=self.analysis, From 322e3411e1b9c6722517b79fac067fef1fb7b987 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 16 Dec 2025 14:49:32 +0000 Subject: [PATCH 23/35] specific message about CPU use if 0 VRAM --- autofit/non_linear/analysis/analysis.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 84272e689..4887224ac 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -350,6 +350,17 @@ def print_vram_use(self, model, batch_size : int) -> str: compiled = lowered.compile() memory_analysis = compiled.memory_analysis() - print( - f"VRAM USE = {(memory_analysis.output_size_in_bytes + memory_analysis.temp_size_in_bytes) / 1024 ** 3:.3f} GB" + vram_bytes = ( + memory_analysis.output_size_in_bytes + + memory_analysis.temp_size_in_bytes ) + + if vram_bytes == 0: + print( + "VRAM USE = 0.000 GB " + "(this likely means JAX is running in CPU-only mode)" + ) + else: + print( + f"VRAM USE = {vram_bytes / 1024 ** 3:.3f} GB" + ) \ No newline at end of file From 9c9f9302451095e16fc4b51682893c9daaf3c42d Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 16 Dec 2025 14:49:59 +0000 Subject: [PATCH 24/35] message for use_jax=False case --- autofit/non_linear/analysis/analysis.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 4887224ac..794d32361 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -325,6 +325,10 @@ def print_vram_use(self, model, batch_size : int) -> str: batch_size The batch size to profile, which is the number of model evaluations JAX will perform simultaneously. """ + if not self._use_jax: + print("use_jax=False for this analysis, therefore does not use GPU and VRAM use cannot be profiled.") + return + import jax import jax.numpy as jnp From a7fffaff83151fd4bdc1ccd77ec392c934ce00bd Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 16 Dec 2025 16:07:24 +0000 Subject: [PATCH 25/35] more tests --- autofit/non_linear/search/nest/nautilus/search.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/autofit/non_linear/search/nest/nautilus/search.py b/autofit/non_linear/search/nest/nautilus/search.py index 3977843b2..fa5363cc3 100644 --- a/autofit/non_linear/search/nest/nautilus/search.py +++ b/autofit/non_linear/search/nest/nautilus/search.py @@ -485,6 +485,10 @@ def samples_via_internal_from( samples_info=self.samples_info_from(search_internal=search_internal), ) + @property + def batch_size(self): + return self.config_dict_search.get("n_batch") + @property def config_dict(self): return conf.instance["non_linear"]["nest"][self.__class__.__name__] From 98452bc5a7e69484af92f5343e28af66654908b3 Mon Sep 17 00:00:00 2001 From: GitHub Actions bot Date: Sun, 21 Dec 2025 19:50:04 +0000 Subject: [PATCH 26/35] 'Updated version in __init__ to 2025.12.21.1 --- autofit/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autofit/__init__.py b/autofit/__init__.py index 225e7e0df..f5947118d 100644 --- a/autofit/__init__.py +++ b/autofit/__init__.py @@ -141,4 +141,4 @@ def save_abc(pickler, obj): pickle._Pickler.save_type(pickler, obj) -__version__ = "2025.12.15.2" \ No newline at end of file +__version__ = "2025.12.21.1" \ No newline at end of file From d0d03baa6d33bb6eb2887170f3b0c12f84ce30f6 Mon Sep 17 00:00:00 2001 From: GitHub Actions bot Date: Wed, 21 Jan 2026 19:10:00 +0000 Subject: [PATCH 27/35] 'Updated version in __init__ to 2026.1.21.3 --- autofit/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autofit/__init__.py b/autofit/__init__.py index f5947118d..4b5783e7e 100644 --- a/autofit/__init__.py +++ b/autofit/__init__.py @@ -141,4 +141,4 @@ def save_abc(pickler, obj): pickle._Pickler.save_type(pickler, obj) -__version__ = "2025.12.21.1" \ No newline at end of file +__version__ = "2026.1.21.3" \ No newline at end of file From 1ae3a9c20631054f5a88bad65bb60b3553414ccc Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 28 Jan 2026 17:03:09 +0000 Subject: [PATCH 28/35] tested use of samples summary --- autofit/aggregator/search_output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autofit/aggregator/search_output.py b/autofit/aggregator/search_output.py index 0a8facd85..d4484dbdf 100644 --- a/autofit/aggregator/search_output.py +++ b/autofit/aggregator/search_output.py @@ -250,7 +250,7 @@ def instance(self): try: return self.samples.max_log_likelihood() except (AttributeError, NotImplementedError): - return None + return self.samples_summary.instance @property def id(self) -> str: From fda58cf69d2be9cbd402560df71bd0113db108cb Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 30 Jan 2026 18:18:28 +0000 Subject: [PATCH 29/35] CPU JAX uses batch size 1 --- autofit/non_linear/fitness.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index a41fc7614..1eb5c9f91 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -44,7 +44,6 @@ def __init__( use_jax_vmap : bool = False, batch_size : Optional[int] = None, iterations_per_quick_update: Optional[int] = None, - xp=np, ): """ Interfaces with any non-linear search to fit the model to the data and return a log likelihood via @@ -123,6 +122,16 @@ def __init__( if self.use_jax_vmap: self._call = self._vmap + if analysis._use_jax: + + import jax + + if jax.default_backend() == "cpu": + + logger.info("JAX using CPU backend, batch size set to 1 which will improve performance.") + + batch_size = 1 + self.batch_size = batch_size self.iterations_per_quick_update = iterations_per_quick_update self.quick_update_max_lh_parameters = None From 074b97b5b936f90ea17fd4a67f0cba7076e59a23 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 30 Jan 2026 18:24:29 +0000 Subject: [PATCH 30/35] update xp uses in fitness --- autofit/non_linear/fitness.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index 1eb5c9f91..fca95406d 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -108,7 +108,8 @@ def __init__( self.model = model self.paths = paths self.fom_is_log_likelihood = fom_is_log_likelihood - self.resample_figure_of_merit = resample_figure_of_merit or -xp.inf + + self.resample_figure_of_merit = resample_figure_of_merit or -self._xp.inf self.convert_to_chi_squared = convert_to_chi_squared self.store_history = store_history @@ -135,7 +136,7 @@ def __init__( self.batch_size = batch_size self.iterations_per_quick_update = iterations_per_quick_update self.quick_update_max_lh_parameters = None - self.quick_update_max_lh = -xp.inf + self.quick_update_max_lh = -self._xp.inf self.quick_update_count = 0 if self.paths is not None: From 6cedcfcc09d42a21c3857007b76a5824c318bc2f Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 30 Jan 2026 20:46:12 +0000 Subject: [PATCH 31/35] disable vmap for CPU --- autofit/non_linear/fitness.py | 18 +++++++++--------- .../non_linear/search/nest/nautilus/search.py | 9 +++++++-- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index fca95406d..0da6185a7 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -116,22 +116,22 @@ def __init__( self.parameters_history_list = [] self.log_likelihood_history_list = [] - self.use_jax_vmap = use_jax_vmap - - self._call = self.call - - if self.use_jax_vmap: - self._call = self._vmap - if analysis._use_jax: import jax if jax.default_backend() == "cpu": - logger.info("JAX using CPU backend, batch size set to 1 which will improve performance.") + logger.info("JAX using CPU backend, vmap disabled for faster performance.") - batch_size = 1 + use_jax_vmap = False + + self.use_jax_vmap = use_jax_vmap + + self._call = self.call + + if self.use_jax_vmap: + self._call = self._vmap self.batch_size = batch_size self.iterations_per_quick_update = iterations_per_quick_update diff --git a/autofit/non_linear/search/nest/nautilus/search.py b/autofit/non_linear/search/nest/nautilus/search.py index fa5363cc3..8a584706f 100644 --- a/autofit/non_linear/search/nest/nautilus/search.py +++ b/autofit/non_linear/search/nest/nautilus/search.py @@ -137,7 +137,7 @@ def _fit(self, model: AbstractPriorModel, analysis): fom_is_log_likelihood=True, resample_figure_of_merit=-1.0e99, iterations_per_quick_update=self.iterations_per_quick_update, - use_jax_vmap=True, + use_jax_vmap=False, batch_size=self.config_dict_search["n_batch"], ) @@ -225,13 +225,18 @@ def fit_x1_cpu(self, fitness, model, analysis): except KeyError: pass + if fitness.use_jax_vmap: + vectorized = True + else: + vectorized = False + search_internal = self.sampler_cls( prior=PriorVectorized(model=model), likelihood=fitness.call_wrap, n_dim=model.prior_count, filepath=self.checkpoint_file, pool=None, - vectorized=True, + vectorized=vectorized, **config_dict, ) From a51ba48f846544fee72f56d24b90eab208255449 Mon Sep 17 00:00:00 2001 From: James Nightingale Date: Fri, 30 Jan 2026 20:52:12 +0000 Subject: [PATCH 32/35] Update autofit/non_linear/search/nest/nautilus/search.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- autofit/non_linear/search/nest/nautilus/search.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/autofit/non_linear/search/nest/nautilus/search.py b/autofit/non_linear/search/nest/nautilus/search.py index 8a584706f..e9c0abad6 100644 --- a/autofit/non_linear/search/nest/nautilus/search.py +++ b/autofit/non_linear/search/nest/nautilus/search.py @@ -225,10 +225,7 @@ def fit_x1_cpu(self, fitness, model, analysis): except KeyError: pass - if fitness.use_jax_vmap: - vectorized = True - else: - vectorized = False + vectorized = fitness.use_jax_vmap search_internal = self.sampler_cls( prior=PriorVectorized(model=model), From 7a254da066d03909def6d538983d00dca1e0c98b Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 30 Jan 2026 21:19:32 +0000 Subject: [PATCH 33/35] use_vmap_jax now manually input --- autofit/non_linear/fitness.py | 10 ---------- autofit/non_linear/search/nest/nautilus/search.py | 5 ++++- autofit/non_linear/settings.py | 3 +++ 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index 0da6185a7..698430640 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -116,16 +116,6 @@ def __init__( self.parameters_history_list = [] self.log_likelihood_history_list = [] - if analysis._use_jax: - - import jax - - if jax.default_backend() == "cpu": - - logger.info("JAX using CPU backend, vmap disabled for faster performance.") - - use_jax_vmap = False - self.use_jax_vmap = use_jax_vmap self._call = self.call diff --git a/autofit/non_linear/search/nest/nautilus/search.py b/autofit/non_linear/search/nest/nautilus/search.py index e9c0abad6..dcf8cdeb0 100644 --- a/autofit/non_linear/search/nest/nautilus/search.py +++ b/autofit/non_linear/search/nest/nautilus/search.py @@ -41,6 +41,7 @@ def __init__( iterations_per_full_update: int = None, number_of_cores: int = None, session: Optional[sa.orm.Session] = None, + use_jax_vmap : bool = True, **kwargs ): """ @@ -90,6 +91,8 @@ def __init__( self.logger.debug("Creating Nautilus Search") + self.use_jax_vmap = use_jax_vmap + def _fit(self, model: AbstractPriorModel, analysis): """ Fit a model using the search and the Analysis class which contains the data and returns the log likelihood from @@ -137,7 +140,7 @@ def _fit(self, model: AbstractPriorModel, analysis): fom_is_log_likelihood=True, resample_figure_of_merit=-1.0e99, iterations_per_quick_update=self.iterations_per_quick_update, - use_jax_vmap=False, + use_jax_vmap=self.use_jax_vmap, batch_size=self.config_dict_search["n_batch"], ) diff --git a/autofit/non_linear/settings.py b/autofit/non_linear/settings.py index a31ea873a..d0dc298aa 100644 --- a/autofit/non_linear/settings.py +++ b/autofit/non_linear/settings.py @@ -11,6 +11,7 @@ def __init__( number_of_cores: Optional[int] = 1, session: Optional[sa.orm.Session] = None, info: Optional[dict] = None, + use_jax_vmap: bool = True, ): """ Stores all the input settings that are used in search's and their `fit functions. @@ -40,6 +41,7 @@ def __init__( self.unique_tag = unique_tag self.number_of_cores = number_of_cores self.session = session + self.use_jax_vmap = use_jax_vmap self.info = info @@ -50,6 +52,7 @@ def search_dict(self): "unique_tag": self.unique_tag, "number_of_cores": self.number_of_cores, "session": self.session, + "use_jax_vmap": self.use_jax_vmap, } @property From ddad8283b2381b4967ea703c5816c2770ad49626 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Feb 2026 11:42:05 +0000 Subject: [PATCH 34/35] use analysis_xp in fitness --- autofit/non_linear/fitness.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index 698430640..85b2be91b 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -134,10 +134,7 @@ def __init__( @property def _xp(self): - if self.analysis._use_jax: - import jax.numpy as jnp - return jnp - return np + return self.analysis._xp def call(self, parameters): """ From a39b5a9b39912143b79503f728db86952681a357 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 2 Apr 2026 19:06:47 +0100 Subject: [PATCH 35/35] add expanded unit tests for model mapping API 81 new tests covering collection composition, shared priors, vector mapping, tree navigation, assertions, subsetting, freezing, serialization and edge cases. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../mapper/test_model_mapping_expanded.py | 631 ++++++++++++++++++ 1 file changed, 631 insertions(+) create mode 100644 test_autofit/mapper/test_model_mapping_expanded.py diff --git a/test_autofit/mapper/test_model_mapping_expanded.py b/test_autofit/mapper/test_model_mapping_expanded.py new file mode 100644 index 000000000..8be6fd534 --- /dev/null +++ b/test_autofit/mapper/test_model_mapping_expanded.py @@ -0,0 +1,631 @@ +""" +Expanded tests for the model mapping API, covering gaps identified in: +- Collection composition and instance creation +- Shared (linked) priors across model types +- Direct use of instance_for_arguments with argument dicts +- Model tree navigation (object_for_path, path_for_prior, name_for_prior) +- Edge cases (empty models, deeply nested models, single-parameter models) +- Model subsetting (with_paths, without_paths) +- Freezing behavior +- Assertion checking +- from_instance round-trips +- mapper_from_prior_arguments and mapper_from_partial_prior_arguments +""" +import copy + +import numpy as np +import pytest + +import autofit as af +from autofit import exc +from autofit.mapper.prior.abstract import Prior + + +# --------------------------------------------------------------------------- +# Collection: composition, nesting, instance creation, iteration +# --------------------------------------------------------------------------- +class TestCollectionComposition: + def test_collection_from_dict(self): + model = af.Collection( + one=af.Model(af.m.MockClassx2), + two=af.Model(af.m.MockClassx2), + ) + assert model.prior_count == 4 + + def test_collection_from_list(self): + model = af.Collection([af.m.MockClassx2, af.m.MockClassx2]) + assert model.prior_count == 4 + + def test_collection_from_generator(self): + model = af.Collection(af.Model(af.m.MockClassx2) for _ in range(3)) + assert model.prior_count == 6 + + def test_nested_collection(self): + inner = af.Collection(a=af.m.MockClassx2) + outer = af.Collection(inner=inner, extra=af.m.MockClassx2) + assert outer.prior_count == 4 + + def test_deeply_nested_collection(self): + model = af.Collection( + level1=af.Collection( + level2=af.Collection( + leaf=af.m.MockClassx2, + ) + ) + ) + assert model.prior_count == 2 + + def test_collection_instance_attribute_access(self): + model = af.Collection(gaussian=af.m.MockClassx2, exp=af.m.MockClassx2) + instance = model.instance_from_vector([1.0, 2.0, 3.0, 4.0]) + assert instance.gaussian.one == 1.0 + assert instance.gaussian.two == 2.0 + assert instance.exp.one == 3.0 + assert instance.exp.two == 4.0 + + def test_collection_instance_index_access(self): + model = af.Collection([af.m.MockClassx2, af.m.MockClassx2]) + instance = model.instance_from_vector([1.0, 2.0, 3.0, 4.0]) + assert instance[0].one == 1.0 + assert instance[1].one == 3.0 + + def test_collection_len(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + assert len(model) == 2 + + def test_collection_contains(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + assert "a" in model + assert "c" not in model + + def test_collection_items(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + keys = [k for k, v in model.items()] + assert "a" in keys + assert "b" in keys + + def test_collection_getitem_string(self): + model = af.Collection(a=af.m.MockClassx2) + assert isinstance(model["a"], af.Model) + + def test_collection_append(self): + model = af.Collection() + model.append(af.m.MockClassx2) + model.append(af.m.MockClassx2) + assert model.prior_count == 4 + + def test_collection_mixed_model_and_fixed(self): + """Collection with one free model and one fixed instance.""" + model = af.Collection( + free=af.Model(af.m.MockClassx2), + ) + assert model.prior_count == 2 + + def test_empty_collection(self): + model = af.Collection() + assert model.prior_count == 0 + + +# --------------------------------------------------------------------------- +# Shared (linked) priors +# --------------------------------------------------------------------------- +class TestSharedPriors: + def test_link_within_model(self): + model = af.Model(af.m.MockClassx2) + model.one = model.two + assert model.prior_count == 1 + instance = model.instance_from_vector([5.0]) + assert instance.one == instance.two == 5.0 + + def test_link_across_collection_children(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + model.a.one = model.b.one # Link a.one to b.one + assert model.prior_count == 3 # 4 - 1 shared + + def test_linked_priors_same_value(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + model.a.one = model.b.one + instance = model.instance_from_vector([10.0, 20.0, 30.0]) + assert instance.a.one == instance.b.one + + def test_link_reduces_unique_prior_count(self): + model = af.Model(af.m.MockClassx2) + original_count = len(model.unique_prior_tuples) + model.one = model.two + assert len(model.unique_prior_tuples) == original_count - 1 + + def test_linked_prior_identity(self): + model = af.Model(af.m.MockClassx2) + model.one = model.two + assert model.one is model.two + + +# --------------------------------------------------------------------------- +# instance_for_arguments (direct argument dict usage) +# --------------------------------------------------------------------------- +class TestInstanceForArguments: + def test_model_instance_for_arguments(self): + model = af.Model(af.m.MockClassx2) + args = {model.one: 10.0, model.two: 20.0} + instance = model.instance_for_arguments(args) + assert instance.one == 10.0 + assert instance.two == 20.0 + + def test_collection_instance_for_arguments(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + args = {} + for name, prior in model.prior_tuples_ordered_by_id: + args[prior] = 1.0 + instance = model.instance_for_arguments(args) + assert instance.a.one == 1.0 + assert instance.b.two == 1.0 + + def test_shared_prior_in_arguments(self): + """When priors are linked, only one entry is needed in the arguments dict.""" + model = af.Model(af.m.MockClassx2) + model.one = model.two + shared_prior = model.one + args = {shared_prior: 42.0} + instance = model.instance_for_arguments(args) + assert instance.one == 42.0 + assert instance.two == 42.0 + + def test_missing_prior_raises(self): + model = af.Model(af.m.MockClassx2) + args = {model.one: 10.0} # missing model.two + with pytest.raises(KeyError): + model.instance_for_arguments(args) + + +# --------------------------------------------------------------------------- +# Vector and unit vector mapping +# --------------------------------------------------------------------------- +class TestVectorMapping: + def test_instance_from_vector_basic(self): + model = af.Model(af.m.MockClassx2) + instance = model.instance_from_vector([3.0, 4.0]) + assert instance.one == 3.0 + assert instance.two == 4.0 + + def test_vector_length_mismatch_raises(self): + model = af.Model(af.m.MockClassx2) + with pytest.raises(AssertionError): + model.instance_from_vector([1.0]) + + def test_unit_vector_length_mismatch_raises(self): + model = af.Model(af.m.MockClassx2) + with pytest.raises(AssertionError): + model.instance_from_unit_vector([0.5]) + + def test_vector_from_unit_vector(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + physical = model.vector_from_unit_vector([0.0, 1.0]) + assert physical[0] == pytest.approx(0.0, abs=1e-6) + assert physical[1] == pytest.approx(10.0, abs=1e-6) + + def test_instance_from_prior_medians(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=100.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=100.0) + instance = model.instance_from_prior_medians() + assert instance.one == pytest.approx(50.0) + assert instance.two == pytest.approx(50.0) + + def test_random_instance(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) + instance = model.random_instance() + assert 0.0 <= instance.one <= 1.0 + assert 0.0 <= instance.two <= 1.0 + + +# --------------------------------------------------------------------------- +# Model tree navigation +# --------------------------------------------------------------------------- +class TestModelTreeNavigation: + def test_object_for_path_child_model(self): + model = af.Collection(g=af.Model(af.m.MockClassx2)) + child = model.object_for_path(("g",)) + assert isinstance(child, af.Model) + + def test_object_for_path_prior(self): + model = af.Collection(g=af.Model(af.m.MockClassx2)) + prior = model.object_for_path(("g", "one")) + assert isinstance(prior, Prior) + + def test_paths_matches_prior_count(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + assert len(model.paths) == model.prior_count + + def test_path_for_prior(self): + model = af.Collection(g=af.Model(af.m.MockClassx2)) + prior = model.g.one + path = model.path_for_prior(prior) + assert path == ("g", "one") + + def test_name_for_prior(self): + model = af.Collection(g=af.Model(af.m.MockClassx2)) + prior = model.g.one + name = model.name_for_prior(prior) + assert name == "g_one" + + def test_path_instance_tuples_for_class(self): + model = af.Collection(g=af.Model(af.m.MockClassx2)) + tuples = model.path_instance_tuples_for_class(Prior) + paths = [t[0] for t in tuples] + assert ("g", "one") in paths + assert ("g", "two") in paths + + def test_deeply_nested_path(self): + inner_model = af.Model(af.m.MockClassx2) + inner_collection = af.Collection(leaf=inner_model) + outer = af.Collection(branch=inner_collection) + + prior = outer.branch.leaf.one + path = outer.path_for_prior(prior) + assert path == ("branch", "leaf", "one") + + def test_direct_vs_recursive_prior_tuples(self): + model = af.Collection(a=af.m.MockClassx2) + assert len(model.direct_prior_tuples) == 0 # Collection has no direct priors + assert len(model.prior_tuples) == 2 # But has 2 recursive priors + + def test_direct_prior_model_tuples(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + assert len(model.direct_prior_model_tuples) == 2 + + +# --------------------------------------------------------------------------- +# instance_from_path_arguments and instance_from_prior_name_arguments +# --------------------------------------------------------------------------- +class TestPathAndNameArguments: + def test_instance_from_path_arguments(self): + model = af.Collection(g=af.m.MockClassx2) + instance = model.instance_from_path_arguments( + {("g", "one"): 10.0, ("g", "two"): 20.0} + ) + assert instance.g.one == 10.0 + assert instance.g.two == 20.0 + + def test_instance_from_prior_name_arguments(self): + model = af.Collection(g=af.m.MockClassx2) + instance = model.instance_from_prior_name_arguments( + {"g_one": 10.0, "g_two": 20.0} + ) + assert instance.g.one == 10.0 + assert instance.g.two == 20.0 + + +# --------------------------------------------------------------------------- +# Assertions +# --------------------------------------------------------------------------- +class TestAssertions: + def test_assertion_passes(self): + model = af.Model(af.m.MockClassx2) + model.add_assertion(model.one > model.two) + # one=10 > two=5 should pass + instance = model.instance_from_vector([10.0, 5.0]) + assert instance.one == 10.0 + + def test_assertion_fails(self): + model = af.Model(af.m.MockClassx2) + model.add_assertion(model.one > model.two) + with pytest.raises(exc.FitException): + model.instance_from_vector([1.0, 10.0]) + + def test_ignore_assertions(self): + model = af.Model(af.m.MockClassx2) + model.add_assertion(model.one > model.two) + # Should not raise even though assertion fails + instance = model.instance_from_vector([1.0, 10.0], ignore_assertions=True) + assert instance.one == 1.0 + + def test_multiple_assertions(self): + model = af.Model(af.m.MockClassx4) + model.add_assertion(model.one > model.two) + model.add_assertion(model.three > model.four) + # Both pass + instance = model.instance_from_vector([10.0, 5.0, 10.0, 5.0]) + assert instance.one == 10.0 + # First fails + with pytest.raises(exc.FitException): + model.instance_from_vector([1.0, 10.0, 10.0, 5.0]) + + def test_true_assertion_ignored(self): + """Adding True as an assertion should be a no-op.""" + model = af.Model(af.m.MockClassx2) + model.add_assertion(True) + assert len(model.assertions) == 0 + + +# --------------------------------------------------------------------------- +# Model subsetting (with_paths, without_paths) +# --------------------------------------------------------------------------- +class TestModelSubsetting: + def test_with_paths_single_child(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + subset = model.with_paths([("a",)]) + assert subset.prior_count == 2 + + def test_without_paths_single_child(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + subset = model.without_paths([("a",)]) + assert subset.prior_count == 2 + + def test_with_paths_specific_prior(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + subset = model.with_paths([("a", "one")]) + assert subset.prior_count == 1 + + def test_with_prefix(self): + model = af.Collection(ab_one=af.m.MockClassx2, cd_two=af.m.MockClassx2) + subset = model.with_prefix("ab") + assert subset.prior_count == 2 + + +# --------------------------------------------------------------------------- +# Freezing behavior +# --------------------------------------------------------------------------- +class TestFreezing: + def test_freeze_prevents_modification(self): + model = af.Model(af.m.MockClassx2) + model.freeze() + with pytest.raises(AssertionError): + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) + + def test_unfreeze_allows_modification(self): + model = af.Model(af.m.MockClassx2) + model.freeze() + model.unfreeze() + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) + assert isinstance(model.one, af.UniformPrior) + + def test_frozen_model_still_creates_instances(self): + model = af.Model(af.m.MockClassx2) + model.freeze() + instance = model.instance_from_vector([1.0, 2.0]) + assert instance.one == 1.0 + + def test_freeze_propagates_to_children(self): + model = af.Collection(a=af.m.MockClassx2) + model.freeze() + with pytest.raises(AssertionError): + model.a.one = 1.0 + + def test_cached_results_consistent(self): + model = af.Model(af.m.MockClassx2) + model.freeze() + result1 = model.prior_tuples_ordered_by_id + result2 = model.prior_tuples_ordered_by_id + assert result1 == result2 + + +# --------------------------------------------------------------------------- +# mapper_from_prior_arguments and related +# --------------------------------------------------------------------------- +class TestMapperFromPriorArguments: + def test_replace_all_priors(self): + model = af.Model(af.m.MockClassx2) + new_one = af.GaussianPrior(mean=0.0, sigma=1.0) + new_two = af.GaussianPrior(mean=5.0, sigma=2.0) + new_model = model.mapper_from_prior_arguments( + {model.one: new_one, model.two: new_two} + ) + assert new_model.prior_count == 2 + assert isinstance(new_model.one, af.GaussianPrior) + + def test_partial_replacement(self): + model = af.Model(af.m.MockClassx2) + new_one = af.GaussianPrior(mean=0.0, sigma=1.0) + new_model = model.mapper_from_partial_prior_arguments( + {model.one: new_one} + ) + assert new_model.prior_count == 2 + assert isinstance(new_model.one, af.GaussianPrior) + # two should retain its original prior type + assert new_model.two is not None + + def test_fix_via_mapper_from_prior_arguments(self): + """Replacing a prior with a float effectively fixes that parameter.""" + model = af.Model(af.m.MockClassx2) + new_model = model.mapper_from_prior_arguments( + {model.one: 5.0, model.two: model.two} + ) + assert new_model.prior_count == 1 + + def test_with_limits(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=100.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=100.0) + new_model = model.with_limits([(10.0, 20.0), (30.0, 40.0)]) + assert new_model.prior_count == 2 + + +# --------------------------------------------------------------------------- +# from_instance round trips +# --------------------------------------------------------------------------- +class TestFromInstance: + def test_from_simple_instance(self): + instance = af.m.MockClassx2(1.0, 2.0) + model = af.AbstractPriorModel.from_instance(instance) + assert model.prior_count == 0 + + def test_from_instance_as_model(self): + instance = af.m.MockClassx2(1.0, 2.0) + model = af.AbstractPriorModel.from_instance(instance) + free_model = model.as_model() + assert free_model.prior_count == 2 + + def test_from_instance_with_model_classes(self): + instance = af.m.MockClassx2(1.0, 2.0) + model = af.AbstractPriorModel.from_instance( + instance, model_classes=(af.m.MockClassx2,) + ) + assert model.prior_count == 2 + + def test_from_list_instance(self): + instance_list = [af.m.MockClassx2(1.0, 2.0), af.m.MockClassx2(3.0, 4.0)] + model = af.AbstractPriorModel.from_instance(instance_list) + assert model.prior_count == 0 + + def test_from_dict_instance(self): + instance_dict = { + "one": af.m.MockClassx2(1.0, 2.0), + "two": af.m.MockClassx2(3.0, 4.0), + } + model = af.AbstractPriorModel.from_instance(instance_dict) + assert model.prior_count == 0 + + +# --------------------------------------------------------------------------- +# Fixing parameters and Constant values +# --------------------------------------------------------------------------- +class TestFixedParameters: + def test_fix_reduces_prior_count(self): + model = af.Model(af.m.MockClassx2) + model.one = 5.0 + assert model.prior_count == 1 + + def test_fixed_value_in_instance(self): + model = af.Model(af.m.MockClassx2) + model.one = 5.0 + instance = model.instance_from_vector([10.0]) + assert instance.one == 5.0 + assert instance.two == 10.0 + + def test_fix_all_parameters(self): + model = af.Model(af.m.MockClassx2) + model.one = 5.0 + model.two = 10.0 + assert model.prior_count == 0 + instance = model.instance_from_vector([]) + assert instance.one == 5.0 + assert instance.two == 10.0 + + +# --------------------------------------------------------------------------- +# take_attributes (prior passing) +# --------------------------------------------------------------------------- +class TestTakeAttributes: + def test_take_from_instance(self): + model = af.Model(af.m.MockClassx2) + source = af.m.MockClassx2(10.0, 20.0) + model.take_attributes(source) + assert model.prior_count == 0 + + def test_take_from_model(self): + """Taking attributes from another model copies priors.""" + source_model = af.Model(af.m.MockClassx2) + source_model.one = af.GaussianPrior(mean=5.0, sigma=1.0) + source_model.two = af.GaussianPrior(mean=10.0, sigma=2.0) + + target_model = af.Model(af.m.MockClassx2) + target_model.take_attributes(source_model) + assert isinstance(target_model.one, af.GaussianPrior) + + +# --------------------------------------------------------------------------- +# Serialization (dict / from_dict) +# --------------------------------------------------------------------------- +class TestSerialization: + def test_model_dict_roundtrip(self): + model = af.Model(af.m.MockClassx2) + d = model.dict() + loaded = af.AbstractPriorModel.from_dict(d) + assert loaded.prior_count == model.prior_count + + def test_collection_dict_roundtrip(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + d = model.dict() + loaded = af.AbstractPriorModel.from_dict(d) + assert loaded.prior_count == model.prior_count + + def test_fixed_parameter_survives_roundtrip(self): + model = af.Model(af.m.MockClassx2) + model.one = 5.0 + d = model.dict() + loaded = af.AbstractPriorModel.from_dict(d) + assert loaded.prior_count == 1 + + def test_linked_prior_survives_roundtrip(self): + model = af.Model(af.m.MockClassx2) + model.one = model.two + assert model.prior_count == 1 + d = model.dict() + loaded = af.AbstractPriorModel.from_dict(d) + assert loaded.prior_count == 1 + + +# --------------------------------------------------------------------------- +# Log prior computation +# --------------------------------------------------------------------------- +class TestLogPrior: + def test_log_prior_within_bounds(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + log_priors = model.log_prior_list_from_vector([5.0, 5.0]) + assert all(np.isfinite(lp) for lp in log_priors) + + def test_log_prior_outside_bounds(self): + model = af.Model(af.m.MockClassx2) + model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0) + log_priors = model.log_prior_list_from_vector([15.0, 5.0]) + # Out-of-bounds value should have a lower (or zero) log prior than in-bounds + assert log_priors[0] <= log_priors[1] + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- +class TestEdgeCases: + def test_single_parameter_model(self): + """A model with a single free parameter using explicit prior.""" + model = af.Model(af.m.MockClassx2) + model.two = 5.0 # Fix one parameter + assert model.prior_count == 1 + instance = model.instance_from_vector([42.0]) + assert instance.one == 42.0 + assert instance.two == 5.0 + + def test_model_copy_preserves_priors(self): + model = af.Model(af.m.MockClassx2) + copied = model.copy() + assert copied.prior_count == model.prior_count + # Priors are independent copies (different objects) + assert copied.one is not model.one + + def test_model_copy_linked_priors_independent(self): + """Copying a model with linked priors preserves the link in the copy.""" + model = af.Model(af.m.MockClassx2) + model.one = model.two + assert model.prior_count == 1 + copied = model.copy() + assert copied.prior_count == 1 + # The copy's internal link is preserved + assert copied.one is copied.two + + def test_prior_ordering_is_deterministic(self): + """prior_tuples_ordered_by_id should be stable across calls.""" + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2) + order1 = [(n, p.id) for n, p in model.prior_tuples_ordered_by_id] + order2 = [(n, p.id) for n, p in model.prior_tuples_ordered_by_id] + assert order1 == order2 + + def test_prior_count_equals_total_free_parameters(self): + model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx4) + assert model.prior_count == model.total_free_parameters + + def test_has_model(self): + model = af.Collection(a=af.Model(af.m.MockClassx2)) + assert model.has_model(af.m.MockClassx2) + assert not model.has_model(af.m.MockClassx4) + + def test_has_instance(self): + model = af.Model(af.m.MockClassx2) + assert model.has_instance(Prior) + assert not model.has_instance(af.m.MockClassx4)