From bff1343cd05adb62b2c8b1b1e82be338fa65d206 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 12:27:53 +0000 Subject: [PATCH 01/15] log_prior_from_value in TruncatedNormalMessae --- .../summary/aggregate_csv/column.py | 30 ++++++++++++++----- autofit/mapper/prior_model/abstract.py | 3 +- autofit/messages/truncated_normal.py | 8 ++--- autofit/non_linear/fitness.py | 2 +- autofit/non_linear/search/mle/bfgs/search.py | 2 +- 5 files changed, 30 insertions(+), 15 deletions(-) diff --git a/autofit/aggregator/summary/aggregate_csv/column.py b/autofit/aggregator/summary/aggregate_csv/column.py index 9635c1bde..77eedafde 100644 --- a/autofit/aggregator/summary/aggregate_csv/column.py +++ b/autofit/aggregator/summary/aggregate_csv/column.py @@ -64,20 +64,34 @@ def value(self, row: "Row"): result = {} if ValueType.Median in self.value_types: - result[""] = row.median_pdf_sample_kwargs[self.path] + try: + result[""] = row.median_pdf_sample_kwargs[self.path] + except KeyError: + result[""] = None if ValueType.MaxLogLikelihood in self.value_types: - result["max_lh"] = row.max_likelihood_kwargs[self.path] + try: + result["max_lh"] = row.max_likelihood_kwargs[self.path] + except KeyError: + result["max_lh"] = None if ValueType.ValuesAt1Sigma in self.value_types: - lower, upper = row.values_at_sigma_1_kwargs[self.path] - result["lower_1_sigma"] = lower - result["upper_1_sigma"] = upper + try: + lower, upper = row.values_at_sigma_1_kwargs[self.path] + result["lower_1_sigma"] = lower + result["upper_1_sigma"] = upper + except KeyError: + result["lower_1_sigma"] = None + result["upper_1_sigma"] = None if ValueType.ValuesAt3Sigma in self.value_types: - lower, upper = row.values_at_sigma_3_kwargs[self.path] - result["lower_3_sigma"] = lower - result["upper_3_sigma"] = upper + try: + lower, upper = row.values_at_sigma_3_kwargs[self.path] + result["lower_3_sigma"] = lower + result["upper_3_sigma"] = upper + except KeyError: + result["lower_3_sigma"] = None + result["upper_3_sigma"] = None return result diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index b41ea3c58..6e7778ad1 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -1078,6 +1078,7 @@ def instance_from_prior_medians(self, ignore_assertions : bool = False): def log_prior_list_from_vector( self, vector: [float], + xp=np, ): """ Compute the log priors of every parameter in a vector, using the Prior of every parameter. @@ -1094,7 +1095,7 @@ def log_prior_list_from_vector( return list( map( lambda prior_tuple, value: prior_tuple.prior.log_prior_from_value( - value=value + value=value, xp=np ), self.prior_tuples_ordered_by_id, vector, diff --git a/autofit/messages/truncated_normal.py b/autofit/messages/truncated_normal.py index 3d1614584..7329147a4 100644 --- a/autofit/messages/truncated_normal.py +++ b/autofit/messages/truncated_normal.py @@ -422,7 +422,7 @@ def value_for(self, unit: float) -> float: x_standard = norm.ppf(truncated_cdf) return self.mean + self.sigma * x_standard - def log_prior_from_value(self, value: float) -> float: + def log_prior_from_value(self, value: float, xp=np) -> float: """ Compute the log prior probability of a given physical value under this truncated Gaussian prior. @@ -446,11 +446,11 @@ def log_prior_from_value(self, value: float) -> float: Z = norm.cdf(b) - norm.cdf(a) z = (value -self.mean) / self.sigma - log_pdf = -0.5 * z ** 2 - np.log(self.sigma) - 0.5 * np.log(2 * np.pi) - log_trunc_pdf = log_pdf - np.log(Z) + log_pdf = -0.5 * z ** 2 - xp.log(self.sigma) - 0.5 * xp.log(2 * xp.pi) + log_trunc_pdf = log_pdf - xp.log(Z) in_bounds = (self.lower_limit <= value) & (value <= self.upper_limit) - return np.where(in_bounds, log_trunc_pdf, -np.inf) + return xp.where(in_bounds, log_trunc_pdf, -xp.inf) def __str__(self): """ diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index be0e63403..be9d916c6 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -179,7 +179,7 @@ def call(self, parameters): figure_of_merit = log_likelihood else: # Ensure prior list is compatible with JAX (must return a JAX array, not list) - log_prior_array = self._xp.array(self.model.log_prior_list_from_vector(vector=parameters)) + log_prior_array = self._xp.array(self.model.log_prior_list_from_vector(vector=parameters, xp=self._xp)) figure_of_merit = log_likelihood + self._xp.sum(log_prior_array) # Convert to chi-squared scale if requested diff --git a/autofit/non_linear/search/mle/bfgs/search.py b/autofit/non_linear/search/mle/bfgs/search.py index 0ef3a5b34..7b5c389e5 100644 --- a/autofit/non_linear/search/mle/bfgs/search.py +++ b/autofit/non_linear/search/mle/bfgs/search.py @@ -145,7 +145,7 @@ def _fit( config_dict_options["maxiter"] = iterations search_internal = optimize.minimize( - fun=fitness.call_wrap, + fun=fitness._jit, x0=x0, method=self.method, options=config_dict_options, From be2dbd0cc49b143db3c53c2fdd6366f50722e153 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 12:29:04 +0000 Subject: [PATCH 02/15] log_prior_from_value in UniformPrior --- autofit/mapper/prior/uniform.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/autofit/mapper/prior/uniform.py b/autofit/mapper/prior/uniform.py index 08baed5bb..af04f23d5 100644 --- a/autofit/mapper/prior/uniform.py +++ b/autofit/mapper/prior/uniform.py @@ -1,3 +1,4 @@ +import numpy as np from typing import Optional, Tuple from autofit.messages.normal import UniformNormalMessage @@ -128,7 +129,7 @@ def value_for(self, unit: float) -> float: round(super().value_for(unit), 14) ) - def log_prior_from_value(self, value): + def log_prior_from_value(self, value, xp=np): """ Returns the log prior of a physical value, so the log likelihood of a model evaluation can be converted to a posterior as log_prior + log_likelihood. From a00bcff1cf270beee4d429179db16a4f0976bbe3 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 12:30:50 +0000 Subject: [PATCH 03/15] remaining def log_prior_from_value --- autofit/mapper/prior/log_gaussian.py | 2 +- autofit/mapper/prior/log_uniform.py | 2 +- autofit/messages/normal.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/autofit/mapper/prior/log_gaussian.py b/autofit/mapper/prior/log_gaussian.py index 6fa458950..d8f2a85b7 100644 --- a/autofit/mapper/prior/log_gaussian.py +++ b/autofit/mapper/prior/log_gaussian.py @@ -133,7 +133,7 @@ def value_for(self, unit: float) -> float: def parameter_string(self) -> str: return f"mean = {self.mean}, sigma = {self.sigma}" - def log_prior_from_value(self, value): + def log_prior_from_value(self, value, xp=np): if value <= 0: return float("-inf") diff --git a/autofit/mapper/prior/log_uniform.py b/autofit/mapper/prior/log_uniform.py index 63a9063b2..434af5b08 100644 --- a/autofit/mapper/prior/log_uniform.py +++ b/autofit/mapper/prior/log_uniform.py @@ -110,7 +110,7 @@ def with_limits(cls, lower_limit: float, upper_limit: float) -> "LogUniformPrior __identifier_fields__ = ("lower_limit", "upper_limit") - def log_prior_from_value(self, value) -> float: + def log_prior_from_value(self, value, xp=np) -> float: """ Returns the log prior of a physical value, so the log likelihood of a model evaluation can be converted to a posterior as log_prior + log_likelihood. diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index 9ff12ef3a..dff47e773 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -401,7 +401,7 @@ def value_for(self, unit: float) -> float: return self.mean + (self.sigma * np.sqrt(2) * inv) - def log_prior_from_value(self, value: float) -> float: + def log_prior_from_value(self, value: float, xp=np) -> float: """ Compute the log prior probability of a given physical value under this Gaussian prior. From 4874c6575137ccb37080e53fbb733adb0a22598c Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 12:46:05 +0000 Subject: [PATCH 04/15] fix xp pass --- autofit/mapper/prior_model/abstract.py | 2 +- autofit/messages/truncated_normal.py | 19 ++++++++++++++++--- autofit/non_linear/search/mle/bfgs/search.py | 3 ++- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index 6e7778ad1..5310315a5 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -1095,7 +1095,7 @@ def log_prior_list_from_vector( return list( map( lambda prior_tuple, value: prior_tuple.prior.log_prior_from_value( - value=value, xp=np + value=value, xp=xp ), self.prior_tuples_ordered_by_id, vector, diff --git a/autofit/messages/truncated_normal.py b/autofit/messages/truncated_normal.py index 7329147a4..07a34464c 100644 --- a/autofit/messages/truncated_normal.py +++ b/autofit/messages/truncated_normal.py @@ -439,17 +439,30 @@ def log_prior_from_value(self, value: float, xp=np) -> float: ------- The log prior probability of the given value, or -inf if outside truncation bounds. """ - from scipy.stats import norm + if xp.__name__.startswith("jax"): + import jax.scipy.stats as jstats + norm = jstats.norm + else: + from scipy.stats import norm + + # Normalization term (truncation) a = (self.lower_limit - self.mean) / self.sigma b = (self.upper_limit - self.mean) / self.sigma Z = norm.cdf(b) - norm.cdf(a) - z = (value -self.mean) / self.sigma - log_pdf = -0.5 * z ** 2 - xp.log(self.sigma) - 0.5 * xp.log(2 * xp.pi) + # Log pdf + z = (value - self.mean) / self.sigma + log_pdf = ( + -0.5 * z ** 2 + - xp.log(self.sigma) + - 0.5 * xp.log(2.0 * xp.pi) + ) log_trunc_pdf = log_pdf - xp.log(Z) + # Truncation mask (must be xp.where for JAX) in_bounds = (self.lower_limit <= value) & (value <= self.upper_limit) + return xp.where(in_bounds, log_trunc_pdf, -xp.inf) def __str__(self): diff --git a/autofit/non_linear/search/mle/bfgs/search.py b/autofit/non_linear/search/mle/bfgs/search.py index 7b5c389e5..9804f83b9 100644 --- a/autofit/non_linear/search/mle/bfgs/search.py +++ b/autofit/non_linear/search/mle/bfgs/search.py @@ -92,6 +92,7 @@ def _fit( fom_is_log_likelihood=False, resample_figure_of_merit=-np.inf, convert_to_chi_squared=True, + use_jax_vmap=True, store_history=self.should_plot_start_point ) @@ -145,7 +146,7 @@ def _fit( config_dict_options["maxiter"] = iterations search_internal = optimize.minimize( - fun=fitness._jit, + fun=fitness.call_wrap, x0=x0, method=self.method, options=config_dict_options, From 0629f55376fd08b78ac31b350c8397df6b9436b4 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 12:58:36 +0000 Subject: [PATCH 05/15] Put xp code throughout messages --- autofit/mapper/prior/arithmetic/assertion.py | 4 +++- autofit/mapper/prior/arithmetic/compound.py | 10 ++++++++++ autofit/mapper/prior/gaussian.py | 3 +++ autofit/mapper/prior_model/abstract.py | 6 +++++- autofit/mapper/prior_model/array.py | 1 + autofit/mapper/prior_model/collection.py | 1 + autofit/mapper/prior_model/prior_model.py | 1 + autofit/messages/normal.py | 17 +++++++++++++++-- autofit/non_linear/fitness.py | 2 +- autofit/non_linear/search/mle/bfgs/search.py | 3 +++ 10 files changed, 43 insertions(+), 5 deletions(-) diff --git a/autofit/mapper/prior/arithmetic/assertion.py b/autofit/mapper/prior/arithmetic/assertion.py index 2e7315d2b..5d9315e43 100644 --- a/autofit/mapper/prior/arithmetic/assertion.py +++ b/autofit/mapper/prior/arithmetic/assertion.py @@ -24,7 +24,7 @@ def __le__(self, other): class GreaterThanLessThanAssertion(ComparisonAssertion): - def _instance_for_arguments(self, arguments, ignore_assertions=False): + def _instance_for_arguments(self, arguments, ignore_assertions=False, xp=np): """ Assert that the value in the dictionary associated with the lower prior is lower than the value associated with the greater prior. @@ -55,6 +55,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): """ Assert that the value in the dictionary associated with the lower @@ -90,6 +91,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.assertion_1.instance_for_arguments( arguments, diff --git a/autofit/mapper/prior/arithmetic/compound.py b/autofit/mapper/prior/arithmetic/compound.py index 4951340a1..ce9d1b910 100644 --- a/autofit/mapper/prior/arithmetic/compound.py +++ b/autofit/mapper/prior/arithmetic/compound.py @@ -212,6 +212,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -237,6 +238,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -256,6 +258,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -275,6 +278,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -294,6 +298,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -313,6 +318,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return self.left_for_arguments( arguments, @@ -396,6 +402,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return -self.prior.instance_for_arguments( arguments, @@ -412,6 +419,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return abs( self.prior.instance_for_arguments( @@ -430,6 +438,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return np.log( self.prior.instance_for_arguments( @@ -448,6 +457,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): return np.log10( self.prior.instance_for_arguments( diff --git a/autofit/mapper/prior/gaussian.py b/autofit/mapper/prior/gaussian.py index a4bcc5bb6..0ee8e568d 100644 --- a/autofit/mapper/prior/gaussian.py +++ b/autofit/mapper/prior/gaussian.py @@ -1,3 +1,4 @@ +import numpy as np from typing import Optional from autofit.messages.normal import NormalMessage @@ -13,6 +14,7 @@ def __init__( mean: float, sigma: float, id_: Optional[int] = None, + _xp=np ): """ A Gaussian prior defined by a normal distribution. @@ -50,6 +52,7 @@ def __init__( message=NormalMessage( mean=mean, sigma=sigma, + xp=_xp ), id_=id_, ) diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index 5310315a5..8c134687e 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -745,7 +745,7 @@ def physical_values_from_prior_medians(self): """ return self.vector_from_unit_vector([0.5] * len(self.unique_prior_tuples)) - def instance_from_vector(self, vector, ignore_assertions: bool = False): + def instance_from_vector(self, vector, ignore_assertions: bool = False, xp=np): """ Returns a ModelInstance, which has an attribute and class instance corresponding to every `Model` attributed to this instance. @@ -778,6 +778,7 @@ def instance_from_vector(self, vector, ignore_assertions: bool = False): return self.instance_for_arguments( arguments, ignore_assertions=ignore_assertions, + xp=xp ) def has(self, cls: Union[Type, Tuple[Type, ...]]) -> bool: @@ -1309,6 +1310,7 @@ def _instance_for_arguments( self, arguments: Dict[Prior, float], ignore_assertions: bool = False, + xp=np, ): raise NotImplementedError() @@ -1316,6 +1318,7 @@ def instance_for_arguments( self, arguments: Dict[Prior, float], ignore_assertions: bool = False, + xp=np, ): """ Returns an instance of the model for a set of arguments @@ -1339,6 +1342,7 @@ def instance_for_arguments( return self._instance_for_arguments( arguments, ignore_assertions=ignore_assertions, + xpx=p ) def path_for_name(self, name: str) -> Tuple[str, ...]: diff --git a/autofit/mapper/prior_model/array.py b/autofit/mapper/prior_model/array.py index 7952b2743..317680818 100644 --- a/autofit/mapper/prior_model/array.py +++ b/autofit/mapper/prior_model/array.py @@ -57,6 +57,7 @@ def _instance_for_arguments( self, arguments: Dict[Prior, float], ignore_assertions: bool = False, + xp=np, ) -> np.ndarray: """ Create an array where the prior at each index is replaced with the diff --git a/autofit/mapper/prior_model/collection.py b/autofit/mapper/prior_model/collection.py index 5d39dcdcd..36f7a14d8 100644 --- a/autofit/mapper/prior_model/collection.py +++ b/autofit/mapper/prior_model/collection.py @@ -208,6 +208,7 @@ def _instance_for_arguments( self, arguments, ignore_assertions=False, + xp=np, ): """ Parameters diff --git a/autofit/mapper/prior_model/prior_model.py b/autofit/mapper/prior_model/prior_model.py index cbf1cb285..778a15fdc 100644 --- a/autofit/mapper/prior_model/prior_model.py +++ b/autofit/mapper/prior_model/prior_model.py @@ -459,6 +459,7 @@ def _instance_for_arguments( self, arguments: {ModelObject: object}, ignore_assertions=False, + xp=np, ): """ Returns an instance of the associated class for a set of arguments diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index dff47e773..248e66c6f 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -23,6 +23,19 @@ def is_nan(value): is_nan_ = is_nan_.all() return is_nan_ +def assert_sigma_non_negative(sigma, xp=np): + sigma_arr = xp.asarray(sigma) + is_negative = xp.any(sigma_arr < 0) + + # Convert to Python bool safely: + try: + flag = bool(is_negative) + except Exception: + # JAX tracers need explicit .item() + flag = bool(is_negative.item()) + + if flag: + raise exc.MessageException("Sigma cannot be negative") class NormalMessage(AbstractMessage): @cached_property @@ -52,6 +65,7 @@ def __init__( sigma : Union[float, np.ndarray], log_norm : Optional[float] = 0.0, id_ : Optional[Hashable] = None, + xp=np ): """ A Gaussian (Normal) message representing a probability distribution over a continuous variable. @@ -73,8 +87,7 @@ def __init__( id_ An optional unique identifier used to track the message in larger probabilistic graphs or models. """ - if (np.array(sigma) < 0).any(): - raise exc.MessageException("Sigma cannot be negative") + assert_sigma_non_negative(sigma, xp=xp) super().__init__( mean, diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index be9d916c6..f0de611de 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -317,7 +317,7 @@ def manage_quick_update(self, parameters, log_likelihood): logger.info("Performing quick update of maximum log likelihood fit image and model.results") - instance = self.model.instance_from_vector(vector=self.quick_update_max_lh_parameters) + instance = self.model.instance_from_vector(vector=self.quick_update_max_lh_parameters, xp=self._xp) try: self.analysis.perform_quick_update(self.paths, instance) diff --git a/autofit/non_linear/search/mle/bfgs/search.py b/autofit/non_linear/search/mle/bfgs/search.py index 9804f83b9..32acb089a 100644 --- a/autofit/non_linear/search/mle/bfgs/search.py +++ b/autofit/non_linear/search/mle/bfgs/search.py @@ -227,6 +227,9 @@ def samples_via_internal_from( weight_list = len(log_likelihood_list) * [1.0] + print(log_likelihood_list) + print(parameter_lists) + sample_list = Sample.from_lists( model=model, parameter_lists=parameter_lists, From 15d08785799ca491c257343098a8e8b672e158f8 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 13:56:45 +0000 Subject: [PATCH 06/15] lots of hjorrible hacks --- autofit/mapper/prior/arithmetic/assertion.py | 1 + autofit/mapper/prior/gaussian.py | 3 +- autofit/mapper/prior_model/abstract.py | 2 +- autofit/mapper/prior_model/collection.py | 3 ++ autofit/mapper/prior_model/prior_model.py | 33 ++------------------ autofit/messages/normal.py | 16 +++++++++- autofit/non_linear/fitness.py | 8 ++--- autofit/non_linear/search/mle/bfgs/search.py | 7 ++--- 8 files changed, 31 insertions(+), 42 deletions(-) diff --git a/autofit/mapper/prior/arithmetic/assertion.py b/autofit/mapper/prior/arithmetic/assertion.py index 5d9315e43..b23a8c7a4 100644 --- a/autofit/mapper/prior/arithmetic/assertion.py +++ b/autofit/mapper/prior/arithmetic/assertion.py @@ -1,4 +1,5 @@ from abc import ABC +import numpy as np from typing import Optional, Dict from autofit.mapper.prior.arithmetic.compound import CompoundPrior, Compound diff --git a/autofit/mapper/prior/gaussian.py b/autofit/mapper/prior/gaussian.py index 0ee8e568d..fe74b6016 100644 --- a/autofit/mapper/prior/gaussian.py +++ b/autofit/mapper/prior/gaussian.py @@ -14,7 +14,6 @@ def __init__( mean: float, sigma: float, id_: Optional[int] = None, - _xp=np ): """ A Gaussian prior defined by a normal distribution. @@ -48,11 +47,11 @@ def __init__( >>> prior = GaussianPrior(mean=1.0, sigma=2.0) >>> physical_value = prior.value_for(unit=0.5) # Returns ~1.0 (mean) """ + super().__init__( message=NormalMessage( mean=mean, sigma=sigma, - xp=_xp ), id_=id_, ) diff --git a/autofit/mapper/prior_model/abstract.py b/autofit/mapper/prior_model/abstract.py index 8c134687e..f6b7f3939 100644 --- a/autofit/mapper/prior_model/abstract.py +++ b/autofit/mapper/prior_model/abstract.py @@ -1342,7 +1342,7 @@ def instance_for_arguments( return self._instance_for_arguments( arguments, ignore_assertions=ignore_assertions, - xpx=p + xp=xp ) def path_for_name(self, name: str) -> Tuple[str, ...]: diff --git a/autofit/mapper/prior_model/collection.py b/autofit/mapper/prior_model/collection.py index 36f7a14d8..13073c10e 100644 --- a/autofit/mapper/prior_model/collection.py +++ b/autofit/mapper/prior_model/collection.py @@ -1,3 +1,5 @@ +import numpy as np + from collections.abc import Iterable from autofit.mapper.model import ModelInstance, assert_not_frozen @@ -229,6 +231,7 @@ def _instance_for_arguments( value = value.instance_for_arguments( arguments, ignore_assertions=ignore_assertions, + xp=xp ) elif isinstance(value, Prior): value = arguments[value] diff --git a/autofit/mapper/prior_model/prior_model.py b/autofit/mapper/prior_model/prior_model.py index 778a15fdc..756899185 100644 --- a/autofit/mapper/prior_model/prior_model.py +++ b/autofit/mapper/prior_model/prior_model.py @@ -2,6 +2,7 @@ import copy import inspect import logging +import numpy as np import typing from typing import * @@ -420,36 +421,6 @@ def __getattr__(self, item): self.__getattribute__(item) - # def __getattr__(self, item): - # - # try: - # if ( - # "_" in item - # and item not in ("_is_frozen", "tuple_prior_tuples") - # and not item.startswith("_") - # ): - # return getattr( - # [v for k, v in self.tuple_prior_tuples if item.split("_")[0] == k][ - # 0 - # ], - # item, - # ) - # - # except IndexError: - # pass - # - # try: - # return getattr( - # self.instance_for_arguments( - # {prior: prior for prior in self.priors}, - # ), - # item, - # ) - # except (AttributeError, TypeError): - # pass - # - # self.__getattribute__(item) - @property def is_deferred_arguments(self): return len(self.direct_deferred_tuples) > 0 @@ -491,6 +462,7 @@ def _instance_for_arguments( ] = prior_model.instance_for_arguments( arguments, ignore_assertions=ignore_assertions, + xp=xp ) prior_arguments = dict() @@ -533,6 +505,7 @@ def _instance_for_arguments( value = value.instance_for_arguments( arguments, ignore_assertions=ignore_assertions, + xp=xp ) elif isinstance(value, Constant): value = value.value diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index 248e66c6f..c5c36f1d6 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -24,6 +24,7 @@ def is_nan(value): return is_nan_ def assert_sigma_non_negative(sigma, xp=np): + sigma_arr = xp.asarray(sigma) is_negative = xp.any(sigma_arr < 0) @@ -65,7 +66,6 @@ def __init__( sigma : Union[float, np.ndarray], log_norm : Optional[float] = 0.0, id_ : Optional[Hashable] = None, - xp=np ): """ A Gaussian (Normal) message representing a probability distribution over a continuous variable. @@ -87,6 +87,20 @@ def __init__( id_ An optional unique identifier used to track the message in larger probabilistic graphs or models. """ + if isinstance(mean, (np.ndarray, float, int)): + xp = np + else: + import jax.numpy as jnp + xp = jnp + + + print(type(mean)) + print(type(mean)) + print(xp) + print(xp) + print(xp) + print(xp) + assert_sigma_non_negative(sigma, xp=xp) super().__init__( diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index f0de611de..e3a7ab0b0 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -157,7 +157,7 @@ def call(self, parameters): The figure of merit returned to the non-linear search, which is either the log likelihood or log posterior. """ # Get instance from model - instance = self.model.instance_from_vector(vector=parameters) + instance = self.model.instance_from_vector(vector=parameters, xp=self._xp) if self._xp.__name__.startswith("jax"): @@ -227,7 +227,7 @@ def call_wrap(self, parameters): if self.fom_is_log_likelihood: log_likelihood = figure_of_merit else: - log_prior_list = self._xp.array(self.model.log_prior_list_from_vector(vector=parameters)) + log_prior_list = self._xp.array(self.model.log_prior_list_from_vector(vector=parameters, xp=self._xp)) log_likelihood = figure_of_merit - self._xp.sum(log_prior_list) self.manage_quick_update(parameters=parameters, log_likelihood=log_likelihood) @@ -237,8 +237,8 @@ def call_wrap(self, parameters): if self.store_history: - self.parameters_history_list.append(parameters) - self.log_likelihood_history_list.append(log_likelihood) + self.parameters_history_list.append(np.array(parameters)) + self.log_likelihood_history_list.append(np.array(log_likelihood)) return figure_of_merit diff --git a/autofit/non_linear/search/mle/bfgs/search.py b/autofit/non_linear/search/mle/bfgs/search.py index 32acb089a..42a849235 100644 --- a/autofit/non_linear/search/mle/bfgs/search.py +++ b/autofit/non_linear/search/mle/bfgs/search.py @@ -92,7 +92,6 @@ def _fit( fom_is_log_likelihood=False, resample_figure_of_merit=-np.inf, convert_to_chi_squared=True, - use_jax_vmap=True, store_history=self.should_plot_start_point ) @@ -146,7 +145,7 @@ def _fit( config_dict_options["maxiter"] = iterations search_internal = optimize.minimize( - fun=fitness.call_wrap, + fun=fitness._jit, x0=x0, method=self.method, options=config_dict_options, @@ -209,7 +208,6 @@ def samples_via_internal_from( x0 = search_internal.x total_iterations = search_internal.nit - if self.should_plot_start_point: parameter_lists = search_internal.parameters_history_list @@ -227,8 +225,9 @@ def samples_via_internal_from( weight_list = len(log_likelihood_list) * [1.0] - print(log_likelihood_list) print(parameter_lists) + print(log_likelihood_list) + print(log_prior_list) sample_list = Sample.from_lists( model=model, From 601c5f1fb74ff510b99782236901b52911315562 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 14:09:40 +0000 Subject: [PATCH 07/15] GaussianPrior works --- autofit/messages/abstract.py | 11 ++++++++-- autofit/messages/interface.py | 5 +++++ autofit/messages/normal.py | 41 +++++++++++++++++------------------ 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/autofit/messages/abstract.py b/autofit/messages/abstract.py index d3b921108..2148c66e5 100644 --- a/autofit/messages/abstract.py +++ b/autofit/messages/abstract.py @@ -45,17 +45,24 @@ def __init__( lower_limit=-math.inf, upper_limit=math.inf, id_=None, + _xp=np ): + xp=_xp + self.lower_limit = float(lower_limit) self.upper_limit = float(upper_limit) self.id = next(self.ids) if id_ is None else id_ self.log_norm = log_norm - self._broadcast = np.broadcast(*parameters) + + if xp is np: + self._broadcast = np.broadcast(*parameters) + else: + self._broadcast = _xp.broadcast_arrays(*parameters) if self.shape: - self.parameters = tuple(np.asanyarray(p) for p in parameters) + self.parameters = tuple(xp.aarray(p) for p in parameters) else: self.parameters = tuple(parameters) diff --git a/autofit/messages/interface.py b/autofit/messages/interface.py index 28d46d2ab..daff57479 100644 --- a/autofit/messages/interface.py +++ b/autofit/messages/interface.py @@ -23,6 +23,11 @@ def broadcast(self): @property def shape(self) -> Tuple[int, ...]: + + # JAX behaviour + if isinstance(self.broadcast, list): + return () + return self.broadcast.shape @property diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index c5c36f1d6..84bd3f2ff 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -25,18 +25,24 @@ def is_nan(value): def assert_sigma_non_negative(sigma, xp=np): - sigma_arr = xp.asarray(sigma) - is_negative = xp.any(sigma_arr < 0) - - # Convert to Python bool safely: - try: - flag = bool(is_negative) - except Exception: - # JAX tracers need explicit .item() - flag = bool(is_negative.item()) - - if flag: - raise exc.MessageException("Sigma cannot be negative") + is_negative = sigma < 0 + + if xp.__name__.startswith("jax"): + import jax + # JAX path: cannot convert to Python bool + # Raise using JAX control flow: + return jax.lax.cond( + is_negative, + lambda _: (_ for _ in ()).throw( + ValueError("Sigma cannot be negative") + ), + lambda _: None, + operand=None, + ) + else: + # NumPy path: normal boolean works + if bool(is_negative): + raise ValueError("Sigma cannot be negative") class NormalMessage(AbstractMessage): @cached_property @@ -93,21 +99,14 @@ def __init__( import jax.numpy as jnp xp = jnp - - print(type(mean)) - print(type(mean)) - print(xp) - print(xp) - print(xp) - print(xp) - - assert_sigma_non_negative(sigma, xp=xp) + # assert_sigma_non_negative(sigma, xp=xp) super().__init__( mean, sigma, log_norm=log_norm, id_=id_, + _xp=xp ) self.mean, self.sigma = self.parameters From e5dda2279d2d72e94ef9269e25b30f2a008580c9 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 14:24:43 +0000 Subject: [PATCH 08/15] various message funcitons now use xp --- .../declarative/factor/hierarchical.py | 7 +++++-- autofit/messages/abstract.py | 4 ++-- autofit/messages/beta.py | 10 +++++----- autofit/messages/composed_transform.py | 7 +++---- autofit/messages/fixed.py | 5 ++--- autofit/messages/gamma.py | 9 ++++----- autofit/messages/interface.py | 19 +++++++++---------- autofit/messages/normal.py | 18 ++++++++---------- autofit/messages/truncated_normal.py | 14 ++++++-------- 9 files changed, 44 insertions(+), 49 deletions(-) diff --git a/autofit/graphical/declarative/factor/hierarchical.py b/autofit/graphical/declarative/factor/hierarchical.py index f4dfec1d0..0ec5fde9b 100644 --- a/autofit/graphical/declarative/factor/hierarchical.py +++ b/autofit/graphical/declarative/factor/hierarchical.py @@ -8,6 +8,7 @@ from autofit.messages import NormalMessage from autofit.non_linear.paths.abstract import AbstractPaths from autofit.tools.namer import namer + from .abstract import AbstractModelFactor @@ -19,6 +20,7 @@ def __init__( distribution: Type[Prior], optimiser=None, name: Optional[str] = None, + use_jax : bool = False, **kwargs, ): """ @@ -70,6 +72,7 @@ def __init__( self._name = name or namer(self.__class__.__name__) self._factors = list() self.optimiser = optimiser + self.use_jax = use_jax @property def name(self): @@ -159,7 +162,7 @@ def __init__( """ self.distribution_model = distribution_model self.drawn_prior = drawn_prior - self.use_jax = use_jax + self.use_jax = distribution_model.use_jax prior_variable_dict = {prior.name: prior for prior in distribution_model.priors} @@ -188,7 +191,7 @@ def variable(self): return self.drawn_prior def log_likelihood_function(self, instance): - return instance.distribution_model.message(instance.drawn_prior) + return instance.distribution_model.message(instance.drawn_prior, xp=self._xp) @property def priors(self) -> Set[Prior]: diff --git a/autofit/messages/abstract.py b/autofit/messages/abstract.py index 2148c66e5..84d0739d4 100644 --- a/autofit/messages/abstract.py +++ b/autofit/messages/abstract.py @@ -366,8 +366,8 @@ def _get_mean_variance( ) return mean, variance - def __call__(self, x): - return np.sum(self.logpdf(x)) + def __call__(self, x, xp=np): + return xp.sum(self.logpdf(x, xp=xp)) def factor_jacobian( self, x: np.ndarray, _variables: Optional[Tuple[str]] = ("x",) diff --git a/autofit/messages/beta.py b/autofit/messages/beta.py index 2f4848e05..460b8a084 100644 --- a/autofit/messages/beta.py +++ b/autofit/messages/beta.py @@ -184,8 +184,7 @@ def log_partition(self) -> np.ndarray: return betaln(*self.parameters) - @cached_property - def natural_parameters(self) -> np.ndarray: + def natural_parameters(self, xp=np) -> np.ndarray: """ Compute the natural parameters of the Beta distribution. @@ -196,7 +195,8 @@ def natural_parameters(self) -> np.ndarray: """ return self.calc_natural_parameters( self.alpha, - self.beta + self.beta, + xp=xp ) @staticmethod @@ -258,7 +258,7 @@ def invert_sufficient_statistics( return cls.calc_natural_parameters(a, b) @classmethod - def to_canonical_form(cls, x: np.ndarray) -> np.ndarray: + def to_canonical_form(cls, x: np.ndarray, xp=np) -> np.ndarray: """ Convert a value x in (0,1) to the canonical sufficient statistics for Beta. @@ -271,7 +271,7 @@ def to_canonical_form(cls, x: np.ndarray) -> np.ndarray: ------- Canonical sufficient statistics [log(x), log(1 - x)]. """ - return np.array([np.log(x), np.log1p(-x)]) + return xp.array([xp.log(x), xp.log1p(-x)]) @cached_property def mean(self) -> Union[np.ndarray, float]: diff --git a/autofit/messages/composed_transform.py b/autofit/messages/composed_transform.py index 959647009..c06983dee 100644 --- a/autofit/messages/composed_transform.py +++ b/autofit/messages/composed_transform.py @@ -157,8 +157,7 @@ def project( def kl(self, dist): return self.base_message.kl(dist.base_message) - @property - def natural_parameters(self): + def natural_parameters(self, xp=np) -> np.ndarray: return self.base_message.natural_parameters @inverse_transform @@ -245,8 +244,8 @@ def calc_log_base_measure(self, x) -> np.ndarray: return self.base_message.calc_log_base_measure(x) @transform - def to_canonical_form(self, x) -> np.ndarray: - return self.base_message.to_canonical_form(x) + def to_canonical_form(self, x, xp=np) -> np.ndarray: + return self.base_message.to_canonical_form(x, xp=xp) @property @inverse_transform diff --git a/autofit/messages/fixed.py b/autofit/messages/fixed.py index 496a62804..8bf990a48 100644 --- a/autofit/messages/fixed.py +++ b/autofit/messages/fixed.py @@ -25,8 +25,7 @@ def __init__( def value_for(self, unit: float) -> float: raise NotImplemented() - @cached_property - def natural_parameters(self) -> Tuple[np.ndarray, ...]: + def natural_parameters(self, xp=np) -> Tuple[np.ndarray, ...]: return self.parameters @staticmethod @@ -35,7 +34,7 @@ def invert_natural_parameters(natural_parameters: np.ndarray return natural_parameters, @staticmethod - def to_canonical_form(x: np.ndarray) -> np.ndarray: + def to_canonical_form(x: np.ndarray, xp=np) -> np.ndarray: return x @cached_property diff --git a/autofit/messages/gamma.py b/autofit/messages/gamma.py index ae86a615e..3b4ed79c5 100644 --- a/autofit/messages/gamma.py +++ b/autofit/messages/gamma.py @@ -36,9 +36,8 @@ def __init__( def value_for(self, unit: float) -> float: raise NotImplemented() - @cached_property - def natural_parameters(self): - return self.calc_natural_parameters(self.alpha, self.beta) + def natural_parameters(self, xp=np) -> np.ndarray: + return self.calc_natural_parameters(self.alpha, self.beta, xp=xp) @staticmethod def calc_natural_parameters(alpha, beta): @@ -50,8 +49,8 @@ def invert_natural_parameters(natural_parameters): return eta1 + 1, -eta2 @staticmethod - def to_canonical_form(x): - return np.array([np.log(x), x]) + def to_canonical_form(x, xp=np): + return xp.array([np.log(x), x]) @classmethod def invert_sufficient_statistics(cls, suff_stats): diff --git a/autofit/messages/interface.py b/autofit/messages/interface.py index daff57479..3826cb821 100644 --- a/autofit/messages/interface.py +++ b/autofit/messages/interface.py @@ -44,32 +44,31 @@ def __eq__(self, other): def pdf(self, x: np.ndarray) -> np.ndarray: return np.exp(self.logpdf(x)) - def logpdf(self, x: Union[np.ndarray, float]) -> np.ndarray: - eta = self._broadcast_natural_parameters(x) - t = self.to_canonical_form(x) + def logpdf(self, x: Union[np.ndarray, float], xp=np) -> np.ndarray: + eta = self._broadcast_natural_parameters(x, xp=xp) + t = self.to_canonical_form(x, xp=xp) log_base = self.calc_log_base_measure(x) return self.natural_logpdf(eta, t, log_base, self.log_partition) - def _broadcast_natural_parameters(self, x): - shape = np.shape(x) + def _broadcast_natural_parameters(self, x, xp=np): + shape = xp.shape(x) if shape == self.shape: - return self.natural_parameters + return self.natural_parameters(xp=xp) elif shape[1:] == self.shape: - return self.natural_parameters[:, None, ...] + return self.natural_parameters(xp=xp)[:, None, ...] else: raise ValueError( f"shape of passed value {shape} does not " f"match message shape {self.shape}" ) - @cached_property @abstractmethod - def natural_parameters(self): + def natural_parameters(self, xp=np) -> np.ndarray: pass @staticmethod @abstractmethod - def to_canonical_form(x: Union[np.ndarray, float]) -> np.ndarray: + def to_canonical_form(x: Union[np.ndarray, float], xp=np) -> np.ndarray: pass @classmethod diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index 84bd3f2ff..407b3c5a6 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -148,8 +148,7 @@ def ppf(self, x : Union[float, np.ndarray]) -> Union[float, np.ndarray]: return norm.ppf(x, loc=self.mean, scale=self.sigma) - @cached_property - def natural_parameters(self) -> np.ndarray: + def natural_parameters(self, xp=np) -> np.ndarray: """ The natural (canonical) parameters of the Gaussian distribution in exponential-family form. @@ -162,10 +161,10 @@ def natural_parameters(self) -> np.ndarray: ------- A NumPy array containing the two natural parameters [η₁, η₂]. """ - return self.calc_natural_parameters(self.mean, self.sigma) + return self.calc_natural_parameters(self.mean, self.sigma, xp=xp) @staticmethod - def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray]) -> np.ndarray: + def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray], xp=np) -> np.ndarray: """ Convert standard parameters of a Gaussian distribution (mean and standard deviation) into natural parameters used in its exponential family representation. @@ -184,7 +183,7 @@ def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, η₂ = -1 / (2σ²) """ precision = 1 / sigma**2 - return np.array([mu * precision, -precision / 2]) + return xp.array([mu * precision, -precision / 2]) @staticmethod def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, float]: @@ -207,7 +206,7 @@ def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, f return mu, sigma @staticmethod - def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray: + def to_canonical_form(x : Union[float, np.ndarray], xp=np) -> np.ndarray: """ Convert a scalar input `x` to its sufficient statistics for the Gaussian exponential family. @@ -223,7 +222,7 @@ def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray: ------- The sufficient statistics [x, x²]. """ - return np.array([x, x**2]) + return xp.array([x, x**2]) @classmethod def invert_sufficient_statistics(cls, suff_stats: Tuple[float, float]) -> np.ndarray: @@ -570,12 +569,11 @@ def calc_natural_parameters(eta1: float, eta2: float) -> np.ndarray: """ return np.array([eta1, eta2]) - @cached_property - def natural_parameters(self) -> np.ndarray: + def natural_parameters(self, xp=np) -> np.ndarray: """ Return the natural parameters of this distribution. """ - return self.calc_natural_parameters(*self.parameters) + return self.calc_natural_parameters(*self.parameters, xp=xp) @classmethod def invert_sufficient_statistics(cls, suff_stats: Tuple[float, float]) -> np.ndarray: diff --git a/autofit/messages/truncated_normal.py b/autofit/messages/truncated_normal.py index 07a34464c..a26292292 100644 --- a/autofit/messages/truncated_normal.py +++ b/autofit/messages/truncated_normal.py @@ -141,8 +141,7 @@ def ppf(self, x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: b = (self.upper_limit - self.mean) / self.sigma return truncnorm.ppf(x, a=a, b=b, loc=self.mean, scale=self.sigma) - @cached_property - def natural_parameters(self) -> np.ndarray: + def natural_parameters(self, xp=np) -> np.ndarray: """ The pseudo-natural (canonical) parameters of a truncated Gaussian distribution. @@ -159,7 +158,7 @@ def natural_parameters(self) -> np.ndarray: ------- A NumPy array containing the pseudo-natural parameters [η₁, η₂]. """ - return self.calc_natural_parameters(self.mean, self.sigma) + return self.calc_natural_parameters(self.mean, self.sigma, xp=xp) @staticmethod def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray]) -> np.ndarray: @@ -217,7 +216,7 @@ def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, f return mu, sigma @staticmethod - def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray: + def to_canonical_form(x : Union[float, np.ndarray], xp=np) -> np.ndarray: """ Convert a scalar input `x` to its sufficient statistics for the Gaussian exponential family. @@ -234,7 +233,7 @@ def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray: ------- The sufficient statistics [x, x²]. """ - return np.array([x, x**2]) + return xp.array([x, x**2]) @classmethod def invert_sufficient_statistics(cls, suff_stats: Tuple[float, float]) -> np.ndarray: @@ -638,12 +637,11 @@ def calc_natural_parameters( """ return np.array([eta1, eta2]) - @cached_property - def natural_parameters(self) -> np.ndarray: + def natural_parameters(self, xp=np) -> np.ndarray: """ Return the natural parameters of this distribution. """ - return self.calc_natural_parameters(*self.parameters, self.lower_limit, self.upper_limit) + return self.calc_natural_parameters(*self.parameters, self.lower_limit, self.upper_limit, xp=xp) @classmethod def invert_sufficient_statistics( From 6f28bba56fef7c5ade08f4575bfca4e72a8bb2ea Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 14:32:44 +0000 Subject: [PATCH 09/15] integration tests work, cleaning up unit tests --- autofit/messages/abstract.py | 2 +- autofit/messages/beta.py | 3 +-- autofit/messages/composed_transform.py | 3 +-- autofit/messages/fixed.py | 3 +-- autofit/messages/gamma.py | 6 +++--- autofit/messages/interface.py | 11 +++++------ autofit/messages/normal.py | 8 ++++---- autofit/messages/truncated_normal.py | 6 +++--- 8 files changed, 19 insertions(+), 23 deletions(-) diff --git a/autofit/messages/abstract.py b/autofit/messages/abstract.py index 84d0739d4..cb76b5f80 100644 --- a/autofit/messages/abstract.py +++ b/autofit/messages/abstract.py @@ -62,7 +62,7 @@ def __init__( self._broadcast = _xp.broadcast_arrays(*parameters) if self.shape: - self.parameters = tuple(xp.aarray(p) for p in parameters) + self.parameters = tuple(xp.asarray(p) for p in parameters) else: self.parameters = tuple(parameters) diff --git a/autofit/messages/beta.py b/autofit/messages/beta.py index 460b8a084..8a4696d92 100644 --- a/autofit/messages/beta.py +++ b/autofit/messages/beta.py @@ -171,8 +171,7 @@ def value_for(self, unit: float) -> float: """ raise NotImplemented() - @cached_property - def log_partition(self) -> np.ndarray: + def log_partition(self, xp=np) -> np.ndarray: """ Compute the log partition function (log normalization constant) of the Beta distribution. diff --git a/autofit/messages/composed_transform.py b/autofit/messages/composed_transform.py index c06983dee..0b7f4f34f 100644 --- a/autofit/messages/composed_transform.py +++ b/autofit/messages/composed_transform.py @@ -228,8 +228,7 @@ def invert_natural_parameters( def cdf(self, x): return self.base_message.cdf(x) - @property - def log_partition(self) -> np.ndarray: + def log_partition(self, xp=np) -> np.ndarray: return self.base_message.log_partition def invert_sufficient_statistics(self, sufficient_statistics): diff --git a/autofit/messages/fixed.py b/autofit/messages/fixed.py index 8bf990a48..1280d5e95 100644 --- a/autofit/messages/fixed.py +++ b/autofit/messages/fixed.py @@ -37,8 +37,7 @@ def invert_natural_parameters(natural_parameters: np.ndarray def to_canonical_form(x: np.ndarray, xp=np) -> np.ndarray: return x - @cached_property - def log_partition(self) -> np.ndarray: + def log_partition(self, xp=np) -> np.ndarray: return 0. @classmethod diff --git a/autofit/messages/gamma.py b/autofit/messages/gamma.py index 3b4ed79c5..05c587552 100644 --- a/autofit/messages/gamma.py +++ b/autofit/messages/gamma.py @@ -6,11 +6,11 @@ class GammaMessage(AbstractMessage): - @property - def log_partition(self): + + def log_partition(self, xp=np): from scipy import special - alpha, beta = GammaMessage.invert_natural_parameters(self.natural_parameters) + alpha, beta = GammaMessage.invert_natural_parameters(self.natural_parameters(xp=xp)) return special.gammaln(alpha) - alpha * np.log(beta) log_base_measure = 0.0 diff --git a/autofit/messages/interface.py b/autofit/messages/interface.py index 3826cb821..2ed1932e9 100644 --- a/autofit/messages/interface.py +++ b/autofit/messages/interface.py @@ -48,7 +48,7 @@ def logpdf(self, x: Union[np.ndarray, float], xp=np) -> np.ndarray: eta = self._broadcast_natural_parameters(x, xp=xp) t = self.to_canonical_form(x, xp=xp) log_base = self.calc_log_base_measure(x) - return self.natural_logpdf(eta, t, log_base, self.log_partition) + return self.natural_logpdf(eta, t, log_base, self.log_partition(xp=xp), xp=xp) def _broadcast_natural_parameters(self, x, xp=np): shape = xp.shape(x) @@ -75,15 +75,14 @@ def to_canonical_form(x: Union[np.ndarray, float], xp=np) -> np.ndarray: def calc_log_base_measure(cls, x): return cls.log_base_measure - @cached_property @abstractmethod - def log_partition(self) -> np.ndarray: + def log_partition(self, xp=np) -> np.ndarray: pass @classmethod - def natural_logpdf(cls, eta, t, log_base, log_partition): - eta_t = np.multiply(eta, t).sum(0) - return np.nan_to_num(log_base + eta_t - log_partition, nan=-np.inf) + def natural_logpdf(cls, eta, t, log_base, log_partition, xp=np): + eta_t = xp.multiply(eta, t).sum(0) + return xp.nan_to_num(log_base + eta_t - log_partition, nan=-xp.inf) def numerical_logpdf_gradient( self, x: np.ndarray, eps: float = 1e-6 diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index 407b3c5a6..d8d6773cc 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -45,8 +45,8 @@ def assert_sigma_non_negative(sigma, xp=np): raise ValueError("Sigma cannot be negative") class NormalMessage(AbstractMessage): - @cached_property - def log_partition(self): + + def log_partition(self, xp=np): """ Compute the log-partition function (also called log-normalizer or cumulant function) for the normal distribution in its natural (canonical) parameterization. @@ -59,8 +59,8 @@ def log_partition(self): A(η) = η₁² / (-4η₂) - 0.5 * log(-2η₂) This ensures normalization of the exponential-family distribution. """ - eta1, eta2 = self.natural_parameters - return -(eta1**2) / 4 / eta2 - np.log(-2 * eta2) / 2 + eta1, eta2 = self.natural_parameters(xp=xp) + return -(eta1**2) / 4 / eta2 - xp.log(-2 * eta2) / 2 log_base_measure = -0.5 * np.log(2 * np.pi) _support = ((-np.inf, np.inf),) diff --git a/autofit/messages/truncated_normal.py b/autofit/messages/truncated_normal.py index a26292292..e8cf91f83 100644 --- a/autofit/messages/truncated_normal.py +++ b/autofit/messages/truncated_normal.py @@ -25,8 +25,8 @@ def is_nan(value): class TruncatedNormalMessage(AbstractMessage): - @cached_property - def log_partition(self) -> float: + + def log_partition(self, xp=np) -> float: """ Compute the log-partition function (normalizer) of the truncated Gaussian. @@ -46,7 +46,7 @@ def log_partition(self) -> float: a = (self.lower_limit - self.mean) / self.sigma b = (self.upper_limit - self.mean) / self.sigma Z = norm.cdf(b) - norm.cdf(a) - return np.log(Z) if Z > 0 else -np.inf + return xp.log(Z) if Z > 0 else -xp.inf log_base_measure = -0.5 * np.log(2 * np.pi) From c62083063b3913c694b0611f87d9a0d61b6444fb Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 15:03:54 +0000 Subject: [PATCH 10/15] hacks which push us most of the way through --- autofit/messages/abstract.py | 9 +++++---- autofit/messages/beta.py | 5 +++-- autofit/messages/composed_transform.py | 8 ++++---- autofit/messages/gamma.py | 8 ++++---- autofit/messages/interface.py | 27 +++++++++++++++++++------- autofit/messages/normal.py | 9 +++++---- autofit/messages/truncated_normal.py | 17 ++++++++-------- 7 files changed, 50 insertions(+), 33 deletions(-) diff --git a/autofit/messages/abstract.py b/autofit/messages/abstract.py index cb76b5f80..e6a9fd860 100644 --- a/autofit/messages/abstract.py +++ b/autofit/messages/abstract.py @@ -56,6 +56,7 @@ def __init__( self.id = next(self.ids) if id_ is None else id_ self.log_norm = log_norm + if xp is np: self._broadcast = np.broadcast(*parameters) else: @@ -222,7 +223,7 @@ def __truediv__(self, other: Union["AbstractMessage", Real]) -> "AbstractMessage ) def __pow__(self, other: Real) -> "AbstractMessage": - natural = self.natural_parameters + natural = self.natural_parameters() new_params = other * natural log_norm = other * self.log_norm new = self.from_natural_parameters( @@ -313,12 +314,12 @@ def log_normalisation(self, *elems: Union["AbstractMessage", float]) -> np.ndarr ] # Calculate log product of message normalisation - log_norm = self.log_base_measure - self.log_partition - log_norm += sum(dist.log_base_measure - dist.log_partition for dist in dists) + log_norm = self.log_base_measure - self.log_partition() + log_norm += sum(dist.log_base_measure - dist.log_partition() for dist in dists) # Calculate log normalisation of product of messages prod_dist = self.sum_natural_parameters(*dists) - log_norm -= prod_dist.log_base_measure - prod_dist.log_partition + log_norm -= prod_dist.log_base_measure - prod_dist.log_partition() return log_norm diff --git a/autofit/messages/beta.py b/autofit/messages/beta.py index 8a4696d92..9b9e6b919 100644 --- a/autofit/messages/beta.py +++ b/autofit/messages/beta.py @@ -201,7 +201,8 @@ def natural_parameters(self, xp=np) -> np.ndarray: @staticmethod def calc_natural_parameters( alpha: Union[float, np.ndarray], - beta: Union[float, np.ndarray] + beta: Union[float, np.ndarray], + xp=np ) -> np.ndarray: """ Calculate the natural parameters of a Beta distribution from alpha and beta. @@ -217,7 +218,7 @@ def calc_natural_parameters( ------- Natural parameters [alpha - 1, beta - 1]. """ - return np.array([alpha - 1, beta - 1]) + return xp.array([alpha - 1, beta - 1]) @staticmethod def invert_natural_parameters( diff --git a/autofit/messages/composed_transform.py b/autofit/messages/composed_transform.py index 0b7f4f34f..b4d1fa1d0 100644 --- a/autofit/messages/composed_transform.py +++ b/autofit/messages/composed_transform.py @@ -34,9 +34,9 @@ def transform(func): """ @functools.wraps(func) - def wrapper(self, x): + def wrapper(self, x, xp): x = self._transform(x) - return func(self, x) + return func(self, x, xp) return wrapper @@ -239,8 +239,8 @@ def value_for(self, unit): return self.base_message.value_for(unit) @transform - def calc_log_base_measure(self, x) -> np.ndarray: - return self.base_message.calc_log_base_measure(x) + def calc_log_base_measure(self, x, xp=np) -> np.ndarray: + return self.base_message.calc_log_base_measure(x, xp=xp) @transform def to_canonical_form(self, x, xp=np) -> np.ndarray: diff --git a/autofit/messages/gamma.py b/autofit/messages/gamma.py index 05c587552..3677186f0 100644 --- a/autofit/messages/gamma.py +++ b/autofit/messages/gamma.py @@ -40,8 +40,8 @@ def natural_parameters(self, xp=np) -> np.ndarray: return self.calc_natural_parameters(self.alpha, self.beta, xp=xp) @staticmethod - def calc_natural_parameters(alpha, beta): - return np.array([alpha - 1, -beta]) + def calc_natural_parameters(alpha, beta, xp=np): + return xp.array([alpha - 1, -beta]) @staticmethod def invert_natural_parameters(natural_parameters): @@ -97,13 +97,13 @@ def kl(self, dist): def logpdf_gradient(self, x): logl = self.logpdf(x) - eta1 = self.natural_parameters[0] + eta1 = self.natural_parameters()[0] gradl = eta1 / x - self.beta return logl, gradl def logpdf_gradient_hessian(self, x): logl = self.logpdf(x) - eta1 = self.natural_parameters[0] + eta1 = self.natural_parameters()[0] gradl = eta1 / x hessl = -gradl / x gradl -= self.beta diff --git a/autofit/messages/interface.py b/autofit/messages/interface.py index 2ed1932e9..e7f3b3136 100644 --- a/autofit/messages/interface.py +++ b/autofit/messages/interface.py @@ -45,9 +45,11 @@ def pdf(self, x: np.ndarray) -> np.ndarray: return np.exp(self.logpdf(x)) def logpdf(self, x: Union[np.ndarray, float], xp=np) -> np.ndarray: + eta = self._broadcast_natural_parameters(x, xp=xp) t = self.to_canonical_form(x, xp=xp) - log_base = self.calc_log_base_measure(x) + log_base = self.calc_log_base_measure(x, xp=xp) + return self.natural_logpdf(eta, t, log_base, self.log_partition(xp=xp), xp=xp) def _broadcast_natural_parameters(self, x, xp=np): @@ -72,7 +74,7 @@ def to_canonical_form(x: Union[np.ndarray, float], xp=np) -> np.ndarray: pass @classmethod - def calc_log_base_measure(cls, x): + def calc_log_base_measure(cls, x, xp=np): return cls.log_base_measure @abstractmethod @@ -81,8 +83,19 @@ def log_partition(self, xp=np) -> np.ndarray: @classmethod def natural_logpdf(cls, eta, t, log_base, log_partition, xp=np): + + try: + eta = eta() + except TypeError: + pass + + try: + log_partition_in = log_partition(xp=xp) + except TypeError: + log_partition_in = log_partition + eta_t = xp.multiply(eta, t).sum(0) - return xp.nan_to_num(log_base + eta_t - log_partition, nan=-xp.inf) + return xp.nan_to_num(log_base + eta_t - log_partition_in, nan=-xp.inf) def numerical_logpdf_gradient( self, x: np.ndarray, eps: float = 1e-6 @@ -184,7 +197,7 @@ def check_support(self) -> np.ndarray: pass def check_finite(self) -> np.ndarray: - return np.isfinite(self.natural_parameters).all(0) + return np.isfinite(self.natural_parameters()).all(0) def check_valid(self) -> np.ndarray: return self.check_finite() & self.check_support() @@ -204,11 +217,11 @@ def sum_natural_parameters(self, *dists: "MessageInterface") -> "MessageInterfac """ new_params = sum( ( - dist.natural_parameters + dist.natural_parameters() for dist in self._iter_dists(dists) if isinstance(dist, MessageInterface) ), - self.natural_parameters, + self.natural_parameters(), ) return self.from_natural_parameters( new_params, @@ -220,7 +233,7 @@ def sub_natural_parameters(self, other: "MessageInterface") -> "MessageInterface of this distribution with another distribution of the same type""" log_norm = self.log_norm - other.log_norm - new_params = self.natural_parameters - other.natural_parameters + new_params = self.natural_parameters() - other.natural_parameters() return self.from_natural_parameters( new_params, log_norm=log_norm, diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index d8d6773cc..56af526aa 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -93,7 +93,8 @@ def __init__( id_ An optional unique identifier used to track the message in larger probabilistic graphs or models. """ - if isinstance(mean, (np.ndarray, float, int)): + + if isinstance(mean, (np.ndarray, float, int, list)): xp = np else: import jax.numpy as jnp @@ -470,7 +471,7 @@ def natural(self)-> "NaturalNormal": A natural form Gaussian with zeroed parameters but same configuration. """ return NaturalNormal.from_natural_parameters( - self.natural_parameters * 0.0, **self._init_kwargs + self.natural_parameters() * 0.0, **self._init_kwargs ) def zeros_like(self) -> "AbstractMessage": @@ -556,7 +557,7 @@ def mean(self) -> float: return np.nan_to_num(-self.parameters[0] / self.parameters[1] / 2) @staticmethod - def calc_natural_parameters(eta1: float, eta2: float) -> np.ndarray: + def calc_natural_parameters(eta1: float, eta2: float, xp=np) -> np.ndarray: """ Return the natural parameters in array form (identity function for this class). @@ -567,7 +568,7 @@ def calc_natural_parameters(eta1: float, eta2: float) -> np.ndarray: eta2 The second natural parameter. """ - return np.array([eta1, eta2]) + return xp.array([eta1, eta2]) def natural_parameters(self, xp=np) -> np.ndarray: """ diff --git a/autofit/messages/truncated_normal.py b/autofit/messages/truncated_normal.py index e8cf91f83..d7ac5c89b 100644 --- a/autofit/messages/truncated_normal.py +++ b/autofit/messages/truncated_normal.py @@ -161,7 +161,7 @@ def natural_parameters(self, xp=np) -> np.ndarray: return self.calc_natural_parameters(self.mean, self.sigma, xp=xp) @staticmethod - def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray]) -> np.ndarray: + def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, np.ndarray], xp=np) -> np.ndarray: """ Convert standard parameters of a Gaussian distribution (mean and standard deviation) into natural parameters used in its exponential family representation. @@ -189,7 +189,7 @@ def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, η₂ = -1 / (2σ²) """ precision = 1 / sigma**2 - return np.array([mu * precision, -precision / 2]) + return xp.array([mu * precision, -precision / 2]) @staticmethod def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, float]: @@ -493,7 +493,7 @@ def natural(self)-> "NaturalNormal": A natural form Gaussian with zeroed parameters but same configuration. """ return TruncatedNaturalNormal.from_natural_parameters( - self.natural_parameters * 0.0, **self._init_kwargs + self.natural_parameters() * 0.0, **self._init_kwargs ) def zeros_like(self) -> "AbstractMessage": @@ -617,10 +617,11 @@ def mean(self) -> float: @staticmethod def calc_natural_parameters( - eta1: float, - eta2: float, - lower_limit: float = -np.inf, - upper_limit: float = np.inf + eta1: float, + eta2: float, + lower_limit: float = -np.inf, + upper_limit: float = np.inf, + xp=np ) -> np.ndarray: """ Return the natural parameters in array form (identity function for this class). @@ -635,7 +636,7 @@ def calc_natural_parameters( eta2 The second natural parameter. """ - return np.array([eta1, eta2]) + return xp.array([eta1, eta2]) def natural_parameters(self, xp=np) -> np.ndarray: """ From 6c78390b5e036d95850a8202ee41afab9f7b8f4f Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 15:06:46 +0000 Subject: [PATCH 11/15] fix removing dodgy hacks --- autofit/messages/composed_transform.py | 4 ++-- autofit/messages/interface.py | 13 +------------ 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/autofit/messages/composed_transform.py b/autofit/messages/composed_transform.py index b4d1fa1d0..28fed6d9a 100644 --- a/autofit/messages/composed_transform.py +++ b/autofit/messages/composed_transform.py @@ -158,7 +158,7 @@ def kl(self, dist): return self.base_message.kl(dist.base_message) def natural_parameters(self, xp=np) -> np.ndarray: - return self.base_message.natural_parameters + return self.base_message.natural_parameters(xp=xp) @inverse_transform def sample(self, n_samples: Optional[int] = None): @@ -229,7 +229,7 @@ def cdf(self, x): return self.base_message.cdf(x) def log_partition(self, xp=np) -> np.ndarray: - return self.base_message.log_partition + return self.base_message.log_partition(xp=xp) def invert_sufficient_statistics(self, sufficient_statistics): return self.base_message.invert_sufficient_statistics(sufficient_statistics) diff --git a/autofit/messages/interface.py b/autofit/messages/interface.py index e7f3b3136..c0d589553 100644 --- a/autofit/messages/interface.py +++ b/autofit/messages/interface.py @@ -83,19 +83,8 @@ def log_partition(self, xp=np) -> np.ndarray: @classmethod def natural_logpdf(cls, eta, t, log_base, log_partition, xp=np): - - try: - eta = eta() - except TypeError: - pass - - try: - log_partition_in = log_partition(xp=xp) - except TypeError: - log_partition_in = log_partition - eta_t = xp.multiply(eta, t).sum(0) - return xp.nan_to_num(log_base + eta_t - log_partition_in, nan=-xp.inf) + return xp.nan_to_num(log_base + eta_t - log_partition, nan=-xp.inf) def numerical_logpdf_gradient( self, x: np.ndarray, eps: float = 1e-6 From 0bbb3dfbb4865360b5f733e1cd252715fb1298cd Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 15:13:38 +0000 Subject: [PATCH 12/15] use_jax made private --- autofit/graphical/declarative/factor/hierarchical.py | 6 +++--- autofit/non_linear/analysis/analysis.py | 4 ++-- test_autofit/graphical/global/test_hierarchical.py | 3 +-- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/autofit/graphical/declarative/factor/hierarchical.py b/autofit/graphical/declarative/factor/hierarchical.py index 0ec5fde9b..69bd3c272 100644 --- a/autofit/graphical/declarative/factor/hierarchical.py +++ b/autofit/graphical/declarative/factor/hierarchical.py @@ -72,7 +72,7 @@ def __init__( self._name = name or namer(self.__class__.__name__) self._factors = list() self.optimiser = optimiser - self.use_jax = use_jax + self._use_jax = use_jax @property def name(self): @@ -147,7 +147,7 @@ def __call__(self, **kwargs): class _HierarchicalFactor(AbstractModelFactor): def __init__( - self, distribution_model: HierarchicalFactor, drawn_prior: Prior, use_jax : bool = False + self, distribution_model: HierarchicalFactor, drawn_prior: Prior, ): """ A factor that links a variable to a parameterised distribution. @@ -162,7 +162,7 @@ def __init__( """ self.distribution_model = distribution_model self.drawn_prior = drawn_prior - self.use_jax = distribution_model.use_jax + self._use_jax = distribution_model._use_jax prior_variable_dict = {prior.name: prior for prior in distribution_model.priors} diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 996bb995f..441bc70b4 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -36,7 +36,7 @@ def __init__( self, use_jax : bool = False, **kwargs ): - self.use_jax = use_jax + self._use_jax = use_jax self.kwargs = kwargs def __getattr__(self, item: str): @@ -63,7 +63,7 @@ def method(*args, **kwargs): @property def _xp(self): - if self.use_jax: + if self._use_jax: import jax.numpy as jnp return jnp return np diff --git a/test_autofit/graphical/global/test_hierarchical.py b/test_autofit/graphical/global/test_hierarchical.py index 5e90d3704..26b3e4da3 100644 --- a/test_autofit/graphical/global/test_hierarchical.py +++ b/test_autofit/graphical/global/test_hierarchical.py @@ -55,8 +55,7 @@ def test_model_info(model): 2 - 3 one UniformPrior [0], lower_limit = 0.0, upper_limit = 1.0 factor - include_prior_factors True - use_jax False""" + include_prior_factors True""" ) From 2cb4a044c0503155b77a943f64b2766fba77a36b Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 15:14:33 +0000 Subject: [PATCH 13/15] include prior factors made private --- autofit/graphical/declarative/abstract.py | 6 +++--- autofit/graphical/declarative/collection.py | 2 +- autofit/graphical/declarative/factor/abstract.py | 2 +- test_autofit/graphical/global/test_hierarchical.py | 4 +--- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/autofit/graphical/declarative/abstract.py b/autofit/graphical/declarative/abstract.py index e633c71df..35e1c177f 100644 --- a/autofit/graphical/declarative/abstract.py +++ b/autofit/graphical/declarative/abstract.py @@ -20,7 +20,7 @@ class AbstractDeclarativeFactor(Analysis, ABC): _plates: Tuple[Plate, ...] = () def __init__(self, include_prior_factors=False, use_jax : bool = False): - self.include_prior_factors = include_prior_factors + self._include_prior_factors = include_prior_factors super().__init__(use_jax=use_jax) @@ -63,7 +63,7 @@ def prior_counts(self) -> List[Tuple[Prior, int]]: for prior in factor.prior_model.priors: counter[prior] += 1 return [ - (prior, count + 1 if self.include_prior_factors else count) + (prior, count + 1 if self._include_prior_factors else count) for prior, count in counter.items() ] @@ -96,7 +96,7 @@ def graph(self) -> DeclarativeFactorGraph: The complete graph made by combining all factors and priors """ factors = [model for model in self.model_factors] - if self.include_prior_factors: + if self._include_prior_factors: factors += self.prior_factors # noinspection PyTypeChecker return DeclarativeFactorGraph(factors) diff --git a/autofit/graphical/declarative/collection.py b/autofit/graphical/declarative/collection.py index 97fb9824f..6dff83693 100644 --- a/autofit/graphical/declarative/collection.py +++ b/autofit/graphical/declarative/collection.py @@ -42,7 +42,7 @@ def __init__( def tree_flatten(self): return ( (self._model_factors,), - (self._name, self.include_prior_factors), + (self._name, self._include_prior_factors), ) @classmethod diff --git a/autofit/graphical/declarative/factor/abstract.py b/autofit/graphical/declarative/factor/abstract.py index 3df6a9f0c..67bfec786 100644 --- a/autofit/graphical/declarative/factor/abstract.py +++ b/autofit/graphical/declarative/factor/abstract.py @@ -42,7 +42,7 @@ def __init__( factor, **prior_variable_dict, name=name or namer(self.__class__.__name__), ) - self.include_prior_factors = include_prior_factors + self._include_prior_factors = include_prior_factors @property def info(self) -> str: diff --git a/test_autofit/graphical/global/test_hierarchical.py b/test_autofit/graphical/global/test_hierarchical.py index 26b3e4da3..6f22fa874 100644 --- a/test_autofit/graphical/global/test_hierarchical.py +++ b/test_autofit/graphical/global/test_hierarchical.py @@ -53,9 +53,7 @@ def test_model_info(model): 1 drawn_prior UniformPrior [1], lower_limit = 0.0, upper_limit = 1.0 2 - 3 - one UniformPrior [0], lower_limit = 0.0, upper_limit = 1.0 -factor - include_prior_factors True""" + one UniformPrior [0], lower_limit = 0.0, upper_limit = 1.0""" ) From fb9d8b97071c0141f3f93a28bfcebbca3158e683 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 15:17:54 +0000 Subject: [PATCH 14/15] fix issues with _use_jax now being private --- autofit/non_linear/analysis/analysis.py | 4 ++-- autofit/non_linear/fitness.py | 2 +- autofit/non_linear/search/nest/dynesty/search/abstract.py | 2 +- autofit/non_linear/search/nest/nautilus/search.py | 2 +- test_autofit/graphical/hierarchical/test_optimise.py | 7 ------- 5 files changed, 5 insertions(+), 12 deletions(-) diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 441bc70b4..3bdeab220 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -109,7 +109,7 @@ def compute_latent_samples(self, samples: Samples, batch_size : Optional[int] = compute_latent_for_model = functools.partial(self.compute_latent_variables, model=samples.model) - if self.use_jax: + if self._use_jax: import jax start = time.time() logger.info("JAX: Applying vmap and jit to likelihood function for latent variables -- may take a few seconds.") @@ -130,7 +130,7 @@ def batched_compute_latent(x): # batched JAX call on this chunk latent_values_batch = batched_compute_latent(batch) - if self.use_jax: + if self._use_jax: import jax.numpy as jnp latent_values_batch = jnp.stack(latent_values_batch, axis=-1) # (batch, n_latents) mask = jnp.all(jnp.isfinite(latent_values_batch), axis=0) diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index e3a7ab0b0..a3792b701 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -134,7 +134,7 @@ def __init__( @property def _xp(self): - if self.analysis.use_jax: + if self.analysis._use_jax: import jax.numpy as jnp return jnp return np diff --git a/autofit/non_linear/search/nest/dynesty/search/abstract.py b/autofit/non_linear/search/nest/dynesty/search/abstract.py index b566a1c50..913bd2b62 100644 --- a/autofit/non_linear/search/nest/dynesty/search/abstract.py +++ b/autofit/non_linear/search/nest/dynesty/search/abstract.py @@ -146,7 +146,7 @@ def _fit( "parallel" ].get("force_x1_cpu") or self.kwargs.get("force_x1_cpu") - or analysis.use_jax + or analysis._use_jax ): raise RuntimeError diff --git a/autofit/non_linear/search/nest/nautilus/search.py b/autofit/non_linear/search/nest/nautilus/search.py index 30531edbe..3977843b2 100644 --- a/autofit/non_linear/search/nest/nautilus/search.py +++ b/autofit/non_linear/search/nest/nautilus/search.py @@ -127,7 +127,7 @@ def _fit(self, model: AbstractPriorModel, analysis): if ( self.config_dict.get("force_x1_cpu") or self.kwargs.get("force_x1_cpu") - or analysis.use_jax + or analysis._use_jax ): fitness = Fitness( diff --git a/test_autofit/graphical/hierarchical/test_optimise.py b/test_autofit/graphical/hierarchical/test_optimise.py index 12caf64f2..f500029ae 100644 --- a/test_autofit/graphical/hierarchical/test_optimise.py +++ b/test_autofit/graphical/hierarchical/test_optimise.py @@ -11,13 +11,6 @@ def make_factor(hierarchical_factor): def test_optimise(factor): search = af.DynestyStatic(maxcall=100, dynamic_delta=False, delta=0.1,) - print(type(factor.analysis)) - print(type(factor.analysis)) - print(type(factor.analysis)) - print(type(factor.analysis)) - print(type(factor.analysis)) - - _, status = search.optimise( factor.mean_field_approximation().factor_approximation(factor) ) From ee14b5bfeca7d31907c6dc9adaca50091d9c9f6f Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 30 Nov 2025 16:18:28 +0000 Subject: [PATCH 15/15] remove print statements --- autofit/messages/composed_transform.py | 6 +++--- autofit/messages/normal.py | 2 +- autofit/messages/truncated_normal.py | 2 +- autofit/non_linear/search/mle/bfgs/search.py | 4 ---- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/autofit/messages/composed_transform.py b/autofit/messages/composed_transform.py index 28fed6d9a..84792b062 100644 --- a/autofit/messages/composed_transform.py +++ b/autofit/messages/composed_transform.py @@ -34,7 +34,7 @@ def transform(func): """ @functools.wraps(func) - def wrapper(self, x, xp): + def wrapper(self, x, xp=np): x = self._transform(x) return func(self, x, xp) @@ -225,8 +225,8 @@ def invert_natural_parameters( return self.base_message.invert_natural_parameters(natural_parameters) @transform - def cdf(self, x): - return self.base_message.cdf(x) + def cdf(self, x, xp=np): + return self.base_message.cdf(x, xp=xp) def log_partition(self, xp=np) -> np.ndarray: return self.base_message.log_partition(xp=xp) diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index 56af526aa..f7128aa97 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -111,7 +111,7 @@ def __init__( ) self.mean, self.sigma = self.parameters - def cdf(self, x : Union[float, np.ndarray]) -> Union[float, np.ndarray]: + def cdf(self, x : Union[float, np.ndarray], xp=np) -> Union[float, np.ndarray]: """ Compute the cumulative distribution function (CDF) of the Gaussian distribution at a given value or array of values `x`. diff --git a/autofit/messages/truncated_normal.py b/autofit/messages/truncated_normal.py index d7ac5c89b..081e4d99d 100644 --- a/autofit/messages/truncated_normal.py +++ b/autofit/messages/truncated_normal.py @@ -96,7 +96,7 @@ def __init__( self.mean, self.sigma, self.lower_limit, self.upper_limit = self.parameters - def cdf(self, x: Union[float, np.ndarray]) -> Union[float, np.ndarray]: + def cdf(self, x: Union[float, np.ndarray], xp=np) -> Union[float, np.ndarray]: """ Compute the cumulative distribution function (CDF) of the truncated Gaussian distribution at a given value or array of values `x`. diff --git a/autofit/non_linear/search/mle/bfgs/search.py b/autofit/non_linear/search/mle/bfgs/search.py index 42a849235..8ca8f064b 100644 --- a/autofit/non_linear/search/mle/bfgs/search.py +++ b/autofit/non_linear/search/mle/bfgs/search.py @@ -225,10 +225,6 @@ def samples_via_internal_from( weight_list = len(log_likelihood_list) * [1.0] - print(parameter_lists) - print(log_likelihood_list) - print(log_prior_list) - sample_list = Sample.from_lists( model=model, parameter_lists=parameter_lists,