Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions autofit/aggregator/summary/aggregate_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, aggregator: Aggregator):
@staticmethod
def _hdus(
result: SearchOutput,
*hdus: Enum,
hdus: List[Enum],
) -> List[fits.ImageHDU]:
"""
Extract the HDUs from a given fits for a given search.
Expand Down Expand Up @@ -79,7 +79,7 @@ def _hdus(
)
return row

def extract_fits(self, *hdus: Enum) -> List[fits.HDUList]:
def extract_fits(self, hdus: List[Enum]) -> List[fits.HDUList]:
"""
Extract the HDUs from the fits files for every search in the aggregator.

Expand All @@ -96,15 +96,15 @@ def extract_fits(self, *hdus: Enum) -> List[fits.HDUList]:
"""
output = [fits.PrimaryHDU()]
for result in self.aggregator:
output.extend(self._hdus(result, *hdus))
output.extend(self._hdus(result, hdus))

return fits.HDUList(output)

def output_to_folder(
self,
folder: Path,
*hdus: Enum,
name: Union[str, List[str]],
hdus: List[Enum],
):
"""
Output the fits files for every search in the aggregator to a folder.
Expand Down Expand Up @@ -133,7 +133,7 @@ def output_to_folder(
[fits.PrimaryHDU()]
+ self._hdus(
result,
*hdus,
hdus,
)
)
with open(folder / f"{output_name}.fits", "wb") as file:
Expand Down
15 changes: 9 additions & 6 deletions autofit/aggregator/summary/aggregate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(

def extract_image(
self,
*subplots: Union[Enum, List[Image.Image], Callable],
subplots: List[Union[Enum, List[Image.Image], Callable]],
subplot_width: Optional[int] = sys.maxsize,
) -> Image.Image:
"""
Expand Down Expand Up @@ -154,7 +154,7 @@ def extract_image(
self._matrix_for_result(
i,
result,
*subplots,
subplots,
subplot_width=subplot_width,
)
)
Expand All @@ -164,9 +164,9 @@ def extract_image(
def output_to_folder(
self,
folder: Path,
*subplots: Union[SubplotFit, List[Image.Image], Callable],
subplot_width: Optional[int] = sys.maxsize,
name: Union[str, List[str]],
subplots: List[Union[SubplotFit, List[Image.Image], Callable]],
subplot_width: Optional[int] = sys.maxsize,
):
"""
Output one subplot image for each fit in the aggregator.
Expand All @@ -190,14 +190,17 @@ def output_to_folder(
The attribute of each fit to use as the name of the output file.
OR a list of names, one for each fit.
"""
if len(subplots) == 0:
raise ValueError("At least one subplot must be provided.")

folder.mkdir(exist_ok=True, parents=True)

for i, result in enumerate(self._aggregator):
image = self._matrix_to_image(
self._matrix_for_result(
i,
result,
*subplots,
subplots,
subplot_width=subplot_width,
)
)
Expand All @@ -213,7 +216,7 @@ def output_to_folder(
def _matrix_for_result(
i: int,
result: SearchOutput,
*subplots: Union[SubplotFit, List[Image.Image], Callable],
subplots: List[Union[SubplotFit, List[Image.Image], Callable]],
subplot_width: int = sys.maxsize,
) -> List[List[Image.Image]]:
"""
Expand Down
22 changes: 14 additions & 8 deletions test_autofit/aggregator/summary_files/test_aggregate_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ def make_summary(aggregator):

def test_aggregate(summary):
result = summary.extract_fits(
af.FitFITS.ModelImage,
af.FitFITS.ResidualMap,
[
af.FitFITS.ModelImage,
af.FitFITS.ResidualMap,
],
)
assert len(result) == 5

Expand All @@ -21,19 +23,23 @@ def test_output_to_file(summary, output_directory):
folder = output_directory / "fits"
summary.output_to_folder(
folder,
af.FitFITS.ModelImage,
af.FitFITS.ResidualMap,
name="id",
hdus=[
af.FitFITS.ModelImage,
af.FitFITS.ResidualMap,
],
)
assert len((list(folder.glob("*")))) == 2
assert list(folder.glob("*"))


def test_list_of_names(summary, output_directory):
summary.output_to_folder(
output_directory,
af.FitFITS.ModelImage,
af.FitFITS.ResidualMap,
name=["one", "two"],
["one", "two"],
[
af.FitFITS.ModelImage,
af.FitFITS.ResidualMap,
],
)
assert [path.name for path in Path(output_directory).glob("*.fits")] == [
"one.fits",
Expand Down
94 changes: 59 additions & 35 deletions test_autofit/aggregator/summary_files/test_aggregate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,53 @@ def aggregate(aggregator):

def test(aggregate):
result = aggregate.extract_image(
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
[
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
],
)
assert result.size == (122, 120)
assert result == aggregate.extract_image(
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
[
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
],
)


def test_different_plots(aggregate):
assert aggregate.extract_image(
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
[
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
],
) != aggregate.extract_image(
SubplotFit.SourcePlaneZoomed,
SubplotFit.Data,
[
SubplotFit.SourcePlaneZoomed,
SubplotFit.Data,
],
)


def test_longer(aggregate):
result = aggregate.extract_image(
SubplotFit.NormalizedResidualMap,
SubplotFit.SourcePlaneNoZoom,
SubplotFit.SourceModelImage,
[
SubplotFit.NormalizedResidualMap,
SubplotFit.SourcePlaneNoZoom,
SubplotFit.SourceModelImage,
],
)

assert result.size == (183, 120)


def test_subplot_width(aggregate):
result = aggregate.extract_image(
SubplotFit.NormalizedResidualMap,
SubplotFit.SourcePlaneNoZoom,
SubplotFit.SourceModelImage,
[
SubplotFit.NormalizedResidualMap,
SubplotFit.SourcePlaneNoZoom,
SubplotFit.SourceModelImage,
],
subplot_width=2,
)

Expand All @@ -60,21 +72,25 @@ def test_subplot_width(aggregate):
def test_output_to_folder(aggregate, output_directory):
aggregate.output_to_folder(
output_directory,
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
SubplotFit.SourceModelImage,
name="name",
"id",
[
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
SubplotFit.SourceModelImage,
],
)
assert list(Path(output_directory).glob("*.png"))


def test_list_of_names(aggregate, output_directory):
aggregate.output_to_folder(
output_directory,
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
SubplotFit.SourceModelImage,
name=["one", "two"],
["one", "two"],
[
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
SubplotFit.SourceModelImage,
],
)
assert [path.name for path in Path(output_directory).glob("*.png")] == [
"two.png",
Expand All @@ -89,10 +105,12 @@ def test_output_to_folder_name(
):
aggregate.output_to_folder(
output_directory,
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
SubplotFit.SourceModelImage,
name="id",
"id",
[
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
SubplotFit.SourceModelImage,
],
)

id_ = next(iter(aggregator)).id
Expand All @@ -107,10 +125,12 @@ def test_custom_images(
images = [image for _ in aggregator]

result = aggregate.extract_image(
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
SubplotFit.SourceModelImage,
images,
[
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
SubplotFit.SourceModelImage,
images,
]
)

assert result.size == (193, 120)
Expand All @@ -121,10 +141,12 @@ def make_image(output):
return Image.new("RGB", (10, 10), "white")

result = aggregate.extract_image(
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
SubplotFit.SourceModelImage,
make_image,
[
SubplotFit.Data,
SubplotFit.SourcePlaneZoomed,
SubplotFit.SourceModelImage,
make_image,
]
)

assert result.size == (193, 120)
Expand All @@ -141,7 +163,9 @@ class SubplotFit(Enum):
Data = (0, 0)

result = aggregate.extract_image(
SubplotFit.Data,
[
SubplotFit.Data,
]
)
assert result.size == (61, 120)

Expand Down
2 changes: 1 addition & 1 deletion test_autofit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def remove_output():
return remove_output


@pytest.fixture(autouse=True, scope="session")
@pytest.fixture(autouse=True)
def do_remove_output(output_directory, remove_output):
yield
remove_output()
Expand Down
Loading