Skip to content

Add a function to create a tensor that represents the loop over the posterior #7885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/samplers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ This submodule contains functions for MCMC and forward sampling.
draw
compute_deterministics
vectorize_over_posterior
loop_over_posterior
init_nuts
sampling.jax.sample_blackjax_nuts
sampling.jax.sample_numpyro_nuts
Expand Down
40 changes: 40 additions & 0 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

__all__ = [
"CallableTensor",
"clone_while_sharing_some_variables",
"compile",
"cont_inputs",
"convert_data",
Expand Down Expand Up @@ -1041,3 +1042,42 @@ def normalize_rng_param(rng: None | Variable) -> Variable:
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
)
return rng


def clone_while_sharing_some_variables(
outputs: list[Variable],
kept_variables: Sequence[Variable] = (),
replace: dict[Variable, Variable] | None = None,
) -> list[Variable]:
"""Clone graphs, applying replacements while preserving some original variables.

Parameters
----------
outputs : list[Variable]
The list of variables to clone.
kept_variables : Sequence[Variable]
The set of variables to preserve in the cloned graph.
replace : dict[Variable, Variable]
A dictionary of variables to replace in the cloned graph.
The keys are the variables to replace, and the values are the new variables
to use in their place.

Returns
-------
list[Variable]
The cloned graphs with the replacements applied.
"""
replace_dict = replace or {}

memo = {rv: rv for rv in kept_variables}
clone_map = clone_get_equiv(
[],
outputs,
memo=memo,
)

