Skip to content

Commit e1ee5d3

Browse files
Jammy2211Jammy2211
authored andcommitted
black
1 parent ee0ccf3 commit e1ee5d3

File tree

10 files changed

+179
-184
lines changed

10 files changed

+179
-184
lines changed

autoarray/inversion/pixelization/mappers/delaunay.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010

1111
def pix_indexes_for_sub_slim_index_delaunay_from(
12-
source_plane_data_grid, # (N_sub, 2)
12+
source_plane_data_grid, # (N_sub, 2)
1313
simplex_index_for_sub_slim_index, # (N_sub,)
14-
pix_indexes_for_simplex_index, # (M, 3)
15-
delaunay_points, # (N_points, 2)
16-
xp=np, # <--- choose backend: np or jnp
14+
pix_indexes_for_simplex_index, # (M, 3)
15+
delaunay_points, # (N_points, 2)
16+
xp=np, # <--- choose backend: np or jnp
1717
):
1818
"""
1919
XP-compatible version of pix_indexes_for_sub_slim_index_delaunay_from.
@@ -34,7 +34,7 @@ def ones_like(x, dtype):
3434
N_sub = source_plane_data_grid.shape[0]
3535

3636
# Boolean mask for points that fall inside simplices
37-
inside_mask = simplex_index_for_sub_slim_index >= 0 # shape (N_sub,)
37+
inside_mask = simplex_index_for_sub_slim_index >= 0 # shape (N_sub,)
3838
outside_mask = ~inside_mask
3939

4040
# ----------------------------
@@ -56,7 +56,7 @@ def ones_like(x, dtype):
5656
axis=-1,
5757
)
5858

59-
nearest = xp.argmin(d2, axis=1).astype(np.int32) # (N_sub,)
59+
nearest = xp.argmin(d2, axis=1).astype(np.int32) # (N_sub,)
6060

6161
# (N_sub, 3) → [nearest, -1, -1]
6262
nn_triplets = xp.stack(
@@ -93,23 +93,24 @@ def ones_like(x, dtype):
9393

9494
import numpy as np
9595

96+
9697
def triangle_area_xp(c0, c1, c2, xp):
9798
"""
9899
Twice triangle area using vector cross product magnitude.
99100
Calling via xp ensures NumPy or JAX backend operation.
100101
"""
101-
v0 = c1 - c0 # (..., 2)
102+
v0 = c1 - c0 # (..., 2)
102103
v1 = c2 - c0
103104
cross = v0[..., 0] * v1[..., 1] - v0[..., 1] * v1[..., 0]
104105
return xp.abs(cross)
105106

106107

107108
def pixel_weights_delaunay_from(
108-
source_plane_data_grid, # (N_sub, 2)
109-
source_plane_mesh_grid, # (N_pix, 2)
109+
source_plane_data_grid, # (N_sub, 2)
110+
source_plane_mesh_grid, # (N_pix, 2)
110111
slim_index_for_sub_slim_index, # (N_sub,) UNUSED? kept for signature compatibility
111-
pix_indexes_for_sub_slim_index, # (N_sub, 3), padded with -1
112-
xp=np, # backend: np (default) or jnp
112+
pix_indexes_for_sub_slim_index, # (N_sub, 3), padded with -1
113+
xp=np, # backend: np (default) or jnp
113114
):
114115
"""
115116
XP-compatible (NumPy/JAX) version of pixel_weights_delaunay_from.
@@ -124,7 +125,7 @@ def pixel_weights_delaunay_from(
124125
# -----------------------------
125126
# If pix_indexes_for_sub_slim_index[sub][1] == -1 → NOT in simplex
126127
has_simplex = pix_indexes_for_sub_slim_index[:, 1] != -1 # (N_sub,)
127-
no_simplex = xp.logical_not(has_simplex)
128+
no_simplex = xp.logical_not(has_simplex)
128129

129130
# -----------------------------
130131
# GATHER TRIANGLE VERTICES
@@ -135,12 +136,12 @@ def pixel_weights_delaunay_from(
135136
# (N_sub, 3, 2)
136137
vertices = source_plane_mesh_grid[safe_indices]
137138

138-
p0 = vertices[:, 0] # (N_sub, 2)
139+
p0 = vertices[:, 0] # (N_sub, 2)
139140
p1 = vertices[:, 1]
140141
p2 = vertices[:, 2]
141142

142143
# Query points
143-
q = source_plane_data_grid # (N_sub, 2)
144+
q = source_plane_data_grid # (N_sub, 2)
144145

145146
# -----------------------------
146147
# TRIANGLE AREAS (barycentric numerators)
@@ -164,23 +165,17 @@ def pixel_weights_delaunay_from(
164165
xp.zeros(N_sub),
165166
xp.zeros(N_sub),
166167
],
167-
axis=1
168+
axis=1,
168169
)
169170

170171
# -----------------------------
171172
# SELECT BETWEEN CASES
172173
# -----------------------------
173-
pixel_weights = xp.where(
174-
has_simplex[:, None],
175-
weights_bary,
176-
weights_nn
177-
)
174+
pixel_weights = xp.where(has_simplex[:, None], weights_bary, weights_nn)
178175

179176
return pixel_weights
180177

181178

182-
183-
184179
class MapperDelaunay(AbstractMapper):
185180
"""
186181
To understand a `Mapper` one must be familiar `Mesh` objects and the `mesh` and `pixelization` packages, where
@@ -285,14 +280,12 @@ def pix_sub_weights(self) -> PixSubWeights:
285280
)
286281
pix_indexes_for_simplex_index = delaunay.simplices
287282

288-
mappings, sizes = (
289-
pix_indexes_for_sub_slim_index_delaunay_from(
290-
source_plane_data_grid=self.source_plane_data_grid.over_sampled,
291-
simplex_index_for_sub_slim_index=simplex_index_for_sub_slim_index,
292-
pix_indexes_for_simplex_index=pix_indexes_for_simplex_index,
293-
delaunay_points=delaunay.points,
294-
xp=self._xp,
295-
)
283+
mappings, sizes = pix_indexes_for_sub_slim_index_delaunay_from(
284+
source_plane_data_grid=self.source_plane_data_grid.over_sampled,
285+
simplex_index_for_sub_slim_index=simplex_index_for_sub_slim_index,
286+
pix_indexes_for_simplex_index=pix_indexes_for_simplex_index,
287+
delaunay_points=delaunay.points,
288+
xp=self._xp,
296289
)
297290

298291
mappings = mappings.astype("int")
@@ -351,7 +344,9 @@ def pix_sub_weights_split_cross(self) -> PixSubWeights:
351344
append_line_float = np.zeros((len(splitted_weights), 1), dtype="float")
352345

353346
return PixSubWeights(
354-
mappings=self._xp.hstack((splitted_mappings.astype(self._xp.int32), append_line_int)),
347+
mappings=self._xp.hstack(
348+
(splitted_mappings.astype(self._xp.int32), append_line_int)
349+
),
355350
sizes=splitted_sizes.astype(self._xp.int32),
356351
weights=self._xp.hstack((splitted_weights, append_line_float)),
357352
)

autoarray/inversion/pixelization/mesh/delaunay.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,4 @@ def mesh_grid_from(
5454
Settings controlling the pixelization for example if a border is used to relocate its exterior coordinates.
5555
"""
5656

57-
return Mesh2DDelaunay(
58-
values=source_plane_mesh_grid, _xp=xp
59-
)
57+
return Mesh2DDelaunay(values=source_plane_mesh_grid, _xp=xp)

autoarray/inversion/pixelization/mesh/triangulation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,21 @@ def mapper_grids_from(
6464
relocated_grid = self.relocated_grid_from(
6565
border_relocator=border_relocator,
6666
source_plane_data_grid=source_plane_data_grid,
67-
xp=xp
67+
xp=xp,
6868
)
6969

7070
relocated_mesh_grid = self.relocated_mesh_grid_from(
7171
border_relocator=border_relocator,
7272
source_plane_data_grid=relocated_grid.over_sampled,
7373
source_plane_mesh_grid=source_plane_mesh_grid,
74-
xp=xp
74+
xp=xp,
7575
)
7676

7777
try:
7878
source_plane_mesh_grid = self.mesh_grid_from(
7979
source_plane_data_grid=relocated_grid.over_sampled,
8080
source_plane_mesh_grid=relocated_mesh_grid,
81-
xp=xp
81+
xp=xp,
8282
)
8383
except ValueError as e:
8484
raise e

autoarray/inversion/regularization/adaptive_brightness_split.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,13 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray
104104
splitted_mappings=pix_sub_weights_split_cross.mappings,
105105
splitted_sizes=pix_sub_weights_split_cross.sizes,
106106
splitted_weights=pix_sub_weights_split_cross.weights,
107-
xp=xp
107+
xp=xp,
108108
)
109109

110110
return regularization_util.pixel_splitted_regularization_matrix_from(
111111
regularization_weights=regularization_weights,
112112
splitted_mappings=splitted_mappings,
113113
splitted_sizes=splitted_sizes,
114114
splitted_weights=splitted_weights,
115-
xp=xp
115+
xp=xp,
116116
)

autoarray/inversion/regularization/adaptive_brightness_split_zeroth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray
106106
splitted_mappings=pix_sub_weights_split_cross.mappings,
107107
splitted_sizes=pix_sub_weights_split_cross.sizes,
108108
splitted_weights=pix_sub_weights_split_cross.weights,
109-
xp=xp
109+
xp=xp,
110110
)
111111

