1414from autoarray .inversion .pixelization .mesh import mesh_numba_util
1515
1616
17- def scipy_delaunay (points_np , query_points_np , max_simplices ):
17+ def scipy_delaunay (points_np , query_points_np , source_pixel_zeroed_indices , max_simplices = None ):
1818 """Compute Delaunay simplices (simplices_padded) and Voronoi areas in one call."""
1919
2020 # --- Delaunay mesh using source plane data grid ---
@@ -44,6 +44,12 @@ def scipy_delaunay(points_np, query_points_np, max_simplices):
4444 xp = np ,
4545 )
4646
47+ # max_area = np.percentile(barycentric_dual_areas, 90.0)
48+ # barycentric_dual_areas[source_pixel_zeroed_indices] = max_area
49+
50+ max_area = 100.0 * np .max (barycentric_dual_areas )
51+ barycentric_dual_areas [source_pixel_zeroed_indices ] = max_area
52+
4753 # ---------- Areas used to weight split points ----------
4854 split_point_areas = 0.5 * np .sqrt (barycentric_dual_areas )
4955
@@ -66,7 +72,7 @@ def scipy_delaunay(points_np, query_points_np, max_simplices):
6672 return points , simplices_padded , mappings , split_points , splitted_mappings
6773
6874
69- def jax_delaunay (points , query_points ):
75+ def jax_delaunay (points , query_points , source_pixel_zeroed_indices ):
7076 import jax
7177 import jax .numpy as jnp
7278
@@ -81,8 +87,8 @@ def jax_delaunay(points, query_points):
8187 splitted_mappings_shape = jax .ShapeDtypeStruct ((N * 4 , 3 ), jnp .int32 )
8288
8389 return jax .pure_callback (
84- lambda points , qpts : scipy_delaunay (
85- np .asarray (points ), np .asarray (qpts ), max_simplices
90+ lambda points , qpts , spzi : scipy_delaunay (
91+ np .asarray (points ), np .asarray (qpts ), np . asarray ( spzi ), max_simplices
8692 ),
8793 (
8894 points_shape ,
@@ -93,6 +99,7 @@ def jax_delaunay(points, query_points):
9399 ),
94100 points ,
95101 query_points ,
102+ source_pixel_zeroed_indices
96103 )
97104
98105
@@ -272,6 +279,7 @@ def __init__(
272279 self ,
273280 values : Union [np .ndarray , List ],
274281 source_plane_data_grid_over_sampled = None ,
282+ preloads = None ,
275283 _xp = np ,
276284 ):
277285 """
@@ -307,6 +315,7 @@ def __init__(
307315 super ().__init__ (values , xp = _xp )
308316
309317 self ._source_plane_data_grid_over_sampled = source_plane_data_grid_over_sampled
318+ self .preloads = preloads
310319
311320 @property
312321 def geometry (self ):
@@ -361,14 +370,16 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
361370
362371 points , simplices , mappings , split_points , splitted_mappings = jax_delaunay (
363372 points = self .mesh_grid_xy ,
364- query_points = self ._source_plane_data_grid_over_sampled
373+ query_points = self ._source_plane_data_grid_over_sampled ,
374+ source_pixel_zeroed_indices = self .preloads .source_pixel_zeroed_indices
365375 )
366376
367377 else :
368378
369379 points , simplices , mappings , split_points , splitted_mappings = scipy_delaunay (
370380 points = self .mesh_grid_xy ,
371- query_points_np = self .source_plane_data_grid_over_sampled
381+ query_points_np = self .source_plane_data_grid_over_sampled ,
382+ source_pixel_zeroed_indices = self .preloads .source_pixel_zeroed_indices
372383 )
373384
374385 return DelaunayInterface (
0 commit comments