Skip to content

Commit e366b49

Browse files
Jammy2211Jammy2211
authored andcommitted
end to end matern tests and skip areas added
1 parent ac9805b commit e366b49

File tree

5 files changed

+135
-24
lines changed

5 files changed

+135
-24
lines changed

autoarray/inversion/regularization/matern_adaptive_brightness_kernel.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from autoarray.inversion.regularization.matern_kernel import matern_kernel
1111

12+
1213
def matern_cov_matrix_from(
1314
scale: float,
1415
nu: float,
@@ -47,7 +48,7 @@ def matern_cov_matrix_from(
4748
# Pairwise distances (broadcasted)
4849
# --------------------------------
4950
diff = pixel_points[:, None, :] - pixel_points[None, :, :] # (N, N, 2)
50-
d_ij = xp.sqrt(diff[..., 0] ** 2 + diff[..., 1] ** 2) # (N, N)
51+
d_ij = xp.sqrt(diff[..., 0] ** 2 + diff[..., 1] ** 2) # (N, N)
5152

5253
# --------------------------------
5354
# Base Matérn covariance
@@ -113,7 +114,9 @@ def __init__(
113114
super().__init__(coefficient=coefficient, scale=scale, nu=nu)
114115
self.rho = rho
115116

116-
def covariance_kernel_weights_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray:
117+
def covariance_kernel_weights_from(
118+
self, linear_obj: LinearObj, xp=np
119+
) -> np.ndarray:
117120
"""
118121
Returns per-pixel kernel weights that adapt to the reconstructed pixel brightness.
119122
"""
@@ -126,10 +129,12 @@ def covariance_kernel_weights_from(self, linear_obj: LinearObj, xp=np) -> np.nda
126129
return xp.exp(-self.rho * (1.0 - pixel_signals / max_signal))
127130

128131
def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray:
129-
kernel_weights = self.covariance_kernel_weights_from(linear_obj=linear_obj, xp=xp)
132+
kernel_weights = self.covariance_kernel_weights_from(
133+
linear_obj=linear_obj, xp=xp
134+
)
130135

131136
# Follow the xp pattern used in the Matérn kernel module (often `.array` for grids).
132-
pixel_points = linear_obj.source_plane_mesh_grid
137+
pixel_points = linear_obj.source_plane_mesh_grid.array
133138

134139
covariance_matrix = matern_cov_matrix_from(
135140
scale=self.scale,
@@ -145,4 +150,4 @@ def regularization_weights_from(self, linear_obj: LinearObj, xp=np) -> np.ndarra
145150
"""
146151
Returns the regularization weights of this regularization scheme.
147152
"""
148-
return 1.0 / self.covariance_kernel_weights_from(linear_obj=linear_obj, xp=xp)
153+
return 1.0 / self.covariance_kernel_weights_from(linear_obj=linear_obj, xp=xp)

autoarray/inversion/regularization/matern_kernel.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,20 @@ def kv_xp(v, z, xp=np):
4444
)
4545

4646

47+
def gamma_xp(x, xp=np):
48+
"""
49+
XP-compatible Gamma(x).
50+
"""
51+
if xp is np:
52+
import scipy.special as sc
53+
54+
return sc.gamma(x)
55+
else:
56+
import jax.scipy.special as jsp
57+
58+
return jsp.gamma(x)
59+
60+
4761
def matern_kernel(r, l: float = 1.0, v: float = 0.5, xp=np):
4862
"""
4963
XP-compatible Matérn kernel.
@@ -55,7 +69,7 @@ def matern_kernel(r, l: float = 1.0, v: float = 0.5, xp=np):
5569

5670
z = xp.sqrt(2.0 * v) * r / l
5771

58-
part1 = 2.0 ** (1.0 - v) / math.gamma(v) # scalar constant
72+
part1 = 2.0 ** (1.0 - v) / gamma_xp(v, xp) # scalar constant
5973
part2 = z**v
6074
part3 = kv_xp(v, z, xp)
6175

@@ -141,8 +155,8 @@ def __init__(self, coefficient: float = 1.0, scale: float = 1.0, nu: float = 0.5
141155
"""
142156

143157
self.coefficient = coefficient
144-
self.scale = float(scale)
145-
self.nu = float(nu)
158+
self.scale = scale
159+
self.nu = nu
146160
super().__init__()
147161

148162
def regularization_weights_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray:

autoarray/preloads.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(
2626
linear_light_profile_blurred_mapping_matrix=None,
2727
use_voronoi_areas: bool = True,
2828
areas_factor: float = 0.5,
29+
skip_areas: bool = False,
2930
):
3031
"""
3132
Stores preloaded arrays and matrices used during pixelized linear inversions, improving both performance
@@ -123,3 +124,4 @@ def __init__(
123124

124125
self.use_voronoi_areas = use_voronoi_areas
125126
self.areas_factor = areas_factor
127+
self.skip_areas = skip_areas

autoarray/structures/mesh/delaunay_2d.py

Lines changed: 99 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,65 @@ def pix_indexes_for_sub_slim_index_delaunay_from(
339339
return out
340340

341341

342+
def scipy_delaunay_matern(points_np, query_points_np):
343+
"""
344+
Minimal SciPy Delaunay callback for Matérn regularization.
345+
346+
Returns only what’s needed for mapping:
347+
- points (tri.points)
348+
- simplices_padded
349+
- mappings (pix indexes for each query point)
350+
"""
351+
352+
max_simplices = 2 * points_np.shape[0]
353+
354+
# --- Delaunay mesh ---
355+
tri = Delaunay(points_np)
356+
357+
points = tri.points.astype(points_np.dtype)
358+
simplices = tri.simplices.astype(np.int32)
359+
360+
# --- Pad simplices to fixed shape for JAX ---
361+
simplices_padded = -np.ones((max_simplices, 3), dtype=np.int32)
362+
simplices_padded[: simplices.shape[0]] = simplices
363+
364+
# --- find_simplex for query points ---
365+
simplex_idx = tri.find_simplex(query_points_np).astype(np.int32) # (Q,)
366+
367+
mappings = pix_indexes_for_sub_slim_index_delaunay_from(
368+
source_plane_data_grid=query_points_np,
369+
simplex_index_for_sub_slim_index=simplex_idx,
370+
pix_indexes_for_simplex_index=simplices,
371+
delaunay_points=points_np,
372+
)
373+
374+
return points, simplices_padded, mappings
375+
376+
377+
def jax_delaunay_matern(points, query_points):
378+
"""
379+
JAX wrapper using pure_callback to run SciPy Delaunay on CPU,
380+
returning only the minimal outputs needed for Matérn usage.
381+
"""
382+
import jax
383+
import jax.numpy as jnp
384+
385+
N = points.shape[0]
386+
Q = query_points.shape[0]
387+
max_simplices = 2 * N
388+
389+
points_shape = jax.ShapeDtypeStruct((N, 2), points.dtype)
390+
simplices_padded_shape = jax.ShapeDtypeStruct((max_simplices, 3), jnp.int32)
391+
mappings_shape = jax.ShapeDtypeStruct((Q, 3), jnp.int32)
392+
393+
return jax.pure_callback(
394+
lambda pts, qpts: scipy_delaunay_matern(np.asarray(pts), np.asarray(qpts)),
395+
(points_shape, simplices_padded_shape, mappings_shape),
396+
points,
397+
query_points,
398+
)
399+
400+
342401
class DelaunayInterface:
343402

344403
def __init__(
@@ -466,33 +525,60 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
466525

467526
use_voronoi_areas = self.preloads.use_voronoi_areas
468527
areas_factor = self.preloads.areas_factor
528+
skip_areas = self.preloads.skip_areas
469529

470530
else:
471531

472532
use_voronoi_areas = True
473533
areas_factor = 0.5
534+
skip_areas = False
474535

475-
if self._xp.__name__.startswith("jax"):
536+
if not skip_areas:
476537

477-
import jax.numpy as jnp
538+
if self._xp.__name__.startswith("jax"):
478539

479-
points, simplices, mappings, split_points, splitted_mappings = jax_delaunay(
480-
points=self.mesh_grid_xy,
481-
query_points=self._source_plane_data_grid_over_sampled,
482-
use_voronoi_areas=use_voronoi_areas,
483-
areas_factor=areas_factor,
484-
)
540+
import jax.numpy as jnp
541+
542+
points, simplices, mappings, split_points, splitted_mappings = (
543+
jax_delaunay(
544+
points=self.mesh_grid_xy,
545+
query_points=self._source_plane_data_grid_over_sampled,
546+
use_voronoi_areas=use_voronoi_areas,
547+
areas_factor=areas_factor,
548+
)
549+
)
550+
551+
else:
552+
553+
points, simplices, mappings, split_points, splitted_mappings = (
554+
scipy_delaunay(
555+
points_np=self.mesh_grid_xy,
556+
query_points_np=self._source_plane_data_grid_over_sampled,
557+
use_voronoi_areas=use_voronoi_areas,
558+
areas_factor=areas_factor,
559+
)
560+
)
485561

486562
else:
487563

488-
points, simplices, mappings, split_points, splitted_mappings = (
489-
scipy_delaunay(
564+
if self._xp.__name__.startswith("jax"):
565+
566+
import jax.numpy as jnp
567+
568+
points, simplices, mappings = jax_delaunay_matern(
569+
points=self.mesh_grid_xy,
570+
query_points=self._source_plane_data_grid_over_sampled,
571+
)
572+
573+
else:
574+
575+
points, simplices, mappings = scipy_delaunay_matern(
490576
points_np=self.mesh_grid_xy,
491577
query_points_np=self._source_plane_data_grid_over_sampled,
492-
use_voronoi_areas=use_voronoi_areas,
493-
areas_factor=areas_factor,
494578
)
495-
)
579+
580+
split_points = None
581+
splitted_mappings = None
496582

497583
return DelaunayInterface(
498584
points=points,

test_autoarray/inversion/regularizations/test_matern_adaptive_brightness_kernel.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
def test__regularization_matrix():
1010

11-
reg = aa.reg.MaternAdaptiveBrightnessKernel(coefficient=1.0, scale=2.0, nu=2.0, rho=1.0)
11+
reg = aa.reg.MaternAdaptiveBrightnessKernel(
12+
coefficient=1.0, scale=2.0, nu=2.0, rho=1.0
13+
)
1214

1315
neighbors = np.array(
1416
[
@@ -41,7 +43,9 @@ def test__regularization_matrix():
4143
assert regularization_matrix[0, 0] == pytest.approx(18.7439565009, 1.0e-4)
4244
assert regularization_matrix[0, 1] == pytest.approx(-8.786547368, 1.0e-4)
4345

44-
reg = aa.reg.MaternAdaptiveBrightnessKernel(coefficient=1.5, scale=2.5, nu=2.5, rho=1.5)
46+
reg = aa.reg.MaternAdaptiveBrightnessKernel(
47+
coefficient=1.5, scale=2.5, nu=2.5, rho=1.5
48+
)
4549

4650
pixel_signals = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
4751

@@ -52,4 +56,4 @@ def test__regularization_matrix():
5256
regularization_matrix = reg.regularization_matrix_from(linear_obj=mapper)
5357

5458
assert regularization_matrix[0, 0] == pytest.approx(121.0190770, 1.0e-4)
55-
assert regularization_matrix[0, 1] == pytest.approx(-66.9580331, 1.0e-4)
59+
assert regularization_matrix[0, 1] == pytest.approx(-66.9580331, 1.0e-4)

0 commit comments

Comments
 (0)