Skip to content

Commit e5e1a07

Browse files
authored
Merge pull request #186 from Jammy2211/feature/jax_wrapper_rectangular_adaptive
Feature/jax wrapper rectangular adaptive
2 parents 3fecb5e + 52ca2ac commit e5e1a07

30 files changed

+876
-235
lines changed

autoarray/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from .inversion.pixelization.mappers.rectangular import MapperRectangular
4444
from .inversion.pixelization.mappers.delaunay import MapperDelaunay
4545
from .inversion.pixelization.mappers.voronoi import MapperVoronoi
46+
from .inversion.pixelization.mappers.rectangular_uniform import MapperRectangularUniform
4647
from .inversion.pixelization.image_mesh.abstract import AbstractImageMesh
4748
from .inversion.pixelization.mesh.abstract import AbstractMesh
4849
from .inversion.inversion.imaging.mapping import InversionImagingMapping
@@ -75,6 +76,7 @@
7576
from .operators.over_sampling.over_sampler import OverSampler
7677
from .structures.grids.irregular_2d import Grid2DIrregular
7778
from .structures.mesh.rectangular_2d import Mesh2DRectangular
79+
from .structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform
7880
from .structures.mesh.voronoi_2d import Mesh2DVoronoi
7981
from .structures.mesh.delaunay_2d import Mesh2DDelaunay
8082
from .structures.arrays.kernel_2d import Kernel2D

autoarray/dataset/imaging/dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ def w_tilde(self):
252252
indexes=indexes.astype("int"),
253253
lengths=lengths.astype("int"),
254254
noise_map_value=self.noise_map[0],
255+
noise_map=self.noise_map,
256+
psf=self.psf,
257+
mask=self.mask,
255258
)
256259

257260
@classmethod

autoarray/dataset/imaging/w_tilde.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22
import logging
33
import numpy as np
44

5+
from autoconf import cached_property
6+
57
from autoarray.dataset.abstract.w_tilde import AbstractWTilde
68

9+
from autoarray.inversion.inversion.imaging import inversion_imaging_util
10+
711
logger = logging.getLogger(__name__)
812

913

@@ -13,6 +17,9 @@ def __init__(
1317
curvature_preload: np.ndarray,
1418
indexes: np.ndim,
1519
lengths: np.ndarray,
20+
noise_map: np.ndarray,
21+
psf: np.ndarray,
22+
mask: np.ndarray,
1623
noise_map_value: float,
1724
):
1825
"""
@@ -44,3 +51,56 @@ def __init__(
4451

4552
self.indexes = indexes
4653
self.lengths = lengths
54+
self.noise_map = noise_map
55+
self.psf = psf
56+
self.mask = mask
57+
58+
@cached_property
59+
def w_matrix(self):
60+
"""
61+
The matrix `w_tilde_curvature` is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF
62+
convolution of every pair of image pixels given the noise map. This can be used to efficiently compute the
63+
curvature matrix via the mappings between image and source pixels, in a way that omits having to perform the
64+
PSF convolution on every individual source pixel. This provides a significant speed up for inversions of imaging
65+
datasets.
66+
67+
The limitation of this matrix is that the dimensions of [image_pixels, image_pixels] can exceed many 10s of GB's,
68+
making it impossible to store in memory and its use in linear algebra calculations extremely. The method
69+
`w_tilde_curvature_preload_imaging_from` describes a compressed representation that overcomes this hurdles. It is
70+
advised `w_tilde` and this method are only used for testing.
71+
72+
Parameters
73+
----------
74+
noise_map_native
75+
The two dimensional masked noise-map of values which w_tilde is computed from.
76+
kernel_native
77+
The two dimensional PSF kernel that w_tilde encodes the convolution of.
78+
native_index_for_slim_index
79+
An array of shape [total_x_pixels*sub_size] that maps pixels from the slimmed array to the native array.
80+
81+
Returns
82+
-------
83+
ndarray
84+
A matrix that encodes the PSF convolution values between the noise map that enables efficient calculation of
85+
the curvature matrix.
86+
"""
87+
88+
return inversion_imaging_util.w_tilde_curvature_imaging_from(
89+
noise_map_native=np.array(self.noise_map.native.array).astype("float64"),
90+
kernel_native=np.array(self.psf.native.array).astype("float64"),
91+
native_index_for_slim_index=np.array(
92+
self.mask.derive_indexes.native_for_slim
93+
).astype("int"),
94+
)
95+
96+
@cached_property
97+
def psf_operator_matrix_dense(self):
98+
99+
return inversion_imaging_util.psf_operator_matrix_dense_from(
100+
kernel_native=np.array(self.psf.native.array).astype("float64"),
101+
native_index_for_slim_index=np.array(
102+
self.mask.derive_indexes.native_for_slim
103+
).astype("int"),
104+
native_shape=self.noise_map.shape_native,
105+
correlate=False,
106+
)

autoarray/fixtures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def make_rectangular_mapper_7x7_3x3():
421421
adapt_data=aa.Array2D.ones(shape_native=(3, 3), pixel_scales=0.1),
422422
)
423423

424-
return aa.MapperRectangular(
424+
return aa.MapperRectangularUniform(
425425
mapper_grids=mapper_grids,
426426
border_relocator=make_border_relocator_2d_7x7(),
427427
regularization=make_regularization_constant(),

autoarray/inversion/inversion/imaging/inversion_imaging_util.py

Lines changed: 83 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,69 @@
1+
from scipy.signal import convolve2d
2+
import jax.numpy as jnp
13
import numpy as np
24
from typing import Tuple
35

46
from autoarray import numba_util
7+
from scipy.signal import correlate2d
8+
9+
import numpy as np
10+
11+
12+
def psf_operator_matrix_dense_from(
13+
kernel_native: np.ndarray,
14+
native_index_for_slim_index: np.ndarray, # shape (N_pix, 2), native (y,x) coords of masked pixels
15+
native_shape: tuple[int, int],
16+
correlate: bool = True,
17+
) -> np.ndarray:
18+
"""
19+
Construct a dense PSF operator W (N_pix x N_pix) that maps masked image pixels to masked image pixels.
20+
21+
Parameters
22+
----------
23+
kernel_native : (Ky, Kx) PSF kernel.
24+
native_index_for_slim_index : (N_pix, 2) array of int
25+
Native (y, x) coords for each masked pixel.
26+
native_shape : (Ny, Nx)
27+
Native 2D image shape.
28+
correlate : bool, default True
29+
If True, use correlation convention (no kernel flip).
30+
If False, use convolution convention (flip kernel).
31+
32+
Returns
33+
-------
34+
W : ndarray, shape (N_pix, N_pix)
35+
Dense PSF operator.
36+
"""
37+
Ky, Kx = kernel_native.shape
38+
ph, pw = Ky // 2, Kx // 2
39+
Ny, Nx = native_shape
40+
N_pix = native_index_for_slim_index.shape[0]
41+
42+
ker = kernel_native if correlate else kernel_native[::-1, ::-1]
43+
44+
# Padded index grid: -1 everywhere, slim index where masked
45+
index_padded = -np.ones((Ny + 2 * ph, Nx + 2 * pw), dtype=np.int64)
46+
for p, (y, x) in enumerate(native_index_for_slim_index):
47+
index_padded[y + ph, x + pw] = p
48+
49+
# Neighborhood offsets
50+
dy = np.arange(Ky) - ph
51+
dx = np.arange(Kx) - pw
52+
53+
W = np.zeros((N_pix, N_pix), dtype=float)
54+
55+
for i, (y, x) in enumerate(native_index_for_slim_index):
56+
yp = y + ph
57+
xp = x + pw
58+
for j, dy_ in enumerate(dy):
59+
for k, dx_ in enumerate(dx):
60+
neigh = index_padded[yp + dy_, xp + dx_]
61+
if neigh >= 0:
62+
W[i, neigh] += ker[j, k]
63+
64+
return W
565

666

7-
@numba_util.jit()
867
def w_tilde_data_imaging_from(
968
image_native: np.ndarray,
1069
noise_map_native: np.ndarray,
@@ -44,32 +103,35 @@ def w_tilde_data_imaging_from(
44103
efficient calculation of the data vector.
45104
"""
46105

