Skip to content

Commit 9947093

Browse files
authored
Merge pull request #157 from Jammy2211/feature/jax_merge
feature/jax merge
2 parents 40be8b1 + 393d273 commit 9947093

File tree

108 files changed

+2781
-2574
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

108 files changed

+2781
-2574
lines changed

autoarray/__init__.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from .dataset.interferometer.dataset import Interferometer
1515
from .dataset.interferometer.simulator import SimulatorInterferometer
1616
from .dataset.interferometer.w_tilde import WTildeInterferometer
17-
from .dataset.over_sampling import OverSamplingDataset
1817
from .dataset.dataset_model import DatasetModel
1918
from .fit.fit_dataset import AbstractFit
2019
from .fit.fit_dataset import FitDataset
@@ -68,14 +67,8 @@
6867
from .structures.arrays.irregular import ArrayIrregular
6968
from .structures.grids.uniform_1d import Grid1D
7069
from .structures.grids.uniform_2d import Grid2D
71-
from .operators.over_sampling.decorator import perform_over_sampling_from
72-
from .operators.over_sampling.grid_oversampled import Grid2DOverSampled
73-
from .operators.over_sampling.uniform import OverSamplingUniform
74-
from .operators.over_sampling.iterate import OverSamplingIterate
75-
from .operators.over_sampling.uniform import OverSamplerUniform
76-
from .operators.over_sampling.iterate import OverSamplerIterate
70+
from .operators.over_sampling.over_sampler import OverSampler
7771
from .structures.grids.irregular_2d import Grid2DIrregular
78-
from .structures.grids.irregular_2d import Grid2DIrregularUniform
7972
from .structures.mesh.rectangular_2d import Mesh2DRectangular
8073
from .structures.mesh.voronoi_2d import Mesh2DVoronoi
8174
from .structures.mesh.delaunay_2d import Mesh2DDelaunay
@@ -98,4 +91,4 @@
9891

9992
conf.instance.register(__file__)
10093

101-
__version__ = "2024.11.6.1"
94+
__version__ = "2025.1.18.7"

autoarray/config/visualize/mat_wrap_2d.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,15 @@ VectorYXQuiver: # wrapper for `plt.quiver()`: customize (y,x) vectors appea
146146
linewidth: 5
147147
pivot: middle
148148
units: xy
149+
DelaunayDrawer: # wrapper for `plt.fill()`: customize the appearance of Delaunay mesh's.
150+
figure:
151+
alpha: 0.7
152+
edgecolor: k
153+
linewidth: 0.0
154+
subplot:
155+
alpha: 0.7
156+
edgecolor: k
157+
linewidth: 0.0
149158
VoronoiDrawer: # wrapper for `plt.fill()`: customize the appearance of Voronoi mesh's.
150159
figure:
151160
alpha: 0.7

autoarray/config/visualize/plots.yaml

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,11 @@
55

