Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autolens/analysis/analysis/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
self=self,
positions_likelihood_list=positions_likelihood_list,
cosmology=cosmology,
use_jax=use_jax
)

self.raise_inversion_positions_likelihood_exception = (
Expand Down
19 changes: 15 additions & 4 deletions autolens/analysis/analysis/lens.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
self,
positions_likelihood_list: Optional[List[PositionsLH]] = None,
cosmology: ag.cosmo.LensingCosmology = None,
use_jax: bool = True,
):
"""
Analysis classes are used by PyAutoFit to fit a model to a dataset via a non-linear search.
Expand All @@ -44,6 +45,15 @@ def __init__(
self.cosmology = cosmology or Planck15()
self.positions_likelihood_list = positions_likelihood_list

self._use_jax = use_jax

@property
def _xp(self):
if self._use_jax:
import jax.numpy as jnp
return jnp
return np

def tracer_via_instance_from(
self,
instance: af.ModelInstance,
Expand Down Expand Up @@ -72,8 +82,9 @@ def tracer_via_instance_from(
subhalo_centre = tracer_util.grid_2d_at_redshift_from(
galaxies=instance.galaxies,
redshift=instance.galaxies.subhalo.redshift,
grid=aa.Grid2DIrregular(values=[instance.galaxies.subhalo.mass.centre]),
grid=aa.Grid2DIrregular(values=[instance.galaxies.subhalo.mass.centre], xp=self._xp),
cosmology=self.cosmology,
xp=self._xp
)

instance.galaxies.subhalo.mass.centre = tuple(subhalo_centre.in_list[0])
Expand All @@ -95,7 +106,7 @@ def tracer_via_instance_from(
)

def log_likelihood_penalty_from(
self, instance: af.ModelInstance, xp=np
self, instance: af.ModelInstance,
) -> Optional[float]:
"""
Call the positions overwrite log likelihood function, which add a penalty term to the likelihood if the
Expand All @@ -116,7 +127,7 @@ def log_likelihood_penalty_from(
The penalty value of the positions log likelihood, if the positions do not trace close in the source plane,
else a None is returned to indicate there is no penalty.
"""
log_likelihood_penalty = xp.array(0.0)
log_likelihood_penalty = self._xp.array(0.0)

if self.positions_likelihood_list is not None:

Expand All @@ -126,7 +137,7 @@ def log_likelihood_penalty_from(

log_likelihood_penalty = (
positions_likelihood.log_likelihood_penalty_from(
instance=instance, analysis=self, xp=xp
instance=instance, analysis=self, xp=self._xp
)
)

Expand Down
1 change: 0 additions & 1 deletion autolens/imaging/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def log_likelihood_function(self, instance: af.ModelInstance) -> float:

log_likelihood_penalty = self.log_likelihood_penalty_from(
instance=instance,
xp=self._xp
)

if self._use_jax:
Expand Down
2 changes: 1 addition & 1 deletion autolens/interferometer/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def log_likelihood_function(self, instance):
"""

log_likelihood_penalty = self.log_likelihood_penalty_from(
instance=instance, xp=self._xp
instance=instance,
)

return self.fit_from(instance=instance).figure_of_merit - log_likelihood_penalty
Expand Down
5 changes: 3 additions & 2 deletions autolens/lens/tracer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def grid_2d_at_redshift_from(
galaxies: List[ag.Galaxy],
grid: aa.type.Grid2DLike,
cosmology: ag.cosmo.LensingCosmology = None,
xp=np,
) -> aa.type.Grid2DLike:
"""
Returns a ray-traced grid of 2D Cartesian (y,x) coordinates, which accounts for multi-plane ray-tracing, at a
Expand Down Expand Up @@ -237,7 +238,7 @@ def grid_2d_at_redshift_from(

if plane_index_with_redshift:
traced_grid_list = traced_grid_2d_list_from(
planes=planes, grid=grid, cosmology=cosmology
planes=planes, grid=grid, cosmology=cosmology, xp=xp
)

return traced_grid_list[plane_index_with_redshift[0]]
Expand All @@ -249,7 +250,7 @@ def grid_2d_at_redshift_from(
planes.insert(plane_index_insert, [ag.Galaxy(redshift=redshift)])

traced_grid_list = traced_grid_2d_list_from(
planes=planes, grid=grid, cosmology=cosmology
planes=planes, grid=grid, cosmology=cosmology, xp=xp
)

return traced_grid_list[plane_index_insert]
Expand Down
Loading