Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autofit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from autoconf import jax_wrapper
from autoconf.dictable import register_parser
from . import conf

Expand Down
2 changes: 0 additions & 2 deletions autofit/config/general.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
jax:
use_jax: false # If True, PyAutoFit uses JAX internally, whereas False uses normal Numpy.
updates:
iterations_per_quick_update: 1e99 # Non-linear search iterations between every quick update, which just displays the maximum likelihood model fit.
iterations_per_full_update: 1e99 # Non-linear search iterations between every full update, which outputs all visuals and result fits (e.g. model.result, search.summary), this exits the search and can be slow.
Expand Down
15 changes: 7 additions & 8 deletions autofit/example/analysis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
from typing import Dict, Optional

from autoconf.jax_wrapper import numpy as xp

import autofit as af

from autofit.example.result import ResultExample
Expand Down Expand Up @@ -38,7 +36,7 @@ class Analysis(af.Analysis):

LATENT_KEYS = ["gaussian.fwhm"]

def __init__(self, data: np.ndarray, noise_map: np.ndarray):
def __init__(self, data: np.ndarray, noise_map: np.ndarray, use_jax=False):
"""
In this example the `Analysis` object only contains the data and noise-map. It can be easily extended,
for more complex data-sets and model fitting problems.
Expand All @@ -51,12 +49,12 @@ def __init__(self, data: np.ndarray, noise_map: np.ndarray):
A 1D numpy array containing the noise values of the data, used for computing the goodness of fit
metric.
"""
super().__init__()
super().__init__(use_jax=use_jax)

self.data = data
self.noise_map = noise_map

def log_likelihood_function(self, instance: af.ModelInstance) -> float:
def log_likelihood_function(self, instance: af.ModelInstance, xp=np) -> float:
"""
Determine the log likelihood of a fit of multiple profiles to the dataset.

Expand Down Expand Up @@ -98,14 +96,15 @@ def model_data_1d_from(self, instance: af.ModelInstance) -> np.ndarray:
The model data of the profiles.
"""

xvalues = xp.arange(self.data.shape[0])
model_data_1d = xp.zeros(self.data.shape[0])
xvalues = self._xp.arange(self.data.shape[0])
model_data_1d = self._xp.zeros(self.data.shape[0])

try:
for profile in instance:
try:
model_data_1d += profile.model_data_from(
xvalues=xvalues
xvalues=xvalues,
xp=self._xp
)
except AttributeError:
pass
Expand Down
10 changes: 4 additions & 6 deletions autofit/example/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import numpy as np
from typing import Tuple

from autoconf.jax_wrapper import numpy as xp

"""
The `Gaussian` class in this module is the model components that is fitted to data using a non-linear search. The
inputs of its __init__ constructor are the parameters which can be fitted for.
Expand Down Expand Up @@ -47,7 +45,7 @@ def fwhm(self) -> float:
the free parameters of the model which we are interested and may want to store the full samples information
on (e.g. to create posteriors).
"""
return 2 * xp.sqrt(2 * xp.log(2)) * self.sigma
return 2 * np.sqrt(2 * np.log(2)) * self.sigma

def _tree_flatten(self):
return (self.centre, self.normalization, self.sigma), None
Expand All @@ -64,7 +62,7 @@ def __eq__(self, other):
and self.sigma == other.sigma
)

def model_data_from(self, xvalues: np.ndarray) -> np.ndarray:
def model_data_from(self, xvalues: np.ndarray, xp=np) -> np.ndarray:
"""
Calculate the normalization of the profile on a 1D grid of Cartesian x coordinates.

Expand All @@ -82,7 +80,7 @@ def model_data_from(self, xvalues: np.ndarray) -> np.ndarray:
xp.exp(-0.5 * xp.square(xp.divide(transformed_xvalues, self.sigma))),
)

def f(self, x: float):
def f(self, x: float, xp=np):
return (
self.normalization
/ (self.sigma * xp.sqrt(2 * math.pi))
Expand Down Expand Up @@ -137,7 +135,7 @@ def __init__(
self.normalization = normalization
self.rate = rate

def model_data_from(self, xvalues: np.ndarray) -> np.ndarray:
def model_data_from(self, xvalues: np.ndarray, xp=np) -> np.ndarray:
"""
Calculate the 1D Gaussian profile on a 1D grid of Cartesian x coordinates.

Expand Down
4 changes: 3 additions & 1 deletion autofit/graphical/declarative/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ class AbstractDeclarativeFactor(Analysis, ABC):
optimiser: AbstractFactorOptimiser
_plates: Tuple[Plate, ...] = ()

