1- from autoarray .inversion .regularization .abstract import AbstractRegularization
1+ from __future__ import annotations
2+ import numpy as np
3+ from typing import TYPE_CHECKING
24
3- from autoarray .inversion .regularization .matern_kernel import matern_cov_matrix_from
5+ from autoarray .inversion .regularization .matern_kernel import MaternKernel
46
7+ if TYPE_CHECKING :
8+ from autoarray .inversion .linear_obj .linear_obj import LinearObj
59
6- class AdaptiveBrightnessMatern (AbstractRegularization ):
10+ from autoarray .inversion .regularization .matern_kernel import matern_kernel
11+
12+ def matern_cov_matrix_from (
13+ scale : float ,
14+ nu : float ,
15+ pixel_points ,
16+ weights = None ,
17+ xp = np ,
18+ ):
19+ """
20+ Construct the regularization covariance matrix (N x N) using a Matérn kernel,
21+ optionally modulated by per-pixel weights.
22+
23+ If `weights` is provided (shape [N]), the covariance is:
24+ C_ij = K(d_ij; scale, nu) * w_i * w_j
25+ with a small diagonal jitter added for numerical stability.
26+
27+ Parameters
28+ ----------
29+ scale
30+ Typical correlation length of the Matérn kernel.
31+ nu
32+ Smoothness parameter of the Matérn kernel.
33+ pixel_points
34+ Array-like of shape [N, 2] with (y, x) coordinates (or any 2D coords; only distances matter).
35+ weights
36+ Optional array-like of shape [N]. If None, treated as all ones.
37+ xp
38+ Backend (numpy or jax.numpy).
39+
40+ Returns
41+ -------
42+ covariance_matrix
43+ Array of shape [N, N].
44+ """
45+
46+ # --------------------------------
47+ # Pairwise distances (broadcasted)
48+ # --------------------------------
49+ diff = pixel_points [:, None , :] - pixel_points [None , :, :] # (N, N, 2)
50+ d_ij = xp .sqrt (diff [..., 0 ] ** 2 + diff [..., 1 ] ** 2 ) # (N, N)
51+
52+ # --------------------------------
53+ # Base Matérn covariance
54+ # --------------------------------
55+ covariance_matrix = matern_kernel (d_ij , l = scale , v = nu , xp = xp ) # (N, N)
56+
57+ # --------------------------------
58+ # Apply weights: C_ij *= w_i * w_j
59+ # (broadcasted outer product, JAX-safe)
60+ # --------------------------------
61+ if weights is not None :
62+ w = xp .asarray (weights )
63+ # Ensure shape (N,) -> outer product (N,1)*(1,N) -> (N,N)
64+ covariance_matrix = covariance_matrix * (w [:, None ] * w [None , :])
65+
66+ # --------------------------------
67+ # Add diagonal jitter (JAX-safe)
68+ # --------------------------------
69+ pixels = pixel_points .shape [0 ]
70+ covariance_matrix = covariance_matrix + 1e-8 * xp .eye (pixels )
71+
72+ return covariance_matrix
73+
74+
75+ class MaternAdaptiveBrightnessKernel (MaternKernel ):
776 def __init__ (
8- self ,
9- coefficient : float = 1.0 ,
10- scale : float = 1.0 ,
11- nu : float = 0.5 ,
12- rho : float = 1.0 ,
77+ self ,
78+ coefficient : float = 1.0 ,
79+ scale : float = 1.0 ,
80+ nu : float = 0.5 ,
81+ rho : float = 1.0 ,
1382 ):
14- super ().__init__ (coefficient = coefficient , scale = scale , rho = rho )
15- self .nu = nu
83+ """
84+ Regularization which uses a Matern smoothing kernel to regularize the solution with regularization weights
85+ that adapt to the brightness of the source being reconstructed.
86+
87+ For this regularization scheme, every pixel is regularized with every other pixel. This contrasts many other
88+ schemes, where regularization is based on neighboring (e.g. do the pixels share a Delaunay edge?) or computing
89+ derivates around the center of the pixel (where nearby pixels are regularization locally in similar ways).
90+
91+ This makes the regularization matrix fully dense and therefore maybe change the run times of the solution.
92+ It also leads to more overall smoothing which can lead to more stable linear inversions.
93+
94+ For the weighted regularization scheme, each pixel is given an 'effective regularization weight', which is
95+ applied when each set of pixel neighbors are regularized with one another. The motivation of this is that
96+ different regions of a pixelization's mesh require different levels of regularization (e.g., high smoothing where the
97+ no signal is present and less smoothing where it is, see (Nightingale, Dye and Massey 2018)).
98+
99+ This scheme is not used by Vernardos et al. (2022): https://arxiv.org/abs/2202.09378, but it follows
100+ a similar approach.
16101
17- def covariance_kernel_weights_from (self , linear_obj : LinearObj ) -> np .ndarray :
18- pixel_signals = linear_obj .pixel_signals_from (signal_scale = 1.0 )
19- return np .exp (- self .rho * (1 - pixel_signals / pixel_signals .max ()))
102+ A full description of regularization and this matrix can be found in the parent `AbstractRegularization` class.
20103
21- def regularization_matrix_from (self , linear_obj : LinearObj ) -> np .ndarray :
22- kernel_weights = self .covariance_kernel_weights_from (linear_obj = linear_obj )
104+ Parameters
105+ ----------
106+ coefficient
107+ The regularization coefficient which controls the degree of smooth of the inversion reconstruction.
108+ scale
109+ The typical scale of the exponential regularization pattern.
110+ nu
111+ Controls the derivative of the regularization pattern (`nu=0.5` is a Gaussian).
112+ """
113+ super ().__init__ (coefficient = coefficient , scale = scale , nu = nu )
114+ self .rho = rho
115+
116+ def covariance_kernel_weights_from (self , linear_obj : LinearObj , xp = np ) -> np .ndarray :
117+ """
118+ Returns per-pixel kernel weights that adapt to the reconstructed pixel brightness.
119+ """
120+ # Assumes linear_obj.pixel_signals_from is xp-aware elsewhere in the codebase.
121+ pixel_signals = linear_obj .pixel_signals_from (signal_scale = 1.0 , xp = xp )
122+
123+ max_signal = xp .max (pixel_signals )
124+ max_signal = xp .maximum (max_signal , 1e-8 ) # avoid divide-by-zero (JAX-safe)
125+
126+ return xp .exp (- self .rho * (1.0 - pixel_signals / max_signal ))
127+
128+ def regularization_matrix_from (self , linear_obj : LinearObj , xp = np ) -> np .ndarray :
129+ kernel_weights = self .covariance_kernel_weights_from (linear_obj = linear_obj , xp = xp )
130+
131+ # Follow the xp pattern used in the Matérn kernel module (often `.array` for grids).
132+ pixel_points = linear_obj .source_plane_mesh_grid
23133
24134 covariance_matrix = matern_cov_matrix_from (
25135 scale = self .scale ,
26- pixel_points = linear_obj . source_plane_mesh_grid ,
136+ pixel_points = pixel_points ,
27137 nu = self .nu ,
28138 weights = kernel_weights ,
139+ xp = xp ,
29140 )
30141
31- return self .coefficient * np .linalg .inv (covariance_matrix )
142+ return self .coefficient * xp .linalg .inv (covariance_matrix )
32143
33- def regularization_weights_from (self , linear_obj : LinearObj ) -> np .ndarray :
144+ def regularization_weights_from (self , linear_obj : LinearObj , xp = np ) -> np .ndarray :
34145 """
35146 Returns the regularization weights of this regularization scheme.
36-
37- The regularization weights define the level of regularization applied to each parameter in the linear object
38- (e.g. the ``pixels`` in a ``Mapper``).
39-
40- For standard regularization (e.g. ``Constant``) are weights are equal, however for adaptive schemes
41- (e.g. ``AdaptiveBrightness``) they vary to adapt to the data being reconstructed.
42-
43- Parameters
44- ----------
45- linear_obj
46- The linear object (e.g. a ``Mapper``) which uses these weights when performing regularization.
47-
48- Returns
49- -------
50- The regularization weights.
51147 """
52- return 1.0 / self .covariance_kernel_weights_from (linear_obj = linear_obj ) #meaningless, but consistent with other regularization schemes
148+ return 1.0 / self .covariance_kernel_weights_from (linear_obj = linear_obj , xp = xp )
0 commit comments