We implement simulation-based inference (SBI) for pulse-based Drift–Diffusion Models (DDMs) using neural likelihood estimation (MNLE) and Bayesian inference with MCMC.
We use:
- PyTorch for simulation and neural networks
- 'sbi (v0.25.0)' for neural likelihoods and MCMC
uvfor virtual enviroment handling
curl -LsSf https://astral.sh/uv/install.sh | shuv .venv
source venv/bin/activateuv sync - Simulate Training Dataset
n_max, steps_per_pulse = pulse_schedule()
P = n_pulses_max_from_schedule(n_max, steps_per_pulse)
# define prior over Theta
prior_theta = build_prior_theta()
# Define training proposals over Theta
pulse_prop = PulseSequenceProposal(P=P, p_success=cfg.P_SUCCESS, seed=0,device="cpu")
proposal_z = ExtendedProposal(theta_prior=prior_theta, pulse_proposal=pulse_prop, device="cpu")
# Simulate Training data
z_train, x_train = simulate_training_set_with_conditions(
proposal=proposal_z,
num_simulations=cfg.NUM_SIMULATIONS,
batch_size=cfg.TRAIN_BATCH_SIZE,
device="cpu",
mu_sensory=cfg.MU_SENSORY,
p_success=cfg.P_SUCCESS,
P=P,
log_rt=cfg.LOG_RT_MANUALLY,
)
# Summarize trial data
summarize_trials("train (sample)", x_train[torch.randperm(len(x_train))[:50_000]])- Train neural likelihood (MNLE)
density_estimator = train_mnle(cfg, proposal_z, z_train, x_train, device="cpu")
# Save trained neural network (still working on function for this)
save_model(density_estimator, cfg)
# Simulate Observed Session
theta_true = prior_theta.sample((1,)).view(5)
x_o, pulses_o = simulate_observed_session(
theta_true,
num_trials=cfg.NUM_TRIALS_OBS,
device="cpu",
mu_sensory=cfg.MU_SENSORY,
p_success=cfg.P_SUCCESS,
P=P,
seed=123,
log_rt=cfg.LOG_RT_MANUALLY,
)- Inference ONLY, load saved model:
# Working on function for this too
density_estimator = load_model(cfg, proposal_z, device="cpu")
# run Inference - can be done after training or in isolation
samples = run_inference_mcmc(cfg, prior_theta, density_estimator, x_o, pulses_o)- Simulation-based Calibration (SBC) To verify posterior correctness, run SBC:
run_sbc(
cfg,
prior_theta=prior_theta,
density_estimator=density_estimator,
device="cpu",
num_datasets=cfg.SBC_NUM_DATASETS,
posterior_samples_per_dataset=cfg.SBC_POST_SAMPLES,
seed=0,
param_names=("a0", "lam", "v", "B", "tau"),
outdir=sbc_outdir,
plot_bins=30,
)This performs repeated cycles of:
- Sample
$\theta$ ~ prior - Simulate dataset
- Run MCMC posterior
- Compute rank statistics
- Plot rank histograms
Uniform rank histograms indicate well-calibrated inference.
All experiment parameters live in sbi_for_diffusion_models/run_config.py
Key controls include
NUM_SIMULATIONS # MNLE training size
NUM_TRIALS_OBS # Trials per dataset
POSTERIOR_SAMPLES # MCMC samples
SBC_NUM_DATASETS # Number of SBC repetitions
SBC_POST_SAMPLES # MCMC samples per SBC dataset