Skip to content
Merged

fixes #1116

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions autofit/aggregator/aggregate_images.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import re
import sys
from enum import Enum
from typing import Optional, List, Union, Callable, Type
from pathlib import Path

Expand All @@ -9,6 +7,22 @@
from autofit.aggregator.search_output import SearchOutput
from autofit.aggregator.aggregator import Aggregator

import re
from enum import Enum


def subplot_filename(subplot: Enum) -> str:
subplot_type = subplot.__class__
return (
re.sub(
r"([A-Z])",
r"_\1",
subplot_type.__name__,
)
.lower()
.lstrip("_")
)


class SubplotFit(Enum):
"""
Expand Down Expand Up @@ -100,6 +114,9 @@ def __init__(
aggregator
The aggregator containing the fit results.
"""
if len(aggregator) == 0:
raise ValueError("The aggregator is empty.")

self._aggregator = aggregator
self._source_images = None

Expand Down Expand Up @@ -172,7 +189,7 @@ def output_to_folder(
name
The attribute of each fit to use as the name of the output file.
"""
folder.mkdir(exist_ok=True)
folder.mkdir(exist_ok=True, parents=True)

for i, result in enumerate(self._aggregator):
image = self._matrix_to_image(
Expand Down Expand Up @@ -231,30 +248,30 @@ class name but using snake_case.
The image for the subplot.
"""
subplot_type = subplot_.__class__
name = (
re.sub(
r"([A-Z])",
r"_\1",
subplot_type.__name__,
)
.lower()
.lstrip("_")
)

if subplot_type not in _images:
_images[subplot_type] = SubplotFitImage(result.image(name))
_images[subplot_type] = SubplotFitImage(
result.image(
subplot_filename(subplot_),
)
)
return _images[subplot_type]

matrix = []
row = []
for subplot in subplots:
if isinstance(subplot, SubplotFit):
if isinstance(subplot, Enum):
row.append(
get_image(subplot).image_at_coordinates(
*subplot.value,
)
)
elif isinstance(subplot, list):
if not isinstance(subplot[i], Image.Image):
raise TypeError(
"The subplots must be of type Subplot or a list of "
"images or a function that takes a SearchOutput as an "
"argument."
)
row.append(subplot[i])
else:
try:
Expand Down
26 changes: 26 additions & 0 deletions test_autofit/aggregator/test_aggregate_images.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from enum import Enum

import pytest
from pathlib import Path

Expand Down Expand Up @@ -117,3 +119,27 @@ def make_image(output):
)

assert result.size == (193, 120)


def test_custom_subplot_fit(aggregate):
class SubplotFit(Enum):
"""
The subplots that can be extracted from the subplot_fit image.

The values correspond to the position of the subplot in the 4x3 grid.
"""

Data = (0, 0)

result = aggregate.extract_image(
SubplotFit.Data,
)
assert result.size == (61, 120)


def test_bad_aggregator():
directory = Path(__file__).parent / "aggregate_summaries"
aggregator = Aggregator.from_directory(directory)

with pytest.raises(ValueError):
AggregateImages(aggregator)
Loading