|
3 | 3 | from typing import Tuple |
4 | 4 |
|
5 | 5 |
|
6 | | -def forward_interp(xp, yp, x): |
7 | 6 |
|
8 | | - import jax |
9 | | - import jax.numpy as jnp |
10 | 7 |
|
11 | | - return jax.vmap(jnp.interp, in_axes=(1, 1, 1, None, None), out_axes=(1))( |
12 | | - x, xp, yp, 0, 1 |
13 | | - ) |
14 | 8 |
|
15 | 9 |
|
16 | | -def reverse_interp(xp, yp, x): |
17 | | - import jax |
18 | | - import jax.numpy as jnp |
19 | | - |
20 | | - return jax.vmap(jnp.interp, in_axes=(1, 1, 1), out_axes=(1))(x, xp, yp) |
21 | | - |
22 | | - |
23 | | -def forward_interp_np(xp, yp, x): |
24 | | - """ |
25 | | - xp: (N, M) |
26 | | - yp: (N, M) |
27 | | - x : (M,) ← one x per column |
28 | | - """ |
29 | | - |
30 | | - if yp.ndim == 1 and xp.ndim == 2: |
31 | | - yp = np.broadcast_to(yp[:, None], xp.shape) |
32 | | - |
33 | | - K, M = x.shape |
34 | | - |
35 | | - out = np.empty((K, 2), dtype=xp.dtype) |
36 | | - |
37 | | - for j in range(2): |
38 | | - out[:, j] = np.interp(x[:, j], xp[:, j], yp[:, j], left=0, right=1) |
39 | | - |
40 | | - return out |
41 | | - |
42 | | - |
43 | | -def reverse_interp_np(xp, yp, x): |
44 | | - """ |
45 | | - xp : (N,) or (N, M) |
46 | | - yp : (N, M) |
47 | | - x : (K, M) query points per column |
48 | | - """ |
49 | | - |
50 | | - # Ensure xp is 2D: (N, M) |
51 | | - if xp.ndim == 1 and yp.ndim == 2: # (N, 1) |
52 | | - xp = np.broadcast_to(xp[:, None], yp.shape) |
53 | | - |
54 | | - # Shapes |
55 | | - K, M = x.shape |
56 | | - |
57 | | - # Output |
58 | | - out = np.empty((K, 2), dtype=yp.dtype) |
59 | | - |
60 | | - # Column-wise interpolation (cannot avoid this loop in pure NumPy) |
61 | | - for j in range(2): |
62 | | - out[:, j] = np.interp(x[:, j], xp[:, j], yp[:, j]) |
63 | | - |
64 | | - return out |
65 | | - |
66 | | - |
67 | | -def create_transforms(traced_points, mesh_weight_map=None, xp=np): |
68 | | - |
69 | | - N = traced_points.shape[0] # // 2 |
70 | | - |
71 | | - if mesh_weight_map is None: |
72 | | - t = xp.arange(1, N + 1) / (N + 1) |
73 | | - t = xp.stack([t, t], axis=1) |
74 | | - sort_points = xp.sort(traced_points, axis=0) # [::2] |
75 | | - else: |
76 | | - sdx = xp.argsort(traced_points, axis=0) |
77 | | - sort_points = xp.take_along_axis(traced_points, sdx, axis=0) |
78 | | - t = xp.stack([mesh_weight_map, mesh_weight_map], axis=1) |
79 | | - t = xp.take_along_axis(t, sdx, axis=0) |
80 | | - t = xp.cumsum(t, axis=0) |
81 | | - |
82 | | - if xp.__name__.startswith("jax"): |
83 | | - transform = partial(forward_interp, sort_points, t) |
84 | | - inv_transform = partial(reverse_interp, t, sort_points) |
85 | | - return transform, inv_transform |
86 | | - |
87 | | - transform = partial(forward_interp_np, sort_points, t) |
88 | | - inv_transform = partial(reverse_interp_np, t, sort_points) |
89 | | - return transform, inv_transform |
90 | | - |
91 | | - |
92 | | -def adaptive_rectangular_transformed_grid_from( |
93 | | - source_plane_data_grid, grid, mesh_weight_map=None, xp=np |
94 | | -): |
95 | | - |
96 | | - mu = source_plane_data_grid.mean(axis=0) |
97 | | - scale = source_plane_data_grid.std(axis=0).min() |
98 | | - source_grid_scaled = (source_plane_data_grid - mu) / scale |
99 | | - |
100 | | - transform, inv_transform = create_transforms( |
101 | | - source_grid_scaled, mesh_weight_map=mesh_weight_map, xp=xp |
102 | | - ) |
103 | | - |
104 | | - def inv_full(U): |
105 | | - return inv_transform(U) * scale + mu |
106 | | - |
107 | | - return inv_full(grid) |
108 | | - |
109 | | - |
110 | | -def adaptive_rectangular_areas_from( |
111 | | - source_grid_shape, source_plane_data_grid, mesh_weight_map=None, xp=np |
112 | | -): |
113 | | - |
114 | | - edges_y = xp.linspace(1, 0, source_grid_shape[0] + 1) |
115 | | - edges_x = xp.linspace(0, 1, source_grid_shape[1] + 1) |
116 | | - |
117 | | - mu = source_plane_data_grid.mean(axis=0) |
118 | | - scale = source_plane_data_grid.std(axis=0).min() |
119 | | - source_grid_scaled = (source_plane_data_grid - mu) / scale |
120 | | - |
121 | | - transform, inv_transform = create_transforms( |
122 | | - source_grid_scaled, mesh_weight_map=mesh_weight_map, xp=xp |
123 | | - ) |
124 | | - |
125 | | - def inv_full(U): |
126 | | - return inv_transform(U) * scale + mu |
127 | | - |
128 | | - pixel_edges = inv_full(xp.stack([edges_y, edges_x]).T) |
129 | | - pixel_lengths = xp.diff(pixel_edges, axis=0).squeeze() # shape (N_source, 2) |
130 | | - |
131 | | - dy = pixel_lengths[:, 0] |
132 | | - dx = pixel_lengths[:, 1] |
133 | | - |
134 | | - return xp.abs(xp.outer(dy, dx).flatten()) |
135 | | - |
136 | | - |
137 | | -def adaptive_rectangular_mappings_weights_via_interpolation_from( |
138 | | - source_grid_size: int, |
139 | | - source_plane_data_grid, |
140 | | - source_plane_data_grid_over_sampled, |
141 | | - mesh_weight_map=None, |
142 | | - xp=np, |
143 | | -): |
144 | | - """ |
145 | | - Compute bilinear interpolation indices and weights for mapping an oversampled |
146 | | - source-plane grid onto a regular rectangular pixelization. |
147 | | -
|
148 | | - This function takes a set of irregularly-sampled source-plane coordinates and |
149 | | - builds an adaptive mapping onto a `source_grid_size x source_grid_size` rectangular |
150 | | - pixelization using bilinear interpolation. The interpolation is expressed as: |
151 | | -
|
152 | | - f(x, y) ≈ w_bl * f(ix_down, iy_down) + |
153 | | - w_br * f(ix_up, iy_down) + |
154 | | - w_tl * f(ix_down, iy_up) + |
155 | | - w_tr * f(ix_up, iy_up) |
156 | | -
|
157 | | - where `(ix_down, ix_up, iy_down, iy_up)` are the integer grid coordinates |
158 | | - surrounding the continuous position `(x, y)`. |
159 | | -
|
160 | | - Steps performed: |
161 | | - 1. Normalize the source-plane grid by subtracting its mean and dividing by |
162 | | - the minimum axis standard deviation (to balance scaling). |
163 | | - 2. Construct forward/inverse transforms which map the grid into the unit square [0,1]^2. |
164 | | - 3. Transform the oversampled source-plane grid into [0,1]^2, then scale it |
165 | | - to index space `[0, source_grid_size)`. |
166 | | - 4. Compute floor/ceil along x and y axes to find the enclosing rectangular cell. |
167 | | - 5. Build the four corner indices: bottom-left (bl), bottom-right (br), |
168 | | - top-left (tl), and top-right (tr). |
169 | | - 6. Flatten the 2D indices into 1D indices suitable for scatter operations, |
170 | | - with a flipped row-major convention: row = source_grid_size - i, col = j. |
171 | | - 7. Compute bilinear interpolation weights (`w_bl, w_br, w_tl, w_tr`). |
172 | | - 8. Return arrays of flattened indices and weights of shape `(N, 4)`, where |
173 | | - `N` is the number of oversampled coordinates. |
174 | | -
|
175 | | - Parameters |
176 | | - ---------- |
177 | | - source_grid_size : int |
178 | | - The number of pixels along one dimension of the rectangular pixelization. |
179 | | - The grid is square: (source_grid_size x source_grid_size). |
180 | | - source_plane_data_grid : (M, 2) ndarray |
181 | | - The base source-plane coordinates, used to define normalization and transforms. |
182 | | - source_plane_data_grid_over_sampled : (N, 2) ndarray |
183 | | - Oversampled source-plane coordinates to be interpolated onto the rectangular grid. |
184 | | - mesh_weight_map |
185 | | - The weight map used to weight the creation of the rectangular mesh grid, which is used for the |
186 | | - `RectangularBrightness` mesh which adapts the size of its pixels to where the source is reconstructed. |
187 | | -
|
188 | | - Returns |
189 | | - ------- |
190 | | - flat_indices : (N, 4) int ndarray |
191 | | - The flattened indices of the four neighboring pixel corners for each oversampled point. |
192 | | - Order: [bl, br, tl, tr]. |
193 | | - weights : (N, 4) float ndarray |
194 | | - The bilinear interpolation weights for each of the four neighboring pixels. |
195 | | - Order: [w_bl, w_br, w_tl, w_tr]. |
196 | | - """ |
197 | | - |
198 | | - # --- Step 1. Normalize grid --- |
199 | | - mu = source_plane_data_grid.mean(axis=0) |
200 | | - scale = source_plane_data_grid.std(axis=0).min() |
201 | | - source_grid_scaled = (source_plane_data_grid - mu) / scale |
202 | | - |
203 | | - # --- Step 2. Build transforms --- |
204 | | - transform, inv_transform = create_transforms( |
205 | | - source_grid_scaled, mesh_weight_map=mesh_weight_map, xp=xp |
206 | | - ) |
207 | | - |
208 | | - # --- Step 3. Transform oversampled grid into index space --- |
209 | | - grid_over_sampled_scaled = (source_plane_data_grid_over_sampled - mu) / scale |
210 | | - grid_over_sampled_transformed = transform(grid_over_sampled_scaled) |
211 | | - grid_over_index = (source_grid_size - 3) * grid_over_sampled_transformed + 1 |
212 | | - |
213 | | - # --- Step 4. Floor/ceil indices --- |
214 | | - ix_down = xp.floor(grid_over_index[:, 0]) |
215 | | - ix_up = xp.ceil(grid_over_index[:, 0]) |
216 | | - iy_down = xp.floor(grid_over_index[:, 1]) |
217 | | - iy_up = xp.ceil(grid_over_index[:, 1]) |
218 | | - |
219 | | - # --- Step 5. Four corners --- |
220 | | - idx_tl = xp.stack([ix_up, iy_down], axis=1) |
221 | | - idx_tr = xp.stack([ix_up, iy_up], axis=1) |
222 | | - idx_br = xp.stack([ix_down, iy_up], axis=1) |
223 | | - idx_bl = xp.stack([ix_down, iy_down], axis=1) |
224 | | - |
225 | | - # --- Step 6. Flatten indices --- |
226 | | - def flatten(idx, n): |
227 | | - row = n - idx[:, 0] |
228 | | - col = idx[:, 1] |
229 | | - return row * n + col |
230 | | - |
231 | | - flat_tl = flatten(idx_tl, source_grid_size) |
232 | | - flat_tr = flatten(idx_tr, source_grid_size) |
233 | | - flat_bl = flatten(idx_bl, source_grid_size) |
234 | | - flat_br = flatten(idx_br, source_grid_size) |
235 | | - |
236 | | - flat_indices = xp.stack([flat_tl, flat_tr, flat_bl, flat_br], axis=1).astype( |
237 | | - "int64" |
238 | | - ) |
239 | | - |
240 | | - # --- Step 7. Bilinear interpolation weights --- |
241 | | - t_row = (grid_over_index[:, 0] - ix_down) / (ix_up - ix_down + 1e-12) |
242 | | - t_col = (grid_over_index[:, 1] - iy_down) / (iy_up - iy_down + 1e-12) |
243 | | - |
244 | | - # Weights |
245 | | - w_tl = (1 - t_row) * (1 - t_col) |
246 | | - w_tr = (1 - t_row) * t_col |
247 | | - w_bl = t_row * (1 - t_col) |
248 | | - w_br = t_row * t_col |
249 | | - weights = xp.stack([w_tl, w_tr, w_bl, w_br], axis=1) |
250 | | - |
251 | | - return flat_indices, weights |
252 | | - |
253 | | - |
254 | | -def rectangular_mappings_weights_via_interpolation_from( |
255 | | - shape_native: Tuple[int, int], |
256 | | - source_plane_data_grid: np.ndarray, |
257 | | - source_plane_mesh_grid: np.ndarray, |
258 | | - xp=np, |
259 | | -): |
260 | | - """ |
261 | | - Compute bilinear interpolation weights and corresponding rectangular mesh indices for an irregular grid. |
262 | | -
|
263 | | - Given a flattened regular rectangular mesh grid and an irregular grid of data points, this function |
264 | | - determines for each irregular point: |
265 | | - - the indices of the 4 nearest rectangular mesh pixels (top-left, top-right, bottom-left, bottom-right), and |
266 | | - - the bilinear interpolation weights with respect to those pixels. |
267 | | -
|
268 | | - The function supports JAX and is compatible with JIT compilation. |
269 | | -
|
270 | | - Parameters |
271 | | - ---------- |
272 | | - shape_native |
273 | | - The shape (Ny, Nx) of the original rectangular mesh grid before flattening. |
274 | | - source_plane_data_grid |
275 | | - The irregular grid of (y, x) points to interpolate. |
276 | | - source_plane_mesh_grid |
277 | | - The flattened regular rectangular mesh grid of (y, x) coordinates. |
278 | | -
|
279 | | - Returns |
280 | | - ------- |
281 | | - mappings : np.ndarray of shape (N, 4) |
282 | | - Indices of the four nearest rectangular mesh pixels in the flattened mesh grid. |
283 | | - Order is: top-left, top-right, bottom-left, bottom-right. |
284 | | - weights : np.ndarray of shape (N, 4) |
285 | | - Bilinear interpolation weights corresponding to the four nearest mesh pixels. |
286 | | -
|
287 | | - Notes |
288 | | - ----- |
289 | | - - Assumes the mesh grid is uniformly spaced. |
290 | | - - The weights sum to 1 for each irregular point. |
291 | | - - Uses bilinear interpolation in the (y, x) coordinate system. |
292 | | - """ |
293 | | - source_plane_mesh_grid = source_plane_mesh_grid.reshape(*shape_native, 2) |
294 | | - |
295 | | - # Assume mesh is shaped (Ny, Nx, 2) |
296 | | - Ny, Nx = source_plane_mesh_grid.shape[:2] |
297 | | - |
298 | | - # Get mesh spacings and lower corner |
299 | | - y_coords = source_plane_mesh_grid[:, 0, 0] # shape (Ny,) |
300 | | - x_coords = source_plane_mesh_grid[0, :, 1] # shape (Nx,) |
301 | | - |
302 | | - dy = y_coords[1] - y_coords[0] |
303 | | - dx = x_coords[1] - x_coords[0] |
304 | | - |
305 | | - y_min = y_coords[0] |
306 | | - x_min = x_coords[0] |
307 | | - |
308 | | - # shape (N_irregular, 2) |
309 | | - irregular = source_plane_data_grid |
310 | | - |
311 | | - # Compute normalized mesh coordinates (floating indices) |
312 | | - fy = (irregular[:, 0] - y_min) / dy |
313 | | - fx = (irregular[:, 1] - x_min) / dx |
314 | | - |
315 | | - # Integer indices of top-left corners |
316 | | - ix = xp.floor(fx).astype(xp.int32) |
317 | | - iy = xp.floor(fy).astype(xp.int32) |
318 | | - |
319 | | - # Clip to stay within bounds |
320 | | - ix = xp.clip(ix, 0, Nx - 2) |
321 | | - iy = xp.clip(iy, 0, Ny - 2) |
322 | | - |
323 | | - # Local coordinates inside the cell (0 <= tx, ty <= 1) |
324 | | - tx = fx - ix |
325 | | - ty = fy - iy |
326 | | - |
327 | | - # Bilinear weights |
328 | | - w00 = (1 - tx) * (1 - ty) |
329 | | - w10 = tx * (1 - ty) |
330 | | - w01 = (1 - tx) * ty |
331 | | - w11 = tx * ty |
332 | | - |
333 | | - weights = xp.stack([w00, w10, w01, w11], axis=1) # shape (N_irregular, 4) |
334 | | - |
335 | | - # Compute indices of 4 surrounding pixels in the flattened mesh |
336 | | - i00 = iy * Nx + ix |
337 | | - i10 = iy * Nx + (ix + 1) |
338 | | - i01 = (iy + 1) * Nx + ix |
339 | | - i11 = (iy + 1) * Nx + (ix + 1) |
340 | | - |
341 | | - mappings = xp.stack([i00, i10, i01, i11], axis=1) # shape (N_irregular, 4) |
342 | | - |
343 | | - return mappings, weights |
344 | | - |
345 | | - |
346 | | -def nearest_pixelization_index_for_slim_index_from_kdtree(grid, mesh_grid): |
347 | | - from scipy.spatial import cKDTree |
348 | | - |
349 | | - kdtree = cKDTree(mesh_grid) |
350 | | - |
351 | | - sparse_index_for_slim_index = [] |
352 | | - |
353 | | - for i in range(grid.shape[0]): |
354 | | - input_point = [grid[i, [0]], grid[i, 1]] |
355 | | - index = kdtree.query(input_point)[1] |
356 | | - sparse_index_for_slim_index.append(index) |
357 | | - |
358 | | - return sparse_index_for_slim_index |
359 | 10 |
|
360 | 11 |
|
361 | 12 | def adaptive_pixel_signals_from( |
|
0 commit comments