Skip to content

[API Update] PyAutoFit PR #1172: Add expanded model mapping unit tests #13

@Jammy2211

Description

@Jammy2211

API Update Required

PyAutoFit PR #1172 has been merged with API changes that may affect workspace scripts.

PR Title

Add expanded model mapping unit tests

Changed Files (Python only)

Click to expand diff

autofit/mapper/prior/abstract.py (modified)

@@ -61,6 +61,18 @@ def unit_value_for(self, physical_value: float) -> float:
         return self.message.cdf(physical_value)
 
     def with_message(self, message):
+        """Return a copy of this prior with a different message (distribution).
+
+        Parameters
+        ----------
+        message
+            The new message object defining the prior's distribution.
+
+        Returns
+        -------
+        Prior
+            A copy of this prior using the new message.
+        """
         new = copy(self)
         new.message = message
         return new
@@ -88,6 +100,23 @@ def factor(self):
 
     @staticmethod
     def for_class_and_attribute_name(cls, attribute_name):
+        """Create a prior from the configuration for a given class and attribute.
+
+        Looks up the prior type and parameters in the prior config files
+        for the specified class and attribute name.
+
+        Parameters
+        ----------
+        cls
+            The model class whose config is looked up.
+        attribute_name
+            The name of the attribute on that class.
+
+        Returns
+        -------
+        Prior
+            A prior instance constructed from the config entry.
+        """
         prior_dict = conf.instance.prior_config.for_class_and_suffix_path(
             cls, [attribute_name]
         )
@@ -129,10 +158,31 @@ def instance_for_arguments(
         arguments,
         ignore_assertions=False,
     ):
+        """Look up this prior's value in an arguments dictionary.
+
+        Parameters
+        ----------
+        arguments
+            A dictionary mapping Prior objects to physical values.
+        ignore_assertions
+            Unused for priors (present for interface compatibility).
+        """
         _ = ignore_assertions
         return arguments[self]
 
     def project(self, samples, weights):
