Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
2607e67
Merge pull request #336 from Jammy2211/feature/jax_merge_2
Jammy2211 Mar 13, 2025
89a638a
Merge branch 'main' into feature/jax_wrapper
Jammy2211 Mar 24, 2025
e87785a
Merge branch 'main' into feature/jax_wrapper
Jammy2211 Apr 2, 2025
e0bae78
swap convolver for PSF
Jammy2211 Apr 2, 2025
fcf9394
black
Jammy2211 Apr 2, 2025
310c42d
Merge pull request #340 from Jammy2211/feature/remove_convolver
Jammy2211 Apr 3, 2025
8289253
remove grid relocate radial
Jammy2211 Apr 3, 2025
fc63e19
remove pylops docs
Jammy2211 Apr 4, 2025
bcbf354
finish
Jammy2211 Apr 8, 2025
5b1046c
update point solver
Jammy2211 Apr 8, 2025
df392c4
fix test_tracer_util
Jammy2211 Apr 8, 2025
7e15c8c
test_tracer
Jammy2211 Apr 8, 2025
cc2c92a
test_operate
Jammy2211 Apr 8, 2025
e2bf1c5
test_to_inversion passes
Jammy2211 Apr 8, 2025
b4fa7b2
test_fit_imaging
Jammy2211 Apr 8, 2025
85135eb
test_simulate_and_fit_imaging.py
Jammy2211 Apr 8, 2025
407b8c0
test_autolens/imaging/test_simulate_and_fit_imaging.py
Jammy2211 Apr 8, 2025
d9d090b
test_fit_interferometer
Jammy2211 Apr 8, 2025
6e71dcc
test_simulate_and_Fit_interferometer
Jammy2211 Apr 8, 2025
8d1af2a
test_analysos_imaging
Jammy2211 Apr 8, 2025
a7e15e6
test_analysis_interferometer
Jammy2211 Apr 8, 2025
8eb945d
black
Jammy2211 Apr 8, 2025
a15e097
Merge pull request #342 from Jammy2211/feature/jax_unit_tests
Jammy2211 Apr 8, 2025
815d429
minor fixes
Jammy2211 Apr 9, 2025
f167f1a
fix flux fit by pasing .,array
Jammy2211 Apr 9, 2025
43c9fa5
failing test
Jammy2211 Apr 9, 2025
a090290
fix another test
Jammy2211 Apr 9, 2025
61be7da
more casting to fix tests
Jammy2211 Apr 9, 2025
79e897f
all point solver issues fixed
Jammy2211 Apr 9, 2025
7df3ffd
black
Jammy2211 Apr 9, 2025
296376a
Merge pull request #344 from Jammy2211/feature/jax_wrapper_point_solver
Jammy2211 Apr 9, 2025
927d954
Merge branch 'main' into feature/jax_wrapper
Jammy2211 Apr 9, 2025
7290b93
wrapper
Apr 29, 2025
0ec1493
Merge branch 'main' into feature/jax_wrapper
Apr 29, 2025
2f394c6
add zoom
Apr 29, 2025
e66f82d
all test pass
May 4, 2025
5300290
merge
May 6, 2025
77bf809
merge
May 13, 2025
023479b
Merge branch 'main' into feature/jax_wrapper
May 16, 2025
7265c68
merge
May 21, 2025
2a7b40a
Merge branch 'main' into feature/jax_wrapper
Jun 1, 2025
97e226b
fix test simulate_and_fit_imaging
Jun 15, 2025
90078e9
Merge pull request #354 from Jammy2211/feature/jax_linear_light
Jammy2211 Jun 15, 2025
2586e77
point solver jax fix
Jun 15, 2025
6d9fb5a
convert fit position image pair repeat residual map to JAx
Jun 15, 2025
5eb989a
another resdidual map fix
Jun 15, 2025
5570bed
fluxes jax'd
Jun 15, 2025
64b1c12
fixing removal of infs from positions
Jun 15, 2025
f1b4746
fix position input
Jun 15, 2025
3f823d1
filter positions in Result
Jun 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autolens/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from autoarray.mask.mask_1d import Mask1D
from autoarray.mask.mask_2d import Mask2D
from autoarray.mask.derive.zoom_2d import Zoom2D
from autoarray.operators.convolver import Convolver
from autoarray.operators.over_sampling.over_sampler import OverSampler # noqa
from autoarray.inversion.inversion.dataset_interface import DatasetInterface
from autoarray.inversion.inversion.mapper_valued import MapperValued
Expand All @@ -27,6 +26,7 @@
from autoarray.inversion.pixelization.mappers.mapper_grids import MapperGrids
from autoarray.inversion.pixelization.mappers.factory import mapper_from as Mapper
from autoarray.inversion.pixelization.border_relocator import BorderRelocator
from autoarray.inversion.convolver import Convolver
from autoarray.operators.transformer import TransformerDFT
from autoarray.operators.transformer import TransformerNUFFT
from autoarray.structures.arrays.uniform_1d import Array1D
Expand Down
7 changes: 5 additions & 2 deletions autolens/analysis/positions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
The plane redshift of the lensed source multiple images, which is only required if position threshold
for a double source plane lens system is being used where the specific plane is required.
"""

self.positions = positions
self.threshold = threshold
self.plane_redshift = plane_redshift
Expand Down Expand Up @@ -165,10 +166,12 @@ def log_likelihood_penalty_base_from(
residual_map=residual_map, noise_map=dataset.noise_map
)

chi_squared = aa.util.fit.chi_squared_from(chi_squared_map=chi_squared_map)
chi_squared = aa.util.fit.chi_squared_from(
chi_squared_map=chi_squared_map.array
)

noise_normalization = aa.util.fit.noise_normalization_from(
noise_map=dataset.noise_map
noise_map=dataset.noise_map.array
)

else:
Expand Down
10 changes: 8 additions & 2 deletions autolens/analysis/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def positions_threshold_from(
data=positions, noise_map=None, tracer=tracer, plane_redshift=plane_redshift
)

threshold = factor * np.max(positions_fits.max_separation_of_plane_positions)
threshold = factor * np.nanmax(positions_fits.max_separation_of_plane_positions)

if minimum_threshold is not None:
if threshold < minimum_threshold:
Expand Down Expand Up @@ -285,7 +285,7 @@ def positions_likelihood_from(

Returns
-------
The `PositionsLH` object used to apply a likelihood penalty using the positions.
The `PositionsLH` object used to apply a likelihood penalty or resample the positions.
"""

