Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion autofit/config/output.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ start_point: true
# manually after completing a successful model-fit. This will save computational run time by not computing latent
# variables during a any model-fit which is unsuccessful.

latent_during_fit: true # Whether to output the `latent.csv`, `latent.results` and `latent_summary.json` files during the fit when it performs on-the-fly output.
latent_during_fit: false # Whether to output the `latent.csv`, `latent.results` and `latent_summary.json` files during the fit when it performs on-the-fly output.
latent_after_fit: true # If `latent_during_fit` is False, whether to output the `latent.csv`, `latent.results` and `latent_summary.json` files after the fit is complete.
latent_draw_via_pdf : true # Whether to draw latent variable values via the PDF of every sample, which uses fewer samples to estimate latent variable errors. If False, latent variable values are drawn from every sample.
latent_draw_via_pdf_size : 100 # The number of samples drawn to estimate latent variable errors if `latent_draw_via_pdf` is True.
latent_csv: true # Whether to ouptut the `latent.csv` file.
latent_results: true # Whether to output the `latent.results` file.

Expand Down
24 changes: 17 additions & 7 deletions autofit/example/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ class Analysis(af.Analysis):
It has been extended, based on the model that is input into the analysis, to include a
property `max_log_likelihood_model_data`, which is the model data of the best-fit model.
"""

Result = ResultExample

LATENT_KEYS = ["fwhm"]

def __init__(self, data: np.ndarray, noise_map: np.ndarray):
"""
In this example the `Analysis` object only contains the data and noise-map. It can be easily extended,
Expand Down Expand Up @@ -204,7 +205,7 @@ def make_result(
analysis=self,
)

def compute_latent_variables(self, instance) -> Dict[str, float]:
def compute_latent_variables(self, parameters, model) -> Dict[str, float]:
"""
A latent variable is not a model parameter but can be derived from the model. Its value and errors may be
of interest and aid in the interpretation of a model-fit.
Expand All @@ -218,8 +219,15 @@ def compute_latent_variables(self, instance) -> Dict[str, float]:
In the example below, the `latent.csv` file will contain one column with the FWHM of every Gausian model
sampled by the non-linear search.

This function is called for every non-linear search sample, where the `instance` passed in corresponds to
each sample.
This function is called at the end of search, following one of two schemes depending on the settings in
`output.yaml`:

1) Call for every search sample, which produces a complete `latent/samples.csv` which mirrors the normal
`samples.csv` file but takes a long time to compute.

2) Call only for N random draws from the posterior inferred at the end of the search, which only produces a
`latent/latent_summary.json` file with the median and 1 and 3 sigma errors of the latent variables but is
fast to compute.

Parameters
----------
Expand All @@ -230,10 +238,12 @@ def compute_latent_variables(self, instance) -> Dict[str, float]:
-------

"""
instance = model.instance_from_vector(vector=parameters)

try:
return {"fwhm": instance.fwhm}
return (instance.fwhm, )
except AttributeError:
try:
return {"gaussian.fwhm": instance[0].fwhm}
return (instance[0].fwhm,)
except AttributeError:
return {"gaussian.fwhm": instance[0].gaussian.fwhm}
return (instance[0].gaussian.fwhm,)
99 changes: 80 additions & 19 deletions autofit/non_linear/analysis/analysis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import inspect
import logging
from abc import ABC
import functools
import numpy as np
import jax
import jax.numpy as jnp
from typing import Optional, Dict

from autofit.jax_wrapper import use_jax

from autofit.mapper.prior_model.abstract import AbstractPriorModel
from autofit.non_linear.paths.abstract import AbstractPaths
from autofit.non_linear.samples.summary import SamplesSummary
Expand All @@ -27,6 +33,8 @@ class Analysis(ABC):
Result = Result
Visualizer = Visualizer

LATENT_KEYS = []

def __getattr__(self, item: str):
"""
If a method starts with 'visualize_' then we assume it is associated with
Expand All @@ -51,44 +59,94 @@ def method(*args, **kwargs):

def compute_latent_samples(self, samples: Samples) -> Optional[Samples]:
"""
Internal method that manages computation of latent samples from samples.
Compute latent variables from a model instance.

A latent variable is not itself a free parameter of the model but can be derived from it.
Latent variables may provide physically meaningful quantities that help interpret a model
fit, and their values (with errors) are stored in `latent.csv` in parallel with `samples.csv`.

This implementation is designed to be compatible with both NumPy and JAX:

