Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 44 additions & 35 deletions autolens/aggregator/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
from typing import List, Optional

import autofit as af
import autoarray as aa
import autogalaxy as ag

from autolens.lens.tracer import Tracer

from autogalaxy.aggregator import agg_util
from autolens.lens import tracer_util

logger = logging.getLogger(__name__)


Expand All @@ -18,16 +23,15 @@ def _tracer_from(
attributes of the fit:

- The model and its best fit parameters (e.g. `model.json`).
- The adapt images associated with adaptive galaxy features (`adapt` folder).

Each individual attribute can be loaded from the database via the `fit.value()` method.

This method combines all of these attributes and returns a `Tracer` object for a given non-linear search sample
(e.g. the maximum likelihood model). This includes associating adapt images with their respective galaxies.
This method combines this attributesand returns a `Tracer` object for a given non-linear search sample
(e.g. the maximum likelihood model).

If multiple `Tracer` objects were fitted simultaneously via analysis summing, the `fit.child_values()` method
is instead used to load lists of Tracers. This is necessary if each Tracer has different galaxies (e.g. certain
parameters vary across each dataset and `Analysis` object).
If multiple `Tracer` objects were fitted simultaneously via multiple analysis, the instance is iterated over as
a list such that a list of `Tracer` objects with parameters updated for each analysis are returned. This means
fits using a single analysis are wrapped in a list to prodcue a consistent API.

Parameters
----------
Expand All @@ -39,41 +43,46 @@ def _tracer_from(
randomly from the PDF).
"""

if instance is not None:
instance_list = agg_util.instance_list_from(fit=fit, instance=instance)

tracer_list = []

for instance in instance_list:

galaxies = instance.galaxies

if hasattr(instance, "extra_galaxies"):
if fit.instance.extra_galaxies is not None:
galaxies = galaxies + fit.instance.extra_galaxies

else:
galaxies = fit.instance.galaxies

if hasattr(fit.instance, "extra_galaxies"):
if fit.instance.extra_galaxies is not None:
galaxies = galaxies + fit.instance.extra_galaxies

try:
cosmology = instance.cosmology
except AttributeError:
cosmology = fit.value(name="cosmology")

tracer = Tracer(galaxies=galaxies, cosmology=cosmology)

if fit.children is not None:
if len(fit.children) > 0:
logger.info(
"""
Using database for a fit with multiple summed Analysis objects.

Tracer objects do not fully support this yet (e.g. model parameters which vary over analyses may be incorrect)
so proceed with caution!
"""
if instance.extra_galaxies is not None:
galaxies = galaxies + instance.extra_galaxies

try:
cosmology = instance.cosmology
except AttributeError:
cosmology = fit.value(name="cosmology")

if cosmology is None:
cosmology = ag.cosmo.Planck15()

# TODO : These are ugly as hell (>_<)

if hasattr(instance, "perturb"):
galaxies.subhalo = instance.perturb

if hasattr(instance.galaxies, "subhalo"):
subhalo_centre = tracer_util.grid_2d_at_redshift_from(
galaxies=instance.galaxies,
redshift=instance.galaxies.subhalo.redshift,
grid=aa.Grid2DIrregular(values=[instance.galaxies.subhalo.mass.centre]),
cosmology=cosmology,
)

return [tracer] * len(fit.children)
galaxies.subhalo.mass.centre = tuple(subhalo_centre.in_list[0])

tracer = Tracer(galaxies=galaxies, cosmology=cosmology)

tracer_list.append(tracer)

return [tracer]
return tracer_list


class TracerAgg(af.AggBase):
Expand Down
62 changes: 32 additions & 30 deletions test_autolens/aggregator/test_aggregator_fit_imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,38 @@ def test__fit_imaging_randomly_drawn_via_pdf_gen_from(
clean(database_file=database_file)


def test__fit_imaging_randomly_drawn_via_pdf_gen_from__analysis_multi(
analysis_imaging_7x7, samples, model
):
agg = aggregator_from(
database_file=database_file,
analysis=analysis_imaging_7x7 + analysis_imaging_7x7,
model=model,
samples=samples,
)

fit_agg = al.agg.FitImagingAgg(aggregator=agg)
fit_pdf_gen = fit_agg.randomly_drawn_via_pdf_gen_from(total_samples=2)

i = 0

for fit_gen in fit_pdf_gen:
for fit_list in fit_gen:
i += 1

assert fit_list[0].tracer.galaxies[0].redshift == 0.5
assert fit_list[0].tracer.galaxies[0].light.centre == (10.0, 10.0)
assert fit_list[0].tracer.galaxies[1].redshift == 1.0

assert fit_list[1].tracer.galaxies[0].redshift == 0.5
assert fit_list[1].tracer.galaxies[0].light.centre == (10.0, 10.0)
assert fit_list[1].tracer.galaxies[1].redshift == 1.0

assert i == 2

clean(database_file=database_file)
# TODO : These need to use FactorGraphModel

# def test__fit_imaging_randomly_drawn_via_pdf_gen_from__analysis_multi(
# analysis_imaging_7x7, samples, model
# ):
# agg = aggregator_from(
# database_file=database_file,
# analysis=analysis_imaging_7x7 + analysis_imaging_7x7,
# model=model,
# samples=samples,
# )
#
# fit_agg = al.agg.FitImagingAgg(aggregator=agg)
# fit_pdf_gen = fit_agg.randomly_drawn_via_pdf_gen_from(total_samples=2)
#
# i = 0
#
# for fit_gen in fit_pdf_gen:
# for fit_list in fit_gen:
# i += 1
#
# assert fit_list[0].tracer.galaxies[0].redshift == 0.5
# assert fit_list[0].tracer.galaxies[0].light.centre == (10.0, 10.0)
# assert fit_list[0].tracer.galaxies[1].redshift == 1.0
#
# assert fit_list[1].tracer.galaxies[0].redshift == 0.5
# assert fit_list[1].tracer.galaxies[0].light.centre == (10.0, 10.0)
# assert fit_list[1].tracer.galaxies[1].redshift == 1.0
#
# assert i == 2
#
# clean(database_file=database_file)


def test__fit_imaging_all_above_weight_gen(analysis_imaging_7x7, samples, model):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def test__fit_interferometer_randomly_drawn_via_pdf_gen_from(
clean(database_file=database_file)


# TODO : These need to use FactorGraphModel

# def test__fit_interferometer_randomly_drawn_via_pdf_gen_from__analysis_multi(analysis_interferometer_7, samples, model):
#
# agg = aggregator_from(
Expand Down
Loading