if os.environ.get("PYAUTOFIT_TEST_MODE") == "1":
Expand All @@ -308,6 +308,12 @@ def positions_likelihood_from(

positions = positions[distances > mass_centre_radial_distance_min]

mask = np.isfinite(positions.array).all(axis=1)

positions = aa.Grid2DIrregular(
positions[mask]
)

threshold = self.positions_threshold_from(
factor=factor,
minimum_threshold=minimum_threshold,
Expand Down
14 changes: 11 additions & 3 deletions autolens/imaging/fit_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,17 @@ def blurred_image(self) -> aa.Array2D:
"""
Returns the image of all light profiles in the fit's tracer convolved with the imaging dataset's PSF.
"""

if len(self.tracer.cls_list_from(cls=ag.LightProfile)) == len(
self.tracer.cls_list_from(cls=ag.lp_operated.LightProfileOperated)
):
return self.tracer.image_2d_from(
grid=self.grids.lp,
)

return self.tracer.blurred_image_2d_from(
grid=self.grids.lp,
convolver=self.dataset.convolver,
psf=self.dataset.psf,
blurring_grid=self.grids.blurring,
)

Expand All @@ -93,7 +101,6 @@ def profile_subtracted_image(self) -> aa.Array2D:
"""
Returns the dataset's image with all blurred light profile images in the fit's tracer subtracted.
"""

return self.data - self.blurred_image

@property
Expand All @@ -103,6 +110,7 @@ def tracer_to_inversion(self) -> TracerToInversion:
data=self.profile_subtracted_image,
noise_map=self.noise_map,
grids=self.grids,
psf=self.dataset.psf,
convolver=self.dataset.convolver,
w_tilde=self.w_tilde,
)
Expand Down Expand Up @@ -163,7 +171,7 @@ def galaxy_model_image_dict(self) -> Dict[ag.Galaxy, np.ndarray]:

galaxy_blurred_image_2d_dict = self.tracer.galaxy_blurred_image_2d_dict_from(
grid=self.grids.lp,
convolver=self.dataset.convolver,
psf=self.dataset.psf,
blurring_grid=self.grids.blurring,
)

Expand Down
2 changes: 2 additions & 0 deletions autolens/imaging/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def log_likelihood_function(self, instance: af.ModelInstance) -> float:
np.linalg.LinAlgError,
OverflowError,
) as e:
print(e)
fggdfg
raise exc.FitException from e