def __init__(self, include_prior_factors=False):
def __init__(self, include_prior_factors=False, use_jax : bool = False):
self.include_prior_factors = include_prior_factors

super().__init__(use_jax=use_jax)

@property
@abstractmethod
def name(self):
Expand Down
5 changes: 2 additions & 3 deletions autofit/graphical/declarative/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
from autofit.mapper.model import ModelInstance
from autofit.mapper.prior_model.prior_model import Model

from autoconf.jax_wrapper import register_pytree_node_class
from ...non_linear.combined_result import CombinedResult


@register_pytree_node_class
class FactorGraphModel(AbstractDeclarativeFactor):
def __init__(
self,
*model_factors: Union[AbstractDeclarativeFactor, HierarchicalFactor],
name=None,
include_prior_factors=True,
use_jax : bool = False
):
"""
A collection of factors that describe models, which can be
Expand All @@ -36,6 +34,7 @@ def __init__(
"""
super().__init__(
include_prior_factors=include_prior_factors,
use_jax=use_jax,
)
self._model_factors = list(model_factors)
self._name = name or namer(self.__class__.__name__)
Expand Down
4 changes: 0 additions & 4 deletions autofit/graphical/declarative/factor/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from autofit.non_linear.paths.abstract import AbstractPaths
from .abstract import AbstractModelFactor

from autoconf.jax_wrapper import register_pytree_node_class


class FactorCallable:
def __init__(
Expand Down Expand Up @@ -45,8 +43,6 @@ def __call__(self, **kwargs: np.ndarray) -> float:
instance = self.prior_model.instance_for_arguments(arguments)
return self.analysis.log_likelihood_function(instance)


@register_pytree_node_class
class AnalysisFactor(AbstractModelFactor):
@property
def prior_model(self):
Expand Down
3 changes: 2 additions & 1 deletion autofit/graphical/declarative/factor/hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __call__(self, **kwargs):

class _HierarchicalFactor(AbstractModelFactor):
def __init__(
self, distribution_model: HierarchicalFactor, drawn_prior: Prior,
self, distribution_model: HierarchicalFactor, drawn_prior: Prior, use_jax : bool = False
):
"""
A factor that links a variable to a parameterised distribution.
Expand All @@ -159,6 +159,7 @@ def __init__(
"""
self.distribution_model = distribution_model
self.drawn_prior = drawn_prior
self.use_jax = use_jax

prior_variable_dict = {prior.name: prior for prior in distribution_model.priors}

Expand Down
4 changes: 3 additions & 1 deletion autofit/graphical/factor_graphs/factor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from copy import deepcopy
from inspect import getfullargspec
import jax
from typing import Tuple, Dict, Any, Callable, Union, List, Optional, TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -285,6 +284,8 @@ def _set_jacobians(
numerical_jacobian=True,
jacfwd=True,
):
import jax

self._vjp = vjp
self._jacfwd = jacfwd
if vjp or factor_vjp:
Expand Down Expand Up @@ -327,6 +328,7 @@ def __call__(self, values: VariableData) -> FactorValue:
return self._cache[key]

def _jax_factor_vjp(self, *args) -> Tuple[Any, Callable]:
import jax
return jax.vjp(self._factor, *args)

_factor_vjp = _jax_factor_vjp
Expand Down
2 changes: 1 addition & 1 deletion autofit/graphical/laplace/newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def take_quasi_newton_step(
) -> Tuple[Optional[float], OptimisationState]:
""" """
state.search_direction = search_direction(state, **(search_direction_kws or {}))
if state.search_direction.vecnorm(np.Inf) == 0:
if state.search_direction.vecnorm(np.inf) == 0:
# if gradient is zero then at maximum already
return 0.0, state

Expand Down
3 changes: 3 additions & 0 deletions autofit/interpolator/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
x: np.ndarray,
y: np.ndarray,
inverse_covariance_matrix: np.ndarray,
use_jax : bool = False
):
"""
An analysis class that describes a linear relationship between x and y, y = mx + c
Expand All @@ -30,6 +31,8 @@ def __init__(
The y values. This is a matrix comprising all the variables in the model at each x value
inverse_covariance_matrix
"""
super().__init__(use_jax=use_jax)

self.x = x
self.y = y
self.inverse_covariance_matrix = inverse_covariance_matrix
Expand Down
3 changes: 0 additions & 3 deletions autofit/mapper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from functools import wraps
from typing import Optional, Union, Tuple, List, Iterable, Type, Dict

from autoconf.jax_wrapper import register_pytree_node_class

from autofit.mapper.model_object import ModelObject
from autofit.mapper.prior_model.recursion import DynamicRecursionCache

