Skip to content

Commit 8c71cae

Browse files
authored
Merge pull request #179 from Jammy2211/feature/jax_remove_profiling
Feature/jax remove profiling
2 parents a6d7729 + f0892ba commit 8c71cae

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+265
-1735
lines changed

autoarray/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from autoconf.dictable import register_parser
2+
from autofit import conf
3+
4+
conf.instance.register(__file__)
5+
16
from . import exc
27
from . import type
38
from . import util
49
from . import fixtures
510
from . import mock as m
6-
from .numba_util import profile_func
711
from .dataset import preprocess
812
from .dataset.abstract.dataset import AbstractDataset
913
from .dataset.abstract.w_tilde import AbstractWTilde

autoarray/config/general.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
jax:
2+
use_jax: true # If True, uses JAX internally, whereas False uses normal Numpy.
13
fits:
24
flip_for_ds9: false # If True, the image is flipped before output to a .fits file, which is useful for viewing in DS9.
35
inversion:

autoarray/dataset/grids.py

Lines changed: 18 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from autoarray.inversion.pixelization.border_relocator import BorderRelocator
1010
from autoconf import cached_property
1111

12+
from autoarray import exc
13+
1214

1315
class GridsDataset:
1416
def __init__(
@@ -24,7 +26,7 @@ def __init__(
2426
2527
The following grids are contained:
2628
27-
- `uniform`: A grids of (y,x) coordinates which aligns with the centre of every image pixel of the image data,
29+
- `lp`: A grids of (y,x) coordinates which aligns with the centre of every image pixel of the image data,
2830
which is used for most normal calculations (e.g. evaluating the amount of light that falls in an pixel
2931
from a light profile).
3032
@@ -60,72 +62,30 @@ def __init__(
6062
self.over_sample_size_pixelization = over_sample_size_pixelization
6163
self.psf = psf
6264

63-
@cached_property
64-
def lp(self) -> Union[Grid1D, Grid2D]:
65-
"""
66-
Returns the grid of (y,x) Cartesian coordinates at the centre of every pixel in the masked data, which is used
67-
to perform most normal calculations (e.g. evaluating the amount of light that falls in an pixel from a light
68-
profile).
69-
70-
This grid is computed based on the mask, in particular its pixel-scale and sub-grid size.
71-
72-
Returns
73-
-------
74-
The (y,x) coordinates of every pixel in the data.
75-
"""
76-
return Grid2D.from_mask(
65+
self.lp = Grid2D.from_mask(
7766
mask=self.mask,
7867
over_sample_size=self.over_sample_size_lp,
7968
)
69+
self.lp.over_sampled
8070

81-
@cached_property
82-
def pixelization(self) -> Grid2D:
83-
"""
84-
Returns the grid of (y,x) Cartesian coordinates of every pixel in the masked data which is used
85-
specifically for calculations associated with a pixelization.
86-
87-
The `pixelization` grid is identical to the `uniform` grid but often uses a different over sampling scheme
88-
when performing calculations. For example, the pixelization may benefit from using a a higher `sub_size` than
89-
the `uniform` grid, in order to better prevent aliasing effects.
90-
91-
This grid is computed based on the mask, in particular its pixel-scale and sub-grid size.
92-
93-
Returns
94-
-------
95-
The (y,x) coordinates of every pixel in the data, used for pixelization / inversion calculations.
96-
"""
97-
return Grid2D.from_mask(
71+
self.pixelization = Grid2D.from_mask(
9872
mask=self.mask,
9973
over_sample_size=self.over_sample_size_pixelization,
10074
)
101-
102-
@cached_property
103-
def blurring(self) -> Optional[Grid2D]:
104-
"""
105-
Returns a blurring-grid from a mask and the 2D shape of the PSF kernel.
106-
107-
A blurring grid consists of all pixels that are masked (and therefore have their values set to (0.0, 0.0)),
108-
but are close enough to the unmasked pixels that their values will be convolved into the unmasked those pixels.
109-
This when computing images from light profile objects.
110-
111-
This uses lazy allocation such that the calculation is only performed when the blurring grid is used, ensuring
112-
efficient set up of the `Imaging` class.
113-
114-
Returns
115-
-------
116-
The blurring grid given the mask of the imaging data.
117-
"""
75+
self.pixelization.over_sampled
11876

11977
if self.psf is None:
120-
return None
121-
122-
return self.lp.blurring_grid_via_kernel_shape_from(
123-
kernel_shape_native=self.psf.shape_native,
124-
)
125-
126-
@cached_property
127-
def border_relocator(self) -> BorderRelocator:
128-
return BorderRelocator(
78+
self.blurring = None
79+
else:
80+
try:
81+
self.blurring = self.lp.blurring_grid_via_kernel_shape_from(
82+
kernel_shape_native=self.psf.shape_native,
83+
)
84+
self.blurring.over_sampled
85+
except exc.MaskException:
86+
self.blurring = None
87+
88+
self.border_relocator = BorderRelocator(
12989
mask=self.mask, sub_size=self.over_sample_size_pixelization
13090
)
13191

autoarray/dataset/imaging/dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,7 @@ def __init__(
170170
if psf.mask.shape[0] % 2 == 0 or psf.mask.shape[1] % 2 == 0:
171171
raise exc.KernelException("Kernel2D Kernel2D must be odd")
172172

173-
@cached_property
174-
def grids(self):
175-
return GridsDataset(
173+
self.grids = GridsDataset(
176174
mask=self.data.mask,
177175
over_sample_size_lp=self.over_sample_size_lp,
178176
over_sample_size_pixelization=self.over_sample_size_pixelization,
@@ -511,7 +509,7 @@ def apply_over_sampling(
511509
passed into the calculations performed in the `inversion` module.
512510
"""
513511

514-
return Imaging(
512+
dataset = Imaging(
515513
data=self.data,
516514
noise_map=self.noise_map,
517515
psf=self.psf,
@@ -522,6 +520,8 @@ def apply_over_sampling(
522520
check_noise_map=False,
523521
)
524522

523+
return dataset
524+
525525
def output_to_fits(
526526
self,
527527
data_path: Union[Path, str],

autoarray/dataset/interferometer/dataset.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,7 @@ def __init__(
101101
else None
102102
)
103103

104-
@cached_property
105-
def grids(self):
106-
return GridsDataset(
104+
self.grids = GridsDataset(
107105
mask=self.real_space_mask,
108106
over_sample_size_lp=self.over_sample_size_lp,
109107
over_sample_size_pixelization=self.over_sample_size_pixelization,

autoarray/exc.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,3 @@ class PlottingException(Exception):
104104
"""
105105

106106
pass
107-
108-
109-
class ProfilingException(Exception):
110-
"""
111-
Raises exceptions associated with in-built profiling tools (e.g. the `profile_func` decorator).
112-
"""
113-
114-
pass

autoarray/fit/fit_dataset.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from autoarray.inversion.inversion.abstract import AbstractInversion
1414
from autoarray.mask.mask_2d import Mask2D
1515

16-
from autoarray.numba_util import profile_func
1716
from autoarray import type as ty
1817

1918

@@ -116,7 +115,6 @@ def __init__(
116115
dataset,
117116
use_mask_in_fit: bool = False,
118117
dataset_model: DatasetModel = None,
119-
run_time_dict: Optional[Dict] = None,
120118
):
121119
"""Class to fit a masked dataset where the dataset's data structures are any dimension.
122120
@@ -149,7 +147,6 @@ def __init__(
149147
self.dataset = dataset
150148
self.use_mask_in_fit = use_mask_in_fit
151149
self.dataset_model = dataset_model or DatasetModel()
152-
self.run_time_dict = run_time_dict
153150

154151
@property
155152
def mask(self) -> Mask2D:
@@ -320,7 +317,6 @@ def log_evidence(self) -> float:
320317
)
321318

322319
@property
323-
@profile_func
324320
def figure_of_merit(self) -> float:
325321
if self.inversion is not None:
326322
return self.log_evidence

autoarray/fit/fit_imaging.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ def __init__(
1414
dataset: Imaging,
1515
use_mask_in_fit: bool = False,
1616
dataset_model: DatasetModel = None,
17-
run_time_dict: Optional[Dict] = None,
1817
):
1918
"""
2019
Class to fit a masked imaging dataset.
@@ -50,7 +49,6 @@ def __init__(
5049
dataset=dataset,
5150
use_mask_in_fit=use_mask_in_fit,
5251
dataset_model=dataset_model,
53-
run_time_dict=run_time_dict,
5452
)
5553

5654
@property

autoarray/fit/fit_interferometer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def __init__(
1818
dataset: Interferometer,
1919
dataset_model: DatasetModel = None,
2020
use_mask_in_fit: bool = False,
21-
run_time_dict: Optional[Dict] = None,
2221
):
2322
"""
2423
Class to fit a masked interferometer dataset.
@@ -59,7 +58,6 @@ def __init__(
5958
dataset=dataset,
6059
dataset_model=dataset_model,
6160
use_mask_in_fit=use_mask_in_fit,
62-
run_time_dict=run_time_dict,
6361
)
6462

6563
@property

autoarray/fit/mock/mock_fit_imaging.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@ def __init__(
1515
model_data=None,
1616
inversion=None,
1717
blurred_image=None,
18-
run_time_dict: Optional[Dict] = None,
1918
):
2019
super().__init__(
2120
dataset=dataset or MockDataset(),
2221
dataset_model=dataset_model,
2322
use_mask_in_fit=use_mask_in_fit,
24-
run_time_dict=run_time_dict,
2523
)
2624

2725
self._noise_map = noise_map

0 commit comments

Comments
 (0)