Skip to content

TheodoreWolf/hyperoptax

Repository files navigation

Hyperoptax Logo

Hyperoptax: Parallel hyperparameter tuning with JAX

PyPI version CI status codecov

Warning

Hyperoptax is still a WIP and the API is subject to change. There are many rough edges to smooth out. It is recommended to download specific versions or to download from source if you want to use it in a large scale project.

⛰️ Introduction

Hyperoptax is a lightweight toolbox for parallel hyperparameter optimization of pure JAX functions. It provides a concise API that lets you wrap any JAX-compatible loss or evaluation function and search across spaces in parallel – all while staying in pure JAX.

🏗️ Installation

pip install hyperoptax

If you want to use the notebooks:

pip install hyperoptax[notebooks]

If you do not yet have JAX installed, pick the right wheel for your accelerator:

# CPU-only
pip install --upgrade "jax[cpu]"
# or GPU/TPU – see the official JAX installation guide

🥜 In a nutshell

All optimizers follow the same stateless pattern: Optimizer.init returns a (state, optimizer) pair, and optimizer.optimize runs the search loop. Your objective function must have the signature fn(key, params) -> scalar.

import jax
from hyperoptax import BayesianSearch, LogSpace, LinearSpace

def train_nn(key, params):
    learning_rate = params["learning_rate"]
    final_lr_pct = params["final_lr_pct"]
    ...
    return val_loss  # scalar, lower is better

search_space = {
    "learning_rate": LogSpace(1e-5, 1e-1),
    "final_lr_pct": LinearSpace(0.01, 0.5),
}

state, optimizer = BayesianSearch.init(
    search_space,
    n_max=100,       # observation buffer size (= number of iterations)
    n_parallel=4,    # Parallel workers per step
    maximize=False,
)

state, (params_hist, results_hist) = optimizer.optimize(
    state, jax.random.PRNGKey(0), train_nn
)
# params_hist: list of pytrees, one per iteration (each leaf has shape (n_parallel,))
# results_hist: list of arrays, one per iteration (each has shape (n_parallel,))

# Retrieve best result
print(optimizer.best_result(state))
print(optimizer.best_params(state))

Other available optimizers:

from hyperoptax import RandomSearch, GridSearch, DiscreteSpace

# Random search
state, optimizer = RandomSearch.init(search_space, n_parallel=8)
state, history = optimizer.optimize(state, jax.random.PRNGKey(0), train_nn, n_iterations=50)

# Grid search (DiscreteSpace only)
# Note: shuffle=True
grid_space = {"lr": DiscreteSpace([1e-4, 1e-3, 1e-2]), "dropout": DiscreteSpace([0.1, 0.3, 0.5])}
state, optimizer = GridSearch.init(grid_space)
state, history = optimizer.optimize(state, jax.random.PRNGKey(0), train_nn, n_iterations=9)

optimize_scan() — JAX-native loop

optimize_scan() has the same signature as optimize() but uses jax.lax.scan internally. This requires your objective function to be JAX-traceable (jit-compilable), and returns stacked arrays rather than Python lists:

state, (params_hist, results_hist) = optimizer.optimize_scan(
    state, jax.random.PRNGKey(0), train_nn, n_iterations=25
)
# params_hist: pytree where each leaf has shape (n_iterations, n_parallel, ...)
# results_hist: array of shape (n_iterations, n_parallel)

Return type difference: optimize() returns Python lists (easy to index by iteration), while optimize_scan() returns stacked JAX arrays (compatible with jax.jit, faster for JAX-traceable objectives). Choose based on your objective function and use case.

💪 Hyperoptax in action

BayesOpt animation

🔪 The Sharp Bits

Since we are working in pure JAX the same sharp bits apply. Some consequences of this for hyperoptax:

  1. Parameters that change the length of an evaluation (e.g: epochs, generations...) can't be optimized in parallel.
  2. Neural network structures can't be optimized in parallel either.
  3. Strings can't be used as hyperparameters.

🫂 Contributing

We welcome pull requests! To get started:

  1. Open an issue describing the bug or feature.
  2. Fork the repository and create a feature branch (git checkout -b my-feature).
  3. Clone and install dependencies. We recommend uv for environment management:
git clone https://github.com/TheodoreWolf/hyperoptax
cd hyperoptax
uv pip install -e ".[all]"
  1. Run the test suite:
uv run pytest
  1. Ensure the notebooks still work.
  2. Format your code with ruff.
  3. Submit a pull request.

Roadmap

I'm developing this both as a passion project and for my work in my PhD. I have a few ideas on where to go with this library:

  • Callbacks!
  • Reduce redundant kernel recomputation — currently the full K matrix is rebuilt each iteration when only the new row/column is needed.
  • Length scale tuning currently uses a fixed Adam step count; smarter convergence criteria could help.

📝 Citation

If you use Hyperoptax in academic work, please cite:

@misc{hyperoptax,
  author = {Theo Wolf},
  title = {{Hyperoptax}: Parallel hyperparameter tuning with JAX},
  year = {2025},
  url = {https://github.com/TheodoreWolf/hyperoptax}
}

About

Parallel hyperparameter tuning with JAX

Resources

License

Stars

Watchers

Forks

Contributors

Languages