Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
e85682c
move jax serialise/deserialise test util to conftest
rhayes777 Feb 14, 2025
4fe5260
pytree methods for FactorGraphModel
rhayes777 Feb 14, 2025
d0fd761
Merge branch 'main' into feature/graphical_pytrees
rhayes777 Feb 14, 2025
f3e49fe
pytree methods for AnalysisFactor
rhayes777 Feb 14, 2025
a7cafb3
Merge branch 'main' into feature/graphical_pytrees
rhayes777 Feb 14, 2025
dd41024
tree flatten for LogUniformPrior
rhayes777 Feb 14, 2025
39b8e8a
prior ids as children when creating pytrees
rhayes777 Feb 14, 2025
679f05a
Merge branch 'main' into feature/graphical_pytrees
rhayes777 Feb 21, 2025
0221785
Merge branch 'main' into feature/graphical_pytrees
rhayes777 Mar 5, 2025
6476093
Merge branch 'main' into feature/graphical_pytrees
Jammy2211 Mar 13, 2025
2e72373
Merge branch 'main' into feature/graphical_pytrees
Jammy2211 Mar 14, 2025
f31fc8c
Merge branch 'main' into feature/graphical_pytrees
Jammy2211 Mar 17, 2025
528ef9c
main merge
Jammy2211 Mar 17, 2025
b990e12
Merge branch 'main' into feature/graphical_pytrees
Jammy2211 Mar 24, 2025
a68537f
specific function which generates initial sa,ples without pool
Jammy2211 Apr 1, 2025
fba57fb
just ignore nans currentlY, asked for help
Jammy2211 Apr 1, 2025
3f0fe22
nans now handled by np.nan
Jammy2211 Apr 1, 2025
34d11d0
Merge branch 'feature/jax_initial_samples' into feature/jax_assert_wi…
Jammy2211 Apr 1, 2025
413d2c2
Merge pull request #1125 from rhayes777/feature/jax_initial_samples
Jammy2211 Apr 2, 2025
154e6fc
Merge branch 'main' of https://github.com/rhayes777/PyAutoFit into main
Jammy2211 Apr 8, 2025
a9dbc65
Merge branch 'main' into feature/graphical_pytrees
rhayes777 Apr 9, 2025
038385c
Merge branch 'main' into feature/graphical_pytrees
Jammy2211 Apr 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions autofit/graphical/declarative/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from autofit.non_linear.samples.summary import SamplesSummary
from autofit.non_linear.analysis.combined import CombinedResult

from autofit.jax_wrapper import register_pytree_node_class


