diff --git a/autogalaxy/aggregator/agg_util.py b/autogalaxy/aggregator/agg_util.py index 035f3947d..78b88eec6 100644 --- a/autogalaxy/aggregator/agg_util.py +++ b/autogalaxy/aggregator/agg_util.py @@ -11,6 +11,52 @@ from autogalaxy.analysis.adapt_images.adapt_images import AdaptImages +def instance_list_from( + fit: af.Fit, instance: Optional[af.ModelInstance] = None +) -> List[af.ModelInstance]: + """ + Returns the list of instances of the maximum likelihood model, depending on the model composition and whether + multiple `Analysis` objects were fitted simultaneously. + + This if loop accounts for 4 scenarios: + + - A single `Analysis` object was fitted, in which case the instance is a single object and converted to a list. + + - Multiple `Analysis` objects were fitted via a `FactorGraphModel`, in which case the instance is a list of + objects and all but the last object (which is the overall `FactorGraphModel` are returned. + + - A single instance is manually input, in which case it is converted to a list. + + - Multiple `Analysis` objects were fitted via a `FactorGraphModel`, in which case the instance is a list of + objects and all but the last object (which is the overall `FactorGraphModel` are returned. + + Parameters + ---------- + fit + A `PyAutoFit` `Fit` object which contains the results of a model-fit as an entry which has been loaded from + an output directory or from an sqlite database. + instance + An optional instance that overwrites the max log likelihood instance in fit (e.g. for drawing the instance + randomly from the PDF). + + Returns + ------- + The list of instances of the maximum likelihood model. + """ + + if instance is None: + if len(fit.children) == 0: + return [fit.instance] + return fit.instance[ + 0:-1 + ] # [0:-1] excludes the last instance, which is the `FactorGraphModel` object itself. + + if isinstance(list(instance.child_items.values())[-1], af.FactorGraphModel): + return list(instance.child_items.values())[0:-1] + + return [instance] + + def mask_header_from(fit, name="dataset"): """ Returns the mask, header and pixel scales of the `PyAutoFit` `Fit` object. diff --git a/autogalaxy/aggregator/dataset_model.py b/autogalaxy/aggregator/dataset_model.py index 130b00abd..f924e87ac 100644 --- a/autogalaxy/aggregator/dataset_model.py +++ b/autogalaxy/aggregator/dataset_model.py @@ -5,6 +5,8 @@ import autofit as af import autoarray as aa +from autogalaxy.aggregator import agg_util + logger = logging.getLogger(__name__) @@ -38,28 +40,14 @@ def _dataset_model_from( randomly from the PDF). """ - if instance is not None: - try: - dataset_model = instance.dataset_model - except AttributeError: - dataset_model = None - else: - try: - dataset_model = fit.instance.dataset_model - except AttributeError: - dataset_model = None - - if fit.children is not None: - if len(fit.children) > 0: - logger.info( - """ - Using database for a fit with multiple summed Analysis objects. - - DatasetModel objects do not fully support this yet (e.g. variables across Analysis objects may not be correct) - so proceed with caution! - """ - ) - - return [dataset_model] * len(fit.children) - - return [dataset_model] + instance_list = agg_util.instance_list_from(fit=fit, instance=instance) + + dataset_model_list = [] + + for instance in instance_list: + + dataset_model = instance.dataset_model + + dataset_model_list.append(dataset_model) + + return dataset_model_list diff --git a/autogalaxy/aggregator/galaxies.py b/autogalaxy/aggregator/galaxies.py index 60dbd3779..492464723 100644 --- a/autogalaxy/aggregator/galaxies.py +++ b/autogalaxy/aggregator/galaxies.py @@ -7,6 +7,7 @@ import autofit as af +from autogalaxy.aggregator import agg_util logger = logging.getLogger(__name__) @@ -40,34 +41,21 @@ def _galaxies_from(fit: af.Fit, instance: af.ModelInstance) -> List[Galaxy]: randomly from the PDF). """ - if instance is not None: + instance_list = agg_util.instance_list_from(fit=fit, instance=instance) + + galaxies_list = [] + + for instance in instance_list: + galaxies = instance.galaxies if hasattr(instance, "extra_galaxies"): - if fit.instance.extra_galaxies is not None: - galaxies = galaxies + fit.instance.extra_galaxies - - else: - galaxies = fit.instance.galaxies - - if hasattr(fit.instance, "extra_galaxies"): - if fit.instance.extra_galaxies is not None: - galaxies = galaxies + fit.instance.extra_galaxies - - if fit.children is not None: - if len(fit.children) > 0: - logger.info( - """ - Using database for a fit with multiple summed Analysis objects. - - Galaxy objects do not fully support this yet (e.g. variables across Analysis objects may not be correct) - so proceed with caution! - """ - ) - - return [galaxies] * len(fit.children) - - return [galaxies] + if instance.extra_galaxies is not None: + galaxies = galaxies + instance.extra_galaxies + + galaxies_list.append(galaxies) + + return galaxies_list class GalaxiesAgg(af.AggBase): diff --git a/autogalaxy/aggregator/imaging/imaging.py b/autogalaxy/aggregator/imaging/imaging.py index a4d727108..dae4ed718 100644 --- a/autogalaxy/aggregator/imaging/imaging.py +++ b/autogalaxy/aggregator/imaging/imaging.py @@ -58,9 +58,10 @@ def values_from(hdu: int, cls): data = values_from(hdu=1, cls=aa.Array2D) noise_map = values_from(hdu=2, cls=aa.Array2D) - psf = values_from(hdu=3, cls=aa.Kernel2D) - - print(psf) + try: + psf = values_from(hdu=3, cls=aa.Kernel2D) + except (TypeError, IndexError): + psf = None dataset = aa.Imaging( data=data, diff --git a/test_autogalaxy/aggregator/imaging/test_aggregator_fit_imaging.py b/test_autogalaxy/aggregator/imaging/test_aggregator_fit_imaging.py index 216f9d105..9c7db3ead 100644 --- a/test_autogalaxy/aggregator/imaging/test_aggregator_fit_imaging.py +++ b/test_autogalaxy/aggregator/imaging/test_aggregator_fit_imaging.py @@ -39,58 +39,122 @@ def make_agg_7x7(samples, model, analysis_imaging_7x7): return agg -def test__fit_imaging_randomly_drawn_via_pdf_gen_from__analysis_has_single_dataset( +def test__fit_imaging__max_log_likelihood__analysis_has_single_dataset( agg_7x7, ): fit_agg = ag.agg.FitImagingAgg(aggregator=agg_7x7) - fit_pdf_gen = fit_agg.randomly_drawn_via_pdf_gen_from(total_samples=2) - - i = 0 - - for fit_gen in fit_pdf_gen: - for fit_list in fit_gen: - i += 1 + fit_max_lh_gen = fit_agg.max_log_likelihood_gen_from() - assert fit_list[0].galaxies[0].redshift == 0.5 - assert fit_list[0].galaxies[0].light.centre == (10.0, 10.0) + for ( + fit_list + ) in ( + fit_max_lh_gen + ): # Only Max LH sample so fit_list contains 1 lists of a single fit. - assert fit_list[0].dataset_model.background_sky_level == 10.0 + assert fit_list[0].galaxies[0].redshift == 0.5 + assert fit_list[0].galaxies[0].light.centre == (10.0, 10.0) - assert i == 2 + assert fit_list[0].dataset_model.background_sky_level == 10.0 clean(database_file=database_file) -def test__fit_imaging_randomly_drawn_via_pdf_gen_from__analysis_multi( - analysis_imaging_7x7, samples, model +def test__fit_imaging__randomly_drawn_via_pdf_gen_from__analysis_has_single_dataset( + agg_7x7, ): - agg = aggregator_from( - database_file=database_file, - analysis=analysis_imaging_7x7 + analysis_imaging_7x7, - model=model, - samples=samples, - ) - - fit_agg = ag.agg.FitImagingAgg(aggregator=agg) - fit_pdf_gen = fit_agg.randomly_drawn_via_pdf_gen_from(total_samples=2) + fit_agg = ag.agg.FitImagingAgg(aggregator=agg_7x7) + fit_pdf_gen = fit_agg.randomly_drawn_via_pdf_gen_from(total_samples=3) i = 0 - for fit_gen in fit_pdf_gen: - for fit_list in fit_gen: + for fit_list_gen in fit_pdf_gen: # 1 Dataset so just one fit + for ( + fit_list + ) in ( + fit_list_gen + ): # Iterate over each total_samples=3, each with two fits for each analysis. + i += 1 + # Check fit for each `Analysis` so take first and only dataset. assert fit_list[0].galaxies[0].redshift == 0.5 assert fit_list[0].galaxies[0].light.centre == (10.0, 10.0) + assert fit_list[0].dataset_model.background_sky_level == 10.0 - assert fit_list[1].galaxies[0].redshift == 0.5 - assert fit_list[1].galaxies[0].light.centre == (10.0, 10.0) - - assert i == 2 + assert i == 3 clean(database_file=database_file) +# TODO : These need to use FactorGraphModel + + +# def test__fit_imaging__max_log_likelihood_gen_from__analysis_multi( +# analysis_imaging_7x7, samples, model +# ): +# +# analysis_factor_list = [] +# +# for i, analysis in enumerate([analysis_imaging_7x7, analysis_imaging_7x7]): +# +# model_analysis = model.copy() +# analysis_factor = af.AnalysisFactor(prior_model=model_analysis, analysis=analysis) +# +# analysis_factor_list.append(analysis_factor) +# +# factor_graph = af.FactorGraphModel(*analysis_factor_list) +# +# agg = aggregator_from( +# database_file=database_file, +# analysis=factor_graph, +# model=factor_graph.global_prior_model, +# samples=samples, +# ) +# +# fit_agg = ag.agg.FitImagingAgg(aggregator=agg) +# +# fit_max_lh_gen = fit_agg.max_log_likelihood_gen_from() +# +# for fit_list in fit_max_lh_gen: # Only Max LH sample so fit_list contains 1 lists of the 2 fit (one for each analysis). +# +# assert fit_list[0].galaxies[0].redshift == 0.5 +# assert fit_list[0].galaxies[0].light.centre == (10.0, 10.0) +# +# assert fit_list[1].galaxies[0].redshift == 0.5 +# assert fit_list[1].galaxies[0].light.centre == (10.0, 10.0) +# +# clean(database_file=database_file) +# +# def test__fit_imaging__randomly_drawn_via_pdf_gen_from__analysis_multi( +# analysis_imaging_7x7, samples, model +# ): +# agg = aggregator_from( +# database_file=database_file, +# analysis=analysis_imaging_7x7 + analysis_imaging_7x7, +# model=model, +# samples=samples, +# ) +# +# fit_agg = ag.agg.FitImagingAgg(aggregator=agg) +# fit_pdf_gen = fit_agg.randomly_drawn_via_pdf_gen_from(total_samples=3) +# +# i = 0 +# +# for fit_gen in fit_pdf_gen: # 1 Dataset so just one fit +# for fit_list in fit_gen: # Iterate over each total_samples=3, each with two fits for each analysis. +# i += 1 +# +# assert fit_list[0].galaxies[0].redshift == 0.5 +# assert fit_list[0].galaxies[0].light.centre == (10.0, 10.0) +# +# assert fit_list[1].galaxies[0].redshift == 0.5 +# assert fit_list[1].galaxies[0].light.centre == (10.0, 10.0) +# +# assert i == 3 +# +# clean(database_file=database_file) + + def test__fit_imaging_all_above_weight_gen(agg_7x7): fit_agg = ag.agg.FitImagingAgg(aggregator=agg_7x7) fit_pdf_gen = fit_agg.all_above_weight_gen_from(minimum_weight=-1.0) diff --git a/test_autogalaxy/aggregator/interferometer/test_aggregator_fit_interferometer.py b/test_autogalaxy/aggregator/interferometer/test_aggregator_fit_interferometer.py index 93b00ff46..b737568fe 100644 --- a/test_autogalaxy/aggregator/interferometer/test_aggregator_fit_interferometer.py +++ b/test_autogalaxy/aggregator/interferometer/test_aggregator_fit_interferometer.py @@ -40,6 +40,8 @@ def test__fit_interferometer_randomly_drawn_via_pdf_gen_from( clean(database_file=database_file) +# TODO : These need to use FactorGraphModel + # def test__fit_interferometer_randomly_drawn_via_pdf_gen_from__analysis_multi( # analysis_interferometer_7, samples, model # ):