Skip to content

rng_key input for sequential MCMC  #9

@AZhou00

Description

@AZhou00

Hi, thank you for the great package.

I am trying to run a model for a few samples, save the state, and keep sampling. A minimal example will be a long the lines of

from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from MultiHMCGibbs import MultiHMCGibbs

def model():
     x = numpyro.sample("x", dist.Normal(0.0, 2.0))
     y = numpyro.sample("y", dist.Normal(0.0, 2.0))
     numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))

inner_kernels = [
    NUTS(model),
    NUTS(model)
]
outer_kernel = MultiHMCGibbs(
    inner_kernels,
    [['y'], ['x']]
)

mcmc = MCMC(
    outer_kernel,
    num_warmup=100,
    num_samples=100,
    progress_bar=True
)
mcmc.run(rng_key)
mcmc.post_warmup_state = mcmc.last_state
mcmc.run(mcmc.post_warmup_state.rng_key)

This gives the progress bar
sample: 100%|██████████| 200/200 [00:02<00:00, 68.52it/s, 3/3 steps of size 8.04e-01/6.58e-01. acc. prob=0.94/0.96]
sample: 100%|██████████| 100/100 [00:00<00:00, 508.15it/s]

As you can see, the progress bar for the second sample some how is not reflecting all the partitions of the kernel.

The code below seems to fix the issue.

mcmc = MCMC(
    outer_kernel,
    num_warmup=100,
    num_samples=100,
    progress_bar=True
)
mcmc.run(rng_key)
mcmc.post_warmup_state = mcmc.last_state
mcmc.run(rng_key)

sample: 100%|██████████| 200/200 [00:02<00:00, 68.79it/s, 3/3 steps of size 8.04e-01/6.58e-01. acc. prob=0.94/0.96]
sample: 100%|██████████| 100/100 [00:02<00:00, 38.47it/s, 7/3 steps of size 8.04e-01/6.58e-01. acc. prob=0.94/0.95]

rng_key here is just a regular jax rng_key while mcmc.post_warmup_state.rng_key are a few keys stacked together. I am not exactly certain how the kernel treats the two cases separately and what is the proper way (or if there are other unintended side effects). I would really appreciate any advice! Thank you in advance. - Alan

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions