Skip to content

Commit 61bbee2

Browse files
committed
readd numba utils
1 parent 4e75879 commit 61bbee2

File tree

1 file changed

+53
-6
lines changed

1 file changed

+53
-6
lines changed

autoarray/mask/mask_1d_util.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,43 @@
11
import numpy as np
22

3+
from autoarray import numba_util
34

5+
6+
@numba_util.jit()
7+
def total_pixels_1d_from(mask_1d: np.ndarray) -> int:
8+
"""
9+
Returns the total number of unmasked pixels in a mask.
10+
11+
Parameters
12+
----------
13+
mask_1d
14+
A 2D array of bools, where `False` values are unmasked and included when counting pixels.
15+
16+
Returns
17+
-------
18+
int
19+
The total number of pixels that are unmasked.
20+
21+
Examples
22+
--------
23+
24+
mask = np.array([[True, False, True],
25+
[False, False, False]
26+
[True, False, True]])
27+
28+
total_regular_pixels = total_regular_pixels_from(mask=mask)
29+
"""
30+
31+
total_regular_pixels = 0
32+
33+
for x in range(mask_1d.shape[0]):
34+
if not mask_1d[x]:
35+
total_regular_pixels += 1
36+
37+
return total_regular_pixels
38+
39+
40+
@numba_util.jit()
441
def native_index_for_slim_index_1d_from(
542
mask_1d: np.ndarray,
643
) -> np.ndarray:
@@ -25,12 +62,22 @@ def native_index_for_slim_index_1d_from(
2562
2663
Examples
2764
--------
28-
>>> mask_1d = np.array([True, False, True, False, False, True])
29-
>>> native_index_for_slim_index_1d_from(mask_1d)
30-
array([1, 3, 4])
65+
mask_2d = np.array([[True, True, True],
66+
[True, False, True]
67+
[True, True, True]])
68+
69+
native_index_for_slim_index_1d = native_index_for_slim_index_1d_from(mask_2d=mask_2d)
3170
3271
"""
33-
# Create an array of native indexes corresponding to unmasked pixels
34-
native_index_for_slim_index_1d = np.flatnonzero(~mask_1d)
3572

36-
return native_index_for_slim_index_1d
73+
total_pixels = total_pixels_1d_from(mask_1d=mask_1d)
74+
native_index_for_slim_index_1d = np.zeros(shape=total_pixels)
75+
76+
slim_index = 0
77+
78+
for x in range(mask_1d.shape[0]):
79+
if not mask_1d[x]:
80+
native_index_for_slim_index_1d[slim_index] = x
81+
slim_index += 1
82+
83+
return native_index_for_slim_index_1d

0 commit comments

Comments
 (0)