Skip to content

feat #7713: implement ICARRV/CARRV as SymbolicRVs #7879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 47 additions & 10 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2133,6 +2133,7 @@ def logp(value, rng, size, mu, sigma, *covs):
return a


# TODO: change this code
class CARRV(RandomVariable):
name = "car"
signature = "(m),(m,m),(),(),()->(m)"
Expand Down Expand Up @@ -2344,21 +2345,56 @@ def logp(value, mu, W, alpha, tau, W_is_valid):
)


class ICARRV(RandomVariable):
class ICARRV(SymbolicMVNormalUsedInternally):
r"""A SymbolicRandomVariable representing an Intrinsic Conditional Autoregressive (ICAR) distribution.

This class contains the symbolic logic for the ICAR distribution, which is used by the
user-facing `pm.ICAR` class to generate random samples and compute log-probabilities.
"""

name = "icar"
signature = "(m,m),(),()->(m)"
dtype = "floatX"
extended_signature = "[rng],[size],(n,n),(),()->[rng],(n)"
_print_name = ("ICAR", "\\operatorname{ICAR}")

def __call__(self, W, sigma, zero_sum_stdev, size=None, **kwargs):
return super().__call__(W, sigma, zero_sum_stdev, size=size, **kwargs)

@classmethod
def rng_fn(cls, rng, size, W, sigma, zero_sum_stdev):
raise NotImplementedError("Cannot sample from ICAR prior")
def rv_op(cls, W, sigma, zero_sum_stdev, method="eigh", rng=None, size=None):
W = pt.as_tensor(W)
sigma = pt.as_tensor(sigma)
zero_sum_stdev = pt.as_tensor(zero_sum_stdev)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

if rv_size_is_none(size):
size = implicit_size_from_params(
W, sigma, zero_sum_stdev, ndims_params=cls.ndims_params
)

N = W.shape[0]

# Construct the precision matrix (graph Laplacian)
D = pt.diag(W.sum(axis=1))
Q = (D - W) / (sigma * sigma)

# Add regularization for the zero eigenvalue based on zero_sum_stdev
zero_sum_precision = 1.0 / (zero_sum_stdev * zero_sum_stdev)
Q_reg = Q + zero_sum_precision * pt.ones((N, N)) / N

# Convert precision to covariance matrix
cov = pt.linalg.inv(Q_reg) # TODO: Should this be matrix_inverse(Q_reg)

next_rng, mv_draws = multivariate_normal(
mean=pt.zeros(N),
cov=cov,
size=size,
rng=rng,
method=method,
).owner.outputs

icar = ICARRV()
return cls(
inputs=[rng, size, W, sigma, zero_sum_stdev],
outputs=[next_rng, mv_draws],
method=method,
)(rng, size, W, sigma, zero_sum_stdev)


class ICAR(Continuous):
Expand Down Expand Up @@ -2449,7 +2485,8 @@ class ICAR(Continuous):

"""

rv_op = icar
rv_type = ICARRV
rv_op = ICARRV.rv_op

@classmethod
def dist(cls, W, sigma=1, zero_sum_stdev=0.001, **kwargs):
Expand Down
40 changes: 36 additions & 4 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,12 +2279,43 @@ def test_icar_logp(self):
).eval(), "logp inaccuracy"

def test_icar_rng_fn(self):
W = np.array([[0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0]])
delta = 0.05 # limit for KS p-value
n_fails = 20 # Allows the KS fails a certain number of times
size = (100,)

W_val = np.array(
[[0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0, 0.0]]
)
sigma = 2.0
zero_sum_stdev = 0.1
N = W_val.shape[0]

D = np.diag(W_val.sum(axis=1))
Q = (D - W_val) / (sigma * sigma)
zero_sum_precision = 1.0 / (zero_sum_stdev**2)
Q_reg = Q + zero_sum_precision * np.ones((N, N)) / N
cov = np.linalg.inv(Q_reg)

# TODO: Should W be a pt.tensor ?
with pm.Model():
icar = pm.ICAR("icar", W=W_val, sigma=sigma, zero_sum_stdev=zero_sum_stdev, size=size)
mn = pm.MvNormal("mn", mu=0.0, cov=cov, size=size)
# Draw n_fails samples
check = pm.sample_prior_predictive(n_fails, return_inferencedata=False, random_seed=42)

RV = pm.ICAR.dist(W=W)
p, f = delta, n_fails
while p <= delta and f > 0:
icar_smp, mn_smp = check["icar"][f - 1, :, :], check["mn"][f - 1, :, :]
p = min(
st.ks_2samp(
np.atleast_1d(icar_smp[..., idx]).flatten(),
np.atleast_1d(mn_smp[..., idx]).flatten(),
)[1]
for idx in range(icar_smp.shape[-1])
)
f -= 1

with pytest.raises(NotImplementedError, match="Cannot sample from ICAR prior"):
pm.draw(RV)
assert p > delta

@pytest.mark.parametrize(
"W,msg",
Expand All @@ -2307,6 +2338,7 @@ def test_icar_matrix_checks(self, W, msg):
pm.ICAR("phi", W=W)


# TODO: Fix this after updating the rng approach
@pytest.mark.parametrize("sparse", [True, False])
def test_car_rng_fn(sparse):
delta = 0.05 # limit for KS p-value
Expand Down
Loading