Skip to content

Commit 9e51798

Browse files
committed
Derive logprob for Split operation
1 parent dc7cfee commit 9e51798

File tree

2 files changed

+224
-3
lines changed

2 files changed

+224
-3
lines changed

pymc/logprob/tensor.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,23 @@
3333
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
3434
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3535
# SOFTWARE.
36+
import typing
3637

3738
from pathlib import Path
3839

3940
from pytensor import tensor as pt
4041
from pytensor.graph.fg import FunctionGraph
4142
from pytensor.graph.rewriting.basic import node_rewriter
43+
from pytensor.npy_2_compat import normalize_axis_index
4244
from pytensor.tensor import TensorVariable
43-
from pytensor.tensor.basic import Join, MakeVector
45+
from pytensor.tensor.basic import Join, MakeVector, Split
4446
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4547
from pytensor.tensor.random.op import RandomVariable
4648
from pytensor.tensor.random.rewriting import (
4749
local_dimshuffle_rv_lift,
4850
)
4951

52+
from pymc.exceptions import NotConstantValueError
5053
from pymc.logprob.abstract import (
5154
MeasurableOp,
5255
ValuedRV,
@@ -70,7 +73,7 @@
7073

7174

7275
class MeasurableMakeVector(MeasurableOp, MakeVector):
73-
"""A placeholder used to specify a log-likelihood for a cumsum sub-graph."""
76+
"""A placeholder used to specify a log-likelihood for a make_vector sub-graph."""
7477

7578

7679
@_logprob.register(MeasurableMakeVector)
@@ -183,6 +186,64 @@ def find_measurable_stacks(fgraph, node) -> list[TensorVariable] | None:
183186
return [measurable_stack]
184187

185188

189+
class MeasurableSplit(MeasurableOp, Split):
190+
"""A placeholder used to specify a log-likelihood for a split sub-graph."""
191+
192+
193+
@node_rewriter([Split])
194+
def find_measurable_splits(fgraph, node) -> list[TensorVariable] | None:
195+
if isinstance(node.op, MeasurableOp):
196+
return None
197+
198+
x, axis, splits = node.inputs
199+
if not filter_measurable_variables([x]):
200+
return None
201+
202+
return MeasurableSplit(node.op.len_splits).make_node(x, axis, splits).outputs
203+
204+
205+
@_logprob.register(MeasurableSplit)
206+
def logprob_split(op: MeasurableSplit, values, x, axis, splits, **kwargs):
207+
"""Compute the log-likelihood graph for a `MeasurableSplit`."""
208+
if len(values) != op.len_splits:
209+
# TODO: Don't rewrite the split in the first place if not all parts are linked to value variables
210+
# This also allows handling some cases where not all splits are used
211+
raise ValueError("Split logp requires the number of values to match the number of splits")
212+
213+
# Reverse the effects of split on the value variable
214+
join_value = pt.join(axis, *values)
215+
216+
join_logp = _logprob_helper(x, join_value)
217+
218+
reduced_dims = join_value.ndim - join_logp.ndim
219+
220+
if reduced_dims:
221+
# This happens for multivariate distributions
222+
try:
223+
constant_axis = constant_fold([axis])
224+
except NotConstantValueError:
225+
raise NotImplementedError("Cannot split multivariate logp with non-constant axis")
226+
227+
constant_axis = normalize_axis_index(constant_axis, join_value.ndim) # type: ignore[arg-type, assignment]
228+
if constant_axis >= join_logp.ndim:
229+
# If the axis is over a dimension that was reduced in the logp (multivariate logp),
230+
# We cannot split it into distinct entries. The mapping between values-densities breaks.
231+
# We return the weighted logp by the split sizes. This is a good solution as any?
232+
split_weights = splits / pt.sum(splits)
233+
return [join_logp * split_weights[i] for i in range(typing.cast(int, op.len_splits))]
234+
else:
235+
# Otherwise we can split the logp as the split were over batched dimensions
236+
# We just need to be sure to use the positive axis index
237+
axis = constant_axis
238+
239+
return pt.split(
240+
join_logp,
241+
splits_size=splits,
242+
n_splits=op.len_splits,
243+
axis=axis,
244+
)
245+
246+
186247
class MeasurableDimShuffle(MeasurableOp, DimShuffle):
187248
"""A placeholder used to specify a log-likelihood for a dimshuffle sub-graph."""
188249

@@ -308,3 +369,10 @@ def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None:
308369
"basic",
309370
"tensor",
310371
)
372+
373+
measurable_ir_rewrites_db.register(
374+
"find_measurable_splits",
375+
find_measurable_splits,
376+
"basic",
377+
"tensor",
378+
)

