From 3aa935a05588912d8ee9c28e52441428b467ba99 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 2 Oct 2025 13:54:58 +0100 Subject: [PATCH 1/6] cast to float to prevent JAX arrays breaking output --- autofit/non_linear/analysis/analysis.py | 19 ++++++++++++------- .../non_linear/search/nest/nautilus/search.py | 2 +- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 7591449b3..ae06fd31c 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -63,21 +63,26 @@ def compute_latent_samples(self, samples: Samples) -> Optional[Samples]: The computed latent samples or None if compute_latent_variables is not implemented. """ try: + latent_samples = [] model = samples.model for sample in samples.sample_list: + + kwargs = self.compute_latent_variables( + sample.instance_for_model(model, ignore_assertions=True) + ) + + # convert all values to Python floats to remove JAX arrays + kwargs = {k: float(v) if hasattr(v, "__array__") or isinstance(v, (np.generic,)) else v + for k, v in kwargs.items()} + latent_samples.append( Sample( log_likelihood=sample.log_likelihood, log_prior=sample.log_prior, weight=sample.weight, - kwargs=self.compute_latent_variables( - sample.instance_for_model( - model, - ignore_assertions=True, - ) - ), - ) + kwargs=kwargs + ) ) return type(samples)( diff --git a/autofit/non_linear/search/nest/nautilus/search.py b/autofit/non_linear/search/nest/nautilus/search.py index 86ebb91c5..128dc5726 100644 --- a/autofit/non_linear/search/nest/nautilus/search.py +++ b/autofit/non_linear/search/nest/nautilus/search.py @@ -250,7 +250,7 @@ def fit_multiprocessing(self, fitness, model, analysis): """ search_internal = self.sampler_cls( prior=PriorVectorized(model=model), - likelihood=fitness.call_numpy_wrapper, + likelihood=fitness.__call__, n_dim=model.prior_count, prior_kwargs={"model": model}, filepath=self.checkpoint_file, From ee9e284d6d8264527a585b55d76f61407051a45e Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 2 Oct 2025 17:58:57 +0100 Subject: [PATCH 2/6] latent variables now compute in JAX --- autofit/config/output.yaml | 2 +- autofit/non_linear/analysis/analysis.py | 63 +++++++++++++++++-------- autofit/non_linear/fitness.py | 12 ++--- 3 files changed, 51 insertions(+), 26 deletions(-) diff --git a/autofit/config/output.yaml b/autofit/config/output.yaml index 62f9d3030..a2be5648f 100644 --- a/autofit/config/output.yaml +++ b/autofit/config/output.yaml @@ -86,7 +86,7 @@ start_point: true # manually after completing a successful model-fit. This will save computational run time by not computing latent # variables during a any model-fit which is unsuccessful. -latent_during_fit: true # Whether to output the `latent.csv`, `latent.results` and `latent_summary.json` files during the fit when it performs on-the-fly output. +latent_during_fit: false # Whether to output the `latent.csv`, `latent.results` and `latent_summary.json` files during the fit when it performs on-the-fly output. latent_after_fit: true # If `latent_during_fit` is False, whether to output the `latent.csv`, `latent.results` and `latent_summary.json` files after the fit is complete. latent_csv: true # Whether to ouptut the `latent.csv` file. latent_results: true # Whether to output the `latent.results` file. diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index ae06fd31c..69604205e 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -1,8 +1,14 @@ import inspect import logging from abc import ABC +import functools +import numpy as np +import jax +import jax.numpy as jnp from typing import Optional, Dict +from autofit.jax_wrapper import use_jax + from autofit.mapper.prior_model.abstract import AbstractPriorModel from autofit.non_linear.paths.abstract import AbstractPaths from autofit.non_linear.samples.summary import SamplesSummary @@ -27,6 +33,8 @@ class Analysis(ABC): Result = Result Visualizer = Visualizer + LATENT_KEYS = [] + def __getattr__(self, item: str): """ If a method starts with 'visualize_' then we assume it is associated with @@ -64,36 +72,50 @@ def compute_latent_samples(self, samples: Samples) -> Optional[Samples]: """ try: + compute_latent_for_model = functools.partial(self.compute_latent_variables, model=samples.model) + + if use_jax: + batched_compute_latent = jax.jit(jax.vmap(compute_latent_for_model)) + else: + def batched_compute_latent(x): + return np.array([compute_latent_for_model(xx) for xx in x]) + + parameter_array = np.array(samples.parameter_lists) + batch_size = 50 latent_samples = [] - model = samples.model - for sample in samples.sample_list: - - kwargs = self.compute_latent_variables( - sample.instance_for_model(model, ignore_assertions=True) - ) - - # convert all values to Python floats to remove JAX arrays - kwargs = {k: float(v) if hasattr(v, "__array__") or isinstance(v, (np.generic,)) else v - for k, v in kwargs.items()} - - latent_samples.append( - Sample( - log_likelihood=sample.log_likelihood, - log_prior=sample.log_prior, - weight=sample.weight, - kwargs=kwargs + + # process in batches + for i in range(0, len(parameter_array), batch_size): + + batch = parameter_array[i:i + batch_size] + + # batched JAX call on this chunk + latent_values_batch = batched_compute_latent(batch) + if use_jax: + latent_values_batch = jnp.stack(latent_values_batch, axis=-1) + + for sample, values in zip(samples.sample_list[i:i + batch_size], latent_values_batch): + kwargs = {k: float(v) for k, v in zip(self.LATENT_KEYS, values)} + + latent_samples.append( + Sample( + log_likelihood=sample.log_likelihood, + log_prior=sample.log_prior, + weight=sample.weight, + kwargs=kwargs, ) - ) + ) return type(samples)( sample_list=latent_samples, model=simple_model_for_kwargs(latent_samples[0].kwargs), samples_info=samples.samples_info, ) + except NotImplementedError: return None - def compute_latent_variables(self, instance) -> Dict[str, float]: + def compute_latent_variables(self, parameters, model) -> Dict[str, float]: """ Override to compute latent variables from the instance. @@ -242,3 +264,6 @@ def profile_log_likelihood_function(self, paths: AbstractPaths, instance): The maximum likliehood instance of the model so far in the non-linear search. """ pass + + def latent_lh_dict_from(self, **kwargs): + return None \ No newline at end of file diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index f2db96a8f..204427e82 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -187,26 +187,26 @@ def __call__(self, parameters, *kwargs): @cached_property def _vmap(self): start = time.time() - print("JAX: Applying vmap and jit to likelihood function -- may take a few seconds.") + logger.info("JAX: Applying vmap and jit to likelihood function -- may take a few seconds.") func = jax.vmap(jax.jit(self.call)) - print(f"JAX: vmap and jit applied in {time.time() - start} seconds.") + logger.info(f"JAX: vmap and jit applied in {time.time() - start} seconds.") return func @cached_property def _call(self): start = time.time() - print("JAX: Applying jit to likelihood function -- may take a few seconds.") + logger.info("JAX: Applying jit to likelihood function -- may take a few seconds.") func = jax_wrapper.jit(self.call) - print(f"JAX: jit applied in {time.time() - start} seconds.") + logger.info(f"JAX: jit applied in {time.time() - start} seconds.") return func @cached_property def _grad(self): start = time.time() - print("JAX: Applying grad to likelihood function -- may take a few seconds.") + logger.info("JAX: Applying grad to likelihood function -- may take a few seconds.") func = jax_wrapper.grad(self._call) - print(f"JAX: grad applied in {time.time() - start} seconds.") + logger.info(f"JAX: grad applied in {time.time() - start} seconds.") return func def grad(self, *args, **kwargs): From 16a400d5de4eb785ebe4bb1881cb97eebe450e50 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Thu, 2 Oct 2025 18:29:18 +0100 Subject: [PATCH 3/6] added samples_drawn_randomly_via_pdf_from --- autofit/non_linear/samples/pdf.py | 31 +++++++++++++++++++ .../non_linear/samples/test_samples.py | 30 ++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/autofit/non_linear/samples/pdf.py b/autofit/non_linear/samples/pdf.py index 48f96e463..6024950ac 100644 --- a/autofit/non_linear/samples/pdf.py +++ b/autofit/non_linear/samples/pdf.py @@ -328,6 +328,37 @@ def draw_randomly_via_pdf(self) -> Union[List, ModelInstance]: return self.parameter_lists[sample_index][:] + def samples_drawn_randomly_via_pdf_from(self, total_draws: int = 100) -> "SamplesPDF": + """ + Draw one or more samples randomly from the PDF, weighted by the sample weights. + + Parameters + ---------- + total_draws : int, optional + The number of samples to draw. Defaults to 100. + + Returns + ------- + SamplesPDF + A new SamplesPDF object containing the drawn samples. + """ + # Normalize weights to sum to 1 + weights = np.asarray(self.weight_list, dtype=float) + weights /= weights.sum() + + sample_indices = np.random.choice( + a=len(self.sample_list), + size=total_draws, + replace=True, + p=weights, + ) + + return SamplesPDF( + model=self.model, + sample_list=[self.sample_list[i] for i in sample_indices], + samples_info=self.samples_info, + ) + @to_instance def offset_values_via_input_values( self, input_vector: List diff --git a/test_autofit/non_linear/samples/test_samples.py b/test_autofit/non_linear/samples/test_samples.py index 40c5f45a5..21ac0eb85 100644 --- a/test_autofit/non_linear/samples/test_samples.py +++ b/test_autofit/non_linear/samples/test_samples.py @@ -164,6 +164,36 @@ def test__samples_above_weight_threshold_from(): assert len(samples_above_weight_threshold) == 3 assert samples_above_weight_threshold.sample_list[0].weight == 1.0 +def test__samples_drawn_randomly_via_pdf_from(): + + model = af.Collection(mock_class=af.m.MockClassx4) + + parameters = [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [1.0, 2.0, 3.0, 4.0], + [1.0, 2.0, 3.0, 4.0], + [1.1, 2.1, 3.1, 4.1], + ] + + samples_x5 = af.m.MockSamples( + model=model, + sample_list=af.Sample.from_lists( + model=model, + parameter_lists=parameters, + log_likelihood_list=[0.0, 0.0, 0.0, 0.0, 0.0], + log_prior_list=[0.0, 0.0, 0.0, 0.0, 0.0], + weight_list=[0.2, 0.2, 1.0, 1.0, 1.0], + ), + ) + + samples_drawn_randomly_via_pdf = samples_x5.samples_drawn_randomly_via_pdf_from( + total_draws=3 + ) + + assert len(samples_drawn_randomly_via_pdf) == 3 + assert samples_drawn_randomly_via_pdf.sample_list[0].weight == 1.0 + def test__addition_of_samples(samples_x5): samples = samples_x5 + samples_x5 From 048e8620ca6d611ad35acc2f98c2d09c5768c467 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 3 Oct 2025 09:50:35 +0100 Subject: [PATCH 4/6] fix one latent variable test --- autofit/config/output.yaml | 2 + autofit/example/analysis.py | 9 +++-- autofit/non_linear/analysis/analysis.py | 39 ++++++++++++++++--- autofit/non_linear/search/abstract_search.py | 18 ++++++++- .../analysis/test_latent_variables.py | 21 ++++++---- 5 files changed, 71 insertions(+), 18 deletions(-) diff --git a/autofit/config/output.yaml b/autofit/config/output.yaml index a2be5648f..d36a9e1f3 100644 --- a/autofit/config/output.yaml +++ b/autofit/config/output.yaml @@ -88,6 +88,8 @@ start_point: true latent_during_fit: false # Whether to output the `latent.csv`, `latent.results` and `latent_summary.json` files during the fit when it performs on-the-fly output. latent_after_fit: true # If `latent_during_fit` is False, whether to output the `latent.csv`, `latent.results` and `latent_summary.json` files after the fit is complete. +latent_draw_via_pdf : true # Whether to draw latent variable values via the PDF of every sample, which uses fewer samples to estimate latent variable errors. If False, latent variable values are drawn from every sample. +latent_draw_via_pdf_size : 100 # The number of samples drawn to estimate latent variable errors if `latent_draw_via_pdf` is True. latent_csv: true # Whether to ouptut the `latent.csv` file. latent_results: true # Whether to output the `latent.results` file. diff --git a/autofit/example/analysis.py b/autofit/example/analysis.py index cdc086f64..65891e2d5 100644 --- a/autofit/example/analysis.py +++ b/autofit/example/analysis.py @@ -34,9 +34,10 @@ class Analysis(af.Analysis): It has been extended, based on the model that is input into the analysis, to include a property `max_log_likelihood_model_data`, which is the model data of the best-fit model. """ - Result = ResultExample + LATENT_KEYS = ["fwhm"] + def __init__(self, data: np.ndarray, noise_map: np.ndarray): """ In this example the `Analysis` object only contains the data and noise-map. It can be easily extended, @@ -231,9 +232,9 @@ def compute_latent_variables(self, instance) -> Dict[str, float]: """ try: - return {"fwhm": instance.fwhm} + return (instance.fwhm, ) except AttributeError: try: - return {"gaussian.fwhm": instance[0].fwhm} + return (instance[0].fwhm,) except AttributeError: - return {"gaussian.fwhm": instance[0].gaussian.fwhm} \ No newline at end of file + return (instance[0].gaussian.fwhm,) \ No newline at end of file diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 69604205e..a5a14ac41 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -59,16 +59,36 @@ def method(*args, **kwargs): def compute_latent_samples(self, samples: Samples) -> Optional[Samples]: """ - Internal method that manages computation of latent samples from samples. + Compute latent variables from a model instance. + + A latent variable is not itself a free parameter of the model but can be derived from it. + Latent variables may provide physically meaningful quantities that help interpret a model + fit, and their values (with errors) are stored in `latent.csv` in parallel with `samples.csv`. + + This implementation is designed to be compatible with both NumPy and JAX: + + - It is written to be side-effect free, so it can be JIT-compiled with `jax.jit`. + - It can be vectorized over many parameter sets at once using `jax.vmap`, enabling efficient + batched evaluation of latent variables for multiple samples. + - Returned values should be simple JAX/NumPy scalars or arrays (no Python objects), so they + can be stacked into arrays of shape `(n_samples, n_latents)` for batching. + - Any NaNs introduced (e.g. from invalid model states) can be masked or replaced downstream. Parameters ---------- - samples - The samples from the non-linear search. + parameters : array-like + The parameter vector of the model sample. This will typically come from the non-linear search. + Inside this method it is mapped back to a model instance via `model.instance_from_vector`. + model : Model + The model object defining how the parameter vector is mapped to an instance. Passed explicitly + so that this function can be used inside JAX transforms (`vmap`, `jit`) with `functools.partial`. Returns ------- - The computed latent samples or None if compute_latent_variables is not implemented. + tuple of (float or jax.numpy scalar) + A tuple containing the latent variables in a fixed order: + `(intensity_total, magnitude, angle)`. Each entry may be NaN if the corresponding component + of the model is not present. """ try: @@ -91,10 +111,19 @@ def batched_compute_latent(x): # batched JAX call on this chunk latent_values_batch = batched_compute_latent(batch) + if use_jax: - latent_values_batch = jnp.stack(latent_values_batch, axis=-1) + # latent_values_batch = jnp.stack(latent_values_batch, axis=-1) # (batch, n_latents) + mask = jnp.all(jnp.isfinite(latent_values_batch), axis=0) + latent_values_batch = latent_values_batch[:, mask] + else: + # latent_values_batch = np.stack(latent_values_batch, axis=-1) # (batch, n_latents) + mask = np.all(np.isfinite(latent_values_batch), axis=0) + latent_values_batch = latent_values_batch[:, mask] for sample, values in zip(samples.sample_list[i:i + batch_size], latent_values_batch): + + kwargs = {k: float(v) for k, v in zip(self.LATENT_KEYS, values)} latent_samples.append( diff --git a/autofit/non_linear/search/abstract_search.py b/autofit/non_linear/search/abstract_search.py index 957b9c9ef..702ebc91f 100644 --- a/autofit/non_linear/search/abstract_search.py +++ b/autofit/non_linear/search/abstract_search.py @@ -958,10 +958,26 @@ def perform_update( if (during_analysis and conf.instance["output"]["latent_during_fit"]) or ( not during_analysis and conf.instance["output"]["latent_after_fit"] ): + + if conf.instance["output"]["latent_draw_via_pdf"]: + + total_draws = conf.instance["output"]["latent_draw_via_pdf_size"] + + logger.info(f"Creating latent samples by drawing {total_draws} from the PDF.") + + latent_samples = samples.samples_drawn_randomly_via_pdf_from(total_draws=total_draws) + + else: + + logger.info(f"Creating latent samples using all samples above the samples weight threshold.") + + latent_samples = samples_save + latent_samples = analysis.compute_latent_samples(samples_save) if latent_samples: - self.paths.save_latent_samples(latent_samples) + if not conf.instance["output"]["latent_draw_via_pdf"]: + self.paths.save_latent_samples(latent_samples) self.paths.save_samples_summary( latent_samples.summary(), "latent/latent_summary", diff --git a/test_autofit/analysis/test_latent_variables.py b/test_autofit/analysis/test_latent_variables.py index 8b1cdd889..a1e305fe3 100644 --- a/test_autofit/analysis/test_latent_variables.py +++ b/test_autofit/analysis/test_latent_variables.py @@ -7,11 +7,14 @@ class Analysis(af.Analysis): + + LATENT_KEYS = ["fwhm"] + def log_likelihood_function(self, instance): return 1.0 - def compute_latent_variables(self, instance): - return {"fwhm": instance.fwhm} + def compute_latent_variables(self, instance, model): + return (instance.fwhm,) @with_config( @@ -100,15 +103,14 @@ def test_info(latent_samples): class ComplexAnalysis(af.Analysis): + + LATENT_KEYS = ["lens.mass", "lens.brightness", "source.brightness"] + def log_likelihood_function(self, instance): return 1.0 - def compute_latent_variables(self, instance): - return { - "lens.mass": 1.0, - "lens.brightness": 2.0, - "source.brightness": 3.0, - } + def compute_latent_variables(self, instance, model): + return (1.0, 2.0, 3.0) def test_complex_model(): @@ -134,6 +136,9 @@ def test_complex_model(): instance = latent_samples.model.instance_from_prior_medians() lens = instance.lens + + print(lens) + assert lens.mass == 1.0 assert lens.brightness == 2.0 From d785b4f39b3baa4ce112140f0acc44e30cff9db8 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 3 Oct 2025 09:51:59 +0100 Subject: [PATCH 5/6] fix all latent tests --- test_autofit/analysis/test_latent_variables.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test_autofit/analysis/test_latent_variables.py b/test_autofit/analysis/test_latent_variables.py index a1e305fe3..8fb5281cd 100644 --- a/test_autofit/analysis/test_latent_variables.py +++ b/test_autofit/analysis/test_latent_variables.py @@ -13,7 +13,10 @@ class Analysis(af.Analysis): def log_likelihood_function(self, instance): return 1.0 - def compute_latent_variables(self, instance, model): + def compute_latent_variables(self, parameters, model): + + instance = model.instance_from_vector(vector=parameters) + return (instance.fwhm,) @@ -109,7 +112,10 @@ class ComplexAnalysis(af.Analysis): def log_likelihood_function(self, instance): return 1.0 - def compute_latent_variables(self, instance, model): + def compute_latent_variables(self, parameters, model): + + instance = model.instance_from_vector(vector=parameters) + return (1.0, 2.0, 3.0) @@ -137,8 +143,6 @@ def test_complex_model(): lens = instance.lens - print(lens) - assert lens.mass == 1.0 assert lens.brightness == 2.0 From 33043afd6cd1bd98948179ab2db3f40526b881f1 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 3 Oct 2025 10:08:47 +0100 Subject: [PATCH 6/6] implementation fully tested --- autofit/example/analysis.py | 15 ++++++++++++--- autofit/non_linear/analysis/analysis.py | 6 ++++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/autofit/example/analysis.py b/autofit/example/analysis.py index 65891e2d5..ce9798200 100644 --- a/autofit/example/analysis.py +++ b/autofit/example/analysis.py @@ -205,7 +205,7 @@ def make_result( analysis=self, ) - def compute_latent_variables(self, instance) -> Dict[str, float]: + def compute_latent_variables(self, parameters, model) -> Dict[str, float]: """ A latent variable is not a model parameter but can be derived from the model. Its value and errors may be of interest and aid in the interpretation of a model-fit. @@ -219,8 +219,15 @@ def compute_latent_variables(self, instance) -> Dict[str, float]: In the example below, the `latent.csv` file will contain one column with the FWHM of every Gausian model sampled by the non-linear search. - This function is called for every non-linear search sample, where the `instance` passed in corresponds to - each sample. + This function is called at the end of search, following one of two schemes depending on the settings in + `output.yaml`: + + 1) Call for every search sample, which produces a complete `latent/samples.csv` which mirrors the normal + `samples.csv` file but takes a long time to compute. + + 2) Call only for N random draws from the posterior inferred at the end of the search, which only produces a + `latent/latent_summary.json` file with the median and 1 and 3 sigma errors of the latent variables but is + fast to compute. Parameters ---------- @@ -231,6 +238,8 @@ def compute_latent_variables(self, instance) -> Dict[str, float]: ------- """ + instance = model.instance_from_vector(vector=parameters) + try: return (instance.fwhm, ) except AttributeError: diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index a5a14ac41..aa47f0c7c 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -95,7 +95,10 @@ def compute_latent_samples(self, samples: Samples) -> Optional[Samples]: compute_latent_for_model = functools.partial(self.compute_latent_variables, model=samples.model) if use_jax: + start = time.time() + logger.info("JAX: Applying vmap and jit to likelihood function for latent variables -- may take a few seconds.") batched_compute_latent = jax.jit(jax.vmap(compute_latent_for_model)) + logger.info(f"JAX: vmap and jit applied in {time.time() - start} seconds.") else: def batched_compute_latent(x): return np.array([compute_latent_for_model(xx) for xx in x]) @@ -113,11 +116,10 @@ def batched_compute_latent(x): latent_values_batch = batched_compute_latent(batch) if use_jax: - # latent_values_batch = jnp.stack(latent_values_batch, axis=-1) # (batch, n_latents) + latent_values_batch = jnp.stack(latent_values_batch, axis=-1) # (batch, n_latents) mask = jnp.all(jnp.isfinite(latent_values_batch), axis=0) latent_values_batch = latent_values_batch[:, mask] else: - # latent_values_batch = np.stack(latent_values_batch, axis=-1) # (batch, n_latents) mask = np.all(np.isfinite(latent_values_batch), axis=0) latent_values_batch = latent_values_batch[:, mask]