Skip to content

Commit a03a6b4

Browse files
Jammy2211Jammy2211
authored andcommitted
remove mesh_tuil
1 parent c916c52 commit a03a6b4

File tree

6 files changed

+730
-744
lines changed

6 files changed

+730
-744
lines changed

autoarray/inversion/pixelization/mappers/mapper_util.py

Lines changed: 0 additions & 349 deletions
Original file line numberDiff line numberDiff line change
@@ -3,359 +3,10 @@
33
from typing import Tuple
44

55

6-
def forward_interp(xp, yp, x):
76

8-
import jax
9-
import jax.numpy as jnp
107

11-
return jax.vmap(jnp.interp, in_axes=(1, 1, 1, None, None), out_axes=(1))(
12-
x, xp, yp, 0, 1
13-
)
148

159

16-
def reverse_interp(xp, yp, x):
17-
import jax
18-
import jax.numpy as jnp
19-
20-
return jax.vmap(jnp.interp, in_axes=(1, 1, 1), out_axes=(1))(x, xp, yp)
21-
22-
23-
def forward_interp_np(xp, yp, x):
24-
"""
25-
xp: (N, M)
26-
yp: (N, M)
27-
x : (M,) ← one x per column
28-
"""
29-
30-
if yp.ndim == 1 and xp.ndim == 2:
31-
yp = np.broadcast_to(yp[:, None], xp.shape)
32-
33-
K, M = x.shape
34-
35-
out = np.empty((K, 2), dtype=xp.dtype)
36-
37-
for j in range(2):
38-
out[:, j] = np.interp(x[:, j], xp[:, j], yp[:, j], left=0, right=1)
39-
40-
return out
41-
42-
43-
def reverse_interp_np(xp, yp, x):
44-
"""
45-
xp : (N,) or (N, M)
46-
yp : (N, M)
47-
x : (K, M) query points per column
48-
"""
49-
50-
# Ensure xp is 2D: (N, M)
51-
if xp.ndim == 1 and yp.ndim == 2: # (N, 1)
52-
xp = np.broadcast_to(xp[:, None], yp.shape)
53-
54-
# Shapes
55-
K, M = x.shape
56-
57-
# Output
58-
out = np.empty((K, 2), dtype=yp.dtype)
59-
60-
# Column-wise interpolation (cannot avoid this loop in pure NumPy)
61-
for j in range(2):
62-
out[:, j] = np.interp(x[:, j], xp[:, j], yp[:, j])
63-
64-
return out
65-
66-
67-
def create_transforms(traced_points, mesh_weight_map=None, xp=np):
68-
69-
N = traced_points.shape[0] # // 2
70-
71-
if mesh_weight_map is None:
72-
t = xp.arange(1, N + 1) / (N + 1)
73-
t = xp.stack([t, t], axis=1)
74-
sort_points = xp.sort(traced_points, axis=0) # [::2]
75-
else:
76-
sdx = xp.argsort(traced_points, axis=0)
77-
sort_points = xp.take_along_axis(traced_points, sdx, axis=0)
78-
t = xp.stack([mesh_weight_map, mesh_weight_map], axis=1)
79-
t = xp.take_along_axis(t, sdx, axis=0)
80-
t = xp.cumsum(t, axis=0)
81-
82-
if xp.__name__.startswith("jax"):
83-
transform = partial(forward_interp, sort_points, t)
84-
inv_transform = partial(reverse_interp, t, sort_points)
85-
return transform, inv_transform
86-
87-
transform = partial(forward_interp_np, sort_points, t)
88-
inv_transform = partial(reverse_interp_np, t, sort_points)
89-
return transform, inv_transform
90-
91-
92-
def adaptive_rectangular_transformed_grid_from(
93-
source_plane_data_grid, grid, mesh_weight_map=None, xp=np
94-
):
95-
96-
mu = source_plane_data_grid.mean(axis=0)
97-
scale = source_plane_data_grid.std(axis=0).min()
98-
source_grid_scaled = (source_plane_data_grid - mu) / scale
99-
100-
transform, inv_transform = create_transforms(
101-
source_grid_scaled, mesh_weight_map=mesh_weight_map, xp=xp
102-
)
103-
104-
def inv_full(U):
105-
return inv_transform(U) * scale + mu
106-
107-
return inv_full(grid)
108-
109-
110-
def adaptive_rectangular_areas_from(
111-
source_grid_shape, source_plane_data_grid, mesh_weight_map=None, xp=np
112-
):
113-
114-
edges_y = xp.linspace(1, 0, source_grid_shape[0] + 1)
115-
edges_x = xp.linspace(0, 1, source_grid_shape[1] + 1)
116-
117-
mu = source_plane_data_grid.mean(axis=0)
118-
scale = source_plane_data_grid.std(axis=0).min()
119-
source_grid_scaled = (source_plane_data_grid - mu) / scale
120-
121-
transform, inv_transform = create_transforms(
122-
source_grid_scaled, mesh_weight_map=mesh_weight_map, xp=xp
123-
)
124-
125-
def inv_full(U):
126-
return inv_transform(U) * scale + mu
127-
128-
pixel_edges = inv_full(xp.stack([edges_y, edges_x]).T)
129-
pixel_lengths = xp.diff(pixel_edges, axis=0).squeeze() # shape (N_source, 2)
130-
131-
dy = pixel_lengths[:, 0]
132-
dx = pixel_lengths[:, 1]
133-
134-
return xp.abs(xp.outer(dy, dx).flatten())
135-
136-
137-
def adaptive_rectangular_mappings_weights_via_interpolation_from(
138-
source_grid_size: int,
139-
source_plane_data_grid,
140-
source_plane_data_grid_over_sampled,
141-
mesh_weight_map=None,
142-
xp=np,
143-
):
144-
"""
145-
Compute bilinear interpolation indices and weights for mapping an oversampled
146-
source-plane grid onto a regular rectangular pixelization.
147-
148-
This function takes a set of irregularly-sampled source-plane coordinates and
149-
builds an adaptive mapping onto a `source_grid_size x source_grid_size` rectangular
150-
pixelization using bilinear interpolation. The interpolation is expressed as:
151-
152-
f(x, y) ≈ w_bl * f(ix_down, iy_down) +
153-
w_br * f(ix_up, iy_down) +
154-
w_tl * f(ix_down, iy_up) +
155-
w_tr * f(ix_up, iy_up)
156-
157-
where `(ix_down, ix_up, iy_down, iy_up)` are the integer grid coordinates
158-
surrounding the continuous position `(x, y)`.
159-
160-
Steps performed:
161-
1. Normalize the source-plane grid by subtracting its mean and dividing by
162-
the minimum axis standard deviation (to balance scaling).
163-
2. Construct forward/inverse transforms which map the grid into the unit square [0,1]^2.
164-
3. Transform the oversampled source-plane grid into [0,1]^2, then scale it
165-
to index space `[0, source_grid_size)`.
166-
4. Compute floor/ceil along x and y axes to find the enclosing rectangular cell.
167-
5. Build the four corner indices: bottom-left (bl), bottom-right (br),
168-
top-left (tl), and top-right (tr).
169-
6. Flatten the 2D indices into 1D indices suitable for scatter operations,
170-
with a flipped row-major convention: row = source_grid_size - i, col = j.
171-
7. Compute bilinear interpolation weights (`w_bl, w_br, w_tl, w_tr`).
172-
8. Return arrays of flattened indices and weights of shape `(N, 4)`, where
173-
`N` is the number of oversampled coordinates.
174-
175-
Parameters
176-
----------
177-
source_grid_size : int
178-
The number of pixels along one dimension of the rectangular pixelization.
179-
The grid is square: (source_grid_size x source_grid_size).
180-
source_plane_data_grid : (M, 2) ndarray
181-
The base source-plane coordinates, used to define normalization and transforms.
182-
source_plane_data_grid_over_sampled : (N, 2) ndarray
183-
Oversampled source-plane coordinates to be interpolated onto the rectangular grid.
184-
mesh_weight_map
185-
The weight map used to weight the creation of the rectangular mesh grid, which is used for the
186-
`RectangularBrightness` mesh which adapts the size of its pixels to where the source is reconstructed.
187-
188-
Returns
189-
-------
190-
flat_indices : (N, 4) int ndarray
191-
The flattened indices of the four neighboring pixel corners for each oversampled point.
192-
Order: [bl, br, tl, tr].
193-
weights : (N, 4) float ndarray
194-
The bilinear interpolation weights for each of the four neighboring pixels.
195-
Order: [w_bl, w_br, w_tl, w_tr].
196-
"""
197-
198-
# --- Step 1. Normalize grid ---
199-
mu = source_plane_data_grid.mean(axis=0)
200-
scale = source_plane_data_grid.std(axis=0).min()
201-
source_grid_scaled = (source_plane_data_grid - mu) / scale
202-
203-
# --- Step 2. Build transforms ---
204-
transform, inv_transform = create_transforms(
205-
source_grid_scaled, mesh_weight_map=mesh_weight_map, xp=xp
206-
)
207-
208-
# --- Step 3. Transform oversampled grid into index space ---
209-
grid_over_sampled_scaled = (source_plane_data_grid_over_sampled - mu) / scale
210-
grid_over_sampled_transformed = transform(grid_over_sampled_scaled)
211-
grid_over_index = (source_grid_size - 3) * grid_over_sampled_transformed + 1
212-
213-
# --- Step 4. Floor/ceil indices ---
214-
ix_down = xp.floor(grid_over_index[:, 0])
215-
ix_up = xp.ceil(grid_over_index[:, 0])
216-
iy_down = xp.floor(grid_over_index[:, 1])
217-
iy_up = xp.ceil(grid_over_index[:, 1])
218-
219-
# --- Step 5. Four corners ---
220-
idx_tl = xp.stack([ix_up, iy_down], axis=1)
221-
idx_tr = xp.stack([ix_up, iy_up], axis=1)
222-
idx_br = xp.stack([ix_down, iy_up], axis=1)
223-
idx_bl = xp.stack([ix_down, iy_down], axis=1)
224-
225-
# --- Step 6. Flatten indices ---
226-
def flatten(idx, n):
227-
row = n - idx[:, 0]
228-
col = idx[:, 1]
229-
return row * n + col
230-
231-
flat_tl = flatten(idx_tl, source_grid_size)
232-
flat_tr = flatten(idx_tr, source_grid_size)
233-
flat_bl = flatten(idx_bl, source_grid_size)
234-
flat_br = flatten(idx_br, source_grid_size)
235-
236-
flat_indices = xp.stack([flat_tl, flat_tr, flat_bl, flat_br], axis=1).astype(
237-
"int64"
238-
)
239-
240-
# --- Step 7. Bilinear interpolation weights ---
241-
t_row = (grid_over_index[:, 0] - ix_down) / (ix_up - ix_down + 1e-12)
242-
t_col = (grid_over_index[:, 1] - iy_down) / (iy_up - iy_down + 1e-12)
243-
244-
# Weights
245-
w_tl = (1 - t_row) * (1 - t_col)
246-
w_tr = (1 - t_row) * t_col
247-
w_bl = t_row * (1 - t_col)
248-
w_br = t_row * t_col
249-
weights = xp.stack([w_tl, w_tr, w_bl, w_br], axis=1)
250-
251-
return flat_indices, weights
252-
253-
254-
def rectangular_mappings_weights_via_interpolation_from(
255-
shape_native: Tuple[int, int],
256-
source_plane_data_grid: np.ndarray,
257-
source_plane_mesh_grid: np.ndarray,
258-
xp=np,
259-
):
260-
"""
261-
Compute bilinear interpolation weights and corresponding rectangular mesh indices for an irregular grid.
262-
263-
Given a flattened regular rectangular mesh grid and an irregular grid of data points, this function
264-
determines for each irregular point:
265-
- the indices of the 4 nearest rectangular mesh pixels (top-left, top-right, bottom-left, bottom-right), and
266-
- the bilinear interpolation weights with respect to those pixels.
267-
268-
The function supports JAX and is compatible with JIT compilation.
269-
270-
Parameters
271-
----------
272-
shape_native
273-
The shape (Ny, Nx) of the original rectangular mesh grid before flattening.
274-
source_plane_data_grid
275-
The irregular grid of (y, x) points to interpolate.
276-
source_plane_mesh_grid
277-
The flattened regular rectangular mesh grid of (y, x) coordinates.
278-
279-
Returns
280-
-------
281-
mappings : np.ndarray of shape (N, 4)
282-
Indices of the four nearest rectangular mesh pixels in the flattened mesh grid.
283-
Order is: top-left, top-right, bottom-left, bottom-right.
284-
weights : np.ndarray of shape (N, 4)
285-
Bilinear interpolation weights corresponding to the four nearest mesh pixels.
286-
287-
Notes
288-
-----
289-
- Assumes the mesh grid is uniformly spaced.
290-
- The weights sum to 1 for each irregular point.
291-
- Uses bilinear interpolation in the (y, x) coordinate system.
292-
"""
293-
source_plane_mesh_grid = source_plane_mesh_grid.reshape(*shape_native, 2)
294-
295-
# Assume mesh is shaped (Ny, Nx, 2)
296-
Ny, Nx = source_plane_mesh_grid.shape[:2]
297-
298-
# Get mesh spacings and lower corner
299-
y_coords = source_plane_mesh_grid[:, 0, 0] # shape (Ny,)
300-
x_coords = source_plane_mesh_grid[0, :, 1] # shape (Nx,)
301-
302-
dy = y_coords[1] - y_coords[0]
303-
dx = x_coords[1] - x_coords[0]
304-
305-
y_min = y_coords[0]
306-
x_min = x_coords[0]
307-
308-
# shape (N_irregular, 2)
309-
irregular = source_plane_data_grid
310-
311-
# Compute normalized mesh coordinates (floating indices)
312-
fy = (irregular[:, 0] - y_min) / dy
313-
fx = (irregular[:, 1] - x_min) / dx
314-
315-
# Integer indices of top-left corners
316-
ix = xp.floor(fx).astype(xp.int32)
317-
iy = xp.floor(fy).astype(xp.int32)
318-
319-
# Clip to stay within bounds
320-
ix = xp.clip(ix, 0, Nx - 2)
321-
iy = xp.clip(iy, 0, Ny - 2)
322-
323-
# Local coordinates inside the cell (0 <= tx, ty <= 1)
324-
tx = fx - ix
325-
ty = fy - iy
326-
327-
# Bilinear weights
328-
w00 = (1 - tx) * (1 - ty)
329-
w10 = tx * (1 - ty)
330-
w01 = (1 - tx) * ty
331-
w11 = tx * ty
332-
333-
weights = xp.stack([w00, w10, w01, w11], axis=1) # shape (N_irregular, 4)
334-
335-
# Compute indices of 4 surrounding pixels in the flattened mesh
336-
i00 = iy * Nx + ix
337-
i10 = iy * Nx + (ix + 1)
338-
i01 = (iy + 1) * Nx + ix
339-
i11 = (iy + 1) * Nx + (ix + 1)
340-
341-
mappings = xp.stack([i00, i10, i01, i11], axis=1) # shape (N_irregular, 4)
342-
343-
return mappings, weights
344-
345-
346-
def nearest_pixelization_index_for_slim_index_from_kdtree(grid, mesh_grid):
347-
from scipy.spatial import cKDTree
348-
349-
kdtree = cKDTree(mesh_grid)
350-
351-
sparse_index_for_slim_index = []
352-
353-
for i in range(grid.shape[0]):
354-
input_point = [grid[i, [0]], grid[i, 1]]
355-
index = kdtree.query(input_point)[1]
356-
sparse_index_for_slim_index.append(index)
357-
358-
return sparse_index_for_slim_index
35910

36011

36112
def adaptive_pixel_signals_from(

0 commit comments

Comments
 (0)