-
Notifications
You must be signed in to change notification settings - Fork 266
Description
Bug Description
TraceEnum_ELBO
ignores the is_auxiliary
option of infer
in sampling statements.
Steps to Reproduce
Apologies, these code snippets are a bit long.
Model
def mix_weights(beta):
"""
Function to do the stick breaking construction
"""
beta1m_cumprod = jnp.cumprod(1 - beta, axis=-1)
term1 = jnp.pad(beta, (0, 1), mode='constant', constant_values=1.)
term2 = jnp.pad(beta1m_cumprod, (1, 0), mode='constant', constant_values=1.)
return jnp.multiply(term1, term2)
def model(
K,
N,
D_discrete,
X_mixture=None,
alpha=None
):
# priors
if alpha is None:
alpha = numpyro.sample("alpha", dist.Uniform(0.3, 10.0))
with numpyro.plate("v_plate", K-1):
v = numpyro.sample("v", dist.Beta(1, alpha))
cluster_probabilities = numpyro.deterministic(
"cluster_proba",
mix_weights(v),
)
with numpyro.plate("cluster_K", K):
_phi_latent = numpyro.sample(
"_phi_latent",
dist.Normal(
loc=jnp.zeros(D_discrete * K).reshape(K, D_discrete),
scale=jnp.ones(D_discrete * K).reshape(K, D_discrete)
).to_event(1)
)
phi = numpyro.deterministic("phi", jax.nn.sigmoid(_phi_latent))
# model sampling
with numpyro.plate('data', N):
# Assignment is which cluster each row of data belongs to.
assignment = numpyro.sample(
'assignment',
dist.CategoricalProbs(cluster_probabilities),
infer={'enumerate': 'parallel'}
)
obs = numpyro.sample(
'obs',
dist.Bernoulli(phi[assignment, :]).to_event(1),
obs=X_mixture if X_mixture is not None else None,
)
Guide
def guide(
K,
N,
D_discrete,
X_mixture=None,
alpha=None,
):
n_vars = 1 + (K - 1) + D_discrete * K
_latent_loc = numpyro.param("_latent_loc", jnp.zeros(n_vars))
tmp = jnp.identity(n_vars) * 0.1
for idx in range(n_vars - 1):
rng1 = jnp.arange(idx+1, n_vars)
rng2 = jnp.arange(n_vars - idx - 1)
tmp = tmp.at[rng1, rng2].set(0.01)
_latent_L = numpyro.param(
"_latent_L",
tmp,
constraint=dist.constraints.corr_cholesky,
)
_latent_distribution = numpyro.sample(
"_latent_distribution",
dist.MultivariateNormal(_latent_loc, scale_tril=_latent_L),
infer={'is_auxiliary': True}
)
if alpha is None:
numpyro.sample("alpha", dist.Delta(_latent_distribution[0]))
with numpyro.plate("v_plate", K-1):
v = numpyro.sample(
"v",
dist.Delta(
jax.nn.sigmoid(_latent_distribution[1:K])
),
)
cluster_probabilities = numpyro.deterministic(
"cluster_proba",
mix_weights(v)
)
with numpyro.plate("cluster_K", K):
_phi_latent = numpyro.sample(
"_phi_latent",
dist.Delta(_latent_distribution[K:(K + D_discrete * K)].reshape(K, D_discrete)).to_event(1)
)
phi = numpyro.deterministic(
"phi",
jax.nn.sigmoid(_phi_latent)
)
with numpyro.plate('data', N):
assignment = numpyro.sample(
'assignment',
dist.CategoricalProbs(cluster_probabilities)
)
The site _latent_distribution
in the guide does not exist in the model.
Fitting
def fit_SVI_manual(df, elbo_fn=numpyro.infer.TraceEnum_ELBO, steps=1000):
learning_rate = 0.05
particles = 5
D_d = df[[i for i in df.keys() if i.startswith("d_var")]].shape[1]
rng = jax.random.PRNGKey(0)
elbo = elbo_fn(num_particles=particles, vectorize_particles=True)
optimiser = numpyro.optim.Adam(learning_rate)
svi = numpyro.infer.SVI(
model,
guide,
optimiser,
elbo,
D_discrete=D_d,
X_mixture=df[[i for i in df.keys() if i.startswith("d_var")]].to_numpy(),
alpha=None,
K=10,
N=df.shape[0]
)
svi_result = svi.run(rng, steps, progress_bar=True)
log_lik_out=np.array([])
posterior_predictive_samples = {}
posterior_predictive_function = numpyro.infer.Predictive(
model,
guide=guide_fn,
params=svi_result.params,
num_samples=samples
)
rng, rng_subkey = jax.random.split(key=rng)
posterior_predictive_samples = posterior_predictive_function(
rng_subkey,
D_discrete=D_d,
X_mixture=df[[i for i in df.keys() if i.startswith("d_var")]].to_numpy(),
alpha=None,
K=10,
N=df.shape[0]
)
posterior_function = numpyro.infer.Predictive(
guide_fn,
params=svi_result.params,
num_samples=samples
)
posterior_samples = posterior_function(
rng,
D_discrete=D_d,
X_mixture=df[[i for i in df.keys() if i.startswith("d_var")]].to_numpy(),
alpha=None,
K=10,
N=df.shape[0]
)
# Add parameters to the posterior samples from the posterior predictive
for k in set(posterior_predictive_samples.keys()):
if not(k.startswith("obs")):
posterior_samples[k] = posterior_predictive_samples[k]
log_lik = numpyro.infer.util.log_likelihood(
model,
posterior_samples,
D_discrete=D_d,
X_mixture=df[[i for i in df.keys() if i.startswith("d_var")]].to_numpy(),
alpha=None,
K=10,
N=df.shape[0]
)
az_obj = az.from_dict(
posterior={k: v.reshape([1] + list(v.shape)) for k, v in posterior_samples.items()},
posterior_predictive={k: v.reshape([1] + list(v.shape)) for k, v in posterior_predictive_samples.items()},
log_likelihood={k: v.reshape([1] + list(v.shape)) for k, v in log_lik.items()},
sample_stats={"losses": svi_result.losses} # This is a bit hacky...
)
return az_obj
The input to the fitting function is a pandas dataframe where the columns are named d_var#number
and contain 0 or 1.
Expected Behavior
Model fit works using TraceEnum_ELBO
. These functions work correctly using TraceGraph_ELBO
, but the parameter estimates are poor.
Currently throws this (long) error:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-20-26729e82f0ad> in <cell line: 0>()
----> 1 az_obj = fit_SVI_manual(df, elbo_fn=numpyro.infer.TraceEnum_ELBO)
2 # Errors with KeyError: '_latent_distribution'
10 frames
<ipython-input-18-b821992b9efc> in fit_SVI_manual(df, elbo_fn)
23 N=df.shape[0]
24 )
---> 25 svi_result = svi.run(rng, steps, progress_bar=True)
26
27 log_lik_out=np.array([])
/usr/local/lib/python3.11/dist-packages/numpyro/infer/svi.py in run(self, rng_key, num_steps, progress_bar, stable_update, forward_mode_differentiation, init_state, init_params, *args, **kwargs)
407 batch = max(num_steps // 20, 1)
408 for i in t:
--> 409 svi_state, loss = jit(body_fn)(svi_state, None)
410 losses.append(jax.device_get(loss))
411 if i % batch == 0:
[... skipping hidden 13 frame]
/usr/local/lib/python3.11/dist-packages/numpyro/infer/svi.py in body_fn(svi_state, _)
390 )
391 else:
--> 392 svi_state, loss = self.update(
393 svi_state,
394 *args,
/usr/local/lib/python3.11/dist-packages/numpyro/infer/svi.py in update(self, svi_state, forward_mode_differentiation, *args, **kwargs)
282 mutable_state=svi_state.mutable_state,
283 )
--> 284 (loss_val, mutable_state), optim_state = self.optim.eval_and_update(
285 loss_fn,
286 svi_state.optim_state,
/usr/local/lib/python3.11/dist-packages/numpyro/optim.py in eval_and_update(self, fn, state, forward_mode_differentiation)
125 """
126 params: _Params = self.get_params(state)
--> 127 (out, aux), grads = _value_and_grad(
128 fn, x=params, forward_mode_differentiation=forward_mode_differentiation
129 )
/usr/local/lib/python3.11/dist-packages/numpyro/optim.py in _value_and_grad(f, x, forward_mode_differentiation)
48 return (out, aux), grads
49 else:
---> 50 return value_and_grad(f, has_aux=True)(x)
51
52
[... skipping hidden 16 frame]
/usr/local/lib/python3.11/dist-packages/numpyro/infer/svi.py in loss_fn(params)
59 else:
60 return (
---> 61 elbo.loss(
62 rng_key, params, model, guide, *args, **kwargs, **static_kwargs
63 ),
/usr/local/lib/python3.11/dist-packages/numpyro/infer/elbo.py in loss(self, rng_key, param_map, model, guide, *args, **kwargs)
1286 rng_keys = random.split(rng_key, self.num_particles)
1287 return -jnp.mean(
-> 1288 self.vectorize_particles_fn(single_particle_elbo, rng_keys)
1289 )
/usr/local/lib/python3.11/dist-packages/numpyro/infer/elbo.py in _apply_vmap(fn, keys)
28
29 def _apply_vmap(fn, keys):
---> 30 return vmap(fn)(keys)
31
32
[... skipping hidden 7 frame]
/usr/local/lib/python3.11/dist-packages/numpyro/infer/elbo.py in single_particle_elbo(rng_key)
1143 seeded_model = seed(model, model_seed)
1144 seeded_guide = seed(guide, guide_seed)
-> 1145 model_trace, guide_trace, sum_vars = get_importance_trace_enum(
1146 seeded_model,
1147 seeded_guide,
/usr/local/lib/python3.11/dist-packages/numpyro/infer/elbo.py in get_importance_trace_enum(model, guide, args, kwargs, params, max_plate_nesting, model_deps, guide_desc)
980 kl_qp, output=funsor.Real, dim_to_name=dim_to_name
981 )
--> 982 elif not is_model and (model_trace[name].get("kl") is not None):
983 # skip logq computation if analytic kl was computed
984 pass
KeyError: '_latent_distribution'
I think this issue is also affecting the AutoMultivariateNormal
autoguide. TraceGraph_ELBO
works with AutoMultivariateNormal
using a block
handler to block the discrete latent site:
guide = numpyro.infer.autoguide.AutoMultivariateNormal(
numpyro.handlers.block(
numpyro.handlers.seed(model, rng),
lambda site: site["name"] == "assignment"
)
)
This was originally discussed on the numpyro forums (https://forum.pyro.ai/t/extra-sampling-site-in-manual-guide-compared-to-model/8714/1). Code is runnable in a google colab sheet: https://colab.research.google.com/drive/1FBs765Gyw0KsDFMevd1WyJYzby8n3but?usp=sharing
Thank you!