Expand Down Expand Up @@ -384,7 +382,6 @@ def path_instances_of_class(
return results


@register_pytree_node_class
class ModelInstance(AbstractModel):
"""
An instance of a Collection or Model. This is created by optimisers and correspond
Expand Down
3 changes: 0 additions & 3 deletions autofit/mapper/prior/gaussian.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Optional

from autoconf.jax_wrapper import register_pytree_node_class

from autofit.messages.normal import NormalMessage
from .abstract import Prior


@register_pytree_node_class
class GaussianPrior(Prior):
__identifier_fields__ = ("mean", "sigma")
__database_args__ = ("mean", "sigma", "id_")
Expand Down
2 changes: 0 additions & 2 deletions autofit/mapper/prior/log_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

import numpy as np

from autoconf.jax_wrapper import register_pytree_node_class
from autofit.messages.normal import NormalMessage
from .abstract import Prior
from ...messages.composed_transform import TransformedMessage
from ...messages.transform import log_transform


@register_pytree_node_class
class LogGaussianPrior(Prior):
__identifier_fields__ = ("mean", "sigma")
__database_args__ = ("mean", "sigma", "id_")
Expand Down
2 changes: 0 additions & 2 deletions autofit/mapper/prior/log_uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

import numpy as np

from autoconf.jax_wrapper import register_pytree_node_class
from autofit.messages.normal import UniformNormalMessage
from autofit.messages.transform import log_10_transform, LinearShiftTransform
from .abstract import Prior
from ...messages.composed_transform import TransformedMessage

from autofit import exc

@register_pytree_node_class
class LogUniformPrior(Prior):
__identifier_fields__ = ("lower_limit", "upper_limit")
__database_args__ = ("lower_limit", "upper_limit", "id_")
Expand Down
3 changes: 0 additions & 3 deletions autofit/mapper/prior/truncated_gaussian.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Optional, Tuple

from autoconf.jax_wrapper import register_pytree_node_class

from autofit.messages.truncated_normal import TruncatedNormalMessage
from .abstract import Prior


@register_pytree_node_class
class TruncatedGaussianPrior(Prior):
__identifier_fields__ = ("mean", "sigma", "lower_limit", "upper_limit")
__database_args__ = ("mean", "sigma", "lower_limit", "upper_limit", "id_")
Expand Down
2 changes: 0 additions & 2 deletions autofit/mapper/prior/uniform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from autoconf.jax_wrapper import register_pytree_node_class
from typing import Optional, Tuple

from autofit.messages.normal import UniformNormalMessage
Expand All @@ -9,7 +8,6 @@

from autofit import exc

@register_pytree_node_class
class UniformPrior(Prior):
__identifier_fields__ = ("lower_limit", "upper_limit")
__database_args__ = ("lower_limit", "upper_limit", "id_")
Expand Down
23 changes: 15 additions & 8 deletions autofit/mapper/prior_model/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@
from autoconf.dictable import from_dict
from .abstract import AbstractPriorModel
from autofit.mapper.prior.abstract import Prior
from autoconf.jax_wrapper import numpy as xp, use_jax
import numpy as np

from autoconf.jax_wrapper import register_pytree_node_class


@register_pytree_node_class
class Array(AbstractPriorModel):
def __init__(
self,
Expand Down Expand Up @@ -77,7 +73,8 @@ def _instance_for_arguments(
-------
The array with the priors replaced.
"""
array = xp.zeros(self.shape)
make_array = True

for index in self.indices:
value = self[index]
try:
Expand All @@ -88,10 +85,20 @@ def _instance_for_arguments(
except AttributeError:
pass

if use_jax:
array = array.at[index].set(value)
else:
if make_array:
if isinstance(value, np.ndarray) or isinstance(value, np.float64):
array = np.zeros(self.shape)
make_array = False
else:
import jax.numpy as jnp
array = jnp.zeros(self.shape)
make_array = False

if isinstance(value, np.ndarray) or isinstance(value, np.float64):
array[index] = value
else:
array = array.at[index].set(value)

return array

def __setitem__(
Expand Down
3 changes: 0 additions & 3 deletions autofit/mapper/prior_model/collection.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from collections.abc import Iterable

from autoconf.jax_wrapper import register_pytree_node_class

from autofit.mapper.model import ModelInstance, assert_not_frozen
from autofit.mapper.prior.abstract import Prior
from autofit.mapper.prior.constant import Constant
from autofit.mapper.prior_model.abstract import AbstractPriorModel


@register_pytree_node_class
class Collection(AbstractPriorModel):
def name_for_prior(self, prior: Prior) -> str:
"""
Expand Down
Loading
Loading