From c7e3666778b0b1af7de2665ca4af169f29070278 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 13 Nov 2025 10:25:47 +0000 Subject: [PATCH 1/3] imports tested --- autogalaxy/__init__.py | 1 + autogalaxy/analysis/analysis/analysis.py | 6 ++++-- autogalaxy/analysis/analysis/dataset.py | 3 +++ autogalaxy/config/general.yaml | 2 -- autogalaxy/convert.py | 1 - autogalaxy/ellipse/model/analysis.py | 6 ++++-- autogalaxy/fixtures.py | 5 +++-- autogalaxy/imaging/model/analysis.py | 10 ++++++---- autogalaxy/interferometer/model/analysis.py | 10 ++++++---- autogalaxy/profiles/mass/total/power_law_multipole.py | 3 +-- autogalaxy/quantity/model/analysis.py | 5 +++-- .../aggregator/ellipse/test_aggregator_ellipses.py | 4 ++-- .../aggregator/ellipse/test_aggregator_multipoles.py | 4 ++-- .../aggregator/imaging/test_aggregator_imaging.py | 2 +- .../test_aggregator_fit_interferometer.py | 6 +++--- .../interferometer/test_aggregator_interferometer.py | 2 +- test_autogalaxy/aggregator/test_aggregator_galaxies.py | 4 ++-- test_autogalaxy/analysis/analysis/test_analysis.py | 2 +- .../analysis/analysis/test_analysis_dataset.py | 2 +- test_autogalaxy/ellipse/model/test_analysis_ellipse.py | 2 +- test_autogalaxy/imaging/model/test_analysis_imaging.py | 4 ++-- .../model/test_analysis_interferometer.py | 4 ++-- .../quantity/model/test_analysis_quantity.py | 4 ++-- 23 files changed, 51 insertions(+), 41 deletions(-) diff --git a/autogalaxy/__init__.py b/autogalaxy/__init__.py index 467ad79b3..f70a2a67f 100644 --- a/autogalaxy/__init__.py +++ b/autogalaxy/__init__.py @@ -1,3 +1,4 @@ +from autoconf import jax_wrapper from autoconf.dictable import register_parser from autofit import conf diff --git a/autogalaxy/analysis/analysis/analysis.py b/autogalaxy/analysis/analysis/analysis.py index d927ba195..f97b965bc 100644 --- a/autogalaxy/analysis/analysis/analysis.py +++ b/autogalaxy/analysis/analysis/analysis.py @@ -1,4 +1,5 @@ import logging +import numpy as np from typing import List, Optional import autofit as af @@ -15,7 +16,7 @@ class Analysis(af.Analysis): def __init__( - self, cosmology: LensingCosmology = None, preloads: aa.Preloads = None, **kwargs + self, cosmology: LensingCosmology = None, preloads: aa.Preloads = None, use_jax : bool = True, **kwargs ): """ Fits a model to a dataset via a non-linear search. @@ -35,7 +36,8 @@ def __init__( self.cosmology = cosmology or Planck15() self.preloads = preloads - self.kwargs = kwargs + + super().__init__(use_jax=use_jax, **kwargs) def galaxies_via_instance_from( self, diff --git a/autogalaxy/analysis/analysis/dataset.py b/autogalaxy/analysis/analysis/dataset.py index f9e73c3bf..b71ea0948 100644 --- a/autogalaxy/analysis/analysis/dataset.py +++ b/autogalaxy/analysis/analysis/dataset.py @@ -1,4 +1,5 @@ import logging +import numpy as np from typing import Optional, Union from autoconf.dictable import to_dict, output_to_json @@ -26,6 +27,7 @@ def __init__( settings_inversion: aa.SettingsInversion = None, preloads: aa.Preloads = None, title_prefix: str = None, + use_jax : bool = True, **kwargs, ): """ @@ -54,6 +56,7 @@ def __init__( super().__init__( cosmology=cosmology, preloads=preloads, + use_jax=use_jax, **kwargs, ) diff --git a/autogalaxy/config/general.yaml b/autogalaxy/config/general.yaml index ed8da8d05..968da2d93 100644 --- a/autogalaxy/config/general.yaml +++ b/autogalaxy/config/general.yaml @@ -1,5 +1,3 @@ -jax: - use_jax: true # If True, uses JAX internally, whereas False uses normal Numpy. fits: flip_for_ds9: true psf: diff --git a/autogalaxy/convert.py b/autogalaxy/convert.py index 5bda71120..0696c9129 100644 --- a/autogalaxy/convert.py +++ b/autogalaxy/convert.py @@ -68,7 +68,6 @@ def axis_ratio_and_angle_from( fac = xp.sqrt(ell_comps[1] ** 2 + ell_comps[0] ** 2) if xp.__name__.startswith("jax"): import jax - fac = jax.lax.min(fac, 0.999) else: # NumPy fac = np.minimum(fac, 0.999) diff --git a/autogalaxy/ellipse/model/analysis.py b/autogalaxy/ellipse/model/analysis.py index 5635d8c57..8a402477c 100644 --- a/autogalaxy/ellipse/model/analysis.py +++ b/autogalaxy/ellipse/model/analysis.py @@ -20,7 +20,7 @@ class AnalysisEllipse(af.Analysis): Result = ResultEllipse Visualizer = VisualizerEllipse - def __init__(self, dataset: aa.Imaging, title_prefix: str = None): + def __init__(self, dataset: aa.Imaging, title_prefix: str = None, use_jax : bool = False): """ Fits a model made of ellipses to an imaging dataset via a non-linear search. @@ -43,7 +43,9 @@ def __init__(self, dataset: aa.Imaging, title_prefix: str = None): self.dataset = dataset self.title_prefix = title_prefix - def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float: + super().__init__(use_jax=use_jax) + + def log_likelihood_function(self, instance: af.ModelInstance) -> float: """ Given an instance of the model, where the model parameters are set via a non-linear search, fit the model instance to the imaging dataset. diff --git a/autogalaxy/fixtures.py b/autogalaxy/fixtures.py index 970631c4a..ee1b54d53 100644 --- a/autogalaxy/fixtures.py +++ b/autogalaxy/fixtures.py @@ -264,6 +264,7 @@ def make_analysis_imaging_7x7(): analysis = ag.AnalysisImaging( dataset=make_masked_imaging_7x7(), settings_inversion=aa.SettingsInversion(use_w_tilde=False), + use_jax=False ) analysis._adapt_images = make_adapt_images_7x7() return analysis @@ -271,7 +272,7 @@ def make_analysis_imaging_7x7(): def make_analysis_interferometer_7(): analysis = ag.AnalysisInterferometer( - dataset=make_interferometer_7(), + dataset=make_interferometer_7(), use_jax=False ) analysis._adapt_images = make_adapt_images_7x7() return analysis @@ -279,6 +280,6 @@ def make_analysis_interferometer_7(): def make_analysis_ellipse_7x7(): analysis = ag.AnalysisEllipse( - dataset=make_masked_imaging_7x7(), + dataset=make_masked_imaging_7x7(), use_jax=False ) return analysis diff --git a/autogalaxy/imaging/model/analysis.py b/autogalaxy/imaging/model/analysis.py index 1620e4a76..67e4d07ab 100644 --- a/autogalaxy/imaging/model/analysis.py +++ b/autogalaxy/imaging/model/analysis.py @@ -24,6 +24,7 @@ def __init__( settings_inversion: aa.SettingsInversion = None, preloads: aa.Preloads = None, title_prefix: str = None, + use_jax : bool = True, ): """ Fits a galaxy model to an imaging dataset via a non-linear search. @@ -62,6 +63,7 @@ def __init__( settings_inversion=settings_inversion, preloads=preloads, title_prefix=title_prefix, + use_jax=use_jax, ) @property @@ -91,7 +93,7 @@ def modify_before_fit(self, paths: af.DirectoryPaths, model: af.Collection): return self - def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float: + def log_likelihood_function(self, instance: af.ModelInstance) -> float: """ Given an instance of the model, where the model parameters are set via a non-linear search, fit the model instance to the imaging dataset. @@ -128,9 +130,9 @@ def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float: float The log likelihood indicating how well this model instance fitted the imaging data. """ - return self.fit_from(instance=instance, xp=xp).figure_of_merit + return self.fit_from(instance=instance).figure_of_merit - def fit_from(self, instance: af.ModelInstance, xp=np) -> FitImaging: + def fit_from(self, instance: af.ModelInstance) -> FitImaging: """ Given a model instance create a `FitImaging` object. @@ -165,7 +167,7 @@ def fit_from(self, instance: af.ModelInstance, xp=np) -> FitImaging: dataset_model=dataset_model, adapt_images=adapt_images, settings_inversion=self.settings_inversion, - xp=xp, + xp=self._xp, ) def save_attributes(self, paths: af.DirectoryPaths): diff --git a/autogalaxy/interferometer/model/analysis.py b/autogalaxy/interferometer/model/analysis.py index 3961ff99b..fac192015 100644 --- a/autogalaxy/interferometer/model/analysis.py +++ b/autogalaxy/interferometer/model/analysis.py @@ -31,6 +31,7 @@ def __init__( settings_inversion: aa.SettingsInversion = None, preloads: aa.Preloads = None, title_prefix: str = None, + use_jax : bool = True, ): """ Fits a galaxy model to an interferometer dataset via a non-linear search. @@ -69,6 +70,7 @@ def __init__( settings_inversion=settings_inversion, preloads=preloads, title_prefix=title_prefix, + use_jax=use_jax ) @property @@ -98,7 +100,7 @@ def modify_before_fit(self, paths: af.DirectoryPaths, model: af.Collection): return self - def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float: + def log_likelihood_function(self, instance: af.ModelInstance) -> float: """ Given an instance of the model, where the model parameters are set via a non-linear search, fit the model instance to the interferometer dataset. @@ -134,9 +136,9 @@ def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float: float The log likelihood indicating how well this model instance fitted the interferometer data. """ - return self.fit_from(instance=instance, xp=xp).figure_of_merit + return self.fit_from(instance=instance).figure_of_merit - def fit_from(self, instance: af.ModelInstance, xp=np) -> FitInterferometer: + def fit_from(self, instance: af.ModelInstance) -> FitInterferometer: """ Given a model instance create a `FitInterferometer` object. @@ -167,7 +169,7 @@ def fit_from(self, instance: af.ModelInstance, xp=np) -> FitInterferometer: galaxies=galaxies, adapt_images=adapt_images, settings_inversion=self.settings_inversion, - xp=xp, + xp=self._xp, ) def save_attributes(self, paths: af.DirectoryPaths): diff --git a/autogalaxy/profiles/mass/total/power_law_multipole.py b/autogalaxy/profiles/mass/total/power_law_multipole.py index 52ab98c60..37c8a01bd 100644 --- a/autogalaxy/profiles/mass/total/power_law_multipole.py +++ b/autogalaxy/profiles/mass/total/power_law_multipole.py @@ -1,4 +1,3 @@ -import jax.numpy as jnp import numpy as np from typing import Tuple @@ -249,7 +248,7 @@ def convergence_2d_from( / 2.0 * (self.einstein_radius / r) ** (self.slope - 1) * self.k_m - * jnp.cos(self.m * (angle - self.angle_m)) + * xp.cos(self.m * (angle - self.angle_m)) ) @aa.grid_dec.to_array diff --git a/autogalaxy/quantity/model/analysis.py b/autogalaxy/quantity/model/analysis.py index 3e5710692..a78a256b2 100644 --- a/autogalaxy/quantity/model/analysis.py +++ b/autogalaxy/quantity/model/analysis.py @@ -22,6 +22,7 @@ def __init__( func_str: str, cosmology: LensingCosmology = None, title_prefix: str = None, + use_jax : bool = True, ): """ Fits a galaxy model to a quantity dataset via a non-linear search. @@ -56,13 +57,13 @@ def __init__( A string that is added before the title of all figures output by visualization, for example to put the name of the dataset and galaxy in the title. """ - super().__init__(cosmology=cosmology) + super().__init__(cosmology=cosmology, use_jax=use_jax) self.dataset = dataset self.func_str = func_str self.title_prefix = title_prefix - def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float: + def log_likelihood_function(self, instance: af.ModelInstance) -> float: """ Given an instance of the model, where the model parameters are set via a non-linear search, fit the model instance to the quantity's dataset. diff --git a/test_autogalaxy/aggregator/ellipse/test_aggregator_ellipses.py b/test_autogalaxy/aggregator/ellipse/test_aggregator_ellipses.py index 51454b6c6..5a5e995a1 100644 --- a/test_autogalaxy/aggregator/ellipse/test_aggregator_ellipses.py +++ b/test_autogalaxy/aggregator/ellipse/test_aggregator_ellipses.py @@ -12,7 +12,7 @@ def test__ellipses_randomly_drawn_via_pdf_gen_from( ): clean(database_file=database_file) - analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7) + analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7, use_jax=False) agg = aggregator_from( database_file=database_file, @@ -48,7 +48,7 @@ def test__ellipses_all_above_weight_gen( ): clean(database_file=database_file) - analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7) + analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7, use_jax=False) agg = aggregator_from( database_file=database_file, diff --git a/test_autogalaxy/aggregator/ellipse/test_aggregator_multipoles.py b/test_autogalaxy/aggregator/ellipse/test_aggregator_multipoles.py index 704fdaeb9..f4a75e561 100644 --- a/test_autogalaxy/aggregator/ellipse/test_aggregator_multipoles.py +++ b/test_autogalaxy/aggregator/ellipse/test_aggregator_multipoles.py @@ -12,7 +12,7 @@ def test__multipoles_randomly_drawn_via_pdf_gen_from( ): clean(database_file=database_file) - analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7) + analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7, use_jax=False) agg = aggregator_from( database_file=database_file, @@ -49,7 +49,7 @@ def test__multipoles_all_above_weight_gen( ): clean(database_file=database_file) - analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7) + analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7, use_jax=False) agg = aggregator_from( database_file=database_file, diff --git a/test_autogalaxy/aggregator/imaging/test_aggregator_imaging.py b/test_autogalaxy/aggregator/imaging/test_aggregator_imaging.py index 371e38dd6..14e225292 100644 --- a/test_autogalaxy/aggregator/imaging/test_aggregator_imaging.py +++ b/test_autogalaxy/aggregator/imaging/test_aggregator_imaging.py @@ -18,7 +18,7 @@ def test__dataset_generator_from_aggregator__analysis_has_single_dataset( masked_imaging_7x7 = imaging.apply_mask(mask=mask_2d_7x7) - analysis = ag.AnalysisImaging(dataset=masked_imaging_7x7) + analysis = ag.AnalysisImaging(dataset=masked_imaging_7x7, use_jax=False) agg = aggregator_from( database_file=database_file, diff --git a/test_autogalaxy/aggregator/interferometer/test_aggregator_fit_interferometer.py b/test_autogalaxy/aggregator/interferometer/test_aggregator_fit_interferometer.py index b737568fe..02f2e5603 100644 --- a/test_autogalaxy/aggregator/interferometer/test_aggregator_fit_interferometer.py +++ b/test_autogalaxy/aggregator/interferometer/test_aggregator_fit_interferometer.py @@ -11,7 +11,7 @@ def test__fit_interferometer_randomly_drawn_via_pdf_gen_from( model, ): analysis = ag.AnalysisInterferometer( - dataset=interferometer_7, + dataset=interferometer_7, use_jax=False ) agg = aggregator_from( @@ -78,7 +78,7 @@ def test__fit_interferometer_randomly_drawn_via_pdf_gen_from( def test__fit_interferometer_all_above_weight_gen(interferometer_7, samples, model): clean(database_file=database_file) - analysis = ag.AnalysisInterferometer(dataset=interferometer_7) + analysis = ag.AnalysisInterferometer(dataset=interferometer_7, use_jax=False) agg = aggregator_from( database_file=database_file, @@ -116,7 +116,7 @@ def test__fit_interferometer__adapt_images( adapt_images_7x7, ): analysis = ag.AnalysisInterferometer( - dataset=interferometer_7, + dataset=interferometer_7, use_jax=False ) analysis._adapt_images = adapt_images_7x7 diff --git a/test_autogalaxy/aggregator/interferometer/test_aggregator_interferometer.py b/test_autogalaxy/aggregator/interferometer/test_aggregator_interferometer.py index aa8e5c5eb..debb99cf5 100644 --- a/test_autogalaxy/aggregator/interferometer/test_aggregator_interferometer.py +++ b/test_autogalaxy/aggregator/interferometer/test_aggregator_interferometer.py @@ -21,7 +21,7 @@ def test__interferometer_generator_from_aggregator__analysis_has_single_dataset( transformer_class=ag.TransformerDFT, ) - analysis = ag.AnalysisInterferometer(dataset=interferometer_7) + analysis = ag.AnalysisInterferometer(dataset=interferometer_7, use_jax=False) agg = aggregator_from( database_file=database_file, diff --git a/test_autogalaxy/aggregator/test_aggregator_galaxies.py b/test_autogalaxy/aggregator/test_aggregator_galaxies.py index e0a7fd22d..61a459c5b 100644 --- a/test_autogalaxy/aggregator/test_aggregator_galaxies.py +++ b/test_autogalaxy/aggregator/test_aggregator_galaxies.py @@ -12,7 +12,7 @@ def test__galaxies_randomly_drawn_via_pdf_gen_from( ): clean(database_file=database_file) - analysis = ag.AnalysisImaging(dataset=masked_imaging_7x7) + analysis = ag.AnalysisImaging(dataset=masked_imaging_7x7, use_jax=False) agg = aggregator_from( database_file=database_file, @@ -46,7 +46,7 @@ def test__galaxies_all_above_weight_gen( ): clean(database_file=database_file) - analysis = ag.AnalysisImaging(dataset=masked_imaging_7x7) + analysis = ag.AnalysisImaging(dataset=masked_imaging_7x7, use_jax=False) agg = aggregator_from( database_file=database_file, diff --git a/test_autogalaxy/analysis/analysis/test_analysis.py b/test_autogalaxy/analysis/analysis/test_analysis.py index 387e50c8f..813cecd4e 100644 --- a/test_autogalaxy/analysis/analysis/test_analysis.py +++ b/test_autogalaxy/analysis/analysis/test_analysis.py @@ -18,7 +18,7 @@ def test__galaxies_via_instance(masked_imaging_7x7): extra_galaxies=af.Collection(extra_galaxy_0=extra_galaxy), ) - analysis = ag.AnalysisImaging(dataset=masked_imaging_7x7) + analysis = ag.AnalysisImaging(dataset=masked_imaging_7x7, use_jax=False) instance = model.instance_from_unit_vector([]) diff --git a/test_autogalaxy/analysis/analysis/test_analysis_dataset.py b/test_autogalaxy/analysis/analysis/test_analysis_dataset.py index 00d496a62..5bad3cbf9 100644 --- a/test_autogalaxy/analysis/analysis/test_analysis_dataset.py +++ b/test_autogalaxy/analysis/analysis/test_analysis_dataset.py @@ -29,7 +29,7 @@ def test__instance_with_associated_adapt_images_from(masked_imaging_7x7): galaxy_name_image_dict=adapt_galaxy_name_image_dict, ) - analysis = ag.AnalysisImaging(dataset=masked_imaging_7x7) + analysis = ag.AnalysisImaging(dataset=masked_imaging_7x7, use_jax=False) analysis._adapt_images = adapt_images adapt_images = analysis.adapt_images_via_instance_from(instance=instance) diff --git a/test_autogalaxy/ellipse/model/test_analysis_ellipse.py b/test_autogalaxy/ellipse/model/test_analysis_ellipse.py index 29cb4aead..edda075bf 100644 --- a/test_autogalaxy/ellipse/model/test_analysis_ellipse.py +++ b/test_autogalaxy/ellipse/model/test_analysis_ellipse.py @@ -56,7 +56,7 @@ def test__figure_of_merit( model = af.Collection(ellipses=ellipse_list, multipoles=multipole_list) - analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7) + analysis = ag.AnalysisEllipse(dataset=masked_imaging_7x7, use_jax=False) instance = model.instance_from_prior_medians() fit_figure_of_merit = analysis.log_likelihood_function(instance=instance) diff --git a/test_autogalaxy/imaging/model/test_analysis_imaging.py b/test_autogalaxy/imaging/model/test_analysis_imaging.py index 7d3173ec9..d65b64a37 100644 --- a/test_autogalaxy/imaging/model/test_analysis_imaging.py +++ b/test_autogalaxy/imaging/model/test_analysis_imaging.py @@ -11,7 +11,7 @@ def test__make_result__result_imaging_is_returned(masked_imaging_7x7): model = af.Collection(galaxies=af.Collection(galaxy_0=ag.Galaxy(redshift=0.5))) - analysis = ag.AnalysisImaging(dataset=masked_imaging_7x7) + analysis = ag.AnalysisImaging(dataset=masked_imaging_7x7, use_jax=False) search = ag.m.MockSearch(name="test_search") @@ -27,7 +27,7 @@ def test__figure_of_merit__matches_correct_fit_given_galaxy_profiles( model = af.Collection(galaxies=af.Collection(galaxy=galaxy)) - analysis = ag.AnalysisImaging(dataset=masked_imaging_7x7) + analysis = ag.AnalysisImaging(dataset=masked_imaging_7x7, use_jax=False) instance = model.instance_from_unit_vector([]) fit_figure_of_merit = analysis.log_likelihood_function(instance=instance) diff --git a/test_autogalaxy/interferometer/model/test_analysis_interferometer.py b/test_autogalaxy/interferometer/model/test_analysis_interferometer.py index 362d1e799..16673f50b 100644 --- a/test_autogalaxy/interferometer/model/test_analysis_interferometer.py +++ b/test_autogalaxy/interferometer/model/test_analysis_interferometer.py @@ -12,7 +12,7 @@ def test__make_result__result_interferometer_is_returned(interferometer_7): model = af.Collection(galaxies=af.Collection(galaxy_0=ag.Galaxy(redshift=0.5))) - analysis = ag.AnalysisInterferometer(dataset=interferometer_7) + analysis = ag.AnalysisInterferometer(dataset=interferometer_7, use_jax=False) search = ag.m.MockSearch(name="test_search") @@ -28,7 +28,7 @@ def test__fit_figure_of_merit__matches_correct_fit_given_galaxy_profiles( model = af.Collection(galaxies=af.Collection(galaxy=galaxy)) - analysis = ag.AnalysisInterferometer(dataset=interferometer_7) + analysis = ag.AnalysisInterferometer(dataset=interferometer_7, use_jax=False) instance = model.instance_from_unit_vector([]) fit_figure_of_merit = analysis.log_likelihood_function(instance=instance) diff --git a/test_autogalaxy/quantity/model/test_analysis_quantity.py b/test_autogalaxy/quantity/model/test_analysis_quantity.py index da304b96f..da51e8a01 100644 --- a/test_autogalaxy/quantity/model/test_analysis_quantity.py +++ b/test_autogalaxy/quantity/model/test_analysis_quantity.py @@ -15,7 +15,7 @@ def test__make_result__result_quantity_is_returned( model = af.Collection(galaxies=af.Collection(galaxy_0=ag.Galaxy(redshift=0.5))) analysis = ag.AnalysisQuantity( - dataset=dataset_quantity_7x7_array_2d, func_str="convergence_2d_from" + dataset=dataset_quantity_7x7_array_2d, func_str="convergence_2d_from", use_jax=False ) search = ag.m.MockSearch(name="test_search") @@ -32,7 +32,7 @@ def test__figure_of_merit__matches_correct_fit_given_galaxy_profiles( model = af.Collection(galaxies=af.Collection(galaxy=galaxy)) analysis = ag.AnalysisQuantity( - dataset=dataset_quantity_7x7_array_2d, func_str="convergence_2d_from" + dataset=dataset_quantity_7x7_array_2d, func_str="convergence_2d_from", use_jax=False ) instance = model.instance_from_unit_vector([]) From 9d7507505b3a4cfef745393c0eccc88bf9e636c0 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 13 Nov 2025 10:27:08 +0000 Subject: [PATCH 2/3] black --- autogalaxy/analysis/analysis/analysis.py | 6 +++++- autogalaxy/analysis/analysis/dataset.py | 2 +- autogalaxy/convert.py | 1 + autogalaxy/ellipse/model/analysis.py | 4 +++- autogalaxy/fixtures.py | 10 +++------- autogalaxy/imaging/model/analysis.py | 2 +- autogalaxy/interferometer/model/analysis.py | 4 ++-- autogalaxy/quantity/model/analysis.py | 2 +- .../test_aggregator_fit_interferometer.py | 8 ++------ .../quantity/model/test_analysis_quantity.py | 8 ++++++-- 10 files changed, 25 insertions(+), 22 deletions(-) diff --git a/autogalaxy/analysis/analysis/analysis.py b/autogalaxy/analysis/analysis/analysis.py index f97b965bc..438ac013c 100644 --- a/autogalaxy/analysis/analysis/analysis.py +++ b/autogalaxy/analysis/analysis/analysis.py @@ -16,7 +16,11 @@ class Analysis(af.Analysis): def __init__( - self, cosmology: LensingCosmology = None, preloads: aa.Preloads = None, use_jax : bool = True, **kwargs + self, + cosmology: LensingCosmology = None, + preloads: aa.Preloads = None, + use_jax: bool = True, + **kwargs, ): """ Fits a model to a dataset via a non-linear search. diff --git a/autogalaxy/analysis/analysis/dataset.py b/autogalaxy/analysis/analysis/dataset.py index b71ea0948..7ab645374 100644 --- a/autogalaxy/analysis/analysis/dataset.py +++ b/autogalaxy/analysis/analysis/dataset.py @@ -27,7 +27,7 @@ def __init__( settings_inversion: aa.SettingsInversion = None, preloads: aa.Preloads = None, title_prefix: str = None, - use_jax : bool = True, + use_jax: bool = True, **kwargs, ): """ diff --git a/autogalaxy/convert.py b/autogalaxy/convert.py index 0696c9129..5bda71120 100644 --- a/autogalaxy/convert.py +++ b/autogalaxy/convert.py @@ -68,6 +68,7 @@ def axis_ratio_and_angle_from( fac = xp.sqrt(ell_comps[1] ** 2 + ell_comps[0] ** 2) if xp.__name__.startswith("jax"): import jax + fac = jax.lax.min(fac, 0.999) else: # NumPy fac = np.minimum(fac, 0.999) diff --git a/autogalaxy/ellipse/model/analysis.py b/autogalaxy/ellipse/model/analysis.py index 8a402477c..8451f12bb 100644 --- a/autogalaxy/ellipse/model/analysis.py +++ b/autogalaxy/ellipse/model/analysis.py @@ -20,7 +20,9 @@ class AnalysisEllipse(af.Analysis): Result = ResultEllipse Visualizer = VisualizerEllipse - def __init__(self, dataset: aa.Imaging, title_prefix: str = None, use_jax : bool = False): + def __init__( + self, dataset: aa.Imaging, title_prefix: str = None, use_jax: bool = False + ): """ Fits a model made of ellipses to an imaging dataset via a non-linear search. diff --git a/autogalaxy/fixtures.py b/autogalaxy/fixtures.py index ee1b54d53..32e492fc0 100644 --- a/autogalaxy/fixtures.py +++ b/autogalaxy/fixtures.py @@ -264,22 +264,18 @@ def make_analysis_imaging_7x7(): analysis = ag.AnalysisImaging( dataset=make_masked_imaging_7x7(), settings_inversion=aa.SettingsInversion(use_w_tilde=False), - use_jax=False + use_jax=False, ) analysis._adapt_images = make_adapt_images_7x7() return analysis def make_analysis_interferometer_7(): - analysis = ag.AnalysisInterferometer( - dataset=make_interferometer_7(), use_jax=False - ) + analysis = ag.AnalysisInterferometer(dataset=make_interferometer_7(), use_jax=False) analysis._adapt_images = make_adapt_images_7x7() return analysis def make_analysis_ellipse_7x7(): - analysis = ag.AnalysisEllipse( - dataset=make_masked_imaging_7x7(), use_jax=False - ) + analysis = ag.AnalysisEllipse(dataset=make_masked_imaging_7x7(), use_jax=False) return analysis diff --git a/autogalaxy/imaging/model/analysis.py b/autogalaxy/imaging/model/analysis.py index 67e4d07ab..75d573123 100644 --- a/autogalaxy/imaging/model/analysis.py +++ b/autogalaxy/imaging/model/analysis.py @@ -24,7 +24,7 @@ def __init__( settings_inversion: aa.SettingsInversion = None, preloads: aa.Preloads = None, title_prefix: str = None, - use_jax : bool = True, + use_jax: bool = True, ): """ Fits a galaxy model to an imaging dataset via a non-linear search. diff --git a/autogalaxy/interferometer/model/analysis.py b/autogalaxy/interferometer/model/analysis.py index fac192015..e07aaf183 100644 --- a/autogalaxy/interferometer/model/analysis.py +++ b/autogalaxy/interferometer/model/analysis.py @@ -31,7 +31,7 @@ def __init__( settings_inversion: aa.SettingsInversion = None, preloads: aa.Preloads = None, title_prefix: str = None, - use_jax : bool = True, + use_jax: bool = True, ): """ Fits a galaxy model to an interferometer dataset via a non-linear search. @@ -70,7 +70,7 @@ def __init__( settings_inversion=settings_inversion, preloads=preloads, title_prefix=title_prefix, - use_jax=use_jax + use_jax=use_jax, ) @property diff --git a/autogalaxy/quantity/model/analysis.py b/autogalaxy/quantity/model/analysis.py index a78a256b2..3537710f7 100644 --- a/autogalaxy/quantity/model/analysis.py +++ b/autogalaxy/quantity/model/analysis.py @@ -22,7 +22,7 @@ def __init__( func_str: str, cosmology: LensingCosmology = None, title_prefix: str = None, - use_jax : bool = True, + use_jax: bool = True, ): """ Fits a galaxy model to a quantity dataset via a non-linear search. diff --git a/test_autogalaxy/aggregator/interferometer/test_aggregator_fit_interferometer.py b/test_autogalaxy/aggregator/interferometer/test_aggregator_fit_interferometer.py index 02f2e5603..f62b59fee 100644 --- a/test_autogalaxy/aggregator/interferometer/test_aggregator_fit_interferometer.py +++ b/test_autogalaxy/aggregator/interferometer/test_aggregator_fit_interferometer.py @@ -10,9 +10,7 @@ def test__fit_interferometer_randomly_drawn_via_pdf_gen_from( samples, model, ): - analysis = ag.AnalysisInterferometer( - dataset=interferometer_7, use_jax=False - ) + analysis = ag.AnalysisInterferometer(dataset=interferometer_7, use_jax=False) agg = aggregator_from( database_file=database_file, @@ -115,9 +113,7 @@ def test__fit_interferometer__adapt_images( model, adapt_images_7x7, ): - analysis = ag.AnalysisInterferometer( - dataset=interferometer_7, use_jax=False - ) + analysis = ag.AnalysisInterferometer(dataset=interferometer_7, use_jax=False) analysis._adapt_images = adapt_images_7x7 agg = aggregator_from( diff --git a/test_autogalaxy/quantity/model/test_analysis_quantity.py b/test_autogalaxy/quantity/model/test_analysis_quantity.py index da51e8a01..c6d0f3cac 100644 --- a/test_autogalaxy/quantity/model/test_analysis_quantity.py +++ b/test_autogalaxy/quantity/model/test_analysis_quantity.py @@ -15,7 +15,9 @@ def test__make_result__result_quantity_is_returned( model = af.Collection(galaxies=af.Collection(galaxy_0=ag.Galaxy(redshift=0.5))) analysis = ag.AnalysisQuantity( - dataset=dataset_quantity_7x7_array_2d, func_str="convergence_2d_from", use_jax=False + dataset=dataset_quantity_7x7_array_2d, + func_str="convergence_2d_from", + use_jax=False, ) search = ag.m.MockSearch(name="test_search") @@ -32,7 +34,9 @@ def test__figure_of_merit__matches_correct_fit_given_galaxy_profiles( model = af.Collection(galaxies=af.Collection(galaxy=galaxy)) analysis = ag.AnalysisQuantity( - dataset=dataset_quantity_7x7_array_2d, func_str="convergence_2d_from", use_jax=False + dataset=dataset_quantity_7x7_array_2d, + func_str="convergence_2d_from", + use_jax=False, ) instance = model.instance_from_unit_vector([]) From f3786e588b466e24d15c838397285ad38bbb9e93 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 13 Nov 2025 15:22:46 +0000 Subject: [PATCH 3/3] finish --- autogalaxy/galaxy/galaxy.py | 2 +- autogalaxy/operate/deflections.py | 38 ++++++++++----------- autogalaxy/operate/image.py | 6 ++-- autogalaxy/profiles/mass/total/power_law.py | 5 +-- 4 files changed, 26 insertions(+), 25 deletions(-) diff --git a/autogalaxy/galaxy/galaxy.py b/autogalaxy/galaxy/galaxy.py index 2558b5410..7126d4988 100644 --- a/autogalaxy/galaxy/galaxy.py +++ b/autogalaxy/galaxy/galaxy.py @@ -337,7 +337,7 @@ def potential_2d_from( if self.has(cls=MassProfile): return sum( map( - lambda p: p.potential_2d_from(grid=grid), + lambda p: p.potential_2d_from(grid=grid, xp=xp), self.cls_list_from(cls=MassProfile), ) ) diff --git a/autogalaxy/operate/deflections.py b/autogalaxy/operate/deflections.py index ceb7bbb41..36fe2c7c5 100644 --- a/autogalaxy/operate/deflections.py +++ b/autogalaxy/operate/deflections.py @@ -242,7 +242,7 @@ def magnification_2d_from(self, grid) -> aa.Array2D: return aa.Array2D(values=1 / det_jacobian, mask=grid.mask) - def hessian_from(self, grid, buffer: float = 0.01, deflections_func=None) -> Tuple: + def hessian_from(self, grid, buffer: float = 0.01, deflections_func=None, xp=np) -> Tuple: """ Returns the Hessian of the lensing object, where the Hessian is the second partial derivatives of the potential (see equation 55 https://inspirehep.net/literature/419263): @@ -270,26 +270,26 @@ def hessian_from(self, grid, buffer: float = 0.01, deflections_func=None) -> Tup if deflections_func is None: deflections_func = self.deflections_yx_2d_from - grid_shift_y_up = aa.Grid2DIrregular(values=np.zeros(grid.shape)) - grid_shift_y_up[:, 0] = grid[:, 0] + buffer - grid_shift_y_up[:, 1] = grid[:, 1] + grid_shift_y_up = aa.Grid2DIrregular( + values=xp.stack([grid[:, 0] + buffer, grid[:, 1]], axis=1) + ) - grid_shift_y_down = aa.Grid2DIrregular(values=np.zeros(grid.shape)) - grid_shift_y_down[:, 0] = grid[:, 0] - buffer - grid_shift_y_down[:, 1] = grid[:, 1] + grid_shift_y_down = aa.Grid2DIrregular( + values=xp.stack([grid[:, 0] - buffer, grid[:, 1]], axis=1) + ) - grid_shift_x_left = aa.Grid2DIrregular(values=np.zeros(grid.shape)) - grid_shift_x_left[:, 0] = grid[:, 0] - grid_shift_x_left[:, 1] = grid[:, 1] - buffer + grid_shift_x_left = aa.Grid2DIrregular( + values=xp.stack([grid[:, 0], grid[:, 1] - buffer], axis=1) + ) - grid_shift_x_right = aa.Grid2DIrregular(values=np.zeros(grid.shape)) - grid_shift_x_right[:, 0] = grid[:, 0] - grid_shift_x_right[:, 1] = grid[:, 1] + buffer + grid_shift_x_right = aa.Grid2DIrregular( + values=xp.stack([grid[:, 0], grid[:, 1] + buffer], axis=1) + ) - deflections_up = deflections_func(grid=grid_shift_y_up) - deflections_down = deflections_func(grid=grid_shift_y_down) - deflections_left = deflections_func(grid=grid_shift_x_left) - deflections_right = deflections_func(grid=grid_shift_x_right) + deflections_up = deflections_func(grid=grid_shift_y_up, xp=xp) + deflections_down = deflections_func(grid=grid_shift_y_down, xp=xp) + deflections_left = deflections_func(grid=grid_shift_x_left, xp=xp) + deflections_right = deflections_func(grid=grid_shift_x_right, xp=xp) hessian_yy = 0.5 * (deflections_up[:, 0] - deflections_down[:, 0]) / buffer hessian_xy = 0.5 * (deflections_up[:, 1] - deflections_down[:, 1]) / buffer @@ -373,7 +373,7 @@ def shear_yx_2d_via_hessian_from( return ShearYX2DIrregular(values=shear_yx_2d, grid=grid) def magnification_2d_via_hessian_from( - self, grid, buffer: float = 0.01, deflections_func=None + self, grid, buffer: float = 0.01, deflections_func=None, xp=np ) -> aa.ArrayIrregular: """ Returns the 2D magnification map of lensing object, which is computed from the 2D deflection angle map @@ -395,7 +395,7 @@ def magnification_2d_via_hessian_from( The 2D grid of (y,x) arc-second coordinates the deflection angles and magnification map are computed on. """ hessian_yy, hessian_xy, hessian_yx, hessian_xx = self.hessian_from( - grid=grid, buffer=buffer, deflections_func=deflections_func + grid=grid, buffer=buffer, deflections_func=deflections_func, xp=xp ) det_A = (1 - hessian_xx) * (1 - hessian_yy) - hessian_xy * hessian_yx diff --git a/autogalaxy/operate/image.py b/autogalaxy/operate/image.py index 7b95e126e..6a3fe668b 100644 --- a/autogalaxy/operate/image.py +++ b/autogalaxy/operate/image.py @@ -193,8 +193,8 @@ def visibilities_from( if self.has(cls=LightProfile) or isinstance(self, LightProfile): - image_2d = self.image_2d_from(grid=grid) - return transformer.visibilities_from(image=image_2d) + image_2d = self.image_2d_from(grid=grid, xp=xp) + return transformer.visibilities_from(image=image_2d, xp=xp) return aa.Visibilities.zeros(shape_slim=(transformer.uv_wavelengths.shape[0],)) @@ -345,7 +345,7 @@ def visibilities_list_from( shape_slim=(transformer.uv_wavelengths.shape[0],) ) else: - visibilities = transformer.visibilities_from(image=image_2d) + visibilities = transformer.visibilities_from(image=image_2d, xp=xp) visibilities_list.append(visibilities) diff --git a/autogalaxy/profiles/mass/total/power_law.py b/autogalaxy/profiles/mass/total/power_law.py index 3742bd23f..6c8beb3d3 100644 --- a/autogalaxy/profiles/mass/total/power_law.py +++ b/autogalaxy/profiles/mass/total/power_law.py @@ -39,7 +39,8 @@ def __init__( @aa.grid_dec.to_array def potential_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): - alpha = self.deflections_yx_2d_from(aa.Grid2DIrregular(grid), **kwargs) + + alpha = self.deflections_yx_2d_from(grid=aa.Grid2DIrregular(grid), xp=xp, **kwargs) alpha_x = alpha[:, 1] alpha_y = alpha[:, 0] @@ -87,7 +88,7 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): + grid.array[:, 0] ** 2 + 1e-16 ) - zh = omega(z, slope, factor, n_terms=20, xp=np) + zh = omega(z, slope, factor, n_terms=20, xp=xp) complex_angle = ( 2.0 * b / (1.0 + self.axis_ratio(xp)) * (b / R) ** (slope - 1.0) * zh