112112
regularization_matrix = (
@@ -115,7 +115,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray
115115
splitted_mappings=splitted_mappings,
116116
splitted_sizes=splitted_sizes,
117117
splitted_weights=splitted_weights,
118-
xp=xp
118+
xp=xp,
119119
)
120120
)
121121

autoarray/inversion/regularization/constant_split.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray
6666
splitted_mappings=pix_sub_weights_split_cross.mappings,
6767
splitted_sizes=pix_sub_weights_split_cross.sizes,
6868
splitted_weights=pix_sub_weights_split_cross.weights,
69-
xp=xp
69+
xp=xp,
7070
)
7171

7272
pixels = int(len(splitted_mappings) / 4)
@@ -77,5 +77,5 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray
7777
splitted_mappings=splitted_mappings,
7878
splitted_sizes=splitted_sizes,
7979
splitted_weights=splitted_weights,
80-
xp=xp
80+
xp=xp,
8181
)

autoarray/inversion/regularization/regularization_util.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def reg_split_np_from(
8282

8383
return splitted_mappings, splitted_sizes, splitted_weights
8484

85+
8586
def reg_split_from(
8687
splitted_mappings: np.ndarray,
8788
splitted_sizes: np.ndarray,
@@ -143,27 +144,27 @@ def reg_split_from(
143144
# 2. Pixel index for each row: i // 4
144145
# -------------------------------------------------------------
145146
pixel_index = (jnp.arange(N) // 4).astype(mappings.dtype) # (N,)
146-
pix_b = pixel_index[:, None] # (N,1)
147+
pix_b = pixel_index[:, None] # (N,1)
147148

148149
# -------------------------------------------------------------
149150
# 3. Mask of valid columns j < size[i]
150151
# -------------------------------------------------------------
151-
cols = jnp.arange(K)[None, :] # (1,4)
152-
valid_mask = cols < sizes[:, None] # (N,4)
152+
cols = jnp.arange(K)[None, :] # (1,4)
153+
valid_mask = cols < sizes[:, None] # (N,4)
153154

154155
# -------------------------------------------------------------
155156
# 4. Self match: mapping[i,j] == pixel_index AND j is valid
156157
# -------------------------------------------------------------
157-
self_mask = (mappings == pix_b) & valid_mask # (N,4)
158-
row_has_self = jnp.any(self_mask, axis=1) # (N,)
158+
self_mask = (mappings == pix_b) & valid_mask # (N,4)
159+
row_has_self = jnp.any(self_mask, axis=1) # (N,)
159160

160161
# Position of self per row
161-
self_pos = jnp.argmax(self_mask, axis=1) # (N,)
162+
self_pos = jnp.argmax(self_mask, axis=1) # (N,)
162163

163164
# -------------------------------------------------------------
164165
# 5. Add +1 weight at self_pos where row_has_self == True
165166
# -------------------------------------------------------------
166-
one_hot = jnn.one_hot(self_pos, K, dtype=weights.dtype) # (N,4)
167+
one_hot = jnn.one_hot(self_pos, K, dtype=weights.dtype) # (N,4)
167168
weights = weights + one_hot * row_has_self[:, None]
168169

169170
# -------------------------------------------------------------
@@ -172,21 +173,19 @@ def reg_split_from(
172173
no_self = ~row_has_self
173174

174175
# Insert position = sizes[i]
175-
insert_pos = sizes # (N,)
176+
insert_pos = sizes # (N,)
176177
insert_mask = no_self[:, None] & (cols == sizes[:, None])
177178

178179
# New mappings and weights
179180
mappings = jnp.where(insert_mask, pix_b, mappings)
180-
weights = jnp.where(insert_mask, jnp.array(1.0, weights.dtype), weights)
181+
weights = jnp.where(insert_mask, jnp.array(1.0, weights.dtype), weights)
181182

182183
# Updated sizes: +1 if no self detected
183184
sizes_new = sizes + no_self.astype(sizes.dtype)
184185

185186
return mappings, sizes_new, weights
186187

187188

188-
189-
190189
def pixel_splitted_regularization_matrix_np_from(
191190
regularization_weights: np.ndarray,
192191
splitted_mappings: np.ndarray,
@@ -228,12 +227,11 @@ def pixel_splitted_regularization_matrix_np_from(
228227
return regularization_matrix
229228

230229

231-
232230
def pixel_splitted_regularization_matrix_from(
233-
regularization_weights: np.ndarray, # (P,)
234-
splitted_mappings: np.ndarray, # (4P, 4)
235-
splitted_sizes: np.ndarray, # (4P,)
236-
splitted_weights: np.ndarray, # (4P, 4)
231+
regularization_weights: np.ndarray, # (P,)
232+
splitted_mappings: np.ndarray, # (4P, 4)
233+
splitted_sizes: np.ndarray, # (4P,)
234+
splitted_weights: np.ndarray, # (4P, 4)
237235
xp=np,
238236
):
239237
"""
@@ -279,33 +277,33 @@ def pixel_splitted_regularization_matrix_from(
279277
P = splitted_mappings.shape[0] // 4
280278

281279
# Square, positive regularization weights
282-
reg_w = regularization_weights**2.0 # (P,)
280+
reg_w = regularization_weights**2.0 # (P,)
283281

284282
# Add diagonal jitter (2e-8)
285-
reg_mat = jnp.eye(P) * 2e-8 # (P, P)
283+
reg_mat = jnp.eye(P) * 2e-8 # (P, P)
286284

287285
# ----- Build all 4P contributions at once -----
288286

289287
# Mask away padded entries (where mapping = -1)
290-
valid = splitted_mappings != -1 # (4P, 4)
288+
valid = splitted_mappings != -1 # (4P, 4)
291289

292290
# Extract valid mapping rows and weights
293291
# BUT keep fixed shape (4) and just zero out invalid ones
294292
map_fixed = jnp.where(valid, splitted_mappings, 0) # (4P, 4)
295-
w_fixed = jnp.where(valid, splitted_weights, 0.) # (4P, 4)
293+
w_fixed = jnp.where(valid, splitted_weights, 0.0) # (4P, 4)
296294

297295
# Compute all outer products of weights
298296
# w_fixed[:, :, None] * w_fixed[:, None, :] → (4P, 4, 4)
299-
outer = w_fixed[:, :, None] * w_fixed[:, None, :] # (4P, 4, 4)
297+
outer = w_fixed[:, :, None] * w_fixed[:, None, :] # (4P, 4, 4)
300298

301299
# Build corresponding row and col index grids
302-
rows = map_fixed[:, :, None] # (4P, 4, 1)
303-
cols = map_fixed[:, None, :] # (4P, 1, 4)
300+
rows = map_fixed[:, :, None] # (4P, 4, 1)
301+
cols = map_fixed[:, None, :] # (4P, 1, 4)
304302

305303
# Multiply each 4x4 block by its pixel’s regularization weight
306304
# Rows 0–3 belong to pixel 0, rows 4–7 to pixel 1, etc.
307-
pixel_index = jnp.arange(4 * P) // 4 # (4P,)
308-
block_scale = reg_w[pixel_index] # (4P,)
305+
pixel_index = jnp.arange(4 * P) // 4 # (4P,)
306+
block_scale = reg_w[pixel_index] # (4P,)
309307
outer_scaled = outer * block_scale[:, None, None]
310308

311309
# Now scatter-add all entries into the (P,P) matrix

0 commit comments

Comments
 (0)