Skip to content

Block transition matrices break some addition and multiplication of kernels #265

@markfortune

Description

@markfortune

Hi, great work on the package!

The block transition matrices implemented for quasiseparable matrices are a neat optimisation but I've noticed that some operations don't seem to have been modified to deal with them. I give a few minimal examples here, in practice they cause a lot of headaches for some multi-wavelength light curve fitting I'm doing and also for GP optimisations I'm implementing which work with quasiseparable matrices directly. I'm not sure what version these specifically became an issue but it's certainly an issue with the latest versions.

For this example I'm running:

  • tinygp 0.3.1
  • jax, jaxlib 0.9.2
import tinygp
import jax.numpy as jnp

N_t = 100
t = jnp.linspace(-10., 10., N_t)

# k_beat has Block transition matrices
k_beat = tinygp.kernels.quasisep.Cosine(1.) + tinygp.kernels.quasisep.Cosine(2.)
banded_term = tinygp.noise.Banded(diag=jnp.ones(N_t), off_diags=jnp.ones((N_t, 1)))

# addition of two QSM where at least one of them has Block transition matrices fails
gp1 = tinygp.GaussianProcess(k_beat, t, noise=banded_term)  # breaks

k_prod = k_beat * tinygp.kernels.quasisep.Exp(1.)

# product of two kernels where at least one of them has Block transition matrices fails
gp2 = tinygp.GaussianProcess(k_prod, t, diag=jnp.ones(N_t))  # breaks
gp3 = tinygp.GaussianProcess(k_beat * k_beat, t, diag=jnp.ones(N_t))  # breaks

Crash 1: Adding QSMs with Block transition matrices

The first crash gives the error TypeError: Cannot determine dtype of Block(blocks=(f32[2,2], f32[2,2])) which corresponds to:

211 p1, q1, a1 = self
212 p2, q2, a2 = other
213 return StrictLowerTriQSM(
214     p=jnp.concatenate((p1, p2)),
215     q=jnp.concatenate((q1, q2)),
216     a=block_diag(a1, a2),  # <-- fails here
217 )

Crash 2: Product kernel with Block transition matrices

The second crash gives the error TypeError: dot_general requires contracting dimensions to have the same shape, got (0,) and (4,). which corresponds to Quasisep.to_symm_qsm:

95 h = jax.vmap(self.observation_model)(X)
96 q = h
97 p = h @ Pinf  # <-- fails here
98 d = jnp.sum(p * q, axis=1)
99 p = jax.vmap(lambda x, y: x @ y)(p, a)

Crash 3: Product of two Sum kernels

The third crash also happens when building a symmetric QSM:

 89 def to_symm_qsm(self, X: JAXArray) -> SymmQSM:
 90     """The symmetric quasiseparable representation of this kernel"""
 91     Pinf = self.stationary_covariance()  # <-- enters Product.stationary_covariance
 92     a = jax.vmap(self.transition_matrix)(
 93         jax.tree_util.tree_map(lambda y: jnp.append(y[0], y[:-1]), X), X
 94     )
 95     h = jax.vmap(self.observation_model)(X)

which calls into _prod_helper:

273 def stationary_covariance(self) -> JAXArray:
274     return _prod_helper(
275         self.kernel1.stationary_covariance(),
276         self.kernel2.stationary_covariance(),
277     )
639     return a1[i] * a2[j]
640 elif a1.ndim == 2:
641     return a1[i[:, None], i[None, :]] * a2[j[:, None], j[None, :]]  # <-- fails here
642 else:
643     raise NotImplementedError

which ultimately hits:

47 @jax.jit
48 def __mul__(self, other: Any) -> "Block":
49     return Block(*(b * other for b in self.blocks))
    # TypeError: unsupported operand type(s) for *: 'DynamicJaxprTracer' and 'Block'

Ideally I would like if there were an option to turn off the formation of block matrices (as suggested in PR #240), but the computational savings they can offer is useful and I'd imagine there shouldn't be any fundamental issue with updating the addition and multiplication rules of kernels to account for Block transition matrices, so that would be a nicer long-term fix.

Thanks again for all the great work!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions