|
1 | 1 | from __future__ import annotations |
2 | | -from collections import defaultdict |
3 | 2 | import numpy as np |
4 | 3 | from typing import TYPE_CHECKING, Union |
5 | 4 | from typing import List, Tuple |
|
11 | 10 |
|
12 | 11 | from autoarray.mask.mask_2d import Mask2D |
13 | 12 |
|
14 | | -from autoarray import numba_util |
15 | 13 |
|
16 | 14 | from autoarray import type as ty |
17 | 15 |
|
@@ -168,7 +166,6 @@ def sub_size_radial_bins_from( |
168 | 166 |
|
169 | 167 | return sub_size_list[bin_indices] |
170 | 168 |
|
171 | | - |
172 | 169 | def grid_2d_slim_over_sampled_via_mask_from( |
173 | 170 | mask_2d: np.ndarray, |
174 | 171 | pixel_scales: ty.PixelScales, |
@@ -218,54 +215,51 @@ def grid_2d_slim_over_sampled_via_mask_from( |
218 | 215 | sy, sx = pixel_scales |
219 | 216 | oy, ox = origin |
220 | 217 |
|
221 | | - # 1) Find unmasked pixels in row-major order |
| 218 | + # 1) Find unmasked pixel indices in row-major order |
222 | 219 | rows, cols = np.nonzero(~mask_2d) |
223 | 220 | Npix = rows.size |
224 | 221 |
|
225 | | - # 2) Normalize sub_size input |
| 222 | + # 2) Broadcast or validate sub_size array |
226 | 223 | sub_arr = np.asarray(sub_size) |
227 | | - sub_arr = np.full(Npix, sub_arr, dtype=int) if sub_arr.size == 1 else sub_arr |
228 | | - |
229 | | - # 3) Pixel centers in physical coords, y↑up |
| 224 | + if sub_arr.ndim == 0: |
| 225 | + sub_arr = np.full(Npix, int(sub_arr), int) |
| 226 | + elif sub_arr.ndim == 1 and sub_arr.size == Npix: |
| 227 | + sub_arr = sub_arr.astype(int) |
| 228 | + else: |
| 229 | + raise ValueError(f"sub_size must be scalar or length-{Npix} array, got shape {sub_arr.shape}") |
| 230 | + |
| 231 | + # 3) Compute pixel centers (y ↑ up, x → right) |
230 | 232 | cy = (H - 1) / 2.0 |
231 | 233 | cx = (W - 1) / 2.0 |
232 | 234 | y_pix = (cy - rows) * sy + oy |
233 | 235 | x_pix = (cols - cx) * sx + ox |
234 | 236 |
|
235 | | - # Pre‐group pixel indices by sub_size |
236 | | - groups = defaultdict(list) |
237 | | - for i, s in enumerate(sub_arr): |
238 | | - groups[s].append(i) |
239 | | - |
240 | | - # Prepare output |
241 | | - total = np.sum(sub_arr * sub_arr) |
242 | | - coords = np.empty((total, 2), float) |
243 | | - idx = 0 |
244 | | - |
245 | | - for s, pix_indices in groups.items(): |
246 | | - # Compute offsets once for this sub_size |
247 | | - dy, dx = sy / s, sx / s |
248 | | - y_off = np.linspace(+sy / 2 - dy / 2, -sy / 2 + dy / 2, s) |
249 | | - x_off = np.linspace(-sx / 2 + dx / 2, +sx / 2 - dx / 2, s) |
| 237 | + # 4) For each pixel, generate its sub-pixel coords and collect |
| 238 | + coords_list = [] |
| 239 | + for i in range(Npix): |
| 240 | + s = sub_arr[i] |
| 241 | + dy = sy / s |
| 242 | + dx = sx / s |
| 243 | + |
| 244 | + # y offsets: from top (+sy/2 - dy/2) down to bottom (-sy/2 + dy/2) |
| 245 | + y_off = np.linspace(+sy/2 - dy/2, -sy/2 + dy/2, s) |
| 246 | + # x offsets: left to right |
| 247 | + x_off = np.linspace(-sx/2 + dx/2, +sx/2 - dx/2, s) |
| 248 | + |
| 249 | + # build subgrid |
250 | 250 | y_sub, x_sub = np.meshgrid(y_off, x_off, indexing="ij") |
251 | 251 | y_sub = y_sub.ravel() |
252 | 252 | x_sub = x_sub.ravel() |
253 | | - n_sub = s * s |
254 | | - |
255 | | - # Now vectorize over all pixels in this group |
256 | | - pix_idx = np.array(pix_indices) |
257 | | - y_centers = y_pix[pix_idx] |
258 | | - x_centers = x_pix[pix_idx] |
259 | 253 |
|
260 | | - # Repeat‐tile to shape (len(pix_idx)*n_sub,) |
261 | | - all_y = np.repeat(y_centers, n_sub) + np.tile(y_sub, len(pix_idx)) |
262 | | - all_x = np.repeat(x_centers, n_sub) + np.tile(x_sub, len(pix_idx)) |
| 254 | + # center + offsets |
| 255 | + y_center = y_pix[i] |
| 256 | + x_center = x_pix[i] |
| 257 | + coords = np.stack([y_center + y_sub, x_center + x_sub], axis=1) |
263 | 258 |
|
264 | | - coords[idx : idx + all_y.size, 0] = all_y |
265 | | - coords[idx : idx + all_x.size, 1] = all_x |
266 | | - idx += all_y.size |
| 259 | + coords_list.append(coords) |
267 | 260 |
|
268 | | - return coords |
| 261 | + # 5) Concatenate all sub-pixel blocks in row-major pixel order |
| 262 | + return np.vstack(coords_list) |
269 | 263 |
|
270 | 264 |
|
271 | 265 | def over_sample_size_via_radial_bins_from( |
|
0 commit comments