+        """Project this prior given samples and log weights from a search.
+
+        Returns a copy of this prior whose message has been updated to
+        reflect the posterior information from the samples.
+
+        Parameters
+        ----------
+        samples
+            Array of sample values for this parameter.
+        weights
+            Log weights for each sample.
+        """
         result = copy(self)
         result.message = self.message.project(
             samples=samples,
@@ -170,6 +220,11 @@ def __str__(self):
     @property
     @abstractmethod
     def parameter_string(self) -> str:
+        """A human-readable string summarizing this prior's parameters.
+
+        Subclasses must implement this to return a description such as
+        ``"mean = 0.0, sigma = 1.0"`` or ``"lower_limit = 0.0, upper_limit = 1.0"``.
+        """
         pass
 
     def __float__(self):
@@ -254,7 +309,22 @@ def name_of_class(cls) -> str:
 
     @property
     def limits(self) -> Tuple[float, float]:
+        """The (lower, upper) bounds of this prior.
+
+        Returns (-inf, inf) by default. Subclasses with finite bounds
+        (e.g. UniformPrior) override this.
+        """
         return (float("-inf"), float("inf"))
 
     def gaussian_prior_model_for_arguments(self, arguments):
+        """Look up this prior in an arguments dict and return the mapped value.
+
+        Used during prior replacement workflows where each prior is mapped
+        to a new prior or fixed value via an arguments dictionary.
+
+        Parameters
+        ----------
+        arguments
+            A dictionary mapping Prior objects to their replacement values.
+        """
         return arguments[self]

autofit/mapper/prior/gaussian.py (modified)

@@ -57,6 +57,13 @@ def __init__(
         )
 
     def tree_flatten(self):
+        """Flatten this prior into a JAX-compatible PyTree representation.
+
+        Returns
+        -------
+        tuple
+            A (children, aux_data) pair where children are (mean, sigma, id).
+        """
         return (self.mean, self.sigma, self.id), ()
 
     @classmethod

autofit/mapper/prior/uniform.py (modified)

@@ -64,23 +64,50 @@ def __init__(
         )
 
     def tree_flatten(self):
+        """Flatten this prior into a JAX-compatible PyTree representation.
+
+        Returns
+        -------
+        tuple
+            A (children, aux_data) pair where children are (lower_limit, upper_limit, id).
+        """
         return (self.lower_limit, self.upper_limit, self.id), ()
 
     @property
     def width(self):
+        """The width of the uniform distribution (upper_limit - lower_limit)."""
         return self.upper_limit - self.lower_limit
 
     def with_limits(
         self,
         lower_limit: float,
         upper_limit: float,
     ) -> "Prior":
+        """Create a new UniformPrior with different bounds.
+
+        Parameters
+        ----------
+        lower_limit
+            The new lower bound.
+        upper_limit
+            The new upper bound.
+        """
         return UniformPrior(
             lower_limit=lower_limit,
             upper_limit=upper_limit,
         )
 
     def logpdf(self, x):
+        """Compute the log probability density at x.
+
+        Adjusts boundary values by epsilon to avoid evaluating exactly at
+        the distribution edges where the PDF is undefined.
+
+        Parameters
+        ----------
+        x
+            The value at which to evaluate the log PDF.
+        """
         # TODO: handle x as a numpy array
         if x == self.lower_limit:
             x += epsilon
@@ -102,6 +129,7 @@ def dict(self) -> dict:
 
     @property
     def parameter_string(self) -> str:
+        """A human-readable string summarizing the prior's lower and upper limits."""
         return f"lower_limit = {self.lower_limit}, upper_limit = {self.upper_limit}"
 
     def value_for(self, unit: float) -> float:
@@ -142,4 +170,5 @@ def log_prior_from_value(self, value, xp=np):
 
     @property
     def limits(self) -> Tuple[float, float]:
+        """The (lower_limit, upper_limit) bounds of this uniform prior."""
         return self.lower_limit, self.upper_limit
\ No newline at end of file

autofit/mapper/prior_model/abstract.py (modified)

@@ -41,6 +41,20 @@
 class Limits:
     @staticmethod
     def for_class_and_attributes_name(cls, attribute_name):
+        """Look up the (lower, upper) limits for a class attribute from config.
+
+        Parameters
+        ----------
+        cls
+            The model class.
+        attribute_name
+            The name of the attribute on that class.
+
+        Returns
+        -------
+        tuple
+            A (lower, upper) pair of limit values.
+        """
         limit_dict = conf.instance.prior_config.for_class_and_suffix_path(
             cls, [attribute_name, "limits"]
         )
@@ -165,6 +179,11 @@ def __init__(self, label=None):
 
     @property
     def assertions(self):
+        """The list of assertion constraints attached to this model.
+
+        Assertions are checked when creating instances; a failed assertion
+        raises FitException, causing the non-linear search to resample.
+        """
         return self._assertions
 
     @assertions.setter
@@ -441,11 +460,28 @@ def add_assertion(self, assertion, name=""):
 
     @property
     def name(self):
+        """The class name of this prior model (e.g. ``"Model"`` or ``"Collection"``)."""
         return self.__class__.__name__
 
     # noinspection PyUnusedLocal
     @staticmethod
     def from_object(t, *args, **kwargs):
+        """Convert an arbitrary object into an appropriate prior model representation.
+
+        - Classes become ``Model`` instances.
+        - Lists and dicts become ``Collection`` instances.
+        - Floats become ``Constant`` instances.
+        - Existing prior models and other objects are returned as-is.
+
+        Parameters
+        ----------
+        t
+            A class, list, dict, float, or existing prior model.
+
+        Returns
+        -------
+        An AbstractPriorModel, Constant, or the original object.
+        """
         if inspect.isclass(t):
             from .prior_model import Model
 
@@ -606,6 +642,7 @@ def prior_tuples_ordered_by_id(self):
 
     @property
     def priors_ordered_by_id(self):
+        """Unique priors sorted by their id, defining the canonical parameter ordering."""
         return [prior for _, prior in self.prior_tuples_ordered_by_id]
 
     def vector_from_unit_vector(self, unit_vector):
@@ -836,6 +873,16 @@ def is_only_model(self, cls) -> bool:
         return len(cls_models) > 0 and len(cls_models) == len(other_models)
 
     def replacing(self, arguments):
+        """Return a new model with some priors replaced.
+
+        This is a convenience alias for ``mapper_from_partial_prior_arguments``.
+        Priors not present in the arguments dict are kept unchanged.
+
+        Parameters
+        ----------
+        arguments : dict
+            A dictionary mapping existing Prior objects to new priors or fixed values.
+        """
         return self.mapper_from_partial_prior_arguments(arguments)
 
     @classmethod
@@ -1211,6 +1258,10 @@ def from_instance(
         return result
 
     def items(self):
+        """Return (name, value) pairs for all public, non-internal attributes.
+
+        Excludes private attributes (prefixed with ``_``), ``cls``, and ``id``.
+        """
         return [
             (key, value)
             for key, value in self.__dict__.items()
@@ -1225,23 +1276,27 @@ def direct_prior_tuples(self):
     @property
     @cast_collection(InstanceNameValue)
     def direct_instance_tuples(self):
+        """(name, value) tuples for direct float and Constant attributes."""
         return self.direct_tuples_with_type(float) + self.direct_tuples_with_type(
             Constant
         )
 
     @property
     @cast_collection(PriorModelNameValue)
     def prior_model_tuples(self):
+        """(name, prior_model) tuples for direct child AbstractPriorModel attributes."""
         return self.direct_tuples_with_type(AbstractPriorModel)
 
     @property
     @cast_collection(PriorModelNameValue)
     def direct_prior_model_tuples(self):
+        """(name, prior_model) tuples for immediate child prior models (non-recursive)."""
         return self.direct_tuples_with_type(AbstractPriorModel)
 
     @property
     @cast_collection(PriorModelNameValue)
     def direct_tuple_priors(self):
+        """(name, tuple_prior) tuples for direct TuplePrior attributes."""
         return self.direct_tuples_with_type(TuplePrior)
 
     @property
@@ -1267,6 +1322,7 @@ def direct_prior_tuples(self):
     @property
     @cast_collection(DeferredNameValue)
     def direct_deferred_tuples(self):
+        """(name, deferred_argument) tuples for direct DeferredArgument attributes."""
         return self.direct_tuples_with_type(DeferredArgument)
 
     @property
@@ -1298,6 +1354,11 @@ def instance_tuples(self):
 
     @property
     def prior_class_dict(self):
+        """Map each prior to the class it will produce when instantiated.
+
+        Direct priors on this model map to ``self.cls``. Child prior models
+        contribute their own mappings recursively.
+        """
         from autofit.mapper.prior_model.annotation import AnnotationPriorModel
 
         d = {prior[1]: self.cls for prior in self.prior_tuples}
@@ -1467,16 +1528,49 @@ def total_free_parameters(self) -> int:
 
     @property
     def priors(self):
+        """A list of all Prior objects in this model (may contain duplicates for shared priors)."""
         return [prior_tuple.prior for prior_tuple in self.prior_tuples]
 
     @property
     def _prior_id_map(self):
         return {prior.id: prior for prior in self.priors}
 
     def prior_with_id(self, prior_id):
+        """Retrieve a prior by its unique integer id.
+
+        Parameters
+        ----------
+        prior_id : int
+            The id of the prior to find.
+
+        Returns
+        -------
+        Prior
+            The prior with the matching id.
+
+        Raises
+        ------
+        KeyError
+            If no prior with the given id exists in this model.
+        """
         return self._prior_id_map[prior_id]
 
     def name_for_prior(self, prior):
+        """Get the underscore-separated name for a prior in this model.
+
+        Searches child prior models recursively. Returns None if the prior
+        is not found.
+
+        Parameters
+        ----------
+        prior : Prior
+            The prior to find.
+
+        Returns
+        -------
+        str or None
+            The name path joined by underscores, e.g. ``"galaxy_centre"``.
+        """
         for prior_model_name, prior_model in self.direct_prior_model_tuples:
             prior_name = prior_model.name_for_prior(prior)
             if prior_name is not None:
@@ -1525,6 +1619,11 @@ def copy_with_fixed_priors(self, instance, excluded_classes=tuple()):
 
     @property
     def path_priors_tuples(self) -> List[Tuple[Path, Prior]]:
+        """All (path, prior) tuples in this model, sorted by prior id.
+
+        Unlike ``unique_path_prior_tuples``, this includes duplicate entries
+        when a prior appears at multiple paths (shared priors).
+        """
         path_priors_tuples = self.path_instance_tuples_for_class(Prior)
         return sorted(path_priors_tuples, key=lambda item: item[1].id)
 
@@ -1546,6 +1645,10 @@ def paths_formatted(self) -> List[Path]:
 
     @property
     def composition(self):
+        """A list of dot-separated path strings for each prior, ordered by prior id.
+
+        For example: ``["galaxy.centre", "galaxy.normalization", "galaxy.sigma"]``.
+        """
         return [".".join(path) for path in self.paths]
 
     def sort_priors_alphabetically(self, priors: Iterable[Prior]) -> List[Prior]:
@@ -1617,14 +1720,20 @@ def all_paths_for_prior(self, prior: Prior) -> Generator[Path, None, None]:
 
     @property
     def path_float_tuples(self):
+        """(path, float) tuples for all fixed float values, excluding Prior objects."""
         return self.path_instance_tuples_for_class(float, ignore_class=Prior)
 
     @property
     def unique_prior_paths(self):
+        """Paths to each unique prior (deduplicated for shared priors), ordered by id."""
         return [item[0] for item in self.unique_path_prior_tuples]
 
     @property
     def unique_path_prior_tuples(self):
+        """(path, prior) tuples deduplicated by prior identity, ordered by id.
+
+        When a prior is shared across multiple paths, only one path is kept.
+        """
         unique = {item[1]: item for item in self.path_priors_tuples}.values()
         return sorted(unique, key=lambda item: item[1].id)
 
@@ -1645,6 +1754,18 @@ def prior_prior_model_dict(self):
         }
 
     def log_prior_list_from(self, parameter_lists: List[List]) -> List:
+        """Compute the total log prior for each parameter vector in a list.
+
+        Parameters
+        ----------
+        parameter_lists
+            A list of physical parameter vectors.
+
+        Returns
+        -------
+        list
+            The summed log prior for each vector.
+        """
         return [
             sum(self.log_prior_list_from_vector(vector=vector))
             for vector in parameter_lists
@@ -1809,6 +1930,7 @@ def model_component_and_parameter_names(self) -> List[str]:
 
     @property
     def joined_paths(self) -> List[str]:
+        """Dot-joined path strings for each unique prior, ordered by id."""
         prior_paths = self.unique_prior_paths
 
         return [".".join(path) for path in prior_paths]

autofit/mapper/prior_model/array.py (modified)

@@ -197,6 +197,7 @@ def tree_unflatten(cls, aux_data, children):
 
     @property
     def prior_class_dict(self):
+        """Map each prior to the class it produces (np.ndarray for direct priors)."""
         return {
             **{
                 prior: cls

autofit/mapper/prior_model/collection.py (modified)

@@ -31,11 +31,28 @@ def name_for_prior(self, prior: Prior) -> str:
                 return name
 
     def tree_flatten(self):
+        """Flatten this collection into a JAX-compatible PyTree representation.
+
+        Returns
+        -------
+        tuple
+            A (children, aux_data) pair where children are the values and
+            aux_data are the corresponding keys.
+        """
         keys, values = zip(*self.items())
         return values, keys
 
     @classmethod
     def tree_unflatten(cls, aux_data, children):
+        """Reconstruct a Collection from a flattened PyTree.
+
+        Parameters
+        ----------
+        aux_data
+            The keys of the collection items.
+        children
+            The values of the collection items.
+        """
         instance = cls()
 
         for key, value in zip(aux_data, children):
@@ -46,6 +63,14 @@ def __contains__(self, item):
         return item in self._dict or item in self._dict.values()
 
     def __getitem__(self, item):
+        """Retrieve an item by string key or integer index.
+
+        Parameters
+        ----------
+        item : str or int
+            A string key for dict-style access, or an integer index
+            for positional access into the values list.
+        """
         if isinstance(item, str):
             return self._dict[item]
         return self.values[item]
@@ -64,9 +89,11 @@ def __repr__(self):
 
     @property
     def values(self):
+        """The model components in this collection as a list."""
         return list(self._dict.values())
 
     def items(self):
+        """The (key, model_component) pairs in this collection."""
         return self._dict.items()
 
     def with_prefix(self, prefix: str):
@@ -79,6 +106,11 @@ def with_prefix(self, prefix: str):
         )
 
     def as_model(self):
+        """Convert all prior models in this collection to Model instances.
+
+        Returns a new Collection where each AbstractPriorModel child has
+        been converted via its own as_model() method.
+        """
         return Collection(
             {
                 key: value.as_model()
@@ -162,6 +194,13 @@ def __init__(
 
     @assert_not_frozen
     def add_dict_items(self, item_dict):
+        """Add all entries from a dictionary, converting values to prior models.
+
+        Parameters
+        ----------
+        item_dict
+            A dictionary mapping string keys to classes, instances, or prior models.
+        """
         for key, value in item_dict.items():
             if isinstance(key, tuple):
                 key = ".".join(key)
@@ -179,11 +218,20 @@ def __eq__(self, other):
 
     @assert_not_frozen
     def append(self, item):
+        """Append an item to the collection with an auto-incremented numeric key.
+
+        The item is converted to an AbstractPriorModel if it is not already one.
+        """
         setattr(self, str(self.item_number), AbstractPriorModel.from_object(item))
         self.item_number += 1
 
     @assert_not_frozen
     def __setitem__(self, key, value):
+        """Set an item by key, converting the value to a prior model.
+
+        Preserves the id of any existing item at the same key so that
+        prior identity is maintained across replacements.
+        """
         obj = AbstractPriorModel.from_object(value)
         try:
             obj.id = getattr(self, str(key)).id
@@ -193,6 +241,12 @@ def __setitem__(self, key, value):
 
     @assert_not_frozen
     def __setattr__(self, key, value):
+        """Set an attribute, automatically converting values to prior models.
+
+        Private attributes (starting with ``_``) are set directly. All other
+        values are wrapped via ``AbstractPriorModel.from_object`` so that
+        plain classes become ``Model`` instances and floats become fixed values.
+        """
         if key.startswith("_"):
             super().__setattr__(key, value)
         else:
@@ -202,6 +256,14 @@ def __setattr__(self, key, value):
                 pass
 
     def remove(self, item):
+        """Remove an item from the collection by value equality.
+
+        Parameters
+        ----------
+        item
+            The item to remove. All entries whose value equals this item
+            are deleted.
+        """
         for key, value in self.__dict__.copy().items():
             if value == item:
                 del self.__dict__[key]
@@ -271,6 +333,11 @@ def gaussian_prior_model_for_arguments(self, arguments):
 
     @property
     def prior_class_dict(self):
+        """Map each prior to the class it will produce when instantiated.
+
+        For child prior models, delegates to their own prior_class_dict.
+        Direct priors on the collection itself map to ModelInstance.
+        """
         return {
             **{
                 prior: cls

test_autofit/mapper/test_model_mapping_expanded.py (added)

@@ -0,0 +1,631 @@
+"""
+Expanded tests for the model mapping API, covering gaps identified in:
+- Collection composition and instance creation
+- Shared (linked) priors across model types
+- Direct use of instance_for_arguments with argument dicts
+- Model tree navigation (object_for_path, path_for_prior, name_for_prior)
+- Edge cases (empty models, deeply nested models, single-parameter models)
+- Model subsetting (with_paths, without_paths)
+- Freezing behavior
+- Assertion checking
+- from_instance round-trips
+- mapper_from_prior_arguments and mapper_from_partial_prior_arguments
+"""
+import copy
+
+import numpy as np
+import pytest
+
+import autofit as af
+from autofit import exc
+from autofit.mapper.prior.abstract import Prior
+
+
+# ---------------------------------------------------------------------------
+# Collection: composition, nesting, instance creation, iteration
+# ---------------------------------------------------------------------------
+class TestCollectionComposition:
+    def test_collection_from_dict(self):
+        model = af.Collection(
+            one=af.Model(af.m.MockClassx2),
+            two=af.Model(af.m.MockClassx2),
+        )
+        assert model.prior_count == 4
+
+    def test_collection_from_list(self):
+        model = af.Collection([af.m.MockClassx2, af.m.MockClassx2])
+        assert model.prior_count == 4
+
+    def test_collection_from_generator(self):
+        model = af.Collection(af.Model(af.m.MockClassx2) for _ in range(3))
+        assert model.prior_count == 6
+
+    def test_nested_collection(self):
+        inner = af.Collection(a=af.m.MockClassx2)
+        outer = af.Collection(inner=inner, extra=af.m.MockClassx2)
+        assert outer.prior_count == 4
+
+    def test_deeply_nested_collection(self):
+        model = af.Collection(
+            level1=af.Collection(
+                level2=af.Collection(
+                    leaf=af.m.MockClassx2,
+                )
+            )
+        )
+        assert model.prior_count == 2
+
+    def test_collection_instance_attribute_access(self):
+        model = af.Collection(gaussian=af.m.MockClassx2, exp=af.m.MockClassx2)
+        instance = model.instance_from_vector([1.0, 2.0, 3.0, 4.0])
+        assert instance.gaussian.one == 1.0
+        assert instance.gaussian.two == 2.0
+        assert instance.exp.one == 3.0
+        assert instance.exp.two == 4.0
+
+    def test_collection_instance_index_access(self):
+        model = af.Collection([af.m.MockClassx2, af.m.MockClassx2])
+        instance = model.instance_from_vector([1.0, 2.0, 3.0, 4.0])
+        assert instance[0].one == 1.0
+        assert instance[1].one == 3.0
+
+    def test_collection_len(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        assert len(model) == 2
+
+    def test_collection_contains(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        assert "a" in model
+        assert "c" not in model
+
+    def test_collection_items(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        keys = [k for k, v in model.items()]
+        assert "a" in keys
+        assert "b" in keys
+
+    def test_collection_getitem_string(self):
+        model = af.Collection(a=af.m.MockClassx2)
+        assert isinstance(model["a"], af.Model)
+
+    def test_collection_append(self):
+        model = af.Collection()
+        model.append(af.m.MockClassx2)
+        model.append(af.m.MockClassx2)
+        assert model.prior_count == 4
+
+    def test_collection_mixed_model_and_fixed(self):
+        """Collection with one free model and one fixed instance."""
+        model = af.Collection(
+            free=af.Model(af.m.MockClassx2),
+        )
+        assert model.prior_count == 2
+
+    def test_empty_collection(self):
+        model = af.Collection()
+        assert model.prior_count == 0
+
+
+# ---------------------------------------------------------------------------
+# Shared (linked) priors
+# ---------------------------------------------------------------------------
+class TestSharedPriors:
+    def test_link_within_model(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = model.two
+        assert model.prior_count == 1
+        instance = model.instance_from_vector([5.0])
+        assert instance.one == instance.two == 5.0
+
+    def test_link_across_collection_children(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        model.a.one = model.b.one  # Link a.one to b.one
+        assert model.prior_count == 3  # 4 - 1 shared
+
+    def test_linked_priors_same_value(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        model.a.one = model.b.one
+        instance = model.instance_from_vector([10.0, 20.0, 30.0])
+        assert instance.a.one == instance.b.one
+
+    def test_link_reduces_unique_prior_count(self):
+        model = af.Model(af.m.MockClassx2)
+        original_count = len(model.unique_prior_tuples)
+        model.one = model.two
+        assert len(model.unique_prior_tuples) == original_count - 1
+
+    def test_linked_prior_identity(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = model.two
+        assert model.one is model.two
+
+
+# ---------------------------------------------------------------------------
+# instance_for_arguments (direct argument dict usage)
+# ---------------------------------------------------------------------------
+class TestInstanceForArguments:
+    def test_model_instance_for_arguments(self):
+        model = af.Model(af.m.MockClassx2)
+        args = {model.one: 10.0, model.two: 20.0}
+        instance = model.instance_for_arguments(args)
+        assert instance.one == 10.0
+        assert instance.two == 20.0
+
+    def test_collection_instance_for_arguments(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        args = {}
+        for name, prior in model.prior_tuples_ordered_by_id:
+            args[prior] = 1.0
+        instance = model.instance_for_arguments(args)
+        assert instance.a.one == 1.0
+        assert instance.b.two == 1.0
+
+    def test_shared_prior_in_arguments(self):
+        """When priors are linked, only one entry is needed in the arguments dict."""
+        model = af.Model(af.m.MockClassx2)
+        model.one = model.two
+        shared_prior = model.one
+        args = {shared_prior: 42.0}
+        instance = model.instance_for_arguments(args)
+        assert instance.one == 42.0
+        assert instance.two == 42.0
+
+    def test_missing_prior_raises(self):
+        model = af.Model(af.m.MockClassx2)
+        args = {model.one: 10.0}  # missing model.two
+        with pytest.raises(KeyError):
+            model.instance_for_arguments(args)
+
+
+# ---------------------------------------------------------------------------
+# Vector and unit vector mapping
+# ---------------------------------------------------------------------------
+class TestVectorMapping:
+    def test_instance_from_vector_basic(self):
+        model = af.Model(af.m.MockClassx2)
+        instance = model.instance_from_vector([3.0, 4.0])
+        assert instance.one == 3.0
+        assert instance.two == 4.0
+
+    def test_vector_length_mismatch_raises(self):
+        model = af.Model(af.m.MockClassx2)
+        with pytest.raises(AssertionError):
+            model.instance_from_vector([1.0])
+
+    def test_unit_vector_length_mismatch_raises(self):
+        model = af.Model(af.m.MockClassx2)
+        with pytest.raises(AssertionError):
+            model.instance_from_unit_vector([0.5])
+
+    def test_vector_from_unit_vector(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+        model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+        physical = model.vector_from_unit_vector([0.0, 1.0])
+        assert physical[0] == pytest.approx(0.0, abs=1e-6)
+        assert physical[1] == pytest.approx(10.0, abs=1e-6)
+
+    def test_instance_from_prior_medians(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = af.UniformPrior(lower_limit=0.0, upper_limit=100.0)
+        model.two = af.UniformPrior(lower_limit=0.0, upper_limit=100.0)
+        instance = model.instance_from_prior_medians()
+        assert instance.one == pytest.approx(50.0)
+        assert instance.two == pytest.approx(50.0)
+
+    def test_random_instance(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0)
+        model.two = af.UniformPrior(lower_limit=0.0, upper_limit=1.0)
+        instance = model.random_instance()
+        assert 0.0 <= instance.one <= 1.0
+        assert 0.0 <= instance.two <= 1.0
+
+
+# ---------------------------------------------------------------------------
+# Model tree navigation
+# ---------------------------------------------------------------------------
+class TestModelTreeNavigation:
+    def test_object_for_path_child_model(self):
+        model = af.Collection(g=af.Model(af.m.MockClassx2))
+        child = model.object_for_path(("g",))
+        assert isinstance(child, af.Model)
+
+    def test_object_for_path_prior(self):
+        model = af.Collection(g=af.Model(af.m.MockClassx2))
+        prior = model.object_for_path(("g", "one"))
+        assert isinstance(prior, Prior)
+
+    def test_paths_matches_prior_count(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        assert len(model.paths) == model.prior_count
+
+    def test_path_for_prior(self):
+        model = af.Collection(g=af.Model(af.m.MockClassx2))
+        prior = model.g.one
+        path = model.path_for_prior(prior)
+        assert path == ("g", "one")
+
+    def test_name_for_prior(self):
+        model = af.Collection(g=af.Model(af.m.MockClassx2))
+        prior = model.g.one
+        name = model.name_for_prior(prior)
+        assert name == "g_one"
+
+    def test_path_instance_tuples_for_class(self):
+        model = af.Collection(g=af.Model(af.m.MockClassx2))
+        tuples = model.path_instance_tuples_for_class(Prior)
+        paths = [t[0] for t in tuples]
+        assert ("g", "one") in paths
+        assert ("g", "two") in paths
+
+    def test_deeply_nested_path(self):
+        inner_model = af.Model(af.m.MockClassx2)
+        inner_collection = af.Collection(leaf=inner_model)
+        outer = af.Collection(branch=inner_collection)
+
+        prior = outer.branch.leaf.one
+        path = outer.path_for_prior(prior)
+        assert path == ("branch", "leaf", "one")
+
+    def test_direct_vs_recursive_prior_tuples(self):
+        model = af.Collection(a=af.m.MockClassx2)
+        assert len(model.direct_prior_tuples) == 0  # Collection has no direct priors
+        assert len(model.prior_tuples) == 2  # But has 2 recursive priors
+
+    def test_direct_prior_model_tuples(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        assert len(model.direct_prior_model_tuples) == 2
+
+
+# ---------------------------------------------------------------------------
+# instance_from_path_arguments and instance_from_prior_name_arguments
+# ---------------------------------------------------------------------------
+class TestPathAndNameArguments:
+    def test_instance_from_path_arguments(self):
+        model = af.Collection(g=af.m.MockClassx2)
+        instance = model.instance_from_path_arguments(
+            {("g", "one"): 10.0, ("g", "two"): 20.0}
+        )
+        assert instance.g.one == 10.0
+        assert instance.g.two == 20.0
+
+    def test_instance_from_prior_name_arguments(self):
+        model = af.Collection(g=af.m.MockClassx2)
+        instance = model.instance_from_prior_name_arguments(
+            {"g_one": 10.0, "g_two": 20.0}
+        )
+        assert instance.g.one == 10.0
+        assert instance.g.two == 20.0
+
+
+# ---------------------------------------------------------------------------
+# Assertions
+# ---------------------------------------------------------------------------
+class TestAssertions:
+    def test_assertion_passes(self):
+        model = af.Model(af.m.MockClassx2)
+        model.add_assertion(model.one > model.two)
+        # one=10 > two=5 should pass
+        instance = model.instance_from_vector([10.0, 5.0])
+        assert instance.one == 10.0
+
+    def test_assertion_fails(self):
+        model = af.Model(af.m.MockClassx2)
+        model.add_assertion(model.one > model.two)
+        with pytest.raises(exc.FitException):
+            model.instance_from_vector([1.0, 10.0])
+
+    def test_ignore_assertions(self):
+        model = af.Model(af.m.MockClassx2)
+        model.add_assertion(model.one > model.two)
+        # Should not raise even though assertion fails
+        instance = model.instance_from_vector([1.0, 10.0], ignore_assertions=True)
+        assert instance.one == 1.0
+
+    def test_multiple_assertions(self):
+        model = af.Model(af.m.MockClassx4)
+        model.add_assertion(model.one > model.two)
+        model.add_assertion(model.three > model.four)
+        # Both pass
+        instance = model.instance_from_vector([10.0, 5.0, 10.0, 5.0])
+        assert instance.one == 10.0
+        # First fails
+        with pytest.raises(exc.FitException):
+            model.instance_from_vector([1.0, 10.0, 10.0, 5.0])
+
+    def test_true_assertion_ignored(self):
+        """Adding True as an assertion should be a no-op."""
+        model = af.Model(af.m.MockClassx2)
+        model.add_assertion(True)
+        assert len(model.assertions) == 0
+
+
+# ---------------------------------------------------------------------------
+# Model subsetting (with_paths, without_paths)
+# ---------------------------------------------------------------------------
+class TestModelSubsetting:
+    def test_with_paths_single_child(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        subset = model.with_paths([("a",)])
+        assert subset.prior_count == 2
+
+    def test_without_paths_single_child(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        subset = model.without_paths([("a",)])
+        assert subset.prior_count == 2
+
+    def test_with_paths_specific_prior(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        subset = model.with_paths([("a", "one")])
+        assert subset.prior_count == 1
+
+    def test_with_prefix(self):
+        model = af.Collection(ab_one=af.m.MockClassx2, cd_two=af.m.MockClassx2)
+        subset = model.with_prefix("ab")
+        assert subset.prior_count == 2
+
+
+# ---------------------------------------------------------------------------
+# Freezing behavior
+# ---------------------------------------------------------------------------
+class TestFreezing:
+    def test_freeze_prevents_modification(self):
+        model = af.Model(af.m.MockClassx2)
+        model.freeze()
+        with pytest.raises(AssertionError):
+            model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0)
+
+    def test_unfreeze_allows_modification(self):
+        model = af.Model(af.m.MockClassx2)
+        model.freeze()
+        model.unfreeze()
+        model.one = af.UniformPrior(lower_limit=0.0, upper_limit=1.0)
+        assert isinstance(model.one, af.UniformPrior)
+
+    def test_frozen_model_still_creates_instances(self):
+        model = af.Model(af.m.MockClassx2)
+        model.freeze()
+        instance = model.instance_from_vector([1.0, 2.0])
+        assert instance.one == 1.0
+
+    def test_freeze_propagates_to_children(self):
+        model = af.Collection(a=af.m.MockClassx2)
+        model.freeze()
+        with pytest.raises(AssertionError):
+            model.a.one = 1.0
+
+    def test_cached_results_consistent(self):
+        model = af.Model(af.m.MockClassx2)
+        model.freeze()
+        result1 = model.prior_tuples_ordered_by_id
+        result2 = model.prior_tuples_ordered_by_id
+        assert result1 == result2
+
+
+# ---------------------------------------------------------------------------
+# mapper_from_prior_arguments and related
+# ---------------------------------------------------------------------------
+class TestMapperFromPriorArguments:
+    def test_replace_all_priors(self):
+        model = af.Model(af.m.MockClassx2)
+        new_one = af.GaussianPrior(mean=0.0, sigma=1.0)
+        new_two = af.GaussianPrior(mean=5.0, sigma=2.0)
+        new_model = model.mapper_from_prior_arguments(
+            {model.one: new_one, model.two: new_two}
+        )
+        assert new_model.prior_count == 2
+        assert isinstance(new_model.one, af.GaussianPrior)
+
+    def test_partial_replacement(self):
+        model = af.Model(af.m.MockClassx2)
+        new_one = af.GaussianPrior(mean=0.0, sigma=1.0)
+        new_model = model.mapper_from_partial_prior_arguments(
+            {model.one: new_one}
+        )
+        assert new_model.prior_count == 2
+        assert isinstance(new_model.one, af.GaussianPrior)
+        # two should retain its original prior type
+        assert new_model.two is not None
+
+    def test_fix_via_mapper_from_prior_arguments(self):
+        """Replacing a prior with a float effectively fixes that parameter."""
+        model = af.Model(af.m.MockClassx2)
+        new_model = model.mapper_from_prior_arguments(
+            {model.one: 5.0, model.two: model.two}
+        )
+        assert new_model.prior_count == 1
+
+    def test_with_limits(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = af.UniformPrior(lower_limit=0.0, upper_limit=100.0)
+        model.two = af.UniformPrior(lower_limit=0.0, upper_limit=100.0)
+        new_model = model.with_limits([(10.0, 20.0), (30.0, 40.0)])
+        assert new_model.prior_count == 2
+
+
+# ---------------------------------------------------------------------------
+# from_instance round trips
+# ---------------------------------------------------------------------------
+class TestFromInstance:
+    def test_from_simple_instance(self):
+        instance = af.m.MockClassx2(1.0, 2.0)
+        model = af.AbstractPriorModel.from_instance(instance)
+        assert model.prior_count == 0
+
+    def test_from_instance_as_model(self):
+        instance = af.m.MockClassx2(1.0, 2.0)
+        model = af.AbstractPriorModel.from_instance(instance)
+        free_model = model.as_model()
+        assert free_model.prior_count == 2
+
+    def test_from_instance_with_model_classes(self):
+        instance = af.m.MockClassx2(1.0, 2.0)
+        model = af.AbstractPriorModel.from_instance(
+            instance, model_classes=(af.m.MockClassx2,)
+        )
+        assert model.prior_count == 2
+
+    def test_from_list_instance(self):
+        instance_list = [af.m.MockClassx2(1.0, 2.0), af.m.MockClassx2(3.0, 4.0)]
+        model = af.AbstractPriorModel.from_instance(instance_list)
+        assert model.prior_count == 0
+
+    def test_from_dict_instance(self):
+        instance_dict = {
+            "one": af.m.MockClassx2(1.0, 2.0),
+            "two": af.m.MockClassx2(3.0, 4.0),
+        }
+        model = af.AbstractPriorModel.from_instance(instance_dict)
+        assert model.prior_count == 0
+
+
+# ---------------------------------------------------------------------------
+# Fixing parameters and Constant values
+# ---------------------------------------------------------------------------
+class TestFixedParameters:
+    def test_fix_reduces_prior_count(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = 5.0
+        assert model.prior_count == 1
+
+    def test_fixed_value_in_instance(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = 5.0
+        instance = model.instance_from_vector([10.0])
+        assert instance.one == 5.0
+        assert instance.two == 10.0
+
+    def test_fix_all_parameters(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = 5.0
+        model.two = 10.0
+        assert model.prior_count == 0
+        instance = model.instance_from_vector([])
+        assert instance.one == 5.0
+        assert instance.two == 10.0
+
+
+# ---------------------------------------------------------------------------
+# take_attributes (prior passing)
+# ---------------------------------------------------------------------------
+class TestTakeAttributes:
+    def test_take_from_instance(self):
+        model = af.Model(af.m.MockClassx2)
+        source = af.m.MockClassx2(10.0, 20.0)
+        model.take_attributes(source)
+        assert model.prior_count == 0
+
+    def test_take_from_model(self):
+        """Taking attributes from another model copies priors."""
+        source_model = af.Model(af.m.MockClassx2)
+        source_model.one = af.GaussianPrior(mean=5.0, sigma=1.0)
+        source_model.two = af.GaussianPrior(mean=10.0, sigma=2.0)
+
+        target_model = af.Model(af.m.MockClassx2)
+        target_model.take_attributes(source_model)
+        assert isinstance(target_model.one, af.GaussianPrior)
+
+
+# ---------------------------------------------------------------------------
+# Serialization (dict / from_dict)
+# ---------------------------------------------------------------------------
+class TestSerialization:
+    def test_model_dict_roundtrip(self):
+        model = af.Model(af.m.MockClassx2)
+        d = model.dict()
+        loaded = af.AbstractPriorModel.from_dict(d)
+        assert loaded.prior_count == model.prior_count
+
+    def test_collection_dict_roundtrip(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        d = model.dict()
+        loaded = af.AbstractPriorModel.from_dict(d)
+        assert loaded.prior_count == model.prior_count
+
+    def test_fixed_parameter_survives_roundtrip(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = 5.0
+        d = model.dict()
+        loaded = af.AbstractPriorModel.from_dict(d)
+        assert loaded.prior_count == 1
+
+    def test_linked_prior_survives_roundtrip(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = model.two
+        assert model.prior_count == 1
+        d = model.dict()
+        loaded = af.AbstractPriorModel.from_dict(d)
+        assert loaded.prior_count == 1
+
+
+# ---------------------------------------------------------------------------
+# Log prior computation
+# ---------------------------------------------------------------------------
+class TestLogPrior:
+    def test_log_prior_within_bounds(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+        model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+        log_priors = model.log_prior_list_from_vector([5.0, 5.0])
+        assert all(np.isfinite(lp) for lp in log_priors)
+
+    def test_log_prior_outside_bounds(self):
+        model = af.Model(af.m.MockClassx2)
+        model.one = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+        model.two = af.UniformPrior(lower_limit=0.0, upper_limit=10.0)
+        log_priors = model.log_prior_list_from_vector([15.0, 5.0])
+        # Out-of-bounds value should have a lower (or zero) log prior than in-bounds
+        assert log_priors[0] <= log_priors[1]
+
+
+# ---------------------------------------------------------------------------
+# Edge cases
+# ---------------------------------------------------------------------------
+class TestEdgeCases:
+    def test_single_parameter_model(self):
+        """A model with a single free parameter using explicit prior."""
+        model = af.Model(af.m.MockClassx2)
+        model.two = 5.0  # Fix one parameter
+        assert model.prior_count == 1
+        instance = model.instance_from_vector([42.0])
+        assert instance.one == 42.0
+        assert instance.two == 5.0
+
+    def test_model_copy_preserves_priors(self):
+        model = af.Model(af.m.MockClassx2)
+        copied = model.copy()
+        assert copied.prior_count == model.prior_count
+        # Priors are independent copies (different objects)
+        assert copied.one is not model.one
+
+    def test_model_copy_linked_priors_independent(self):
+        """Copying a model with linked priors preserves the link in the copy."""
+        model = af.Model(af.m.MockClassx2)
+        model.one = model.two
+        assert model.prior_count == 1
+        copied = model.copy()
+        assert copied.prior_count == 1
+        # The copy's internal link is preserved
+        assert copied.one is copied.two
+
+    def test_prior_ordering_is_deterministic(self):
+        """prior_tuples_ordered_by_id should be stable across calls."""
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx2)
+        order1 = [(n, p.id) for n, p in model.prior_tuples_ordered_by_id]
+        order2 = [(n, p.id) for n, p in model.prior_tuples_ordered_by_id]
+        assert order1 == order2
+
+    def test_prior_count_equals_total_free_parameters(self):
+        model = af.Collection(a=af.m.MockClassx2, b=af.m.MockClassx4)
+        assert model.prior_count == model.total_free_parameters
+
+    def test_has_model(self):
+        model = af.Collection(a=af.Model(af.m.MockClassx2))
+        assert model.has_model(af.m.MockClassx2)
+        assert not model.has_model(af.m.MockClassx4)
+
+    def test_has_instance(self):
+        model = af.Model(af.m.MockClassx2)
+        assert model.has_instance(Prior)
+        assert not model.has_instance(af.m.MockClassx4)

Instructions

  1. Read the diff above carefully. Identify every renamed, moved, removed, or changed public API (functions, classes, parameters, imports).
  2. Search all files in scripts/ for usages of the old API patterns.
  3. Update each usage to the new API. Preserve the script's existing behaviour and docstrings.
  4. Test your changes by running: bash run_scripts.sh
  5. Fix any test failures and re-run until the test suite is clean.
  6. If a script cannot be fixed automatically (e.g. a missing upstream dependency, an ambiguous API change), leave it unchanged and list it in your PR description under a "Could not update" section with the reason.
  7. After all scripts pass, regenerate notebooks by running:
    pip install ipynb-py-convert
    git clone https://github.com/Jammy2211/PyAutoBuild.git ../PyAutoBuild
    PYTHONPATH=../PyAutoBuild/autobuild python3 ../PyAutoBuild/autobuild/generate.py autofit
    
  8. Commit the regenerated notebooks alongside the script changes.

Branch

Use branch name: feature/feature/model_docs

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions