Skip to content

Commit ac9805b

Browse files
Jammy2211Jammy2211
authored andcommitted
unit tests added with updates for JAX and correct functionality
1 parent 09be29c commit ac9805b

File tree

4 files changed

+187
-35
lines changed

4 files changed

+187
-35
lines changed

autoarray/inversion/regularization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
from .gaussian_kernel import GaussianKernel
1111
from .exponential_kernel import ExponentialKernel
1212
from .matern_kernel import MaternKernel
13+
from .matern_adaptive_brightness_kernel import MaternAdaptiveBrightnessKernel
Lines changed: 130 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,148 @@
1-
from autoarray.inversion.regularization.abstract import AbstractRegularization
1+
from __future__ import annotations
2+
import numpy as np
3+
from typing import TYPE_CHECKING
24

3-
from autoarray.inversion.regularization.matern_kernel import matern_cov_matrix_from
5+
from autoarray.inversion.regularization.matern_kernel import MaternKernel
46

7+
if TYPE_CHECKING:
8+
from autoarray.inversion.linear_obj.linear_obj import LinearObj
59

6-
class AdaptiveBrightnessMatern(AbstractRegularization):
10+
from autoarray.inversion.regularization.matern_kernel import matern_kernel
11+
12+
def matern_cov_matrix_from(
13+
scale: float,
14+
nu: float,
15+
pixel_points,
16+
weights=None,
17+
xp=np,
18+
):
19+
"""
20+
Construct the regularization covariance matrix (N x N) using a Matérn kernel,
21+
optionally modulated by per-pixel weights.
22+
23+
If `weights` is provided (shape [N]), the covariance is:
24+
C_ij = K(d_ij; scale, nu) * w_i * w_j
25+
with a small diagonal jitter added for numerical stability.
26+
27+
Parameters
28+
----------
29+
scale
30+
Typical correlation length of the Matérn kernel.
31+
nu
32+
Smoothness parameter of the Matérn kernel.
33+
pixel_points
34+
Array-like of shape [N, 2] with (y, x) coordinates (or any 2D coords; only distances matter).
35+
weights
36+
Optional array-like of shape [N]. If None, treated as all ones.
37+
xp
38+
Backend (numpy or jax.numpy).
39+
40+
Returns
41+
-------
42+
covariance_matrix
43+
Array of shape [N, N].
44+
"""
45+
46+
# --------------------------------
47+
# Pairwise distances (broadcasted)
48+
# --------------------------------
49+
diff = pixel_points[:, None, :] - pixel_points[None, :, :] # (N, N, 2)
50+
d_ij = xp.sqrt(diff[..., 0] ** 2 + diff[..., 1] ** 2) # (N, N)
51+
52+
# --------------------------------
53+
# Base Matérn covariance
54+
# --------------------------------
55+
covariance_matrix = matern_kernel(d_ij, l=scale, v=nu, xp=xp) # (N, N)
56+
57+
# --------------------------------
58+
# Apply weights: C_ij *= w_i * w_j
59+
# (broadcasted outer product, JAX-safe)
60+
# --------------------------------
61+
if weights is not None:
62+
w = xp.asarray(weights)
63+
# Ensure shape (N,) -> outer product (N,1)*(1,N) -> (N,N)
64+
covariance_matrix = covariance_matrix * (w[:, None] * w[None, :])
65+
66+
# --------------------------------
67+
# Add diagonal jitter (JAX-safe)
68+
# --------------------------------
69+
pixels = pixel_points.shape[0]
70+
covariance_matrix = covariance_matrix + 1e-8 * xp.eye(pixels)
71+
72+
return covariance_matrix
73+
74+
75+
class MaternAdaptiveBrightnessKernel(MaternKernel):
776
def __init__(
8-
self,
9-
coefficient: float = 1.0,
10-
scale: float = 1.0,
11-
nu: float = 0.5,
12-
rho: float = 1.0,
77+
self,
78+
coefficient: float = 1.0,
79+
scale: float = 1.0,
80+
nu: float = 0.5,
81+
rho: float = 1.0,
1382
):
14-
super().__init__(coefficient=coefficient, scale=scale, rho=rho)
15-
self.nu = nu
83+
"""
84+
Regularization which uses a Matern smoothing kernel to regularize the solution with regularization weights
85+
that adapt to the brightness of the source being reconstructed.
86+
87+
For this regularization scheme, every pixel is regularized with every other pixel. This contrasts many other
88+
schemes, where regularization is based on neighboring (e.g. do the pixels share a Delaunay edge?) or computing
89+
derivates around the center of the pixel (where nearby pixels are regularization locally in similar ways).
90+
91+
This makes the regularization matrix fully dense and therefore maybe change the run times of the solution.
92+
It also leads to more overall smoothing which can lead to more stable linear inversions.
93+
94+
For the weighted regularization scheme, each pixel is given an 'effective regularization weight', which is
95+
applied when each set of pixel neighbors are regularized with one another. The motivation of this is that
96+
different regions of a pixelization's mesh require different levels of regularization (e.g., high smoothing where the
97+
no signal is present and less smoothing where it is, see (Nightingale, Dye and Massey 2018)).
98+
99+
This scheme is not used by Vernardos et al. (2022): https://arxiv.org/abs/2202.09378, but it follows
100+
a similar approach.
16101
17-
def covariance_kernel_weights_from(self, linear_obj: LinearObj) -> np.ndarray:
18-
pixel_signals = linear_obj.pixel_signals_from(signal_scale=1.0)
19-
return np.exp(-self.rho * (1 - pixel_signals / pixel_signals.max()))
102+
A full description of regularization and this matrix can be found in the parent `AbstractRegularization` class.
20103
21-
def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
22-
kernel_weights = self.covariance_kernel_weights_from(linear_obj=linear_obj)
104+
Parameters
105+
----------
106+
coefficient
107+
The regularization coefficient which controls the degree of smooth of the inversion reconstruction.
108+
scale
109+
The typical scale of the exponential regularization pattern.
110+
nu
111+
Controls the derivative of the regularization pattern (`nu=0.5` is a Gaussian).
112+
"""
113+
super().__init__(coefficient=coefficient, scale=scale, nu=nu)
114+
self.rho = rho
115+
116+
def covariance_kernel_weights_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray:
117+
"""
118+
Returns per-pixel kernel weights that adapt to the reconstructed pixel brightness.
119+
"""
120+
# Assumes linear_obj.pixel_signals_from is xp-aware elsewhere in the codebase.
121+
pixel_signals = linear_obj.pixel_signals_from(signal_scale=1.0, xp=xp)
122+
123+
max_signal = xp.max(pixel_signals)
124+
max_signal = xp.maximum(max_signal, 1e-8) # avoid divide-by-zero (JAX-safe)
125+
126+
return xp.exp(-self.rho * (1.0 - pixel_signals / max_signal))
127+
128+
def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray:
129+
kernel_weights = self.covariance_kernel_weights_from(linear_obj=linear_obj, xp=xp)
130+
131+
# Follow the xp pattern used in the Matérn kernel module (often `.array` for grids).
132+
pixel_points = linear_obj.source_plane_mesh_grid
23133

