From 4fa4cd9abc671dbe81a1f70255aceded767364f1 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 26 Nov 2025 18:02:57 +0000 Subject: [PATCH] this is a mess --- autofit/messages/abstract.py | 15 +++++++++++++-- autofit/messages/interface.py | 18 +++++++++++++++--- autofit/messages/normal.py | 29 ++++++++++++++++++++++++----- 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/autofit/messages/abstract.py b/autofit/messages/abstract.py index d3b921108..b856a5a3b 100644 --- a/autofit/messages/abstract.py +++ b/autofit/messages/abstract.py @@ -52,10 +52,21 @@ def __init__( self.id = next(self.ids) if id_ is None else id_ self.log_norm = log_norm - self._broadcast = np.broadcast(*parameters) + + self._broadcast = None + self._broadcast_jnp = None + + if isinstance(parameters[0], (np.float64, float, int)): + self._broadcast = np.broadcast(*parameters) + else: + import jax.numpy as jnp + self._broadcast_jnp = jnp.broadcast_arrays(*parameters) if self.shape: - self.parameters = tuple(np.asanyarray(p) for p in parameters) + if isinstance(parameters[0], (np.float64, float, int)): + self.parameters = tuple(np.asanyarray(p) for p in parameters) + else: + self.parameters = tuple(jnp.asarray(p) for p in parameters) else: self.parameters = tuple(parameters) diff --git a/autofit/messages/interface.py b/autofit/messages/interface.py index 28d46d2ab..2ec6e8a10 100644 --- a/autofit/messages/interface.py +++ b/autofit/messages/interface.py @@ -23,7 +23,11 @@ def broadcast(self): @property def shape(self) -> Tuple[int, ...]: - return self.broadcast.shape + + if self.broadcast is not None: + return self.broadcast.shape + + return () @property def size(self) -> int: @@ -47,6 +51,7 @@ def logpdf(self, x: Union[np.ndarray, float]) -> np.ndarray: def _broadcast_natural_parameters(self, x): shape = np.shape(x) + print(shape, self.shape) if shape == self.shape: return self.natural_parameters elif shape[1:] == self.shape: @@ -78,8 +83,15 @@ def log_partition(self) -> np.ndarray: @classmethod def natural_logpdf(cls, eta, t, log_base, log_partition): - eta_t = np.multiply(eta, t).sum(0) - return np.nan_to_num(log_base + eta_t - log_partition, nan=-np.inf) + + if isinstance(eta, (np.ndarray, np.float64)): + eta_t = np.multiply(eta, t).sum(0) + return np.nan_to_num(log_base + eta_t - log_partition, nan=-np.inf) + + import jax.numpy as jnp + + eta_t = jnp.multiply(eta, t).sum(0) + return jnp.nan_to_num(log_base + eta_t - log_partition, nan=-jnp.inf) def numerical_logpdf_gradient( self, x: np.ndarray, eps: float = 1e-6 diff --git a/autofit/messages/normal.py b/autofit/messages/normal.py index 9ff12ef3a..de8285770 100644 --- a/autofit/messages/normal.py +++ b/autofit/messages/normal.py @@ -40,7 +40,13 @@ def log_partition(self): This ensures normalization of the exponential-family distribution. """ eta1, eta2 = self.natural_parameters - return -(eta1**2) / 4 / eta2 - np.log(-2 * eta2) / 2 + + if isinstance(eta1, (np.ndarray, np.float64)): + return -(eta1**2) / 4 / eta2 - np.log(-2 * eta2) / 2 + + import jax.numpy as jnp + + return -(eta1**2) / 4 / eta2 - jnp.log(-2 * eta2) / 2 log_base_measure = -0.5 * np.log(2 * np.pi) _support = ((-np.inf, np.inf),) @@ -73,8 +79,9 @@ def __init__( id_ An optional unique identifier used to track the message in larger probabilistic graphs or models. """ - if (np.array(sigma) < 0).any(): - raise exc.MessageException("Sigma cannot be negative") + if isinstance(sigma, (float, int, np.ndarray)): + if (np.array(sigma) < 0).any(): + raise exc.MessageException("Sigma cannot be negative") super().__init__( mean, @@ -158,7 +165,13 @@ def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float, η₂ = -1 / (2σ²) """ precision = 1 / sigma**2 - return np.array([mu * precision, -precision / 2]) + + if isinstance(mu, (np.ndarray, np.float64)): + return np.array([mu * precision, -precision / 2]) + + import jax.numpy as jnp + + return jnp.array([mu * precision, -precision / 2]) @staticmethod def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, float]: @@ -197,7 +210,13 @@ def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray: ------- The sufficient statistics [x, x²]. """ - return np.array([x, x**2]) + + if isinstance(x, (np.ndarray, np.float64)): + return np.array([x, x**2]) + + import jax.numpy as jnp + + return jnp.array([x, x**2]) @classmethod def invert_sufficient_statistics(cls, suff_stats: Tuple[float, float]) -> np.ndarray: