From 8e4b8e25bdb6b60eebdac3357ff6853dd46a51d2 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 19:16:36 +0100 Subject: [PATCH 1/4] jax solver tests reinstated --- autolens/point/solver/shape_solver.py | 2 +- .../point/triangles/test_solver_jax.py | 210 +++++++----------- 2 files changed, 77 insertions(+), 135 deletions(-) diff --git a/autolens/point/solver/shape_solver.py b/autolens/point/solver/shape_solver.py index 48a88ca99..3ecf2f371 100644 --- a/autolens/point/solver/shape_solver.py +++ b/autolens/point/solver/shape_solver.py @@ -9,7 +9,7 @@ from autoarray.structures.triangles.shape import Shape from autofit.jax_wrapper import register_pytree_node_class -from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import ( +from autoarray.structures.triangles.coordinate_array import ( CoordinateArrayTriangles, ) from autoarray.structures.triangles.abstract import AbstractTriangles diff --git a/test_autolens/point/triangles/test_solver_jax.py b/test_autolens/point/triangles/test_solver_jax.py index ce9ae5a82..8e3a40e36 100644 --- a/test_autolens/point/triangles/test_solver_jax.py +++ b/test_autolens/point/triangles/test_solver_jax.py @@ -1,141 +1,83 @@ +import numpy as np +import pytest import time from typing import Tuple -import pytest - import autogalaxy as ag import autofit as af -import numpy as np from autolens import PointSolver, Tracer -# -# try: -# from autoarray.structures.triangles.coordinate_array.jax_coordinate_array import ( -# CoordinateArrayTriangles, -# ) -# -# except ImportError: -# from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles -# -# from autolens.mock import NullTracer -# -# pytest.importorskip("jax") -# -# -# @pytest.fixture(autouse=True) -# def register(tracer): -# af.Model.from_instance(tracer) -# -# -# @pytest.fixture -# def solver(grid): -# return PointSolver.for_grid( -# grid=grid, -# pixel_scale_precision=0.01, -# array_triangles_cls=CoordinateArrayTriangles, -# ) -# -# -# def test_solver(solver): -# mass_profile = ag.mp.Isothermal( -# centre=(0.0, 0.0), -# einstein_radius=1.0, -# ) -# tracer = Tracer( -# galaxies=[ag.Galaxy(redshift=0.5, mass=mass_profile)], -# ) -# result = solver.solve( -# tracer, -# source_plane_coordinate=(0.0, 0.0), -# ) -# print(result) -# assert result -# -# -# @pytest.mark.parametrize( -# "source_plane_coordinate", -# [ -# (0.0, 0.0), -# (0.0, 1.0), -# (1.0, 0.0), -# (1.0, 1.0), -# (0.5, 0.5), -# (0.1, 0.1), -# (-1.0, -1.0), -# ], -# ) -# def test_trivial( -# source_plane_coordinate: Tuple[float, float], -# grid, -# solver, -# ): -# coordinates = solver.solve( -# NullTracer(), -# source_plane_coordinate=source_plane_coordinate, -# ) -# coordinates = coordinates.array[~np.isnan(coordinates.array).any(axis=1)] -# assert coordinates[0] == pytest.approx(source_plane_coordinate, abs=1.0e-1) -# -# -# def test_real_example(grid, tracer): -# solver = PointSolver.for_grid( -# grid=grid, -# pixel_scale_precision=0.001, -# array_triangles_cls=CoordinateArrayTriangles, -# ) -# -# result = solver.solve(tracer, (0.07, 0.07)) -# assert len(result) == 5 -# -# -# def _test_jax(grid): -# sizes = (5, 10, 15, 20, 25, 30, 35, 40, 45, 50) -# run_times = [] -# init_times = [] -# -# for size in sizes: -# start = time.time() -# solver = PointSolver.for_grid( -# grid=grid, -# pixel_scale_precision=0.001, -# array_triangles_cls=CoordinateArrayTriangles, -# max_containing_size=size, -# ) -# -# solver.solve(NullTracer(), (0.07, 0.07)) -# -# repeats = 100 -# -# done_init_time = time.time() -# init_time = done_init_time - start -# for _ in range(repeats): -# _ = solver.solve(NullTracer(), (0.07, 0.07)) -# -# # print(result) -# -# init_times.append(init_time) -# -# run_time = (time.time() - done_init_time) / repeats -# run_times.append(run_time) -# -# print(f"Time taken for {size}: {run_time} ({init_time} to init)") -# -# from matplotlib import pyplot as plt -# -# plt.plot(sizes, run_times) -# plt.show() -# -# -# def test_real_example_jax(grid, tracer): -# jax_solver = PointSolver.for_grid( -# grid=grid, -# pixel_scale_precision=0.001, -# array_triangles_cls=CoordinateArrayTriangles, -# ) -# -# result = jax_solver.solve( -# tracer=tracer, -# source_plane_coordinate=(0.07, 0.07), -# ) -# -# assert len(result) == 5 +from autoarray.structures.triangles.coordinate_array import ( + CoordinateArrayTriangles, +) + +from autolens.mock import NullTracer + + +@pytest.fixture(autouse=True) +def register(tracer): + af.Model.from_instance(tracer) + + +@pytest.fixture +def solver(grid): + return PointSolver.for_grid( + grid=grid, + pixel_scale_precision=0.01, + array_triangles_cls=CoordinateArrayTriangles, + ) + + +def test_solver(solver): + mass_profile = ag.mp.Isothermal( + centre=(0.0, 0.0), + einstein_radius=1.0, + ) + tracer = Tracer( + galaxies=[ag.Galaxy(redshift=0.5, mass=mass_profile)], + ) + result = solver.solve( + tracer, + source_plane_coordinate=(0.0, 0.0), + ) + print(result) + assert result + + +@pytest.mark.parametrize( + "source_plane_coordinate", + [ + (0.0, 0.0), + (0.0, 1.0), + (1.0, 0.0), + (1.0, 1.0), + (0.5, 0.5), + (0.1, 0.1), + (-1.0, -1.0), + ], +) +def test_trivial( + source_plane_coordinate: Tuple[float, float], + grid, + solver, +): + coordinates = solver.solve( + NullTracer(), + source_plane_coordinate=source_plane_coordinate, + ) + coordinates = coordinates.array[~np.isnan(coordinates.array).any(axis=1)] + assert coordinates[0] == pytest.approx(source_plane_coordinate, abs=1.0e-1) + +def test_real_example_jax(grid, tracer): + jax_solver = PointSolver.for_grid( + grid=grid, + pixel_scale_precision=0.001, + array_triangles_cls=CoordinateArrayTriangles, + ) + + result = jax_solver.solve( + tracer=tracer, + source_plane_coordinate=(0.07, 0.07), + ) + + assert len(result) == 15 From 5b67e42f4ad70ff365d6e9bbe4396dab8c7ab626 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 19:38:55 +0100 Subject: [PATCH 2/4] refactor done but one test fails --- autolens/mock.py | 10 ++- autolens/point/solver/shape_solver.py | 13 +-- autolens/point/solver/step.py | 6 -- .../point/triangles/test_extended.py | 2 - test_autolens/point/triangles/test_solver.py | 73 +++++++--------- .../point/triangles/test_solver_jax.py | 83 ------------------- 6 files changed, 37 insertions(+), 150 deletions(-) delete mode 100644 test_autolens/point/triangles/test_solver_jax.py diff --git a/autolens/mock.py b/autolens/mock.py index 128af87ac..65fa0b921 100644 --- a/autolens/mock.py +++ b/autolens/mock.py @@ -1,4 +1,6 @@ -from autofit.jax_wrapper import register_pytree_node_class, numpy as np +import jax.numpy as jnp + +from autofit.jax_wrapper import register_pytree_node_class from autofit.mock import * # noqa from autoarray.mock import * # noqa from autogalaxy.mock import * # noqa @@ -17,15 +19,15 @@ def __init__(self): super().__init__([]) def deflections_yx_2d_from(self, grid): - return np.zeros_like(grid.array) + return jnp.zeros_like(grid.array) def deflections_between_planes_from(self, grid, plane_i=0, plane_j=-1): - return np.zeros_like(grid.array) + return jnp.zeros_like(grid.array) def magnification_2d_via_hessian_from( self, grid, buffer: float = 0.01, deflections_func=None ) -> aa.ArrayIrregular: - return aa.ArrayIrregular(values=np.ones(grid.shape[0])) + return aa.ArrayIrregular(values=jnp.ones(grid.shape[0])) def tree_flatten(self): """ diff --git a/autolens/point/solver/shape_solver.py b/autolens/point/solver/shape_solver.py index 3ecf2f371..f55fbf59f 100644 --- a/autolens/point/solver/shape_solver.py +++ b/autolens/point/solver/shape_solver.py @@ -2,7 +2,7 @@ import logging import math -from typing import Tuple, List, Iterator, Type, Optional +from typing import Tuple, List, Iterator, Optional import autoarray as aa @@ -59,7 +59,6 @@ def for_grid( grid: aa.Grid2D, pixel_scale_precision: float, magnification_threshold=0.1, - array_triangles_cls: Type[AbstractTriangles] = CoordinateArrayTriangles, neighbor_degree: int = 1, ): """ @@ -75,9 +74,6 @@ def for_grid( The precision to which the triangles should be subdivided. magnification_threshold The threshold for the magnification under which multiple images are filtered. - array_triangles_cls - The class to use for the triangles. JAX is used implicitly if USE_JAX=1 and - jax is installed. max_containing_size Only applies to JAX. This is the maximum number of multiple images expected. We need to know this in advance to allocate memory for the JAX array. @@ -106,7 +102,6 @@ def for_grid( scale=scale, pixel_scale_precision=pixel_scale_precision, magnification_threshold=magnification_threshold, - array_triangles_cls=array_triangles_cls, neighbor_degree=neighbor_degree, ) @@ -120,7 +115,6 @@ def for_limits_and_scale( scale=0.1, pixel_scale_precision: float = 0.001, magnification_threshold=0.1, - array_triangles_cls: Type[AbstractTriangles] = CoordinateArrayTriangles, neighbor_degree: int = 1, ): """ @@ -141,9 +135,6 @@ def for_limits_and_scale( The precision to which the triangles should be subdivided. magnification_threshold The threshold for the magnification under which multiple images are filtered. - array_triangles_cls - The class to use for the triangles. JAX is used implicitly if USE_JAX=1 and - jax is installed. neighbor_degree The number of times recursively add neighbors for the triangles that contain @@ -151,7 +142,7 @@ def for_limits_and_scale( ------- The solver. """ - initial_triangles = array_triangles_cls.for_limits_and_scale( + initial_triangles = CoordinateArrayTriangles.for_limits_and_scale( y_min=y_min, y_max=y_max, x_min=x_min, diff --git a/autolens/point/solver/step.py b/autolens/point/solver/step.py index 27a8ae779..626f1f0a1 100644 --- a/autolens/point/solver/step.py +++ b/autolens/point/solver/step.py @@ -4,12 +4,6 @@ import autoarray as aa from autoarray.numpy_wrapper import register_pytree_node_class -try: - from autoarray.structures.triangles.array.jax_array import ArrayTriangles -except ImportError: - from autoarray.structures.triangles.array import ArrayTriangles - - logger = logging.getLogger(__name__) diff --git a/test_autolens/point/triangles/test_extended.py b/test_autolens/point/triangles/test_extended.py index a0d21666a..f524ae6f4 100644 --- a/test_autolens/point/triangles/test_extended.py +++ b/test_autolens/point/triangles/test_extended.py @@ -1,6 +1,5 @@ import pytest -from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles from autoarray.structures.triangles.shape import Circle from autolens.mock import NullTracer from autolens.point.solver.shape_solver import ShapeSolver @@ -11,7 +10,6 @@ def solver(grid): return ShapeSolver.for_grid( grid=grid, pixel_scale_precision=0.001, - array_triangles_cls=CoordinateArrayTriangles, ) diff --git a/test_autolens/point/triangles/test_solver.py b/test_autolens/point/triangles/test_solver.py index 62b3b35dc..68813254a 100644 --- a/test_autolens/point/triangles/test_solver.py +++ b/test_autolens/point/triangles/test_solver.py @@ -1,13 +1,22 @@ -from typing import Tuple - import numpy as np import pytest +import time +from typing import Tuple -import autolens as al import autogalaxy as ag -from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles +import autofit as af +from autolens import PointSolver, Tracer + +from autoarray.structures.triangles.coordinate_array import ( + CoordinateArrayTriangles, +) + from autolens.mock import NullTracer -from autolens.point.solver import PointSolver + + +@pytest.fixture(autouse=True) +def register(tracer): + af.Model.from_instance(tracer) @pytest.fixture @@ -18,30 +27,19 @@ def solver(grid): ) -def test_solver_basic(solver): - tracer = al.Tracer( - galaxies=[ - al.Galaxy( - redshift=0.5, - mass=ag.mp.Isothermal( - centre=(0.0, 0.0), - einstein_radius=1.0, - ), - ), - al.Galaxy( - redshift=1.0, - ), - ] +def test_solver(solver): + mass_profile = ag.mp.Isothermal( + centre=(0.0, 0.0), + einstein_radius=1.0, ) - - assert solver.solve( - tracer=tracer, + tracer = Tracer( + galaxies=[ag.Galaxy(redshift=0.5, mass=mass_profile)], + ) + result = solver.solve( + tracer, source_plane_coordinate=(0.0, 0.0), ) - - -def test_steps(solver): - assert solver.n_steps == 7 + assert result @pytest.mark.parametrize( @@ -59,32 +57,19 @@ def test_steps(solver): def test_trivial( source_plane_coordinate: Tuple[float, float], grid, + solver, ): - solver = PointSolver.for_grid( - grid=grid, - pixel_scale_precision=0.01, - ) coordinates = solver.solve( - tracer=NullTracer(), + NullTracer(), source_plane_coordinate=source_plane_coordinate, ) - + coordinates = coordinates.array[~np.isnan(coordinates.array).any(axis=1)] assert coordinates[0] == pytest.approx(source_plane_coordinate, abs=1.0e-1) - -def triangle_set(triangles): - return { - tuple(sorted([tuple(np.round(pair, 4)) for pair in triangle])) - for triangle in triangles.triangles.tolist() - if not np.isnan(triangle).any() - } - - -def test_real_example_normal(grid, tracer): +def test_real_example_jax(grid, tracer): jax_solver = PointSolver.for_grid( grid=grid, pixel_scale_precision=0.001, - array_triangles_cls=CoordinateArrayTriangles, ) result = jax_solver.solve( @@ -92,4 +77,4 @@ def test_real_example_normal(grid, tracer): source_plane_coordinate=(0.07, 0.07), ) - assert len(result) == 5 + assert len(result) == 15 diff --git a/test_autolens/point/triangles/test_solver_jax.py b/test_autolens/point/triangles/test_solver_jax.py deleted file mode 100644 index 8e3a40e36..000000000 --- a/test_autolens/point/triangles/test_solver_jax.py +++ /dev/null @@ -1,83 +0,0 @@ -import numpy as np -import pytest -import time -from typing import Tuple - -import autogalaxy as ag -import autofit as af -from autolens import PointSolver, Tracer - -from autoarray.structures.triangles.coordinate_array import ( - CoordinateArrayTriangles, -) - -from autolens.mock import NullTracer - - -@pytest.fixture(autouse=True) -def register(tracer): - af.Model.from_instance(tracer) - - -@pytest.fixture -def solver(grid): - return PointSolver.for_grid( - grid=grid, - pixel_scale_precision=0.01, - array_triangles_cls=CoordinateArrayTriangles, - ) - - -def test_solver(solver): - mass_profile = ag.mp.Isothermal( - centre=(0.0, 0.0), - einstein_radius=1.0, - ) - tracer = Tracer( - galaxies=[ag.Galaxy(redshift=0.5, mass=mass_profile)], - ) - result = solver.solve( - tracer, - source_plane_coordinate=(0.0, 0.0), - ) - print(result) - assert result - - -@pytest.mark.parametrize( - "source_plane_coordinate", - [ - (0.0, 0.0), - (0.0, 1.0), - (1.0, 0.0), - (1.0, 1.0), - (0.5, 0.5), - (0.1, 0.1), - (-1.0, -1.0), - ], -) -def test_trivial( - source_plane_coordinate: Tuple[float, float], - grid, - solver, -): - coordinates = solver.solve( - NullTracer(), - source_plane_coordinate=source_plane_coordinate, - ) - coordinates = coordinates.array[~np.isnan(coordinates.array).any(axis=1)] - assert coordinates[0] == pytest.approx(source_plane_coordinate, abs=1.0e-1) - -def test_real_example_jax(grid, tracer): - jax_solver = PointSolver.for_grid( - grid=grid, - pixel_scale_precision=0.001, - array_triangles_cls=CoordinateArrayTriangles, - ) - - result = jax_solver.solve( - tracer=tracer, - source_plane_coordinate=(0.07, 0.07), - ) - - assert len(result) == 15 From 71948fd38af55f9f77cb13961b4db9e68597747c Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 17 Jun 2025 19:39:12 +0100 Subject: [PATCH 3/4] black --- autolens/analysis/analysis/lens.py | 8 +++++--- autolens/analysis/positions.py | 14 +++++++------ autolens/analysis/result.py | 4 +--- autolens/interferometer/model/analysis.py | 9 +++++---- autolens/point/fit/fluxes.py | 10 ++++++---- .../point/fit/positions/image/pair_all.py | 20 ++++++++++--------- autolens/point/max_separation.py | 4 +++- autolens/point/solver/point_solver.py | 2 +- test_autolens/lens/test_operate.py | 16 +++++++++++---- test_autolens/point/triangles/test_solver.py | 1 + 10 files changed, 53 insertions(+), 35 deletions(-) diff --git a/autolens/analysis/analysis/lens.py b/autolens/analysis/analysis/lens.py index 840907c05..de20520ab 100644 --- a/autolens/analysis/analysis/lens.py +++ b/autolens/analysis/analysis/lens.py @@ -127,8 +127,10 @@ def log_likelihood_penalty_from( try: for positions_likelihood in self.positions_likelihood_list: - log_likelihood_penalty = positions_likelihood.log_likelihood_penalty_from( - instance=instance, analysis=self + log_likelihood_penalty = ( + positions_likelihood.log_likelihood_penalty_from( + instance=instance, analysis=self + ) ) log_likelihood_penalty += log_likelihood_penalty @@ -137,4 +139,4 @@ def log_likelihood_penalty_from( except (ValueError, np.linalg.LinAlgError) as e: raise exc.FitException from e - return log_likelihood_penalty \ No newline at end of file + return log_likelihood_penalty diff --git a/autolens/analysis/positions.py b/autolens/analysis/positions.py index 481c6a7fc..79683e266 100644 --- a/autolens/analysis/positions.py +++ b/autolens/analysis/positions.py @@ -136,7 +136,9 @@ def output_positions_info( ) f.write("") - def log_likelihood_penalty_from(self, instance: af.ModelInstance, analysis: AnalysisDataset) -> jnp.array: + def log_likelihood_penalty_from( + self, instance: af.ModelInstance, analysis: AnalysisDataset + ) -> jnp.array: """ Returns a log-likelihood penalty used to constrain lens models where multiple image-plane positions do not trace to within a threshold distance of one another in the source-plane. @@ -174,7 +176,7 @@ def log_likelihood_penalty_from(self, instance: af.ModelInstance, analysis: Anal tracer = analysis.tracer_via_instance_from(instance=instance) if not tracer.has(cls=ag.mp.MassProfile) or len(tracer.planes) == 1: - return jnp.array(0.0), + return (jnp.array(0.0),) positions_fit = SourceMaxSeparation( data=self.positions, @@ -183,11 +185,11 @@ def log_likelihood_penalty_from(self, instance: af.ModelInstance, analysis: Anal plane_redshift=self.plane_redshift, ) - max_separation = jnp.max(positions_fit.furthest_separations_of_plane_positions.array) + max_separation = jnp.max( + positions_fit.furthest_separations_of_plane_positions.array + ) - penalty = self.log_likelihood_penalty_factor * ( - max_separation - self.threshold - ) + penalty = self.log_likelihood_penalty_factor * (max_separation - self.threshold) return jax.lax.cond( max_separation > self.threshold, diff --git a/autolens/analysis/result.py b/autolens/analysis/result.py index a1c15b975..96397b6f7 100644 --- a/autolens/analysis/result.py +++ b/autolens/analysis/result.py @@ -310,9 +310,7 @@ def positions_likelihood_from( mask = np.isfinite(positions.array).all(axis=1) - positions = aa.Grid2DIrregular( - positions[mask] - ) + positions = aa.Grid2DIrregular(positions[mask]) threshold = self.positions_threshold_from( factor=factor, diff --git a/autolens/interferometer/model/analysis.py b/autolens/interferometer/model/analysis.py index e7b062610..9f3118730 100644 --- a/autolens/interferometer/model/analysis.py +++ b/autolens/interferometer/model/analysis.py @@ -152,14 +152,15 @@ def log_likelihood_function(self, instance): """ try: - log_likelihood_penalty = self.log_likelihood_penalty_from( - instance=instance - ) + log_likelihood_penalty = self.log_likelihood_penalty_from(instance=instance) except Exception as e: raise e try: - return self.fit_from(instance=instance).figure_of_merit + log_likelihood_penalty + return ( + self.fit_from(instance=instance).figure_of_merit + + log_likelihood_penalty + ) except ( PixelizationException, exc.PixelizationException, diff --git a/autolens/point/fit/fluxes.py b/autolens/point/fit/fluxes.py index a34c4e56c..898fe6f87 100644 --- a/autolens/point/fit/fluxes.py +++ b/autolens/point/fit/fluxes.py @@ -102,10 +102,12 @@ def model_data(self): are used. """ return aa.ArrayIrregular( - values=jnp.array([ - magnification * self.profile.flux - for magnification in self.magnifications_at_positions - ]) + values=jnp.array( + [ + magnification * self.profile.flux + for magnification in self.magnifications_at_positions + ] + ) ) @property diff --git a/autolens/point/fit/positions/image/pair_all.py b/autolens/point/fit/positions/image/pair_all.py index a21fd39b9..d93e6bbb4 100644 --- a/autolens/point/fit/positions/image/pair_all.py +++ b/autolens/point/fit/positions/image/pair_all.py @@ -107,16 +107,18 @@ def all_permutations_log_likelihoods(self) -> np.ndarray: [ jnp.log( jnp.sum( - jnp.array([ - jnp.exp( - self.log_p( - data_position, - model_position, - sigma, + jnp.array( + [ + jnp.exp( + self.log_p( + data_position, + model_position, + sigma, + ) ) - ) - for model_position in model_data - ]) + for model_position in model_data + ] + ) ) ) for data_position, sigma in zip(self.data, self.noise_map) diff --git a/autolens/point/max_separation.py b/autolens/point/max_separation.py index f17e65de3..a83af2e1e 100644 --- a/autolens/point/max_separation.py +++ b/autolens/point/max_separation.py @@ -42,7 +42,9 @@ def __init__( except TypeError: plane_index = -1 - self.plane_positions = aa.Grid2DIrregular(values=tracer.traced_grid_2d_list_from(grid=data)[plane_index]) + self.plane_positions = aa.Grid2DIrregular( + values=tracer.traced_grid_2d_list_from(grid=data)[plane_index] + ) @property def furthest_separations_of_plane_positions(self) -> aa.ArrayIrregular: diff --git a/autolens/point/solver/point_solver.py b/autolens/point/solver/point_solver.py index 4caee637c..b9bdc5f2b 100644 --- a/autolens/point/solver/point_solver.py +++ b/autolens/point/solver/point_solver.py @@ -62,4 +62,4 @@ def solve( sentinel = jnp.full_like(solution[0], fill_value=jnp.inf) solution = jnp.where(is_nan[:, None], sentinel, solution) - return aa.Grid2DIrregular(solution) \ No newline at end of file + return aa.Grid2DIrregular(solution) diff --git a/test_autolens/lens/test_operate.py b/test_autolens/lens/test_operate.py index d5f7e09fb..ef9936535 100644 --- a/test_autolens/lens/test_operate.py +++ b/test_autolens/lens/test_operate.py @@ -122,10 +122,18 @@ def test__operate_image__galaxy_blurred_image_2d_dict_from( blurring_grid=blurring_grid_2d_7x7, ) - assert blurred_image_dict[g0].slim == pytest.approx(g0_blurred_image.slim.array, 1.0e-4) - assert blurred_image_dict[g1].slim == pytest.approx(g1_blurred_image.slim.array, 1.0e-4) - assert blurred_image_dict[g2].slim == pytest.approx(g2_blurred_image.slim.array, 1.0e-4) - assert blurred_image_dict[g3].slim == pytest.approx(g3_blurred_image.slim.array, 1.0e-4) + assert blurred_image_dict[g0].slim == pytest.approx( + g0_blurred_image.slim.array, 1.0e-4 + ) + assert blurred_image_dict[g1].slim == pytest.approx( + g1_blurred_image.slim.array, 1.0e-4 + ) + assert blurred_image_dict[g2].slim == pytest.approx( + g2_blurred_image.slim.array, 1.0e-4 + ) + assert blurred_image_dict[g3].slim == pytest.approx( + g3_blurred_image.slim.array, 1.0e-4 + ) def test__operate_image__galaxy_visibilities_dict_from_grid_and_transformer( diff --git a/test_autolens/point/triangles/test_solver.py b/test_autolens/point/triangles/test_solver.py index 68813254a..d9fbe996f 100644 --- a/test_autolens/point/triangles/test_solver.py +++ b/test_autolens/point/triangles/test_solver.py @@ -66,6 +66,7 @@ def test_trivial( coordinates = coordinates.array[~np.isnan(coordinates.array).any(axis=1)] assert coordinates[0] == pytest.approx(source_plane_coordinate, abs=1.0e-1) + def test_real_example_jax(grid, tracer): jax_solver = PointSolver.for_grid( grid=grid, From 67a710bb7425f3811b613b04e5ea5c960805714b Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 18 Jun 2025 15:43:51 +0100 Subject: [PATCH 4/4] black --- autolens/interferometer/model/analysis.py | 2 ++ autolens/point/solver/shape_solver.py | 1 + 2 files changed, 3 insertions(+) diff --git a/autolens/interferometer/model/analysis.py b/autolens/interferometer/model/analysis.py index 9f3118730..27910869b 100644 --- a/autolens/interferometer/model/analysis.py +++ b/autolens/interferometer/model/analysis.py @@ -172,6 +172,8 @@ def log_likelihood_function(self, instance): np.linalg.LinAlgError, OverflowError, ) as e: + print(e) + fggdfg raise exc.FitException from e def fit_from( diff --git a/autolens/point/solver/shape_solver.py b/autolens/point/solver/shape_solver.py index f55fbf59f..8f08751e0 100644 --- a/autolens/point/solver/shape_solver.py +++ b/autolens/point/solver/shape_solver.py @@ -301,6 +301,7 @@ def steps( An iterator over the steps of the triangle solver algorithm. """ initial_triangles = self.initial_triangles + for number in range(self.n_steps): plane_triangles = self._plane_triangles( tracer=tracer,