Skip to content

Draft: add JAX VMEC/Boozer optimization path#604

Draft
rogeriojorge wants to merge 19 commits intohiddenSymmetries:masterfrom
rogeriojorge:codex/qh-boozer-jax-wip
Draft

Draft: add JAX VMEC/Boozer optimization path#604
rogeriojorge wants to merge 19 commits intohiddenSymmetries:masterfrom
rogeriojorge:codex/qh-boozer-jax-wip

Conversation

@rogeriojorge
Copy link
Contributor

@rogeriojorge rogeriojorge commented Mar 21, 2026

Summary

  • add SIMSOPT wrappers for vmec_jax and booz_xform_jax
  • add a JAX least-squares solve path plus QH Boozer JAX example/profilers
  • add regression tests and diagnostics used during the QH parity/performance investigation
  • include the paired tracked Boozer and QH example changes required by the new tests/examples

Included changes

  • new VmecJax and BoozerJax wrappers under simsopt.mhd
  • new JAX-based solve helpers under simsopt.solve
  • QH_fixed_resolution_boozer_jax.py and focused profiling helpers
  • tracked Boozer support for use_wout_file and MPI-group-safe serial Boozer execution
  • explicit proc0_print(..., flush=True) default and safe VMEC cleanup behavior for missing auxiliary files
  • wrapper/parity tests covering initialization, Boozer-input parity, staged input defaults, and QH mismatch diagnostics

Dependency

Reviewer fixes in latest update

  • VmecJax now derives default solver controls from NITER_ARRAY / FTOL_ARRAY when present, instead of silently falling back to scalar NITER / FTOL
  • the wrapper-side warm start is now explicit and opt-in; default warm_start_iters=0 uses the plain VMEC-style initial guess
  • the missing tracked Boozer / QH example changes are now included in the PR branch, so test_boozer_jax_compare.py no longer depends on an uncommitted API extension

Measured status

  • initial QH QS on the JAX path improved from roughly 1.59e-2 to 7.30e-3
  • reference initial QS from the original VMEC2000/BoozXform workflow is 1.7644570754324314e-3
  • cold-start runtime is still much slower than the original example, so this PR is reviewable but not ready to merge

Validation

  • pytest -q tests/mhd/test_vmec_jax_wrapper.py
  • pytest -q tests/mhd/test_boozer_jax_compare.py
  • pytest -q tests/mhd/test_vmec_jax_wrapper.py tests/mhd/test_booz_input_parity.py tests/mhd/test_vmec_jax_qh_mismatch_diagnostics.py

Review focus

  • wrapper/API shape for VmecJax, BoozerJax, and JAX least-squares integration
  • whether the example/profiling structure is reasonable for continued parity work
  • test coverage around the known QH mismatch points

Status

This remains a draft PR for early review. The JAX path is functional and several wrapper bugs are fixed, but full parity and runtime targets versus the original VMEC2000/BoozXform workflow are still in progress.

@rogeriojorge rogeriojorge changed the title WIP: add JAX VMEC/Boozer optimization path Draft: add JAX VMEC/Boozer optimization path Mar 21, 2026
@rogeriojorge
Copy link
Contributor Author

This draft depends on uwplasma/vmec_jax#2. The wrapper-side initialization bug is fixed, but the remaining gap is upstream in the VMEC-JAX equilibrium/lambda-current path rather than in the SIMSOPT API layer.

@rogeriojorge
Copy link
Contributor Author

Thanks for the review. I pushed a follow-up commit that addresses the three concrete blockers:

  • test_boozer_jax_compare.py is now valid on the PR branch because the tracked Boozer(..., use_wout_file=...) change was added to the branch instead of only existing in the local worktree.
  • VmecJax now derives defaults from staged VMEC inputs. When NITER_ARRAY / FTOL_ARRAY are present, the wrapper uses their final staged values by default instead of silently reading scalar NITER / FTOL.
  • the hidden 20-step L-BFGS presolve is gone as a default. warm_start_iters now defaults to 0, and there is a regression that the context starts from the plain VMEC-style initial guess unless warm starting is explicitly requested.

I also added the missing tracked example updates (QH_fixed_resolution_boozer.py and QH_fixed_resolution.py) to the branch, plus the small MPI/VMEC cleanup changes those examples rely on.

Validation after the follow-up commit:

  • pytest -q tests/mhd/test_vmec_jax_wrapper.py -> 8 passed
  • pytest -q tests/mhd/test_boozer_jax_compare.py -> 1 passed

On the booz_xform_jax point: there is still no separate PR there because this line of work did not require source changes in booz_xform_jax; the dependency here is the existing upstream package, not a patched branch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we revert changes here? What there a reason to change the resolution?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way if these are good fixes anyways, but unrelated to the jax_vmec stuff, just submit this as a separate, small PR, and we will approve extremely fast.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment, is there some reason to change this file? Looks like you are doing a more robust, higher res solve, but no vmec or boozer jax functionality here, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not very readable. Do we need all the parser arguments here? Can we put this in the source code instead or something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify what you are doing here? Hard to tell if this is related to the jax booz_xform or this is a different fix?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and other new files, please add detailed docstrings (with cursor) for the new functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic of this change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment, these functions need a lot of docstrings.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filename is very confusing because it conflicts with external jax package! Please change to jax_solve or jax_least_squares_solve or something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstrings needed too

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change needed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants