From 038ec807589506c0f371a218aaf3fc729be90e90 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 14 Dec 2025 17:46:05 +0000 Subject: [PATCH] complete --- autolens/analysis/analysis/dataset.py | 1 + autolens/analysis/analysis/lens.py | 19 +++++++++++++++---- autolens/imaging/model/analysis.py | 1 - autolens/interferometer/model/analysis.py | 2 +- autolens/lens/tracer_util.py | 5 +++-- 5 files changed, 20 insertions(+), 8 deletions(-) diff --git a/autolens/analysis/analysis/dataset.py b/autolens/analysis/analysis/dataset.py index 681cbb625..6f29dce1a 100644 --- a/autolens/analysis/analysis/dataset.py +++ b/autolens/analysis/analysis/dataset.py @@ -85,6 +85,7 @@ def __init__( self=self, positions_likelihood_list=positions_likelihood_list, cosmology=cosmology, + use_jax=use_jax ) self.raise_inversion_positions_likelihood_exception = ( diff --git a/autolens/analysis/analysis/lens.py b/autolens/analysis/analysis/lens.py index 9f00b6278..8b2a16770 100644 --- a/autolens/analysis/analysis/lens.py +++ b/autolens/analysis/analysis/lens.py @@ -23,6 +23,7 @@ def __init__( self, positions_likelihood_list: Optional[List[PositionsLH]] = None, cosmology: ag.cosmo.LensingCosmology = None, + use_jax: bool = True, ): """ Analysis classes are used by PyAutoFit to fit a model to a dataset via a non-linear search. @@ -44,6 +45,15 @@ def __init__( self.cosmology = cosmology or Planck15() self.positions_likelihood_list = positions_likelihood_list + self._use_jax = use_jax + + @property + def _xp(self): + if self._use_jax: + import jax.numpy as jnp + return jnp + return np + def tracer_via_instance_from( self, instance: af.ModelInstance, @@ -72,8 +82,9 @@ def tracer_via_instance_from( subhalo_centre = tracer_util.grid_2d_at_redshift_from( galaxies=instance.galaxies, redshift=instance.galaxies.subhalo.redshift, - grid=aa.Grid2DIrregular(values=[instance.galaxies.subhalo.mass.centre]), + grid=aa.Grid2DIrregular(values=[instance.galaxies.subhalo.mass.centre], xp=self._xp), cosmology=self.cosmology, + xp=self._xp ) instance.galaxies.subhalo.mass.centre = tuple(subhalo_centre.in_list[0]) @@ -95,7 +106,7 @@ def tracer_via_instance_from( ) def log_likelihood_penalty_from( - self, instance: af.ModelInstance, xp=np + self, instance: af.ModelInstance, ) -> Optional[float]: """ Call the positions overwrite log likelihood function, which add a penalty term to the likelihood if the @@ -116,7 +127,7 @@ def log_likelihood_penalty_from( The penalty value of the positions log likelihood, if the positions do not trace close in the source plane, else a None is returned to indicate there is no penalty. """ - log_likelihood_penalty = xp.array(0.0) + log_likelihood_penalty = self._xp.array(0.0) if self.positions_likelihood_list is not None: @@ -126,7 +137,7 @@ def log_likelihood_penalty_from( log_likelihood_penalty = ( positions_likelihood.log_likelihood_penalty_from( - instance=instance, analysis=self, xp=xp + instance=instance, analysis=self, xp=self._xp ) ) diff --git a/autolens/imaging/model/analysis.py b/autolens/imaging/model/analysis.py index acd2de877..8d7464c29 100644 --- a/autolens/imaging/model/analysis.py +++ b/autolens/imaging/model/analysis.py @@ -59,7 +59,6 @@ def log_likelihood_function(self, instance: af.ModelInstance) -> float: log_likelihood_penalty = self.log_likelihood_penalty_from( instance=instance, - xp=self._xp ) if self._use_jax: diff --git a/autolens/interferometer/model/analysis.py b/autolens/interferometer/model/analysis.py index 0a51132e8..4bb3ba50c 100644 --- a/autolens/interferometer/model/analysis.py +++ b/autolens/interferometer/model/analysis.py @@ -130,7 +130,7 @@ def log_likelihood_function(self, instance): """ log_likelihood_penalty = self.log_likelihood_penalty_from( - instance=instance, xp=self._xp + instance=instance, ) return self.fit_from(instance=instance).figure_of_merit - log_likelihood_penalty diff --git a/autolens/lens/tracer_util.py b/autolens/lens/tracer_util.py index 0700396dc..c14cc20a3 100644 --- a/autolens/lens/tracer_util.py +++ b/autolens/lens/tracer_util.py @@ -179,6 +179,7 @@ def grid_2d_at_redshift_from( galaxies: List[ag.Galaxy], grid: aa.type.Grid2DLike, cosmology: ag.cosmo.LensingCosmology = None, + xp=np, ) -> aa.type.Grid2DLike: """ Returns a ray-traced grid of 2D Cartesian (y,x) coordinates, which accounts for multi-plane ray-tracing, at a @@ -237,7 +238,7 @@ def grid_2d_at_redshift_from( if plane_index_with_redshift: traced_grid_list = traced_grid_2d_list_from( - planes=planes, grid=grid, cosmology=cosmology + planes=planes, grid=grid, cosmology=cosmology, xp=xp ) return traced_grid_list[plane_index_with_redshift[0]] @@ -249,7 +250,7 @@ def grid_2d_at_redshift_from( planes.insert(plane_index_insert, [ag.Galaxy(redshift=redshift)]) traced_grid_list = traced_grid_2d_list_from( - planes=planes, grid=grid, cosmology=cosmology + planes=planes, grid=grid, cosmology=cosmology, xp=xp ) return traced_grid_list[plane_index_insert]