1212from autoarray .structures .grids import grid_2d_util
1313
1414
15-
1615def scipy_delaunay_padded (points_np , max_simplices ):
1716 tri = scipy .spatial .Delaunay (points_np )
1817
@@ -44,6 +43,33 @@ def jax_delaunay(points):
4443 )
4544
4645
46+ # def delaunay_no_vmap(points):
47+ #
48+ # import jax
49+ # import jax.numpy as jnp
50+ #
51+ # # prevent batching
52+ # points = jax.lax.stop_gradient(points)
53+ #
54+ # N = points.shape[0]
55+ # max_simplices = 2 * N
56+ #
57+ # pts_spec = jax.ShapeDtypeStruct((N, 2), points.dtype)
58+ # simp_spec = jax.ShapeDtypeStruct((max_simplices, 3), jnp.int32)
59+ #
60+ # def _cb(pts):
61+ # return scipy_delaunay_padded(pts, max_simplices)
62+ #
63+ # pts_out, simplices_out = jax.pure_callback(
64+ # _cb,
65+ # (pts_spec, simp_spec),
66+ # points,
67+ # vectorized=False, # ← VERY IMPORTANT (JAX 0.4.33+)
68+ # )
69+ #
70+ # return pts_out, simplices_out
71+
72+
4773def find_simplex_from (query_points , points , simplices ):
4874 """
4975 Return simplex index for each query point.
@@ -148,6 +174,7 @@ def circumcenters_from(points, simplices, xp=np):
148174
149175MAX_DEG_JAX = 64
150176
177+
151178def voronoi_areas_via_delaunay_from (points , simplices , xp = np ):
152179 """
153180 Compute 'Voronoi-ish' cell areas for each vertex in a 2D Delaunay triangulation.
@@ -185,21 +212,21 @@ def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
185212
186213 # 2) Build a flattened vertex-triangle incidence
187214 # Each triangle contributes its circumcenter to 3 vertices.
188- vert_ids = tris .reshape (- 1 ) # (3M,)
215+ vert_ids = tris .reshape (- 1 ) # (3M,)
189216 centers_rep = xp .repeat (centers , 3 , axis = 0 ) # (3M, 2)
190217
191218 # 3) Sort by vertex id so all entries for a given vertex are contiguous
192219 order = xp .argsort (vert_ids )
193- vert_sorted = vert_ids [order ] # (3M,)
220+ vert_sorted = vert_ids [order ] # (3M,)
194221 centers_sorted = centers_rep [order ] # (3M, 2)
195222
196223 # 4) Compute how many triangles are incident to each vertex
197224 if xp is np :
198- counts = xp .bincount (vert_sorted , minlength = N ) # (N,)
225+ counts = xp .bincount (vert_sorted , minlength = N ) # (N,)
199226 max_deg = int (counts .max ())
200227 else :
201- counts = xp .bincount (vert_sorted , length = N ) # (N,)
202- max_deg = MAX_DEG_JAX # static upper bound
228+ counts = xp .bincount (vert_sorted , length = N ) # (N,)
229+ max_deg = MAX_DEG_JAX # static upper bound
203230
204231 # 5) Compute start index for each vertex's block in vert_sorted
205232 # start[v] = cumulative sum of counts up to v
@@ -209,8 +236,8 @@ def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
209236 arange_all = xp .arange (3 * M , dtype = int )
210237
211238 # Position within each vertex block: pos = i - start[vertex]
212- start_per_entry = start [vert_sorted ] # (3M,)
213- pos = arange_all - start_per_entry # (3M,)
239+ start_per_entry = start [vert_sorted ] # (3M,)
240+ pos = arange_all - start_per_entry # (3M,)
214241
215242 # 6) Scatter into a padded (N, max_deg, 2) array of circumcenters
216243 circum_padded = xp .zeros ((N , max_deg , 2 ), dtype = pts .dtype )
@@ -227,14 +254,14 @@ def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
227254
228255 # Mark which slots are valid (j < count[v])
229256 j_idx = xp .arange (max_deg )[None , :]
230- valid_mask = j_idx < counts [:, None ] # (N, max_deg)
257+ valid_mask = j_idx < counts [:, None ] # (N, max_deg)
231258
232259 # For invalid entries, set angle to a big constant so they go to the end
233260 big_angle = xp .array (1e9 , dtype = angles .dtype )
234261 angles_masked = xp .where (valid_mask , angles , big_angle )
235262
236263 # Sort indices by angle for each vertex
237- order_angles = xp .argsort (angles_masked , axis = 1 ) # (N, max_deg)
264+ order_angles = xp .argsort (angles_masked , axis = 1 ) # (N, max_deg)
238265
239266 # Reorder circumcenters accordingly
240267 centers_sorted2 = xp .take_along_axis (
0 commit comments