Skip to content

Commit cbd68f0

Browse files
authored
Merge pull request #228 from Jammy2211/feature/documentation-fit
Improve docstrings for autoarray/fit package
2 parents 1ec232c + 21a1744 commit cbd68f0

File tree

4 files changed

+113
-47
lines changed

4 files changed

+113
-47
lines changed

autoarray/fit/fit_dataset.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(
145145
noise_normalization
146146
The overall normalization term of the noise_map, summed over every data point.
147147
log_likelihood
148-
The overall log likelihood of the model's fit to the dataset, summed over evey data point.
148+
The overall log likelihood of the model's fit to the dataset, summed over every data point.
149149
"""
150150
self.dataset = dataset
151151
self.use_mask_in_fit = use_mask_in_fit
@@ -162,6 +162,12 @@ def __init__(
162162

163163
@property
164164
def _xp(self):
165+
"""
166+
Returns the array module in use: `numpy` if JAX is disabled or `jax.numpy` if JAX is enabled.
167+
168+
This is controlled by the `use_jax` flag set during initialisation and is the single point of control
169+
for switching between NumPy and JAX computation paths throughout the fit.
170+
"""
165171
if self.use_jax:
166172
import jax.numpy as jnp
167173

@@ -170,10 +176,19 @@ def _xp(self):
170176

171177
@property
172178
def mask(self) -> Mask2D:
179+
"""
180+
The 2D mask of the dataset being fitted, where `False` entries are unmasked and included in the fit
181+
and `True` entries are masked and excluded.
182+
"""
173183
return self.dataset.mask
174184

175185
@property
176186
def grids(self) -> GridsInterface:
187+
"""
188+
The grids of (y,x) coordinates associated with the dataset, adjusted by any `grid_offset` specified in
189+
the `dataset_model`. Each grid (`lp`, `pixelization`, `blurring`) has the offset subtracted from it
190+
before being returned.
191+
"""
177192

178193
def subtracted_from(grid, offset):
179194
if grid is None:
@@ -200,10 +215,16 @@ def subtracted_from(grid, offset):
200215

201216
@property
202217
def data(self) -> ty.DataLike:
218+
"""
219+
The data of the dataset being fitted.
220+
"""
203221
return self.dataset.data
204222

205223
@property
206224
def noise_map(self) -> ty.DataLike:
225+
"""
226+
The noise-map of the dataset being fitted, representing the RMS noise in each pixel.
227+
"""
207228
return self.dataset.noise_map
208229

209230
@property
@@ -310,19 +331,7 @@ def log_evidence(self) -> float:
310331
Log Evidence = -0.5*[Chi_Squared_Term + Regularization_Term + Log(Covariance_Regularization_Term) -
311332
Log(Regularization_Matrix_Term) + Noise_Term]
312333
313-
Parameters
314-
----------
315-
chi_squared
316-
The chi-squared term of the inversion's fit to the data.
317-
regularization_term
318-
The regularization term of the inversion, which is the sum of the difference between reconstructed \
319-
flux of every pixel multiplied by the regularization coefficient.
320-
log_curvature_regularization_term
321-
The log of the determinant of the sum of the curvature and regularization matrices.
322-
log_regularization_term
323-
The log of the determinant o the regularization matrix.
324-
noise_normalization
325-
The normalization noise_map-term for the data's noise-map.
334+
Returns `None` if no inversion is present, in which case `log_likelihood` is used as the figure of merit.
326335
"""
327336
if self.inversion is not None:
328337
return fit_util.log_evidence_from(
@@ -335,6 +344,11 @@ def log_evidence(self) -> float:
335344

336345
@property
337346
def figure_of_merit(self) -> float:
347+
"""
348+
The overall goodness-of-fit of the model to the dataset.
349+
350+
If the fit uses an inversion, this is the `log_evidence`; otherwise it is the `log_likelihood`.
351+
"""
338352
if self.inversion is not None:
339353
return self.log_evidence
340354

@@ -371,4 +385,11 @@ def inversion(self) -> Optional[AbstractInversion]:
371385

372386
@property
373387
def reduced_chi_squared(self) -> float:
388+
"""
389+
The reduced chi-squared of the model's fit to the dataset, defined as:
390+
391+
Reduced_Chi_Squared = Chi_Squared / N_unmasked
392+
393+
where `N_unmasked` is the number of unmasked (i.e. `False`) pixels in the mask.
394+
"""
374395
return self.chi_squared / int(np.size(self.mask) - np.sum(self.mask))

autoarray/fit/fit_imaging.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343
noise_normalization
4444
The overall normalization term of the noise_map, summed over every data point.
4545
log_likelihood
46-
The overall log likelihood of the model's fit to the dataset, summed over evey data point.
46+
The overall log likelihood of the model's fit to the dataset, summed over every data point.
4747
"""
4848

4949
super().__init__(
@@ -55,8 +55,20 @@ def __init__(
5555

5656
@property
5757
def data(self) -> ty.DataLike:
58+
"""
59+
The imaging data being fitted, with any background sky level subtracted.
60+
61+
The background sky is taken from `dataset_model.background_sky_level`, which defaults to 0.0 if not
62+
set, meaning this property is equivalent to `dataset.data` in the common case.
63+
"""
5864
return self.dataset.data - self.dataset_model.background_sky_level
5965

6066
@property
6167
def blurred_image(self) -> Array2D:
68+
"""
69+
The PSF-convolved (blurred) model image of the fit, as a 2D `Array2D`.
70+
71+
This is the model image after it has been convolved with the dataset's PSF. It must be implemented by
72+
child classes (e.g. a tracer fit) that produce a blurred model image as part of their fitting procedure.
73+
"""
6274
raise NotImplementedError

autoarray/fit/fit_interferometer.py

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,10 @@ def __init__(
2424
2525
Parameters
2626
----------
27-
dataset : MaskedInterferometer
28-
The masked interferometer dataset that is fitted.
27+
dataset
28+
The interferometer dataset that is fitted, containing the observed visibilities and noise-map.
2929
dataset_model
3030
Attributes which allow for parts of a dataset to be treated as a model (e.g. the background sky level).
31-
model_data : Visibilities
32-
The model visibilities the masked imaging is fitted with.
33-
inversion : Inversion
34-
If the fit uses an `Inversion` this is the instance of the object used to perform the fit. This determines
35-
if the `log_likelihood` or `log_evidence` is used as the `figure_of_merit`.
3631
use_mask_in_fit
3732
If `True`, masked data points are omitted from the fit. If `False` they are not (in most use cases the
3833
`dataset` will have been processed to remove masked points, for example the `slim` representation).
@@ -51,7 +46,7 @@ def __init__(
5146
noise_normalization
5247
The overall normalization term of the noise_map, summed over every data point.
5348
log_likelihood
54-
The overall log likelihood of the model's fit to the dataset, summed over evey data point.
49+
The overall log likelihood of the model's fit to the dataset, summed over every data point.
5550
"""
5651

5752
super().__init__(
@@ -63,10 +58,22 @@ def __init__(
6358

6459
@property
6560
def mask(self) -> np.ndarray:
61+
"""
62+
The mask of the interferometer fit, returned as an all-`False` array matching the shape of the visibility data.
63+
64+
Interferometer data is not spatially masked in the same way as imaging data — all visibility measurements
65+
are included in the fit — so this always returns an unmasked array.
66+
"""
6667
return np.full(shape=self.data.shape, fill_value=False)
6768

6869
@property
6970
def transformer(self) -> ty.Transformer:
71+
"""
72+
The Fourier transformer used to map between image space and visibility (uv-plane) space.
73+
74+
This is taken directly from the interferometer dataset and is used internally to compute the
75+
`dirty_*` image-space representations of the fit quantities.
76+
"""
7077
return self.dataset.transformer
7178

7279
@property
@@ -135,19 +142,9 @@ def log_evidence(self) -> float:
135142
Log Evidence = -0.5*[Chi_Squared_Term + Regularization_Term + Log(Covariance_Regularization_Term) -
136143
Log(Regularization_Matrix_Term) + Noise_Term]
137144
138-
Parameters
139-
----------
140-
chi_squared
141-
The chi-squared term of the inversion's fit to the data.
142-
regularization_term
143-
The regularization term of the inversion, which is the sum of the difference between reconstructed \
144-
flux of every pixel multiplied by the regularization coefficient.
145-
log_curvature_regularization_term
146-
The log of the determinant of the sum of the curvature and regularization matrices.
147-
log_regularization_term
148-
The log of the determinant o the regularization matrix.
149-
noise_normalization
150-
The normalization noise_map-term for the data's noise-map.
145+
For interferometer fits the chi-squared uses the fast inversion chi-squared (`inversion.fast_chi_squared`).
146+
147+
Returns `None` if no inversion is present, in which case `log_likelihood` is used as the figure of merit.
151148
"""
152149
if self.inversion is not None:
153150
return fit_util.log_evidence_from(
@@ -160,28 +157,56 @@ def log_evidence(self) -> float:
160157

161158
@property
162159
def dirty_image(self) -> Array2D:
160+
"""
161+
The dirty image of the observed visibility data, computed by applying the inverse Fourier transform to the
162+
data visibilities. This is the image-space representation of the observed data before any deconvolution.
163+
"""
163164
return self.transformer.image_from(visibilities=self.data)
164165

165166
@property
166167
def dirty_noise_map(self) -> Array2D:
168+
"""
169+
The dirty noise-map, computed by applying the inverse Fourier transform to the noise-map visibilities.
170+
This gives an image-space representation of the noise level across the field of view.
171+
"""
167172
return self.transformer.image_from(visibilities=self.noise_map)
168173

169174
@property
170175
def dirty_signal_to_noise_map(self) -> Array2D:
176+
"""
177+
The dirty signal-to-noise map, computed by applying the inverse Fourier transform to the signal-to-noise
178+
visibilities. This gives an image-space representation of the signal-to-noise ratio across the field of view.
179+
"""
171180
return self.transformer.image_from(visibilities=self.signal_to_noise_map)
172181

173182
@property
174183
def dirty_model_image(self) -> Array2D:
184+
"""
185+
The dirty model image, computed by applying the inverse Fourier transform to the model data visibilities.
186+
This is the image-space representation of the model before any deconvolution.
187+
"""
175188
return self.transformer.image_from(visibilities=self.model_data)
176189

177190
@property
178191
def dirty_residual_map(self) -> Array2D:
192+
"""
193+
The dirty residual map, computed by applying the inverse Fourier transform to the residual-map visibilities
194+
(data - model_data). This is the image-space representation of the residuals.
195+
"""
179196
return self.transformer.image_from(visibilities=self.residual_map)
180197

181198
@property
182199
def dirty_normalized_residual_map(self) -> Array2D:
200+
"""
201+
The dirty normalized residual map, computed by applying the inverse Fourier transform to the
202+
normalized residual-map visibilities ((data - model_data) / noise_map).
203+
"""
183204
return self.transformer.image_from(visibilities=self.normalized_residual_map)
184205

185206
@property
186207
def dirty_chi_squared_map(self) -> Array2D:
208+
"""
209+
The dirty chi-squared map, computed by applying the inverse Fourier transform to the chi-squared-map
210+
visibilities (((data - model_data) / noise_map) ** 2.0).
211+
"""
187212
return self.transformer.image_from(visibilities=self.chi_squared_map)

autoarray/fit/fit_util.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77

88

99
def to_new_array(func):
10+
"""
11+
Decorator that wraps the return value of a fit utility function in the same data structure as its first argument.
12+
13+
After computing the result, it calls `.with_new_array(result)` on the first keyword argument. If the first
14+
argument does not have a `with_new_array` method (e.g. it is a plain ndarray), the raw result is returned
15+
instead.
16+
"""
17+
1018
@wraps(func)
1119
def wrapper(**kwargs):
1220
result = func(**kwargs)
@@ -28,8 +36,6 @@ def residual_map_from(*, data: ty.DataLike, model_data: ty.DataLike) -> ty.DataL
2836
----------
2937
data
3038
The data that is fitted.
31-
mask
32-
The mask applied to the dataset, where `False` entries are included in the calculation.
3339
model_data
3440
The model data used to fit the data.
3541
"""
@@ -50,8 +56,6 @@ def normalized_residual_map_from(
5056
The residual-map of the model-data fit to the dataset.
5157
noise_map
5258
The noise-map of the dataset.
53-
mask
54-
The mask applied to the residual-map, where `False` entries are included in the calculation.
5559
"""
5660
return residual_map / noise_map
5761

@@ -129,7 +133,7 @@ def chi_squared_map_complex_from(
129133
*, residual_map: np.ndarray, noise_map: np.ndarray
130134
) -> np.ndarray:
131135
"""
132-
Returnss the chi-squared-map of the fit of complex model-data to a dataset, where:
136+
Returns the chi-squared-map of the fit of complex model-data to a dataset, where:
133137
134138
Chi_Squared = ((Residuals) / (Noise)) ** 2.0 = ((Data - Model)**2.0)/(Variances)
135139
@@ -229,7 +233,7 @@ def chi_squared_map_with_mask_from(
229233
*, residual_map: ty.DataLike, noise_map: ty.DataLike, mask: Mask, xp=np
230234
) -> ty.DataLike:
231235
"""
232-
Returnss the chi-squared-map of the fit of model-data to a masked dataset, where:
236+
Returns the chi-squared-map of the fit of model-data to a masked dataset, where:
233237
234238
Chi_Squared = ((Residuals) / (Noise)) ** 2.0 = ((Data - Model)**2.0)/(Variances)
235239
@@ -289,10 +293,14 @@ def chi_squared_with_mask_fast_from(
289293
290294
Parameters
291295
----------
292-
chi_squared_map
293-
The chi-squared-map of values of the model-data fit to the dataset.
296+
data
297+
The data that is fitted.
294298
mask
295-
The mask applied to the chi-squared-map, where `False` entries are included in the calculation.
299+
The mask applied to the dataset, where `False` entries are included in the calculation.
300+
model_data
301+
The model data used to fit the data.
302+
noise_map
303+
The noise-map of the dataset.
296304
"""
297305
return float(
298306
xp.sum(
@@ -412,7 +420,7 @@ def log_evidence_from(
412420
log_curvature_regularization_term
413421
The log of the determinant of the sum of the curvature and regularization matrices.
414422
log_regularization_term
415-
The log of the determinant o the regularization matrix.
423+
The log of the determinant of the regularization matrix.
416424
noise_normalization
417425
The normalization noise_map-term for the dataset's noise-map.
418426
"""
@@ -447,7 +455,7 @@ def residual_flux_fraction_map_with_mask_from(
447455
*, residual_map: np.ndarray, data: np.ndarray, mask: Mask, xp=np
448456
) -> np.ndarray:
449457
"""
450-
Returnss the residual flux fraction map of the fit of model-data to a masked dataset, where:
458+
Returns the residual flux fraction map of the fit of model-data to a masked dataset, where:
451459
452460
Residual_Flux_Fraction = Residuals / Data = (Data - Model)/Data
453461

0 commit comments

Comments
 (0)