diff --git a/environment.yml b/environment.yml index be125e07..b48bd033 100644 --- a/environment.yml +++ b/environment.yml @@ -1,13 +1,12 @@ name: ringdown channels: - conda-forge - - defaults dependencies: - python=3.12 - numpy=1.26.4 - lalsuite=7.23 - - h5py=3.11 - - arviz=0.19 + - h5py=3.12 + - arviz=0.20 - pandas=2.2 - qnm=0.4.3 - seaborn=0.13 diff --git a/ringdown/fit.py b/ringdown/fit.py index 67e4ab76..bb715c22 100644 --- a/ringdown/fit.py +++ b/ringdown/fit.py @@ -870,8 +870,7 @@ def run(self, if 'a_scale_max' in ms: ms['a_scale_max'] = ms['a_scale_max'] / self.strain_scale logging.info('making model') - model = make_model(self.modes.value, prior=prior, predictive=False, - store_h_det=False, store_h_det_mode=False, **ms) + model = make_model(self.modes.value, prior=prior, **ms) if return_model: return model diff --git a/ringdown/model.py b/ringdown/model.py index 9b846ee0..14bf0ff7 100644 --- a/ringdown/model.py +++ b/ringdown/model.py @@ -3,6 +3,7 @@ import numpy as np import jax.numpy as jnp import jax.scipy as jsp +from jax import lax import numpyro import numpyro.distributions as dist @@ -185,7 +186,7 @@ def chi_factors(chi, coeffs): log_sqrt_1m_chi2_4 = log_sqrt_1m_chi2_2*log_sqrt_1m_chi2_2 log_sqrt_1m_chi2_5 = log_sqrt_1m_chi2_3*log_sqrt_1m_chi2_2 log_sqrt_1m_chi2_6 = log_sqrt_1m_chi2_3*log_sqrt_1m_chi2_3 - + v = jnp.stack([ 1., log_1m_chi, @@ -199,7 +200,7 @@ def chi_factors(chi, coeffs): log_sqrt_1m_chi2_5, log_sqrt_1m_chi2_6 ]) - + return jnp.dot(coeffs, v) @@ -289,6 +290,146 @@ def get_quad_derived_quantities(nmodes, design_matrices, quads, a_scale, YpYc, return a, h_det +def make_likelihood_update(dms, ls, strains) -> tuple: + """A marginalized-likelihood update for a single detector, + to be used with 'jax.lax.scan' or 'numpyro.scan'. + + The meaning of the arguments is: + + mu : array_like + The prior mean; shape (nquads*nmode,). + Lambda_inv : array_like + The prior precision; shape (nquads*nmode, nquads*nmode). + Lambda_inv_chol : array_like + The Cholesky factor of the prior precision; shape (nquads*nmode, + nquads*nmode). + M : array_like + The design matrix; shape (ntime, nquads*nmode). + L : array_like + The noise covariance matrix; shape (ntime, ntime). + y : array_like + The strain data; shape (ntime,). + + Arguments + --------- + mu_Lambda_inv_Lambda_inv_chol : tuple + A tuple of the prior mean, precision, and Cholesky factor of the + precision. + M_L_y : tuple + A tuple of the design matrix, the noise covariance matrix, and the + strain data. + + Returns + ------- + mu_Lambda_inv_Lambda_inv_chol : tuple + The updated prior mean, precision, and Cholesky factor of the + precision. + """ + def likelihood_update(mu_Lambda_inv_Lambda_inv_chol, i): + # unpack (carry) variables and (current state) parameters + mu, Lambda_inv, Lambda_inv_chol = mu_Lambda_inv_Lambda_inv_chol + M = dms[i] + L = ls[i] + y = strains[i] + + # M acts as a coordinate transformation matrix, taking us + # from the space of quadratures to the space of the data , + # while M^T takes us from data space to quadrature space + # (M is ntime x nquads*nmode) + + # L whitens the noise in the detector, taking it from + # N(0, C) to N(0, I) (L is ntime x ntime) + + # we can use M and L to compute the precision (A_inv) of + # the marginal posterior on the quadratures (conditioned on + # the current data and nonlinear parameters), which is just + # the sum of the prior precision (Lambda_inv) and the + # likelihood precision (M^T C^-1 M): + # A_inv = Lambda_inv + M^T C^-1 M + # so that A and A_inv are (nquads*nmode, nquads*nmode) + A_inv = Lambda_inv + \ + jnp.dot(M.T, jsp.linalg.cho_solve((L, True), M)) + A_inv_chol = jsp.linalg.cholesky(A_inv, lower=True) + + # we can also compute the marginal-posterior mean (a), + # which is the precision-weighted sum of the prior mean + # (mu) and the likelihood mean (M^T C^-1 y): + # a = A_inv (Lambda_inv mu + M^T C^-1 y) + # so that a is (nquads*nmode,) + a = jsp.linalg.cho_solve( + (A_inv_chol, True), jnp.dot(Lambda_inv, mu) + + jnp.dot(M.T, jsp.linalg.cho_solve((L, True), y))) + + # the mean (b) of the marginal likelihood p(y|b, B), + # i.e., the likelihood obtained after integrating out + # the quadratures, is simply the value of the strain y + # corresponding to the mean quadratures, i.e., mu after + # a coordinate transformation: + # b = M mu + # so that b is (ntime,) + b = jnp.dot(M, mu) + + # the (co)variance of the marginal likelihood (B) is the + # sum of the variance from the noise (C) and the variance + # from the quadrature prior (Lambda): + # B = C + M Lambda M^T + # this is (ntime, ntime), which is large; but, to compute + # the marginal likelihood, we need the inverse covariance + # B^-1, so we can use the Woodbury identity to write: + # B^-1 = C^-1 - C^-1 M (Lambda^-1 + M^T C^-1 M)^-1 M^T C^-1 + # = C^-1 - C^-1 M A M^T C^-1 + # where A = A_inv^-1 per the above; this way we avoid + # inverting the large matrix B directly and take advantage + # of the precomputed Cholesky factor L to get C^-1 + + # with the residual r = y - b, the marginal log-likelihood + # becomes + # logl = -0.5 r^T B^-1 r - 0.5 log |2pi B| + # where |2pi B| is the determinant of 2pi*B and we can + # ignore the 2pi factor since it introduces a term like + # - 0.5*ntime*log(2pi), which is constant + r = y - b + Cinv_r = jsp.linalg.cho_solve((L, True), r) + + M_A_Mt_Cinv_r = jnp.dot(M, jsp.linalg.cho_solve( + (A_inv_chol, True), jnp.dot(M.T, Cinv_r))) + + Cinv_M_A_Mt_Cinv_r = \ + jsp.linalg.cho_solve((L, True), M_A_Mt_Cinv_r) + + # now all we have left to compute is the log determinant + # term, 0.5*log|B|; from the Gaussian refactorization, we + # have that + # |Lambda| |C| = |A| |B| + # and therefore + # log|B| = log|C| + log|Lambda| - log|A| + # furthermore, since |C| = |L|^2, we can write + # 0.5 log|C| = log|L| + # and |L| is the product of the diagonal entries of L; + # writing similarly for |A| and |Lambda|, we thus have + # that log_sqrt_det_B = 0.5 log|B| is + # (note that |A| = -|A_inv|) + log_sqrt_det_B = \ + jnp.sum(jnp.log(jnp.diag(L))) - \ + jnp.sum(jnp.log(jnp.diag(Lambda_inv_chol))) + \ + jnp.sum(jnp.log(jnp.diag(A_inv_chol))) + + # putting it all together we can get the contribution + # to the log likelihood from this detector + logl = -0.5*jnp.dot(r, Cinv_r - Cinv_M_A_Mt_Cinv_r) \ + - log_sqrt_det_B + + # numpyro.factor(f'logl_{i}', logl) + + # update the prior mean and precision for the next detector + mu = a + Lambda_inv = A_inv + Lambda_inv_chol = A_inv_chol + + return (mu, Lambda_inv, Lambda_inv_chol), logl + return likelihood_update + + def make_model(modes: int | list[(int, int, int, int)], a_scale_max: float, marginalized: bool = True, @@ -311,9 +452,9 @@ def make_model(modes: int | list[(int, int, int, int)], mode_ordering: None | str = None, single_polarization: bool = False, prior: bool = False, - predictive: bool = True, - store_h_det: bool = True, - store_h_det_mode: bool = True): + predictive: bool = False, + store_h_det: bool = False, + store_h_det_mode: bool = False): """ Arguments --------- @@ -652,107 +793,19 @@ def model(times, strains, ls, fps, fcs, # iterating over all detectors, we have turned the prior into # the posterior - for i in range(n_det): - # select the design matrix (M), the Cholesky factor (L), - # and the strain (y) for the current detector - # (ndet, ntime, nquads*nmode) => (i, ntime, nquads*nmode) - M = dms[i, :, :] - L = ls[i, :, :] - y = strains[i, :] - - # M acts as a coordinate transformation matrix, taking us - # from the space of quadratures to the space of the data , - # while M^T takes us from data space to quadrature space - # (M is ntime x nquads*nmode) - - # L whitens the noise in the detector, taking it from - # N(0, C) to N(0, I) (L is ntime x ntime) - - # we can use M and L to compute the precision (A_inv) of - # the marginal posterior on the quadratures (conditioned on - # the current data and nonlinear parameters), which is just - # the sum of the prior precision (Lambda_inv) and the - # likelihood precision (M^T C^-1 M): - # A_inv = Lambda_inv + M^T C^-1 M - # so that A and A_inv are (nquads*nmode, nquads*nmode) - A_inv = Lambda_inv + \ - jnp.dot(M.T, jsp.linalg.cho_solve((L, True), M)) - A_inv_chol = jsp.linalg.cholesky(A_inv, lower=True) - - # we can also compute the marginal-posterior mean (a), - # which is the precision-weighted sum of the prior mean - # (mu) and the likelihood mean (M^T C^-1 y): - # a = A_inv (Lambda_inv mu + M^T C^-1 y) - # so that a is (nquads*nmode,) - a = jsp.linalg.cho_solve( - (A_inv_chol, True), jnp.dot(Lambda_inv, mu) + - jnp.dot(M.T, jsp.linalg.cho_solve((L, True), y))) - - # the mean (b) of the marginal likelihood p(y|b, B), - # i.e., the likelihood obtained after integrating out - # the quadratures, is simply the value of the strain y - # corresponding to the mean quadratures, i.e., mu after - # a coordinate transformation: - # b = M mu - # so that b is (ntime,) - b = jnp.dot(M, mu) - - # the (co)variance of the marginal likelihood (B) is the - # sum of the variance from the noise (C) and the variance - # from the quadrature prior (Lambda): - # B = C + M Lambda M^T - # this is (ntime, ntime), which is large; but, to compute - # the marginal likelihood, we need the inverse covariance - # B^-1, so we can use the Woodbury identity to write: - # B^-1 = C^-1 - C^-1 M (Lambda^-1 + M^T C^-1 M)^-1 M^T C^-1 - # = C^-1 - C^-1 M A M^T C^-1 - # where A = A_inv^-1 per the above; this way we avoid - # inverting the large matrix B directly and take advantage - # of the precomputed Cholesky factor L to get C^-1 - - # with the residual r = y - b, the marginal log-likelihood - # becomes - # logl = -0.5 r^T B^-1 r - 0.5 log |2pi B| - # where |2pi B| is the determinant of 2pi*B and we can - # ignore the 2pi factor since it introduces a term like - # - 0.5*ntime*log(2pi), which is constant - r = y - b - Cinv_r = jsp.linalg.cho_solve((L, True), r) - - M_A_Mt_Cinv_r = jnp.dot(M, jsp.linalg.cho_solve( - (A_inv_chol, True), jnp.dot(M.T, Cinv_r))) - - Cinv_M_A_Mt_Cinv_r = \ - jsp.linalg.cho_solve((L, True), M_A_Mt_Cinv_r) - - # now all we have left to compute is the log determinant - # term, 0.5*log|B|; from the Gaussian refactorization, we - # have that - # |Lambda| |C| = |A| |B| - # and therefore - # log|B| = log|C| + log|Lambda| - log|A| - # furthermore, since |C| = |L|^2, we can write - # 0.5 log|C| = log|L| - # and |L| is the product of the diagonal entries of L; - # writing similarly for |A| and |Lambda|, we thus have - # that log_sqrt_det_B = 0.5 log|B| is - # (note that |A| = -|A_inv|) - log_sqrt_det_B = \ - jnp.sum(jnp.log(jnp.diag(L))) - \ - jnp.sum(jnp.log(jnp.diag(Lambda_inv_chol))) + \ - jnp.sum(jnp.log(jnp.diag(A_inv_chol))) - - # putting it all together we can get the contribution - # to the log likelihood from this detector - logl = -0.5*jnp.dot(r, Cinv_r - Cinv_M_A_Mt_Cinv_r) \ - - log_sqrt_det_B + likelihood_update = make_likelihood_update(dms, ls, strains) - numpyro.factor(f'logl_{i}', logl) + mu_Lambda_inv_Lambda_inv_chol, logls = lax.scan( + likelihood_update, + (mu, Lambda_inv, Lambda_inv_chol), + jnp.arange(n_det), + ) + mu, Lambda_inv, Lambda_inv_chol = mu_Lambda_inv_Lambda_inv_chol - # update the prior mean and precision for the next detector - mu = a - Lambda_inv = A_inv - Lambda_inv_chol = A_inv_chol + # add likelihoods to potential + # TODO: check if can use numpyro.control_flow.scan instead + for i, logl in enumerate(logls): + numpyro.factor(f'logl_{i}', logl) if predictive: # Generate the actual quadrature amplitudes by taking a draw diff --git a/ringdown/utils/swsh.py b/ringdown/utils/swsh.py index 3bbe541e..b493cf79 100644 --- a/ringdown/utils/swsh.py +++ b/ringdown/utils/swsh.py @@ -1,8 +1,9 @@ __all__ = ['construct_sYlm', 'calc_YpYc'] import numpy as np +import jax import jax.numpy as jnp -from scipy.special import factorial as fac +from jax.scipy.special import factorial as fac def binom_coeff(n, k): @@ -21,8 +22,8 @@ def binom_coeff(n, k): # binomial coefficient is zero if k>n, or generally if above formula # returns an inf - num[denom == 0] = 0 - denom[denom == 0] = 1 + num = jnp.where(denom == 0, 0, num) + denom = jnp.where(denom == 0, 1, denom) return num / denom @@ -88,11 +89,12 @@ def ylm(cosi): ylm = np.sqrt(1/0.159) * prefactor * sin_th_2(cosi)**(2*ell) - summands = [(-1)**r * binom_coeff(ell - s, r) - * binom_coeff(ell + s, r + s - m) - * cot_th_2(cosi)**(2*r + s - m) - for r in rs] - ylm *= jnp.sum(jnp.array(summands), axis=0) + def get_summand(r): + return (-1)**r * binom_coeff(ell - s, r) * \ + binom_coeff(ell + s, r + s - m) * \ + cot_th_2(cosi)**(2*r + s - m) + + ylm *= jnp.sum(jax.vmap(get_summand)(rs), axis=0) return ylm return ylm