Hi, thank you for the great package.
I am trying to run a model for a few samples, save the state, and keep sampling. A minimal example will be a long the lines of
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from MultiHMCGibbs import MultiHMCGibbs
def model():
x = numpyro.sample("x", dist.Normal(0.0, 2.0))
y = numpyro.sample("y", dist.Normal(0.0, 2.0))
numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))
inner_kernels = [
NUTS(model),
NUTS(model)
]
outer_kernel = MultiHMCGibbs(
inner_kernels,
[['y'], ['x']]
)
mcmc = MCMC(
outer_kernel,
num_warmup=100,
num_samples=100,
progress_bar=True
)
mcmc.run(rng_key)
mcmc.post_warmup_state = mcmc.last_state
mcmc.run(mcmc.post_warmup_state.rng_key)
This gives the progress bar
sample: 100%|██████████| 200/200 [00:02<00:00, 68.52it/s, 3/3 steps of size 8.04e-01/6.58e-01. acc. prob=0.94/0.96]
sample: 100%|██████████| 100/100 [00:00<00:00, 508.15it/s]
As you can see, the progress bar for the second sample some how is not reflecting all the partitions of the kernel.
The code below seems to fix the issue.
mcmc = MCMC(
outer_kernel,
num_warmup=100,
num_samples=100,
progress_bar=True
)
mcmc.run(rng_key)
mcmc.post_warmup_state = mcmc.last_state
mcmc.run(rng_key)
sample: 100%|██████████| 200/200 [00:02<00:00, 68.79it/s, 3/3 steps of size 8.04e-01/6.58e-01. acc. prob=0.94/0.96]
sample: 100%|██████████| 100/100 [00:02<00:00, 38.47it/s, 7/3 steps of size 8.04e-01/6.58e-01. acc. prob=0.94/0.95]
rng_key here is just a regular jax rng_key while mcmc.post_warmup_state.rng_key are a few keys stacked together. I am not exactly certain how the kernel treats the two cases separately and what is the proper way (or if there are other unintended side effects). I would really appreciate any advice! Thank you in advance. - Alan
Hi, thank you for the great package.
I am trying to run a model for a few samples, save the state, and keep sampling. A minimal example will be a long the lines of
This gives the progress bar
sample: 100%|██████████| 200/200 [00:02<00:00, 68.52it/s, 3/3 steps of size 8.04e-01/6.58e-01. acc. prob=0.94/0.96]
sample: 100%|██████████| 100/100 [00:00<00:00, 508.15it/s]
As you can see, the progress bar for the second sample some how is not reflecting all the partitions of the kernel.
The code below seems to fix the issue.
sample: 100%|██████████| 200/200 [00:02<00:00, 68.79it/s, 3/3 steps of size 8.04e-01/6.58e-01. acc. prob=0.94/0.96]
sample: 100%|██████████| 100/100 [00:02<00:00, 38.47it/s, 7/3 steps of size 8.04e-01/6.58e-01. acc. prob=0.94/0.95]
rng_keyhere is just a regular jax rng_key whilemcmc.post_warmup_state.rng_keyare a few keys stacked together. I am not exactly certain how the kernel treats the two cases separately and what is the proper way (or if there are other unintended side effects). I would really appreciate any advice! Thank you in advance. - Alan