47-
kernel_shift_y = -(kernel_native.shape[1] // 2)
48-
kernel_shift_x = -(kernel_native.shape[0] // 2)
49-
50-
image_pixels = len(native_index_for_slim_index)
51-
52-
w_tilde_data = np.zeros((image_pixels,))
106+
# 1) weight map = image / noise^2 (safe where noise==0)
107+
weight_map = jnp.where(
108+
noise_map_native > 0.0, image_native / (noise_map_native**2), 0.0
109+
)
53110

54-
weight_map_native = image_native / noise_map_native**2.0
111+
Ky, Kx = kernel_native.shape
112+
ph, pw = Ky // 2, Kx // 2
55113

56-
for ip0 in range(image_pixels):
57-
ip0_y, ip0_x = native_index_for_slim_index[ip0]
58-
59-
value = 0.0
114+
# 2) pad so neighbourhood gathers never go OOB
115+
padded = jnp.pad(
116+
weight_map, ((ph, ph), (pw, pw)), mode="constant", constant_values=0.0
117+
)
60118

61-
for k0_y in range(kernel_native.shape[0]):
62-
for k0_x in range(kernel_native.shape[1]):
63-
weight_value = weight_map_native[
64-
ip0_y + k0_y + kernel_shift_y, ip0_x + k0_x + kernel_shift_x
65-
]
119+
# 3) build broadcasted neighbourhood indices for all requested pixels
120+
# shift pixel coords into the padded frame
121+
ys = native_index_for_slim_index[:, 0] + ph # (N,)
122+
xs = native_index_for_slim_index[:, 1] + pw # (N,)
66123

67-
if not np.isnan(weight_value):
68-
value += kernel_native[k0_y, k0_x] * weight_value
124+
# kernel-relative offsets
125+
dy = jnp.arange(Ky) - ph # (Ky,)
126+
dx = jnp.arange(Kx) - pw # (Kx,)
69127

70-
w_tilde_data[ip0] = value
128+
# broadcast to (N, Ky, Kx)
129+
Y = ys[:, None, None] + dy[None, :, None]
130+
X = xs[:, None, None] + dx[None, None, :]
71131

72-
return w_tilde_data
132+
# 4) gather patches and correlate (no kernel flip)
133+
patches = padded[Y, X] # (N, Ky, Kx)
134+
return jnp.sum(patches * kernel_native[None, :, :], axis=(1, 2)) # (N,)
73135

74136

75137
@numba_util.jit()

0 commit comments

Comments
 (0)