Skip to content

The TraceEnum_ELBO ELBO implementation ignores is_auxiliary option #2013

@jim-rafferty

Description

@jim-rafferty

Bug Description

TraceEnum_ELBO ignores the is_auxiliary option of infer in sampling statements.

Steps to Reproduce

Apologies, these code snippets are a bit long.

Model

def mix_weights(beta):
    """
    Function to do the stick breaking construction
    """
    beta1m_cumprod = jnp.cumprod(1 - beta, axis=-1)
    term1 = jnp.pad(beta, (0, 1), mode='constant', constant_values=1.)
    term2 = jnp.pad(beta1m_cumprod, (1, 0), mode='constant', constant_values=1.)
    return jnp.multiply(term1, term2)


def model(
    K,
    N,
    D_discrete,
    X_mixture=None,
    alpha=None
):

    # priors
    if alpha is None:
        alpha = numpyro.sample("alpha", dist.Uniform(0.3, 10.0))

    with numpyro.plate("v_plate", K-1):
        v = numpyro.sample("v", dist.Beta(1, alpha))

    cluster_probabilities = numpyro.deterministic(
        "cluster_proba",
        mix_weights(v),
    )

    with numpyro.plate("cluster_K", K):
        _phi_latent = numpyro.sample(
            "_phi_latent",
            dist.Normal(
                loc=jnp.zeros(D_discrete * K).reshape(K, D_discrete),
                scale=jnp.ones(D_discrete * K).reshape(K, D_discrete)
            ).to_event(1)
        )

    phi = numpyro.deterministic("phi", jax.nn.sigmoid(_phi_latent))

    # model sampling
    with numpyro.plate('data', N):

        # Assignment is which cluster each row of data belongs to.
        assignment =  numpyro.sample(
            'assignment',
            dist.CategoricalProbs(cluster_probabilities),
            infer={'enumerate': 'parallel'}
        )

        obs = numpyro.sample(
            'obs',
            dist.Bernoulli(phi[assignment, :]).to_event(1),
            obs=X_mixture if X_mixture is not None else None,
        )

Guide

def guide(
    K,
    N,
    D_discrete,
    X_mixture=None,
    alpha=None,
):

    n_vars = 1 + (K - 1) + D_discrete * K


    _latent_loc = numpyro.param("_latent_loc", jnp.zeros(n_vars))
    tmp = jnp.identity(n_vars) * 0.1
    for idx in range(n_vars - 1):
        rng1 = jnp.arange(idx+1, n_vars)
        rng2 = jnp.arange(n_vars - idx - 1)
        tmp = tmp.at[rng1, rng2].set(0.01)

    _latent_L = numpyro.param(
        "_latent_L",
        tmp,
        constraint=dist.constraints.corr_cholesky,
    )

    _latent_distribution = numpyro.sample(
        "_latent_distribution",
        dist.MultivariateNormal(_latent_loc, scale_tril=_latent_L),
        infer={'is_auxiliary': True}
    )

    if alpha is None:
        numpyro.sample("alpha", dist.Delta(_latent_distribution[0]))

    with numpyro.plate("v_plate", K-1):
        v = numpyro.sample(
            "v",
            dist.Delta(
                jax.nn.sigmoid(_latent_distribution[1:K])
            ),
        )

    cluster_probabilities = numpyro.deterministic(
        "cluster_proba",
        mix_weights(v)
    )

    with numpyro.plate("cluster_K", K):
        _phi_latent = numpyro.sample(
            "_phi_latent",
            dist.Delta(_latent_distribution[K:(K + D_discrete * K)].reshape(K, D_discrete)).to_event(1)
        )
    phi = numpyro.deterministic(
        "phi",
        jax.nn.sigmoid(_phi_latent)
    )

    with numpyro.plate('data', N):
        assignment =  numpyro.sample(
            'assignment',
            dist.CategoricalProbs(cluster_probabilities)
        )

The site _latent_distribution in the guide does not exist in the model.

Fitting

def fit_SVI_manual(df, elbo_fn=numpyro.infer.TraceEnum_ELBO, steps=1000):
    learning_rate = 0.05
    particles = 5
    D_d = df[[i for i in df.keys() if i.startswith("d_var")]].shape[1]

    rng = jax.random.PRNGKey(0)


    elbo = elbo_fn(num_particles=particles, vectorize_particles=True)
    optimiser = numpyro.optim.Adam(learning_rate)

    svi = numpyro.infer.SVI(
        model,
        guide,
        optimiser,
        elbo,
        D_discrete=D_d,
        X_mixture=df[[i for i in df.keys() if i.startswith("d_var")]].to_numpy(),
        alpha=None,
        K=10,
        N=df.shape[0]
    )
    svi_result = svi.run(rng, steps, progress_bar=True)

    log_lik_out=np.array([])
    posterior_predictive_samples = {}

    posterior_predictive_function = numpyro.infer.Predictive(
        model,
        guide=guide_fn,
        params=svi_result.params,
        num_samples=samples
    )

    rng, rng_subkey = jax.random.split(key=rng)

    posterior_predictive_samples = posterior_predictive_function(
        rng_subkey,
        D_discrete=D_d,
        X_mixture=df[[i for i in df.keys() if i.startswith("d_var")]].to_numpy(),
        alpha=None,
        K=10,
        N=df.shape[0]
    )

    posterior_function = numpyro.infer.Predictive(
        guide_fn,
        params=svi_result.params,
        num_samples=samples
    )
    posterior_samples = posterior_function(
        rng,
        D_discrete=D_d,
        X_mixture=df[[i for i in df.keys() if i.startswith("d_var")]].to_numpy(),
        alpha=None,
        K=10,
        N=df.shape[0]
    )

    # Add parameters to the posterior samples from the posterior predictive
    for k in set(posterior_predictive_samples.keys()):
        if not(k.startswith("obs")):
            posterior_samples[k] = posterior_predictive_samples[k]


    log_lik = numpyro.infer.util.log_likelihood(
        model,
        posterior_samples,
        D_discrete=D_d,
        X_mixture=df[[i for i in df.keys() if i.startswith("d_var")]].to_numpy(),
        alpha=None,
        K=10,
        N=df.shape[0]
    )

    az_obj = az.from_dict(
        posterior={k: v.reshape([1] + list(v.shape)) for k, v in posterior_samples.items()},
        posterior_predictive={k: v.reshape([1] + list(v.shape)) for k, v in posterior_predictive_samples.items()},
        log_likelihood={k: v.reshape([1] + list(v.shape)) for k, v in log_lik.items()},
        sample_stats={"losses": svi_result.losses} # This is a bit hacky...
    )
    return az_obj

The input to the fitting function is a pandas dataframe where the columns are named d_var#number and contain 0 or 1.

Expected Behavior

Model fit works using TraceEnum_ELBO. These functions work correctly using TraceGraph_ELBO, but the parameter estimates are poor.

Currently throws this (long) error:

---------------------------------------------------------------------------

KeyError                                  Traceback (most recent call last)

<ipython-input-20-26729e82f0ad> in <cell line: 0>()
----> 1 az_obj = fit_SVI_manual(df, elbo_fn=numpyro.infer.TraceEnum_ELBO)
      2 # Errors with KeyError: '_latent_distribution'

10 frames

<ipython-input-18-b821992b9efc> in fit_SVI_manual(df, elbo_fn)
     23         N=df.shape[0]
     24     )
---> 25     svi_result = svi.run(rng, steps, progress_bar=True)
     26 
     27     log_lik_out=np.array([])