@register_pytree_node_class
class FactorGraphModel(AbstractDeclarativeFactor):
def __init__(
self,
Expand All @@ -34,6 +37,20 @@ def __init__(
self._model_factors = list(model_factors)
self._name = name or namer(self.__class__.__name__)

def tree_flatten(self):
return (
(self._model_factors,),
(self._name, self.include_prior_factors),
)

@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(
*children[0],
name=aux_data[0],
include_prior_factors=aux_data[1],
)

@property
def prior_model(self):
"""
Expand Down
23 changes: 22 additions & 1 deletion autofit/graphical/declarative/factor/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from autofit.non_linear.paths.abstract import AbstractPaths
from .abstract import AbstractModelFactor

from autofit.jax_wrapper import register_pytree_node_class


class FactorCallable:
def __init__(
Expand Down Expand Up @@ -43,6 +45,7 @@ def __call__(self, **kwargs: np.ndarray) -> float:
return self.analysis.log_likelihood_function(instance)


@register_pytree_node_class
class AnalysisFactor(AbstractModelFactor):
@property
def prior_model(self):
Expand Down Expand Up @@ -85,7 +88,25 @@ def __init__(
prior_variable_dict=prior_variable_dict,
name=name,
)
print(name)

def tree_flatten(self):
return (
(self.prior_model,),
(
self.analysis,
self.optimiser,
self.name,
),
)

@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(
children[0],
analysis=aux_data[0],
optimiser=aux_data[1],
name=aux_data[2],
)

def __getstate__(self):
return self.__dict__
Expand Down
2 changes: 1 addition & 1 deletion autofit/mapper/prior/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def tree_unflatten(cls, aux_data, children):
-------
An instance of this class
"""
return cls(*children, id_=aux_data[0])
return cls(*children)

@property
def lower_unit_limit(self) -> float:
Expand Down
2 changes: 1 addition & 1 deletion autofit/mapper/prior/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
)

def tree_flatten(self):
return (self.mean, self.sigma, self.lower_limit, self.upper_limit), (self.id,)
return (self.mean, self.sigma, self.lower_limit, self.upper_limit, self.id), ()

@classmethod
def with_limits(cls, lower_limit: float, upper_limit: float) -> "GaussianPrior":
Expand Down
11 changes: 11 additions & 0 deletions autofit/mapper/prior/log_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import numpy as np

from autofit.jax_wrapper import register_pytree_node_class
from autofit.messages.normal import NormalMessage
from .abstract import Prior
from ...messages.composed_transform import TransformedMessage
from ...messages.transform import log_transform


@register_pytree_node_class
class LogGaussianPrior(Prior):
__identifier_fields__ = ("lower_limit", "upper_limit", "mean", "sigma")
__database_args__ = ("mean", "sigma", "lower_limit", "upper_limit", "id_")
Expand Down Expand Up @@ -71,6 +73,15 @@ def __init__(
id_=id_,
)

def tree_flatten(self):
return (
self.mean,
self.sigma,
self.lower_limit,
self.upper_limit,
self.id,
), ()

@classmethod
def with_limits(cls, lower_limit: float, upper_limit: float) -> "LogGaussianPrior":
"""
Expand Down
9 changes: 9 additions & 0 deletions autofit/mapper/prior/log_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import numpy as np

from autofit.jax_wrapper import register_pytree_node_class
from autofit import exc
from autofit.messages.normal import UniformNormalMessage
from autofit.messages.transform import log_10_transform, LinearShiftTransform
from .abstract import Prior
from ...messages.composed_transform import TransformedMessage


@register_pytree_node_class
class LogUniformPrior(Prior):
def __init__(
self,
Expand Down Expand Up @@ -67,6 +69,13 @@ def __init__(
id_=id_,
)

def tree_flatten(self):
return (
self.lower_limit,
self.upper_limit,
self.id,
), ()

@classmethod
def with_limits(cls, lower_limit: float, upper_limit: float) -> "LogUniformPrior":
"""
Expand Down
2 changes: 1 addition & 1 deletion autofit/mapper/prior/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
)

def tree_flatten(self):
return (self.lower_limit, self.upper_limit), (self.id,)
return (self.lower_limit, self.upper_limit, self.id), ()

