diff --git a/autolens/analysis/analysis/lens.py b/autolens/analysis/analysis/lens.py index bcb1fc60e..840907c05 100644 --- a/autolens/analysis/analysis/lens.py +++ b/autolens/analysis/analysis/lens.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import logging import numpy as np from typing import Dict, List, Optional, Union @@ -98,7 +99,7 @@ def tracer_via_instance_from( run_time_dict=run_time_dict, ) - def log_likelihood_positions_overwrite_from( + def log_likelihood_penalty_from( self, instance: af.ModelInstance ) -> Optional[float]: """ @@ -120,21 +121,20 @@ def log_likelihood_positions_overwrite_from( The penalty value of the positions log likelihood, if the positions do not trace close in the source plane, else a None is returned to indicate there is no penalty. """ - if self.positions_likelihood_list is not None: + log_likelihood_penalty = jnp.array(0.0) - log_likelihood_overwrite = None + if self.positions_likelihood_list is not None: try: for positions_likelihood in self.positions_likelihood_list: - log_likelihood_with_penalty = positions_likelihood.log_likelihood_function_positions_overwrite( + log_likelihood_penalty = positions_likelihood.log_likelihood_penalty_from( instance=instance, analysis=self ) - if log_likelihood_with_penalty is not None: - try: - log_likelihood_overwrite += log_likelihood_with_penalty - except TypeError: - log_likelihood_overwrite = log_likelihood_with_penalty - return log_likelihood_overwrite + log_likelihood_penalty += log_likelihood_penalty + + return log_likelihood_penalty except (ValueError, np.linalg.LinAlgError) as e: raise exc.FitException from e + + return log_likelihood_penalty \ No newline at end of file diff --git a/autolens/analysis/positions.py b/autolens/analysis/positions.py index f467088b1..481c6a7fc 100644 --- a/autolens/analysis/positions.py +++ b/autolens/analysis/positions.py @@ -1,3 +1,5 @@ +import jax +import jax.numpy as jnp import numpy as np from typing import Optional, Union from os import path @@ -71,6 +73,7 @@ def __init__( The plane redshift of the lensed source multiple images, which is only required if position threshold for a double source plane lens system is being used where the specific plane is required. """ + self.positions = positions self.threshold = threshold self.plane_redshift = plane_redshift @@ -133,79 +136,45 @@ def output_positions_info( ) f.write("") - def log_likelihood_penalty_base_from( - self, dataset: Union[aa.Imaging, aa.Interferometer] - ) -> float: + def log_likelihood_penalty_from(self, instance: af.ModelInstance, analysis: AnalysisDataset) -> jnp.array: """ - The fast log likelihood penalty scheme returns an alternative penalty log likelihood for any model where the - image-plane positions do not trace within a threshold distance of one another in the source-plane. + Returns a log-likelihood penalty used to constrain lens models where multiple image-plane + positions do not trace to within a threshold distance of one another in the source-plane. + + This penalty is intended for use in `Analysis` classes that include the `PenaltyLH` mixin. It adds a + heavy penalty to the likelihood when the multiple images traces far apart in the source-plane, discouraging + models where the mapped source-plane positions are too widely separated. - This `log_likelihood_penalty` is defined as: + Specifically, if the maximum separation between traced positions in the source-plane exceeds + a defined threshold, a penalty term is applied to the log likelihood: - log_Likelihood_penalty_base - log_likelihood_penalty_factor * (max_source_plane_separation - threshold) + penalty = log_likelihood_penalty_factor * (max_separation - threshold) - The `log_likelihood_penalty` is only used if `max_source_plane_separation > threshold`. + If the separation is within the threshold, no penalty is applied. - This function returns the `log_likelihood_penalty_base`, which represents the lowest possible likelihood - solutions a model-fit can give. It is the chi-squared of model-data consisting of all zeros plus - the noise normalziation term. + JAX Compatibility + ----------------- + Because this function may be jitted or differentiated using JAX, it uses `jax.lax.cond` to apply + conditional logic in a way that is compatible with JAX's functional and tracing model. + Both branches (penalty and zero) are evaluated at trace time, though only one is returned + at runtime depending on the condition. Parameters ---------- - dataset - The imaging or interferometer dataset from which the penalty base is computed. - """ - - residual_map = aa.util.fit.residual_map_from( - data=dataset.data, model_data=np.zeros(dataset.data.shape) - ) - - if isinstance(dataset, aa.Imaging): - chi_squared_map = aa.util.fit.chi_squared_map_from( - residual_map=residual_map, noise_map=dataset.noise_map - ) - - chi_squared = aa.util.fit.chi_squared_from( - chi_squared_map=chi_squared_map.array - ) - - noise_normalization = aa.util.fit.noise_normalization_from( - noise_map=dataset.noise_map.array - ) - - else: - chi_squared_map = aa.util.fit.chi_squared_map_complex_from( - residual_map=residual_map, noise_map=dataset.noise_map - ) - - chi_squared = aa.util.fit.chi_squared_complex_from( - chi_squared_map=chi_squared_map - ) - - noise_normalization = aa.util.fit.noise_normalization_complex_from( - noise_map=dataset.noise_map - ) - - return -0.5 * (chi_squared + noise_normalization) + instance + The current model instance evaluated during the non-linear search. + analysis + The `Analysis` object calling this function, from which the `tracer` and `dataset` are derived. - def log_likelihood_penalty_from(self, tracer: Tracer) -> Optional[float]: + Returns + ------- + penalty + A scalar log-likelihood penalty (≥ 0) if the max separation exceeds the threshold, or 0.0 otherwise. """ - The fast log likelihood penalty scheme returns an alternative penalty log likelihood for any model where the - image-plane positions to not trace within a threshold distance of one another in the source-plane. - - This `log_likelihood_penalty` is defined as: - - log_Likelihood_penalty_base - log_likelihood_penalty_factor * (max_source_plane_separation - threshold) - - The `log_likelihood_penalty` is only used if `max_source_plane_separation > threshold`. + tracer = analysis.tracer_via_instance_from(instance=instance) - Parameters - ---------- - dataset - The imaging or interferometer dataset from which the penalty base is computed. - """ if not tracer.has(cls=ag.mp.MassProfile) or len(tracer.planes) == 1: - return + return jnp.array(0.0), positions_fit = SourceMaxSeparation( data=self.positions, @@ -214,41 +183,14 @@ def log_likelihood_penalty_from(self, tracer: Tracer) -> Optional[float]: plane_redshift=self.plane_redshift, ) - if not positions_fit.max_separation_within_threshold(self.threshold): + max_separation = jnp.max(positions_fit.furthest_separations_of_plane_positions.array) - return self.log_likelihood_penalty_factor * ( - positions_fit.max_separation_of_plane_positions - self.threshold + penalty = self.log_likelihood_penalty_factor * ( + max_separation - self.threshold ) - def log_likelihood_function_positions_overwrite( - self, instance: af.ModelInstance, analysis: AnalysisDataset - ) -> Optional[float]: - """ - This is called in the `log_likelihood_function` of certain `Analysis` classes to add the penalty term of - this class, which penalies mass models which do not trace within the threshold of one another in the - source-plane. - - Parameters - ---------- - instance - The instance of the lens model that is being fitted for this iteration of the non-linear search. - analysis - The analysis class from which the log likliehood function is called. - """ - tracer = analysis.tracer_via_instance_from(instance=instance) - - if not tracer.has(cls=ag.mp.MassProfile) or len(tracer.planes) == 1: - return - - log_likelihood_positions_penalty = self.log_likelihood_penalty_from( - tracer=tracer + return jax.lax.cond( + max_separation > self.threshold, + lambda: penalty, + lambda: jnp.array(0.0), ) - - if log_likelihood_positions_penalty is None: - return None - - log_likelihood_penalty_base = self.log_likelihood_penalty_base_from( - dataset=analysis.dataset - ) - - return log_likelihood_penalty_base - log_likelihood_positions_penalty diff --git a/autolens/analysis/result.py b/autolens/analysis/result.py index cc07db977..a1c15b975 100644 --- a/autolens/analysis/result.py +++ b/autolens/analysis/result.py @@ -308,6 +308,12 @@ def positions_likelihood_from( positions = positions[distances > mass_centre_radial_distance_min] + mask = np.isfinite(positions.array).all(axis=1) + + positions = aa.Grid2DIrregular( + positions[mask] + ) + threshold = self.positions_threshold_from( factor=factor, minimum_threshold=minimum_threshold, diff --git a/autolens/imaging/model/analysis.py b/autolens/imaging/model/analysis.py index a63f30fb6..770d057e5 100644 --- a/autolens/imaging/model/analysis.py +++ b/autolens/imaging/model/analysis.py @@ -93,19 +93,14 @@ def log_likelihood_function(self, instance: af.ModelInstance) -> float: """ try: - log_likelihood_positions_overwrite = self.log_likelihood_positions_overwrite_from( + log_likelihood_penalty = self.log_likelihood_penalty_from( instance=instance ) - if log_likelihood_positions_overwrite is not None: - return log_likelihood_positions_overwrite except Exception as e: raise e - if log_likelihood_positions_overwrite is not None: - return log_likelihood_positions_overwrite - try: - return self.fit_from(instance=instance).figure_of_merit + return self.fit_from(instance=instance).figure_of_merit + log_likelihood_penalty except ( PixelizationException, exc.PixelizationException, diff --git a/autolens/interferometer/model/analysis.py b/autolens/interferometer/model/analysis.py index 3fbced456..e7b062610 100644 --- a/autolens/interferometer/model/analysis.py +++ b/autolens/interferometer/model/analysis.py @@ -152,16 +152,14 @@ def log_likelihood_function(self, instance): """ try: - log_likelihood_positions_overwrite = ( - self.log_likelihood_positions_overwrite_from(instance=instance) + log_likelihood_penalty = self.log_likelihood_penalty_from( + instance=instance ) - if log_likelihood_positions_overwrite is not None: - return log_likelihood_positions_overwrite except Exception as e: raise e try: - return self.fit_from(instance=instance).figure_of_merit + return self.fit_from(instance=instance).figure_of_merit + log_likelihood_penalty except ( PixelizationException, exc.PixelizationException, diff --git a/autolens/point/fit/fluxes.py b/autolens/point/fit/fluxes.py index aef37118f..a34c4e56c 100644 --- a/autolens/point/fit/fluxes.py +++ b/autolens/point/fit/fluxes.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp from typing import Optional import autoarray as aa @@ -101,10 +102,10 @@ def model_data(self): are used. """ return aa.ArrayIrregular( - values=[ + values=jnp.array([ magnification * self.profile.flux for magnification in self.magnifications_at_positions - ] + ]) ) @property diff --git a/autolens/point/fit/positions/image/pair.py b/autolens/point/fit/positions/image/pair.py index e7af82c2b..ab2cf99b4 100644 --- a/autolens/point/fit/positions/image/pair.py +++ b/autolens/point/fit/positions/image/pair.py @@ -1,4 +1,7 @@ +import jax.numpy as jnp import numpy as np +from ott.geometry import pointcloud +from ott.solvers.linear import sinkhorn from scipy.optimize import linear_sum_assignment import autoarray as aa diff --git a/autolens/point/fit/positions/image/pair_all.py b/autolens/point/fit/positions/image/pair_all.py index 00bd4a454..a21fd39b9 100644 --- a/autolens/point/fit/positions/image/pair_all.py +++ b/autolens/point/fit/positions/image/pair_all.py @@ -1,7 +1,6 @@ +import jax.numpy as jnp import numpy as np -import autoarray as aa - from autolens.point.fit.positions.image.abstract import AbstractFitPositionsImagePair @@ -85,7 +84,7 @@ def log_p( The log probability of the model coordinate explaining the observed coordinate. """ chi2 = self.square_distance(data_position, model_position) / sigma**2 - return -np.log(np.sqrt(2 * np.pi * sigma**2)) - 0.5 * chi2 + return -jnp.log(jnp.sqrt(2 * jnp.pi * sigma**2)) - 0.5 * chi2 def all_permutations_log_likelihoods(self) -> np.ndarray: """ @@ -101,21 +100,23 @@ def all_permutations_log_likelihoods(self) -> np.ndarray: This is every way in which the coordinates generated by the model can explain the observed coordinates. """ - return np.array( + + model_data = self.model_data.array + + return jnp.array( [ - np.log( - np.sum( - [ - np.exp( + jnp.log( + jnp.sum( + jnp.array([ + jnp.exp( self.log_p( data_position, model_position, sigma, ) ) - for model_position in self.model_data - if not np.isnan(model_position).any() - ] + for model_position in model_data + ]) ) ) for data_position, sigma in zip(self.data, self.noise_map) @@ -140,12 +141,12 @@ def chi_squared(self) -> float: This is every way in which the coordinates generated by the model can explain the observed coordinates. """ - n_non_nan_model_positions = np.count_nonzero( - ~np.isnan( - self.model_data, + n_non_nan_model_positions = jnp.count_nonzero( + jnp.isfinite( + self.model_data.array, ).any(axis=1) ) n_permutations = n_non_nan_model_positions ** len(self.data) return -2.0 * ( - -np.log(n_permutations) + np.sum(self.all_permutations_log_likelihoods()) + -jnp.log(n_permutations) + jnp.sum(self.all_permutations_log_likelihoods()) ) diff --git a/autolens/point/fit/positions/image/pair_repeat.py b/autolens/point/fit/positions/image/pair_repeat.py index 9d6da6b99..f3bc1dac1 100644 --- a/autolens/point/fit/positions/image/pair_repeat.py +++ b/autolens/point/fit/positions/image/pair_repeat.py @@ -1,5 +1,4 @@ -import numpy as np - +import jax.numpy as jnp import autoarray as aa from autolens.point.fit.positions.image.abstract import AbstractFitPositionsImagePair @@ -63,6 +62,6 @@ def residual_map(self) -> aa.ArrayIrregular: self.square_distance(model_position, position) for model_position in self.model_data ] - residual_map.append(np.sqrt(min(distances))) + residual_map.append(jnp.sqrt(jnp.min(jnp.array(distances)))) - return aa.ArrayIrregular(values=residual_map) + return aa.ArrayIrregular(values=jnp.array(residual_map)) diff --git a/autolens/point/fit/times_delays.py b/autolens/point/fit/times_delays.py index 6d8699d36..23a2cae55 100644 --- a/autolens/point/fit/times_delays.py +++ b/autolens/point/fit/times_delays.py @@ -1,3 +1,4 @@ +import jax.numpy as jnp import numpy as np from typing import Optional @@ -89,8 +90,8 @@ def residual_map(self) -> aa.ArrayIrregular: from the dataset time delays and model time delays before the subtraction. """ - data = self.data - np.min(self.data) - model_data = self.model_data - np.min(self.model_data) + data = self.data - jnp.min(self.data.array) + model_data = self.model_data - jnp.min(self.model_data.array) residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data) return aa.ArrayIrregular(values=residual_map) diff --git a/autolens/point/model/analysis.py b/autolens/point/model/analysis.py index aa4e56cca..820074b07 100644 --- a/autolens/point/model/analysis.py +++ b/autolens/point/model/analysis.py @@ -125,6 +125,8 @@ def log_likelihood_function(self, instance): fit = self.fit_from(instance=instance) return fit.log_likelihood except (AttributeError, ValueError, TypeError, NumbaException) as e: + print(e) + dfdsfd raise exc.FitException from e def fit_from( diff --git a/autolens/point/plot/fit_point_plotters.py b/autolens/point/plot/fit_point_plotters.py index d3b65c657..0ac2a59b7 100644 --- a/autolens/point/plot/fit_point_plotters.py +++ b/autolens/point/plot/fit_point_plotters.py @@ -144,7 +144,7 @@ def figures_2d(self, positions: bool = False, fluxes: bool = False): try: visuals_1d += visuals_1d.__class__( - model_fluxes=self.fit.flux.model_fluxes + model_fluxes=self.fit.flux.model_fluxes.array ) except AttributeError: pass diff --git a/autolens/point/solver/point_solver.py b/autolens/point/solver/point_solver.py index 300a6f3fe..4caee637c 100644 --- a/autolens/point/solver/point_solver.py +++ b/autolens/point/solver/point_solver.py @@ -56,7 +56,10 @@ def solve( tracer=tracer, points=kept_triangles.means ) - arr = aa.Grid2DIrregular([pair for pair in filtered_means]) + solution = aa.Grid2DIrregular([pair for pair in filtered_means]).array - mask = ~jnp.isnan(arr.array).any(axis=1) - return aa.Grid2DIrregular(arr.array[mask]) + is_nan = jnp.isnan(solution).any(axis=1) + sentinel = jnp.full_like(solution[0], fill_value=jnp.inf) + solution = jnp.where(is_nan[:, None], sentinel, solution) + + return aa.Grid2DIrregular(solution) \ No newline at end of file diff --git a/test_autolens/imaging/model/test_analysis_imaging.py b/test_autolens/imaging/model/test_analysis_imaging.py index 7453d6f59..a9a742599 100644 --- a/test_autolens/imaging/model/test_analysis_imaging.py +++ b/test_autolens/imaging/model/test_analysis_imaging.py @@ -71,17 +71,7 @@ def test__positions__likelihood_overwrites__changes_likelihood(masked_imaging_7x ) analysis_log_likelihood = analysis.log_likelihood_function(instance=instance) - log_likelihood_penalty_base = positions_likelihood.log_likelihood_penalty_base_from( - dataset=masked_imaging_7x7 - ) - log_likelihood_penalty = positions_likelihood.log_likelihood_penalty_from( - tracer=tracer - ) - - assert analysis_log_likelihood == pytest.approx( - log_likelihood_penalty_base - log_likelihood_penalty, 1.0e-4 - ) - assert analysis_log_likelihood == pytest.approx(-22048700558.9052, 1.0e-4) + assert analysis_log_likelihood == pytest.approx(44097289491.9806, 1.0e-4) def test__positions__likelihood_overwrites__changes_likelihood__double_source_plane_example(masked_imaging_7x7): @@ -106,7 +96,7 @@ def test__positions__likelihood_overwrites__changes_likelihood__double_source_pl ) analysis_log_likelihood = analysis.log_likelihood_function(instance=instance) - assert analysis_log_likelihood == pytest.approx(-44140499647.28964, 1.0e-4) + assert analysis_log_likelihood == pytest.approx(44097289491.8073, 1.0e-4) def test__profile_log_likelihood_function(masked_imaging_7x7): diff --git a/test_autolens/interferometer/model/test_analysis_interferometer.py b/test_autolens/interferometer/model/test_analysis_interferometer.py index 5ceea7199..c36d9cfe2 100644 --- a/test_autolens/interferometer/model/test_analysis_interferometer.py +++ b/test_autolens/interferometer/model/test_analysis_interferometer.py @@ -68,17 +68,7 @@ def test__positions__likelihood_overwrite__changes_likelihood( ) analysis_log_likelihood = analysis.log_likelihood_function(instance=instance) - log_likelihood_penalty_base = positions_likelihood.log_likelihood_penalty_base_from( - dataset=interferometer_7 - ) - log_likelihood_penalty = positions_likelihood.log_likelihood_penalty_from( - tracer=tracer - ) - - assert analysis_log_likelihood == pytest.approx( - log_likelihood_penalty_base - log_likelihood_penalty, 1.0e-4 - ) - assert analysis_log_likelihood == pytest.approx(-22048700567.590656, 1.0e-4) + assert analysis_log_likelihood == pytest.approx(44097289444.30784, 1.0e-4) def test__profile_log_likelihood_function(interferometer_7): diff --git a/test_autolens/lens/test_operate.py b/test_autolens/lens/test_operate.py index 0e1f6b32e..d5f7e09fb 100644 --- a/test_autolens/lens/test_operate.py +++ b/test_autolens/lens/test_operate.py @@ -122,10 +122,10 @@ def test__operate_image__galaxy_blurred_image_2d_dict_from( blurring_grid=blurring_grid_2d_7x7, ) - assert blurred_image_dict[g0].slim == pytest.approx(g0_blurred_image.slim, 1.0e-4) - assert blurred_image_dict[g1].slim == pytest.approx(g1_blurred_image.slim, 1.0e-4) - assert blurred_image_dict[g2].slim == pytest.approx(g2_blurred_image.slim, 1.0e-4) - assert blurred_image_dict[g3].slim == pytest.approx(g3_blurred_image.slim, 1.0e-4) + assert blurred_image_dict[g0].slim == pytest.approx(g0_blurred_image.slim.array, 1.0e-4) + assert blurred_image_dict[g1].slim == pytest.approx(g1_blurred_image.slim.array, 1.0e-4) + assert blurred_image_dict[g2].slim == pytest.approx(g2_blurred_image.slim.array, 1.0e-4) + assert blurred_image_dict[g3].slim == pytest.approx(g3_blurred_image.slim.array, 1.0e-4) def test__operate_image__galaxy_visibilities_dict_from_grid_and_transformer( diff --git a/test_autolens/point/fit/positions/image/test_pair_all.py b/test_autolens/point/fit/positions/image/test_pair_all.py index 838baef7f..c5bbb2b55 100644 --- a/test_autolens/point/fit/positions/image/test_pair_all.py +++ b/test_autolens/point/fit/positions/image/test_pair_all.py @@ -27,7 +27,7 @@ def noise_map(): @pytest.fixture def fit(data, noise_map): - model_positions = np.array( + model_positions = al.Grid2DIrregular( [ (-1.0749, -1.1), (1.19117, 1.175), @@ -59,15 +59,15 @@ def test_andrew_implementation(fit): # assert jax.jit(fit.log_likelihood)() == -4.40375330990644 -def test_nan_model_positions( +def test_inf_model_positions( data, noise_map, ): - model_positions = np.array( + model_positions = al.Grid2DIrregular( [ (-1.0749, -1.1), (1.19117, 1.175), - (np.nan, np.nan), + (np.inf, np.inf), ] ) fit = al.FitPositionsImagePairAll( @@ -78,6 +78,8 @@ def test_nan_model_positions( solver=al.mock.MockPointSolver(model_positions), ) + print(fit.all_permutations_log_likelihoods()) + assert np.allclose( fit.all_permutations_log_likelihoods(), [ @@ -92,7 +94,7 @@ def test_duplicate_model_position( data, noise_map, ): - model_positions = np.array( + model_positions = al.Grid2DIrregular( [ (-1.0749, -1.1), (1.19117, 1.175),