Skip to content

Commit 5fc944d

Browse files
Jammy2211Jammy2211
authored andcommitted
vectorized jax?
1 parent 5d9a629 commit 5fc944d

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

autoarray/structures/mesh/triangulation_2d.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,23 +26,32 @@ def scipy_delaunay_padded(points_np, max_simplices):
2626
return pts, padded
2727

2828

29-
def jax_delaunay(points):
29+
def delaunay_no_vmap(points):
3030

3131
import jax
3232
import jax.numpy as jnp
3333

34+
# prevent batching
35+
points = jax.lax.stop_gradient(points)
36+
3437
N = points.shape[0]
3538
max_simplices = 2 * N
3639

37-
pts_shape = jax.ShapeDtypeStruct((N, 2), points.dtype)
38-
simp_shape = jax.ShapeDtypeStruct((max_simplices, 3), jnp.int32)
40+
pts_spec = jax.ShapeDtypeStruct((N, 2), points.dtype)
41+
simp_spec = jax.ShapeDtypeStruct((max_simplices, 3), jnp.int32)
42+
43+
def _cb(pts):
44+
return scipy_delaunay_padded(pts, max_simplices)
3945

40-
return jax.pure_callback(
41-
lambda pts: scipy_delaunay_padded(pts, max_simplices),
42-
(pts_shape, simp_shape),
46+
pts_out, simplices_out = jax.pure_callback(
47+
_cb,
48+
(pts_spec, simp_spec),
4349
points,
50+
vectorized=False, # ← VERY IMPORTANT (JAX 0.4.33+)
4451
)
4552

53+
return pts_out, simplices_out
54+
4655

4756
def find_simplex_from(query_points, points, simplices):
4857
"""
@@ -146,7 +155,7 @@ def circumcenters_from(points, simplices, xp=np):
146155
return centers
147156

148157

149-
MAX_DEG_JAX = 32
158+
MAX_DEG_JAX = 64
150159

151160
def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
152161
"""
@@ -406,7 +415,7 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
406415

407416
import jax.numpy as jnp
408417

409-
points, simplices = jax_delaunay(mesh_grid)
418+
points, simplices = delaunay_no_vmap(mesh_grid)
410419
vertex_neighbor_vertices = None
411420

412421
else:

0 commit comments

Comments
 (0)