Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion diffrax/_solver/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
UnderdampedLangevinTuple,
UnderdampedLangevinX,
)
from .base import AbstractAdaptiveSolver
from .foster_langevin_srk import (
AbstractCoeffs,
AbstractFosterLangevinSRK,
Expand Down Expand Up @@ -44,7 +45,9 @@ def __init__(self, beta, a1, b1, aa, chh):
_ErrorEstimate = UnderdampedLangevinTuple


class ALIGN(AbstractFosterLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]):
class ALIGN(
AbstractFosterLangevinSRK[_ALIGNCoeffs, _ErrorEstimate], AbstractAdaptiveSolver
):
r"""The Adaptive Langevin via Interpolated Gradients and Noise method
designed by James Foster. This is a second order solver for the
Underdamped Langevin Diffusion, and accepts terms of the form
Expand Down
4 changes: 2 additions & 2 deletions diffrax/_solver/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import equinox.internal as eqxi
import numpy as np

from .base import AbstractStratonovichSolver
from .base import AbstractAdaptiveSolver, AbstractStratonovichSolver
from .srk import AbstractSRK, GeneralCoeffs, StochasticButcherTableau


Expand Down Expand Up @@ -35,7 +35,7 @@
)


class SPaRK(AbstractSRK, AbstractStratonovichSolver):
class SPaRK(AbstractSRK, AbstractStratonovichSolver, AbstractAdaptiveSolver):
r"""The Splitting Path Runge-Kutta method.
It uses three evaluations of the drift and diffusion per step, and has the following
Expand Down
Loading