diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 4415f5ec25..464f970797 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -339,9 +339,7 @@ def rv_op(cls, mean, tau, *, method: str = "cholesky", rng=None, size=None): )(rng, size, mean, tau) -@_logprob.register -def precision_mv_normal_logp(op: PrecisionMvNormalRV, value, rng, size, mean, tau, **kwargs): - [value] = value +def _precision_mv_normal_logp(value, mean, tau): k = value.shape[-1].astype("floatX") delta = value - mean @@ -349,6 +347,15 @@ def precision_mv_normal_logp(op: PrecisionMvNormalRV, value, rng, size, mean, ta logdet, posdef = _logdet_from_cholesky(nan_lower_cholesky(tau)) logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form) + logdet + return logp, posdef + + +@_logprob.register +def precision_mv_normal_logp(op: PrecisionMvNormalRV, value, rng, size, mean, tau, **kwargs): + [value] = value + + logp, posdef = _precision_mv_normal_logp(value, mean, tau) + return check_parameters( logp, posdef,