Skip to content

Reimplementing the code using JAX#3

Draft
htjb wants to merge 66 commits intomasterfrom
v2-jax
Draft

Reimplementing the code using JAX#3
htjb wants to merge 66 commits intomasterfrom
v2-jax

Conversation

@htjb
Copy link
Owner

@htjb htjb commented Nov 20, 2025

The goal is to reimplement maxsmooth in JAX to hopefully speed up the code and modernise the package with type hinting, more standardised docstrings and formatting with Ruff. An example call to the new code is shown below. Currently it tests every possible combination of signs on the derivatives.

"""Example use case for the qp solver."""

import time

import jax
import matplotlib.pyplot as plt
import tqdm
from jax import numpy as jnp

from maxsmooth.derivatives import make_derivative_functions
from maxsmooth.models import normalised_polynomial, normalised_polynomial_basis
from maxsmooth.qp import qp

jax.config.update("jax_enable_x64", True)

function = normalised_polynomial
basis_function = normalised_polynomial_basis

key = jax.random.PRNGKey(0)
x = jnp.linspace(50, 150, 100)
y = 5e6 * x ** (-2.5) + 0.01 * jax.random.normal(key, x.shape)
N = 10
pivot_point = len(x) // 2

qp = jax.jit(
    qp, static_argnames=("N", "pivot_point", "function", "basis_function")
)
start = time.time()
sol = qp(x, y, N, pivot_point, function, basis_function)
end = time.time()
print(f"First Call: QP solved in {end - start:.5f} seconds")

start = time.time()
sol = qp(x, y, N, pivot_point, function, basis_function)
end = time.time()
print(f"Second Call: QP solved in {end - start:.5f} seconds")

vmapped_function = jax.vmap(function, in_axes=(0, None, None, None))

objective_values = []
for i in tqdm.tqdm(range(len(sol["params"]))):
    params = sol["params"][i]
    # plt.plot(x, y, 'o', label='data')
    fit = vmapped_function(x, x[pivot_point], y[pivot_point], params)
    obj_val = jnp.sum((y - fit) ** 2)
    objective_values.append(obj_val)


plt.plot(objective_values, "o-")
plt.xlabel("Solution index")
plt.ylabel("Objective value")
plt.title("Objective values for different constraint sign combinations")
plt.show()

Key Changes:

  • Replacing CVXOPT with Jaxopt
  • Rewritten models in with Jax numpy and jit compiled
  • Using Jax grad and Jax jacobian to get the gradients and prefactors on the model parameters for the gradient e.g. the $G$ in $Ga \leq h$
  • Vectorised call to fit different sign spaces
  • Added pyproject.toml and uv files.
  • Added pre-commit-config.yaml
  • Added docstrings, type hinting and linting
  • Added a linting workflow and a workflow to check version number
  • Added a PR template
  • Added Contribution guidelines

To do:

  • Update the documentation
  • Re-write tests
  • Re-write the best_basis class as a function
  • Examples need rewriting
  • implement some form of the sign exploration/flipping

htjb added 30 commits November 4, 2025 12:15
@htjb
Copy link
Owner Author

htjb commented Jan 14, 2026

Jaxopt is deprecated so I'm not sure it is a good idea to rely on it. There is another code called qpax which might be a viable alternative.

https://github.com/kevin-tracy/qpax

@htjb
Copy link
Owner Author

htjb commented Jan 14, 2026

MPAX also looks good (https://github.com/MIT-Lu-Lab/MPAX).

After some initial testing, qpax seems to struggle to converge.

htjb and others added 26 commits January 16, 2026 10:31
…iggest issue is the interior point solver in v1 vs admm in v2
… to module level

- Replace jaxopt.OSQP (ADMM) with qpax primal-dual interior point solver;
  warm-call time drops from ~1400ms to ~2ms for N=6 on CPU
- Cache derivative_prefactors() results keyed by (fn, x, norm_x, norm_y, N)
  so the autodiff chain is only computed once per unique dataset/model
- Lift _dcf and _flip_sign to module level so JAX's JIT cache persists
  across calls (previously recompiled on every qp() invocation)
- Switch qpsignsearch to use chi-squared as comparison metric instead of
  OSQP internal error; simplify state tuple (remove c, Q, status)
- Expand benchmark to cover all four variants (v1-qp, v1-signflip,
  v2-qp, v2-signsearch) with residuals comparison panel

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…benchmark

- _dcf now returns (params, converged) so callers can inspect solver status
- qp() and qpsignsearch() return (params, chi2, converged: bool)
- qpsignsearch converged=True only if all QP solves along the path converged
- tests updated to unpack 3-tuple; polynomial coefficient test checks fit
  quality instead of exact params (float32 + constraints prevent exact recovery)
- constraint sign test uses N=6 (N=8 hits qpax's hardcoded 30-iter limit)
- benchmark adds 'qp conv?' column showing convergence for the winning solve

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant