From b182f21ff6e62a7fb1cbc67c59c3c8815b13a17b Mon Sep 17 00:00:00 2001 From: Richard Date: Fri, 14 Mar 2025 10:20:47 +0000 Subject: [PATCH 1/3] fixes --- autofit/aggregator/aggregate_images.py | 44 +++++++++++++++++--------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/autofit/aggregator/aggregate_images.py b/autofit/aggregator/aggregate_images.py index c08d29a3f..faac61cff 100644 --- a/autofit/aggregator/aggregate_images.py +++ b/autofit/aggregator/aggregate_images.py @@ -1,6 +1,4 @@ -import re import sys -from enum import Enum from typing import Optional, List, Union, Callable, Type from pathlib import Path @@ -9,6 +7,22 @@ from autofit.aggregator.search_output import SearchOutput from autofit.aggregator.aggregator import Aggregator +import re +from enum import Enum + + +def subplot_filename(subplot: Enum) -> str: + subplot_type = subplot.__class__ + return ( + re.sub( + r"([A-Z])", + r"_\1", + subplot_type.__name__, + ) + .lower() + .lstrip("_") + ) + class SubplotFit(Enum): """ @@ -172,7 +186,7 @@ def output_to_folder( name The attribute of each fit to use as the name of the output file. """ - folder.mkdir(exist_ok=True) + folder.mkdir(exist_ok=True, parents=True) for i, result in enumerate(self._aggregator): image = self._matrix_to_image( @@ -231,30 +245,30 @@ class name but using snake_case. The image for the subplot. """ subplot_type = subplot_.__class__ - name = ( - re.sub( - r"([A-Z])", - r"_\1", - subplot_type.__name__, - ) - .lower() - .lstrip("_") - ) - if subplot_type not in _images: - _images[subplot_type] = SubplotFitImage(result.image(name)) + _images[subplot_type] = SubplotFitImage( + result.image( + subplot_filename(subplot_), + ) + ) return _images[subplot_type] matrix = [] row = [] for subplot in subplots: - if isinstance(subplot, SubplotFit): + if isinstance(subplot, Enum): row.append( get_image(subplot).image_at_coordinates( *subplot.value, ) ) elif isinstance(subplot, list): + if not isinstance(subplot[i], Image.Image): + raise TypeError( + "The subplots must be of type Subplot or a list of " + "images or a function that takes a SearchOutput as an " + "argument." + ) row.append(subplot[i]) else: try: From 922b18f7aa0acb4ec5e207c988974d386ef7e016 Mon Sep 17 00:00:00 2001 From: Richard Date: Fri, 14 Mar 2025 10:34:47 +0000 Subject: [PATCH 2/3] test custom enums work --- .../aggregator/test_aggregate_images.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test_autofit/aggregator/test_aggregate_images.py b/test_autofit/aggregator/test_aggregate_images.py index df5b05ae9..33cb37056 100644 --- a/test_autofit/aggregator/test_aggregate_images.py +++ b/test_autofit/aggregator/test_aggregate_images.py @@ -1,3 +1,5 @@ +from enum import Enum + import pytest from pathlib import Path @@ -117,3 +119,19 @@ def make_image(output): ) assert result.size == (193, 120) + + +def test_custom_subplot_fit(aggregate): + class SubplotFit(Enum): + """ + The subplots that can be extracted from the subplot_fit image. + + The values correspond to the position of the subplot in the 4x3 grid. + """ + + Data = (0, 0) + + result = aggregate.extract_image( + SubplotFit.Data, + ) + assert result.size == (61, 120) From a04623cd8432734d407c1c48721bd5d5ea8cb578 Mon Sep 17 00:00:00 2001 From: Richard Date: Fri, 14 Mar 2025 10:37:10 +0000 Subject: [PATCH 3/3] raise a ValueError for an empty aggregator --- autofit/aggregator/aggregate_images.py | 3 +++ test_autofit/aggregator/test_aggregate_images.py | 8 ++++++++ 2 files changed, 11 insertions(+) diff --git a/autofit/aggregator/aggregate_images.py b/autofit/aggregator/aggregate_images.py index faac61cff..e04529df2 100644 --- a/autofit/aggregator/aggregate_images.py +++ b/autofit/aggregator/aggregate_images.py @@ -114,6 +114,9 @@ def __init__( aggregator The aggregator containing the fit results. """ + if len(aggregator) == 0: + raise ValueError("The aggregator is empty.") + self._aggregator = aggregator self._source_images = None diff --git a/test_autofit/aggregator/test_aggregate_images.py b/test_autofit/aggregator/test_aggregate_images.py index 33cb37056..135a7b3b3 100644 --- a/test_autofit/aggregator/test_aggregate_images.py +++ b/test_autofit/aggregator/test_aggregate_images.py @@ -135,3 +135,11 @@ class SubplotFit(Enum): SubplotFit.Data, ) assert result.size == (61, 120) + + +def test_bad_aggregator(): + directory = Path(__file__).parent / "aggregate_summaries" + aggregator = Aggregator.from_directory(directory) + + with pytest.raises(ValueError): + AggregateImages(aggregator)