'
@@ -224,13 +244,21 @@ def set_contrast(self, value):
self._apply_transforms_and_display()
def set_rgb_channels(self, r=None, g=None, b=None):
- """Sets RGB channel visibility."""
+ """Sets RGB channel visibility for both RGB and N-channel modes."""
if r is not None:
self.show_r = r
if g is not None:
self.show_g = g
if b is not None:
self.show_b = b
+
+ # Keep channel_visibility in sync for N-channel mode, where
+ # the first 3 entries control the RGB display channels.
+ if self.num_channels > 3:
+ for i, visible in enumerate([self.show_r, self.show_g, self.show_b]):
+ if i < len(self.channel_visibility):
+ self.channel_visibility[i] = visible
+
self._apply_transforms_and_display()
def toggle_full_resolution(self):
@@ -267,5 +295,6 @@ def _apply_transforms_and_display(self):
show_r=self.show_r,
show_g=self.show_g,
show_b=self.show_b,
+ channel_visibility=self.channel_visibility if self.num_channels > 3 else None,
)
self._display_image(self.modified_image)
diff --git a/anomaly_match/ui/ui_elements.py b/anomaly_match_ui/ui_elements.py
similarity index 96%
rename from anomaly_match/ui/ui_elements.py
rename to anomaly_match_ui/ui_elements.py
index cce2bfa..cb65ffe 100644
--- a/anomaly_match/ui/ui_elements.py
+++ b/anomaly_match_ui/ui_elements.py
@@ -5,13 +5,13 @@
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
import ipywidgets as widgets
-from ipywidgets import Button, HBox, VBox
+from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
from IPython.display import HTML
+from ipywidgets import Button, HBox, VBox
from loguru import logger
-from anomaly_match import __version__
-from anomaly_match.ui.memory_monitor import MemoryMonitor
-from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
+from anomaly_match_ui.memory_monitor import MemoryMonitor
+from anomaly_match_ui.utils.backend_interface import BackendInterface
HTML_setup = HTML(
"""
@@ -90,9 +90,7 @@
def create_ui_elements():
- """
- Creates and returns the necessary UI widgets as a dictionary.
- """
+ """Creates and returns the necessary UI widgets as a dictionary."""
logger.debug("Creating UI elements")
# HTML widget
@@ -253,7 +251,7 @@ def create_ui_elements():
),
)
- # Create a slider container with both sliders and the RGB controls
+ # Create a slider container with both sliders
slider_row = HBox(
[brightness_slider, contrast_slider],
layout=widgets.Layout(background_color="black"),
@@ -349,7 +347,7 @@ def create_ui_elements():
# Create version and memory monitor display with more explicit styling
version_text = widgets.HTML(
- value=f'Version: {__version__}
',
+ value=f'Version: {BackendInterface.get_version()}
',
layout=widgets.Layout(
background_color="black",
padding="3px",
@@ -425,8 +423,11 @@ def create_ui_elements():
],
layout=widgets.Layout(background_color="black"),
) # Add new row with channels, normalisation, and remove label button
+ # Include RGB mapping controls if available (for N-channel images)
+ bottom_row2_items = [channel_controls]
+ bottom_row2_items.extend([normalisation_dropdown, remove_label_button])
bottom_row2 = HBox(
- [channel_controls, normalisation_dropdown, remove_label_button],
+ bottom_row2_items,
layout=widgets.Layout(background_color="black"),
)
@@ -521,11 +522,11 @@ def attach_click_listeners(widget):
# Decision buttons
def mark_anomalous(_):
- widget.session.label_image(widget.preview.current_index, "anomaly")
+ BackendInterface.label_image(widget.preview.current_index, "anomaly")
widget.update_image_UI_label()
def mark_nominal(_):
- widget.session.label_image(widget.preview.current_index, "normal")
+ BackendInterface.label_image(widget.preview.current_index, "normal")
widget.update_image_UI_label()
widget.ui["decision_buttons"][0].on_click(mark_anomalous)
diff --git a/anomaly_match_ui/utils/__init__.py b/anomaly_match_ui/utils/__init__.py
new file mode 100644
index 0000000..fdde815
--- /dev/null
+++ b/anomaly_match_ui/utils/__init__.py
@@ -0,0 +1,23 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Utilities for the AnomalyMatch UI package."""
+
+from anomaly_match_ui.utils.backend_interface import BackendInterface
+from anomaly_match_ui.utils.display_transforms import (
+ apply_transforms_ui,
+ display_image_normalisation,
+ prepare_for_display,
+)
+from anomaly_match_ui.utils.image_utils import numpy_array_to_byte_stream
+
+__all__ = [
+ "BackendInterface",
+ "apply_transforms_ui",
+ "display_image_normalisation",
+ "prepare_for_display",
+ "numpy_array_to_byte_stream",
+]
diff --git a/anomaly_match_ui/utils/backend_interface.py b/anomaly_match_ui/utils/backend_interface.py
new file mode 100644
index 0000000..09e2ec8
--- /dev/null
+++ b/anomaly_match_ui/utils/backend_interface.py
@@ -0,0 +1,355 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""
+BackendInterface: Single interface for UI-backend communication.
+
+This module provides a static interface for the UI components to interact with
+the backend Session without direct imports, enabling clean separation between
+the UI and backend packages.
+"""
+
+from typing import Any, Callable, Dict, Optional, Tuple
+
+import numpy as np
+
+import anomaly_match
+
+
+class BackendInterface:
+ """Single interface for UI-backend communication.
+
+ This static class wraps all Session methods needed by the UI components,
+ providing a clean interface that decouples the UI from the backend implementation.
+ """
+
+ _session = None
+
+ # ========== Session Lifecycle ==========
+
+ @staticmethod
+ def set_session(session) -> None:
+ """Set the session instance to use.
+
+ Args:
+ session: The Session instance to interact with.
+ """
+ BackendInterface._session = session
+
+ @staticmethod
+ def get_session():
+ """Get the current session instance.
+
+ Returns:
+ The current Session instance, or None if not set.
+ """
+ return BackendInterface._session
+
+ @staticmethod
+ def _check_session() -> None:
+ """Check that a session is set, raise RuntimeError if not."""
+ if BackendInterface._session is None:
+ raise RuntimeError("No session set. Call BackendInterface.set_session() first.")
+
+ # ========== Configuration ==========
+
+ @staticmethod
+ def get_config():
+ """Get the session configuration.
+
+ Returns:
+ The session configuration object (DotMap).
+ """
+ BackendInterface._check_session()
+ return BackendInterface._session.cfg
+
+ @staticmethod
+ def get_num_channels() -> int:
+ """Get the number of channels in the images.
+
+ Returns:
+ int: Number of image channels.
+ """
+ BackendInterface._check_session()
+ return getattr(BackendInterface._session.cfg, "num_channels", 3)
+
+ @staticmethod
+ def set_normalisation_method(method) -> None:
+ """Set the normalisation method.
+
+ Args:
+ method (NormalisationMethod): The new normalisation method to apply.
+ """
+ BackendInterface._check_session()
+ BackendInterface._session.set_normalisation_method(method)
+
+ @staticmethod
+ def get_cached_normalisation_method():
+ """Get the cached normalisation method (what images are currently loaded with).
+
+ Returns:
+ NormalisationMethod: The cached normalisation method enum value.
+ """
+ BackendInterface._check_session()
+ return BackendInterface._session.cached_image_normalisation_enum
+
+ # ========== Image Data ==========
+
+ @staticmethod
+ def get_image_at_index(index: int) -> np.ndarray:
+ """Get the image at the given index from the catalog.
+
+ Args:
+ index (int): The index of the image.
+
+ Returns:
+ np.ndarray: The image array.
+ """
+ BackendInterface._check_session()
+ return BackendInterface._session.img_catalog[index]
+
+ @staticmethod
+ def get_image_count() -> int:
+ """Get the total number of images in the catalog.
+
+ Returns:
+ int: Number of images.
+ """
+ BackendInterface._check_session()
+ if BackendInterface._session.img_catalog is None:
+ return 0
+ return len(BackendInterface._session.img_catalog)
+
+ @staticmethod
+ def get_scores() -> np.ndarray:
+ """Get the anomaly scores for all images.
+
+ Returns:
+ np.ndarray: Array of anomaly scores.
+ """
+ BackendInterface._check_session()
+ return BackendInterface._session.scores
+
+ @staticmethod
+ def get_filenames() -> np.ndarray:
+ """Get the filenames for all images.
+
+ Returns:
+ np.ndarray: Array of filenames.
+ """
+ BackendInterface._check_session()
+ return BackendInterface._session.filenames
+
+ # ========== Labeling ==========
+
+ @staticmethod
+ def label_image(index: int, label: str) -> None:
+ """Label an image at the given index.
+
+ Args:
+ index (int): Index of the image to label.
+ label (str): Label to assign ("normal" or "anomaly").
+ """
+ BackendInterface._check_session()
+ BackendInterface._session.label_image(index, label)
+
+ @staticmethod
+ def unlabel_image(index: int) -> None:
+ """Remove the label from an image at the given index.
+
+ Args:
+ index (int): Index of the image to unlabel.
+ """
+ BackendInterface._check_session()
+ BackendInterface._session.unlabel_image(index)
+
+ @staticmethod
+ def get_label(index: int) -> str:
+ """Get the label for an image at the given index.
+
+ Args:
+ index (int): Index of the image.
+
+ Returns:
+ str: The label ("normal", "anomaly", or "None").
+ """
+ BackendInterface._check_session()
+ return BackendInterface._session.get_label(index)
+
+ @staticmethod
+ def get_label_distribution() -> Tuple[int, int]:
+ """Get the distribution of labels.
+
+ Returns:
+ tuple: (normal_count, anomalous_count)
+ """
+ BackendInterface._check_session()
+ return BackendInterface._session.get_label_distribution()
+
+ @staticmethod
+ def get_active_learning_counts() -> Tuple[int, int]:
+ """Get the count of newly annotated samples in active learning.
+
+ Returns:
+ tuple: (new_normal_count, new_anomalous_count)
+ """
+ BackendInterface._check_session()
+ return BackendInterface._session.get_active_learning_counts()
+
+ @staticmethod
+ def save_labels() -> None:
+ """Save the current labels to a file."""
+ BackendInterface._check_session()
+ BackendInterface._session.save_labels()
+
+ # ========== Sorting ==========
+
+ @staticmethod
+ def sort_by_anomalous() -> None:
+ """Sort images by anomalous scores (most anomalous first)."""
+ BackendInterface._check_session()
+ BackendInterface._session.sort_by_anomalous()
+
+ @staticmethod
+ def sort_by_nominal() -> None:
+ """Sort images by nominal scores (most nominal first)."""
+ BackendInterface._check_session()
+ BackendInterface._session.sort_by_nominal()
+
+ @staticmethod
+ def sort_by_mean() -> None:
+ """Sort images by distance to mean score."""
+ BackendInterface._check_session()
+ BackendInterface._session.sort_by_mean()
+
+ @staticmethod
+ def sort_by_median() -> None:
+ """Sort images by distance to median score."""
+ BackendInterface._check_session()
+ BackendInterface._session.sort_by_median()
+
+ # ========== Model Operations ==========
+
+ @staticmethod
+ def train(cfg, progress_callback: Optional[Callable] = None) -> None:
+ """Train the model.
+
+ Args:
+ cfg: Configuration for training.
+ progress_callback: Optional callback for progress updates.
+ """
+ BackendInterface._check_session()
+ BackendInterface._session.train(cfg, progress_callback=progress_callback)
+
+ @staticmethod
+ def save_model() -> None:
+ """Save the current model state."""
+ BackendInterface._check_session()
+ BackendInterface._session.save_model()
+
+ @staticmethod
+ def load_model() -> None:
+ """Load the model from the saved state."""
+ BackendInterface._check_session()
+ BackendInterface._session.load_model()
+
+ @staticmethod
+ def reset_model() -> None:
+ """Reset the model and reinitialize."""
+ BackendInterface._check_session()
+ BackendInterface._session.reset_model()
+
+ @staticmethod
+ def get_model():
+ """Get the current model instance.
+
+ Returns:
+ The FixMatch model instance.
+ """
+ BackendInterface._check_session()
+ return BackendInterface._session.model
+
+ # ========== Prediction/Evaluation ==========
+
+ @staticmethod
+ def update_predictions(progress_callback: Optional[Callable] = None) -> None:
+ """Update predictions using the current model.
+
+ Args:
+ progress_callback: Optional callback for progress updates.
+ """
+ BackendInterface._check_session()
+ BackendInterface._session.update_predictions(progress_callback=progress_callback)
+
+ @staticmethod
+ def evaluate_all_images(top_n: int, progress_callback: Optional[Callable] = None) -> None:
+ """Evaluate all images in the prediction search directory.
+
+ Args:
+ top_n (int): Number of top images to keep.
+ progress_callback: Optional callback for progress updates.
+ """
+ BackendInterface._check_session()
+ BackendInterface._session.evaluate_all_images(
+ top_N=top_n, progress_callback=progress_callback
+ )
+
+ @staticmethod
+ def load_next_batch() -> None:
+ """Load the next batch of data and update predictions."""
+ BackendInterface._check_session()
+ BackendInterface._session.load_next_batch()
+
+ @staticmethod
+ def load_top_files(top_n: int) -> None:
+ """Load the top files from the output directory.
+
+ Args:
+ top_n (int): Number of top files to load.
+ """
+ BackendInterface._check_session()
+ BackendInterface._session.load_top_files(top_n)
+
+ @staticmethod
+ def get_eval_performance() -> Optional[Dict[str, Any]]:
+ """Get the evaluation performance metrics.
+
+ Returns:
+ dict: Evaluation performance metrics, or None if not available.
+ """
+ BackendInterface._check_session()
+ return getattr(BackendInterface._session, "eval_performance", None)
+
+ # ========== Utilities ==========
+
+ @staticmethod
+ def set_terminal_output(output_widget) -> None:
+ """Set the terminal output widget for logging.
+
+ Args:
+ output_widget: The output widget for terminal logging.
+ """
+ BackendInterface._check_session()
+ BackendInterface._session.set_terminal_out(output_widget)
+
+ @staticmethod
+ def remember_current_file(filename: str) -> None:
+ """Remember the current file by appending it to a CSV.
+
+ Args:
+ filename (str): The filename to remember.
+ """
+ BackendInterface._check_session()
+ BackendInterface._session.remember_current_file(filename)
+
+ @staticmethod
+ def get_version() -> str:
+ """Get the anomaly_match version.
+
+ Returns:
+ str: The version string.
+ """
+ return anomaly_match.__version__
diff --git a/anomaly_match_ui/utils/display_transforms.py b/anomaly_match_ui/utils/display_transforms.py
new file mode 100644
index 0000000..8e75bf7
--- /dev/null
+++ b/anomaly_match_ui/utils/display_transforms.py
@@ -0,0 +1,178 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+import numpy as np
+from PIL import Image, ImageEnhance, ImageFilter, ImageOps
+from skimage.util import img_as_ubyte
+
+
+def prepare_for_display(img, rgb_mapping=None):
+ """Prepare N-channel image for RGB display.
+
+ Converts images with arbitrary channel counts to 3-channel RGB for display.
+
+ Args:
+ img (np.ndarray): Input image array with shape (H, W, C)
+ rgb_mapping (list, optional): List of 3 channel indices to use as RGB.
+ Defaults to [0, 1, 2] (first 3 channels).
+
+ Returns:
+ np.ndarray: 3-channel RGB image with shape (H, W, 3) and dtype uint8
+ """
+ # Handle different input formats
+ if isinstance(img, Image.Image):
+ img = np.array(img)
+
+ # Ensure we have a valid array
+ if not isinstance(img, np.ndarray):
+ raise ValueError(f"Expected numpy array or PIL Image, got {type(img)}")
+
+ # Handle 2D images (grayscale without channel dimension)
+ if len(img.shape) == 2:
+ img = img[:, :, np.newaxis]
+
+ channels = img.shape[-1]
+
+ # Convert based on channel count
+ if channels == 1:
+ # Grayscale: repeat to RGB
+ result = np.repeat(img, 3, axis=-1)
+ elif channels == 2:
+ # 2 channels: average to grayscale, then repeat
+ gray = img.mean(axis=-1, keepdims=True)
+ result = np.repeat(gray, 3, axis=-1)
+ elif channels == 3:
+ # RGB: use as-is
+ result = img
+ else:
+ # N-channels (4+): extract specified channels for RGB mapping
+ if rgb_mapping is None:
+ rgb_mapping = [0, 1, 2]
+
+ # Validate rgb_mapping
+ if len(rgb_mapping) != 3:
+ raise ValueError(f"rgb_mapping must have 3 elements, got {len(rgb_mapping)}")
+ if any(i >= channels for i in rgb_mapping):
+ raise ValueError(f"rgb_mapping indices {rgb_mapping} exceed channel count {channels}")
+
+ result = img[:, :, rgb_mapping]
+
+ # Ensure uint8 output
+ if result.dtype != np.uint8:
+ if result.max() <= 1.0:
+ result = (result * 255).astype(np.uint8)
+ else:
+ result = np.clip(result, 0, 255).astype(np.uint8)
+
+ return result
+
+
+def display_image_normalisation(img, rgb_mapping=None):
+ """Normalises the image for display.
+
+ Args:
+ img (np.ndarray): The input image array.
+ rgb_mapping (list, optional): For N-channel images, which channels to display as RGB.
+
+ Returns:
+ PIL.Image.Image: The normalised image.
+ """
+ # Handle NaN/inf values
+ if not np.isfinite(img).all():
+ img = np.nan_to_num(img, nan=0.0, posinf=1.0, neginf=0.0)
+
+ img = img - np.min(img)
+ img_max = np.max(img)
+
+ # Avoid division by zero
+ if img_max > 0:
+ img = img / img_max
+ else:
+ img = np.zeros_like(img) # Handle constant images
+
+ img = img_as_ubyte(img)
+
+ # Convert to displayable RGB
+ img = prepare_for_display(img, rgb_mapping=rgb_mapping)
+
+ return Image.fromarray(img)
+
+
+# from utility_functions
+def apply_transforms_ui(
+ img,
+ invert,
+ brightness,
+ contrast,
+ unsharp_mask_applied,
+ show_r=True,
+ show_g=True,
+ show_b=True,
+ channel_visibility=None,
+):
+ """
+ Applies the requested transformations to the given PIL Image.
+
+ Args:
+ img (PIL.Image.Image): The original image.
+ invert (bool): Whether to invert colors.
+ brightness (float): Brightness factor.
+ contrast (float): Contrast factor.
+ unsharp_mask_applied (bool): Whether to apply an unsharp mask.
+ show_r (bool): Whether to show the red channel (for RGB mode).
+ show_g (bool): Whether to show the green channel (for RGB mode).
+ show_b (bool): Whether to show the blue channel (for RGB mode).
+ channel_visibility (list, optional): For N-channel mode, list of booleans
+ indicating which channels to show. If provided, overrides show_r/g/b.
+
+ Returns:
+ PIL.Image.Image: The transformed image.
+ """
+ # Apply inversion
+ if invert:
+ img = ImageOps.invert(img)
+
+ # Apply brightness
+ if brightness != 1.0:
+ enhancer = ImageEnhance.Brightness(img)
+ img = enhancer.enhance(brightness)
+
+ # Apply contrast
+ if contrast != 1.0:
+ enhancer = ImageEnhance.Contrast(img)
+ img = enhancer.enhance(contrast)
+
+ # Apply unsharp mask if enabled
+ if unsharp_mask_applied:
+ img = img.filter(ImageFilter.UnsharpMask())
+
+ # Apply channel toggling
+ # Determine which channels to show
+ if channel_visibility is not None:
+ # N-channel mode: use first 3 values for RGB display
+ channels_mask = (
+ channel_visibility[:3] if len(channel_visibility) >= 3 else channel_visibility
+ )
+ # Pad with True if less than 3 values
+ while len(channels_mask) < 3:
+ channels_mask.append(True)
+ else:
+ # RGB mode: use individual channel flags
+ channels_mask = [show_r, show_g, show_b]
+
+ if not all(channels_mask):
+ # Convert PIL image to numpy array
+ img_array = np.array(img)
+
+ # Apply masking to the image array (zero out disabled channels)
+ for i, show_channel in enumerate(channels_mask):
+ if not show_channel and i < img_array.shape[-1]:
+ img_array[:, :, i] = 0
+
+ # Convert back to PIL image
+ img = Image.fromarray(img_array)
+
+ return img
diff --git a/anomaly_match/utils/numpy_to_byte_stream.py b/anomaly_match_ui/utils/image_utils.py
similarity index 99%
rename from anomaly_match/utils/numpy_to_byte_stream.py
rename to anomaly_match_ui/utils/image_utils.py
index d16f0ef..9cea43e 100644
--- a/anomaly_match/utils/numpy_to_byte_stream.py
+++ b/anomaly_match_ui/utils/image_utils.py
@@ -4,10 +4,11 @@
# is part of this source code package. No part of the package, including
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
-from PIL import Image
-import numpy as np
import io
+import numpy as np
+from PIL import Image
+
def numpy_array_to_byte_stream(numpy_array, normalize=True):
"""Convert a numpy array to a byte stream.
diff --git a/anomaly_match/ui/Widget.py b/anomaly_match_ui/widget.py
similarity index 76%
rename from anomaly_match/ui/Widget.py
rename to anomaly_match_ui/widget.py
index 86b3865..81396cb 100644
--- a/anomaly_match/ui/Widget.py
+++ b/anomaly_match_ui/widget.py
@@ -4,28 +4,29 @@
# is part of this source code package. No part of the package, including
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
+import datetime
import os
-import numpy as np
-from ipywidgets import HBox
+import time
+
import matplotlib.pyplot as plt
-from PIL import Image
+import numpy as np
from IPython.display import display
+from ipywidgets import HBox
from loguru import logger
-from sklearn.metrics import roc_curve
-import time
-import datetime
-
+from PIL import Image
from skimage.util import img_as_ubyte
+from sklearn.metrics import roc_curve
from anomaly_match.data_io.load_images import load_and_process_single_wrapper
-from anomaly_match.ui.preview_widget import PreviewWidget
+from anomaly_match_ui.preview_widget import PreviewWidget
# Import the newly created UI elements
-from anomaly_match.ui.ui_elements import (
- create_ui_elements,
- attach_click_listeners,
+from anomaly_match_ui.ui_elements import (
HTML_setup,
+ attach_click_listeners,
+ create_ui_elements,
)
+from anomaly_match_ui.utils.backend_interface import BackendInterface
def shorten_filename(filename: str, max_length: int = 25) -> str:
@@ -74,19 +75,13 @@ def shorten_filename(filename: str, max_length: int = 25) -> str:
class Widget:
"""A widget-based user interface for interacting with the anomaly detection session."""
- def __init__(self, session):
- """
- Initializes the UI widget with the given session.
-
- Args:
- session (Session): The session to interact with.
- """
- # Store session
- self.session = session
- self.cfg = session.cfg
+ def __init__(self):
+ """Initializes the UI widget using the BackendInterface."""
+ # Get config from backend
+ self.cfg = BackendInterface.get_config()
# Create the preview widget (handles image display and transforms)
- self.preview = PreviewWidget(session)
+ self.preview = PreviewWidget()
# Create all UI elements (moved from the big monolithic class)
self.ui = create_ui_elements()
@@ -96,8 +91,8 @@ def __init__(self, session):
self.ui["filename_text"] = self.preview.filename_text
# Rebuild center_row with preview widget's components
- from ipywidgets import VBox
import ipywidgets as widgets
+ from ipywidgets import VBox
self.ui["center_row"] = VBox(
[self.preview.filename_text, self.preview.image_widget],
@@ -119,7 +114,7 @@ def __init__(self, session):
self.preview.set_full_res_button(self.ui["transform_buttons"]["full_res"])
# Attach the output widget so the session logs go there
- session.set_terminal_out(self.ui["out"])
+ BackendInterface.set_terminal_output(self.ui["out"])
# Also attach it to the memory monitor
self.ui["memory_monitor"].set_output_widget(self.ui["out"])
@@ -127,7 +122,7 @@ def __init__(self, session):
self.ui["batch_size_slider"].value = self.cfg.N_to_load
# Set initial normalization dropdown value from session cached value
- self.ui["normalisation_dropdown"].value = session.cached_image_normalisation_enum
+ self.ui["normalisation_dropdown"].value = BackendInterface.get_cached_normalisation_method()
# Attach all click listeners / slider observers
attach_click_listeners(self)
@@ -145,7 +140,8 @@ def __init__(self, session):
self.update()
# Only attempt sorting if we have scores
- if self.session.scores is not None:
+ scores = BackendInterface.get_scores()
+ if scores is not None:
self.sort_by_anomalous()
# Start memory monitoring
@@ -193,10 +189,8 @@ def _pack_layout(self, main_layout, side_display):
def search_all_files(self):
"""Searches all files and displays the top N with their scores."""
with self.ui["out"]:
- self.session.cfg.N_to_load = self.ui["batch_size_slider"].value
- logger.debug(
- f"Searching all files for anomalies with batch size: {self.session.cfg.N_to_load}"
- )
+ self.cfg.N_to_load = self.ui["batch_size_slider"].value
+ logger.debug(f"Searching all files for anomalies with batch size: {self.cfg.N_to_load}")
# Set progress bar color to cyan for 'search_all_files' task
self.ui["progress_bar"].style = {"bar_color": "cyan"}
@@ -213,16 +207,22 @@ def update_progress(
completed=False,
total_time_str=None,
final_speed=None,
+ results_updated=False,
):
+ # Update gallery when new results are loaded (after each chunk)
+ if results_updated:
+ self.display_top_files_scores()
+ return
+
# Always update progress bar if batch info provided
if batch is not None and num_batches is not None:
self.ui["progress_bar"].value = batch / num_batches
# Handle completion message
if completed and total_time_str:
- self.ui["train_label"].value = (
- f"Search complete in {total_time_str} ({final_speed:.1f} img/sec)"
- )
+ self.ui[
+ "train_label"
+ ].value = f"Search complete in {total_time_str} ({final_speed:.1f} img/sec)"
return
# Handle batch update with ETA information
@@ -238,41 +238,41 @@ def update_progress(
self.ui["train_label"].value = message
else:
# Early in the process when ETA isn't available yet
- self.ui["train_label"].value = (
- f"Evaluating files... Batch: {batch}/{num_batches}"
- )
+ self.ui[
+ "train_label"
+ ].value = f"Evaluating files... Batch: {batch}/{num_batches}"
else:
# Regular evaluation updates (not used in this function but keeping for completeness)
if eta_str:
- self.ui["train_label"].value = (
- f"Evaluating... {batch}/{num_batches} | ETA: {eta_str}"
- )
+ self.ui[
+ "train_label"
+ ].value = f"Evaluating... {batch}/{num_batches} | ETA: {eta_str}"
else:
self.ui["train_label"].value = f"Evaluating... {batch}/{num_batches}"
- self.session.evaluate_all_images(
- top_N=self.cfg.top_N, progress_callback=update_progress
- )
- # update models last_normalisation_method only after successful eval
- if self.session.model.last_normalisation_method is None:
- self.session.model.last_normalisation_method = (
- self.session.cfg.normalisation.normalisation_method
- )
- elif (
- self.session.model.last_normalisation_method
- != self.session.cfg.normalisation.normalisation_method
- ):
- logger.warning(
- f"Evaluated with a new normalisation {self.session.cfg.normalisation.normalisation_method.name} method "
- + f"not previously used with the model: {self.session.model.last_normalisation_method.name}"
- )
- self.session.model.last_normalisation_method = (
- self.session.cfg.normalisation.normalisation_method
+ try:
+ BackendInterface.evaluate_all_images(
+ top_n=self.cfg.top_N, progress_callback=update_progress
)
-
- # Display will be updated by the callback when completed
- self.display_top_files_scores()
- self.ui["progress_bar"].style = {"bar_color": "green"}
+ # update models last_normalisation_method only after successful eval
+ model = BackendInterface.get_model()
+ if model.last_normalisation_method is None:
+ model.last_normalisation_method = self.cfg.normalisation.normalisation_method
+ elif model.last_normalisation_method != self.cfg.normalisation.normalisation_method:
+ logger.warning(
+ f"Evaluated with a new normalisation {self.cfg.normalisation.normalisation_method.name} method "
+ + f"not previously used with the model: {model.last_normalisation_method.name}"
+ )
+ model.last_normalisation_method = self.cfg.normalisation.normalisation_method
+ except Exception as e:
+ logger.error(f"Error during evaluation: {e}")
+ finally:
+ # Always update display with whatever results are available,
+ # even if evaluation was interrupted or partially failed
+ scores = BackendInterface.get_scores()
+ if scores is not None and len(scores) > 0:
+ self.display_top_files_scores()
+ self.ui["progress_bar"].style = {"bar_color": "green"}
def display_top_files_scores(self):
"""Displays the top files and their scores."""
@@ -281,10 +281,6 @@ def display_top_files_scores(self):
self.ui["progress_bar"].style = {"bar_color": "green"}
self.display_gallery()
- def update_image_display(self):
- """Updates the display of the current image."""
- self.preview.update_display()
-
def update_image_UI_label(self, filename=None, score=None):
"""Updates the UI label with the current image's filename, score, and label."""
self.preview.update_label_only()
@@ -292,32 +288,32 @@ def update_image_UI_label(self, filename=None, score=None):
# ======== Sorting Methods ========
def sort_by_anomalous(self):
"""Sorts the images by their anomalous scores and updates the display."""
- self.session.sort_by_anomalous()
+ BackendInterface.sort_by_anomalous()
self.preview.set_index(0)
self.preview.update_display()
def sort_by_nominal(self):
"""Sorts the images by their nominal scores and updates the display."""
- self.session.sort_by_nominal()
+ BackendInterface.sort_by_nominal()
self.preview.set_index(0)
self.preview.update_display()
def sort_by_mean(self):
"""Sorts the images by distance to mean score and updates the display."""
- self.session.sort_by_mean()
+ BackendInterface.sort_by_mean()
self.preview.set_index(0)
self.preview.update_display()
def sort_by_median(self):
"""Sorts the images by distance to median score and updates the display."""
- self.session.sort_by_median()
+ BackendInterface.sort_by_median()
self.preview.set_index(0)
self.preview.update_display()
# ======== Navigation ========
def next_image(self):
"""Displays the next image in the catalog."""
- new_index = min(len(self.session.img_catalog) - 1, self.preview.current_index + 1)
+ new_index = min(BackendInterface.get_image_count() - 1, self.preview.current_index + 1)
self.preview.reset_full_resolution_mode()
self.preview.set_index(new_index)
self.preview.update_display()
@@ -362,6 +358,9 @@ def display_gallery(self):
self.ui["gallery"].clear_output(wait=True)
try:
if self.cfg.test_ratio > 0:
+ eval_performance = BackendInterface.get_eval_performance()
+ if eval_performance is None:
+ return
# Show mispredicted images
mispredicted_images = []
@@ -369,7 +368,7 @@ def display_gallery(self):
# First collect all filenames we want to display
display_files = []
- for filename, (pred, label) in self.session.eval_performance[
+ for filename, (pred, label) in eval_performance[
"eval/predictions_and_labels"
].items():
pred, label = pred.item(), label.item()
@@ -386,12 +385,11 @@ def display_gallery(self):
try:
img_array = load_and_process_single_wrapper(
path,
- self.session.cfg,
+ self.cfg,
desc="widget loading image",
show_progress=False,
)
- img = Image.fromarray(img_array)
mispredicted_images.append(img_array)
display_name = shorten_filename(filename)
@@ -406,7 +404,7 @@ def display_gallery(self):
num_images = len(mispredicted_images)
if num_images > 0:
plt.figure(figsize=(12, 4), facecolor="black")
- eval_perf = self.session.eval_performance
+ eval_perf = eval_performance
plt.suptitle(
f"Top {num_images} Mispredicted Test Images | "
f"Acc: {eval_perf['eval/top-1-acc'] * 100:.1f}% | "
@@ -434,7 +432,7 @@ def display_gallery(self):
ax1 = plt.subplot(1, 2, 1)
labels, probs = eval_perf["eval/roc_data"]
fpr, tpr, _ = roc_curve(labels, probs)
- ax1.plot(fpr, tpr, "b-", label=f'ROC (AUC={eval_perf["eval/auroc"]:.3f})')
+ ax1.plot(fpr, tpr, "b-", label=f"ROC (AUC={eval_perf['eval/auroc']:.3f})")
ax1.plot([0, 1], [0, 1], "r--")
ax1.set_title("ROC Curve", color="white")
ax1.grid(True, alpha=0.3)
@@ -451,7 +449,7 @@ def display_gallery(self):
recall,
precision,
"g-",
- label=f'PRC (AUC={eval_perf["eval/auprc"]:.3f})',
+ label=f"PRC (AUC={eval_perf['eval/auprc']:.3f})",
)
ax2.set_title("Precision-Recall Curve", color="white")
ax2.grid(True, alpha=0.3)
@@ -467,7 +465,10 @@ def display_gallery(self):
else:
# Show top 5 anomalous & top 5 nominal
- scores = self.session.scores
+ scores = BackendInterface.get_scores()
+ img_catalog = BackendInterface.get_session().img_catalog
+ filenames = BackendInterface.get_filenames()
+
indices = np.argsort(scores)
num_images_to_display = min(5, len(scores) // 2)
@@ -478,7 +479,7 @@ def display_gallery(self):
image_text = []
for idx in top_anomalous_indices:
- img_arr = self.session.img_catalog[idx]
+ img_arr = img_catalog[idx]
img_arr = img_arr - np.min(img_arr)
img_arr = img_arr / np.max(img_arr)
img_arr = img_as_ubyte(img_arr)
@@ -486,13 +487,13 @@ def display_gallery(self):
img_arr = np.repeat(img_arr, 3, axis=-1)
pil_img = Image.fromarray(img_arr)
images.append(pil_img)
- filename = self.session.filenames[idx]
+ filename = filenames[idx]
display_name = shorten_filename(filename)
score = scores[idx]
image_text.append(f"{display_name}\nScore: {score:.4f}")
for idx in top_nominal_indices:
- img_arr = self.session.img_catalog[idx]
+ img_arr = img_catalog[idx]
img_arr = img_arr - np.min(img_arr)
img_arr = img_arr / np.max(img_arr)
img_arr = img_as_ubyte(img_arr)
@@ -500,7 +501,7 @@ def display_gallery(self):
img_arr = np.repeat(img_arr, 3, axis=-1)
pil_img = Image.fromarray(img_arr)
images.append(pil_img)
- filename = self.session.filenames[idx]
+ filename = filenames[idx]
display_name = shorten_filename(filename)
score = scores[idx]
image_text.append(f"{display_name}\nScore: {score:.4f}")
@@ -527,33 +528,33 @@ def display_gallery(self):
logger.error(f"Error displaying gallery: {e}")
def save_labels(self):
- self.session.save_labels()
+ BackendInterface.save_labels()
def remember_current_file(self, _):
"""Remembers the currently displayed file."""
- self.session.remember_current_file(self.session.filenames[self.preview.current_index])
+ filenames = BackendInterface.get_filenames()
+ BackendInterface.remember_current_file(filenames[self.preview.current_index])
def save_model(self):
"""Saves the model using the session."""
- self.session.save_model()
+ BackendInterface.save_model()
def load_model(self):
"""Loads the model using the session."""
with self.ui["out"]:
logger.debug(
- f"Loading model, cfg norm: {self.session.cfg.normalisation.normalisation_method}, "
+ f"Loading model, cfg norm: {self.cfg.normalisation.normalisation_method}, "
)
- self.session.load_model()
+ BackendInterface.load_model()
# Update the normalization dropdown to match the session's method
- self.ui["normalisation_dropdown"].value = (
- self.session.cfg.normalisation.normalisation_method
- )
+ self.ui["normalisation_dropdown"].value = self.cfg.normalisation.normalisation_method
with self.ui["out"]:
+ model = BackendInterface.get_model()
logger.debug(
- f"Loaded model, cfg norm: {self.session.cfg.normalisation.normalisation_method},"
- + f" model norm: {self.session.model.last_normalisation_method}"
+ f"Loaded model, cfg norm: {self.cfg.normalisation.normalisation_method},"
+ + f" model norm: {model.last_normalisation_method}"
)
self.update()
@@ -562,8 +563,8 @@ def train(self):
with self.ui["out"]:
logger.debug(
- f"Session norm: {self.session.cached_image_normalisation_enum}, "
- f"selected norm: {self.session.cfg.normalisation.normalisation_method}"
+ f"Session norm: {BackendInterface.get_cached_normalisation_method()}, "
+ f"selected norm: {self.cfg.normalisation.normalisation_method}"
)
with self.ui["out"]:
@@ -606,57 +607,72 @@ def update_training_progress(iteration, total_iterations):
eta_str = str(datetime.timedelta(seconds=int(eta_seconds)))
# Update display with iteration count and ETA
- self.ui["train_label"].value = (
- f"Training... Iteration {iteration}/{total_iterations} | "
- f"ETA: {eta_str}"
+ self.ui[
+ "train_label"
+ ].value = (
+ f"Training... Iteration {iteration}/{total_iterations} | ETA: {eta_str}"
)
else:
# Early iterations - no reliable ETA yet
- self.ui["train_label"].value = (
- f"Training... Iteration {iteration}/{total_iterations}"
- )
+ self.ui[
+ "train_label"
+ ].value = f"Training... Iteration {iteration}/{total_iterations}"
last_update_time = current_time
last_iteration = iteration
- self.session.train(self.cfg, progress_callback=update_training_progress)
+ BackendInterface.train(self.cfg, progress_callback=update_training_progress)
# update models last_normalisation_method after successful training
- if self.session.model.last_normalisation_method is None:
- self.session.model.last_normalisation_method = (
- self.session.cfg.normalisation.normalisation_method
- )
- elif (
- self.session.model.last_normalisation_method
- != self.session.cfg.normalisation.normalisation_method
- ):
+ model = BackendInterface.get_model()
+ if model.last_normalisation_method is None:
+ model.last_normalisation_method = self.cfg.normalisation.normalisation_method
+ elif model.last_normalisation_method != self.cfg.normalisation.normalisation_method:
logger.warning(
- f"Trained with a new normalisation {self.session.cfg.normalisation.normalisation_method.name} method "
- + f"not previously used with the model: {self.session.model.last_normalisation_method.name}"
- )
- self.session.model.last_normalisation_method = (
- self.session.cfg.normalisation.normalisation_method
+ f"Trained with a new normalisation {self.cfg.normalisation.normalisation_method.name} method "
+ + f"not previously used with the model: {model.last_normalisation_method.name}"
)
+ model.last_normalisation_method = self.cfg.normalisation.normalisation_method
# Calculate total time taken
total_time = time.time() - start_time
time_str = str(datetime.timedelta(seconds=int(total_time)))
- self.ui["progress_bar"].style = {"bar_color": "green"}
- self.ui["train_label"].value = f"Training complete in {time_str}."
+ self.ui[
+ "train_label"
+ ].value = f"Training complete in {time_str}. Updating predictions..."
self.update()
self.sort_by_anomalous()
def update(self):
"""Updates the UI components and performs evaluation."""
+ self.ui["progress_bar"].value = 0.0
self.ui["progress_bar"].style = {"bar_color": "cyan"}
- self.session.update_predictions()
+ self.ui["train_label"].value = "Scoring unlabelled samples..."
+
+ phase = {"name": "scoring"}
+
+ def update_progress(current, total):
+ self.ui["progress_bar"].value = current / total
+ if phase["name"] == "scoring":
+ self.ui["train_label"].value = f"Scoring unlabelled samples... {current}/{total}"
+ if current >= total:
+ phase["name"] = "evaluating"
+ self.ui["train_label"].value = "Evaluating model..."
+ self.ui["progress_bar"].value = 0.0
+ else:
+ self.ui["train_label"].value = f"Evaluating model... {current}/{total}"
+
+ BackendInterface.update_predictions(progress_callback=update_progress)
self.preview.set_index(0)
if self.cfg.test_ratio > 0:
- if self.session.eval_performance is not None:
- self.ui["train_label"].value = (
- f"Training Complete. Eval Acc: {self.session.eval_performance['eval/top-1-acc'] * 100:.2f}%"
+ eval_performance = BackendInterface.get_eval_performance()
+ if eval_performance is not None:
+ self.ui[
+ "train_label"
+ ].value = (
+ f"Training Complete. Eval Acc: {eval_performance['eval/top-1-acc'] * 100:.2f}%"
)
else:
self.ui["train_label"].value = "Training Complete. No evaluation performed yet."
@@ -671,7 +687,7 @@ def update(self):
def update_batch_size(self, change):
"""Updates the batch size in the session config."""
- self.session.cfg.N_to_load = change["new"]
+ self.cfg.N_to_load = change["new"]
def next_batch(self):
"""Loads the next batch and updates predictions."""
@@ -680,7 +696,7 @@ def next_batch(self):
self.ui["progress_bar"].style = {"bar_color": "orange"}
self.ui["train_label"].value = "Predicting next batch..."
- self.session.load_next_batch()
+ BackendInterface.load_next_batch()
self.ui["progress_bar"].style = {"bar_color": "green"}
self.ui["train_label"].value = "Batch loading complete."
@@ -690,14 +706,14 @@ def next_batch(self):
def reset_model(self):
"""Resets the model in the session."""
with self.ui["out"]:
- self.session.reset_model()
+ BackendInterface.reset_model()
self.update()
self.sort_by_anomalous()
def load_top_files(self):
"""Loads the top files and updates the display."""
with self.ui["out"]:
- self.session.load_top_files(self.cfg.top_N)
+ BackendInterface.load_top_files(self.cfg.top_N)
self.display_top_files_scores()
# Add channel toggle methods
@@ -717,10 +733,10 @@ def select_normalisation(self, change):
"""Updates the normalization method when dropdown selection changes."""
new_value = change["new"]
if new_value != self.cfg.normalisation.normalisation_method:
- self.session.set_normalisation_method(new_value)
+ BackendInterface.set_normalisation_method(new_value)
self.preview.update_display()
def unlabel_current_image(self):
"""Removes the label from the currently displayed image."""
- self.session.unlabel_image(self.preview.current_index)
+ BackendInterface.unlabel_image(self.preview.current_index)
self.preview.update_label_only()
diff --git a/environment.yml b/environment.yml
index f9cacce..94b3ce4 100644
--- a/environment.yml
+++ b/environment.yml
@@ -12,21 +12,22 @@ channels:
dependencies:
- astropy
- dotmap
- - efficientnet-pytorch
- h5py
- - imageio
- ipykernel
- ipywidgets
- loguru
- matplotlib
- numpy
- pandas<3
+ - psutil
+ - pyarrow
- python=3.11
- pytorch>=2.6
- pytorch-cuda=12.4
- pyturbojpeg
- - scikit-learn
- scikit-image
+ - scikit-learn
+ - scipy
- toml
- torchvision
- tqdm
@@ -34,8 +35,7 @@ dependencies:
- pip
- pip:
- albumentations
- - efficientnet_lite_pytorch
- - efficientnet_lite0_pytorch_model
- - opencv-python-headless
- - fitsbolt>=0.1.6
- cutana>=0.2.1
+ - fitsbolt>=0.2
+ - opencv-python-headless
+ - timm
diff --git a/environment_CI.yml b/environment_CI.yml
index edb3021..d815bc1 100644
--- a/environment_CI.yml
+++ b/environment_CI.yml
@@ -11,7 +11,6 @@ channels:
dependencies:
- astropy
- dotmap
- - efficientnet-pytorch
- h5py
- imageio
- ipykernel
@@ -23,6 +22,7 @@ dependencies:
- python=3.11
- pytorch>=2.6
- pytest
+ - pytest-cov
- pytest-asyncio>=0.23.0
- pyturbojpeg
- scikit-learn
@@ -35,7 +35,6 @@ dependencies:
- pip:
- opencv-python-headless
- albumentations
- - efficientnet_lite_pytorch
- - efficientnet_lite0_pytorch_model
- - fitsbolt>=0.1.6
+ - timm
+ - fitsbolt>=0.2
- cutana>=0.2.1
diff --git a/paper_scripts/create_results.py b/paper_scripts/create_results.py
index 736d0ea..6bf2250 100644
--- a/paper_scripts/create_results.py
+++ b/paper_scripts/create_results.py
@@ -17,14 +17,15 @@
python create_results.py
"""
-import sys
-import time
+import argparse
import datetime
+import glob
import subprocess
+import sys
+import time
from pathlib import Path
-import argparse
+
import pandas as pd
-import glob
# ========== CONFIGURATION ==========
# Toggle which experiment sets to run
diff --git a/paper_scripts/dataset_plot.py b/paper_scripts/dataset_plot.py
index dac6293..2a18cee 100644
--- a/paper_scripts/dataset_plot.py
+++ b/paper_scripts/dataset_plot.py
@@ -12,12 +12,13 @@
"""
import os
+from pathlib import Path
+
+import matplotlib.gridspec as gridspec
+import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
-import matplotlib.pyplot as plt
-import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle
-from pathlib import Path
from PIL import Image
# Figure settings for paper-quality output
diff --git a/paper_scripts/galaxyzoo_multi_seed.py b/paper_scripts/galaxyzoo_multi_seed.py
index bd9b0cc..3e60012 100644
--- a/paper_scripts/galaxyzoo_multi_seed.py
+++ b/paper_scripts/galaxyzoo_multi_seed.py
@@ -46,31 +46,32 @@
└── average_astronomaly_comparison_zoo_class1_n400_ratio0.015_iter{DEFAULT_TRAINING_RUNS}.pdf
"""
-import sys
-import time
+import argparse
import datetime
-import subprocess
import pickle
+import subprocess
+import sys
+import time
from pathlib import Path
-import argparse
-import pandas as pd
-import numpy as np
+from typing import Dict, List
+
import matplotlib.pyplot as plt
-from typing import List, Dict
+import numpy as np
+import pandas as pd
# Import plotting constants from paper_plots
sys.path.append(str(Path(__file__).parent))
try:
from paper_plots import (
BLUE,
+ DEFAULT_DPI,
GREEN,
ORANGE,
+ PERFECT_LINE_STYLE,
PURPLE,
RED,
REFERENCE_LINE_COLOR,
REFERENCE_LINE_STYLE,
- PERFECT_LINE_STYLE,
- DEFAULT_DPI,
)
except ImportError:
# Fallback colors if import fails
@@ -227,8 +228,8 @@ def collect_results_from_seeds(output_dir: Path, seeds: List[int]) -> Dict[str,
config_key = f"zoo_class{cls}_n{n_samples}_ratio{anomaly_ratio:.3f}"
config_results = []
sub_config_key = (
- f"galaxyzoo_anomaly{cls}_n{n_samples-int(n_samples*anomaly_ratio)}"
- + f"_a{int(n_samples*anomaly_ratio)}"
+ f"galaxyzoo_anomaly{cls}_n{n_samples - int(n_samples * anomaly_ratio)}"
+ + f"_a{int(n_samples * anomaly_ratio)}"
)
for seed in seeds:
seed_dir = output_dir / f"seed_{seed}" / "galaxyzoo" / config_key / sub_config_key
@@ -331,8 +332,8 @@ def load_plot_data_from_seeds(
for n_samples, anomaly_ratio in GALAXYZOO_CONFIGS:
config_key = f"zoo_class{cls}_n{n_samples}_ratio{anomaly_ratio:.3f}"
sub_config_key = (
- f"galaxyzoo_anomaly{cls}_n{n_samples-int(n_samples*anomaly_ratio)}"
- + f"_a{int(n_samples*anomaly_ratio)}"
+ f"galaxyzoo_anomaly{cls}_n{n_samples - int(n_samples * anomaly_ratio)}"
+ + f"_a{int(n_samples * anomaly_ratio)}"
)
plot_data_list = []
@@ -594,9 +595,9 @@ def main():
# Run experiments for each seed
for i, seed in enumerate(args.seeds):
- print(f"\n{'='*60}")
- print(f"Running experiments for seed {seed} ({i+1}/{len(args.seeds)})")
- print(f"{'='*60}")
+ print(f"\n{'=' * 60}")
+ print(f"Running experiments for seed {seed} ({i + 1}/{len(args.seeds)})")
+ print(f"{'=' * 60}")
try:
seed_output_dir = run_single_seed_experiment(
@@ -608,9 +609,9 @@ def main():
continue
# Collect and analyze results
- print(f"\n{'='*60}")
+ print(f"\n{'=' * 60}")
print("Collecting and analyzing results from all seeds")
- print(f"{'='*60}")
+ print(f"{'=' * 60}")
# Create summary directory
summary_dir = output_dir / "summary"
@@ -636,10 +637,10 @@ def main():
# Final summary
total_time = time.time() - start_time
- print(f"\n{'='*60}")
+ print(f"\n{'=' * 60}")
print("EXPERIMENT SUMMARY")
- print(f"{'='*60}")
- print(f"Total execution time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")
+ print(f"{'=' * 60}")
+ print(f"Total execution time: {total_time:.2f} seconds ({total_time / 60:.2f} minutes)")
print(f"Seeds processed: {len(args.seeds)}")
print(f"Configurations per seed: {len(GALAXYZOO_CONFIGS)}")
print(f"Results saved to: {output_dir}")
diff --git a/paper_scripts/get_example_images.py b/paper_scripts/get_example_images.py
index 6938efe..9ec453d 100644
--- a/paper_scripts/get_example_images.py
+++ b/paper_scripts/get_example_images.py
@@ -14,24 +14,23 @@
import os
import sys
+from pathlib import Path
+
+import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
-import matplotlib.pyplot as plt
+import torchvision.transforms as transforms
from matplotlib.patches import Rectangle
-from pathlib import Path
from PIL import Image
-import torchvision.transforms as transforms
-
sys.path.append("/media/home/AnomalyMatch")
sys.path.append("../")
from anomaly_match.image_processing.transforms import (
- get_weak_transforms,
get_strong_transforms,
+ get_weak_transforms,
)
-
# Constants
HOURGLASS_CLASS_IDX = 57 # Class ID for hourglass images (anomaly)
NUM_EXAMPLES = 10 # Number of examples to generate for each class
diff --git a/paper_scripts/get_hst_example_images.py b/paper_scripts/get_hst_example_images.py
index f0c7055..6ee1c5a 100644
--- a/paper_scripts/get_hst_example_images.py
+++ b/paper_scripts/get_hst_example_images.py
@@ -13,14 +13,15 @@
"""
import os
+import random
import sys
-import numpy as np
+from pathlib import Path
+
import matplotlib.pyplot as plt
+import numpy as np
+import torchvision.transforms as transforms
from matplotlib.patches import Rectangle
-from pathlib import Path
from PIL import Image
-import torchvision.transforms as transforms
-import random
sys.path.append("/media/home/AnomalyMatch")
sys.path.append("../")
diff --git a/paper_scripts/paper_benchmark.py b/paper_scripts/paper_benchmark.py
index 318e60b..bb0b18a 100644
--- a/paper_scripts/paper_benchmark.py
+++ b/paper_scripts/paper_benchmark.py
@@ -235,7 +235,7 @@ def get_prediction_scores(session, labeled_filenames, hdf5_path, progress_bar=No
if batch_idx % 5 == 0 or batch_idx == num_batches - 1:
logger.info(
- f"Batch {batch_idx+1}/{num_batches}: Processed {batch_size_actual} images "
+ f"Batch {batch_idx + 1}/{num_batches}: Processed {batch_size_actual} images "
f"in {batch_time:.2f}s ({images_per_sec:.1f} img/s)"
)
@@ -267,7 +267,7 @@ def get_prediction_scores(session, labeled_filenames, hdf5_path, progress_bar=No
avg_time_per_image = total_time / processed_images if processed_images > 0 else 0
logger.info(
f"Processed {processed_images} images in {total_time:.2f}s "
- f"({processed_images/total_time:.1f} img/s, {avg_time_per_image*1000:.2f}ms/img)"
+ f"({processed_images / total_time:.1f} img/s, {avg_time_per_image * 1000:.2f}ms/img)"
)
# Handle mismatched length between scores and filenames
@@ -504,11 +504,11 @@ def run_benchmark(args):
# Run training and evaluation loop
for iteration in range(args.training_runs):
# Create iteration-specific directory
- iter_dir = os.path.join(run_dir, f"iteration_{iteration+1}")
+ iter_dir = os.path.join(run_dir, f"iteration_{iteration + 1}")
iter_plots_dir = os.path.join(iter_dir, "plots")
logger.info(
- f"\n======= Starting training iteration {iteration+1}/{args.training_runs} ======="
+ f"\n======= Starting training iteration {iteration + 1}/{args.training_runs} ======="
)
# Train model
@@ -516,7 +516,7 @@ def run_benchmark(args):
train_with_progress_bar(session, cfg)
# Save model after training
- model_save_path = os.path.join(model_dir, f"model_iter{iteration+1}.pth")
+ model_save_path = os.path.join(model_dir, f"model_iter{iteration + 1}.pth")
iter_model_path = os.path.join(iter_dir, f"model.pth")
# session.session_tracker = None
session.cfg.model_path = model_save_path
@@ -562,12 +562,12 @@ def run_benchmark(args):
# Log top percentile metrics
if "top_0.1pct_anomalies_found" in metrics:
logger.info(
- f"Iter {iteration+1} - Anomalies in top 0.1%: {metrics['top_0.1pct_anomalies_found']:.2f}%, "
+ f"Iter {iteration + 1} - Anomalies in top 0.1%: {metrics['top_0.1pct_anomalies_found']:.2f}%, "
f"Precision: {metrics['top_0.1pct_precision']:.2f}%"
)
if "top_1.0pct_anomalies_found" in metrics:
logger.info(
- f"Iter {iteration+1} - Anomalies in top 1.0%: {metrics['top_1.0pct_anomalies_found']:.2f}%, "
+ f"Iter {iteration + 1} - Anomalies in top 1.0%: {metrics['top_1.0pct_anomalies_found']:.2f}%, "
f"Precision: {metrics['top_1.0pct_precision']:.2f}%"
)
@@ -868,9 +868,9 @@ def run_multi_class_benchmark(args):
)
for anomaly_class in args.anomaly_classes:
- logger.info(f"\n\n{'='*50}")
+ logger.info(f"\n\n{'=' * 50}")
logger.info(f"Starting benchmark for anomaly class {anomaly_class}")
- logger.info(f"{'='*50}\n")
+ logger.info(f"{'=' * 50}\n")
# Update args for this specific class
args.anomaly_class = anomaly_class
@@ -1045,11 +1045,11 @@ def run_multi_class_benchmark(args):
# Run training and evaluation loop
for iteration in range(args.training_runs):
# Create iteration-specific directory
- iter_dir = os.path.join(run_dir, f"iteration_{iteration+1}")
+ iter_dir = os.path.join(run_dir, f"iteration_{iteration + 1}")
iter_plots_dir = os.path.join(iter_dir, "plots")
logger.info(
- f"\n======= Starting training iteration {iteration+1}/{args.training_runs} for anomaly class {anomaly_class} ======="
+ f"\n======= Starting training iteration {iteration + 1}/{args.training_runs} for anomaly class {anomaly_class} ======="
)
# Train model with progress bar
@@ -1057,7 +1057,7 @@ def run_multi_class_benchmark(args):
train_with_progress_bar(session, cfg)
# Save model after training
- model_save_path = os.path.join(model_dir, f"model_iter{iteration+1}.pth")
+ model_save_path = os.path.join(model_dir, f"model_iter{iteration + 1}.pth")
iter_model_path = os.path.join(iter_dir, f"model.pth")
# session.session_tracker = None
session.cfg.model_path = model_save_path
@@ -1101,12 +1101,12 @@ def run_multi_class_benchmark(args):
# Log top percentile metrics
if "top_0.1pct_anomalies_found" in metrics:
logger.info(
- f"Iter {iteration+1} - Anomalies in top 0.1%: {metrics['top_0.1pct_anomalies_found']:.2f}%, "
+ f"Iter {iteration + 1} - Anomalies in top 0.1%: {metrics['top_0.1pct_anomalies_found']:.2f}%, "
f"Precision: {metrics['top_0.1pct_precision']:.2f}%"
)
if "top_1.0pct_anomalies_found" in metrics:
logger.info(
- f"Iter {iteration+1} - Anomalies in top 1.0%: {metrics['top_1.0pct_anomalies_found']:.2f}%, "
+ f"Iter {iteration + 1} - Anomalies in top 1.0%: {metrics['top_1.0pct_anomalies_found']:.2f}%, "
f"Precision: {metrics['top_1.0pct_precision']:.2f}%"
)
diff --git a/paper_scripts/paper_plots.py b/paper_scripts/paper_plots.py
index 305c49b..1c8f756 100644
--- a/paper_scripts/paper_plots.py
+++ b/paper_scripts/paper_plots.py
@@ -12,38 +12,40 @@
"""
import os
+
+import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
-import matplotlib.pyplot as plt
import seaborn as sns
-from sklearn.metrics import roc_curve
from loguru import logger
from paper_utils import save_plot_data
+from sklearn.metrics import roc_curve
+
+from paper_scripts.create_results import GALAXYZOO_THRESHOLDS
from paper_scripts.plot_colors import (
+ ANOMALY_COLOR,
BLUE,
- RED,
+ COLORMAP_NAME,
GREEN,
+ HIST_ALPHA,
+ HLINE_ALPHA,
+ HLINE_COLOR,
+ HLINE_STYLE,
+ LAST_ITER_COLOR,
+ NORMAL_COLOR,
ORANGE,
- PURPLE,
+ PERFECT_LINE_ALPHA,
PERFECT_LINE_COLOR,
PERFECT_LINE_STYLE,
- PERFECT_LINE_ALPHA,
+ PURPLE,
+ RED,
+ REFERENCE_LINE_ALPHA,
REFERENCE_LINE_COLOR,
REFERENCE_LINE_STYLE,
- REFERENCE_LINE_ALPHA,
+ VLINE_ALPHA,
VLINE_COLOR,
VLINE_STYLE,
- VLINE_ALPHA,
- HLINE_COLOR,
- HLINE_STYLE,
- HLINE_ALPHA,
- COLORMAP_NAME,
- LAST_ITER_COLOR,
- NORMAL_COLOR,
- ANOMALY_COLOR,
- HIST_ALPHA,
)
-from paper_scripts.create_results import GALAXYZOO_THRESHOLDS
# Scaling factor for all font sizes (adjust this to make all text larger or smaller)
FONT_SCALE = 1.75
@@ -199,7 +201,7 @@ def plot_roc_prc_curves(metrics, iteration, plots_dir):
y_scores = np.concatenate([metrics["anomaly_scores"], metrics["normal_scores"]])
fpr, tpr, _ = roc_curve(y_true, y_scores) # 1. ROC Curve
plt.figure(figsize=(8, 8))
- plt.plot(fpr, tpr, color=BLUE, linewidth=2, label=f'AUROC = {metrics["auroc"]:.3f}')
+ plt.plot(fpr, tpr, color=BLUE, linewidth=2, label=f"AUROC = {metrics['auroc']:.3f}")
plt.plot(
[0, 1],
[0, 1],
@@ -229,7 +231,7 @@ def plot_roc_prc_curves(metrics, iteration, plots_dir):
metrics["precision"],
color=RED,
linewidth=2,
- label=f'AUPRC = {metrics["auprc"]:.3f}',
+ label=f"AUPRC = {metrics['auprc']:.3f}",
)
# Note: We're removing the baseline from the PR curve as requested
@@ -251,7 +253,7 @@ def plot_roc_prc_curves(metrics, iteration, plots_dir):
# 3. Combined figure (side by side)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) # Plot ROC curve
- ax1.plot(fpr, tpr, color=BLUE, linewidth=2, label=f'AUROC = {metrics["auroc"]:.3f}')
+ ax1.plot(fpr, tpr, color=BLUE, linewidth=2, label=f"AUROC = {metrics['auroc']:.3f}")
ax1.plot(
[0, 1],
[0, 1],
@@ -274,7 +276,7 @@ def plot_roc_prc_curves(metrics, iteration, plots_dir):
metrics["precision"],
color=RED,
linewidth=2,
- label=f'AUPRC = {metrics["auprc"]:.3f}',
+ label=f"AUPRC = {metrics['auprc']:.3f}",
)
ax2.set_xlabel("Recall")
ax2.set_ylabel("Precision")
@@ -1418,9 +1420,9 @@ def create_grid_plot(merged_df, n_grid, data_dir, plots_dir, iteration, suffix,
suffix: Suffix for the output filename
fig_title: Optional title for the figure
"""
- from PIL import Image
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches
+ from PIL import Image
# Number of user score bins (fewer than ML score bins)
n_user_grid = 6
@@ -1863,7 +1865,7 @@ def create_grid_plot(merged_df, n_grid, data_dir, plots_dir, iteration, suffix,
# Add a legend for the visual indicators in the bottom-left corner
ax_legend = plt.subplot(gs[n_user_grid + 1, 0])
- legend_text = "+/-n: AM ranks\n" "higher/lower\n" "score than GZ\n"
+ legend_text = "+/-n: AM ranks\nhigher/lower\nscore than GZ\n"
ax_legend.text(
0.35,
0.5,
diff --git a/paper_scripts/paper_utils.py b/paper_scripts/paper_utils.py
index 9aabffc..de4ecba 100644
--- a/paper_scripts/paper_utils.py
+++ b/paper_scripts/paper_utils.py
@@ -6,16 +6,15 @@
# the terms contained in the file 'LICENCE.txt'.
import argparse
import os
-
+import sys
from pathlib import Path
+
+import h5py
import ipywidgets as widgets
import numpy as np
import pandas as pd
-import h5py
from loguru import logger
-from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
-
-import sys
+from sklearn.metrics import auc, precision_recall_curve, roc_auc_score
sys.path.append("/media/home/AnomalyMatch")
sys.path.append("../")
@@ -478,8 +477,9 @@ def train_with_progress_bar(session, cfg):
session (am.Session): The AnomalyMatch session
cfg (DotMap): Configuration for training
"""
- import time
import datetime
+ import time
+
from tqdm import tqdm
# Create a tqdm progress bar
@@ -566,8 +566,8 @@ def save_plot_data(data, plot_type, iteration, output_dir):
iteration: Iteration number (0 for baseline)
output_dir: Directory to save data to
"""
- import pickle
import os
+ import pickle
# Create plot_data directory if it doesn't exist
plot_data_dir = os.path.join(output_dir, "plot_data")
diff --git a/paper_scripts/prepare_datasets.py b/paper_scripts/prepare_datasets.py
index 8e23371..0a6e7a8 100644
--- a/paper_scripts/prepare_datasets.py
+++ b/paper_scripts/prepare_datasets.py
@@ -22,19 +22,20 @@
- GalaxyZoo images and training_solutions_am_2.csv in datasets/ folder
"""
-import os
-import io
import argparse
+import concurrent.futures
+import gc # Add garbage collection
+import io
+import os
+
+import h5py
import numpy as np
import pandas as pd
-import h5py
-from PIL import Image
+import pyarrow.parquet as pq
import torch
-from tqdm import tqdm
from loguru import logger
-import concurrent.futures
-import pyarrow.parquet as pq
-import gc # Add garbage collection
+from PIL import Image
+from tqdm import tqdm
# Configure basic logging
logger.remove()
diff --git a/paper_scripts/recreate_plots.py b/paper_scripts/recreate_plots.py
index 1aae0c4..80f812c 100644
--- a/paper_scripts/recreate_plots.py
+++ b/paper_scripts/recreate_plots.py
@@ -18,42 +18,43 @@
Date: April 14, 2025
"""
-import os
-import sys
import argparse
import glob
+import os
import pickle
+import sys
+from pathlib import Path
+
+import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
-import matplotlib.pyplot as plt
import seaborn as sns
-from pathlib import Path
from loguru import logger
# Add parent directory to path to import paper_plots
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from paper_scripts.paper_plots import (
- plot_score_histogram,
- plot_roc_prc_curves,
- plot_top_n_anomaly_detection,
- plot_metrics_over_time,
+ FONT_SCALE,
+ plot_astronomaly_comparison,
plot_combined_anomaly_detection,
- plot_top_n_with_thresholds,
+ plot_metrics_over_time,
+ plot_roc_prc_curves,
plot_roc_with_thresholds,
- plot_astronomaly_comparison,
+ plot_score_histogram,
plot_score_vs_user_score_grid,
- FONT_SCALE,
+ plot_top_n_anomaly_detection,
+ plot_top_n_with_thresholds,
)
from paper_scripts.plot_colors import (
BLUE,
ORANGE,
- PURPLE,
+ PERFECT_LINE_ALPHA,
PERFECT_LINE_COLOR,
PERFECT_LINE_STYLE,
- PERFECT_LINE_ALPHA,
+ PURPLE,
+ VLINE_ALPHA,
VLINE_COLOR,
VLINE_STYLE,
- VLINE_ALPHA,
)
# Set matplotlib parameters for publication-quality plots (similar to paper_plots.py)
diff --git a/paper_scripts/results_analysis.py b/paper_scripts/results_analysis.py
index b709a4f..113a6f5 100644
--- a/paper_scripts/results_analysis.py
+++ b/paper_scripts/results_analysis.py
@@ -4,9 +4,9 @@
# is part of this source code package. No part of the package, including
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
-import pandas as pd
from pathlib import Path
+import pandas as pd
MINIIMAGENET_CLASS_NAMES = {48: "Guitar", 57: "Hourglass", 68: "Printer", 85: "Piano", 95: "Orange"}
diff --git a/paper_scripts/test_plots.py b/paper_scripts/test_plots.py
index 47e57a9..b22f3ee 100644
--- a/paper_scripts/test_plots.py
+++ b/paper_scripts/test_plots.py
@@ -13,22 +13,23 @@
python test_plots.py --reload # Test loading and recreating plots from saved data
"""
-import os
import argparse
+import os
+
import numpy as np
import pandas as pd
-from PIL import Image
from paper_plots import (
- plot_score_histogram,
+ plot_astronomaly_comparison,
+ plot_combined_anomaly_detection,
plot_metrics_over_time,
plot_roc_prc_curves,
- plot_top_n_anomaly_detection,
- plot_combined_anomaly_detection,
- plot_astronomaly_comparison,
plot_roc_with_thresholds,
- plot_top_n_with_thresholds,
+ plot_score_histogram,
plot_score_vs_user_score_grid,
+ plot_top_n_anomaly_detection,
+ plot_top_n_with_thresholds,
)
+from PIL import Image
# Create output directory for test plots
output_dir = "test_plots_output"
@@ -304,9 +305,9 @@ def test_all_plots():
def test_reload_plots(plot_data_dir):
"""Test loading saved plot data and recreating plots from it."""
- import pickle
- import os
import glob
+ import os
+ import pickle
# Directory for reloaded plots
reload_dir = os.path.join(output_dir, "reloaded_plots")
diff --git a/prediction_process.py b/prediction_process.py
index 5efb448..28417c8 100644
--- a/prediction_process.py
+++ b/prediction_process.py
@@ -5,32 +5,27 @@
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
import argparse
-import os
-import pickle
-import sys
+import time
+from concurrent.futures import ThreadPoolExecutor
-from dotmap import DotMap
-import torch
import numpy as np
+import torch
from loguru import logger
-from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
-import time
from anomaly_match.data_io.load_images import (
load_and_process_single_wrapper,
)
-
+from anomaly_match.image_processing.transforms import (
+ get_prediction_transforms,
+)
from prediction_utils import (
+ clear_gpu_cache_if_needed,
load_model,
- save_results,
+ load_prediction_config,
process_batch_predictions,
- estimate_batch_size,
- clear_gpu_cache_if_needed,
-)
-
-from anomaly_match.image_processing.transforms import (
- get_prediction_transforms,
+ save_results,
+ setup_prediction_logging,
)
@@ -55,6 +50,12 @@ def evaluate_files(file_list, cfg, top_n=1000, batch_size=1000, max_workers=1):
"""Evaluate files in batches and return top N scores.
file list is a list of cfg.prediction_search_dir+filename
"""
+ if not file_list:
+ raise FileNotFoundError(
+ f"No files to evaluate. The prediction search directory "
+ f"'{cfg.prediction_search_dir}' is empty or contains no supported image files."
+ )
+
logger.trace(f"{len(file_list)} unlabeled images remain.")
# Load model first - this loads the fitsbolt config from the checkpoint
@@ -130,21 +131,7 @@ def main():
parser.add_argument("top_n", type=int, default=1000, help="Number of top scores to keep")
args = parser.parse_args()
- logger.info(f"Loading config from {args.config_path}")
- # Load cfg from pkl
- try:
- with open(args.config_path, "rb") as f:
- cfg = pickle.load(f)
- cfg = DotMap(cfg)
- except Exception as e:
- logger.error(f"Failed to load config from {args.config_path}: {e}")
- sys.exit(1)
-
- logger.info("Setting batch size")
- batch_size = (
- estimate_batch_size(cfg) if cfg.N_batch_prediction is None else cfg.N_batch_prediction
- )
- logger.info(f"Batch size set to: {batch_size}")
+ cfg, batch_size = load_prediction_config(args.config_path)
logger.info(f"Loading file list from {args.file_list_path}")
with open(args.file_list_path, "r") as f:
@@ -167,14 +154,5 @@ def main():
if __name__ == "__main__":
- # Configure logging
- logs_dir = os.path.join(os.path.dirname(__file__), "logs")
- os.makedirs(logs_dir, exist_ok=True)
- logger.remove()
- logger.add(
- os.path.join(logs_dir, "prediction_thread_{time}.log"),
- rotation="1 MB",
- format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}",
- level="DEBUG",
- )
+ setup_prediction_logging("prediction_thread")
main()
diff --git a/prediction_process_cutana.py b/prediction_process_cutana.py
index 7a661e2..b248721 100644
--- a/prediction_process_cutana.py
+++ b/prediction_process_cutana.py
@@ -5,70 +5,84 @@
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
import argparse
+import glob
import os
-import sys
-import pickle
+import time
-from dotmap import DotMap
-import torch
+import cutana
import numpy as np
+import pandas as pd
+import torch
+from cutana.catalogue_preprocessor import extract_filter_name, parse_fits_file_paths
+from dotmap import DotMap
from loguru import logger
-from concurrent.futures import ThreadPoolExecutor
-import time
from tqdm import tqdm
-import cutana
-
-from anomaly_match.data_io.load_images import process_single_wrapper
+from anomaly_match.image_processing.transforms import (
+ get_prediction_transforms,
+)
from prediction_utils import (
+ clear_gpu_cache_if_needed,
+ convert_cutana_cutout,
+ create_cutana_format_cfg,
load_model,
- save_results,
+ load_prediction_config,
process_batch_predictions,
- estimate_batch_size,
- clear_gpu_cache_if_needed,
-)
-
-from anomaly_match.image_processing.transforms import (
- get_prediction_transforms,
+ save_results,
+ setup_prediction_logging,
)
-def read_and_preprocess_image_from_zarr(image_data, cfg):
- """Read and preprocess image data from Zarr array using standardized functions."""
- try:
- # Convert Zarr data to numpy array if it's not already
- if not isinstance(image_data, np.ndarray):
- image_data = np.array(image_data)
-
- # Check if we need to transpose based on the shape
- # If last dimension is 3 (RGB channels), data is already in HWC format
- # If first dimension is 3, data is in CHW format and needs transposing
- if image_data.shape[0] == cfg.normalisation.n_output_channels:
- # In CHW format, convert to HWC
- image = image_data.transpose(1, 2, 0)
- else:
- # Assume HWC format if neither first nor last dimension is 3
- # This handles grayscale or other formats
- image = image_data
-
- # Use the centralized processing function - this handles RGB conversion,
- # normalization, and resizing efficiently without temporary files
- processed_image = process_single_wrapper(image, cfg, desc="zarr")
- return processed_image
+def _resolve_filter_names_from_catalogue(catalogue_path, n_extensions):
+ """Resolve filter/band names from the first source in a cutana catalogue.
- except Exception as e:
- logger.error(f"Error processing image from Zarr: {e}")
- raise
+ Cutana streaming predictions operate on large mosaic tiles (separate FITS
+ files per band) referenced by the catalogue's ``fits_file_paths`` column.
+ Filter names are extracted from file paths using Euclid naming conventions.
+ Raises:
+ ValueError: If filter names cannot be determined — this indicates the
+ catalogue uses non-Euclid file naming. In that case users must set
+ ``cfg.normalisation.fits_extension`` to explicit filter name strings.
+ """
+ # catalogue_path may be a single file (buffer parquet) or a directory
+ if os.path.isfile(catalogue_path):
+ first_file = catalogue_path
+ else:
+ cat_files = sorted(glob.glob(os.path.join(catalogue_path, "*.parquet")))
+ if not cat_files:
+ cat_files = sorted(glob.glob(os.path.join(catalogue_path, "*.csv")))
+ if not cat_files:
+ raise FileNotFoundError(f"No catalogue files found in {catalogue_path}")
+ first_file = cat_files[0]
+
+ # Read just the first row to get fits_file_paths
+ if first_file.endswith(".parquet"):
+ df = pd.read_parquet(first_file, columns=["fits_file_paths"]).head(1)
+ else:
+ df = pd.read_csv(first_file, usecols=["fits_file_paths"], nrows=1)
+
+ fits_paths = parse_fits_file_paths(df["fits_file_paths"].iloc[0])
+ if len(fits_paths) != n_extensions:
+ raise ValueError(
+ f"Catalogue has {len(fits_paths)} FITS files per source but "
+ f"cfg.normalisation.fits_extension specifies {n_extensions} extensions."
+ )
-def load_and_preprocess_zarr(args):
- """Load and preprocess a single image from Zarr.
+ filter_names = [extract_filter_name(p) for p in fits_paths]
+ unknown = [p for p, name in zip(fits_paths, filter_names) if name == "UNKNOWN"]
+ if unknown:
+ raise ValueError(
+ f"Could not determine filter names from catalogue FITS file paths. "
+ f"Cutana streaming predictions currently only support Euclid data with "
+ f"standard file naming conventions (VIS, NIR-H, NIR-Y, NIR-J). "
+ f"Unrecognised files: {unknown}. "
+ f"If using non-Euclid data, set cfg.normalisation.fits_extension to "
+ f"explicit filter name strings instead of integer indices."
+ )
- Note: Returns numpy array, not tensor. Tensor conversion is done on main
- thread to avoid CUDA context issues in ThreadPoolExecutor.
- """
- image_data, cfg = args
- return read_and_preprocess_image_from_zarr(image_data, cfg)
+ logger.info(f"Resolved filter names from catalogue: {filter_names}")
+ return filter_names
def evaluate_images_from_cutana(
@@ -81,37 +95,104 @@ def evaluate_images_from_cutana(
cutana_config.target_resolution = cfg.normalisation.image_size[0]
cutana_config.source_catalogue = cutana_sources_path
- # Configure FITS extensions from AM config, default to PRIMARY if not specified
- # fits_extension can be: None, str/int, list of str/int, or list of tuples (name, ext_type)
+ # Configure FITS extensions for cutana.
+ #
+ # AnomalyMatch's fits_extension uses integer HDU indices (for multi-extension
+ # cutout files loaded by fitsbolt). Cutana operates on large mosaic tiles
+ # referenced in the catalogue — each source has separate FITS files per band.
+ # Cutana identifies bands by filter name (e.g. "VIS", "NIR-H") extracted from
+ # the file paths, so we must resolve integer indices to filter names here.
+ #
+ # NOTE: filter name extraction currently relies on Euclid naming conventions
+ # (via cutana.catalogue_preprocessor.extract_filter_name). If your catalogue
+ # uses non-Euclid file naming, set cfg.normalisation.fits_extension to
+ # explicit filter name strings instead of integer indices.
fits_ext = cfg.normalisation.fits_extension
if fits_ext is None:
fits_ext = ["PRIMARY"]
elif isinstance(fits_ext, (str, int)):
fits_ext = [fits_ext]
- # Build selected_extensions - handle both simple names and (name, ext_type) tuples
- selected_extensions = []
- extension_names = []
- for ext in fits_ext:
- if isinstance(ext, tuple):
- name, ext_type = ext
- selected_extensions.append({"name": str(name), "ext": ext_type})
- extension_names.append(name)
+ # When fits_extension contains integers, resolve to filter names from the
+ # catalogue's fits_file_paths column.
+ has_integer_indices = any(isinstance(e, int) for e in fits_ext)
+ if has_integer_indices:
+ if len(fits_ext) > 1:
+ extension_names = _resolve_filter_names_from_catalogue(
+ cutana_sources_path, len(fits_ext)
+ )
else:
- selected_extensions.append({"name": str(ext), "ext": "PrimaryHDU"})
- extension_names.append(ext)
+ # Single integer index (e.g. [0]) maps to the PRIMARY HDU
+ extension_names = ["PRIMARY"]
+ else:
+ extension_names = [str(e) for e in fits_ext]
+
+ # Build selected_extensions for cutana.
+ # For multi-file catalogues (separate FITS per band), each file has only a
+ # PRIMARY HDU, so fits_extensions must be ["PRIMARY"]. The filter names go
+ # into channel_weights and selected_extensions for channel identification.
+ if has_integer_indices:
+ cutana_config.fits_extensions = ["PRIMARY"]
+ else:
+ cutana_config.fits_extensions = extension_names
- cutana_config.fits_extensions = extension_names
+ selected_extensions = []
+ for name in extension_names:
+ selected_extensions.append({"name": name, "ext": "PRIMARY"})
cutana_config.selected_extensions = selected_extensions
- # Pass channel combination - required for multi-extension data
+ # Build channel_weights dict for cutana from AM's channel configuration.
+ # Cutana expects {"ext_name": [weight_per_output_channel, ...], ...}.
+ # Channel combination must happen BEFORE normalisation (cutana's pipeline
+ # ensures this) so that ZSCALE/ASINH see the same data shape as training.
+ n_out = cfg.normalisation.n_output_channels
if cfg.normalisation.channel_combination is not None:
- cutana_config.channel_weights = cfg.normalisation.channel_combination
+ # Multi-extension: convert numpy matrix (n_out x n_in) to cutana dict
+ combo = cfg.normalisation.channel_combination
+ channel_weights = {}
+ for j, ext_name in enumerate(extension_names):
+ channel_weights[str(ext_name)] = combo[:, j].tolist()
+ cutana_config.channel_weights = channel_weights
elif len(fits_ext) > 1:
raise ValueError(
"cfg.normalisation.channel_combination must be set when using multiple FITS extensions. "
"This defines how extensions are combined into RGB channels."
)
+ else:
+ # Single extension: replicate to n_output_channels (e.g. 1→3 for RGB)
+ cutana_config.channel_weights = {str(extension_names[0]): [1.0] * n_out}
+
+ # Verify channel configuration consistency
+ if len(extension_names) > 1:
+ combo = cfg.normalisation.channel_combination
+ n_in = combo.shape[1] if hasattr(combo, "shape") else len(extension_names)
+ if len(extension_names) != n_in:
+ raise ValueError(
+ f"Number of resolved filter names ({len(extension_names)}) does not match "
+ f"channel_combination input dimension ({n_in}). "
+ f"Filter names: {extension_names}, matrix shape: {combo.shape}"
+ )
+ if combo.shape[0] != n_out:
+ raise ValueError(
+ f"channel_combination output dimension ({combo.shape[0]}) does not match "
+ f"n_output_channels ({n_out})"
+ )
+ # For non-diagonal matrices, verify all input channels contribute
+ # (a zero column means an extension is loaded but never used)
+ for j, ext_name in enumerate(extension_names):
+ col_sum = abs(combo[:, j]).sum()
+ if col_sum == 0:
+ logger.warning(
+ f"Extension '{ext_name}' (column {j}) has zero weight in "
+ f"channel_combination — this channel will be loaded but ignored"
+ )
+ logger.info(
+ f"Channel configuration: {len(extension_names)} inputs -> {n_out} outputs, "
+ f"filter order: {extension_names}"
+ )
+
+ # Flux conversion: must match the training path setting
+ cutana_config.apply_flux_conversion = cfg.normalisation.apply_flux_conversion
# Pass AnomalyMatch's fitsbolt_cfg directly to cutana for normalization
# This ensures cutana uses the exact same normalization settings as training
@@ -142,7 +223,7 @@ def evaluate_images_from_cutana(
model = load_model(cfg)
model.eval()
- transform = get_prediction_transforms()
+ transform = get_prediction_transforms(num_channels=n_out)
# Process images in batches
scores_list = []
@@ -164,13 +245,15 @@ def evaluate_images_from_cutana(
)
logger.debug("Using fitsbolt config loaded from model checkpoint")
+ # CONVERSION_ONLY config for format conversion (created once, reused per cutout)
+ format_cfg = create_cutana_format_cfg(cfg)
+
batches_count = cutana_orchestrator.get_batch_count()
num_images = 0
filenames = []
for batch_idx in tqdm(range(batches_count), desc="Processing batches"):
-
loaded_batch = cutana_orchestrator.next_batch()
batch_data = loaded_batch["cutouts"]
@@ -194,12 +277,12 @@ def evaluate_images_from_cutana(
batch_filenames = (source["source_id"] for source in loaded_batch["metadata"])
filenames.extend(batch_filenames)
- # I/O and preprocessing in ThreadPool (returns numpy arrays)
- # CUDA operations are kept on main thread to prevent memory fragmentation
+ # Cutana already normalised the cutouts via external_fitsbolt_cfg.
+ # Only format conversion is needed (CHW→HWC, dtype, channel replication).
batch_process_start = time.time()
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
- batch_args = [(batch_data[i], cfg) for i in range(batch_size_actual)]
- numpy_images = list(executor.map(load_and_preprocess_zarr, batch_args))
+ numpy_images = [
+ convert_cutana_cutout(batch_data[i], format_cfg) for i in range(batch_size_actual)
+ ]
# Tensor conversion on main thread (not in ThreadPool) to avoid CUDA context issues
stack_start = time.time()
@@ -261,36 +344,7 @@ def main():
parser.add_argument("top_n", type=int, default=1000, help="Number of top scores to keep")
args = parser.parse_args()
- logger.info(f"Loading config from {args.config_path}")
- # Load cfg from pkl
- try:
- with open(args.config_path, "rb") as f:
- cfg = pickle.load(f)
- cfg = DotMap(cfg)
- except Exception as e:
- logger.error(f"Failed to load config from {args.config_path}: {e}")
- sys.exit(1)
-
- logger.info("Setting batch size")
- batch_size = (
- estimate_batch_size(cfg) if cfg.N_batch_prediction is None else cfg.N_batch_prediction
- )
- logger.info(f"Batch size set to: {batch_size}")
-
- # Log key configuration parameters
- logger.debug("Configuration loaded with parameters:")
- logger.debug(f" Save file: {cfg.save_file}")
- logger.debug(f" Save path: {cfg.save_path}")
- logger.debug(f" Model path: {cfg.model_path}")
- logger.debug(f" Output directory: {cfg.output_dir}")
- logger.debug(f" Image size: {cfg.normalisation.image_size}")
-
- # Log full configuration
- logger.debug("Full configuration:")
- logger.debug(f"{cfg.toDict()}")
-
- # Create output directory if it doesn't exist
- os.makedirs(cfg.output_dir, exist_ok=True)
+ cfg, batch_size = load_prediction_config(args.config_path)
logger.info(f"Streaming from directory: {args.cutana_sources_path}")
@@ -306,18 +360,5 @@ def main():
if __name__ == "__main__":
-
- # Configure logging
- logs_dir = os.path.join(os.path.dirname(__file__), "logs")
- os.makedirs(logs_dir, exist_ok=True)
-
- # Remove default handler and set up file logging
- logger.remove()
- script_logger_id = logger.add(
- os.path.join(logs_dir, "prediction_cutana_{time}.log"),
- rotation="1 MB",
- format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}",
- level="DEBUG",
- )
- logger.add(sys.stderr, level="INFO")
+ setup_prediction_logging("prediction_cutana")
main()
diff --git a/prediction_process_hdf5.py b/prediction_process_hdf5.py
index de87f33..112c744 100644
--- a/prediction_process_hdf5.py
+++ b/prediction_process_hdf5.py
@@ -5,32 +5,27 @@
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
import argparse
-import os
-import pickle
-import sys
+import time
+from concurrent.futures import ThreadPoolExecutor
-from dotmap import DotMap
-import torch
+import h5py
import numpy as np
+import torch
from loguru import logger
-from concurrent.futures import ThreadPoolExecutor
-import time
-import h5py
from tqdm import tqdm
from anomaly_match.data_io.load_images import process_single_wrapper
-
+from anomaly_match.image_processing.transforms import (
+ get_prediction_transforms,
+)
from prediction_utils import (
- load_model,
- save_results,
- process_batch_predictions,
clear_gpu_cache_if_needed,
jpeg_decoder,
- estimate_batch_size,
-)
-
-from anomaly_match.image_processing.transforms import (
- get_prediction_transforms,
+ load_model,
+ load_prediction_config,
+ process_batch_predictions,
+ save_results,
+ setup_prediction_logging,
)
@@ -48,9 +43,10 @@ def read_and_decode_image_from_hdf5(image_data, cfg):
image = image[:, :, [2, 1, 0]]
except Exception:
# If TurboJPEG fails, fall back to PIL
- from PIL import Image
import io
+ from PIL import Image
+
image = np.array(Image.open(io.BytesIO(image_bytes)))
processed_image = process_single_wrapper(image, cfg, desc="hdf5")
@@ -172,32 +168,7 @@ def main():
parser.add_argument("top_n", type=int, default=1000, help="Number of top scores to keep")
args = parser.parse_args()
- logger.info(f"Loading config from {args.config_path}")
- # Load cfg from pkl
- try:
- with open(args.config_path, "rb") as f:
- cfg = pickle.load(f)
- cfg = DotMap(cfg)
- except Exception as e:
- logger.error(f"Failed to load config from {args.config_path}: {e}")
- sys.exit(1)
-
- logger.info("Setting batch size")
- batch_size = (
- estimate_batch_size(cfg) if cfg.N_batch_prediction is None else cfg.N_batch_prediction
- )
- logger.info(f"Batch size set to: {batch_size}")
-
- # Log key configuration parameters
- logger.debug("Configuration loaded with parameters:")
- logger.debug(f" Save file: {cfg.save_file}")
- logger.debug(f" Save path: {cfg.save_path}")
- logger.debug(f" Model path: {cfg.model_path}")
- logger.debug(f" Output directory: {cfg.output_dir}")
- logger.debug(f" Image size: {cfg.normalisation.image_size}")
-
- # Create output directory if it doesn't exist
- os.makedirs(cfg.output_dir, exist_ok=True)
+ cfg, batch_size = load_prediction_config(args.config_path)
logger.info(f"Processing HDF5 file: {args.hdf5_path}")
@@ -211,18 +182,5 @@ def main():
if __name__ == "__main__":
-
- # Configure logging
- logs_dir = os.path.join(os.path.dirname(__file__), "logs")
- os.makedirs(logs_dir, exist_ok=True)
-
- # Remove default handler and set up file logging
- logger.remove()
- logger.add(
- os.path.join(logs_dir, "prediction_thread_{time}.log"),
- rotation="1 MB",
- format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}",
- level="DEBUG",
- )
- logger.add(sys.stderr, level="INFO")
+ setup_prediction_logging("prediction_hdf5")
main()
diff --git a/prediction_process_zarr.py b/prediction_process_zarr.py
index e375b32..f4a1c26 100644
--- a/prediction_process_zarr.py
+++ b/prediction_process_zarr.py
@@ -5,74 +5,31 @@
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
import argparse
-import os
-import sys
-import pickle
+import time
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
-from dotmap import DotMap
-import torch
import numpy as np
-from loguru import logger
-from concurrent.futures import ThreadPoolExecutor
-import time
-import zarr
import pandas as pd
-from pathlib import Path
+import torch
+import zarr
+from loguru import logger
from tqdm import tqdm
-from anomaly_match.data_io.load_images import process_single_wrapper
-
+from anomaly_match.image_processing.transforms import (
+ get_prediction_transforms,
+)
from prediction_utils import (
+ clear_gpu_cache_if_needed,
+ load_and_preprocess_zarr,
load_model,
- save_results,
+ load_prediction_config,
process_batch_predictions,
- estimate_batch_size,
- clear_gpu_cache_if_needed,
-)
-
-from anomaly_match.image_processing.transforms import (
- get_prediction_transforms,
+ save_results,
+ setup_prediction_logging,
)
-def read_and_preprocess_image_from_zarr(image_data, cfg):
- """Read and preprocess image data from Zarr array using standardized functions."""
- try:
- # Convert Zarr data to numpy array if it's not already
- if not isinstance(image_data, np.ndarray):
- image_data = np.array(image_data)
-
- # Check if we need to transpose based on the shape
- # If last dimension is 3 (RGB channels), data is already in HWC format
- # If first dimension is 3, data is in CHW format and needs transposing
- if image_data.shape[0] == cfg.normalisation.n_output_channels:
- # In CHW format, convert to HWC
- image = image_data.transpose(1, 2, 0)
- else:
- # Assume HWC format if neither first nor last dimension is 3
- # This handles grayscale or other formats
- image = image_data
-
- # Use the centralized processing function - this handles RGB conversion,
- # normalization, and resizing efficiently without temporary files
- processed_image = process_single_wrapper(image, cfg, desc="zarr")
- return processed_image
-
- except Exception as e:
- logger.error(f"Error processing image from Zarr: {e}")
- raise
-
-
-def load_and_preprocess_zarr(args):
- """Load and preprocess a single image from Zarr.
-
- Note: Returns numpy array, not tensor. Tensor conversion is done on main
- thread to avoid CUDA context issues in ThreadPoolExecutor.
- """
- image_data, cfg = args
- return read_and_preprocess_image_from_zarr(image_data, cfg)
-
-
def evaluate_images_in_zarr(zarr_path, cfg, top_n=1000, batch_size=1000, max_workers=4):
"""Evaluate images inside a Zarr file and return top N scores."""
logger.info(f"Opening Zarr file {zarr_path}")
@@ -244,36 +201,7 @@ def main():
parser.add_argument("top_n", type=int, default=1000, help="Number of top scores to keep")
args = parser.parse_args()
- logger.info(f"Loading config from {args.config_path}")
- # Load cfg from pkl
- try:
- with open(args.config_path, "rb") as f:
- cfg = pickle.load(f)
- cfg = DotMap(cfg)
- except Exception as e:
- logger.error(f"Failed to load config from {args.config_path}: {e}")
- sys.exit(1)
-
- logger.info("Setting batch size")
- batch_size = (
- estimate_batch_size(cfg) if cfg.N_batch_prediction is None else cfg.N_batch_prediction
- )
- logger.info(f"Batch size set to: {batch_size}")
-
- # Log key configuration parameters
- logger.debug("Configuration loaded with parameters:")
- logger.debug(f" Save file: {cfg.save_file}")
- logger.debug(f" Save path: {cfg.save_path}")
- logger.debug(f" Model path: {cfg.model_path}")
- logger.debug(f" Output directory: {cfg.output_dir}")
- logger.debug(f" Image size: {cfg.normalisation.image_size}")
-
- # Log full configuration
- logger.debug("Full configuration:")
- logger.debug(f"{cfg.toDict()}")
-
- # Create output directory if it doesn't exist
- os.makedirs(cfg.output_dir, exist_ok=True)
+ cfg, batch_size = load_prediction_config(args.config_path)
logger.info(f"Processing Zarr file: {args.zarr_path}")
@@ -287,17 +215,5 @@ def main():
if __name__ == "__main__":
- # Configure logging
- logs_dir = os.path.join(os.path.dirname(__file__), "logs")
- os.makedirs(logs_dir, exist_ok=True)
-
- # Remove default handler and set up file logging
- logger.remove()
- logger.add(
- os.path.join(logs_dir, "prediction_zarr_{time}.log"),
- rotation="1 MB",
- format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}",
- level="DEBUG",
- )
- logger.add(sys.stderr, level="INFO")
+ setup_prediction_logging("prediction_zarr")
main()
diff --git a/prediction_utils.py b/prediction_utils.py
index 55226d0..6ce6348 100644
--- a/prediction_utils.py
+++ b/prediction_utils.py
@@ -13,13 +13,19 @@
"""
import os
-import torch
+import pickle
+import sys
+
import numpy as np
import pandas as pd
-
+import torch
+from dotmap import DotMap
+from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
from loguru import logger
from turbojpeg import TurboJPEG
+from anomaly_match.data_io.load_images import get_fitsbolt_config, process_single_wrapper
+from anomaly_match.utils.get_default_cfg import get_default_cfg
# Initialize TurboJPEG
jpeg_decoder = TurboJPEG()
@@ -124,13 +130,7 @@ def estimate_batch_size(
# usable_vram - c = B * (a * S² + b)
# B = (usable_vram - c) / (a * S² + b)
S2 = cfg.normalisation.image_size[0] * cfg.normalisation.image_size[1]
- # Use num_channels if set, otherwise fall back to normalisation.n_output_channels
- num_channels = (
- cfg.num_channels
- if isinstance(cfg.num_channels, int)
- else cfg.normalisation.n_output_channels
- )
- denominator = coef["a"] * S2 * num_channels + coef["b"]
+ denominator = coef["a"] * S2 * cfg.num_channels + coef["b"]
if denominator <= 0:
logger.warning("Invalid memory model parameters, returning minimum batch size")
@@ -174,18 +174,12 @@ def load_model(cfg):
from anomaly_match.utils.get_net_builder import get_net_builder
- # Use num_channels if set, otherwise fall back to normalisation.n_output_channels
- num_channels = (
- cfg.num_channels
- if isinstance(cfg.num_channels, int)
- else cfg.normalisation.n_output_channels
- )
net_builder = get_net_builder(
cfg.net,
pretrained=cfg.pretrained,
- in_channels=num_channels,
+ in_channels=cfg.num_channels,
)
- model = net_builder(num_classes=2, in_channels=num_channels)
+ model = net_builder(num_classes=2, in_channels=cfg.num_channels)
if torch.cuda.is_available():
gpu_device = getattr(cfg, "gpu", 0) # Default to 0 if not set
@@ -449,6 +443,174 @@ def _ensure_consistent_image_format(images):
return images
+def setup_prediction_logging(log_name):
+ """Set up logging for prediction scripts.
+
+ Configures file logging with rotation and stderr output. Also adds
+ session-specific logging if a config path is available in sys.argv.
+
+ Args:
+ log_name: Name used for the log file (e.g. "prediction_thread",
+ "prediction_zarr", "prediction_cutana").
+ """
+ logs_dir = os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), "logs")
+ os.makedirs(logs_dir, exist_ok=True)
+
+ logger.remove()
+ logger.add(
+ os.path.join(logs_dir, f"{log_name}_{{time}}.log"),
+ rotation="1 MB",
+ format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}",
+ level="DEBUG",
+ )
+ logger.add(sys.stderr, level="INFO")
+
+ # Also log to session output directory if available
+ if len(sys.argv) > 1:
+ try:
+ with open(sys.argv[1], "rb") as _f:
+ _pre_cfg = DotMap(pickle.load(_f))
+ if _pre_cfg.output_dir:
+ os.makedirs(_pre_cfg.output_dir, exist_ok=True)
+ logger.add(
+ os.path.join(_pre_cfg.output_dir, "prediction.log"),
+ rotation="10 MB",
+ format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}",
+ level="DEBUG",
+ )
+ except Exception:
+ pass
+
+
+def load_prediction_config(config_path):
+ """Load prediction config from pickle file and compute batch size.
+
+ Args:
+ config_path: Path to the pickled config file.
+
+ Returns:
+ tuple: (cfg, batch_size) where cfg is a DotMap config object
+ and batch_size is the computed or configured batch size.
+ """
+ logger.info(f"Loading config from {config_path}")
+ try:
+ with open(config_path, "rb") as f:
+ cfg = pickle.load(f)
+ cfg = DotMap(cfg)
+ except Exception as e:
+ logger.error(f"Failed to load config from {config_path}: {e}")
+ sys.exit(1)
+
+ logger.info("Setting batch size")
+ batch_size = (
+ estimate_batch_size(cfg) if cfg.N_batch_prediction is None else cfg.N_batch_prediction
+ )
+ logger.info(f"Batch size set to: {batch_size}")
+
+ # Log key configuration parameters
+ logger.debug("Configuration loaded with parameters:")
+ logger.debug(f" Save file: {cfg.save_file}")
+ logger.debug(f" Save path: {cfg.save_path}")
+ logger.debug(f" Model path: {cfg.model_path}")
+ logger.debug(f" Output directory: {cfg.output_dir}")
+ logger.debug(f" Image size: {cfg.normalisation.image_size}")
+
+ # Log full configuration
+ logger.debug("Full configuration:")
+ logger.debug(f"{cfg.toDict()}")
+
+ # Create output directory if it doesn't exist
+ os.makedirs(cfg.output_dir, exist_ok=True)
+
+ return cfg, batch_size
+
+
+def create_cutana_format_cfg(cfg):
+ """Create a CONVERSION_ONLY fitsbolt config for cutana format conversion.
+
+ This config is used by convert_cutana_cutout to handle dtype and channel
+ conversion without re-applying normalisation. Callers should create this
+ once and pass it to convert_cutana_cutout for each image.
+ """
+ format_cfg = get_default_cfg()
+ format_cfg.normalisation.image_size = cfg.normalisation.image_size
+ format_cfg.normalisation.n_output_channels = cfg.normalisation.n_output_channels
+ format_cfg.normalisation.normalisation_method = NormalisationMethod.CONVERSION_ONLY
+ format_cfg.normalisation.norm_asinh_scale = cfg.normalisation.norm_asinh_scale
+ format_cfg.normalisation.norm_asinh_clip = cfg.normalisation.norm_asinh_clip
+ format_cfg.num_workers = 0
+ return get_fitsbolt_config(format_cfg)
+
+
+def convert_cutana_cutout(image_data, format_cfg):
+ """Convert a cutana-normalised cutout to the format expected by the model.
+
+ Cutana already applies normalisation via external_fitsbolt_cfg, so this
+ function only performs format conversion (dtype, channel replication) using
+ fitsbolt's CONVERSION_ONLY mode — no normalisation stretch is applied.
+
+ Args:
+ image_data: Cutana cutout array (already normalised).
+ format_cfg: CONVERSION_ONLY config from create_cutana_format_cfg.
+
+ Returns:
+ np.ndarray: Image in HWC uint8 format ready for model inference.
+ """
+ if not isinstance(image_data, np.ndarray):
+ image_data = np.array(image_data)
+
+ # CHW → HWC (cutana may deliver CHW depending on config)
+ if image_data.ndim == 3 and image_data.shape[0] <= 4 and image_data.shape[2] > 4:
+ image_data = image_data.transpose(1, 2, 0)
+
+ # Delegate dtype conversion and channel replication to fitsbolt via
+ # process_single_wrapper with a CONVERSION_ONLY config so the already-
+ # normalised pixel values are preserved (only dtype + channels change).
+ return process_single_wrapper(image_data, format_cfg, desc="cutana")
+
+
+def read_and_preprocess_image_from_zarr(image_data, cfg):
+ """Read and preprocess raw image data from a Zarr array.
+
+ Handles CHW/HWC format detection and uses fitsbolt for normalization.
+ Used by the zarr prediction script for data that has NOT been normalised yet.
+ """
+ try:
+ # Convert Zarr data to numpy array if it's not already
+ if not isinstance(image_data, np.ndarray):
+ image_data = np.array(image_data)
+
+ # Check if we need to transpose based on the shape
+ # If last dimension is 3 (RGB channels), data is already in HWC format
+ # If first dimension is 3, data is in CHW format and needs transposing
+ if image_data.shape[0] == cfg.normalisation.n_output_channels:
+ # In CHW format, convert to HWC
+ image = image_data.transpose(1, 2, 0)
+ else:
+ # Assume HWC format if neither first nor last dimension is 3
+ # This handles grayscale or other formats
+ image = image_data
+
+ # Use the centralized processing function - this handles RGB conversion,
+ # normalization, and resizing efficiently without temporary files
+ processed_image = process_single_wrapper(image, cfg, desc="zarr")
+ return processed_image
+
+ except Exception as e:
+ logger.error(f"Error processing image from Zarr: {e}")
+ raise
+
+
+def load_and_preprocess_zarr(args):
+ """Load and preprocess a single image from Zarr.
+
+ Note: Returns numpy array, not tensor. Tensor conversion is done on main
+ thread to avoid CUDA context issues in ThreadPoolExecutor.
+ """
+ image_data, cfg = args
+ return read_and_preprocess_image_from_zarr(image_data, cfg)
+
+
def process_batch_predictions(model, images, original_images=None):
"""Process a batch of images through the model to get anomaly scores.
diff --git a/pyproject.toml b/pyproject.toml
index 031d175..3e25184 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "anomaly_match"
-version = "1.2.0"
+version = "1.3.0"
description = "A tool for anomaly detection in images using semi-supervised and active learning with a GUI"
readme = "README.md"
license = { file = "LICENSE.txt" }
@@ -32,51 +32,66 @@ classifiers = [
dependencies = [
"albumentations",
"astropy",
+ "cutana>=0.2.1",
"dotmap",
- "efficientnet-pytorch",
- "efficientnet_lite_pytorch",
- "efficientnet_lite0_pytorch_model",
+ "fitsbolt>=0.2",
"h5py",
"ipykernel",
"ipywidgets",
- "imageio",
"loguru",
"matplotlib",
"numpy",
"opencv-python-headless",
"pandas<3",
+ "psutil",
+ "pyarrow",
"pyturbojpeg",
- "scikit-learn",
"scikit-image",
- "fitsbolt>=0.1.6",
+ "scikit-learn",
+ "scipy",
+ "timm",
"toml",
- "torch>=2.6",
+ "torch",
"torchvision",
"tqdm",
"zarr>=3.0.0b0",
- "cutana>=0.2.1",
]
[project.optional-dependencies]
-dev = ["pytest", "pytest-cov", "black", "flake8", "mypy", "vulture>=2.10"]
+dev = ["pytest", "pytest-asyncio", "pytest-cov", "ruff", "mypy", "vulture>=2.10"]
-[tool.setuptools]
-packages = ["anomaly_match"]
+[tool.setuptools.packages.find]
+include = ["anomaly_match*"]
[tool.setuptools.package-dir]
"" = "."
-[tool.black]
+[tool.ruff]
line-length = 100
-target-version = ['py311']
+target-version = "py311"
+
+[tool.ruff.lint]
+select = ["E", "F", "W", "I"]
+ignore = ["E402", "E203", "E501"]
+
+[tool.ruff.lint.isort]
+known-first-party = ["anomaly_match", "anomaly_match_ui"]
[tool.pytest.ini_options]
+minversion = "7.0"
testpaths = ["tests"]
+pythonpath = ["."]
python_files = ["test_*.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
+markers = [
+ "slow: tests that train models or take >10s (skip with -m 'not slow')",
+ "ui: tests requiring ipywidgets/Jupyter",
+]
+asyncio_mode = "strict"
+asyncio_default_fixture_loop_scope = "function"
[tool.vulture]
min_confidence = 60
-paths = ["anomaly_match"]
+paths = ["anomaly_match", "anomaly_match_ui"]
exclude = ["tests/", "docs/", "examples/", "paper_scripts/"]
diff --git a/scripts/generate_4ch_test_data.py b/scripts/generate_4ch_test_data.py
new file mode 100644
index 0000000..118df14
--- /dev/null
+++ b/scripts/generate_4ch_test_data.py
@@ -0,0 +1,128 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Generate 4-channel test images with distinct per-channel content.
+
+Each channel has a different geometric pattern so multispectral images
+are visually distinguishable from grayscale:
+ Ch0: Concentric circles (radial gradient)
+ Ch1: Diagonal stripes
+ Ch2: Checkerboard pattern
+ Ch3: Gaussian blob (random position per image)
+"""
+
+import os
+
+import numpy as np
+import pandas as pd
+import tifffile
+
+
+def make_radial_gradient(size, center_x, center_y, scale=1.0):
+ """Concentric circles centered at (center_x, center_y)."""
+ y, x = np.mgrid[0:size, 0:size]
+ dist = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2)
+ return np.clip((np.cos(dist * scale * 0.1) * 0.5 + 0.5) * 255, 0, 255).astype(np.uint8)
+
+
+def make_diagonal_stripes(size, frequency=0.05, angle=0.0):
+ """Diagonal stripe pattern."""
+ y, x = np.mgrid[0:size, 0:size]
+ val = np.sin(2 * np.pi * frequency * (x * np.cos(angle) + y * np.sin(angle)))
+ return np.clip((val * 0.5 + 0.5) * 255, 0, 255).astype(np.uint8)
+
+
+def make_checkerboard(size, block_size=16):
+ """Checkerboard pattern."""
+ y, x = np.mgrid[0:size, 0:size]
+ checker = ((x // block_size) + (y // block_size)) % 2
+ return (checker * 255).astype(np.uint8)
+
+
+def make_gaussian_blob(size, cx, cy, sigma=20, intensity=200):
+ """Gaussian blob at (cx, cy)."""
+ y, x = np.mgrid[0:size, 0:size]
+ blob = intensity * np.exp(-((x - cx) ** 2 + (y - cy) ** 2) / (2 * sigma**2))
+ return np.clip(blob, 0, 255).astype(np.uint8)
+
+
+def main():
+ script_dir = os.path.dirname(os.path.abspath(__file__))
+ repo_root = os.path.dirname(script_dir)
+ ms_subdir = os.path.join(repo_root, "tests", "test_data", "multispectral_4ch")
+ os.makedirs(ms_subdir, exist_ok=True)
+
+ # Clean old files
+ for f in os.listdir(ms_subdir):
+ os.remove(os.path.join(ms_subdir, f))
+
+ size = 64
+ n_images = 10
+ np.random.seed(42)
+
+ source_names = [
+ "Abell2390_VIS_2",
+ "Abell2390_VIS_3",
+ "Abell2390_VIS_5",
+ "Abell2390_VIS_6",
+ "Abell2390_VIS_2021",
+ "Abell2390_VIS_2172",
+ "Abell2390_VIS_4989",
+ "Abell2390_VIS_5260",
+ "Abell2390_VIS_5783",
+ "Abell2390_VIS_6212",
+ ]
+
+ generated_filenames = []
+
+ for i, name in enumerate(source_names[:n_images]):
+ # Each image gets slightly different parameters for variety
+ cx = np.random.randint(size // 4, 3 * size // 4)
+ cy = np.random.randint(size // 4, 3 * size // 4)
+ angle = np.random.uniform(0, np.pi)
+ block = np.random.choice([8, 12, 16])
+
+ ch0 = make_radial_gradient(size, cx, cy, scale=1.0 + i * 0.3)
+ ch1 = make_diagonal_stripes(size, frequency=0.04 + i * 0.01, angle=angle)
+ ch2 = make_checkerboard(size, block_size=block)
+ ch3 = make_gaussian_blob(size, cx, cy, sigma=10 + i * 2)
+
+ ms_img = np.stack([ch0, ch1, ch2, ch3], axis=-1) # (H, W, 4)
+
+ filename = f"{name}_4ch.tiff"
+ tifffile.imwrite(os.path.join(ms_subdir, filename), ms_img)
+ generated_filenames.append(filename)
+ print(f" Created {filename} with shape {ms_img.shape}")
+
+ # Create labeled_data.csv - label first 6, leave 4 unlabeled
+ labeled_filenames = generated_filenames[:6]
+ labels = ["anomaly"] * 3 + ["normal"] * 3
+ df = pd.DataFrame({"filename": labeled_filenames, "label": labels})
+ df.to_csv(os.path.join(ms_subdir, "labeled_data.csv"), index=False)
+ print(f"Created labeled_data.csv with {len(labeled_filenames)} labeled entries")
+
+ # Create metadata.csv for all images
+ metadata_rows = []
+ base_ra, base_dec = 328.4034, 17.6950
+ for i, filename in enumerate(generated_filenames):
+ metadata_rows.append(
+ {
+ "filename": filename,
+ "sourceID": f"MS_{i:05d}",
+ "ra": base_ra + i * 0.002,
+ "dec": base_dec + i * 0.001,
+ "custom_metadata": f"4ch test source {i}",
+ }
+ )
+ meta_df = pd.DataFrame(metadata_rows)
+ meta_df.to_csv(os.path.join(ms_subdir, "metadata.csv"), index=False)
+ print(f"Created metadata.csv with {len(metadata_rows)} entries")
+
+ print(f"\nDone! Test data saved to: {ms_subdir}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..43e2e91
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,76 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Shared pytest fixtures for AnomalyMatch tests."""
+
+import os
+
+import pytest
+
+
+@pytest.fixture(scope="function")
+def multispectral_test_data():
+ """Provide 4-channel test images from permanent test data.
+
+ Uses pre-generated 4-channel TIFF images from tests/test_data/multispectral_4ch/
+
+ Yields:
+ tuple: (data_dir, image_filenames, labeled_data_path, num_channels)
+ - data_dir: Path to directory containing 4-channel test images
+ - image_filenames: List of image filenames
+ - labeled_data_path: Path to the labeled_data.csv file
+ - num_channels: Number of channels (4)
+ """
+ # Get path to 4-channel test images
+ test_data_dir = os.path.join(os.path.dirname(__file__), "test_data", "multispectral_4ch")
+
+ if not os.path.exists(test_data_dir):
+ pytest.skip(
+ "4-channel test data not found. Run 'python scripts/generate_4ch_test_data.py' first."
+ )
+
+ # Get all .tiff files (the format that supports 4+ channels)
+ tiff_images = [f for f in os.listdir(test_data_dir) if f.endswith(".tiff")]
+
+ if not tiff_images:
+ pytest.skip("No .tiff test images found in tests/test_data/multispectral_4ch/")
+
+ csv_path = os.path.join(test_data_dir, "labeled_data.csv")
+ if not os.path.exists(csv_path):
+ pytest.skip("labeled_data.csv not found in tests/test_data/multispectral_4ch/")
+
+ yield test_data_dir, sorted(tiff_images), csv_path, 4
+
+
+@pytest.fixture(scope="function")
+def multispectral_config(multispectral_test_data):
+ """Create a configuration for multispectral testing.
+
+ Args:
+ multispectral_test_data: The multispectral test data fixture.
+
+ Yields:
+ DotMap: Configuration object set up for 4-channel images.
+ """
+ import anomaly_match as am
+
+ data_dir, filenames, csv_path, num_channels = multispectral_test_data
+
+ cfg = am.get_default_cfg()
+ cfg.data_dir = data_dir
+ cfg.normalisation.image_size = [64, 64]
+ cfg.net = "test-cnn"
+ cfg.pretrained = False
+ cfg.num_train_iter = 2
+ cfg.num_workers = 0
+ cfg.test_ratio = 0.3
+ cfg.N_to_load = 10
+ cfg.normalisation.fits_extension = None
+ cfg.label_file = csv_path
+ cfg.seed = 42
+ # n_output_channels is auto-detected from images by AnomalyDetectionDataset
+
+ yield cfg
diff --git a/pytest.ini b/tests/e2e/conftest.py
similarity index 85%
rename from pytest.ini
rename to tests/e2e/conftest.py
index 53d01ae..5d45ca4 100644
--- a/pytest.ini
+++ b/tests/e2e/conftest.py
@@ -4,7 +4,4 @@
# is part of this source code package. No part of the package, including
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
-[pytest]
-minversion = 7.0
-testpaths = tests
-pythonpath = .
\ No newline at end of file
+"""Shared fixtures for end-to-end tests."""
diff --git a/tests/e2e/test_normalisation_consistency.py b/tests/e2e/test_normalisation_consistency.py
new file mode 100644
index 0000000..7d32171
--- /dev/null
+++ b/tests/e2e/test_normalisation_consistency.py
@@ -0,0 +1,248 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Tests verifying that cutana's prediction pipeline normalises images identically to AM's
+training pipeline. A mismatch would cause silent model failure in production (data drift).
+
+The cutana prediction path (prediction_process_cutana.py) passes AM's fitsbolt_cfg to cutana
+via external_fitsbolt_cfg so that cutana handles normalisation. This test verifies that
+the output matches what AM's load_and_process_wrapper produces for the same raw data.
+
+Test approach:
+1. Extract raw (unnormalised) cutouts from a test FITS tile using cutana's
+ do_only_cutout_extraction mode, then save as clean FITS files.
+2. Prediction path: run cutana with external_fitsbolt_cfg + channel_weights
+ (expanding 1->n_output_channels before normalisation) -> convert_cutana_cutout
+3. Training path: load_and_process_wrapper on the same raw FITS files
+4. Compare pixel-for-pixel (+/-1 tolerance for WCS reprojection float rounding).
+"""
+
+import csv
+import os
+
+import cutana
+import numpy as np
+import pytest
+from astropy.io import fits
+from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
+
+from anomaly_match.data_io.load_images import (
+ get_fitsbolt_config,
+ load_and_process_wrapper,
+)
+from anomaly_match.utils.get_default_cfg import get_default_cfg
+from prediction_utils import convert_cutana_cutout, create_cutana_format_cfg
+
+# Paths to pre-generated test data
+_TEST_DATA_DIR = os.path.join(
+ os.path.dirname(__file__), os.pardir, "test_data", "normalisation_consistency"
+)
+_FITS_TILE = os.path.join(
+ _TEST_DATA_DIR,
+ "EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits",
+)
+_CSV_CATALOGUE = os.path.join(_TEST_DATA_DIR, "mock_sources.csv")
+
+_TARGET_RESOLUTION = 150
+_MAX_SOURCES = 2
+
+
+def _rewrite_csv_with_absolute_paths(csv_in, csv_out, fits_tile_abs):
+ """Rewrite the catalogue CSV with absolute paths, limited to _MAX_SOURCES."""
+ with open(csv_in) as f_in, open(csv_out, "w", newline="") as f_out:
+ reader = csv.DictReader(f_in)
+ writer = csv.DictWriter(f_out, fieldnames=reader.fieldnames)
+ writer.writeheader()
+ for i, row in enumerate(reader):
+ if i >= _MAX_SOURCES:
+ break
+ row["fits_file_paths"] = str([fits_tile_abs])
+ writer.writerow(row)
+
+
+def _make_cutana_config(csv_path, n_output_channels=3):
+ """Create a base cutana config for the test tile.
+
+ Sets channel_weights so cutana expands single-extension data to
+ n_output_channels before normalisation — matching the training path's
+ channel_combine -> normalise order in fitsbolt's _process_image.
+ """
+ config = cutana.get_default_config()
+ config.target_resolution = _TARGET_RESOLUTION
+ config.source_catalogue = csv_path
+ config.fits_extensions = ["PRIMARY"]
+ config.selected_extensions = [{"name": "PRIMARY", "ext": "PrimaryHDU"}]
+ config.apply_flux_conversion = False
+ config.channel_weights = {"PRIMARY": [1.0] * n_output_channels}
+ return config
+
+
+def _run_cutana_normalised(csv_path, fitsbolt_cfg, n_output_channels=3):
+ """Run cutana with external_fitsbolt_cfg and return normalised cutout arrays."""
+ config = _make_cutana_config(csv_path, n_output_channels)
+ config.external_fitsbolt_cfg = fitsbolt_cfg
+
+ orchestrator = cutana.StreamingOrchestrator(config)
+ orchestrator.init_streaming(batch_size=10, write_to_disk=False, synchronised_loading=False)
+
+ all_cutouts = []
+ for _ in range(orchestrator.get_batch_count()):
+ batch = orchestrator.next_batch()
+ cutouts = batch["cutouts"]
+ if isinstance(cutouts, list) and len(cutouts) == 0:
+ continue
+ if isinstance(cutouts, list):
+ cutouts = np.array(cutouts)
+ for i in range(cutouts.shape[0]):
+ all_cutouts.append(np.array(cutouts[i]))
+
+ orchestrator.cleanup()
+ return all_cutouts
+
+
+def _extract_raw_cutouts(csv_path, output_dir):
+ """Extract raw (unnormalised) cutouts as FITS files using cutana."""
+ config = _make_cutana_config(csv_path, n_output_channels=1)
+ config.do_only_cutout_extraction = True
+ config.output_format = "fits"
+ config.write_to_disk = True
+ config.output_dir = output_dir
+
+ orchestrator = cutana.StreamingOrchestrator(config)
+ orchestrator.init_streaming(batch_size=10, write_to_disk=True, synchronised_loading=False)
+ for _ in range(orchestrator.get_batch_count()):
+ orchestrator.next_batch()
+ orchestrator.cleanup()
+
+ # Raw cutout data is in HDU[1] of each file
+ return sorted(
+ os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".fits")
+ )
+
+
+@pytest.fixture(scope="session")
+def cutana_test_data(tmp_path_factory):
+ """Extract raw cutouts from the test FITS tile using cutana.
+
+ Session-scoped: raw cutout extraction is expensive and identical
+ across all normalisation methods.
+
+ Returns:
+ tuple: (clean_fits_paths, rewritten_csv_path)
+ - clean_fits_paths: list of FITS file paths with raw cutout data in HDU[0]
+ - rewritten_csv_path: path to CSV with absolute FITS tile paths
+ """
+ if not os.path.exists(_CSV_CATALOGUE) or not os.path.exists(_FITS_TILE):
+ pytest.skip("Normalisation consistency test data not found")
+
+ tmp_path = tmp_path_factory.mktemp("normalisation_consistency")
+
+ # Rewrite CSV with absolute paths for this environment
+ rewritten_csv = str(tmp_path / "sources.csv")
+ _rewrite_csv_with_absolute_paths(_CSV_CATALOGUE, rewritten_csv, os.path.abspath(_FITS_TILE))
+
+ # Extract raw cutouts (do_only_cutout_extraction writes FITS with data in HDU[1])
+ raw_dir = str(tmp_path / "raw_cutouts")
+ os.makedirs(raw_dir, exist_ok=True)
+ raw_fits_paths = _extract_raw_cutouts(rewritten_csv, raw_dir)
+
+ if not raw_fits_paths:
+ pytest.skip("Cutana did not produce any cutouts from the test tile")
+
+ # Save as clean FITS files with data in HDU[0] for the training path
+ clean_fits_paths = []
+ for i, raw_path in enumerate(raw_fits_paths):
+ with fits.open(raw_path) as hdul:
+ raw_data = hdul[1].data
+ clean_path = str(tmp_path / f"cutout_{i}.fits")
+ fits.PrimaryHDU(raw_data.astype(np.float32)).writeto(clean_path, overwrite=True)
+ clean_fits_paths.append(clean_path)
+
+ return clean_fits_paths, rewritten_csv
+
+
+def _make_cfg(normalisation_method):
+ """Create an AM config with fitsbolt_cfg for the given normalisation method."""
+ cfg = get_default_cfg()
+ cfg.normalisation.image_size = [_TARGET_RESOLUTION, _TARGET_RESOLUTION]
+ cfg.normalisation.n_output_channels = 3
+ cfg.normalisation.normalisation_method = normalisation_method
+ cfg.num_channels = 3
+ cfg.num_workers = 0
+ cfg = get_fitsbolt_config(cfg)
+ return cfg
+
+
+_ALL_METHODS = [
+ NormalisationMethod.CONVERSION_ONLY,
+ NormalisationMethod.LOG,
+ NormalisationMethod.ZSCALE,
+ NormalisationMethod.ASINH,
+]
+
+
+def test_cutana_vs_training_normalisation(cutana_test_data):
+ """Verify that cutana's normalisation matches AM's training normalisation.
+
+ Tests all normalisation methods in a single function to avoid repeated
+ cutana orchestrator initialisation overhead (~3s per method).
+
+ For each normalisation method:
+ 1. Training path: raw FITS file -> load_and_process_wrapper() -> normalised image
+ 2. Prediction path: FITS tile -> cutana with external_fitsbolt_cfg ->
+ convert_cutana_cutout -> normalised image
+
+ Both paths must produce matching output for the same source data.
+ A tolerance of +/-1 (uint8) is allowed because the two separate cutana runs
+ (raw extraction vs normalised) may produce tiny float32 differences in WCS
+ reprojection, which non-linear stretches (ASINH) can amplify to +/-1/255.
+ """
+ clean_fits_paths, rewritten_csv = cutana_test_data
+ failures = []
+
+ for method in _ALL_METHODS:
+ cfg = _make_cfg(method)
+
+ # --- Prediction path: cutana normalisation + format conversion ---
+ cutana_normalised = _run_cutana_normalised(
+ rewritten_csv, cfg.fitsbolt_cfg, n_output_channels=3
+ )
+ if len(cutana_normalised) != len(clean_fits_paths):
+ failures.append(
+ f"{method.name}: cutana returned {len(cutana_normalised)} cutouts "
+ f"but {len(clean_fits_paths)} raw cutouts were extracted"
+ )
+ continue
+
+ format_cfg = create_cutana_format_cfg(cfg)
+ prediction_images = [convert_cutana_cutout(c, format_cfg) for c in cutana_normalised]
+
+ # --- Training path: load raw FITS via fitsbolt ---
+ training_pairs = load_and_process_wrapper(clean_fits_paths, cfg, show_progress=False)
+
+ # --- Compare ---
+ for i, (pred_img, (_, train_img)) in enumerate(zip(prediction_images, training_pairs)):
+ if pred_img.shape != train_img.shape:
+ failures.append(
+ f"{method.name} cutout {i}: shape mismatch — "
+ f"prediction {pred_img.shape} vs training {train_img.shape}"
+ )
+ continue
+ if pred_img.dtype != train_img.dtype:
+ failures.append(
+ f"{method.name} cutout {i}: dtype mismatch — "
+ f"prediction {pred_img.dtype} vs training {train_img.dtype}"
+ )
+ continue
+
+ diff = np.abs(pred_img.astype(np.int16) - train_img.astype(np.int16))
+ max_diff = int(diff.max())
+ if max_diff > 1:
+ failures.append(
+ f"{method.name} cutout {i}: max abs diff = {max_diff} (tolerance: 1)"
+ )
+
+ assert not failures, "Normalisation mismatches:\n" + "\n".join(failures)
diff --git a/tests/pipeline_test.py b/tests/e2e/test_pipeline.py
similarity index 89%
rename from tests/pipeline_test.py
rename to tests/e2e/test_pipeline.py
index e1137ca..41730d6 100644
--- a/tests/pipeline_test.py
+++ b/tests/e2e/test_pipeline.py
@@ -6,10 +6,13 @@
# the terms contained in the file 'LICENCE.txt'.
"""Testing a mininal pipeline"""
-import pytest
import ipywidgets as widgets
+import pytest
+
import anomaly_match as am
+pytestmark = pytest.mark.slow
+
@pytest.fixture(scope="module")
def pipeline_config():
@@ -25,7 +28,10 @@ def pipeline_config():
cfg.data_dir = "tests/test_data/"
cfg.normalisation.image_size = [64, 64]
cfg.normalisation.n_output_channels = 3
- cfg.num_train_iter = 10
+ cfg.net = "test-cnn"
+ cfg.pretrained = False
+ cfg.num_train_iter = 2
+ cfg.num_workers = 0
return cfg, out
diff --git a/tests/test_prediction_process.py b/tests/e2e/test_prediction_process.py
similarity index 96%
rename from tests/test_prediction_process.py
rename to tests/e2e/test_prediction_process.py
index 32897f9..799d366 100644
--- a/tests/test_prediction_process.py
+++ b/tests/e2e/test_prediction_process.py
@@ -4,33 +4,34 @@
# is part of this source code package. No part of the package, including
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
-import pytest
import csv
import os
-import numpy as np
-import h5py
-import zarr
import tempfile
-from PIL import Image
+
+import h5py
+import numpy as np
import pandas as pd
+import pytest
import torch
-from loguru import logger
+import zarr
+
+pytestmark = pytest.mark.slow
from astropy.io import fits
-from astropy.wcs import WCS
from astropy.table import Table
+from astropy.wcs import WCS
+from fitsbolt.cfg.create_config import create_config as fb_create_cfg
+from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
+from loguru import logger
+from PIL import Image
+from anomaly_match.utils.get_default_cfg import get_default_cfg
from prediction_process import evaluate_files
+from prediction_process_cutana import evaluate_images_from_cutana
from prediction_process_hdf5 import evaluate_images_in_hdf5
from prediction_process_zarr import evaluate_images_in_zarr
-from prediction_process_cutana import evaluate_images_from_cutana
from prediction_utils import save_results
-from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
-from fitsbolt.cfg.create_config import create_config as fb_create_cfg
-from anomaly_match.utils.get_default_cfg import get_default_cfg
-
-
@pytest.fixture
def test_config():
cfg = get_default_cfg()
@@ -49,6 +50,7 @@ def test_config():
cfg.test_ratio = 0.0 # Add test ratio
cfg.save_dir = tempfile.mkdtemp() # Add save directory
cfg.data_dir = "tests/test_data/" # Add data directory
+ cfg.num_workers = 0 # Use main process for data loading (avoids spawn overhead)
# Create fb_cfg for fitsbolt
cfg.fitsbolt_cfg = fb_create_cfg(
@@ -58,7 +60,7 @@ def test_config():
interpolation_order=cfg.normalisation.interpolation_order,
normalisation_method=cfg.normalisation.normalisation_method,
channel_combination=cfg.normalisation.channel_combination,
- num_workers=cfg.num_workers,
+ num_workers=max(cfg.num_workers, 1),
norm_maximum_value=cfg.normalisation.norm_maximum_value,
norm_minimum_value=cfg.normalisation.norm_minimum_value,
norm_log_calculate_minimum_value=cfg.normalisation.norm_log_calculate_minimum_value,
@@ -554,14 +556,19 @@ def test_prediction_file_type_cutana_malformed_header(test_config, test_cutana_m
with pytest.warns(
RuntimeWarning,
- match=r"File .* did not pass cutana compatibility check and will be skipped \(.*\)",
+ match=r"File .* did not pass cutana column check \(.*\)",
):
with pytest.raises(RuntimeError, match="All found files are not compatible with cutana"):
session.evaluate_all_images()
def test_prediction_file_type_cutana_missing_images(test_config, test_cutana_missing_images):
- """Test for meaningful exception when streaming from cutana and images are missing."""
+ """Test that catalogues with missing FITS images still pass validation.
+
+ FITS existence is no longer checked during validation (only column schema
+ is verified). Errors from missing files surface later during cutana
+ processing.
+ """
from anomaly_match.pipeline.session import Session
from anomaly_match.utils.get_default_cfg import get_default_cfg
@@ -572,12 +579,11 @@ def test_prediction_file_type_cutana_missing_images(test_config, test_cutana_mis
session = Session(cfg)
- with pytest.warns(
- RuntimeWarning,
- match=r"File .* did not pass cutana compatibility check and will be skipped \(.*\)",
- ):
- with pytest.raises(RuntimeError, match="All found files are not compatible with cutana"):
- session.evaluate_all_images()
+ # Validation passes (columns are valid), but the cutana subprocess will
+ # fail when it tries to open missing FITS files. The session logs a
+ # warning instead of raising, so we just verify no scores are loaded.
+ session.evaluate_all_images()
+ assert session.scores is None
def test_stream_file_type_detection_csv_and_parquet(tmp_path):
@@ -698,8 +704,8 @@ def mock_save_results(cfg, all_scores, all_imgs, all_filenames, top_n):
def test_load_and_preprocess_multiple_formats(test_config, mixed_format_images):
"""Test the load_and_preprocess function can handle multiple formats."""
- from prediction_process import load_and_preprocess
from anomaly_match.image_processing.transforms import get_prediction_transforms
+ from prediction_process import load_and_preprocess
image_paths, _ = mixed_format_images
transform = get_prediction_transforms()
@@ -815,9 +821,9 @@ def mock_process_batch(model, images):
f"Image array size ({len(final_images)}) doesn't match CSV size ({len(final_scores)}). "
f"This indicates a bug in image accumulation logic."
)
- assert (
- final_images.shape[0] == top_n
- ), f"Expected {top_n} images, got {final_images.shape[0]}"
+ assert final_images.shape[0] == top_n, (
+ f"Expected {top_n} images, got {final_images.shape[0]}"
+ )
def test_all_predictions_accumulation(test_config, monkeypatch):
@@ -988,9 +994,9 @@ def test_top_images_preservation_across_batches(test_config, tmp_path):
logger.info(f"All stored filenames: {all_stored_filenames}")
# Should have 4 total predictions (2 from each batch)
assert len(all_stored_scores) == 4, f"Expected 4 total scores, got {len(all_stored_scores)}"
- assert (
- len(all_stored_filenames) == 4
- ), f"Expected 4 total filenames, got {len(all_stored_filenames)}"
+ assert len(all_stored_filenames) == 4, (
+ f"Expected 4 total filenames, got {len(all_stored_filenames)}"
+ )
# Verify that all scores from both batches are present
expected_all_scores = [0.9, 0.8, 0.6, 0.5] # batch1: 0.9, 0.8; batch2: 0.6, 0.5
@@ -1006,12 +1012,12 @@ def test_top_images_preservation_across_batches(test_config, tmp_path):
sorted_scores = all_stored_scores[sorted_indices]
sorted_filenames = all_stored_filenames[sorted_indices]
- assert np.allclose(
- sorted_scores, expected_all_scores
- ), f"Expected all scores {expected_all_scores}, got {sorted_scores}"
- assert (
- list(sorted_filenames) == expected_all_filenames
- ), f"Expected all filenames {expected_all_filenames}, got {list(sorted_filenames)}"
+ assert np.allclose(sorted_scores, expected_all_scores), (
+ f"Expected all scores {expected_all_scores}, got {sorted_scores}"
+ )
+ assert list(sorted_filenames) == expected_all_filenames, (
+ f"Expected all filenames {expected_all_filenames}, got {list(sorted_filenames)}"
+ )
logger.info("✅ All predictions file correctly contains data from both batches")
@@ -1022,8 +1028,8 @@ def test_image_directory_processing(test_config, mixed_format_images):
_, directory_path = mixed_format_images
# Create a file list with all image paths in the directory
- from pathlib import Path
import tempfile
+ from pathlib import Path
# Create a temporary file list
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as file_list:
@@ -1065,9 +1071,10 @@ def test_image_directory_processing(test_config, mixed_format_images):
def test_prediction_file_type_image(test_config, monkeypatch, mixed_format_images):
"""Test that the correct prediction process is called for the 'image' file type."""
import os
+ import subprocess # Create a directory with mixed format images
import sys
+
from anomaly_match.pipeline.session import Session
- import subprocess # Create a directory with mixed format images
image_paths, directory_path = mixed_format_images
@@ -1147,7 +1154,7 @@ def mock_run_pipeline(self, temp_config_path, input_path, top_N):
with open(group_file, "r") as f:
group_content = f.read().strip().split("\n")
assert set(group_content) == set(image_paths), (
- f"Wrong content in group file: {group_content}, " f"expected {image_paths}"
+ f"Wrong content in group file: {group_content}, expected {image_paths}"
)
# Verify that the file list exists and points to the group file
@@ -1155,7 +1162,7 @@ def mock_run_pipeline(self, temp_config_path, input_path, top_N):
with open(temp_file_list, "r") as f:
content = f.read().strip()
assert content == group_file, (
- f"Wrong content in file list: '{content}', " f"expected '{group_file}'"
+ f"Wrong content in file list: '{content}', expected '{group_file}'"
)
finally:
@@ -1208,14 +1215,14 @@ def test_image_channel_order_rgb(test_config, tmp_path):
f"HDF5: Red channel not highest. R={np.mean(loaded_hdf5[:, :, 0])}, "
f"G={np.mean(loaded_hdf5[:, :, 1])}, B={np.mean(loaded_hdf5[:, :, 2])}"
)
- assert (
- np.mean(loaded_hdf5[:, :, 0]) > np.mean(loaded_hdf5[:, :, 2]) + tolerance
- ), f"HDF5: Red channel not highest vs blue. R={np.mean(loaded_hdf5[:, :, 0])}, B={np.mean(loaded_hdf5[:, :, 2])}"
+ assert np.mean(loaded_hdf5[:, :, 0]) > np.mean(loaded_hdf5[:, :, 2]) + tolerance, (
+ f"HDF5: Red channel not highest vs blue. R={np.mean(loaded_hdf5[:, :, 0])}, B={np.mean(loaded_hdf5[:, :, 2])}"
+ )
# Check green channel (should be middle)
- assert (
- np.mean(loaded_hdf5[:, :, 1]) > np.mean(loaded_hdf5[:, :, 2]) + tolerance
- ), f"HDF5: Green channel not higher than blue. G={np.mean(loaded_hdf5[:, :, 1])}, B={np.mean(loaded_hdf5[:, :, 2])}"
+ assert np.mean(loaded_hdf5[:, :, 1]) > np.mean(loaded_hdf5[:, :, 2]) + tolerance, (
+ f"HDF5: Green channel not higher than blue. G={np.mean(loaded_hdf5[:, :, 1])}, B={np.mean(loaded_hdf5[:, :, 2])}"
+ )
logger.info(
f"HDF5 RGB values: R={np.mean(loaded_hdf5[:, :, 0]):.1f}, "
@@ -1226,9 +1233,10 @@ def test_image_channel_order_rgb(test_config, tmp_path):
def test_prediction_file_type_zarr(test_config, monkeypatch, test_zarr):
"""Test that the correct prediction process is called for the 'zarr' file type."""
import os
+ import subprocess
import sys
+
from anomaly_match.pipeline.session import Session
- import subprocess
# Create config based on default
from anomaly_match.utils.get_default_cfg import get_default_cfg
@@ -1289,9 +1297,9 @@ def mock_run_pipeline(self, temp_config_path, input_path, top_N):
assert script_path == "prediction_process_zarr.py", f"Wrong script called: {script_path}"
# Verify the zarr file path is passed correctly
- assert (
- called_processes[0][3] == test_zarr
- ), f"Wrong zarr file path: {called_processes[0][3]}"
+ assert called_processes[0][3] == test_zarr, (
+ f"Wrong zarr file path: {called_processes[0][3]}"
+ )
finally:
# Clean up any temporary files
@@ -1301,7 +1309,7 @@ def mock_run_pipeline(self, temp_config_path, input_path, top_N):
def test_zarr_image_processing_consistency(test_config, test_zarr):
"""Test that zarr image processing produces consistent results with standard methods."""
- from prediction_process_zarr import read_and_preprocess_image_from_zarr
+ from prediction_utils import read_and_preprocess_image_from_zarr
# Open the zarr file and get a sample image
root = zarr.open_group(test_zarr, mode="r")
@@ -1566,9 +1574,9 @@ def test_zarr_fallback_filenames_have_prefix(tmp_path, test_config):
for batch_idx, filenames in enumerate(batch_filenames):
sample_filename = filenames[0]
# Should have format: __image_000000
- assert (
- "__image_" in sample_filename
- ), f"Batch {batch_idx} fallback filename doesn't have expected format. Got: {sample_filename}"
+ assert "__image_" in sample_filename, (
+ f"Batch {batch_idx} fallback filename doesn't have expected format. Got: {sample_filename}"
+ )
# Verify no collision between batches
set_0 = set(batch_filenames[0])
diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py
new file mode 100644
index 0000000..3a800e5
--- /dev/null
+++ b/tests/integration/conftest.py
@@ -0,0 +1,7 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Shared fixtures for integration tests."""
diff --git a/tests/test_fitsbolt_config_persistence.py b/tests/integration/test_fitsbolt_config_persistence.py
similarity index 99%
rename from tests/test_fitsbolt_config_persistence.py
rename to tests/integration/test_fitsbolt_config_persistence.py
index 179b54f..ee58584 100644
--- a/tests/test_fitsbolt_config_persistence.py
+++ b/tests/integration/test_fitsbolt_config_persistence.py
@@ -18,7 +18,8 @@
import numpy as np
import torch
from dotmap import DotMap
-from fitsbolt.cfg.create_config import create_config as fb_create_cfg, validate_config
+from fitsbolt.cfg.create_config import create_config as fb_create_cfg
+from fitsbolt.cfg.create_config import validate_config
from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
from anomaly_match.data_io.load_images import get_fitsbolt_config
diff --git a/tests/test_model_io_integration.py b/tests/integration/test_model_io_integration.py
similarity index 80%
rename from tests/test_model_io_integration.py
rename to tests/integration/test_model_io_integration.py
index 6c7f6a6..eb8e0e3 100644
--- a/tests/test_model_io_integration.py
+++ b/tests/integration/test_model_io_integration.py
@@ -5,16 +5,19 @@
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
-import tempfile
import shutil
+import tempfile
from pathlib import Path
+
+import pytest
import torch
import torch.nn as nn
from dotmap import DotMap
+from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
from anomaly_match.data_io.SessionIOHandler import SessionIOHandler
from anomaly_match.pipeline.SessionTracker import SessionTracker
-from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
+from anomaly_match.utils.get_net_builder import get_net_builder
class MockModel(nn.Module):
@@ -165,3 +168,36 @@ def test_load_model_checkpoint_nonexistent(self):
checkpoint = self.session_io.load_model_checkpoint(str(self.temp_dir / "nonexistent.pkl"))
assert checkpoint is None
+
+
+TEST_MODEL_PATH = Path(__file__).parent.parent / "test_data" / "test_model.pth"
+
+
+@pytest.mark.skipif(not TEST_MODEL_PATH.exists(), reason="test_model.pth not available")
+class TestStoredModelLoading:
+ """Regression tests for loading the stored test_model.pth checkpoint.
+
+ These tests verify that the checked-in test model remains compatible
+ with the current model architecture (timm-based EfficientNet).
+ """
+
+ def test_stored_model_has_expected_keys(self):
+ """Verify the stored checkpoint contains expected top-level keys."""
+ checkpoint = torch.load(str(TEST_MODEL_PATH), weights_only=False, map_location="cpu")
+
+ assert "eval_model" in checkpoint, (
+ f"Checkpoint missing 'eval_model' key. Found: {list(checkpoint.keys())}"
+ )
+ assert "train_model" in checkpoint
+
+ def test_stored_model_loads_into_efficientnet_lite0(self):
+ """Verify stored model state_dict is compatible with the current architecture."""
+ checkpoint = torch.load(str(TEST_MODEL_PATH), weights_only=False, map_location="cpu")
+
+ net_builder = get_net_builder("efficientnet-lite0", pretrained=False, in_channels=3)
+ model = net_builder(num_classes=2, in_channels=3)
+
+ # This will raise RuntimeError if keys don't match (the exact regression
+ # that would occur if the model was saved with a different architecture)
+ model.load_state_dict(checkpoint["eval_model"])
+ model.load_state_dict(checkpoint["train_model"])
diff --git a/tests/integration/test_multispectral_training.py b/tests/integration/test_multispectral_training.py
new file mode 100644
index 0000000..b5dba50
--- /dev/null
+++ b/tests/integration/test_multispectral_training.py
@@ -0,0 +1,65 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Integration tests for multispectral (N-channel) model training."""
+
+import pytest
+
+from anomaly_match.datasets.SSL_Dataset import SSL_Dataset
+
+
+@pytest.mark.slow
+class TestMultispectralTraining:
+ """Tests for multispectral training pipeline."""
+
+ def test_model_training_4_channels(self, multispectral_config):
+ """Test that FixMatch can be set up with 4-channel data."""
+ from anomaly_match.models.FixMatch import FixMatch
+ from anomaly_match.utils.get_net_builder import get_net_builder
+
+ # Load datasets
+ ssl_dataset = SSL_Dataset(cfg=multispectral_config, train=True)
+ labeled_dset, unlabeled_dset = ssl_dataset.get_ssl_dset()
+
+ # Verify datasets have correct channel count
+ assert labeled_dset.num_channels == 4
+ assert unlabeled_dset.num_channels == 4
+
+ # Build network
+ net_builder = get_net_builder(
+ multispectral_config.net,
+ pretrained=multispectral_config.pretrained,
+ in_channels=4,
+ )
+
+ # Create model
+ model = FixMatch(
+ net_builder=net_builder,
+ num_classes=2,
+ in_channels=4,
+ ema_m=multispectral_config.ema_m,
+ T=multispectral_config.temperature,
+ p_cutoff=multispectral_config.p_cutoff,
+ lambda_u=multispectral_config.ulb_loss_ratio,
+ )
+
+ # Set up data loaders
+ model.set_data_loader(
+ cfg=multispectral_config,
+ lb_dset=labeled_dset,
+ ulb_dset=unlabeled_dset,
+ eval_dset=None,
+ )
+
+ # Verify model is set up correctly for 4-channel input
+ assert model.train_model is not None
+ assert model.eval_model is not None
+ # The first conv layer should accept 4 channels
+ # TestCNN uses _conv_stem, timm models use conv_stem
+ train_model = model.train_model
+ first_conv = getattr(train_model, "conv_stem", getattr(train_model, "_conv_stem", None))
+ assert first_conv is not None
+ assert first_conv.in_channels == 4
diff --git a/tests/test_run_label_migration.py b/tests/integration/test_run_label_migration.py
similarity index 99%
rename from tests/test_run_label_migration.py
rename to tests/integration/test_run_label_migration.py
index dfc2fce..a484cf7 100644
--- a/tests/test_run_label_migration.py
+++ b/tests/integration/test_run_label_migration.py
@@ -5,12 +5,13 @@
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
-import tempfile
import os
+import tempfile
+from unittest.mock import Mock
+
+import pandas as pd
import pytest
import torch
-import pandas as pd
-from unittest.mock import Mock
from anomaly_match.data_io.SessionIOHandler import SessionIOHandler
from anomaly_match.pipeline.SessionTracker import SessionTracker
diff --git a/tests/session_test.py b/tests/integration/test_session.py
similarity index 98%
rename from tests/session_test.py
rename to tests/integration/test_session.py
index 624a6cc..e267371 100644
--- a/tests/session_test.py
+++ b/tests/integration/test_session.py
@@ -4,12 +4,16 @@
# is part of this source code package. No part of the package, including
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
-import pytest
import os
+
+import ipywidgets as widgets
import numpy as np
import pandas as pd
+import pytest
import torch
-import ipywidgets as widgets
+
+pytestmark = pytest.mark.slow
+
import anomaly_match as am
from anomaly_match.pipeline.session import Session
@@ -27,7 +31,10 @@ def base_config():
cfg.data_dir = "tests/test_data/"
cfg.normalisation.image_size = [64, 64]
cfg.normalisation.n_output_channels = 3
+ cfg.net = "test-cnn"
+ cfg.pretrained = False
cfg.num_train_iter = 2
+ cfg.num_workers = 0
cfg.test_ratio = 0.5
cfg.output_dir = "tests/test_output"
return cfg, out
@@ -212,9 +219,9 @@ def test_load_top_files(trained_session):
# Verify that the transpose from CHW to HWC was applied correctly
# Convert the original CHW back to HWC for comparison
expected_images_hwc = test_images_chw.transpose(0, 2, 3, 1)
- assert np.array_equal(
- session.img_catalog, expected_images_hwc
- ), "CHW to HWC conversion should be correct"
+ assert np.array_equal(session.img_catalog, expected_images_hwc), (
+ "CHW to HWC conversion should be correct"
+ )
# Test loading with images already in HWC format
test_images_hwc = np.random.randint(0, 255, (top_N, 64, 64, 3), dtype=np.uint8)
@@ -235,9 +242,9 @@ def test_load_top_files(trained_session):
# Verify that float images were converted to uint8
assert session.img_catalog.dtype == np.uint8, "Float images should be converted to uint8"
expected_uint8 = (test_images_float * 255.0).clip(0, 255).astype(np.uint8)
- assert np.array_equal(
- session.img_catalog, expected_uint8
- ), "Float to uint8 conversion should be correct"
+ assert np.array_equal(session.img_catalog, expected_uint8), (
+ "Float to uint8 conversion should be correct"
+ )
# Clean up test files
if os.path.exists(output_csv_path):
@@ -705,9 +712,9 @@ def test_iteration_scores_saved_after_training(base_config):
# Verify score mapping: check a few samples match between CSV and session
for idx, (filename, score) in enumerate(zip(session.filenames[:5], session.scores[:5])):
csv_score = unlabelled_df[unlabelled_df["filename"] == filename]["score"].values[0]
- assert (
- abs(csv_score - score) < 1e-6
- ), f"Score mismatch for {filename}: {csv_score} vs {score}"
+ assert abs(csv_score - score) < 1e-6, (
+ f"Score mismatch for {filename}: {csv_score} vs {score}"
+ )
# If test_ratio > 0, verify test scores were also saved
if cfg.test_ratio > 0:
diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_2021_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_2021_4ch.tiff
new file mode 100644
index 0000000..6699872
Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_2021_4ch.tiff differ
diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_2172_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_2172_4ch.tiff
new file mode 100644
index 0000000..39a3aff
Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_2172_4ch.tiff differ
diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_2_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_2_4ch.tiff
new file mode 100644
index 0000000..e276931
Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_2_4ch.tiff differ
diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_3_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_3_4ch.tiff
new file mode 100644
index 0000000..742c851
Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_3_4ch.tiff differ
diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_4989_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_4989_4ch.tiff
new file mode 100644
index 0000000..6287179
Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_4989_4ch.tiff differ
diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_5260_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_5260_4ch.tiff
new file mode 100644
index 0000000..553cc0a
Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_5260_4ch.tiff differ
diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_5783_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_5783_4ch.tiff
new file mode 100644
index 0000000..405afdd
Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_5783_4ch.tiff differ
diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_5_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_5_4ch.tiff
new file mode 100644
index 0000000..a3b90a0
Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_5_4ch.tiff differ
diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_6212_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_6212_4ch.tiff
new file mode 100644
index 0000000..d5cd0da
Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_6212_4ch.tiff differ
diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_6_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_6_4ch.tiff
new file mode 100644
index 0000000..d7a2d4f
Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_6_4ch.tiff differ
diff --git a/tests/test_data/multispectral_4ch/labeled_data.csv b/tests/test_data/multispectral_4ch/labeled_data.csv
new file mode 100644
index 0000000..ce6e7f2
--- /dev/null
+++ b/tests/test_data/multispectral_4ch/labeled_data.csv
@@ -0,0 +1,7 @@
+filename,label
+Abell2390_VIS_2_4ch.tiff,anomaly
+Abell2390_VIS_3_4ch.tiff,anomaly
+Abell2390_VIS_5_4ch.tiff,anomaly
+Abell2390_VIS_6_4ch.tiff,normal
+Abell2390_VIS_2021_4ch.tiff,normal
+Abell2390_VIS_2172_4ch.tiff,normal
diff --git a/tests/test_data/multispectral_4ch/metadata.csv b/tests/test_data/multispectral_4ch/metadata.csv
new file mode 100644
index 0000000..60d78e4
--- /dev/null
+++ b/tests/test_data/multispectral_4ch/metadata.csv
@@ -0,0 +1,11 @@
+filename,sourceID,ra,dec,custom_metadata
+Abell2390_VIS_2_4ch.tiff,MS_00000,328.4034,17.695,4ch test source 0
+Abell2390_VIS_3_4ch.tiff,MS_00001,328.4054,17.696,4ch test source 1
+Abell2390_VIS_5_4ch.tiff,MS_00002,328.4074,17.697,4ch test source 2
+Abell2390_VIS_6_4ch.tiff,MS_00003,328.40939999999995,17.698,4ch test source 3
+Abell2390_VIS_2021_4ch.tiff,MS_00004,328.41139999999996,17.699,4ch test source 4
+Abell2390_VIS_2172_4ch.tiff,MS_00005,328.41339999999997,17.7,4ch test source 5
+Abell2390_VIS_4989_4ch.tiff,MS_00006,328.4154,17.701,4ch test source 6
+Abell2390_VIS_5260_4ch.tiff,MS_00007,328.4174,17.702,4ch test source 7
+Abell2390_VIS_5783_4ch.tiff,MS_00008,328.4194,17.703,4ch test source 8
+Abell2390_VIS_6212_4ch.tiff,MS_00009,328.42139999999995,17.704,4ch test source 9
diff --git a/tests/test_data/normalisation_consistency/EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits b/tests/test_data/normalisation_consistency/EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits
new file mode 100644
index 0000000..79179f6
Binary files /dev/null and b/tests/test_data/normalisation_consistency/EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits differ
diff --git a/tests/test_data/normalisation_consistency/mock_sources.csv b/tests/test_data/normalisation_consistency/mock_sources.csv
new file mode 100644
index 0000000..2892a73
--- /dev/null
+++ b/tests/test_data/normalisation_consistency/mock_sources.csv
@@ -0,0 +1,6 @@
+SourceID,RA,Dec,diameter_pixel,fits_file_paths
+MockSource_102018211000001,150.1406475476454,2.336233868960458,150,['C:\\Arbeit\\Code\\AnomalyMatch\\tests\\test_data\\normalisation_consistency\\EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits']
+MockSource_102018211000002,150.13846652223168,2.3426457870360125,150,['C:\\Arbeit\\Code\\AnomalyMatch\\tests\\test_data\\normalisation_consistency\\EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits']
+MockSource_102018211000003,150.13720611369422,2.3416953968595355,150,['C:\\Arbeit\\Code\\AnomalyMatch\\tests\\test_data\\normalisation_consistency\\EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits']
+MockSource_102018211000004,150.13954247363657,2.3378254467880653,150,['C:\\Arbeit\\Code\\AnomalyMatch\\tests\\test_data\\normalisation_consistency\\EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits']
+MockSource_102018211000005,150.13806143453291,2.3366348110534423,150,['C:\\Arbeit\\Code\\AnomalyMatch\\tests\\test_data\\normalisation_consistency\\EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits']
diff --git a/tests/test_data/test_model.pth b/tests/test_data/test_model.pth
index ffd70f0..1ffb3f4 100644
Binary files a/tests/test_data/test_model.pth and b/tests/test_data/test_model.pth differ
diff --git a/tests/ui/conftest.py b/tests/ui/conftest.py
new file mode 100644
index 0000000..93db83a
--- /dev/null
+++ b/tests/ui/conftest.py
@@ -0,0 +1,7 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Shared fixtures for UI tests."""
diff --git a/tests/ui/test_memory_monitor.py b/tests/ui/test_memory_monitor.py
index 800eacc..17ab3bb 100644
--- a/tests/ui/test_memory_monitor.py
+++ b/tests/ui/test_memory_monitor.py
@@ -4,9 +4,13 @@
# is part of this source code package. No part of the package, including
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
-import pytest
import asyncio
-from anomaly_match.ui.memory_monitor import MemoryMonitor
+
+import pytest
+
+from anomaly_match_ui.memory_monitor import MemoryMonitor
+
+pytestmark = pytest.mark.ui
@pytest.mark.asyncio
diff --git a/tests/ui_test.py b/tests/ui/test_widget.py
similarity index 62%
rename from tests/ui_test.py
rename to tests/ui/test_widget.py
index bbd3070..7a60843 100644
--- a/tests/ui_test.py
+++ b/tests/ui/test_widget.py
@@ -4,71 +4,23 @@
# is part of this source code package. No part of the package, including
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
-import pytest
-import ipywidgets as widgets
-from anomaly_match.ui.Widget import Widget, shorten_filename
-from anomaly_match.pipeline.session import Session
-import anomaly_match as am
import os
-import numpy as np
from unittest.mock import patch
+
+import ipywidgets as widgets
import matplotlib
+import numpy as np
+import pytest
from PIL import Image
-matplotlib.use("Agg") # Prevent matplotlib windows from opening
+import anomaly_match as am
+from anomaly_match.pipeline.session import Session
+from anomaly_match_ui.utils.backend_interface import BackendInterface
+from anomaly_match_ui.widget import Widget
+pytestmark = [pytest.mark.ui, pytest.mark.slow]
-class TestShortenFilename:
- """Tests for the shorten_filename helper function."""
-
- def test_short_filename_unchanged(self):
- """Filenames within max length should remain unchanged."""
- assert shorten_filename("short.fits", max_length=25) == "short.fits"
- assert shorten_filename("image.jpg", max_length=25) == "image.jpg"
-
- def test_long_filename_shortened(self):
- """Long filenames should be shortened to max_length."""
- long_name = "very_long_filename_that_exceeds_limit.fits"
- result = shorten_filename(long_name, max_length=25)
- assert len(result) <= 25
- assert result.endswith(".fits")
- assert "..." in result
-
- def test_filename_with_multiple_dots(self):
- """Filenames with multiple dots should preserve only the extension."""
- name = "image.2024.01.15.observation.fits"
- result = shorten_filename(name, max_length=25)
- assert len(result) <= 25
- assert result.endswith(".fits")
- assert "..." in result
-
- def test_filename_without_extension(self):
- """Filenames without extension should still be shortened correctly."""
- name = "very_long_filename_without_any_extension"
- result = shorten_filename(name, max_length=25)
- assert len(result) <= 25
- assert "..." in result
-
- def test_exact_max_length(self):
- """Filename exactly at max_length should be unchanged."""
- name = "exactly_25_chars_long.fit"
- assert len(name) == 25
- assert shorten_filename(name, max_length=25) == name
-
- def test_very_short_max_length(self):
- """Very short max_length should still produce valid output."""
- name = "some_filename.fits"
- result = shorten_filename(name, max_length=10)
- assert len(result) <= 10
- assert "..." in result
-
- def test_preserves_start_and_end(self):
- """Shortened name should contain parts of the original start and end."""
- name = "START_middle_content_END.fits"
- result = shorten_filename(name, max_length=20)
- assert result.startswith("START")
- # Should contain some part of the end before the extension
- assert "END" in result or "..." in result
+matplotlib.use("Agg") # Prevent matplotlib windows from opening
@pytest.fixture(scope="session")
@@ -84,7 +36,10 @@ def base_config():
cfg.data_dir = "tests/test_data/"
cfg.normalisation.image_size = [64, 64]
cfg.normalisation.n_output_channels = 3
+ cfg.net = "test-cnn"
+ cfg.pretrained = False
cfg.num_train_iter = 2
+ cfg.num_workers = 0
cfg.test_ratio = 0.5
cfg.output_dir = "tests/test_output"
cfg.prediction_search_dir = "tests/test_data/" # Set a default search directory
@@ -105,7 +60,9 @@ def session_fixture(base_config):
@pytest.fixture(scope="session")
def ui_widget(session_fixture):
with patch("IPython.display.display"): # Prevent actual display calls
- widget = Widget(session_fixture)
+ # Set up the backend interface
+ BackendInterface.set_session(session_fixture)
+ widget = Widget()
yield widget
# Only close widgets in teardown if they still exist
try:
@@ -127,13 +84,13 @@ def setup_display_mocks():
# Test classes to organize related tests
class TestUIInitialization:
def test_ui_initialization(self, ui_widget):
- assert ui_widget.session is not None
+ assert BackendInterface.get_session() is not None
assert isinstance(ui_widget.ui["image_widget"], widgets.Image)
assert isinstance(ui_widget.ui["filename_text"], widgets.HTML)
def test_normalization_dropdown(self, ui_widget):
# Get the initial normalization method
- initial_method = ui_widget.session.cfg.normalisation.normalisation_method
+ initial_method = BackendInterface.get_config().normalisation.normalisation_method
# Get the dropdown options and find a different method
dropdown_options = ui_widget.ui["normalisation_dropdown"].options
@@ -148,7 +105,7 @@ def test_normalization_dropdown(self, ui_widget):
ui_widget.ui["normalisation_dropdown"].value = new_method
# Assert that the session config was updated
- assert ui_widget.session.cfg.normalisation.normalisation_method == new_method
+ assert BackendInterface.get_config().normalisation.normalisation_method == new_method
class TestUINavigation:
@@ -167,25 +124,25 @@ def test_previous_image(self, ui_widget):
class TestUISorting:
def test_sort_by_anomalous(self, ui_widget):
ui_widget.sort_by_anomalous()
- assert ui_widget.session.scores[0] >= ui_widget.session.scores[-1]
+ scores = BackendInterface.get_scores()
+ assert scores[0] >= scores[-1]
def test_sort_by_nominal(self, ui_widget):
ui_widget.sort_by_nominal()
- assert ui_widget.session.scores[0] <= ui_widget.session.scores[-1]
+ scores = BackendInterface.get_scores()
+ assert scores[0] <= scores[-1]
def test_sort_by_mean(self, ui_widget):
ui_widget.sort_by_mean()
- mean_score = ui_widget.session.scores.mean()
- assert abs(ui_widget.session.scores[0] - mean_score) <= abs(
- ui_widget.session.scores[-1] - mean_score
- )
+ scores = BackendInterface.get_scores()
+ mean_score = scores.mean()
+ assert abs(scores[0] - mean_score) <= abs(scores[-1] - mean_score)
def test_sort_by_median(self, ui_widget):
ui_widget.sort_by_median()
- median_score = np.median(ui_widget.session.scores)
- assert abs(ui_widget.session.scores[0] - median_score) <= abs(
- ui_widget.session.scores[-1] - median_score
- )
+ scores = BackendInterface.get_scores()
+ median_score = np.median(scores)
+ assert abs(scores[0] - median_score) <= abs(scores[-1] - median_score)
class TestUIImageProcessing:
@@ -211,46 +168,49 @@ def test_adjust_brightness_contrast(self, ui_widget):
class TestUIModelOperations:
def test_save_load_model(self, ui_widget):
ui_widget.save_model()
- assert os.path.exists(ui_widget.session.cfg.model_path)
+ cfg = BackendInterface.get_config()
+ assert os.path.exists(cfg.model_path)
ui_widget.load_model()
- assert ui_widget.session.model is not None
+ assert BackendInterface.get_model() is not None
def test_train(self, ui_widget):
ui_widget.train()
- assert ui_widget.session.eval_performance is not None
+ assert BackendInterface.get_eval_performance() is not None
def test_reset_model(self, ui_widget):
ui_widget.reset_model()
- assert ui_widget.session.model is not None
+ assert BackendInterface.get_model() is not None
class TestUIBatchOperations:
def test_update_batch_size(self, ui_widget):
initial_batch_size = ui_widget.ui["batch_size_slider"].value
ui_widget.ui["batch_size_slider"].value = initial_batch_size + 500
- assert ui_widget.session.cfg.N_to_load == initial_batch_size + 500
+ cfg = BackendInterface.get_config()
+ assert cfg.N_to_load == initial_batch_size + 500
def test_next_batch(self, ui_widget):
ui_widget.next_batch()
- assert ui_widget.session.img_catalog is not None
+ session = BackendInterface.get_session()
+ assert session.img_catalog is not None
def test_search_all_files(self, ui_widget):
- with patch("anomaly_match.ui.Widget.display"):
+ with patch("anomaly_match_ui.widget.display"):
# Save model first so evaluate_all_images can find it
ui_widget.save_model()
+ cfg = BackendInterface.get_config()
# Ensure test data directory exists and has files
- os.makedirs(ui_widget.session.cfg.prediction_search_dir, exist_ok=True)
+ os.makedirs(cfg.prediction_search_dir, exist_ok=True)
# Create a test image if directory is empty
- if not os.listdir(ui_widget.session.cfg.prediction_search_dir):
+ if not os.listdir(cfg.prediction_search_dir):
test_img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
- test_img_path = os.path.join(
- ui_widget.session.cfg.prediction_search_dir, "test.jpg"
- )
+ test_img_path = os.path.join(cfg.prediction_search_dir, "test.jpg")
Image.fromarray(test_img).save(test_img_path)
ui_widget.search_all_files()
- assert len(ui_widget.session.img_catalog) > 0
+ session = BackendInterface.get_session()
+ assert len(session.img_catalog) > 0
def test_cleanup(ui_widget):
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
new file mode 100644
index 0000000..9ff1a91
--- /dev/null
+++ b/tests/unit/conftest.py
@@ -0,0 +1,35 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Shared fixtures for unit tests."""
+
+import ipywidgets as widgets
+import pytest
+
+import anomaly_match as am
+
+
+@pytest.fixture(scope="module")
+def base_config():
+ """Base configuration for unit tests (no training, no output dir)."""
+ out = widgets.Output(
+ layout=widgets.Layout(
+ border="1px solid white", height="400px", background_color="black", overflow="auto"
+ ),
+ )
+
+ cfg = am.get_default_cfg()
+ am.set_log_level("debug", cfg)
+ cfg.data_dir = "tests/test_data/"
+ cfg.normalisation.image_size = [64, 64]
+ cfg.normalisation.n_output_channels = 3
+ cfg.net = "test-cnn"
+ cfg.pretrained = False
+ cfg.num_train_iter = 2
+ cfg.num_workers = 0
+ cfg.test_ratio = 0.5
+ cfg.output_dir = "tests/test_output"
+ return cfg, out
diff --git a/tests/test_batch_size_estimation.py b/tests/unit/test_batch_size_estimation.py
similarity index 98%
rename from tests/test_batch_size_estimation.py
rename to tests/unit/test_batch_size_estimation.py
index 746f2f9..c411f66 100644
--- a/tests/test_batch_size_estimation.py
+++ b/tests/unit/test_batch_size_estimation.py
@@ -6,9 +6,9 @@
# the terms contained in the file 'LICENCE.txt'.
import pytest
-from prediction_utils import estimate_batch_size, MEMORY_COEFFICIENTS
-from anomaly_match import get_default_cfg
+from anomaly_match import get_default_cfg
+from prediction_utils import MEMORY_COEFFICIENTS, estimate_batch_size
FAKE_VRAM_BYTES = 16 * 1024**3 # 16GB
@@ -152,7 +152,6 @@ def test_estimate_batch_size_invalid_memory(test_config):
def test_estimate_batch_size_invalid_model_coefficients(test_config, monkeypatch):
-
import prediction_utils
# Inject invalid coefficients
diff --git a/tests/cfg_validation_test.py b/tests/unit/test_config_validation.py
similarity index 99%
rename from tests/cfg_validation_test.py
rename to tests/unit/test_config_validation.py
index f4b36da..03d03ef 100644
--- a/tests/cfg_validation_test.py
+++ b/tests/unit/test_config_validation.py
@@ -7,6 +7,7 @@
import pytest
from loguru import logger
+
from anomaly_match.utils.get_default_cfg import get_default_cfg
from anomaly_match.utils.validate_config import validate_config
diff --git a/tests/unit/test_cutana_stream_utils.py b/tests/unit/test_cutana_stream_utils.py
new file mode 100644
index 0000000..67c0c94
--- /dev/null
+++ b/tests/unit/test_cutana_stream_utils.py
@@ -0,0 +1,87 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Tests for cutana catalogue validation and source counting."""
+
+import pandas as pd
+import pytest
+
+from anomaly_match.utils.cutana_stream_utils import cutana_validate_files_and_count_sources
+
+
+def _make_valid_df(n=50):
+ """Create a minimal valid cutana catalogue DataFrame."""
+ return pd.DataFrame(
+ {
+ "SourceID": range(n),
+ "RA": [180.0] * n,
+ "Dec": [-30.0] * n,
+ "fits_file_paths": ["dummy.fits"] * n,
+ "diameter_pixel": [10] * n,
+ }
+ )
+
+
+@pytest.fixture
+def valid_parquet(tmp_path):
+ """Write a valid parquet catalogue and return its path."""
+ path = tmp_path / "valid.parquet"
+ _make_valid_df(100).to_parquet(path, index=False)
+ return path
+
+
+@pytest.fixture
+def invalid_parquet(tmp_path):
+ """Write a parquet file with wrong columns."""
+ path = tmp_path / "invalid.parquet"
+ pd.DataFrame({"bad_col": [1, 2, 3]}).to_parquet(path, index=False)
+ return path
+
+
+class TestCutanaValidateFilesAndCountSources:
+ def test_valid_parquet_accepted(self, valid_parquet):
+ files, total, chunks = cutana_validate_files_and_count_sources(
+ [valid_parquet], chunk_size=50
+ )
+ assert len(files) == 1
+ assert total == 100
+ assert chunks == 2 # 100 rows / 50 chunk_size
+
+ def test_row_count_from_metadata(self, tmp_path):
+ """Row counts should match actual rows without scanning data."""
+ paths = []
+ for i, n in enumerate([200, 300]):
+ p = tmp_path / f"cat_{i}.parquet"
+ _make_valid_df(n).to_parquet(p, index=False)
+ paths.append(p)
+
+ files, total, chunks = cutana_validate_files_and_count_sources(paths, chunk_size=100)
+ assert len(files) == 2
+ assert total == 500
+ assert chunks == 5 # 200/100 + 300/100
+
+ def test_invalid_columns_rejected(self, invalid_parquet):
+ files, total, chunks = cutana_validate_files_and_count_sources(
+ [invalid_parquet], chunk_size=100
+ )
+ assert len(files) == 0
+ assert total == 0
+
+ def test_mixed_valid_invalid(self, valid_parquet, invalid_parquet):
+ """Invalid file should be skipped, valid file should be counted."""
+ files, total, chunks = cutana_validate_files_and_count_sources(
+ [invalid_parquet, valid_parquet], chunk_size=100
+ )
+ assert len(files) == 1
+ assert total == 100
+
+ def test_valid_csv_accepted(self, tmp_path):
+ path = tmp_path / "valid.csv"
+ _make_valid_df(75).to_csv(path, index=False)
+ files, total, chunks = cutana_validate_files_and_count_sources([path], chunk_size=50)
+ assert len(files) == 1
+ assert total == 75
+ assert chunks == 2 # 50 + 25
diff --git a/tests/unit/test_data_utils.py b/tests/unit/test_data_utils.py
new file mode 100644
index 0000000..2aec1e1
--- /dev/null
+++ b/tests/unit/test_data_utils.py
@@ -0,0 +1,120 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Tests for data loading utilities."""
+
+import numpy as np
+import pytest
+import torch
+from torch.utils.data import DataLoader
+
+from anomaly_match.datasets.BasicDataset import BasicDataset
+from anomaly_match.datasets.data_utils import get_data_loader, get_sampler_by_name
+from anomaly_match.image_processing.transforms import get_prediction_transforms
+
+
+@pytest.fixture
+def simple_dataset():
+ """Create a simple BasicDataset for testing."""
+ imgs = np.random.randint(0, 255, (20, 64, 64, 3), dtype=np.uint8)
+ filenames = [f"img_{i}.jpg" for i in range(20)]
+ targets = [0] * 10 + [1] * 10
+ transform = get_prediction_transforms(num_channels=3)
+ return BasicDataset(imgs, filenames, targets, num_classes=2, transform=transform)
+
+
+class TestGetSamplerByName:
+ def test_random_sampler(self):
+ sampler = get_sampler_by_name("RandomSampler")
+ assert sampler is torch.utils.data.sampler.RandomSampler
+
+ def test_sequential_sampler(self):
+ sampler = get_sampler_by_name("SequentialSampler")
+ assert sampler is torch.utils.data.sampler.SequentialSampler
+
+ def test_invalid_sampler_raises(self):
+ with pytest.raises(AttributeError, match="not found"):
+ get_sampler_by_name("NonexistentSampler")
+
+
+class TestGetDataLoader:
+ def test_requires_batch_size(self, simple_dataset):
+ with pytest.raises(AssertionError, match="Batch size must be specified"):
+ get_data_loader(simple_dataset, batch_size=None)
+
+ def test_basic_dataloader(self, simple_dataset):
+ loader = get_data_loader(simple_dataset, batch_size=4, num_workers=0)
+ assert isinstance(loader, DataLoader)
+ batch = next(iter(loader))
+ assert batch[0].shape[0] == 4
+
+ def test_shuffle_dataloader(self, simple_dataset):
+ loader = get_data_loader(simple_dataset, batch_size=4, shuffle=True, num_workers=0)
+ assert isinstance(loader, DataLoader)
+
+ def test_weighted_sampler(self, simple_dataset):
+ loader = get_data_loader(
+ simple_dataset,
+ batch_size=4,
+ use_weighted_sampler=True,
+ num_workers=0,
+ )
+ assert isinstance(loader, DataLoader)
+
+ def test_weighted_sampler_with_num_iters(self, simple_dataset):
+ loader = get_data_loader(
+ simple_dataset,
+ batch_size=4,
+ use_weighted_sampler=True,
+ num_iters=10,
+ num_workers=0,
+ )
+ assert isinstance(loader, DataLoader)
+
+ def test_weighted_sampler_with_num_epochs(self, simple_dataset):
+ loader = get_data_loader(
+ simple_dataset,
+ batch_size=4,
+ use_weighted_sampler=True,
+ num_epochs=2,
+ num_workers=0,
+ )
+ assert isinstance(loader, DataLoader)
+
+ def test_random_sampler_by_name(self, simple_dataset):
+ loader = get_data_loader(
+ simple_dataset,
+ batch_size=4,
+ data_sampler="RandomSampler",
+ num_workers=0,
+ )
+ assert isinstance(loader, DataLoader)
+
+ def test_unsupported_sampler_raises(self, simple_dataset):
+ with pytest.raises(RuntimeError, match="not fully implemented"):
+ get_data_loader(
+ simple_dataset,
+ batch_size=4,
+ data_sampler="SequentialSampler",
+ num_workers=0,
+ )
+
+
+class TestWeightedSamplerSingleClass:
+ def test_single_class_uniform_weights(self):
+ """Weighted sampler with only one class should use uniform weights."""
+ imgs = np.random.randint(0, 255, (10, 64, 64, 3), dtype=np.uint8)
+ filenames = [f"img_{i}.jpg" for i in range(10)]
+ targets = [0] * 10 # All same class
+ transform = get_prediction_transforms(num_channels=3)
+ dataset = BasicDataset(imgs, filenames, targets, num_classes=2, transform=transform)
+ loader = get_data_loader(
+ dataset,
+ batch_size=4,
+ use_weighted_sampler=True,
+ num_workers=0,
+ )
+ assert isinstance(loader, DataLoader)
diff --git a/tests/dataset_test.py b/tests/unit/test_dataset.py
similarity index 98%
rename from tests/dataset_test.py
rename to tests/unit/test_dataset.py
index 1a93bbc..05dbd2b 100644
--- a/tests/dataset_test.py
+++ b/tests/unit/test_dataset.py
@@ -4,22 +4,23 @@
# is part of this source code package. No part of the package, including
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
-import pytest
+import os
+import tempfile
+
import numpy as np
import pandas as pd
+import pytest
import torch
+from PIL import Image
+
import anomaly_match as am
-from anomaly_match.datasets.Label import Label
from anomaly_match.datasets.AnomalyDetectionDataset import AnomalyDetectionDataset
from anomaly_match.datasets.BasicDataset import BasicDataset
+from anomaly_match.datasets.Label import Label
from anomaly_match.datasets.SSL_Dataset import SSL_Dataset
-import os
-import tempfile
-from PIL import Image
-
from anomaly_match.image_processing.transforms import (
- get_weak_transforms,
get_prediction_transforms,
+ get_weak_transforms,
)
@@ -112,9 +113,9 @@ def test_multiple_file_extensions_support(multi_extension_dataset, test_config):
dataset = AnomalyDetectionDataset(test_config)
# Check if all images were found
- assert len(dataset.all_filenames) == len(
- extensions
- ), "Not all images with different extensions were found"
+ assert len(dataset.all_filenames) == len(extensions), (
+ "Not all images with different extensions were found"
+ )
# Verify that all expected files are included
for filename in test_images:
diff --git a/tests/unit/test_display_transforms.py b/tests/unit/test_display_transforms.py
new file mode 100644
index 0000000..70bdff4
--- /dev/null
+++ b/tests/unit/test_display_transforms.py
@@ -0,0 +1,198 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Tests for display transform functions."""
+
+import numpy as np
+import pytest
+from PIL import Image
+
+from anomaly_match_ui.utils.display_transforms import (
+ apply_transforms_ui,
+ display_image_normalisation,
+ prepare_for_display,
+)
+
+
+class TestPrepareForDisplay:
+ def test_rgb_passthrough(self):
+ img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
+ result = prepare_for_display(img)
+ assert result.shape == (64, 64, 3)
+ assert np.array_equal(result, img)
+
+ def test_grayscale_to_rgb(self):
+ img = np.random.randint(0, 255, (64, 64, 1), dtype=np.uint8)
+ result = prepare_for_display(img)
+ assert result.shape == (64, 64, 3)
+ # All channels should be the same
+ assert np.array_equal(result[:, :, 0], result[:, :, 1])
+ assert np.array_equal(result[:, :, 1], result[:, :, 2])
+
+ def test_2d_grayscale(self):
+ img = np.random.randint(0, 255, (64, 64), dtype=np.uint8)
+ result = prepare_for_display(img)
+ assert result.shape == (64, 64, 3)
+
+ def test_2_channel_to_rgb(self):
+ img = np.random.randint(0, 255, (64, 64, 2), dtype=np.uint8)
+ result = prepare_for_display(img)
+ assert result.shape == (64, 64, 3)
+
+ def test_4_channel_default_mapping(self):
+ img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8)
+ result = prepare_for_display(img)
+ assert result.shape == (64, 64, 3)
+ # Default mapping uses first 3 channels
+ assert np.array_equal(result, img[:, :, :3])
+
+ def test_4_channel_custom_mapping(self):
+ img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8)
+ result = prepare_for_display(img, rgb_mapping=[1, 2, 3])
+ assert result.shape == (64, 64, 3)
+ assert np.array_equal(result[:, :, 0], img[:, :, 1])
+ assert np.array_equal(result[:, :, 1], img[:, :, 2])
+ assert np.array_equal(result[:, :, 2], img[:, :, 3])
+
+ def test_invalid_rgb_mapping_length(self):
+ img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8)
+ with pytest.raises(ValueError, match="must have 3 elements"):
+ prepare_for_display(img, rgb_mapping=[0, 1])
+
+ def test_invalid_rgb_mapping_index(self):
+ img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8)
+ with pytest.raises(ValueError, match="exceed channel count"):
+ prepare_for_display(img, rgb_mapping=[0, 1, 5])
+
+ def test_float_to_uint8_conversion(self):
+ img = np.random.random((64, 64, 3)).astype(np.float32)
+ result = prepare_for_display(img)
+ assert result.dtype == np.uint8
+ assert result.shape == (64, 64, 3)
+
+ def test_pil_image_input(self):
+ pil_img = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8))
+ result = prepare_for_display(pil_img)
+ assert result.shape == (64, 64, 3)
+ assert result.dtype == np.uint8
+
+ def test_invalid_input_type(self):
+ with pytest.raises(ValueError, match="Expected numpy array or PIL Image"):
+ prepare_for_display("not_an_image")
+
+
+class TestDisplayImageNormalisation:
+ def test_basic_normalisation(self):
+ img = np.random.random((64, 64, 3)).astype(np.float64) * 255
+ result = display_image_normalisation(img)
+ assert isinstance(result, Image.Image)
+
+ def test_constant_image(self):
+ img = np.full((64, 64, 3), 0.5, dtype=np.float64)
+ result = display_image_normalisation(img)
+ assert isinstance(result, Image.Image)
+
+ def test_handles_nan(self):
+ img = np.random.random((64, 64, 3))
+ img[10, 10, 0] = np.nan
+ result = display_image_normalisation(img)
+ assert isinstance(result, Image.Image)
+
+ def test_handles_inf(self):
+ img = np.random.random((64, 64, 3))
+ img[10, 10, 0] = np.inf
+ result = display_image_normalisation(img)
+ assert isinstance(result, Image.Image)
+
+
+class TestApplyTransformsUI:
+ @pytest.fixture
+ def sample_pil_image(self):
+ return Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8))
+
+ def test_no_transforms(self, sample_pil_image):
+ result = apply_transforms_ui(
+ sample_pil_image,
+ invert=False,
+ brightness=1.0,
+ contrast=1.0,
+ unsharp_mask_applied=False,
+ )
+ assert isinstance(result, Image.Image)
+
+ def test_invert(self, sample_pil_image):
+ result = apply_transforms_ui(
+ sample_pil_image,
+ invert=True,
+ brightness=1.0,
+ contrast=1.0,
+ unsharp_mask_applied=False,
+ )
+ assert isinstance(result, Image.Image)
+
+ def test_brightness(self, sample_pil_image):
+ result = apply_transforms_ui(
+ sample_pil_image,
+ invert=False,
+ brightness=1.5,
+ contrast=1.0,
+ unsharp_mask_applied=False,
+ )
+ assert isinstance(result, Image.Image)
+
+ def test_contrast(self, sample_pil_image):
+ result = apply_transforms_ui(
+ sample_pil_image,
+ invert=False,
+ brightness=1.0,
+ contrast=1.5,
+ unsharp_mask_applied=False,
+ )
+ assert isinstance(result, Image.Image)
+
+ def test_unsharp_mask(self, sample_pil_image):
+ result = apply_transforms_ui(
+ sample_pil_image,
+ invert=False,
+ brightness=1.0,
+ contrast=1.0,
+ unsharp_mask_applied=True,
+ )
+ assert isinstance(result, Image.Image)
+
+ def test_channel_toggling_hide_red(self, sample_pil_image):
+ result = apply_transforms_ui(
+ sample_pil_image,
+ invert=False,
+ brightness=1.0,
+ contrast=1.0,
+ unsharp_mask_applied=False,
+ show_r=False,
+ )
+ result_array = np.array(result)
+ assert np.all(result_array[:, :, 0] == 0)
+
+ def test_channel_visibility_list(self, sample_pil_image):
+ result = apply_transforms_ui(
+ sample_pil_image,
+ invert=False,
+ brightness=1.0,
+ contrast=1.0,
+ unsharp_mask_applied=False,
+ channel_visibility=[True, False, True],
+ )
+ result_array = np.array(result)
+ assert np.all(result_array[:, :, 1] == 0)
+
+ def test_all_transforms_combined(self, sample_pil_image):
+ result = apply_transforms_ui(
+ sample_pil_image,
+ invert=True,
+ brightness=1.2,
+ contrast=0.8,
+ unsharp_mask_applied=True,
+ )
+ assert isinstance(result, Image.Image)
diff --git a/tests/file_io_test.py b/tests/unit/test_file_io.py
similarity index 90%
rename from tests/file_io_test.py
rename to tests/unit/test_file_io.py
index d90df1d..c46a01a 100644
--- a/tests/file_io_test.py
+++ b/tests/unit/test_file_io.py
@@ -7,23 +7,25 @@
"""
Tests for the image IO utility functions.
"""
+
import os
-import numpy as np
-import pytest
import shutil
import tempfile
-from PIL import Image
+
+import numpy as np
+import pytest
from astropy.io import fits
+from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
+from PIL import Image
from anomaly_match.data_io.find_images_in_folder import (
get_image_names_from_folder,
get_image_paths_from_folder,
)
from anomaly_match.data_io.load_images import (
- load_and_process_wrapper,
load_and_process_single_wrapper,
+ load_and_process_wrapper,
)
-from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
def _load_image_with_fitsbolt(filepath, cfg):
@@ -305,9 +307,9 @@ def test_load_and_process_images_rgba(self, test_config):
transparent_img = _load_image_with_fitsbolt(self.transparent_path, cfg=test_config)
assert transparent_img.shape[2] == 3 # Should still be RGB
# The green channel should still be present even though the pixels were transparent
- assert np.any(
- transparent_img[:, :, 1] > 0
- ), "Green channel data lost in transparent image" # Test complex RGBA with gradient alpha
+ assert np.any(transparent_img[:, :, 1] > 0), (
+ "Green channel data lost in transparent image"
+ ) # Test complex RGBA with gradient alpha
complex_rgba_img = _load_image_with_fitsbolt(self.complex_rgba_path, cfg=test_config)
assert complex_rgba_img.shape[2] == 3 # Should be RGB
# Check that gradients are preserved in RGB channels
@@ -352,24 +354,24 @@ def test_rgba_to_rgb_conversion_values(self, test_config):
# Test that colors are preserved correctly according to alpha values
# Colors with full alpha should be preserved exactly
- assert np.all(
- rgb_img[: height // 2, : width // 2, 0] == 255
- ), "Red with full alpha should be preserved"
- assert np.all(
- rgb_img[: height // 2, : width // 2, 1] == 0
- ), "Red channel should have no green"
- assert np.all(
- rgb_img[: height // 2, : width // 2, 2] == 0
- ), "Red channel should have no blue"
+ assert np.all(rgb_img[: height // 2, : width // 2, 0] == 255), (
+ "Red with full alpha should be preserved"
+ )
+ assert np.all(rgb_img[: height // 2, : width // 2, 1] == 0), (
+ "Red channel should have no green"
+ )
+ assert np.all(rgb_img[: height // 2, : width // 2, 2] == 0), (
+ "Red channel should have no blue"
+ )
# Colors with partial alpha might be handled differently depending on implementation
# We'll just check that they exist and aren't black
- assert (
- np.mean(rgb_img[: height // 2, width // 2 :, 1]) > 0
- ), "Green with half alpha should be visible"
- assert (
- np.mean(rgb_img[height // 2 :, : width // 2, 2]) > 0
- ), "Blue with quarter alpha should be visible"
+ assert np.mean(rgb_img[: height // 2, width // 2 :, 1]) > 0, (
+ "Green with half alpha should be visible"
+ )
+ assert np.mean(rgb_img[height // 2 :, : width // 2, 2]) > 0, (
+ "Blue with quarter alpha should be visible"
+ )
# The behavior with fully transparent pixels can vary depending on the implementation
# Some libraries preserve the color values even with zero alpha, others apply a background color
# Instead of asserting they shouldn't be white, we'll just check that the image was loaded successfully
@@ -467,9 +469,9 @@ def test_load_image_with_fitsbolt_fits(self, test_config):
# Check if normalization preserved the pattern (bright area should be brighter than dark area)
bright_area = extreme_img[50:80, 50:80]
dark_area = extreme_img[10:40, 10:40]
- assert np.mean(bright_area) > np.mean(
- dark_area
- ), "Normalization failed to preserve contrast"
+ assert np.mean(bright_area) > np.mean(dark_area), (
+ "Normalization failed to preserve contrast"
+ )
def test_fits_extension_parameter(self, test_config):
"""Test the fits_extension parameter for FITS files."""
@@ -634,14 +636,14 @@ def test_fits_multiple_extensions(self, test_config):
assert combined_img3.shape == (50, 50, 3), "Combined image should have shape (50, 50, 3)"
# The images should be different due to different ordering
- assert not np.array_equal(
- combined_img1, combined_img3
- ), "Different extension order should give different results"
+ assert not np.array_equal(combined_img1, combined_img3), (
+ "Different extension order should give different results"
+ )
# Int indices and string names with same order should give same result
- assert np.array_equal(
- combined_img1, combined_img2
- ), "Same extension order should give same results"
+ assert np.array_equal(combined_img1, combined_img2), (
+ "Same extension order should give same results"
+ )
# Create another FITS file with extensions of different shapes to test error handling
diff_shapes_path = os.path.join(self.test_dir, "diff_shapes.fits")
@@ -668,9 +670,9 @@ def test_fits_multiple_extensions(self, test_config):
# Validate the error message contains information about channel mismatch
error_message = str(e_info.value)
- assert (
- "channel" in error_message.lower() or "extension" in error_message.lower()
- ), f"Error should mention channel or extension mismatch: {error_message}"
+ assert "channel" in error_message.lower() or "extension" in error_message.lower(), (
+ f"Error should mention channel or extension mismatch: {error_message}"
+ )
# could expand test if needed with more extensions
def test_load_and_process_images_fits_extension(self, test_config):
@@ -759,12 +761,12 @@ def test_load_and_process_images_fits_extension(self, test_config):
img_combined = results_combined[0][1]
# The images should have different content because they used different extensions
- assert not np.array_equal(
- img_ext0, img_ext1
- ), "Different extensions should produce different images"
- assert not np.array_equal(
- img_ext0, img_combined
- ), "Combined extensions should differ from single extension"
+ assert not np.array_equal(img_ext0, img_ext1), (
+ "Different extensions should produce different images"
+ )
+ assert not np.array_equal(img_ext0, img_combined), (
+ "Combined extensions should differ from single extension"
+ )
# Also test with string extension names
_update_config(test_config, fits_extension=["PRIMARY", "EXT1", "EXT2"])
@@ -774,9 +776,9 @@ def test_load_and_process_images_fits_extension(self, test_config):
img_named = results_named[0][1]
# Should be identical to using numeric indices [0, 1, 2]
- assert np.array_equal(
- img_combined, img_named
- ), "String extension names should produce same result as numeric indices"
+ assert np.array_equal(img_combined, img_named), (
+ "String extension names should produce same result as numeric indices"
+ )
def test_image_normalisation(self, test_config):
"""Test that different normalisation methods are correctly applied during image loading."""
@@ -797,9 +799,9 @@ def test_image_normalisation(self, test_config):
# Test with no normalisation (default)
_update_config(test_config, normalisation_method=NormalisationMethod.CONVERSION_ONLY)
img_none = _load_image_with_fitsbolt(test_path, cfg=test_config)
- assert np.array_equal(
- img_none, test_values
- ), "NONE normalisation should preserve original values"
+ assert np.array_equal(img_none, test_values), (
+ "NONE normalisation should preserve original values"
+ )
# Test with LOG normalisation
_update_config(test_config, normalisation_method=NormalisationMethod.LOG)
@@ -815,9 +817,9 @@ def test_image_normalisation(self, test_config):
# Test with ZSCALE normalisation
_update_config(test_config, normalisation_method=NormalisationMethod.ZSCALE)
img_zscale = _load_image_with_fitsbolt(test_path, cfg=test_config)
- assert not np.array_equal(
- img_zscale, test_values
- ), "ZSCALE normalisation should modify values"
+ assert not np.array_equal(img_zscale, test_values), (
+ "ZSCALE normalisation should modify values"
+ )
# ZScale should produce values with reasonable contrast
assert np.min(img_zscale) < np.max(img_zscale), "ZSCALE should preserve contrast"
@@ -829,9 +831,9 @@ def test_image_normalisation(self, test_config):
# Test that all normalised outputs preserve image dimensions
assert img_none.shape == test_values.shape, "NONE normalisation should preserve dimensions"
assert img_log.shape == test_values.shape, "LOG normalisation should preserve dimensions"
- assert (
- img_zscale.shape == test_values.shape
- ), "ZSCALE normalisation should preserve dimensions"
+ assert img_zscale.shape == test_values.shape, (
+ "ZSCALE normalisation should preserve dimensions"
+ )
def test_image_interpolation_orders(self, test_config):
"""Test different interpolation orders when resizing images.
@@ -892,9 +894,9 @@ def test_image_interpolation_orders(self, test_config):
100,
3,
), f"Resized small image should be 100x100 with order {order}"
- assert (
- resized_small.dtype == np.uint8
- ), f"Resized small image should be uint8 with order {order}"
+ assert resized_small.dtype == np.uint8, (
+ f"Resized small image should be uint8 with order {order}"
+ )
# Resize large image (200x200 → 100x100) - downsampling
resized_large = _load_image_with_fitsbolt(large_path, cfg=test_config)
@@ -903,9 +905,9 @@ def test_image_interpolation_orders(self, test_config):
100,
3,
), f"Resized large image should be 100x100 with order {order}"
- assert (
- resized_large.dtype == np.uint8
- ), f"Resized large image should be uint8 with order {order}"
+ assert resized_large.dtype == np.uint8, (
+ f"Resized large image should be uint8 with order {order}"
+ )
# Check the center pixel of each quadrant for both resized images
# Allow for some variation (±20%) in color values due to interpolation differences
@@ -931,9 +933,9 @@ def test_image_interpolation_orders(self, test_config):
)
else:
# For zero values, small absolute threshold
- assert (
- small_color[c] <= 50
- ), f"Small image order {order}, quadrant {idx}, channel {c}: expected ~0, got {small_color[c]}"
+ assert small_color[c] <= 50, (
+ f"Small image order {order}, quadrant {idx}, channel {c}: expected ~0, got {small_color[c]}"
+ )
# Check large image downsampled
large_color = resized_large[y, x]
@@ -948,9 +950,9 @@ def test_image_interpolation_orders(self, test_config):
)
else:
# For zero values, small absolute threshold
- assert (
- large_color[c] <= 50
- ), f"Large image order {order}, quadrant {idx}, channel {c}: expected ~0, got {large_color[c]}"
+ assert large_color[c] <= 50, (
+ f"Large image order {order}, quadrant {idx}, channel {c}: expected ~0, got {large_color[c]}"
+ )
# Additional check for sharp transitions with order 0 (nearest neighbor)
if order == 0:
@@ -962,16 +964,16 @@ def test_image_interpolation_orders(self, test_config):
# For small image
left_of_boundary_small = resized_small[boundary_y, boundary_x - 1]
right_of_boundary_small = resized_small[boundary_y, boundary_x + 1]
- assert not np.array_equal(
- left_of_boundary_small, right_of_boundary_small
- ), "Small image order 0 should have sharp transitions at boundary"
+ assert not np.array_equal(left_of_boundary_small, right_of_boundary_small), (
+ "Small image order 0 should have sharp transitions at boundary"
+ )
# For large image
left_of_boundary_large = resized_large[boundary_y, boundary_x - 1]
right_of_boundary_large = resized_large[boundary_y, boundary_x + 1]
- assert not np.array_equal(
- left_of_boundary_large, right_of_boundary_large
- ), "Large image order 0 should have sharp transitions at boundary"
+ assert not np.array_equal(left_of_boundary_large, right_of_boundary_large), (
+ "Large image order 0 should have sharp transitions at boundary"
+ )
# Higher order interpolation (order > 1) should lead to smoother transitions
# This is difficult to quantify precisely, but we can check for values between the extremes for upscaling
@@ -986,21 +988,21 @@ def test_image_interpolation_orders(self, test_config):
unique_values_small = np.unique(boundary_region_small)
# Higher order interpolation should have more unique values in the boundary region when upscaling
- assert (
- len(unique_values_small) > 4
- ), f"Small image order {order} should have intermediate values at boundaries"
+ assert len(unique_values_small) > 4, (
+ f"Small image order {order} should have intermediate values at boundaries"
+ )
# Compare results between different interpolation orders to verify they're not identical
# We'll compare order 0 (nearest neighbor) with orders 1, 3, and 5
# These should produce visibly different results
for i, upscaled_im in enumerate(upscaled_results):
if i != 0:
- assert not np.array_equal(
- upscaled_results[0], upscaled_results[i]
- ), "Order 0 and order {i} interpolation should produce different results"
- assert not np.array_equal(
- upscaled_results[i - 1], upscaled_results[i]
- ), f"Order {i - 1} and order {i} interpolation should produce different results"
+ assert not np.array_equal(upscaled_results[0], upscaled_results[i]), (
+ "Order 0 and order {i} interpolation should produce different results"
+ )
+ assert not np.array_equal(upscaled_results[i - 1], upscaled_results[i]), (
+ f"Order {i - 1} and order {i} interpolation should produce different results"
+ )
def test_fits_combination_configurations(self, test_config):
"""Test different configurations of the fits_combination dictionary."""
@@ -1063,3 +1065,35 @@ def test_fits_combination_configurations(self, test_config):
assert np.any(img_two[:, :, 0] > 0) # Red channel should have data
assert np.any(img_two[:, :, 1] > 0) # Green channel should have data
assert np.any(img_two[:, :, 2] > 0) # Blue channel should have data
+
+ @pytest.mark.parametrize(
+ "norm_method",
+ [
+ NormalisationMethod.CONVERSION_ONLY,
+ NormalisationMethod.LINEAR,
+ NormalisationMethod.MIDTONES,
+ ],
+ )
+ def test_load_and_process_wrapper_normalisation_methods(self, test_config, norm_method):
+ """Regression test: load_and_process_wrapper must work with all normalisation methods.
+
+ The PIL resize optimization (size=None to fitsbolt) is only safe for
+ CONVERSION_ONLY. Other methods need fitsbolt's skimage resize to keep
+ images as float64 for their normalisation pipelines.
+ """
+ target_size = [64, 64]
+ _update_config(test_config, size=target_size)
+ test_config.normalisation.normalisation_method = norm_method
+ test_config.num_workers = 1
+
+ # Use RGBA gradient images (non-constant channels needed for MIDTONES)
+ filepaths = [self.complex_rgba_path, self.complex_rgba_path]
+ results = _load_multiple_images_with_fitsbolt(filepaths, test_config)
+
+ assert len(results) == 2
+ for fp, img in results:
+ assert isinstance(img, np.ndarray)
+ assert img.shape[:2] == tuple(target_size)
+ assert img.shape[2] == 3
+ assert img.dtype == np.uint8
+ assert np.isfinite(img).all()
diff --git a/tests/fixmatch_test.py b/tests/unit/test_fixmatch.py
similarity index 99%
rename from tests/fixmatch_test.py
rename to tests/unit/test_fixmatch.py
index c426da5..7237b92 100644
--- a/tests/fixmatch_test.py
+++ b/tests/unit/test_fixmatch.py
@@ -6,6 +6,7 @@
# the terms contained in the file 'LICENCE.txt'.
import pytest
import torch
+
from anomaly_match.models.FixMatch import FixMatch
@@ -26,7 +27,6 @@ def __getitem__(self, idx):
class TestFixMatch:
-
@pytest.fixture
def net_builder(self):
"""Simple CNN network builder for testing."""
@@ -58,7 +58,6 @@ def fixmatch_model(self, net_builder):
T=0.5,
p_cutoff=0.95,
lambda_u=1.0,
- hard_label=True,
)
# Set optimizer
optimizer = torch.optim.SGD(model.train_model.parameters(), lr=0.01)
diff --git a/tests/test_image_io.py b/tests/unit/test_image_io.py
similarity index 88%
rename from tests/test_image_io.py
rename to tests/unit/test_image_io.py
index 5b280fb..2092060 100644
--- a/tests/test_image_io.py
+++ b/tests/unit/test_image_io.py
@@ -5,19 +5,20 @@
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
-import pytest
+import copy
import tempfile
-import numpy as np
-import h5py
from pathlib import Path
-from PIL import Image
+
+import h5py
+import numpy as np
+import pytest
import torch
-import copy
+from PIL import Image
from anomaly_match.data_io.load_images import (
+ get_fitsbolt_config,
load_and_process_single_wrapper,
process_single_wrapper,
- get_fitsbolt_config,
)
from anomaly_match.utils.get_default_cfg import get_default_cfg
from prediction_utils import save_results
@@ -288,8 +289,8 @@ def test_image_pipeline_integration(self, test_image, test_config):
def test_prediction_process_integration(self, test_image, test_config):
"""Test integration with prediction processes for different image formats."""
import tempfile
- from pathlib import Path
import time
+ from pathlib import Path
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
@@ -341,6 +342,7 @@ def test_prediction_process_integration(self, test_image, test_config):
# Save config with pickle
import pickle
+
from dotmap import DotMap
with open(config_path, "wb") as f:
@@ -352,47 +354,47 @@ def test_prediction_process_integration(self, test_image, test_config):
reloaded_config = DotMap(reloaded_config)
# Check that critical fields are present
- assert hasattr(
- reloaded_config, "normalisation"
- ), "normalisation field missing from reloaded config"
- assert hasattr(
- reloaded_config.normalisation, "fits_extension"
- ), "fits_extension field missing from reloaded config.normalisation"
- assert hasattr(
- reloaded_config.normalisation, "size"
- ), "size field missing from reloaded config.normalisation"
- assert hasattr(
- reloaded_config.normalisation, "normalisation_method"
- ), "normalisation_method field missing from reloaded config.normalisation"
+ assert hasattr(reloaded_config, "normalisation"), (
+ "normalisation field missing from reloaded config"
+ )
+ assert hasattr(reloaded_config.normalisation, "fits_extension"), (
+ "fits_extension field missing from reloaded config.normalisation"
+ )
+ assert hasattr(reloaded_config.normalisation, "size"), (
+ "size field missing from reloaded config.normalisation"
+ )
+ assert hasattr(reloaded_config.normalisation, "normalisation_method"), (
+ "normalisation_method field missing from reloaded config.normalisation"
+ )
# Test image loading with reloaded config
for test_file in test_files:
try:
loaded_image = _load_image_with_fitsbolt(test_file, reloaded_config)
- assert (
- loaded_image is not None
- ), f"Failed to load {test_file} with reloaded config"
- assert isinstance(
- loaded_image, np.ndarray
- ), f"Loaded image from {test_file} is not a numpy array"
- assert (
- loaded_image.ndim == 3
- ), f"Loaded image from {test_file} should be 3D (HWC)"
- assert (
- loaded_image.shape[2] == 3
- ), f"Loaded image from {test_file} should have 3 channels" # Ensure no NaN/inf values
- assert np.isfinite(
- loaded_image
- ).all(), f"Image from {test_file} contains NaN or inf values"
+ assert loaded_image is not None, (
+ f"Failed to load {test_file} with reloaded config"
+ )
+ assert isinstance(loaded_image, np.ndarray), (
+ f"Loaded image from {test_file} is not a numpy array"
+ )
+ assert loaded_image.ndim == 3, (
+ f"Loaded image from {test_file} should be 3D (HWC)"
+ )
+ assert loaded_image.shape[2] == 3, (
+ f"Loaded image from {test_file} should have 3 channels"
+ ) # Ensure no NaN/inf values
+ assert np.isfinite(loaded_image).all(), (
+ f"Image from {test_file} contains NaN or inf values"
+ )
except Exception as e:
pytest.fail(f"Failed to load {test_file} with reloaded config: {e}")
def test_image_formats_comprehensive(self, test_image, test_config):
"""Test comprehensive image format support."""
+ import gc
import tempfile
- from pathlib import Path
import time
- import gc
+ from pathlib import Path
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
@@ -427,24 +429,24 @@ def test_image_formats_comprehensive(self, test_image, test_config):
if should_succeed:
assert loaded_image is not None, f"Failed to load {filename}"
- assert isinstance(
- loaded_image, np.ndarray
- ), f"Loaded {filename} is not a numpy array"
+ assert isinstance(loaded_image, np.ndarray), (
+ f"Loaded {filename} is not a numpy array"
+ )
# Check dimensions
- assert (
- loaded_image.ndim >= 2
- ), f"Loaded {filename} has insufficient dimensions"
+ assert loaded_image.ndim >= 2, (
+ f"Loaded {filename} has insufficient dimensions"
+ )
# Check data integrity
- assert np.isfinite(
- loaded_image
- ).all(), f"Loaded {filename} contains NaN or inf values"
+ assert np.isfinite(loaded_image).all(), (
+ f"Loaded {filename} contains NaN or inf values"
+ )
# Check data type
- assert (
- loaded_image.dtype == np.uint8
- ), f"Loaded {filename} has wrong dtype: {loaded_image.dtype}"
+ assert loaded_image.dtype == np.uint8, (
+ f"Loaded {filename} has wrong dtype: {loaded_image.dtype}"
+ )
except Exception as e:
if should_succeed:
@@ -455,7 +457,7 @@ def test_image_formats_comprehensive(self, test_image, test_config):
def test_numpy_to_byte_stream_nan_inf_handling(self):
"""Test that numpy_to_byte_stream handles NaN and inf values properly."""
- from anomaly_match.utils.numpy_to_byte_stream import numpy_array_to_byte_stream
+ from anomaly_match_ui.utils.image_utils import numpy_array_to_byte_stream
# Test array with NaN values
array_with_nan = np.array([[1.0, 2.0, np.nan], [4.0, 5.0, 6.0]], dtype=np.float32)
diff --git a/tests/import_test.py b/tests/unit/test_import.py
similarity index 100%
rename from tests/import_test.py
rename to tests/unit/test_import.py
diff --git a/tests/unit/test_label.py b/tests/unit/test_label.py
new file mode 100644
index 0000000..5df1f34
--- /dev/null
+++ b/tests/unit/test_label.py
@@ -0,0 +1,33 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Tests for the Label enum."""
+
+from anomaly_match.datasets.Label import Label
+
+
+class TestLabel:
+ def test_label_values(self):
+ assert Label.UNKNOWN == -1
+ assert Label.NORMAL == 0
+ assert Label.ANOMALY == 1
+
+ def test_label_is_int(self):
+ assert isinstance(Label.NORMAL, int)
+ assert isinstance(Label.ANOMALY, int)
+ assert isinstance(Label.UNKNOWN, int)
+
+ def test_label_from_value(self):
+ assert Label(-1) == Label.UNKNOWN
+ assert Label(0) == Label.NORMAL
+ assert Label(1) == Label.ANOMALY
+
+ def test_label_members(self):
+ members = list(Label)
+ assert len(members) == 3
+ assert Label.UNKNOWN in members
+ assert Label.NORMAL in members
+ assert Label.ANOMALY in members
diff --git a/tests/unit/test_losses.py b/tests/unit/test_losses.py
new file mode 100644
index 0000000..6b9343a
--- /dev/null
+++ b/tests/unit/test_losses.py
@@ -0,0 +1,137 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Tests for accuracy, cross_entropy_loss, and consistency_loss."""
+
+import pytest
+import torch
+
+from anomaly_match.utils.accuracy import accuracy
+from anomaly_match.utils.consistency_loss import consistency_loss
+from anomaly_match.utils.cross_entropy_loss import cross_entropy_loss
+
+
+class TestAccuracy:
+ def test_perfect_predictions(self):
+ output = torch.tensor([[10.0, -10.0], [-10.0, 10.0], [10.0, -10.0]])
+ target = torch.tensor([0, 1, 0])
+ [top1] = accuracy(output, target, topk=(1,))
+ assert top1.item() == 100.0
+
+ def test_zero_accuracy(self):
+ output = torch.tensor([[10.0, -10.0], [-10.0, 10.0]])
+ target = torch.tensor([1, 0])
+ [top1] = accuracy(output, target, topk=(1,))
+ assert top1.item() == 0.0
+
+ def test_partial_accuracy(self):
+ output = torch.tensor([[10.0, -10.0], [-10.0, 10.0], [10.0, -10.0], [-10.0, 10.0]])
+ target = torch.tensor([0, 1, 1, 0])
+ [top1] = accuracy(output, target, topk=(1,))
+ assert top1.item() == 50.0
+
+ def test_batch_size_one(self):
+ output = torch.tensor([[5.0, 1.0]])
+ target = torch.tensor([0])
+ [top1] = accuracy(output, target, topk=(1,))
+ assert top1.item() == 100.0
+
+
+class TestCrossEntropyLoss:
+ def test_hard_labels_shape(self):
+ logits = torch.randn(4, 2)
+ targets = torch.tensor([0, 1, 0, 1])
+ loss = cross_entropy_loss(logits, targets, use_hard_labels=True, reduction="none")
+ assert loss.shape == (4,)
+
+ def test_hard_labels_mean_reduction(self):
+ logits = torch.randn(4, 2)
+ targets = torch.tensor([0, 1, 0, 1])
+ loss = cross_entropy_loss(logits, targets, use_hard_labels=True, reduction="mean")
+ assert loss.ndim == 0 # scalar
+
+ def test_hard_labels_non_negative(self):
+ logits = torch.randn(4, 2)
+ targets = torch.tensor([0, 1, 0, 1])
+ loss = cross_entropy_loss(logits, targets, use_hard_labels=True, reduction="none")
+ assert (loss >= 0).all()
+
+ def test_soft_labels_shape(self):
+ logits = torch.randn(4, 2)
+ targets = torch.tensor([[0.9, 0.1], [0.1, 0.9], [0.7, 0.3], [0.3, 0.7]])
+ loss = cross_entropy_loss(logits, targets, use_hard_labels=False)
+ assert loss.shape == (4,)
+
+ def test_soft_labels_non_negative(self):
+ logits = torch.randn(4, 2)
+ targets = torch.softmax(torch.randn(4, 2), dim=-1)
+ loss = cross_entropy_loss(logits, targets, use_hard_labels=False)
+ assert (loss >= 0).all()
+
+ def test_soft_labels_shape_mismatch_raises(self):
+ logits = torch.randn(4, 2)
+ targets = torch.randn(4, 3)
+ with pytest.raises(AssertionError):
+ cross_entropy_loss(logits, targets, use_hard_labels=False)
+
+
+class TestConsistencyLoss:
+ def test_l2_loss(self):
+ logits_w = torch.randn(4, 2)
+ logits_s = torch.randn(4, 2)
+ loss = consistency_loss(logits_w, logits_s, name="L2")
+ assert loss.ndim == 0 # scalar
+ assert loss >= 0
+
+ def test_ce_loss_returns_tuple(self):
+ logits_w = torch.randn(4, 2)
+ logits_s = torch.randn(4, 2)
+ result = consistency_loss(logits_w, logits_s, name="ce", p_cutoff=0.0)
+ assert isinstance(result, tuple)
+ assert len(result) == 2
+ masked_loss, mask_ratio = result
+ assert masked_loss.ndim == 0
+ assert mask_ratio.ndim == 0
+
+ def test_ce_hard_labels(self):
+ logits_w = torch.randn(4, 2)
+ logits_s = torch.randn(4, 2)
+ masked_loss, mask_ratio = consistency_loss(
+ logits_w, logits_s, name="ce", use_hard_labels=True, p_cutoff=0.0
+ )
+ assert masked_loss >= 0
+ assert mask_ratio == 1.0 # p_cutoff=0 means all samples pass
+
+ def test_ce_soft_labels(self):
+ logits_w = torch.randn(4, 2)
+ logits_s = torch.randn(4, 2)
+ masked_loss, mask_ratio = consistency_loss(
+ logits_w, logits_s, name="ce", use_hard_labels=False, p_cutoff=0.0
+ )
+ assert masked_loss >= 0
+ assert mask_ratio == 1.0
+
+ def test_ce_high_cutoff_masks_all(self):
+ # Uniform logits = 0.5 probability, cutoff at 0.99 should mask all
+ logits_w = torch.zeros(4, 2)
+ logits_s = torch.randn(4, 2)
+ masked_loss, mask_ratio = consistency_loss(logits_w, logits_s, name="ce", p_cutoff=0.99)
+ assert mask_ratio == 0.0
+
+ def test_ce_detaches_weak_logits(self):
+ logits_w = torch.randn(4, 2, requires_grad=True)
+ logits_s = torch.randn(4, 2, requires_grad=True)
+ masked_loss, _ = consistency_loss(logits_w, logits_s, name="ce", p_cutoff=0.0)
+ masked_loss.backward()
+ # Gradient should only flow through logits_s, not logits_w
+ assert logits_w.grad is None
+ assert logits_s.grad is not None
+
+ def test_invalid_loss_name_raises(self):
+ logits_w = torch.randn(4, 2)
+ logits_s = torch.randn(4, 2)
+ with pytest.raises(AssertionError):
+ consistency_loss(logits_w, logits_s, name="invalid")
diff --git a/tests/metadata_test.py b/tests/unit/test_metadata.py
similarity index 99%
rename from tests/metadata_test.py
rename to tests/unit/test_metadata.py
index f9c381f..0245a72 100644
--- a/tests/metadata_test.py
+++ b/tests/unit/test_metadata.py
@@ -8,11 +8,12 @@
import os
import shutil
import tempfile
+
+import numpy as np
import pandas as pd
import pytest
-import numpy as np
-from PIL import Image
from dotmap import DotMap
+from PIL import Image
from anomaly_match.datasets.AnomalyDetectionDataset import AnomalyDetectionDataset
from anomaly_match.pipeline.session import Session
diff --git a/tests/metadata_handler_test.py b/tests/unit/test_metadata_handler.py
similarity index 99%
rename from tests/metadata_handler_test.py
rename to tests/unit/test_metadata_handler.py
index 9e201b8..b0f56ea 100644
--- a/tests/metadata_handler_test.py
+++ b/tests/unit/test_metadata_handler.py
@@ -7,6 +7,7 @@
import os
import tempfile
+
import numpy as np
from anomaly_match.data_io.metadata_handler import MetadataHandler
diff --git a/tests/unit/test_multispectral.py b/tests/unit/test_multispectral.py
new file mode 100644
index 0000000..cc295af
--- /dev/null
+++ b/tests/unit/test_multispectral.py
@@ -0,0 +1,274 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Unit tests for multispectral (N-channel) image support."""
+
+import numpy as np
+import torch
+
+from anomaly_match.datasets.AnomalyDetectionDataset import AnomalyDetectionDataset
+from anomaly_match.datasets.BasicDataset import BasicDataset
+from anomaly_match.datasets.SSL_Dataset import SSL_Dataset
+from anomaly_match.image_processing.transforms import (
+ get_prediction_transforms,
+ get_strong_transforms,
+ get_weak_transforms,
+)
+
+
+class TestMultispectralDataset:
+ """Tests for multispectral dataset loading and handling."""
+
+ def test_dataset_loads_4_channel_images(self, multispectral_config):
+ """Test that AnomalyDetectionDataset correctly loads 4-channel images."""
+ dataset = AnomalyDetectionDataset(multispectral_config)
+
+ assert dataset is not None
+ assert dataset.num_channels == 4, f"Expected 4 channels, got {dataset.num_channels}"
+ assert dataset.size == multispectral_config.normalisation.image_size
+
+ # Check that images in data_dict have correct shape
+ for filename, (image, label) in dataset.data_dict.items():
+ assert image.shape[-1] == 4, f"Image {filename} has wrong channels: {image.shape}"
+ assert image.shape[:2] == tuple(dataset.size), (
+ f"Image {filename} has wrong size: {image.shape}"
+ )
+
+ def test_channel_auto_detection(self, multispectral_config):
+ """Test that n_output_channels is auto-detected from 4-channel images."""
+ # Config starts at default (3), should be updated after dataset load
+ assert multispectral_config.normalisation.n_output_channels == 3
+
+ dataset = AnomalyDetectionDataset(multispectral_config)
+
+ # After loading, config should be updated to 4
+ assert multispectral_config.normalisation.n_output_channels == 4
+ assert dataset.num_channels == 4
+
+ def test_dataset_mean_std_4_channels(self, multispectral_config):
+ """Test that mean and std are computed correctly for 4 channels."""
+ dataset = AnomalyDetectionDataset(multispectral_config)
+
+ assert len(dataset.mean) == 4, f"Mean should have 4 values, got {len(dataset.mean)}"
+ assert len(dataset.std) == 4, f"Std should have 4 values, got {len(dataset.std)}"
+
+ # Check that values are reasonable (between 0 and 1 for normalized)
+ assert all(0 <= m <= 1 for m in dataset.mean), f"Mean values out of range: {dataset.mean}"
+ assert all(0 <= s <= 1 for s in dataset.std), f"Std values out of range: {dataset.std}"
+
+ def test_ssl_dataset_4_channels(self, multispectral_config):
+ """Test SSL_Dataset initialization with 4-channel data."""
+ ssl_dataset = SSL_Dataset(cfg=multispectral_config, train=True)
+ assert ssl_dataset.num_classes == 2
+
+ # Auto-detection happens when dataset is loaded (lazy)
+ labeled_dataset, unlabeled_dataset = ssl_dataset.get_ssl_dset()
+ assert labeled_dataset is not None
+ assert unlabeled_dataset is not None
+ assert multispectral_config.normalisation.n_output_channels == 4
+
+
+class TestMultispectralAugmentation:
+ """Tests for multispectral augmentation compatibility."""
+
+ def test_basic_dataset_4_channel_no_transform(self, multispectral_config):
+ """Test BasicDataset works with 4-channel data without transforms."""
+ # Create sample 4-channel data
+ imgs = np.random.randint(0, 255, (5, 64, 64, 4), dtype=np.uint8)
+ filenames = [f"img_{i}.npy" for i in range(5)]
+ targets = [0, 1, 0, 1, 0]
+
+ transform = get_prediction_transforms(num_channels=4)
+ dataset = BasicDataset(
+ imgs, filenames, targets, num_classes=2, transform=transform, num_channels=4
+ )
+
+ img, target, filename = dataset[0]
+
+ assert isinstance(img, torch.Tensor)
+ assert img.shape[0] == 4, f"Expected 4 channels, got shape {img.shape}"
+ assert img.shape[1:] == (64, 64), f"Expected (64, 64) spatial, got {img.shape}"
+
+ def test_basic_dataset_4_channel_weak_transform(self, multispectral_config):
+ """Test BasicDataset works with 4-channel data and weak transforms."""
+ imgs = np.random.randint(0, 255, (5, 64, 64, 4), dtype=np.uint8)
+ filenames = [f"img_{i}.npy" for i in range(5)]
+ targets = [0, 1, 0, 1, 0]
+
+ transform = get_weak_transforms(num_channels=4)
+ dataset = BasicDataset(
+ imgs, filenames, targets, num_classes=2, transform=transform, num_channels=4
+ )
+
+ img, target, filename = dataset[0]
+
+ assert isinstance(img, torch.Tensor)
+ assert img.shape[0] == 4, f"Expected 4 channels after weak transform, got {img.shape}"
+
+ def test_basic_dataset_4_channel_strong_transform(self, multispectral_config):
+ """Test BasicDataset works with 4-channel data and strong transforms."""
+ imgs = np.random.randint(0, 255, (5, 64, 64, 4), dtype=np.uint8)
+ filenames = [f"img_{i}.npy" for i in range(5)]
+ targets = [0, 1, 0, 1, 0]
+
+ weak_transform = get_weak_transforms(num_channels=4)
+ strong_transform = get_strong_transforms(num_channels=4)
+
+ dataset = BasicDataset(
+ imgs,
+ filenames,
+ targets,
+ num_classes=2,
+ transform=weak_transform,
+ use_strong_transform=True,
+ strong_transform=strong_transform,
+ num_channels=4,
+ )
+
+ weak_img, strong_img, target = dataset[0]
+
+ assert isinstance(weak_img, torch.Tensor)
+ assert isinstance(strong_img, torch.Tensor)
+ assert weak_img.shape[0] == 4, f"Weak augmented should have 4 channels: {weak_img.shape}"
+ assert strong_img.shape[0] == 4, (
+ f"Strong augmented should have 4 channels: {strong_img.shape}"
+ )
+
+
+class TestMultispectralModel:
+ """Tests for multispectral model architecture."""
+
+ def test_network_builder_4_channels(self):
+ """Test that network builder creates model accepting 4 channels."""
+ from anomaly_match.utils.get_net_builder import get_net_builder
+
+ net_builder = get_net_builder("efficientnet-lite0", pretrained=True, in_channels=4)
+ model = net_builder(num_classes=2, in_channels=4)
+
+ # Set to eval mode to avoid batch norm issues with batch size 1
+ model.eval()
+
+ # Create dummy input with 4 channels
+ dummy_input = torch.randn(1, 4, 64, 64)
+ with torch.no_grad():
+ output = model(dummy_input)
+
+ assert output.shape == (1, 2), f"Expected output shape (1, 2), got {output.shape}"
+
+ def test_network_builder_pretrained_weight_transfer(self):
+ """Test that pretrained weights are adapted for 4-channel input by timm."""
+ from anomaly_match.utils.get_net_builder import get_net_builder
+
+ # Get 4-channel pretrained model
+ net_builder_4ch = get_net_builder("efficientnet-lite0", pretrained=True, in_channels=4)
+ model_4ch = net_builder_4ch(num_classes=2, in_channels=4)
+
+ # Verify conv_stem was adapted to 4 input channels
+ assert model_4ch.conv_stem.weight.data.shape[1] == 4, (
+ "conv_stem should have 4 input channels"
+ )
+
+ # Verify weights are not all zeros (pretrained weights were adapted, not just zero-padded)
+ assert model_4ch.conv_stem.weight.data.abs().sum() > 0, "Adapted weights should be non-zero"
+
+
+class TestMultispectralPrediction:
+ """Tests for multispectral prediction pipeline."""
+
+ def test_model_prediction_4_channels(self, multispectral_config):
+ """Test model prediction with 4-channel images."""
+ from anomaly_match.models.FixMatch import FixMatch
+ from anomaly_match.utils.get_net_builder import get_net_builder
+
+ # Build and initialize model
+ net_builder = get_net_builder(
+ multispectral_config.net,
+ pretrained=multispectral_config.pretrained,
+ in_channels=4,
+ )
+ model = FixMatch(
+ net_builder=net_builder,
+ num_classes=2,
+ in_channels=4,
+ ema_m=multispectral_config.ema_m,
+ T=multispectral_config.temperature,
+ p_cutoff=multispectral_config.p_cutoff,
+ lambda_u=multispectral_config.ulb_loss_ratio,
+ )
+
+ # Create test batch
+ test_batch = torch.randn(4, 4, 64, 64) # (batch, channels, H, W)
+
+ # Get predictions
+ model.eval_model.eval()
+ with torch.no_grad():
+ predictions = model.eval_model(test_batch)
+
+ assert predictions.shape == (4, 2), f"Expected (4, 2), got {predictions.shape}"
+ # Check predictions are valid probabilities after softmax
+ probs = torch.softmax(predictions, dim=1)
+ assert torch.allclose(probs.sum(dim=1), torch.ones(4)), "Probabilities should sum to 1"
+
+
+class TestMultispectralDisplay:
+ """Tests for multispectral display functionality."""
+
+ def test_prepare_4_channel_for_display(self):
+ """Test converting 4-channel image to RGB for display."""
+ from anomaly_match_ui.utils.display_transforms import prepare_for_display
+
+ # Create 4-channel image
+ img_4ch = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8)
+
+ # Convert to displayable RGB
+ rgb_img = prepare_for_display(img_4ch)
+
+ assert rgb_img.shape == (64, 64, 3), f"Expected RGB output, got {rgb_img.shape}"
+ assert rgb_img.dtype == np.uint8
+
+ def test_prepare_4_channel_with_rgb_mapping(self):
+ """Test converting 4-channel image with custom RGB channel mapping."""
+ from anomaly_match_ui.utils.display_transforms import prepare_for_display
+
+ # Create 4-channel image
+ img_4ch = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8)
+
+ # Use channels 0, 2, 3 as RGB
+ rgb_mapping = [0, 2, 3]
+ rgb_img = prepare_for_display(img_4ch, rgb_mapping=rgb_mapping)
+
+ assert rgb_img.shape == (64, 64, 3)
+ # Verify mapping is correct
+ assert np.array_equal(rgb_img[:, :, 0], img_4ch[:, :, 0])
+ assert np.array_equal(rgb_img[:, :, 1], img_4ch[:, :, 2])
+ assert np.array_equal(rgb_img[:, :, 2], img_4ch[:, :, 3])
+
+
+class TestMultispectralHDF5:
+ """Tests for HDF5 save/load with multispectral data."""
+
+ def test_hdf5_save_load_4_channels(self, multispectral_config, tmp_path):
+ """Test saving and loading 4-channel dataset to/from HDF5."""
+ dataset = AnomalyDetectionDataset(multispectral_config)
+
+ # Save to HDF5
+ hdf5_path = tmp_path / "test_4ch.hdf5"
+ dataset.save_as_hdf5(str(hdf5_path))
+
+ assert hdf5_path.exists()
+
+ # Load back
+ new_dataset = AnomalyDetectionDataset(multispectral_config)
+ new_dataset.load_from_hdf5(str(hdf5_path))
+
+ # Verify data integrity
+ assert len(new_dataset.data_dict) == len(dataset.data_dict)
+ for filename in dataset.data_dict:
+ if filename in new_dataset.data_dict:
+ orig_img, _ = dataset.data_dict[filename]
+ loaded_img, _ = new_dataset.data_dict[filename]
+ assert orig_img.shape == loaded_img.shape
+ assert np.array_equal(orig_img, loaded_img)
diff --git a/tests/unit/test_prediction_utils.py b/tests/unit/test_prediction_utils.py
new file mode 100644
index 0000000..97f024a
--- /dev/null
+++ b/tests/unit/test_prediction_utils.py
@@ -0,0 +1,117 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+
+import numpy as np
+import pytest
+from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
+
+from anomaly_match.utils.get_default_cfg import get_default_cfg
+from prediction_utils import convert_cutana_cutout, create_cutana_format_cfg
+
+
+@pytest.fixture
+def format_cfg():
+ """Create a CONVERSION_ONLY format config for testing."""
+ cfg = get_default_cfg()
+ cfg.normalisation.image_size = [64, 64]
+ cfg.normalisation.n_output_channels = 3
+ cfg.normalisation.normalisation_method = NormalisationMethod.CONVERSION_ONLY
+ cfg.num_workers = 0
+ return create_cutana_format_cfg(cfg)
+
+
+class TestCreateCutanaFormatCfg:
+ """Tests for create_cutana_format_cfg."""
+
+ def test_returns_config_with_conversion_only(self):
+ cfg = get_default_cfg()
+ cfg.normalisation.image_size = [64, 64]
+ cfg.normalisation.n_output_channels = 3
+ cfg.num_workers = 0
+ result = create_cutana_format_cfg(cfg)
+
+ assert result.fitsbolt_cfg.normalisation_method == NormalisationMethod.CONVERSION_ONLY
+
+ def test_preserves_image_size(self):
+ cfg = get_default_cfg()
+ cfg.normalisation.image_size = [128, 128]
+ cfg.normalisation.n_output_channels = 3
+ cfg.num_workers = 0
+ result = create_cutana_format_cfg(cfg)
+
+ assert result.fitsbolt_cfg.size == [128, 128]
+
+ def test_preserves_output_channels(self):
+ cfg = get_default_cfg()
+ cfg.normalisation.image_size = [64, 64]
+ cfg.normalisation.n_output_channels = 3
+ cfg.num_workers = 0
+ result = create_cutana_format_cfg(cfg)
+
+ assert result.fitsbolt_cfg.n_output_channels == 3
+
+
+class TestConvertCutanaCutout:
+ """Tests for convert_cutana_cutout CHW/HWC handling and format conversion."""
+
+ def test_hwc_input_preserved(self, format_cfg):
+ """HWC input should pass through without transpose."""
+ image = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
+ result = convert_cutana_cutout(image, format_cfg)
+
+ assert result.shape[0] == 64 # H
+ assert result.shape[1] == 64 # W
+
+ def test_chw_input_transposed_to_hwc(self, format_cfg):
+ """CHW input should be transposed to HWC before processing."""
+ # Create a CHW image (3, 64, 64) with distinct channel values
+ image = np.zeros((3, 64, 64), dtype=np.uint8)
+ image[0] = 100 # R channel
+ image[1] = 150 # G channel
+ image[2] = 200 # B channel
+
+ result = convert_cutana_cutout(image, format_cfg)
+
+ # Result should be HWC
+ assert result.ndim == 3
+ assert result.shape[2] == 3 # channels last
+
+ def test_single_channel_chw_transposed(self, format_cfg):
+ """Single-channel CHW (1, H, W) should be detected and transposed."""
+ image = np.random.randint(0, 255, (1, 64, 64), dtype=np.uint8)
+ result = convert_cutana_cutout(image, format_cfg)
+
+ assert result.ndim == 3
+ assert result.shape[2] == 3 # replicated to 3 channels by fitsbolt
+
+ def test_list_input_converted_to_array(self, format_cfg):
+ """Non-ndarray input should be converted before processing."""
+ image_list = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8).tolist()
+ result = convert_cutana_cutout(image_list, format_cfg)
+
+ assert isinstance(result, np.ndarray)
+
+ def test_float_input_converted_to_uint8(self, format_cfg):
+ """Float input (from cutana normalisation) should be converted to uint8."""
+ image = np.random.random((64, 64, 3)).astype(np.float32)
+ result = convert_cutana_cutout(image, format_cfg)
+
+ assert result.dtype == np.uint8
+
+ def test_output_matches_configured_size(self):
+ """Output should be resized to the configured image_size."""
+ cfg = get_default_cfg()
+ cfg.normalisation.image_size = [32, 32]
+ cfg.normalisation.n_output_channels = 3
+ cfg.num_workers = 0
+ small_cfg = create_cutana_format_cfg(cfg)
+
+ image = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
+ result = convert_cutana_cutout(image, small_cfg)
+
+ assert result.shape[0] == 32
+ assert result.shape[1] == 32
diff --git a/tests/test_session_io_handler.py b/tests/unit/test_session_io_handler.py
similarity index 99%
rename from tests/test_session_io_handler.py
rename to tests/unit/test_session_io_handler.py
index 0d6e47b..f0f4e1e 100644
--- a/tests/test_session_io_handler.py
+++ b/tests/unit/test_session_io_handler.py
@@ -5,15 +5,16 @@
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
-import pytest
-import tempfile
-import shutil
import json
import pickle
-import pandas as pd
+import shutil
+import tempfile
from pathlib import Path
from unittest.mock import patch
+import pandas as pd
+import pytest
+
from anomaly_match.data_io.SessionIOHandler import SessionIOHandler, print_session
from anomaly_match.pipeline.SessionTracker import SessionTracker
diff --git a/tests/test_session_tracker.py b/tests/unit/test_session_tracker.py
similarity index 99%
rename from tests/test_session_tracker.py
rename to tests/unit/test_session_tracker.py
index 7379e04..594c49c 100644
--- a/tests/test_session_tracker.py
+++ b/tests/unit/test_session_tracker.py
@@ -5,12 +5,13 @@
# this file, may be copied, modified, propagated, or distributed except according to
# the terms contained in the file 'LICENCE.txt'.
-import pytest
import datetime
from unittest.mock import patch
+
import pandas as pd
+import pytest
-from anomaly_match.pipeline.SessionTracker import SessionTracker, IterationInfo
+from anomaly_match.pipeline.SessionTracker import IterationInfo, SessionTracker
class TestIterationInfo:
diff --git a/tests/unit/test_shorten_filename.py b/tests/unit/test_shorten_filename.py
new file mode 100644
index 0000000..61ca24e
--- /dev/null
+++ b/tests/unit/test_shorten_filename.py
@@ -0,0 +1,60 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+from anomaly_match_ui.widget import shorten_filename
+
+
+class TestShortenFilename:
+ """Tests for the shorten_filename helper function."""
+
+ def test_short_filename_unchanged(self):
+ """Filenames within max length should remain unchanged."""
+ assert shorten_filename("short.fits", max_length=25) == "short.fits"
+ assert shorten_filename("image.jpg", max_length=25) == "image.jpg"
+
+ def test_long_filename_shortened(self):
+ """Long filenames should be shortened to max_length."""
+ long_name = "very_long_filename_that_exceeds_limit.fits"
+ result = shorten_filename(long_name, max_length=25)
+ assert len(result) <= 25
+ assert result.endswith(".fits")
+ assert "..." in result
+
+ def test_filename_with_multiple_dots(self):
+ """Filenames with multiple dots should preserve only the extension."""
+ name = "image.2024.01.15.observation.fits"
+ result = shorten_filename(name, max_length=25)
+ assert len(result) <= 25
+ assert result.endswith(".fits")
+ assert "..." in result
+
+ def test_filename_without_extension(self):
+ """Filenames without extension should still be shortened correctly."""
+ name = "very_long_filename_without_any_extension"
+ result = shorten_filename(name, max_length=25)
+ assert len(result) <= 25
+ assert "..." in result
+
+ def test_exact_max_length(self):
+ """Filename exactly at max_length should be unchanged."""
+ name = "exactly_25_chars_long.fit"
+ assert len(name) == 25
+ assert shorten_filename(name, max_length=25) == name
+
+ def test_very_short_max_length(self):
+ """Very short max_length should still produce valid output."""
+ name = "some_filename.fits"
+ result = shorten_filename(name, max_length=10)
+ assert len(result) <= 10
+ assert "..." in result
+
+ def test_preserves_start_and_end(self):
+ """Shortened name should contain parts of the original start and end."""
+ name = "START_middle_content_END.fits"
+ result = shorten_filename(name, max_length=20)
+ assert result.startswith("START")
+ # Should contain some part of the end before the extension
+ assert "END" in result or "..." in result
diff --git a/tests/test_toml_config.py b/tests/unit/test_toml_config.py
similarity index 97%
rename from tests/test_toml_config.py
rename to tests/unit/test_toml_config.py
index b057a57..fe9a4d5 100644
--- a/tests/test_toml_config.py
+++ b/tests/unit/test_toml_config.py
@@ -6,15 +6,16 @@
# the terms contained in the file 'LICENCE.txt'.
import tempfile
-import toml
from pathlib import Path
+
+import toml
from dotmap import DotMap
+from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
from anomaly_match.data_io.save_config import (
- save_config_toml,
_convert_enum_to_string,
+ save_config_toml,
)
-from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod
class TestTOMLConfigUtils:
@@ -134,6 +135,6 @@ def test_config_types_and_structure(self):
assert isinstance(config.name, str)
# Verify image_size is NOT in default config (user must set it)
- assert (
- "image_size" not in config.normalisation
- ), "image_size should not have a default value"
+ assert "image_size" not in config.normalisation, (
+ "image_size should not have a default value"
+ )
diff --git a/tests/unit/test_transforms.py b/tests/unit/test_transforms.py
new file mode 100644
index 0000000..f1de876
--- /dev/null
+++ b/tests/unit/test_transforms.py
@@ -0,0 +1,139 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Tests for image transformation pipelines."""
+
+import numpy as np
+import torch
+
+from anomaly_match.image_processing.transforms import (
+ NumpyRandomHorizontalFlip,
+ NumpyRandomTranslate,
+ NumpyToTensor,
+ get_prediction_transforms,
+ get_strong_transforms,
+ get_weak_transforms,
+)
+
+
+class TestNumpyToTensor:
+ def test_converts_hwc_to_chw(self):
+ img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
+ tensor = NumpyToTensor()(img)
+ assert isinstance(tensor, torch.Tensor)
+ assert tensor.shape == (3, 64, 64)
+
+ def test_normalizes_to_float(self):
+ img = np.full((32, 32, 3), 255, dtype=np.uint8)
+ tensor = NumpyToTensor()(img)
+ assert tensor.dtype == torch.float32
+ assert torch.allclose(tensor, torch.ones(3, 32, 32))
+
+ def test_handles_4_channels(self):
+ img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8)
+ tensor = NumpyToTensor()(img)
+ assert tensor.shape == (4, 64, 64)
+
+ def test_passthrough_non_numpy(self):
+ tensor = torch.randn(3, 64, 64)
+ result = NumpyToTensor()(tensor)
+ assert torch.equal(result, tensor)
+
+
+class TestNumpyRandomHorizontalFlip:
+ def test_output_shape_preserved(self):
+ img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
+ result = NumpyRandomHorizontalFlip(p=1.0)(img)
+ assert result.shape == img.shape
+
+ def test_flip_with_p_one(self):
+ img = np.zeros((4, 4, 1), dtype=np.uint8)
+ img[0, 0, 0] = 255
+ result = NumpyRandomHorizontalFlip(p=1.0)(img)
+ assert result[0, 3, 0] == 255
+
+ def test_no_flip_with_p_zero(self):
+ img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
+ result = NumpyRandomHorizontalFlip(p=0.0)(img)
+ assert np.array_equal(result, img)
+
+ def test_handles_tensor(self):
+ tensor = torch.randn(3, 64, 64)
+ result = NumpyRandomHorizontalFlip(p=1.0)(tensor)
+ assert isinstance(result, torch.Tensor)
+ assert result.shape == tensor.shape
+
+
+class TestNumpyRandomTranslate:
+ def test_output_shape_preserved(self):
+ img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
+ result = NumpyRandomTranslate()(img)
+ assert result.shape == img.shape
+
+ def test_handles_tensor(self):
+ tensor = torch.randn(3, 64, 64)
+ result = NumpyRandomTranslate()(tensor)
+ assert isinstance(result, torch.Tensor)
+ assert result.shape == tensor.shape
+
+
+class TestGetWeakTransforms:
+ def test_rgb_returns_compose(self):
+ transform = get_weak_transforms(num_channels=3)
+ assert transform is not None
+
+ def test_rgb_output_tensor(self):
+ transform = get_weak_transforms(num_channels=3)
+ img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
+ result = transform(img)
+ assert isinstance(result, torch.Tensor)
+ assert result.shape[0] == 3
+
+ def test_multispectral_output_tensor(self):
+ transform = get_weak_transforms(num_channels=4)
+ img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8)
+ result = transform(img)
+ assert isinstance(result, torch.Tensor)
+ assert result.shape[0] == 4
+
+
+class TestGetPredictionTransforms:
+ def test_rgb_output_tensor(self):
+ transform = get_prediction_transforms(num_channels=3)
+ img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
+ result = transform(img)
+ assert isinstance(result, torch.Tensor)
+ assert result.shape == (3, 64, 64)
+
+ def test_multispectral_output_tensor(self):
+ transform = get_prediction_transforms(num_channels=4)
+ img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8)
+ result = transform(img)
+ assert isinstance(result, torch.Tensor)
+ assert result.shape == (4, 64, 64)
+
+
+class TestGetStrongTransforms:
+ def test_rgb_returns_compose(self):
+ transform = get_strong_transforms(num_channels=3)
+ assert transform is not None
+
+ def test_rgb_output_tensor(self):
+ from PIL import Image
+
+ transform = get_strong_transforms(num_channels=3)
+ # RandAugment for RGB expects PIL Image input
+ img = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8))
+ result = transform(img)
+ assert isinstance(result, torch.Tensor)
+ assert result.shape[0] == 3
+
+ def test_multispectral_output_tensor(self):
+ transform = get_strong_transforms(num_channels=4)
+ img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8)
+ result = transform(img)
+ assert isinstance(result, torch.Tensor)
+ assert result.shape[0] == 4
diff --git a/tests/utils_test.py b/tests/unit/test_utils.py
similarity index 99%
rename from tests/utils_test.py
rename to tests/unit/test_utils.py
index e80d8eb..ab7aac3 100644
--- a/tests/utils_test.py
+++ b/tests/unit/test_utils.py
@@ -6,9 +6,10 @@
# the terms contained in the file 'LICENCE.txt'.
import torch
import torch.optim as optim
+
from anomaly_match.utils.get_cosine_schedule_with_warmup import get_cosine_schedule_with_warmup
-from anomaly_match.utils.get_optimizer import get_optimizer
from anomaly_match.utils.get_net_builder import get_net_builder
+from anomaly_match.utils.get_optimizer import get_optimizer
from anomaly_match.utils.set_seeds import set_seeds
diff --git a/tests/unit/test_validate_config.py b/tests/unit/test_validate_config.py
new file mode 100644
index 0000000..0af0adc
--- /dev/null
+++ b/tests/unit/test_validate_config.py
@@ -0,0 +1,285 @@
+# Copyright (c) European Space Agency, 2025.
+#
+# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which
+# is part of this source code package. No part of the package, including
+# this file, may be copied, modified, propagated, or distributed except according to
+# the terms contained in the file 'LICENCE.txt'.
+"""Tests for configuration validation edge cases."""
+
+import pytest
+from loguru import logger
+
+from anomaly_match.utils.get_default_cfg import get_default_cfg
+from anomaly_match.utils.validate_config import (
+ _get_all_keys,
+ _get_nested_value,
+ validate_config,
+)
+
+
+@pytest.fixture
+def caplog(caplog):
+ """Configure loguru to use the caplog handler."""
+ handler_id = logger.add(caplog.handler)
+ yield caplog
+ logger.remove(handler_id)
+
+
+@pytest.fixture
+def valid_cfg():
+ """Return a valid default config with image_size set."""
+ cfg = get_default_cfg()
+ cfg.normalisation.image_size = [64, 64]
+ return cfg
+
+
+class TestGetNestedValue:
+ def test_simple_key(self, valid_cfg):
+ assert _get_nested_value(valid_cfg, "seed") == 42
+
+ def test_nested_key(self, valid_cfg):
+ assert _get_nested_value(valid_cfg, "normalisation.image_size") == [64, 64]
+
+ def test_missing_key_raises(self, valid_cfg):
+ with pytest.raises(ValueError, match="Missing key in config"):
+ _get_nested_value(valid_cfg, "nonexistent.key")
+
+
+class TestGetAllKeys:
+ def test_returns_top_level_keys(self, valid_cfg):
+ keys = _get_all_keys(valid_cfg)
+ assert "seed" in keys
+ assert "batch_size" in keys
+ assert "net" in keys
+
+ def test_returns_nested_keys(self, valid_cfg):
+ keys = _get_all_keys(valid_cfg)
+ assert "normalisation" in keys
+ assert "normalisation.image_size" in keys
+ assert "normalisation.n_output_channels" in keys
+
+
+class TestValidateConfigRequired:
+ def test_missing_required_string(self, valid_cfg):
+ del valid_cfg["name"]
+ with pytest.raises(ValueError, match="Missing required parameter"):
+ validate_config(valid_cfg)
+
+ def test_missing_required_integer(self, valid_cfg):
+ del valid_cfg["batch_size"]
+ with pytest.raises(ValueError, match="Missing required parameter"):
+ validate_config(valid_cfg)
+
+
+class TestValidateConfigTypes:
+ def test_string_type_mismatch(self, valid_cfg):
+ valid_cfg.name = 123
+ with pytest.raises(ValueError, match="must be a string"):
+ validate_config(valid_cfg)
+
+ def test_int_type_mismatch(self, valid_cfg):
+ valid_cfg.batch_size = "not_an_int"
+ with pytest.raises(ValueError, match="must be an integer"):
+ validate_config(valid_cfg)
+
+ def test_float_type_mismatch(self, valid_cfg):
+ valid_cfg.test_ratio = "not_a_float"
+ with pytest.raises(ValueError, match="must be a number"):
+ validate_config(valid_cfg)
+
+ def test_bool_type_mismatch(self, valid_cfg):
+ valid_cfg.pin_memory = "not_a_bool"
+ with pytest.raises(ValueError, match="must be a boolean"):
+ validate_config(valid_cfg)
+
+
+class TestValidateConfigRanges:
+ def test_int_below_minimum(self, valid_cfg):
+ valid_cfg.batch_size = 0
+ with pytest.raises(ValueError, match="must be >= 1"):
+ validate_config(valid_cfg)
+
+ def test_float_below_minimum(self, valid_cfg):
+ valid_cfg.test_ratio = -0.1
+ with pytest.raises(ValueError, match="must be >= 0.0"):
+ validate_config(valid_cfg)
+
+ def test_float_above_maximum(self, valid_cfg):
+ valid_cfg.test_ratio = 1.5
+ with pytest.raises(ValueError, match="must be <= 1.0"):
+ validate_config(valid_cfg)
+
+ def test_n_to_load_below_minimum(self, valid_cfg):
+ valid_cfg.N_to_load = 5
+ with pytest.raises(ValueError, match="must be >= 10"):
+ validate_config(valid_cfg)
+
+
+class TestValidateConfigAllowedValues:
+ def test_invalid_optimizer(self, valid_cfg):
+ valid_cfg.opt = "RMSProp"
+ with pytest.raises(ValueError, match="must be one of"):
+ validate_config(valid_cfg)
+
+ def test_invalid_net(self, valid_cfg):
+ valid_cfg.net = "resnet50"
+ with pytest.raises(ValueError, match="must be one of"):
+ validate_config(valid_cfg)
+
+ def test_valid_optimizer_sgd(self, valid_cfg):
+ valid_cfg.opt = "SGD"
+ validate_config(valid_cfg)
+
+ def test_valid_optimizer_adam(self, valid_cfg):
+ valid_cfg.opt = "Adam"
+ validate_config(valid_cfg)
+
+
+class TestValidateConfigSpecialTypes:
+ def test_invalid_image_size_not_list(self, valid_cfg):
+ valid_cfg.normalisation.image_size = 64
+ with pytest.raises(ValueError, match="must be a list or tuple of length 2"):
+ validate_config(valid_cfg)
+
+ def test_invalid_image_size_wrong_length(self, valid_cfg):
+ valid_cfg.normalisation.image_size = [64, 64, 64]
+ with pytest.raises(ValueError, match="must be a list or tuple of length 2"):
+ validate_config(valid_cfg)
+
+ def test_invalid_eval_iter(self, valid_cfg):
+ valid_cfg.num_eval_iter = 0
+ with pytest.raises(ValueError, match="must be an integer > 0 or -1"):
+ validate_config(valid_cfg)
+
+ def test_valid_eval_iter_negative_one(self, valid_cfg):
+ valid_cfg.num_eval_iter = -1
+ validate_config(valid_cfg)
+
+ def test_valid_eval_iter_positive(self, valid_cfg):
+ valid_cfg.num_eval_iter = 10
+ validate_config(valid_cfg)
+
+ def test_normalisation_not_dotmap(self, valid_cfg):
+ valid_cfg.normalisation = "not_a_dotmap"
+ with pytest.raises(ValueError, match="must be a DotMap"):
+ validate_config(valid_cfg)
+
+
+class TestValidateConfigPaths:
+ def test_skip_path_checks(self, valid_cfg):
+ valid_cfg.data_dir = "/nonexistent/path"
+ validate_config(valid_cfg, check_paths=False)
+
+ def test_invalid_directory_path(self, valid_cfg):
+ valid_cfg.data_dir = "/definitely/nonexistent/path"
+ with pytest.raises(ValueError, match="directory does not exist"):
+ validate_config(valid_cfg, check_paths=True)
+
+ def test_invalid_file_path(self, valid_cfg):
+ valid_cfg.label_file = "/nonexistent/file.csv"
+ with pytest.raises(ValueError, match="file does not exist"):
+ validate_config(valid_cfg, check_paths=True)
+
+
+class TestValidateConfigOptional:
+ def test_optional_none_metadata_file(self, valid_cfg):
+ valid_cfg.metadata_file = None
+ validate_config(valid_cfg)
+
+ def test_optional_none_prediction_search_dir(self, valid_cfg):
+ valid_cfg.prediction_search_dir = None
+ validate_config(valid_cfg)
+
+
+class TestFitsExtensionChannelAutoAdjust:
+ """Regression tests for auto-adjusting n_output_channels from fits_extension."""
+
+ def test_fits_extension_4_auto_adjusts_n_output_channels(self, valid_cfg):
+ """fits_extension=[0,1,2,3] with default n_output_channels=3 should auto-adjust to 4."""
+ import numpy as np
+
+ valid_cfg.normalisation.fits_extension = [0, 1, 2, 3]
+ assert valid_cfg.normalisation.n_output_channels == 3
+ validate_config(valid_cfg)
+ assert valid_cfg.normalisation.n_output_channels == 4
+ assert valid_cfg.num_channels == 4
+ np.testing.assert_array_equal(valid_cfg.normalisation.channel_combination, np.eye(4))
+
+ def test_fits_extension_matching_channels_creates_identity(self, valid_cfg):
+ """fits_extension=[0,1,2] with n_output_channels=3 should create identity matrix."""
+ import numpy as np
+
+ valid_cfg.normalisation.fits_extension = [0, 1, 2]
+ validate_config(valid_cfg)
+ assert valid_cfg.normalisation.n_output_channels == 3
+ np.testing.assert_array_equal(valid_cfg.normalisation.channel_combination, np.eye(3))
+
+ def test_fits_extension_single_no_adjustment(self, valid_cfg):
+ """Single fits_extension should not trigger adjustment."""
+ valid_cfg.normalisation.fits_extension = [0]
+ validate_config(valid_cfg)
+ assert valid_cfg.normalisation.n_output_channels == 3
+
+ def test_channel_combination_provided_no_adjustment(self, valid_cfg):
+ """Explicit channel_combination should prevent auto-adjustment."""
+ import numpy as np
+
+ valid_cfg.normalisation.fits_extension = [0, 1, 2, 3]
+ valid_cfg.normalisation.channel_combination = np.eye(3, 4)
+ validate_config(valid_cfg)
+ assert valid_cfg.normalisation.n_output_channels == 3
+
+ def test_asinh_params_extended_on_adjustment(self, valid_cfg):
+ """Per-channel asinh params should be extended when channels are added."""
+ valid_cfg.normalisation.fits_extension = [0, 1, 2, 3]
+ valid_cfg.normalisation.norm_asinh_scale = [0.7, 0.7, 0.7]
+ valid_cfg.normalisation.norm_asinh_clip = [99.8, 99.8, 99.8]
+ validate_config(valid_cfg)
+ assert len(valid_cfg.normalisation.norm_asinh_scale) == 4
+ assert len(valid_cfg.normalisation.norm_asinh_clip) == 4
+
+ def test_n_output_channels_inferred_from_channel_combination(self, valid_cfg):
+ """n_output_channels should be inferred from channel_combination shape."""
+ import numpy as np
+
+ valid_cfg.normalisation.fits_extension = [0, 1, 2, 3]
+ # User provides 1x4 matrix but doesn't set n_output_channels (defaults to 3)
+ valid_cfg.normalisation.channel_combination = np.array([[0.25, 0.25, 0.25, 0.25]])
+ validate_config(valid_cfg)
+ assert valid_cfg.normalisation.n_output_channels == 1
+ assert valid_cfg.num_channels == 1
+ assert len(valid_cfg.normalisation.norm_asinh_scale) == 1
+ assert len(valid_cfg.normalisation.norm_asinh_clip) == 1
+
+ def test_n_output_channels_inferred_3x4_matrix(self, valid_cfg):
+ """3x4 channel_combination with default n_output_channels=3 should not change."""
+ import numpy as np
+
+ valid_cfg.normalisation.fits_extension = [0, 1, 2, 3]
+ valid_cfg.normalisation.channel_combination = np.eye(3, 4)
+ validate_config(valid_cfg)
+ assert valid_cfg.normalisation.n_output_channels == 3
+
+ def test_n_output_channels_none_with_multi_ext_auto_infers(self, valid_cfg):
+ """n_output_channels=None with multiple fits_extension should auto-infer."""
+ import numpy as np
+
+ valid_cfg.normalisation.fits_extension = [0, 1, 2, 3]
+ valid_cfg.normalisation.n_output_channels = None
+ validate_config(valid_cfg)
+ assert valid_cfg.normalisation.n_output_channels == 4
+ np.testing.assert_array_equal(valid_cfg.normalisation.channel_combination, np.eye(4))
+
+ def test_n_output_channels_none_single_ext_raises(self, valid_cfg):
+ """n_output_channels=None with single fits_extension should raise ValueError."""
+ valid_cfg.normalisation.fits_extension = [0]
+ valid_cfg.normalisation.n_output_channels = None
+ with pytest.raises(ValueError, match="n_output_channels is None"):
+ validate_config(valid_cfg)
+
+
+class TestValidateConfigUnexpectedKeys:
+ def test_warns_on_unexpected_keys(self, valid_cfg, caplog):
+ valid_cfg.unexpected_key = "some_value"
+ validate_config(valid_cfg)
+ assert "Found unexpected keys in config" in caplog.text