66
dataset: # Settings for plots of all datasets (e.g. ImagingPlotter, InterferometerPlotter).
77
subplot_dataset: true # Plot subplot containing all dataset quantities (e.g. the data, noise-map, etc.)?
8-
data: false # Plot the individual data of every dataset?
9-
noise_map: false # Plot the individual noise-map of every dataset?
10-
signal_to_noise_map: false # Plot the individual signal-to-noise-map of every dataset?
11-
over_sampling: false # Plot the over-sampling sub-size, used to evaluate light profiles, of every dataset?
12-
over_sampling_non_uniform: false # Plot the over-sampling sub-size, used to evaluate non uniform grids, of every dataset?
13-
over_sampling_pixelization: false # Plot the over-sampling sub-size, used to evaluate pixelizations, of every dataset?
148
imaging: # Settings for plots of imaging datasets (e.g. ImagingPlotter)
159
psf: false
1610
fit: # Settings for plots of all fits (e.g. FitImagingPlotter, FitInterferometerPlotter).
1711
subplot_fit: true # Plot subplot of all fit quantities for any dataset (e.g. the model data, residual-map, etc.)?
1812
subplot_fit_log10: true # Plot subplot of all fit quantities for any dataset using log10 color maps (e.g. the model data, residual-map, etc.)?
19-
all_at_end_png: true # Plot all individual plots listed below as .png (even if False)?
20-
all_at_end_fits: true # Plot all individual plots listed below as .fits (even if False)?
21-
all_at_end_pdf: false # Plot all individual plots listed below as publication-quality .pdf (even if False)?
2213
data: false # Plot individual plots of the data?
2314
noise_map: false # Plot individual plots of the noise-map?
2415
signal_to_noise_map: false # Plot individual plots of the signal-to-noise-map?
@@ -31,33 +22,14 @@ fit_imaging: {} # Settings for plots of fits to imagi
3122
inversion: # Settings for plots of inversions (e.g. InversionPlotter).
3223
subplot_inversion: true # Plot subplot of all quantities in each inversion (e.g. reconstrucuted image, reconstruction)?
3324
subplot_mappings: true # Plot subplot of the image-to-source pixels mappings of each pixelization?
34-
all_at_end_png: true # Plot all individual plots listed below as .png (even if False)?
35-
all_at_end_fits: true # Plot all individual plots listed below as .fits (even if False)?
36-
all_at_end_pdf: false # Plot all individual plots listed below as publication-quality .pdf (even if False)?
3725
data_subtracted: false # Plot individual plots of the data with the other inversion linear objects subtracted?
38-
errors: false # Plot image of the errors of every mesh-pixel reconstructed value?
26+
reconstruction_noise_map: false # Plot image of the noise of every mesh-pixel reconstructed value?
3927
sub_pixels_per_image_pixels: false # Plot the number of sub pixels per masked data pixels?
4028
mesh_pixels_per_image_pixels: false # Plot the number of image-plane mesh pixels per masked data pixels?
4129
image_pixels_per_mesh_pixels: false # Plot the number of image pixels in each pixel of the mesh?
4230
reconstructed_image: false # Plot image of the reconstructed data (e.g. in the image-plane)?
4331
reconstruction: false # Plot the reconstructed inversion (e.g. the pixelization's mesh in the source-plane)?
4432
regularization_weights: false # Plot the effective regularization weight of every inversion mesh pixel?
45-
interferometer: # Settings for plots of interferometer datasets (e.g. InterferometerPlotter).
46-
amplitudes_vs_uv_distances: false
47-
phases_vs_uv_distances: false
48-
uv_wavelengths: false
49-
dirty_image: false
50-
dirty_noise_map: false
51-
dirty_signal_to_noise_map: false
5233
fit_interferometer: # Settings for plots of fits to interferometer datasets (e.g. FitInterferometerPlotter).
5334
subplot_fit_dirty_images: false # Plot subplot of the dirty-images of all interferometer datasets?
54-
subplot_fit_real_space: false # Plot subplot of the real-space images of all interferometer datasets?
55-
amplitudes_vs_uv_distances: false
56-
phases_vs_uv_distances: false
57-
uv_wavelengths: false
58-
dirty_image: false
59-
dirty_noise_map: false
60-
dirty_signal_to_noise_map: false
61-
dirty_residual_map: false
62-
dirty_normalized_residual_map: false
63-
dirty_chi_squared_map: false
35+
subplot_fit_real_space: false # Plot subplot of the real-space images of all interferometer datasets?

autoarray/dataset/abstract/dataset.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,17 @@
44
import warnings
55
from typing import Optional, Union
66

7-
from autoarray.dataset.over_sampling import OverSamplingDataset
7+
from autoconf import cached_property
8+
89
from autoarray.dataset.grids import GridsDataset
910

1011
from autoarray import exc
1112
from autoarray.mask.mask_1d import Mask1D
1213
from autoarray.mask.mask_2d import Mask2D
1314
from autoarray.structures.abstract_structure import Structure
1415
from autoarray.structures.arrays.uniform_2d import Array2D
15-
from autoconf import cached_property
16+
17+
from autoarray.operators.over_sampling import over_sample_util
1618