replace_keys = [clone_map.get(key, key) for key in replace_dict]
replace_values = replace_vars_in_graphs(list(replace_dict.values()), clone_map)
fg = FunctionGraph(None, [clone_map[o] for o in outputs], clone=False)
fg.replace_all(list(zip(replace_keys, replace_values)), import_missing=True)
return fg.outputs
128 changes: 127 additions & 1 deletion pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
walk,
)
from pytensor.graph.fg import FunctionGraph
from pytensor.scan.basic import scan
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable
from pytensor.tensor.variable import TensorConstant, TensorVariable
Expand All @@ -57,7 +58,12 @@
from pymc.distributions.shape_utils import change_dist_size
from pymc.model import Model, modelcontext
from pymc.progress_bar import CustomProgress, default_progress_theme
from pymc.pytensorf import compile, rvs_in_graph
from pymc.pytensorf import (
clone_while_sharing_some_variables,
collect_default_updates,
compile,
rvs_in_graph,
)
from pymc.util import (
RandomState,
_get_seeds_per_chain,
Expand All @@ -68,6 +74,7 @@
__all__ = (
"compile_forward_sampling_function",
"draw",
"loop_over_posterior",
"sample_posterior_predictive",
"sample_prior_predictive",
"vectorize_over_posterior",
Expand Down Expand Up @@ -1083,3 +1090,122 @@ def vectorize_over_posterior(
f"The following random variables found in the extracted graph: {remaining_rvs}"
)
return vectorized_outputs


def loop_over_posterior(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of a separate function have an argument: use_scan: bool = False in the previous function? Most logic should be the same.

This flatten batch -> scan -> reshape logic should go in PyTensor, it's needed in general

outputs: list[Variable],
posterior: xr.Dataset,
input_rvs: list[Variable],
input_tensors: Sequence[Variable] = (),
allow_rvs_in_graph: bool = True,
sample_dims: tuple[str, ...] = ("chain", "draw"),
) -> tuple[list[Variable], dict[Variable, Variable]]:
"""Loop over posterior samples of subset of input rvs.

This function creates a new graph for the supplied outputs, where the required
subset of input rvs are replaced by their posterior samples (chain and draw
dimensions, or the dimensions provided in sample_dims are flattened). The other
input tensors are kept as is.

Parameters
----------
outputs : list[Variable]
The list of variables to vectorize over the posterior samples.
posterior : xr.Dataset
The posterior samples to use as replacements for the `input_rvs`.
input_rvs : list[Variable]
The list of random variables to replace with their posterior samples.
input_tensors : Sequence[Variable]
The list of tensors to keep as is.
allow_rvs_in_graph : bool
Whether to allow random variables to be present in the graph. If False,
an error will be raised if any random variables are found in the graph. If
True, the remaining random variables will be resized to match the number of
draws from the posterior.
sample_dims : tuple[str, ...]
The dimensions of the posterior samples to use for looping the `input_rvs`.

Returns
-------
looped_outputs : list[Variable]
The looped variables, reshaped to match the original shape of the outputs, but
adding the sample_dims to the left.
updates : dict[Variable, Variable]
Dictionary of updates needed to compile the pytensor function to produce the
outputs.

Raises
------
RuntimeError
If random variables are found in the graph and `allow_rvs_in_graph` is False
ValueError
If the supplied output tensors do not depend on the requested input tensors
"""
if not (set(input_tensors) <= set(ancestors(outputs))):
raise ValueError( # pragma: no cover
"The supplied output tensors do not depend on the following requested "
f"input tensors: {set(input_tensors) - set(ancestors(outputs))}"
)
outputs_ancestors = ancestors(outputs, blockers=input_rvs)
rvs_from_posterior: list[TensorVariable] = [
cast(TensorVariable, rv) for rv in outputs_ancestors if rv in set(input_rvs)
]
independent_rvs = [
rv
for rv in rvs_in_graph(outputs)
if rv in outputs_ancestors and rv not in rvs_from_posterior
]

def step(*args):
input_values = args[: len(args) - len(input_tensors) - len(independent_rvs)]
non_sequences = args[len(args) - len(input_tensors) - len(independent_rvs) :]

# Compute output sample value for input sample values
replace = {
**dict(zip(rvs_from_posterior, input_values, strict=True)),
}
samples = clone_while_sharing_some_variables(
outputs, replace=replace, kept_variables=non_sequences
)
Comment on lines +1167 to +1169
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know VI does something like this, either by passing things that shouldn't be replaced as keys:values, or by using graph_replace, can you check if the same approach works?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# Collect updates if there are RV Ops in the graph
updates = collect_default_updates(outputs=samples, inputs=input_values)
return (*samples,), updates

sequences = []
batch_shape = tuple([len(posterior.coords[dim]) for dim in sample_dims])
nsamples = np.prod(batch_shape)
for rv in rvs_from_posterior:
values = posterior[rv.name].data
sequences.append(
pt.constant(
np.reshape(values, (nsamples, *values.shape[2:])),
name=rv.name,
dtype=rv.dtype,
)
)
scan_out, updates = scan(
fn=step,
sequences=sequences,
non_sequences=[*input_tensors, *independent_rvs],
n_steps=nsamples,
)
if len(outputs) == 1:
scan_out = [scan_out] # pragma: no cover

looped: list[Variable] = []
for out in scan_out:
core_shape = tuple(
[
static if static is not None else dynamic
for static, dynamic in zip(out.type.shape[1:], out.shape[1:])
]
)
looped.append(pt.reshape(out, (*batch_shape, *core_shape)))
if not allow_rvs_in_graph:
remaining_rvs = rvs_in_graph(looped)
if remaining_rvs:
raise RuntimeError(
f"The following random variables found in the extracted graph: {remaining_rvs}"
)
return looped, updates
113 changes: 113 additions & 0 deletions tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
compile_forward_sampling_function,
get_constant_coords,
get_vars_in_point_list,
loop_over_posterior,
observed_dependent_deterministics,
vectorize_over_posterior,
)
Expand Down Expand Up @@ -1958,3 +1959,115 @@ def test_vectorize_over_posterior_matches_sample():
atol=0.6 / np.sqrt(10000),
)
assert np.all(np.abs(vect_obs - x_posterior[..., None]) < 1)


def test_loop_over_posterior(
variable_to_vectorize,
input_rv_names,
allow_rvs_in_graph,
model_to_vectorize,
):
model, idata = model_to_vectorize

if not allow_rvs_in_graph and (len(input_rv_names) == 0 or "z" in variable_to_vectorize):
with pytest.raises(
RuntimeError,
match="The following random variables found in the extracted graph",
):
loop_over_posterior(
outputs=[model[name] for name in variable_to_vectorize],
posterior=idata.posterior,
input_rvs=[model[name] for name in input_rv_names],
input_tensors=[model["d"]],
allow_rvs_in_graph=allow_rvs_in_graph,
)
else:
vectorized, _ = loop_over_posterior(
outputs=[model[name] for name in variable_to_vectorize],
posterior=idata.posterior,
input_rvs=[model[name] for name in input_rv_names],
input_tensors=[model["d"]],
allow_rvs_in_graph=allow_rvs_in_graph,
)
assert all(
vectorized_var is not model[name]
for vectorized_var, name in zip(vectorized, variable_to_vectorize)
)
assert all(vectorized_var.type.shape == (1, 100, 3) for vectorized_var in vectorized)
assert all(
variable_depends_on(
vectorized_var.owner.inputs[0].owner.op.inner_outputs[0], model["d"]
)
for vectorized_var in vectorized
)
inner_graph_outputs = [
vectorized_var.owner.inputs[0].owner.op.inner_outputs[i]
for i, vectorized_var in enumerate(vectorized)
]
if len(vectorized) == 2:
assert variable_depends_on(
inner_graph_outputs[variable_to_vectorize.index("z_downstream")],
inner_graph_outputs[variable_to_vectorize.index("z")],
)
if len(input_rv_names) > 0:
for input_rv_name in input_rv_names:
if input_rv_name == "x_parent":
assert len(get_var_by_name(inner_graph_outputs, input_rv_name)) == 0
else:
[vectorized_rv] = get_var_by_name(vectorized, input_rv_name)
rv_posterior = idata.posterior[input_rv_name].data
assert isinstance(vectorized_rv, TensorConstant)
np.testing.assert_equal(
np.reshape(vectorized_rv.value, rv_posterior.shape),
rv_posterior,
strict=True,
)
else:
original_rvs = rvs_in_graph([model[name] for name in variable_to_vectorize])
expected_rv_shapes = {rv.type.shape for rv in original_rvs}
rvs = rvs_in_graph(inner_graph_outputs)
assert {rv.type.shape for rv in rvs} == expected_rv_shapes


