-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add a function to create a tensor that represents the loop over the posterior #7885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,7 @@ | |
walk, | ||
) | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.scan.basic import scan | ||
from pytensor.tensor.random.var import RandomGeneratorSharedVariable | ||
from pytensor.tensor.sharedvar import SharedVariable, TensorSharedVariable | ||
from pytensor.tensor.variable import TensorConstant, TensorVariable | ||
|
@@ -57,7 +58,12 @@ | |
from pymc.distributions.shape_utils import change_dist_size | ||
from pymc.model import Model, modelcontext | ||
from pymc.progress_bar import CustomProgress, default_progress_theme | ||
from pymc.pytensorf import compile, rvs_in_graph | ||
from pymc.pytensorf import ( | ||
clone_while_sharing_some_variables, | ||
collect_default_updates, | ||
compile, | ||
rvs_in_graph, | ||
) | ||
from pymc.util import ( | ||
RandomState, | ||
_get_seeds_per_chain, | ||
|
@@ -68,6 +74,7 @@ | |
__all__ = ( | ||
"compile_forward_sampling_function", | ||
"draw", | ||
"loop_over_posterior", | ||
"sample_posterior_predictive", | ||
"sample_prior_predictive", | ||
"vectorize_over_posterior", | ||
|
@@ -1083,3 +1090,122 @@ def vectorize_over_posterior( | |
f"The following random variables found in the extracted graph: {remaining_rvs}" | ||
) | ||
return vectorized_outputs | ||
|
||
|
||
def loop_over_posterior( | ||
outputs: list[Variable], | ||
posterior: xr.Dataset, | ||
input_rvs: list[Variable], | ||
input_tensors: Sequence[Variable] = (), | ||
allow_rvs_in_graph: bool = True, | ||
sample_dims: tuple[str, ...] = ("chain", "draw"), | ||
) -> tuple[list[Variable], dict[Variable, Variable]]: | ||
"""Loop over posterior samples of subset of input rvs. | ||
|
||
This function creates a new graph for the supplied outputs, where the required | ||
subset of input rvs are replaced by their posterior samples (chain and draw | ||
dimensions, or the dimensions provided in sample_dims are flattened). The other | ||
input tensors are kept as is. | ||
|
||
Parameters | ||
---------- | ||
outputs : list[Variable] | ||
The list of variables to vectorize over the posterior samples. | ||
posterior : xr.Dataset | ||
The posterior samples to use as replacements for the `input_rvs`. | ||
input_rvs : list[Variable] | ||
The list of random variables to replace with their posterior samples. | ||
input_tensors : Sequence[Variable] | ||
The list of tensors to keep as is. | ||
allow_rvs_in_graph : bool | ||
Whether to allow random variables to be present in the graph. If False, | ||
an error will be raised if any random variables are found in the graph. If | ||
True, the remaining random variables will be resized to match the number of | ||
draws from the posterior. | ||
sample_dims : tuple[str, ...] | ||
The dimensions of the posterior samples to use for looping the `input_rvs`. | ||
|
||
Returns | ||
------- | ||
looped_outputs : list[Variable] | ||
The looped variables, reshaped to match the original shape of the outputs, but | ||
adding the sample_dims to the left. | ||
updates : dict[Variable, Variable] | ||
Dictionary of updates needed to compile the pytensor function to produce the | ||
outputs. | ||
|
||
Raises | ||
------ | ||
RuntimeError | ||
If random variables are found in the graph and `allow_rvs_in_graph` is False | ||
ValueError | ||
If the supplied output tensors do not depend on the requested input tensors | ||
""" | ||
if not (set(input_tensors) <= set(ancestors(outputs))): | ||
raise ValueError( # pragma: no cover | ||
"The supplied output tensors do not depend on the following requested " | ||
f"input tensors: {set(input_tensors) - set(ancestors(outputs))}" | ||
) | ||
outputs_ancestors = ancestors(outputs, blockers=input_rvs) | ||
rvs_from_posterior: list[TensorVariable] = [ | ||
cast(TensorVariable, rv) for rv in outputs_ancestors if rv in set(input_rvs) | ||
] | ||
independent_rvs = [ | ||
rv | ||
for rv in rvs_in_graph(outputs) | ||
if rv in outputs_ancestors and rv not in rvs_from_posterior | ||
] | ||
|
||
def step(*args): | ||
input_values = args[: len(args) - len(input_tensors) - len(independent_rvs)] | ||
non_sequences = args[len(args) - len(input_tensors) - len(independent_rvs) :] | ||
|
||
# Compute output sample value for input sample values | ||
replace = { | ||
**dict(zip(rvs_from_posterior, input_values, strict=True)), | ||
} | ||
samples = clone_while_sharing_some_variables( | ||
outputs, replace=replace, kept_variables=non_sequences | ||
) | ||
Comment on lines
+1167
to
+1169
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know VI does something like this, either by passing things that shouldn't be replaced as keys:values, or by using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
# Collect updates if there are RV Ops in the graph | ||
updates = collect_default_updates(outputs=samples, inputs=input_values) | ||
return (*samples,), updates | ||
|
||
sequences = [] | ||
batch_shape = tuple([len(posterior.coords[dim]) for dim in sample_dims]) | ||
nsamples = np.prod(batch_shape) | ||
for rv in rvs_from_posterior: | ||
values = posterior[rv.name].data | ||
sequences.append( | ||
pt.constant( | ||
np.reshape(values, (nsamples, *values.shape[2:])), | ||
name=rv.name, | ||
dtype=rv.dtype, | ||
) | ||
) | ||
scan_out, updates = scan( | ||
fn=step, | ||
sequences=sequences, | ||
non_sequences=[*input_tensors, *independent_rvs], | ||
n_steps=nsamples, | ||
) | ||
if len(outputs) == 1: | ||
scan_out = [scan_out] # pragma: no cover | ||
|
||
looped: list[Variable] = [] | ||
for out in scan_out: | ||
core_shape = tuple( | ||
[ | ||
static if static is not None else dynamic | ||
for static, dynamic in zip(out.type.shape[1:], out.shape[1:]) | ||
] | ||
) | ||
looped.append(pt.reshape(out, (*batch_shape, *core_shape))) | ||
if not allow_rvs_in_graph: | ||
remaining_rvs = rvs_in_graph(looped) | ||
if remaining_rvs: | ||
raise RuntimeError( | ||
f"The following random variables found in the extracted graph: {remaining_rvs}" | ||
) | ||
return looped, updates |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of a separate function have an argument:
use_scan: bool = False
in the previous function? Most logic should be the same.This flatten batch -> scan -> reshape logic should go in PyTensor, it's needed in general