- It is written to be side-effect free, so it can be JIT-compiled with `jax.jit`.
- It can be vectorized over many parameter sets at once using `jax.vmap`, enabling efficient
batched evaluation of latent variables for multiple samples.
- Returned values should be simple JAX/NumPy scalars or arrays (no Python objects), so they
can be stacked into arrays of shape `(n_samples, n_latents)` for batching.
- Any NaNs introduced (e.g. from invalid model states) can be masked or replaced downstream.

Parameters
----------
samples
The samples from the non-linear search.
parameters : array-like
The parameter vector of the model sample. This will typically come from the non-linear search.
Inside this method it is mapped back to a model instance via `model.instance_from_vector`.
model : Model
The model object defining how the parameter vector is mapped to an instance. Passed explicitly
so that this function can be used inside JAX transforms (`vmap`, `jit`) with `functools.partial`.

Returns
-------
The computed latent samples or None if compute_latent_variables is not implemented.
tuple of (float or jax.numpy scalar)
A tuple containing the latent variables in a fixed order:
`(intensity_total, magnitude, angle)`. Each entry may be NaN if the corresponding component
of the model is not present.
"""
try:

compute_latent_for_model = functools.partial(self.compute_latent_variables, model=samples.model)

if use_jax:
start = time.time()
logger.info("JAX: Applying vmap and jit to likelihood function for latent variables -- may take a few seconds.")
batched_compute_latent = jax.jit(jax.vmap(compute_latent_for_model))
logger.info(f"JAX: vmap and jit applied in {time.time() - start} seconds.")
else:
def batched_compute_latent(x):
return np.array([compute_latent_for_model(xx) for xx in x])

parameter_array = np.array(samples.parameter_lists)
batch_size = 50
latent_samples = []
model = samples.model
for sample in samples.sample_list:
latent_samples.append(
Sample(
log_likelihood=sample.log_likelihood,
log_prior=sample.log_prior,
weight=sample.weight,
kwargs=self.compute_latent_variables(
sample.instance_for_model(
model,
ignore_assertions=True,
)
),

# process in batches
for i in range(0, len(parameter_array), batch_size):

batch = parameter_array[i:i + batch_size]

# batched JAX call on this chunk
latent_values_batch = batched_compute_latent(batch)

if use_jax:
latent_values_batch = jnp.stack(latent_values_batch, axis=-1) # (batch, n_latents)
mask = jnp.all(jnp.isfinite(latent_values_batch), axis=0)
latent_values_batch = latent_values_batch[:, mask]
else:
mask = np.all(np.isfinite(latent_values_batch), axis=0)
latent_values_batch = latent_values_batch[:, mask]

for sample, values in zip(samples.sample_list[i:i + batch_size], latent_values_batch):


kwargs = {k: float(v) for k, v in zip(self.LATENT_KEYS, values)}

latent_samples.append(
Sample(
log_likelihood=sample.log_likelihood,
log_prior=sample.log_prior,
weight=sample.weight,
kwargs=kwargs,
)
)
)

return type(samples)(
sample_list=latent_samples,
model=simple_model_for_kwargs(latent_samples[0].kwargs),
samples_info=samples.samples_info,
)

except NotImplementedError:
return None

def compute_latent_variables(self, instance) -> Dict[str, float]:
def compute_latent_variables(self, parameters, model) -> Dict[str, float]:
"""
Override to compute latent variables from the instance.

Expand Down Expand Up @@ -237,3 +295,6 @@ def profile_log_likelihood_function(self, paths: AbstractPaths, instance):
The maximum likliehood instance of the model so far in the non-linear search.
"""
pass

def latent_lh_dict_from(self, **kwargs):
return None
12 changes: 6 additions & 6 deletions autofit/non_linear/fitness.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,26 +187,26 @@ def __call__(self, parameters, *kwargs):
@cached_property
def _vmap(self):
start = time.time()
print("JAX: Applying vmap and jit to likelihood function -- may take a few seconds.")
logger.info("JAX: Applying vmap and jit to likelihood function -- may take a few seconds.")
func = jax.vmap(jax.jit(self.call))
print(f"JAX: vmap and jit applied in {time.time() - start} seconds.")
logger.info(f"JAX: vmap and jit applied in {time.time() - start} seconds.")
return func

@cached_property
def _call(self):
start = time.time()
print("JAX: Applying jit to likelihood function -- may take a few seconds.")
logger.info("JAX: Applying jit to likelihood function -- may take a few seconds.")
func = jax_wrapper.jit(self.call)
print(f"JAX: jit applied in {time.time() - start} seconds.")
logger.info(f"JAX: jit applied in {time.time() - start} seconds.")
return func


@cached_property
def _grad(self):
start = time.time()
print("JAX: Applying grad to likelihood function -- may take a few seconds.")
logger.info("JAX: Applying grad to likelihood function -- may take a few seconds.")
func = jax_wrapper.grad(self._call)
print(f"JAX: grad applied in {time.time() - start} seconds.")
logger.info(f"JAX: grad applied in {time.time() - start} seconds.")
return func

def grad(self, *args, **kwargs):
Expand Down
31 changes: 31 additions & 0 deletions autofit/non_linear/samples/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,37 @@ def draw_randomly_via_pdf(self) -> Union[List, ModelInstance]:

return self.parameter_lists[sample_index][:]

def samples_drawn_randomly_via_pdf_from(self, total_draws: int = 100) -> "SamplesPDF":
"""
Draw one or more samples randomly from the PDF, weighted by the sample weights.

Parameters
----------
total_draws : int, optional
The number of samples to draw. Defaults to 100.

Returns
-------
SamplesPDF
A new SamplesPDF object containing the drawn samples.
"""
# Normalize weights to sum to 1
weights = np.asarray(self.weight_list, dtype=float)
weights /= weights.sum()

sample_indices = np.random.choice(
a=len(self.sample_list),
size=total_draws,
replace=True,
p=weights,
)

return SamplesPDF(
model=self.model,
sample_list=[self.sample_list[i] for i in sample_indices],
samples_info=self.samples_info,
)

@to_instance
def offset_values_via_input_values(
self, input_vector: List
Expand Down
18 changes: 17 additions & 1 deletion autofit/non_linear/search/abstract_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,10 +958,26 @@ def perform_update(
if (during_analysis and conf.instance["output"]["latent_during_fit"]) or (
not during_analysis and conf.instance["output"]["latent_after_fit"]
):

if conf.instance["output"]["latent_draw_via_pdf"]:

total_draws = conf.instance["output"]["latent_draw_via_pdf_size"]

logger.info(f"Creating latent samples by drawing {total_draws} from the PDF.")

latent_samples = samples.samples_drawn_randomly_via_pdf_from(total_draws=total_draws)

else:

logger.info(f"Creating latent samples using all samples above the samples weight threshold.")

latent_samples = samples_save

latent_samples = analysis.compute_latent_samples(samples_save)

if latent_samples:
self.paths.save_latent_samples(latent_samples)
if not conf.instance["output"]["latent_draw_via_pdf"]:
self.paths.save_latent_samples(latent_samples)
self.paths.save_samples_summary(
latent_samples.summary(),
"latent/latent_summary",
Expand Down
2 changes: 1 addition & 1 deletion autofit/non_linear/search/nest/nautilus/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def fit_multiprocessing(self, fitness, model, analysis):
"""
search_internal = self.sampler_cls(
prior=PriorVectorized(model=model),
likelihood=fitness.call_numpy_wrapper,
likelihood=fitness.__call__,
n_dim=model.prior_count,
prior_kwargs={"model": model},
filepath=self.checkpoint_file,
Expand Down
25 changes: 17 additions & 8 deletions test_autofit/analysis/test_latent_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@


class Analysis(af.Analysis):

LATENT_KEYS = ["fwhm"]

def log_likelihood_function(self, instance):
return 1.0

def compute_latent_variables(self, instance):
return {"fwhm": instance.fwhm}
def compute_latent_variables(self, parameters, model):

instance = model.instance_from_vector(vector=parameters)

return (instance.fwhm,)


@with_config(
Expand Down Expand Up @@ -100,15 +106,17 @@ def test_info(latent_samples):


class ComplexAnalysis(af.Analysis):

LATENT_KEYS = ["lens.mass", "lens.brightness", "source.brightness"]

def log_likelihood_function(self, instance):
return 1.0

def compute_latent_variables(self, instance):
return {
"lens.mass": 1.0,
"lens.brightness": 2.0,
"source.brightness": 3.0,
}
def compute_latent_variables(self, parameters, model):

instance = model.instance_from_vector(vector=parameters)

return (1.0, 2.0, 3.0)


def test_complex_model():
Expand All @@ -134,6 +142,7 @@ def test_complex_model():
instance = latent_samples.model.instance_from_prior_medians()

lens = instance.lens

assert lens.mass == 1.0
assert lens.brightness == 2.0

Expand Down
Loading
Loading