Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autogalaxy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from autoconf import jax_wrapper
from autoconf.dictable import register_parser
from autofit import conf

Expand Down
10 changes: 8 additions & 2 deletions autogalaxy/analysis/analysis/analysis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import numpy as np
from typing import List, Optional

import autofit as af
Expand All @@ -15,7 +16,11 @@

class Analysis(af.Analysis):
def __init__(
self, cosmology: LensingCosmology = None, preloads: aa.Preloads = None, **kwargs
self,
cosmology: LensingCosmology = None,
preloads: aa.Preloads = None,
use_jax: bool = True,
**kwargs,
):
"""
Fits a model to a dataset via a non-linear search.
Expand All @@ -35,7 +40,8 @@ def __init__(

self.cosmology = cosmology or Planck15()
self.preloads = preloads
self.kwargs = kwargs

super().__init__(use_jax=use_jax, **kwargs)

def galaxies_via_instance_from(
self,
Expand Down
3 changes: 3 additions & 0 deletions autogalaxy/analysis/analysis/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import numpy as np
from typing import Optional, Union

from autoconf.dictable import to_dict, output_to_json
Expand Down Expand Up @@ -26,6 +27,7 @@ def __init__(
settings_inversion: aa.SettingsInversion = None,
preloads: aa.Preloads = None,
title_prefix: str = None,
use_jax: bool = True,
**kwargs,
):
"""
Expand Down Expand Up @@ -54,6 +56,7 @@ def __init__(
super().__init__(
cosmology=cosmology,
preloads=preloads,
use_jax=use_jax,
**kwargs,
)

Expand Down
2 changes: 0 additions & 2 deletions autogalaxy/config/general.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
jax:
use_jax: true # If True, uses JAX internally, whereas False uses normal Numpy.
fits:
flip_for_ds9: true
psf:
Expand Down
8 changes: 6 additions & 2 deletions autogalaxy/ellipse/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ class AnalysisEllipse(af.Analysis):
Result = ResultEllipse
Visualizer = VisualizerEllipse

def __init__(self, dataset: aa.Imaging, title_prefix: str = None):
def __init__(
self, dataset: aa.Imaging, title_prefix: str = None, use_jax: bool = False
):
"""
Fits a model made of ellipses to an imaging dataset via a non-linear search.

Expand All @@ -43,7 +45,9 @@ def __init__(self, dataset: aa.Imaging, title_prefix: str = None):
self.dataset = dataset
self.title_prefix = title_prefix

def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float:
super().__init__(use_jax=use_jax)

def log_likelihood_function(self, instance: af.ModelInstance) -> float:
"""
Given an instance of the model, where the model parameters are set via a non-linear search, fit the model
instance to the imaging dataset.
Expand Down
9 changes: 3 additions & 6 deletions autogalaxy/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,21 +264,18 @@ def make_analysis_imaging_7x7():
analysis = ag.AnalysisImaging(
dataset=make_masked_imaging_7x7(),
settings_inversion=aa.SettingsInversion(use_w_tilde=False),
use_jax=False,
)
analysis._adapt_images = make_adapt_images_7x7()
return analysis


def make_analysis_interferometer_7():
analysis = ag.AnalysisInterferometer(
dataset=make_interferometer_7(),
)
analysis = ag.AnalysisInterferometer(dataset=make_interferometer_7(), use_jax=False)
analysis._adapt_images = make_adapt_images_7x7()
return analysis


def make_analysis_ellipse_7x7():
analysis = ag.AnalysisEllipse(
dataset=make_masked_imaging_7x7(),
)
analysis = ag.AnalysisEllipse(dataset=make_masked_imaging_7x7(), use_jax=False)
return analysis
2 changes: 1 addition & 1 deletion autogalaxy/galaxy/galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def potential_2d_from(
if self.has(cls=MassProfile):
return sum(
map(
lambda p: p.potential_2d_from(grid=grid),
lambda p: p.potential_2d_from(grid=grid, xp=xp),
self.cls_list_from(cls=MassProfile),
)
)
Expand Down
10 changes: 6 additions & 4 deletions autogalaxy/imaging/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
settings_inversion: aa.SettingsInversion = None,
preloads: aa.Preloads = None,
title_prefix: str = None,
use_jax: bool = True,
):
"""
Fits a galaxy model to an imaging dataset via a non-linear search.
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(
settings_inversion=settings_inversion,
preloads=preloads,
title_prefix=title_prefix,
use_jax=use_jax,
)

@property
Expand Down Expand Up @@ -91,7 +93,7 @@ def modify_before_fit(self, paths: af.DirectoryPaths, model: af.Collection):

return self

def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float:
def log_likelihood_function(self, instance: af.ModelInstance) -> float:
"""
Given an instance of the model, where the model parameters are set via a non-linear search, fit the model
instance to the imaging dataset.
Expand Down Expand Up @@ -128,9 +130,9 @@ def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float:
float
The log likelihood indicating how well this model instance fitted the imaging data.
"""
return self.fit_from(instance=instance, xp=xp).figure_of_merit
return self.fit_from(instance=instance).figure_of_merit

def fit_from(self, instance: af.ModelInstance, xp=np) -> FitImaging:
def fit_from(self, instance: af.ModelInstance) -> FitImaging:
"""
Given a model instance create a `FitImaging` object.

Expand Down Expand Up @@ -165,7 +167,7 @@ def fit_from(self, instance: af.ModelInstance, xp=np) -> FitImaging:
dataset_model=dataset_model,
adapt_images=adapt_images,
settings_inversion=self.settings_inversion,
xp=xp,
xp=self._xp,
)

def save_attributes(self, paths: af.DirectoryPaths):
Expand Down
10 changes: 6 additions & 4 deletions autogalaxy/interferometer/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
settings_inversion: aa.SettingsInversion = None,
preloads: aa.Preloads = None,
title_prefix: str = None,
use_jax: bool = True,
):
"""
Fits a galaxy model to an interferometer dataset via a non-linear search.
Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(
settings_inversion=settings_inversion,
preloads=preloads,
title_prefix=title_prefix,
use_jax=use_jax,
)

@property
Expand Down Expand Up @@ -98,7 +100,7 @@ def modify_before_fit(self, paths: af.DirectoryPaths, model: af.Collection):

return self

def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float:
def log_likelihood_function(self, instance: af.ModelInstance) -> float:
"""
Given an instance of the model, where the model parameters are set via a non-linear search, fit the model
instance to the interferometer dataset.
Expand Down Expand Up @@ -134,9 +136,9 @@ def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float:
float
The log likelihood indicating how well this model instance fitted the interferometer data.
"""
return self.fit_from(instance=instance, xp=xp).figure_of_merit
return self.fit_from(instance=instance).figure_of_merit

def fit_from(self, instance: af.ModelInstance, xp=np) -> FitInterferometer:
def fit_from(self, instance: af.ModelInstance) -> FitInterferometer:
"""
Given a model instance create a `FitInterferometer` object.

Expand Down Expand Up @@ -167,7 +169,7 @@ def fit_from(self, instance: af.ModelInstance, xp=np) -> FitInterferometer:
galaxies=galaxies,
adapt_images=adapt_images,
settings_inversion=self.settings_inversion,
xp=xp,
xp=self._xp,
)

def save_attributes(self, paths: af.DirectoryPaths):
Expand Down
38 changes: 19 additions & 19 deletions autogalaxy/operate/deflections.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def magnification_2d_from(self, grid) -> aa.Array2D:

return aa.Array2D(values=1 / det_jacobian, mask=grid.mask)