def test_loop_over_posterior_matches_sample():
rng = np.random.default_rng(1234)
with pm.Model() as model:
x = pm.Normal("x")
sigma = 0.1
obs = pm.Normal("obs", x, sigma, observed=rng.normal(size=10))
det = pm.Deterministic("det", obs + 1)

chains = 2
draws = 100
x_posterior = np.broadcast_to(100 * np.arange(chains)[..., None], (chains, draws))
with model:
posterior = xr.Dataset(
{
"x": xr.DataArray(
x_posterior,
dims=("chain", "draw"),
coords={"chain": np.arange(chains), "draw": np.arange(draws)},
)
}
)
idata = InferenceData(posterior=posterior)
with model:
pp = pm.sample_posterior_predictive(idata, var_names=["obs", "det"], random_seed=1234)
vectorized, updates = loop_over_posterior(
outputs=[obs, det],
posterior=posterior,
input_rvs=[x],
allow_rvs_in_graph=True,
)
[vect_obs, vect_det] = compile(
inputs=[], outputs=vectorized, random_seed=1234, updates=updates
)()
assert pp.posterior_predictive["obs"].shape == vect_obs.shape
assert pp.posterior_predictive["det"].shape == vect_det.shape
np.testing.assert_allclose(vect_obs + 1, vect_det)
np.testing.assert_allclose(
pp.posterior_predictive["obs"].mean(dim=("chain", "draw")),
vect_obs.mean(axis=(0, 1)),
atol=0.6 / np.sqrt(10000),
)
assert np.all(np.abs(vect_obs - x_posterior[..., None]) < 1)
24 changes: 23 additions & 1 deletion tests/test_pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytensor import scan, shared
from pytensor.compile import UnusedInputError
from pytensor.compile.builders import OpFromGraph
from pytensor.graph.basic import Variable, equal_computations
from pytensor.graph.basic import Variable, ancestors, equal_computations, get_var_by_name
from pytensor.tensor.subtensor import AdvancedIncSubtensor

import pymc as pm
Expand All @@ -36,6 +36,7 @@
from pymc.logprob.utils import ParameterValueError
from pymc.pytensorf import (
PointFunc,
clone_while_sharing_some_variables,
collect_default_updates,
compile,
constant_fold,
Expand Down Expand Up @@ -785,3 +786,24 @@ def test_pickle_point_func():
np.testing.assert_allclose(
point_f_unpickled({"y": [3], "x": [2]}), point_f({"y": [3], "x": [2]})
)


def test_clone_while_sharing_some_variables():
with pm.Model() as model:
x = pm.Normal("x")
d = pm.Data("d", np.array([1, 2, 3]))
obs = pm.Data("obs", np.ones_like(d.get_value()))
y = pm.Deterministic("y", x * d)
z = pm.Gamma("z", mu=pt.exp(y), sigma=pt.exp(y) * 0.1, observed=obs)

kept_variables = [*model.free_RVs, *model.data_vars]
d_replace = pt.zeros_like(d.get_value())
d_replace.name = "d"
z_clone = clone_while_sharing_some_variables([z], kept_variables, {d: d_replace})[0]
assert z_clone is not z
cloned_ancestors = list(ancestors([z_clone]))
for kept_var in [x, obs]:
assert kept_var in cloned_ancestors
for different_var in [d, y]:
assert different_var not in cloned_ancestors
assert np.all(get_var_by_name([z_clone], "d")[0].eval() == 0)
Loading