Skip to content

plate_stack doesn't work as expected, example below #2010

@vishu-tyagi

Description

@vishu-tyagi

Bug Description

plate_stack doesn't work as expected. I'll describe the issue with an example below:

Steps to Reproduce

In the code below, in the _model callable, I want the shape of a to be (2, 1, 3) so that it multiplies properly with a_fraction. a_fraction has the shape (2, 4, 3). But I get the following error when running this model with NUTS sampler. (Although, I'm able to get the trace of this model, which works fine.)

Incompatible shapes for broadcasting: shapes=[(4, 1, 1), (2, 1, 3)]

ValueError: Incompatible shapes for broadcasting: shapes=[(4, 1, 1), (2, 1, 3)]

But, the _model_without_plate_stack works, although it's the same model as _model, just without plate_stack.

_model_03 also raises the same error as _model, and the difference between these two models is that, in _model_03 I've specified the outer dimensions as well (although by default they should be -1).

from jax import random
import numpyro as pyro
from numpyro import distributions as dist
from numpyro.infer import NUTS, MCMC


def _model():
    outer_plate = 3
    inner_plate = [2, 4]

    with pyro.plate("outer_plate", outer_plate):
        with pyro.plate("inner_plate_0", inner_plate[0], dim=-3):
            a = pyro.sample("a", dist.Exponential(.2))

    with pyro.plate("outer_plate", outer_plate):
        with pyro.plate_stack("inner_plate", inner_plate, rightmost_dim=-2):
            a_fraction = pyro.sample("a_fraction", dist.Beta(2, 1))
            b = pyro.deterministic("b", a_fraction * a)


def _model_without_plate_stack():
    outer_plate = 3
    inner_plate = [2, 4]

    with pyro.plate("outer_plate", outer_plate):
        with pyro.plate("inner_plate_0", inner_plate[0], dim=-3):
            a = pyro.sample("a", dist.Exponential(.2))

    with pyro.plate("outer_plate", outer_plate):
        # with pyro.plate_stack("inner_plate", inner_plate, rightmost_dim=-2):
        with pyro.plate("inner_plate_1", inner_plate[1]):
            with pyro.plate("inner_plate_0", inner_plate[0]):
                a_fraction = pyro.sample("a_fraction", dist.Beta(2, 1))
                b = pyro.deterministic("b", a_fraction * a)


def _model_03():
    outer_plate = 3
    inner_plate = [2, 4]

    with pyro.plate("outer_plate", outer_plate, dim=-1):
        with pyro.plate("inner_plate_0", inner_plate[0], dim=-3):
            a = pyro.sample("a", dist.Exponential(.2))

    with pyro.plate("outer_plate", outer_plate, dim=-1):
        with pyro.plate_stack("inner_plate", inner_plate, rightmost_dim=-2):
            a_fraction = pyro.sample("a_fraction", dist.Beta(2, 1))
            b = pyro.deterministic("b", a_fraction * a)


def trace(key, model):
    with pyro.handlers.seed(rng_seed=key):
        trace = pyro.handlers.trace(model).get_trace()
    print(f"Trace successful")
    return trace


def run(key, model):
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_samples=100, num_warmup=100)
    mcmc.run(key)
    print(f"Run successful")
    return mcmc
    

key = random.key(0)

trace01 = trace(key, _model) # works
mcmc01 = run(key, _model) # fails

trace02 = trace(key, _model_without_plate_stack) # works
mcmc02 = run(key, _model_without_plate_stack) # works

trace03 = trace(key, _model_03) # works
mcmc03 = run(key, _model_03) # fails

Expected Behavior

mcmc01 = run(key, _model) should run as expected.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions