diff --git a/diffrax/_solver/align.py b/diffrax/_solver/align.py index 806a8a8f..750469b6 100644 --- a/diffrax/_solver/align.py +++ b/diffrax/_solver/align.py @@ -15,6 +15,7 @@ UnderdampedLangevinTuple, UnderdampedLangevinX, ) +from .base import AbstractAdaptiveSolver from .foster_langevin_srk import ( AbstractCoeffs, AbstractFosterLangevinSRK, @@ -44,7 +45,9 @@ def __init__(self, beta, a1, b1, aa, chh): _ErrorEstimate = UnderdampedLangevinTuple -class ALIGN(AbstractFosterLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]): +class ALIGN( + AbstractFosterLangevinSRK[_ALIGNCoeffs, _ErrorEstimate], AbstractAdaptiveSolver +): r"""The Adaptive Langevin via Interpolated Gradients and Noise method designed by James Foster. This is a second order solver for the Underdamped Langevin Diffusion, and accepts terms of the form diff --git a/diffrax/_solver/spark.py b/diffrax/_solver/spark.py index bfaa0323..dda948b9 100644 --- a/diffrax/_solver/spark.py +++ b/diffrax/_solver/spark.py @@ -3,7 +3,7 @@ import equinox.internal as eqxi import numpy as np -from .base import AbstractStratonovichSolver +from .base import AbstractAdaptiveSolver, AbstractStratonovichSolver from .srk import AbstractSRK, GeneralCoeffs, StochasticButcherTableau @@ -35,7 +35,7 @@ ) -class SPaRK(AbstractSRK, AbstractStratonovichSolver): +class SPaRK(AbstractSRK, AbstractStratonovichSolver, AbstractAdaptiveSolver): r"""The Splitting Path Runge-Kutta method. It uses three evaluations of the drift and diffusion per step, and has the following