Skip to content

Commit b2dd830

Browse files
authored
Merge pull request #214 from Jammy2211/feature/matern_adaptive
Feature/matern adaptive
2 parents 4055b30 + 95cf5b7 commit b2dd830

File tree

7 files changed

+336
-17
lines changed

7 files changed

+336
-17
lines changed

autoarray/inversion/regularization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
from .gaussian_kernel import GaussianKernel
1111
from .exponential_kernel import ExponentialKernel
1212
from .matern_kernel import MaternKernel
13+
from .matern_adaptive_brightness_kernel import MaternAdaptiveBrightnessKernel
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from __future__ import annotations
2+
import numpy as np
3+
from typing import TYPE_CHECKING
4+
5+
from autoarray.inversion.regularization.matern_kernel import MaternKernel
6+
7+
if TYPE_CHECKING:
8+
from autoarray.inversion.linear_obj.linear_obj import LinearObj
9+
10+
from autoarray.inversion.regularization.matern_kernel import matern_kernel
11+
12+
13+
def matern_cov_matrix_from(
14+
scale: float,
15+
nu: float,
16+
pixel_points,
17+
weights=None,
18+
xp=np,
19+
):
20+
"""
21+
Construct the regularization covariance matrix (N x N) using a Matérn kernel,
22+
optionally modulated by per-pixel weights.
23+
24+
If `weights` is provided (shape [N]), the covariance is:
25+
C_ij = K(d_ij; scale, nu) * w_i * w_j
26+
with a small diagonal jitter added for numerical stability.
27+
28+
Parameters
29+
----------
30+
scale
31+
Typical correlation length of the Matérn kernel.
32+
nu
33+
Smoothness parameter of the Matérn kernel.
34+
pixel_points
35+
Array-like of shape [N, 2] with (y, x) coordinates (or any 2D coords; only distances matter).
36+
weights
37+
Optional array-like of shape [N]. If None, treated as all ones.
38+
xp
39+
Backend (numpy or jax.numpy).
40+
41+
Returns
42+
-------
43+
covariance_matrix
44+
Array of shape [N, N].
45+
"""
46+
47+
# --------------------------------
48+
# Pairwise distances (broadcasted)
49+
# --------------------------------
50+
diff = pixel_points[:, None, :] - pixel_points[None, :, :] # (N, N, 2)
51+
d_ij = xp.sqrt(diff[..., 0] ** 2 + diff[..., 1] ** 2) # (N, N)
52+
53+
# --------------------------------
54+
# Base Matérn covariance
55+
# --------------------------------
56+
covariance_matrix = matern_kernel(d_ij, l=scale, v=nu, xp=xp) # (N, N)
57+
58+
# --------------------------------
59+
# Apply weights: C_ij *= w_i * w_j
60+
# (broadcasted outer product, JAX-safe)
61+
# --------------------------------
62+
if weights is not None:
63+
w = xp.asarray(weights)
64+
# Ensure shape (N,) -> outer product (N,1)*(1,N) -> (N,N)
65+
covariance_matrix = covariance_matrix * (w[:, None] * w[None, :])
66+
67+
# --------------------------------
68+
# Add diagonal jitter (JAX-safe)
69+
# --------------------------------
70+
pixels = pixel_points.shape[0]
71+
covariance_matrix = covariance_matrix + 1e-8 * xp.eye(pixels)
72+
73+
return covariance_matrix
74+
75+
76+
class MaternAdaptiveBrightnessKernel(MaternKernel):
77+
def __init__(
78+
self,
79+
coefficient: float = 1.0,
80+
scale: float = 1.0,
81+
nu: float = 0.5,
82+
rho: float = 1.0,
83+
):
84+
"""
85+
Regularization which uses a Matern smoothing kernel to regularize the solution with regularization weights
86+
that adapt to the brightness of the source being reconstructed.
87+
88+
For this regularization scheme, every pixel is regularized with every other pixel. This contrasts many other
89+
schemes, where regularization is based on neighboring (e.g. do the pixels share a Delaunay edge?) or computing
90+
derivatives around the center of the pixel (where nearby pixels are regularization locally in similar ways).
91+
92+
This makes the regularization matrix fully dense and therefore may change the run times of the solution.
93+
It also leads to more overall smoothing which can lead to more stable linear inversions.
94+
95+
For the weighted regularization scheme, each pixel is given an 'effective regularization weight', which is
96+
applied when each set of pixel neighbors are regularized with one another. The motivation of this is that
97+
different regions of a pixelization's mesh require different levels of regularization (e.g., high smoothing where the
98+
no signal is present and less smoothing where it is, see (Nightingale, Dye and Massey 2018)).
99+
100+
This scheme is not used by Vernardos et al. (2022): https://arxiv.org/abs/2202.09378, but it follows
101+
a similar approach.
102+
103+
A full description of regularization and this matrix can be found in the parent `AbstractRegularization` class.
104+
105+
Parameters
106+
----------
107+
coefficient
108+
The regularization coefficient which controls the degree of smooth of the inversion reconstruction.
109+
scale
110+
The typical scale (correlation length) of the Matérn regularization kernel.
111+
nu
112+
Controls the smoothness (differentiability) of the Matérn kernel; ``nu=0.5`` corresponds to an
113+
exponential (Ornstein–Uhlenbeck) kernel, while a Gaussian covariance is obtained in the limit
114+
as ``nu`` approaches infinity.
115+
rho
116+
Controls how strongly the kernel weights adapt to pixel brightness. Larger values make bright pixels
117+
receive significantly higher weights (and faint pixels lower weights), while smaller values produce a
118+
more uniform weighting. Typical values are of order unity (e.g. 0.5–2.0).
119+
"""
120+
super().__init__(coefficient=coefficient, scale=scale, nu=nu)
121+
self.rho = rho
122+
123+
def covariance_kernel_weights_from(
124+
self, linear_obj: LinearObj, xp=np
125+
) -> np.ndarray:
126+
"""
127+
Returns per-pixel kernel weights that adapt to the reconstructed pixel brightness.
128+
"""
129+
# Assumes linear_obj.pixel_signals_from is xp-aware elsewhere in the codebase.
130+
pixel_signals = linear_obj.pixel_signals_from(signal_scale=1.0, xp=xp)
131+
132+
max_signal = xp.max(pixel_signals)
133+
max_signal = xp.maximum(max_signal, 1e-8) # avoid divide-by-zero (JAX-safe)
134+
135+
return xp.exp(-self.rho * (1.0 - pixel_signals / max_signal))
136+
137+
def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray:
138+
kernel_weights = self.covariance_kernel_weights_from(
139+
linear_obj=linear_obj, xp=xp
140+
)
141+
142+
# Follow the xp pattern used in the Matérn kernel module (often `.array` for grids).
143+
pixel_points = linear_obj.source_plane_mesh_grid.array
144+
145+
covariance_matrix = matern_cov_matrix_from(
146+
scale=self.scale,
147+
pixel_points=pixel_points,
148+
nu=self.nu,
149+
weights=kernel_weights,
150+
xp=xp,
151+
)
152+
153+
return self.coefficient * xp.linalg.inv(covariance_matrix)
154+
155+
def regularization_weights_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray:
156+
"""
157+
Returns the regularization weights of this regularization scheme.
158+
"""
159+
return 1.0 / self.covariance_kernel_weights_from(linear_obj=linear_obj, xp=xp)

