Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 32 additions & 79 deletions autofit/graphical/declarative/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,7 @@ def priors(self) -> Set[Prior]:
A set of all priors encompassed by the contained likelihood models
"""
return {
prior
for model
in self.model_factors
for prior
in model.prior_model.priors
prior for model in self.model_factors for prior in model.prior_model.priors
}

@property
Expand All @@ -67,8 +63,7 @@ def prior_counts(self) -> List[Tuple[Prior, int]]:
counter[prior] += 1
return [
(prior, count + 1 if self.include_prior_factors else count)
for prior, count
in counter.items()
for prior, count in counter.items()
]

@property
Expand All @@ -91,29 +86,21 @@ def message_dict(self) -> Dict[Prior, NormalMessage]:
"""
return {
prior: prior.message ** (1 / (count - 1)) if count > 1 else prior.message
for prior, count
in self.prior_counts
for prior, count in self.prior_counts
}

@property
def graph(self) -> DeclarativeFactorGraph:
"""
The complete graph made by combining all factors and priors
"""
factors = [
model
for model
in self.model_factors
]
factors = [model for model in self.model_factors]
if self.include_prior_factors:
factors += self.prior_factors
# noinspection PyTypeChecker
return DeclarativeFactorGraph(factors)

def draw_graph(
self,
**kwargs
):
def draw_graph(self, **kwargs):
"""
Visualise the graph.

Expand All @@ -128,28 +115,23 @@ def draw_graph(
graph = self.graph

factor_labels = {
factor: factor.name
if factor.label is None
else factor.label
factor: factor.name if factor.label is None else factor.label
for factor in graph.factors
}
variable_labels = {
variable: variable.name
if variable.label is None
else variable.label
variable: variable.name if variable.label is None else variable.label
for variable in graph.all_variables
}

import matplotlib.pyplot as plt

if "draw_labels" not in kwargs:
kwargs["draw_labels"] = True
if "variable_labels" not in kwargs:
kwargs["variable_labels"] = variable_labels
if "factor_labels" not in kwargs:
kwargs["factor_labels"] = factor_labels
graph.draw_graph(
**kwargs
)
graph.draw_graph(**kwargs)
plt.show()
plt.close()

Expand All @@ -161,16 +143,13 @@ def mean_field_approximation(self) -> EPMeanField:
"""
Returns a EPMeanField of the factor graph
"""
return EPMeanField.from_approx_dists(
self.graph,
self.message_dict
)
return EPMeanField.from_approx_dists(self.graph, self.message_dict)

def _make_ep_optimiser(
self,
optimiser: AbstractFactorOptimiser,
paths: Optional[AbstractPaths] = None,
ep_history: Optional = None,
self,
optimiser: AbstractFactorOptimiser,
paths: Optional[AbstractPaths] = None,
ep_history: Optional = None,
) -> EPOptimiser:
return EPOptimiser(
self.graph,
Expand All @@ -181,15 +160,15 @@ def _make_ep_optimiser(
if factor.optimiser is not None
},
ep_history=ep_history,
paths=paths
paths=paths,
)

def optimise(
self,
optimiser: AbstractFactorOptimiser,
paths: Optional[AbstractPaths] = None,
ep_history: Optional = None,
**kwargs
self,
optimiser: AbstractFactorOptimiser,
paths: Optional[AbstractPaths] = None,
ep_history: Optional = None,
**kwargs
):
"""
Use an EP Optimiser to optimise the graph associated with this collection
Expand All @@ -209,15 +188,9 @@ def optimise(
A collection of prior models
"""
from autofit.graphical.declarative.result import EPResult
opt = self._make_ep_optimiser(
optimiser,
paths=paths,
ep_history=ep_history
)
updated_ep_mean_field = opt.run(
self.mean_field_approximation(),
**kwargs
)

opt = self._make_ep_optimiser(optimiser, paths=paths, ep_history=ep_history)
updated_ep_mean_field = opt.run(self.mean_field_approximation(), **kwargs)

return EPResult(
ep_history=opt.ep_history,
Expand All @@ -228,10 +201,7 @@ def optimise(
# TODO : Visualize method before fit?

def visualize(
self,
paths: AbstractPaths,
instance: ModelInstance,
during_analysis: bool
self, paths: AbstractPaths, instance: ModelInstance, during_analysis: bool
):
"""
Visualise the instances provided using each factor.
Expand All @@ -247,21 +217,9 @@ def visualize(
during_analysis
Is this visualisation during analysis?
"""
for model_factor, instance in zip(
self.model_factors,
instance
):
model_factor.visualize(
paths,
instance,
during_analysis
)
model_factor.visualize_combined(
None,
paths,
instance,
during_analysis
)
for model_factor, instance in zip(self.model_factors, instance):
model_factor.visualize(paths, instance, during_analysis)
model_factor.visualize_combined(None, paths, instance, during_analysis)

@property
def global_prior_model(self) -> Collection:
Expand All @@ -272,10 +230,7 @@ def global_prior_model(self) -> Collection:


class GlobalPriorModel(Collection):
def __init__(
self,
factor: AbstractDeclarativeFactor
):
def __init__(self, factor: AbstractDeclarativeFactor):
"""
A global model comprising all factors which can be used to compare
results between global optimisation and expectation propagation.
Expand All @@ -285,15 +240,13 @@ def __init__(
factor
A factor comprising one or more factors, usually a graph
"""
super().__init__([
model_factor.prior_model
for model_factor
in factor.model_factors
])
super().__init__(
[model_factor.prior_model for model_factor in factor.model_factors]
)
self.factor = factor

@property
def info(self) -> str:
def graph_info(self) -> str:
"""
A string describing the collection of factors in the graphical style
"""
Expand Down
6 changes: 6 additions & 0 deletions autofit/non_linear/paths/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,12 @@ def _save_model_info(self, model):
with open_(self.output_path / "model.info", "w+") as f:
f.write(model.info)

try:
with open_(self.output_path / "model.graph", "w+") as f:
f.write( model.graph_info)
except AttributeError:
pass

def _save_model_start_point(self, info):
"""
Save the model.start file, which summarizes the start point of every parameter.
Expand Down
2 changes: 1 addition & 1 deletion test_autofit/graphical/global/test_global.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def reset_ids():

def test_info(model_factor):
assert (
model_factor.global_prior_model.info
model_factor.global_prior_model.graph_info
== """PriorFactors

PriorFactor0 (AnalysisFactor0.one) UniformPrior [0], lower_limit = 0.0, upper_limit = 1.0
Expand Down
29 changes: 28 additions & 1 deletion test_autofit/graphical/global/test_hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,36 @@ def reset_ids():
af.Prior._ids = itertools.count()


def test_info(model):
def test_model_info(model):
assert (
model.info
== """Total Free Parameters = 4

model GlobalPriorModel (N=4)
0 - 1 Collection (N=3)
distribution_model GaussianPrior (N=2)
2 - 3 Collection (N=1)

0 - 1
distribution_model
mean GaussianPrior [2], mean = 0.5, sigma = 0.1
sigma GaussianPrior [3], mean = 1.0, sigma = 0.01
lower_limit -inf
upper_limit inf
0
drawn_prior UniformPrior [0], lower_limit = 0.0, upper_limit = 1.0
1
drawn_prior UniformPrior [1], lower_limit = 0.0, upper_limit = 1.0
2 - 3
one UniformPrior [0], lower_limit = 0.0, upper_limit = 1.0
factor
include_prior_factors True"""
)


def test_graph_info(model):
assert (
model.graph_info
== """PriorFactors

PriorFactor0 (HierarchicalFactor0) GaussianPrior [3], mean = 1.0, sigma = 0.01
Expand Down
Loading