@@ -64,6 +64,66 @@ def scipy_delaunay(points_np, query_points_np, use_voronoi_areas, areas_factor):
6464
6565 max_simplices = 2 * points_np .shape [0 ]
6666
67+ # --- Delaunay mesh using source plane data grid ---
68+ tri = Delaunay (points_np )
69+
70+ points = tri .points .astype (points_np .dtype )
71+ simplices = tri .simplices .astype (np .int32 )
72+
73+ # Pad simplices to max_simplices
74+ simplices_padded = - np .ones ((max_simplices , 3 ), dtype = np .int32 )
75+ simplices_padded [: simplices .shape [0 ]] = simplices
76+
77+ # ---------- find_simplex for source plane data grid ----------
78+ simplex_idx = tri .find_simplex (query_points_np ).astype (np .int32 ) # (Q,)
79+
80+ mappings = pix_indexes_for_sub_slim_index_delaunay_from (
81+ source_plane_data_grid = query_points_np ,
82+ simplex_index_for_sub_slim_index = simplex_idx ,
83+ pix_indexes_for_simplex_index = simplices ,
84+ delaunay_points = points_np ,
85+ )
86+
87+ # ---------- Voronoi or Barycentric Areas used to weight split points ----------
88+
89+ if use_voronoi_areas :
90+
91+ areas = voronoi_areas_numpy (
92+ points ,
93+ )
94+
95+ max_area = np .percentile (areas , 90.0 )
96+
97+ areas [areas == - 1 ] = max_area
98+ areas [areas > max_area ] = max_area
99+
100+ else :
101+
102+ areas = barycentric_dual_area_from (
103+ points ,
104+ simplices ,
105+ xp = np ,
106+ )
107+
108+ split_point_areas = areas_factor * np .sqrt (areas )
109+
110+ # ---------- Compute split cross points for Split regularization ----------
111+ split_points = split_points_from (
112+ points = points_np ,
113+ area_weights = split_point_areas ,
114+ )
115+
116+ # ---------- find_simplex for split cross points ----------
117+ split_points_idx = tri .find_simplex (split_points )
118+
119+ splitted_mappings = pix_indexes_for_sub_slim_index_delaunay_from (
120+ source_plane_data_grid = split_points ,
121+ simplex_index_for_sub_slim_index = split_points_idx ,
122+ pix_indexes_for_simplex_index = simplices ,
123+ delaunay_points = points_np ,
124+ )
125+
126+ return points , simplices_padded , mappings , split_points , splitted_mappings
67127
68128
69129def jax_delaunay (points , query_points , use_voronoi_areas , areas_factor = 0.5 ):
@@ -81,7 +141,9 @@ def jax_delaunay(points, query_points, use_voronoi_areas, areas_factor=0.5):
81141 splitted_mappings_shape = jax .ShapeDtypeStruct ((N * 4 , 3 ), jnp .int32 )
82142
83143 return jax .pure_callback (
84- lambda points , qpts : scipy_delaunay (points , qpts , use_voronoi_areas , areas_factor ),
144+ lambda points , qpts : scipy_delaunay (
145+ np .asarray (points ), np .asarray (qpts ), use_voronoi_areas , areas_factor
146+ ),
85147 (
86148 points_shape ,
87149 simplices_padded_shape ,
0 commit comments