Skip to content

Commit 92b6bb9

Browse files
Jammy2211Jammy2211
authored andcommitted
seems to work fast without crazy memory use, time to integration test
1 parent 5d07361 commit 92b6bb9

File tree

2 files changed

+186
-209
lines changed

2 files changed

+186
-209
lines changed

autoarray/inversion/pixelization/mappers/delaunay.py

Lines changed: 6 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -6,87 +6,7 @@
66
from autoarray.inversion.pixelization.mappers.abstract import PixSubWeights
77

88

9-
def pix_indexes_for_sub_slim_index_delaunay_from(
10-
source_plane_data_grid, # (N_sub, 2)
11-
simplex_index_for_sub_slim_index, # (N_sub,)
12-
pix_indexes_for_simplex_index, # (M, 3)
13-
delaunay_points, # (N_points, 2)
14-
xp=np, # <--- choose backend: np or jnp
15-
):
16-
"""
17-
XP-compatible version of pix_indexes_for_sub_slim_index_delaunay_from.
18-
19-
If xp=np: runs with NumPy (no JAX needed)
20-
If xp=jnp: runs inside JAX (jit, vmap, GPU)
21-
22-
Returns
23-
-------
24-
pix_indexes_for_sub_slim_index : (N_sub, 3)
25-
pix_indexes_for_sub_slim_index_sizes : (N_sub,)
26-
"""
27-
28-
# Helper: xp.ones_like that supports setting dtype
29-
def ones_like(x, dtype):
30-
return xp.ones(x.shape, dtype=dtype)
31-
32-
N_sub = source_plane_data_grid.shape[0]
33-
34-
# Boolean mask for points that fall inside simplices
35-
inside_mask = simplex_index_for_sub_slim_index >= 0 # shape (N_sub,)
36-
outside_mask = ~inside_mask
37-
38-
# ----------------------------
39-
# Case 1: inside a simplex
40-
# ----------------------------
41-
# (N_sub, 3)
42-
pix_inside = xp.where(
43-
inside_mask[:, None],
44-
pix_indexes_for_simplex_index[simplex_index_for_sub_slim_index],
45-
-ones_like((inside_mask[:, None] + 0), dtype=np.int32), # -1 filler
46-
)
47-
48-
# ----------------------------
49-
# Case 2: outside any simplex → nearest delaunay point
50-
# ----------------------------
51-
# Squared distances: (N_sub, N_points)
52-
d2 = xp.sum(
53-
(source_plane_data_grid[:, None, :] - delaunay_points[None, :, :]) ** 2.0,
54-
axis=-1,
55-
)
56-
57-
nearest = xp.argmin(d2, axis=1).astype(np.int32) # (N_sub,)
58-
59-
# (N_sub, 3) → [nearest, -1, -1]
60-
nn_triplets = xp.stack(
61-
[
62-
nearest,
63-
-ones_like(nearest, dtype=np.int32),
64-
-ones_like(nearest, dtype=np.int32),
65-
],
66-
axis=1,
67-
)
689

69-
pix_outside = xp.where(
70-
outside_mask[:, None],
71-
nn_triplets,
72-
-ones_like((outside_mask[:, None] + 0), dtype=np.int32),
73-
)
74-
75-
# ----------------------------
76-
# Combine inside + outside
77-
# ----------------------------
78-
pix_indexes_for_sub_slim_index = xp.where(
79-
inside_mask[:, None],
80-
pix_inside,
81-
pix_outside,
82-
)
83-
84-
# ----------------------------
85-
# Count valid entries
86-
# ----------------------------
87-
pix_sizes = xp.sum(pix_indexes_for_sub_slim_index >= 0, axis=1)
88-
89-
return pix_indexes_for_sub_slim_index, pix_sizes
9010

9111

9212
def triangle_area_xp(c0, c1, c2, xp):
@@ -269,22 +189,8 @@ def pix_sub_weights(self) -> PixSubWeights:
269189
"""
270190
delaunay = self.delaunay
271191

272-
simplex_index_for_sub_slim_index = delaunay.simplex_index_for_sub_slim_index
273-
pix_indexes_for_simplex_index = delaunay.simplices
274-
275-
mappings, sizes = pix_indexes_for_sub_slim_index_delaunay_from(
276-
source_plane_data_grid=self.source_plane_data_grid.over_sampled,
277-
simplex_index_for_sub_slim_index=simplex_index_for_sub_slim_index,
278-
pix_indexes_for_simplex_index=pix_indexes_for_simplex_index,
279-
delaunay_points=delaunay.points,
280-
xp=self._xp,
281-
)
282-
283-
mappings = mappings.astype("int")
284-
sizes = sizes.astype("int")
285-
286-
print(mappings)
287-
print(sizes)
192+
mappings = delaunay.mappings.astype("int")
193+
sizes = delaunay.sizes.astype("int")
288194

289195
weights = pixel_weights_delaunay_from(
290196
source_plane_data_grid=self.source_plane_data_grid.over_sampled,
@@ -310,24 +216,10 @@ def pix_sub_weights_split_cross(self) -> PixSubWeights:
310216
"""
311217
delaunay = self.delaunay
312218

313-
splitted_simplex_index_for_sub_slim_index = delaunay.splitted_simplex_index_for_sub_slim_index
314-
pix_indexes_for_simplex_index = delaunay.simplices
315-
316-
(
317-
splitted_mappings,
318-
splitted_sizes,
319-
) = pix_indexes_for_sub_slim_index_delaunay_from(
320-
source_plane_data_grid=self.source_plane_mesh_grid.split_cross,
321-
simplex_index_for_sub_slim_index=splitted_simplex_index_for_sub_slim_index,
322-
pix_indexes_for_simplex_index=pix_indexes_for_simplex_index,
323-
delaunay_points=delaunay.points,
324-
xp=self._xp,
325-
)
326-
327219
splitted_weights = pixel_weights_delaunay_from(
328-
source_plane_data_grid=self.source_plane_mesh_grid.split_cross,
220+
source_plane_data_grid=delaunay.split_cross,
329221
source_plane_mesh_grid=self.source_plane_mesh_grid.array,
330-
pix_indexes_for_sub_slim_index=splitted_mappings.astype("int"),
222+
pix_indexes_for_sub_slim_index=delaunay.splitted_mappings.astype("int"),
331223
xp=self._xp,
332224
)
333225

@@ -336,8 +228,8 @@ def pix_sub_weights_split_cross(self) -> PixSubWeights:
336228

337229
return PixSubWeights(
338230
mappings=self._xp.hstack(
339-
(splitted_mappings.astype(self._xp.int32), append_line_int)
231+
(delaunay.splitted_mappings.astype(self._xp.int32), append_line_int)
340232
),
341-
sizes=splitted_sizes.astype(self._xp.int32),
233+
sizes=delaunay.splitted_sizes.astype(self._xp.int32),
342234
weights=self._xp.hstack((splitted_weights, append_line_float)),
343235
)

0 commit comments

Comments
 (0)