Skip to content

Commit 123c13f

Browse files
Jammy2211Jammy2211
authored andcommitted
revert to original voro
1 parent 4547f14 commit 123c13f

File tree

1 file changed

+108
-38
lines changed

1 file changed

+108
-38
lines changed

autoarray/structures/mesh/delaunay_2d.py

Lines changed: 108 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
import scipy.spatial
3-
from scipy.spatial import cKDTree, Delaunay
3+
from scipy.spatial import cKDTree, Delaunay, Voronoi
44
from typing import List, Union, Optional, Tuple
55

66
from autoconf import cached_property
@@ -14,9 +14,11 @@
1414
from autoarray.inversion.pixelization.mesh import mesh_numba_util
1515

1616

17-
def scipy_delaunay(points_np, query_points_np, source_pixel_zeroed_indices, max_simplices=None):
17+
def scipy_delaunay(points_np, query_points_np, source_pixel_zeroed_indices):
1818
"""Compute Delaunay simplices (simplices_padded) and Voronoi areas in one call."""
1919

20+
max_simplices = 2 * points_np.shape[0]
21+
2022
# --- Delaunay mesh using source plane data grid ---
2123
tri = Delaunay(points_np)
2224

@@ -37,21 +39,31 @@ def scipy_delaunay(points_np, query_points_np, source_pixel_zeroed_indices, max_
3739
delaunay_points=points_np,
3840
)
3941

40-
# ---------- Baronicentric Dual areas ----------
41-
barycentric_dual_areas = barycentric_dual_area_from(
42-
points,
43-
simplices,
44-
xp=np,
45-
)
42+
# ---------- Baronicentric Dual used to weight split points ----------
43+
# barycentric_dual_areas = np.abs(voronoi_areas_numpy(
44+
# points,
45+
# ))
46+
47+
# barycentric_dual_areas = barycentric_dual_area_from(
48+
# points,
49+
# simplices,
50+
# xp=np,
51+
# )
4652

4753
# max_area = np.percentile(barycentric_dual_areas, 90.0)
4854
# barycentric_dual_areas[source_pixel_zeroed_indices] = max_area
4955

50-
max_area = 100.0 * np.max(barycentric_dual_areas)
51-
barycentric_dual_areas[source_pixel_zeroed_indices] = max_area
56+
# ---------- Voronoi Areas used to weight split points ----------
57+
areas = voronoi_areas_numpy(
58+
points,
59+
)
5260

53-
# ---------- Areas used to weight split points ----------
54-
split_point_areas = 0.5 * np.sqrt(barycentric_dual_areas)
61+
max_area = np.percentile(areas, 90.0)
62+
63+
areas[areas == -1] = max_area
64+
areas[areas > max_area] = max_area
65+
66+
split_point_areas = 0.5 * np.sqrt(areas)
5567

5668
# ---------- Compute split cross points for Split regularization ----------
5769
split_points = split_points_from(
@@ -88,7 +100,7 @@ def jax_delaunay(points, query_points, source_pixel_zeroed_indices):
88100

89101
return jax.pure_callback(
90102
lambda points, qpts, spzi: scipy_delaunay(
91-
np.asarray(points), np.asarray(qpts), np.asarray(spzi), max_simplices
103+
np.asarray(points), np.asarray(qpts), np.asarray(spzi),
92104
),
93105
(
94106
points_shape,
@@ -161,6 +173,84 @@ def barycentric_dual_area_from(
161173
return dual_area
162174

163175

176+
def voronoi_areas_numpy(points, qhull_options="Qbb Qc Qx Qm"):
177+
"""
178+
Compute Voronoi cell areas with a fully optimized pure-NumPy pipeline.
179+
Exact match to the per-cell SciPy Voronoi loop but much faster.
180+
"""
181+
vor = Voronoi(points, qhull_options=qhull_options)
182+
183+
vertices = vor.vertices
184+
point_region = vor.point_region
185+
regions = vor.regions
186+
N = len(point_region)
187+
188+
# ------------------------------------------------------------
189+
# 1) Collect all region lists in one go (list comprehension is fast)
190+
# ------------------------------------------------------------
191+
region_lists = [regions[r] for r in point_region]
192+
193+
# Precompute which regions are unbounded (vectorized test)
194+
unbounded = np.array([(-1 in r) for r in region_lists], dtype=bool)
195+
196+
# Filter only bounded region vertex indices
197+
clean_regions = [np.asarray([v for v in r if v != -1], dtype=int)
198+
for r in region_lists]
199+
200+
# Compute lengths once
201+
lengths = np.array([len(r) for r in clean_regions], dtype=int)
202+
max_len = lengths.max()
203+
204+
# ------------------------------------------------------------
205+
# 2) Build padded idx + mask in a vectorized-like way
206+
#
207+
# Instead of doing Python work inside the loop, we pre-pack
208+
# the flattened data and then reshape.
209+
# ------------------------------------------------------------
210+
idx = np.full((N, max_len), -1, dtype=int)
211+
mask = np.zeros((N, max_len), dtype=bool)
212+
213+
# Single loop remaining: extremely cheap
214+
for i, (r, L) in enumerate(zip(clean_regions, lengths)):
215+
if L:
216+
idx[i, :L] = r
217+
mask[i, :L] = True
218+
219+
# ------------------------------------------------------------
220+
# 3) Gather polygon vertices (vectorized)
221+
# ------------------------------------------------------------
222+
safe_idx = idx.clip(min=0)
223+
verts = vertices[safe_idx] # (N, max_len, 2)
224+
225+
# Extract x, y with masked invalid entries zeroed
226+
x = np.where(mask, verts[..., 1], 0.0)
227+
y = np.where(mask, verts[..., 0], 0.0)
228+
229+
# ------------------------------------------------------------
230+
# 4) Vectorized "previous index" per polygon
231+
# ------------------------------------------------------------
232+
safe_lengths = np.where(lengths == 0, 1, lengths)
233+
j = np.arange(max_len)
234+
prev = (j[None, :] - 1) % safe_lengths[:, None]
235+
236+
# Efficient take-along-axis
237+
x_prev = np.take_along_axis(x, prev, axis=1)
238+
y_prev = np.take_along_axis(y, prev, axis=1)
239+
240+
# ------------------------------------------------------------
241+
# 5) Shoelace vectorized
242+
# ------------------------------------------------------------
243+
cross = x * y_prev - y * x_prev
244+
areas = 0.5 * np.abs(cross.sum(axis=1))
245+
246+
# ------------------------------------------------------------
247+
# 6) Mark unbounded regions
248+
# ------------------------------------------------------------
249+
areas[unbounded] = -1.0
250+
251+
return areas
252+
253+
164254
def split_points_from(points, area_weights, xp=np):
165255
"""
166256
points : (N, 2)
@@ -348,7 +438,7 @@ def mesh_grid_xy(self):
348438
349439
Therefore, this property simply converts the (y,x) grid of irregular coordinates into an (x,y) grid.
350440
"""
351-
return self._xp.stack([self.array[:, 1], self.array[:, 0]]).T
441+
return self._xp.stack([self.array[:, 0], self.array[:, 1]]).T
352442

353443
@cached_property
354444
def delaunay(self) -> "scipy.spatial.Delaunay":
@@ -377,9 +467,9 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
377467
else:
378468

379469
points, simplices, mappings, split_points, splitted_mappings = scipy_delaunay(
380-
points=self.mesh_grid_xy,
381-
query_points_np=self.source_plane_data_grid_over_sampled,
382-
source_pixel_zeroed_indices=self.preloads.source_pixel_zeroed_indices
470+
points_np=self.mesh_grid_xy,
471+
query_points_np=self._source_plane_data_grid_over_sampled,
472+
source_pixel_zeroed_indices=self.preloads.source_pixel_zeroed_indices,
383473
)
384474

385475
return DelaunayInterface(
@@ -509,27 +599,7 @@ def voronoi(self) -> "scipy.spatial.Voronoi":
509599

510600
@property
511601
def voronoi_areas(self):
512-
513-
N = self.mesh_grid_xy.shape[0]
514-
515-
voronoi_areas = np.zeros(N)
516-
517-
for i in range(N):
518-
region_vertices_indexes = self.voronoi.regions[self.voronoi.point_region[i]]
519-
if -1 in region_vertices_indexes:
520-
voronoi_areas[i] = -1
521-
else:
522-
523-
points_of_region = self.voronoivertices[region_vertices_indexes]
524-
525-
x = points_of_region[:, 1]
526-
y = points_of_region[:, 0]
527-
528-
voronoi_areas[i] = 0.5 * np.abs(
529-
np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))
530-
)
531-
532-
return voronoi_areas
602+
return voronoi_areas_numpy(points=self.mesh_grid_xy)
533603

534604
@property
535605
def areas_for_magnification(self) -> np.ndarray:

0 commit comments

Comments
 (0)