Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions autofit/messages/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,21 @@ def __init__(

self.id = next(self.ids) if id_ is None else id_
self.log_norm = log_norm
self._broadcast = np.broadcast(*parameters)

self._broadcast = None
self._broadcast_jnp = None

if isinstance(parameters[0], (np.float64, float, int)):
self._broadcast = np.broadcast(*parameters)
else:
import jax.numpy as jnp
self._broadcast_jnp = jnp.broadcast_arrays(*parameters)

if self.shape:
self.parameters = tuple(np.asanyarray(p) for p in parameters)
if isinstance(parameters[0], (np.float64, float, int)):
self.parameters = tuple(np.asanyarray(p) for p in parameters)
else:
self.parameters = tuple(jnp.asarray(p) for p in parameters)
else:
self.parameters = tuple(parameters)

Expand Down
18 changes: 15 additions & 3 deletions autofit/messages/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ def broadcast(self):

@property
def shape(self) -> Tuple[int, ...]:
return self.broadcast.shape

if self.broadcast is not None:
return self.broadcast.shape

return ()

@property
def size(self) -> int:
Expand All @@ -47,6 +51,7 @@ def logpdf(self, x: Union[np.ndarray, float]) -> np.ndarray:

def _broadcast_natural_parameters(self, x):
shape = np.shape(x)
print(shape, self.shape)
if shape == self.shape:
return self.natural_parameters
elif shape[1:] == self.shape:
Expand Down Expand Up @@ -78,8 +83,15 @@ def log_partition(self) -> np.ndarray:

@classmethod
def natural_logpdf(cls, eta, t, log_base, log_partition):
eta_t = np.multiply(eta, t).sum(0)
return np.nan_to_num(log_base + eta_t - log_partition, nan=-np.inf)

if isinstance(eta, (np.ndarray, np.float64)):
eta_t = np.multiply(eta, t).sum(0)
return np.nan_to_num(log_base + eta_t - log_partition, nan=-np.inf)

import jax.numpy as jnp

eta_t = jnp.multiply(eta, t).sum(0)
return jnp.nan_to_num(log_base + eta_t - log_partition, nan=-jnp.inf)

def numerical_logpdf_gradient(
self, x: np.ndarray, eps: float = 1e-6
Expand Down
29 changes: 24 additions & 5 deletions autofit/messages/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,13 @@ def log_partition(self):
This ensures normalization of the exponential-family distribution.
"""
eta1, eta2 = self.natural_parameters
return -(eta1**2) / 4 / eta2 - np.log(-2 * eta2) / 2

if isinstance(eta1, (np.ndarray, np.float64)):
return -(eta1**2) / 4 / eta2 - np.log(-2 * eta2) / 2

import jax.numpy as jnp

return -(eta1**2) / 4 / eta2 - jnp.log(-2 * eta2) / 2

log_base_measure = -0.5 * np.log(2 * np.pi)
_support = ((-np.inf, np.inf),)
Expand Down Expand Up @@ -73,8 +79,9 @@ def __init__(
id_
An optional unique identifier used to track the message in larger probabilistic graphs or models.
"""
if (np.array(sigma) < 0).any():
raise exc.MessageException("Sigma cannot be negative")
if isinstance(sigma, (float, int, np.ndarray)):
if (np.array(sigma) < 0).any():
raise exc.MessageException("Sigma cannot be negative")

super().__init__(
mean,
Expand Down Expand Up @@ -158,7 +165,13 @@ def calc_natural_parameters(mu : Union[float, np.ndarray], sigma : Union[float,
η₂ = -1 / (2σ²)
"""
precision = 1 / sigma**2
return np.array([mu * precision, -precision / 2])

if isinstance(mu, (np.ndarray, np.float64)):
return np.array([mu * precision, -precision / 2])

import jax.numpy as jnp

return jnp.array([mu * precision, -precision / 2])

@staticmethod
def invert_natural_parameters(natural_parameters : np.ndarray) -> Tuple[float, float]:
Expand Down Expand Up @@ -197,7 +210,13 @@ def to_canonical_form(x : Union[float, np.ndarray]) -> np.ndarray:
-------
The sufficient statistics [x, x²].
"""
return np.array([x, x**2])

if isinstance(x, (np.ndarray, np.float64)):
return np.array([x, x**2])

import jax.numpy as jnp

return jnp.array([x, x**2])

@classmethod
def invert_sufficient_statistics(cls, suff_stats: Tuple[float, float]) -> np.ndarray:
Expand Down
Loading