Skip to content

Commit 09be29c

Browse files
Jammy2211Jammy2211
authored andcommitted
added adaptive matern kernel
1 parent 4055b30 commit 09be29c

File tree

1 file changed

+52
-0
lines changed

1 file changed

+52
-0
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from autoarray.inversion.regularization.abstract import AbstractRegularization
2+
3+
from autoarray.inversion.regularization.matern_kernel import matern_cov_matrix_from
4+
5+
6+
class AdaptiveBrightnessMatern(AbstractRegularization):
7+
def __init__(
8+
self,
9+
coefficient: float = 1.0,
10+
scale: float = 1.0,
11+
nu: float = 0.5,
12+
rho: float = 1.0,
13+
):
14+
super().__init__(coefficient=coefficient, scale=scale, rho=rho)
15+
self.nu = nu
16+
17+
def covariance_kernel_weights_from(self, linear_obj: LinearObj) -> np.ndarray:
18+
pixel_signals = linear_obj.pixel_signals_from(signal_scale=1.0)
19+
return np.exp(-self.rho * (1 - pixel_signals / pixel_signals.max()))
20+
21+
def regularization_matrix_from(self, linear_obj: LinearObj) -> np.ndarray:
22+
kernel_weights = self.covariance_kernel_weights_from(linear_obj=linear_obj)
23+
24+
covariance_matrix = matern_cov_matrix_from(
25+
scale=self.scale,
26+
pixel_points=linear_obj.source_plane_mesh_grid,
27+
nu=self.nu,
28+
weights=kernel_weights,
29+
)
30+
31+
return self.coefficient * np.linalg.inv(covariance_matrix)
32+
33+
def regularization_weights_from(self, linear_obj: LinearObj) -> np.ndarray:
34+
"""
35+
Returns the regularization weights of this regularization scheme.
36+
37+
The regularization weights define the level of regularization applied to each parameter in the linear object
38+
(e.g. the ``pixels`` in a ``Mapper``).
39+
40+
For standard regularization (e.g. ``Constant``) are weights are equal, however for adaptive schemes
41+
(e.g. ``AdaptiveBrightness``) they vary to adapt to the data being reconstructed.
42+
43+
Parameters
44+
----------
45+
linear_obj
46+
The linear object (e.g. a ``Mapper``) which uses these weights when performing regularization.
47+
48+
Returns
49+
-------
50+
The regularization weights.
51+
"""
52+
return 1.0/self.covariance_kernel_weights_from(linear_obj=linear_obj) #meaningless, but consistent with other regularization schemes

0 commit comments

Comments
 (0)