Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 73 additions & 41 deletions autofit/graphical/declarative/collection.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from typing import Union
from typing import Union, Optional

from autofit.graphical.declarative.factor.hierarchical import HierarchicalFactor
from autofit.mapper.model import ModelInstance
from autofit.tools.namer import namer
from .abstract import AbstractDeclarativeFactor
from autofit.non_linear.paths.abstract import AbstractPaths
from autofit.non_linear.samples.pdf import SamplesPDF
from autofit.non_linear.samples.summary import SamplesSummary
from autofit.non_linear.analysis.combined import CombinedResult


class FactorGraphModel(AbstractDeclarativeFactor):
def __init__(
self,
*model_factors: Union[
AbstractDeclarativeFactor,
HierarchicalFactor
],
name=None,
include_prior_factors=True,
self,
*model_factors: Union[AbstractDeclarativeFactor, HierarchicalFactor],
name=None,
include_prior_factors=True,
):
"""
A collection of factors that describe models, which can be
Expand All @@ -40,11 +41,10 @@ def prior_model(self):
in each model factor
"""
from autofit.mapper.prior_model.collection import Collection
return Collection({
factor.name: factor.prior_model
for factor
in self.model_factors
})

return Collection(
{factor.name: factor.prior_model for factor in self.model_factors}
)

@property
def optimiser(self):
Expand All @@ -61,21 +61,13 @@ def info(self) -> str:
def name(self):
return self._name

def add(
self,
model_factor: AbstractDeclarativeFactor
):
def add(self, model_factor: AbstractDeclarativeFactor):
"""
Add another factor to this collection.
"""
self._model_factors.append(
model_factor
)
self._model_factors.append(model_factor)

def log_likelihood_function(
self,
instance: ModelInstance
) -> float:
def log_likelihood_function(self, instance: ModelInstance) -> float:
"""
Compute the combined likelihood of each factor from a collection of instances
with the same ordering as the factors.
Expand All @@ -90,29 +82,69 @@ def log_likelihood_function(
The combined likelihood of all factors
"""
log_likelihood = 0
for model_factor, instance_ in zip(
self.model_factors,
instance
):
log_likelihood += model_factor.log_likelihood_function(
instance_
)
for model_factor, instance_ in zip(self.model_factors, instance):
log_likelihood += model_factor.log_likelihood_function(instance_)

return log_likelihood

@property
def model_factors(self):
model_factors = list()
for model_factor in self._model_factors:
if isinstance(
model_factor,
HierarchicalFactor
):
model_factors.extend(
model_factor.factors
)
if isinstance(model_factor, HierarchicalFactor):
model_factors.extend(model_factor.factors)
else:
model_factors.append(
model_factor
)
model_factors.append(model_factor)
return model_factors

def make_result(
self,
samples_summary: SamplesSummary,
paths: AbstractPaths,
samples: Optional[SamplesPDF] = None,
search_internal: Optional[object] = None,
analysis: Optional[object] = None,
) -> CombinedResult:
"""
Make a result from the samples summary and paths.

The top level result accounts for the combined model.
There is one child result for each model factor.

Parameters
----------
samples_summary
A summary of the samples
paths
Handles saving and loading data
samples
The full list of samples
search_internal
analysis

Returns
-------
A result with child results for each model factor
"""
child_results = [
model_factor.analysis.make_result(
samples_summary=samples_summary.subsamples(
model_factor.prior_model,
),
paths=paths,
samples=samples.subsamples(model_factor.prior_model)
if samples
else None,
search_internal=search_internal,
analysis=model_factor,
)
for model_factor in self.model_factors
]
return CombinedResult(
child_results,
samples_summary=samples_summary,
paths=paths,
samples=samples,
search_internal=search_internal,
analysis=analysis,
)
16 changes: 12 additions & 4 deletions autofit/non_linear/analysis/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
logger = logging.getLogger(__name__)


class CombinedResult:
class CombinedResult(Result):
def __init__(
self,
results: List[Result],
samples: Optional[SamplesPDF] = None,
samples_summary: Optional[SamplesSummary] = None,
paths: Optional[AbstractPaths] = None,
search_internal: Optional[object] = None,
analysis: Optional[Analysis] = None,
):
"""
A `Result` object that is composed of multiple `Result` objects. This is used to combine the results of
Expand All @@ -32,9 +35,14 @@ def __init__(
results
The list of `Result` objects that are combined into this `CombinedResult` object.
"""
super().__init__(
samples_summary=samples_summary,
samples=samples,
paths=paths,
search_internal=search_internal,
analysis=analysis,
)
self.child_results = results
self.samples = samples
self.samples_summary = samples_summary

def __getattr__(self, item: str):
"""
Expand Down Expand Up @@ -415,4 +423,4 @@ def with_free_parameters(
return FreeParameterAnalysis(*self.analyses, free_parameters=free_parameters)

def compute_latent_samples(self, samples):
return self.analyses[0].compute_latent_samples(samples=samples)
return self.analyses[0].compute_latent_samples(samples=samples)
33 changes: 33 additions & 0 deletions test_autofit/graphical/test_combined_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import autofit as af
from autofit.non_linear.paths.null import NullPaths


def test_make_result():
model = af.Model(af.Gaussian)
factor_graph_model = af.FactorGraphModel(
af.AnalysisFactor(
model,
af.Analysis(),
)
)
result = factor_graph_model.make_result(
samples_summary=af.SamplesSummary(
max_log_likelihood_sample=af.Sample(
0,
0,
0,
kwargs={
("0", "centre"): 1.0,
("0", "normalization"): 1.0,
("0", "sigma"): 1.0,
},
),
model=af.Collection(model),
),
paths=NullPaths(),
)
assert len(result.child_results) == 1
assert isinstance(result.model, af.Collection)

(child_result,) = result.child_results
assert child_result.model == model
Loading