-
Notifications
You must be signed in to change notification settings - Fork 267
Closed
Milestone
Description
Bug Description
plate_stack doesn't work as expected. I'll describe the issue with an example below:
Steps to Reproduce
In the code below, in the _model callable, I want the shape of a
to be (2, 1, 3) so that it multiplies properly with a_fraction
. a_fraction
has the shape (2, 4, 3). But I get the following error when running this model with NUTS sampler. (Although, I'm able to get the trace of this model, which works fine.)
Incompatible shapes for broadcasting: shapes=[(4, 1, 1), (2, 1, 3)]
ValueError: Incompatible shapes for broadcasting: shapes=[(4, 1, 1), (2, 1, 3)]
But, the _model_without_plate_stack works, although it's the same model as _model, just without plate_stack.
_model_03 also raises the same error as _model, and the difference between these two models is that, in _model_03 I've specified the outer dimensions as well (although by default they should be -1).
from jax import random
import numpyro as pyro
from numpyro import distributions as dist
from numpyro.infer import NUTS, MCMC
def _model():
outer_plate = 3
inner_plate = [2, 4]
with pyro.plate("outer_plate", outer_plate):
with pyro.plate("inner_plate_0", inner_plate[0], dim=-3):
a = pyro.sample("a", dist.Exponential(.2))
with pyro.plate("outer_plate", outer_plate):
with pyro.plate_stack("inner_plate", inner_plate, rightmost_dim=-2):
a_fraction = pyro.sample("a_fraction", dist.Beta(2, 1))
b = pyro.deterministic("b", a_fraction * a)
def _model_without_plate_stack():
outer_plate = 3
inner_plate = [2, 4]
with pyro.plate("outer_plate", outer_plate):
with pyro.plate("inner_plate_0", inner_plate[0], dim=-3):
a = pyro.sample("a", dist.Exponential(.2))
with pyro.plate("outer_plate", outer_plate):
# with pyro.plate_stack("inner_plate", inner_plate, rightmost_dim=-2):
with pyro.plate("inner_plate_1", inner_plate[1]):
with pyro.plate("inner_plate_0", inner_plate[0]):
a_fraction = pyro.sample("a_fraction", dist.Beta(2, 1))
b = pyro.deterministic("b", a_fraction * a)
def _model_03():
outer_plate = 3
inner_plate = [2, 4]
with pyro.plate("outer_plate", outer_plate, dim=-1):
with pyro.plate("inner_plate_0", inner_plate[0], dim=-3):
a = pyro.sample("a", dist.Exponential(.2))
with pyro.plate("outer_plate", outer_plate, dim=-1):
with pyro.plate_stack("inner_plate", inner_plate, rightmost_dim=-2):
a_fraction = pyro.sample("a_fraction", dist.Beta(2, 1))
b = pyro.deterministic("b", a_fraction * a)
def trace(key, model):
with pyro.handlers.seed(rng_seed=key):
trace = pyro.handlers.trace(model).get_trace()
print(f"Trace successful")
return trace
def run(key, model):
kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=100, num_warmup=100)
mcmc.run(key)
print(f"Run successful")
return mcmc
key = random.key(0)
trace01 = trace(key, _model) # works
mcmc01 = run(key, _model) # fails
trace02 = trace(key, _model_without_plate_stack) # works
mcmc02 = run(key, _model_without_plate_stack) # works
trace03 = trace(key, _model_03) # works
mcmc03 = run(key, _model_03) # fails
Expected Behavior
mcmc01 = run(key, _model) should run as expected.
Metadata
Metadata
Assignees
Labels
No labels