1719

1820
logger = logging.getLogger(__name__)
@@ -24,7 +26,8 @@ def __init__(
2426
data: Structure,
2527
noise_map: Structure,
2628
noise_covariance_matrix: Optional[np.ndarray] = None,
27-
over_sampling: Optional[OverSamplingDataset] = OverSamplingDataset(),
29+
over_sample_size_lp: Union[int, Array2D] = 4,
30+
over_sample_size_pixelization: Union[int, Array2D] = 4,
2831
):
2932
"""
3033
An abstract dataset, containing the image data, noise-map, PSF and associated quantities for calculations
@@ -45,6 +48,32 @@ def __init__(
4548
over sampling calculations built in which approximate the 2D line integral of these calculations within a
4649
pixel. This is explained in more detail in the `GridsDataset` class.
4750
51+
**Over Sampling**
52+
53+
If a grid is uniform and the centre of each point on the grid is the centre of a 2D pixel, evaluating
54+
the value of a function on the grid requires a 2D line integral to compute it precisely. This can be
55+
computationally expensive and difficult to implement.
56+
57+
Over sampling is a numerical technique where the function is evaluated on a sub-grid within each grid pixel
58+
which is higher resolution than the grid itself. This approximates more closely the value of the function
59+
within a 2D line intergral of the values in the square pixel that the grid is centred.
60+
61+
For example, in PyAutoGalaxy and PyAutoLens the light profiles and galaxies are evaluated in order to determine
62+
how much light falls in each pixel. This uses over sampling and therefore a higher resolution grid than the
63+
image data to ensure the calculation is accurate.
64+
65+
This class controls how over sampling is performed for 2 different types of grids:
66+
67+
- `lp`: A grids of (y,x) coordinates which aligns with the centre of every image pixel of the image data
68+
and is used to evaluate light profiles for model-fititng.
69+
70+
- `pixelization`: A grid of (y,x) coordinates which again align with the centre of every image pixel of
71+
the image data. This grid is used specifically for pixelizations computed via the `inversion` module, which
72+
can benefit from using different oversampling schemes than the normal grid.
73+
74+
Different calculations typically benefit from different over sampling, which this class enables
75+
the customization of.
76+
4877
Parameters
4978
----------
5079
data
@@ -56,10 +85,13 @@ def __init__(
5685
noise_covariance_matrix
5786
A noise-map covariance matrix representing the covariance between noise in every `data` value, which
5887
can be used via a bespoke fit to account for correlated noise in the data.
59-
over_sampling
60-
The over sampling schemes which divide the grids into sub grids of smaller pixels within their host image
61-
pixels when using the grid to evaluate a function (e.g. images) to better approximate the 2D line integral
62-
This class controls over sampling for all the different grids (e.g. `grid`, `grids.pixelization).
88+
over_sample_size_lp
89+
The over sampling scheme size, which divides the grid into a sub grid of smaller pixels when computing
90+
values (e.g. images) from the grid to approximate the 2D line integral of the amount of light that falls
91+
into each pixel.
92+
over_sample_size_pixelization
93+
How over sampling is performed for the grid which is associated with a pixelization, which is therefore
94+
passed into the calculations performed in the `inversion` module.
6395
"""
6496

