Warning
Hyperoptax is still a WIP and the API is subject to change. There are many rough edges to smooth out. It is recommended to download specific versions or to download from source if you want to use it in a large scale project.
Hyperoptax is a lightweight toolbox for parallel hyperparameter optimization of pure JAX functions. It provides a concise API that lets you wrap any JAX-compatible loss or evaluation function and search across spaces in parallel – all while staying in pure JAX.
pip install hyperoptaxIf you want to use the notebooks:
pip install hyperoptax[notebooks]If you do not yet have JAX installed, pick the right wheel for your accelerator:
# CPU-only
pip install --upgrade "jax[cpu]"
# or GPU/TPU – see the official JAX installation guideAll optimizers follow the same stateless pattern: Optimizer.init returns a (state, optimizer) pair, and optimizer.optimize runs the search loop. Your objective function must have the signature fn(key, params) -> scalar.
import jax
from hyperoptax import BayesianSearch, LogSpace, LinearSpace
def train_nn(key, params):
learning_rate = params["learning_rate"]
final_lr_pct = params["final_lr_pct"]
...
return val_loss # scalar, lower is better
search_space = {
"learning_rate": LogSpace(1e-5, 1e-1),
"final_lr_pct": LinearSpace(0.01, 0.5),
}
state, optimizer = BayesianSearch.init(
search_space,
n_max=100, # observation buffer size (= number of iterations)
n_parallel=4, # Parallel workers per step
maximize=False,
)
state, (params_hist, results_hist) = optimizer.optimize(
state, jax.random.PRNGKey(0), train_nn
)
# params_hist: list of pytrees, one per iteration (each leaf has shape (n_parallel,))
# results_hist: list of arrays, one per iteration (each has shape (n_parallel,))
# Retrieve best result
print(optimizer.best_result(state))
print(optimizer.best_params(state))Other available optimizers:
from hyperoptax import RandomSearch, GridSearch, DiscreteSpace
# Random search
state, optimizer = RandomSearch.init(search_space, n_parallel=8)
state, history = optimizer.optimize(state, jax.random.PRNGKey(0), train_nn, n_iterations=50)
# Grid search (DiscreteSpace only)
# Note: shuffle=True
grid_space = {"lr": DiscreteSpace([1e-4, 1e-3, 1e-2]), "dropout": DiscreteSpace([0.1, 0.3, 0.5])}
state, optimizer = GridSearch.init(grid_space)
state, history = optimizer.optimize(state, jax.random.PRNGKey(0), train_nn, n_iterations=9)optimize_scan() has the same signature as optimize() but uses jax.lax.scan internally.
This requires your objective function to be JAX-traceable (jit-compilable), and returns
stacked arrays rather than Python lists:
state, (params_hist, results_hist) = optimizer.optimize_scan(
state, jax.random.PRNGKey(0), train_nn, n_iterations=25
)
# params_hist: pytree where each leaf has shape (n_iterations, n_parallel, ...)
# results_hist: array of shape (n_iterations, n_parallel)Return type difference:
optimize()returns Python lists (easy to index by iteration), whileoptimize_scan()returns stacked JAX arrays (compatible withjax.jit, faster for JAX-traceable objectives). Choose based on your objective function and use case.
Since we are working in pure JAX the same sharp bits apply. Some consequences of this for hyperoptax:
- Parameters that change the length of an evaluation (e.g: epochs, generations...) can't be optimized in parallel.
- Neural network structures can't be optimized in parallel either.
- Strings can't be used as hyperparameters.
We welcome pull requests! To get started:
- Open an issue describing the bug or feature.
- Fork the repository and create a feature branch (
git checkout -b my-feature). - Clone and install dependencies. We recommend uv for environment management:
git clone https://github.com/TheodoreWolf/hyperoptax
cd hyperoptax
uv pip install -e ".[all]"- Run the test suite:
uv run pytest- Ensure the notebooks still work.
- Format your code with
ruff. - Submit a pull request.
I'm developing this both as a passion project and for my work in my PhD. I have a few ideas on where to go with this library:
- Callbacks!
- Reduce redundant kernel recomputation — currently the full K matrix is rebuilt each iteration when only the new row/column is needed.
- Length scale tuning currently uses a fixed Adam step count; smarter convergence criteria could help.
If you use Hyperoptax in academic work, please cite:
@misc{hyperoptax,
author = {Theo Wolf},
title = {{Hyperoptax}: Parallel hyperparameter tuning with JAX},
year = {2025},
url = {https://github.com/TheodoreWolf/hyperoptax}
}
