Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions autogalaxy/aggregator/agg_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,52 @@
from autogalaxy.analysis.adapt_images.adapt_images import AdaptImages


def instance_list_from(
fit: af.Fit, instance: Optional[af.ModelInstance] = None
) -> List[af.ModelInstance]:
"""
Returns the list of instances of the maximum likelihood model, depending on the model composition and whether
multiple `Analysis` objects were fitted simultaneously.

This if loop accounts for 4 scenarios:

- A single `Analysis` object was fitted, in which case the instance is a single object and converted to a list.

- Multiple `Analysis` objects were fitted via a `FactorGraphModel`, in which case the instance is a list of
objects and all but the last object (which is the overall `FactorGraphModel` are returned.

- A single instance is manually input, in which case it is converted to a list.

- Multiple `Analysis` objects were fitted via a `FactorGraphModel`, in which case the instance is a list of
objects and all but the last object (which is the overall `FactorGraphModel` are returned.

Parameters
----------
fit
A `PyAutoFit` `Fit` object which contains the results of a model-fit as an entry which has been loaded from
an output directory or from an sqlite database.
instance
An optional instance that overwrites the max log likelihood instance in fit (e.g. for drawing the instance
randomly from the PDF).

Returns
-------
The list of instances of the maximum likelihood model.
"""

if instance is None:
if len(fit.children) == 0:
return [fit.instance]
return fit.instance[
0:-1
] # [0:-1] excludes the last instance, which is the `FactorGraphModel` object itself.

if isinstance(list(instance.child_items.values())[-1], af.FactorGraphModel):
return list(instance.child_items.values())[0:-1]

return [instance]


def mask_header_from(fit, name="dataset"):
"""
Returns the mask, header and pixel scales of the `PyAutoFit` `Fit` object.
Expand Down
38 changes: 13 additions & 25 deletions autogalaxy/aggregator/dataset_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import autofit as af
import autoarray as aa

from autogalaxy.aggregator import agg_util

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -38,28 +40,14 @@ def _dataset_model_from(
randomly from the PDF).
"""

if instance is not None:
try:
dataset_model = instance.dataset_model
except AttributeError:
dataset_model = None
else:
try:
dataset_model = fit.instance.dataset_model
except AttributeError:
dataset_model = None

if fit.children is not None:
if len(fit.children) > 0:
logger.info(
"""
Using database for a fit with multiple summed Analysis objects.

DatasetModel objects do not fully support this yet (e.g. variables across Analysis objects may not be correct)
so proceed with caution!
"""
)

return [dataset_model] * len(fit.children)

return [dataset_model]
instance_list = agg_util.instance_list_from(fit=fit, instance=instance)

dataset_model_list = []

for instance in instance_list:

dataset_model = instance.dataset_model

dataset_model_list.append(dataset_model)

return dataset_model_list
38 changes: 13 additions & 25 deletions autogalaxy/aggregator/galaxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import autofit as af

from autogalaxy.aggregator import agg_util

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,34 +41,21 @@ def _galaxies_from(fit: af.Fit, instance: af.ModelInstance) -> List[Galaxy]:
randomly from the PDF).
"""

if instance is not None:
instance_list = agg_util.instance_list_from(fit=fit, instance=instance)

galaxies_list = []

for instance in instance_list:

galaxies = instance.galaxies

if hasattr(instance, "extra_galaxies"):
if fit.instance.extra_galaxies is not None:
galaxies = galaxies + fit.instance.extra_galaxies

else:
galaxies = fit.instance.galaxies

if hasattr(fit.instance, "extra_galaxies"):
if fit.instance.extra_galaxies is not None:
galaxies = galaxies + fit.instance.extra_galaxies

if fit.children is not None:
if len(fit.children) > 0:
logger.info(
"""
Using database for a fit with multiple summed Analysis objects.

