Skip to content

Commit 7ef7104

Browse files
Jammy2211Jammy2211
authored andcommitted
black
1 parent 84cd539 commit 7ef7104

File tree

2 files changed

+44
-13
lines changed

2 files changed

+44
-13
lines changed

autoarray/structures/mesh/triangulation_2d.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from autoarray.structures.grids import grid_2d_util
1313

1414

15-
1615
def scipy_delaunay_padded(points_np, max_simplices):
1716
tri = scipy.spatial.Delaunay(points_np)
1817

@@ -44,6 +43,33 @@ def jax_delaunay(points):
4443
)
4544

4645

46+
# def delaunay_no_vmap(points):
47+
#
48+
# import jax
49+
# import jax.numpy as jnp
50+
#
51+
# # prevent batching
52+
# points = jax.lax.stop_gradient(points)
53+
#
54+
# N = points.shape[0]
55+
# max_simplices = 2 * N
56+
#
57+
# pts_spec = jax.ShapeDtypeStruct((N, 2), points.dtype)
58+
# simp_spec = jax.ShapeDtypeStruct((max_simplices, 3), jnp.int32)
59+
#
60+
# def _cb(pts):
61+
# return scipy_delaunay_padded(pts, max_simplices)
62+
#
63+
# pts_out, simplices_out = jax.pure_callback(
64+
# _cb,
65+
# (pts_spec, simp_spec),
66+
# points,
67+
# vectorized=False, # ← VERY IMPORTANT (JAX 0.4.33+)
68+
# )
69+
#
70+
# return pts_out, simplices_out
71+
72+
4773
def find_simplex_from(query_points, points, simplices):
4874
"""
4975
Return simplex index for each query point.
@@ -148,6 +174,7 @@ def circumcenters_from(points, simplices, xp=np):
148174

149175
MAX_DEG_JAX = 64
150176

177+
151178
def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
152179
"""
153180
Compute 'Voronoi-ish' cell areas for each vertex in a 2D Delaunay triangulation.
@@ -185,21 +212,21 @@ def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
185212

186213
# 2) Build a flattened vertex-triangle incidence
187214
# Each triangle contributes its circumcenter to 3 vertices.
188-
vert_ids = tris.reshape(-1) # (3M,)
215+
vert_ids = tris.reshape(-1) # (3M,)
189216
centers_rep = xp.repeat(centers, 3, axis=0) # (3M, 2)
190217

191218
# 3) Sort by vertex id so all entries for a given vertex are contiguous
192219
order = xp.argsort(vert_ids)
193-
vert_sorted = vert_ids[order] # (3M,)
220+
vert_sorted = vert_ids[order] # (3M,)
194221
centers_sorted = centers_rep[order] # (3M, 2)
195222

196223
# 4) Compute how many triangles are incident to each vertex
197224
if xp is np:
198-
counts = xp.bincount(vert_sorted, minlength=N) # (N,)
225+
counts = xp.bincount(vert_sorted, minlength=N) # (N,)
199226
max_deg = int(counts.max())
200227
else:
201-
counts = xp.bincount(vert_sorted, length=N) # (N,)
202-
max_deg = MAX_DEG_JAX # static upper bound
228+
counts = xp.bincount(vert_sorted, length=N) # (N,)
229+
max_deg = MAX_DEG_JAX # static upper bound
203230

204231
# 5) Compute start index for each vertex's block in vert_sorted
205232
# start[v] = cumulative sum of counts up to v
@@ -209,8 +236,8 @@ def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
209236
arange_all = xp.arange(3 * M, dtype=int)
210237

211238
# Position within each vertex block: pos = i - start[vertex]
212-
start_per_entry = start[vert_sorted] # (3M,)
213-
pos = arange_all - start_per_entry # (3M,)
239+
start_per_entry = start[vert_sorted] # (3M,)
240+
pos = arange_all - start_per_entry # (3M,)
214241

215242
# 6) Scatter into a padded (N, max_deg, 2) array of circumcenters
216243
circum_padded = xp.zeros((N, max_deg, 2), dtype=pts.dtype)
@@ -227,14 +254,14 @@ def voronoi_areas_via_delaunay_from(points, simplices, xp=np):
227254

228255
# Mark which slots are valid (j < count[v])
229256
j_idx = xp.arange(max_deg)[None, :]
230-
valid_mask = j_idx < counts[:, None] # (N, max_deg)
257+
valid_mask = j_idx < counts[:, None] # (N, max_deg)
231258

232259
# For invalid entries, set angle to a big constant so they go to the end
233260
big_angle = xp.array(1e9, dtype=angles.dtype)
234261
angles_masked = xp.where(valid_mask, angles, big_angle)
235262

236263
# Sort indices by angle for each vertex
237-
order_angles = xp.argsort(angles_masked, axis=1) # (N, max_deg)
264+
order_angles = xp.argsort(angles_masked, axis=1) # (N, max_deg)
238265

239266
# Reorder circumcenters accordingly
240267
centers_sorted2 = xp.take_along_axis(

test_autoarray/structures/mesh/test_delaunay.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from autoarray.structures.mesh.triangulation_2d import voronoi_areas_via_delaunay_from
77

8+
89
def test__edge_pixel_list():
910
grid = np.array(
1011
[
@@ -65,12 +66,15 @@ def test__voronoi_areas_via_delaunay_from():
6566

6667
import scipy.spatial
6768

68-
mesh_grid = np.array([[0.0, 0.0], [1.1, 0.6], [2.1, 0.1], [0.4, 1.1], [1.1, 7.1], [2.1, 1.1]])
69+
mesh_grid = np.array(
70+
[[0.0, 0.0], [1.1, 0.6], [2.1, 0.1], [0.4, 1.1], [1.1, 7.1], [2.1, 1.1]]
71+
)
6972

7073
delaunay = scipy.spatial.Delaunay(mesh_grid)
7174

7275
voronoi_areas = voronoi_areas_via_delaunay_from(
73-
mesh_grid, delaunay.simplices,
76+
mesh_grid,
77+
delaunay.simplices,
7478
)
7579

7680
voronoi = scipy.spatial.Voronoi(
@@ -100,4 +104,4 @@ def test__voronoi_areas_via_delaunay_from():
100104

101105
# Old Voronoi cell code put -1 in edge pixels, new code puts large area
102106

103-
assert voronoi_areas[4] == pytest.approx(32.83847776, 1.0e-4)
107+
assert voronoi_areas[4] == pytest.approx(32.83847776, 1.0e-4)

0 commit comments

Comments
 (0)