From 2586e779f7fe475a9f2c6f1209b4a42757345ba5 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 15 Jun 2025 17:01:08 +0100 Subject: [PATCH 01/15] point solver jax fix --- autolens/point/model/analysis.py | 2 ++ autolens/point/solver/point_solver.py | 3 ++- test_autolens/lens/test_operate.py | 8 ++++---- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/autolens/point/model/analysis.py b/autolens/point/model/analysis.py index aa4e56cca..820074b07 100644 --- a/autolens/point/model/analysis.py +++ b/autolens/point/model/analysis.py @@ -125,6 +125,8 @@ def log_likelihood_function(self, instance): fit = self.fit_from(instance=instance) return fit.log_likelihood except (AttributeError, ValueError, TypeError, NumbaException) as e: + print(e) + dfdsfd raise exc.FitException from e def fit_from( diff --git a/autolens/point/solver/point_solver.py b/autolens/point/solver/point_solver.py index 300a6f3fe..5e6f7d5a3 100644 --- a/autolens/point/solver/point_solver.py +++ b/autolens/point/solver/point_solver.py @@ -59,4 +59,5 @@ def solve( arr = aa.Grid2DIrregular([pair for pair in filtered_means]) mask = ~jnp.isnan(arr.array).any(axis=1) - return aa.Grid2DIrregular(arr.array[mask]) + + return aa.Grid2DIrregular(jnp.take(arr.array, jnp.nonzero(mask, size=mask.shape[0])[0], axis=0)) \ No newline at end of file diff --git a/test_autolens/lens/test_operate.py b/test_autolens/lens/test_operate.py index 0e1f6b32e..d5f7e09fb 100644 --- a/test_autolens/lens/test_operate.py +++ b/test_autolens/lens/test_operate.py @@ -122,10 +122,10 @@ def test__operate_image__galaxy_blurred_image_2d_dict_from( blurring_grid=blurring_grid_2d_7x7, ) - assert blurred_image_dict[g0].slim == pytest.approx(g0_blurred_image.slim, 1.0e-4) - assert blurred_image_dict[g1].slim == pytest.approx(g1_blurred_image.slim, 1.0e-4) - assert blurred_image_dict[g2].slim == pytest.approx(g2_blurred_image.slim, 1.0e-4) - assert blurred_image_dict[g3].slim == pytest.approx(g3_blurred_image.slim, 1.0e-4) + assert blurred_image_dict[g0].slim == pytest.approx(g0_blurred_image.slim.array, 1.0e-4) + assert blurred_image_dict[g1].slim == pytest.approx(g1_blurred_image.slim.array, 1.0e-4) + assert blurred_image_dict[g2].slim == pytest.approx(g2_blurred_image.slim.array, 1.0e-4) + assert blurred_image_dict[g3].slim == pytest.approx(g3_blurred_image.slim.array, 1.0e-4) def test__operate_image__galaxy_visibilities_dict_from_grid_and_transformer( From 6d9fb5a83c4812bde9b32734cde679db48e43b01 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 15 Jun 2025 17:04:43 +0100 Subject: [PATCH 02/15] convert fit position image pair repeat residual map to JAx --- autolens/point/fit/positions/image/pair_repeat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/autolens/point/fit/positions/image/pair_repeat.py b/autolens/point/fit/positions/image/pair_repeat.py index 9d6da6b99..fb280fe2a 100644 --- a/autolens/point/fit/positions/image/pair_repeat.py +++ b/autolens/point/fit/positions/image/pair_repeat.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np import autoarray as aa @@ -63,6 +64,6 @@ def residual_map(self) -> aa.ArrayIrregular: self.square_distance(model_position, position) for model_position in self.model_data ] - residual_map.append(np.sqrt(min(distances))) + residual_map.append(jnp.sqrt(jnp.min(jnp.array(distances)))) return aa.ArrayIrregular(values=residual_map) From 5eb989a07e0d28677ff59354cf5e5f0344739a1d Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 15 Jun 2025 17:07:12 +0100 Subject: [PATCH 03/15] another resdidual map fix --- autolens/point/fit/positions/image/pair_repeat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autolens/point/fit/positions/image/pair_repeat.py b/autolens/point/fit/positions/image/pair_repeat.py index fb280fe2a..c6e089ab3 100644 --- a/autolens/point/fit/positions/image/pair_repeat.py +++ b/autolens/point/fit/positions/image/pair_repeat.py @@ -66,4 +66,4 @@ def residual_map(self) -> aa.ArrayIrregular: ] residual_map.append(jnp.sqrt(jnp.min(jnp.array(distances)))) - return aa.ArrayIrregular(values=residual_map) + return aa.ArrayIrregular(values=jnp.array(residual_map)) From 5570bede19a330537843d1dde33d35172d8ab3a6 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 15 Jun 2025 18:11:17 +0100 Subject: [PATCH 04/15] fluxes jax'd --- autolens/point/fit/fluxes.py | 5 ++-- autolens/point/fit/positions/image/pair.py | 3 ++ .../point/fit/positions/image/pair_all.py | 30 +++++++++++-------- .../point/fit/positions/image/pair_repeat.py | 2 -- .../fit/positions/image/test_pair_all.py | 6 ++-- 5 files changed, 26 insertions(+), 20 deletions(-) diff --git a/autolens/point/fit/fluxes.py b/autolens/point/fit/fluxes.py index aef37118f..a34c4e56c 100644 --- a/autolens/point/fit/fluxes.py +++ b/autolens/point/fit/fluxes.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp from typing import Optional import autoarray as aa @@ -101,10 +102,10 @@ def model_data(self): are used. """ return aa.ArrayIrregular( - values=[ + values=jnp.array([ magnification * self.profile.flux for magnification in self.magnifications_at_positions - ] + ]) ) @property diff --git a/autolens/point/fit/positions/image/pair.py b/autolens/point/fit/positions/image/pair.py index e7af82c2b..ab2cf99b4 100644 --- a/autolens/point/fit/positions/image/pair.py +++ b/autolens/point/fit/positions/image/pair.py @@ -1,4 +1,7 @@ +import jax.numpy as jnp import numpy as np +from ott.geometry import pointcloud +from ott.solvers.linear import sinkhorn from scipy.optimize import linear_sum_assignment import autoarray as aa diff --git a/autolens/point/fit/positions/image/pair_all.py b/autolens/point/fit/positions/image/pair_all.py index 00bd4a454..940786ef9 100644 --- a/autolens/point/fit/positions/image/pair_all.py +++ b/autolens/point/fit/positions/image/pair_all.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np import autoarray as aa @@ -85,7 +86,7 @@ def log_p( The log probability of the model coordinate explaining the observed coordinate. """ chi2 = self.square_distance(data_position, model_position) / sigma**2 - return -np.log(np.sqrt(2 * np.pi * sigma**2)) - 0.5 * chi2 + return -jnp.log(jnp.sqrt(2 * jnp.pi * sigma**2)) - 0.5 * chi2 def all_permutations_log_likelihoods(self) -> np.ndarray: """ @@ -101,21 +102,24 @@ def all_permutations_log_likelihoods(self) -> np.ndarray: This is every way in which the coordinates generated by the model can explain the observed coordinates. """ - return np.array( + is_nan = jnp.isnan(self.model_data.array).any(axis=1) + sentinel = jnp.full_like(self.model_data.array[0], fill_value=jnp.inf) + model_data = jnp.where(is_nan[:, None], sentinel, self.model_data.array) + + return jnp.array( [ - np.log( - np.sum( - [ - np.exp( + jnp.log( + jnp.sum( + jnp.array([ + jnp.exp( self.log_p( data_position, model_position, sigma, ) ) - for model_position in self.model_data - if not np.isnan(model_position).any() - ] + for model_position in model_data + ]) ) ) for data_position, sigma in zip(self.data, self.noise_map) @@ -140,12 +144,12 @@ def chi_squared(self) -> float: This is every way in which the coordinates generated by the model can explain the observed coordinates. """ - n_non_nan_model_positions = np.count_nonzero( - ~np.isnan( - self.model_data, + n_non_nan_model_positions = jnp.count_nonzero( + ~jnp.isnan( + self.model_data.array, ).any(axis=1) ) n_permutations = n_non_nan_model_positions ** len(self.data) return -2.0 * ( - -np.log(n_permutations) + np.sum(self.all_permutations_log_likelihoods()) + -jnp.log(n_permutations) + jnp.sum(self.all_permutations_log_likelihoods()) ) diff --git a/autolens/point/fit/positions/image/pair_repeat.py b/autolens/point/fit/positions/image/pair_repeat.py index c6e089ab3..f3bc1dac1 100644 --- a/autolens/point/fit/positions/image/pair_repeat.py +++ b/autolens/point/fit/positions/image/pair_repeat.py @@ -1,6 +1,4 @@ import jax.numpy as jnp -import numpy as np - import autoarray as aa from autolens.point.fit.positions.image.abstract import AbstractFitPositionsImagePair diff --git a/test_autolens/point/fit/positions/image/test_pair_all.py b/test_autolens/point/fit/positions/image/test_pair_all.py index 838baef7f..cefad8282 100644 --- a/test_autolens/point/fit/positions/image/test_pair_all.py +++ b/test_autolens/point/fit/positions/image/test_pair_all.py @@ -27,7 +27,7 @@ def noise_map(): @pytest.fixture def fit(data, noise_map): - model_positions = np.array( + model_positions = al.Grid2DIrregular( [ (-1.0749, -1.1), (1.19117, 1.175), @@ -63,7 +63,7 @@ def test_nan_model_positions( data, noise_map, ): - model_positions = np.array( + model_positions = al.Grid2DIrregular( [ (-1.0749, -1.1), (1.19117, 1.175), @@ -92,7 +92,7 @@ def test_duplicate_model_position( data, noise_map, ): - model_positions = np.array( + model_positions = al.Grid2DIrregular( [ (-1.0749, -1.1), (1.19117, 1.175), From 64b1c1239362f54b0ecdaae50fcf83994fea7cb4 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 15 Jun 2025 18:34:26 +0100 Subject: [PATCH 05/15] fixing removal of infs from positions --- autolens/analysis/positions.py | 5 ++++- autolens/point/fit/positions/image/pair_all.py | 5 ++--- autolens/point/fit/times_delays.py | 5 +++-- autolens/point/solver/point_solver.py | 8 +++++--- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/autolens/analysis/positions.py b/autolens/analysis/positions.py index f467088b1..6f0d8c87a 100644 --- a/autolens/analysis/positions.py +++ b/autolens/analysis/positions.py @@ -71,7 +71,10 @@ def __init__( The plane redshift of the lensed source multiple images, which is only required if position threshold for a double source plane lens system is being used where the specific plane is required. """ - self.positions = positions + + self.positions = aa.Grid2DIrregular( + np.isfinite(positions.array) + ) self.threshold = threshold self.plane_redshift = plane_redshift diff --git a/autolens/point/fit/positions/image/pair_all.py b/autolens/point/fit/positions/image/pair_all.py index 940786ef9..6429b333c 100644 --- a/autolens/point/fit/positions/image/pair_all.py +++ b/autolens/point/fit/positions/image/pair_all.py @@ -102,9 +102,8 @@ def all_permutations_log_likelihoods(self) -> np.ndarray: This is every way in which the coordinates generated by the model can explain the observed coordinates. """ - is_nan = jnp.isnan(self.model_data.array).any(axis=1) - sentinel = jnp.full_like(self.model_data.array[0], fill_value=jnp.inf) - model_data = jnp.where(is_nan[:, None], sentinel, self.model_data.array) + + model_data = self.model_data.array return jnp.array( [ diff --git a/autolens/point/fit/times_delays.py b/autolens/point/fit/times_delays.py index 6d8699d36..df576510d 100644 --- a/autolens/point/fit/times_delays.py +++ b/autolens/point/fit/times_delays.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np from typing import Optional @@ -89,8 +90,8 @@ def residual_map(self) -> aa.ArrayIrregular: from the dataset time delays and model time delays before the subtraction. """ - data = self.data - np.min(self.data) - model_data = self.model_data - np.min(self.model_data) + data = self.data - jnp.min(self.data) + model_data = self.model_data - jnp.min(self.model_data) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) return aa.ArrayIrregular(values=residual_map) diff --git a/autolens/point/solver/point_solver.py b/autolens/point/solver/point_solver.py index 5e6f7d5a3..4caee637c 100644 --- a/autolens/point/solver/point_solver.py +++ b/autolens/point/solver/point_solver.py @@ -56,8 +56,10 @@ def solve( tracer=tracer, points=kept_triangles.means ) - arr = aa.Grid2DIrregular([pair for pair in filtered_means]) + solution = aa.Grid2DIrregular([pair for pair in filtered_means]).array - mask = ~jnp.isnan(arr.array).any(axis=1) + is_nan = jnp.isnan(solution).any(axis=1) + sentinel = jnp.full_like(solution[0], fill_value=jnp.inf) + solution = jnp.where(is_nan[:, None], sentinel, solution) - return aa.Grid2DIrregular(jnp.take(arr.array, jnp.nonzero(mask, size=mask.shape[0])[0], axis=0)) \ No newline at end of file + return aa.Grid2DIrregular(solution) \ No newline at end of file From f1b47462b0f1375cc83f2bf60503b51fe26a11e8 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 15 Jun 2025 18:45:04 +0100 Subject: [PATCH 06/15] fix position input --- autolens/analysis/positions.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/autolens/analysis/positions.py b/autolens/analysis/positions.py index 6f0d8c87a..a6271e018 100644 --- a/autolens/analysis/positions.py +++ b/autolens/analysis/positions.py @@ -72,8 +72,13 @@ def __init__( for a double source plane lens system is being used where the specific plane is required. """ + try: + mask = np.isfinite(positions.array).all(axis=1) + except AttributeError: + mask = np.isfinite(positions).all(axis=1) + self.positions = aa.Grid2DIrregular( - np.isfinite(positions.array) + positions[mask] ) self.threshold = threshold self.plane_redshift = plane_redshift From 3f823d194fa574c1f25228dd57c9dd9ddc5cc01c Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 16 Jun 2025 17:24:36 +0100 Subject: [PATCH 07/15] filter positions in Result --- autolens/analysis/positions.py | 9 +-------- autolens/analysis/result.py | 6 ++++++ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/autolens/analysis/positions.py b/autolens/analysis/positions.py index a6271e018..f21abce0c 100644 --- a/autolens/analysis/positions.py +++ b/autolens/analysis/positions.py @@ -72,14 +72,7 @@ def __init__( for a double source plane lens system is being used where the specific plane is required. """ - try: - mask = np.isfinite(positions.array).all(axis=1) - except AttributeError: - mask = np.isfinite(positions).all(axis=1) - - self.positions = aa.Grid2DIrregular( - positions[mask] - ) + self.positions = positions self.threshold = threshold self.plane_redshift = plane_redshift diff --git a/autolens/analysis/result.py b/autolens/analysis/result.py index cc07db977..a1c15b975 100644 --- a/autolens/analysis/result.py +++ b/autolens/analysis/result.py @@ -308,6 +308,12 @@ def positions_likelihood_from( positions = positions[distances > mass_centre_radial_distance_min] + mask = np.isfinite(positions.array).all(axis=1) + + positions = aa.Grid2DIrregular( + positions[mask] + ) + threshold = self.positions_threshold_from( factor=factor, minimum_threshold=minimum_threshold, From db3a82590ae93fa328263c6200e36fb055566ae7 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 16 Jun 2025 19:21:12 +0100 Subject: [PATCH 08/15] LH PEnalty uses where logic --- autolens/analysis/positions.py | 20 +++++++++++++++++--- autolens/point/max_separation.py | 3 ++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/autolens/analysis/positions.py b/autolens/analysis/positions.py index f21abce0c..9cf0ba1db 100644 --- a/autolens/analysis/positions.py +++ b/autolens/analysis/positions.py @@ -1,3 +1,5 @@ +import jax +import jax.numpy as jnp import numpy as np from typing import Optional, Union from os import path @@ -215,12 +217,24 @@ def log_likelihood_penalty_from(self, tracer: Tracer) -> Optional[float]: plane_redshift=self.plane_redshift, ) - if not positions_fit.max_separation_within_threshold(self.threshold): + max_separation = positions_fit.max_separation_of_plane_positions - return self.log_likelihood_penalty_factor * ( - positions_fit.max_separation_of_plane_positions - self.threshold + penalty = self.log_likelihood_penalty_factor * ( + max_separation - self.threshold ) + + return jnp.where(max_separation < self.threshold, 0.0, penalty) + + # if not positions_fit.max_separation_within_threshold(self.threshold): + # + # return + # return jax.lax.cond( + # positions_fit.max_separation_within_threshold(self.threshold), + # lambda: compute_penalty(), + # lambda: jnp.array(-1.0e10), + # ) + def log_likelihood_function_positions_overwrite( self, instance: af.ModelInstance, analysis: AnalysisDataset ) -> Optional[float]: diff --git a/autolens/point/max_separation.py b/autolens/point/max_separation.py index f17e65de3..f4080a8c2 100644 --- a/autolens/point/max_separation.py +++ b/autolens/point/max_separation.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp from typing import Optional import autoarray as aa @@ -68,7 +69,7 @@ def furthest_separations_of_plane_positions(self) -> aa.ArrayIrregular: @property def max_separation_of_plane_positions(self) -> float: - return max(self.furthest_separations_of_plane_positions) + return jnp.max(self.furthest_separations_of_plane_positions.array) def max_separation_within_threshold(self, threshold) -> bool: return self.max_separation_of_plane_positions <= threshold From 8d887bbf5ac6d37399b00dd7ae1b121ec6d488ab Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 16 Jun 2025 19:21:30 +0100 Subject: [PATCH 09/15] deleted commented code --- autolens/analysis/positions.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/autolens/analysis/positions.py b/autolens/analysis/positions.py index 9cf0ba1db..eed5577dc 100644 --- a/autolens/analysis/positions.py +++ b/autolens/analysis/positions.py @@ -226,15 +226,6 @@ def log_likelihood_penalty_from(self, tracer: Tracer) -> Optional[float]: return jnp.where(max_separation < self.threshold, 0.0, penalty) - # if not positions_fit.max_separation_within_threshold(self.threshold): - # - # return - # return jax.lax.cond( - # positions_fit.max_separation_within_threshold(self.threshold), - # lambda: compute_penalty(), - # lambda: jnp.array(-1.0e10), - # ) - def log_likelihood_function_positions_overwrite( self, instance: af.ModelInstance, analysis: AnalysisDataset ) -> Optional[float]: From 4d23adbfe0bc9398a18c0ce22a81beed95fbb758 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 16 Jun 2025 19:26:42 +0100 Subject: [PATCH 10/15] now use jax.lax.cond --- autolens/analysis/positions.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/autolens/analysis/positions.py b/autolens/analysis/positions.py index eed5577dc..ec043a73f 100644 --- a/autolens/analysis/positions.py +++ b/autolens/analysis/positions.py @@ -223,8 +223,11 @@ def log_likelihood_penalty_from(self, tracer: Tracer) -> Optional[float]: max_separation - self.threshold ) - - return jnp.where(max_separation < self.threshold, 0.0, penalty) + return jax.lax.cond( + max_separation > self.threshold, + lambda: penalty, + lambda: jnp.array(0.0), + ) def log_likelihood_function_positions_overwrite( self, instance: af.ModelInstance, analysis: AnalysisDataset From 5051c5a9a3ab20733af0685025698002c94ad2f0 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 16 Jun 2025 20:27:26 +0100 Subject: [PATCH 11/15] update imaging unitt ests with working positon LH calc --- autolens/analysis/analysis/lens.py | 20 +-- autolens/analysis/positions.py | 121 ++++-------------- autolens/imaging/model/analysis.py | 9 +- autolens/interferometer/model/analysis.py | 8 +- .../imaging/model/test_analysis_imaging.py | 14 +- 5 files changed, 44 insertions(+), 128 deletions(-) diff --git a/autolens/analysis/analysis/lens.py b/autolens/analysis/analysis/lens.py index bcb1fc60e..840907c05 100644 --- a/autolens/analysis/analysis/lens.py +++ b/autolens/analysis/analysis/lens.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import logging import numpy as np from typing import Dict, List, Optional, Union @@ -98,7 +99,7 @@ def tracer_via_instance_from( run_time_dict=run_time_dict, ) - def log_likelihood_positions_overwrite_from( + def log_likelihood_penalty_from( self, instance: af.ModelInstance ) -> Optional[float]: """ @@ -120,21 +121,20 @@ def log_likelihood_positions_overwrite_from( The penalty value of the positions log likelihood, if the positions do not trace close in the source plane, else a None is returned to indicate there is no penalty. """ - if self.positions_likelihood_list is not None: + log_likelihood_penalty = jnp.array(0.0) - log_likelihood_overwrite = None + if self.positions_likelihood_list is not None: try: for positions_likelihood in self.positions_likelihood_list: - log_likelihood_with_penalty = positions_likelihood.log_likelihood_function_positions_overwrite( + log_likelihood_penalty = positions_likelihood.log_likelihood_penalty_from( instance=instance, analysis=self ) - if log_likelihood_with_penalty is not None: - try: - log_likelihood_overwrite += log_likelihood_with_penalty - except TypeError: - log_likelihood_overwrite = log_likelihood_with_penalty - return log_likelihood_overwrite + log_likelihood_penalty += log_likelihood_penalty + + return log_likelihood_penalty except (ValueError, np.linalg.LinAlgError) as e: raise exc.FitException from e + + return log_likelihood_penalty \ No newline at end of file diff --git a/autolens/analysis/positions.py b/autolens/analysis/positions.py index ec043a73f..ea77f3c5b 100644 --- a/autolens/analysis/positions.py +++ b/autolens/analysis/positions.py @@ -136,79 +136,45 @@ def output_positions_info( ) f.write("") - def log_likelihood_penalty_base_from( - self, dataset: Union[aa.Imaging, aa.Interferometer] - ) -> float: + def log_likelihood_penalty_from(self, instance: af.ModelInstance, analysis: AnalysisDataset) -> jnp.array: """ - The fast log likelihood penalty scheme returns an alternative penalty log likelihood for any model where the - image-plane positions do not trace within a threshold distance of one another in the source-plane. + Returns a log-likelihood penalty used to constrain lens models where multiple image-plane + positions do not trace to within a threshold distance of one another in the source-plane. - This `log_likelihood_penalty` is defined as: + This penalty is intended for use in `Analysis` classes that include the `PenaltyLH` mixin. It adds a + heavy penalty to the likelihood when the multiple images traces far apart in the source-plane, discouraging + models where the mapped source-plane positions are too widely separated. - log_Likelihood_penalty_base - log_likelihood_penalty_factor * (max_source_plane_separation - threshold) + Specifically, if the maximum separation between traced positions in the source-plane exceeds + a defined threshold, a penalty term is applied to the log likelihood: - The `log_likelihood_penalty` is only used if `max_source_plane_separation > threshold`. + penalty = log_likelihood_penalty_factor * (max_separation - threshold) - This function returns the `log_likelihood_penalty_base`, which represents the lowest possible likelihood - solutions a model-fit can give. It is the chi-squared of model-data consisting of all zeros plus - the noise normalziation term. + If the separation is within the threshold, no penalty is applied. + + JAX Compatibility + ----------------- + Because this function may be jitted or differentiated using JAX, it uses `jax.lax.cond` to apply + conditional logic in a way that is compatible with JAX's functional and tracing model. + Both branches (penalty and zero) are evaluated at trace time, though only one is returned + at runtime depending on the condition. Parameters ---------- - dataset - The imaging or interferometer dataset from which the penalty base is computed. - """ - - residual_map = aa.util.fit.residual_map_from( - data=dataset.data, model_data=np.zeros(dataset.data.shape) - ) - - if isinstance(dataset, aa.Imaging): - chi_squared_map = aa.util.fit.chi_squared_map_from( - residual_map=residual_map, noise_map=dataset.noise_map - ) - - chi_squared = aa.util.fit.chi_squared_from( - chi_squared_map=chi_squared_map.array - ) - - noise_normalization = aa.util.fit.noise_normalization_from( - noise_map=dataset.noise_map.array - ) - - else: - chi_squared_map = aa.util.fit.chi_squared_map_complex_from( - residual_map=residual_map, noise_map=dataset.noise_map - ) - - chi_squared = aa.util.fit.chi_squared_complex_from( - chi_squared_map=chi_squared_map - ) - - noise_normalization = aa.util.fit.noise_normalization_complex_from( - noise_map=dataset.noise_map - ) - - return -0.5 * (chi_squared + noise_normalization) + instance + The current model instance evaluated during the non-linear search. + analysis + The `Analysis` object calling this function, from which the `tracer` and `dataset` are derived. - def log_likelihood_penalty_from(self, tracer: Tracer) -> Optional[float]: + Returns + ------- + penalty + A scalar log-likelihood penalty (≥ 0) if the max separation exceeds the threshold, or 0.0 otherwise. """ - The fast log likelihood penalty scheme returns an alternative penalty log likelihood for any model where the - image-plane positions to not trace within a threshold distance of one another in the source-plane. - - This `log_likelihood_penalty` is defined as: - - log_Likelihood_penalty_base - log_likelihood_penalty_factor * (max_source_plane_separation - threshold) - - The `log_likelihood_penalty` is only used if `max_source_plane_separation > threshold`. + tracer = analysis.tracer_via_instance_from(instance=instance) - Parameters - ---------- - dataset - The imaging or interferometer dataset from which the penalty base is computed. - """ if not tracer.has(cls=ag.mp.MassProfile) or len(tracer.planes) == 1: - return + return jnp.array(0.0), positions_fit = SourceMaxSeparation( data=self.positions, @@ -228,36 +194,3 @@ def log_likelihood_penalty_from(self, tracer: Tracer) -> Optional[float]: lambda: penalty, lambda: jnp.array(0.0), ) - - def log_likelihood_function_positions_overwrite( - self, instance: af.ModelInstance, analysis: AnalysisDataset - ) -> Optional[float]: - """ - This is called in the `log_likelihood_function` of certain `Analysis` classes to add the penalty term of - this class, which penalies mass models which do not trace within the threshold of one another in the - source-plane. - - Parameters - ---------- - instance - The instance of the lens model that is being fitted for this iteration of the non-linear search. - analysis - The analysis class from which the log likliehood function is called. - """ - tracer = analysis.tracer_via_instance_from(instance=instance) - - if not tracer.has(cls=ag.mp.MassProfile) or len(tracer.planes) == 1: - return - - log_likelihood_positions_penalty = self.log_likelihood_penalty_from( - tracer=tracer - ) - - if log_likelihood_positions_penalty is None: - return None - - log_likelihood_penalty_base = self.log_likelihood_penalty_base_from( - dataset=analysis.dataset - ) - - return log_likelihood_penalty_base - log_likelihood_positions_penalty diff --git a/autolens/imaging/model/analysis.py b/autolens/imaging/model/analysis.py index a63f30fb6..770d057e5 100644 --- a/autolens/imaging/model/analysis.py +++ b/autolens/imaging/model/analysis.py @@ -93,19 +93,14 @@ def log_likelihood_function(self, instance: af.ModelInstance) -> float: """ try: - log_likelihood_positions_overwrite = self.log_likelihood_positions_overwrite_from( + log_likelihood_penalty = self.log_likelihood_penalty_from( instance=instance ) - if log_likelihood_positions_overwrite is not None: - return log_likelihood_positions_overwrite except Exception as e: raise e - if log_likelihood_positions_overwrite is not None: - return log_likelihood_positions_overwrite - try: - return self.fit_from(instance=instance).figure_of_merit + return self.fit_from(instance=instance).figure_of_merit + log_likelihood_penalty except ( PixelizationException, exc.PixelizationException, diff --git a/autolens/interferometer/model/analysis.py b/autolens/interferometer/model/analysis.py index 3fbced456..e7b062610 100644 --- a/autolens/interferometer/model/analysis.py +++ b/autolens/interferometer/model/analysis.py @@ -152,16 +152,14 @@ def log_likelihood_function(self, instance): """ try: - log_likelihood_positions_overwrite = ( - self.log_likelihood_positions_overwrite_from(instance=instance) + log_likelihood_penalty = self.log_likelihood_penalty_from( + instance=instance ) - if log_likelihood_positions_overwrite is not None: - return log_likelihood_positions_overwrite except Exception as e: raise e try: - return self.fit_from(instance=instance).figure_of_merit + return self.fit_from(instance=instance).figure_of_merit + log_likelihood_penalty except ( PixelizationException, exc.PixelizationException, diff --git a/test_autolens/imaging/model/test_analysis_imaging.py b/test_autolens/imaging/model/test_analysis_imaging.py index 7453d6f59..a9a742599 100644 --- a/test_autolens/imaging/model/test_analysis_imaging.py +++ b/test_autolens/imaging/model/test_analysis_imaging.py @@ -71,17 +71,7 @@ def test__positions__likelihood_overwrites__changes_likelihood(masked_imaging_7x ) analysis_log_likelihood = analysis.log_likelihood_function(instance=instance) - log_likelihood_penalty_base = positions_likelihood.log_likelihood_penalty_base_from( - dataset=masked_imaging_7x7 - ) - log_likelihood_penalty = positions_likelihood.log_likelihood_penalty_from( - tracer=tracer - ) - - assert analysis_log_likelihood == pytest.approx( - log_likelihood_penalty_base - log_likelihood_penalty, 1.0e-4 - ) - assert analysis_log_likelihood == pytest.approx(-22048700558.9052, 1.0e-4) + assert analysis_log_likelihood == pytest.approx(44097289491.9806, 1.0e-4) def test__positions__likelihood_overwrites__changes_likelihood__double_source_plane_example(masked_imaging_7x7): @@ -106,7 +96,7 @@ def test__positions__likelihood_overwrites__changes_likelihood__double_source_pl ) analysis_log_likelihood = analysis.log_likelihood_function(instance=instance) - assert analysis_log_likelihood == pytest.approx(-44140499647.28964, 1.0e-4) + assert analysis_log_likelihood == pytest.approx(44097289491.8073, 1.0e-4) def test__profile_log_likelihood_function(masked_imaging_7x7): From 1b02b56a8d6b37fa03403e09defb9f5de06b2cab Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 16 Jun 2025 20:30:12 +0100 Subject: [PATCH 12/15] fix interferometer --- .../model/test_analysis_interferometer.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/test_autolens/interferometer/model/test_analysis_interferometer.py b/test_autolens/interferometer/model/test_analysis_interferometer.py index 5ceea7199..c36d9cfe2 100644 --- a/test_autolens/interferometer/model/test_analysis_interferometer.py +++ b/test_autolens/interferometer/model/test_analysis_interferometer.py @@ -68,17 +68,7 @@ def test__positions__likelihood_overwrite__changes_likelihood( ) analysis_log_likelihood = analysis.log_likelihood_function(instance=instance) - log_likelihood_penalty_base = positions_likelihood.log_likelihood_penalty_base_from( - dataset=interferometer_7 - ) - log_likelihood_penalty = positions_likelihood.log_likelihood_penalty_from( - tracer=tracer - ) - - assert analysis_log_likelihood == pytest.approx( - log_likelihood_penalty_base - log_likelihood_penalty, 1.0e-4 - ) - assert analysis_log_likelihood == pytest.approx(-22048700567.590656, 1.0e-4) + assert analysis_log_likelihood == pytest.approx(44097289444.30784, 1.0e-4) def test__profile_log_likelihood_function(interferometer_7): From dbb73f6e41e74dd74e26154d9571f68acca9c144 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 16 Jun 2025 20:58:20 +0100 Subject: [PATCH 13/15] full positions refactor --- autolens/analysis/positions.py | 2 +- autolens/point/fit/positions/image/pair_all.py | 2 -- autolens/point/max_separation.py | 3 +-- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/autolens/analysis/positions.py b/autolens/analysis/positions.py index ea77f3c5b..481c6a7fc 100644 --- a/autolens/analysis/positions.py +++ b/autolens/analysis/positions.py @@ -183,7 +183,7 @@ def log_likelihood_penalty_from(self, instance: af.ModelInstance, analysis: Anal plane_redshift=self.plane_redshift, ) - max_separation = positions_fit.max_separation_of_plane_positions + max_separation = jnp.max(positions_fit.furthest_separations_of_plane_positions.array) penalty = self.log_likelihood_penalty_factor * ( max_separation - self.threshold diff --git a/autolens/point/fit/positions/image/pair_all.py b/autolens/point/fit/positions/image/pair_all.py index 6429b333c..b76c27fb1 100644 --- a/autolens/point/fit/positions/image/pair_all.py +++ b/autolens/point/fit/positions/image/pair_all.py @@ -1,8 +1,6 @@ import jax.numpy as jnp import numpy as np -import autoarray as aa - from autolens.point.fit.positions.image.abstract import AbstractFitPositionsImagePair diff --git a/autolens/point/max_separation.py b/autolens/point/max_separation.py index f4080a8c2..f17e65de3 100644 --- a/autolens/point/max_separation.py +++ b/autolens/point/max_separation.py @@ -1,4 +1,3 @@ -import jax.numpy as jnp from typing import Optional import autoarray as aa @@ -69,7 +68,7 @@ def furthest_separations_of_plane_positions(self) -> aa.ArrayIrregular: @property def max_separation_of_plane_positions(self) -> float: - return jnp.max(self.furthest_separations_of_plane_positions.array) + return max(self.furthest_separations_of_plane_positions) def max_separation_within_threshold(self, threshold) -> bool: return self.max_separation_of_plane_positions <= threshold From df77e4a24999cdb8e5f9fb229734307b953de17f Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 16:20:04 +0100 Subject: [PATCH 14/15] fix point plot --- autolens/point/plot/fit_point_plotters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autolens/point/plot/fit_point_plotters.py b/autolens/point/plot/fit_point_plotters.py index d3b65c657..0ac2a59b7 100644 --- a/autolens/point/plot/fit_point_plotters.py +++ b/autolens/point/plot/fit_point_plotters.py @@ -144,7 +144,7 @@ def figures_2d(self, positions: bool = False, fluxes: bool = False): try: visuals_1d += visuals_1d.__class__( - model_fluxes=self.fit.flux.model_fluxes + model_fluxes=self.fit.flux.model_fluxes.array ) except AttributeError: pass From e725f0b142d8654cade6d7bf3c74b2734b7c9d74 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 16:51:48 +0100 Subject: [PATCH 15/15] fix FitPoisitonsAll --- autolens/point/fit/positions/image/pair_all.py | 2 +- autolens/point/fit/times_delays.py | 4 ++-- test_autolens/point/fit/positions/image/test_pair_all.py | 6 ++++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/autolens/point/fit/positions/image/pair_all.py b/autolens/point/fit/positions/image/pair_all.py index b76c27fb1..a21fd39b9 100644 --- a/autolens/point/fit/positions/image/pair_all.py +++ b/autolens/point/fit/positions/image/pair_all.py @@ -142,7 +142,7 @@ def chi_squared(self) -> float: This is every way in which the coordinates generated by the model can explain the observed coordinates. """ n_non_nan_model_positions = jnp.count_nonzero( - ~jnp.isnan( + jnp.isfinite( self.model_data.array, ).any(axis=1) ) diff --git a/autolens/point/fit/times_delays.py b/autolens/point/fit/times_delays.py index df576510d..23a2cae55 100644 --- a/autolens/point/fit/times_delays.py +++ b/autolens/point/fit/times_delays.py @@ -90,8 +90,8 @@ def residual_map(self) -> aa.ArrayIrregular: from the dataset time delays and model time delays before the subtraction. """ - data = self.data - jnp.min(self.data) - model_data = self.model_data - jnp.min(self.model_data) + data = self.data - jnp.min(self.data.array) + model_data = self.model_data - jnp.min(self.model_data.array) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) return aa.ArrayIrregular(values=residual_map) diff --git a/test_autolens/point/fit/positions/image/test_pair_all.py b/test_autolens/point/fit/positions/image/test_pair_all.py index cefad8282..c5bbb2b55 100644 --- a/test_autolens/point/fit/positions/image/test_pair_all.py +++ b/test_autolens/point/fit/positions/image/test_pair_all.py @@ -59,7 +59,7 @@ def test_andrew_implementation(fit): # assert jax.jit(fit.log_likelihood)() == -4.40375330990644 -def test_nan_model_positions( +def test_inf_model_positions( data, noise_map, ): @@ -67,7 +67,7 @@ def test_nan_model_positions( [ (-1.0749, -1.1), (1.19117, 1.175), - (np.nan, np.nan), + (np.inf, np.inf), ] ) fit = al.FitPositionsImagePairAll( @@ -78,6 +78,8 @@ def test_nan_model_positions( solver=al.mock.MockPointSolver(model_positions), ) + print(fit.all_permutations_log_likelihoods()) + assert np.allclose( fit.all_permutations_log_likelihoods(), [