From 70afbbccb2ec8a370aab051de0f2a14f65fa5a91 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Sun, 3 Aug 2025 17:54:38 -0400 Subject: [PATCH] adapt --- diffrax/_solver/align.py | 5 ++++- diffrax/_solver/spark.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/diffrax/_solver/align.py b/diffrax/_solver/align.py index 806a8a8f..750469b6 100644 --- a/diffrax/_solver/align.py +++ b/diffrax/_solver/align.py @@ -15,6 +15,7 @@ UnderdampedLangevinTuple, UnderdampedLangevinX, ) +from .base import AbstractAdaptiveSolver from .foster_langevin_srk import ( AbstractCoeffs, AbstractFosterLangevinSRK, @@ -44,7 +45,9 @@ def __init__(self, beta, a1, b1, aa, chh): _ErrorEstimate = UnderdampedLangevinTuple -class ALIGN(AbstractFosterLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]): +class ALIGN( + AbstractFosterLangevinSRK[_ALIGNCoeffs, _ErrorEstimate], AbstractAdaptiveSolver +): r"""The Adaptive Langevin via Interpolated Gradients and Noise method designed by James Foster. This is a second order solver for the Underdamped Langevin Diffusion, and accepts terms of the form diff --git a/diffrax/_solver/spark.py b/diffrax/_solver/spark.py index bfaa0323..dda948b9 100644 --- a/diffrax/_solver/spark.py +++ b/diffrax/_solver/spark.py @@ -3,7 +3,7 @@ import equinox.internal as eqxi import numpy as np -from .base import AbstractStratonovichSolver +from .base import AbstractAdaptiveSolver, AbstractStratonovichSolver from .srk import AbstractSRK, GeneralCoeffs, StochasticButcherTableau @@ -35,7 +35,7 @@ ) -class SPaRK(AbstractSRK, AbstractStratonovichSolver): +class SPaRK(AbstractSRK, AbstractStratonovichSolver, AbstractAdaptiveSolver): r"""The Splitting Path Runge-Kutta method. It uses three evaluations of the drift and diffusion per step, and has the following