Skip to content

uwplasma/vmec_jax

Repository files navigation

vmec-jax

End-to-end differentiable JAX implementation of VMEC2000 for fixed-boundary and free-boundary ideal-MHD equilibria.

Showcase (single-grid)

All figures below use the same single-grid run settings: NS_ARRAY=151, NITER_ARRAY=5000, FTOL_ARRAY=1e-14, NSTEP=500.

ITERModel cross-section (VMEC2000 vs vmec_jax) LandremanPaul2021_QA_lowres cross-section (VMEC2000 vs vmec_jax)
ITERModel iota (VMEC2000 vs vmec_jax vs VMEC++) LandremanPaul2021_QA_lowres iota (VMEC2000 vs vmec_jax vs VMEC++)

More visuals (single-grid)
ITERModel 3D LCFS (VMEC2000 vs vmec_jax) LandremanPaul2021_QA_lowres 3D LCFS (VMEC2000 vs vmec_jax)
ITERModel |B| on LCFS (VMEC2000 vs vmec_jax) LandremanPaul2021_QA_lowres |B| on LCFS (VMEC2000 vs vmec_jax)

What it is

  • VMEC2000-parity solver for fixed-boundary and free-boundary equilibria.
  • Supports axisymmetric and non-axisymmetric configurations, with lasym=False and lasym=True for stellarator symmetry/asymmetry and up-down symmetry/asymmetry.
  • Default CLI path is vmec_jax input.name.
  • wout_*.nc outputs, iteration diagnostics, and manifest-based parity sweeps are built around VMEC2000-compatible workflows.
  • JAX-native kernels for geometry, transforms, and residual assembly.
  • Differentiable optimization workflows are available through the Python API and bundled examples.

Quickstart

Install (editable) and run the showcase:

python -m venv .venv
source .venv/bin/activate
python -m pip install -e .
python examples/showcase_axisym_input_to_wout.py --suite

If you want the bundled reference outputs and mgrid files, fetch the assets once:

python tools/fetch_assets.py

Lightweight clone (keeps full history, downloads blobs lazily):

git clone --filter=blob:none https://github.com/uwplasma/vmec_jax

Note: the repo history was rewritten on 2026-03-16 to remove large assets from all commits. If you cloned before that date, please re-clone (or prune and reset) to get the smaller history.

CLI (VMEC2000-style executable):

vmec_jax examples/data/input.circular_tokamak

Sanity check (verifies the console script is wired to the right interpreter):

vmec_jax --help

If the vmec_jax command is not found or raises ModuleNotFoundError, make sure you installed with the same interpreter and use the module entrypoint:

python -m pip install -e .
python -m vmec_jax examples/data/input.circular_tokamak

For fixed-boundary inputs, the default CLI path now uses the optimized controller: it tries the fast final-grid scan route first, then escalates to staged continuation and strict parity finishing only when the input structure and residual history require it. Pass --parity to force the conservative VMEC2000 loop. Pass --solver-mode accelerated to request the optimized track explicitly.

Python driver comparison (reference track vs optimized CLI-style track):

python examples/fixed_boundary_driver_tracks.py \
  examples/data/input.circular_tokamak \
  --quiet --json

Run tests:

pytest -q

Full test suite (requires netCDF assets):

python tools/fetch_assets.py
RUN_FULL=1 pytest -q

Optimization tutorials (differentiable boundary tuning):

python examples/optimization/optimize_bmag_volume.py --case circular_tokamak --opt-steps 3
python examples/optimization/explicit_target_iota_volume.py --case circular_tokamak --opt-steps 3
python examples/optimization/implicit_target_iota_volume.py --case circular_tokamak --opt-steps 3

Performance vs parity

  • Default runs aim for VMEC2000-compatible behavior while selecting the fastest stable path for the input.
  • Use --parity (or performance_mode=False in Python) to force the conservative VMEC2000 loop.
  • Use --solver-mode accelerated to force the optimized fixed-boundary controller explicitly.

Details, profiling guidance, and parity methodology:

  • docs/performance.rst
  • docs/validation.rst
  • tools/diagnostics/parity_manifest.toml + tools/diagnostics/parity_sweep_manifest.py

VMEC++ notes

The runtime plot includes VMEC++ (green) for context. Some inputs are not supported or do not converge under the same single-grid settings.

VMEC++ unsupported inputs in this benchmark (lasym=True):

  • LandremanSenguptaPlunk_section5p3_low_res
  • basic_non_stellsym_pressure
  • cth_like_free_bdy_lasym_small
  • up_down_asymmetric_tokamak

VMEC++ failed to converge (non-zero exit) on these lasym=False cases under the same single-grid settings:

  • DIII-D_lasym_false
  • LandremanPaul2021_QA_reactorScale_lowres
  • LandremanPaul2021_QH_reactorScale_lowres
  • LandremanSengupta2019_section5.4_B2_A80
  • cth_like_fixed_bdy

CLI output and NSTEP

The VMEC-style iteration loop prints every NSTEP iterations. Larger NSTEP means fewer print callbacks and faster runs.

To disable live printing, set:

export VMEC_JAX_SCAN_PRINT=0

Quiet runs (--quiet or verbose=False) default the scan path to a minimal history mode (only fsqr/fsqz/fsql and w_history are kept) to reduce host/device traffic. You can override this with:

export VMEC_JAX_SCAN_MINIMAL=0  # keep full scan diagnostics even when quiet

When to use vmec_jax

  • Use vmec_jax for fixed-boundary and free-boundary production runs, autodiff, rapid parameter sweeps, and JAX-native optimization workflows.
  • Use the VMEC2000 executable as an optional parity reference or regression oracle, not as an operational requirement.