Skip to content

Commit 1d88968

Browse files
Jammy2211Jammy2211
authored andcommitted
revert to simpler
1 parent 5fc944d commit 1d88968

File tree

1 file changed

+7
-16
lines changed

1 file changed

+7
-16
lines changed

autoarray/structures/mesh/triangulation_2d.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,32 +26,23 @@ def scipy_delaunay_padded(points_np, max_simplices):
2626
return pts, padded
2727

2828

29-
def delaunay_no_vmap(points):
29+
def jax_delaunay(points):
3030

3131
import jax
3232
import jax.numpy as jnp
3333

34-
# prevent batching
35-
points = jax.lax.stop_gradient(points)
36-
3734
N = points.shape[0]
3835
max_simplices = 2 * N
3936

40-
pts_spec = jax.ShapeDtypeStruct((N, 2), points.dtype)
41-
simp_spec = jax.ShapeDtypeStruct((max_simplices, 3), jnp.int32)
42-
43-
def _cb(pts):
44-
return scipy_delaunay_padded(pts, max_simplices)
37+
pts_shape = jax.ShapeDtypeStruct((N, 2), points.dtype)
38+
simp_shape = jax.ShapeDtypeStruct((max_simplices, 3), jnp.int32)
4539

46-
pts_out, simplices_out = jax.pure_callback(
47-
_cb,
48-
(pts_spec, simp_spec),
40+
return jax.pure_callback(
41+
lambda pts: scipy_delaunay_padded(pts, max_simplices),
42+
(pts_shape, simp_shape),
4943
points,
50-
vectorized=False, # ← VERY IMPORTANT (JAX 0.4.33+)
5144
)
5245

53-
return pts_out, simplices_out
54-
5546

5647
def find_simplex_from(query_points, points, simplices):
5748
"""
@@ -415,7 +406,7 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
415406

416407
import jax.numpy as jnp
417408

418-
points, simplices = delaunay_no_vmap(mesh_grid)
409+
points, simplices = jax_delaunay(mesh_grid)
419410
vertex_neighbor_vertices = None
420411

421412
else:

0 commit comments

Comments
 (0)