def fit_from(
Expand Down
6 changes: 3 additions & 3 deletions autolens/imaging/plot/fit_imaging_plotters.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def figures_2d_of_planes(
for plane_index in plane_indexes:

if use_source_vmax:
self.mat_plot_2d.cmap.kwargs["vmax"] = np.max(self.fit.model_images_of_planes_list[plane_index])
self.mat_plot_2d.cmap.kwargs["vmax"] = np.max(self.fit.model_images_of_planes_list[plane_index].array)

if subtracted_image:

Expand Down Expand Up @@ -765,7 +765,7 @@ def figures_2d(
if data:

if use_source_vmax:
self.mat_plot_2d.cmap.kwargs["vmax"] = np.max(self.fit.model_images_of_planes_list[1:])
self.mat_plot_2d.cmap.kwargs["vmax"] = np.max([model_image.array for model_image in self.fit.model_images_of_planes_list[1:]])

self.mat_plot_2d.plot_array(
array=self.fit.data,
Expand Down Expand Up @@ -799,7 +799,7 @@ def figures_2d(
if model_image:

if use_source_vmax:
self.mat_plot_2d.cmap.kwargs["vmax"] = np.max(self.fit.model_images_of_planes_list[1:])
self.mat_plot_2d.cmap.kwargs["vmax"] = np.max([model_image.array for model_image in self.fit.model_images_of_planes_list[1:]])

self.mat_plot_2d.plot_array(
array=self.fit.model_data,
Expand Down
3 changes: 2 additions & 1 deletion autolens/lens/to_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def lp_linear_func_list_galaxy_dict(
data=self.dataset.data,
noise_map=self.dataset.noise_map,
grids=grids,
convolver=self.convolver,
psf=self.psf,
convolver=self.dataset.convolver,
transformer=self.transformer,
w_tilde=self.dataset.w_tilde,
)
Expand Down
11 changes: 5 additions & 6 deletions autolens/lens/tracer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from abc import ABC
import numpy as np
from functools import wraps
from scipy.interpolate import griddata
from typing import Dict, List, Optional, Type, Union

Expand Down Expand Up @@ -549,9 +548,9 @@ def image_2d_via_input_plane_image_from(
)[plane_index]

image = griddata(
points=plane_grid,
values=plane_image,
xi=traced_grid.over_sampled,
points=plane_grid.array,
values=plane_image.array,
xi=traced_grid.over_sampled.array,
fill_value=0.0,
method="linear",
)
Expand Down Expand Up @@ -1191,5 +1190,5 @@ def set_snr_of_snr_light_profiles(
)

@aa.profile_func
def convolve_via_convolver(self, image, blurring_image, convolver):
return convolver.convolve_image(image=image, blurring_image=blurring_image)
def convolve_via_psf(self, image, blurring_image, psf):
return psf.convolve_image(image=image, blurring_image=blurring_image)
7 changes: 4 additions & 3 deletions autolens/point/fit/fluxes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax.numpy as jnp
from typing import Optional

import autoarray as aa
Expand Down Expand Up @@ -101,10 +102,10 @@ def model_data(self):
are used.
"""
return aa.ArrayIrregular(
values=[
values=jnp.array([
magnification * self.profile.flux
for magnification in self.magnifications_at_positions
]
])
)

@property
Expand All @@ -128,5 +129,5 @@ def chi_squared(self) -> float:
RMS noise-map values squared.
"""
return ag.util.fit.chi_squared_from(
chi_squared_map=self.chi_squared_map,
chi_squared_map=self.chi_squared_map.array,
)
3 changes: 3 additions & 0 deletions autolens/point/fit/positions/image/pair.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import jax.numpy as jnp
import numpy as np
from ott.geometry import pointcloud
from ott.solvers.linear import sinkhorn
from scipy.optimize import linear_sum_assignment

import autoarray as aa
Expand Down
29 changes: 16 additions & 13 deletions autolens/point/fit/positions/image/pair_all.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax.numpy as jnp
import numpy as np

import autoarray as aa
Expand Down Expand Up @@ -85,7 +86,7 @@ def log_p(
The log probability of the model coordinate explaining the observed coordinate.
"""
chi2 = self.square_distance(data_position, model_position) / sigma**2
return -np.log(np.sqrt(2 * np.pi * sigma**2)) - 0.5 * chi2
return -jnp.log(jnp.sqrt(2 * jnp.pi * sigma**2)) - 0.5 * chi2

def all_permutations_log_likelihoods(self) -> np.ndarray:
"""
Expand All @@ -101,21 +102,23 @@ def all_permutations_log_likelihoods(self) -> np.ndarray:

This is every way in which the coordinates generated by the model can explain the observed coordinates.
"""
return np.array(

model_data = self.model_data.array

return jnp.array(
[
np.log(
np.sum(
[
np.exp(
jnp.log(
jnp.sum(
jnp.array([
jnp.exp(
self.log_p(
data_position,
model_position,
sigma,
)
)
for model_position in self.model_data
if not np.isnan(model_position).any()
]
for model_position in model_data
])
)
)
for data_position, sigma in zip(self.data, self.noise_map)
Expand All @@ -140,12 +143,12 @@ def chi_squared(self) -> float:

This is every way in which the coordinates generated by the model can explain the observed coordinates.
"""
n_non_nan_model_positions = np.count_nonzero(
~np.isnan(
self.model_data,
n_non_nan_model_positions = jnp.count_nonzero(
~jnp.isnan(
self.model_data.array,
).any(axis=1)
)
n_permutations = n_non_nan_model_positions ** len(self.data)
return -2.0 * (
-np.log(n_permutations) + np.sum(self.all_permutations_log_likelihoods())
-jnp.log(n_permutations) + jnp.sum(self.all_permutations_log_likelihoods())
)
7 changes: 3 additions & 4 deletions autolens/point/fit/positions/image/pair_repeat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np

import jax.numpy as jnp
import autoarray as aa

from autolens.point.fit.positions.image.abstract import AbstractFitPositionsImagePair
Expand Down Expand Up @@ -63,6 +62,6 @@ def residual_map(self) -> aa.ArrayIrregular:
self.square_distance(model_position, position)
for model_position in self.model_data
]
residual_map.append(np.sqrt(min(distances)))
residual_map.append(jnp.sqrt(jnp.min(jnp.array(distances))))

return aa.ArrayIrregular(values=residual_map)
return aa.ArrayIrregular(values=jnp.array(residual_map))
13 changes: 8 additions & 5 deletions autolens/point/fit/positions/source/separations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from autoarray.numpy_wrapper import numpy as npw
import jax.numpy as jnp
import numpy as np
from typing import Optional

Expand Down Expand Up @@ -118,19 +118,22 @@ def chi_squared_map(self) -> float:
"""

return self.residual_map**2.0 / (
self.magnifications_at_positions**-2.0 * self.noise_map**2.0
self.magnifications_at_positions.array**-2.0 * self.noise_map.array**2.0
)

@property
def noise_normalization(self) -> float:
"""
Returns the normalization of the noise-map, which is the sum of the noise-map values squared.
"""
return npw.sum(
npw.log(
return jnp.sum(
jnp.log(
2
* np.pi
* (self.magnifications_at_positions**-2.0 * self.noise_map**2.0)
* (
self.magnifications_at_positions.array**-2.0
* self.noise_map.array**2.0
)
)
)

Expand Down
7 changes: 4 additions & 3 deletions autolens/point/fit/times_delays.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax.numpy as jnp
import numpy as np
from typing import Optional

Expand Down Expand Up @@ -89,8 +90,8 @@ def residual_map(self) -> aa.ArrayIrregular:
from the dataset time delays and model time delays before the subtraction.
"""

data = self.data - np.min(self.data)
model_data = self.model_data - np.min(self.model_data)
data = self.data - jnp.min(self.data)
model_data = self.model_data - jnp.min(self.model_data)

residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data)
return aa.ArrayIrregular(values=residual_map)
Expand All @@ -102,5 +103,5 @@ def chi_squared(self) -> float:
which is the residual values divided by the RMS noise-map squared.
"""
return ag.util.fit.chi_squared_from(
chi_squared_map=self.chi_squared_map,
chi_squared_map=self.chi_squared_map.array,
)
4 changes: 2 additions & 2 deletions autolens/point/max_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ def __init__(
except TypeError:
plane_index = -1

self.plane_positions = tracer.traced_grid_2d_list_from(grid=data)[plane_index]
self.plane_positions = aa.Grid2DIrregular(values=tracer.traced_grid_2d_list_from(grid=data)[plane_index])

@property
def furthest_separations_of_plane_positions(self) -> aa.ArrayIrregular:
"""
Returns the furthest distance of every source-plane (y,x) coordinate to the other source-plane (y,x)
coordinates.

For example, for the following source-plane positions:
For example, for the following plane positions:

plane_positions = [[(0.0, 0.0), (0.0, 1.0), (0.0, 3.0)]

Expand Down
2 changes: 2 additions & 0 deletions autolens/point/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def log_likelihood_function(self, instance):
fit = self.fit_from(instance=instance)
return fit.log_likelihood
except (AttributeError, ValueError, TypeError, NumbaException) as e:
print(e)
dfdsfd
raise exc.FitException from e

def fit_from(
Expand Down
Loading
Loading