Skip to content

Commit fdac408

Browse files
Jammy2211Jammy2211
authored andcommitted
use simpler grid_2d_slim_over_sampled_via_mask_from which works
1 parent 650c8a0 commit fdac408

File tree

1 file changed

+30
-36
lines changed

1 file changed

+30
-36
lines changed

autoarray/operators/over_sampling/over_sample_util.py

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from __future__ import annotations
2-
from collections import defaultdict
32
import numpy as np
43
from typing import TYPE_CHECKING, Union
54
from typing import List, Tuple
@@ -11,7 +10,6 @@
1110

1211
from autoarray.mask.mask_2d import Mask2D
1312

14-
from autoarray import numba_util
1513

1614
from autoarray import type as ty
1715

@@ -168,7 +166,6 @@ def sub_size_radial_bins_from(
168166

169167
return sub_size_list[bin_indices]
170168

171-
172169
def grid_2d_slim_over_sampled_via_mask_from(
173170
mask_2d: np.ndarray,
174171
pixel_scales: ty.PixelScales,
@@ -218,54 +215,51 @@ def grid_2d_slim_over_sampled_via_mask_from(
218215
sy, sx = pixel_scales
219216
oy, ox = origin
220217

221-
# 1) Find unmasked pixels in row-major order
218+
# 1) Find unmasked pixel indices in row-major order
222219
rows, cols = np.nonzero(~mask_2d)
223220
Npix = rows.size
224221

225-
# 2) Normalize sub_size input
222+
# 2) Broadcast or validate sub_size array
226223
sub_arr = np.asarray(sub_size)
227-
sub_arr = np.full(Npix, sub_arr, dtype=int) if sub_arr.size == 1 else sub_arr
228-
229-
# 3) Pixel centers in physical coords, y↑up
224+
if sub_arr.ndim == 0:
225+
sub_arr = np.full(Npix, int(sub_arr), int)
226+
elif sub_arr.ndim == 1 and sub_arr.size == Npix:
227+
sub_arr = sub_arr.astype(int)
228+
else:
229+
raise ValueError(f"sub_size must be scalar or length-{Npix} array, got shape {sub_arr.shape}")
230+
231+
# 3) Compute pixel centers (y ↑ up, x → right)
230232
cy = (H - 1) / 2.0
231233
cx = (W - 1) / 2.0
232234
y_pix = (cy - rows) * sy + oy
233235
x_pix = (cols - cx) * sx + ox
234236

235-
# Pre‐group pixel indices by sub_size
236-
groups = defaultdict(list)
237-
for i, s in enumerate(sub_arr):
238-
groups[s].append(i)
239-
240-
# Prepare output
241-
total = np.sum(sub_arr * sub_arr)
242-
coords = np.empty((total, 2), float)
243-
idx = 0
244-
245-
for s, pix_indices in groups.items():
246-
# Compute offsets once for this sub_size
247-
dy, dx = sy / s, sx / s
248-
y_off = np.linspace(+sy / 2 - dy / 2, -sy / 2 + dy / 2, s)
249-
x_off = np.linspace(-sx / 2 + dx / 2, +sx / 2 - dx / 2, s)
237+
# 4) For each pixel, generate its sub-pixel coords and collect
238+
coords_list = []
239+
for i in range(Npix):
240+
s = sub_arr[i]
241+
dy = sy / s
242+
dx = sx / s
243+
244+
# y offsets: from top (+sy/2 - dy/2) down to bottom (-sy/2 + dy/2)
245+
y_off = np.linspace(+sy/2 - dy/2, -sy/2 + dy/2, s)
246+
# x offsets: left to right
247+
x_off = np.linspace(-sx/2 + dx/2, +sx/2 - dx/2, s)
248+
249+
# build subgrid
250250
y_sub, x_sub = np.meshgrid(y_off, x_off, indexing="ij")
251251
y_sub = y_sub.ravel()
252252
x_sub = x_sub.ravel()
253-
n_sub = s * s
254-
255-
# Now vectorize over all pixels in this group
256-
pix_idx = np.array(pix_indices)
257-
y_centers = y_pix[pix_idx]
258-
x_centers = x_pix[pix_idx]
259253

260-
# Repeat‐tile to shape (len(pix_idx)*n_sub,)
261-
all_y = np.repeat(y_centers, n_sub) + np.tile(y_sub, len(pix_idx))
262-
all_x = np.repeat(x_centers, n_sub) + np.tile(x_sub, len(pix_idx))
254+
# center + offsets
255+
y_center = y_pix[i]
256+
x_center = x_pix[i]
257+
coords = np.stack([y_center + y_sub, x_center + x_sub], axis=1)
263258

264-
coords[idx : idx + all_y.size, 0] = all_y
265-
coords[idx : idx + all_x.size, 1] = all_x
266-
idx += all_y.size
259+
coords_list.append(coords)
267260

268-
return coords
261+
# 5) Concatenate all sub-pixel blocks in row-major pixel order
262+
return np.vstack(coords_list)
269263

270264

271265
def over_sample_size_via_radial_bins_from(

0 commit comments

Comments
 (0)