6597
self.data = data
@@ -93,15 +125,28 @@ def __init__(
93125

94126
self.noise_map = noise_map
95127

96-
self.over_sampling = over_sampling
128+
self.over_sample_size_lp = (
129+
over_sample_util.over_sample_size_convert_to_array_2d_from(
130+
over_sample_size=over_sample_size_lp, mask=self.mask
131+
)
132+
)
133+
self.over_sample_size_pixelization = (
134+
over_sample_util.over_sample_size_convert_to_array_2d_from(
135+
over_sample_size=over_sample_size_pixelization, mask=self.mask
136+
)
137+
)
97138

98139
@property
99140
def grid(self):
100-
return self.grids.uniform
141+
return self.grids.lp
101142

102143
@cached_property
103144
def grids(self):
104-
return GridsDataset(mask=self.data.mask, over_sampling=self.over_sampling)
145+
return GridsDataset(
146+
mask=self.data.mask,
147+
over_sample_size_lp=self.over_sample_size_lp,
148+
over_sample_size_pixelization=self.over_sample_size_pixelization,
149+
)
105150

106151
@property
107152
def shape_native(self):
@@ -161,4 +206,16 @@ def trimmed_after_convolution_from(self, kernel_shape) -> "AbstractDataset":
161206
kernel_shape=kernel_shape
162207
)
163208

209+
dataset.over_sample_size_lp = (
210+
dataset.over_sample_size_lp.trimmed_after_convolution_from(
211+
kernel_shape=kernel_shape
212+
)
213+
)
214+
215+
dataset.over_sample_size_pixelization = (
216+
dataset.over_sample_size_pixelization.trimmed_after_convolution_from(
217+
kernel_shape=kernel_shape
218+
)
219+
)
220+
164221
return dataset

autoarray/dataset/grids.py

Lines changed: 21 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from typing import Optional, Union
22

3-
from autoarray.dataset.over_sampling import OverSamplingDataset
43
from autoarray.mask.mask_2d import Mask2D
4+
from autoarray.structures.arrays.uniform_2d import Array2D
55
from autoarray.structures.arrays.kernel_2d import Kernel2D
66
from autoarray.structures.grids.uniform_1d import Grid1D
77
from autoarray.structures.grids.uniform_2d import Grid2D
88

9-
from autoarray.operators.over_sampling.uniform import OverSamplingUniform
109
from autoarray.inversion.pixelization.border_relocator import BorderRelocator
1110
from autoconf import cached_property
1211

@@ -15,7 +14,8 @@ class GridsDataset:
1514
def __init__(
1615
self,
1716
mask: Mask2D,
18-
over_sampling: OverSamplingDataset,
17+
over_sample_size_lp: Union[int, Array2D],
18+
over_sample_size_pixelization: Union[int, Array2D],
1919
psf: Optional[Kernel2D] = None,
2020
):
2121
"""
@@ -28,11 +28,6 @@ def __init__(
2828
which is used for most normal calculations (e.g. evaluating the amount of light that falls in an pixel
2929
from a light profile).
3030
31-
- `non_uniform`: A grid of (y,x) coordinates which aligns with the centre of every image pixel of the image
32-
data, but where their values are going to be deflected to become non-uniform such that the adaptive over
33-
sampling scheme used for the main grid does not apply. This is used to compute over sampled light profiles of
34-
lensed sources in PyAutoLens.
35-
3631
- `pixelization`: A grid of (y,x) coordinates which again align with the centre of every image pixel of
3732
the image data. This grid is used specifically for pixelizations computed via the `inversion` module, which
3833
can benefit from using different oversampling schemes than the normal grid.
@@ -49,16 +44,24 @@ def __init__(
4944
5045
Parameters
5146
----------
52-
mask
53-
over_sampling
47+
over_sample_size_lp
48+
The over sampling scheme size, which divides the grid into a sub grid of smaller pixels when computing
49+
values (e.g. images) from the grid to approximate the 2D line integral of the amount of light that falls
50+
into each pixel.
51+
over_sample_size_pixelization
52+
How over sampling is performed for the grid which is associated with a pixelization, which is therefore
53+
passed into the calculations performed in the `inversion` module.
5454
psf
55+
The Point Spread Function kernel of the image which accounts for diffraction due to the telescope optics
56+
via 2D convolution.
5557
"""
5658
self.mask = mask
57-
self.over_sampling = over_sampling
59+
self.over_sample_size_lp = over_sample_size_lp
60+
self.over_sample_size_pixelization = over_sample_size_pixelization
5861
self.psf = psf
5962

