diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index 3d577aac9..da3506c1f 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -33,20 +33,23 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import typing from pathlib import Path from pytensor import tensor as pt from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.npy_2_compat import normalize_axis_index from pytensor.tensor import TensorVariable -from pytensor.tensor.basic import Join, MakeVector +from pytensor.tensor.basic import Join, MakeVector, Split from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.rewriting import ( local_dimshuffle_rv_lift, ) +from pymc.exceptions import NotConstantValueError from pymc.logprob.abstract import ( MeasurableOp, ValuedRV, @@ -70,7 +73,7 @@ class MeasurableMakeVector(MeasurableOp, MakeVector): - """A placeholder used to specify a log-likelihood for a cumsum sub-graph.""" + """A placeholder used to specify a log-likelihood for a make_vector sub-graph.""" @_logprob.register(MeasurableMakeVector) @@ -183,6 +186,64 @@ def find_measurable_stacks(fgraph, node) -> list[TensorVariable] | None: return [measurable_stack] +class MeasurableSplit(MeasurableOp, Split): + """A placeholder used to specify a log-likelihood for a split sub-graph.""" + + +@node_rewriter([Split]) +def find_measurable_splits(fgraph, node) -> list[TensorVariable] | None: + if isinstance(node.op, MeasurableOp): + return None + + x, axis, splits = node.inputs + if not filter_measurable_variables([x]): + return None + + return MeasurableSplit(node.op.len_splits).make_node(x, axis, splits).outputs + + +@_logprob.register(MeasurableSplit) +def logprob_split(op: MeasurableSplit, values, x, axis, splits, **kwargs): + """Compute the log-likelihood graph for a `MeasurableSplit`.""" + if len(values) != op.len_splits: + # TODO: Don't rewrite the split in the first place if not all parts are linked to value variables + # This also allows handling some cases where not all splits are used + raise ValueError("Split logp requires the number of values to match the number of splits") + + # Reverse the effects of split on the value variable + join_value = pt.join(axis, *values) + + join_logp = _logprob_helper(x, join_value) + + reduced_dims = join_value.ndim - join_logp.ndim + + if reduced_dims: + # This happens for multivariate distributions + try: + [constant_axis] = constant_fold([axis]) + except NotConstantValueError: + raise NotImplementedError("Cannot split multivariate logp with non-constant axis") + + constant_axis = normalize_axis_index(constant_axis, join_value.ndim) # type: ignore[arg-type, assignment] + if constant_axis >= join_logp.ndim: + # If the axis is over a dimension that was reduced in the logp (multivariate logp), + # We cannot split it into distinct entries. The mapping between values-densities breaks. + # We return the weighted logp by the split sizes. This is a good solution as any? + split_weights = splits / pt.sum(splits) + return [join_logp * split_weights[i] for i in range(typing.cast(int, op.len_splits))] + else: + # Otherwise we can split the logp as the split were over batched dimensions + # We just need to be sure to use the positive axis index + axis = constant_axis + + return pt.split( + join_logp, + splits_size=splits, + n_splits=op.len_splits, + axis=axis, + ) + + class MeasurableDimShuffle(MeasurableOp, DimShuffle): """A placeholder used to specify a log-likelihood for a dimshuffle sub-graph.""" @@ -308,3 +369,10 @@ def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None: "basic", "tensor", ) + +measurable_ir_rewrites_db.register( + "find_measurable_splits", + find_measurable_splits, + "basic", + "tensor", +) diff --git a/tests/logprob/test_tensor.py b/tests/logprob/test_tensor.py index df5c7052f..fe9a87597 100644 --- a/tests/logprob/test_tensor.py +++ b/tests/logprob/test_tensor.py @@ -40,6 +40,7 @@ from pytensor import tensor as pt from pytensor.graph import RewriteDatabaseQuery +from pytensor.tensor.random.type import random_generator_type from scipy import stats as st from pymc.logprob.basic import conditional_logp, logp @@ -352,7 +353,7 @@ def test_measurable_dimshuffle(ds_order, multivariate): np.testing.assert_array_equal(ref_logp_fn(base_test_value), ds_logp_fn(ds_test_value)) -def test_unmeargeable_dimshuffles(): +def test_unmeasurable_dimshuffles(): # Test that graphs with DimShuffles that cannot be lifted/merged fail # Initial support axis is at axis=-1 @@ -372,3 +373,155 @@ def test_unmeargeable_dimshuffles(): # TODO: Check that logp is correct if this type of graphs is ever supported with pytest.raises(RuntimeError, match="could not be derived"): conditional_logp({w: w_vv}) + + +class TestMeasurableSplit: + def test_univariate(self): + rng = np.random.default_rng(388) + mu = np.arange(6)[:, None] + sigma = np.arange(5) + 1 + + x = pt.random.normal(mu, sigma, size=(6, 5), name="x") + + # axis=0 + x_parts = pt.split(x, splits_size=[2, 4], n_splits=2, axis=0) + x_parts_vv = [x_part.clone() for x_part in x_parts] + logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values()) + + logp_fn = pytensor.function(x_parts_vv, logp_parts) + x_parts_test = [rng.normal(size=x_part.type.shape) for x_part in x_parts_vv] + logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test) + np.testing.assert_allclose( + logp_x1_eval, + st.norm.logpdf(x_parts_test[0], mu[:2], sigma), + ) + np.testing.assert_allclose( + logp_x2_eval, + st.norm.logpdf(x_parts_test[1], mu[2:], sigma), + ) + + # axis=1 + x_parts = pt.split(x, splits_size=[2, 1, 2], n_splits=3, axis=1) + x_parts_vv = [x_part.clone() for x_part in x_parts] + logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values()) + + logp_fn = pytensor.function(x_parts_vv, logp_parts) + x_parts_test = [rng.normal(size=x_part.type.shape) for x_part in x_parts_vv] + logp_x1_eval, logp_x2_eval, logp_x3_eval = logp_fn(*x_parts_test) + np.testing.assert_allclose( + logp_x1_eval, + st.norm.logpdf(x_parts_test[0], mu, sigma[:2]), + ) + np.testing.assert_allclose( + logp_x2_eval, + st.norm.logpdf(x_parts_test[1], mu, sigma[2:3]), + ) + np.testing.assert_allclose( + logp_x3_eval, + st.norm.logpdf(x_parts_test[2], mu, sigma[3:]), + ) + + def test_multivariate(self): + @np.vectorize(signature=("(n),(n)->()")) + def scipy_dirichlet_logpdf(x, alpha): + """Compute the logpdf of a Dirichlet distribution using scipy.""" + return st.dirichlet.logpdf(x, alpha) + + # (3, 5) Dirichlet + rng = np.random.default_rng(426) + rng_pt = random_generator_type("rng") + alpha = np.linspace(1, 10, 5) * np.array([1, 10, 100])[:, None] + x = pt.random.dirichlet(alpha, rng=rng_pt) + + # axis=-2 (i.e., 0, - batch dimension) + x_parts = pt.split(x, splits_size=[2, 1], n_splits=2, axis=-2) + x_parts_vv = [x_part.clone() for x_part in x_parts] + logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values()) + assert logp_parts[0].type.shape == (2,) + assert logp_parts[1].type.shape == (1,) + + logp_fn = pytensor.function(x_parts_vv, logp_parts) + x_parts_test = pytensor.function([rng_pt], x_parts)(rng) + logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test) + np.testing.assert_allclose( + logp_x1_eval, + scipy_dirichlet_logpdf(x_parts_test[0], alpha[:2]), + ) + np.testing.assert_allclose( + logp_x2_eval, + scipy_dirichlet_logpdf(x_parts_test[1], alpha[2:]), + ) + + # axis=-1 (i.e., 1, - support dimension) + x_parts = pt.split(x, splits_size=[2, 3], n_splits=2, axis=-1) + x_parts_vv = [x_part.clone() for x_part in x_parts] + logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values()) + + assert logp_parts[0].type.shape == (3,) + assert logp_parts[1].type.shape == (3,) + logp_fn = pytensor.function(x_parts_vv, logp_parts) + + x_parts_test = pytensor.function([rng_pt], x_parts)(rng) + logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test) + np.testing.assert_allclose(logp_x1_eval * 3, logp_x2_eval * 2) + logp_total = logp_x1_eval + logp_x2_eval + np.testing.assert_allclose( + logp_total, + scipy_dirichlet_logpdf(np.concatenate(x_parts_test, axis=1), alpha), + ) + + @pytest.mark.xfail( + reason="Rewrite from partial split to split on subtensor not implemented yet" + ) + def test_not_all_splits_used(self): + x = pt.random.normal(mu=pt.arange(6), name="x") + x_parts = pt.split(x, splits_size=[2, 2, 2], n_splits=3, axis=0)[ + ::2 + ] # Only use first two splits + x_parts_vv = [x_part.clone() for x_part in x_parts] + logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values()) + assert len(logp_parts) == 2 + + logp_fn = pytensor.function(x_parts_vv, logp_parts) + x_parts_test = [x_part.eval() for x_part in x_parts_vv] + logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test) + np.testing.assert_allclose( + logp_x1_eval, + st.norm.logpdf(x_parts_test[0], loc=[0, 1]), + ) + np.testing.assert_allclose( + logp_x2_eval, + st.norm.logpdf(x_parts_test[1], loc=[4, 5]), + ) + + def test_not_all_splits_used_core_dim(self): + # TODO: We could support this for univariate/batch dimensions by rewriting as + # split(x, splits_size=[2, 2, 2], n_splits=3, axis=1)[:2] -> split(x[:-2], splits_size=[2, 2], n_splits=2, axis=1) + # And letting logp infer the probability of x[:-2] + x = pt.random.dirichlet(alphas=pt.ones(6), name="x") + x_parts = pt.split(x, splits_size=[2, 2, 2], n_splits=3, axis=0)[ + :2 + ] # Only use first two splits + x_parts_vv = [x_part.clone() for x_part in x_parts] + + with pytest.raises( + ValueError, + match="Split logp requires the number of values to match the number of splits", + ): + conditional_logp(dict(zip(x_parts, x_parts_vv))) + + @pytest.mark.xfail(reason="Rewrite from subtensor to split not implemented yet") + def test_subtensor_converted_to_splits(self): + rng = np.random.default_rng(388) + x = pt.random.normal(mu=pt.arange(5), name="x") + + x_parts = [x[:2], x[2:3], x[3:]] + x_parts_vv = [x_part.clone() for x_part in x_parts] + logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values()) + assert len(logp_parts) == 3 + logp_fn = pytensor.function(x_parts_vv, logp_parts) + x_parts_test = [rng.normal(size=x_part.type.shape) for x_part in x_parts_vv] + logp_x1_eval, logp_x2_eval, logp_x3_eval = logp_fn(*x_parts_test) + np.testing.assert_allclose(logp_x1_eval, st.norm.logpdf(x_parts_test[0], loc=[0, 1])) + np.testing.assert_allclose(logp_x2_eval, st.norm.logpdf(x_parts_test[1], loc=[2])) + np.testing.assert_allclose(logp_x3_eval, st.norm.logpdf(x_parts_test[2], loc=[3, 4]))