Skip to content

Commit 9324e25

Browse files
Jammy2211Jammy2211
authored andcommitted
fix adaptive rectangular
1 parent 6968472 commit 9324e25

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+371
-354
lines changed

autoarray/abstract_ndarray.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ def __getitem__(self, item):
342342

343343
def __setitem__(self, key, value):
344344
from jax import Array
345+
345346
if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)):
346347
self._array = jnp.where(key, value, self._array)
347348
else:

autoarray/dataset/imaging/simulator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ def __init__(
100100
self.noise_seed = noise_seed
101101

102102
def via_image_from(
103-
self, image: Array2D, over_sample_size: Optional[Union[int, np.ndarray]] = None, xp=np
103+
self,
104+
image: Array2D,
105+
over_sample_size: Optional[Union[int, np.ndarray]] = None,
106+
xp=np,
104107
) -> Imaging:
105108
"""
106109
Simulate an `Imaging` dataset from an input image.

autoarray/fit/fit_dataset.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def chi_squared(self) -> float:
8383
"""
8484
Returns the chi-squared terms of the model data's fit to an dataset, by summing the chi-squared-map.
8585
"""
86-
return fit_util.chi_squared_from(chi_squared_map=self.chi_squared_map.array, xp=self._xp)
86+
return fit_util.chi_squared_from(
87+
chi_squared_map=self.chi_squared_map.array, xp=self._xp
88+
)
8789

8890
@property
8991
def noise_normalization(self) -> float:
@@ -92,7 +94,9 @@ def noise_normalization(self) -> float:
9294
9395
[Noise_Term] = sum(log(2*pi*[Noise]**2.0))
9496
"""
95-
return fit_util.noise_normalization_from(noise_map=self.noise_map.array, xp=self._xp)
97+
return fit_util.noise_normalization_from(
98+
noise_map=self.noise_map.array, xp=self._xp
99+
)
96100

97101
@property
98102
def log_likelihood(self) -> float:
@@ -113,7 +117,7 @@ def __init__(
113117
dataset,
114118
use_mask_in_fit: bool = False,
115119
dataset_model: DatasetModel = None,
116-
xp=np
120+
xp=np,
117121
):
118122
"""Class to fit a masked dataset where the dataset's data structures are any dimension.
119123
@@ -209,7 +213,10 @@ def normalized_residual_map(self) -> ty.DataLike:
209213
"""
210214
if self.use_mask_in_fit:
211215
return fit_util.normalized_residual_map_with_mask_from(
212-
residual_map=self.residual_map, noise_map=self.noise_map, mask=self.mask, xp=self._xp
216+
residual_map=self.residual_map,
217+
noise_map=self.noise_map,
218+
mask=self.mask,
219+
xp=self._xp,
213220
)
214221
return super().normalized_residual_map
215222

@@ -222,7 +229,10 @@ def chi_squared_map(self) -> ty.DataLike:
222229
"""
223230
if self.use_mask_in_fit:
224231
return fit_util.chi_squared_map_with_mask_from(
225-
residual_map=self.residual_map, noise_map=self.noise_map, mask=self.mask, xp=self._xp
232+
residual_map=self.residual_map,
233+
noise_map=self.noise_map,
234+
mask=self.mask,
235+
xp=self._xp,
226236
)
227237
return super().chi_squared_map
228238

autoarray/fit/fit_imaging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(
1414
dataset: Imaging,
1515
use_mask_in_fit: bool = False,
1616
dataset_model: DatasetModel = None,
17-
xp=np
17+
xp=np,
1818
):
1919
"""
2020
Class to fit a masked imaging dataset.
@@ -50,7 +50,7 @@ def __init__(
5050
dataset=dataset,
5151
use_mask_in_fit=use_mask_in_fit,
5252
dataset_model=dataset_model,
53-
xp=xp
53+
xp=xp,
5454
)
5555

5656
@property

autoarray/fit/fit_interferometer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(
1717
dataset: Interferometer,
1818
dataset_model: DatasetModel = None,
1919
use_mask_in_fit: bool = False,
20-
xp=np
20+
xp=np,
2121
):
2222
"""
2323
Class to fit a masked interferometer dataset.
@@ -58,7 +58,7 @@ def __init__(
5858
dataset=dataset,
5959
dataset_model=dataset_model,
6060
use_mask_in_fit=use_mask_in_fit,
61-
xp=xp
61+
xp=xp,
6262
)
6363

6464
@property

autoarray/fit/fit_util.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,9 @@ def chi_squared_map_with_mask_from(
247247
return xp.where(xp.asarray(mask) == 0, xp.square(residual_map / noise_map), 0)
248248

249249

250-
def chi_squared_with_mask_from(*, chi_squared_map: ty.DataLike, mask: Mask, xp=np) -> float:
250+
def chi_squared_with_mask_from(
251+
*, chi_squared_map: ty.DataLike, mask: Mask, xp=np
252+
) -> float:
251253
"""
252254
Returns the chi-squared terms of each model data's fit to a masked dataset, by summing the masked
253255
chi-squared-map of the fit.
@@ -265,7 +267,12 @@ def chi_squared_with_mask_from(*, chi_squared_map: ty.DataLike, mask: Mask, xp=n
265267

266268

267269
def chi_squared_with_mask_fast_from(
268-
*, data: ty.DataLike, mask: Mask, model_data: ty.DataLike, noise_map: ty.DataLike, xp=np
270+
*,
271+
data: ty.DataLike,
272+
mask: Mask,
273+
model_data: ty.DataLike,
274+
noise_map: ty.DataLike,
275+
xp=np,
269276
) -> float:
270277
"""
271278
Returns the chi-squared terms of each model data's fit to a masked dataset, by summing the masked
@@ -302,7 +309,9 @@ def chi_squared_with_mask_fast_from(
302309
)
303310

304311

305-
def noise_normalization_with_mask_from(*, noise_map: ty.DataLike, mask: Mask, xp=np) -> float:
312+
def noise_normalization_with_mask_from(
313+
*, noise_map: ty.DataLike, mask: Mask, xp=np
314+
) -> float:
306315
"""
307316
Returns the noise-map normalization terms of masked noise-map, summing the noise_map value in every pixel as:
308317
@@ -317,9 +326,7 @@ def noise_normalization_with_mask_from(*, noise_map: ty.DataLike, mask: Mask, xp
317326
mask
318327
The mask applied to the noise-map, where `False` entries are included in the calculation.
319328
"""
320-
return float(
321-
xp.sum(xp.log(2 * xp.pi * noise_map[xp.asarray(mask) == 0] ** 2.0))
322-
)
329+
return float(xp.sum(xp.log(2 * xp.pi * noise_map[xp.asarray(mask) == 0] ** 2.0)))
323330

324331

325332
def chi_squared_with_noise_covariance_from(

autoarray/inversion/inversion/abstract.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
from autoarray.util import misc_util
1919
from autoarray.inversion.inversion import inversion_util
2020

21+
2122
class AbstractInversion:
2223
def __init__(
2324
self,
2425
dataset: Union[Imaging, Interferometer, DatasetInterface],
2526
linear_obj_list: List[LinearObj],
2627
settings: SettingsInversion = SettingsInversion(),
2728
preloads: Preloads = None,
28-
xp=np
29+
xp=np,
2930
):
3031
"""
3132
An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions
@@ -75,8 +76,6 @@ def __init__(
7576

7677
self._xp = xp
7778

78-
79-
8079
@property
8180
def data(self):
8281
return self.dataset.data
@@ -333,10 +332,15 @@ def regularization_matrix(self) -> Optional[np.ndarray]:
333332
"""
334333
if self._xp.__name__.startswith("jax"):
335334
from jax.scipy.linalg import block_diag
335+
336336
return block_diag(
337-
*[linear_obj.regularization_matrix for linear_obj in self.linear_obj_list]
337+
*[
338+
linear_obj.regularization_matrix
339+
for linear_obj in self.linear_obj_list
340+
]
338341
)
339342
from scipy.linalg import block_diag
343+
340344
return block_diag(
341345
*[linear_obj.regularization_matrix for linear_obj in self.linear_obj_list]
342346
)
@@ -448,7 +452,7 @@ def reconstruction(self) -> np.ndarray:
448452
data_vector=data_vector,
449453
curvature_reg_matrix=curvature_reg_matrix,
450454
settings=self.settings,
451-
xp=self._xp
455+
xp=self._xp,
452456
)
453457
)
454458

@@ -471,13 +475,13 @@ def reconstruction(self) -> np.ndarray:
471475
data_vector=self.data_vector,
472476
curvature_reg_matrix=self.curvature_reg_matrix,
473477
settings=self.settings,
474-
xp=self._xp
478+
xp=self._xp,
475479
)
476480

477481
return inversion_util.reconstruction_positive_negative_from(
478482
data_vector=self.data_vector,
479483
curvature_reg_matrix=self.curvature_reg_matrix,
480-
xp=self._xp
484+
xp=self._xp,
481485
)
482486

483487
@property
@@ -640,7 +644,9 @@ def regularization_term(self) -> float:
640644

641645
return self._xp.matmul(
642646
self.reconstruction_reduced.T,
643-
self._xp.matmul(self.regularization_matrix_reduced, self.reconstruction_reduced),
647+
self._xp.matmul(
648+
self.regularization_matrix_reduced, self.reconstruction_reduced
649+
),
644650
)
645651

646652
@property
@@ -654,7 +660,11 @@ def log_det_curvature_reg_matrix_term(self) -> float:
654660
return 0.0
655661

656662
return 2.0 * self._xp.sum(
657-
self._xp.log(self._xp.diag(self._xp.linalg.cholesky(self.curvature_reg_matrix_reduced)))
663+
self._xp.log(
664+
self._xp.diag(
665+
self._xp.linalg.cholesky(self.curvature_reg_matrix_reduced)
666+
)
667+
)
658668
)
659669

660670
@property
@@ -675,7 +685,11 @@ def log_det_regularization_matrix_term(self) -> float:
675685
return 0.0
676686

677687
return 2.0 * self._xp.sum(
678-
self._xp.log(self._xp.diag(self._xp.linalg.cholesky(self.regularization_matrix_reduced)))
688+
self._xp.log(
689+
self._xp.diag(
690+
self._xp.linalg.cholesky(self.regularization_matrix_reduced)
691+
)
692+
)
679693
)
680694

681695
@property
@@ -738,7 +752,9 @@ def regularization_weights_from(self, index: int) -> np.ndarray:
738752

739753
return np.zeros((pixels,))
740754

741-
return regularization.regularization_weights_from(linear_obj=linear_obj, xp=self._xp)
755+
return regularization.regularization_weights_from(
756+
linear_obj=linear_obj, xp=self._xp
757+
)
742758

743759
@property
744760
def regularization_weights_mapper_dict(self) -> Dict[LinearObj, np.ndarray]:

autoarray/inversion/inversion/factory.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def inversion_from(
2424
linear_obj_list: List[LinearObj],
2525
settings: SettingsInversion = SettingsInversion(),
2626
preloads: Preloads = None,
27-
xp=np
27+
xp=np,
2828
):
2929
"""
3030
Factory which given an input dataset and list of linear objects, creates an `Inversion`.
@@ -60,14 +60,11 @@ def inversion_from(
6060
linear_obj_list=linear_obj_list,
6161
settings=settings,
6262
preloads=preloads,
63-
xp=xp
63+
xp=xp,
6464
)
6565

6666
return inversion_interferometer_from(
67-
dataset=dataset,
68-
linear_obj_list=linear_obj_list,
69-
settings=settings,
70-
xp=xp
67+
dataset=dataset, linear_obj_list=linear_obj_list, settings=settings, xp=xp
7168
)
7269

7370

@@ -76,7 +73,7 @@ def inversion_imaging_from(
7673
linear_obj_list: List[LinearObj],
7774
settings: SettingsInversion = SettingsInversion(),
7875
preloads: Preloads = None,
79-
xp=np
76+
xp=np,
8077
):
8178
"""
8279
Factory which given an input `Imaging` dataset and list of linear objects, creates an `InversionImaging`.
@@ -129,23 +126,23 @@ def inversion_imaging_from(
129126
w_tilde=w_tilde,
130127
linear_obj_list=linear_obj_list,
131128
settings=settings,
132-
xp=xp
129+
xp=xp,
133130
)
134131

135132
return InversionImagingMapping(
136133
dataset=dataset,
137134
linear_obj_list=linear_obj_list,
138135
settings=settings,
139136
preloads=preloads,
140-
xp=xp
137+
xp=xp,
141138
)
142139

143140

144141
def inversion_interferometer_from(
145142
dataset: Union[Interferometer, DatasetInterface],
146143
linear_obj_list: List[LinearObj],
147144
settings: SettingsInversion = SettingsInversion(),
148-
xp=np
145+
xp=np,
149146
):
150147
"""
151148
Factory which given an input `Interferometer` dataset and list of linear objects, creates
@@ -199,13 +196,13 @@ def inversion_interferometer_from(
199196
w_tilde=w_tilde,
200197
linear_obj_list=linear_obj_list,
201198
settings=settings,
202-
xp=xp
199+
xp=xp,
203200
)
204201

205202
else:
206203
return InversionInterferometerMapping(
207204
dataset=dataset,
208205
linear_obj_list=linear_obj_list,
209206
settings=settings,
210-
xp=xp
207+
xp=xp,
211208
)

autoarray/inversion/inversion/imaging/abstract.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(
2020
linear_obj_list: List[LinearObj],
2121
settings: SettingsInversion = SettingsInversion(),
2222
preloads: Preloads = None,
23-
xp=np
23+
xp=np,
2424
):
2525
"""
2626
An `Inversion` reconstructs an input dataset using a list of linear objects (e.g. a list of analytic functions
@@ -93,7 +93,9 @@ def operated_mapping_matrix_list(self) -> List[np.ndarray]:
9393
return [
9494
(
9595
self.psf.convolved_mapping_matrix_from(
96-
mapping_matrix=linear_obj.mapping_matrix, mask=self.mask, xp=self._xp
96+
mapping_matrix=linear_obj.mapping_matrix,
97+
mask=self.mask,
98+
xp=self._xp,
9799
)
98100
if linear_obj.operated_mapping_matrix_override is None
99101
else self.linear_func_operated_mapping_matrix_dict[linear_obj]
@@ -137,7 +139,7 @@ def linear_func_operated_mapping_matrix_dict(self) -> Dict:
137139
operated_mapping_matrix = self.psf.convolved_mapping_matrix_from(
138140
mapping_matrix=linear_func.mapping_matrix,
139141
mask=self.mask,
140-
xp=self._xp
142+
xp=self._xp,
141143
)
142144

143145
linear_func_operated_mapping_matrix_dict[linear_func] = (
@@ -217,9 +219,7 @@ def mapper_operated_mapping_matrix_dict(self) -> Dict:
217219

218220
for mapper in self.cls_list_from(cls=AbstractMapper):
219221
operated_mapping_matrix = self.psf.convolved_mapping_matrix_from(
220-
mapping_matrix=mapper.mapping_matrix,
221-
mask=self.mask,
222-
xp=self._xp
222+
mapping_matrix=mapper.mapping_matrix, mask=self.mask, xp=self._xp
223223
)
224224

225225
mapper_operated_mapping_matrix_dict[mapper] = operated_mapping_matrix

0 commit comments

Comments
 (0)