Skip to content
20 changes: 10 additions & 10 deletions autolens/analysis/analysis/lens.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax.numpy as jnp
import logging
import numpy as np
from typing import Dict, List, Optional, Union
Expand Down Expand Up @@ -98,7 +99,7 @@ def tracer_via_instance_from(
run_time_dict=run_time_dict,
)

def log_likelihood_positions_overwrite_from(
def log_likelihood_penalty_from(
self, instance: af.ModelInstance
) -> Optional[float]:
"""
Expand All @@ -120,21 +121,20 @@ def log_likelihood_positions_overwrite_from(
The penalty value of the positions log likelihood, if the positions do not trace close in the source plane,
else a None is returned to indicate there is no penalty.
"""
if self.positions_likelihood_list is not None:
log_likelihood_penalty = jnp.array(0.0)

log_likelihood_overwrite = None
if self.positions_likelihood_list is not None:

try:
for positions_likelihood in self.positions_likelihood_list:
log_likelihood_with_penalty = positions_likelihood.log_likelihood_function_positions_overwrite(
log_likelihood_penalty = positions_likelihood.log_likelihood_penalty_from(
instance=instance, analysis=self
)
if log_likelihood_with_penalty is not None:
try:
log_likelihood_overwrite += log_likelihood_with_penalty
except TypeError:
log_likelihood_overwrite = log_likelihood_with_penalty

return log_likelihood_overwrite
log_likelihood_penalty += log_likelihood_penalty

return log_likelihood_penalty
except (ValueError, np.linalg.LinAlgError) as e:
raise exc.FitException from e

return log_likelihood_penalty
132 changes: 37 additions & 95 deletions autolens/analysis/positions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import jax
import jax.numpy as jnp
import numpy as np
from typing import Optional, Union
from os import path
Expand Down Expand Up @@ -71,6 +73,7 @@ def __init__(
The plane redshift of the lensed source multiple images, which is only required if position threshold
for a double source plane lens system is being used where the specific plane is required.
"""

self.positions = positions
self.threshold = threshold
self.plane_redshift = plane_redshift
Expand Down Expand Up @@ -133,79 +136,45 @@ def output_positions_info(
)
f.write("")

def log_likelihood_penalty_base_from(
self, dataset: Union[aa.Imaging, aa.Interferometer]
) -> float:
def log_likelihood_penalty_from(self, instance: af.ModelInstance, analysis: AnalysisDataset) -> jnp.array:
"""
The fast log likelihood penalty scheme returns an alternative penalty log likelihood for any model where the
image-plane positions do not trace within a threshold distance of one another in the source-plane.
Returns a log-likelihood penalty used to constrain lens models where multiple image-plane
positions do not trace to within a threshold distance of one another in the source-plane.

This penalty is intended for use in `Analysis` classes that include the `PenaltyLH` mixin. It adds a
heavy penalty to the likelihood when the multiple images traces far apart in the source-plane, discouraging
models where the mapped source-plane positions are too widely separated.

This `log_likelihood_penalty` is defined as:
Specifically, if the maximum separation between traced positions in the source-plane exceeds
a defined threshold, a penalty term is applied to the log likelihood:

log_Likelihood_penalty_base - log_likelihood_penalty_factor * (max_source_plane_separation - threshold)
penalty = log_likelihood_penalty_factor * (max_separation - threshold)

The `log_likelihood_penalty` is only used if `max_source_plane_separation > threshold`.
If the separation is within the threshold, no penalty is applied.

This function returns the `log_likelihood_penalty_base`, which represents the lowest possible likelihood
solutions a model-fit can give. It is the chi-squared of model-data consisting of all zeros plus
the noise normalziation term.
JAX Compatibility
-----------------
Because this function may be jitted or differentiated using JAX, it uses `jax.lax.cond` to apply
conditional logic in a way that is compatible with JAX's functional and tracing model.
Both branches (penalty and zero) are evaluated at trace time, though only one is returned
at runtime depending on the condition.

Parameters
----------
dataset
The imaging or interferometer dataset from which the penalty base is computed.
"""

residual_map = aa.util.fit.residual_map_from(
data=dataset.data, model_data=np.zeros(dataset.data.shape)
)

if isinstance(dataset, aa.Imaging):
chi_squared_map = aa.util.fit.chi_squared_map_from(
residual_map=residual_map, noise_map=dataset.noise_map
)

chi_squared = aa.util.fit.chi_squared_from(
chi_squared_map=chi_squared_map.array
)

noise_normalization = aa.util.fit.noise_normalization_from(
noise_map=dataset.noise_map.array
)

else:
chi_squared_map = aa.util.fit.chi_squared_map_complex_from(
residual_map=residual_map, noise_map=dataset.noise_map
)

chi_squared = aa.util.fit.chi_squared_complex_from(
chi_squared_map=chi_squared_map
)

noise_normalization = aa.util.fit.noise_normalization_complex_from(
noise_map=dataset.noise_map
)

return -0.5 * (chi_squared + noise_normalization)
instance
The current model instance evaluated during the non-linear search.
analysis
The `Analysis` object calling this function, from which the `tracer` and `dataset` are derived.

def log_likelihood_penalty_from(self, tracer: Tracer) -> Optional[float]:
Returns
-------
penalty
A scalar log-likelihood penalty (≥ 0) if the max separation exceeds the threshold, or 0.0 otherwise.
"""
The fast log likelihood penalty scheme returns an alternative penalty log likelihood for any model where the
image-plane positions to not trace within a threshold distance of one another in the source-plane.

This `log_likelihood_penalty` is defined as:

log_Likelihood_penalty_base - log_likelihood_penalty_factor * (max_source_plane_separation - threshold)

The `log_likelihood_penalty` is only used if `max_source_plane_separation > threshold`.
tracer = analysis.tracer_via_instance_from(instance=instance)

Parameters
----------
dataset
The imaging or interferometer dataset from which the penalty base is computed.
"""
if not tracer.has(cls=ag.mp.MassProfile) or len(tracer.planes) == 1:
return
return jnp.array(0.0),

positions_fit = SourceMaxSeparation(
data=self.positions,
Expand All @@ -214,41 +183,14 @@ def log_likelihood_penalty_from(self, tracer: Tracer) -> Optional[float]:
plane_redshift=self.plane_redshift,
)

if not positions_fit.max_separation_within_threshold(self.threshold):
max_separation = jnp.max(positions_fit.furthest_separations_of_plane_positions.array)

return self.log_likelihood_penalty_factor * (
positions_fit.max_separation_of_plane_positions - self.threshold
penalty = self.log_likelihood_penalty_factor * (
max_separation - self.threshold
)

def log_likelihood_function_positions_overwrite(
self, instance: af.ModelInstance, analysis: AnalysisDataset
) -> Optional[float]:
"""
This is called in the `log_likelihood_function` of certain `Analysis` classes to add the penalty term of
this class, which penalies mass models which do not trace within the threshold of one another in the
source-plane.

Parameters
----------
instance
The instance of the lens model that is being fitted for this iteration of the non-linear search.
analysis
The analysis class from which the log likliehood function is called.
"""
tracer = analysis.tracer_via_instance_from(instance=instance)

if not tracer.has(cls=ag.mp.MassProfile) or len(tracer.planes) == 1:
return

log_likelihood_positions_penalty = self.log_likelihood_penalty_from(
tracer=tracer
return jax.lax.cond(
max_separation > self.threshold,
lambda: penalty,
lambda: jnp.array(0.0),
)

if log_likelihood_positions_penalty is None:
return None

log_likelihood_penalty_base = self.log_likelihood_penalty_base_from(
dataset=analysis.dataset
)

return log_likelihood_penalty_base - log_likelihood_positions_penalty
6 changes: 6 additions & 0 deletions autolens/analysis/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,12 @@ def positions_likelihood_from(

positions = positions[distances > mass_centre_radial_distance_min]

mask = np.isfinite(positions.array).all(axis=1)

positions = aa.Grid2DIrregular(
positions[mask]
)

threshold = self.positions_threshold_from(
factor=factor,
minimum_threshold=minimum_threshold,
Expand Down
9 changes: 2 additions & 7 deletions autolens/imaging/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,14 @@ def log_likelihood_function(self, instance: af.ModelInstance) -> float:
"""

try:
log_likelihood_positions_overwrite = self.log_likelihood_positions_overwrite_from(
log_likelihood_penalty = self.log_likelihood_penalty_from(
instance=instance
)
if log_likelihood_positions_overwrite is not None:
return log_likelihood_positions_overwrite
except Exception as e:
raise e

if log_likelihood_positions_overwrite is not None:
return log_likelihood_positions_overwrite

try:
return self.fit_from(instance=instance).figure_of_merit
return self.fit_from(instance=instance).figure_of_merit + log_likelihood_penalty
except (
PixelizationException,
exc.PixelizationException,
Expand Down
8 changes: 3 additions & 5 deletions autolens/interferometer/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,14 @@ def log_likelihood_function(self, instance):
"""

try:
log_likelihood_positions_overwrite = (
self.log_likelihood_positions_overwrite_from(instance=instance)
log_likelihood_penalty = self.log_likelihood_penalty_from(
instance=instance
)
if log_likelihood_positions_overwrite is not None:
return log_likelihood_positions_overwrite
except Exception as e:
raise e

try:
return self.fit_from(instance=instance).figure_of_merit
return self.fit_from(instance=instance).figure_of_merit + log_likelihood_penalty
except (
PixelizationException,
exc.PixelizationException,
Expand Down
5 changes: 3 additions & 2 deletions autolens/point/fit/fluxes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax.numpy as jnp
from typing import Optional

import autoarray as aa
Expand Down Expand Up @@ -101,10 +102,10 @@ def model_data(self):
are used.
"""
return aa.ArrayIrregular(
values=[
values=jnp.array([
magnification * self.profile.flux
for magnification in self.magnifications_at_positions
]
])
)

@property
Expand Down
3 changes: 3 additions & 0 deletions autolens/point/fit/positions/image/pair.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import jax.numpy as jnp
import numpy as np
from ott.geometry import pointcloud
from ott.solvers.linear import sinkhorn
from scipy.optimize import linear_sum_assignment

import autoarray as aa
Expand Down
31 changes: 16 additions & 15 deletions autolens/point/fit/positions/image/pair_all.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import jax.numpy as jnp
import numpy as np

import autoarray as aa

from autolens.point.fit.positions.image.abstract import AbstractFitPositionsImagePair


Expand Down Expand Up @@ -85,7 +84,7 @@ def log_p(
The log probability of the model coordinate explaining the observed coordinate.
"""
chi2 = self.square_distance(data_position, model_position) / sigma**2
return -np.log(np.sqrt(2 * np.pi * sigma**2)) - 0.5 * chi2
return -jnp.log(jnp.sqrt(2 * jnp.pi * sigma**2)) - 0.5 * chi2

def all_permutations_log_likelihoods(self) -> np.ndarray:
"""
Expand All @@ -101,21 +100,23 @@ def all_permutations_log_likelihoods(self) -> np.ndarray:

This is every way in which the coordinates generated by the model can explain the observed coordinates.
"""
return np.array(

model_data = self.model_data.array

return jnp.array(
[
np.log(
np.sum(
[
np.exp(
jnp.log(
jnp.sum(
jnp.array([
jnp.exp(
self.log_p(
data_position,
model_position,
sigma,
)
)
for model_position in self.model_data
if not np.isnan(model_position).any()
]
for model_position in model_data
])
)
)
for data_position, sigma in zip(self.data, self.noise_map)
Expand All @@ -140,12 +141,12 @@ def chi_squared(self) -> float:

This is every way in which the coordinates generated by the model can explain the observed coordinates.
"""
n_non_nan_model_positions = np.count_nonzero(
~np.isnan(
self.model_data,
n_non_nan_model_positions = jnp.count_nonzero(
jnp.isfinite(
self.model_data.array,
).any(axis=1)
)
n_permutations = n_non_nan_model_positions ** len(self.data)
return -2.0 * (
-np.log(n_permutations) + np.sum(self.all_permutations_log_likelihoods())
-jnp.log(n_permutations) + jnp.sum(self.all_permutations_log_likelihoods())
)
7 changes: 3 additions & 4 deletions autolens/point/fit/positions/image/pair_repeat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np

import jax.numpy as jnp
import autoarray as aa

from autolens.point.fit.positions.image.abstract import AbstractFitPositionsImagePair
Expand Down Expand Up @@ -63,6 +62,6 @@ def residual_map(self) -> aa.ArrayIrregular:
self.square_distance(model_position, position)
for model_position in self.model_data
]
residual_map.append(np.sqrt(min(distances)))
residual_map.append(jnp.sqrt(jnp.min(jnp.array(distances))))

return aa.ArrayIrregular(values=residual_map)
return aa.ArrayIrregular(values=jnp.array(residual_map))
5 changes: 3 additions & 2 deletions autolens/point/fit/times_delays.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax.numpy as jnp
import numpy as np
from typing import Optional

Expand Down Expand Up @@ -89,8 +90,8 @@ def residual_map(self) -> aa.ArrayIrregular:
from the dataset time delays and model time delays before the subtraction.
"""

data = self.data - np.min(self.data)
model_data = self.model_data - np.min(self.model_data)
data = self.data - jnp.min(self.data.array)
model_data = self.model_data - jnp.min(self.model_data.array)

residual_map = aa.util.fit.residual_map_from(data=data, model_data=model_data)
return aa.ArrayIrregular(values=residual_map)
Expand Down
Loading
Loading