Galaxy objects do not fully support this yet (e.g. variables across Analysis objects may not be correct)
so proceed with caution!
"""
)

return [galaxies] * len(fit.children)

return [galaxies]
if instance.extra_galaxies is not None:
galaxies = galaxies + instance.extra_galaxies

galaxies_list.append(galaxies)

return galaxies_list


class GalaxiesAgg(af.AggBase):
Expand Down
7 changes: 4 additions & 3 deletions autogalaxy/aggregator/imaging/imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,10 @@ def values_from(hdu: int, cls):
data = values_from(hdu=1, cls=aa.Array2D)
noise_map = values_from(hdu=2, cls=aa.Array2D)

psf = values_from(hdu=3, cls=aa.Kernel2D)

print(psf)
try:
psf = values_from(hdu=3, cls=aa.Kernel2D)
except (TypeError, IndexError):
psf = None

dataset = aa.Imaging(
data=data,
Expand Down
122 changes: 93 additions & 29 deletions test_autogalaxy/aggregator/imaging/test_aggregator_fit_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,58 +39,122 @@ def make_agg_7x7(samples, model, analysis_imaging_7x7):
return agg


def test__fit_imaging_randomly_drawn_via_pdf_gen_from__analysis_has_single_dataset(
def test__fit_imaging__max_log_likelihood__analysis_has_single_dataset(
agg_7x7,
):
fit_agg = ag.agg.FitImagingAgg(aggregator=agg_7x7)
fit_pdf_gen = fit_agg.randomly_drawn_via_pdf_gen_from(total_samples=2)

i = 0

for fit_gen in fit_pdf_gen:
for fit_list in fit_gen:
i += 1
fit_max_lh_gen = fit_agg.max_log_likelihood_gen_from()

assert fit_list[0].galaxies[0].redshift == 0.5
assert fit_list[0].galaxies[0].light.centre == (10.0, 10.0)
for (
fit_list
) in (
fit_max_lh_gen
): # Only Max LH sample so fit_list contains 1 lists of a single fit.

assert fit_list[0].dataset_model.background_sky_level == 10.0
assert fit_list[0].galaxies[0].redshift == 0.5
assert fit_list[0].galaxies[0].light.centre == (10.0, 10.0)

assert i == 2
assert fit_list[0].dataset_model.background_sky_level == 10.0

clean(database_file=database_file)


def test__fit_imaging_randomly_drawn_via_pdf_gen_from__analysis_multi(
analysis_imaging_7x7, samples, model
def test__fit_imaging__randomly_drawn_via_pdf_gen_from__analysis_has_single_dataset(
agg_7x7,
):
agg = aggregator_from(
database_file=database_file,
analysis=analysis_imaging_7x7 + analysis_imaging_7x7,
model=model,
samples=samples,
)

fit_agg = ag.agg.FitImagingAgg(aggregator=agg)
fit_pdf_gen = fit_agg.randomly_drawn_via_pdf_gen_from(total_samples=2)
fit_agg = ag.agg.FitImagingAgg(aggregator=agg_7x7)
fit_pdf_gen = fit_agg.randomly_drawn_via_pdf_gen_from(total_samples=3)

i = 0

for fit_gen in fit_pdf_gen:
for fit_list in fit_gen:
for fit_list_gen in fit_pdf_gen: # 1 Dataset so just one fit
for (
fit_list
) in (
fit_list_gen
): # Iterate over each total_samples=3, each with two fits for each analysis.

i += 1

# Check fit for each `Analysis` so take first and only dataset.
assert fit_list[0].galaxies[0].redshift == 0.5
assert fit_list[0].galaxies[0].light.centre == (10.0, 10.0)
assert fit_list[0].dataset_model.background_sky_level == 10.0

assert fit_list[1].galaxies[0].redshift == 0.5
assert fit_list[1].galaxies[0].light.centre == (10.0, 10.0)

assert i == 2
assert i == 3

clean(database_file=database_file)


# TODO : These need to use FactorGraphModel


# def test__fit_imaging__max_log_likelihood_gen_from__analysis_multi(
# analysis_imaging_7x7, samples, model
# ):
#
# analysis_factor_list = []
#
# for i, analysis in enumerate([analysis_imaging_7x7, analysis_imaging_7x7]):
#
# model_analysis = model.copy()
# analysis_factor = af.AnalysisFactor(prior_model=model_analysis, analysis=analysis)
#
# analysis_factor_list.append(analysis_factor)
#
# factor_graph = af.FactorGraphModel(*analysis_factor_list)
#
# agg = aggregator_from(
# database_file=database_file,
# analysis=factor_graph,
# model=factor_graph.global_prior_model,
# samples=samples,
# )
#
# fit_agg = ag.agg.FitImagingAgg(aggregator=agg)
#
# fit_max_lh_gen = fit_agg.max_log_likelihood_gen_from()
#
# for fit_list in fit_max_lh_gen: # Only Max LH sample so fit_list contains 1 lists of the 2 fit (one for each analysis).
#
# assert fit_list[0].galaxies[0].redshift == 0.5
# assert fit_list[0].galaxies[0].light.centre == (10.0, 10.0)
#
# assert fit_list[1].galaxies[0].redshift == 0.5
# assert fit_list[1].galaxies[0].light.centre == (10.0, 10.0)
#
# clean(database_file=database_file)
#
# def test__fit_imaging__randomly_drawn_via_pdf_gen_from__analysis_multi(
# analysis_imaging_7x7, samples, model
# ):
# agg = aggregator_from(
# database_file=database_file,
# analysis=analysis_imaging_7x7 + analysis_imaging_7x7,
# model=model,
# samples=samples,
# )
#
# fit_agg = ag.agg.FitImagingAgg(aggregator=agg)
# fit_pdf_gen = fit_agg.randomly_drawn_via_pdf_gen_from(total_samples=3)
#
# i = 0
#
# for fit_gen in fit_pdf_gen: # 1 Dataset so just one fit
# for fit_list in fit_gen: # Iterate over each total_samples=3, each with two fits for each analysis.
# i += 1
#
# assert fit_list[0].galaxies[0].redshift == 0.5
# assert fit_list[0].galaxies[0].light.centre == (10.0, 10.0)
#
# assert fit_list[1].galaxies[0].redshift == 0.5
# assert fit_list[1].galaxies[0].light.centre == (10.0, 10.0)
#
# assert i == 3
#
# clean(database_file=database_file)


def test__fit_imaging_all_above_weight_gen(agg_7x7):
fit_agg = ag.agg.FitImagingAgg(aggregator=agg_7x7)
fit_pdf_gen = fit_agg.all_above_weight_gen_from(minimum_weight=-1.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def test__fit_interferometer_randomly_drawn_via_pdf_gen_from(
clean(database_file=database_file)


# TODO : These need to use FactorGraphModel

# def test__fit_interferometer_randomly_drawn_via_pdf_gen_from__analysis_multi(
# analysis_interferometer_7, samples, model
# ):
Expand Down
Loading