24134
covariance_matrix = matern_cov_matrix_from(
25135
scale=self.scale,
26-
pixel_points=linear_obj.source_plane_mesh_grid,
136+
pixel_points=pixel_points,
27137
nu=self.nu,
28138
weights=kernel_weights,
139+
xp=xp,
29140
)
30141

31-
return self.coefficient * np.linalg.inv(covariance_matrix)
142+
return self.coefficient * xp.linalg.inv(covariance_matrix)
32143

33-
def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray:
144+
def regularization_weights_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray:
34145
"""
35146
Returns the regularization weights of this regularization scheme.
36-
37-
The regularization weights define the level of regularization applied to each parameter in the linear object
38-
(e.g. the ``pixels`` in a ``Mapper``).
39-
40-
For standard regularization (e.g. ``Constant``) are weights are equal, however for adaptive schemes
41-
(e.g. ``AdaptiveBrightness``) they vary to adapt to the data being reconstructed.
42-
43-
Parameters
44-
----------
45-
linear_obj
46-
The linear object (e.g. a ``Mapper``) which uses these weights when performing regularization.
47-
48-
Returns
49-
-------
50-
The regularization weights.
51147
"""
52-
return 1.0/self.covariance_kernel_weights_from(linear_obj=linear_obj) #meaningless, but consistent with other regularization schemes
148+
return 1.0 / self.covariance_kernel_weights_from(linear_obj=linear_obj, xp=xp)

autoarray/structures/arrays/kernel_2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ def convolved_image_from(
612612
image,
613613
blurring_image,
614614
jax_method="direct",
615-
use_mixed_precision : bool = False,
615+
use_mixed_precision: bool = False,
616616
xp=np,
617617
):
618618
"""
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
3+
import autoarray as aa
4+
import numpy as np
5+
6+
np.set_printoptions(threshold=np.inf)
7+
8+
9+
def test__regularization_matrix():
10+
11+
reg = aa.reg.MaternAdaptiveBrightnessKernel(coefficient=1.0, scale=2.0, nu=2.0, rho=1.0)
12+
13+
neighbors = np.array(
14+
[
15+
[1, 4, -1, -1],
16+
[2, 4, 0, -1],
17+
[3, 4, 5, 1],
18+
[5, 2, -1, -1],
19+
[5, 0, 1, 2],
20+
[2, 3, 4, -1],
21+
]
22+
)
23+
24+
neighbors_sizes = np.array([2, 3, 4, 2, 4, 3])
25+
pixel_signals = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
26+
27+
mesh_grid = aa.m.MockMeshGrid(neighbors=neighbors, neighbors_sizes=neighbors_sizes)
28+
29+
source_plane_mesh_grid = aa.Grid2D.no_mask(
30+
values=[[0.1, 0.1], [1.1, 0.6], [2.1, 0.1], [0.4, 1.1], [1.1, 7.1], [2.1, 1.1]],
31+
shape_native=(3, 2),
32+
pixel_scales=1.0,
33+
)
34+
35+
mapper = aa.m.MockMapper(
36+
source_plane_mesh_grid=source_plane_mesh_grid, pixel_signals=pixel_signals
37+
)
38+
39+
regularization_matrix = reg.regularization_matrix_from(linear_obj=mapper)
40+
41+
assert regularization_matrix[0, 0] == pytest.approx(18.7439565009, 1.0e-4)
42+
assert regularization_matrix[0, 1] == pytest.approx(-8.786547368, 1.0e-4)
43+
44+
reg = aa.reg.MaternAdaptiveBrightnessKernel(coefficient=1.5, scale=2.5, nu=2.5, rho=1.5)
45+
46+
pixel_signals = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
47+
48+
mapper = aa.m.MockMapper(
49+
source_plane_mesh_grid=source_plane_mesh_grid, pixel_signals=pixel_signals
50+
)
51+
52+
regularization_matrix = reg.regularization_matrix_from(linear_obj=mapper)
53+
54+
assert regularization_matrix[0, 0] == pytest.approx(121.0190770, 1.0e-4)
55+
assert regularization_matrix[0, 1] == pytest.approx(-66.9580331, 1.0e-4)

0 commit comments

Comments
 (0)