Skip to content

Commit 8ecdec0

Browse files
Jammy2211Jammy2211
authored andcommitted
more updates
1 parent fdc510f commit 8ecdec0

File tree

6 files changed

+50
-150
lines changed

6 files changed

+50
-150
lines changed

autoarray/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99
from . import util
1010
from . import fixtures
1111
from . import mock as m
12-
from .inversion.inversion.interferometer.inversion_interferometer_util import (
13-
load_curvature_preload,
14-
)
12+
1513
from .dataset import preprocess
1614
from .dataset.abstract.dataset import AbstractDataset
1715
from .dataset.grids import GridsInterface

autoarray/dataset/interferometer/dataset.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,10 @@ def apply_sparse_linear_algebra(
201201
"INTERFEROMETER - Computing W-Tilde; runtime scales with visibility count and mask resolution, CPU run times may exceed hours."
202202
)
203203

204-
curvature_preload = inversion_interferometer_util.w_tilde_curvature_preload_interferometer_from(
205-
noise_map_real=self.noise_map.array.real,
206-
uv_wavelengths=self.uv_wavelengths,
207-
shape_masked_pixels_2d=self.transformer.grid.mask.shape_native_masked_pixels,
208-
grid_radians_2d=self.transformer.grid.mask.derive_grid.all_false.in_radians.native.array,
204+
curvature_preload = self.curvature_preload_from(
209205
chunk_k=chunk_k,
210-
show_memory=show_memory,
211206
show_progress=show_progress,
207+
show_memory=show_memory,
212208
use_jax=use_jax,
213209
)
214210

@@ -233,6 +229,24 @@ def apply_sparse_linear_algebra(
233229
sparse_linalg=sparse_linalg,
234230
)
235231

232+
def curvature_preload_from(
233+
self,
234+
chunk_k: int = 2048,
235+
show_progress: bool = False,
236+
show_memory: bool = False,
237+
use_jax: bool = False,
238+
):
239+
return inversion_interferometer_util.w_tilde_curvature_preload_interferometer_from(
240+
noise_map_real=self.noise_map.array.real,
241+
uv_wavelengths=self.uv_wavelengths,
242+
shape_masked_pixels_2d=self.transformer.grid.mask.shape_native_masked_pixels,
243+
grid_radians_2d=self.transformer.grid.mask.derive_grid.all_false.in_radians.native.array,
244+
chunk_k=chunk_k,
245+
show_memory=show_memory,
246+
show_progress=show_progress,
247+
use_jax=use_jax,
248+
)
249+
236250
@property
237251
def mask(self):
238252
return self.real_space_mask

autoarray/inversion/inversion/imaging/inversion_imaging_util.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ class ImagingSparseLinAlg:
294294

295295
data_native: np.ndarray
296296
noise_map_native: np.ndarray
297-
weight_map : np.ndarray
297+
weight_map: np.ndarray
298298
inverse_variances_native: "jax.Array" # (y, x) float64
299299
y_shape: int
300300
x_shape: int
@@ -341,10 +341,8 @@ def from_noise_map_and_psf(
341341
)
342342
inverse_variances_native = inverse_variances_native.native
343343

344-
weight_map = data.array / (noise_map.array ** 2)
345-
weight_map = Array2D(
346-
values=weight_map, mask=noise_map.mask
347-
)
344+
weight_map = data.array / (noise_map.array**2)
345+
weight_map = Array2D(values=weight_map, mask=noise_map.mask)
348346

349347
# If you *also* want to zero masked pixels explicitly:
350348
# mask_native = noise_map.mask (depends on your API; might be bool native)

autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -528,36 +528,6 @@ def w_tilde_via_preload_from(curvature_preload, native_index_for_slim_index):
528528
return w_tilde_via_preload
529529

530530

531-
def load_curvature_preload(
532-
file: Union[str, Path],
533-
) -> Optional[np.ndarray]:
534-
"""
535-
Load a saved curvature_preload if (and only if) it is compatible with the current mask geometry.
536-
537-
Parameters
538-
----------
539-
file
540-
Path to a previously saved NPZ.
541-
require_mask_hash
542-
If True, require the full mask content hash to match (safest).
543-
If False, only bbox + shape + pixel scales are checked.
544-
545-
Returns
546-
-------
547-
np.ndarray
548-
The loaded curvature_preload if compatible, otherwise raises ValueError.
549-
"""
550-
file = Path(file)
551-
if file.suffix.lower() != ".npz":
552-
file = file.with_suffix(".npz")
553-
554-
if not file.exists():
555-
raise FileNotFoundError(str(file))
556-
557-
with np.load(file, allow_pickle=False) as npz:
558-
return np.asarray(npz["curvature_preload"])
559-
560-
561531
@dataclass(frozen=True)
562532
class InterferometerSparseLinAlg:
563533
"""
@@ -718,44 +688,3 @@ def compute_block(start_col: int) -> jnp.ndarray:
718688
pix_weights_for_sub_slim_index,
719689
fft_index_for_masked_pixel,
720690
)
721-
722-
def save_curvature_preload(
723-
self,
724-
file: Union[str, Path],
725-
*,
726-
overwrite: bool = False,
727-
) -> Path:
728-
"""
729-
Save curvature_preload plus enough metadata to ensure it is only reused when safe.
730-
731-
Uses NPZ so we can store:
732-
- curvature_preload (array)
733-
- meta_json (string)
734-
735-
Parameters
736-
----------
737-
file
738-
Path to save to. Recommended suffix: ".npz".
739-
If you pass ".npy", we will still save an ".npz" next to it.
740-
overwrite
741-
If False and the file exists, raise FileExistsError.
742-
743-
Returns
744-
-------
745-
Path
746-
The path actually written (will end with ".npz").
747-
"""
748-
file = Path(file)
749-
750-
# Force .npz (storing metadata safely)
751-
if file.suffix.lower() != ".npz":
752-
file = file.with_suffix(".npz")
753-
754-
if file.exists() and not overwrite:
755-
raise FileExistsError(f"File already exists: {file}")
756-
757-
np.savez_compressed(
758-
file,
759-
curvature_preload=np.asarray(self.curvature_preload),
760-
)
761-
return file

test_autoarray/dataset/interferometer/test_dataset.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -149,55 +149,3 @@ def test__different_interferometer_without_mock_objects__customize_constructor_i
149149
assert (dataset.data == 1.0 + 1.0j * np.ones((19,))).all()
150150
assert (dataset.noise_map == 2.0 + 2.0j * np.ones((19,))).all()
151151
assert (dataset.uv_wavelengths == 3.0 * np.ones((19, 2))).all()
152-
153-
154-
def test__curvature_preload_metadata_from(
155-
visibilities_7,
156-
visibilities_noise_map_7,
157-
uv_wavelengths_7x2,
158-
mask_2d_7x7,
159-
):
160-
161-
dataset = aa.Interferometer(
162-
data=visibilities_7,
163-
noise_map=visibilities_noise_map_7,
164-
uv_wavelengths=uv_wavelengths_7x2,
165-
real_space_mask=mask_2d_7x7,
166-
)
167-
168-
dataset = dataset.apply_sparse_linear_algebra(use_jax=False)
169-
170-
file = f"{test_data_path}/curvature_preload_metadata"
171-
172-
dataset.sparse_linalg.save_curvature_preload(
173-
file=file,
174-
overwrite=True,
175-
)
176-
177-
curvature_preload = aa.load_curvature_preload(
178-
file=file, real_space_mask=dataset.real_space_mask
179-
)
180-
181-
assert curvature_preload[0, 0] == pytest.approx(1.75, 1.0e-4)
182-
183-
real_space_mask_changed = np.array(
184-
[
185-
[True, True, True, True, True, True, True],
186-
[True, True, True, True, True, True, True],
187-
[True, True, False, False, False, True, True],
188-
[True, True, False, True, False, True, True],
189-
[True, True, False, False, False, True, True],
190-
[True, True, True, True, True, True, True],
191-
[True, True, True, True, True, True, True],
192-
]
193-
)
194-
195-
real_space_mask_changed = aa.Mask2D(
196-
mask=real_space_mask_changed, pixel_scales=(1.0, 1.0)
197-
)
198-
199-
with pytest.raises(ValueError):
200-
201-
curvature_preload = aa.load_curvature_preload(
202-
file=file, real_space_mask=real_space_mask_changed
203-
)

test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,35 +37,46 @@ def test__w_tilde_imaging_from():
3737

3838

3939
def test__weighted_data_imaging_from():
40-
data = np.array(
41-
[
40+
41+
mask = aa.Mask2D(
42+
mask=[
43+
[True, True, True, True],
44+
[True, False, False, True],
45+
[True, False, False, True],
46+
[True, True, True, True],
47+
],
48+
pixel_scales=(1.0, 1.0),
49+
)
50+
51+
data = aa.Array2D(
52+
values=[
4253
[0.0, 0.0, 0.0, 0.0],
4354
[0.0, 2.0, 1.0, 0.0],
4455
[0.0, 1.0, 2.0, 0.0],
4556
[0.0, 0.0, 0.0, 0.0],
46-
]
57+
],
58+
mask=mask,
4759
)
4860

49-
noise_map = np.array(
50-
[
61+
noise_map = aa.Array2D(
62+
values=[
5163
[0.0, 0.0, 0.0, 0.0],
5264
[0.0, 1.0, 1.0, 0.0],
5365
[0.0, 1.0, 2.0, 0.0],
5466
[0.0, 0.0, 0.0, 0.0],
55-
]
67+
],
68+
mask=mask,
5669
)
5770

5871
kernel = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 2.0, 0.0]])
5972

6073
native_index_for_slim_index = np.array([[1, 1], [1, 2], [2, 1], [2, 2]])
6174

62-
weight_map = data.array / (noise_map.array ** 2)
63-
weight_map = aa.Array2D(
64-
values=weight_map, mask=noise_map.mask
65-
)
75+
weight_map = data / (noise_map**2)
76+
weight_map = aa.Array2D(values=weight_map, mask=mask)
6677

6778
weighted_data = aa.util.inversion_imaging.weighted_data_imaging_from(
68-
weight_map_native=weight_map.native,
79+
weight_map_native=weight_map.native.array,
6980
kernel_native=kernel,
7081
native_index_for_slim_index=native_index_for_slim_index,
7182
)
@@ -230,9 +241,11 @@ def test__data_vector_via_weighted_data_two_methods_agree():
230241
sub_fraction_slim=mapper.over_sampler.sub_fraction.array,
231242
)
232243

244+
weight_map = image.array / (noise_map.array**2)
245+
weight_map = aa.Array2D(values=weight_map, mask=noise_map.mask)
246+
233247
weighted_data = aa.util.inversion_imaging.weighted_data_imaging_from(
234-
image_native=image.native.array,
235-
noise_map_native=noise_map.native.array,
248+
weight_map_native=weight_map.native.array,
236249
kernel_native=kernel.native.array,
237250
native_index_for_slim_index=mask.derive_indexes.native_for_slim.astype(
238251
"int"

0 commit comments

Comments
 (0)