6063
@cached_property
61-
def uniform(self) -> Union[Grid1D, Grid2D]:
64+
def lp(self) -> Union[Grid1D, Grid2D]:
6265
"""
6366
Returns the grid of (y,x) Cartesian coordinates at the centre of every pixel in the masked data, which is used
6467
to perform most normal calculations (e.g. evaluating the amount of light that falls in an pixel from a light
@@ -70,38 +73,9 @@ def uniform(self) -> Union[Grid1D, Grid2D]:
7073
-------
7174
The (y,x) coordinates of every pixel in the data.
7275
"""
73-
74-
return Grid2D.from_mask(
75-
mask=self.mask,
76-
over_sampling=self.over_sampling.uniform,
77-
)
78-
79-
@cached_property
80-
def non_uniform(self) -> Optional[Union[Grid1D, Grid2D]]:
81-
"""
82-
Returns the grid of (y,x) Cartesian coordinates at the centre of every pixel in the masked data, but
83-
with a different over sampling scheme designed for
84-
85-
where
86-
their values are going to be deflected to become non-uniform such that the adaptive over sampling scheme used
87-
for the main grid does not apply.
88-
89-
This is used to compute over sampled light profiles of lensed sources in PyAutoLens.
90-
91-
92-
This grid is computed based on the mask, in particular its pixel-scale and sub-grid size.
93-
94-
Returns
95-
-------
96-
The (y,x) coordinates of every pixel in the data.
97-
"""
98-
99-
if self.over_sampling.non_uniform is None:
100-
return None
101-
10276
return Grid2D.from_mask(
10377
mask=self.mask,
104-
over_sampling=self.over_sampling.non_uniform,
78+
over_sample_size=self.over_sample_size_lp,
10579
)
10680

10781
@cached_property
@@ -120,15 +94,9 @@ def pixelization(self) -> Grid2D:
12094
-------
12195
The (y,x) coordinates of every pixel in the data, used for pixelization / inversion calculations.
12296
"""
123-
124-
over_sampling = self.over_sampling.pixelization
125-
126-
if over_sampling is None:
127-
over_sampling = OverSamplingUniform(sub_size=4)
128-
12997
return Grid2D.from_mask(
13098
mask=self.mask,
131-
over_sampling=over_sampling,
99+
over_sample_size=self.over_sample_size_pixelization,
132100
)
133101

134102
@cached_property
@@ -151,36 +119,26 @@ def blurring(self) -> Optional[Grid2D]:
151119
if self.psf is None:
152120
return None
153121

154-
return self.uniform.blurring_grid_via_kernel_shape_from(
122+
return self.lp.blurring_grid_via_kernel_shape_from(
155123
kernel_shape_native=self.psf.shape_native,
156124
)
157125

158-
@cached_property
159-
def over_sampler_non_uniform(self):
160-
return self.non_uniform.over_sampling.over_sampler_from(mask=self.mask)
161-
162-
@cached_property
163-
def over_sampler_pixelization(self):
164-
return self.pixelization.over_sampling.over_sampler_from(mask=self.mask)
165-
166126
@cached_property
167127
def border_relocator(self) -> BorderRelocator:
168128
return BorderRelocator(
169-
mask=self.mask, sub_size=self.pixelization.over_sampling.sub_size
129+
mask=self.mask, sub_size=self.over_sample_size_pixelization
170130
)
171131

172132

173133
class GridsInterface:
174134
def __init__(
175135
self,
176-
uniform=None,
177-
non_uniform=None,
136+
lp=None,
178137
pixelization=None,
179138
blurring=None,
180139
border_relocator=None,
181140
):
182-
self.uniform = uniform
183-
self.non_uniform = non_uniform
141+
self.lp = lp
184142
self.pixelization = pixelization
185143
self.blurring = blurring
186144
self.border_relocator = border_relocator

0 commit comments

Comments
 (0)