tests/logprob/test_tensor.py

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
from pytensor import tensor as pt
4242
from pytensor.graph import RewriteDatabaseQuery
43+
from pytensor.tensor.random.type import random_generator_type
4344
from scipy import stats as st
4445

4546
from pymc.logprob.basic import conditional_logp, logp
@@ -352,7 +353,7 @@ def test_measurable_dimshuffle(ds_order, multivariate):
352353
np.testing.assert_array_equal(ref_logp_fn(base_test_value), ds_logp_fn(ds_test_value))
353354

354355

355-
def test_unmeargeable_dimshuffles():
356+
def test_unmeasurable_dimshuffles():
356357
# Test that graphs with DimShuffles that cannot be lifted/merged fail
357358

358359
# Initial support axis is at axis=-1
@@ -372,3 +373,155 @@ def test_unmeargeable_dimshuffles():
372373
# TODO: Check that logp is correct if this type of graphs is ever supported
373374
with pytest.raises(RuntimeError, match="could not be derived"):
374375
conditional_logp({w: w_vv})
376+
377+
378+
class TestMeasurableSplit:
379+
def test_univariate(self):
380+
rng = np.random.default_rng(388)
381+
mu = np.arange(6)[:, None]
382+
sigma = np.arange(5) + 1
383+
384+
x = pt.random.normal(mu, sigma, size=(6, 5), name="x")
385+
386+
# axis=0
387+
x_parts = pt.split(x, splits_size=[2, 4], n_splits=2, axis=0)
388+
x_parts_vv = [x_part.clone() for x_part in x_parts]
389+
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())
390+
391+
logp_fn = pytensor.function(x_parts_vv, logp_parts)
392+
x_parts_test = [rng.normal(size=x_part.type.shape) for x_part in x_parts_vv]
393+
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test)
394+
np.testing.assert_allclose(
395+
logp_x1_eval,
396+
st.norm.logpdf(x_parts_test[0], mu[:2], sigma),
397+
)
398+
np.testing.assert_allclose(
399+
logp_x2_eval,
400+
st.norm.logpdf(x_parts_test[1], mu[2:], sigma),
401+
)
402+
403+
# axis=1
404+
x_parts = pt.split(x, splits_size=[2, 1, 2], n_splits=3, axis=1)
405+
x_parts_vv = [x_part.clone() for x_part in x_parts]
406+
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())
407+
408+
logp_fn = pytensor.function(x_parts_vv, logp_parts)
409+
x_parts_test = [rng.normal(size=x_part.type.shape) for x_part in x_parts_vv]
410+
logp_x1_eval, logp_x2_eval, logp_x3_eval = logp_fn(*x_parts_test)
411+
np.testing.assert_allclose(
412+
logp_x1_eval,
413+
st.norm.logpdf(x_parts_test[0], mu, sigma[:2]),
414+
)
415+
np.testing.assert_allclose(
416+
logp_x2_eval,
417+
st.norm.logpdf(x_parts_test[1], mu, sigma[2:3]),
418+
)
419+
np.testing.assert_allclose(
420+
logp_x3_eval,
421+
st.norm.logpdf(x_parts_test[2], mu, sigma[3:]),
422+
)
423+
424+
def test_multivariate(self):
425+
@np.vectorize(signature=("(n),(n)->()"))
426+
def scipy_dirichlet_logpdf(x, alpha):
427+
"""Compute the logpdf of a Dirichlet distribution using scipy."""
428+
return st.dirichlet.logpdf(x, alpha)
429+
430+
# (3, 5) Dirichlet
431+
rng = np.random.default_rng(426)
432+
rng_pt = random_generator_type("rng")
433+
alpha = np.linspace(1, 10, 5) * np.array([1, 10, 100])[:, None]
434+
x = pt.random.dirichlet(alpha, rng=rng_pt)
435+
436+
# axis=-2 (i.e., 0, - batch dimension)
437+
x_parts = pt.split(x, splits_size=[2, 1], n_splits=2, axis=-2)
438+
x_parts_vv = [x_part.clone() for x_part in x_parts]
439+
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())
440+
assert logp_parts[0].type.shape == (2,)
441+
assert logp_parts[1].type.shape == (1,)
442+
443+
logp_fn = pytensor.function(x_parts_vv, logp_parts)
444+
x_parts_test = pytensor.function([rng_pt], x_parts)(rng)
445+
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test)
446+
np.testing.assert_allclose(
447+
logp_x1_eval,
448+
scipy_dirichlet_logpdf(x_parts_test[0], alpha[:2]),
449+
)
450+
np.testing.assert_allclose(
451+
logp_x2_eval,
452+
scipy_dirichlet_logpdf(x_parts_test[1], alpha[2:]),
453+
)
454+
455+
# axis=-1 (i.e., 1, - support dimension)
456+
x_parts = pt.split(x, splits_size=[2, 3], n_splits=2, axis=-1)
457+
x_parts_vv = [x_part.clone() for x_part in x_parts]
458+
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())
459+
460+
assert logp_parts[0].type.shape == (3,)
461+
assert logp_parts[1].type.shape == (3,)
462+
logp_fn = pytensor.function(x_parts_vv, logp_parts)
463+
464+
x_parts_test = pytensor.function([rng_pt], x_parts)(rng)
465+
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test)
466+
np.testing.assert_allclose(logp_x1_eval * 3, logp_x2_eval * 2)
467+
logp_total = logp_x1_eval + logp_x2_eval
468+
np.testing.assert_allclose(
469+
logp_total,
470+
scipy_dirichlet_logpdf(np.concatenate(x_parts_test, axis=1), alpha),
471+
)
472+
473+
@pytest.mark.xfail(
474+
reason="Rewrite from partial split to split on subtensor not implemented yet"
475+
)
476+
def test_not_all_splits_used(self):
477+
x = pt.random.normal(mu=pt.arange(6), name="x")
478+
x_parts = pt.split(x, splits_size=[2, 2, 2], n_splits=3, axis=0)[
479+
::2
480+
] # Only use first two splits
481+
x_parts_vv = [x_part.clone() for x_part in x_parts]
482+
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())
483+
assert len(logp_parts) == 2
484+
485+
logp_fn = pytensor.function(x_parts_vv, logp_parts)
486+
x_parts_test = [x_part.eval() for x_part in x_parts_vv]
487+
logp_x1_eval, logp_x2_eval = logp_fn(*x_parts_test)
488+
np.testing.assert_allclose(
489+
logp_x1_eval,
490+
st.norm.logpdf(x_parts_test[0], loc=[0, 1]),
491+
)
492+
np.testing.assert_allclose(
493+
logp_x2_eval,
494+
st.norm.logpdf(x_parts_test[1], loc=[4, 5]),
495+
)
496+
497+
def test_not_all_splits_used_core_dim(self):
498+
# TODO: We could support this for univariate/batch dimensions by rewriting as
499+
# split(x, splits_size=[2, 2, 2], n_splits=3, axis=1)[:2] -> split(x[:-2], splits_size=[2, 2], n_splits=2, axis=1)
500+
# And letting logp infer the probability of x[:-2]
501+
x = pt.random.dirichlet(alphas=pt.ones(6), name="x")
502+
x_parts = pt.split(x, splits_size=[2, 2, 2], n_splits=3, axis=0)[
503+
:2
504+
] # Only use first two splits
505+
x_parts_vv = [x_part.clone() for x_part in x_parts]
506+
507+
with pytest.raises(
508+
ValueError,
509+
match="Split logp requires the number of values to match the number of splits",
510+
):
511+
conditional_logp(dict(zip(x_parts, x_parts_vv)))
512+
513+
@pytest.mark.xfail(reason="Rewrite from subtensor to split not implemented yet")
514+
def test_subtensor_converted_to_splits(self):
515+
rng = np.random.default_rng(388)
516+
x = pt.random.normal(mu=pt.arange(5), name="x")
517+
518+
x_parts = [x[:2], x[2:3], x[3:]]
519+
x_parts_vv = [x_part.clone() for x_part in x_parts]
520+
logp_parts = list(conditional_logp(dict(zip(x_parts, x_parts_vv))).values())
521+
assert len(logp_parts) == 3
522+
logp_fn = pytensor.function(x_parts_vv, logp_parts)
523+
x_parts_test = [rng.normal(size=x_part.type.shape) for x_part in x_parts_vv]
524+
logp_x1_eval, logp_x2_eval, logp_x3_eval = logp_fn(*x_parts_test)
525+
np.testing.assert_allclose(logp_x1_eval, st.norm.logpdf(x_parts_test[0], loc=[0, 1]))
526+
np.testing.assert_allclose(logp_x2_eval, st.norm.logpdf(x_parts_test[1], loc=[2]))
527+
np.testing.assert_allclose(logp_x3_eval, st.norm.logpdf(x_parts_test[2], loc=[3, 4]))

0 commit comments

Comments
 (0)