/usr/local/lib/python3.11/dist-packages/numpyro/infer/svi.py in run(self, rng_key, num_steps, progress_bar, stable_update, forward_mode_differentiation, init_state, init_params, *args, **kwargs)
    407                 batch = max(num_steps // 20, 1)
    408                 for i in t:
--> 409                     svi_state, loss = jit(body_fn)(svi_state, None)
    410                     losses.append(jax.device_get(loss))
    411                     if i % batch == 0:

    [... skipping hidden 13 frame]

/usr/local/lib/python3.11/dist-packages/numpyro/infer/svi.py in body_fn(svi_state, _)
    390                 )
    391             else:
--> 392                 svi_state, loss = self.update(
    393                     svi_state,
    394                     *args,

/usr/local/lib/python3.11/dist-packages/numpyro/infer/svi.py in update(self, svi_state, forward_mode_differentiation, *args, **kwargs)
    282             mutable_state=svi_state.mutable_state,
    283         )
--> 284         (loss_val, mutable_state), optim_state = self.optim.eval_and_update(
    285             loss_fn,
    286             svi_state.optim_state,

/usr/local/lib/python3.11/dist-packages/numpyro/optim.py in eval_and_update(self, fn, state, forward_mode_differentiation)
    125         """
    126         params: _Params = self.get_params(state)
--> 127         (out, aux), grads = _value_and_grad(
    128             fn, x=params, forward_mode_differentiation=forward_mode_differentiation
    129         )

/usr/local/lib/python3.11/dist-packages/numpyro/optim.py in _value_and_grad(f, x, forward_mode_differentiation)
     48         return (out, aux), grads
     49     else:
---> 50         return value_and_grad(f, has_aux=True)(x)
     51 
     52 

    [... skipping hidden 16 frame]

/usr/local/lib/python3.11/dist-packages/numpyro/infer/svi.py in loss_fn(params)
     59         else:
     60             return (
---> 61                 elbo.loss(
     62                     rng_key, params, model, guide, *args, **kwargs, **static_kwargs
     63                 ),

/usr/local/lib/python3.11/dist-packages/numpyro/infer/elbo.py in loss(self, rng_key, param_map, model, guide, *args, **kwargs)
   1286             rng_keys = random.split(rng_key, self.num_particles)
   1287             return -jnp.mean(
-> 1288                 self.vectorize_particles_fn(single_particle_elbo, rng_keys)
   1289             )

/usr/local/lib/python3.11/dist-packages/numpyro/infer/elbo.py in _apply_vmap(fn, keys)
     28 
     29 def _apply_vmap(fn, keys):
---> 30     return vmap(fn)(keys)
     31 
     32 

    [... skipping hidden 7 frame]

/usr/local/lib/python3.11/dist-packages/numpyro/infer/elbo.py in single_particle_elbo(rng_key)
   1143             seeded_model = seed(model, model_seed)
   1144             seeded_guide = seed(guide, guide_seed)
-> 1145             model_trace, guide_trace, sum_vars = get_importance_trace_enum(
   1146                 seeded_model,
   1147                 seeded_guide,

/usr/local/lib/python3.11/dist-packages/numpyro/infer/elbo.py in get_importance_trace_enum(model, guide, args, kwargs, params, max_plate_nesting, model_deps, guide_desc)
    980                         kl_qp, output=funsor.Real, dim_to_name=dim_to_name
    981                     )
--> 982                 elif not is_model and (model_trace[name].get("kl") is not None):
    983                     # skip logq computation if analytic kl was computed
    984                     pass

KeyError: '_latent_distribution'

I think this issue is also affecting the AutoMultivariateNormal autoguide. TraceGraph_ELBO works with AutoMultivariateNormal using a block handler to block the discrete latent site:

guide = numpyro.infer.autoguide.AutoMultivariateNormal(
    numpyro.handlers.block(
        numpyro.handlers.seed(model, rng), 
        lambda site: site["name"] == "assignment"
    )
)

This was originally discussed on the numpyro forums (https://forum.pyro.ai/t/extra-sampling-site-in-manual-guide-compared-to-model/8714/1). Code is runnable in a google colab sheet: https://colab.research.google.com/drive/1FBs765Gyw0KsDFMevd1WyJYzby8n3but?usp=sharing

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions