Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions src/tinygp/kernels/quasisep.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np

from tinygp.helpers import JAXArray
from tinygp.kernels.base import Kernel
from tinygp.solvers.quasisep.block import Block
from tinygp.solvers.quasisep.core import DiagQSM, StrictLowerTriQSM, SymmQSM
from tinygp.solvers.quasisep.general import GeneralQSM

Expand Down Expand Up @@ -96,7 +96,7 @@ def to_symm_qsm(self, X: JAXArray) -> SymmQSM:
q = h
p = h @ Pinf
d = jnp.sum(p * q, axis=1)
p = jax.vmap(jnp.dot)(p, a)
p = jax.vmap(lambda x, y: x @ y)(p, a)
return SymmQSM(diag=DiagQSM(d=d), lower=StrictLowerTriQSM(p=p, q=q, a=a))

def to_general_qsm(self, X1: JAXArray, X2: JAXArray) -> GeneralQSM:
Expand All @@ -117,11 +117,11 @@ def to_general_qsm(self, X1: JAXArray, X2: JAXArray) -> GeneralQSM:

i = jnp.clip(idx, 0, ql.shape[0] - 1)
Xi = jax.tree_util.tree_map(lambda x: jnp.asarray(x)[i], X2)
pl = jax.vmap(jnp.dot)(pl, jax.vmap(self.transition_matrix)(Xi, X1))
pl = jax.vmap(lambda x, y: x @ y)(pl, jax.vmap(self.transition_matrix)(Xi, X1))

i = jnp.clip(idx + 1, 0, pu.shape[0] - 1)
Xi = jax.tree_util.tree_map(lambda x: jnp.asarray(x)[i], X2)
qu = jax.vmap(jnp.dot)(jax.vmap(self.transition_matrix)(X1, Xi), qu)
qu = jax.vmap(lambda x, y: x @ y)(jax.vmap(self.transition_matrix)(X1, Xi), qu)

return GeneralQSM(pl=pl, ql=ql, pu=pu, qu=qu, a=a, idx=idx)

Expand Down Expand Up @@ -230,12 +230,10 @@ def coord_to_sortable(self, X: JAXArray) -> JAXArray:
return self.kernel1.coord_to_sortable(X)

def design_matrix(self) -> JAXArray:
return jsp.linalg.block_diag(
self.kernel1.design_matrix(), self.kernel2.design_matrix()
)
return Block(self.kernel1.design_matrix(), self.kernel2.design_matrix())

def stationary_covariance(self) -> JAXArray:
return jsp.linalg.block_diag(
return Block(
self.kernel1.stationary_covariance(),
self.kernel2.stationary_covariance(),
)
Expand All @@ -249,7 +247,7 @@ def observation_model(self, X: JAXArray) -> JAXArray:
)

def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
return jsp.linalg.block_diag(
return Block(
self.kernel1.transition_matrix(X1, X2),
self.kernel2.transition_matrix(X1, X2),
)
Expand Down
2 changes: 2 additions & 0 deletions src/tinygp/solvers/kalman.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def step(carry, data): # type: ignore
return Pk, (sk, Kk)

init = Pinf
if hasattr(init, "to_dense"):
init = init.to_dense()
return jax.lax.scan(step, init, (A, H, diag))[1]


Expand Down
120 changes: 120 additions & 0 deletions src/tinygp/solvers/quasisep/block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import Any

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.linalg import block_diag

from tinygp.helpers import JAXArray


class Block(eqx.Module):
blocks: tuple[Any, ...]
__array_priority__ = 1999

def __init__(self, *blocks: Any):
self.blocks = blocks

def __getitem__(self, idx: Any) -> "Block":
return Block(*(b[idx] for b in self.blocks))

def __len__(self) -> int:
assert all(np.ndim(b) == 2 for b in self.blocks)
return sum(len(b) for b in self.blocks)

@property
def ndim(self) -> int:
(ndim,) = {np.ndim(b) for b in self.blocks}
return ndim

@property
def shape(self) -> tuple[int, int]:
size = len(self)
return (size, size)

def transpose(self) -> "Block":
return Block(*(b.transpose() for b in self.blocks))

@property
def T(self) -> "Block":
return self.transpose()

def to_dense(self) -> JAXArray:
assert all(np.ndim(b) == 2 for b in self.blocks)
return block_diag(*self.blocks)

@jax.jit
def __mul__(self, other: Any) -> "Block":
return Block(*(b * other for b in self.blocks))

@jax.jit
def __add__(self, other: Any) -> Any:
if isinstance(other, Block):
assert len(self) == len(other)
assert all(
np.shape(b1) == np.shape(b2)
for b1, b2 in zip(self.blocks, other.blocks)
)
return Block(*(b1 + b2 for b1, b2 in zip(self.blocks, other.blocks)))
else:
# TODO(dfm): This could be optimized to avoid converting to dense.
return self.to_dense() + other

def __radd__(self, other: Any) -> Any:
# TODO(dfm): This could be optimized to avoid converting to dense.
return other + self.to_dense()

@jax.jit
def __sub__(self, other: Any) -> Any:
if isinstance(other, Block):
assert len(self) == len(other)
assert all(
np.shape(b1) == np.shape(b2)
for b1, b2 in zip(self.blocks, other.blocks)
)
return Block(*(b1 - b2 for b1, b2 in zip(self.blocks, other.blocks)))
else:
# TODO(dfm): This could be optimized to avoid converting to dense.
return self.to_dense() - other

def __rsub__(self, other: Any) -> Any:
# TODO(dfm): This could be optimized to avoid converting to dense.
return other - self.to_dense()

@jax.jit
def __matmul__(self, other: Any) -> Any:
if isinstance(other, Block):
assert len(self.blocks) == len(other.blocks)
assert all(
np.shape(b1) == np.shape(b2)
for b1, b2 in zip(self.blocks, other.blocks)
)
return Block(*(b1 @ b2 for b1, b2 in zip(self.blocks, other.blocks)))
assert all(np.ndim(b) == 2 for b in self.blocks)
ndim = np.ndim(other)
assert ndim >= 1
idx = 0
ys = []
for b in self.blocks:
size = len(b)
x = (
other[idx : idx + size]
if ndim == 1
else other[..., idx : idx + size, :]
)
ys.append(b @ x)
idx += size
return jnp.concatenate(ys) if ndim == 1 else jnp.concatenate(ys, axis=-2)

@jax.jit
def __rmatmul__(self, other: Any) -> Any:
assert all(np.ndim(b) == 2 for b in self.blocks)
idx = 0
ys = []
for b in self.blocks:
size = len(b)
x = other[..., idx : idx + size]
ys.append(x @ b)
idx += size
return jnp.concatenate(ys, axis=-1)
11 changes: 6 additions & 5 deletions src/tinygp/solvers/quasisep/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,17 @@ def forward(f, data): # type: ignore
lower = jax.vmap(jnp.dot)(jnp.where(mask[:, None], self.pl, 0), f[idx])

# Then a backward pass to apply the "upper" matrix
def backward(f, data): # type: ignore
def backward(carry, data): # type: ignore
f, ap = carry
p, a, x = data
fn = a.T @ f + jnp.outer(p, x)
return fn, fn
fn = ap.T @ f + jnp.outer(p, x)
return (fn, a), fn

init = jnp.zeros_like(jnp.outer(self.pu[-1], x[-1]))
init = jnp.zeros_like(jnp.outer(self.pu[-1], x[-1])), self.a[0]
_, f = jax.lax.scan(
backward,
init,
(self.pu, jnp.roll(self.a, -1, axis=0), x),
(self.pu, self.a, x),
reverse=True,
)
idx = jnp.clip(self.idx + 1, 0, f.shape[0] - 1)
Expand Down
10 changes: 8 additions & 2 deletions tests/test_kernels/test_quasisep.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,14 @@ def test_quasisep_kernels(data, kernel):
# Test that F is defined consistently
x1 = x[0]
x2 = x[1]
num_A = jsp.linalg.expm(kernel.design_matrix().T * (x2 - x1))
assert_allclose(kernel.transition_matrix(x1, x2), num_A)
arg = kernel.design_matrix().T * (x2 - x1)
if hasattr(arg, "to_dense"):
arg = arg.to_dense()
expected = jsp.linalg.expm(arg)
calculated = kernel.transition_matrix(x1, x2)
if hasattr(calculated, "to_dense"):
calculated = calculated.to_dense()
assert_allclose(calculated, expected)


def test_quasisep_kernel_as_pytree(data, kernel):
Expand Down
31 changes: 31 additions & 0 deletions tests/test_solvers/test_quasisep/test_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# mypy: ignore-errors

import jax.numpy as jnp
import numpy as np
import pytest
from scipy.linalg import block_diag

from tinygp.solvers.quasisep.block import Block
from tinygp.test_utils import assert_allclose


@pytest.fixture(params=[(10,), (10, 3), (4, 10, 3), (2, 4, 10, 3)])
def block_params(request):
shape = request.param
random = np.random.default_rng(1234)
block_sizes = [1, 3, 2, 4]
block = Block(*(random.uniform(size=(s, s)) for s in block_sizes))
x = random.uniform(size=shape)
return block, x


def test_block(block_params):
block, x = block_params
xt = x if len(x.shape) == 1 else jnp.swapaxes(x, -1, -2)
block_ = block_diag(*block.blocks)
assert len(block) == len(block_)
assert block.shape == block_.shape
assert_allclose(block @ x, block_ @ x)
assert_allclose(xt @ block, xt @ block_)
assert_allclose(block.T @ x, block_.T @ x)
assert_allclose(xt @ block.T, xt @ block_.T)