Learning Methodologies for Autoregressive Neural Emulators.
Installation • Documentation • Quickstart • Background • Features • Citation
Convenience abstractions using optax to train neural networks to
autoregressively emulate time-dependent problems taking care of trajectory
subsampling and offering a wide range of training methodologies (regarding
unrolling length and including differentiable physics).
pip install trainaxRequires Python 3.10+ and JAX 0.4.13+. 👉 JAX install guide.
The documentation is available at fkoehler.site/trainax.
Train a kernel size 2 linear convolution (no bias) to become an emulator for the 1D advection problem.
import jax
import jax.numpy as jnp
import equinox as eqx
import optax  # pip install optax
import trainax as tx
CFL = -0.75
ref_data = tx.sample_data.advection_1d_periodic(
    cfl = CFL,
    key = jax.random.PRNGKey(0),
)
linear_conv_kernel_2 = eqx.nn.Conv1d(
    1, 1, 2,
    padding="SAME", padding_mode="CIRCULAR", use_bias=False,
    key=jax.random.PRNGKey(73)
)
sup_1_trainer, sup_5_trainer, sup_20_trainer = (
    tx.trainer.SupervisedTrainer(
        ref_data,
        num_rollout_steps=r,
        optimizer=optax.adam(1e-2),
        num_training_steps=1000,
        batch_size=32,
    )
    for r in (1, 5, 20)
)
sup_1_conv, sup_1_loss_history = sup_1_trainer(
    linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)
sup_5_conv, sup_5_loss_history = sup_5_trainer(
    linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)
sup_20_conv, sup_20_loss_history = sup_20_trainer(
    linear_conv_kernel_2, key=jax.random.PRNGKey(42)
)
FOU_STENCIL = jnp.array([1+CFL, -CFL])
print(jnp.linalg.norm(sup_1_conv.weight - FOU_STENCIL))   # 0.033
print(jnp.linalg.norm(sup_5_conv.weight - FOU_STENCIL))   # 0.025
print(jnp.linalg.norm(sup_20_conv.weight - FOU_STENCIL))  # 0.017Increasing the supervised unrolling steps during training makes the learned stencil come closer to the numerical FOU stencil.
After the discretization of space and time, the simulation of a time-dependent partial differential equation amounts to the repeated application of a simulation operator $\mathcal{P}h$. Here, we are interested in imitating/emulating this physical/numerical operator with a neural network $f\theta$. This repository is concerned with an abstract implementation of all ways we can frame a learning problem to inject "knowledge" from $\mathcal{P}h$ into $f\theta$.
Assume we have a distribution of initial conditions 
For a one-step supervised learning task, we substack the training trajectory
into windows of size 
where 
Trainax supports way more than just one-step supervised learning, e.g., to
train with unrolled steps, to include the reference simulator 
- Wide collection of unrolled training methodologies:
- Supervised
 - Diverted Chain
 - Mix Chain
 - Residuum
 
 - Based on JAX:
- One of the best Automatic Differentiation engines (forward & reverse)
 - Automatic vectorization
 - Backend-agnostic code (run on CPU, GPU, and TPU)
 
 - Build on top and compatible with Equinox
 - Batch-Parallel Training
 - Collection of Callbacks
 - Composability
 
This package was developed as part of the APEBench paper (arxiv.org/abs/2411.00180) (accepted at Neurips 2024). If you find it useful for your research, please consider citing it:
@article{koehler2024apebench,
  title={{APEBench}: A Benchmark for Autoregressive Neural Emulators of {PDE}s},
  author={Felix Koehler and Simon Niedermayr and R{\"}udiger Westermann and Nils Thuerey},
  journal={Advances in Neural Information Processing Systems (NeurIPS)},
  volume={38},
  year={2024}
}(Feel free to also give the project a star on GitHub if you like it.)
Here you can find the APEBench benchmark suite.
The main author (Felix Koehler) is a PhD student in the group of Prof. Thuerey at TUM and his research is funded by the Munich Center for Machine Learning.
MIT, see here
fkoehler.site · GitHub @ceyron · X @felix_m_koehler · LinkedIn Felix Köhler
