Skip to content

Commit c635031

Browse files
Jammy2211Jammy2211
authored andcommitted
more experitmental changes
1 parent 76856ce commit c635031

File tree

4 files changed

+113
-86
lines changed

4 files changed

+113
-86
lines changed

autoarray/inversion/pixelization/mappers/delaunay.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,7 @@ def pix_sub_weights(self) -> PixSubWeights:
269269
"""
270270
delaunay = self.delaunay
271271

272-
simplex_index_for_sub_slim_index = delaunay.find_simplex(
273-
self.source_plane_data_grid.over_sampled
274-
)
272+
simplex_index_for_sub_slim_index = delaunay.simplex_index_for_sub_slim_index
275273
pix_indexes_for_simplex_index = delaunay.simplices
276274

277275
mappings, sizes = pix_indexes_for_sub_slim_index_delaunay_from(
@@ -309,9 +307,7 @@ def pix_sub_weights_split_cross(self) -> PixSubWeights:
309307
"""
310308
delaunay = self.delaunay
311309

312-
splitted_simplex_index_for_sub_slim_index = delaunay.find_simplex(
313-
self.source_plane_mesh_grid.split_cross
314-
)
310+
splitted_simplex_index_for_sub_slim_index = delaunay.splitted_simplex_index_for_sub_slim_index
315311
pix_indexes_for_simplex_index = delaunay.simplices
316312

317313
(

autoarray/inversion/pixelization/mappers/rectangular.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,6 @@ def pix_sub_weights(self) -> PixSubWeights:
9494
dimension of the array `pix_indexes_for_sub_slim_index` 1 and all entries in `pix_weights_for_sub_slim_index`
9595
are equal to 1.0.
9696
"""
97-
98-
weight_map = self.mapper_grids
99-
10097
mappings, weights = (
10198
mapper_util.adaptive_rectangular_mappings_weights_via_interpolation_from(
10299
source_grid_size=self.shape_native[0],

autoarray/inversion/pixelization/mesh/delaunay.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def mesh_grid_from(
5858
Settings controlling the pixelization for example if a border is used to relocate its exterior coordinates.
5959
"""
6060

61-
return Mesh2DDelaunay(values=source_plane_mesh_grid, _xp=xp)
61+
return Mesh2DDelaunay(values=source_plane_mesh_grid, source_plane_data_grid_over_sampled=source_plane_data_grid, _xp=xp)
6262

6363
def mapper_grids_from(
6464
self,

autoarray/structures/mesh/delaunay_2d.py

Lines changed: 110 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -25,61 +25,90 @@ def voronoi_areas_from(points_np):
2525
voronoi_vertices = vor.vertices
2626
voronoi_regions = vor.regions
2727
voronoi_point_region = vor.point_region
28-
region_areas = np.zeros(N)
28+
voronoi_areas = np.zeros(N)
2929

3030
for i in range(N):
3131
region_vertices_indexes = voronoi_regions[voronoi_point_region[i]]
3232
if -1 in region_vertices_indexes:
33-
region_areas[i] = -1
33+
voronoi_areas[i] = -1
3434
else:
3535

3636
points_of_region = voronoi_vertices[region_vertices_indexes]
3737

3838
x = points_of_region[:, 1]
3939
y = points_of_region[:, 0]
4040

41-
region_areas[i] = 0.5 * np.abs(
41+
voronoi_areas[i] = 0.5 * np.abs(
4242
np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))
4343
)
4444

45-
return region_areas
45+
voronoi_pixel_areas_for_split = voronoi_areas.copy()
4646

47+
# 90th percentile
48+
max_area = np.percentile(voronoi_pixel_areas_for_split, 90.0)
4749

48-
def scipy_delaunay_voronoi(points_np, max_simplices):
50+
voronoi_pixel_areas_for_split[voronoi_pixel_areas_for_split == -1] = max_area
51+
voronoi_pixel_areas_for_split[voronoi_pixel_areas_for_split > max_area] = max_area
52+
53+
half_region_area_sqrt_lengths = 0.5 * np.sqrt(voronoi_pixel_areas_for_split)
54+
55+
split_points = split_points_from(
56+
points=points_np,
57+
area_weights=half_region_area_sqrt_lengths,
58+
)
59+
60+
return voronoi_areas, split_points
61+
62+
63+
def scipy_delaunay_voronoi(points_np, query_points_np, max_simplices):
4964
"""Compute Delaunay simplices (padded) and Voronoi areas in one call."""
5065
from scipy.spatial import Delaunay
5166
import numpy as np
5267

53-
N = points_np.shape[0]
54-
5568
# --- Delaunay ---
5669
tri = Delaunay(points_np)
70+
5771
pts = tri.points.astype(points_np.dtype)
5872
simplices = tri.simplices.astype(np.int32)
5973

74+
# Pad simplices to max_simplices
6075
padded = -np.ones((max_simplices, 3), dtype=np.int32)
6176
padded[: simplices.shape[0]] = simplices
6277

63-
areas = voronoi_areas_from(points_np)
78+
# ---------- Voronoi cell areas ----------
79+
areas, split_points = voronoi_areas_from(points_np)
80+
81+
# ---------- find_simplex ----------
82+
simplex_idx = tri.find_simplex(query_points_np).astype(np.int32) # (Q,)
6483

65-
return pts, padded, areas
84+
# ---------- find_simplex for split cross points ----------
85+
split_cross_idx = tri.find_simplex(split_points)
6686

87+
return pts, padded, areas, simplex_idx, split_cross_idx
6788

68-
def jax_delaunay_voronoi(points):
89+
90+
def jax_delaunay_voronoi(points, query_points):
6991
import jax
7092
import jax.numpy as jnp
7193

7294
N = points.shape[0]
95+
Q = query_points.shape[0]
96+
97+
# Conservative pad (you can pass exact M if you want)
7398
max_simplices = 2 * N # same logic as before
7499

75100
pts_shape = jax.ShapeDtypeStruct((N, 2), points.dtype)
76101
simp_shape = jax.ShapeDtypeStruct((max_simplices, 3), jnp.int32)
77102
area_shape = jax.ShapeDtypeStruct((N,), points.dtype)
103+
sx_shape = jax.ShapeDtypeStruct((Q,), jnp.int32)
104+
sc_shape = jax.ShapeDtypeStruct((N*4,), jnp.int32)
78105

79106
return jax.pure_callback(
80-
lambda pts: scipy_delaunay_voronoi(pts, max_simplices),
81-
(pts_shape, simp_shape, area_shape),
82-
points,
107+
lambda pts, qpts: scipy_delaunay_voronoi(
108+
np.asarray(pts), np.asarray(qpts), max_simplices
109+
),
110+
(pts_shape, simp_shape, area_shape, sx_shape, sc_shape),
111+
points, query_points,
83112
)
84113

85114

@@ -110,61 +139,60 @@ def jax_delaunay_voronoi(points):
110139
# return pts_out, simplices_out
111140

112141

113-
def find_simplex_from(query_points, points, simplices):
114-
"""
115-
Return simplex index for each query point.
116-
Returns -1 where no simplex contains the point.
117-
"""
118-
119-
import jax
120-
import jax.numpy as jnp
121-
122-
# Mask padded simplices (marked with -1)
123-
valid = simplices[:, 0] >= 0 # (M,)
124-
simplices_clipped = simplices.clip(min=0)
125-
126-
# Triangle vertices: (M, 3, 2)
127-
tri = points[simplices_clipped]
128-
129-
p0 = tri[:, 0] # (M, 2)
130-
p1 = tri[:, 1]
131-
p2 = tri[:, 2]
132-
133-
# Edges
134-
v0 = p1 - p0 # (M, 2)
135-
v1 = p2 - p0
136-
137-
# Precomputed dot products
138-
d00 = jnp.sum(v0 * v0, axis=1) # (M,)
139-
d01 = jnp.sum(v0 * v1, axis=1)
140-
d11 = jnp.sum(v1 * v1, axis=1)
141-
denom = d00 * d11 - d01 * d01 # (M,)
142-
143-
# Barycentric computation for each query point vs each triangle
144-
diff = query_points[:, None, :] - p0[None, :, :] # (Q, M, 2)
145-
146-
a = jnp.sum(diff * v0[None, :, :], axis=-1) # (Q, M)
147-
b = jnp.sum(diff * v1[None, :, :], axis=-1)
148-
149-
b0 = (a * d11 - b * d01) / denom # (Q, M)
150-
b1 = (b * d00 - a * d01) / denom
151-
152-
# Inside test
153-
inside = (b0 >= 0.0) & (b1 >= 0.0) & (b0 + b1 <= 1.0) # (Q, M)
154-
155-
# Remove padded simplices
156-
inside = inside & valid[None, :]
157-
158-
# First valid simplex per point
159-
simplex_idx = jnp.argmax(inside, axis=1) # (Q,)
160-
161-
# Detect points with no simplex match
162-
has_match = jnp.any(inside, axis=1) # (Q,)
163-
164-
# Replace unmatched with -1
165-
simplex_idx = jnp.where(has_match, simplex_idx, -1)
166-
167-
return simplex_idx
142+
# def find_simplex_from(query_points, points, simplices):
143+
# """
144+
# Return simplex index for each query point.
145+
# Returns -1 where no simplex contains the point.
146+
# """
147+
#
148+
# import jax.numpy as jnp
149+
#
150+
# # Mask padded simplices (marked with -1)
151+
# valid = simplices[:, 0] >= 0 # (M,)
152+
# simplices_clipped = simplices.clip(min=0)
153+
#
154+
# # Triangle vertices: (M, 3, 2)
155+
# tri = points[simplices_clipped]
156+
#
157+
# p0 = tri[:, 0] # (M, 2)
158+
# p1 = tri[:, 1]
159+
# p2 = tri[:, 2]
160+
#
161+
# # Edges
162+
# v0 = p1 - p0 # (M, 2)
163+
# v1 = p2 - p0
164+
#
165+
# # Precomputed dot products
166+
# d00 = jnp.sum(v0 * v0, axis=1) # (M,)
167+
# d01 = jnp.sum(v0 * v1, axis=1)
168+
# d11 = jnp.sum(v1 * v1, axis=1)
169+
# denom = d00 * d11 - d01 * d01 # (M,)
170+
#
171+
# # Barycentric computation for each query point vs each triangle
172+
# diff = query_points[:, None, :] - p0[None, :, :] # (Q, M, 2)
173+
#
174+
# a = jnp.sum(diff * v0[None, :, :], axis=-1) # (Q, M)
175+
# b = jnp.sum(diff * v1[None, :, :], axis=-1)
176+
#
177+
# b0 = (a * d11 - b * d01) / denom # (Q, M)
178+
# b1 = (b * d00 - a * d01) / denom
179+
#
180+
# # Inside test
181+
# inside = (b0 >= 0.0) & (b1 >= 0.0) & (b0 + b1 <= 1.0) # (Q, M)
182+
#
183+
# # Remove padded simplices
184+
# inside = inside & valid[None, :]
185+
#
186+
# # First valid simplex per point
187+
# simplex_idx = jnp.argmax(inside, axis=1) # (Q,)
188+
#
189+
# # Detect points with no simplex match
190+
# has_match = jnp.any(inside, axis=1) # (Q,)
191+
#
192+
# # Replace unmatched with -1
193+
# simplex_idx = jnp.where(has_match, simplex_idx, -1)
194+
#
195+
# return simplex_idx
168196

169197

170198
def split_points_from(points, area_weights, xp=np):
@@ -223,19 +251,18 @@ def split_points_from(points, area_weights, xp=np):
223251

224252
class DelaunayInterface:
225253

226-
def __init__(self, ppoints, simplices, voronoi_areas, vertex_neighbor_vertices):
254+
def __init__(self, points, simplices, voronoi_areas, vertex_neighbor_vertices, simplex_index_for_sub_slim_index, splitted_simplex_index_for_sub_slim_index):
227255

228-
self.points = ppoints
256+
self.points = points
229257
self.simplices = simplices
230258
self.voronoi_areas = voronoi_areas
231259
self.vertex_neighbor_vertices = vertex_neighbor_vertices
232-
233-
def find_simplex(self, query_points):
234-
return find_simplex_from(query_points, self.points, self.simplices)
260+
self.simplex_index_for_sub_slim_index = simplex_index_for_sub_slim_index
261+
self.splitted_simplex_index_for_sub_slim_index = splitted_simplex_index_for_sub_slim_index
235262

236263

237264
class Mesh2DDelaunay(Abstract2DMesh):
238-
def __init__(self, values: Union[np.ndarray, List], _xp=np):
265+
def __init__(self, values: Union[np.ndarray, List], source_plane_data_grid_over_sampled=None, _xp=np):
239266
"""
240267
An irregular 2D grid of (y,x) coordinates which represents both a Delaunay triangulation and Voronoi mesh.
241268
@@ -268,6 +295,8 @@ def __init__(self, values: Union[np.ndarray, List], _xp=np):
268295

269296
super().__init__(values, xp=_xp)
270297

298+
self._source_plane_data_grid_over_sampled = source_plane_data_grid_over_sampled
299+
271300
@property
272301
def geometry(self):
273302
shape_native_scaled = (
@@ -312,7 +341,7 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
312341

313342
import jax.numpy as jnp
314343

315-
points, simplices, voronoi_areas = jax_delaunay_voronoi(mesh_grid)
344+
points, simplices, voronoi_areas, simplex_index_for_sub_slim_index, splitted_simplex_index_for_sub_slim_index = jax_delaunay_voronoi(mesh_grid, self._source_plane_data_grid_over_sampled)
316345
vertex_neighbor_vertices = None
317346

318347
else:
@@ -324,9 +353,14 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
324353
vertex_neighbor_vertices = delaunay.vertex_neighbor_vertices
325354

326355
voronoi_areas = voronoi_areas_from(mesh_grid)
356+
split_points = self.split_cross
357+
simplex_index_for_sub_slim_index = delaunay.find_simplex(self.source_plane_data_grid.over_sampled)
358+
splitted_simplex_index_for_sub_slim_index = delaunay.find_simplex(
359+
self.split_cross
360+
)
327361

328362
return DelaunayInterface(
329-
points, simplices, voronoi_areas, vertex_neighbor_vertices
363+
points, simplices, voronoi_areas, vertex_neighbor_vertices, simplex_index_for_sub_slim_index, splitted_simplex_index_for_sub_slim_index
330364
)
331365

332366
@property

0 commit comments

Comments
 (0)