66from autoarray .inversion .pixelization .mappers .abstract import PixSubWeights
77
88
9- def pix_indexes_for_sub_slim_index_delaunay_from (
10- source_plane_data_grid , # (N_sub, 2)
11- simplex_index_for_sub_slim_index , # (N_sub,)
12- pix_indexes_for_simplex_index , # (M, 3)
13- delaunay_points , # (N_points, 2)
14- xp = np , # <--- choose backend: np or jnp
15- ):
16- """
17- XP-compatible version of pix_indexes_for_sub_slim_index_delaunay_from.
18-
19- If xp=np: runs with NumPy (no JAX needed)
20- If xp=jnp: runs inside JAX (jit, vmap, GPU)
21-
22- Returns
23- -------
24- pix_indexes_for_sub_slim_index : (N_sub, 3)
25- pix_indexes_for_sub_slim_index_sizes : (N_sub,)
26- """
27-
28- # Helper: xp.ones_like that supports setting dtype
29- def ones_like (x , dtype ):
30- return xp .ones (x .shape , dtype = dtype )
31-
32- N_sub = source_plane_data_grid .shape [0 ]
33-
34- # Boolean mask for points that fall inside simplices
35- inside_mask = simplex_index_for_sub_slim_index >= 0 # shape (N_sub,)
36- outside_mask = ~ inside_mask
37-
38- # ----------------------------
39- # Case 1: inside a simplex
40- # ----------------------------
41- # (N_sub, 3)
42- pix_inside = xp .where (
43- inside_mask [:, None ],
44- pix_indexes_for_simplex_index [simplex_index_for_sub_slim_index ],
45- - ones_like ((inside_mask [:, None ] + 0 ), dtype = np .int32 ), # -1 filler
46- )
47-
48- # ----------------------------
49- # Case 2: outside any simplex → nearest delaunay point
50- # ----------------------------
51- # Squared distances: (N_sub, N_points)
52- d2 = xp .sum (
53- (source_plane_data_grid [:, None , :] - delaunay_points [None , :, :]) ** 2.0 ,
54- axis = - 1 ,
55- )
56-
57- nearest = xp .argmin (d2 , axis = 1 ).astype (np .int32 ) # (N_sub,)
58-
59- # (N_sub, 3) → [nearest, -1, -1]
60- nn_triplets = xp .stack (
61- [
62- nearest ,
63- - ones_like (nearest , dtype = np .int32 ),
64- - ones_like (nearest , dtype = np .int32 ),
65- ],
66- axis = 1 ,
67- )
689
69- pix_outside = xp .where (
70- outside_mask [:, None ],
71- nn_triplets ,
72- - ones_like ((outside_mask [:, None ] + 0 ), dtype = np .int32 ),
73- )
74-
75- # ----------------------------
76- # Combine inside + outside
77- # ----------------------------
78- pix_indexes_for_sub_slim_index = xp .where (
79- inside_mask [:, None ],
80- pix_inside ,
81- pix_outside ,
82- )
83-
84- # ----------------------------
85- # Count valid entries
86- # ----------------------------
87- pix_sizes = xp .sum (pix_indexes_for_sub_slim_index >= 0 , axis = 1 )
88-
89- return pix_indexes_for_sub_slim_index , pix_sizes
9010
9111
9212def triangle_area_xp (c0 , c1 , c2 , xp ):
@@ -269,22 +189,8 @@ def pix_sub_weights(self) -> PixSubWeights:
269189 """
270190 delaunay = self .delaunay
271191
272- simplex_index_for_sub_slim_index = delaunay .simplex_index_for_sub_slim_index
273- pix_indexes_for_simplex_index = delaunay .simplices
274-
275- mappings , sizes = pix_indexes_for_sub_slim_index_delaunay_from (
276- source_plane_data_grid = self .source_plane_data_grid .over_sampled ,
277- simplex_index_for_sub_slim_index = simplex_index_for_sub_slim_index ,
278- pix_indexes_for_simplex_index = pix_indexes_for_simplex_index ,
279- delaunay_points = delaunay .points ,
280- xp = self ._xp ,
281- )
282-
283- mappings = mappings .astype ("int" )
284- sizes = sizes .astype ("int" )
285-
286- print (mappings )
287- print (sizes )
192+ mappings = delaunay .mappings .astype ("int" )
193+ sizes = delaunay .sizes .astype ("int" )
288194
289195 weights = pixel_weights_delaunay_from (
290196 source_plane_data_grid = self .source_plane_data_grid .over_sampled ,
@@ -310,24 +216,10 @@ def pix_sub_weights_split_cross(self) -> PixSubWeights:
310216 """
311217 delaunay = self .delaunay
312218
313- splitted_simplex_index_for_sub_slim_index = delaunay .splitted_simplex_index_for_sub_slim_index
314- pix_indexes_for_simplex_index = delaunay .simplices
315-
316- (
317- splitted_mappings ,
318- splitted_sizes ,
319- ) = pix_indexes_for_sub_slim_index_delaunay_from (
320- source_plane_data_grid = self .source_plane_mesh_grid .split_cross ,
321- simplex_index_for_sub_slim_index = splitted_simplex_index_for_sub_slim_index ,
322- pix_indexes_for_simplex_index = pix_indexes_for_simplex_index ,
323- delaunay_points = delaunay .points ,
324- xp = self ._xp ,
325- )
326-
327219 splitted_weights = pixel_weights_delaunay_from (
328- source_plane_data_grid = self . source_plane_mesh_grid .split_cross ,
220+ source_plane_data_grid = delaunay .split_cross ,
329221 source_plane_mesh_grid = self .source_plane_mesh_grid .array ,
330- pix_indexes_for_sub_slim_index = splitted_mappings .astype ("int" ),
222+ pix_indexes_for_sub_slim_index = delaunay . splitted_mappings .astype ("int" ),
331223 xp = self ._xp ,
332224 )
333225
@@ -336,8 +228,8 @@ def pix_sub_weights_split_cross(self) -> PixSubWeights:
336228
337229 return PixSubWeights (
338230 mappings = self ._xp .hstack (
339- (splitted_mappings .astype (self ._xp .int32 ), append_line_int )
231+ (delaunay . splitted_mappings .astype (self ._xp .int32 ), append_line_int )
340232 ),
341- sizes = splitted_sizes .astype (self ._xp .int32 ),
233+ sizes = delaunay . splitted_sizes .astype (self ._xp .int32 ),
342234 weights = self ._xp .hstack ((splitted_weights , append_line_float )),
343235 )
0 commit comments