diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4fb654d4f..75e0a78a2 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -22,7 +22,7 @@ jobs: uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - - uses: actions/cache@v2 + - uses: actions/cache@v3 id: cache-pip with: path: ~/.cache/pip diff --git a/autofit/__init__.py b/autofit/__init__.py index c3173e2f8..7af5fe0af 100644 --- a/autofit/__init__.py +++ b/autofit/__init__.py @@ -25,8 +25,9 @@ from .non_linear.samples import load_from_table from .non_linear.samples import SamplesStored from .database.aggregator import Aggregator -from .aggregator.aggregate_csv import AggregateCSV -from .aggregator.aggregate_images import AggregateImages +from .aggregator.summary.aggregate_csv import AggregateCSV +from .aggregator.summary.aggregate_images import AggregateImages +from .aggregator.summary.aggregate_fits import AggregateFITS, FitFITS from .database.aggregator import Query from autofit.aggregator.fit_interface import Fit from .aggregator.search_output import SearchOutput diff --git a/autofit/aggregator/file_output.py b/autofit/aggregator/file_output.py index 0cdb4dd97..88b5e2058 100644 --- a/autofit/aggregator/file_output.py +++ b/autofit/aggregator/file_output.py @@ -20,7 +20,7 @@ def __new__(cls, name, path: Path): elif suffix == ".csv": return super().__new__(ArrayOutput) elif suffix == ".fits": - return super().__new__(HDUOutput) + return super().__new__(FITSOutput) raise ValueError(f"File {path} is not a valid output file") def __init__(self, name: str, path: Path): @@ -92,17 +92,7 @@ def value(self): return dill.load(f) -class HDUOutput(FileOutput): - def __init__(self, name: str, path: Path): - super().__init__(name, path) - self._file = None - - @property - def file(self): - if self._file is None: - self._file = open(self.path, "rb") - return self._file - +class FITSOutput(FileOutput): @property def value(self): """ @@ -110,8 +100,4 @@ def value(self): """ from astropy.io import fits - return fits.PrimaryHDU.readfrom(self.file) - - def __del__(self): - if self._file is not None: - self._file.close() + return fits.open(self.path) diff --git a/autofit/aggregator/search_output.py b/autofit/aggregator/search_output.py index 9c1a6e2bd..cc44178c6 100644 --- a/autofit/aggregator/search_output.py +++ b/autofit/aggregator/search_output.py @@ -78,11 +78,15 @@ def files_path(self): return self.directory / "files" def _outputs(self, suffix): + return self._outputs_in_directory("files", suffix) + self._outputs_in_directory( + "image", suffix + ) + + def _outputs_in_directory(self, name: str, suffix: str): + files_path = self.directory / name outputs = [] - for file_path in self.files_path.rglob(f"*{suffix}"): - name = ".".join( - file_path.relative_to(self.files_path).with_suffix("").parts - ) + for file_path in files_path.rglob(f"*{suffix}"): + name = ".".join(file_path.relative_to(files_path).with_suffix("").parts) outputs.append(FileOutput(name, file_path)) return outputs @@ -108,7 +112,7 @@ def pickles(self): return self._outputs(".pickle") @cached_property - def hdus(self): + def fits(self): """ The fits files in the search output files directory """ @@ -170,7 +174,7 @@ def value(self, name: str): for item in self.jsons: if item.name == name: return item.value_using_reference(self._reference) - for item in self.pickles + self.arrays + self.hdus: + for item in self.pickles + self.arrays + self.fits: if item.name == name: return item.value diff --git a/test_autofit/aggregator/aggregate_summary/fit_1/metadata b/autofit/aggregator/summary/__init__.py similarity index 100% rename from test_autofit/aggregator/aggregate_summary/fit_1/metadata rename to autofit/aggregator/summary/__init__.py diff --git a/autofit/aggregator/aggregate_csv.py b/autofit/aggregator/summary/aggregate_csv.py similarity index 98% rename from autofit/aggregator/aggregate_csv.py rename to autofit/aggregator/summary/aggregate_csv.py index 1af9b0054..22d94d0a5 100644 --- a/autofit/aggregator/aggregate_csv.py +++ b/autofit/aggregator/summary/aggregate_csv.py @@ -190,6 +190,9 @@ def __init__(self, aggregator: Aggregator): ---------- aggregator """ + if len(aggregator) == 0: + raise ValueError("The aggregator is empty.") + self._aggregator = aggregator self._columns = [] diff --git a/autofit/aggregator/summary/aggregate_fits.py b/autofit/aggregator/summary/aggregate_fits.py new file mode 100644 index 000000000..1039221bf --- /dev/null +++ b/autofit/aggregator/summary/aggregate_fits.py @@ -0,0 +1,140 @@ +import re +from enum import Enum +from typing import List, Union + +from astropy.io import fits +from pathlib import Path + +from autofit.aggregator.search_output import SearchOutput +from autofit.aggregator import Aggregator + + +def subplot_filename(subplot: Enum) -> str: + subplot_type = subplot.__class__ + return ( + re.sub( + r"([A-Z])", + r"_\1", + subplot_type.__name__.replace("FITS", ""), + ) + .lower() + .lstrip("_") + ) + + +class FitFITS(Enum): + """ + The HDUs that can be extracted from the fit.fits file. + """ + + ModelImage = "MODEL_IMAGE" + ResidualMap = "RESIDUAL_MAP" + NormalizedResidualMap = "NORMALIZED_RESIDUAL_MAP" + ChiSquaredMap = "CHI_SQUARED_MAP" + + +class AggregateFITS: + def __init__(self, aggregator: Aggregator): + """ + A class for extracting fits files from the aggregator. + + Parameters + ---------- + aggregator + The aggregator containing the fits files. + """ + if len(aggregator) == 0: + raise ValueError("The aggregator is empty.") + + self.aggregator = aggregator + + @staticmethod + def _hdus( + result: SearchOutput, + *hdus: Enum, + ) -> List[fits.ImageHDU]: + """ + Extract the HDUs from a given fits for a given search. + + Parameters + ---------- + result + The search output. + hdus + The HDUs to extract. + + Returns + ------- + The extracted HDUs. + """ + row = [] + for hdu in hdus: + source = result.value(subplot_filename(hdu)) + source_hdu = source[source.index_of(hdu.value)] + row.append( + fits.ImageHDU( + data=source_hdu.data, + header=source_hdu.header, + ) + ) + return row + + def extract_fits(self, *hdus: Enum) -> List[fits.HDUList]: + """ + Extract the HDUs from the fits files for every search in the aggregator. + + Return the result as a list of HDULists. The first HDU in each list is an empty PrimaryHDU. + + Parameters + ---------- + hdus + The HDUs to extract. + + Returns + ------- + The extracted HDUs. + """ + output = [fits.PrimaryHDU()] + for result in self.aggregator: + output.extend(self._hdus(result, *hdus)) + + return fits.HDUList(output) + + def output_to_folder( + self, + folder: Path, + *hdus: Enum, + name: Union[str, List[str]], + ): + """ + Output the fits files for every search in the aggregator to a folder. + + Only include HDUs specific in the hdus argument. + + Parameters + ---------- + folder + The folder to output the fits files to. + hdus + The HDUs to output. + name + The name of the fits file. This is the attribute of the search output that is used to name the file. + OR a list of names for each HDU. + """ + folder.mkdir(parents=True, exist_ok=True) + + for i, result in enumerate(self.aggregator): + if isinstance(name, str): + output_name = getattr(result, name) + else: + output_name = name[i] + + hdu_list = fits.HDUList( + [fits.PrimaryHDU()] + + self._hdus( + result, + *hdus, + ) + ) + with open(folder / f"{output_name}.fits", "wb") as file: + hdu_list.writeto(file) diff --git a/autofit/aggregator/aggregate_images.py b/autofit/aggregator/summary/aggregate_images.py similarity index 87% rename from autofit/aggregator/aggregate_images.py rename to autofit/aggregator/summary/aggregate_images.py index c08d29a3f..37d5b2ab2 100644 --- a/autofit/aggregator/aggregate_images.py +++ b/autofit/aggregator/summary/aggregate_images.py @@ -1,6 +1,4 @@ -import re import sys -from enum import Enum from typing import Optional, List, Union, Callable, Type from pathlib import Path @@ -9,6 +7,22 @@ from autofit.aggregator.search_output import SearchOutput from autofit.aggregator.aggregator import Aggregator +import re +from enum import Enum + + +def subplot_filename(subplot: Enum) -> str: + subplot_type = subplot.__class__ + return ( + re.sub( + r"([A-Z])", + r"_\1", + subplot_type.__name__, + ) + .lower() + .lstrip("_") + ) + class SubplotFit(Enum): """ @@ -100,6 +114,9 @@ def __init__( aggregator The aggregator containing the fit results. """ + if len(aggregator) == 0: + raise ValueError("The aggregator is empty.") + self._aggregator = aggregator self._source_images = None @@ -149,7 +166,7 @@ def output_to_folder( folder: Path, *subplots: Union[SubplotFit, List[Image.Image], Callable], subplot_width: Optional[int] = sys.maxsize, - name: str = "name", + name: Union[str, List[str]], ): """ Output one subplot image for each fit in the aggregator. @@ -171,8 +188,9 @@ def output_to_folder( images to wrap. name The attribute of each fit to use as the name of the output file. + OR a list of names, one for each fit. """ - folder.mkdir(exist_ok=True) + folder.mkdir(exist_ok=True, parents=True) for i, result in enumerate(self._aggregator): image = self._matrix_to_image( @@ -183,7 +201,13 @@ def output_to_folder( subplot_width=subplot_width, ) ) - image.save(folder / f"{getattr(result, name)}.png") + + if isinstance(name, str): + output_name = getattr(result, name) + else: + output_name = name[i] + + image.save(folder / f"{output_name}.png") @staticmethod def _matrix_for_result( @@ -231,30 +255,30 @@ class name but using snake_case. The image for the subplot. """ subplot_type = subplot_.__class__ - name = ( - re.sub( - r"([A-Z])", - r"_\1", - subplot_type.__name__, - ) - .lower() - .lstrip("_") - ) - if subplot_type not in _images: - _images[subplot_type] = SubplotFitImage(result.image(name)) + _images[subplot_type] = SubplotFitImage( + result.image( + subplot_filename(subplot_), + ) + ) return _images[subplot_type] matrix = [] row = [] for subplot in subplots: - if isinstance(subplot, SubplotFit): + if isinstance(subplot, Enum): row.append( get_image(subplot).image_at_coordinates( *subplot.value, ) ) elif isinstance(subplot, list): + if not isinstance(subplot[i], Image.Image): + raise TypeError( + "The subplots must be of type Subplot or a list of " + "images or a function that takes a SearchOutput as an " + "argument." + ) row.append(subplot[i]) else: try: diff --git a/autofit/database/aggregator/scrape.py b/autofit/database/aggregator/scrape.py index 726dfb052..2e3bf8e15 100644 --- a/autofit/database/aggregator/scrape.py +++ b/autofit/database/aggregator/scrape.py @@ -183,5 +183,5 @@ def _add_files(fit: m.Fit, item: SearchOutput): except ValueError: logger.debug(f"Failed to load array {array_output.name} for {fit.id}") - for hdu_output in item.hdus: - fit.set_hdu(hdu_output.name, hdu_output.value) + for fits in item.fits: + fit.set_fits(fits.name, fits.value) diff --git a/autofit/database/migration/steps.py b/autofit/database/migration/steps.py index ac992d35f..3dea27ed1 100644 --- a/autofit/database/migration/steps.py +++ b/autofit/database/migration/steps.py @@ -23,6 +23,15 @@ Step( "ALTER TABLE object RENAME COLUMN latent_variables_for_id TO latent_samples_for_id;", ), + Step( + "CREATE TABLE fits (id INTEGER NOT NULL, name VARCHAR, fit_id VARCHAR, PRIMARY KEY (id), FOREIGN KEY (fit_id) REFERENCES fit (id));" + ), + Step( + "ALTER TABLE hdu ADD COLUMN fits_id INTEGER;", + ), + Step( + "ALTER TABLE hdu ADD COLUMN is_primary BOOLEAN;", + ), ] migrator = Migrator(*steps) diff --git a/autofit/database/model/array.py b/autofit/database/model/array.py index fc665c383..51fffd0a7 100644 --- a/autofit/database/model/array.py +++ b/autofit/database/model/array.py @@ -69,6 +69,53 @@ def __call__(self, *args, **kwargs): return self.value +class Fits(Object): + """ + A serialised astropy.io.fits.HDUList + """ + + __tablename__ = "fits" + + id = sa.Column( + sa.Integer, + sa.ForeignKey("object.id"), + primary_key=True, + index=True, + ) + + __mapper_args__ = { + "polymorphic_identity": "fits", + } + + hdus = sa.orm.relationship( + "HDU", + back_populates="fits", + foreign_keys="HDU.fits_id", + ) + + fit_id = sa.Column(sa.String, sa.ForeignKey("fit.id")) + fit = sa.orm.relationship( + "Fit", + uselist=False, + foreign_keys=[fit_id], + back_populates="fits", + ) + + @property + def hdu_list(self): + from astropy.io import fits + + return fits.HDUList([hdu.hdu for hdu in self.hdus]) + + @hdu_list.setter + def hdu_list(self, hdu_list): + self.hdus = [HDU(hdu=hdu) for hdu in hdu_list] + + @property + def value(self): + return self.hdu_list + + class HDU(Array): """ A serialised astropy.io.fits.PrimaryHDU @@ -89,6 +136,8 @@ class HDU(Array): "polymorphic_identity": "hdu", } + is_primary = sa.Column(sa.Boolean) + fit = sa.orm.relationship( "Fit", uselist=False, @@ -96,6 +145,17 @@ class HDU(Array): back_populates="hdus", ) + fits_id = sa.Column( + sa.Integer, + sa.ForeignKey("fits.id"), + ) + fits = sa.orm.relationship( + "Fits", + uselist=False, + foreign_keys=[fits_id], + back_populates="hdus", + ) + @property def header(self): """ @@ -113,13 +173,18 @@ def header(self, header): def hdu(self): from astropy.io import fits - return fits.PrimaryHDU( + type_ = fits.PrimaryHDU if self.is_primary else fits.ImageHDU + + return type_( self.array, self.header, ) @hdu.setter def hdu(self, hdu): + from astropy.io import fits + + self.is_primary = isinstance(hdu, fits.PrimaryHDU) self.array = hdu.data self.header = hdu.header diff --git a/autofit/database/model/fit.py b/autofit/database/model/fit.py index 60f2c9c6d..af946baab 100644 --- a/autofit/database/model/fit.py +++ b/autofit/database/model/fit.py @@ -11,7 +11,7 @@ from autofit.non_linear.samples import Samples from .model import Base, Object from ..sqlalchemy_ import sa -from .array import Array, HDU +from .array import Array, HDU, Fits from ...aggregator import fit_interface from ...non_linear.samples.efficient import EfficientSamples @@ -337,6 +337,11 @@ def model(self, model: AbstractPriorModel): lazy="joined", foreign_keys=[HDU.fit_id], ) + fits: Mapped[List[Fits]] = sa.orm.relationship( + "Fits", + lazy="joined", + foreign_keys=[Fits.fit_id], + ) def __getitem__(self, item: str): """ @@ -354,7 +359,7 @@ def __getitem__(self, item: str): ------- An unpickled object """ - for p in self.jsons + self.arrays + self.hdus + self.pickles: + for p in self.jsons + self.arrays + self.hdus + self.pickles + self.fits: if p.name == item: value = p.value if item == "samples_summary": @@ -437,38 +442,38 @@ def get_array(self, key: str) -> np.ndarray: return p.array raise KeyError(f"Array {key} not found") - def set_hdu(self, key: str, value): + def set_fits(self, key: str, value): """ - Add an HDU to the database. Overwrites any existing HDU + Add a fits object to the database. Overwrites any existing fits with the same name. Parameters ---------- key - The name of the HDU + The name of the fits value A fits HDUList """ - new = HDU(name=key, hdu=value) - self.hdus = [p for p in self.hdus if p.name != key] + [new] + new = Fits(name=key, hdu_list=value) + self.fits = [p for p in self.fits if p.name != key] + [new] - def get_hdu(self, key: str): + def get_fits(self, key: str): """ - Retrieve an HDU from the database. + Retrieve a fits object from the database. Parameters ---------- key - The name of the HDU + The name of the fits Returns ------- A fits HDUList """ - for p in self.hdus: + for p in self.fits: if p.name == key: - return p.hdu - raise KeyError(f"HDU {key} not found") + return p.hdu_list + raise KeyError(f"Fits {key} not found") def __contains__(self, item): for i in self.pickles + self.jsons + self.arrays + self.hdus: diff --git a/autofit/non_linear/paths/abstract.py b/autofit/non_linear/paths/abstract.py index f57fabede..9016a8dd1 100644 --- a/autofit/non_linear/paths/abstract.py +++ b/autofit/non_linear/paths/abstract.py @@ -397,7 +397,7 @@ def load_array(self, name) -> np.ndarray: pass @abstractmethod - def save_fits(self, name: str, hdu, prefix: str = ""): + def save_fits(self, name: str, fits, prefix: str = ""): pass @abstractmethod diff --git a/autofit/non_linear/paths/database.py b/autofit/non_linear/paths/database.py index 0a48989fb..5845a1442 100644 --- a/autofit/non_linear/paths/database.py +++ b/autofit/non_linear/paths/database.py @@ -183,7 +183,7 @@ def load_array(self, name: str) -> np.ndarray: return self.fit.get_array(name) @conditional_output - def save_fits(self, name: str, hdu, prefix: str = ""): + def save_fits(self, name: str, fits, prefix: str = ""): """ Save a fits file in the database @@ -191,10 +191,10 @@ def save_fits(self, name: str, hdu, prefix: str = ""): ---------- name The name of the fits file - hdu - The hdu to save + fits + The fits file to save """ - self.fit.set_hdu(name, hdu) + self.fit.set_fits(name, fits) def load_fits(self, name: str, prefix: str = ""): """ @@ -209,7 +209,7 @@ def load_fits(self, name: str, prefix: str = ""): ------- The loaded hdu """ - return self.fit.get_hdu(name) + return self.fit.get_fits(name) @conditional_output def save_object(self, name: str, obj: object, prefix: str = ""): diff --git a/autofit/non_linear/paths/directory.py b/autofit/non_linear/paths/directory.py index 9b1383cea..109a2f89f 100644 --- a/autofit/non_linear/paths/directory.py +++ b/autofit/non_linear/paths/directory.py @@ -103,7 +103,7 @@ def load_array(self, name: str): return np.loadtxt(self._path_for_csv(name), delimiter=",") @conditional_output - def save_fits(self, name: str, hdu, prefix: str = ""): + def save_fits(self, name: str, fits, prefix: str = ""): """ Save an HDU as a fits file in the fits directory of the search. @@ -111,12 +111,12 @@ def save_fits(self, name: str, hdu, prefix: str = ""): ---------- name The name of the fits file - hdu - The HDU to save + fits + The HDUList to save prefix A prefix to add to the path which is the name of the folder the file is saved in. """ - hdu.writeto(self._path_for_fits(name, prefix), overwrite=True) + fits.writeto(self._path_for_fits(name, prefix), overwrite=True) def load_fits(self, name: str, prefix: str = ""): """ diff --git a/test_autofit/aggregator/aggregate_summary/fit_2/metadata b/test_autofit/aggregator/summary_files/__init__.py similarity index 100% rename from test_autofit/aggregator/aggregate_summary/fit_2/metadata rename to test_autofit/aggregator/summary_files/__init__.py diff --git a/test_autofit/aggregator/aggregate_summary/fit_1/files/latent/latent_summary.json b/test_autofit/aggregator/summary_files/aggregate_summary/fit_1/files/latent/latent_summary.json similarity index 100% rename from test_autofit/aggregator/aggregate_summary/fit_1/files/latent/latent_summary.json rename to test_autofit/aggregator/summary_files/aggregate_summary/fit_1/files/latent/latent_summary.json diff --git a/test_autofit/aggregator/aggregate_summary/fit_1/files/model.json b/test_autofit/aggregator/summary_files/aggregate_summary/fit_1/files/model.json similarity index 100% rename from test_autofit/aggregator/aggregate_summary/fit_1/files/model.json rename to test_autofit/aggregator/summary_files/aggregate_summary/fit_1/files/model.json diff --git a/test_autofit/aggregator/aggregate_summary/fit_1/files/samples.json b/test_autofit/aggregator/summary_files/aggregate_summary/fit_1/files/samples.json similarity index 100% rename from test_autofit/aggregator/aggregate_summary/fit_1/files/samples.json rename to test_autofit/aggregator/summary_files/aggregate_summary/fit_1/files/samples.json diff --git a/test_autofit/aggregator/aggregate_summary/fit_1/files/samples_summary.json b/test_autofit/aggregator/summary_files/aggregate_summary/fit_1/files/samples_summary.json similarity index 100% rename from test_autofit/aggregator/aggregate_summary/fit_1/files/samples_summary.json rename to test_autofit/aggregator/summary_files/aggregate_summary/fit_1/files/samples_summary.json diff --git a/test_autofit/aggregator/summary_files/aggregate_summary/fit_1/image/fit.fits b/test_autofit/aggregator/summary_files/aggregate_summary/fit_1/image/fit.fits new file mode 100644 index 000000000..ff6df3490 Binary files /dev/null and b/test_autofit/aggregator/summary_files/aggregate_summary/fit_1/image/fit.fits differ diff --git a/test_autofit/aggregator/aggregate_summary/fit_1/image/subplot_fit.png b/test_autofit/aggregator/summary_files/aggregate_summary/fit_1/image/subplot_fit.png similarity index 100% rename from test_autofit/aggregator/aggregate_summary/fit_1/image/subplot_fit.png rename to test_autofit/aggregator/summary_files/aggregate_summary/fit_1/image/subplot_fit.png diff --git a/test_autofit/aggregator/summary_files/aggregate_summary/fit_1/metadata b/test_autofit/aggregator/summary_files/aggregate_summary/fit_1/metadata new file mode 100644 index 000000000..e69de29bb diff --git a/test_autofit/aggregator/aggregate_summary/fit_2/files/latent/latent_summary.json b/test_autofit/aggregator/summary_files/aggregate_summary/fit_2/files/latent/latent_summary.json similarity index 100% rename from test_autofit/aggregator/aggregate_summary/fit_2/files/latent/latent_summary.json rename to test_autofit/aggregator/summary_files/aggregate_summary/fit_2/files/latent/latent_summary.json diff --git a/test_autofit/aggregator/aggregate_summary/fit_2/files/model.json b/test_autofit/aggregator/summary_files/aggregate_summary/fit_2/files/model.json similarity index 100% rename from test_autofit/aggregator/aggregate_summary/fit_2/files/model.json rename to test_autofit/aggregator/summary_files/aggregate_summary/fit_2/files/model.json diff --git a/test_autofit/aggregator/aggregate_summary/fit_2/files/samples.json b/test_autofit/aggregator/summary_files/aggregate_summary/fit_2/files/samples.json similarity index 100% rename from test_autofit/aggregator/aggregate_summary/fit_2/files/samples.json rename to test_autofit/aggregator/summary_files/aggregate_summary/fit_2/files/samples.json diff --git a/test_autofit/aggregator/aggregate_summary/fit_2/files/samples_summary.json b/test_autofit/aggregator/summary_files/aggregate_summary/fit_2/files/samples_summary.json similarity index 100% rename from test_autofit/aggregator/aggregate_summary/fit_2/files/samples_summary.json rename to test_autofit/aggregator/summary_files/aggregate_summary/fit_2/files/samples_summary.json diff --git a/test_autofit/aggregator/summary_files/aggregate_summary/fit_2/image/fit.fits b/test_autofit/aggregator/summary_files/aggregate_summary/fit_2/image/fit.fits new file mode 100644 index 000000000..ff6df3490 Binary files /dev/null and b/test_autofit/aggregator/summary_files/aggregate_summary/fit_2/image/fit.fits differ diff --git a/test_autofit/aggregator/aggregate_summary/fit_2/image/subplot_fit.png b/test_autofit/aggregator/summary_files/aggregate_summary/fit_2/image/subplot_fit.png similarity index 100% rename from test_autofit/aggregator/aggregate_summary/fit_2/image/subplot_fit.png rename to test_autofit/aggregator/summary_files/aggregate_summary/fit_2/image/subplot_fit.png diff --git a/test_autofit/aggregator/summary_files/aggregate_summary/fit_2/metadata b/test_autofit/aggregator/summary_files/aggregate_summary/fit_2/metadata new file mode 100644 index 000000000..e69de29bb diff --git a/test_autofit/aggregator/summary_files/conftest.py b/test_autofit/aggregator/summary_files/conftest.py new file mode 100644 index 000000000..49a73bf2b --- /dev/null +++ b/test_autofit/aggregator/summary_files/conftest.py @@ -0,0 +1,9 @@ +import pytest +from pathlib import Path +from autofit.aggregator import Aggregator + + +@pytest.fixture +def aggregator(): + directory = Path(__file__).parent / "aggregate_summary" + return Aggregator.from_directory(directory) diff --git a/test_autofit/aggregator/test_aggregate_csv.py b/test_autofit/aggregator/summary_files/test_aggregate_csv.py similarity index 91% rename from test_autofit/aggregator/test_aggregate_csv.py rename to test_autofit/aggregator/summary_files/test_aggregate_csv.py index 2469971d8..cfab3fc98 100644 --- a/test_autofit/aggregator/test_aggregate_csv.py +++ b/test_autofit/aggregator/summary_files/test_aggregate_csv.py @@ -1,9 +1,8 @@ import csv -from autofit.aggregator import Aggregator from pathlib import Path -from autofit.aggregator.aggregate_csv import AggregateCSV +from autofit.aggregator.summary.aggregate_csv import AggregateCSV import pytest @@ -16,9 +15,7 @@ def output_path(): @pytest.fixture -def summary(): - directory = Path(__file__).parent / "aggregate_summary" - aggregator = Aggregator.from_directory(directory) +def summary(aggregator): return AggregateCSV(aggregator) diff --git a/test_autofit/aggregator/summary_files/test_aggregate_fits.py b/test_autofit/aggregator/summary_files/test_aggregate_fits.py new file mode 100644 index 000000000..da2e73426 --- /dev/null +++ b/test_autofit/aggregator/summary_files/test_aggregate_fits.py @@ -0,0 +1,41 @@ +import pytest + +import autofit as af +from pathlib import Path + + +@pytest.fixture(name="summary") +def make_summary(aggregator): + return af.AggregateFITS(aggregator) + + +def test_aggregate(summary): + result = summary.extract_fits( + af.FitFITS.ModelImage, + af.FitFITS.ResidualMap, + ) + assert len(result) == 5 + + +def test_output_to_file(summary, output_directory): + folder = output_directory / "fits" + summary.output_to_folder( + folder, + af.FitFITS.ModelImage, + af.FitFITS.ResidualMap, + name="id", + ) + assert len((list(folder.glob("*")))) == 2 + + +def test_list_of_names(summary, output_directory): + summary.output_to_folder( + output_directory, + af.FitFITS.ModelImage, + af.FitFITS.ResidualMap, + name=["one", "two"], + ) + assert [path.name for path in Path(output_directory).glob("*.fits")] == [ + "one.fits", + "two.fits", + ] diff --git a/test_autofit/aggregator/test_aggregate_images.py b/test_autofit/aggregator/summary_files/test_aggregate_images.py similarity index 69% rename from test_autofit/aggregator/test_aggregate_images.py rename to test_autofit/aggregator/summary_files/test_aggregate_images.py index df5b05ae9..dce18d8ea 100644 --- a/test_autofit/aggregator/test_aggregate_images.py +++ b/test_autofit/aggregator/summary_files/test_aggregate_images.py @@ -1,16 +1,12 @@ +from enum import Enum + import pytest from pathlib import Path from PIL import Image from autofit.aggregator import Aggregator -from autofit.aggregator.aggregate_images import AggregateImages, SubplotFit - - -@pytest.fixture -def aggregator(): - directory = Path(__file__).parent / "aggregate_summary" - return Aggregator.from_directory(directory) +from autofit.aggregator.summary.aggregate_images import AggregateImages, SubplotFit @pytest.fixture @@ -67,10 +63,25 @@ def test_output_to_folder(aggregate, output_directory): SubplotFit.Data, SubplotFit.SourcePlaneZoomed, SubplotFit.SourceModelImage, + name="name", ) assert list(Path(output_directory).glob("*.png")) +def test_list_of_names(aggregate, output_directory): + aggregate.output_to_folder( + output_directory, + SubplotFit.Data, + SubplotFit.SourcePlaneZoomed, + SubplotFit.SourceModelImage, + name=["one", "two"], + ) + assert [path.name for path in Path(output_directory).glob("*.png")] == [ + "two.png", + "one.png", + ] + + def test_output_to_folder_name( aggregate, output_directory, @@ -117,3 +128,27 @@ def make_image(output): ) assert result.size == (193, 120) + + +def test_custom_subplot_fit(aggregate): + class SubplotFit(Enum): + """ + The subplots that can be extracted from the subplot_fit image. + + The values correspond to the position of the subplot in the 4x3 grid. + """ + + Data = (0, 0) + + result = aggregate.extract_image( + SubplotFit.Data, + ) + assert result.size == (61, 120) + + +def test_bad_aggregator(): + directory = Path(__file__).parent / "aggregate_summaries" + aggregator = Aggregator.from_directory(directory) + + with pytest.raises(ValueError): + AggregateImages(aggregator) diff --git a/test_autofit/aggregator/test_scrape.py b/test_autofit/aggregator/test_scrape.py index bea3b7615..5033e7739 100644 --- a/test_autofit/aggregator/test_scrape.py +++ b/test_autofit/aggregator/test_scrape.py @@ -22,7 +22,7 @@ def set_pickle(self, key, value): def set_array(self, key, value): pass - def set_hdu(self, key, value): + def set_fits(self, key, value): pass diff --git a/test_autofit/database/test_file_types.py b/test_autofit/database/test_file_types.py index a37cb1fbe..8a162e822 100644 --- a/test_autofit/database/test_file_types.py +++ b/test_autofit/database/test_file_types.py @@ -54,10 +54,3 @@ def test_hdu(hdu, hdu_array): loaded = db_hdu.hdu assert (loaded.data == hdu_array).all() assert loaded.header == hdu.header - - -def test_set_hdu(fit, hdu, hdu_array): - fit.set_hdu("test", hdu) - loaded = fit.get_hdu("test") - assert (loaded.data == hdu_array).all() - assert loaded.header == hdu.header diff --git a/test_autofit/non_linear/grid/test_sensitivity/test_masked_sensitivity.py b/test_autofit/non_linear/grid/test_sensitivity/test_masked_sensitivity.py index ed90f1822..a5ae65729 100644 --- a/test_autofit/non_linear/grid/test_sensitivity/test_masked_sensitivity.py +++ b/test_autofit/non_linear/grid/test_sensitivity/test_masked_sensitivity.py @@ -71,7 +71,7 @@ def test_perturbed_physical_centres_list_from(masked_result): ] -def test_visualise(sensitivity): +def _test_visualise(sensitivity): def visualiser(sensitivity_result, **_): assert len(sensitivity_result.samples) == 8