@@ -26,32 +26,23 @@ def scipy_delaunay_padded(points_np, max_simplices):
2626 return pts , padded
2727
2828
29- def delaunay_no_vmap (points ):
29+ def jax_delaunay (points ):
3030
3131 import jax
3232 import jax .numpy as jnp
3333
34- # prevent batching
35- points = jax .lax .stop_gradient (points )
36-
3734 N = points .shape [0 ]
3835 max_simplices = 2 * N
3936
40- pts_spec = jax .ShapeDtypeStruct ((N , 2 ), points .dtype )
41- simp_spec = jax .ShapeDtypeStruct ((max_simplices , 3 ), jnp .int32 )
42-
43- def _cb (pts ):
44- return scipy_delaunay_padded (pts , max_simplices )
37+ pts_shape = jax .ShapeDtypeStruct ((N , 2 ), points .dtype )
38+ simp_shape = jax .ShapeDtypeStruct ((max_simplices , 3 ), jnp .int32 )
4539
46- pts_out , simplices_out = jax .pure_callback (
47- _cb ,
48- (pts_spec , simp_spec ),
40+ return jax .pure_callback (
41+ lambda pts : scipy_delaunay_padded ( pts , max_simplices ) ,
42+ (pts_shape , simp_shape ),
4943 points ,
50- vectorized = False , # ← VERY IMPORTANT (JAX 0.4.33+)
5144 )
5245
53- return pts_out , simplices_out
54-
5546
5647def find_simplex_from (query_points , points , simplices ):
5748 """
@@ -415,7 +406,7 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
415406
416407 import jax .numpy as jnp
417408
418- points , simplices = delaunay_no_vmap (mesh_grid )
409+ points , simplices = jax_delaunay (mesh_grid )
419410 vertex_neighbor_vertices = None
420411
421412 else :
0 commit comments