From 6a9966b2f905ef5e097c33dc5d32e1634ea79c5c Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Wed, 2 Jul 2025 14:21:18 -0400 Subject: [PATCH 1/2] Exploit band-diagonal structure in quasisep transitions for better scaling. --- src/tinygp/kernels/quasisep.py | 16 +-- src/tinygp/solvers/kalman.py | 2 + src/tinygp/solvers/quasisep/block.py | 120 ++++++++++++++++++ src/tinygp/solvers/quasisep/general.py | 11 +- tests/test_kernels/test_quasisep.py | 10 +- .../test_solvers/test_quasisep/test_block.py | 31 +++++ 6 files changed, 174 insertions(+), 16 deletions(-) create mode 100644 src/tinygp/solvers/quasisep/block.py create mode 100644 tests/test_solvers/test_quasisep/test_block.py diff --git a/src/tinygp/kernels/quasisep.py b/src/tinygp/kernels/quasisep.py index 21eb095c..af232811 100644 --- a/src/tinygp/kernels/quasisep.py +++ b/src/tinygp/kernels/quasisep.py @@ -31,11 +31,11 @@ import equinox as eqx import jax import jax.numpy as jnp -import jax.scipy as jsp import numpy as np from tinygp.helpers import JAXArray from tinygp.kernels.base import Kernel +from tinygp.solvers.quasisep.block import Block from tinygp.solvers.quasisep.core import DiagQSM, StrictLowerTriQSM, SymmQSM from tinygp.solvers.quasisep.general import GeneralQSM @@ -96,7 +96,7 @@ def to_symm_qsm(self, X: JAXArray) -> SymmQSM: q = h p = h @ Pinf d = jnp.sum(p * q, axis=1) - p = jax.vmap(jnp.dot)(p, a) + p = jax.vmap(lambda x, y: x @ y)(p, a) return SymmQSM(diag=DiagQSM(d=d), lower=StrictLowerTriQSM(p=p, q=q, a=a)) def to_general_qsm(self, X1: JAXArray, X2: JAXArray) -> GeneralQSM: @@ -117,11 +117,11 @@ def to_general_qsm(self, X1: JAXArray, X2: JAXArray) -> GeneralQSM: i = jnp.clip(idx, 0, ql.shape[0] - 1) Xi = jax.tree_util.tree_map(lambda x: jnp.asarray(x)[i], X2) - pl = jax.vmap(jnp.dot)(pl, jax.vmap(self.transition_matrix)(Xi, X1)) + pl = jax.vmap(lambda x, y: x @ y)(pl, jax.vmap(self.transition_matrix)(Xi, X1)) i = jnp.clip(idx + 1, 0, pu.shape[0] - 1) Xi = jax.tree_util.tree_map(lambda x: jnp.asarray(x)[i], X2) - qu = jax.vmap(jnp.dot)(jax.vmap(self.transition_matrix)(X1, Xi), qu) + qu = jax.vmap(lambda x, y: x @ y)(jax.vmap(self.transition_matrix)(X1, Xi), qu) return GeneralQSM(pl=pl, ql=ql, pu=pu, qu=qu, a=a, idx=idx) @@ -230,12 +230,10 @@ def coord_to_sortable(self, X: JAXArray) -> JAXArray: return self.kernel1.coord_to_sortable(X) def design_matrix(self) -> JAXArray: - return jsp.linalg.block_diag( - self.kernel1.design_matrix(), self.kernel2.design_matrix() - ) + return Block(self.kernel1.design_matrix(), self.kernel2.design_matrix()) def stationary_covariance(self) -> JAXArray: - return jsp.linalg.block_diag( + return Block( self.kernel1.stationary_covariance(), self.kernel2.stationary_covariance(), ) @@ -249,7 +247,7 @@ def observation_model(self, X: JAXArray) -> JAXArray: ) def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray: - return jsp.linalg.block_diag( + return Block( self.kernel1.transition_matrix(X1, X2), self.kernel2.transition_matrix(X1, X2), ) diff --git a/src/tinygp/solvers/kalman.py b/src/tinygp/solvers/kalman.py index e4a45166..d84f143f 100644 --- a/src/tinygp/solvers/kalman.py +++ b/src/tinygp/solvers/kalman.py @@ -101,6 +101,8 @@ def step(carry, data): # type: ignore return Pk, (sk, Kk) init = Pinf + if hasattr(init, "to_dense"): + init = init.to_dense() return jax.lax.scan(step, init, (A, H, diag))[1] diff --git a/src/tinygp/solvers/quasisep/block.py b/src/tinygp/solvers/quasisep/block.py new file mode 100644 index 00000000..5a470625 --- /dev/null +++ b/src/tinygp/solvers/quasisep/block.py @@ -0,0 +1,120 @@ +from typing import Any + +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np +from jax.scipy.linalg import block_diag + +from tinygp.helpers import JAXArray + + +class Block(eqx.Module): + blocks: tuple[Any, ...] + __array_priority__ = 1999 + + def __init__(self, *blocks: Any): + self.blocks = blocks + + def __getitem__(self, idx: Any) -> "Block": + return Block(*(b[idx] for b in self.blocks)) + + def __len__(self) -> int: + assert all(np.ndim(b) == 2 for b in self.blocks) + return sum(len(b) for b in self.blocks) + + @property + def ndim(self) -> int: + ndim, = {np.ndim(b) for b in self.blocks} + return ndim + + @property + def shape(self) -> tuple[int, int]: + size = len(self) + return (size, size) + + def transpose(self) -> "Block": + return Block(*(b.transpose() for b in self.blocks)) + + @property + def T(self) -> "Block": + return self.transpose() + + def to_dense(self) -> JAXArray: + assert all(np.ndim(b) == 2 for b in self.blocks) + return block_diag(*self.blocks) + + @jax.jit + def __mul__(self, other: Any) -> "Block": + return Block(*(b * other for b in self.blocks)) + + @jax.jit + def __add__(self, other: Any) -> Any: + if isinstance(other, Block): + assert len(self) == len(other) + assert all( + np.shape(b1) == np.shape(b2) + for b1, b2 in zip(self.blocks, other.blocks) + ) + return Block(*(b1 + b2 for b1, b2 in zip(self.blocks, other.blocks))) + else: + # TODO(dfm): This could be optimized to avoid converting to dense. + return self.to_dense() + other + + def __radd__(self, other: Any) -> Any: + # TODO(dfm): This could be optimized to avoid converting to dense. + return other + self.to_dense() + + @jax.jit + def __sub__(self, other: Any) -> Any: + if isinstance(other, Block): + assert len(self) == len(other) + assert all( + np.shape(b1) == np.shape(b2) + for b1, b2 in zip(self.blocks, other.blocks) + ) + return Block(*(b1 - b2 for b1, b2 in zip(self.blocks, other.blocks))) + else: + # TODO(dfm): This could be optimized to avoid converting to dense. + return self.to_dense() - other + + def __rsub__(self, other: Any) -> Any: + # TODO(dfm): This could be optimized to avoid converting to dense. + return other - self.to_dense() + + @jax.jit + def __matmul__(self, other: Any) -> Any: + if isinstance(other, Block): + assert len(self.blocks) == len(other.blocks) + assert all( + np.shape(b1) == np.shape(b2) + for b1, b2 in zip(self.blocks, other.blocks) + ) + return Block(*(b1 @ b2 for b1, b2 in zip(self.blocks, other.blocks))) + assert all(np.ndim(b) == 2 for b in self.blocks) + ndim = np.ndim(other) + assert ndim >= 1 + idx = 0 + ys = [] + for b in self.blocks: + size = len(b) + x = ( + other[idx : idx + size] + if ndim == 1 + else other[..., idx : idx + size, :] + ) + ys.append(b @ x) + idx += size + return jnp.concatenate(ys) if ndim == 1 else jnp.concatenate(ys, axis=-2) + + @jax.jit + def __rmatmul__(self, other: Any) -> Any: + assert all(np.ndim(b) == 2 for b in self.blocks) + idx = 0 + ys = [] + for b in self.blocks: + size = len(b) + x = other[..., idx : idx + size] + ys.append(x @ b) + idx += size + return jnp.concatenate(ys, axis=-1) diff --git a/src/tinygp/solvers/quasisep/general.py b/src/tinygp/solvers/quasisep/general.py index 5a62dd5c..e2a0c4f6 100644 --- a/src/tinygp/solvers/quasisep/general.py +++ b/src/tinygp/solvers/quasisep/general.py @@ -85,16 +85,17 @@ def forward(f, data): # type: ignore lower = jax.vmap(jnp.dot)(jnp.where(mask[:, None], self.pl, 0), f[idx]) # Then a backward pass to apply the "upper" matrix - def backward(f, data): # type: ignore + def backward(carry, data): # type: ignore + f, ap = carry p, a, x = data - fn = a.T @ f + jnp.outer(p, x) - return fn, fn + fn = ap.T @ f + jnp.outer(p, x) + return (fn, a), fn - init = jnp.zeros_like(jnp.outer(self.pu[-1], x[-1])) + init = jnp.zeros_like(jnp.outer(self.pu[-1], x[-1])), self.a[0] _, f = jax.lax.scan( backward, init, - (self.pu, jnp.roll(self.a, -1, axis=0), x), + (self.pu, self.a, x), reverse=True, ) idx = jnp.clip(self.idx + 1, 0, f.shape[0] - 1) diff --git a/tests/test_kernels/test_quasisep.py b/tests/test_kernels/test_quasisep.py index 1a259c5a..c426224d 100644 --- a/tests/test_kernels/test_quasisep.py +++ b/tests/test_kernels/test_quasisep.py @@ -61,8 +61,14 @@ def test_quasisep_kernels(data, kernel): # Test that F is defined consistently x1 = x[0] x2 = x[1] - num_A = jsp.linalg.expm(kernel.design_matrix().T * (x2 - x1)) - assert_allclose(kernel.transition_matrix(x1, x2), num_A) + arg = kernel.design_matrix().T * (x2 - x1) + if hasattr(arg, "to_dense"): + arg = arg.to_dense() + expected = jsp.linalg.expm(arg) + calculated = kernel.transition_matrix(x1, x2) + if hasattr(calculated, "to_dense"): + calculated = calculated.to_dense() + assert_allclose(calculated, expected) def test_quasisep_kernel_as_pytree(data, kernel): diff --git a/tests/test_solvers/test_quasisep/test_block.py b/tests/test_solvers/test_quasisep/test_block.py new file mode 100644 index 00000000..be089494 --- /dev/null +++ b/tests/test_solvers/test_quasisep/test_block.py @@ -0,0 +1,31 @@ +# mypy: ignore-errors + +import jax.numpy as jnp +import numpy as np +import pytest +from scipy.linalg import block_diag + +from tinygp.solvers.quasisep.block import Block +from tinygp.test_utils import assert_allclose + + +@pytest.fixture(params=[(10,), (10, 3), (4, 10, 3), (2, 4, 10, 3)]) +def block_params(request): + shape = request.param + random = np.random.default_rng(1234) + block_sizes = [1, 3, 2, 4] + block = Block(*(random.uniform(size=(s, s)) for s in block_sizes)) + x = random.uniform(size=shape) + return block, x + + +def test_block(block_params): + block, x = block_params + xt = x if len(x.shape) == 1 else jnp.swapaxes(x, -1, -2) + block_ = block_diag(*block.blocks) + assert len(block) == len(block_) + assert block.shape == block_.shape + assert_allclose(block @ x, block_ @ x) + assert_allclose(xt @ block, xt @ block_) + assert_allclose(block.T @ x, block_.T @ x) + assert_allclose(xt @ block.T, xt @ block_.T) From d587a931a50fc7146ee0269b76aaddbffb8113b6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Jul 2025 16:06:25 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/tinygp/solvers/quasisep/block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tinygp/solvers/quasisep/block.py b/src/tinygp/solvers/quasisep/block.py index 5a470625..f5736064 100644 --- a/src/tinygp/solvers/quasisep/block.py +++ b/src/tinygp/solvers/quasisep/block.py @@ -25,7 +25,7 @@ def __len__(self) -> int: @property def ndim(self) -> int: - ndim, = {np.ndim(b) for b in self.blocks} + (ndim,) = {np.ndim(b) for b in self.blocks} return ndim @property