diff --git a/autolens/__init__.py b/autolens/__init__.py index cee1b5bb9..d65d30736 100644 --- a/autolens/__init__.py +++ b/autolens/__init__.py @@ -54,7 +54,9 @@ from autogalaxy import cosmology as cosmo from autogalaxy.analysis.adapt_images.adapt_images import AdaptImages -from autogalaxy.analysis.adapt_images.adapt_image_maker import AdaptImageMaker +from autogalaxy.analysis.adapt_images.adapt_images import ( + galaxy_name_image_dict_via_result_from, +) from autogalaxy.gui.clicker import Clicker from autogalaxy.gui.scribbler import Scribbler from autogalaxy.galaxy.galaxy import Galaxy diff --git a/autolens/analysis/analysis/dataset.py b/autolens/analysis/analysis/dataset.py index 6afa2781f..9bf87b7f3 100644 --- a/autolens/analysis/analysis/dataset.py +++ b/autolens/analysis/analysis/dataset.py @@ -28,7 +28,7 @@ def __init__( self, dataset, positions_likelihood_list: Optional[List[PositionsLH]] = None, - adapt_image_maker: Optional[ag.AdaptImageMaker] = None, + adapt_images: Optional[ag.AdaptImages] = None, cosmology: ag.cosmo.LensingCosmology = None, settings_inversion: aa.SettingsInversion = None, preloads: aa.Preloads = None, @@ -72,7 +72,7 @@ def __init__( super().__init__( dataset=dataset, - adapt_image_maker=adapt_image_maker, + adapt_images=adapt_images, cosmology=cosmology, settings_inversion=settings_inversion, preloads=preloads, diff --git a/autolens/fixtures.py b/autolens/fixtures.py index 093a372c7..030af1b67 100644 --- a/autolens/fixtures.py +++ b/autolens/fixtures.py @@ -131,10 +131,26 @@ def make_adapt_galaxy_name_image_dict_7x7(): return adapt_galaxy_name_image_dict +def make_adapt_galaxy_name_image_plane_mesh_grid_dict_7x7(): + image_plane_mesh_grid_0 = ag.Grid2DIrregular( + values=[(0.0, 0.0), (1.0, 1.0), (2.0, 2.0)] + ) + + image_plane_mesh_grid_1 = ag.Grid2DIrregular( + values=[(3.0, 3.0), (4.0, 4.0), (5.0, 5.0)] + ) + + adapt_galaxy_name_image_plane_mesh_grid_dict = { + str(("galaxies", "lens")): image_plane_mesh_grid_0, + str(("galaxies", "source")): image_plane_mesh_grid_1, + } + + return adapt_galaxy_name_image_plane_mesh_grid_dict def make_adapt_images_7x7(): return ag.AdaptImages( galaxy_name_image_dict=make_adapt_galaxy_name_image_dict_7x7(), + galaxy_name_image_plane_mesh_grid_dict=make_adapt_galaxy_name_image_plane_mesh_grid_dict_7x7(), ) @@ -142,14 +158,17 @@ def make_analysis_imaging_7x7(): analysis = al.AnalysisImaging( dataset=make_masked_imaging_7x7(), use_jax=False, + adapt_images=make_adapt_images_7x7(), ) - analysis._adapt_images = make_adapt_images_7x7() return analysis def make_analysis_interferometer_7(): - analysis = al.AnalysisInterferometer(dataset=make_interferometer_7(), use_jax=False) - analysis._adapt_images = make_adapt_images_7x7() + analysis = al.AnalysisInterferometer( + dataset=make_interferometer_7(), + adapt_images=make_adapt_images_7x7(), + use_jax=False, + ) return analysis diff --git a/autolens/interferometer/model/analysis.py b/autolens/interferometer/model/analysis.py index eaa0ad328..0a51132e8 100644 --- a/autolens/interferometer/model/analysis.py +++ b/autolens/interferometer/model/analysis.py @@ -25,7 +25,7 @@ def __init__( self, dataset, positions_likelihood_list: Optional[PositionsLH] = None, - adapt_image_maker: Optional[ag.AdaptImageMaker] = None, + adapt_images: Optional[ag.AdaptImages] = None, cosmology: ag.cosmo.LensingCosmology = None, settings_inversion: aa.SettingsInversion = None, preloads: aa.Preloads = None, @@ -77,7 +77,7 @@ def __init__( super().__init__( dataset=dataset, positions_likelihood_list=positions_likelihood_list, - adapt_image_maker=adapt_image_maker, + adapt_images=adapt_images, cosmology=cosmology, settings_inversion=settings_inversion, preloads=preloads, diff --git a/autolens/lens/to_inversion.py b/autolens/lens/to_inversion.py index 973d24375..a80dfabdc 100644 --- a/autolens/lens/to_inversion.py +++ b/autolens/lens/to_inversion.py @@ -261,7 +261,7 @@ def adapt_galaxy_image_pg_list(self) -> List[List[np.ndarray]]: for galaxy in galaxies_with_pixelization_list: try: image = self.adapt_images.galaxy_image_dict[galaxy] - except (AttributeError, KeyError): + except (AttributeError, KeyError, TypeError): image = None # Bug fix whereby for certain models the galaxy doesnt pair correctly. diff --git a/test_autolens/aggregator/test_aggregator_fit_imaging.py b/test_autolens/aggregator/test_aggregator_fit_imaging.py index 4687ea903..8dfacfd7a 100644 --- a/test_autolens/aggregator/test_aggregator_fit_imaging.py +++ b/test_autolens/aggregator/test_aggregator_fit_imaging.py @@ -124,6 +124,15 @@ def test__fit_imaging__adapt_images( == list(adapt_images_7x7.galaxy_name_image_dict.values())[0] ).all() + assert ( + list( + fit_list[0].adapt_images.galaxy_image_plane_mesh_grid_dict.values() + )[0] + == list( + adapt_images_7x7.galaxy_name_image_plane_mesh_grid_dict.values() + )[0] + ).all() + assert i == 2 clean(database_file=database_file) diff --git a/test_autolens/aggregator/test_aggregator_fit_interferometer.py b/test_autolens/aggregator/test_aggregator_fit_interferometer.py index 760fdd0bc..003e010e2 100644 --- a/test_autolens/aggregator/test_aggregator_fit_interferometer.py +++ b/test_autolens/aggregator/test_aggregator_fit_interferometer.py @@ -130,6 +130,15 @@ def test__fit_interferometer__adapt_images( == list(adapt_images_7x7.galaxy_name_image_dict.values())[0] ).all() + assert ( + list( + fit_list[0].adapt_images.galaxy_image_plane_mesh_grid_dict.values() + )[0] + == list( + adapt_images_7x7.galaxy_name_image_plane_mesh_grid_dict.values() + )[0] + ).all() + assert i == 2 clean(database_file=database_file) diff --git a/test_autolens/lens/test_to_inversion.py b/test_autolens/lens/test_to_inversion.py index 082423658..2f1c8f4b4 100644 --- a/test_autolens/lens/test_to_inversion.py +++ b/test_autolens/lens/test_to_inversion.py @@ -234,18 +234,22 @@ def test__adapt_galaxy_image_pg_list(masked_imaging_7x7, grid_2d_7x7): def test__image_plane_mesh_grid_pg_list(masked_imaging_7x7): # Test Correct - pixelization = al.m.MockPixelization( - image_mesh=al.m.MockImageMesh(image_plane_mesh_grid=np.array([[1.0, 1.0]])) - ) + image_plane_mesh_grid_0 = np.array([[1.0, 1.0]]) - galaxy_pix = al.Galaxy(redshift=1.0, pixelization=pixelization) + galaxy_pix = al.Galaxy(redshift=1.0, pixelization=al.m.MockPixelization()) galaxy_no_pix = al.Galaxy(redshift=0.5) + adapt_images = al.AdaptImages( + galaxy_image_dict={galaxy_pix: 2}, + galaxy_image_plane_mesh_grid_dict={galaxy_pix: image_plane_mesh_grid_0}, + ) + tracer = al.Tracer(galaxies=[galaxy_no_pix, galaxy_pix]) tracer_to_inversion = al.TracerToInversion( dataset=masked_imaging_7x7, tracer=tracer, + adapt_images=adapt_images, ) mesh_grids = tracer_to_inversion.image_plane_mesh_grid_pg_list @@ -255,22 +259,28 @@ def test__image_plane_mesh_grid_pg_list(masked_imaging_7x7): # Test for extra galaxies - galaxy_pix0 = al.Galaxy(redshift=1.0, pixelization=pixelization) + galaxy_pix_0 = al.Galaxy(redshift=1.0, pixelization=al.m.MockPixelization()) - pixelization = al.m.MockPixelization( - image_mesh=al.m.MockImageMesh(image_plane_mesh_grid=np.array([[2.0, 2.0]])) - ) + image_plane_mesh_grid_1 = np.array([[2.0, 2.0]]) - galaxy_pix1 = al.Galaxy(redshift=2.0, pixelization=pixelization) + galaxy_pix_1 = al.Galaxy(redshift=2.0, pixelization=al.m.MockPixelization()) galaxy_no_pix_0 = al.Galaxy(redshift=0.25) galaxy_no_pix_1 = al.Galaxy(redshift=0.5) galaxy_no_pix_2 = al.Galaxy(redshift=1.5) + adapt_images = al.AdaptImages( + galaxy_image_dict={galaxy_pix_0: 2, galaxy_pix_1: 3}, + galaxy_image_plane_mesh_grid_dict={ + galaxy_pix_0: image_plane_mesh_grid_0, + galaxy_pix_1: image_plane_mesh_grid_1, + }, + ) + tracer = al.Tracer( galaxies=[ - galaxy_pix0, - galaxy_pix1, + galaxy_pix_0, + galaxy_pix_1, galaxy_no_pix_0, galaxy_no_pix_1, galaxy_no_pix_2, @@ -280,6 +290,7 @@ def test__image_plane_mesh_grid_pg_list(masked_imaging_7x7): tracer_to_inversion = al.TracerToInversion( dataset=masked_imaging_7x7, tracer=tracer, + adapt_images=adapt_images, ) mesh_grids = tracer_to_inversion.image_plane_mesh_grid_pg_list @@ -303,26 +314,26 @@ def test__traced_mesh_grid_pg_list(masked_imaging_7x7): values=[[[1.0, 0.0]]], pixel_scales=(1.0, 1.0) ) - pixelization_0 = al.m.MockPixelization( - image_mesh=al.m.MockImageMesh(image_plane_mesh_grid=image_plane_mesh_grid_0) - ) - - galaxy_pix_0 = al.Galaxy(redshift=1.0, pixelization=pixelization_0) + galaxy_pix_0 = al.Galaxy(redshift=1.0, pixelization=al.m.MockPixelization()) image_plane_mesh_grid_1 = al.Grid2D.no_mask( values=[[[2.0, 0.0]]], pixel_scales=(1.0, 1.0) ) - pixelization_1 = al.m.MockPixelization( - image_mesh=al.m.MockImageMesh(image_plane_mesh_grid=image_plane_mesh_grid_1) - ) - - galaxy_pix_1 = al.Galaxy(redshift=1.0, pixelization=pixelization_1) + galaxy_pix_1 = al.Galaxy(redshift=1.0, pixelization=al.m.MockPixelization()) tracer = al.Tracer(galaxies=[galaxy_no_pix, galaxy_pix_0, galaxy_pix_1]) + adapt_images = al.AdaptImages( + galaxy_image_dict={galaxy_pix_0: 2, galaxy_pix_1: 3}, + galaxy_image_plane_mesh_grid_dict={ + galaxy_pix_0: image_plane_mesh_grid_0, + galaxy_pix_1: image_plane_mesh_grid_1, + }, + ) + tracer_to_inversion = al.TracerToInversion( - dataset=masked_imaging_7x7, tracer=tracer + dataset=masked_imaging_7x7, tracer=tracer, adapt_images=adapt_images ) traced_mesh_grids_list_of_planes = tracer_to_inversion.traced_mesh_grid_pg_list @@ -337,9 +348,16 @@ def test__traced_mesh_grid_pg_list(masked_imaging_7x7): # Test Extra Galaxies - galaxy_pix_0 = al.Galaxy(redshift=1.0, pixelization=pixelization_0) + galaxy_pix_0 = al.Galaxy(redshift=1.0, pixelization=al.m.MockPixelization()) + galaxy_pix_1 = al.Galaxy(redshift=2.0, pixelization=al.m.MockPixelization()) - galaxy_pix_1 = al.Galaxy(redshift=2.0, pixelization=pixelization_1) + adapt_images = al.AdaptImages( + galaxy_image_dict={galaxy_pix_0: 2, galaxy_pix_1: 3}, + galaxy_image_plane_mesh_grid_dict={ + galaxy_pix_0: image_plane_mesh_grid_0, + galaxy_pix_1: image_plane_mesh_grid_1, + }, + ) galaxy_no_pix_0 = al.Galaxy( redshift=0.25, @@ -359,7 +377,7 @@ def test__traced_mesh_grid_pg_list(masked_imaging_7x7): ) tracer_to_inversion = al.TracerToInversion( - dataset=masked_imaging_7x7, tracer=tracer + dataset=masked_imaging_7x7, tracer=tracer, adapt_images=adapt_images ) traced_mesh_grids_list_of_planes = tracer_to_inversion.traced_mesh_grid_pg_list