diff --git a/autolens/aggregator/tracer.py b/autolens/aggregator/tracer.py index e6febfe1f..96eaba956 100644 --- a/autolens/aggregator/tracer.py +++ b/autolens/aggregator/tracer.py @@ -2,9 +2,14 @@ from typing import List, Optional import autofit as af +import autoarray as aa +import autogalaxy as ag from autolens.lens.tracer import Tracer +from autogalaxy.aggregator import agg_util +from autolens.lens import tracer_util + logger = logging.getLogger(__name__) @@ -18,16 +23,15 @@ def _tracer_from( attributes of the fit: - The model and its best fit parameters (e.g. `model.json`). - - The adapt images associated with adaptive galaxy features (`adapt` folder). Each individual attribute can be loaded from the database via the `fit.value()` method. - This method combines all of these attributes and returns a `Tracer` object for a given non-linear search sample - (e.g. the maximum likelihood model). This includes associating adapt images with their respective galaxies. + This method combines this attributesand returns a `Tracer` object for a given non-linear search sample + (e.g. the maximum likelihood model). - If multiple `Tracer` objects were fitted simultaneously via analysis summing, the `fit.child_values()` method - is instead used to load lists of Tracers. This is necessary if each Tracer has different galaxies (e.g. certain - parameters vary across each dataset and `Analysis` object). + If multiple `Tracer` objects were fitted simultaneously via multiple analysis, the instance is iterated over as + a list such that a list of `Tracer` objects with parameters updated for each analysis are returned. This means + fits using a single analysis are wrapped in a list to prodcue a consistent API. Parameters ---------- @@ -39,41 +43,46 @@ def _tracer_from( randomly from the PDF). """ - if instance is not None: + instance_list = agg_util.instance_list_from(fit=fit, instance=instance) + + tracer_list = [] + + for instance in instance_list: + galaxies = instance.galaxies if hasattr(instance, "extra_galaxies"): - if fit.instance.extra_galaxies is not None: - galaxies = galaxies + fit.instance.extra_galaxies - - else: - galaxies = fit.instance.galaxies - - if hasattr(fit.instance, "extra_galaxies"): - if fit.instance.extra_galaxies is not None: - galaxies = galaxies + fit.instance.extra_galaxies - - try: - cosmology = instance.cosmology - except AttributeError: - cosmology = fit.value(name="cosmology") - - tracer = Tracer(galaxies=galaxies, cosmology=cosmology) - - if fit.children is not None: - if len(fit.children) > 0: - logger.info( - """ - Using database for a fit with multiple summed Analysis objects. - - Tracer objects do not fully support this yet (e.g. model parameters which vary over analyses may be incorrect) - so proceed with caution! - """ + if instance.extra_galaxies is not None: + galaxies = galaxies + instance.extra_galaxies + + try: + cosmology = instance.cosmology + except AttributeError: + cosmology = fit.value(name="cosmology") + + if cosmology is None: + cosmology = ag.cosmo.Planck15() + + # TODO : These are ugly as hell (>_<) + + if hasattr(instance, "perturb"): + galaxies.subhalo = instance.perturb + + if hasattr(instance.galaxies, "subhalo"): + subhalo_centre = tracer_util.grid_2d_at_redshift_from( + galaxies=instance.galaxies, + redshift=instance.galaxies.subhalo.redshift, + grid=aa.Grid2DIrregular(values=[instance.galaxies.subhalo.mass.centre]), + cosmology=cosmology, ) - return [tracer] * len(fit.children) + galaxies.subhalo.mass.centre = tuple(subhalo_centre.in_list[0]) + + tracer = Tracer(galaxies=galaxies, cosmology=cosmology) + + tracer_list.append(tracer) - return [tracer] + return tracer_list class TracerAgg(af.AggBase): diff --git a/test_autolens/aggregator/test_aggregator_fit_imaging.py b/test_autolens/aggregator/test_aggregator_fit_imaging.py index 20aee69ce..d2bd588f5 100644 --- a/test_autolens/aggregator/test_aggregator_fit_imaging.py +++ b/test_autolens/aggregator/test_aggregator_fit_imaging.py @@ -35,36 +35,38 @@ def test__fit_imaging_randomly_drawn_via_pdf_gen_from( clean(database_file=database_file) -def test__fit_imaging_randomly_drawn_via_pdf_gen_from__analysis_multi( - analysis_imaging_7x7, samples, model -): - agg = aggregator_from( - database_file=database_file, - analysis=analysis_imaging_7x7 + analysis_imaging_7x7, - model=model, - samples=samples, - ) - - fit_agg = al.agg.FitImagingAgg(aggregator=agg) - fit_pdf_gen = fit_agg.randomly_drawn_via_pdf_gen_from(total_samples=2) - - i = 0 - - for fit_gen in fit_pdf_gen: - for fit_list in fit_gen: - i += 1 - - assert fit_list[0].tracer.galaxies[0].redshift == 0.5 - assert fit_list[0].tracer.galaxies[0].light.centre == (10.0, 10.0) - assert fit_list[0].tracer.galaxies[1].redshift == 1.0 - - assert fit_list[1].tracer.galaxies[0].redshift == 0.5 - assert fit_list[1].tracer.galaxies[0].light.centre == (10.0, 10.0) - assert fit_list[1].tracer.galaxies[1].redshift == 1.0 - - assert i == 2 - - clean(database_file=database_file) +# TODO : These need to use FactorGraphModel + +# def test__fit_imaging_randomly_drawn_via_pdf_gen_from__analysis_multi( +# analysis_imaging_7x7, samples, model +# ): +# agg = aggregator_from( +# database_file=database_file, +# analysis=analysis_imaging_7x7 + analysis_imaging_7x7, +# model=model, +# samples=samples, +# ) +# +# fit_agg = al.agg.FitImagingAgg(aggregator=agg) +# fit_pdf_gen = fit_agg.randomly_drawn_via_pdf_gen_from(total_samples=2) +# +# i = 0 +# +# for fit_gen in fit_pdf_gen: +# for fit_list in fit_gen: +# i += 1 +# +# assert fit_list[0].tracer.galaxies[0].redshift == 0.5 +# assert fit_list[0].tracer.galaxies[0].light.centre == (10.0, 10.0) +# assert fit_list[0].tracer.galaxies[1].redshift == 1.0 +# +# assert fit_list[1].tracer.galaxies[0].redshift == 0.5 +# assert fit_list[1].tracer.galaxies[0].light.centre == (10.0, 10.0) +# assert fit_list[1].tracer.galaxies[1].redshift == 1.0 +# +# assert i == 2 +# +# clean(database_file=database_file) def test__fit_imaging_all_above_weight_gen(analysis_imaging_7x7, samples, model): diff --git a/test_autolens/aggregator/test_aggregator_fit_interferometer.py b/test_autolens/aggregator/test_aggregator_fit_interferometer.py index ec05f3393..760fdd0bc 100644 --- a/test_autolens/aggregator/test_aggregator_fit_interferometer.py +++ b/test_autolens/aggregator/test_aggregator_fit_interferometer.py @@ -35,6 +35,8 @@ def test__fit_interferometer_randomly_drawn_via_pdf_gen_from( clean(database_file=database_file) +# TODO : These need to use FactorGraphModel + # def test__fit_interferometer_randomly_drawn_via_pdf_gen_from__analysis_multi(analysis_interferometer_7, samples, model): # # agg = aggregator_from(