diff --git a/autofit/graphical/declarative/collection.py b/autofit/graphical/declarative/collection.py index 9bad8fb29..341e51c24 100644 --- a/autofit/graphical/declarative/collection.py +++ b/autofit/graphical/declarative/collection.py @@ -1,20 +1,21 @@ -from typing import Union +from typing import Union, Optional from autofit.graphical.declarative.factor.hierarchical import HierarchicalFactor from autofit.mapper.model import ModelInstance from autofit.tools.namer import namer from .abstract import AbstractDeclarativeFactor +from autofit.non_linear.paths.abstract import AbstractPaths +from autofit.non_linear.samples.pdf import SamplesPDF +from autofit.non_linear.samples.summary import SamplesSummary +from autofit.non_linear.analysis.combined import CombinedResult class FactorGraphModel(AbstractDeclarativeFactor): def __init__( - self, - *model_factors: Union[ - AbstractDeclarativeFactor, - HierarchicalFactor - ], - name=None, - include_prior_factors=True, + self, + *model_factors: Union[AbstractDeclarativeFactor, HierarchicalFactor], + name=None, + include_prior_factors=True, ): """ A collection of factors that describe models, which can be @@ -40,11 +41,10 @@ def prior_model(self): in each model factor """ from autofit.mapper.prior_model.collection import Collection - return Collection({ - factor.name: factor.prior_model - for factor - in self.model_factors - }) + + return Collection( + {factor.name: factor.prior_model for factor in self.model_factors} + ) @property def optimiser(self): @@ -61,21 +61,13 @@ def info(self) -> str: def name(self): return self._name - def add( - self, - model_factor: AbstractDeclarativeFactor - ): + def add(self, model_factor: AbstractDeclarativeFactor): """ Add another factor to this collection. """ - self._model_factors.append( - model_factor - ) + self._model_factors.append(model_factor) - def log_likelihood_function( - self, - instance: ModelInstance - ) -> float: + def log_likelihood_function(self, instance: ModelInstance) -> float: """ Compute the combined likelihood of each factor from a collection of instances with the same ordering as the factors. @@ -90,13 +82,8 @@ def log_likelihood_function( The combined likelihood of all factors """ log_likelihood = 0 - for model_factor, instance_ in zip( - self.model_factors, - instance - ): - log_likelihood += model_factor.log_likelihood_function( - instance_ - ) + for model_factor, instance_ in zip(self.model_factors, instance): + log_likelihood += model_factor.log_likelihood_function(instance_) return log_likelihood @@ -104,15 +91,60 @@ def log_likelihood_function( def model_factors(self): model_factors = list() for model_factor in self._model_factors: - if isinstance( - model_factor, - HierarchicalFactor - ): - model_factors.extend( - model_factor.factors - ) + if isinstance(model_factor, HierarchicalFactor): + model_factors.extend(model_factor.factors) else: - model_factors.append( - model_factor - ) + model_factors.append(model_factor) return model_factors + + def make_result( + self, + samples_summary: SamplesSummary, + paths: AbstractPaths, + samples: Optional[SamplesPDF] = None, + search_internal: Optional[object] = None, + analysis: Optional[object] = None, + ) -> CombinedResult: + """ + Make a result from the samples summary and paths. + + The top level result accounts for the combined model. + There is one child result for each model factor. + + Parameters + ---------- + samples_summary + A summary of the samples + paths + Handles saving and loading data + samples + The full list of samples + search_internal + analysis + + Returns + ------- + A result with child results for each model factor + """ + child_results = [ + model_factor.analysis.make_result( + samples_summary=samples_summary.subsamples( + model_factor.prior_model, + ), + paths=paths, + samples=samples.subsamples(model_factor.prior_model) + if samples + else None, + search_internal=search_internal, + analysis=model_factor, + ) + for model_factor in self.model_factors + ] + return CombinedResult( + child_results, + samples_summary=samples_summary, + paths=paths, + samples=samples, + search_internal=search_internal, + analysis=analysis, + ) diff --git a/autofit/non_linear/analysis/combined.py b/autofit/non_linear/analysis/combined.py index 0717c7fdb..9fd6fd778 100644 --- a/autofit/non_linear/analysis/combined.py +++ b/autofit/non_linear/analysis/combined.py @@ -15,12 +15,15 @@ logger = logging.getLogger(__name__) -class CombinedResult: +class CombinedResult(Result): def __init__( self, results: List[Result], samples: Optional[SamplesPDF] = None, samples_summary: Optional[SamplesSummary] = None, + paths: Optional[AbstractPaths] = None, + search_internal: Optional[object] = None, + analysis: Optional[Analysis] = None, ): """ A `Result` object that is composed of multiple `Result` objects. This is used to combine the results of @@ -32,9 +35,14 @@ def __init__( results The list of `Result` objects that are combined into this `CombinedResult` object. """ + super().__init__( + samples_summary=samples_summary, + samples=samples, + paths=paths, + search_internal=search_internal, + analysis=analysis, + ) self.child_results = results - self.samples = samples - self.samples_summary = samples_summary def __getattr__(self, item: str): """ @@ -415,4 +423,4 @@ def with_free_parameters( return FreeParameterAnalysis(*self.analyses, free_parameters=free_parameters) def compute_latent_samples(self, samples): - return self.analyses[0].compute_latent_samples(samples=samples) \ No newline at end of file + return self.analyses[0].compute_latent_samples(samples=samples) diff --git a/test_autofit/graphical/test_combined_analysis.py b/test_autofit/graphical/test_combined_analysis.py new file mode 100644 index 000000000..b19fa8b21 --- /dev/null +++ b/test_autofit/graphical/test_combined_analysis.py @@ -0,0 +1,33 @@ +import autofit as af +from autofit.non_linear.paths.null import NullPaths + + +def test_make_result(): + model = af.Model(af.Gaussian) + factor_graph_model = af.FactorGraphModel( + af.AnalysisFactor( + model, + af.Analysis(), + ) + ) + result = factor_graph_model.make_result( + samples_summary=af.SamplesSummary( + max_log_likelihood_sample=af.Sample( + 0, + 0, + 0, + kwargs={ + ("0", "centre"): 1.0, + ("0", "normalization"): 1.0, + ("0", "sigma"): 1.0, + }, + ), + model=af.Collection(model), + ), + paths=NullPaths(), + ) + assert len(result.child_results) == 1 + assert isinstance(result.model, af.Collection) + + (child_result,) = result.child_results + assert child_result.model == model