diff --git a/autofit/__init__.py b/autofit/__init__.py index f9682793a..949baad80 100644 --- a/autofit/__init__.py +++ b/autofit/__init__.py @@ -31,7 +31,7 @@ from .aggregator.summary.aggregate_csv import AggregateCSV from .aggregator.summary.aggregate_csv import ValueType from .aggregator.summary.aggregate_images import AggregateImages -from .aggregator.summary.aggregate_fits import AggregateFITS, FITSFit +from .aggregator.summary.aggregate_fits import AggregateFITS from .database.aggregator import Query from autofit.aggregator.fit_interface import Fit from .aggregator.search_output import SearchOutput diff --git a/autofit/aggregator/summary/aggregate_fits.py b/autofit/aggregator/summary/aggregate_fits.py index 64634ad4c..1f8267a98 100644 --- a/autofit/aggregator/summary/aggregate_fits.py +++ b/autofit/aggregator/summary/aggregate_fits.py @@ -20,18 +20,6 @@ def subplot_filename(subplot: Enum) -> str: .lstrip("_") ) - -class FITSFit(Enum): - """ - The HDUs that can be extracted from the fit.fits file. - """ - - ModelData = "MODEL_IMAGE" - ResidualMap = "RESIDUAL_MAP" - NormalizedResidualMap = "NORMALIZED_RESIDUAL_MAP" - ChiSquaredMap = "CHI_SQUARED_MAP" - - class AggregateFITS: def __init__(self, aggregator: Union[Aggregator, List[SearchOutput]]): """ @@ -51,6 +39,7 @@ def __init__(self, aggregator: Union[Aggregator, List[SearchOutput]]): def _hdus( result: SearchOutput, hdus: List[Enum], + extname_prefix = None ) -> "List[fits.ImageHDU]": """ Extract the HDUs from a given fits for a given search. @@ -72,6 +61,10 @@ def _hdus( for hdu in hdus: source = result.value(subplot_filename(hdu)) source_hdu = source[source.index_of(hdu.value)] + + if extname_prefix is not None: + source_hdu.header["EXTNAME"] = f"{extname_prefix.upper()}_{source_hdu.header['EXTNAME']}" + row.append( fits.ImageHDU( data=source_hdu.data, @@ -80,7 +73,7 @@ def _hdus( ) return row - def extract_fits(self, hdus: List[Enum]) -> "fits.HDUList": + def extract_fits(self, hdus: List[Enum], extname_prefix_list = None) -> "fits.HDUList": """ Extract the HDUs from the fits files for every search in the aggregator. @@ -98,8 +91,11 @@ def extract_fits(self, hdus: List[Enum]) -> "fits.HDUList": from astropy.io import fits output = [fits.PrimaryHDU()] - for result in self.aggregator: - output.extend(self._hdus(result, hdus)) + for i, result in enumerate(self.aggregator): + + extname_prefix = extname_prefix_list[i] if extname_prefix_list is not None else None + + output.extend(self._hdus(result, hdus, extname_prefix)) return fits.HDUList(output) diff --git a/autofit/aggregator/summary/aggregate_images.py b/autofit/aggregator/summary/aggregate_images.py index 9adaadfd4..da9b179f9 100644 --- a/autofit/aggregator/summary/aggregate_images.py +++ b/autofit/aggregator/summary/aggregate_images.py @@ -124,6 +124,7 @@ def extract_image( self, subplots: List[Union[Enum, List[Image.Image], Callable]], subplot_width: Optional[int] = sys.maxsize, + transpose: bool = False, ) -> Image.Image: """ Extract the images at the specified subplots and combine them into @@ -143,6 +144,9 @@ def extract_image( the number of subplots. If this is less than the number of subplots then it causes the images to wrap. + transpose + If True the output image is transposed before being returned, else it + is returned as is. Returns ------- @@ -159,6 +163,10 @@ def extract_image( ) ) + if transpose: + + matrix = [list(row) for row in zip(*matrix)] + return self._matrix_to_image(matrix) def output_to_folder( diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index 696cf2322..be0e63403 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -159,8 +159,17 @@ def call(self, parameters): # Get instance from model instance = self.model.instance_from_vector(vector=parameters) - # Evaluate log likelihood (must be side-effect free and exception-free) - log_likelihood = self.analysis.log_likelihood_function(instance=instance) + if self._xp.__name__.startswith("jax"): + + # Evaluate log likelihood (must be side-effect free and exception-free) + log_likelihood = self.analysis.log_likelihood_function(instance=instance) + + else: + + try: + log_likelihood = self.analysis.log_likelihood_function(instance=instance) + except exc.FitException: + return self.resample_figure_of_merit # Penalize NaNs in the log-likelihood log_likelihood = self._xp.where(self._xp.isnan(log_likelihood), self.resample_figure_of_merit, log_likelihood) @@ -288,7 +297,7 @@ def manage_quick_update(self, parameters, log_likelihood): best_parameters = parameters[best_idx] total_updates = log_likelihood.shape[0] - except (AttributeError, IndexError): + except (AttributeError, IndexError, TypeError): best_log_likelihood = log_likelihood best_parameters = parameters diff --git a/autofit/non_linear/search/mle/pyswarms/search/abstract.py b/autofit/non_linear/search/mle/pyswarms/search/abstract.py index 1a2150126..c92b5d044 100644 --- a/autofit/non_linear/search/mle/pyswarms/search/abstract.py +++ b/autofit/non_linear/search/mle/pyswarms/search/abstract.py @@ -212,7 +212,7 @@ def _fit(self, model: AbstractPriorModel, analysis): if iterations > 0: search_internal.optimize( - objective_func=fitness.call_wrap, iters=iterations + objective_func=fitness.call_wrap, iters=int(iterations) ) total_iterations += iterations diff --git a/test_autofit/aggregator/summary_files/test_aggregate_fits.py b/test_autofit/aggregator/summary_files/test_aggregate_fits.py index 7416c6c87..ba4aeb303 100644 --- a/test_autofit/aggregator/summary_files/test_aggregate_fits.py +++ b/test_autofit/aggregator/summary_files/test_aggregate_fits.py @@ -1,9 +1,21 @@ +from enum import Enum import pytest import autofit as af from pathlib import Path +class FITSFit(Enum): + """ + The HDUs that can be extracted from the fit.fits file. + """ + + ModelData = "MODEL_IMAGE" + ResidualMap = "RESIDUAL_MAP" + NormalizedResidualMap = "NORMALIZED_RESIDUAL_MAP" + ChiSquaredMap = "CHI_SQUARED_MAP" + + @pytest.fixture(name="summary") def make_summary(aggregator): return af.AggregateFITS(aggregator) @@ -12,8 +24,8 @@ def make_summary(aggregator): def test_aggregate(summary): result = summary.extract_fits( [ - af.FITSFit.ModelData, - af.FITSFit.ResidualMap, + FITSFit.ModelData, + FITSFit.ResidualMap, ], ) assert len(result) == 5 @@ -25,8 +37,8 @@ def test_output_to_file(summary, output_directory): folder, name="id", hdus=[ - af.FITSFit.ModelData, - af.FITSFit.ResidualMap, + FITSFit.ModelData, + FITSFit.ResidualMap, ], ) assert list(folder.glob("*")) @@ -37,8 +49,8 @@ def test_list_of_names(summary, output_directory): output_directory, ["one", "two"], [ - af.FITSFit.ModelData, - af.FITSFit.ResidualMap, + FITSFit.ModelData, + FITSFit.ResidualMap, ], ) assert set([path.name for path in Path(output_directory).glob("*.fits")]) == set([