|
4 | 4 | import shutil |
5 | 5 |
|
6 | 6 | import autoarray as aa |
| 7 | +import pytest |
7 | 8 |
|
8 | 9 | from autoarray.operators import transformer |
9 | 10 |
|
@@ -148,3 +149,54 @@ def test__different_interferometer_without_mock_objects__customize_constructor_i |
148 | 149 | assert (dataset.data == 1.0 + 1.0j * np.ones((19,))).all() |
149 | 150 | assert (dataset.noise_map == 2.0 + 2.0j * np.ones((19,))).all() |
150 | 151 | assert (dataset.uv_wavelengths == 3.0 * np.ones((19, 2))).all() |
| 152 | + |
| 153 | + |
| 154 | +def test__curvature_preload_metadata_from( |
| 155 | + visibilities_7, |
| 156 | + visibilities_noise_map_7, |
| 157 | + uv_wavelengths_7x2, |
| 158 | + mask_2d_7x7, |
| 159 | +): |
| 160 | + |
| 161 | + dataset = aa.Interferometer( |
| 162 | + data=visibilities_7, |
| 163 | + noise_map=visibilities_noise_map_7, |
| 164 | + uv_wavelengths=uv_wavelengths_7x2, |
| 165 | + real_space_mask=mask_2d_7x7, |
| 166 | + ) |
| 167 | + |
| 168 | + dataset = dataset.apply_w_tilde(use_jax=False) |
| 169 | + |
| 170 | + file = f"{test_data_path}/curvature_preload_metadata" |
| 171 | + |
| 172 | + dataset.w_tilde.save_curvature_preload( |
| 173 | + file=file, |
| 174 | + overwrite=True, |
| 175 | + ) |
| 176 | + |
| 177 | + curvature_preload = aa.load_curvature_preload_if_compatible( |
| 178 | + file=file, |
| 179 | + real_space_mask=dataset.real_space_mask |
| 180 | + ) |
| 181 | + |
| 182 | + real_space_mask_changed = np.array( |
| 183 | + [ |
| 184 | + [True, True, True, True, True, True, True], |
| 185 | + [True, True, True, True, True, True, True], |
| 186 | + [True, True, False, False, False, True, True], |
| 187 | + [True, True, False, True, False, True, True], |
| 188 | + [True, True, False, False, False, True, True], |
| 189 | + [True, True, True, True, True, True, True], |
| 190 | + [True, True, True, True, True, True, True], |
| 191 | + ] |
| 192 | + ) |
| 193 | + |
| 194 | + real_space_mask_changed = aa.Mask2D(mask=real_space_mask_changed, pixel_scales=(1.0, 1.0)) |
| 195 | + |
| 196 | + with pytest.raises(ValueError): |
| 197 | + |
| 198 | + curvature_preload = aa.load_curvature_preload_if_compatible( |
| 199 | + file=file, |
| 200 | + real_space_mask=real_space_mask_changed |
| 201 | + ) |
| 202 | + |
0 commit comments