autoarray/inversion/regularization/matern_kernel.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,20 @@ def kv_xp(v, z, xp=np):
4444
)
4545

4646

47+
def gamma_xp(x, xp=np):
48+
"""
49+
XP-compatible Gamma(x).
50+
"""
51+
if xp is np:
52+
import scipy.special as sc
53+
54+
return sc.gamma(x)
55+
else:
56+
import jax.scipy.special as jsp
57+
58+
return jsp.gamma(x)
59+
60+
4761
def matern_kernel(r, l: float = 1.0, v: float = 0.5, xp=np):
4862
"""
4963
XP-compatible Matérn kernel.
@@ -55,7 +69,7 @@ def matern_kernel(r, l: float = 1.0, v: float = 0.5, xp=np):
5569

5670
z = xp.sqrt(2.0 * v) * r / l
5771

58-
part1 = 2.0 ** (1.0 - v) / math.gamma(v) # scalar constant
72+
part1 = 2.0 ** (1.0 - v) / gamma_xp(v, xp) # scalar constant
5973
part2 = z**v
6074
part3 = kv_xp(v, z, xp)
6175

@@ -141,8 +155,8 @@ def __init__(self, coefficient: float = 1.0, scale: float = 1.0, nu: float = 0.5
141155
"""
142156

143157
self.coefficient = coefficient
144-
self.scale = float(scale)
145-
self.nu = float(nu)
158+
self.scale = scale
159+
self.nu = nu
146160
super().__init__()
147161

