Skip to content

Commit 9a316ef

Browse files
authored
Merge pull request #182 from Jammy2211/feature/jax_speed_up_general
Feature/jax speed up general
2 parents 2c0edb8 + 5f539ce commit 9a316ef

File tree

83 files changed

+394
-892
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+394
-892
lines changed

autoarray/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from autoconf.dictable import register_parser
2-
from autofit import conf
2+
from autoconf import conf
33

44
conf.instance.register(__file__)
55

autoarray/dataset/imaging/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,12 +430,12 @@ def apply_noise_scaling(
430430
"""
431431

432432
if signal_to_noise_value is None:
433-
noise_map = np.array(self.noise_map.native.array)
433+
noise_map = self.noise_map.native
434434
noise_map[mask.array == False] = noise_value
435435
else:
436436
noise_map = np.where(
437437
mask == False,
438-
np.median(self.data.native.array[mask.derive_mask.edge == False])
438+
np.median(self.data.native[mask.derive_mask.edge == False])
439439
/ signal_to_noise_value,
440440
self.noise_map.native.array,
441441
)

autoarray/dataset/interferometer/dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from astropy.io import fits
21
import logging
32
import numpy as np
43
from pathlib import Path
@@ -148,6 +147,9 @@ def from_fits(
148147
)
149148

150149
def w_tilde_preprocessing(self):
150+
151+
from astropy.io import fits
152+
151153
if self.preprocessing_directory.is_dir():
152154
filename = "{}/curvature_preload.fits".format(self.preprocessing_directory)
153155

autoarray/dataset/plot/imaging_plotters.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ def __init__(
1414
self,
1515
dataset: Imaging,
1616
get_visuals_2d: Callable,
17-
mat_plot_2d: MatPlot2D = MatPlot2D(),
18-
visuals_2d: Visuals2D = Visuals2D(),
19-
include_2d: Include2D = Include2D(),
17+
mat_plot_2d: MatPlot2D = None,
18+
visuals_2d: Visuals2D = None,
19+
include_2d: Include2D = None,
2020
):
2121
"""
2222
Plots the attributes of `Imaging` objects using the matplotlib method `imshow()` and many other matplotlib
@@ -231,9 +231,9 @@ class ImagingPlotter(Plotter):
231231
def __init__(
232232
self,
233233
dataset: Imaging,
234-
mat_plot_2d: MatPlot2D = MatPlot2D(),
235-
visuals_2d: Visuals2D = Visuals2D(),
236-
include_2d: Include2D = Include2D(),
234+
mat_plot_2d: MatPlot2D = None,
235+
visuals_2d: Visuals2D = None,
236+
include_2d: Include2D = None,
237237
):
238238
"""
239239
Plots the attributes of `Imaging` objects using the matplotlib method `imshow()` and many other matplotlib

autoarray/dataset/plot/interferometer_plotters.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ class InterferometerPlotter(Plotter):
1414
def __init__(
1515
self,
1616
dataset: Interferometer,
17-
mat_plot_1d: MatPlot1D = MatPlot1D(),
18-
visuals_1d: Visuals1D = Visuals1D(),
19-
include_1d: Include1D = Include1D(),
20-
mat_plot_2d: MatPlot2D = MatPlot2D(),
21-
visuals_2d: Visuals2D = Visuals2D(),
22-
include_2d: Include2D = Include2D(),
17+
mat_plot_1d: MatPlot1D = None,
18+
visuals_1d: Visuals1D = None,
19+
include_1d: Include1D = None,
20+
mat_plot_2d: MatPlot2D = None,
21+
visuals_2d: Visuals2D = None,
22+
include_2d: Include2D = None,
2323
):
2424
"""
2525
Plots the attributes of `Interferometer` objects using the matplotlib methods `plot()`, `scatter()` and

autoarray/dataset/preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import numpy as np
2-
from scipy.stats import norm
32

43
from autoarray import exc
54

@@ -316,6 +315,7 @@ def background_noise_map_via_edges_from(image, no_edges):
316315
no_edges
317316
Number of edges used to estimate the background level.
318317
"""
318+
from scipy.stats import norm
319319

320320
from autoarray.structures.arrays.uniform_2d import Array2D
321321

autoarray/fit/fit_interferometer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def noise_normalization(self) -> float:
122122
[Noise_Term] = sum(log(2*pi*[Noise]**2.0))
123123
"""
124124
return fit_util.noise_normalization_complex_from(
125-
noise_map=self.noise_map,
125+
noise_map=self.noise_map.array,
126126
)
127127

128128
@property

autoarray/fit/fit_util.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,8 @@ def noise_normalization_complex_from(*, noise_map: jnp.ndarray) -> float:
174174
noise_map
175175
The masked noise-map of the dataset.
176176
"""
177-
noise_normalization_real = jnp.sum(
178-
jnp.log(2 * jnp.pi * np.array(noise_map).real ** 2.0)
179-
)
180-
noise_normalization_imag = jnp.sum(
181-
jnp.log(2 * jnp.pi * np.array(noise_map).imag ** 2.0)
182-
)
177+
noise_normalization_real = jnp.sum(jnp.log(2 * jnp.pi * noise_map.real**2.0))
178+
noise_normalization_imag = jnp.sum(jnp.log(2 * jnp.pi * noise_map.imag**2.0))
183179
return noise_normalization_real + noise_normalization_imag
184180

185181

autoarray/fit/mock/mock_fit_imaging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional
1+
from typing import Optional
22

33
from autoarray.dataset.mock.mock_dataset import MockDataset
44
from autoarray.dataset.dataset_model import DatasetModel

autoarray/fit/plot/fit_imaging_plotters.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ def __init__(
1313
self,
1414
fit,
1515
get_visuals_2d: Callable,
16-
mat_plot_2d: MatPlot2D = MatPlot2D(),
17-
visuals_2d: Visuals2D = Visuals2D(),
18-
include_2d: Include2D = Include2D(),
16+
mat_plot_2d: MatPlot2D = None,
17+
visuals_2d: Visuals2D = None,
18+
include_2d: Include2D = None,
1919
residuals_symmetric_cmap: bool = True,
2020
):
2121
"""
@@ -242,9 +242,9 @@ class FitImagingPlotter(Plotter):
242242
def __init__(
243243
self,
244244
fit: FitImaging,
245-
mat_plot_2d: MatPlot2D = MatPlot2D(),
246-
visuals_2d: Visuals2D = Visuals2D(),
247-
include_2d: Include2D = Include2D(),
245+
mat_plot_2d: MatPlot2D = None,
246+
visuals_2d: Visuals2D = None,
247+
include_2d: Include2D = None,
248248
):
249249
"""
250250
Plots the attributes of `FitImaging` objects using the matplotlib method `imshow()` and many other matplotlib

0 commit comments

Comments
 (0)