From 416737e6140833ad8fd07b216ab7eb2abe44e4fb Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 10 Nov 2025 12:56:44 +0100 Subject: [PATCH 1/3] Fix logcdf of DiscreteUniform at lower bound logcdf(x=lower) != -inf --- pymc/distributions/discrete.py | 4 ++-- tests/distributions/test_discrete.py | 20 +++++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 70b37b9e42..f313d26cd4 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -1052,11 +1052,11 @@ def logp(value, lower, upper): def logcdf(value, lower, upper): res = pt.switch( - pt.le(value, lower), + pt.lt(value, lower), -np.inf, pt.switch( pt.lt(value, upper), - pt.log(pt.minimum(pt.floor(value), upper) - lower + 1) - pt.log(upper - lower + 1), + pt.log(pt.floor(value) - lower + 1) - pt.log(upper - lower + 1), 0, ), ) diff --git a/tests/distributions/test_discrete.py b/tests/distributions/test_discrete.py index 55e8c23128..1ec58a9eeb 100644 --- a/tests/distributions/test_discrete.py +++ b/tests/distributions/test_discrete.py @@ -41,9 +41,7 @@ Nat, NatSmall, R, - Rdunif, Rplus, - Rplusdunif, Runif, Simplex, Unit, @@ -95,32 +93,36 @@ def orderedprobit_logpdf(value, eta, cutpoints): class TestMatchesScipy: - def test_discrete_unif(self): + def test_discrete_uniform(self): + # Choose domain/paramdomain so we test edge cases as well + test_domain = Domain([-np.inf, -10, -1, 0, 1, 10, np.inf], dtype="int64") + test_paramdomain = Domain([-np.inf, 0, 10, np.inf], dtype="int64") check_logp( pm.DiscreteUniform, - Rdunif, - {"lower": -Rplusdunif, "upper": Rplusdunif}, + test_domain, + {"lower": -test_paramdomain, "upper": test_paramdomain}, lambda value, lower, upper: st.randint.logpmf(value, lower, upper + 1), skip_paramdomain_outside_edge_test=True, ) check_logcdf( pm.DiscreteUniform, - Rdunif, - {"lower": -Rplusdunif, "upper": Rplusdunif}, + test_domain, + {"lower": -test_paramdomain, "upper": test_paramdomain}, lambda value, lower, upper: st.randint.logcdf(value, lower, upper + 1), skip_paramdomain_outside_edge_test=True, ) check_selfconsistency_discrete_logcdf( pm.DiscreteUniform, Domain([-10, 0, 10], "int64"), - {"lower": -Rplusdunif, "upper": Rplusdunif}, + {"lower": -test_paramdomain, "upper": test_paramdomain}, ) check_icdf( pm.DiscreteUniform, - {"lower": -Rplusdunif, "upper": Rplusdunif}, + {"lower": -test_paramdomain, "upper": test_paramdomain}, lambda q, lower, upper: st.randint.ppf(q=q, low=lower, high=upper + 1), skip_paramdomain_outside_edge_test=True, ) + # Custom logp / logcdf check for invalid parameters invalid_dist = pm.DiscreteUniform.dist(lower=1, upper=0) with pytensor.config.change_flags(mode=Mode("py")): From da3db75e7228cc6d834689922604467ddab36aad Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 10 Nov 2025 13:28:16 +0100 Subject: [PATCH 2/3] Handle edge case of logdiffexp(-inf, -inf) --- pymc/logprob/transforms.py | 3 ++- pymc/math.py | 9 ++++++++- tests/test_math.py | 8 ++++++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 8d2bbacd26..a56856aa13 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -123,6 +123,7 @@ filter_measurable_variables, find_negated_var, ) +from pymc.math import logdiffexp class Transform(abc.ABC): @@ -267,7 +268,7 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg logcdf_zero = _logcdf_helper(measurable_input, 0) logcdf = pt.switch( pt.lt(backward_value, 0), - pt.log(pt.exp(logcdf_zero) - pt.exp(logcdf)), + logdiffexp(logcdf_zero, logcdf), pt.logaddexp(logccdf, logcdf_zero), ) else: diff --git a/pymc/math.py b/pymc/math.py index 65ddacfb95..6a46fe11c8 100644 --- a/pymc/math.py +++ b/pymc/math.py @@ -282,7 +282,14 @@ def kron_diag(*diags): def logdiffexp(a, b): """Return log(exp(a) - exp(b)).""" - return a + pt.log1mexp(b - a) + return pt.where( + # Handle special case where b is -inf + # If a == b == -inf, this will return the correct result of -inf + # whereas the default else branch would get a nan due to -inf - (-inf) + pt.isneginf(b), + a, + a + pt.log1mexp(b - a), + ) invlogit = sigmoid diff --git a/tests/test_math.py b/tests/test_math.py index eeee1f164e..0acee624ee 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -144,9 +144,13 @@ def test_log1mexp_deprecation_warnings(): def test_logdiffexp(): a = np.log([1, 2, 3, 4]) with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning) b = np.log([0, 1, 2, 3]) - assert np.allclose(logdiffexp(a, b).eval(), 0) + np.testing.assert_allclose(logdiffexp(a, b).eval(), 0, atol=1e-15) + + np.testing.assert_allclose( + logdiffexp(-np.inf, -np.inf).eval(), + -np.inf, + ) class TestLogDet: From 474c8b750f89a099d53e9a3e76699a5de12cb465 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 10 Nov 2025 13:41:32 +0100 Subject: [PATCH 3/3] Infer logp and logcdf of abs of discrete variables --- pymc/logprob/transforms.py | 95 +++++++++++++++++++------ tests/logprob/test_transforms.py | 116 +++++++++++++++++-------------- 2 files changed, 138 insertions(+), 73 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index a56856aa13..07b7f1e10d 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -40,7 +40,7 @@ import numpy as np import pytensor.tensor as pt -from pytensor import scan +from pytensor import graph_replace, scan from pytensor.gradient import jacobian from pytensor.graph.basic import Apply, Variable from pytensor.graph.fg import FunctionGraph @@ -163,6 +163,8 @@ def __str__(self): class MeasurableTransform(MeasurableElemwise): """A placeholder used to specify a log-likelihood for a transformed measurable variable.""" + __props__ = ("scalar_op", "inplace_pattern", "is_discrete") + valid_scalar_types = ( Exp, Log, @@ -187,16 +189,55 @@ class MeasurableTransform(MeasurableElemwise): transform_elemwise: Transform measurable_input_idx: int - def __init__(self, *args, transform: Transform, measurable_input_idx: int, **kwargs): + def __init__( + self, *args, transform: Transform, measurable_input_idx: int, is_discrete: bool, **kwargs + ): self.transform_elemwise = transform self.measurable_input_idx = measurable_input_idx + self.is_discrete = is_discrete super().__init__(*args, **kwargs) +def abs_logprob(op, value, x, **kwargs): + """Compute the log-CDF graph for an absolute value transformation. + + For `Y = |X|`, we have `PDF_Y(y) = PDF_Y(-y) + PDF_Y(y)`. + Except for discrete distributions where there's a special case `P(Y=0) = P(X=0)`. + """ + logprob_pos = _logprob_helper(x, value) + logprob_neg = graph_replace(logprob_pos, {value: -value}) + if op.is_discrete: + logprob = pt.switch( + pt.eq(value, 0), + logprob_pos, + pt.logaddexp(logprob_pos, logprob_neg), + ) + else: + logprob = pt.logaddexp(logprob_pos, logprob_neg) + logprob = pt.where(value < 0, -np.inf, logprob) + return logprob + + +def abs_logcdf(op, value, x, **kwargs): + """Compute the log-CDF graph for an absolute value transformation. + + For `Y = |X|`, we have `CDF_Y(y) = P(|X| <= y) = P(-y <= X <= y) = CDF_X(y) - CDF_X(-y)`. + """ + logcdf_pos = _logcdf_helper(x, value) + neg_value = -value - 1 if op.is_discrete else -value + logcdf_neg = graph_replace(logcdf_pos, {value: neg_value}) + logcdf = logdiffexp(logcdf_pos, logcdf_neg) + logcdf = pt.where(value < 0, -np.inf, logcdf) + return logcdf + + @_logprob.register(MeasurableTransform) def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwargs): """Compute the log-probability graph for a `MeasurabeTransform`.""" # TODO: Could other rewrites affect the order of inputs? + if isinstance(op.scalar_op, Abs): + return abs_logprob(op, values[0], *inputs, **kwargs) + (value,) = values other_inputs = list(inputs) measurable_input = other_inputs.pop(op.measurable_input_idx) @@ -207,6 +248,11 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa # Some transformations, like squaring may produce multiple backward values if isinstance(backward_value, tuple): + if op.is_discrete: + # Discrete variables tend to have the tricky x=0 case, get out if we don't have a custom implementation + raise NotImplementedError( + "Logprob of transformed discrete variables with non-injective transforms not implemented" + ) input_logprob = pt.logaddexp( *( _logprob_helper(measurable_input, backward_val, **kwargs) @@ -225,8 +271,11 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa ndim_supp = value.ndim - input_logprob.ndim jacobian = jacobian.sum(axis=tuple(range(-ndim_supp, 0))) + # Discrete transformations do not need the jacobian adjustment + logprob = input_logprob if op.is_discrete else input_logprob + jacobian + # The jacobian is used to ensure a value in the supported domain was provided - return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian) + return pt.switch(pt.isnan(jacobian), -np.inf, logprob) MONOTONICALLY_INCREASING_OPS = (Exp, Log, Add, Sinh, Tanh, ArcSinh, ArcCosh, ArcTanh, Erf, Sigmoid) @@ -236,6 +285,10 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa @_logcdf.register(MeasurableTransform) def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwargs): """Compute the log-CDF graph for a `MeasurabeTransform`.""" + if isinstance(op.scalar_op, Abs): + # Special case for absolute value transformation + return abs_logcdf(op, value, *inputs, **kwargs) + other_inputs = list(inputs) measurable_input = other_inputs.pop(op.measurable_input_idx) backward_value = op.transform_elemwise.backward(value, *other_inputs) @@ -245,10 +298,8 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg if isinstance(backward_value, tuple): raise NotImplementedError - is_discrete = measurable_input.type.dtype.startswith("int") - logcdf = _logcdf_helper(measurable_input, backward_value) - if is_discrete: + if op.is_discrete: logccdf = pt.log1mexp(_logcdf_helper(measurable_input, backward_value - 1)) else: logccdf = pt.log1mexp(logcdf) @@ -275,9 +326,6 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg # We don't know if this Op is monotonically increasing/decreasing raise NotImplementedError - if is_discrete: - return logcdf - # The jacobian is used to ensure a value in the supported domain was provided jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs) return pt.switch(pt.isnan(jacobian), -np.inf, logcdf) @@ -286,13 +334,12 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg @_icdf.register(MeasurableTransform) def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs): """Compute the inverse CDF graph for a `MeasurabeTransform`.""" + if op.is_discrete: + raise NotImplementedError("icdf of transformed discrete variables not implemented") + other_inputs = list(inputs) measurable_input = other_inputs.pop(op.measurable_input_idx) - # Do not apply rewrite to discrete variables - if measurable_input.type.dtype.startswith("int"): - raise NotImplementedError("icdf of transformed discrete variables not implemented") - if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS): pass elif isinstance(op.scalar_op, MONOTONICALLY_DECREASING_OPS): @@ -323,7 +370,7 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs) # Fail if transformation is not injective # A TensorVariable is returned in 1-to-1 inversions, and a tuple in 1-to-many if isinstance(op.transform_elemwise.backward(icdf, *other_inputs), tuple): - raise NotImplementedError + raise NotImplementedError("icdf of non-injective transformations not implemented") return icdf @@ -481,15 +528,22 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Apply) -> list[Varia [measurable_input] = measurable_inputs [measurable_output] = node.outputs - # Do not apply rewrite to discrete variables except for their addition and negation - if measurable_input.type.dtype.startswith("int"): + # Do not apply rewrite to discrete variables except if: + # 1. Operation retains a discrete output + # 2. Operation doesn't create holes in the support + # Reason: + # 1. Due to a limitation in our IR we don't know the type of the MeasurableVariable + # We don't want to make other rewrites think they are dealing with continuous variables when they are not + # 2. We don't want to add cumbersome within-domain checks + is_discrete = measurable_input.type.dtype.startswith("int") + if is_discrete: + if not measurable_output.type.dtype.startswith("int"): + return None if not ( - find_negated_var(measurable_output) is not None or isinstance(node.op.scalar_op, Add) + isinstance(node.op.scalar_op, Add | Abs) + or find_negated_var(measurable_output) is not None ): return None - # Do not allow rewrite if output is cast to a float, because we don't have meta-info on the type of the MeasurableVariable - if not measurable_output.type.dtype.startswith("int"): - return None # Check that other inputs are not potentially measurable, in which case this rewrite # would be invalid @@ -545,6 +599,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Apply) -> list[Varia scalar_op=scalar_op, transform=transform, measurable_input_idx=measurable_input_idx, + is_discrete=is_discrete, ) transform_out = transform_op.make_node(*transform_inputs).default_output() return [transform_out] diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index c9aeaa8abf..cd49904dbf 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -44,7 +44,7 @@ from pytensor.graph.basic import equal_computations from pymc.distributions.continuous import Cauchy, ChiSquared -from pymc.distributions.discrete import Bernoulli +from pymc.distributions.discrete import Bernoulli, DiscreteUniform from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp from pymc.logprob.transforms import ( ArccoshTransform, @@ -285,6 +285,27 @@ def test_loc_transform_rv(self, rv_size, loc_type, addition): sp.stats.norm(loc_test_val, 1).ppf(q_test_val), ) + def test_shifted_discrete_rv_transform(self): + p = 0.7 + rv = Bernoulli.dist(p=p) + 5 + vv = rv.type() + + rv_logp_fn = pytensor.function([vv], logp(rv, vv)) + assert rv_logp_fn(4) == -np.inf + np.testing.assert_allclose(rv_logp_fn(5), np.log(1 - p)) + np.testing.assert_allclose(rv_logp_fn(6), np.log(p)) + assert rv_logp_fn(7) == -np.inf + + rv_logcdf_fn = pytensor.function([vv], logcdf(rv, vv)) + assert rv_logcdf_fn(4) == -np.inf + np.testing.assert_allclose(rv_logcdf_fn(5), np.log(1 - p)) + np.testing.assert_allclose(rv_logcdf_fn(6), 0) + assert rv_logcdf_fn(7) == 0 + + # icdf not supported yet + with pytest.raises(NotImplementedError): + icdf(rv, 0) + @pytest.mark.parametrize( "rv_size, scale_type, product", [ @@ -337,6 +358,23 @@ def test_negated_rv_transform(self): np.testing.assert_allclose(x_logcdf_fn(-1.5), sp.stats.halfnorm.logsf(1.5)) np.testing.assert_allclose(x_icdf_fn(0.3), -sp.stats.halfnorm.ppf(1 - 0.3)) + def test_negated_discrete_rv_transform(self): + p = 0.7 + rv = -Bernoulli.dist(p=p, shape=(4,)) + vv = rv.type() + + # A negated Bernoulli has pmf {p if x == -1; 1-p if x == 0; 0 otherwise} + logp_fn = pytensor.function([vv], logp(rv, vv)) + np.testing.assert_allclose( + logp_fn([-2, -1, 0, 1]), [-np.inf, np.log(p), np.log(1 - p), -np.inf] + ) + + logcdf_fn = pytensor.function([vv], logcdf(rv, vv)) + np.testing.assert_allclose(logcdf_fn([-2, -1, 0, 1]), [-np.inf, np.log(p), 0, 0]) + + with pytest.raises(NotImplementedError): + icdf(rv, [-2, -1, 0, 1]) + def test_subtracted_rv_transform(self): # Choose base RV that is asymmetric around zero x_rv = 5.0 - pt.random.normal(1.0) @@ -501,21 +539,33 @@ def test_negative_value_frac_power_transform_logp(self, power): assert np.isneginf(x_logp_fn(-2.5)) -@pytest.mark.parametrize("test_val", (2.5, -2.5)) -def test_absolute_rv_transform(test_val): - x_rv = pt.abs(pt.random.normal()) - y_rv = pt.random.halfnormal() +@pytest.mark.parametrize("continuous", (True, False)) +def test_absolute_rv_transform(continuous): + if continuous: + x_rv = pt.abs(pt.random.normal(size=(5,))) + ref_rv = pt.random.halfnormal(size=(5,)) + else: + x_rv = pt.abs(DiscreteUniform.dist(-4, 4, size=(5,))) + # |x_rv| = DiscreteUniform(0,4) with P(X=0) halved relative to other values + # We can use a Categorical to representh this + ref_rv = pt.random.categorical( + p=np.array([1, 2, 2, 2, 2]) / 9, + size=(5,), + ) - x_vv = x_rv.clone() - y_vv = y_rv.clone() - x_logp_fn = pytensor.function([x_vv], logp(x_rv, x_vv)) - with pytest.raises(NotImplementedError): - logcdf(x_rv, x_vv) + x_vv = x_rv.type() + ref_vv = ref_rv.type() + # Not working with logs because it's easier to debug for discrete case + x_pdf_fn = pytensor.function([x_vv], pt.exp(logp(x_rv, x_vv))) + x_cdf_fn = pytensor.function([x_vv], pt.exp(logcdf(x_rv, x_vv))) with pytest.raises(NotImplementedError): icdf(x_rv, x_vv) - y_logp_fn = pytensor.function([y_vv], logp(y_rv, y_vv)) - np.testing.assert_allclose(x_logp_fn(test_val), y_logp_fn(test_val)) + ref_pdf_fn = pytensor.function([ref_vv], pt.exp(logp(ref_rv, ref_vv))) + ref_cdf_fn = pytensor.function([ref_vv], pt.exp(logcdf(ref_rv, ref_vv))) + test_val = np.array([-2.5, -2.0, 0, 2.0, 2.5], dtype=x_vv.dtype) + np.testing.assert_allclose(x_pdf_fn(test_val), ref_pdf_fn(test_val)) + np.testing.assert_allclose(x_cdf_fn(test_val), ref_cdf_fn(test_val)) @pytest.mark.parametrize( @@ -690,51 +740,11 @@ def test_not_implemented_discrete_rv_transform(): with pytest.raises(RuntimeError, match="could not be derived"): conditional_logp({y_rv: y_rv.clone()}) - y_rv = 5 * pt.random.poisson(1) + y_rv = 5.5 * pt.random.poisson(1) with pytest.raises(RuntimeError, match="could not be derived"): conditional_logp({y_rv: y_rv.clone()}) -def test_negated_discrete_rv_transform(): - p = 0.7 - rv = -Bernoulli.dist(p=p, shape=(4,)) - vv = rv.type() - - # A negated Bernoulli has pmf {p if x == -1; 1-p if x == 0; 0 otherwise} - logp_fn = pytensor.function([vv], logp(rv, vv)) - np.testing.assert_allclose( - logp_fn([-2, -1, 0, 1]), [-np.inf, np.log(p), np.log(1 - p), -np.inf] - ) - - logcdf_fn = pytensor.function([vv], logcdf(rv, vv)) - np.testing.assert_allclose(logcdf_fn([-2, -1, 0, 1]), [-np.inf, np.log(p), 0, 0]) - - with pytest.raises(NotImplementedError): - icdf(rv, [-2, -1, 0, 1]) - - -def test_shifted_discrete_rv_transform(): - p = 0.7 - rv = Bernoulli.dist(p=p) + 5 - vv = rv.type() - - rv_logp_fn = pytensor.function([vv], logp(rv, vv)) - assert rv_logp_fn(4) == -np.inf - np.testing.assert_allclose(rv_logp_fn(5), np.log(1 - p)) - np.testing.assert_allclose(rv_logp_fn(6), np.log(p)) - assert rv_logp_fn(7) == -np.inf - - rv_logcdf_fn = pytensor.function([vv], logcdf(rv, vv)) - assert rv_logcdf_fn(4) == -np.inf - np.testing.assert_allclose(rv_logcdf_fn(5), np.log(1 - p)) - np.testing.assert_allclose(rv_logcdf_fn(6), 0) - assert rv_logcdf_fn(7) == 0 - - # icdf not supported yet - with pytest.raises(NotImplementedError): - icdf(rv, 0) - - @pytest.mark.xfail(reason="Check not implemented yet") def test_invalid_broadcasted_transform_rv_fails(): loc = pt.vector("loc")