@@ -25,61 +25,90 @@ def voronoi_areas_from(points_np):
2525 voronoi_vertices = vor .vertices
2626 voronoi_regions = vor .regions
2727 voronoi_point_region = vor .point_region
28- region_areas = np .zeros (N )
28+ voronoi_areas = np .zeros (N )
2929
3030 for i in range (N ):
3131 region_vertices_indexes = voronoi_regions [voronoi_point_region [i ]]
3232 if - 1 in region_vertices_indexes :
33- region_areas [i ] = - 1
33+ voronoi_areas [i ] = - 1
3434 else :
3535
3636 points_of_region = voronoi_vertices [region_vertices_indexes ]
3737
3838 x = points_of_region [:, 1 ]
3939 y = points_of_region [:, 0 ]
4040
41- region_areas [i ] = 0.5 * np .abs (
41+ voronoi_areas [i ] = 0.5 * np .abs (
4242 np .dot (x , np .roll (y , 1 )) - np .dot (y , np .roll (x , 1 ))
4343 )
4444
45- return region_areas
45+ voronoi_pixel_areas_for_split = voronoi_areas . copy ()
4646
47+ # 90th percentile
48+ max_area = np .percentile (voronoi_pixel_areas_for_split , 90.0 )
4749
48- def scipy_delaunay_voronoi (points_np , max_simplices ):
50+ voronoi_pixel_areas_for_split [voronoi_pixel_areas_for_split == - 1 ] = max_area
51+ voronoi_pixel_areas_for_split [voronoi_pixel_areas_for_split > max_area ] = max_area
52+
53+ half_region_area_sqrt_lengths = 0.5 * np .sqrt (voronoi_pixel_areas_for_split )
54+
55+ split_points = split_points_from (
56+ points = points_np ,
57+ area_weights = half_region_area_sqrt_lengths ,
58+ )
59+
60+ return voronoi_areas , split_points
61+
62+
63+ def scipy_delaunay_voronoi (points_np , query_points_np , max_simplices ):
4964 """Compute Delaunay simplices (padded) and Voronoi areas in one call."""
5065 from scipy .spatial import Delaunay
5166 import numpy as np
5267
53- N = points_np .shape [0 ]
54-
5568 # --- Delaunay ---
5669 tri = Delaunay (points_np )
70+
5771 pts = tri .points .astype (points_np .dtype )
5872 simplices = tri .simplices .astype (np .int32 )
5973
74+ # Pad simplices to max_simplices
6075 padded = - np .ones ((max_simplices , 3 ), dtype = np .int32 )
6176 padded [: simplices .shape [0 ]] = simplices
6277
63- areas = voronoi_areas_from (points_np )
78+ # ---------- Voronoi cell areas ----------
79+ areas , split_points = voronoi_areas_from (points_np )
80+
81+ # ---------- find_simplex ----------
82+ simplex_idx = tri .find_simplex (query_points_np ).astype (np .int32 ) # (Q,)
6483
65- return pts , padded , areas
84+ # ---------- find_simplex for split cross points ----------
85+ split_cross_idx = tri .find_simplex (split_points )
6686
87+ return pts , padded , areas , simplex_idx , split_cross_idx
6788
68- def jax_delaunay_voronoi (points ):
89+
90+ def jax_delaunay_voronoi (points , query_points ):
6991 import jax
7092 import jax .numpy as jnp
7193
7294 N = points .shape [0 ]
95+ Q = query_points .shape [0 ]
96+
97+ # Conservative pad (you can pass exact M if you want)
7398 max_simplices = 2 * N # same logic as before
7499
75100 pts_shape = jax .ShapeDtypeStruct ((N , 2 ), points .dtype )
76101 simp_shape = jax .ShapeDtypeStruct ((max_simplices , 3 ), jnp .int32 )
77102 area_shape = jax .ShapeDtypeStruct ((N ,), points .dtype )
103+ sx_shape = jax .ShapeDtypeStruct ((Q ,), jnp .int32 )
104+ sc_shape = jax .ShapeDtypeStruct ((N * 4 ,), jnp .int32 )
78105
79106 return jax .pure_callback (
80- lambda pts : scipy_delaunay_voronoi (pts , max_simplices ),
81- (pts_shape , simp_shape , area_shape ),
82- points ,
107+ lambda pts , qpts : scipy_delaunay_voronoi (
108+ np .asarray (pts ), np .asarray (qpts ), max_simplices
109+ ),
110+ (pts_shape , simp_shape , area_shape , sx_shape , sc_shape ),
111+ points , query_points ,
83112 )
84113
85114
@@ -110,61 +139,60 @@ def jax_delaunay_voronoi(points):
110139# return pts_out, simplices_out
111140
112141
113- def find_simplex_from (query_points , points , simplices ):
114- """
115- Return simplex index for each query point.
116- Returns -1 where no simplex contains the point.
117- """
118-
119- import jax
120- import jax .numpy as jnp
121-
122- # Mask padded simplices (marked with -1)
123- valid = simplices [:, 0 ] >= 0 # (M,)
124- simplices_clipped = simplices .clip (min = 0 )
125-
126- # Triangle vertices: (M, 3, 2)
127- tri = points [simplices_clipped ]
128-
129- p0 = tri [:, 0 ] # (M, 2)
130- p1 = tri [:, 1 ]
131- p2 = tri [:, 2 ]
132-
133- # Edges
134- v0 = p1 - p0 # (M, 2)
135- v1 = p2 - p0
136-
137- # Precomputed dot products
138- d00 = jnp .sum (v0 * v0 , axis = 1 ) # (M,)
139- d01 = jnp .sum (v0 * v1 , axis = 1 )
140- d11 = jnp .sum (v1 * v1 , axis = 1 )
141- denom = d00 * d11 - d01 * d01 # (M,)
142-
143- # Barycentric computation for each query point vs each triangle
144- diff = query_points [:, None , :] - p0 [None , :, :] # (Q, M, 2)
145-
146- a = jnp .sum (diff * v0 [None , :, :], axis = - 1 ) # (Q, M)
147- b = jnp .sum (diff * v1 [None , :, :], axis = - 1 )
148-
149- b0 = (a * d11 - b * d01 ) / denom # (Q, M)
150- b1 = (b * d00 - a * d01 ) / denom
151-
152- # Inside test
153- inside = (b0 >= 0.0 ) & (b1 >= 0.0 ) & (b0 + b1 <= 1.0 ) # (Q, M)
154-
155- # Remove padded simplices
156- inside = inside & valid [None , :]
157-
158- # First valid simplex per point
159- simplex_idx = jnp .argmax (inside , axis = 1 ) # (Q,)
160-
161- # Detect points with no simplex match
162- has_match = jnp .any (inside , axis = 1 ) # (Q,)
163-
164- # Replace unmatched with -1
165- simplex_idx = jnp .where (has_match , simplex_idx , - 1 )
166-
167- return simplex_idx
142+ # def find_simplex_from(query_points, points, simplices):
143+ # """
144+ # Return simplex index for each query point.
145+ # Returns -1 where no simplex contains the point.
146+ # """
147+ #
148+ # import jax.numpy as jnp
149+ #
150+ # # Mask padded simplices (marked with -1)
151+ # valid = simplices[:, 0] >= 0 # (M,)
152+ # simplices_clipped = simplices.clip(min=0)
153+ #
154+ # # Triangle vertices: (M, 3, 2)
155+ # tri = points[simplices_clipped]
156+ #
157+ # p0 = tri[:, 0] # (M, 2)
158+ # p1 = tri[:, 1]
159+ # p2 = tri[:, 2]
160+ #
161+ # # Edges
162+ # v0 = p1 - p0 # (M, 2)
163+ # v1 = p2 - p0
164+ #
165+ # # Precomputed dot products
166+ # d00 = jnp.sum(v0 * v0, axis=1) # (M,)
167+ # d01 = jnp.sum(v0 * v1, axis=1)
168+ # d11 = jnp.sum(v1 * v1, axis=1)
169+ # denom = d00 * d11 - d01 * d01 # (M,)
170+ #
171+ # # Barycentric computation for each query point vs each triangle
172+ # diff = query_points[:, None, :] - p0[None, :, :] # (Q, M, 2)
173+ #
174+ # a = jnp.sum(diff * v0[None, :, :], axis=-1) # (Q, M)
175+ # b = jnp.sum(diff * v1[None, :, :], axis=-1)
176+ #
177+ # b0 = (a * d11 - b * d01) / denom # (Q, M)
178+ # b1 = (b * d00 - a * d01) / denom
179+ #
180+ # # Inside test
181+ # inside = (b0 >= 0.0) & (b1 >= 0.0) & (b0 + b1 <= 1.0) # (Q, M)
182+ #
183+ # # Remove padded simplices
184+ # inside = inside & valid[None, :]
185+ #
186+ # # First valid simplex per point
187+ # simplex_idx = jnp.argmax(inside, axis=1) # (Q,)
188+ #
189+ # # Detect points with no simplex match
190+ # has_match = jnp.any(inside, axis=1) # (Q,)
191+ #
192+ # # Replace unmatched with -1
193+ # simplex_idx = jnp.where(has_match, simplex_idx, -1)
194+ #
195+ # return simplex_idx
168196
169197
170198def split_points_from (points , area_weights , xp = np ):
@@ -223,19 +251,18 @@ def split_points_from(points, area_weights, xp=np):
223251
224252class DelaunayInterface :
225253
226- def __init__ (self , ppoints , simplices , voronoi_areas , vertex_neighbor_vertices ):
254+ def __init__ (self , points , simplices , voronoi_areas , vertex_neighbor_vertices , simplex_index_for_sub_slim_index , splitted_simplex_index_for_sub_slim_index ):
227255
228- self .points = ppoints
256+ self .points = points
229257 self .simplices = simplices
230258 self .voronoi_areas = voronoi_areas
231259 self .vertex_neighbor_vertices = vertex_neighbor_vertices
232-
233- def find_simplex (self , query_points ):
234- return find_simplex_from (query_points , self .points , self .simplices )
260+ self .simplex_index_for_sub_slim_index = simplex_index_for_sub_slim_index
261+ self .splitted_simplex_index_for_sub_slim_index = splitted_simplex_index_for_sub_slim_index
235262
236263
237264class Mesh2DDelaunay (Abstract2DMesh ):
238- def __init__ (self , values : Union [np .ndarray , List ], _xp = np ):
265+ def __init__ (self , values : Union [np .ndarray , List ], source_plane_data_grid_over_sampled = None , _xp = np ):
239266 """
240267 An irregular 2D grid of (y,x) coordinates which represents both a Delaunay triangulation and Voronoi mesh.
241268
@@ -268,6 +295,8 @@ def __init__(self, values: Union[np.ndarray, List], _xp=np):
268295
269296 super ().__init__ (values , xp = _xp )
270297
298+ self ._source_plane_data_grid_over_sampled = source_plane_data_grid_over_sampled
299+
271300 @property
272301 def geometry (self ):
273302 shape_native_scaled = (
@@ -312,7 +341,7 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
312341
313342 import jax .numpy as jnp
314343
315- points , simplices , voronoi_areas = jax_delaunay_voronoi (mesh_grid )
344+ points , simplices , voronoi_areas , simplex_index_for_sub_slim_index , splitted_simplex_index_for_sub_slim_index = jax_delaunay_voronoi (mesh_grid , self . _source_plane_data_grid_over_sampled )
316345 vertex_neighbor_vertices = None
317346
318347 else :
@@ -324,9 +353,14 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
324353 vertex_neighbor_vertices = delaunay .vertex_neighbor_vertices
325354
326355 voronoi_areas = voronoi_areas_from (mesh_grid )
356+ split_points = self .split_cross
357+ simplex_index_for_sub_slim_index = delaunay .find_simplex (self .source_plane_data_grid .over_sampled )
358+ splitted_simplex_index_for_sub_slim_index = delaunay .find_simplex (
359+ self .split_cross
360+ )
327361
328362 return DelaunayInterface (
329- points , simplices , voronoi_areas , vertex_neighbor_vertices
363+ points , simplices , voronoi_areas , vertex_neighbor_vertices , simplex_index_for_sub_slim_index , splitted_simplex_index_for_sub_slim_index
330364 )
331365
332366 @property
0 commit comments