@@ -26,23 +26,32 @@ def scipy_delaunay_padded(points_np, max_simplices):
2626 return pts , padded
2727
2828
29- def jax_delaunay (points ):
29+ def delaunay_no_vmap (points ):
3030
3131 import jax
3232 import jax .numpy as jnp
3333
34+ # prevent batching
35+ points = jax .lax .stop_gradient (points )
36+
3437 N = points .shape [0 ]
3538 max_simplices = 2 * N
3639
37- pts_shape = jax .ShapeDtypeStruct ((N , 2 ), points .dtype )
38- simp_shape = jax .ShapeDtypeStruct ((max_simplices , 3 ), jnp .int32 )
40+ pts_spec = jax .ShapeDtypeStruct ((N , 2 ), points .dtype )
41+ simp_spec = jax .ShapeDtypeStruct ((max_simplices , 3 ), jnp .int32 )
42+
43+ def _cb (pts ):
44+ return scipy_delaunay_padded (pts , max_simplices )
3945
40- return jax .pure_callback (
41- lambda pts : scipy_delaunay_padded ( pts , max_simplices ) ,
42- (pts_shape , simp_shape ),
46+ pts_out , simplices_out = jax .pure_callback (
47+ _cb ,
48+ (pts_spec , simp_spec ),
4349 points ,
50+ vectorized = False , # ← VERY IMPORTANT (JAX 0.4.33+)
4451 )
4552
53+ return pts_out , simplices_out
54+
4655
4756def find_simplex_from (query_points , points , simplices ):
4857 """
@@ -146,7 +155,7 @@ def circumcenters_from(points, simplices, xp=np):
146155 return centers
147156
148157
149- MAX_DEG_JAX = 32
158+ MAX_DEG_JAX = 64
150159
151160def voronoi_areas_via_delaunay_from (points , simplices , xp = np ):
152161 """
@@ -406,7 +415,7 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
406415
407416 import jax .numpy as jnp
408417
409- points , simplices = jax_delaunay (mesh_grid )
418+ points , simplices = delaunay_no_vmap (mesh_grid )
410419 vertex_neighbor_vertices = None
411420
412421 else :
0 commit comments