Skip to content

Commit 28adc24

Browse files
authored
Merge pull request #191 from Jammy2211/feature/rectangular_weight
Feature/rectangular weight
2 parents 72c4d1d + 5347dd5 commit 28adc24

File tree

122 files changed

+1805
-1402
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

122 files changed

+1805
-1402
lines changed

autoarray/abstract_ndarray.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from abc import abstractmethod
77
import jax.numpy as jnp
88
from jax._src.tree_util import register_pytree_node
9-
from jax import Array
9+
10+
import numpy as np
1011

1112
from autoconf.fitsable import output_to_fits
1213

@@ -64,7 +65,11 @@ def wrapper(self, other):
6465

6566

6667
class AbstractNDArray(ABC):
67-
def __init__(self, array):
68+
69+
__no_flatten__ = ()
70+
71+
def __init__(self, array, xp=np):
72+
6873
self._is_transformed = False
6974

7075
while isinstance(array, AbstractNDArray):
@@ -79,7 +84,7 @@ def __init__(self, array):
7984
except ValueError:
8085
pass
8186

82-
__no_flatten__ = ()
87+
self._xp = xp
8388

8489
def invert(self):
8590
new = self.copy()
@@ -102,12 +107,6 @@ def instance_flatten(cls, instance):
102107
)
103108
return values, keys
104109

105-
@staticmethod
106-
def flip_hdu_for_ds9(values):
107-
if conf.instance["general"]["fits"]["flip_for_ds9"]:
108-
return jnp.flipud(values)
109-
return values
110-
111110
@classmethod
112111
def instance_unflatten(cls, aux_data, children):
113112
"""
@@ -138,6 +137,12 @@ def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray":
138137
new_array._array = array
139138
return new_array
140139

140+
@staticmethod
141+
def flip_hdu_for_ds9(values):
142+
if conf.instance["general"]["fits"]["flip_for_ds9"]:
143+
return jnp.flipud(values)
144+
return values
145+
141146
def copy(self):
142147
new = copy(self)
143148
return new
@@ -336,6 +341,7 @@ def __getitem__(self, item):
336341
return result
337342

338343
def __setitem__(self, key, value):
344+
from jax import Array
339345
if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)):
340346
self._array = jnp.where(key, value, self._array)
341347
else:

autoarray/config/general.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ inversion:
1313
reconstruction_vmax_factor: 0.5 # Plots of an Inversion's reconstruction use the reconstructed data's bright value multiplied by this factor.
1414
numba:
1515
use_numba: true
16-
cache: false
16+
cache: true
1717
nopython: true
1818
parallel: false
1919
pixelization:

autoarray/dataset/abstract/dataset.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
import warnings
55
from typing import Optional, Union
66

7-
from autoconf import cached_property
8-
9-
from autoarray.dataset.grids import GridsDataset
10-
117
from autoarray import exc
128
from autoarray.mask.mask_1d import Mask1D
139
from autoarray.mask.mask_2d import Mask2D
@@ -140,14 +136,6 @@ def __init__(
140136
def grid(self):
141137
return self.grids.lp
142138

143-
@cached_property
144-
def grids(self):
145-
return GridsDataset(
146-
mask=self.data.mask,
147-
over_sample_size_lp=self.over_sample_size_lp,
148-
over_sample_size_pixelization=self.over_sample_size_pixelization,
149-
)
150-
151139
@property
152140
def shape_native(self):
153141
return self.mask.shape_native
@@ -188,7 +176,7 @@ def signal_to_noise_max(self) -> float:
188176
"""
189177
return np.max(self.signal_to_noise_map)
190178

191-
@cached_property
179+
@property
192180
def noise_covariance_matrix_inv(self) -> np.ndarray:
193181
"""
194182
Returns the inverse of the noise covariance matrix, which is used when computing a chi-squared which accounts

autoarray/dataset/imaging/dataset.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33
from pathlib import Path
44
from typing import Optional, Union
55

6-
from autoconf import cached_property
7-
from autoconf import instance
8-
96
from autoarray.dataset.abstract.dataset import AbstractDataset
107
from autoarray.dataset.grids import GridsDataset
118
from autoarray.dataset.imaging.w_tilde import WTildeImaging
@@ -194,7 +191,7 @@ def __init__(
194191
psf=self.psf,
195192
)
196193

197-
@cached_property
194+
@property
198195
def w_tilde(self):
199196
"""
200197
The w_tilde formalism of the linear algebra equations precomputes the convolution of every pair of masked

autoarray/dataset/imaging/w_tilde.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import logging
22
import numpy as np
33

4-
from autoconf import cached_property
5-
64
from autoarray.dataset.abstract.w_tilde import AbstractWTilde
75

86
from autoarray.inversion.inversion.imaging import inversion_imaging_util
@@ -55,7 +53,7 @@ def __init__(
5553
self.psf = psf
5654
self.mask = mask
5755

58-
@cached_property
56+
@property
5957
def w_matrix(self):
6058
"""
6159
The matrix `w_tilde_curvature` is a matrix of dimensions [image_pixels, image_pixels] that encodes the PSF
@@ -93,7 +91,7 @@ def w_matrix(self):
9391
).astype("int"),
9492
)
9593

96-
@cached_property
94+
@property
9795
def psf_operator_matrix_dense(self):
9896

9997
return inversion_imaging_util.psf_operator_matrix_dense_from(

autoarray/dataset/interferometer/dataset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import numpy as np
33
from pathlib import Path
44

5-
from autoconf import cached_property
65
from autoconf.fitsable import ndarray_via_fits_from, output_to_fits
76

87
from autoarray.dataset.abstract.dataset import AbstractDataset
@@ -166,7 +165,7 @@ def w_tilde_preprocessing(self):
166165

167166
fits.writeto(filename, data=curvature_preload)
168167

169-
@cached_property
168+
@property
170169
def w_tilde(self):
171170
"""
172171
The w_tilde formalism of the linear algebra equations precomputes the Fourier Transform of all the visibilities

autoarray/exc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class MeshException(Exception):
7272
"""
7373
Raises exceptions associated with the `inversion/mesh` modules and `Mesh` classes.
7474
75-
For example if a `Rectangular` mesh has dimensions below 3x3.
75+
For example if a `RectangularMagnification` mesh has dimensions below 3x3.
7676
"""
7777

7878
pass

autoarray/fit/fit_dataset.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
import numpy as np
77

8-
from autoconf import cached_property
9-
108
from autoarray.dataset.grids import GridsInterface
119
from autoarray.dataset.dataset_model import DatasetModel
1210
from autoarray.fit import fit_util
@@ -85,7 +83,7 @@ def chi_squared(self) -> float:
8583
"""
8684
Returns the chi-squared terms of the model data's fit to an dataset, by summing the chi-squared-map.
8785
"""
88-
return fit_util.chi_squared_from(chi_squared_map=self.chi_squared_map.array)
86+
return fit_util.chi_squared_from(chi_squared_map=self.chi_squared_map.array, xp=self._xp)
8987

9088
@property
9189
def noise_normalization(self) -> float:
@@ -94,7 +92,7 @@ def noise_normalization(self) -> float:
9492
9593
[Noise_Term] = sum(log(2*pi*[Noise]**2.0))
9694
"""
97-
return fit_util.noise_normalization_from(noise_map=self.noise_map.array)
95+
return fit_util.noise_normalization_from(noise_map=self.noise_map.array, xp=self._xp)
9896

9997
@property
10098
def log_likelihood(self) -> float:
@@ -115,6 +113,7 @@ def __init__(
115113
dataset,
116114
use_mask_in_fit: bool = False,
117115
dataset_model: DatasetModel = None,
116+
xp=np
118117
):
119118
"""Class to fit a masked dataset where the dataset's data structures are any dimension.
120119
@@ -147,12 +146,13 @@ def __init__(
147146
self.dataset = dataset
148147
self.use_mask_in_fit = use_mask_in_fit
149148
self.dataset_model = dataset_model or DatasetModel()
149+
self._xp = xp
150150

151151
@property
152152
def mask(self) -> Mask2D:
153153
return self.dataset.mask
154154

155-
@cached_property
155+
@property
156156
def grids(self) -> GridsInterface:
157157

158158
def subtracted_from(grid, offset):
@@ -196,7 +196,7 @@ def residual_map(self) -> ty.DataLike:
196196

197197
if self.use_mask_in_fit:
198198
return fit_util.residual_map_with_mask_from(
199-
data=self.data, model_data=self.model_data, mask=self.mask
199+
data=self.data, model_data=self.model_data, mask=self.mask, xp=self._xp
200200
)
201201
return super().residual_map
202202

@@ -209,7 +209,7 @@ def normalized_residual_map(self) -> ty.DataLike:
209209
"""
210210
if self.use_mask_in_fit:
211211
return fit_util.normalized_residual_map_with_mask_from(
212-
residual_map=self.residual_map, noise_map=self.noise_map, mask=self.mask
212+
residual_map=self.residual_map, noise_map=self.noise_map, mask=self.mask, xp=self._xp
213213
)
214214
return super().normalized_residual_map
215215

@@ -222,7 +222,7 @@ def chi_squared_map(self) -> ty.DataLike:
222222
"""
223223
if self.use_mask_in_fit:
224224
return fit_util.chi_squared_map_with_mask_from(
225-
residual_map=self.residual_map, noise_map=self.noise_map, mask=self.mask
225+
residual_map=self.residual_map, noise_map=self.noise_map, mask=self.mask, xp=self._xp
226226
)
227227
return super().chi_squared_map
228228

@@ -243,7 +243,7 @@ def chi_squared(self) -> float:
243243

244244
if self.use_mask_in_fit:
245245
return fit_util.chi_squared_with_mask_from(
246-
chi_squared_map=self.chi_squared_map, mask=self.mask
246+
chi_squared_map=self.chi_squared_map, mask=self.mask, xp=self._xp
247247
)
248248
return super().chi_squared
249249

@@ -256,7 +256,7 @@ def noise_normalization(self) -> float:
256256
"""
257257
if self.use_mask_in_fit:
258258
return fit_util.noise_normalization_with_mask_from(
259-
noise_map=self.noise_map, mask=self.mask
259+
noise_map=self.noise_map, mask=self.mask, xp=self._xp
260260
)
261261
return super().noise_normalization
262262

autoarray/fit/fit_imaging.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional
1+
import numpy as np
22

33
from autoarray.dataset.imaging.dataset import Imaging
44
from autoarray.dataset.dataset_model import DatasetModel
@@ -14,6 +14,7 @@ def __init__(
1414
dataset: Imaging,
1515
use_mask_in_fit: bool = False,
1616
dataset_model: DatasetModel = None,
17+
xp=np
1718
):
1819
"""
1920
Class to fit a masked imaging dataset.
@@ -49,6 +50,7 @@ def __init__(
4950
dataset=dataset,
5051
use_mask_in_fit=use_mask_in_fit,
5152
dataset_model=dataset_model,
53+
xp=xp
5254
)
5355

5456
@property

autoarray/fit/fit_interferometer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
from typing import Dict, Optional
32

43
from autoarray.dataset.interferometer.dataset import Interferometer
54

@@ -18,6 +17,7 @@ def __init__(
1817
dataset: Interferometer,
1918
dataset_model: DatasetModel = None,
2019
use_mask_in_fit: bool = False,
20+
xp=np
2121
):
2222
"""
2323
Class to fit a masked interferometer dataset.
@@ -58,6 +58,7 @@ def __init__(
5858
dataset=dataset,
5959
dataset_model=dataset_model,
6060
use_mask_in_fit=use_mask_in_fit,
61+
xp=xp
6162
)
6263

6364
@property

0 commit comments

Comments
 (0)