diff --git a/src/score_models/sbm/score_model.py b/src/score_models/sbm/score_model.py index b0742f8..60d3848 100644 --- a/src/score_models/sbm/score_model.py +++ b/src/score_models/sbm/score_model.py @@ -10,6 +10,7 @@ from ..solver import Solver, ODESolver from ..utils import DEVICE from ..save_load_utils import load_hyperparameters + if TYPE_CHECKING: from score_models import HessianDiagonal @@ -22,7 +23,7 @@ def __new__(cls, *args, **kwargs): path = kwargs.get("path", None) if path is not None: try: - hyperparameters = load_hyperparameters(path) + hyperparameters = load_hyperparameters(path) formulation = hyperparameters.get("formulation", "original") except FileNotFoundError: # Freak case where a new model is created from scratch with a path (so no hyperparameters is present) @@ -31,6 +32,7 @@ def __new__(cls, *args, **kwargs): formulation = kwargs.get("formulation", "original") if formulation.lower() == "edm": from score_models import EDMScoreModel + return super().__new__(EDMScoreModel) else: return super().__new__(cls) @@ -43,7 +45,7 @@ def __init__( checkpoint: Optional[int] = None, hessian_diagonal_model: Optional["HessianDiagonal"] = None, device=DEVICE, - **hyperparameters + **hyperparameters, ): super().__init__(net, sde, path, checkpoint=checkpoint, device=device, **hyperparameters) if hessian_diagonal_model is not None: @@ -81,7 +83,7 @@ def log_prob( steps: int, t: float = 0.0, solver: Literal["Euler", "Heun", "RK4"] = "Euler", - **kwargs + **kwargs, ) -> Tensor: """ Compute the log likelihood of point x using the probability flow ODE, @@ -96,7 +98,14 @@ def log_prob( solver = ODESolver(self, solver=solver, **kwargs) # Solve the probability flow ODE up in temperature to time t=1. xT, dlogp = solver( - x, *args, steps=steps, forward=True, t_min=t, **kwargs, return_dlogp=True, dlogp=self.dlogp + x, + *args, + steps=steps, + forward=True, + t_min=t, + **kwargs, + return_dlogp=True, + dlogp=self.dlogp, ) # add boundary condition PDF probability logp = self.sde.prior(D).log_prob(xT) + dlogp @@ -109,11 +118,17 @@ def sample( shape: tuple, # TODO grab dimensions from model hyperparams if available steps: int, solver: Literal[ - "EMSDESolver", "HeunSDESolver", "RK4SDESolver", "EulerODESolver", "HeunODESolver", "RK4ODESolver" + "EMSDESolver", + "HeunSDESolver", + "HeunSDESolverAdaptive", + "RK4SDESolver", + "EulerODESolver", + "HeunODESolver", + "RK4ODESolver", ] = "EMSDESolver", progress_bar: bool = True, denoise_last_step: bool = True, - **kwargs + **kwargs, ) -> Tensor: """ Sample from the score model by solving the reverse-time SDE using the Euler-Maruyama method. @@ -125,14 +140,7 @@ def sample( B, *D = shape solver = Solver(self, solver=solver, **kwargs) xT = self.sde.prior(D).sample([B]) - x0 = solver( - xT, - *args, - steps=steps, - forward=False, - progress_bar=progress_bar, - **kwargs - ) + x0 = solver(xT, *args, steps=steps, forward=False, progress_bar=progress_bar, **kwargs) if denoise_last_step: t = self.sde.t_min * torch.ones(B, device=self.device) x0 = self.tweedie(t, x0, *args, **kwargs) @@ -146,10 +154,15 @@ def denoise( *args, steps: int, solver: Literal[ - "EMSDESolver", "HeunSDESolver", "RK4SDESolver", "EulerODESolver", "HeunODESolver", "RK4ODESolver" + "EMSDESolver", + "HeunSDESolver", + "RK4SDESolver", + "EulerODESolver", + "HeunODESolver", + "RK4ODESolver", ] = "EMSDESolver", progress_bar: bool = True, - **kwargs + **kwargs, ) -> Tensor: """ Sample from the score model by solving the reverse-time SDE using the Euler-Maruyama method. @@ -158,13 +171,7 @@ def denoise( """ x0 = Solver(self, solver=solver, **kwargs)( - xt, - *args, - t_max=t, - steps=steps, - forward=False, - progress_bar=progress_bar, - **kwargs + xt, *args, t_max=t, steps=steps, forward=False, progress_bar=progress_bar, **kwargs ) # Denoise last step with Tweedie t = self.sde.t_min * torch.ones(x0.shape[0], device=self.device) diff --git a/src/score_models/solver/__init__.py b/src/score_models/solver/__init__.py index 10b47ce..57f688b 100644 --- a/src/score_models/solver/__init__.py +++ b/src/score_models/solver/__init__.py @@ -1,3 +1,4 @@ from .ode import * from .sde import * +from .sdeadaptive import * from .solver import * diff --git a/src/score_models/solver/sdeadaptive.py b/src/score_models/solver/sdeadaptive.py new file mode 100644 index 0000000..a82ea52 --- /dev/null +++ b/src/score_models/solver/sdeadaptive.py @@ -0,0 +1,201 @@ +from typing import Callable, Optional +from contextlib import nullcontext + +from torch import Tensor +from tqdm import tqdm +import torch +import numpy as np + +from .solver import Solver + +__all__ = ["SDESolverAdaptive", "HeunSDESolverAdaptive"] + + +class SDESolverAdaptive(Solver): + + @torch.no_grad() + def solve( + self, + x: Tensor, + *args: tuple, + forward: bool, + dt_init: Optional[float] = None, + steps: Optional[int] = None, + accuracy: float = 1e-1, + progress_bar: bool = True, + trace: bool = False, + kill_on_nan: bool = False, + corrector_steps: int = 0, + corrector_snr: float = 0.1, + hook: Optional[Callable] = None, + **kwargs, + ): + """ + Integrate the diffusion SDE forward or backward in time. + + Discretizes the SDE using the given method and integrates with + + .. math:: + x_{i+1} = x_i + \\frac{dx}{dt}(t_i, x_i) * dt + g(t_i, x_i) * dw + + where the :math:`\\frac{dx}{dt}` is the diffusion drift of + + .. math:: + \\frac{dx}{dt} = f(t, x) - \\frac{1}{2} g(t, x)^2 s(t, x) + + where :math:`f(t, x)` is the sde drift, :math:`g(t, x)` is the sde diffusion, + and :math:`s(t, x)` is the score. + + Args: + x: Initial condition. + dt_init: Initial time step. + forward: Direction of integration. + *args: Additional arguments to pass to the score model. + progress_bar: Whether to display a progress bar. + trace: Whether to return the full path or just the last point. + kill_on_nan: Whether to raise an error if NaNs are encountered. + time_steps: Optional time steps to use for integration. Should be a 1D tensor containing the bin edges of the + time steps. For example, if one wanted 50 steps from 0 to 1, the time steps would be ``torch.linspace(0, 1, 51)``. + corrector_steps: Number of corrector steps to add after each SDE step (0 for no corrector steps). + corrector_snr: Signal-to-noise ratio for the corrector steps. + hook: Optional hook function to call after each step. Will be called with the signature ``hook(t, x, sde, score, solver)``. + """ + B, *D = x.shape + + # Step + if dt_init is None and steps is not None: + dt_init = 1.0 / steps + t_min = kwargs.get("t_min", self.sde.t_min) + t_max = kwargs.get("t_max", self.sde.t_max) + if forward: + dt = torch.tensor(dt_init, device=x.device, dtype=x.dtype) + T = [t_min] + else: + dt = -torch.tensor(dt_init, device=x.device, dtype=x.dtype) + T = [t_max] + dt = dt.reshape(1, *[1] * len(D)).repeat(B, *[1] * len(D)) + user_max_dt = kwargs.pop("max_dt", 1.0) + + if trace: + path = [x] + + if progress_bar: + ptotal = int((t_max - t_min) / dt_init) + _pbar = tqdm(total=ptotal) + pcurrent = 0 + else: + _pbar = nullcontext() + with _pbar as pbar: + while (forward and T[-1] < t_max * 0.999999) or ( + not forward and T[-1] > t_min * 1.000001 + ): + t = torch.tensor(T[-1], device=x.device, dtype=x.dtype).repeat(B) + if forward: # don't pass integration endpoint + max_dt = torch.tensor( + min(t_max - T[-1], user_max_dt), + device=x.device, + dtype=x.dtype, + ) + else: + max_dt = -torch.tensor( + min(T[-1] - t_min, user_max_dt), + device=x.device, + dtype=x.dtype, + ) + dx, dt = self.step(t, x, args, dt, forward, accuracy, max_dt, **kwargs) + x = x + dx + T.append(T[-1] + dt.flatten()[0].item()) + for _ in range(corrector_steps): + x = self.corrector_step(t, x, args, corrector_snr, **kwargs) + + # Logs + if progress_bar: + if forward: + frac_progress = int((T[-1] - t_min) / (t_max - t_min) * ptotal) + else: + frac_progress = int((t_max - T[-1]) / (t_max - t_min) * ptotal) + if frac_progress > pcurrent: + pbar.update(frac_progress - pcurrent) + pcurrent = frac_progress + pbar.set_description( + f"t={T[-1]:.1g} | sigma={self.sde.sigma(t[0]).item():.1g} | " + f"x={x.mean().item():.1g}\u00b1{x.std().item():.1g}" + ) + if kill_on_nan and torch.any(torch.isnan(x)): + raise ValueError("NaN encountered in SDE solver") + if trace: + path.append(x) + if hook is not None: + hook(t, x, self.sde, self.sbm.score, self) + if trace: + return torch.stack(path), T + return x + + def corrector_step(self, t, x, args, snr, **kwargs): + """Basic Langevin corrector step for the SDE.""" + _, *D = x.shape + z = torch.randn_like(x) + epsilon = (snr * self.sde.sigma(t).view(-1, *[1] * len(D))) ** 2 + return x + epsilon * self.sbm.score(t, x, *args, **kwargs) + z * torch.sqrt(2 * epsilon) + + def drift(self, t: Tensor, x: Tensor, args: tuple, forward: bool, **kwargs): + """SDE drift term""" + f = self.sde.drift(t, x) + if forward: + return f + g = self.sde.diffusion(t, x) + s = self.sbm.score(t, x, *args, **kwargs) + return f - g**2 * s + + def dx(self, t, x, args, dt, forward, dw=None, **kwargs): + """SDE differential element dx""" + if dw is None: + dw = torch.randn_like(x) * torch.sqrt(dt.abs()) + return self.drift(t, x, args, forward, **kwargs) * dt + self.sde.diffusion(t, x) * dw + + +class HeunSDESolverAdaptive(SDESolverAdaptive): + """ + Base SDE solver using a 2nd order Runge-Kutta method. For more + details see Equation 2.5 in chapter 7.2 of the book "Introduction to + Stochastic Differential Equations" by Thomas C. Gard. + + This solver adopts the Stratonovich interpretation of the SDE, + though we note that the interpretation does not affect our package + because our diffusion coefficient are homogeneous, i.e. they do not depend on x. + The dependence of sde.diffusion on x is artificial in that it's only used + to infer the shape of the state space. + """ + + def step(self, t, x, args, dt, forward, accuracy, max_dt, **kwargs): + if "dw" in kwargs: + dw = kwargs.pop("dw") + else: + dw = torch.randn_like(x) * torch.sqrt(dt.abs()) + + k1 = self.dx(t, x, args, dt, forward, dw, **kwargs) + k2 = self.dx(t + dt.squeeze(), x + k1, args, dt, forward, dw, **kwargs) + dx = (k1 + k2) / 2 + + # Check error is within noise level + dw_norm = ( + self.sde.diffusion(t, x).flatten()[0] + * torch.sqrt(dt.flatten()[0].abs()) + * np.sqrt(x.numel()) + ) + if torch.linalg.norm(k1 - dx).item() / dw_norm > accuracy: + return self.step( + t, + x, + args, + dt * kwargs.get("stepsize_down", 0.5), + forward, + accuracy, + max_dt, + dw=dw * np.sqrt(kwargs.get("stepsize_down", 0.5)), + **kwargs, + ) + + if forward: + return dx, torch.min(dt * kwargs.get("stepsize_up", 1.4), max_dt) + return dx, torch.max(dt * kwargs.get("stepsize_up", 1.4), max_dt) diff --git a/tests/test_solvers.py b/tests/test_solvers.py index 852ab28..82ebe63 100644 --- a/tests/test_solvers.py +++ b/tests/test_solvers.py @@ -15,8 +15,20 @@ def test_solver_constructor(): with pytest.raises(TypeError): # abstract class cant be created Solver(None) + for solver in [ + "EMSDESolver", + "HeunSDESolver", + "HeunSDESolverAdaptive", + "RK4SDESolver", + "EulerODESolver", + "HeunODESolver", + "RK4ODESolver", + ]: + assert isinstance(Solver(None, solver=solver), Solver), f"{solver} not created" assert isinstance(Solver(None, solver="EMSDESolver"), EMSDESolver), "EMSDESolver not created" - assert isinstance(Solver(None, solver="HeunODESolver"), HeunODESolver), "HeunODESolver not created" + assert isinstance( + Solver(None, solver="HeunODESolver"), HeunODESolver + ), "HeunODESolver not created" assert isinstance(EMSDESolver(None), Solver), "EMSDESolver not created" with pytest.raises(ValueError): # unknown solver Solver(None, solver="random_solver") @@ -32,7 +44,16 @@ def test_solver_constructor(): ), ) @pytest.mark.parametrize( - "solver", ["EMSDESolver", "HeunSDESolver", "RK4SDESolver", "EulerODESolver", "HeunODESolver", "RK4ODESolver"] + "solver", + [ + "EMSDESolver", + "HeunSDESolver", + "HeunSDESolverAdaptive", + "RK4SDESolver", + "EulerODESolver", + "HeunODESolver", + "RK4ODESolver", + ], ) def test_solver_sample(solver, mean, cov): sde = VESDE(sigma_min=1e-2, sigma_max=10) @@ -48,6 +69,7 @@ def test_solver_sample(solver, mean, cov): steps=50, solver=solver, kill_on_nan=True, + progress_bar=True, ) assert torch.all(torch.isfinite(samples)) assert torch.allclose(samples.mean(dim=0), mean, atol=1), "mean not close" @@ -62,7 +84,16 @@ def test_solver_sample(solver, mean, cov): ), ) @pytest.mark.parametrize( - "solver", ["EMSDESolver", "HeunSDESolver", "RK4SDESolver", "EulerODESolver", "HeunODESolver", "RK4ODESolver"] + "solver", + [ + "EMSDESolver", + "HeunSDESolver", + "HeunSDESolverAdaptive", + "RK4SDESolver", + "EulerODESolver", + "HeunODESolver", + "RK4ODESolver", + ], ) def test_solver_forward(solver, mean, cov): sde = VESDE(sigma_min=1e-2, sigma_max=10) @@ -91,7 +122,19 @@ def test_solver_forward(solver, mean, cov): (None, torch.cat((torch.logspace(0, -2, 49), torch.zeros(1)))), # 50 steps with log spacing ), ) -def test_solver_step(steps, time_steps): +@pytest.mark.parametrize( + "solver", + [ + "EMSDESolver", + "HeunSDESolver", + "HeunSDESolverAdaptive", + "RK4SDESolver", + "EulerODESolver", + "HeunODESolver", + "RK4ODESolver", + ], +) +def test_solver_step(steps, time_steps, solver): sde = VESDE(sigma_min=1e-2, sigma_max=10) mean = torch.zeros(2, dtype=torch.float32) cov = torch.ones(2, dtype=torch.float32) @@ -100,7 +143,13 @@ def test_solver_step(steps, time_steps): mean=mean, cov=cov, ) - samples = model.sample(shape=(100, mean.shape[-1]), steps=steps, time_steps=time_steps) + if solver == "HeunSDESolverAdaptive" and steps is None: + kwargs = {"dt_init": 1e-2} + else: + kwargs = {} + samples = model.sample( + shape=(100, mean.shape[-1]), solver=solver, steps=steps, time_steps=time_steps, **kwargs + ) assert torch.all(torch.isfinite(samples)) assert torch.allclose(samples.mean(dim=0), mean, atol=1), "mean for MVG samples not close" assert torch.allclose(samples.std(dim=0), cov.sqrt(), atol=1), "std for MVG samples not close" @@ -112,7 +161,10 @@ def test_solver_step(steps, time_steps): ( (50, None), # 50 steps normally (None, torch.linspace(1, 0, 50)), # 50 steps set by user - (None, torch.cat((torch.logspace(0, -2, 49), torch.zeros(1))),), # 50 steps with log spacing + ( + None, + torch.cat((torch.logspace(0, -2, 49), torch.zeros(1))), + ), # 50 steps with log spacing ), ) def test_solver_logprob(steps, time_steps): @@ -129,7 +181,7 @@ def test_solver_logprob(steps, time_steps): x = torch.rand(100, 2, dtype=torch.float32) logp = model.log_prob(x, steps=steps, time_steps=time_steps) true_logp = torch.distributions.MultivariateNormal(mean, cov).log_prob(x) - + print(logp - true_logp) assert torch.all(torch.isfinite(logp)) assert torch.allclose(logp, true_logp, atol=1e-3), "logp for MVG samples not close"