Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions autolens/analysis/analysis/lens.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,10 @@ def log_likelihood_penalty_from(

try:
for positions_likelihood in self.positions_likelihood_list:
log_likelihood_penalty = positions_likelihood.log_likelihood_penalty_from(
instance=instance, analysis=self
log_likelihood_penalty = (
positions_likelihood.log_likelihood_penalty_from(
instance=instance, analysis=self
)
)

log_likelihood_penalty += log_likelihood_penalty
Expand All @@ -137,4 +139,4 @@ def log_likelihood_penalty_from(
except (ValueError, np.linalg.LinAlgError) as e:
raise exc.FitException from e

return log_likelihood_penalty
return log_likelihood_penalty
14 changes: 8 additions & 6 deletions autolens/analysis/positions.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def output_positions_info(
)
f.write("")

def log_likelihood_penalty_from(self, instance: af.ModelInstance, analysis: AnalysisDataset) -> jnp.array:
def log_likelihood_penalty_from(
self, instance: af.ModelInstance, analysis: AnalysisDataset
) -> jnp.array:
"""
Returns a log-likelihood penalty used to constrain lens models where multiple image-plane
positions do not trace to within a threshold distance of one another in the source-plane.
Expand Down Expand Up @@ -174,7 +176,7 @@ def log_likelihood_penalty_from(self, instance: af.ModelInstance, analysis: Anal
tracer = analysis.tracer_via_instance_from(instance=instance)

if not tracer.has(cls=ag.mp.MassProfile) or len(tracer.planes) == 1:
return jnp.array(0.0),
return (jnp.array(0.0),)

positions_fit = SourceMaxSeparation(
data=self.positions,
Expand All @@ -183,11 +185,11 @@ def log_likelihood_penalty_from(self, instance: af.ModelInstance, analysis: Anal
plane_redshift=self.plane_redshift,
)

max_separation = jnp.max(positions_fit.furthest_separations_of_plane_positions.array)
max_separation = jnp.max(
positions_fit.furthest_separations_of_plane_positions.array
)

penalty = self.log_likelihood_penalty_factor * (
max_separation - self.threshold
)
penalty = self.log_likelihood_penalty_factor * (max_separation - self.threshold)

return jax.lax.cond(
max_separation > self.threshold,
Expand Down
4 changes: 1 addition & 3 deletions autolens/analysis/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,7 @@ def positions_likelihood_from(

mask = np.isfinite(positions.array).all(axis=1)

positions = aa.Grid2DIrregular(
positions[mask]
)
positions = aa.Grid2DIrregular(positions[mask])

threshold = self.positions_threshold_from(
factor=factor,
Expand Down
11 changes: 7 additions & 4 deletions autolens/interferometer/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,15 @@ def log_likelihood_function(self, instance):
"""

try:
log_likelihood_penalty = self.log_likelihood_penalty_from(
instance=instance
)
log_likelihood_penalty = self.log_likelihood_penalty_from(instance=instance)
except Exception as e:
raise e

try:
return self.fit_from(instance=instance).figure_of_merit + log_likelihood_penalty
return (
self.fit_from(instance=instance).figure_of_merit
+ log_likelihood_penalty
)
except (
PixelizationException,
exc.PixelizationException,
Expand All @@ -171,6 +172,8 @@ def log_likelihood_function(self, instance):
np.linalg.LinAlgError,
OverflowError,
) as e:
print(e)
fggdfg
raise exc.FitException from e

def fit_from(
Expand Down
10 changes: 6 additions & 4 deletions autolens/mock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from autofit.jax_wrapper import register_pytree_node_class, numpy as np
import jax.numpy as jnp

from autofit.jax_wrapper import register_pytree_node_class
from autofit.mock import * # noqa
from autoarray.mock import * # noqa
from autogalaxy.mock import * # noqa
Expand All @@ -17,15 +19,15 @@ def __init__(self):
super().__init__([])

def deflections_yx_2d_from(self, grid):
return np.zeros_like(grid.array)
return jnp.zeros_like(grid.array)

def deflections_between_planes_from(self, grid, plane_i=0, plane_j=-1):
return np.zeros_like(grid.array)
return jnp.zeros_like(grid.array)

def magnification_2d_via_hessian_from(
self, grid, buffer: float = 0.01, deflections_func=None
) -> aa.ArrayIrregular:
return aa.ArrayIrregular(values=np.ones(grid.shape[0]))
return aa.ArrayIrregular(values=jnp.ones(grid.shape[0]))

def tree_flatten(self):
"""
Expand Down
10 changes: 6 additions & 4 deletions autolens/point/fit/fluxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,12 @@ def model_data(self):
are used.
"""
return aa.ArrayIrregular(
values=jnp.array([
magnification * self.profile.flux
for magnification in self.magnifications_at_positions
])
values=jnp.array(
[
magnification * self.profile.flux
for magnification in self.magnifications_at_positions
]
)
)

@property
Expand Down
20 changes: 11 additions & 9 deletions autolens/point/fit/positions/image/pair_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,18 @@ def all_permutations_log_likelihoods(self) -> np.ndarray:
[
jnp.log(
jnp.sum(
jnp.array([
jnp.exp(
self.log_p(
data_position,
model_position,
sigma,
jnp.array(
[
jnp.exp(
self.log_p(
data_position,
model_position,
sigma,
)
)
)
for model_position in model_data
])
for model_position in model_data
]
)
)
)
for data_position, sigma in zip(self.data, self.noise_map)
Expand Down
4 changes: 3 additions & 1 deletion autolens/point/max_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def __init__(
except TypeError:
plane_index = -1

self.plane_positions = aa.Grid2DIrregular(values=tracer.traced_grid_2d_list_from(grid=data)[plane_index])
self.plane_positions = aa.Grid2DIrregular(
values=tracer.traced_grid_2d_list_from(grid=data)[plane_index]
)

@property
def furthest_separations_of_plane_positions(self) -> aa.ArrayIrregular:
Expand Down
2 changes: 1 addition & 1 deletion autolens/point/solver/point_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def solve(
sentinel = jnp.full_like(solution[0], fill_value=jnp.inf)
solution = jnp.where(is_nan[:, None], sentinel, solution)

return aa.Grid2DIrregular(solution)
return aa.Grid2DIrregular(solution)
16 changes: 4 additions & 12 deletions autolens/point/solver/shape_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import logging
import math

from typing import Tuple, List, Iterator, Type, Optional
from typing import Tuple, List, Iterator, Optional

import autoarray as aa

from autoarray.structures.triangles.shape import Shape
from autofit.jax_wrapper import register_pytree_node_class

from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import (
from autoarray.structures.triangles.coordinate_array import (
CoordinateArrayTriangles,
)
from autoarray.structures.triangles.abstract import AbstractTriangles
Expand Down Expand Up @@ -59,7 +59,6 @@ def for_grid(
grid: aa.Grid2D,
pixel_scale_precision: float,
magnification_threshold=0.1,
array_triangles_cls: Type[AbstractTriangles] = CoordinateArrayTriangles,
neighbor_degree: int = 1,
):
"""
Expand All @@ -75,9 +74,6 @@ def for_grid(
The precision to which the triangles should be subdivided.
magnification_threshold
The threshold for the magnification under which multiple images are filtered.
array_triangles_cls
The class to use for the triangles. JAX is used implicitly if USE_JAX=1 and
jax is installed.
max_containing_size
Only applies to JAX. This is the maximum number of multiple images expected.
We need to know this in advance to allocate memory for the JAX array.
Expand Down Expand Up @@ -106,7 +102,6 @@ def for_grid(
scale=scale,
pixel_scale_precision=pixel_scale_precision,
magnification_threshold=magnification_threshold,
array_triangles_cls=array_triangles_cls,
neighbor_degree=neighbor_degree,
)

Expand All @@ -120,7 +115,6 @@ def for_limits_and_scale(
scale=0.1,
pixel_scale_precision: float = 0.001,
magnification_threshold=0.1,
array_triangles_cls: Type[AbstractTriangles] = CoordinateArrayTriangles,
neighbor_degree: int = 1,
):
"""
Expand All @@ -141,17 +135,14 @@ def for_limits_and_scale(
The precision to which the triangles should be subdivided.
magnification_threshold
The threshold for the magnification under which multiple images are filtered.
array_triangles_cls
The class to use for the triangles. JAX is used implicitly if USE_JAX=1 and
jax is installed.
neighbor_degree
The number of times recursively add neighbors for the triangles that contain

Returns
-------
The solver.
"""
initial_triangles = array_triangles_cls.for_limits_and_scale(
initial_triangles = CoordinateArrayTriangles.for_limits_and_scale(
y_min=y_min,
y_max=y_max,
x_min=x_min,
Expand Down Expand Up @@ -310,6 +301,7 @@ def steps(
An iterator over the steps of the triangle solver algorithm.
"""
initial_triangles = self.initial_triangles

for number in range(self.n_steps):
plane_triangles = self._plane_triangles(
tracer=tracer,
Expand Down
6 changes: 0 additions & 6 deletions autolens/point/solver/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@
import autoarray as aa
from autoarray.numpy_wrapper import register_pytree_node_class

try:
from autoarray.structures.triangles.array.jax_array import ArrayTriangles
except ImportError:
from autoarray.structures.triangles.array import ArrayTriangles


logger = logging.getLogger(__name__)


Expand Down
16 changes: 12 additions & 4 deletions test_autolens/lens/test_operate.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,18 @@ def test__operate_image__galaxy_blurred_image_2d_dict_from(
blurring_grid=blurring_grid_2d_7x7,
)

assert blurred_image_dict[g0].slim == pytest.approx(g0_blurred_image.slim.array, 1.0e-4)
assert blurred_image_dict[g1].slim == pytest.approx(g1_blurred_image.slim.array, 1.0e-4)
assert blurred_image_dict[g2].slim == pytest.approx(g2_blurred_image.slim.array, 1.0e-4)
assert blurred_image_dict[g3].slim == pytest.approx(g3_blurred_image.slim.array, 1.0e-4)
assert blurred_image_dict[g0].slim == pytest.approx(
g0_blurred_image.slim.array, 1.0e-4
)
assert blurred_image_dict[g1].slim == pytest.approx(
g1_blurred_image.slim.array, 1.0e-4
)
assert blurred_image_dict[g2].slim == pytest.approx(
g2_blurred_image.slim.array, 1.0e-4
)
assert blurred_image_dict[g3].slim == pytest.approx(
g3_blurred_image.slim.array, 1.0e-4
)


def test__operate_image__galaxy_visibilities_dict_from_grid_and_transformer(
Expand Down
2 changes: 0 additions & 2 deletions test_autolens/point/triangles/test_extended.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest

from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
from autoarray.structures.triangles.shape import Circle
from autolens.mock import NullTracer
from autolens.point.solver.shape_solver import ShapeSolver
Expand All @@ -11,7 +10,6 @@ def solver(grid):
return ShapeSolver.for_grid(
grid=grid,
pixel_scale_precision=0.001,
array_triangles_cls=CoordinateArrayTriangles,
)


Expand Down
Loading
Loading