Educational project for learning score-based diffusion models through 2D toy datasets. A simple MLP learns to denoise samples from various 2D distributions, using VP-SDE or VE-SDE forward processes and DPM-Solver++(2M) for fast sampling.
⚡ This diffusion model can be trained in a few minutes on a consumer-grade GPU. Let's give it a try!
Learn a 2D data distribution via score-based generative modeling, then generate new samples by reversing the diffusion process.
Forward process (noising):
x_0 ~ p_data → x_t = α(t) x_0 + σ(t) ε, ε ~ N(0, I)
Reverse process (denoising):
x_T ~ N(0, I) → x_{t-1} → ... → x_0 ~ p_data
- Score network: 4-layer MLP with ReLU activations
- Predictor types:
x_start,noise, orscoreparameterization - Forward SDEs:
- VP-SDE (Variance Preserving): linear beta schedule from
beta_mintobeta_max - VE-SDE (Variance Exploding): geometric sigma schedule from
sigma_mintosigma_max
- VP-SDE (Variance Preserving): linear beta schedule from
- ODE Solver: DPM-Solver++(2M) — 2nd-order multistep solver in x0-prediction form
Seven built-in 2D datasets are available:
| Dataset | Source |
|---|---|
two_moons |
scikit-learn |
swiss_roll |
scikit-learn |
circles |
scikit-learn |
s_curve |
scikit-learn (projected to 2D) |
spiral |
Custom |
pinwheel |
Custom |
checkerboard |
Custom (Numba-accelerated) |
- Sample a batch
x_0from the training data - Sample random time
t ~ U(eps, 1) - Compute noisy sample
x_tvia the forward SDE's marginal distribution - Predict the target (x_start, noise, or score) with the MLP
- Minimize the weighted MSE loss
At sampling time, DPM-Solver++(2M) integrates the probability flow ODE
from t=1 back to t=eps in a configurable number of steps.
Make sure you have uv installed on your system. If you don't have it yet, you can install it by following the instructions here.
uv sync --extra dev
uv pip install -e .Train:
uv run scripts/run_training.py --config config.yamlThis will train the model and save outputs (samples, plots, TensorBoard logs)
in the outputs/ directory.
All settings are defined in config.yaml:
device(CUDA or CPU)seed(reproducibility)data.*(dataset type, number of samples, noise level)model.*(predictor type, ODE solver, hidden size, SDE parameters)training.*(learning rate, weight decay, steps, batch size)sampling.*(number of generated samples, solver steps)
Install dev dependencies (includes pre-commit and ruff), then install hooks:
uv sync --extra dev
pre-commit installRun all hooks manually if needed:
pre-commit run --all-files