Click to expand diff
autofit/mapper/prior/abstract.py (modified)
@@ -61,6 +61,18 @@ def unit_value_for(self, physical_value: float) -> float:
return self.message.cdf(physical_value)
def with_message(self, message):
+ """Return a copy of this prior with a different message (distribution).
+
+ Parameters
+ ----------
+ message
+ The new message object defining the prior's distribution.
+
+ Returns
+ -------
+ Prior
+ A copy of this prior using the new message.
+ """
new = copy(self)
new.message = message
return new
@@ -88,6 +100,23 @@ def factor(self):
@staticmethod
def for_class_and_attribute_name(cls, attribute_name):
+ """Create a prior from the configuration for a given class and attribute.
+
+ Looks up the prior type and parameters in the prior config files
+ for the specified class and attribute name.
+
+ Parameters
+ ----------
+ cls
+ The model class whose config is looked up.
+ attribute_name
+ The name of the attribute on that class.
+
+ Returns
+ -------
+ Prior
+ A prior instance constructed from the config entry.
+ """
prior_dict = conf.instance.prior_config.for_class_and_suffix_path(
cls, [attribute_name]
)
@@ -129,10 +158,31 @@ def instance_for_arguments(
arguments,
ignore_assertions=False,
):
+ """Look up this prior's value in an arguments dictionary.
+
+ Parameters
+ ----------
+ arguments
+ A dictionary mapping Prior objects to physical values.
+ ignore_assertions
+ Unused for priors (present for interface compatibility).
+ """
_ = ignore_assertions
return arguments[self]
def project(self, samples, weights):
+ """Project this prior given samples and log weights from a search.
+
+ Returns a copy of this prior whose message has been updated to
+ reflect the posterior information from the samples.
+
+ Parameters
+ ----------
+ samples
+ Array of sample values for this parameter.
+ weights
+ Log weights for each sample.
+ """
result = copy(self)
result.message = self.message.project(
samples=samples,
@@ -170,6 +220,11 @@ def __str__(self):
@property
@abstractmethod
def parameter_string(self) -> str:
+ """A human-readable string summarizing this prior's parameters.
+
+ Subclasses must implement this to return a description such as
+ ``"mean = 0.0, sigma = 1.0"`` or ``"lower_limit = 0.0, upper_limit = 1.0"``.
+ """
pass
def __float__(self):
@@ -254,7 +309,22 @@ def name_of_class(cls) -> str:
@property
def limits(self) -> Tuple[float, float]:
+ """The (lower, upper) bounds of this prior.
+
+ Returns (-inf, inf) by default. Subclasses with finite bounds
+ (e.g. UniformPrior) override this.
+ """
return (float("-inf"), float("inf"))
def gaussian_prior_model_for_arguments(self, arguments):
+ """Look up this prior in an arguments dict and return the mapped value.
+
+ Used during prior replacement workflows where each prior is mapped
+ to a new prior or fixed value via an arguments dictionary.
+
+ Parameters
+ ----------
+ arguments
+ A dictionary mapping Prior objects to their replacement values.
+ """
return arguments[self]
autofit/mapper/prior/gaussian.py (modified)
@@ -57,6 +57,13 @@ def __init__(
)
def tree_flatten(self):
+ """Flatten this prior into a JAX-compatible PyTree representation.
+
+ Returns
+ -------
+ tuple
+ A (children, aux_data) pair where children are (mean, sigma, id).
+ """
return (self.mean, self.sigma, self.id), ()
@classmethod
autofit/mapper/prior/uniform.py (modified)
@@ -64,23 +64,50 @@ def __init__(
)
def tree_flatten(self):
+ """Flatten this prior into a JAX-compatible PyTree representation.
+
+ Returns
+ -------
+ tuple
+ A (children, aux_data) pair where children are (lower_limit, upper_limit, id).
+ """
return (self.lower_limit, self.upper_limit, self.id), ()
@property
def width(self):
+ """The width of the uniform distribution (upper_limit - lower_limit)."""
return self.upper_limit - self.lower_limit
def with_limits(
self,
lower_limit: float,
upper_limit: float,
) -> "Prior":
+ """Create a new UniformPrior with different bounds.
+
+ Parameters
+ ----------
+ lower_limit
+ The new lower bound.
+ upper_limit
+ The new upper bound.
+ """
return UniformPrior(
lower_limit=lower_limit,
upper_limit=upper_limit,
)
def logpdf(self, x):
+ """Compute the log probability density at x.
+
+ Adjusts boundary values by epsilon to avoid evaluating exactly at
+ the distribution edges where the PDF is undefined.
+
+ Parameters
+ ----------
+ x
+ The value at which to evaluate the log PDF.
+ """
# TODO: handle x as a numpy array
if x == self.lower_limit:
x += epsilon
@@ -102,6 +129,7 @@ def dict(self) -> dict:
@property
def parameter_string(self) -> str:
+ """A human-readable string summarizing the prior's lower and upper limits."""
return f"lower_limit = {self.lower_limit}, upper_limit = {self.upper_limit}"
def value_for(self, unit: float) -> float:
@@ -142,4 +170,5 @@ def log_prior_from_value(self, value, xp=np):
@property
def limits(self) -> Tuple[float, float]:
+ """The (lower_limit, upper_limit) bounds of this uniform prior."""
return self.lower_limit, self.upper_limit
\ No newline at end of file
autofit/mapper/prior_model/abstract.py (modified)
@@ -41,6 +41,20 @@
class Limits:
@staticmethod
def for_class_and_attributes_name(cls, attribute_name):
+ """Look up the (lower, upper) limits for a class attribute from config.
+
+ Parameters
+ ----------
+ cls
+ The model class.
+ attribute_name
+ The name of the attribute on that class.
+
+ Returns
+ -------
+ tuple
+ A (lower, upper) pair of limit values.
+ """
limit_dict = conf.instance.prior_config.for_class_and_suffix_path(
cls, [attribute_name, "limits"]
)
@@ -165,6 +179,11 @@ def __init__(self, label=None):
@property
def assertions(self):
+ """The list of assertion constraints attached to this model.
+
+ Assertions are checked when creating instances; a failed assertion
+ raises FitException, causing the non-linear search to resample.
+ """
return self._assertions
@assertions.setter
@@ -441,11 +460,28 @@ def add_assertion(self, assertion, name=""):
@property
def name(self):
+ """The class name of this prior model (e.g. ``"Model"`` or ``"Collection"``)."""
return self.__class__.__name__
# noinspection PyUnusedLocal
@staticmethod
def from_object(t, *args, **kwargs):
+ """Convert an arbitrary object into an appropriate prior model representation.
+
+ - Classes become ``Model`` instances.
+ - Lists and dicts become ``Collection`` instances.
+ - Floats become ``Constant`` instances.
+ - Existing prior models and other objects are returned as-is.
+
+ Parameters
+ ----------
+ t
+ A class, list, dict, float, or existing prior model.
+
+ Returns
+ -------
+ An AbstractPriorModel, Constant, or the original object.
+ """
if inspect.isclass(t):
from .prior_model import Model
@@ -606,6 +642,7 @@ def prior_tuples_ordered_by_id(self):
@property
def priors_ordered_by_id(self):
+ """Unique priors sorted by their id, defining the canonical parameter ordering."""
return [prior for _, prior in self.prior_tuples_ordered_by_id]
def vector_from_unit_vector(self, unit_vector):
@@ -836,6 +873,16 @@ def is_only_model(self, cls) -> bool:
return len(cls_models) > 0 and len(cls_models) == len(other_models)
def replacing(self, arguments):
+ """Return a new model with some priors replaced.
+
+ This is a convenience alias for ``mapper_from_partial_prior_arguments``.
+ Priors not present in the arguments dict are kept unchanged.
+
+ Parameters
+ ----------
+ arguments : dict
+ A dictionary mapping existing Prior objects to new priors or fixed values.
+ """
return self.mapper_from_partial_prior_arguments(arguments)
@classmethod
@@ -1211,6 +1258,10 @@ def from_instance(
return result
def items(self):
+ """Return (name, value) pairs for all public, non-internal attributes.
+
+ Excludes private attributes (prefixed with ``_``), ``cls``, and ``id``.
+ """
return [
(key, value)
for key, value in self.__dict__.items()
@@ -1225,23 +1276,27 @@ def direct_prior_tuples(self):
@property
@cast_collection(InstanceNameValue)
def direct_instance_tuples(self):
+ """(name, value) tuples for direct float and Constant attributes."""
return self.direct_tuples_with_type(float) + self.direct_tuples_with_type(
Constant
)
@property
@cast_collection(PriorModelNameValue)
def prior_model_tuples(self):
+ """(name, prior_model) tuples for direct child AbstractPriorModel attributes."""
return self.direct_tuples_with_type(AbstractPriorModel)
@property
@cast_collection(PriorModelNameValue)
def direct_prior_model_tuples(self):
+ """(name, prior_model) tuples for immediate child prior models (non-recursive)."""
return self.direct_tuples_with_type(AbstractPriorModel)
@property
@cast_collection(PriorModelNameValue)
def direct_tuple_priors(self):
+ """(name, tuple_prior) tuples for direct TuplePrior attributes."""
return self.direct_tuples_with_type(TuplePrior)
@property
@@ -1267,6 +1322,7 @@ def direct_prior_tuples(self):
@property
@cast_collection(DeferredNameValue)
def direct_deferred_tuples(self):
+ """(name, deferred_argument) tuples for direct DeferredArgument attributes."""
return self.direct_tuples_with_type(DeferredArgument)
@property
@@ -1298,6 +1354,11 @@ def instance_tuples(self):
@property
def prior_class_dict(self):
+ """Map each prior to the class it will produce when instantiated.
+
+ Direct priors on this model map to ``self.cls``. Child prior models
+ contribute their own mappings recursively.
+ """
from autofit.mapper.prior_model.annotation import AnnotationPriorModel
d = {prior[1]: self.cls for prior in self.prior_tuples}
@@ -1467,16 +1528,49 @@ def total_free_parameters(self) -> int:
@property
def priors(self):
+ """A list of all Prior objects in this model (may contain duplicates for shared priors)."""
return [prior_tuple.prior for prior_tuple in self.prior_tuples]
@property
def _prior_id_map(self):
return {prior.id: prior for prior in self.priors}
def prior_with_id(self, prior_id):
+ """Retrieve a prior by its unique integer id.
+
+ Parameters
+ ----------
+ prior_id : int
+ The id of the prior to find.
+
+ Returns
+ -------
+ Prior
+ The prior with the matching id.
+
+ Raises
+ ------
+ KeyError
+ If no prior with the given id exists in this model.
+ """
return self._prior_id_map[prior_id]
def name_for_prior(self, prior):
+ """Get the underscore-separated name for a prior in this model.
+
+ Searches child prior models recursively. Returns None if the prior
+ is not found.
+
+ Parameters
+ ----------
+ prior : Prior
+ The prior to find.
+
+ Returns
+ -------
+ str or None
+ The name path joined by underscores, e.g. ``"galaxy_centre"``.
+ """
for prior_model_name, prior_model in self.direct_prior_model_tuples:
prior_name = prior_model.name_for_prior(prior)
if prior_name is not None:
@@ -1525,6 +1619,11 @@ def copy_with_fixed_priors(self, instance, excluded_classes=tuple()):
@property
def path_priors_tuples(self) -> List[Tuple[Path, Prior]]:
+ """All (path, prior) tuples in this model, sorted by prior id.
+
+ Unlike ``unique_path_prior_tuples``, this includes duplicate entries
+ when a prior appears at multiple paths (shared priors).
+ """
path_priors_tuples = self.path_instance_tuples_for_class(Prior)
return sorted(path_priors_tuples, key=lambda item: item[1].id)
@@ -1546,6 +1645,10 @@ def paths_formatted(self) -> List[Path]:
@property
def composition(self):
+ """A list of dot-separated path strings for each prior, ordered by prior id.
+
+ For example: ``["galaxy.centre", "galaxy.normalization", "galaxy.sigma"]``.
+ """
return [".".join(path) for path in self.paths]
def sort_priors_alphabetically(self, priors: Iterable[Prior]) -> List[Prior]:
@@ -1617,14 +1720,20 @@ def all_paths_for_prior(self, prior: Prior) -> Generator[Path, None, None]:
@property
def path_float_tuples(self):
+ """(path, float) tuples for all fixed float values, excluding Prior objects."""
return self.path_instance_tuples_for_class(float, ignore_class=Prior)
@property
def unique_prior_paths(self):
+ """Paths to each unique prior (deduplicated for shared priors), ordered by id."""
return [item[0] for item in self.unique_path_prior_tuples]
@property
def unique_path_prior_tuples(self):
+ """(path, prior) tuples deduplicated by prior identity, ordered by id.
+
+ When a prior is shared across multiple paths, only one path is kept.
+ """
unique = {item[1]: item for item in self.path_priors_tuples}.values()
return sorted(unique, key=lambda item: item[1].id)
@@ -1645,6 +1754,18 @@ def prior_prior_model_dict(self):
}
def log_prior_list_from(self, parameter_lists: List[List]) -> List:
+ """Compute the total log prior for each parameter vector in a list.
+
+ Parameters
+ ----------
+ parameter_lists
+ A list of physical parameter vectors.
+
+ Returns
+ -------
+ list
+ The summed log prior for each vector.
+ """
return [
sum(self.log_prior_list_from_vector(vector=vector))
for vector in parameter_lists
@@ -1809,6 +1930,7 @@ def model_component_and_parameter_names(self) -> List[str]:
@property
def joined_paths(self) -> List[str]:
+ """Dot-joined path strings for each unique prior, ordered by id."""
prior_paths = self.unique_prior_paths
return [".".join(path) for path in prior_paths]
autofit/mapper/prior_model/array.py (modified)
@@ -197,6 +197,7 @@ def tree_unflatten(cls, aux_data, children):
@property
def prior_class_dict(self):
+ """Map each prior to the class it produces (np.ndarray for direct priors)."""
return {
**{
prior: cls
autofit/mapper/prior_model/collection.py (modified)
@@ -31,11 +31,28 @@ def name_for_prior(self, prior: Prior) -> str:
return name
def tree_flatten(self):
+ """Flatten this collection into a JAX-compatible PyTree representation.
+
+ Returns
+ -------
+ tuple
+ A (children, aux_data) pair where children are the values and
+ aux_data are the corresponding keys.
+ """
keys, values = zip(*self.items())
return values, keys
@classmethod
def tree_unflatten(cls, aux_data, children):
+ """Reconstruct a Collection from a flattened PyTree.
+
+ Parameters
+ ----------
+ aux_data
+ The keys of the collection items.
+ children
+ The values of the collection items.
+ """
instance = cls()
for key, value in zip(aux_data, children):
@@ -46,6 +63,14 @@ def __contains__(self, item):
return item in self._dict or item in self._dict.values()
def __getitem__(self, item):
+ """Retrieve an item by string key or integer index.
+
+ Parameters
+ ----------
+ item : str or int
+ A string key for dict-style access, or an integer index
+ for positional access into the values list.
+ """
if isinstance(item, str):
return self._dict[item]
return self.values[item]
@@ -64,9 +89,11 @@ def __repr__(self):
@property
def values(self):
+ """The model components in this collection as a list."""
return list(self._dict.values())
def items(self):
+ """The (key, model_component) pairs in this collection."""
return self._dict.items()
def with_prefix(self, prefix: str):
@@ -79,6 +106,11 @@ def with_prefix(self, prefix: str):
)
def as_model(self):
+ """Convert all prior models in this collection to Model instances.
+
+ Returns a new Collection where each AbstractPriorModel child has
+ been converted via its own as_model() method.
+ """
return Collection(
{
key: value.as_model()
@@ -162,6 +194,13 @@ def __init__(
@assert_not_frozen
def add_dict_items(self, item_dict):
+ """Add all entries from a dictionary, converting values to prior models.
+
+ Parameters
+ ----------
+ item_dict
+ A dictionary mapping string keys to classes, instances, or prior models.
+ """
for key, value in item_dict.items():
if isinstance(key, tuple):
key = ".".join(key)
@@ -179,11 +218,20 @@ def __eq__(self, other):
@assert_not_frozen
def append(self, item):
+ """Append an item to the collection with an auto-incremented numeric key.
+
+ The item is converted to an AbstractPriorModel if it is not already one.
+ """
setattr(self, str(self.item_number), AbstractPriorModel.from_object(item))
self.item_number += 1
@assert_not_frozen
def __setitem__(self, key, value):
+ """Set an item by key, converting the value to a prior model.
+
+ Preserves the id of any existing item at the same key so that
+ prior identity is maintained across replacements.
+ """
obj = AbstractPriorModel.from_object(value)
try:
obj.id = getattr(self, str(key)).id
@@ -193,6 +241,12 @@ def __setitem__(self, key, value):
@assert_not_frozen
def __setattr__(self, key, value):
+ """Set an attribute, automatically converting values to prior models.
+
+ Private attributes (starting with ``_``) are set directly. All other
+ values are wrapped via ``AbstractPriorModel.from_object`` so that
+ plain classes become ``Model`` instances and floats become fixed values.
+ """
if key.startswith("_"):
super().__setattr__(key, value)
else:
@@ -202,6 +256,14 @@ def __setattr__(self, key, value):
pass
def remove(self, item):
+ """Remove an item from the collection by value equality.
+
+ Parameters
+ ----------
+ item
+ The item to remove. All entries whose value equals this item
+ are deleted.
+ """
for key, value in self.__dict__.copy().items():
if value == item:
del self.__dict__[key]
@@ -271,6 +333,11 @@ def gaussian_prior_model_for_arguments(self, arguments):
@property
def prior_class_dict(self):
+ """Map each prior to the class it will produce when instantiated.
+
+ For child prior models, delegates to their own prior_class_dict.
+ Direct priors on the collection itself map to ModelInstance.
+ """
return {
**{
prior: cls
test_autofit/mapper/test_model_mapping_expanded.py (added)
@@ -0,0 +1,631 @@
+"""
+Expanded tests for the model mapping API, covering gaps identified in:
+- Collection composition and instance creation
+- Shared (linked) priors across model types
+- Direct use of instance_for_arguments with argument dicts
+- Model tree navigation (object_for_path, path_for_prior, name_for_prior)
+- Edge cases (empty models, deeply nested models, single-parameter models)
+- Model subsetting (with_paths, without_paths)
+- Freezing behavior
+- Assertion checking
+- from_instance round-trips
+- mapper_from_prior_arguments and mapper_from_partial_prior_arguments
+"""
+import copy
+
+import numpy as np
+import pytest
+
+import autofit as af
+from autofit import exc
+from autofit.mapper.prior.abstract import Prior
+
+
+# ---------------------------------------------------------------------------
+# Collection: composition, nesting, instance creation, iteration
+# ---------------------------------------------------------------------------
+class TestCollectionComposition:
+ def test_collection_from_dict(self):
+ model = af.Collection(
+ one=af.Model(af.m.MockClassx2),
+ two=af.Model(af.m.MockClassx2),
+ )
+ assert model.prior_count == 4
+
+ def test_collection_from_list(self):
+ model = af.Collection([af.m.MockClassx2, af.m.MockClassx2])
+ assert model.prior_count == 4
+
+ def test_collection_from_generator(self):
+ model = af.Collection(af.Model(af.m.MockClassx2) for _ in range(3))
+ assert model.prior_count == 6
+
+ def test_nested_collection(self):
+ inner = af.Collection(a=af.m.MockClassx2)
+ outer = af.Collection(inner=inner, extra=af.m.MockClassx2)
+ assert outer.prior_count == 4
+
+ def test_deeply_nested_collection(self):
+ model = af.Collection(
+ level1=af.Collection(
+ level2=af.Collection(
+ leaf=af.m.MockClassx2,
+ )
+ )
+ )
+ assert model.prior_count == 2
+
+ def test_collection_instance_attribute_access(self):
+ model = af.Collection(gaussian=af.m.MockClassx2, exp=af.m.MockClassx2)
+ instance = model.instance_from_vector([1.0, 2.0, 3.0, 4.0])
+ assert instance.gaussian.one == 1.0
+ assert instance.gaussian.two == 2.0
+ assert instance.exp.one == 3.0
+ assert instance.exp.two == 4.0
+
+ def test_collection_instance_index_access(self):
+ model = af.Collection([af.m.MockClassx2, af.m.MockClassx2])
+ instance = model.instance_from_vector([1.0, 2.0, 3.0, 4.0])
+ assert instance[0].one == 1.0
+ assert instance[1].one == 3.0
+
+ def test_collection_len(self):
+ model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+ assert len(model) == 2
+
+ def test_collection_contains(self):
+ model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+ assert "a" in model
+ assert "c" not in model
+
+ def test_collection_items(self):
+ model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+ keys = [k for k, v in model.items()]
+ assert "a" in keys
+ assert "b" in keys
+
+ def test_collection_getitem_string(self):
+ model = af.Collection(a=af.m.MockClassx2)
+ assert isinstance(model["a"], af.Model)
+
+ def test_collection_append(self):
+ model = af.Collection()
+ model.append(af.m.MockClassx2)
+ model.append(af.m.MockClassx2)
+ assert model.prior_count == 4
+
+ def test_collection_mixed_model_and_fixed(self):
+ """Collection with one free model and one fixed instance."""
+ model = af.Collection(
+ free=af.Model(af.m.MockClassx2),
+ )
+ assert model.prior_count == 2
+
+ def test_empty_collection(self):
+ model = af.Collection()
+ assert model.prior_count == 0
+
+
+# ---------------------------------------------------------------------------
+# Shared (linked) priors
+# ---------------------------------------------------------------------------
+class TestSharedPriors:
+ def test_link_within_model(self):
+ model = af.Model(af.m.MockClassx2)
+ model.one = model.two
+ assert model.prior_count == 1
+ instance = model.instance_from_vector([5.0])
+ assert instance.one == instance.two == 5.0
+
+ def test_link_across_collection_children(self):
+ model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+ model.a.one = model.b.one # Link a.one to b.one
+ assert model.prior_count == 3 # 4 - 1 shared
+
+ def test_linked_priors_same_value(self):
+ model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+ model.a.one = model.b.one
+ instance = model.instance_from_vector([10.0, 20.0, 30.0])
+ assert instance.a.one == instance.b.one
+
+ def test_link_reduces_unique_prior_count(self):
+ model = af.Model(af.m.MockClassx2)
+ original_count = len(model.unique_prior_tuples)
+ model.one = model.two
+ assert len(model.unique_prior_tuples) == original_count - 1
+
+ def test_linked_prior_identity(self):
+ model = af.Model(af.m.MockClassx2)
+ model.one = model.two
+ assert model.one is model.two
+
+
+# ---------------------------------------------------------------------------
+# instance_for_arguments (direct argument dict usage)
+# ---------------------------------------------------------------------------
+class TestInstanceForArguments:
+ def test_model_instance_for_arguments(self):
+ model = af.Model(af.m.MockClassx2)
+ args = {model.one: 10.0, model.two: 20.0}
+ instance = model.instance_for_arguments(args)
+ assert instance.one == 10.0
+ assert instance.two == 20.0
+
+ def test_collection_instance_for_arguments(self):
+ model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+ args = {}
+ for name, prior in model.prior_tuples_ordered_by_id:
+ args[prior] = 1.0
+ instance = model.instance_for_arguments(args)
+ assert instance.a.one == 1.0
+ assert instance.b.two == 1.0
+
+ def test_shared_prior_in_arguments(self):
+ """When priors are linked, only one entry is needed in the arguments dict."""
+ model = af.Model(af.m.MockClassx2)
+ model.one = model.two
+ shared_prior = model.one
+ args = {shared_prior: 42.0}
+ instance = model.instance_for_arguments(args)
+ assert instance.one == 42.0
+ assert instance.two == 42.0
+
+ def test_missing_prior_raises(self):
+ model = af.Model(af.m.MockClassx2)
+ args = {model.one: 10.0} # missing model.two
+ with pytest.raises(KeyError):
+ model.instance_for_arguments(args)
+
+
+# ---------------------------------------------------------------------------
+# Vector and unit vector mapping
+# ---------------------------------------------------------------------------
+class TestVectorMapping:
+ def test_instance_from_vector_basic(self):
+ model = af.Model(af.m.MockClassx2)
+ instance = model.instance_from_vector([3.0, 4.0])
+ assert instance.one == 3.0
+ assert instance.two == 4.0
+
+ def test_vector_length_mismatch_raises(self):
+ model = af.Model(af.m.MockClassx2)
+ with pytest.raises(AssertionError):
+ model.instance_from_vector([1.0])
+
+ def test_unit_vector_length_mismatch_raises(self):
+ model = af.Model(af.m.MockClassx2)
+ with pytest.raises(AssertionError):
+ model.instance_from_unit_vector([0.5])
+
+ def test_vector_from_unit_vector(self):
+ model = af.Model(af.m.MockClassx2)
+ model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+ model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+ physical = model.vector_from_unit_vector([0.0, 1.0])
+ assert physical[0] == pytest.approx(0.0, abs=1e-6)
+ assert physical[1] == pytest.approx(10.0, abs=1e-6)
+
+ def test_instance_from_prior_medians(self):
+ model = af.Model(af.m.MockClassx2)
+ model.one = af.UniformPrior(lower_limit=0.0, upper_limit=100.0)
+ model.two = af.UniformPrior(lower_limit=0.0, upper_limit=100.0)
+ instance = model.instance_from_prior_medians()
+ assert instance.one == pytest.approx(50.0)
+ assert instance.two == pytest.approx(50.0)
+
+ def test_random_instance(self):
+ model = af.Model(af.m.MockClassx2)
+ model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0)
+ model.two = af.UniformPrior(lower_limit=0.0, upper_limit=1.0)
+ instance = model.random_instance()
+ assert 0.0 <= instance.one <= 1.0
+ assert 0.0 <= instance.two <= 1.0
+
+
+# ---------------------------------------------------------------------------
+# Model tree navigation
+# ---------------------------------------------------------------------------
+class TestModelTreeNavigation:
+ def test_object_for_path_child_model(self):
+ model = af.Collection(g=af.Model(af.m.MockClassx2))
+ child = model.object_for_path(("g",))
+ assert isinstance(child, af.Model)
+
+ def test_object_for_path_prior(self):
+ model = af.Collection(g=af.Model(af.m.MockClassx2))
+ prior = model.object_for_path(("g", "one"))
+ assert isinstance(prior, Prior)
+
+ def test_paths_matches_prior_count(self):
+ model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+ assert len(model.paths) == model.prior_count
+
+ def test_path_for_prior(self):
+ model = af.Collection(g=af.Model(af.m.MockClassx2))
+ prior = model.g.one
+ path = model.path_for_prior(prior)
+ assert path == ("g", "one")
+
+ def test_name_for_prior(self):
+ model = af.Collection(g=af.Model(af.m.MockClassx2))
+ prior = model.g.one
+ name = model.name_for_prior(prior)
+ assert name == "g_one"
+
+ def test_path_instance_tuples_for_class(self):
+ model = af.Collection(g=af.Model(af.m.MockClassx2))
+ tuples = model.path_instance_tuples_for_class(Prior)
+ paths = [t[0] for t in tuples]
+ assert ("g", "one") in paths
+ assert ("g", "two") in paths
+
+ def test_deeply_nested_path(self):
+ inner_model = af.Model(af.m.MockClassx2)
+ inner_collection = af.Collection(leaf=inner_model)
+ outer = af.Collection(branch=inner_collection)
+
+ prior = outer.branch.leaf.one
+ path = outer.path_for_prior(prior)
+ assert path == ("branch", "leaf", "one")
+
+ def test_direct_vs_recursive_prior_tuples(self):
+ model = af.Collection(a=af.m.MockClassx2)
+ assert len(model.direct_prior_tuples) == 0 # Collection has no direct priors
+ assert len(model.prior_tuples) == 2 # But has 2 recursive priors
+
+ def test_direct_prior_model_tuples(self):
+ model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+ assert len(model.direct_prior_model_tuples) == 2
+
+
+# ---------------------------------------------------------------------------
+# instance_from_path_arguments and instance_from_prior_name_arguments
+# ---------------------------------------------------------------------------
+class TestPathAndNameArguments:
+ def test_instance_from_path_arguments(self):
+ model = af.Collection(g=af.m.MockClassx2)
+ instance = model.instance_from_path_arguments(
+ {("g", "one"): 10.0, ("g", "two"): 20.0}
+ )
+ assert instance.g.one == 10.0
+ assert instance.g.two == 20.0
+
+ def test_instance_from_prior_name_arguments(self):
+ model = af.Collection(g=af.m.MockClassx2)
+ instance = model.instance_from_prior_name_arguments(
+ {"g_one": 10.0, "g_two": 20.0}
+ )
+ assert instance.g.one == 10.0
+ assert instance.g.two == 20.0
+
+
+# ---------------------------------------------------------------------------
+# Assertions
+# ---------------------------------------------------------------------------
+class TestAssertions:
+ def test_assertion_passes(self):
+ model = af.Model(af.m.MockClassx2)
+ model.add_assertion(model.one > model.two)
+ # one=10 > two=5 should pass
+ instance = model.instance_from_vector([10.0, 5.0])
+ assert instance.one == 10.0
+
+ def test_assertion_fails(self):
+ model = af.Model(af.m.MockClassx2)
+ model.add_assertion(model.one > model.two)
+ with pytest.raises(exc.FitException):
+ model.instance_from_vector([1.0, 10.0])
+
+ def test_ignore_assertions(self):
+ model = af.Model(af.m.MockClassx2)
+ model.add_assertion(model.one > model.two)
+ # Should not raise even though assertion fails
+ instance = model.instance_from_vector([1.0, 10.0], ignore_assertions=True)
+ assert instance.one == 1.0
+
+ def test_multiple_assertions(self):
+ model = af.Model(af.m.MockClassx4)
+ model.add_assertion(model.one > model.two)
+ model.add_assertion(model.three > model.four)
+ # Both pass
+ instance = model.instance_from_vector([10.0, 5.0, 10.0, 5.0])
+ assert instance.one == 10.0
+ # First fails
+ with pytest.raises(exc.FitException):
+ model.instance_from_vector([1.0, 10.0, 10.0, 5.0])
+
+ def test_true_assertion_ignored(self):
+ """Adding True as an assertion should be a no-op."""
+ model = af.Model(af.m.MockClassx2)
+ model.add_assertion(True)
+ assert len(model.assertions) == 0
+
+
+# ---------------------------------------------------------------------------
+# Model subsetting (with_paths, without_paths)
+# ---------------------------------------------------------------------------
+class TestModelSubsetting:
+ def test_with_paths_single_child(self):
+ model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+ subset = model.with_paths([("a",)])
+ assert subset.prior_count == 2
+
+ def test_without_paths_single_child(self):
+ model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+ subset = model.without_paths([("a",)])
+ assert subset.prior_count == 2
+
+ def test_with_paths_specific_prior(self):
+ model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+ subset = model.with_paths([("a", "one")])
+ assert subset.prior_count == 1
+
+ def test_with_prefix(self):
+ model = af.Collection(ab_one=af.m.MockClassx2, cd_two=af.m.MockClassx2)
+ subset = model.with_prefix("ab")
+ assert subset.prior_count == 2
+
+
+# ---------------------------------------------------------------------------
+# Freezing behavior
+# ---------------------------------------------------------------------------
+class TestFreezing:
+ def test_freeze_prevents_modification(self):
+ model = af.Model(af.m.MockClassx2)
+ model.freeze()
+ with pytest.raises(AssertionError):
+ model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0)
+
+ def test_unfreeze_allows_modification(self):
+ model = af.Model(af.m.MockClassx2)
+ model.freeze()
+ model.unfreeze()
+ model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0)
+ assert isinstance(model.one, af.UniformPrior)
+
+ def test_frozen_model_still_creates_instances(self):
+ model = af.Model(af.m.MockClassx2)
+ model.freeze()
+ instance = model.instance_from_vector([1.0, 2.0])
+ assert instance.one == 1.0
+
+ def test_freeze_propagates_to_children(self):
+ model = af.Collection(a=af.m.MockClassx2)
+ model.freeze()
+ with pytest.raises(AssertionError):
+ model.a.one = 1.0
+
+ def test_cached_results_consistent(self):
+ model = af.Model(af.m.MockClassx2)
+ model.freeze()
+ result1 = model.prior_tuples_ordered_by_id
+ result2 = model.prior_tuples_ordered_by_id
+ assert result1 == result2
+
+
+# ---------------------------------------------------------------------------
+# mapper_from_prior_arguments and related
+# ---------------------------------------------------------------------------
+class TestMapperFromPriorArguments:
+ def test_replace_all_priors(self):
+ model = af.Model(af.m.MockClassx2)
+ new_one = af.GaussianPrior(mean=0.0, sigma=1.0)
+ new_two = af.GaussianPrior(mean=5.0, sigma=2.0)
+ new_model = model.mapper_from_prior_arguments(
+ {model.one: new_one, model.two: new_two}
+ )
+ assert new_model.prior_count == 2
+ assert isinstance(new_model.one, af.GaussianPrior)
+
+ def test_partial_replacement(self):
+ model = af.Model(af.m.MockClassx2)
+ new_one = af.GaussianPrior(mean=0.0, sigma=1.0)
+ new_model = model.mapper_from_partial_prior_arguments(
+ {model.one: new_one}
+ )
+ assert new_model.prior_count == 2
+ assert isinstance(new_model.one, af.GaussianPrior)
+ # two should retain its original prior type
+ assert new_model.two is not None
+
+ def test_fix_via_mapper_from_prior_arguments(self):
+ """Replacing a prior with a float effectively fixes that parameter."""
+ model = af.Model(af.m.MockClassx2)
+ new_model = model.mapper_from_prior_arguments(
+ {model.one: 5.0, model.two: model.two}
+ )
+ assert new_model.prior_count == 1
+
+ def test_with_limits(self):
+ model = af.Model(af.m.MockClassx2)
+ model.one = af.UniformPrior(lower_limit=0.0, upper_limit=100.0)
+ model.two = af.UniformPrior(lower_limit=0.0, upper_limit=100.0)
+ new_model = model.with_limits([(10.0, 20.0), (30.0, 40.0)])
+ assert new_model.prior_count == 2
+
+
+# ---------------------------------------------------------------------------
+# from_instance round trips
+# ---------------------------------------------------------------------------
+class TestFromInstance:
+ def test_from_simple_instance(self):
+ instance = af.m.MockClassx2(1.0, 2.0)
+ model = af.AbstractPriorModel.from_instance(instance)
+ assert model.prior_count == 0
+
+ def test_from_instance_as_model(self):
+ instance = af.m.MockClassx2(1.0, 2.0)
+ model = af.AbstractPriorModel.from_instance(instance)
+ free_model = model.as_model()
+ assert free_model.prior_count == 2
+
+ def test_from_instance_with_model_classes(self):
+ instance = af.m.MockClassx2(1.0, 2.0)
+ model = af.AbstractPriorModel.from_instance(
+ instance, model_classes=(af.m.MockClassx2,)
+ )
+ assert model.prior_count == 2
+
+ def test_from_list_instance(self):
+ instance_list = [af.m.MockClassx2(1.0, 2.0), af.m.MockClassx2(3.0, 4.0)]
+ model = af.AbstractPriorModel.from_instance(instance_list)
+ assert model.prior_count == 0
+
+ def test_from_dict_instance(self):
+ instance_dict = {
+ "one": af.m.MockClassx2(1.0, 2.0),
+ "two": af.m.MockClassx2(3.0, 4.0),
+ }
+ model = af.AbstractPriorModel.from_instance(instance_dict)
+ assert model.prior_count == 0
+
+
+# ---------------------------------------------------------------------------
+# Fixing parameters and Constant values
+# ---------------------------------------------------------------------------
+class TestFixedParameters:
+ def test_fix_reduces_prior_count(self):
+ model = af.Model(af.m.MockClassx2)
+ model.one = 5.0
+ assert model.prior_count == 1
+
+ def test_fixed_value_in_instance(self):
+ model = af.Model(af.m.MockClassx2)
+ model.one = 5.0
+ instance = model.instance_from_vector([10.0])
+ assert instance.one == 5.0
+ assert instance.two == 10.0
+
+ def test_fix_all_parameters(self):
+ model = af.Model(af.m.MockClassx2)
+ model.one = 5.0
+ model.two = 10.0
+ assert model.prior_count == 0
+ instance = model.instance_from_vector([])
+ assert instance.one == 5.0
+ assert instance.two == 10.0
+
+
+# ---------------------------------------------------------------------------
+# take_attributes (prior passing)
+# ---------------------------------------------------------------------------
+class TestTakeAttributes:
+ def test_take_from_instance(self):
+ model = af.Model(af.m.MockClassx2)
+ source = af.m.MockClassx2(10.0, 20.0)
+ model.take_attributes(source)
+ assert model.prior_count == 0
+
+ def test_take_from_model(self):
+ """Taking attributes from another model copies priors."""
+ source_model = af.Model(af.m.MockClassx2)
+ source_model.one = af.GaussianPrior(mean=5.0, sigma=1.0)
+ source_model.two = af.GaussianPrior(mean=10.0, sigma=2.0)
+
+ target_model = af.Model(af.m.MockClassx2)
+ target_model.take_attributes(source_model)
+ assert isinstance(target_model.one, af.GaussianPrior)
+
+
+# ---------------------------------------------------------------------------
+# Serialization (dict / from_dict)
+# ---------------------------------------------------------------------------
+class TestSerialization:
+ def test_model_dict_roundtrip(self):
+ model = af.Model(af.m.MockClassx2)
+ d = model.dict()
+ loaded = af.AbstractPriorModel.from_dict(d)
+ assert loaded.prior_count == model.prior_count
+
+ def test_collection_dict_roundtrip(self):
+ model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+ d = model.dict()
+ loaded = af.AbstractPriorModel.from_dict(d)
+ assert loaded.prior_count == model.prior_count
+
+ def test_fixed_parameter_survives_roundtrip(self):
+ model = af.Model(af.m.MockClassx2)
+ model.one = 5.0
+ d = model.dict()
+ loaded = af.AbstractPriorModel.from_dict(d)
+ assert loaded.prior_count == 1
+
+ def test_linked_prior_survives_roundtrip(self):
+ model = af.Model(af.m.MockClassx2)
+ model.one = model.two
+ assert model.prior_count == 1
+ d = model.dict()
+ loaded = af.AbstractPriorModel.from_dict(d)
+ assert loaded.prior_count == 1
+
+
+# ---------------------------------------------------------------------------
+# Log prior computation
+# ---------------------------------------------------------------------------
+class TestLogPrior:
+ def test_log_prior_within_bounds(self):
+ model = af.Model(af.m.MockClassx2)
+ model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+ model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+ log_priors = model.log_prior_list_from_vector([5.0, 5.0])
+ assert all(np.isfinite(lp) for lp in log_priors)
+
+ def test_log_prior_outside_bounds(self):
+ model = af.Model(af.m.MockClassx2)
+ model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+ model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+ log_priors = model.log_prior_list_from_vector([15.0, 5.0])
+ # Out-of-bounds value should have a lower (or zero) log prior than in-bounds
+ assert log_priors[0] <= log_priors[1]
+
+
+# ---------------------------------------------------------------------------
+# Edge cases
+# ---------------------------------------------------------------------------
+class TestEdgeCases:
+ def test_single_parameter_model(self):
+ """A model with a single free parameter using explicit prior."""
+ model = af.Model(af.m.MockClassx2)
+ model.two = 5.0 # Fix one parameter
+ assert model.prior_count == 1
+ instance = model.instance_from_vector([42.0])
+ assert instance.one == 42.0
+ assert instance.two == 5.0
+
+ def test_model_copy_preserves_priors(self):
+ model = af.Model(af.m.MockClassx2)
+ copied = model.copy()
+ assert copied.prior_count == model.prior_count
+ # Priors are independent copies (different objects)
+ assert copied.one is not model.one
+
+ def test_model_copy_linked_priors_independent(self):
+ """Copying a model with linked priors preserves the link in the copy."""
+ model = af.Model(af.m.MockClassx2)
+ model.one = model.two
+ assert model.prior_count == 1
+ copied = model.copy()
+ assert copied.prior_count == 1
+ # The copy's internal link is preserved
+ assert copied.one is copied.two
+
+ def test_prior_ordering_is_deterministic(self):
+ """prior_tuples_ordered_by_id should be stable across calls."""
+ model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+ order1 = [(n, p.id) for n, p in model.prior_tuples_ordered_by_id]
+ order2 = [(n, p.id) for n, p in model.prior_tuples_ordered_by_id]
+ assert order1 == order2
+
+ def test_prior_count_equals_total_free_parameters(self):
+ model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx4)
+ assert model.prior_count == model.total_free_parameters
+
+ def test_has_model(self):
+ model = af.Collection(a=af.Model(af.m.MockClassx2))
+ assert model.has_model(af.m.MockClassx2)
+ assert not model.has_model(af.m.MockClassx4)
+
+ def test_has_instance(self):
+ model = af.Model(af.m.MockClassx2)
+ assert model.has_instance(Prior)
+ assert not model.has_instance(af.m.MockClassx4)
API Update Required
PyAutoFit PR #1172 has been merged with API changes that may affect workspace scripts.
PR Title
Add expanded model mapping unit tests
Changed Files (Python only)
Click to expand diff
autofit/mapper/prior/abstract.py (modified)
autofit/mapper/prior/gaussian.py (modified)
autofit/mapper/prior/uniform.py (modified)
autofit/mapper/prior_model/abstract.py (modified)
autofit/mapper/prior_model/array.py (modified)
autofit/mapper/prior_model/collection.py (modified)
test_autofit/mapper/test_model_mapping_expanded.py (added)
Instructions
scripts/for usages of the old API patterns.bash run_scripts.shBranch
Use branch name:
feature/feature/model_docs