Skip to content

Commit 40be8b1

Browse files
authored
Merge pull request #158 from ickc/feature/jax_wrapper-mask_2d_circular_from
fix jit on mask_2d_circular_from
2 parents 311f7cd + ba31c1c commit 40be8b1

File tree

1 file changed

+7
-19
lines changed

1 file changed

+7
-19
lines changed

autoarray/mask/mask_2d_util.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def total_pixels_2d_from(mask_2d: np.ndarray) -> int:
8181
return total_regular_pixels
8282

8383

84-
@numba_util.jit()
84+
@numba_util.jit(static_argnums=0)
8585
def mask_2d_circular_from(
8686
shape_native: Tuple[int, int],
8787
pixel_scales: ty.PixelScales,
@@ -114,24 +114,12 @@ def mask_2d_circular_from(
114114
mask = mask_circular_from(
115115
shape=(10, 10), pixel_scales=0.1, radius=0.5, centre=(0.0, 0.0))
116116
"""
117-
118-
mask_2d = np.full(shape_native, True)
119-
120-
centres_scaled = mask_2d_centres_from(
121-
shape_native=mask_2d.shape, pixel_scales=pixel_scales, centre=centre
122-
)
123-
124-
for y in range(mask_2d.shape[0]):
125-
for x in range(mask_2d.shape[1]):
126-
y_scaled = (y - centres_scaled[0]) * pixel_scales[0]
127-
x_scaled = (x - centres_scaled[1]) * pixel_scales[1]
128-
129-
r_scaled = np.sqrt(x_scaled**2 + y_scaled**2)
130-
131-
if r_scaled <= radius:
132-
mask_2d[y, x] = False
133-
134-
return mask_2d
117+
centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre)
118+
ys, xs = np.indices(shape_native)
119+
return (radius * radius) < (
120+
np.square((ys - centres_scaled[0]) * pixel_scales[0]) +
121+
np.square((xs - centres_scaled[1]) * pixel_scales[1])
122+
)
135123

136124

137125
@numba_util.jit()

0 commit comments

Comments
 (0)