Skip to content

Commit 3cf259f

Browse files
Jammy2211Jammy2211
authored andcommitted
add unit test on loading curvature preload
1 parent d185f9f commit 3cf259f

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

autoarray/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from . import util
1010
from . import fixtures
1111
from . import mock as m
12+
from .dataset.interferometer.w_tilde import load_curvature_preload_if_compatible
1213
from .dataset import preprocess
1314
from .dataset.abstract.dataset import AbstractDataset
1415
from .dataset.abstract.w_tilde import AbstractWTilde

test_autoarray/dataset/interferometer/test_dataset.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import shutil
55

66
import autoarray as aa
7+
import pytest
78

89
from autoarray.operators import transformer
910

@@ -148,3 +149,54 @@ def test__different_interferometer_without_mock_objects__customize_constructor_i
148149
assert (dataset.data == 1.0 + 1.0j * np.ones((19,))).all()
149150
assert (dataset.noise_map == 2.0 + 2.0j * np.ones((19,))).all()
150151
assert (dataset.uv_wavelengths == 3.0 * np.ones((19, 2))).all()
152+
153+
154+
def test__curvature_preload_metadata_from(
155+
visibilities_7,
156+
visibilities_noise_map_7,
157+
uv_wavelengths_7x2,
158+
mask_2d_7x7,
159+
):
160+
161+
dataset = aa.Interferometer(
162+
data=visibilities_7,
163+
noise_map=visibilities_noise_map_7,
164+
uv_wavelengths=uv_wavelengths_7x2,
165+
real_space_mask=mask_2d_7x7,
166+
)
167+
168+
dataset = dataset.apply_w_tilde(use_jax=False)
169+
170+
file = f"{test_data_path}/curvature_preload_metadata"
171+
172+
dataset.w_tilde.save_curvature_preload(
173+
file=file,
174+
overwrite=True,
175+
)
176+
177+
curvature_preload = aa.load_curvature_preload_if_compatible(
178+
file=file,
179+
real_space_mask=dataset.real_space_mask
180+
)
181+
182+
real_space_mask_changed = np.array(
183+
[
184+
[True, True, True, True, True, True, True],
185+
[True, True, True, True, True, True, True],
186+
[True, True, False, False, False, True, True],
187+
[True, True, False, True, False, True, True],
188+
[True, True, False, False, False, True, True],
189+
[True, True, True, True, True, True, True],
190+
[True, True, True, True, True, True, True],
191+
]
192+
)
193+
194+
real_space_mask_changed = aa.Mask2D(mask=real_space_mask_changed, pixel_scales=(1.0, 1.0))
195+
196+
with pytest.raises(ValueError):
197+
198+
curvature_preload = aa.load_curvature_preload_if_compatible(
199+
file=file,
200+
real_space_mask=real_space_mask_changed
201+
)
202+

0 commit comments

Comments
 (0)