Skip to content

[API Update] PyAutoFit PR #1172: Add expanded model mapping unit tests #11

@Jammy2211

Description

@Jammy2211

API Update Required

PyAutoFit PR #1172 has been merged with API changes that may affect workspace scripts.

PR Title

Add expanded model mapping unit tests

Changed Files (Python only)

Click to expand diff

autofit/mapper/prior/abstract.py (modified)

@@ -61,6 +61,18 @@ def unit_value_for(self, physical_value: float) -> float:
         return self.message.cdf(physical_value)
 
     def with_message(self, message):
+        """Return a copy of this prior with a different message (distribution).
+
+        Parameters
+        ----------
+        message
+            The new message object defining the prior's distribution.
+
+        Returns
+        -------
+        Prior
+            A copy of this prior using the new message.
+        """
         new = copy(self)
         new.message = message
         return new
@@ -88,6 +100,23 @@ def factor(self):
 
     @staticmethod
     def for_class_and_attribute_name(cls, attribute_name):
+        """Create a prior from the configuration for a given class and attribute.
+
+        Looks up the prior type and parameters in the prior config files
+        for the specified class and attribute name.
+
+        Parameters
+        ----------
+        cls
+            The model class whose config is looked up.
+        attribute_name
+            The name of the attribute on that class.
+
+        Returns
+        -------
+        Prior
+            A prior instance constructed from the config entry.
+        """
         prior_dict = conf.instance.prior_config.for_class_and_suffix_path(
             cls, [attribute_name]
         )
@@ -129,10 +158,31 @@ def instance_for_arguments(
         arguments,
         ignore_assertions=False,
     ):
+        """Look up this prior's value in an arguments dictionary.
+
+        Parameters
+        ----------
+        arguments
+            A dictionary mapping Prior objects to physical values.
+        ignore_assertions
+            Unused for priors (present for interface compatibility).
+        """
         _ = ignore_assertions
         return arguments[self]
 
     def project(self, samples, weights):