def with_limits(
self,
Expand Down
8 changes: 4 additions & 4 deletions autofit/non_linear/fitness.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import numpy as np

import os
from typing import Optional

from autoconf import conf

from autofit import exc

from autofit.jax_wrapper import numpy as np

from autofit.mapper.prior_model.abstract import AbstractPriorModel
from autofit.non_linear.paths.abstract import AbstractPaths
from autofit.non_linear.analysis import Analysis
Expand Down Expand Up @@ -154,9 +156,7 @@ def __call__(self, parameters, *kwargs):
try:
instance = self.model.instance_from_vector(vector=parameters)
log_likelihood = self.log_likelihood_function(instance=instance)

if np.isnan(log_likelihood):
return self.resample_figure_of_merit
log_likelihood = np.where(np.isnan(log_likelihood), self.resample_figure_of_merit, log_likelihood)

except exc.FitException:
return self.resample_figure_of_merit
Expand Down
114 changes: 95 additions & 19 deletions autofit/non_linear/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from autofit.mapper.prior_model.abstract import AbstractPriorModel
from autofit.non_linear.parallel import SneakyPool

from autofit import jax_wrapper

logger = logging.getLogger(__name__)


Expand All @@ -39,14 +41,14 @@ def figure_of_metric(args) -> Optional[float]:
return None

def samples_from_model(
self,
total_points: int,
model: AbstractPriorModel,
fitness,
paths: AbstractPaths,
use_prior_medians: bool = False,
test_mode_samples: bool = True,
n_cores: int = 1,
self,
total_points: int,
model: AbstractPriorModel,
fitness,
paths: AbstractPaths,
use_prior_medians: bool = False,
test_mode_samples: bool = True,
n_cores: int = 1,
):
"""
Generate the initial points of the non-linear search, by randomly drawing unit values from a uniform
Expand All @@ -64,6 +66,14 @@ def samples_from_model(
if os.environ.get("PYAUTOFIT_TEST_MODE") == "1" and test_mode_samples:
return self.samples_in_test_mode(total_points=total_points, model=model)

if jax_wrapper.use_jax:
return self.samples_jax(
total_points=total_points,
model=model,
fitness=fitness,
use_prior_medians=use_prior_medians
)

unit_parameter_lists = []
parameter_lists = []
figures_of_merit_list = []
Expand Down Expand Up @@ -92,30 +102,95 @@ def samples_from_model(
unit_parameter_lists_.append(unit_parameter_list)

for figure_of_merit, unit_parameter_list, parameter_list in zip(
sneaky_pool.map(
function=self.figure_of_metric,
args_list=[(fitness, parameter_list) for parameter_list in parameter_lists_],
log_info=False
),
unit_parameter_lists_,
parameter_lists_,
sneaky_pool.map(
function=self.figure_of_metric,
args_list=[(fitness, parameter_list) for parameter_list in parameter_lists_],
log_info=False
),
unit_parameter_lists_,
parameter_lists_,
):
if figure_of_merit is not None:
unit_parameter_lists.append(unit_parameter_list)
parameter_lists.append(parameter_list)
figures_of_merit_list.append(figure_of_merit)

if total_points > 1 and np.allclose(
a=figures_of_merit_list[0], b=figures_of_merit_list[1:]
a=figures_of_merit_list[0], b=figures_of_merit_list[1:]
):
raise exc.InitializerException(
"""
The initial samples all have the same figure of merit (e.g. log likelihood values).

The non-linear search will therefore not progress correctly.

Possible causes for this behaviour are:


- The `log_likelihood_function` of the analysis class is defined incorrectly.
- The model parameterization creates numerically inaccurate log likelihoods.
- The`log_likelihood_function` is always returning `nan` values.
"""
)

logger.info(f"Initial samples generated, starting non-linear search")

return unit_parameter_lists, parameter_lists, figures_of_merit_list

def samples_jax(
self,
total_points: int,
model: AbstractPriorModel,
fitness,
use_prior_medians: bool = False,
):
"""
Generate the initial points of the non-linear search, by randomly drawing unit values from a uniform
distribution between the ball_lower_limit and ball_upper_limit values.

Parameters
----------
total_points
The number of points in non-linear paramemter space which initial points are created for.
model
An object that represents possible instances of some model with a given dimensionality which is the number
of free dimensions of the model.
"""

unit_parameter_lists = []
parameter_lists = []
figures_of_merit_list = []

logger.info(f"Generating initial samples of model using JAX LH Function cores")

while len(figures_of_merit_list) < total_points:

if not use_prior_medians:
unit_parameter_list = self._generate_unit_parameter_list(model)
else:
unit_parameter_list = [0.5] * model.prior_count

parameter_list = model.vector_from_unit_vector(
unit_vector=unit_parameter_list
)

figure_of_merit = self.figure_of_metric((fitness, parameter_list))

if figure_of_merit is not None:
unit_parameter_lists.append(unit_parameter_list)
parameter_lists.append(parameter_list)
figures_of_merit_list.append(figure_of_merit)

if total_points > 1 and np.allclose(
a=figures_of_merit_list[0], b=figures_of_merit_list[1:]
):
raise exc.InitializerException(
"""
The initial samples all have the same figure of merit (e.g. log likelihood values).

The non-linear search will therefore not progress correctly.

Possible causes for this behaviour are:

- The `log_likelihood_function` of the analysis class is defined incorrectly.
- The model parameterization creates numerically inaccurate log likelihoods.
- The`log_likelihood_function` is always returning `nan` values.
Expand Down Expand Up @@ -321,6 +396,7 @@ def info_value_from(self, value : Tuple[float, float]) -> float:
"""
return (value[1] + value[0]) / 2.0


class Initializer(AbstractInitializer):
def __init__(self, lower_limit: float, upper_limit: float):
"""
Expand Down
12 changes: 12 additions & 0 deletions test_autofit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@
directory = Path(__file__).parent


@pytest.fixture(name="recreate")
def recreate():
jax = pytest.importorskip("jax")

def _recreate(o):
flatten_func, unflatten_func = jax._src.tree_util._registry[type(o)]
children, aux_data = flatten_func(o)
return unflatten_func(aux_data, children)

return _recreate


@pytest.fixture(autouse=True)
def turn_off_gc(monkeypatch):
monkeypatch.setattr(abstract_search, "gc", MagicMock())
Expand Down
Loading
Loading