148162
def regularization_weights_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray:

autoarray/preloads.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(
2626
linear_light_profile_blurred_mapping_matrix=None,
2727
use_voronoi_areas: bool = True,
2828
areas_factor: float = 0.5,
29+
skip_areas: bool = False,
2930
):
3031
"""
3132
Stores preloaded arrays and matrices used during pixelized linear inversions, improving both performance
@@ -81,6 +82,16 @@ def __init__(
8182
inversion, with the other component being the pixelization's pixels. These are fixed when the lens light
8283
is fixed to the maximum likelihood solution, allowing the blurred mapping matrix to be preloaded, but
8384
the intensity values will still be solved for during the inversion.
85+
use_voronoi_areas
86+
Whether to use Voronoi areas during Delaunay triangulation. When True, computes areas for each Voronoi
87+
region which can be used in certain regularization schemes. Default is True.
88+
areas_factor
89+
Factor used to scale the Voronoi areas during split point computation. Default is 0.5.
90+
skip_areas
91+
Whether to skip Voronoi area calculations and split point computations during Delaunay triangulation.
92+
When True, the Delaunay interface returns only the minimal set of outputs (points, simplices, mappings)
93+
without computing split_points or splitted_mappings. This optimization is useful for regularization
94+
schemes like Matérn kernels that don't require area-based calculations. Default is False.
8495
"""
8596
self.mapper_indices = None
8697
self.source_pixel_zeroed_indices = None
@@ -123,3 +134,4 @@ def __init__(
123134

124135
self.use_voronoi_areas = use_voronoi_areas
125136
self.areas_factor = areas_factor
137+
self.skip_areas = skip_areas

autoarray/structures/arrays/kernel_2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ def convolved_image_from(
612612
image,
613613
blurring_image,
614614
jax_method="direct",
615-
use_mixed_precision : bool = False,
615+
use_mixed_precision: bool = False,
616616
xp=np,
617617
):
618618
"""

autoarray/structures/mesh/delaunay_2d.py

Lines changed: 101 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,67 @@ def pix_indexes_for_sub_slim_index_delaunay_from(
339339
return out
340340

341341

342+
def scipy_delaunay_matern(points_np, query_points_np):
343+
"""
344+
Minimal SciPy Delaunay callback for Matérn regularization.
345+
346+
Returns only what’s needed for mapping:
347+
- points (tri.points)
348+
- simplices_padded
349+
- mappings: integer array of pixel indices for each query point,
350+
typically of shape (Q, 3), where each row gives the indices of the
351+
Delaunay mesh vertices ("pixels") associated with that query point.
352+
"""
353+
354+
max_simplices = 2 * points_np.shape[0]
355+
356+
# --- Delaunay mesh ---
357+
tri = Delaunay(points_np)
358+
359+
points = tri.points.astype(points_np.dtype)
360+
simplices = tri.simplices.astype(np.int32)
361+
362+
# --- Pad simplices to fixed shape for JAX ---
363+
simplices_padded = -np.ones((max_simplices, 3), dtype=np.int32)
364+
simplices_padded[: simplices.shape[0]] = simplices
365+
366+
# --- find_simplex for query points ---
367+
simplex_idx = tri.find_simplex(query_points_np).astype(np.int32) # (Q,)
368+
369+
mappings = pix_indexes_for_sub_slim_index_delaunay_from(
370+
source_plane_data_grid=query_points_np,
371+
simplex_index_for_sub_slim_index=simplex_idx,
372+
pix_indexes_for_simplex_index=simplices,
373+
delaunay_points=points_np,
374+
)
375+
376+
return points, simplices_padded, mappings
377+
378+
379+
def jax_delaunay_matern(points, query_points):
380+
"""
381+
JAX wrapper using pure_callback to run SciPy Delaunay on CPU,
382+
returning only the minimal outputs needed for Matérn usage.
383+
"""
384+
import jax
385+
import jax.numpy as jnp
386+
387+
N = points.shape[0]
388+
Q = query_points.shape[0]
389+
max_simplices = 2 * N
390+
391+
points_shape = jax.ShapeDtypeStruct((N, 2), points.dtype)
392+
simplices_padded_shape = jax.ShapeDtypeStruct((max_simplices, 3), jnp.int32)
393+
mappings_shape = jax.ShapeDtypeStruct((Q, 3), jnp.int32)
394+
395+
return jax.pure_callback(
396+
lambda pts, qpts: scipy_delaunay_matern(np.asarray(pts), np.asarray(qpts)),
397+
(points_shape, simplices_padded_shape, mappings_shape),
398+
points,
399+
query_points,
400+
)
401+
402+
342403
class DelaunayInterface:
343404

344405
def __init__(
@@ -466,33 +527,60 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
466527

467528
use_voronoi_areas = self.preloads.use_voronoi_areas
468529
areas_factor = self.preloads.areas_factor
530+
skip_areas = self.preloads.skip_areas
469531

470532
else:
471533

472534
use_voronoi_areas = True
473535
areas_factor = 0.5
536+
skip_areas = False
474537

475-
if self._xp.__name__.startswith("jax"):
538+
if not skip_areas:
476539

477-
import jax.numpy as jnp
540+
if self._xp.__name__.startswith("jax"):
478541

479-
points, simplices, mappings, split_points, splitted_mappings = jax_delaunay(
480-
points=self.mesh_grid_xy,
481-
query_points=self._source_plane_data_grid_over_sampled,
482-
use_voronoi_areas=use_voronoi_areas,
483-
areas_factor=areas_factor,
484-
)
542+
import jax.numpy as jnp
543+
544+
points, simplices, mappings, split_points, splitted_mappings = (
545+
jax_delaunay(
546+
points=self.mesh_grid_xy,
547+
query_points=self._source_plane_data_grid_over_sampled,
548+
use_voronoi_areas=use_voronoi_areas,
549+
areas_factor=areas_factor,
550+
)
551+
)
552+
553+
else:
554+
555+
points, simplices, mappings, split_points, splitted_mappings = (
556+
scipy_delaunay(
557+
points_np=self.mesh_grid_xy,
558+
query_points_np=self._source_plane_data_grid_over_sampled,
559+
use_voronoi_areas=use_voronoi_areas,
560+
areas_factor=areas_factor,
561+
)
562+
)
485563

486564
else:
487565

488-
points, simplices, mappings, split_points, splitted_mappings = (
489-
scipy_delaunay(
566+
if self._xp.__name__.startswith("jax"):
567+
568+
import jax.numpy as jnp
569+
570+
points, simplices, mappings = jax_delaunay_matern(
571+
points=self.mesh_grid_xy,
572+
query_points=self._source_plane_data_grid_over_sampled,
573+
)
574+
575+
else:
576+
577+
points, simplices, mappings = scipy_delaunay_matern(
490578
points_np=self.mesh_grid_xy,
491579
query_points_np=self._source_plane_data_grid_over_sampled,
492-
use_voronoi_areas=use_voronoi_areas,
493-
areas_factor=areas_factor,
494580
)
495-
)
581+
582+
split_points = None
583+
splitted_mappings = None
496584

497585
return DelaunayInterface(
498586
points=points,

0 commit comments

Comments
 (0)