def hessian_from(self, grid, buffer: float = 0.01, deflections_func=None) -> Tuple:
def hessian_from(self, grid, buffer: float = 0.01, deflections_func=None, xp=np) -> Tuple:
"""
Returns the Hessian of the lensing object, where the Hessian is the second partial derivatives of the
potential (see equation 55 https://inspirehep.net/literature/419263):
Expand Down Expand Up @@ -270,26 +270,26 @@ def hessian_from(self, grid, buffer: float = 0.01, deflections_func=None) -> Tup
if deflections_func is None:
deflections_func = self.deflections_yx_2d_from

grid_shift_y_up = aa.Grid2DIrregular(values=np.zeros(grid.shape))
grid_shift_y_up[:, 0] = grid[:, 0] + buffer
grid_shift_y_up[:, 1] = grid[:, 1]
grid_shift_y_up = aa.Grid2DIrregular(
values=xp.stack([grid[:, 0] + buffer, grid[:, 1]], axis=1)
)

grid_shift_y_down = aa.Grid2DIrregular(values=np.zeros(grid.shape))
grid_shift_y_down[:, 0] = grid[:, 0] - buffer
grid_shift_y_down[:, 1] = grid[:, 1]
grid_shift_y_down = aa.Grid2DIrregular(
values=xp.stack([grid[:, 0] - buffer, grid[:, 1]], axis=1)
)

grid_shift_x_left = aa.Grid2DIrregular(values=np.zeros(grid.shape))
grid_shift_x_left[:, 0] = grid[:, 0]
grid_shift_x_left[:, 1] = grid[:, 1] - buffer
grid_shift_x_left = aa.Grid2DIrregular(
values=xp.stack([grid[:, 0], grid[:, 1] - buffer], axis=1)
)

grid_shift_x_right = aa.Grid2DIrregular(values=np.zeros(grid.shape))
grid_shift_x_right[:, 0] = grid[:, 0]
grid_shift_x_right[:, 1] = grid[:, 1] + buffer
grid_shift_x_right = aa.Grid2DIrregular(
values=xp.stack([grid[:, 0], grid[:, 1] + buffer], axis=1)
)

deflections_up = deflections_func(grid=grid_shift_y_up)
deflections_down = deflections_func(grid=grid_shift_y_down)
deflections_left = deflections_func(grid=grid_shift_x_left)
deflections_right = deflections_func(grid=grid_shift_x_right)
deflections_up = deflections_func(grid=grid_shift_y_up, xp=xp)
deflections_down = deflections_func(grid=grid_shift_y_down, xp=xp)
deflections_left = deflections_func(grid=grid_shift_x_left, xp=xp)
deflections_right = deflections_func(grid=grid_shift_x_right, xp=xp)

hessian_yy = 0.5 * (deflections_up[:, 0] - deflections_down[:, 0]) / buffer
hessian_xy = 0.5 * (deflections_up[:, 1] - deflections_down[:, 1]) / buffer
Expand Down Expand Up @@ -373,7 +373,7 @@ def shear_yx_2d_via_hessian_from(
return ShearYX2DIrregular(values=shear_yx_2d, grid=grid)

def magnification_2d_via_hessian_from(
self, grid, buffer: float = 0.01, deflections_func=None
self, grid, buffer: float = 0.01, deflections_func=None, xp=np
) -> aa.ArrayIrregular:
"""
Returns the 2D magnification map of lensing object, which is computed from the 2D deflection angle map
Expand All @@ -395,7 +395,7 @@ def magnification_2d_via_hessian_from(
The 2D grid of (y,x) arc-second coordinates the deflection angles and magnification map are computed on.
"""
hessian_yy, hessian_xy, hessian_yx, hessian_xx = self.hessian_from(
grid=grid, buffer=buffer, deflections_func=deflections_func
grid=grid, buffer=buffer, deflections_func=deflections_func, xp=xp
)

det_A = (1 - hessian_xx) * (1 - hessian_yy) - hessian_xy * hessian_yx
Expand Down
6 changes: 3 additions & 3 deletions autogalaxy/operate/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def visibilities_from(

if self.has(cls=LightProfile) or isinstance(self, LightProfile):

image_2d = self.image_2d_from(grid=grid)
return transformer.visibilities_from(image=image_2d)
image_2d = self.image_2d_from(grid=grid, xp=xp)
return transformer.visibilities_from(image=image_2d, xp=xp)

return aa.Visibilities.zeros(shape_slim=(transformer.uv_wavelengths.shape[0],))

Expand Down Expand Up @@ -345,7 +345,7 @@ def visibilities_list_from(
shape_slim=(transformer.uv_wavelengths.shape[0],)
)
else:
visibilities = transformer.visibilities_from(image=image_2d)
visibilities = transformer.visibilities_from(image=image_2d, xp=xp)

visibilities_list.append(visibilities)

Expand Down
5 changes: 3 additions & 2 deletions autogalaxy/profiles/mass/total/power_law.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def __init__(

@aa.grid_dec.to_array
def potential_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs):
alpha = self.deflections_yx_2d_from(aa.Grid2DIrregular(grid), **kwargs)

alpha = self.deflections_yx_2d_from(grid=aa.Grid2DIrregular(grid), xp=xp, **kwargs)

alpha_x = alpha[:, 1]
alpha_y = alpha[:, 0]
Expand Down Expand Up @@ -87,7 +88,7 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs):
+ grid.array[:, 0] ** 2
+ 1e-16
)
zh = omega(z, slope, factor, n_terms=20, xp=np)
zh = omega(z, slope, factor, n_terms=20, xp=xp)

complex_angle = (
2.0 * b / (1.0 + self.axis_ratio(xp)) * (b / R) ** (slope - 1.0) * zh
Expand Down
3 changes: 1 addition & 2 deletions autogalaxy/profiles/mass/total/power_law_multipole.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import jax.numpy as jnp
import numpy as np
from typing import Tuple

Expand Down Expand Up @@ -249,7 +248,7 @@ def convergence_2d_from(
/ 2.0
* (self.einstein_radius / r) ** (self.slope - 1)
* self.k_m
* jnp.cos(self.m * (angle - self.angle_m))
* xp.cos(self.m * (angle - self.angle_m))
)

@aa.grid_dec.to_array
Expand Down
5 changes: 3 additions & 2 deletions autogalaxy/quantity/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
func_str: str,
cosmology: LensingCosmology = None,
title_prefix: str = None,
use_jax: bool = True,
):
"""
Fits a galaxy model to a quantity dataset via a non-linear search.
Expand Down Expand Up @@ -56,13 +57,13 @@ def __init__(
A string that is added before the title of all figures output by visualization, for example to
put the name of the dataset and galaxy in the title.
"""
super().__init__(cosmology=cosmology)
super().__init__(cosmology=cosmology, use_jax=use_jax)

self.dataset = dataset
self.func_str = func_str
self.title_prefix = title_prefix

def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float:
def log_likelihood_function(self, instance: af.ModelInstance) -> float:
"""
Given an instance of the model, where the model parameters are set via a non-linear search, fit the model
instance to the quantity's dataset.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test__ellipses_randomly_drawn_via_pdf_gen_from(
):
clean(database_file=database_file)

analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7)
analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7, use_jax=False)

agg = aggregator_from(
database_file=database_file,
Expand Down Expand Up @@ -48,7 +48,7 @@ def test__ellipses_all_above_weight_gen(
):
clean(database_file=database_file)

analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7)
analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7, use_jax=False)

agg = aggregator_from(
database_file=database_file,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test__multipoles_randomly_drawn_via_pdf_gen_from(
):
clean(database_file=database_file)

analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7)
analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7, use_jax=False)

agg = aggregator_from(
database_file=database_file,
Expand Down Expand Up @@ -49,7 +49,7 @@ def test__multipoles_all_above_weight_gen(
):
clean(database_file=database_file)

analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7)
analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7, use_jax=False)

agg = aggregator_from(
database_file=database_file,
Expand Down
Loading
Loading