+        """Project this prior given samples and log weights from a search.
+
+        Returns a copy of this prior whose message has been updated to
+        reflect the posterior information from the samples.
+
+        Parameters
+        ----------
+        samples
+            Array of sample values for this parameter.
+        weights
+            Log weights for each sample.
+        """
         result = copy(self)
         result.message = self.message.project(
             samples=samples,
@@ -170,6 +220,11 @@ def __str__(self):
     @property
     @abstractmethod
     def parameter_string(self) -> str:
+        """A human-readable string summarizing this prior's parameters.
+
+        Subclasses must implement this to return a description such as
+        ``"mean = 0.0, sigma = 1.0"`` or ``"lower_limit = 0.0, upper_limit = 1.0"``.
+        """
         pass
 
     def __float__(self):
@@ -254,7 +309,22 @@ def name_of_class(cls) -> str:
 
     @property
     def limits(self) -> Tuple[float, float]:
+        """The (lower, upper) bounds of this prior.
+
+        Returns (-inf, inf) by default. Subclasses with finite bounds
+        (e.g. UniformPrior) override this.
+        """
         return (float("-inf"), float("inf"))
 
     def gaussian_prior_model_for_arguments(self, arguments):
+        """Look up this prior in an arguments dict and return the mapped value.
+
+        Used during prior replacement workflows where each prior is mapped
+        to a new prior or fixed value via an arguments dictionary.
+
+        Parameters
+        ----------
+        arguments
+            A dictionary mapping Prior objects to their replacement values.
+        """
         return arguments[self]

autofit/mapper/prior/gaussian.py (modified)

@@ -57,6 +57,13 @@ def __init__(
         )
 
     def tree_flatten(self):
+        """Flatten this prior into a JAX-compatible PyTree representation.
+
+        Returns
+        -------
+        tuple
+            A (children, aux_data) pair where children are (mean, sigma, id).
+        """
         return (self.mean, self.sigma, self.id), ()
 
     @classmethod

autofit/mapper/prior/uniform.py (modified)

@@ -64,23 +64,50 @@ def __init__(
         )
 
     def tree_flatten(self):
+        """Flatten this prior into a JAX-compatible PyTree representation.
+
+        Returns
+        -------
+        tuple
+            A (children, aux_data) pair where children are (lower_limit, upper_limit, id).
+        """
         return (self.lower_limit, self.upper_limit, self.id), ()
 
     @property
     def width(self):
+        """The width of the uniform distribution (upper_limit - lower_limit)."""
         return self.upper_limit - self.lower_limit
 
     def with_limits(
         self,
         lower_limit: float,
         upper_limit: float,
     ) -> "Prior":
+        """Create a new UniformPrior with different bounds.
+
+        Parameters
+        ----------
+        lower_limit
+            The new lower bound.
+        upper_limit
+            The new upper bound.
+        """
         return UniformPrior(
             lower_limit=lower_limit,
             upper_limit=upper_limit,
         )
 
     def logpdf(self, x):
+        """Compute the log probability density at x.
+
+        Adjusts boundary values by epsilon to avoid evaluating exactly at
+        the distribution edges where the PDF is undefined.
+
+        Parameters
+        ----------
+        x
+            The value at which to evaluate the log PDF.
+        """
         # TODO: handle x as a numpy array
         if x == self.lower_limit:
             x += epsilon
@@ -102,6 +129,7 @@ def dict(self) -> dict:
 
     @property
     def parameter_string(self) -> str:
+        """A human-readable string summarizing the prior's lower and upper limits."""
         return f"lower_limit = {self.lower_limit}, upper_limit = {self.upper_limit}"
 
     def value_for(self, unit: float) -> float:
@@ -142,4 +170,5 @@ def log_prior_from_value(self, value, xp=np):
 
     @property
     def limits(self) -> Tuple[float, float]:
+        """The (lower_limit, upper_limit) bounds of this uniform prior."""
         return self.lower_limit, self.upper_limit
\ No newline at end of file

autofit/mapper/prior_model/abstract.py (modified)

@@ -41,6 +41,20 @@
 class Limits:
     @staticmethod
     def for_class_and_attributes_name(cls, attribute_name):
+        """Look up the (lower, upper) limits for a class attribute from config.
+
+        Parameters
+        ----------
+        cls
+            The model class.
+        attribute_name
+            The name of the attribute on that class.
+
+        Returns
+        -------
+        tuple
+            A (lower, upper) pair of limit values.
+        """
         limit_dict = conf.instance.prior_config.for_class_and_suffix_path(
             cls, [attribute_name, "limits"]
         )
@@ -165,6 +179,11 @@ def __init__(self, label=None):
 
     @property
     def assertions(self):
+        """The list of assertion constraints attached to this model.
+
+        Assertions are checked when creating instances; a failed assertion
+        raises FitException, causing the non-linear search to resample.
+        """
         return self._assertions
 
     @assertions.setter
@@ -441,11 +460,28 @@ def add_assertion(self, assertion, name=""):
 
     @property
     def name(self):
+        """The class name of this prior model (e.g. ``"Model"`` or ``"Collection"``)."""
         return self.__class__.__name__
 
     # noinspection PyUnusedLocal
     @staticmethod
     def from_object(t, *args, **kwargs):
+        """Convert an arbitrary object into an appropriate prior model representation.
+
+        - Classes become ``Model`` instances.
+        - Lists and dicts become ``Collection`` instances.
+        - Floats become ``Constant`` instances.
+        - Existing prior models and other objects are returned as-is.
+
+        Parameters
+        ----------
+        t
+            A class, list, dict, float, or existing prior model.
+
+        Returns
+        -------
+        An AbstractPriorModel, Constant, or the original object.
+        """
         if inspect.isclass(t):
             from .prior_model import Model
 
@@ -606,6 +642,7 @@ def prior_tuples_ordered_by_id(self):
 
     @property
     def priors_ordered_by_id(self):
+        """Unique priors sorted by their id, defining the canonical parameter ordering."""
         return [prior for _, prior in self.prior_tuples_ordered_by_id]
 
     def vector_from_unit_vector(self, unit_vector):
@@ -836,6 +873,16 @@ def is_only_model(self, cls) -> bool:
         return len(cls_models) > 0 and len(cls_models) == len(other_models)
 
     def replacing(self, arguments):
+        """Return a new model with some priors replaced.
+
+        This is a convenience alias for ``mapper_from_partial_prior_arguments``.
+        Priors not present in the arguments dict are kept unchanged.
+
+        Parameters
+        ----------
+        arguments : dict
+            A dictionary mapping existing Prior objects to new priors or fixed values.
+        """
         return self.mapper_from_partial_prior_arguments(arguments)
 
     @classmethod
@@ -1211,6 +1258,10 @@ def from_instance(
         return result
 
     def items(self):
+        """Return (name, value) pairs for all public, non-internal attributes.
+
+        Excludes private attributes (prefixed with ``_``), ``cls``, and ``id``.
+        """
         return [
             (key, value)
             for key, value in self.__dict__.items()
@@ -1225,23 +1276,27 @@ def direct_prior_tuples(self):
     @property
     @cast_collection(InstanceNameValue)
     def direct_instance_tuples(self):
+        """(name, value) tuples for direct float and Constant attributes."""
         return self.direct_tuples_with_type(float) + self.direct_tuples_with_type(
             Constant
         )
 
     @property
     @cast_collection(PriorModelNameValue)
     def prior_model_tuples(self):
+        """(name, prior_model) tuples for direct child AbstractPriorModel attributes."""
         return self.direct_tuples_with_type(AbstractPriorModel)
 
     @property
     @cast_collection(PriorModelNameValue)
     def direct_prior_model_tuples(self):
+        """(name, prior_model) tuples for immediate child prior models (non-recursive)."""
         return self.direct_tuples_with_type(AbstractPriorModel)
 
     @property
     @cast_collection(PriorModelNameValue)
     def direct_tuple_priors(self):
+        """(name, tuple_prior) tuples for direct TuplePrior attributes."""
         return self.direct_tuples_with_type(TuplePrior)
 
     @property
@@ -1267,6 +1322,7 @@ def direct_prior_tuples(self):
     @property
     @cast_collection(DeferredNameValue)
     def direct_deferred_tuples(self):
+        """(name, deferred_argument) tuples for direct DeferredArgument attributes."""
         return self.direct_tuples_with_type(DeferredArgument)
 
     @property
@@ -1298,6 +1354,11 @@ def instance_tuples(self):
 
     @property
     def prior_class_dict(self):
+        """Map each prior to the class it will produce when instantiated.
+
+        Direct priors on this model map to ``self.cls``. Child prior models
+        contribute their own mappings recursively.
+        """
         from autofit.mapper.prior_model.annotation import AnnotationPriorModel
 
         d = {prior[1]: self.cls for prior in self.prior_tuples}
@@ -1467,16 +1528,49 @@ def total_free_parameters(self) -> int:
 
     @property
     def priors(self):
+        """A list of all Prior objects in this model (may contain duplicates for shared priors)."""
         return [prior_tuple.prior for prior_tuple in self.prior_tuples]
 
     @property
     def _prior_id_map(self):
         return {prior.id: prior for prior in self.priors}
 
     def prior_with_id(self, prior_id):
+        """Retrieve a prior by its unique integer id.
+
+        Parameters
+        ----------
+        prior_id : int
+            The id of the prior to find.
+
+        Returns
+        -------
+        Prior
+            The prior with the matching id.
+
+        Raises
+        ------
+        KeyError
+            If no prior with the given id exists in this model.
+        """
         return self._prior_id_map[prior_id]
 
     def name_for_prior(self, prior):
+        """Get the underscore-separated name for a prior in this model.
+
+        Searches child prior models recursively. Returns None if the prior
+        is not found.
+
+        Parameters
+        ----------
+        prior : Prior
+            The prior to find.
+
+        Returns
+        -------
+        str or None
+            The name path joined by underscores, e.g. ``"galaxy_centre"``.
+        """
         for prior_model_name, prior_model in self.direct_prior_model_tuples:
             prior_name = prior_model.name_for_prior(prior)
             if prior_name is not None:
@@ -1525,6 +1619,11 @@ def copy_with_fixed_priors(self, instance, excluded_classes=tuple()):
 
     @property
     def path_priors_tuples(self) -> List[Tuple[Path, Prior]]:
+        """All (path, prior) tuples in this model, sorted by prior id.
+
+        Unlike ``unique_path_prior_tuples``, this includes duplicate entries
+        when a prior appears at multiple paths (shared priors).
+        """
         path_priors_tuples = self.path_instance_tuples_for_class(Prior)
         return sorted(path_priors_tuples, key=lambda item: item[1].id)
 
@@ -1546,6 +1645,10 @@ def paths_formatted(self) -> List[Path]:
 
     @property
     def composition(self):
+        """A list of dot-separated path strings for each prior, ordered by prior id.
+
+        For example: ``["galaxy.centre", "galaxy.normalization", "galaxy.sigma"]``.
+        """
         return [".".join(path) for path in self.paths]
 
     def sort_priors_alphabetically(self, priors: Iterable[Prior]) -> List[Prior]:
@@ -1617,14 +1720,20 @@ def all_paths_for_prior(self, prior: Prior) -> Generator[Path, None, None]:
 
     @property
     def path_float_tuples(self):
+        """(path, float) tuples for all fixed float values, excluding Prior objects."""
         return self.path_instance_tuples_for_class(float, ignore_class=Prior)
 
     @property
     def unique_prior_paths(self):
+        """Paths to each unique prior (deduplicated for shared priors), ordered by id."""
         return [item[0] for item in self.unique_path_prior_tuples]
 
     @property
     def unique_path_prior_tuples(self):
+        """(path, prior) tuples deduplicated by prior identity, ordered by id.
+
+        When a prior is shared across multiple paths, only one path is kept.
+        """
         unique = {item[1]: item for item in self.path_priors_tuples}.values()
         return sorted(unique, key=lambda item: item[1].id)
 
@@ -1645,6 +1754,18 @@ def prior_prior_model_dict(self):
         }
 
     def log_prior_list_from(self, parameter_lists: List[List]) -> List:
+        """Compute the total log prior for each parameter vector in a list.
+
+        Parameters
+        ----------
+        parameter_lists
+            A list of physical parameter vectors.
+
+        Returns
+        -------
+        list
+            The summed log prior for each vector.
+        """
         return [
             sum(self.log_prior_list_from_vector(vector=vector))
             for vector in parameter_lists
@@ -1809,6 +1930,7 @@ def model_component_and_parameter_names(self) -> List[str]:
 
     @property
     def joined_paths(self) -> List[str]:
+        """Dot-joined path strings for each unique prior, ordered by id."""
         prior_paths = self.unique_prior_paths
 
         return [".".join(path) for path in prior_paths]

autofit/mapper/prior_model/array.py (modified)

@@ -197,6 +197,7 @@ def tree_unflatten(cls, aux_data, children):
 
     @property
     def prior_class_dict(self):
+        """Map each prior to the class it produces (np.ndarray for direct priors)."""
         return {
             **{
                 prior: cls

autofit/mapper/prior_model/collection.py (modified)

@@ -31,11 +31,28 @@ def name_for_prior(self, prior: Prior) -> str:
                 return name
 
     def tree_flatten(self):
+        """Flatten this collection into a JAX-compatible PyTree representation.
+
+        Returns
+        -------
+        tuple
+            A (children, aux_data) pair where children are the values and
+            aux_data are the corresponding keys.
+        """
         keys, values = zip(*self.items())
         return values, keys
 
     @classmethod
     def tree_unflatten(cls, aux_data, children):
+        """Reconstruct a Collection from a flattened PyTree.
+
+        Parameters
+        ----------
+        aux_data
+            The keys of the collection items.
+        children
+            The values of the collection items.
+        """
         instance = cls()
 
         for key, value in zip(aux_data, children):
@@ -46,6 +63,14 @@ def __contains__(self, item):
         return item in self._dict or item in self._dict.values()
 
     def __getitem__(self, item):
+        """Retrieve an item by string key or integer index.
+
+        Parameters
+        ----------
+        item : str or int
+            A string key for dict-style access, or an integer index
+            for positional access into the values list.
+        """
         if isinstance(item, str):
             return self._dict[item]
         return self.values[item]
@@ -64,9 +89,11 @@ def __repr__(self):
 
     @property
     def values(self):
+        """The model components in this collection as a list."""
         return list(self._dict.values())
 
     def items(self):
+        """The (key, model_component) pairs in this collection."""
         return self._dict.items()
 
     def with_prefix(self, prefix: str):
@@ -79,6 +106,11 @@ def with_prefix(self, prefix: str):
         )
 
     def as_model(self):
+        """Convert all prior models in this collection to Model instances.
+
+        Returns a new Collection where each AbstractPriorModel child has
+        been converted via its own as_model() method.
+        """
         return Collection(
             {
                 key: value.as_model()
@@ -162,6 +194,13 @@ def __init__(
 
     @assert_not_frozen
     def add_dict_items(self, item_dict):
+        """Add all entries from a dictionary, converting values to prior models.
+
+        Parameters
+        ----------
+        item_dict
+            A dictionary mapping string keys to classes, instances, or prior models.
+        """
         for key, value in item_dict.items():
             if isinstance(key, tuple):
                 key = ".".join(key)
@@ -179,11 +218,20 @@ def __eq__(self, other):
 
     @assert_not_frozen
     def append(self, item):
+        """Append an item to the collection with an auto-incremented numeric key.
+
+        The item is converted to an AbstractPriorModel if it is not already one.
+        """
         setattr(self, str(self.item_number), AbstractPriorModel.from_object(item))
         self.item_number += 1
 
     @assert_not_frozen
     def __setitem__(self, key, value):
+        """Set an item by key, converting the value to a prior model.
+
+        Preserves the id of any existing item at the same key so that
+        prior identity is maintained across replacements.
+        """
         obj = AbstractPriorModel.from_object(value)
         try:
             obj.id = getattr(self, str(key)).id
@@ -193,6 +241,12 @@ def __setitem__(self, key, value):
 
     @assert_not_frozen
     def __setattr__(self, key, value):
+        """Set an attribute, automatically converting values to prior models.
+
+        Private attributes (starting with ``_``) are set directly. All other
+        values are wrapped via ``AbstractPriorModel.from_object`` so that
+        plain classes become ``Model`` instances and floats become fixed values.
+        """
         if key.startswith("_"):
             super().__setattr__(key, value)
         else:
@@ -202,6 +256,14 @@ def __setattr__(self, key, value):
                 pass
 
     def remove(self, item):
+        """Remove an item from the collection by value equality.
+
+        Parameters
+        ----------
+        item
+            The item to remove. All entries whose value equals this item
+            are deleted.
+        """
         for key, value in self.__dict__.copy().items():
             if value == item:
                 del self.__dict__[key]
@@ -271,6 +333,11 @@ def gaussian_prior_model_for_arguments(self, arguments):
 
     @property
     def prior_class_dict(self):
+        """Map each prior to the class it will produce when instantiated.
+
+        For child prior models, delegates to their own prior_class_dict.
+        Direct priors on the collection itself map to ModelInstance.
+        """
         return {
             **{
                 prior: cls

test_autofit/mapper/test_model_mapping_expanded.py (added)

@@ -0,0 +1,631 @@
+"""
+Expanded tests for the model mapping API, covering gaps identified in:
+- Collection composition and instance creation
+- Shared (linked) priors across model types
+- Direct use of instance_for_arguments with argument dicts
+- Model tree navigation (object_for_path, path_for_prior, name_for_prior)
+- Edge cases (empty models, deeply nested models, single-parameter models)
+- Model subsetting (with_paths, without_paths)
+- Freezing behavior
+- Assertion checking
+- from_instance round-trips
+- mapper_from_prior_arguments and mapper_from_partial_prior_arguments
+"""
+import copy
+
+import numpy as np
+import pytest
+
+import autofit as af
+from autofit import exc
+from autofit.mapper.prior.abstract import Prior
+
+
+# ---------------------------------------------------------------------------
+# Collection: composition, nesting, instance creation, iteration
+# ---------------------------------------------------------------------------
+class TestCollectionComposition:
+    def test_collection_from_dict(self):
+        model = af.Collection(
+            one=af.Model(af.m.MockClassx2),
+            two=af.Model(af.m.MockClassx2),
+        )
+        assert model.prior_count == 4
+
+    def test_collection_from_list(self):
+        model = af.Collection([af.m.MockClassx2, af.m.MockClassx2])
+        assert model.prior_count == 4
+
+    def test_collection_from_generator(self):
+        model = af.Collection(af.Model(af.m.MockClassx2) for _ in range(3))
+        assert model.prior_count == 6
+
+    def test_nested_collection(self):
+        inner = af.Collection(a=af.m.MockClassx2)
+        outer = af.Collection(inner=inner, extra=af.m.MockClassx2)
+        assert outer.prior_count == 4
+
+    def test_deeply_nested_collection(self):
+        model = af.Collection(
+            level1=af.Collection(
+                level2=af.Collection(
+                    leaf=af.m.MockClassx2,
+                )
+            )
+        )
+        assert model.prior_count == 2
+
+    def test_collection_instance_attribute_access(self):
+        model = af.Collection(gaussian=af.m.MockClassx2, exp=af.m.MockClassx2)
+        instance = model.instance_from_vector([1.0, 2.0, 3.0, 4.0])
+        assert instance.gaussian.one == 1.0
+        assert instance.gaussian.two == 2.0
+        assert instance.exp.one == 3.0
+        assert instance.exp.two == 4.0
+
+    def test_collection_instance_index_access(self):
+        model = af.Collection([af.m.MockClassx2, af.m.MockClassx2])
+        instance = model.instance_from_vector([1.0, 2.0, 3.0, 4.0])
+        assert instance[0].one == 1.0
+        assert instance[1].one == 3.0
+
+    def test_collection_len(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        assert len(model) == 2
+
+    def test_collection_contains(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        assert "a" in model
+        assert "c" not in model
+
+    def test_collection_items(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        keys = [k for k, v in model.items()]
+        assert "a" in keys
+        assert "b" in keys
+
+    def test_collection_getitem_string(self):
+        model = af.Collection(a=af.m.MockClassx2)
+        assert isinstance(model["a"], af.Model)
+
+    def test_collection_append(self):
+        model = af.Collection()
+        model.append(af.m.MockClassx2)
+        model.append(af.m.MockClassx2)
+        assert model.prior_count == 4
+
+    def test_collection_mixed_model_and_fixed(self):
+        """Collection with one free model and one fixed instance."""
+        model = af.Collection(
+            free=af.Model(af.m.MockClassx2),
+        )
+        assert model.prior_count == 2
+
+    def test_empty_collection(self):
+        model = af.Collection()
+        assert model.prior_count == 0
+
+
+# ---------------------------------------------------------------------------
+# Shared (linked) priors
+# ---------------------------------------------------------------------------
+class TestSharedPriors:
+    def test_link_within_model(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = model.two
+        assert model.prior_count == 1
+        instance = model.instance_from_vector([5.0])
+        assert instance.one == instance.two == 5.0
+
+    def test_link_across_collection_children(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        model.a.one = model.b.one  # Link a.one to b.one
+        assert model.prior_count == 3  # 4 - 1 shared
+
+    def test_linked_priors_same_value(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        model.a.one = model.b.one
+        instance = model.instance_from_vector([10.0, 20.0, 30.0])
+        assert instance.a.one == instance.b.one
+
+    def test_link_reduces_unique_prior_count(self):
+        model = af.Model(af.m.MockClassx2)
+        original_count = len(model.unique_prior_tuples)
+        model.one = model.two
+        assert len(model.unique_prior_tuples) == original_count - 1
+
+    def test_linked_prior_identity(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = model.two
+        assert model.one is model.two
+
+
+# ---------------------------------------------------------------------------
+# instance_for_arguments (direct argument dict usage)
+# ---------------------------------------------------------------------------
+class TestInstanceForArguments:
+    def test_model_instance_for_arguments(self):
+        model = af.Model(af.m.MockClassx2)
+        args = {model.one: 10.0, model.two: 20.0}
+        instance = model.instance_for_arguments(args)
+        assert instance.one == 10.0
+        assert instance.two == 20.0
+
+    def test_collection_instance_for_arguments(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        args = {}
+        for name, prior in model.prior_tuples_ordered_by_id:
+            args[prior] = 1.0
+        instance = model.instance_for_arguments(args)
+        assert instance.a.one == 1.0
+        assert instance.b.two == 1.0
+
+    def test_shared_prior_in_arguments(self):
+        """When priors are linked, only one entry is needed in the arguments dict."""
+        model = af.Model(af.m.MockClassx2)
+        model.one = model.two
+        shared_prior = model.one
+        args = {shared_prior: 42.0}
+        instance = model.instance_for_arguments(args)
+        assert instance.one == 42.0
+        assert instance.two == 42.0
+
+    def test_missing_prior_raises(self):
+        model = af.Model(af.m.MockClassx2)
+        args = {model.one: 10.0}  # missing model.two
+        with pytest.raises(KeyError):
+            model.instance_for_arguments(args)
+
+
+# ---------------------------------------------------------------------------
+# Vector and unit vector mapping
+# ---------------------------------------------------------------------------
+class TestVectorMapping:
+    def test_instance_from_vector_basic(self):
+        model = af.Model(af.m.MockClassx2)
+        instance = model.instance_from_vector([3.0, 4.0])
+        assert instance.one == 3.0
+        assert instance.two == 4.0
+
+    def test_vector_length_mismatch_raises(self):
+        model = af.Model(af.m.MockClassx2)
+        with pytest.raises(AssertionError):
+            model.instance_from_vector([1.0])
+
+    def test_unit_vector_length_mismatch_raises(self):
+        model = af.Model(af.m.MockClassx2)
+        with pytest.raises(AssertionError):
+            model.instance_from_unit_vector([0.5])
+
+    def test_vector_from_unit_vector(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+        model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+        physical = model.vector_from_unit_vector([0.0, 1.0])
+        assert physical[0] == pytest.approx(0.0, abs=1e-6)
+        assert physical[1] == pytest.approx(10.0, abs=1e-6)
+
+    def test_instance_from_prior_medians(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = af.UniformPrior(lower_limit=0.0, upper_limit=100.0)
+        model.two = af.UniformPrior(lower_limit=0.0, upper_limit=100.0)
+        instance = model.instance_from_prior_medians()
+        assert instance.one == pytest.approx(50.0)
+        assert instance.two == pytest.approx(50.0)
+
+    def test_random_instance(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0)
+        model.two = af.UniformPrior(lower_limit=0.0, upper_limit=1.0)
+        instance = model.random_instance()
+        assert 0.0 <= instance.one <= 1.0
+        assert 0.0 <= instance.two <= 1.0
+
+
+# ---------------------------------------------------------------------------
+# Model tree navigation
+# ---------------------------------------------------------------------------
+class TestModelTreeNavigation:
+    def test_object_for_path_child_model(self):
+        model = af.Collection(g=af.Model(af.m.MockClassx2))
+        child = model.object_for_path(("g",))
+        assert isinstance(child, af.Model)
+
+    def test_object_for_path_prior(self):
+        model = af.Collection(g=af.Model(af.m.MockClassx2))
+        prior = model.object_for_path(("g", "one"))
+        assert isinstance(prior, Prior)
+
+    def test_paths_matches_prior_count(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        assert len(model.paths) == model.prior_count
+
+    def test_path_for_prior(self):
+        model = af.Collection(g=af.Model(af.m.MockClassx2))
+        prior = model.g.one
+        path = model.path_for_prior(prior)
+        assert path == ("g", "one")
+
+    def test_name_for_prior(self):
+        model = af.Collection(g=af.Model(af.m.MockClassx2))
+        prior = model.g.one
+        name = model.name_for_prior(prior)
+        assert name == "g_one"
+
+    def test_path_instance_tuples_for_class(self):
+        model = af.Collection(g=af.Model(af.m.MockClassx2))
+        tuples = model.path_instance_tuples_for_class(Prior)
+        paths = [t[0] for t in tuples]
+        assert ("g", "one") in paths
+        assert ("g", "two") in paths
+
+    def test_deeply_nested_path(self):
+        inner_model = af.Model(af.m.MockClassx2)
+        inner_collection = af.Collection(leaf=inner_model)
+        outer = af.Collection(branch=inner_collection)
+
+        prior = outer.branch.leaf.one
+        path = outer.path_for_prior(prior)
+        assert path == ("branch", "leaf", "one")
+
+    def test_direct_vs_recursive_prior_tuples(self):
+        model = af.Collection(a=af.m.MockClassx2)
+        assert len(model.direct_prior_tuples) == 0  # Collection has no direct priors
+        assert len(model.prior_tuples) == 2  # But has 2 recursive priors
+
+    def test_direct_prior_model_tuples(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        assert len(model.direct_prior_model_tuples) == 2
+
+
+# ---------------------------------------------------------------------------
+# instance_from_path_arguments and instance_from_prior_name_arguments
+# ---------------------------------------------------------------------------
+class TestPathAndNameArguments:
+    def test_instance_from_path_arguments(self):
+        model = af.Collection(g=af.m.MockClassx2)
+        instance = model.instance_from_path_arguments(
+            {("g", "one"): 10.0, ("g", "two"): 20.0}
+        )
+        assert instance.g.one == 10.0
+        assert instance.g.two == 20.0
+
+    def test_instance_from_prior_name_arguments(self):
+        model = af.Collection(g=af.m.MockClassx2)
+        instance = model.instance_from_prior_name_arguments(
+            {"g_one": 10.0, "g_two": 20.0}
+        )
+        assert instance.g.one == 10.0
+        assert instance.g.two == 20.0
+
+
+# ---------------------------------------------------------------------------
+# Assertions
+# ---------------------------------------------------------------------------
+class TestAssertions:
+    def test_assertion_passes(self):
+        model = af.Model(af.m.MockClassx2)
+        model.add_assertion(model.one > model.two)
+        # one=10 > two=5 should pass
+        instance = model.instance_from_vector([10.0, 5.0])
+        assert instance.one == 10.0
+
+    def test_assertion_fails(self):
+        model = af.Model(af.m.MockClassx2)
+        model.add_assertion(model.one > model.two)
+        with pytest.raises(exc.FitException):
+            model.instance_from_vector([1.0, 10.0])
+
+    def test_ignore_assertions(self):
+        model = af.Model(af.m.MockClassx2)
+        model.add_assertion(model.one > model.two)
+        # Should not raise even though assertion fails
+        instance = model.instance_from_vector([1.0, 10.0], ignore_assertions=True)
+        assert instance.one == 1.0
+
+    def test_multiple_assertions(self):
+        model = af.Model(af.m.MockClassx4)
+        model.add_assertion(model.one > model.two)
+        model.add_assertion(model.three > model.four)
+        # Both pass
+        instance = model.instance_from_vector([10.0, 5.0, 10.0, 5.0])
+        assert instance.one == 10.0
+        # First fails
+        with pytest.raises(exc.FitException):
+            model.instance_from_vector([1.0, 10.0, 10.0, 5.0])
+
+    def test_true_assertion_ignored(self):
+        """Adding True as an assertion should be a no-op."""
+        model = af.Model(af.m.MockClassx2)
+        model.add_assertion(True)
+        assert len(model.assertions) == 0
+
+
+# ---------------------------------------------------------------------------
+# Model subsetting (with_paths, without_paths)
+# ---------------------------------------------------------------------------
+class TestModelSubsetting:
+    def test_with_paths_single_child(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        subset = model.with_paths([("a",)])
+        assert subset.prior_count == 2
+
+    def test_without_paths_single_child(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        subset = model.without_paths([("a",)])
+        assert subset.prior_count == 2
+
+    def test_with_paths_specific_prior(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        subset = model.with_paths([("a", "one")])
+        assert subset.prior_count == 1
+
+    def test_with_prefix(self):
+        model = af.Collection(ab_one=af.m.MockClassx2, cd_two=af.m.MockClassx2)
+        subset = model.with_prefix("ab")
+        assert subset.prior_count == 2
+
+
+# ---------------------------------------------------------------------------
+# Freezing behavior
+# ---------------------------------------------------------------------------
+class TestFreezing:
+    def test_freeze_prevents_modification(self):
+        model = af.Model(af.m.MockClassx2)
+        model.freeze()
+        with pytest.raises(AssertionError):
+            model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0)
+
+    def test_unfreeze_allows_modification(self):
+        model = af.Model(af.m.MockClassx2)
+        model.freeze()
+        model.unfreeze()
+        model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0)
+        assert isinstance(model.one, af.UniformPrior)
+
+    def test_frozen_model_still_creates_instances(self):
+        model = af.Model(af.m.MockClassx2)
+        model.freeze()
+        instance = model.instance_from_vector([1.0, 2.0])
+        assert instance.one == 1.0
+
+    def test_freeze_propagates_to_children(self):
+        model = af.Collection(a=af.m.MockClassx2)
+        model.freeze()
+        with pytest.raises(AssertionError):
+            model.a.one = 1.0
+
+    def test_cached_results_consistent(self):
+        model = af.Model(af.m.MockClassx2)
+        model.freeze()
+        result1 = model.prior_tuples_ordered_by_id
+        result2 = model.prior_tuples_ordered_by_id
+        assert result1 == result2
+
+
+# ---------------------------------------------------------------------------
+# mapper_from_prior_arguments and related
+# ---------------------------------------------------------------------------
+class TestMapperFromPriorArguments:
+    def test_replace_all_priors(self):
+        model = af.Model(af.m.MockClassx2)
+        new_one = af.GaussianPrior(mean=0.0, sigma=1.0)
+        new_two = af.GaussianPrior(mean=5.0, sigma=2.0)
+        new_model = model.mapper_from_prior_arguments(
+            {model.one: new_one, model.two: new_two}
+        )
+        assert new_model.prior_count == 2
+        assert isinstance(new_model.one, af.GaussianPrior)
+
+    def test_partial_replacement(self):
+        model = af.Model(af.m.MockClassx2)
+        new_one = af.GaussianPrior(mean=0.0, sigma=1.0)
+        new_model = model.mapper_from_partial_prior_arguments(
+            {model.one: new_one}
+        )
+        assert new_model.prior_count == 2
+        assert isinstance(new_model.one, af.GaussianPrior)
+        # two should retain its original prior type
+        assert new_model.two is not None
+
+    def test_fix_via_mapper_from_prior_arguments(self):
+        """Replacing a prior with a float effectively fixes that parameter."""
+        model = af.Model(af.m.MockClassx2)
+        new_model = model.mapper_from_prior_arguments(
+            {model.one: 5.0, model.two: model.two}
+        )
+        assert new_model.prior_count == 1
+
+    def test_with_limits(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = af.UniformPrior(lower_limit=0.0, upper_limit=100.0)
+        model.two = af.UniformPrior(lower_limit=0.0, upper_limit=100.0)
+        new_model = model.with_limits([(10.0, 20.0), (30.0, 40.0)])
+        assert new_model.prior_count == 2
+
+
+# ---------------------------------------------------------------------------
+# from_instance round trips
+# ---------------------------------------------------------------------------
+class TestFromInstance:
+    def test_from_simple_instance(self):
+        instance = af.m.MockClassx2(1.0, 2.0)
+        model = af.AbstractPriorModel.from_instance(instance)
+        assert model.prior_count == 0
+
+    def test_from_instance_as_model(self):
+        instance = af.m.MockClassx2(1.0, 2.0)
+        model = af.AbstractPriorModel.from_instance(instance)
+        free_model = model.as_model()
+        assert free_model.prior_count == 2
+
+    def test_from_instance_with_model_classes(self):
+        instance = af.m.MockClassx2(1.0, 2.0)
+        model = af.AbstractPriorModel.from_instance(
+            instance, model_classes=(af.m.MockClassx2,)
+        )
+        assert model.prior_count == 2
+
+    def test_from_list_instance(self):
+        instance_list = [af.m.MockClassx2(1.0, 2.0), af.m.MockClassx2(3.0, 4.0)]
+        model = af.AbstractPriorModel.from_instance(instance_list)
+        assert model.prior_count == 0
+
+    def test_from_dict_instance(self):
+        instance_dict = {
+            "one": af.m.MockClassx2(1.0, 2.0),
+            "two": af.m.MockClassx2(3.0, 4.0),
+        }
+        model = af.AbstractPriorModel.from_instance(instance_dict)
+        assert model.prior_count == 0
+
+
+# ---------------------------------------------------------------------------
+# Fixing parameters and Constant values
+# ---------------------------------------------------------------------------
+class TestFixedParameters:
+    def test_fix_reduces_prior_count(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = 5.0
+        assert model.prior_count == 1
+
+    def test_fixed_value_in_instance(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = 5.0
+        instance = model.instance_from_vector([10.0])
+        assert instance.one == 5.0
+        assert instance.two == 10.0
+
+    def test_fix_all_parameters(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = 5.0
+        model.two = 10.0
+        assert model.prior_count == 0
+        instance = model.instance_from_vector([])
+        assert instance.one == 5.0
+        assert instance.two == 10.0
+
+
+# ---------------------------------------------------------------------------
+# take_attributes (prior passing)
+# ---------------------------------------------------------------------------
+class TestTakeAttributes:
+    def test_take_from_instance(self):
+        model = af.Model(af.m.MockClassx2)
+        source = af.m.MockClassx2(10.0, 20.0)
+        model.take_attributes(source)
+        assert model.prior_count == 0
+
+    def test_take_from_model(self):
+        """Taking attributes from another model copies priors."""
+        source_model = af.Model(af.m.MockClassx2)
+        source_model.one = af.GaussianPrior(mean=5.0, sigma=1.0)
+        source_model.two = af.GaussianPrior(mean=10.0, sigma=2.0)
+
+        target_model = af.Model(af.m.MockClassx2)
+        target_model.take_attributes(source_model)
+        assert isinstance(target_model.one, af.GaussianPrior)
+
+
+# ---------------------------------------------------------------------------
+# Serialization (dict / from_dict)
+# ---------------------------------------------------------------------------
+class TestSerialization:
+    def test_model_dict_roundtrip(self):
+        model = af.Model(af.m.MockClassx2)
+        d = model.dict()
+        loaded = af.AbstractPriorModel.from_dict(d)
+        assert loaded.prior_count == model.prior_count
+
+    def test_collection_dict_roundtrip(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        d = model.dict()
+        loaded = af.AbstractPriorModel.from_dict(d)
+        assert loaded.prior_count == model.prior_count
+
+    def test_fixed_parameter_survives_roundtrip(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = 5.0
+        d = model.dict()
+        loaded = af.AbstractPriorModel.from_dict(d)
+        assert loaded.prior_count == 1
+
+    def test_linked_prior_survives_roundtrip(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = model.two
+        assert model.prior_count == 1
+        d = model.dict()
+        loaded = af.AbstractPriorModel.from_dict(d)
+        assert loaded.prior_count == 1
+
+
+# ---------------------------------------------------------------------------
+# Log prior computation
+# ---------------------------------------------------------------------------
+class TestLogPrior:
+    def test_log_prior_within_bounds(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+        model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+        log_priors = model.log_prior_list_from_vector([5.0, 5.0])
+        assert all(np.isfinite(lp) for lp in log_priors)
+
+    def test_log_prior_outside_bounds(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+        model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+        log_priors = model.log_prior_list_from_vector([15.0, 5.0])
+        # Out-of-bounds value should have a lower (or zero) log prior than in-bounds
+        assert log_priors[0] <= log_priors[1]
+
+
+# ---------------------------------------------------------------------------
+# Edge cases
+# ---------------------------------------------------------------------------
+class TestEdgeCases:
+    def test_single_parameter_model(self):
+        """A model with a single free parameter using explicit prior."""
+        model = af.Model(af.m.MockClassx2)
+        model.two = 5.0  # Fix one parameter
+        assert model.prior_count == 1
+        instance = model.instance_from_vector([42.0])
+        assert instance.one == 42.0
+        assert instance.two == 5.0
+
+    def test_model_copy_preserves_priors(self):
+        model = af.Model(af.m.MockClassx2)
+        copied = model.copy()
+        assert copied.prior_count == model.prior_count
+        # Priors are independent copies (different objects)
+        assert copied.one is not model.one
+
+    def test_model_copy_linked_priors_independent(self):
+        """Copying a model with linked priors preserves the link in the copy."""
+        model = af.Model(af.m.MockClassx2)
+        model.one = model.two
+        assert model.prior_count == 1
+        copied = model.copy()
+        assert copied.prior_count == 1
+        # The copy's internal link is preserved
+        assert copied.one is copied.two
+
+    def test_prior_ordering_is_deterministic(self):
+        """prior_tuples_ordered_by_id should be stable across calls."""
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        order1 = [(n, p.id) for n, p in model.prior_tuples_ordered_by_id]
+        order2 = [(n, p.id) for n, p in model.prior_tuples_ordered_by_id]
+        assert order1 == order2
+
+    def test_prior_count_equals_total_free_parameters(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx4)
+        assert model.prior_count == model.total_free_parameters
+
+    def test_has_model(self):
+        model = af.Collection(a=af.Model(af.m.MockClassx2))
+        assert model.has_model(af.m.MockClassx2)
+        assert not model.has_model(af.m.MockClassx4)
+
+    def test_has_instance(self):
+        model = af.Model(af.m.MockClassx2)
+        assert model.has_instance(Prior)
+        assert not model.has_instance(af.m.MockClassx4)

Instructions

  1. Read the diff above carefully. Identify every renamed, moved, removed, or changed public API (functions, classes, parameters, imports).
  2. Search all files in scripts/ for usages of the old API patterns.
  3. Update each usage to the new API. Preserve the script's existing behaviour and docstrings.
  4. Test your changes by running: bash run_scripts.sh
  5. Fix any test failures and re-run until the test suite is clean.
  6. If a script cannot be fixed automatically (e.g. a missing upstream dependency, an ambiguous API change), leave it unchanged and list it in your PR description under a "Could not update" section with the reason.
  7. After all scripts pass, regenerate notebooks by running:
    pip install ipynb-py-convert
    git clone https://github.com/Jammy2211/PyAutoBuild.git ../PyAutoBuild
    PYTHONPATH=../PyAutoBuild/autobuild python3 ../PyAutoBuild/autobuild/generate.py autofit
    
  8. Commit the regenerated notebooks alongside the script changes.

Branch

Use branch name: feature/feature/model_docs

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions