Skip to content

Commit 9b97bcb

Browse files
authored
Merge pull request #185 from Jammy2211/feature/jax_rectangular_slam
Feature/jax rectangular slam
2 parents 5d94e9a + f0a4a8d commit 9b97bcb

File tree

19 files changed

+516
-528
lines changed

19 files changed

+516
-528
lines changed

autoarray/inversion/inversion/abstract.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import jax.numpy as jnp
3+
from jax.scipy.linalg import block_diag
34
import numpy as np
45

56
from typing import Dict, List, Optional, Type, Union
@@ -334,8 +335,6 @@ def regularization_matrix(self) -> Optional[np.ndarray]:
334335
If the `settings.force_edge_pixels_to_zeros` is `True`, the edge pixels of each mapper in the inversion
335336
are regularized so high their value is forced to zero.
336337
"""
337-
from scipy.linalg import block_diag
338-
339338
return block_diag(
340339
*[linear_obj.regularization_matrix for linear_obj in self.linear_obj_list]
341340
)
@@ -664,30 +663,17 @@ def log_det_regularization_matrix_term(self) -> float:
664663
float
665664
The log determinant of the regularization matrix.
666665
"""
667-
from scipy.sparse import csc_matrix
668-
from scipy.sparse.linalg import splu
669-
670666
if not self.has(cls=AbstractRegularization):
671667
return 0.0
672668

673669
try:
674-
lu = splu(csc_matrix(self.regularization_matrix_reduced))
675-
diagL = lu.L.diagonal()
676-
diagU = lu.U.diagonal()
677-
diagL = diagL.astype(np.complex128)
678-
diagU = diagU.astype(np.complex128)
679-
680-
return np.real(np.log(diagL).sum() + np.log(diagU).sum())
681-
682-
except RuntimeError:
683-
try:
684-
return 2.0 * np.sum(
685-
np.log(
686-
np.diag(np.linalg.cholesky(self.regularization_matrix_reduced))
687-
)
670+
return 2.0 * np.sum(
671+
jnp.log(
672+
jnp.diag(jnp.linalg.cholesky(self.regularization_matrix_reduced))
688673
)
689-
except np.linalg.LinAlgError as e:
690-
raise exc.InversionException() from e
674+
)
675+
except np.linalg.LinAlgError as e:
676+
raise exc.InversionException() from e
691677

692678
@property
693679
def reconstruction_noise_map_with_covariance(self) -> np.ndarray:

autoarray/inversion/pixelization/mappers/abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def pixel_signals_from(self, signal_scale: float) -> np.ndarray:
288288
pix_indexes_for_sub_slim_index=self.pix_indexes_for_sub_slim_index,
289289
pix_size_for_sub_slim_index=self.pix_sizes_for_sub_slim_index,
290290
slim_index_for_sub_slim_index=self.over_sampler.slim_for_sub_slim,
291-
adapt_data=np.array(self.adapt_data),
291+
adapt_data=self.adapt_data.array,
292292
)
293293

294294
def slim_indexes_for_pix_indexes(self, pix_indexes: List) -> List[List]:

autoarray/inversion/pixelization/mesh/mesh_util.py

Lines changed: 81 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from autoarray import numba_util
66

77

8-
@numba_util.jit()
98
def rectangular_neighbors_from(
109
shape_native: Tuple[int, int],
1110
) -> Tuple[np.ndarray, np.ndarray]:
@@ -68,7 +67,6 @@ def rectangular_neighbors_from(
6867
return neighbors, neighbors_sizes
6968

7069

71-
@numba_util.jit()
7270
def rectangular_corner_neighbors(
7371
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
7472
) -> Tuple[np.ndarray, np.ndarray]:
@@ -113,7 +111,6 @@ def rectangular_corner_neighbors(
113111
return neighbors, neighbors_sizes
114112

115113

116-
@numba_util.jit()
117114
def rectangular_top_edge_neighbors(
118115
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
119116
) -> Tuple[np.ndarray, np.ndarray]:
@@ -136,17 +133,20 @@ def rectangular_top_edge_neighbors(
136133
-------
137134
The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has.
138135
"""
139-
for pix in range(1, shape_native[1] - 1):
140-
pixel_index = pix
141-
neighbors[pixel_index, 0:3] = np.array(
142-
[pixel_index - 1, pixel_index + 1, pixel_index + shape_native[1]]
143-
)
144-
neighbors_sizes[pixel_index] = 3
136+
"""
137+
Vectorized version of the top edge neighbor update using NumPy arithmetic.
138+
"""
139+
# Pixels along the top edge, excluding corners
140+
top_edge_pixels = np.arange(1, shape_native[1] - 1)
141+
142+
neighbors[top_edge_pixels, 0] = top_edge_pixels - 1
143+
neighbors[top_edge_pixels, 1] = top_edge_pixels + 1
144+
neighbors[top_edge_pixels, 2] = top_edge_pixels + shape_native[1]
145+
neighbors_sizes[top_edge_pixels] = 3
145146

146147
return neighbors, neighbors_sizes
147148

148149

149-
@numba_util.jit()
150150
def rectangular_left_edge_neighbors(
151151
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
152152
) -> Tuple[np.ndarray, np.ndarray]:
@@ -169,21 +169,20 @@ def rectangular_left_edge_neighbors(
169169
-------
170170
The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has.
171171
"""
172-
for pix in range(1, shape_native[0] - 1):
173-
pixel_index = pix * shape_native[1]
174-
neighbors[pixel_index, 0:3] = np.array(
175-
[
176-
pixel_index - shape_native[1],
177-
pixel_index + 1,
178-
pixel_index + shape_native[1],
179-
]
180-
)
181-
neighbors_sizes[pixel_index] = 3
172+
# Row indices (excluding top and bottom corners)
173+
rows = np.arange(1, shape_native[0] - 1)
174+
175+
# Convert to flat pixel indices for the left edge (first column)
176+
pixel_indices = rows * shape_native[1]
177+
178+
neighbors[pixel_indices, 0] = pixel_indices - shape_native[1]
179+
neighbors[pixel_indices, 1] = pixel_indices + 1
180+
neighbors[pixel_indices, 2] = pixel_indices + shape_native[1]
181+
neighbors_sizes[pixel_indices] = 3
182182

183183
return neighbors, neighbors_sizes
184184

185185

186-
@numba_util.jit()
187186
def rectangular_right_edge_neighbors(
188187
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
189188
) -> Tuple[np.ndarray, np.ndarray]:
@@ -206,21 +205,20 @@ def rectangular_right_edge_neighbors(
206205
-------
207206
The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has.
208207
"""
209-
for pix in range(1, shape_native[0] - 1):
210-
pixel_index = pix * shape_native[1] + shape_native[1] - 1
211-
neighbors[pixel_index, 0:3] = np.array(
212-
[
213-
pixel_index - shape_native[1],
214-
pixel_index - 1,
215-
pixel_index + shape_native[1],
216-
]
217-
)
218-
neighbors_sizes[pixel_index] = 3
208+
# Rows excluding the top and bottom corners
209+
rows = np.arange(1, shape_native[0] - 1)
210+
211+
# Flat indices for the right edge pixels
212+
pixel_indices = rows * shape_native[1] + shape_native[1] - 1
213+
214+
neighbors[pixel_indices, 0] = pixel_indices - shape_native[1]
215+
neighbors[pixel_indices, 1] = pixel_indices - 1
216+
neighbors[pixel_indices, 2] = pixel_indices + shape_native[1]
217+
neighbors_sizes[pixel_indices] = 3
219218

220219
return neighbors, neighbors_sizes
221220

222221

223-
@numba_util.jit()
224222
def rectangular_bottom_edge_neighbors(
225223
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
226224
) -> Tuple[np.ndarray, np.ndarray]:
@@ -243,19 +241,21 @@ def rectangular_bottom_edge_neighbors(
243241
-------
244242
The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has.
245243
"""
246-
pixels = int(shape_native[0] * shape_native[1])
244+
n_rows, n_cols = shape_native
245+
pixels = n_rows * n_cols
247246

248-
for pix in range(1, shape_native[1] - 1):
249-
pixel_index = pixels - pix - 1
250-
neighbors[pixel_index, 0:3] = np.array(
251-
[pixel_index - shape_native[1], pixel_index - 1, pixel_index + 1]
252-
)
253-
neighbors_sizes[pixel_index] = 3
247+
# Horizontal pixel positions along bottom row, excluding corners
248+
cols = np.arange(1, n_cols - 1)
249+
pixel_indices = pixels - cols - 1 # Reverse order from right to left
250+
251+
neighbors[pixel_indices, 0] = pixel_indices - n_cols
252+
neighbors[pixel_indices, 1] = pixel_indices - 1
253+
neighbors[pixel_indices, 2] = pixel_indices + 1
254+
neighbors_sizes[pixel_indices] = 3
254255

255256
return neighbors, neighbors_sizes
256257

257258

258-
@numba_util.jit()
259259
def rectangular_central_neighbors(
260260
neighbors: np.ndarray, neighbors_sizes: np.ndarray, shape_native: Tuple[int, int]
261261
) -> Tuple[np.ndarray, np.ndarray]:
@@ -279,46 +279,61 @@ def rectangular_central_neighbors(
279279
-------
280280
The arrays containing the 1D index of every pixel's neighbors and the number of neighbors that each pixel has.
281281
"""
282-
for x in range(1, shape_native[0] - 1):
283-
for y in range(1, shape_native[1] - 1):
284-
pixel_index = x * shape_native[1] + y
285-
neighbors[pixel_index, 0:4] = np.array(
286-
[
287-
pixel_index - shape_native[1],
288-
pixel_index - 1,
289-
pixel_index + 1,
290-
pixel_index + shape_native[1],
291-
]
292-
)
293-
neighbors_sizes[pixel_index] = 4
282+
n_rows, n_cols = shape_native
283+
284+
# Grid coordinates excluding edges
285+
xs = np.arange(1, n_rows - 1)
286+
ys = np.arange(1, n_cols - 1)
287+
288+
# 2D grid of central pixel indices
289+
grid_x, grid_y = np.meshgrid(xs, ys, indexing="ij")
290+
pixel_indices = grid_x * n_cols + grid_y
291+
pixel_indices = pixel_indices.ravel()
292+
293+
# Compute neighbor indices
294+
neighbors[pixel_indices, 0] = pixel_indices - n_cols # Up
295+
neighbors[pixel_indices, 1] = pixel_indices - 1 # Left
296+
neighbors[pixel_indices, 2] = pixel_indices + 1 # Right
297+
neighbors[pixel_indices, 3] = pixel_indices + n_cols # Down
298+
299+
neighbors_sizes[pixel_indices] = 4
294300

295301
return neighbors, neighbors_sizes
296302

297303

298-
def rectangular_edge_pixel_list_from(neighbors: np.ndarray) -> List:
304+
def rectangular_edge_pixel_list_from(shape_native: Tuple[int, int]) -> List[int]:
299305
"""
300-
Returns a list of the 1D indices of all pixels on the edge of a rectangular pixelization.
301-
302-
This is computed by searching the `neighbors` array for pixels that have a neighbor with index -1, meaning there
303-
is at least one neighbor from the 4 expected missing.
306+
Returns a list of the 1D indices of all pixels on the edge of a rectangular pixelization,
307+
based on its 2D shape.
304308
305309
Parameters
306310
----------
307-
neighbors
308-
An array of dimensions [total_pixels, 4] which provides the index of all neighbors of every pixel in the
309-
rectangular pixelization (entries of -1 correspond to no neighbor).
311+
shape_native
312+
The (rows, cols) shape of the rectangular 2D pixel grid.
310313
311314
Returns
312315
-------
313-
A list of the 1D indices of all pixels on the edge of a rectangular pixelization.
316+
A list of the 1D indices of all edge pixels.
314317
"""
315-
edge_pixel_list = []
318+
rows, cols = shape_native
319+
320+
# Top row
321+
top = np.arange(0, cols)
322+
323+
# Bottom row
324+
bottom = np.arange((rows - 1) * cols, rows * cols)
325+
326+
# Left column (excluding corners)
327+
left = np.arange(1, rows - 1) * cols
328+
329+
# Right column (excluding corners)
330+
right = (np.arange(1, rows - 1) + 1) * cols - 1
316331

317-
for i, neighbors in enumerate(neighbors):
318-
if -1 in neighbors:
319-
edge_pixel_list.append(i)
332+
# Concatenate all edge indices
333+
edge_pixel_indices = np.concatenate([top, left, right, bottom])
320334

321-
return edge_pixel_list
335+
# Sort and return
336+
return np.sort(edge_pixel_indices).tolist()
322337

323338

324339
@numba_util.jit()

0 commit comments

Comments
 (0)