Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 74 additions & 27 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
![CI status](https://github.com/TheodoreWolf/hyperoptax/actions/workflows/test.yml/badge.svg?branch=main)
[![codecov](https://codecov.io/gh/TheodoreWolf/hyperoptax/graph/badge.svg?token=Y582MZ25GG)](https://codecov.io/gh/TheodoreWolf/hyperoptax)

>[!NOTE]
>[!WARNING]
> Hyperoptax is still a WIP and the API is subject to change. There are _many_ rough edges to smooth out. It is recommended to download specific versions or to download from source if you want to use it in a large scale project.

## ⛰️ Introduction

Hyperoptax is a lightweight toolbox for parallel hyperparameter optimization of pure JAX functions. It provides a concise API that lets you wrap any JAX-compatible loss or evaluation function and search across spaces __in parallel__ – all while staying in pure JAX.
Hyperoptax is a lightweight toolbox for parallel hyperparameter optimization of pure JAX functions. It provides a concise API that lets you wrap any JAX-compatible loss or evaluation function and search across spaces __in parallel__ – all while staying in pure JAX.

## 🏗️ Installation

Expand All @@ -32,26 +32,76 @@ pip install --upgrade "jax[cpu]"
# or GPU/TPU – see the official JAX installation guide
```
## 🥜 In a nutshell
Hyperoptax offers a simple API to wrap pure JAX functions for hyperparameter search and making use of parallelization (vmap only currently). See the [notebooks](https://github.com/TheodoreWolf/hyperoptax/tree/main/notebooks) for more examples.

All optimizers follow the same stateless pattern: `Optimizer.init` returns a `(state, optimizer)` pair, and `optimizer.optimize` runs the search loop. Your objective function must have the signature `fn(key, params) -> scalar`.

```python
from hyperoptax.bayesian import BayesianOptimizer
from hyperoptax.spaces import LogSpace, LinearSpace
import jax
from hyperoptax import BayesianSearch, LogSpace, LinearSpace

@jax.jit
def train_nn(learning_rate, final_lr_pct):
def train_nn(key, params):
learning_rate = params["learning_rate"]
final_lr_pct = params["final_lr_pct"]
...
return val_loss
return val_loss # scalar, lower is better

search_space = {
"learning_rate": LogSpace(1e-5, 1e-1),
"final_lr_pct": LinearSpace(0.01, 0.5),
}

state, optimizer = BayesianSearch.init(
search_space,
n_max=100, # observation buffer size (= number of iterations)
n_parallel=4, # Parallel workers per step
maximize=False,
)

state, (params_hist, results_hist) = optimizer.optimize(
state, jax.random.PRNGKey(0), train_nn
)
# params_hist: list of pytrees, one per iteration (each leaf has shape (n_parallel,))
# results_hist: list of arrays, one per iteration (each has shape (n_parallel,))

# Retrieve best result
print(optimizer.best_result(state))
print(optimizer.best_params(state))
```

Other available optimizers:

```python
from hyperoptax import RandomSearch, GridSearch, DiscreteSpace

# Random search
state, optimizer = RandomSearch.init(search_space, n_parallel=8)
state, history = optimizer.optimize(state, jax.random.PRNGKey(0), train_nn, n_iterations=50)

# Grid search (DiscreteSpace only)
# Note: shuffle=True
grid_space = {"lr": DiscreteSpace([1e-4, 1e-3, 1e-2]), "dropout": DiscreteSpace([0.1, 0.3, 0.5])}
state, optimizer = GridSearch.init(grid_space)
state, history = optimizer.optimize(state, jax.random.PRNGKey(0), train_nn, n_iterations=9)
```

search_space = {"learning_rate": LogSpace(1e-5,1e-1, 100),
"final_lr_pct": LinearSpace(0.01, 0.5, 100)}
### `optimize_scan()` — JAX-native loop

search = BayesianOptimizer(search_space, train_nn)
best_params = search.optimize(n_iterations=100,
n_parallel=10,
maximize=False,
)
`optimize_scan()` has the same signature as `optimize()` but uses `jax.lax.scan` internally.
This requires your objective function to be JAX-traceable (jit-compilable), and returns
**stacked arrays** rather than Python lists:

```python
state, (params_hist, results_hist) = optimizer.optimize_scan(
state, jax.random.PRNGKey(0), train_nn, n_iterations=25
)
# params_hist: pytree where each leaf has shape (n_iterations, n_parallel, ...)
# results_hist: array of shape (n_iterations, n_parallel)
```

> **Return type difference:** `optimize()` returns Python lists (easy to index by iteration),
> while `optimize_scan()` returns stacked JAX arrays (compatible with `jax.jit`, faster for
> JAX-traceable objectives). Choose based on your objective function and use case.

## 💪 Hyperoptax in action
<img src="./assets/gp_animation.gif" alt="BayesOpt animation" style="width:80%;"/>

Expand All @@ -68,31 +118,28 @@ We welcome pull requests! To get started:

1. Open an issue describing the bug or feature.
2. Fork the repository and create a feature branch (`git checkout -b my-feature`).
3. Install dependencies:
3. Clone and install dependencies. We recommend [uv](https://docs.astral.sh/uv/) for environment management:

```bash
pip install -e ".[all]"
git clone https://github.com/TheodoreWolf/hyperoptax
cd hyperoptax
uv pip install -e ".[all]"
```

4. Run the test suite:

```bash
XLA_FLAGS=--xla_force_host_platform_device_count=4 pytest # Fake GPUs for pmap tests
uv run pytest
```
5. Ensure the notebooks still work.
6. Format your code with `ruff`.
7. Submit a pull request.

## Roadmap
I'm developing this both as a passion project and for my work in my PhD. I have a few ideas on where to go with this libary:
- Sample hyperparameter configurations on the fly rather than generate a huge grid at initialisation.
- Switch domain type from a list of arrays to a PyTree.
I'm developing this both as a passion project and for my work in my PhD. I have a few ideas on where to go with this library:
- Callbacks!
- Inspired by wandb's sweeps, use a linear grid for all parameters and apply transformations at sample time.
- We are currently redoing the kernel calculation at each iteration when only the last row/column is actually needed. JAX requires sizes to be constant, so we need to do something clever...
- Need to find a way to share the GP across workers on pmap for Bayesian.
- Length scale tuning of kernel tuned during optimization (as done in other implementations).
- Some clumpiness in the acquisisiton, there is literature that can help us.
- Reduce redundant kernel recomputation — currently the full K matrix is rebuilt each iteration when only the new row/column is needed.
- Length scale tuning currently uses a fixed Adam step count; smarter convergence criteria could help.

## 📝 Citation

Expand All @@ -105,4 +152,4 @@ If you use Hyperoptax in academic work, please cite:
year = {2025},
url = {https://github.com/TheodoreWolf/hyperoptax}
}
```
```
10 changes: 9 additions & 1 deletion docs/source/api/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,12 @@ Grid Search
.. automodule:: hyperoptax.grid
:members:
:undoc-members:
:show-inheritance:
:show-inheritance:

Random Search
-------------

.. automodule:: hyperoptax.random
:members:
:undoc-members:
:show-inheritance:
30 changes: 19 additions & 11 deletions docs/source/api/spaces.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,30 @@ Parameter spaces define the search domains for hyperparameter optimization.
Examples
--------

Creating Linear Space
~~~~~~~~~~~~~~~~~~~~~~
Creating a Linear Space
~~~~~~~~~~~~~~~~~~~~~~~

.. code-block:: python

from hyperoptax.spaces import LinearSpace

# Create a linear space from 0.01 to 1.0 with 100 points
lr_space = LinearSpace(0.01, 1.0, 100)
from hyperoptax import LinearSpace

Creating Logarithmic Space
dropout_space = LinearSpace(0.0, 0.5)

Creating a Logarithmic Space
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. code-block:: python

from hyperoptax import LogSpace

lr_space = LogSpace(1e-5, 1e-1)

Creating a Discrete Space
~~~~~~~~~~~~~~~~~~~~~~~~~~

.. code-block:: python

from hyperoptax.spaces import LogSpace
# Create a log space from 1e-5 to 1e-1 with 50 points
lr_space = LogSpace(1e-5, 1e-1, 50)
from hyperoptax import DiscreteSpace

optimizer_space = DiscreteSpace(["adam", "sgd", "rmsprop"])
lr_space = DiscreteSpace([1e-4, 1e-3, 1e-2])
34 changes: 21 additions & 13 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,29 @@ Quick Start

.. code-block:: python

from hyperoptax.bayesian import BayesianOptimizer
from hyperoptax.spaces import LogSpace, LinearSpace
import jax
from hyperoptax import BayesianSearch, LogSpace, LinearSpace

@jax.jit
def train_nn(learning_rate, final_lr_pct):
def train_nn(key, params):
learning_rate = params["learning_rate"]
final_lr_pct = params["final_lr_pct"]
...
return val_loss

search_space = {"learning_rate": LogSpace(1e-5,1e-1, 100),
"final_lr_pct": LinearSpace(0.01, 0.5, 100)}

search = BayesianOptimizer(search_space, train_nn)
best_params = search.optimize(n_iterations=100,
n_parallel=10,
maximize=False)
return val_loss # scalar, lower is better

search_space = {
"learning_rate": LogSpace(1e-5, 1e-1),
"final_lr_pct": LinearSpace(0.01, 0.5),
}

state, optimizer = BayesianSearch.init(
search_space,
n_max=100,
maximize=False,
)
state, (params_hist, results_hist) = optimizer.optimize(
state, jax.random.PRNGKey(0), train_nn
)
print(optimizer.best_params(state))

.. toctree::
:maxdepth: 2
Expand Down
Binary file modified notebooks/gp_animation.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading