Skip to content

Commit 9123588

Browse files
authored
Merge pull request #200 from Jammy2211/feature/matern_jax
feature/matern_jax
2 parents b3f8832 + 6f33782 commit 9123588

File tree

3 files changed

+88
-75
lines changed

3 files changed

+88
-75
lines changed

autoarray/inversion/regularization/matern_kernel.py

Lines changed: 72 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
2+
import jax.numpy as jnp
3+
import jax.scipy.special as jsp
24
import numpy as np
3-
45
import math
5-
import scipy.special as sc
66
from typing import TYPE_CHECKING
77

88
if TYPE_CHECKING:
@@ -12,35 +12,68 @@
1212

1313
from autoarray import numba_util
1414

15+
import jax.numpy as jnp
16+
17+
18+
def kv_xp(v, z, xp=np):
19+
"""
20+
XP-compatible modified Bessel K_v(v, z).
21+
22+
NumPy backend:
23+
-> scipy.special.kv
1524
16-
@numba_util.jit(cache=False)
17-
def matern_kernel(r: float, l: float = 1.0, v: float = 0.5):
25+
JAX backend:
26+
-> jax.scipy.special.kv if available
27+
-> else tfp.substrates.jax.math.bessel_kve * exp(-|z|)
1828
"""
19-
need to `pip install numba-scipy `
20-
see https://gaussianprocess.org/gpml/chapters/RW4.pdf for more info
2129

22-
the distance r need to be scalar
23-
l is the scale
24-
v is the order, better < 30, otherwise may have numerical NaN issue.
30+
# -------------------------
31+
# NumPy backend
32+
# -------------------------
33+
if xp is np:
34+
import scipy.special as sc
35+
36+
return sc.kv(v, z)
2537

26-
v control the smoothness level. the larger the v, the stronger smoothing condition (i.e., the solution is
27-
v-th differentiable) imposed by the kernel.
38+
# -------------------------
39+
# JAX backend
40+
# -------------------------
41+
else:
42+
try:
43+
import tensorflow_probability.substrates.jax as tfp
44+
45+
return tfp.math.bessel_kve(v, z) * xp.exp(-xp.abs(z))
46+
except ImportError:
47+
raise ImportError(
48+
"To use the JAX backend with the Matérn kernel, "
49+
"please install tensorflow-probability via `pip install tensorflow-probability==0.25.0`."
50+
)
51+
52+
53+
def matern_kernel(r, l: float = 1.0, v: float = 0.5, xp=np):
2854
"""
29-
r = abs(r)
30-
if r == 0:
31-
r = 0.00000001
32-
part1 = 2 ** (1 - v) / math.gamma(v)
33-
part2 = (math.sqrt(2 * v) * r / l) ** v
34-
part3 = sc.kv(v, math.sqrt(2 * v) * r / l)
55+
XP-compatible Matérn kernel.
56+
Works with NumPy or JAX.
57+
"""
58+
59+
# Avoid r = 0 singularity (JAX-safe)
60+
r = xp.maximum(xp.abs(r), 1e-8)
61+
62+
z = xp.sqrt(2.0 * v) * r / l
63+
64+
part1 = 2.0 ** (1.0 - v) / math.gamma(v) # scalar constant
65+
part2 = z**v
66+
part3 = kv_xp(v, z, xp)
67+
3568
return part1 * part2 * part3
3669

3770

38-
@numba_util.jit(cache=False)
3971
def matern_cov_matrix_from(
4072
scale: float,
4173
nu: float,
42-
pixel_points: np.ndarray,
43-
) -> np.ndarray:
74+
pixel_points,
75+
xp=np,
76+
):
4477
"""
4578
Consutruct the regularization covariance matrix, which is used to determined the regularization pattern (i.e,
4679
how the different pixels are correlated).
@@ -63,45 +96,27 @@ def matern_cov_matrix_from(
6396
The source covariance matrix (2d array), shape [N_source_pixels, N_source_pixels].
6497
"""
6598

66-
pixels = len(pixel_points)
67-
covariance_matrix = np.zeros(shape=(pixels, pixels))
68-
69-
for i in range(pixels):
70-
covariance_matrix[i, i] += 1e-8
71-
for j in range(pixels):
72-
xi = pixel_points[i, 1]
73-
yi = pixel_points[i, 0]
74-
xj = pixel_points[j, 1]
75-
yj = pixel_points[j, 0]
76-
d_ij = np.sqrt(
77-
(xi - xj) ** 2 + (yi - yj) ** 2
78-
) # distance between the pixel i and j
79-
80-
covariance_matrix[i, j] += matern_kernel(d_ij, l=scale, v=nu)
81-
82-
return covariance_matrix
99+
# --------------------------------
100+
# Pairwise distances (broadcasted)
101+
# --------------------------------
102+
# pixel_points[:, None, :] -> (N, 1, 2)
103+
# pixel_points[None, :, :] -> (1, N, 2)
104+
diff = pixel_points[:, None, :] - pixel_points[None, :, :] # (N, N, 2)
83105

