Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autofit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .aggregator.summary.aggregate_csv import AggregateCSV
from .aggregator.summary.aggregate_csv import ValueType
from .aggregator.summary.aggregate_images import AggregateImages
from .aggregator.summary.aggregate_fits import AggregateFITS, FITSFit
from .aggregator.summary.aggregate_fits import AggregateFITS
from .database.aggregator import Query
from autofit.aggregator.fit_interface import Fit
from .aggregator.search_output import SearchOutput
Expand Down
26 changes: 11 additions & 15 deletions autofit/aggregator/summary/aggregate_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,6 @@ def subplot_filename(subplot: Enum) -> str:
.lstrip("_")
)


class FITSFit(Enum):
"""
The HDUs that can be extracted from the fit.fits file.
"""

ModelData = "MODEL_IMAGE"
ResidualMap = "RESIDUAL_MAP"
NormalizedResidualMap = "NORMALIZED_RESIDUAL_MAP"
ChiSquaredMap = "CHI_SQUARED_MAP"


class AggregateFITS:
def __init__(self, aggregator: Union[Aggregator, List[SearchOutput]]):
"""
Expand All @@ -51,6 +39,7 @@ def __init__(self, aggregator: Union[Aggregator, List[SearchOutput]]):
def _hdus(
result: SearchOutput,
hdus: List[Enum],
extname_prefix = None
) -> "List[fits.ImageHDU]":
"""
Extract the HDUs from a given fits for a given search.
Expand All @@ -72,6 +61,10 @@ def _hdus(
for hdu in hdus:
source = result.value(subplot_filename(hdu))
source_hdu = source[source.index_of(hdu.value)]

if extname_prefix is not None:
source_hdu.header["EXTNAME"] = f"{extname_prefix.upper()}_{source_hdu.header['EXTNAME']}"

row.append(
fits.ImageHDU(
data=source_hdu.data,
Expand All @@ -80,7 +73,7 @@ def _hdus(
)
return row

def extract_fits(self, hdus: List[Enum]) -> "fits.HDUList":
def extract_fits(self, hdus: List[Enum], extname_prefix_list = None) -> "fits.HDUList":
"""
Extract the HDUs from the fits files for every search in the aggregator.

Expand All @@ -98,8 +91,11 @@ def extract_fits(self, hdus: List[Enum]) -> "fits.HDUList":
from astropy.io import fits

output = [fits.PrimaryHDU()]
for result in self.aggregator:
output.extend(self._hdus(result, hdus))
for i, result in enumerate(self.aggregator):

extname_prefix = extname_prefix_list[i] if extname_prefix_list is not None else None

output.extend(self._hdus(result, hdus, extname_prefix))

return fits.HDUList(output)

Expand Down
8 changes: 8 additions & 0 deletions autofit/aggregator/summary/aggregate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def extract_image(
self,
subplots: List[Union[Enum, List[Image.Image], Callable]],
subplot_width: Optional[int] = sys.maxsize,
transpose: bool = False,
) -> Image.Image:
"""
Extract the images at the specified subplots and combine them into
Expand All @@ -143,6 +144,9 @@ def extract_image(
the number of subplots.
If this is less than the number of subplots then it causes the
images to wrap.
transpose
If True the output image is transposed before being returned, else it
is returned as is.

Returns
-------
Expand All @@ -159,6 +163,10 @@ def extract_image(
)
)

if transpose:

matrix = [list(row) for row in zip(*matrix)]

return self._matrix_to_image(matrix)

def output_to_folder(
Expand Down
2 changes: 1 addition & 1 deletion autofit/non_linear/search/mle/pyswarms/search/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _fit(self, model: AbstractPriorModel, analysis):

if iterations > 0:
search_internal.optimize(
objective_func=fitness.call_wrap, iters=iterations
objective_func=fitness.call_wrap, iters=int(iterations)
)

total_iterations += iterations
Expand Down
24 changes: 18 additions & 6 deletions test_autofit/aggregator/summary_files/test_aggregate_fits.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
from enum import Enum
import pytest

import autofit as af
from pathlib import Path


class FITSFit(Enum):
"""
The HDUs that can be extracted from the fit.fits file.
"""

ModelData = "MODEL_IMAGE"
ResidualMap = "RESIDUAL_MAP"
NormalizedResidualMap = "NORMALIZED_RESIDUAL_MAP"
ChiSquaredMap = "CHI_SQUARED_MAP"


@pytest.fixture(name="summary")
def make_summary(aggregator):
return af.AggregateFITS(aggregator)
Expand All @@ -12,8 +24,8 @@ def make_summary(aggregator):
def test_aggregate(summary):
result = summary.extract_fits(
[
af.FITSFit.ModelData,
af.FITSFit.ResidualMap,
FITSFit.ModelData,
FITSFit.ResidualMap,
],
)
assert len(result) == 5
Expand All @@ -25,8 +37,8 @@ def test_output_to_file(summary, output_directory):
folder,
name="id",
hdus=[
af.FITSFit.ModelData,
af.FITSFit.ResidualMap,
FITSFit.ModelData,
FITSFit.ResidualMap,
],
)
assert list(folder.glob("*"))
Expand All @@ -37,8 +49,8 @@ def test_list_of_names(summary, output_directory):
output_directory,
["one", "two"],
[
af.FITSFit.ModelData,
af.FITSFit.ResidualMap,
FITSFit.ModelData,
FITSFit.ResidualMap,
],
)
assert set([path.name for path in Path(output_directory).glob("*.fits")]) == set([
Expand Down
Loading