Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions autofit/mapper/prior/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ def unit_value_for(self, physical_value: float) -> float:
return self.message.cdf(physical_value)

def with_message(self, message):
"""Return a copy of this prior with a different message (distribution).

Parameters
----------
message
The new message object defining the prior's distribution.

Returns
-------
Prior
A copy of this prior using the new message.
"""
new = copy(self)
new.message = message
return new
Expand Down Expand Up @@ -88,6 +100,23 @@ def factor(self):

@staticmethod
def for_class_and_attribute_name(cls, attribute_name):
"""Create a prior from the configuration for a given class and attribute.

Looks up the prior type and parameters in the prior config files
for the specified class and attribute name.

Parameters
----------
cls
The model class whose config is looked up.
attribute_name
The name of the attribute on that class.

Returns
-------
Prior
A prior instance constructed from the config entry.
"""
prior_dict = conf.instance.prior_config.for_class_and_suffix_path(
cls, [attribute_name]
)
Expand Down Expand Up @@ -129,10 +158,31 @@ def instance_for_arguments(
arguments,
ignore_assertions=False,
):
"""Look up this prior's value in an arguments dictionary.

Parameters
----------
arguments
A dictionary mapping Prior objects to physical values.
ignore_assertions
Unused for priors (present for interface compatibility).
"""
_ = ignore_assertions
return arguments[self]

def project(self, samples, weights):
"""Project this prior given samples and log weights from a search.

Returns a copy of this prior whose message has been updated to
reflect the posterior information from the samples.

Parameters
----------
samples
Array of sample values for this parameter.
weights
Log weights for each sample.
"""
result = copy(self)
result.message = self.message.project(
samples=samples,
Expand Down Expand Up @@ -170,6 +220,11 @@ def __str__(self):
@property
@abstractmethod
def parameter_string(self) -> str:
"""A human-readable string summarizing this prior's parameters.

Subclasses must implement this to return a description such as
``"mean = 0.0, sigma = 1.0"`` or ``"lower_limit = 0.0, upper_limit = 1.0"``.
"""
pass

def __float__(self):
Expand Down Expand Up @@ -254,7 +309,22 @@ def name_of_class(cls) -> str:

@property
def limits(self) -> Tuple[float, float]:
"""The (lower, upper) bounds of this prior.

Returns (-inf, inf) by default. Subclasses with finite bounds
(e.g. UniformPrior) override this.
"""
return (float("-inf"), float("inf"))

def gaussian_prior_model_for_arguments(self, arguments):
"""Look up this prior in an arguments dict and return the mapped value.

Used during prior replacement workflows where each prior is mapped
to a new prior or fixed value via an arguments dictionary.

Parameters
----------
arguments
A dictionary mapping Prior objects to their replacement values.
"""
return arguments[self]
7 changes: 7 additions & 0 deletions autofit/mapper/prior/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ def __init__(
)

def tree_flatten(self):
"""Flatten this prior into a JAX-compatible PyTree representation.

Returns
-------
tuple
A (children, aux_data) pair where children are (mean, sigma, id).
"""
return (self.mean, self.sigma, self.id), ()

@classmethod
Expand Down
29 changes: 29 additions & 0 deletions autofit/mapper/prior/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,50 @@ def __init__(
)

def tree_flatten(self):
"""Flatten this prior into a JAX-compatible PyTree representation.

Returns
-------
tuple
A (children, aux_data) pair where children are (lower_limit, upper_limit, id).
"""
return (self.lower_limit, self.upper_limit, self.id), ()

@property
def width(self):
"""The width of the uniform distribution (upper_limit - lower_limit)."""
return self.upper_limit - self.lower_limit

def with_limits(
self,
lower_limit: float,
upper_limit: float,
) -> "Prior":
"""Create a new UniformPrior with different bounds.

Parameters
----------
lower_limit
The new lower bound.
upper_limit
The new upper bound.
"""
return UniformPrior(
lower_limit=lower_limit,
upper_limit=upper_limit,
)

def logpdf(self, x):
"""Compute the log probability density at x.

Adjusts boundary values by epsilon to avoid evaluating exactly at
the distribution edges where the PDF is undefined.

Parameters
----------
x
The value at which to evaluate the log PDF.
"""
# TODO: handle x as a numpy array
if x == self.lower_limit:
x += epsilon
Expand All @@ -102,6 +129,7 @@ def dict(self) -> dict:

@property
def parameter_string(self) -> str:
"""A human-readable string summarizing the prior's lower and upper limits."""
return f"lower_limit = {self.lower_limit}, upper_limit = {self.upper_limit}"

def value_for(self, unit: float) -> float:
Expand Down Expand Up @@ -142,4 +170,5 @@ def log_prior_from_value(self, value, xp=np):

@property
def limits(self) -> Tuple[float, float]:
"""The (lower_limit, upper_limit) bounds of this uniform prior."""
return self.lower_limit, self.upper_limit
Loading
Loading