Skip to content

Commit 40581d1

Browse files
Jammy2211Jammy2211
authored andcommitted
fix copy and paste error
1 parent 354fed2 commit 40581d1

File tree

1 file changed

+63
-1
lines changed

1 file changed

+63
-1
lines changed

autoarray/structures/mesh/delaunay_2d.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,66 @@ def scipy_delaunay(points_np, query_points_np, use_voronoi_areas, areas_factor):
6464

6565
max_simplices = 2 * points_np.shape[0]
6666

67+
# --- Delaunay mesh using source plane data grid ---
68+
tri = Delaunay(points_np)
69+
70+
points = tri.points.astype(points_np.dtype)
71+
simplices = tri.simplices.astype(np.int32)
72+
73+
# Pad simplices to max_simplices
74+
simplices_padded = -np.ones((max_simplices, 3), dtype=np.int32)
75+
simplices_padded[: simplices.shape[0]] = simplices
76+
77+
# ---------- find_simplex for source plane data grid ----------
78+
simplex_idx = tri.find_simplex(query_points_np).astype(np.int32) # (Q,)
79+
80+
mappings = pix_indexes_for_sub_slim_index_delaunay_from(
81+
source_plane_data_grid=query_points_np,
82+
simplex_index_for_sub_slim_index=simplex_idx,
83+
pix_indexes_for_simplex_index=simplices,
84+
delaunay_points=points_np,
85+
)
86+
87+
# ---------- Voronoi or Barycentric Areas used to weight split points ----------
88+
89+
if use_voronoi_areas:
90+
91+
areas = voronoi_areas_numpy(
92+
points,
93+
)
94+
95+
max_area = np.percentile(areas, 90.0)
96+
97+
areas[areas == -1] = max_area
98+
areas[areas > max_area] = max_area
99+
100+
else:
101+
102+
areas = barycentric_dual_area_from(
103+
points,
104+
simplices,
105+
xp=np,
106+
)
107+
108+
split_point_areas = areas_factor * np.sqrt(areas)
109+
110+
# ---------- Compute split cross points for Split regularization ----------
111+
split_points = split_points_from(
112+
points=points_np,
113+
area_weights=split_point_areas,
114+
)
115+
116+
# ---------- find_simplex for split cross points ----------
117+
split_points_idx = tri.find_simplex(split_points)
118+
119+
splitted_mappings = pix_indexes_for_sub_slim_index_delaunay_from(
120+
source_plane_data_grid=split_points,
121+
simplex_index_for_sub_slim_index=split_points_idx,
122+
pix_indexes_for_simplex_index=simplices,
123+
delaunay_points=points_np,
124+
)
125+
126+
return points, simplices_padded, mappings, split_points, splitted_mappings
67127

68128

69129
def jax_delaunay(points, query_points, use_voronoi_areas, areas_factor=0.5):
@@ -81,7 +141,9 @@ def jax_delaunay(points, query_points, use_voronoi_areas, areas_factor=0.5):
81141
splitted_mappings_shape = jax.ShapeDtypeStruct((N * 4, 3), jnp.int32)
82142

83143
return jax.pure_callback(
84-
lambda points, qpts: scipy_delaunay(points, qpts, use_voronoi_areas, areas_factor),
144+
lambda points, qpts: scipy_delaunay(
145+
np.asarray(points), np.asarray(qpts), use_voronoi_areas, areas_factor
146+
),
85147
(
86148
points_shape,
87149
simplices_padded_shape,

0 commit comments

Comments
 (0)