106+
d_ij = xp.sqrt(diff[..., 0] ** 2 + diff[..., 1] ** 2) # (N, N)
84107

85-
class NumbaScipyPlaceholder:
86-
pass
108+
# --------------------------------
109+
# Apply Matérn kernel elementwise
110+
# --------------------------------
111+
covariance_matrix = matern_kernel(d_ij, l=scale, v=nu, xp=xp)
87112

113+
# --------------------------------
114+
# Add diagonal jitter (JAX-safe)
115+
# --------------------------------
116+
pixels = pixel_points.shape[0]
117+
covariance_matrix = covariance_matrix + 1e-8 * xp.eye(pixels)
88118

89-
try:
90-
import numba_scipy
91-
92-
numba_scipy = object
93-
except ModuleNotFoundError:
94-
numba_scipy = NumbaScipyPlaceholder()
95-
96-
97-
def numba_scipy_exception():
98-
raise ModuleNotFoundError(
99-
"\n--------------------\n"
100-
"You are attempting to use the MaternKernel for Regularization.\n\n"
101-
"However, the optional library numba_scipy (https://pypi.org/project/numba-scipy/) is not installed.\n\n"
102-
"Install it via the command `pip install numba-scipy==0.3.1`.\n\n"
103-
"----------------------"
104-
)
119+
return covariance_matrix
105120

106121

107122
class MaternKernel(AbstractRegularization):
@@ -131,9 +146,6 @@ def __init__(self, coefficient: float = 1.0, scale: float = 1.0, nu: float = 0.5
131146
Controls the derivative of the regularization pattern (`nu=0.5` is a Gaussian).
132147
"""
133148

134-
if isinstance(numba_scipy, NumbaScipyPlaceholder):
135-
numba_scipy_exception()
136-
137149
self.coefficient = coefficient
138150
self.scale = float(scale)
139151
self.nu = float(nu)
@@ -175,8 +187,9 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray
175187
"""
176188
covariance_matrix = matern_cov_matrix_from(
177189
scale=self.scale,
178-
pixel_points=xp.array(linear_obj.source_plane_mesh_grid),
190+
pixel_points=linear_obj.source_plane_mesh_grid.array,
179191
nu=self.nu,
192+
xp=xp,
180193
)
181194

182195
return self.coefficient * xp.linalg.inv(covariance_matrix)

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ local_scheme = "no-local-version"
5252
[project.optional-dependencies]
5353
optional=[
5454
"numba",
55-
"pynufft"
55+
"pynufft",
56+
"ensorflow-probability==0.25.0"
5657
]
5758
test = ["pytest"]
5859
dev = ["pytest", "black"]

test_autoarray/inversion/regularizations/test_matern_kernel.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,17 @@
77

88

99
def test__regularization_matrix():
10-
pass
11-
12-
# reg = aa.reg.MaternKernel(coefficient=1.0, scale=2.0, nu=2.0)
13-
#
14-
# source_plane_mesh_grid = aa.Grid2D.no_mask(
15-
# values=[[0.1, 0.1], [1.1, 0.6], [2.1, 0.1], [0.4, 1.1], [1.1, 7.1], [2.1, 1.1]],
16-
# shape_native=(3, 2),
17-
# pixel_scales=1.0,
18-
# )
19-
#
20-
# mapper = aa.m.MockMapper(source_plane_mesh_grid=source_plane_mesh_grid)
21-
22-
# regularization_matrix = reg.regularization_matrix_from(linear_obj=mapper)
23-
#
24-
# assert regularization_matrix[0, 0] == pytest.approx(3.540276762, 1.0e-4)
10+
11+
reg = aa.reg.MaternKernel(coefficient=1.0, scale=2.0, nu=2.0)
12+
13+
source_plane_mesh_grid = aa.Grid2D.no_mask(
14+
values=[[0.1, 0.1], [1.1, 0.6], [2.1, 0.1], [0.4, 1.1], [1.1, 7.1], [2.1, 1.1]],
15+
shape_native=(3, 2),
16+
pixel_scales=1.0,
17+
)
18+
19+
mapper = aa.m.MockMapper(source_plane_mesh_grid=source_plane_mesh_grid)
20+
21+
regularization_matrix = reg.regularization_matrix_from(linear_obj=mapper)
22+
23+
assert regularization_matrix[0, 0] == pytest.approx(3.540276762, 1.0e-4)

0 commit comments

Comments
 (0)