|
| 1 | +from scipy.signal import convolve2d |
| 2 | +import jax.numpy as jnp |
1 | 3 | import numpy as np |
2 | 4 | from typing import Tuple |
3 | 5 |
|
4 | 6 | from autoarray import numba_util |
| 7 | +from scipy.signal import correlate2d |
| 8 | + |
| 9 | +import numpy as np |
| 10 | + |
| 11 | + |
| 12 | +def psf_operator_matrix_dense_from( |
| 13 | + kernel_native: np.ndarray, |
| 14 | + native_index_for_slim_index: np.ndarray, # shape (N_pix, 2), native (y,x) coords of masked pixels |
| 15 | + native_shape: tuple[int, int], |
| 16 | + correlate: bool = True, |
| 17 | +) -> np.ndarray: |
| 18 | + """ |
| 19 | + Construct a dense PSF operator W (N_pix x N_pix) that maps masked image pixels to masked image pixels. |
| 20 | +
|
| 21 | + Parameters |
| 22 | + ---------- |
| 23 | + kernel_native : (Ky, Kx) PSF kernel. |
| 24 | + native_index_for_slim_index : (N_pix, 2) array of int |
| 25 | + Native (y, x) coords for each masked pixel. |
| 26 | + native_shape : (Ny, Nx) |
| 27 | + Native 2D image shape. |
| 28 | + correlate : bool, default True |
| 29 | + If True, use correlation convention (no kernel flip). |
| 30 | + If False, use convolution convention (flip kernel). |
| 31 | +
|
| 32 | + Returns |
| 33 | + ------- |
| 34 | + W : ndarray, shape (N_pix, N_pix) |
| 35 | + Dense PSF operator. |
| 36 | + """ |
| 37 | + Ky, Kx = kernel_native.shape |
| 38 | + ph, pw = Ky // 2, Kx // 2 |
| 39 | + Ny, Nx = native_shape |
| 40 | + N_pix = native_index_for_slim_index.shape[0] |
| 41 | + |
| 42 | + ker = kernel_native if correlate else kernel_native[::-1, ::-1] |
| 43 | + |
| 44 | + # Padded index grid: -1 everywhere, slim index where masked |
| 45 | + index_padded = -np.ones((Ny + 2 * ph, Nx + 2 * pw), dtype=np.int64) |
| 46 | + for p, (y, x) in enumerate(native_index_for_slim_index): |
| 47 | + index_padded[y + ph, x + pw] = p |
| 48 | + |
| 49 | + # Neighborhood offsets |
| 50 | + dy = np.arange(Ky) - ph |
| 51 | + dx = np.arange(Kx) - pw |
| 52 | + |
| 53 | + W = np.zeros((N_pix, N_pix), dtype=float) |
| 54 | + |
| 55 | + for i, (y, x) in enumerate(native_index_for_slim_index): |
| 56 | + yp = y + ph |
| 57 | + xp = x + pw |
| 58 | + for j, dy_ in enumerate(dy): |
| 59 | + for k, dx_ in enumerate(dx): |
| 60 | + neigh = index_padded[yp + dy_, xp + dx_] |
| 61 | + if neigh >= 0: |
| 62 | + W[i, neigh] += ker[j, k] |
| 63 | + |
| 64 | + return W |
5 | 65 |
|
6 | 66 |
|
7 | | -@numba_util.jit() |
8 | 67 | def w_tilde_data_imaging_from( |
9 | 68 | image_native: np.ndarray, |
10 | 69 | noise_map_native: np.ndarray, |
@@ -44,32 +103,35 @@ def w_tilde_data_imaging_from( |
44 | 103 | efficient calculation of the data vector. |
45 | 104 | """ |
46 | 105 |
|
47 | | - kernel_shift_y = -(kernel_native.shape[1] // 2) |
48 | | - kernel_shift_x = -(kernel_native.shape[0] // 2) |
49 | | - |
50 | | - image_pixels = len(native_index_for_slim_index) |
51 | | - |
52 | | - w_tilde_data = np.zeros((image_pixels,)) |
| 106 | + # 1) weight map = image / noise^2 (safe where noise==0) |
| 107 | + weight_map = jnp.where( |
| 108 | + noise_map_native > 0.0, image_native / (noise_map_native**2), 0.0 |
| 109 | + ) |
53 | 110 |
|
54 | | - weight_map_native = image_native / noise_map_native**2.0 |
| 111 | + Ky, Kx = kernel_native.shape |
| 112 | + ph, pw = Ky // 2, Kx // 2 |
55 | 113 |
|
56 | | - for ip0 in range(image_pixels): |
57 | | - ip0_y, ip0_x = native_index_for_slim_index[ip0] |
58 | | - |
59 | | - value = 0.0 |
| 114 | + # 2) pad so neighbourhood gathers never go OOB |
| 115 | + padded = jnp.pad( |
| 116 | + weight_map, ((ph, ph), (pw, pw)), mode="constant", constant_values=0.0 |
| 117 | + ) |
60 | 118 |
|
61 | | - for k0_y in range(kernel_native.shape[0]): |
62 | | - for k0_x in range(kernel_native.shape[1]): |
63 | | - weight_value = weight_map_native[ |
64 | | - ip0_y + k0_y + kernel_shift_y, ip0_x + k0_x + kernel_shift_x |
65 | | - ] |
| 119 | + # 3) build broadcasted neighbourhood indices for all requested pixels |
| 120 | + # shift pixel coords into the padded frame |
| 121 | + ys = native_index_for_slim_index[:, 0] + ph # (N,) |
| 122 | + xs = native_index_for_slim_index[:, 1] + pw # (N,) |
66 | 123 |
|
67 | | - if not np.isnan(weight_value): |
68 | | - value += kernel_native[k0_y, k0_x] * weight_value |
| 124 | + # kernel-relative offsets |
| 125 | + dy = jnp.arange(Ky) - ph # (Ky,) |
| 126 | + dx = jnp.arange(Kx) - pw # (Kx,) |
69 | 127 |
|
70 | | - w_tilde_data[ip0] = value |
| 128 | + # broadcast to (N, Ky, Kx) |
| 129 | + Y = ys[:, None, None] + dy[None, :, None] |
| 130 | + X = xs[:, None, None] + dx[None, None, :] |
71 | 131 |
|
72 | | - return w_tilde_data |
| 132 | + # 4) gather patches and correlate (no kernel flip) |
| 133 | + patches = padded[Y, X] # (N, Ky, Kx) |
| 134 | + return jnp.sum(patches * kernel_native[None, :, :], axis=(1, 2)) # (N,) |
73 | 135 |
|
74 | 136 |
|
75 | 137 | @numba_util.jit() |
|
0 commit comments