11import numpy as np
22import scipy .spatial
3- from scipy .spatial import cKDTree , Delaunay
3+ from scipy .spatial import cKDTree , Delaunay , Voronoi
44from typing import List , Union , Optional , Tuple
55
66from autoconf import cached_property
1414from autoarray .inversion .pixelization .mesh import mesh_numba_util
1515
1616
17- def scipy_delaunay (points_np , query_points_np , source_pixel_zeroed_indices , max_simplices = None ):
17+ def scipy_delaunay (points_np , query_points_np , source_pixel_zeroed_indices ):
1818 """Compute Delaunay simplices (simplices_padded) and Voronoi areas in one call."""
1919
20+ max_simplices = 2 * points_np .shape [0 ]
21+
2022 # --- Delaunay mesh using source plane data grid ---
2123 tri = Delaunay (points_np )
2224
@@ -37,21 +39,31 @@ def scipy_delaunay(points_np, query_points_np, source_pixel_zeroed_indices, max_
3739 delaunay_points = points_np ,
3840 )
3941
40- # ---------- Baronicentric Dual areas ----------
41- barycentric_dual_areas = barycentric_dual_area_from (
42- points ,
43- simplices ,
44- xp = np ,
45- )
42+ # ---------- Baronicentric Dual used to weight split points ----------
43+ # barycentric_dual_areas = np.abs(voronoi_areas_numpy(
44+ # points,
45+ # ))
46+
47+ # barycentric_dual_areas = barycentric_dual_area_from(
48+ # points,
49+ # simplices,
50+ # xp=np,
51+ # )
4652
4753 # max_area = np.percentile(barycentric_dual_areas, 90.0)
4854 # barycentric_dual_areas[source_pixel_zeroed_indices] = max_area
4955
50- max_area = 100.0 * np .max (barycentric_dual_areas )
51- barycentric_dual_areas [source_pixel_zeroed_indices ] = max_area
56+ # ---------- Voronoi Areas used to weight split points ----------
57+ areas = voronoi_areas_numpy (
58+ points ,
59+ )
5260
53- # ---------- Areas used to weight split points ----------
54- split_point_areas = 0.5 * np .sqrt (barycentric_dual_areas )
61+ max_area = np .percentile (areas , 90.0 )
62+
63+ areas [areas == - 1 ] = max_area
64+ areas [areas > max_area ] = max_area
65+
66+ split_point_areas = 0.5 * np .sqrt (areas )
5567
5668 # ---------- Compute split cross points for Split regularization ----------
5769 split_points = split_points_from (
@@ -88,7 +100,7 @@ def jax_delaunay(points, query_points, source_pixel_zeroed_indices):
88100
89101 return jax .pure_callback (
90102 lambda points , qpts , spzi : scipy_delaunay (
91- np .asarray (points ), np .asarray (qpts ), np .asarray (spzi ), max_simplices
103+ np .asarray (points ), np .asarray (qpts ), np .asarray (spzi ),
92104 ),
93105 (
94106 points_shape ,
@@ -161,6 +173,84 @@ def barycentric_dual_area_from(
161173 return dual_area
162174
163175
176+ def voronoi_areas_numpy (points , qhull_options = "Qbb Qc Qx Qm" ):
177+ """
178+ Compute Voronoi cell areas with a fully optimized pure-NumPy pipeline.
179+ Exact match to the per-cell SciPy Voronoi loop but much faster.
180+ """
181+ vor = Voronoi (points , qhull_options = qhull_options )
182+
183+ vertices = vor .vertices
184+ point_region = vor .point_region
185+ regions = vor .regions
186+ N = len (point_region )
187+
188+ # ------------------------------------------------------------
189+ # 1) Collect all region lists in one go (list comprehension is fast)
190+ # ------------------------------------------------------------
191+ region_lists = [regions [r ] for r in point_region ]
192+
193+ # Precompute which regions are unbounded (vectorized test)
194+ unbounded = np .array ([(- 1 in r ) for r in region_lists ], dtype = bool )
195+
196+ # Filter only bounded region vertex indices
197+ clean_regions = [np .asarray ([v for v in r if v != - 1 ], dtype = int )
198+ for r in region_lists ]
199+
200+ # Compute lengths once
201+ lengths = np .array ([len (r ) for r in clean_regions ], dtype = int )
202+ max_len = lengths .max ()
203+
204+ # ------------------------------------------------------------
205+ # 2) Build padded idx + mask in a vectorized-like way
206+ #
207+ # Instead of doing Python work inside the loop, we pre-pack
208+ # the flattened data and then reshape.
209+ # ------------------------------------------------------------
210+ idx = np .full ((N , max_len ), - 1 , dtype = int )
211+ mask = np .zeros ((N , max_len ), dtype = bool )
212+
213+ # Single loop remaining: extremely cheap
214+ for i , (r , L ) in enumerate (zip (clean_regions , lengths )):
215+ if L :
216+ idx [i , :L ] = r
217+ mask [i , :L ] = True
218+
219+ # ------------------------------------------------------------
220+ # 3) Gather polygon vertices (vectorized)
221+ # ------------------------------------------------------------
222+ safe_idx = idx .clip (min = 0 )
223+ verts = vertices [safe_idx ] # (N, max_len, 2)
224+
225+ # Extract x, y with masked invalid entries zeroed
226+ x = np .where (mask , verts [..., 1 ], 0.0 )
227+ y = np .where (mask , verts [..., 0 ], 0.0 )
228+
229+ # ------------------------------------------------------------
230+ # 4) Vectorized "previous index" per polygon
231+ # ------------------------------------------------------------
232+ safe_lengths = np .where (lengths == 0 , 1 , lengths )
233+ j = np .arange (max_len )
234+ prev = (j [None , :] - 1 ) % safe_lengths [:, None ]
235+
236+ # Efficient take-along-axis
237+ x_prev = np .take_along_axis (x , prev , axis = 1 )
238+ y_prev = np .take_along_axis (y , prev , axis = 1 )
239+
240+ # ------------------------------------------------------------
241+ # 5) Shoelace vectorized
242+ # ------------------------------------------------------------
243+ cross = x * y_prev - y * x_prev
244+ areas = 0.5 * np .abs (cross .sum (axis = 1 ))
245+
246+ # ------------------------------------------------------------
247+ # 6) Mark unbounded regions
248+ # ------------------------------------------------------------
249+ areas [unbounded ] = - 1.0
250+
251+ return areas
252+
253+
164254def split_points_from (points , area_weights , xp = np ):
165255 """
166256 points : (N, 2)
@@ -348,7 +438,7 @@ def mesh_grid_xy(self):
348438
349439 Therefore, this property simply converts the (y,x) grid of irregular coordinates into an (x,y) grid.
350440 """
351- return self ._xp .stack ([self .array [:, 1 ], self .array [:, 0 ]]).T
441+ return self ._xp .stack ([self .array [:, 0 ], self .array [:, 1 ]]).T
352442
353443 @cached_property
354444 def delaunay (self ) -> "scipy.spatial.Delaunay" :
@@ -377,9 +467,9 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
377467 else :
378468
379469 points , simplices , mappings , split_points , splitted_mappings = scipy_delaunay (
380- points = self .mesh_grid_xy ,
381- query_points_np = self .source_plane_data_grid_over_sampled ,
382- source_pixel_zeroed_indices = self .preloads .source_pixel_zeroed_indices
470+ points_np = self .mesh_grid_xy ,
471+ query_points_np = self ._source_plane_data_grid_over_sampled ,
472+ source_pixel_zeroed_indices = self .preloads .source_pixel_zeroed_indices ,
383473 )
384474
385475 return DelaunayInterface (
@@ -509,27 +599,7 @@ def voronoi(self) -> "scipy.spatial.Voronoi":
509599
510600 @property
511601 def voronoi_areas (self ):
512-
513- N = self .mesh_grid_xy .shape [0 ]
514-
515- voronoi_areas = np .zeros (N )
516-
517- for i in range (N ):
518- region_vertices_indexes = self .voronoi .regions [self .voronoi .point_region [i ]]
519- if - 1 in region_vertices_indexes :
520- voronoi_areas [i ] = - 1
521- else :
522-
523- points_of_region = self .voronoivertices [region_vertices_indexes ]
524-
525- x = points_of_region [:, 1 ]
526- y = points_of_region [:, 0 ]
527-
528- voronoi_areas [i ] = 0.5 * np .abs (
529- np .dot (x , np .roll (y , 1 )) - np .dot (y , np .roll (x , 1 ))
530- )
531-
532- return voronoi_areas
602+ return voronoi_areas_numpy (points = self .mesh_grid_xy )
533603
534604 @property
535605 def areas_for_magnification (self ) -> np .ndarray :
0 commit comments