From 3b13a6988a709099ad710ebf9e7c445bb007adbe Mon Sep 17 00:00:00 2001 From: Michal-Novomestsky Date: Mon, 11 Aug 2025 21:15:16 +1000 Subject: [PATCH 1/2] refactor: move _precition_mv_normal_logp into a seperate function --- pymc/distributions/multivariate.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 4415f5ec25..1856d33c96 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -338,10 +338,7 @@ def rv_op(cls, mean, tau, *, method: str = "cholesky", rng=None, size=None): method=method, )(rng, size, mean, tau) - -@_logprob.register -def precision_mv_normal_logp(op: PrecisionMvNormalRV, value, rng, size, mean, tau, **kwargs): - [value] = value +def _precision_mv_normal_logp(value, mean, tau): k = value.shape[-1].astype("floatX") delta = value - mean @@ -349,6 +346,14 @@ def precision_mv_normal_logp(op: PrecisionMvNormalRV, value, rng, size, mean, ta logdet, posdef = _logdet_from_cholesky(nan_lower_cholesky(tau)) logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form) + logdet + return logp, posdef + +@_logprob.register +def precision_mv_normal_logp(op: PrecisionMvNormalRV, value, rng, size, mean, tau, **kwargs): + [value] = value + + logp, posdef = _precision_mv_normal_logp(value, mean, tau) + return check_parameters( logp, posdef, From 67dc8173ad734db1820aa465922c516ab1922191 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Aug 2025 09:40:34 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pymc/distributions/multivariate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 1856d33c96..464f970797 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -338,6 +338,7 @@ def rv_op(cls, mean, tau, *, method: str = "cholesky", rng=None, size=None): method=method, )(rng, size, mean, tau) + def _precision_mv_normal_logp(value, mean, tau): k = value.shape[-1].astype("floatX") @@ -348,6 +349,7 @@ def _precision_mv_normal_logp(value, mean, tau): return logp, posdef + @_logprob.register def precision_mv_normal_logp(op: PrecisionMvNormalRV, value, rng, size, mean, tau, **kwargs): [value] = value