diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 4749b27..0000000 --- a/.flake8 +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) European Space Agency, 2025. -# -# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which -# is part of this source code package. No part of the package, including -# this file, may be copied, modified, propagated, or distributed except according to -# the terms contained in the file 'LICENCE.txt'. -[flake8] -max-line-length = 100 -extend-ignore = E203, E501, W503, W293, W291, F541, F841 -exclude = - .git, - __pycache__, - docs/source/conf.py, - old, - build, - examples, - setup.py - dist, - *.egg-info, - .venv, - venv, - env, - .pytest_cache, - htmlcov, - logs, - tmp, - screenshots, - workflow_*.json, - tracking*.json, - conftest.py, - playwright.config.py, - .csv, - paper_scripts \ No newline at end of file diff --git a/.github/workflows/check_headers.yml b/.github/workflows/check_headers.yml index a1f6dde..7a21618 100644 --- a/.github/workflows/check_headers.yml +++ b/.github/workflows/check_headers.yml @@ -61,21 +61,20 @@ jobs: *.py) echo "$PYTHON_HEADER" ;; *.yml|*.yaml) echo "$YML_HEADER" ;; *.md) echo "$MARKDOWN_HEADER" ;; - *.ini|.gitignore|.flake8) echo "$INI_HEADER" ;; + *.ini|.gitignore) echo "$INI_HEADER" ;; *) echo "$INI_HEADER" ;; # Default to # comments esac } # Find all relevant files excluding specified directories PYTHON_FILES=$(find . -name "*.py" -not -path "./datalabs_setup/*" -not -path "./utility_scripts/*" -not -path "./.git/*" -not -path "./__pycache__/*" -not -path "./.*/__pycache__/*") - YML_FILES=$(find . -name "*.yml" -o -name "*.yaml" | grep -v "./.git/") - MARKDOWN_FILES=$(find . -name "*.md" -not -path "./.github/ISSUE_TEMPLATE/*" | grep -v "./.git/") - INI_FILES=$(find . -name "*.ini" | grep -v "./.git/") - GITIGNORE_FILES=$(find . -name ".gitignore" | grep -v "./.git/") - FLAKE8_FILES=$(find . -name ".flake8" | grep -v "./.git/") - + YML_FILES=$(find . -name "*.yml" -o -name "*.yaml" | grep -v "./.git/" || true) + MARKDOWN_FILES=$(find . -name "*.md" -not -path "./.github/ISSUE_TEMPLATE/*" | grep -v "./.git/" || true) + INI_FILES=$(find . -name "*.ini" | grep -v "./.git/" || true) + GITIGNORE_FILES=$(find . -name ".gitignore" | grep -v "./.git/" || true) + # Combine all files - ALL_FILES="$PYTHON_FILES $YML_FILES $MARKDOWN_FILES $INI_FILES $GITIGNORE_FILES $FLAKE8_FILES" + ALL_FILES="$PYTHON_FILES $YML_FILES $MARKDOWN_FILES $INI_FILES $GITIGNORE_FILES" MISSING_HEADERS=() @@ -118,7 +117,7 @@ jobs: done echo "" echo "Please add the correct license header to these files." - echo "File types checked: .py, .yml/.yaml, .md, .ini, .gitignore, .flake8" + echo "File types checked: .py, .yml/.yaml, .md, .ini, .gitignore" echo "Note: Python files exclude datalabs_setup/ and utility_scripts/ directories" exit 1 fi \ No newline at end of file diff --git a/.github/workflows/dead_code.yml b/.github/workflows/dead_code.yml index 5f1d511..e167a4d 100644 --- a/.github/workflows/dead_code.yml +++ b/.github/workflows/dead_code.yml @@ -23,7 +23,7 @@ jobs: - name: Run vulture (100% confidence) run: | echo "Running vulture dead code detection (100% confidence - blocking)..." - vulture anomaly_match/ .vulture_whitelist.py --min-confidence 100 + vulture anomaly_match/ anomaly_match_ui/ .vulture_whitelist.py --min-confidence 100 vulture-warnings: name: Vulture (60% confidence - not required) @@ -41,4 +41,4 @@ jobs: echo "Running vulture dead code detection (60% confidence)..." echo "This check fails if potential dead code is found, but is not required to pass." echo "" - vulture anomaly_match/ .vulture_whitelist.py --min-confidence 60 + vulture anomaly_match/ anomaly_match_ui/ .vulture_whitelist.py --min-confidence 60 diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml index b673ba3..31fd4e6 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting.yml @@ -4,7 +4,7 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -name: Black Format Check +name: Ruff Lint and Format Check on: [pull_request] @@ -13,9 +13,11 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - - uses: psf/black@stable + - uses: astral-sh/ruff-action@v3 with: - options: "--check --line-length 100" + args: "check" src: "." - jupyter: false - version: "24.8.0" \ No newline at end of file + - uses: astral-sh/ruff-action@v3 + with: + args: "format --check" + src: "." \ No newline at end of file diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index e8a92b6..d118ab9 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -15,19 +15,6 @@ on: pull_request: workflow_dispatch: jobs: - lint_flake8: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.11 - uses: actions/setup-python@v2 - with: - python-version: 3.11 - - name: Lint with flake8 - run: | - pip install flake8 - flake8 . --count --show-source --statistics --max-line-length=127 --ignore=E402,W503,E203 build: runs-on: ubuntu-latest timeout-minutes: 10 @@ -36,32 +23,32 @@ jobs: contents: read id-token: write steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python 3.11 - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: 3.11 - - name: Search for severe code errors with flake8 + - name: Search for severe code errors with ruff run: | # stop the build if there are Python syntax errors or undefined names - pip install flake8 - flake8 . --count --select=E9,F63,F7,F82,F541 --show-source --statistics --max-line-length=127 + pip install ruff + ruff check . --select=E9,F63,F7,F82,F541 --output-format=full - name: provision-with-micromamba uses: mamba-org/setup-micromamba@v2 with: environment-file: environment_CI.yml environment-name: am cache-downloads: true + cache-environment: true - name: Test with pytest shell: bash -l {0} run: | set -e set -o pipefail micromamba activate am - micromamba install pytest pytest-cov - pytest --junitxml=pytest.xml --cov-report=xml:coverage.xml --cov-report=term --cov=anomaly_match tests/ | tee pytest-coverage.txt + pytest --durations=20 --junitxml=pytest.xml --cov-report=xml:coverage.xml --cov-report=term --cov=anomaly_match tests/ | tee pytest-coverage.txt echo "===== Coverage report =====" - cat pytest-coverage.txt + cat pytest-coverage.txt - name: Pytest coverage comment uses: MishaKav/pytest-coverage-comment@main if: always() && github.event_name == 'pull_request' @@ -75,4 +62,4 @@ jobs: create-new-comment: false hide-comment: false report-only-changed-files: false - junitxml-path: ./pytest.xml \ No newline at end of file + junitxml-path: ./pytest.xml diff --git a/.gitignore b/.gitignore index 0270407..356818a 100644 --- a/.gitignore +++ b/.gitignore @@ -197,3 +197,9 @@ pytest-coverage.txt pytest.xml # IDE and editor settings .vscode/ +.claude/ +pr.md +# Local Cutana checkout for development +Cutana/ +# Claude Code temp directories +tmpclaude-* diff --git a/.vulture_whitelist.py b/.vulture_whitelist.py index 2682d5d..af316bf 100644 --- a/.vulture_whitelist.py +++ b/.vulture_whitelist.py @@ -21,6 +21,9 @@ # FixMatch class attributes requires_grad # noqa - PyTorch tensor property set to disable gradient for EMA model +# TestCNN - nn.Module.forward() called implicitly by PyTorch +TestCNN.forward # noqa - Called via model(x) in FixMatch training loop + # AnomalyDetectionDataset methods used in tests (tests/dataset_test.py) _read_and_resize_image # noqa - Used in test_read_and_resize_different_formats unlabeled_filepaths # noqa - Used in test_anomaly_detection_dataset_properties @@ -60,5 +63,5 @@ _.benchmark # noqa - torch.backends.cudnn.benchmark attribute # Image processing functions used in prediction scripts (root level, excluded from scan) -process_single_wrapper # noqa - Used in prediction_process_hdf5.py, prediction_process_zarr.py +process_single_wrapper # noqa - Used in prediction_utils.py, prediction_process_hdf5.py _.n_expected_channels # noqa - fitsbolt config attribute set dynamically diff --git a/CHANGELOG.MD b/CHANGELOG.MD index 2af4657..460fa58 100644 --- a/CHANGELOG.MD +++ b/CHANGELOG.MD @@ -5,6 +5,42 @@ [//]: # (this file, may be copied, modified, propagated, or distributed except according to) [//]: # (the terms contained in the file 'LICENCE.txt'.) +## [v1.3.0] – 2026-02-11 + +### Added +- **Multispectral image support** for arbitrary channel count images with configurable `channel_combination` matrices (#255) +- **UI separation** into standalone `anomaly_match_ui` package for cleaner architecture (#257) +- **Flux conversion configuration** (`apply_flux_conversion`) ensuring training/prediction consistency (#272) +- **timm model backend** replacing efficientnet-specific packages for broader model support (#268) +- **`test-cnn` model** for fast unit/integration testing without heavy model downloads (#265) + +### Changed +- **Replaced black and flake8 with ruff** for linting and formatting (#258) +- **Restructured test suite** into unit/integration/e2e/ui directories with pytest markers and CI caching (#265) +- **Deduplicated prediction code** into shared `prediction_utils.py` module (#264) +- **Auto-inference of `n_output_channels`** from `channel_combination` matrix or FITS extension count (#276, #278) +- **PIL resize for CONVERSION_ONLY** normalisation achieving up to 73x faster image loading (#259) +- **Faster catalogue validation** by skipping per-chunk FITS existence checks and using parquet metadata (#277) + +### Fixed +- **Double normalisation** in cutana streaming pipeline (#271) +- **Normalisation consistency** between cutana and training paths with channel_weights passthrough (#274) +- **Session logging** with eager directory creation and per-session log files (#260) +- **Gallery not updating** after prediction chunks complete (#262) +- **Cutana source ID handling** for non-string int64 source_ids (#261) +- **Albumentations 2.0 compatibility** renaming deprecated `mode` to `border_mode` (#276) +- **Prediction progress bar** with phase tracking for better user feedback (#276) +- **Identity channel_combination** auto-creation for multi-extension FITS configs (#278) +- **ASinh parameters** missing in cutana format config (#279) +- **Filter name resolution** from catalogue for cutana streaming (#281) +- **Primary HDU validation** raising ValueError when no image data found (#271) + +### Documentation +- **Normalisation README** with improved channel_combination and flux conversion documentation (#275) +- **Auto-inference documentation** updating multispectral config examples (#284) + +--- + ## [v1.2.0] – 2025-01-13 ### Added diff --git a/README.md b/README.md index a5f3954..3f5faf3 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ High-performance semi-supervised anomaly detection with active learning - [AnomalyMatch](#anomalymatch) - [Table of Contents](#table-of-contents) - [Overview](#overview) + - [Ecosystem](#ecosystem) - [Requirements](#requirements) - [Installation](#installation) - [Session Tracking](#session-tracking) @@ -22,7 +23,10 @@ High-performance semi-supervised anomaly detection with active learning - [Zarr File Requirements](#zarr-file-requirements) - [Creating Zarr Files](#creating-zarr-files) - [Zarr Configuration](#zarr-configuration) + - [Multiple Zarr Files for Prediction](#multiple-zarr-files-for-prediction) - [FITS File Handling](#fits-file-handling) + - [Multispectral Support](#multispectral-support) + - [Cutana Streaming Integration](#cutana-streaming-integration) - [Normalisation and Stretching](#normalisation-and-stretching) - [Key Config Parameters](#key-config-parameters) - [Advanced CFG Parameters](#advanced-cfg-parameters) @@ -34,8 +38,8 @@ High-performance semi-supervised anomaly detection with active learning ## Overview -This package uses a FixMatch pipeline built on EfficientNet models and provides a mechanism -for active learning to detect anomalies in images. It also offers a GUI via ipywidgets for labelling and managing the detection process, including the ability to unlabel previously labelled images. +This package uses a FixMatch pipeline built on EfficientNet models (via [timm](https://github.com/huggingface/pytorch-image-models)) and provides a mechanism +for active learning to detect anomalies in images. It also offers a GUI (in the separate `anomaly_match_ui` package) for labelling and managing the detection process, including the ability to unlabel previously labelled images. AnomalyMatch is available plug-and-play on GPUs in [ESA Datalabs](https://datalabs.esa.int/), providing seamless access to high-performance computing resources for large-scale anomaly detection tasks. @@ -43,6 +47,30 @@ For detailed information about the method and its applications, see our papers: - [AnomalyMatch: Discovering Rare Objects of Interest with Semi-supervised and Active Learning](https://arxiv.org/abs/2505.03509) - describing the method in detail - [Identifying Astrophysical Anomalies in 99.6 Million Cutouts from the Hubble Legacy Archive Using AnomalyMatch](https://arxiv.org/abs/2505.03508) - describing a scaled-up search through 100M cutouts European Space Agency, 2025.) +### Ecosystem + +AnomalyMatch relies on two companion libraries for image loading and normalisation: + +``` + ┌──────────────┐ + │ AnomalyMatch │ + └──┬───────┬───┘ + Training/ │ │ Streaming + Prediction │ │ Prediction + ▼ ▼ + ┌─────────┐ ┌────────┐ + │ fitsbolt │ │ cutana │ + └─────────┘ └───┬────┘ + ▲ │ + └───────────┘ + (normalisation) +``` + +- **[fitsbolt](https://github.com/esa/fitsbolt)** handles FITS/image loading and normalisation (stretching, channel combination, dtype conversion). +- **[Cutana](https://github.com/esa/cutana)** orchestrates cutout extraction from FITS tiles and delegates normalisation to fitsbolt. + +Because both the training and Cutana streaming paths use fitsbolt for normalisation, results are guaranteed to be consistent. + ## Requirements Dependencies are listed in the `environment.yml` file. To leverage the full capabilities of this package (especially training on large images or predicting over large image datasets), a GPU is strongly recommended. @@ -239,6 +267,62 @@ prediction_search_dir/ When working with FITS files containing multiple images or data products, specify which extension(s) to use in the configuration. +### Multispectral Support + +AnomalyMatch supports training and prediction on images with arbitrary channel counts (1 to N channels), not just RGB (3 channels). This is useful for multispectral astronomical data. + +**Configuration for N-channel images:** + +```python +import anomaly_match as am + +cfg = am.get_default_cfg() + +# For FITS files with multiple extensions as channels +cfg.normalisation.fits_extension = ["VIS", "NIR-H", "NIR-J", "NIR-Y"] + +# Asinh normalisation parameters (one per channel) +cfg.normalisation.norm_asinh_scale = [0.7, 0.7, 0.7, 0.7] +cfg.normalisation.norm_asinh_clip = [99.8, 99.8, 99.8, 99.8] +``` + +`n_output_channels` and `channel_combination` are automatically inferred from `fits_extension`: when multiple extensions are specified without an explicit `channel_combination`, an identity matrix is created and `n_output_channels` is set to match. Per-channel asinh parameters are also extended automatically if needed. + +**Combining extensions into fewer channels with `channel_combination`:** + +When you have more FITS extensions than desired output channels, use `channel_combination` to define a linear mapping. It is a NumPy array of shape `(n_output_channels, n_extensions)`. `n_output_channels` is automatically inferred from the matrix shape: + +```python +import numpy as np + +# 4 FITS extensions → 3 RGB output channels +cfg.normalisation.fits_extension = ["VIS", "NIR-H", "NIR-J", "NIR-Y"] +cfg.normalisation.channel_combination = np.array([ + [1, 0, 0, 0], # R = VIS + [0, 0.5, 0.5, 0], # G = average of NIR-H and NIR-J + [0, 0, 0, 1], # B = NIR-Y +]) +``` + +Each row defines one output channel as a weighted sum of the input extensions. `n_output_channels` is set to the number of rows in the matrix. When `channel_combination` is `None` (default), an identity matrix is created automatically for multi-extension configs. + +**Supported formats for N-channel data:** +- **NumPy arrays (`.npy`)**: Shape `(H, W, C)` where C is the number of channels +- **FITS files**: Multiple extensions combined as channels +- **HDF5/Zarr**: Arrays with shape `(N, H, W, C)` + +**UI Channel Mapping:** + +For images with more than 3 channels, the UI provides RGB mapping dropdowns to select which 3 channels to display as Red, Green, and Blue. This allows visual inspection of different channel combinations without affecting the training data. + +**Model Architecture:** + +When using pretrained models (default), AnomalyMatch automatically adapts the first convolutional layer for N-channel input: +- The first 3 channels use the pretrained RGB weights +- Additional channels are initialized with averaged RGB weights + +This approach preserves the benefit of pretrained features while supporting arbitrary channel counts. + ### Cutana Streaming Integration AnomalyMatch supports streaming predictions via [Cutana](https://github.com/esa/cutana), which enables on-the-fly cutout extraction from FITS tiles. This is particularly useful for Euclid mission data, which Cutana primarily targets. @@ -251,6 +335,10 @@ AnomalyMatch supports streaming predictions via [Cutana](https://github.com/esa/ **FITS extension configuration:** When using Cutana streaming, ensure `cfg.normalisation.fits_extension` matches the FITS extensions referenced in your catalogue. For multi-band Euclid data, this might be `["VIS", "NIR-H", "NIR-J"]` or similar, depending on your catalogue structure. +**Normalisation consistency:** AnomalyMatch automatically passes the same fitsbolt normalisation configuration to Cutana, so training and streaming prediction produce identically normalised images. If `channel_combination` is set, it is automatically translated to Cutana's expected format. + +**Flux conversion:** Set `cfg.normalisation.apply_flux_conversion = True` when working with Euclid data to convert pixel values to flux density in Jansky using the AB zeropoint (`MAGZERO`) from FITS headers. This is applied consistently in both the training and Cutana prediction paths, before normalisation. + For more details on catalogue format and Cutana configuration, see the [Cutana documentation](https://github.com/esa/cutana). ## Normalisation and Stretching @@ -264,6 +352,10 @@ For more details on catalogue format and Cutana configuration, see the [Cutana d - It currently allows an enum from [NormalisationMethod](anomaly_match/image_processing/NormalisationMethod.py) - Selecting a new [normalisation](anomaly_match/image_processing/normalisation.py) in the dropdown will apply it when training or predicting. For further detail see [Normalisation-Readme](anomaly_match/image_processing/Normalisationreadme.md) +**Normalisation Consistency:** Normalisation settings (method, channel combination, flux conversion, etc.) are saved in the model checkpoint during training. During prediction, these settings are loaded automatically from the checkpoint — there is no need to re-specify them. Both training and Cutana streaming use fitsbolt for normalisation, guaranteeing identical results. + +**Flux Conversion:** For Euclid data, set `cfg.normalisation.apply_flux_conversion = True` to convert pixel values to flux density in Jansky using the AB zeropoint (`MAGZERO`) from FITS headers. This is applied consistently in both training and prediction paths. + ## Key Config Parameters - `save_dir`: Path to store the trained model output. - `data_dir`: Location of the training data (*.jpeg, *.jpg, *.png, *.tif, or *.tiff). @@ -305,7 +397,7 @@ The following advanced parameters can be configured: ### Additional Parameters - `fits_extension`: Extension(s) to use for FITS files, can be int, string, or list of int/string (default: None) -- `fits_combination`: Dictonary with keys `R`,`G`,`B` of lists of length of `fits_extension` denoting how the specified fits_extensions are (linearly) mapped to the R,G,B channels. +- `channel_combination`: NumPy array of shape `(n_output_channels, n_extensions)` defining how FITS extensions are linearly combined into output channels. When `None` (default), an identity matrix is auto-created for multi-extension configs. `n_output_channels` is inferred from the matrix shape when provided. - `interpolation_order`: 0-5 corresponding to [skimage resize interpolation orders](https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.warp) (default: 1 (Bi-linear)) - `normalisation_method`: Normalisation method to be applied during file loading. Can also be selected in the UI dropdown. Correspons to an entry from the class NormalisationMethod (default: `NormalisationMethod.CONVERSION_ONLY`) diff --git a/StarterNotebook.ipynb b/StarterNotebook.ipynb index 9703d1c..add0df9 100644 --- a/StarterNotebook.ipynb +++ b/StarterNotebook.ipynb @@ -91,7 +91,10 @@ "cfg.N_to_load = 100 # Number of unlabeled images loaded into the training dataset at once\n", "\n", "# Set the image size\n", - "cfg.normalisation.image_size = [64, 64] # Dimensions to which images are resized (below 96x96 is not recommended)\n", + "cfg.normalisation.image_size = [\n", + " 64,\n", + " 64,\n", + "] # Dimensions to which images are resized (below 96x96 is not recommended)\n", "\n", "# Set the logger level (options: \"trace\",\"debug\", \"info\", \"warning\", \"error\", \"critical\")\n", "logger_level = \"info\"\n", @@ -108,7 +111,9 @@ "outputs": [], "source": [ "# Start the UI\n", - "session.start_UI()" + "from anomaly_match_ui import start_ui\n", + "\n", + "start_ui(session)" ] }, { @@ -189,4 +194,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/anomaly_match/__init__.py b/anomaly_match/__init__.py index 5d73855..6b08686 100644 --- a/anomaly_match/__init__.py +++ b/anomaly_match/__init__.py @@ -5,13 +5,14 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod + +from .data_io.SessionIOHandler import print_session from .pipeline.session import Session from .utils.get_default_cfg import get_default_cfg -from .utils.set_log_level import set_log_level from .utils.print_cfg import print_cfg -from .data_io.SessionIOHandler import print_session +from .utils.set_log_level import set_log_level -__version__ = "1.2.0" +__version__ = "1.3.0" __all__ = [ "get_default_cfg", diff --git a/anomaly_match/data_io/SessionIOHandler.py b/anomaly_match/data_io/SessionIOHandler.py index ad875d5..86fc0c3 100644 --- a/anomaly_match/data_io/SessionIOHandler.py +++ b/anomaly_match/data_io/SessionIOHandler.py @@ -6,17 +6,17 @@ # the terms contained in the file 'LICENCE.txt'. import json -import pickle import os +import pickle from pathlib import Path -from typing import Dict, Any, Optional, List +from typing import Any, Dict, List, Optional + import pandas as pd -from loguru import logger import torch +from loguru import logger -from anomaly_match.pipeline.SessionTracker import SessionTracker -from anomaly_match.pipeline.SessionTracker import IterationInfo from anomaly_match.data_io.save_config import save_config_toml +from anomaly_match.pipeline.SessionTracker import IterationInfo, SessionTracker class SessionIOHandler: @@ -283,6 +283,7 @@ def save_model(self, model, cfg, session_tracker: SessionTracker = None) -> str: "best_it": model.best_it, "last_normalisation_method": getattr(model, "last_normalisation_method", None), "normalisation_method": cfg.normalisation.normalisation_method, + "num_channels": cfg.num_channels, "fitsbolt_cfg": fitsbolt_cfg, } @@ -400,6 +401,15 @@ def load_model(self, model, cfg, model_path: str = None) -> bool: f"Model loaded with normalisation method: {cfg.normalisation.normalisation_method.name}" ) + # Warn if model was trained with a different number of channels + saved_channels = checkpoint.get("num_channels") + if saved_channels is not None and saved_channels != cfg.num_channels: + logger.warning( + f"Channel mismatch: model was trained with {saved_channels} channels " + f"but current dataset has {cfg.num_channels} channels. " + f"This will likely cause errors." + ) + # Load fitsbolt config if present in checkpoint (DotMap pickles directly) if "fitsbolt_cfg" in checkpoint and checkpoint["fitsbolt_cfg"] is not None: cfg.fitsbolt_cfg = checkpoint["fitsbolt_cfg"] @@ -691,6 +701,9 @@ def update_config_paths_for_session(self, cfg, session_tracker: SessionTracker) """ session_path = self.get_session_save_path(session_tracker) + # Create session directory immediately so logs and outputs have a home + session_path.mkdir(parents=True, exist_ok=True) + # Update model path to session directory only if not already set by user if cfg.model_path is None: cfg.model_path = str(session_path / "model.pth") @@ -701,6 +714,14 @@ def update_config_paths_for_session(self, cfg, session_tracker: SessionTracker) # Update save directory to session directory cfg.save_dir = str(session_path) + # Add session-specific log file + logger.add( + str(session_path / "session.log"), + rotation="10 MB", + format="{time:YYYY-MM-DD HH:mm:ss}|{level}|{message}", + level="DEBUG", + ) + logger.debug(f"Updated config paths to use session directory: {session_path}") diff --git a/anomaly_match/data_io/find_images_in_folder.py b/anomaly_match/data_io/find_images_in_folder.py index 0c04665..7a57243 100644 --- a/anomaly_match/data_io/find_images_in_folder.py +++ b/anomaly_match/data_io/find_images_in_folder.py @@ -7,10 +7,12 @@ """ Functions to retrieve image filenames from folders. """ + import os from pathlib import Path -from loguru import logger + from fitsbolt import SUPPORTED_IMAGE_EXTENSIONS +from loguru import logger def get_image_names_from_folder(folder_path, recursive=False, extensions=None): diff --git a/anomaly_match/data_io/load_images.py b/anomaly_match/data_io/load_images.py index e561eab..a32337e 100644 --- a/anomaly_match/data_io/load_images.py +++ b/anomaly_match/data_io/load_images.py @@ -8,9 +8,90 @@ Functions for loading and processing images as a wrapper around fitsbolt """ +import os + +import numpy as np +from astropy.io import fits as astropy_fits +from cutana.flux_conversion import convert_mosaic_to_flux from dotmap import DotMap -from fitsbolt.image_loader import load_and_process_images, _process_image from fitsbolt.cfg.create_config import create_config as fb_create_cfg +from fitsbolt.image_loader import _process_image, load_and_process_images +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod +from loguru import logger +from PIL import Image as PILImage + +# Exact PIL equivalents for skimage interpolation orders. +# Only orders with a true PIL match are included — others would silently +# produce different results than skimage, causing train/predict divergence. +_PIL_INTERPOLATION = { + 0: PILImage.NEAREST, + 1: PILImage.BILINEAR, + 3: PILImage.BICUBIC, +} + + +def apply_fits_flux_conversion(filepath, zeropoint_keyword="MAGZERO"): + """Read a FITS file and convert pixel values to flux density in Jansky. + + Uses the AB zeropoint from the FITS header and cutana's + ``convert_mosaic_to_flux`` to ensure the training path applies the + exact same conversion as the cutana prediction path. + + Args: + filepath: Path to the FITS file. + zeropoint_keyword: FITS header keyword for the AB zeropoint. + + Returns: + np.ndarray: Flux-converted 2-D image in Jansky (float32). + + Raises: + KeyError: If the zeropoint keyword is missing from the header. + ValueError: If the primary HDU contains no image data. + """ + with astropy_fits.open(filepath) as hdul: + if hdul[0].data is None: + raise ValueError( + f"Primary HDU of '{filepath}' contains no image data. " + f"Some observatories store image data in a different extension " + f"(e.g. hdul[1]). This is not yet supported." + ) + data = hdul[0].data.astype(np.float32) + zeropoint = float(hdul[0].header[zeropoint_keyword]) + return convert_mosaic_to_flux(data, zeropoint) + + +def _pil_resize(image, target_size, interpolation_order=1): + """Resize an HWC or HW numpy array using PIL (much faster than skimage). + + Args: + image: numpy array of shape (H, W) or (H, W, C) + target_size: [height, width] + interpolation_order: 0=nearest, 1=bilinear, 3=bicubic + + Returns: + Resized numpy array with same dtype. + + Raises: + ValueError: If interpolation_order has no exact PIL equivalent. + """ + h, w = target_size + if image.shape[0] == h and image.shape[1] == w: + return image + if interpolation_order not in _PIL_INTERPOLATION: + raise ValueError( + f"interpolation_order={interpolation_order} has no exact PIL equivalent. " + f"Supported orders for PIL resize: {sorted(_PIL_INTERPOLATION.keys())}. " + f"Use one of these or switch to a normalisation method that uses skimage resize." + ) + original_dtype = image.dtype + resample = _PIL_INTERPOLATION[interpolation_order] + pil_img = PILImage.fromarray(image) + # PIL.resize takes (width, height) + pil_img = pil_img.resize((w, h), resample) + result = np.array(pil_img) + if result.dtype != original_dtype: + result = result.astype(original_dtype) + return result def get_fitsbolt_config(cfg, size_override="default"): @@ -36,7 +117,7 @@ def get_fitsbolt_config(cfg, size_override="default"): n_output_channels=cfg.normalisation.n_output_channels, normalisation_method=cfg.normalisation.normalisation_method, channel_combination=cfg.normalisation.channel_combination, - num_workers=cfg.num_workers, + num_workers=max(cfg.num_workers, 1), norm_maximum_value=cfg.normalisation.norm_maximum_value, norm_minimum_value=cfg.normalisation.norm_minimum_value, norm_log_calculate_minimum_value=cfg.normalisation.norm_log_calculate_minimum_value, @@ -50,17 +131,45 @@ def get_fitsbolt_config(cfg, size_override="default"): def load_and_process_wrapper(filepaths, cfg, desc="Loading images", show_progress=True): + target_size = cfg.normalisation.image_size + norm_method = cfg.normalisation.normalisation_method + + # When flux conversion is enabled, load FITS files manually, apply the + # conversion, and normalise via process_single_wrapper per image. + if cfg.normalisation.apply_flux_conversion: + keyword = cfg.normalisation.flux_conversion_zeropoint_keyword + cfg_with_fb = get_fitsbolt_config(cfg) + images_list = [] + for fp in filepaths: + ext = os.path.splitext(fp)[1].lower() + if ext in (".fits", ".fit", ".fts"): + converted = apply_fits_flux_conversion(fp, zeropoint_keyword=keyword) + img = process_single_wrapper(converted, cfg_with_fb, desc=desc) + else: + # Non-FITS files: fall through to standard fitsbolt loading + img = load_and_process_images( + [fp], + cfg=cfg_with_fb.fitsbolt_cfg, + desc=desc, + show_progress=False, + )[0] + images_list.append(img) + return [(fp, img) for fp, img in zip(filepaths, images_list)] + + # PIL resize is only safe for CONVERSION_ONLY: other normalisation methods + # depend on the float64 dtype that fitsbolt's skimage resize produces. + use_pil_resize = norm_method == NormalisationMethod.CONVERSION_ONLY images_list = load_and_process_images( filepaths, cfg=None, output_dtype=cfg.normalisation.output_dtype, - size=cfg.normalisation.image_size, + size=None if use_pil_resize else target_size, fits_extension=cfg.normalisation.fits_extension, interpolation_order=cfg.normalisation.interpolation_order, normalisation_method=cfg.normalisation.normalisation_method, channel_combination=cfg.normalisation.channel_combination, n_output_channels=cfg.normalisation.n_output_channels, - num_workers=cfg.num_workers, + num_workers=max(cfg.num_workers, 1), norm_maximum_value=cfg.normalisation.norm_maximum_value, norm_minimum_value=cfg.normalisation.norm_minimum_value, norm_log_calculate_minimum_value=cfg.normalisation.norm_log_calculate_minimum_value, @@ -70,6 +179,10 @@ def load_and_process_wrapper(filepaths, cfg, desc="Loading images", show_progres desc=desc, show_progress=show_progress, ) + # Resize with PIL (much faster than fitsbolt's skimage resize) + if use_pil_resize and target_size is not None: + interp = cfg.normalisation.interpolation_order + images_list = [_pil_resize(img, target_size, interp) for img in images_list] # return a list of tuples (filename, image) # check that images_list has same length as filepaths if len(images_list) != len(filepaths): @@ -145,3 +258,37 @@ def process_single_wrapper(image, cfg, desc="source"): fitsbolt_cfg.n_expected_channels = image.shape[-1] fitsbolt_cfg.num_workers = 1 # for single image processing, force to 1 return _process_image(image, fitsbolt_cfg, image_source=desc) + + +def detect_num_channels(root_dir, filenames): + """Detect the number of image channels from a sample image file. + + For FITS files, returns None since channel count depends on + fits_extension and channel_combination settings. + + Args: + root_dir: Directory containing the images. + filenames: List of image filenames. + + Returns: + Detected number of channels, or None if detection is not applicable. + """ + if not filenames: + return None + + sample_file = filenames[0] + sample_path = os.path.join(root_dir, sample_file) + ext = os.path.splitext(sample_file)[1].lower() + + # FITS channel count depends on extension/combination config + if ext in (".fits", ".fit", ".fts"): + return None + + try: + img = PILImage.open(sample_path) + num_channels = len(img.getbands()) + logger.debug(f"Detected {num_channels} channels from {sample_file}") + return num_channels + except Exception as e: + logger.warning(f"Could not detect channels from {sample_file}: {e}") + return None diff --git a/anomaly_match/data_io/metadata_handler.py b/anomaly_match/data_io/metadata_handler.py index 6676a94..159c0a3 100644 --- a/anomaly_match/data_io/metadata_handler.py +++ b/anomaly_match/data_io/metadata_handler.py @@ -6,12 +6,12 @@ # the terms contained in the file 'LICENCE.txt'. import os -import pandas as pd -import numpy as np -from loguru import logger -from astropy.coordinates import SkyCoord +import numpy as np +import pandas as pd from astropy import units as u +from astropy.coordinates import SkyCoord +from loguru import logger class MetadataHandler: diff --git a/anomaly_match/data_io/save_config.py b/anomaly_match/data_io/save_config.py index 92d0aca..2014694 100644 --- a/anomaly_match/data_io/save_config.py +++ b/anomaly_match/data_io/save_config.py @@ -5,13 +5,13 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import toml from pathlib import Path -from typing import Union, Dict, Any -from dotmap import DotMap -from loguru import logger +from typing import Any, Dict, Union +import toml +from dotmap import DotMap from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod +from loguru import logger def _critical_optional_fields() -> list: diff --git a/anomaly_match/datasets/AnomalyDetectionDataset.py b/anomaly_match/datasets/AnomalyDetectionDataset.py index 452ba56..d999a74 100644 --- a/anomaly_match/datasets/AnomalyDetectionDataset.py +++ b/anomaly_match/datasets/AnomalyDetectionDataset.py @@ -4,24 +4,26 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import numpy as np import os + import h5py +import numpy as np import pandas as pd -from PIL import Image -from sklearn.model_selection import train_test_split import torch from loguru import logger +from PIL import Image +from sklearn.model_selection import train_test_split -from .Label import Label - +from anomaly_match.data_io.find_images_in_folder import get_image_names_from_folder from anomaly_match.data_io.load_images import ( + detect_num_channels, load_and_process_single_wrapper, load_and_process_wrapper, ) -from anomaly_match.data_io.find_images_in_folder import get_image_names_from_folder from anomaly_match.data_io.metadata_handler import MetadataHandler +from .Label import Label + class AnomalyDetectionDataset(torch.utils.data.Dataset): """AnomalyDetectionDataset for binary classification of normal vs anomaly images.""" @@ -55,7 +57,6 @@ def __init__( # Initialize key variables self.seed = cfg.seed self.size = cfg.normalisation.image_size - self.num_channels = 3 self.root_dir = cfg.data_dir self.transform = transform self.cfg = cfg @@ -81,6 +82,27 @@ def __init__( # Get all filenames first self.all_filenames = get_image_names_from_folder(self.root_dir, recursive=False) + # Auto-detect channel count from images: update config when images have + # more channels than configured to avoid silently discarding channels + detected_channels = detect_num_channels(self.root_dir, self.all_filenames) + if ( + detected_channels is not None + and detected_channels > cfg.normalisation.n_output_channels + ): + logger.info( + f"Detected {detected_channels} channels from images " + f"(config had n_output_channels={cfg.normalisation.n_output_channels}), updating" + ) + old_channels = cfg.normalisation.n_output_channels + cfg.normalisation.n_output_channels = detected_channels + cfg.num_channels = detected_channels + # Extend asinh normalisation lists to match new channel count + for attr in ("norm_asinh_scale", "norm_asinh_clip"): + val = getattr(cfg.normalisation, attr, None) + if isinstance(val, list) and len(val) == old_channels: + cfg.normalisation[attr] = val + [val[-1]] * (detected_channels - old_channels) + self.num_channels = cfg.normalisation.n_output_channels + # If we have fewer than N_to_load images, set N_to_load to the number of images self.N_to_load = min(self.N_to_load, len(self.all_filenames)) @@ -129,10 +151,9 @@ def _load_csv_and_apply_labels(self): assert col in labeled_data.columns, f"CSV file must contain column '{col}'" # Check that labels are valid - assert set(labeled_data["label"].unique()) <= set( - ["normal", "anomaly", "removed"] - ), "Labels should be either 'normal', 'anomaly' or 'removed' but found" + str( - set(labeled_data["label"].unique()) + assert set(labeled_data["label"].unique()) <= set(["normal", "anomaly", "removed"]), ( + "Labels should be either 'normal', 'anomaly' or 'removed' but found" + + str(set(labeled_data["label"].unique())) ) # Label distribution in the new CSV @@ -449,7 +470,9 @@ def load_from_hdf5(self, hdf5_path): # Load data from the compound dataset for entry in f["data"]: filename = entry["filename"].decode("utf-8") # Decode bytes to string - image = np.array(entry["image"]).reshape(self.size + [3]) # Reshape back to image + image = np.array(entry["image"]).reshape( + self.size + [self.num_channels] + ) # Reshape back to image self.data_dict[filename] = (image, Label.UNKNOWN) # Load mean and std if they exist @@ -468,7 +491,7 @@ def _load_labeled_from_hdf5(self): with h5py.File(self.labeled_hdf5, "r") as f: for entry in f["data"]: filename = entry["filename"].decode("utf-8") - image = np.array(entry["image"]).reshape(self.size + [3]) + image = np.array(entry["image"]).reshape(self.size + [self.num_channels]) label = entry["label"] self.data_dict[filename] = (image, label) @@ -487,7 +510,7 @@ def _load_current_unlabeled_batch(self): filename = entry["filename"].decode("utf-8") # Skip if file is now labeled if filename not in self.data_dict: - image = np.array(entry["image"]).reshape(self.size + [3]) + image = np.array(entry["image"]).reshape(self.size + [self.num_channels]) self.data_dict[filename] = (image, Label.UNKNOWN) def get_metadata_for_file(self, filename): diff --git a/anomaly_match/datasets/BasicDataset.py b/anomaly_match/datasets/BasicDataset.py index bdc8255..12cd28c 100644 --- a/anomaly_match/datasets/BasicDataset.py +++ b/anomaly_match/datasets/BasicDataset.py @@ -4,12 +4,11 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -from torchvision import transforms -from torch.utils.data import Dataset - -from PIL import Image import numpy as np import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms from anomaly_match.image_processing.transforms import get_strong_transforms @@ -31,6 +30,7 @@ def __init__( transform=None, use_strong_transform=False, strong_transform=None, + num_channels=3, ): """ Args @@ -71,12 +71,13 @@ def __init__( self.targets = targets.clone().detach() self.num_classes = num_classes + self.num_channels = num_channels self.use_strong_transform = use_strong_transform self.transform = transform if use_strong_transform: if strong_transform is None: - self.strong_transform = get_strong_transforms() + self.strong_transform = get_strong_transforms(num_channels=num_channels) else: self.strong_transform = strong_transform @@ -95,7 +96,19 @@ def __getitem__(self, idx): if self.transform is None: return transforms.ToTensor()(img), target, self.filenames[idx] else: - img = Image.fromarray(img.numpy()) if isinstance(img, torch.Tensor) else img + # Convert to numpy if tensor + if isinstance(img, torch.Tensor): + img_np = img.numpy() + else: + img_np = img + + # For RGB images (3 channels), use PIL-based transforms + # For N-channel images, pass numpy array directly to transforms + if self.num_channels == 3: + img = Image.fromarray(img_np) + else: + img = img_np + img_w = self.transform(img) if not self.use_strong_transform: diff --git a/anomaly_match/datasets/SSL_Dataset.py b/anomaly_match/datasets/SSL_Dataset.py index d848632..c788240 100644 --- a/anomaly_match/datasets/SSL_Dataset.py +++ b/anomaly_match/datasets/SSL_Dataset.py @@ -7,13 +7,14 @@ import torch from loguru import logger -from .BasicDataset import BasicDataset -from .AnomalyDetectionDataset import AnomalyDetectionDataset from anomaly_match.image_processing.transforms import ( - get_weak_transforms, get_prediction_transforms, + get_weak_transforms, ) +from .AnomalyDetectionDataset import AnomalyDetectionDataset +from .BasicDataset import BasicDataset + class SSL_Dataset: """ @@ -65,11 +66,11 @@ def get_data(self): if self.train: filenames, imgs, targets = self.dset.train_data unlabeled, unlabeled_filenames = self.dset.unlabeled - self.transform = get_weak_transforms() + self.transform = get_weak_transforms(num_channels=self.num_channels) else: filenames, imgs, targets = self.dset.test_data unlabeled, unlabeled_filenames = None, None # no unlabeled data in test - self.transform = get_prediction_transforms() + self.transform = get_prediction_transforms(num_channels=self.num_channels) return imgs, targets, unlabeled, filenames, unlabeled_filenames @@ -96,6 +97,7 @@ def get_dset(self, use_strong_transform=False, strong_transform=None): self.transform, use_strong_transform, strong_transform, + num_channels=self.num_channels, ) def get_ssl_dset( @@ -146,6 +148,7 @@ def get_ssl_dset( self.transform, use_strong_transform=False, strong_transform=None, + num_channels=self.num_channels, ) ulb_dset = BasicDataset( @@ -156,6 +159,7 @@ def get_ssl_dset( self.transform, use_strong_transform, strong_transform, + num_channels=self.num_channels, ) return lb_dset, ulb_dset @@ -216,6 +220,7 @@ def update_dsets( self.transform, use_strong_transform=False, strong_transform=None, + num_channels=self.num_channels, ) ulb_dset = BasicDataset( @@ -226,6 +231,7 @@ def update_dsets( self.transform, use_strong_transform, strong_transform, + num_channels=self.num_channels, ) return lb_dset, ulb_dset diff --git a/anomaly_match/datasets/augmentation/randaugment.py b/anomaly_match/datasets/augmentation/randaugment.py index f6abd1c..26efcdf 100644 --- a/anomaly_match/datasets/augmentation/randaugment.py +++ b/anomaly_match/datasets/augmentation/randaugment.py @@ -6,12 +6,12 @@ # the terms contained in the file 'LICENCE.txt'. import random -import PIL -import PIL.ImageOps -import PIL.ImageEnhance -import PIL.ImageDraw import albumentations as A import numpy as np +import PIL +import PIL.ImageDraw +import PIL.ImageEnhance +import PIL.ImageOps def AutoContrast(img, _): diff --git a/anomaly_match/datasets/augmentation/randaugment_multispectral.py b/anomaly_match/datasets/augmentation/randaugment_multispectral.py new file mode 100644 index 0000000..08149bf --- /dev/null +++ b/anomaly_match/datasets/augmentation/randaugment_multispectral.py @@ -0,0 +1,281 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Multispectral-compatible augmentations for arbitrary channel counts. + +This module provides augmentation operations that work with images having +any number of channels (not just RGB). It excludes PIL-dependent operations +like AutoContrast, Brightness, Color, and Contrast that assume RGB images. + +Based on the approach from DistMSMatch/MSMatch. +""" + +import random + +import albumentations as A +import numpy as np +from scipy.ndimage import uniform_filter + + +def Identity(img, v): + """Return the image unchanged. + + Args: + img: numpy array (H, W, C) to be processed + v: Unused parameter + + Returns: + The original image + """ + return img + + +def Posterize(img, v): + """Reduce the number of bits for each channel. + + Args: + img: numpy array (H, W, C) to be processed + v: Number of bits to keep for each channel (range [4, 8]) + + Returns: + numpy array with reduced color depth + """ + v = int(v) + v = max(1, min(8, v)) + shift = 8 - v + return ((img >> shift) << shift).astype(np.uint8) + + +def Rotate(img, v): + """Rotate the image by v degrees. + + Args: + img: numpy array (H, W, C) to be processed + v: Rotation angle in degrees (range [-30, 30]) + + Returns: + Rotated numpy array + """ + transform = A.Rotate(limit=(v, v), border_mode=0, p=1.0) + return transform(image=img)["image"] + + +def Sharpness(img, v): + """Adjust the sharpness of the image using unsharp masking. + + Args: + img: numpy array (H, W, C) to be processed + v: Sharpness factor (range [0.05, 0.95]) + + Returns: + numpy array with adjusted sharpness + """ + # Apply per-channel sharpening using unsharp mask + blurred = uniform_filter(img.astype(np.float32), size=(3, 3, 1)) + sharpened = img.astype(np.float32) + v * (img.astype(np.float32) - blurred) + return np.clip(sharpened, 0, 255).astype(np.uint8) + + +def ShearX(img, v): + """Apply horizontal shear to the image. + + Args: + img: numpy array (H, W, C) to be processed + v: Shear factor (range [-0.3, 0.3]) + + Returns: + numpy array with horizontal shear applied + """ + transform = A.Affine(shear={"x": (v * 45, v * 45), "y": (0, 0)}, border_mode=0, p=1.0) + return transform(image=img)["image"] + + +def ShearY(img, v): + """Apply vertical shear to the image. + + Args: + img: numpy array (H, W, C) to be processed + v: Shear factor (range [-0.3, 0.3]) + + Returns: + numpy array with vertical shear applied + """ + transform = A.Affine(shear={"x": (0, 0), "y": (v * 45, v * 45)}, border_mode=0, p=1.0) + return transform(image=img)["image"] + + +def TranslateX(img, v): + """Translate the image horizontally by a percentage of its width. + + Args: + img: numpy array (H, W, C) to be processed + v: Translation factor as a percentage of image width (range [-0.3, 0.3]) + + Returns: + numpy array with horizontal translation applied + """ + transform = A.Affine(translate_percent={"x": (v, v), "y": (0, 0)}, border_mode=0, p=1.0) + return transform(image=img)["image"] + + +def TranslateY(img, v): + """Translate the image vertically by a percentage of its height. + + Args: + img: numpy array (H, W, C) to be processed + v: Translation factor as a percentage of image height (range [-0.3, 0.3]) + + Returns: + numpy array with vertical translation applied + """ + transform = A.Affine(translate_percent={"x": (0, 0), "y": (v, v)}, border_mode=0, p=1.0) + return transform(image=img)["image"] + + +def Solarize(img, v): + """Invert all pixel values above a threshold. + + Args: + img: numpy array (H, W, C) to be processed + v: Threshold for solarization (range [0, 256]) + + Returns: + Solarized numpy array + """ + v = int(v) + return np.where(img >= v, 255 - img, img).astype(np.uint8) + + +def Cutout(img, v, num_channels=None): + """Apply cutout augmentation to the image. + + Creates a square mask at a random location in the image. + + Args: + img: numpy array (H, W, C) to be processed + v: Size of the cutout as a percentage of image size (range [0.0, 0.5]) + num_channels: Number of channels (inferred from img if not provided) + + Returns: + numpy array with cutout applied + """ + if v <= 0.0: + return img + + h, w = img.shape[:2] + if num_channels is None: + num_channels = img.shape[2] if len(img.shape) > 2 else 1 + + cutout_size = int(v * min(h, w)) + if cutout_size <= 0: + return img + + # Random position for cutout + x0 = np.random.randint(0, max(1, w - cutout_size + 1)) + y0 = np.random.randint(0, max(1, h - cutout_size + 1)) + + # Apply cutout with gray fill (128 for all channels) + img_out = img.copy() + img_out[y0 : y0 + cutout_size, x0 : x0 + cutout_size, :] = 128 + + return img_out + + +def multispectral_augment_list(): + """Return a list of available augmentation operations with their value ranges. + + These operations work with arbitrary channel counts, excluding PIL-dependent + operations (AutoContrast, Brightness, Color, Contrast, Equalize) that assume RGB. + + Returns: + List of tuples (operation, min_value, max_value) for each augmentation + """ + return [ + (Identity, 0, 1), + (Posterize, 4, 8), + (Rotate, -30, 30), + (Sharpness, 0.05, 0.95), + (ShearX, -0.3, 0.3), + (ShearY, -0.3, 0.3), + (Solarize, 0, 256), + (TranslateX, -0.3, 0.3), + (TranslateY, -0.3, 0.3), + ] + + +class MultispectralRandAugment: + """Random augmentation pipeline for multispectral images. + + Randomly applies a series of channel-agnostic image transformations + that work with arbitrary channel counts. + + Attributes: + n: Number of augmentation operations to apply + m: Magnitude parameter (unused, kept for API compatibility) + num_channels: Number of image channels + augment_list: List of available augmentation operations + """ + + def __init__(self, n, m, num_channels=4): + """Initialize the MultispectralRandAugment pipeline. + + Args: + n: Number of augmentation operations to apply + m: Magnitude parameter (unused, kept for API compatibility) + num_channels: Number of channels in the images + """ + self.n = n + self.m = m + self.num_channels = num_channels + self.augment_list = multispectral_augment_list() + + def __call__(self, img): + """Apply random augmentations to the input image. + + Args: + img: numpy array (H, W, C) to be augmented. Will be converted to + uint8 if necessary, as operations like Posterize and Solarize + use bit-shift arithmetic that assumes 0-255 integer values. + + Returns: + Augmented numpy array (uint8) + """ + # Ensure input is numpy array + if not isinstance(img, np.ndarray): + img = np.array(img) + + # Convert to uint8 — operations like Posterize and Solarize use + # bit-shift arithmetic that requires integer 0-255 values. + if img.dtype != np.uint8: + img = np.clip(img, 0, 255).astype(np.uint8) + + # Apply n random augmentations + ops = random.choices(self.augment_list, k=self.n) + for op, min_val, max_val in ops: + val = min_val + float(max_val - min_val) * random.random() + img = op(img, val) + + # Apply cutout (for FixMatch) + cutout_val = random.random() * 0.5 + img = Cutout(img, cutout_val, num_channels=self.num_channels) + + return img + + +if __name__ == "__main__": + # Test with 4-channel image + randaug = MultispectralRandAugment(3, 5, num_channels=4) + test_img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8) + print(f"Input shape: {test_img.shape}") + + for op, min_val, max_val in randaug.augment_list: + val = min_val + float(max_val - min_val) * random.random() + result = op(test_img.copy(), val) + print(f"{op.__name__}: input {test_img.shape} -> output {result.shape}") + + # Test full pipeline + result = randaug(test_img.copy()) + print(f"Full pipeline: input {test_img.shape} -> output {result.shape}") diff --git a/anomaly_match/datasets/data_utils.py b/anomaly_match/datasets/data_utils.py index 532ea91..5f94d0e 100644 --- a/anomaly_match/datasets/data_utils.py +++ b/anomaly_match/datasets/data_utils.py @@ -4,16 +4,16 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import torch -from torch.utils.data import sampler, DataLoader -from torch.utils.data.sampler import BatchSampler, WeightedRandomSampler import numpy as np - +import torch from loguru import logger -from .BasicDataset import BasicDataset +from torch.utils.data import DataLoader, sampler +from torch.utils.data.sampler import BatchSampler, WeightedRandomSampler from anomaly_match.image_processing.transforms import get_prediction_transforms +from .BasicDataset import BasicDataset + def get_prediction_dataloader(dset, batch_size=None, num_workers=4, pin_memory=True): """Create a DataLoader for making predictions on unlabeled data. @@ -28,9 +28,10 @@ def get_prediction_dataloader(dset, batch_size=None, num_workers=4, pin_memory=T DataLoader: PyTorch DataLoader for the unlabeled data """ unlabeled, unlabeled_filenames = dset.unlabeled + num_channels = dset.num_channels # Basic transform for prediction - just convert to tensor - transform = get_prediction_transforms() + transform = get_prediction_transforms(num_channels=num_channels) # Create dataset with dummy labels (-1) ulb_dset = BasicDataset( @@ -41,6 +42,7 @@ def get_prediction_dataloader(dset, batch_size=None, num_workers=4, pin_memory=T transform=transform, use_strong_transform=False, strong_transform=transform, + num_channels=num_channels, ) return DataLoader( diff --git a/anomaly_match/image_processing/Normalisationreadme.md b/anomaly_match/image_processing/Normalisationreadme.md index ddc9b6e..ff2d938 100644 --- a/anomaly_match/image_processing/Normalisationreadme.md +++ b/anomaly_match/image_processing/Normalisationreadme.md @@ -51,6 +51,26 @@ Applies an asinh stretch - Minimum and Maximum are determined based on the config parameters described below +## Channel Combination + +When loading FITS files with multiple extensions, `cfg.normalisation.channel_combination` defines how extensions are linearly combined into output channels before normalisation is applied. + +- **Type:** NumPy array of shape `(n_output_channels, n_extensions)`, or `None` +- **When needed:** When you have more FITS extensions than desired output channels (e.g. 4 extensions -> 3 RGB channels) +- **Default:** `None` — extensions map directly to channels one-to-one +- **Order of operations:** Channel combination is applied first, then normalisation + +Each row of the array defines one output channel as a weighted sum of the input extensions. For example, with 4 extensions and 3 output channels: + +```python +import numpy as np +cfg.normalisation.channel_combination = np.array([ + [1, 0, 0, 0], # Channel 0 = extension 0 + [0, 0.5, 0.5, 0], # Channel 1 = average of extensions 1 and 2 + [0, 0, 0, 1], # Channel 2 = extension 3 +]) +``` + ## Normalisation settings (optional) - cfg.normalisation.maximum_value (None, float, default: None) - set the upper clipping value for normalisation, overwrites other settings diff --git a/anomaly_match/image_processing/display_transforms.py b/anomaly_match/image_processing/display_transforms.py deleted file mode 100644 index 6b2f769..0000000 --- a/anomaly_match/image_processing/display_transforms.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) European Space Agency, 2025. -# -# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which -# is part of this source code package. No part of the package, including -# this file, may be copied, modified, propagated, or distributed except according to -# the terms contained in the file 'LICENCE.txt'. -import numpy as np -from PIL import ImageOps, ImageEnhance, ImageFilter, Image -from skimage.util import img_as_ubyte - - -def display_image_normalisation(img): - """Normalises the image for display. - - Args: - img (np.ndarray): The input image array. - - Returns: - PIL.Image.Image: The normalised image. - """ - # Handle NaN/inf values - if not np.isfinite(img).all(): - img = np.nan_to_num(img, nan=0.0, posinf=1.0, neginf=0.0) - - img = img - np.min(img) - img_max = np.max(img) - - # Avoid division by zero - if img_max > 0: - img = img / img_max - else: - img = np.zeros_like(img) # Handle constant images - - img = img_as_ubyte(img) - if img.shape[-1] == 1: # Convert grayscale to RGB if necessary - img = np.repeat(img, 3, axis=-1) - return Image.fromarray(img) - - -# from utility_functions -def apply_transforms_ui( - img, - invert, - brightness, - contrast, - unsharp_mask_applied, - show_r, - show_g, - show_b, -): - """ - Applies the requested transformations to the given PIL Image. - - Args: - img (PIL.Image.Image): The original image. - invert (bool): Whether to invert colors. - brightness (float): Brightness factor. - contrast (float): Contrast factor. - unsharp_mask_applied (bool): Whether to apply an unsharp mask. - show_r (bool): Whether to show the red channel. - show_g (bool): Whether to show the green channel. - show_b (bool): Whether to show the blue channel. - - Returns: - PIL.Image.Image: The transformed image. - """ - # Apply inversion - if invert: - img = ImageOps.invert(img) - - # Apply brightness - if brightness != 1.0: - enhancer = ImageEnhance.Brightness(img) - img = enhancer.enhance(brightness) - - # Apply contrast - if contrast != 1.0: - enhancer = ImageEnhance.Contrast(img) - img = enhancer.enhance(contrast) - - # Apply unsharp mask if enabled - if unsharp_mask_applied: - img = img.filter(ImageFilter.UnsharpMask()) - - # Apply channel toggling - if not (show_r and show_g and show_b): - # Convert PIL image to numpy array - img_array = np.array(img) - - # Create a mask for RGB channels - channels_mask = [show_r, show_g, show_b] - - # Apply masking to the image array (zero out disabled channels) - for i, show_channel in enumerate(channels_mask): - if not show_channel: - img_array[:, :, i] = 0 - - # Convert back to PIL image - img = Image.fromarray(img_array) - - return img diff --git a/anomaly_match/image_processing/transforms.py b/anomaly_match/image_processing/transforms.py index 3a8c443..eb8a72f 100644 --- a/anomaly_match/image_processing/transforms.py +++ b/anomaly_match/image_processing/transforms.py @@ -4,56 +4,151 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. +import numpy as np +import torch from torchvision import transforms from anomaly_match.datasets.augmentation.randaugment import RandAugment +from anomaly_match.datasets.augmentation.randaugment_multispectral import ( + MultispectralRandAugment, +) -def get_weak_transforms(): +class NumpyToTensor: + """Convert numpy array to torch tensor for N-channel images. + + Handles HWC to CHW conversion for arbitrary channel counts. + """ + + def __call__(self, img): + """Convert numpy array (H, W, C) to tensor (C, H, W).""" + if isinstance(img, np.ndarray): + # HWC to CHW conversion + img = np.transpose(img, (2, 0, 1)) + img = torch.from_numpy(img.copy()).float() / 255.0 + return img + + +class NumpyRandomHorizontalFlip: + """Random horizontal flip for numpy arrays.""" + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img): + if isinstance(img, np.ndarray): + if np.random.random() < self.p: + return np.ascontiguousarray(img[:, ::-1, :]) + return img + elif isinstance(img, torch.Tensor): + if np.random.random() < self.p: + return torch.flip(img, dims=[2]) # Flip width dimension (CHW format) + return img + return img + + +class NumpyRandomTranslate: + """Random translation for numpy arrays.""" + + def __init__(self, translate=(0, 0.125)): + self.translate = translate + + def __call__(self, img): + if isinstance(img, np.ndarray): + h, w = img.shape[:2] + max_dx = self.translate[1] * w + max_dy = self.translate[0] * h + dx = np.random.uniform(-max_dx, max_dx) + dy = np.random.uniform(-max_dy, max_dy) + + # Use numpy roll for translation + img = np.roll(img, int(dx), axis=1) + img = np.roll(img, int(dy), axis=0) + return img + elif isinstance(img, torch.Tensor): + # For tensor (CHW format) + _, h, w = img.shape + max_dx = int(self.translate[1] * w) + max_dy = int(self.translate[0] * h) + dx = np.random.randint(-max_dx, max_dx + 1) if max_dx > 0 else 0 + dy = np.random.randint(-max_dy, max_dy + 1) if max_dy > 0 else 0 + return torch.roll(torch.roll(img, dx, dims=2), dy, dims=1) + return img + + +def get_weak_transforms(num_channels=3): """Get weak augmentation transforms. Args: - train (bool, optional): Whether training, in test only normalization is applied. + num_channels: Number of image channels (3 for RGB, other for multispectral) Returns: torchvision.transforms.Compose: transforms. """ - return transforms.Compose( - [ - transforms.ToTensor(), - transforms.RandomHorizontalFlip(), - transforms.RandomAffine(0, translate=(0, 0.125)), - ] - ) + if num_channels == 3: + # Use PIL-based transforms for RGB + return transforms.Compose( + [ + transforms.ToTensor(), + transforms.RandomHorizontalFlip(), + transforms.RandomAffine(0, translate=(0, 0.125)), + ] + ) + else: + # Use numpy-based transforms for multispectral + return transforms.Compose( + [ + NumpyRandomHorizontalFlip(), + NumpyRandomTranslate(translate=(0, 0.125)), + NumpyToTensor(), + ] + ) -def get_prediction_transforms(): +def get_prediction_transforms(num_channels=3): """Get the standard image transform. Args: - None + num_channels: Number of image channels (3 for RGB, other for multispectral) Returns: torchvision.transforms.Compose: transforms. - with an empty transform """ - return transforms.Compose([transforms.ToTensor()]) + if num_channels == 3: + return transforms.Compose([transforms.ToTensor()]) + else: + return transforms.Compose([NumpyToTensor()]) -def get_strong_transforms(): +def get_strong_transforms(num_channels=3): """Get strong augmentations for FixMatch. Includes RandAugment followed by the same transforms as weak (ToTensor, RandomHorizontalFlip, RandomAffine). + Args: + num_channels: Number of image channels (3 for RGB, other for multispectral) + Returns: torchvision.transforms.Compose: Strong augmentation pipeline. """ - return transforms.Compose( - [ - RandAugment(3, 5), - transforms.ToTensor(), - transforms.RandomHorizontalFlip(), - transforms.RandomAffine(0, translate=(0, 0.125)), - ] - ) + if num_channels == 3: + # Use PIL-based transforms for RGB + return transforms.Compose( + [ + RandAugment(3, 5), + transforms.ToTensor(), + transforms.RandomHorizontalFlip(), + transforms.RandomAffine(0, translate=(0, 0.125)), + ] + ) + else: + # Use numpy-based transforms for multispectral + return transforms.Compose( + [ + MultispectralRandAugment(3, 5, num_channels=num_channels), + NumpyRandomHorizontalFlip(), + NumpyRandomTranslate(translate=(0, 0.125)), + NumpyToTensor(), + ] + ) diff --git a/anomaly_match/models/FixMatch.py b/anomaly_match/models/FixMatch.py index 555c7ad..cbc680a 100644 --- a/anomaly_match/models/FixMatch.py +++ b/anomaly_match/models/FixMatch.py @@ -4,19 +4,18 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import torch -import torch.nn.functional as F -from sklearn.metrics import roc_auc_score, precision_recall_curve, auc - import sys -from tqdm.auto import tqdm +import torch +import torch.nn.functional as F from loguru import logger +from sklearn.metrics import auc, precision_recall_curve, roc_auc_score +from tqdm.auto import tqdm +from anomaly_match.datasets.data_utils import get_data_loader +from anomaly_match.utils.accuracy import accuracy from anomaly_match.utils.consistency_loss import consistency_loss from anomaly_match.utils.cross_entropy_loss import cross_entropy_loss -from anomaly_match.utils.accuracy import accuracy -from anomaly_match.datasets.data_utils import get_data_loader class FixMatch: @@ -29,7 +28,6 @@ def __init__( T, p_cutoff, lambda_u, - hard_label=True, logger=None, session_tracker=None, ): @@ -46,7 +44,6 @@ def __init__( T: Temperature parameter for sharpening predictions p_cutoff: Confidence threshold for pseudo-labeling lambda_u: Weight for unsupervised loss component - hard_label: If True, uses hard pseudo-labels, otherwise soft labels logger: Logger instance for outputting information session_tracker: Optional session tracker for recording training progress """ diff --git a/anomaly_match/pipeline/SessionTracker.py b/anomaly_match/pipeline/SessionTracker.py index 47354f3..bbe53df 100644 --- a/anomaly_match/pipeline/SessionTracker.py +++ b/anomaly_match/pipeline/SessionTracker.py @@ -6,9 +6,10 @@ # the terms contained in the file 'LICENCE.txt'. import datetime -from typing import Dict, List, Optional, Any -import pandas as pd from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import pandas as pd from loguru import logger diff --git a/anomaly_match/pipeline/session.py b/anomaly_match/pipeline/session.py index 83adf92..77c60df 100644 --- a/anomaly_match/pipeline/session.py +++ b/anomaly_match/pipeline/session.py @@ -4,41 +4,39 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import os import datetime -import time -import sys +import os +import pickle import subprocess - -from loguru import logger -import torch -import pandas as pd -import numpy as np +import sys +import time from contextlib import nullcontext -import pickle +from pathlib import Path + import h5py +import numpy as np +import pandas as pd +import torch import zarr -from pathlib import Path from fitsbolt import SUPPORTED_IMAGE_EXTENSIONS +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod +from loguru import logger - -from anomaly_match.datasets.SSL_Dataset import SSL_Dataset +from anomaly_match.data_io.load_images import get_fitsbolt_config +from anomaly_match.data_io.SessionIOHandler import SessionIOHandler from anomaly_match.datasets.data_utils import get_prediction_dataloader +from anomaly_match.datasets.SSL_Dataset import SSL_Dataset from anomaly_match.models.FixMatch import FixMatch - -from anomaly_match.utils.print_cfg import print_cfg -from anomaly_match.utils.set_log_level import set_log_level -from anomaly_match.utils.get_net_builder import get_net_builder +from anomaly_match.pipeline.SessionTracker import SessionTracker from anomaly_match.utils.cutana_stream_utils import ( cutana_buffer_generator, cutana_validate_files_and_count_sources, ) +from anomaly_match.utils.get_net_builder import get_net_builder from anomaly_match.utils.get_optimizer import get_optimizer +from anomaly_match.utils.print_cfg import print_cfg +from anomaly_match.utils.set_log_level import set_log_level from anomaly_match.utils.validate_config import validate_config -from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod -from anomaly_match.pipeline.SessionTracker import SessionTracker -from anomaly_match.data_io.SessionIOHandler import SessionIOHandler -from anomaly_match.data_io.load_images import get_fitsbolt_config class Session: @@ -48,7 +46,6 @@ class Session: unlabeled_train_dataset = None test_dataset = None - widget = None model: FixMatch = None active_learning_df = pd.DataFrame(columns=["filename", "label"]) @@ -115,7 +112,6 @@ def _init_model(self): T=self.cfg.temperature, p_cutoff=self.cfg.p_cutoff, lambda_u=self.cfg.ulb_loss_ratio, - hard_label=True, logger=logger, session_tracker=self.session_tracker, ) @@ -174,7 +170,7 @@ def _load_datasets(self): pin_memory=self.cfg.pin_memory, ) - def update_predictions(self): + def update_predictions(self, progress_callback=None): """Updates the predictions using the current model and datasets.""" with self.out if self.out is not None else nullcontext(): logger.debug("Updating predictions") @@ -186,12 +182,6 @@ def update_predictions(self): pin_memory=self.cfg.pin_memory, ) - def progress_callback(current, total): - if self.widget is not None and self.widget.ui["progress_bar"] is not None: - self.widget.ui["progress_bar"].value = current / total - - if self.widget is not None: - self.widget.ui["train_label"].value = "Updating predictions..." scores, imgs, filenames, _ = self.model.get_scored_binary_unlabeled_samples( self.prediction_dataloader, target_class=1, @@ -205,11 +195,9 @@ def progress_callback(current, total): if self.cfg.test_ratio > 0: logger.debug("Predictions updated, evaluating model") - if self.widget is not None: - self.widget.ui["train_label"].value = "Evaluating model..." self.eval_performance = self.model.evaluate( cfg=self.cfg, - progress_callback=lambda current, total: progress_callback(current, total), + progress_callback=progress_callback, ) def sort_by_anomalous(self): @@ -596,18 +584,6 @@ def get_active_learning_counts(self): self._active_learning_counts_cache = (new_normal, new_anomalous) return self._active_learning_counts_cache - def start_UI(self): - """Starts the user interface for the session.""" - from anomaly_match.ui.Widget import Widget - - if self.widget is None: - logger.info("Starting new UI... (this may compute furiously for a few seconds)") - self.widget = Widget(self) - self.widget.start() - else: - logger.debug("UI already running, restarting") - self.widget.start() - def get_label(self, idx): """Gets the label for the image at the given index. @@ -717,19 +693,18 @@ def run_pipeline(self, temp_config_path, input_path, top_N, file_type=None): with open(temp_file_list, "w") as f: f.write(input_path) - # Call prediction_process.py with the file list - subprocess.run( - [ - sys.executable, - script_path, - temp_config_path, - temp_file_list, - str(top_N), - ] - ) + cmd = [sys.executable, script_path, temp_config_path, temp_file_list, str(top_N)] else: - # For hdf5 and zarr files, pass the file path directly - subprocess.run([sys.executable, script_path, temp_config_path, input_path, str(top_N)]) + cmd = [sys.executable, script_path, temp_config_path, input_path, str(top_N)] + + logger.info(f"Launching prediction subprocess: {script}") + result = subprocess.run(cmd) + + if result.returncode != 0: + logger.error( + f"Prediction subprocess failed (exit code {result.returncode}). " + f"Check prediction.log in {self.cfg.output_dir} for details." + ) # Reset logger to old level set_log_level(self.cfg.log_level, self.cfg) @@ -748,8 +723,6 @@ def evaluate_all_images(self, top_N=1000, progress_callback=None): "Please train and save a model before running predictions." ) logger.error(error_msg) - if self.widget is not None: - self.widget.ui["train_label"].value = "Error: Model not found!" raise FileNotFoundError(error_msg) # Auto-detect file type based on prediction_search_dir @@ -760,8 +733,6 @@ def evaluate_all_images(self, top_N=1000, progress_callback=None): "images, HDF5 files, Zarr files, or Cutana buffer files." ) logger.error(error_msg) - if self.widget is not None: - self.widget.ui["train_label"].value = "Error: No prediction directory!" raise ValueError(error_msg) detected_file_type = self._auto_detect_prediction_file_type(self.cfg.prediction_search_dir) @@ -774,8 +745,6 @@ def evaluate_all_images(self, top_N=1000, progress_callback=None): "Please use CONVERSION_ONLY, LOG, ZSCALE, or ASINH." ) logger.error(error_msg) - if self.widget is not None: - self.widget.ui["train_label"].value = "Error: MIDTONES not supported!" raise ValueError(error_msg) # Define supported file extensions @@ -876,7 +845,6 @@ def evaluate_all_images(self, top_N=1000, progress_callback=None): # Creating a generator that loads the csv/parquet in chunks and saves to a temporary file elif detected_file_type == "stream": - # Files are read in chunks and saved into this intermediate buffer cutana_buffer_path = Path("tmp") / ".cutana_buffer.parquet" input_files = cutana_buffer_generator( @@ -984,7 +952,7 @@ def evaluate_all_images(self, top_N=1000, progress_callback=None): logger.info("Loading updated results from output files") df = pd.read_csv(output_csv_path) filenames = df["Filename"].values - self.filenames = np.array([os.path.basename(f) for f in filenames]) + self.filenames = np.array([os.path.basename(str(f)) for f in filenames]) self.scores = df["Score"].values # Load images using consistent format handling (same as load_top_files) @@ -1011,16 +979,21 @@ def evaluate_all_images(self, top_N=1000, progress_callback=None): else: self.img_catalog = self.img_catalog.clip(0, 255).astype(np.uint8) - # Update UI if available - if self.widget is not None: - self.widget.display_top_files_scores() + # Notify UI that results are available for display + if progress_callback: + progress_callback( + file_idx + 1, + num_files, + results_updated=True, + ) + else: logger.error( "Output files not found. Prediction process might have failed. On Datalabs, the process may have exceeded the RAM allocation. Please check logs in the folder ." # noqa: E501 ) # Log statistics - if len(self.scores) > 0: + if self.scores is not None and len(self.scores) > 0: logger.debug( f"File {file_idx} processed, scores mean={np.mean(self.scores):.4f}, " f"std={np.std(self.scores):.4f}, min={np.min(self.scores):.4f}, " @@ -1058,7 +1031,10 @@ def evaluate_all_images(self, top_N=1000, progress_callback=None): logger.warning("No images were processed or processing time was too short") logger.info(f"Processed {num_files} files with {detected_file_type} format") - logger.debug(f"Total images scored: {len(self.scores)}") + if self.scores is not None: + logger.debug(f"Total images scored: {len(self.scores)}") + else: + logger.warning("No scores were loaded - all prediction subprocesses may have failed") def load_top_files(self, top_N): """Loads the top files from the output directory using consistent image processing.""" @@ -1069,8 +1045,8 @@ def load_top_files(self, top_N): logger.info("Loading updated results from output files") df = pd.read_csv(output_csv_path) filenames = df["Filename"].values - # Convert to basename - self.filenames = np.array([os.path.basename(f) for f in filenames]) + # Convert to basename (str() handles cutana int64 source_ids) + self.filenames = np.array([os.path.basename(str(f)) for f in filenames]) self.scores = df["Score"].values # Load images using consistent format handling @@ -1105,11 +1081,6 @@ def load_top_files(self, top_N): logger.debug( f"Image catalog shape: {self.img_catalog.shape}, dtype: {self.img_catalog.dtype}" ) - - # Call Widget's display_top_files_scores to update the UI - if self.widget is not None: - logger.debug("Displaying top files and scores") - self.widget.display_top_files_scores() else: logger.error( f"Output files not found at {output_csv_path} and {output_npy_path}. \n Note that you may need to rename the" diff --git a/anomaly_match/utils/cross_entropy_loss.py b/anomaly_match/utils/cross_entropy_loss.py index f9ee713..923e2e7 100644 --- a/anomaly_match/utils/cross_entropy_loss.py +++ b/anomaly_match/utils/cross_entropy_loss.py @@ -4,8 +4,8 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import torch.nn.functional as F import torch +import torch.nn.functional as F def cross_entropy_loss(logits, targets, use_hard_labels=True, reduction="none"): @@ -31,9 +31,9 @@ def cross_entropy_loss(logits, targets, use_hard_labels=True, reduction="none"): return F.cross_entropy(logits, targets.long(), reduction=reduction) else: # KL divergence style loss with probability distributions as targets - assert ( - logits.shape == targets.shape - ), "Logits and targets must have the same shape when using soft labels" + assert logits.shape == targets.shape, ( + "Logits and targets must have the same shape when using soft labels" + ) log_pred = F.log_softmax(logits, dim=-1) # Negative KL divergence (equivalent to cross-entropy for soft targets) nll_loss = torch.sum(-targets * log_pred, dim=1) diff --git a/anomaly_match/utils/cutana_stream_utils.py b/anomaly_match/utils/cutana_stream_utils.py index 5b4514c..3b11a00 100644 --- a/anomaly_match/utils/cutana_stream_utils.py +++ b/anomaly_match/utils/cutana_stream_utils.py @@ -5,19 +5,23 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. +import math import warnings from pathlib import Path -import pyarrow.parquet as pq import pandas as pd +import pyarrow.parquet as pq +from cutana.catalogue_preprocessor import validate_catalogue_columns from loguru import logger -from cutana.catalogue_preprocessor import validate_catalogue_columns, check_fits_files_exist def cutana_validate_files_and_count_sources( files: list[Path | str], chunk_size: int = 100_000 ) -> tuple[list[Path], int, int]: - """Validate catalogue files for cutana compatibility and count total number sources and total number of chunks to process. + """Validate catalogue files for cutana compatibility and count sources. + + Only checks column schema (first chunk of first file). Row counts are + read from parquet metadata when possible to avoid scanning every row. Args: files (list[Path | str]): list of file paths to validate (CSV or Parquet). @@ -27,72 +31,61 @@ def cutana_validate_files_and_count_sources( tuple[list[Path], int, int]: valid files, total number of sources, and total number of chunks. """ - def _validate_against_cutana(index: int, dataframe: pd.DataFrame) -> bool: - # Check header once - if index == 0: - errors = validate_catalogue_columns(dataframe) - if errors: - return errors - - errors, _ = check_fits_files_exist(dataframe) - if errors: - return errors - return [] - + # Schema is validated once from the first valid file (all catalogue files + # are expected to share the same column layout). + columns_validated = False valid_files = [] total_sources = 0 total_chunks = 0 for file in files: - - is_file_valid = True - - current_file_sources = 0 - current_file_chunks = 0 - - if isinstance(file, Path): - file_type = file.name.split(".")[-1] + file_str = str(file) + file_type = file_str.rsplit(".", 1)[-1].lower() + + if file_type == "parquet": + try: + parquet_file = pq.ParquetFile(file_str) + + # Validate columns once from the first file's schema + if not columns_validated: + first_batch = next(parquet_file.iter_batches(batch_size=1)) + errors = validate_catalogue_columns(first_batch.to_pandas()) + if errors: + msg = f"File {file} did not pass cutana column check ({errors})" + logger.warning(msg) + warnings.warn(msg, RuntimeWarning) + continue + columns_validated = True + + num_rows = parquet_file.metadata.num_rows + total_sources += num_rows + total_chunks += math.ceil(num_rows / chunk_size) + valid_files.append(file) + except Exception as e: + logger.warning(f"Could not read parquet file {file}: {e}") + + elif file_type == "csv": + current_sources = 0 + current_chunks = 0 + is_valid = True + for i, df in enumerate(pd.read_csv(file_str, chunksize=chunk_size)): + if not columns_validated: + errors = validate_catalogue_columns(df) + if errors: + msg = f"File {file} did not pass cutana column check ({errors})" + logger.warning(msg) + warnings.warn(msg, RuntimeWarning) + is_valid = False + break + columns_validated = True + current_sources += len(df) + current_chunks += 1 + if is_valid: + total_sources += current_sources + total_chunks += current_chunks + valid_files.append(file) else: - file_type = file.split(".")[-1] - - if file_type == "csv": - for i, df in enumerate(pd.read_csv(file, chunksize=chunk_size)): - - errors = _validate_against_cutana(i, df) - if errors: - current_file_sources = 0 - current_file_chunks = 0 - is_file_valid = False - msg = f"File {file} did not pass cutana compatibility check and will be skipped ({errors})" - logger.warning(msg) - warnings.warn(msg, RuntimeWarning) - break - current_file_sources += len(df) - current_file_chunks += 1 - - elif file_type == "parquet": - parquet_file = pq.ParquetFile(file) - for i, batch in enumerate(parquet_file.iter_batches(batch_size=chunk_size)): - df = batch.to_pandas() - - errors = _validate_against_cutana(i, df) - if errors: - current_file_sources = 0 - current_file_chunks = 0 - is_file_valid = False - msg = f"File {file} did not pass cutana compatibility check and will be skipped ({errors})" - logger.warning(msg) - warnings.warn(msg, RuntimeWarning) - break - current_file_sources += len(df) - current_file_chunks += 1 - else: - is_file_valid = False - - total_sources += current_file_sources - total_chunks += current_file_chunks - if is_file_valid: - valid_files.append(file) + logger.warning(f"Unsupported file type '{file_type}' for {file}, skipping") return valid_files, total_sources, total_chunks @@ -111,11 +104,10 @@ def cutana_buffer_generator(files: list[Path | str], buffer_path: Path, chunk_si buffer_path.parent.mkdir(parents=True, exist_ok=True) for file in files: - if isinstance(file, Path): - file_type = file.name.split(".")[-1] + file_type = file.name.split(".")[-1].lower() else: - file_type = file.split(".")[-1] + file_type = file.split(".")[-1].lower() if file_type == "csv": for df in pd.read_csv(file, chunksize=chunk_size): diff --git a/anomaly_match/utils/get_cosine_schedule_with_warmup.py b/anomaly_match/utils/get_cosine_schedule_with_warmup.py index b1d32cc..8c2aba5 100644 --- a/anomaly_match/utils/get_cosine_schedule_with_warmup.py +++ b/anomaly_match/utils/get_cosine_schedule_with_warmup.py @@ -4,10 +4,10 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -from torch.optim.lr_scheduler import LambdaLR - import math +from torch.optim.lr_scheduler import LambdaLR + def get_cosine_schedule_with_warmup( optimizer, diff --git a/anomaly_match/utils/get_default_cfg.py b/anomaly_match/utils/get_default_cfg.py index 59520c3..5ac0bc9 100644 --- a/anomaly_match/utils/get_default_cfg.py +++ b/anomaly_match/utils/get_default_cfg.py @@ -4,13 +4,13 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -from dotmap import DotMap - import os + import numpy as np +from dotmap import DotMap +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod from .create_model_string import create_model_string -from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod def get_default_cfg(): @@ -53,6 +53,7 @@ def get_default_cfg(): cfg.normalisation.output_dtype = np.uint8 # output dtype of the images # NOTE: image_size has no default - user must explicitly set it cfg.normalisation.n_output_channels = 3 # number of output channels (e.g. 3 for RGB) + cfg.num_channels = cfg.normalisation.n_output_channels # set from dataset at runtime # FITS file handling settings # fits_extension: Extension(s) to use when loading FITS files @@ -89,6 +90,12 @@ def get_default_cfg(): ] # end of fitsbolt settings + # Flux conversion (Euclid): convert pixel values to flux density in Jansky + # using the AB zeropoint (MAGZERO) from FITS headers. When True, must be + # applied in both training (load_and_process_wrapper) and prediction (cutana) paths. + cfg.normalisation.apply_flux_conversion = False + cfg.normalisation.flux_conversion_zeropoint_keyword = "MAGZERO" + # FixMatch settings cfg.ema_m = 0.99 cfg.hard_label = True diff --git a/anomaly_match/utils/get_net_builder.py b/anomaly_match/utils/get_net_builder.py index 7007849..777cac4 100644 --- a/anomaly_match/utils/get_net_builder.py +++ b/anomaly_match/utils/get_net_builder.py @@ -4,24 +4,103 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -from efficientnet_pytorch import EfficientNet -import efficientnet_lite_pytorch -from efficientnet_lite0_pytorch_model import EfficientnetLite0ModelFile - +import timm +import torch.nn as nn from loguru import logger +# Mapping from AnomalyMatch net names to timm model identifiers. +# efficientnet-lite variants use the tf_ prefix for TF-style same-padding, +# which is backward compatible with the previous efficientnet_lite_pytorch package. +_TIMM_MODEL_MAP = { + "efficientnet-lite0": "tf_efficientnet_lite0", + "efficientnet-lite1": "tf_efficientnet_lite1", + "efficientnet-lite2": "tf_efficientnet_lite2", + "efficientnet-lite3": "tf_efficientnet_lite3", + "efficientnet-lite4": "tf_efficientnet_lite4", + "efficientnet-b0": "efficientnet_b0", + "efficientnet-b1": "efficientnet_b1", + "efficientnet-b2": "efficientnet_b2", + "efficientnet-b3": "efficientnet_b3", + "efficientnet-b4": "efficientnet_b4", + "efficientnet-b5": "efficientnet_b5", + "efficientnet-b6": "efficientnet_b6", + "efficientnet-b7": "efficientnet_b7", +} + +# Pretrained tag for timm (appended to model name for pretrained loading) +_TIMM_PRETRAINED_TAG = { + "tf_efficientnet_lite0": "tf_efficientnet_lite0.in1k", + "efficientnet_b0": "efficientnet_b0.ra_in1k", + "efficientnet_b1": "efficientnet_b1.ft_in1k", + "efficientnet_b2": "efficientnet_b2.ra_in1k", + "efficientnet_b3": "efficientnet_b3.ra2_in1k", + "efficientnet_b4": "efficientnet_b4.ra2_in1k", +} + + +class TestCNN(nn.Module): + """Minimal CNN for fast testing. Not for production use.""" + + def __init__(self, num_classes=2, in_channels=3): + super().__init__() + self._conv_stem = nn.Conv2d(in_channels, 8, 3, stride=2, padding=1) + self.features = nn.Sequential( + nn.ReLU(), + nn.AdaptiveAvgPool2d(1), + ) + self._fc = nn.Linear(8, num_classes) + + def forward(self, x): + x = self._conv_stem(x) + x = self.features(x) + x = x.flatten(1) + return self._fc(x) + + +def _resolve_timm_name(net_name, pretrained): + """Resolve AnomalyMatch net name to timm model identifier. + + Args: + net_name: AnomalyMatch-style net name (e.g. "efficientnet-lite0") + pretrained: Whether pretrained weights are requested + + Returns: + str: timm model name (with pretrained tag if applicable) + + Raises: + ValueError: If net_name is not supported + """ + timm_base = _TIMM_MODEL_MAP.get(net_name) + if timm_base is None: + supported = list(_TIMM_MODEL_MAP.keys()) + raise ValueError( + f"Unsupported network architecture: {net_name}. Supported architectures: {supported}" + ) + + if pretrained: + timm_name = _TIMM_PRETRAINED_TAG.get(timm_base) + if timm_name is None: + logger.warning( + f"No pretrained weights available for {net_name}. Using random initialization." + ) + return timm_base, False + return timm_name, True + + return timm_base, False + def get_net_builder(net_name, pretrained=False, in_channels=3): """Create a neural network builder function for the specified architecture. This function returns a builder function that creates a neural network with the specified architecture when called with num_classes and in_channels parameters. - Currently supports various EfficientNet variants. + Uses timm (pytorch-image-models) as the backend for all EfficientNet variants. Args: net_name (str): Name of the network architecture, supported values: - - efficientnet-lite0, efficientnet-lite1, etc. - - efficientnet-b0, efficientnet-b1, etc. + - efficientnet-lite0 through efficientnet-lite4 + - efficientnet-b0 through efficientnet-b7 + - test-cnn (for testing only) pretrained (bool, optional): If True, loads pretrained weights. Default is False. in_channels (int, optional): Number of input channels. Default is 3. @@ -31,49 +110,26 @@ def get_net_builder(net_name, pretrained=False, in_channels=3): Raises: ValueError: If an unsupported network architecture is specified """ - if "efficientnet-lite" in net_name: - if pretrained: - if net_name == "efficientnet-lite0": - logger.debug(f"Using pretrained {net_name} model") - weights_path = EfficientnetLite0ModelFile.get_model_file_path() - - return lambda num_classes, in_channels: efficientnet_lite_pytorch.EfficientNet.from_pretrained( - "efficientnet-lite0", - weights_path=weights_path, - num_classes=num_classes, - in_channels=in_channels, - ) - else: - logger.warning( - f"Only efficientnet-lite0 pretrained is supported. Using non-pretrained {net_name} instead." - ) - return lambda num_classes, in_channels: efficientnet_lite_pytorch.EfficientNet.from_name( - net_name, num_classes=num_classes, in_channels=in_channels - ) - else: - logger.debug(f"Using non-pretrained {net_name} model") - return ( - lambda num_classes, in_channels: efficientnet_lite_pytorch.EfficientNet.from_name( - net_name, num_classes=num_classes, in_channels=in_channels - ) - ) + if net_name == "test-cnn": + logger.debug("Using test-cnn model (for testing only)") - elif "efficientnet" in net_name: - if pretrained: - logger.debug(f"Using pretrained {net_name} model") - return lambda num_classes, in_channels: EfficientNet.from_pretrained( - net_name, num_classes=num_classes, in_channels=in_channels - ) + def build_test_cnn(num_classes, in_channels): + return TestCNN(num_classes=num_classes, in_channels=in_channels) - else: - logger.debug(f"Using non-pretrained {net_name} model") - return lambda num_classes, in_channels: EfficientNet.from_name( - net_name, num_classes=num_classes, in_channels=in_channels - ) - else: - error_msg = ( - f"Unsupported network architecture: {net_name}. " - f"Supported architectures: efficientnet-b[0-7] and efficientnet-lite[0-4]" + return build_test_cnn + + timm_name, use_pretrained = _resolve_timm_name(net_name, pretrained) + logger.debug( + f"Using {'pretrained' if use_pretrained else 'non-pretrained'} {net_name} " + f"(timm: {timm_name})" + ) + + def build_model(num_classes, in_channels, _timm_name=timm_name, _pretrained=use_pretrained): + return timm.create_model( + _timm_name, + pretrained=_pretrained, + num_classes=num_classes, + in_chans=in_channels, ) - logger.error(error_msg) - raise ValueError(error_msg) + + return build_model diff --git a/anomaly_match/utils/get_optimizer.py b/anomaly_match/utils/get_optimizer.py index 43e4944..7051a11 100644 --- a/anomaly_match/utils/get_optimizer.py +++ b/anomaly_match/utils/get_optimizer.py @@ -50,7 +50,6 @@ def get_optimizer( nesterov=nesterov, ) elif name == "Adam" or name == "ADAM": - if lr > 0.005: raise ValueError("Learning rate is " + str(lr) + ". That is too high for ADAM.") diff --git a/anomaly_match/utils/print_cfg.py b/anomaly_match/utils/print_cfg.py index 947fe00..8c81f28 100644 --- a/anomaly_match/utils/print_cfg.py +++ b/anomaly_match/utils/print_cfg.py @@ -4,9 +4,9 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -from dotmap import DotMap -import numpy as np import ipywidgets as widgets +import numpy as np +from dotmap import DotMap def print_cfg(cfg: DotMap): diff --git a/anomaly_match/utils/set_log_level.py b/anomaly_match/utils/set_log_level.py index f8f993a..88d4362 100644 --- a/anomaly_match/utils/set_log_level.py +++ b/anomaly_match/utils/set_log_level.py @@ -4,10 +4,11 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -from loguru import logger -import sys import os +import sys + from dotmap import DotMap +from loguru import logger def set_log_level(log_level: str, cfg: DotMap, log_to_file: bool = True): @@ -25,9 +26,9 @@ def set_log_level(log_level: str, cfg: DotMap, log_to_file: bool = True): valid_log_levels = ["TRACE", "DEBUG", "INFO", "SUCCESS", "WARNING", "ERROR", "CRITICAL"] # Assert that the provided log_level is valid - assert ( - log_level.upper() in valid_log_levels - ), f"Invalid log level: {log_level}. Expected one of {valid_log_levels}." + assert log_level.upper() in valid_log_levels, ( + f"Invalid log level: {log_level}. Expected one of {valid_log_levels}." + ) # Create logs directory in project root (two levels up from utils) logs_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "logs") diff --git a/anomaly_match/utils/set_seeds.py b/anomaly_match/utils/set_seeds.py index cdda134..6814c9d 100644 --- a/anomaly_match/utils/set_seeds.py +++ b/anomaly_match/utils/set_seeds.py @@ -5,9 +5,10 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. import random + import numpy as np -import torch.backends.cudnn as cudnn import torch +import torch.backends.cudnn as cudnn def set_seeds(seed: int, deterministic: bool = False) -> None: diff --git a/anomaly_match/utils/validate_config.py b/anomaly_match/utils/validate_config.py index f3b37da..c74df77 100644 --- a/anomaly_match/utils/validate_config.py +++ b/anomaly_match/utils/validate_config.py @@ -5,11 +5,12 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -from dotmap import DotMap -from loguru import logger - import os + +import numpy as np +from dotmap import DotMap from fitsbolt.cfg.create_config import create_config as fb_create_cfg +from loguru import logger def _return_required_and_optional_keys(): @@ -40,7 +41,7 @@ def _return_required_and_optional_keys(): # Required numeric parameter" "seed": [float, None, None, False, None], # accepts int or float # Required positive integers - "num_workers": [int, 1, None, False, None], + "num_workers": [int, 0, None, False, None], "uratio": [int, 1, None, False, None], "batch_size": [int, 1, None, False, None], "num_train_iter": [int, 1, None, False, None], @@ -66,7 +67,28 @@ def _return_required_and_optional_keys(): "pretrained": [bool, None, None, False, None], # Required parameters with allowed values "opt": [str, None, None, False, ["SGD", "Adam"]], - "net": [str, None, None, False, ["efficientnet-lite0"]], + "net": [ + str, + None, + None, + False, + [ + "efficientnet-lite0", + "efficientnet-lite1", + "efficientnet-lite2", + "efficientnet-lite3", + "efficientnet-lite4", + "efficientnet-b0", + "efficientnet-b1", + "efficientnet-b2", + "efficientnet-b3", + "efficientnet-b4", + "efficientnet-b5", + "efficientnet-b6", + "efficientnet-b7", + "test-cnn", + ], + ], "log_level": [ str, None, @@ -79,6 +101,7 @@ def _return_required_and_optional_keys(): # Optional directory parameters "prediction_search_dir": ["directory", None, None, True, None], "N_batch_prediction": [int, 1, None, True, None], + "num_channels": [int, 1, None, False, None], # fitsbolt config parameters - only validate that it's a DotMap and check size "normalisation": ["special_fitsbolt", None, None, False, None], "normalisation.image_size": ["special_size", None, None, False, None], @@ -177,7 +200,7 @@ def _format_constraints(): return f" ({', '.join(constraints)})" if constraints else "" # Validate based on data type - if dtype == str: + if dtype is str: if not isinstance(value, str): raise ValueError( f"{param_name} must be a string, got {type(value).__name__}{_format_constraints()}" @@ -208,7 +231,7 @@ def _format_constraints(): if check_paths and not os.path.isfile(value): raise ValueError(f"{param_name} file does not exist: {value}") - elif dtype == int: + elif dtype is int: if not isinstance(value, int): raise ValueError( f"{param_name} must be an integer, got {type(value).__name__}{_format_constraints()}" @@ -224,7 +247,7 @@ def _format_constraints(): if allowed_values is not None and value not in allowed_values: raise ValueError(f"{param_name} must be one of {allowed_values}, got {value}") - elif dtype == float: + elif dtype is float: if not isinstance(value, (int, float)): raise ValueError( f"{param_name} must be a number, got {type(value).__name__}{_format_constraints()}" @@ -240,7 +263,7 @@ def _format_constraints(): if allowed_values is not None and value not in allowed_values: raise ValueError(f"{param_name} must be one of {allowed_values}, got {value}") - elif dtype == bool: + elif dtype is bool: if not isinstance(value, bool): raise ValueError(f"{param_name} must be a boolean, got {type(value).__name__}") @@ -268,6 +291,52 @@ def _format_constraints(): # Also validate normalisation configuration with its own validation function if possible if hasattr(cfg, "normalisation"): + cc = cfg.normalisation.channel_combination + fits_ext = cfg.normalisation.fits_extension + + # Infer n_output_channels from channel_combination matrix if provided + if cc is not None and hasattr(cc, "shape") and len(cc.shape) == 2: + inferred = cc.shape[0] + if inferred != cfg.normalisation.n_output_channels: + logger.info( + f"Setting n_output_channels to {inferred} " + f"from channel_combination shape {cc.shape}" + ) + cfg.normalisation.n_output_channels = inferred + + # Auto-create identity channel_combination for multiple FITS extensions + # when no explicit matrix is provided. + elif fits_ext is not None and isinstance(fits_ext, (list, tuple)) and len(fits_ext) > 1: + n_ext = len(fits_ext) + cfg.normalisation.channel_combination = np.eye(n_ext) + cfg.normalisation.n_output_channels = n_ext + logger.info( + f"Auto-created {n_ext}x{n_ext} identity channel_combination " + f"for {n_ext} FITS extensions (n_output_channels set to {n_ext})" + ) + + # Guard against n_output_channels being None (e.g. user unset it + # expecting auto-inference but only has a single FITS extension). + if cfg.normalisation.n_output_channels is None: + raise ValueError( + "n_output_channels is None and could not be inferred. " + "Set normalisation.n_output_channels explicitly or provide " + "multiple fits_extension entries or a channel_combination matrix." + ) + + # Keep cfg.num_channels in sync with n_output_channels + cfg.num_channels = cfg.normalisation.n_output_channels + + # Ensure per-channel normalisation lists match n_output_channels + n_out = cfg.normalisation.n_output_channels + for attr in ("norm_asinh_scale", "norm_asinh_clip"): + val = cfg.normalisation[attr] + if isinstance(val, list) and len(val) != n_out and len(val) != 1: + if n_out < len(val): + cfg.normalisation[attr] = val[:n_out] + else: + cfg.normalisation[attr] = val + [val[-1]] * (n_out - len(val)) + try: # Use fitsbolt's own validation by calling its create_config function _ = fb_create_cfg( @@ -275,9 +344,10 @@ def _format_constraints(): size=cfg.normalisation.image_size, fits_extension=cfg.normalisation.fits_extension, interpolation_order=cfg.normalisation.interpolation_order, + n_output_channels=cfg.normalisation.n_output_channels, normalisation_method=cfg.normalisation.normalisation_method, channel_combination=cfg.normalisation.channel_combination, - num_workers=cfg.num_workers, + num_workers=max(cfg.num_workers, 1), norm_maximum_value=cfg.normalisation.norm_maximum_value, norm_minimum_value=cfg.normalisation.norm_minimum_value, norm_log_calculate_minimum_value=cfg.normalisation.norm_log_calculate_minimum_value, @@ -302,6 +372,8 @@ def _format_constraints(): "normalisation.norm_crop_for_maximum_value", "normalisation.norm_asinh_scale", "normalisation.norm_asinh_clip", + "normalisation.apply_flux_conversion", + "normalisation.flux_conversion_zeropoint_keyword", ] ) except Exception as e: diff --git a/anomaly_match_ui/__init__.py b/anomaly_match_ui/__init__.py new file mode 100644 index 0000000..dfc0846 --- /dev/null +++ b/anomaly_match_ui/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +""" +AnomalyMatch UI Package - Jupyter notebook interface for anomaly detection. + +This package provides the UI components for AnomalyMatch, separated from the core +backend functionality to allow headless operation of the backend. +""" + +from anomaly_match_ui.app import start_ui +from anomaly_match_ui.utils.backend_interface import BackendInterface + +__all__ = ["start_ui", "BackendInterface"] diff --git a/anomaly_match_ui/app.py b/anomaly_match_ui/app.py new file mode 100644 index 0000000..e04acc6 --- /dev/null +++ b/anomaly_match_ui/app.py @@ -0,0 +1,53 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +""" +Entry point for the AnomalyMatch UI. + +This module provides the main entry function for starting the UI with a Session. +""" + +from loguru import logger + +from anomaly_match_ui.utils.backend_interface import BackendInterface +from anomaly_match_ui.widget import Widget + + +def start_ui(session): + """Start the AnomalyMatch UI with the given session. + + This is the main entry point for launching the UI. It sets up the + BackendInterface with the session and creates/displays the Widget. + + Args: + session: An AnomalyMatch Session instance that has been initialized + and has predictions ready for display. + + Returns: + Widget: The created Widget instance. + + Example: + >>> import anomaly_match as am + >>> from anomaly_match_ui import start_ui + >>> + >>> cfg = am.get_default_cfg() + >>> cfg.data_dir = "path/to/images" + >>> session = am.Session(cfg) + >>> session.train(cfg) + >>> widget = start_ui(session) + """ + logger.debug("Starting AnomalyMatch UI...") + + # Set up the backend interface with the session + BackendInterface.set_session(session) + + # Create and start the widget + widget = Widget() + widget.start() + + logger.debug("AnomalyMatch UI started successfully.") + + return widget diff --git a/anomaly_match/ui/memory_monitor.py b/anomaly_match_ui/memory_monitor.py similarity index 99% rename from anomaly_match/ui/memory_monitor.py rename to anomaly_match_ui/memory_monitor.py index 565b648..ecb88ef 100644 --- a/anomaly_match/ui/memory_monitor.py +++ b/anomaly_match_ui/memory_monitor.py @@ -4,14 +4,15 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import psutil -import ipywidgets as widgets import asyncio -import torch -from loguru import logger from contextlib import nullcontext from datetime import datetime +import ipywidgets as widgets +import psutil +import torch +from loguru import logger + class MemoryMonitor: """Monitors system and GPU memory usage and displays it in a widget.""" diff --git a/anomaly_match/ui/preview_widget.py b/anomaly_match_ui/preview_widget.py similarity index 75% rename from anomaly_match/ui/preview_widget.py rename to anomaly_match_ui/preview_widget.py index 0d9df5c..393e8d3 100644 --- a/anomaly_match/ui/preview_widget.py +++ b/anomaly_match_ui/preview_widget.py @@ -7,31 +7,30 @@ """ PreviewWidget: A self-contained widget for displaying and manipulating preview images. """ + import os -import numpy as np + import ipywidgets as widgets +import numpy as np from ipywidgets import VBox from loguru import logger from anomaly_match.data_io.load_images import load_and_process_single_wrapper -from anomaly_match.image_processing.display_transforms import ( +from anomaly_match_ui.utils.backend_interface import BackendInterface +from anomaly_match_ui.utils.display_transforms import ( apply_transforms_ui, display_image_normalisation, ) -from anomaly_match.utils.numpy_to_byte_stream import numpy_array_to_byte_stream +from anomaly_match_ui.utils.image_utils import numpy_array_to_byte_stream class PreviewWidget: """A widget component for displaying images with transformation controls.""" - def __init__(self, session): - """ - Initialize the preview widget. - - Args: - session: The session object providing image data and configuration. - """ - self.session = session + def __init__(self): + """Initialize the preview widget.""" + # Get number of channels from backend + self.num_channels = BackendInterface.get_num_channels() # Create UI elements self.filename_text = widgets.HTML( @@ -61,9 +60,18 @@ def __init__(self, session): self.brightness = 1.0 self.contrast = 1.0 self.unsharp_mask_applied = False + + # Channel visibility - support both RGB and N-channel modes self.show_r = True self.show_g = True self.show_b = True + + # For N-channel images: which channels to display as RGB + # Default: first 3 channels [0, 1, 2] + self.rgb_mapping = [0, 1, 2] if self.num_channels >= 3 else list(range(self.num_channels)) + # Channel visibility for N-channel mode + self.channel_visibility = [True] * self.num_channels + self.full_resolution_mode = False # Image data @@ -83,18 +91,22 @@ def set_index(self, index): def update_display(self): """Updates the display of the current image.""" - filename = self.session.filenames[self.current_index] - score = self.session.scores[self.current_index] - filepath = os.path.join(self.session.cfg.data_dir, filename) + filenames = BackendInterface.get_filenames() + scores = BackendInterface.get_scores() + cfg = BackendInterface.get_config() + + filename = filenames[self.current_index] + score = scores[self.current_index] + filepath = os.path.join(cfg.data_dir, filename) # Determine if we need to reload from disk needs_reload = ( self.full_resolution_mode - or self.session.cfg.normalisation.normalisation_method - != self.session.cached_image_normalisation_enum + or cfg.normalisation.normalisation_method + != BackendInterface.get_cached_normalisation_method() ) - if needs_reload: + if needs_reload and os.path.exists(filepath): try: size_override = None if self.full_resolution_mode else "default" logger.debug( @@ -103,19 +115,23 @@ def update_display(self): img = load_and_process_single_wrapper( filepath, - self.session.cfg, + cfg, desc="widget loading image", show_progress=False, size_override=size_override, ) - self.original_image = display_image_normalisation(img) + # For N-channel images, use rgb_mapping to convert to displayable RGB + rgb_map = self.rgb_mapping if self.num_channels > 3 else None + self.original_image = display_image_normalisation(img, rgb_mapping=rgb_map) except Exception as e: logger.error(f"Error loading image {filepath}: {e}") return else: - img = self.session.img_catalog[self.current_index] - self.original_image = display_image_normalisation(img) + img = BackendInterface.get_image_at_index(self.current_index) + # For N-channel images, use rgb_mapping to convert to displayable RGB + rgb_map = self.rgb_mapping if self.num_channels > 3 else None + self.original_image = display_image_normalisation(img, rgb_mapping=rgb_map) # Apply transforms self.modified_image = apply_transforms_ui( @@ -127,6 +143,7 @@ def update_display(self): show_r=self.show_r, show_g=self.show_g, show_b=self.show_b, + channel_visibility=self.channel_visibility if self.num_channels > 3 else None, ) self._display_image(self.modified_image, filename, score) @@ -141,7 +158,7 @@ def _update_label(self, filename=None, score=None): """Updates the UI label with the current image's filename, score, and label.""" label_color = "white" label_text = "None" - label = self.session.get_label(self.current_index) + label = BackendInterface.get_label(self.current_index) if label == "anomaly": label_color = "red" label_text = "Anomalous" @@ -150,13 +167,16 @@ def _update_label(self, filename=None, score=None): label_text = "Nominal" # Get counts for anomalies and nominal samples - normal_count, anomalous_count = self.session.get_label_distribution() + normal_count, anomalous_count = BackendInterface.get_label_distribution() # Calculate newly annotated samples using cached method - new_nominal, new_anomalous = self.session.get_active_learning_counts() + new_nominal, new_anomalous = BackendInterface.get_active_learning_counts() # Format the file name (shortened version) - fname = self.session.filenames[self.current_index] + filenames = BackendInterface.get_filenames() + scores = BackendInterface.get_scores() + + fname = filenames[self.current_index] fname_short = os.path.basename(fname) if len(fname_short) > 57: fname_short = ( @@ -166,8 +186,8 @@ def _update_label(self, filename=None, score=None): + "." + fname_short.split(".")[-1] ) - sc = self.session.scores[self.current_index] - total_len = len(self.session.img_catalog) - 1 + sc = scores[self.current_index] + total_len = BackendInterface.get_image_count() - 1 self.filename_text.value = ( f'' @@ -224,13 +244,21 @@ def set_contrast(self, value): self._apply_transforms_and_display() def set_rgb_channels(self, r=None, g=None, b=None): - """Sets RGB channel visibility.""" + """Sets RGB channel visibility for both RGB and N-channel modes.""" if r is not None: self.show_r = r if g is not None: self.show_g = g if b is not None: self.show_b = b + + # Keep channel_visibility in sync for N-channel mode, where + # the first 3 entries control the RGB display channels. + if self.num_channels > 3: + for i, visible in enumerate([self.show_r, self.show_g, self.show_b]): + if i < len(self.channel_visibility): + self.channel_visibility[i] = visible + self._apply_transforms_and_display() def toggle_full_resolution(self): @@ -267,5 +295,6 @@ def _apply_transforms_and_display(self): show_r=self.show_r, show_g=self.show_g, show_b=self.show_b, + channel_visibility=self.channel_visibility if self.num_channels > 3 else None, ) self._display_image(self.modified_image) diff --git a/anomaly_match/ui/ui_elements.py b/anomaly_match_ui/ui_elements.py similarity index 96% rename from anomaly_match/ui/ui_elements.py rename to anomaly_match_ui/ui_elements.py index cce2bfa..cb65ffe 100644 --- a/anomaly_match/ui/ui_elements.py +++ b/anomaly_match_ui/ui_elements.py @@ -5,13 +5,13 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. import ipywidgets as widgets -from ipywidgets import Button, HBox, VBox +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod from IPython.display import HTML +from ipywidgets import Button, HBox, VBox from loguru import logger -from anomaly_match import __version__ -from anomaly_match.ui.memory_monitor import MemoryMonitor -from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod +from anomaly_match_ui.memory_monitor import MemoryMonitor +from anomaly_match_ui.utils.backend_interface import BackendInterface HTML_setup = HTML( """ @@ -90,9 +90,7 @@ def create_ui_elements(): - """ - Creates and returns the necessary UI widgets as a dictionary. - """ + """Creates and returns the necessary UI widgets as a dictionary.""" logger.debug("Creating UI elements") # HTML widget @@ -253,7 +251,7 @@ def create_ui_elements(): ), ) - # Create a slider container with both sliders and the RGB controls + # Create a slider container with both sliders slider_row = HBox( [brightness_slider, contrast_slider], layout=widgets.Layout(background_color="black"), @@ -349,7 +347,7 @@ def create_ui_elements(): # Create version and memory monitor display with more explicit styling version_text = widgets.HTML( - value=f'
Version: {__version__}
', + value=f'
Version: {BackendInterface.get_version()}
', layout=widgets.Layout( background_color="black", padding="3px", @@ -425,8 +423,11 @@ def create_ui_elements(): ], layout=widgets.Layout(background_color="black"), ) # Add new row with channels, normalisation, and remove label button + # Include RGB mapping controls if available (for N-channel images) + bottom_row2_items = [channel_controls] + bottom_row2_items.extend([normalisation_dropdown, remove_label_button]) bottom_row2 = HBox( - [channel_controls, normalisation_dropdown, remove_label_button], + bottom_row2_items, layout=widgets.Layout(background_color="black"), ) @@ -521,11 +522,11 @@ def attach_click_listeners(widget): # Decision buttons def mark_anomalous(_): - widget.session.label_image(widget.preview.current_index, "anomaly") + BackendInterface.label_image(widget.preview.current_index, "anomaly") widget.update_image_UI_label() def mark_nominal(_): - widget.session.label_image(widget.preview.current_index, "normal") + BackendInterface.label_image(widget.preview.current_index, "normal") widget.update_image_UI_label() widget.ui["decision_buttons"][0].on_click(mark_anomalous) diff --git a/anomaly_match_ui/utils/__init__.py b/anomaly_match_ui/utils/__init__.py new file mode 100644 index 0000000..fdde815 --- /dev/null +++ b/anomaly_match_ui/utils/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Utilities for the AnomalyMatch UI package.""" + +from anomaly_match_ui.utils.backend_interface import BackendInterface +from anomaly_match_ui.utils.display_transforms import ( + apply_transforms_ui, + display_image_normalisation, + prepare_for_display, +) +from anomaly_match_ui.utils.image_utils import numpy_array_to_byte_stream + +__all__ = [ + "BackendInterface", + "apply_transforms_ui", + "display_image_normalisation", + "prepare_for_display", + "numpy_array_to_byte_stream", +] diff --git a/anomaly_match_ui/utils/backend_interface.py b/anomaly_match_ui/utils/backend_interface.py new file mode 100644 index 0000000..09e2ec8 --- /dev/null +++ b/anomaly_match_ui/utils/backend_interface.py @@ -0,0 +1,355 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +""" +BackendInterface: Single interface for UI-backend communication. + +This module provides a static interface for the UI components to interact with +the backend Session without direct imports, enabling clean separation between +the UI and backend packages. +""" + +from typing import Any, Callable, Dict, Optional, Tuple + +import numpy as np + +import anomaly_match + + +class BackendInterface: + """Single interface for UI-backend communication. + + This static class wraps all Session methods needed by the UI components, + providing a clean interface that decouples the UI from the backend implementation. + """ + + _session = None + + # ========== Session Lifecycle ========== + + @staticmethod + def set_session(session) -> None: + """Set the session instance to use. + + Args: + session: The Session instance to interact with. + """ + BackendInterface._session = session + + @staticmethod + def get_session(): + """Get the current session instance. + + Returns: + The current Session instance, or None if not set. + """ + return BackendInterface._session + + @staticmethod + def _check_session() -> None: + """Check that a session is set, raise RuntimeError if not.""" + if BackendInterface._session is None: + raise RuntimeError("No session set. Call BackendInterface.set_session() first.") + + # ========== Configuration ========== + + @staticmethod + def get_config(): + """Get the session configuration. + + Returns: + The session configuration object (DotMap). + """ + BackendInterface._check_session() + return BackendInterface._session.cfg + + @staticmethod + def get_num_channels() -> int: + """Get the number of channels in the images. + + Returns: + int: Number of image channels. + """ + BackendInterface._check_session() + return getattr(BackendInterface._session.cfg, "num_channels", 3) + + @staticmethod + def set_normalisation_method(method) -> None: + """Set the normalisation method. + + Args: + method (NormalisationMethod): The new normalisation method to apply. + """ + BackendInterface._check_session() + BackendInterface._session.set_normalisation_method(method) + + @staticmethod + def get_cached_normalisation_method(): + """Get the cached normalisation method (what images are currently loaded with). + + Returns: + NormalisationMethod: The cached normalisation method enum value. + """ + BackendInterface._check_session() + return BackendInterface._session.cached_image_normalisation_enum + + # ========== Image Data ========== + + @staticmethod + def get_image_at_index(index: int) -> np.ndarray: + """Get the image at the given index from the catalog. + + Args: + index (int): The index of the image. + + Returns: + np.ndarray: The image array. + """ + BackendInterface._check_session() + return BackendInterface._session.img_catalog[index] + + @staticmethod + def get_image_count() -> int: + """Get the total number of images in the catalog. + + Returns: + int: Number of images. + """ + BackendInterface._check_session() + if BackendInterface._session.img_catalog is None: + return 0 + return len(BackendInterface._session.img_catalog) + + @staticmethod + def get_scores() -> np.ndarray: + """Get the anomaly scores for all images. + + Returns: + np.ndarray: Array of anomaly scores. + """ + BackendInterface._check_session() + return BackendInterface._session.scores + + @staticmethod + def get_filenames() -> np.ndarray: + """Get the filenames for all images. + + Returns: + np.ndarray: Array of filenames. + """ + BackendInterface._check_session() + return BackendInterface._session.filenames + + # ========== Labeling ========== + + @staticmethod + def label_image(index: int, label: str) -> None: + """Label an image at the given index. + + Args: + index (int): Index of the image to label. + label (str): Label to assign ("normal" or "anomaly"). + """ + BackendInterface._check_session() + BackendInterface._session.label_image(index, label) + + @staticmethod + def unlabel_image(index: int) -> None: + """Remove the label from an image at the given index. + + Args: + index (int): Index of the image to unlabel. + """ + BackendInterface._check_session() + BackendInterface._session.unlabel_image(index) + + @staticmethod + def get_label(index: int) -> str: + """Get the label for an image at the given index. + + Args: + index (int): Index of the image. + + Returns: + str: The label ("normal", "anomaly", or "None"). + """ + BackendInterface._check_session() + return BackendInterface._session.get_label(index) + + @staticmethod + def get_label_distribution() -> Tuple[int, int]: + """Get the distribution of labels. + + Returns: + tuple: (normal_count, anomalous_count) + """ + BackendInterface._check_session() + return BackendInterface._session.get_label_distribution() + + @staticmethod + def get_active_learning_counts() -> Tuple[int, int]: + """Get the count of newly annotated samples in active learning. + + Returns: + tuple: (new_normal_count, new_anomalous_count) + """ + BackendInterface._check_session() + return BackendInterface._session.get_active_learning_counts() + + @staticmethod + def save_labels() -> None: + """Save the current labels to a file.""" + BackendInterface._check_session() + BackendInterface._session.save_labels() + + # ========== Sorting ========== + + @staticmethod + def sort_by_anomalous() -> None: + """Sort images by anomalous scores (most anomalous first).""" + BackendInterface._check_session() + BackendInterface._session.sort_by_anomalous() + + @staticmethod + def sort_by_nominal() -> None: + """Sort images by nominal scores (most nominal first).""" + BackendInterface._check_session() + BackendInterface._session.sort_by_nominal() + + @staticmethod + def sort_by_mean() -> None: + """Sort images by distance to mean score.""" + BackendInterface._check_session() + BackendInterface._session.sort_by_mean() + + @staticmethod + def sort_by_median() -> None: + """Sort images by distance to median score.""" + BackendInterface._check_session() + BackendInterface._session.sort_by_median() + + # ========== Model Operations ========== + + @staticmethod + def train(cfg, progress_callback: Optional[Callable] = None) -> None: + """Train the model. + + Args: + cfg: Configuration for training. + progress_callback: Optional callback for progress updates. + """ + BackendInterface._check_session() + BackendInterface._session.train(cfg, progress_callback=progress_callback) + + @staticmethod + def save_model() -> None: + """Save the current model state.""" + BackendInterface._check_session() + BackendInterface._session.save_model() + + @staticmethod + def load_model() -> None: + """Load the model from the saved state.""" + BackendInterface._check_session() + BackendInterface._session.load_model() + + @staticmethod + def reset_model() -> None: + """Reset the model and reinitialize.""" + BackendInterface._check_session() + BackendInterface._session.reset_model() + + @staticmethod + def get_model(): + """Get the current model instance. + + Returns: + The FixMatch model instance. + """ + BackendInterface._check_session() + return BackendInterface._session.model + + # ========== Prediction/Evaluation ========== + + @staticmethod + def update_predictions(progress_callback: Optional[Callable] = None) -> None: + """Update predictions using the current model. + + Args: + progress_callback: Optional callback for progress updates. + """ + BackendInterface._check_session() + BackendInterface._session.update_predictions(progress_callback=progress_callback) + + @staticmethod + def evaluate_all_images(top_n: int, progress_callback: Optional[Callable] = None) -> None: + """Evaluate all images in the prediction search directory. + + Args: + top_n (int): Number of top images to keep. + progress_callback: Optional callback for progress updates. + """ + BackendInterface._check_session() + BackendInterface._session.evaluate_all_images( + top_N=top_n, progress_callback=progress_callback + ) + + @staticmethod + def load_next_batch() -> None: + """Load the next batch of data and update predictions.""" + BackendInterface._check_session() + BackendInterface._session.load_next_batch() + + @staticmethod + def load_top_files(top_n: int) -> None: + """Load the top files from the output directory. + + Args: + top_n (int): Number of top files to load. + """ + BackendInterface._check_session() + BackendInterface._session.load_top_files(top_n) + + @staticmethod + def get_eval_performance() -> Optional[Dict[str, Any]]: + """Get the evaluation performance metrics. + + Returns: + dict: Evaluation performance metrics, or None if not available. + """ + BackendInterface._check_session() + return getattr(BackendInterface._session, "eval_performance", None) + + # ========== Utilities ========== + + @staticmethod + def set_terminal_output(output_widget) -> None: + """Set the terminal output widget for logging. + + Args: + output_widget: The output widget for terminal logging. + """ + BackendInterface._check_session() + BackendInterface._session.set_terminal_out(output_widget) + + @staticmethod + def remember_current_file(filename: str) -> None: + """Remember the current file by appending it to a CSV. + + Args: + filename (str): The filename to remember. + """ + BackendInterface._check_session() + BackendInterface._session.remember_current_file(filename) + + @staticmethod + def get_version() -> str: + """Get the anomaly_match version. + + Returns: + str: The version string. + """ + return anomaly_match.__version__ diff --git a/anomaly_match_ui/utils/display_transforms.py b/anomaly_match_ui/utils/display_transforms.py new file mode 100644 index 0000000..8e75bf7 --- /dev/null +++ b/anomaly_match_ui/utils/display_transforms.py @@ -0,0 +1,178 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +import numpy as np +from PIL import Image, ImageEnhance, ImageFilter, ImageOps +from skimage.util import img_as_ubyte + + +def prepare_for_display(img, rgb_mapping=None): + """Prepare N-channel image for RGB display. + + Converts images with arbitrary channel counts to 3-channel RGB for display. + + Args: + img (np.ndarray): Input image array with shape (H, W, C) + rgb_mapping (list, optional): List of 3 channel indices to use as RGB. + Defaults to [0, 1, 2] (first 3 channels). + + Returns: + np.ndarray: 3-channel RGB image with shape (H, W, 3) and dtype uint8 + """ + # Handle different input formats + if isinstance(img, Image.Image): + img = np.array(img) + + # Ensure we have a valid array + if not isinstance(img, np.ndarray): + raise ValueError(f"Expected numpy array or PIL Image, got {type(img)}") + + # Handle 2D images (grayscale without channel dimension) + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + + channels = img.shape[-1] + + # Convert based on channel count + if channels == 1: + # Grayscale: repeat to RGB + result = np.repeat(img, 3, axis=-1) + elif channels == 2: + # 2 channels: average to grayscale, then repeat + gray = img.mean(axis=-1, keepdims=True) + result = np.repeat(gray, 3, axis=-1) + elif channels == 3: + # RGB: use as-is + result = img + else: + # N-channels (4+): extract specified channels for RGB mapping + if rgb_mapping is None: + rgb_mapping = [0, 1, 2] + + # Validate rgb_mapping + if len(rgb_mapping) != 3: + raise ValueError(f"rgb_mapping must have 3 elements, got {len(rgb_mapping)}") + if any(i >= channels for i in rgb_mapping): + raise ValueError(f"rgb_mapping indices {rgb_mapping} exceed channel count {channels}") + + result = img[:, :, rgb_mapping] + + # Ensure uint8 output + if result.dtype != np.uint8: + if result.max() <= 1.0: + result = (result * 255).astype(np.uint8) + else: + result = np.clip(result, 0, 255).astype(np.uint8) + + return result + + +def display_image_normalisation(img, rgb_mapping=None): + """Normalises the image for display. + + Args: + img (np.ndarray): The input image array. + rgb_mapping (list, optional): For N-channel images, which channels to display as RGB. + + Returns: + PIL.Image.Image: The normalised image. + """ + # Handle NaN/inf values + if not np.isfinite(img).all(): + img = np.nan_to_num(img, nan=0.0, posinf=1.0, neginf=0.0) + + img = img - np.min(img) + img_max = np.max(img) + + # Avoid division by zero + if img_max > 0: + img = img / img_max + else: + img = np.zeros_like(img) # Handle constant images + + img = img_as_ubyte(img) + + # Convert to displayable RGB + img = prepare_for_display(img, rgb_mapping=rgb_mapping) + + return Image.fromarray(img) + + +# from utility_functions +def apply_transforms_ui( + img, + invert, + brightness, + contrast, + unsharp_mask_applied, + show_r=True, + show_g=True, + show_b=True, + channel_visibility=None, +): + """ + Applies the requested transformations to the given PIL Image. + + Args: + img (PIL.Image.Image): The original image. + invert (bool): Whether to invert colors. + brightness (float): Brightness factor. + contrast (float): Contrast factor. + unsharp_mask_applied (bool): Whether to apply an unsharp mask. + show_r (bool): Whether to show the red channel (for RGB mode). + show_g (bool): Whether to show the green channel (for RGB mode). + show_b (bool): Whether to show the blue channel (for RGB mode). + channel_visibility (list, optional): For N-channel mode, list of booleans + indicating which channels to show. If provided, overrides show_r/g/b. + + Returns: + PIL.Image.Image: The transformed image. + """ + # Apply inversion + if invert: + img = ImageOps.invert(img) + + # Apply brightness + if brightness != 1.0: + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness) + + # Apply contrast + if contrast != 1.0: + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast) + + # Apply unsharp mask if enabled + if unsharp_mask_applied: + img = img.filter(ImageFilter.UnsharpMask()) + + # Apply channel toggling + # Determine which channels to show + if channel_visibility is not None: + # N-channel mode: use first 3 values for RGB display + channels_mask = ( + channel_visibility[:3] if len(channel_visibility) >= 3 else channel_visibility + ) + # Pad with True if less than 3 values + while len(channels_mask) < 3: + channels_mask.append(True) + else: + # RGB mode: use individual channel flags + channels_mask = [show_r, show_g, show_b] + + if not all(channels_mask): + # Convert PIL image to numpy array + img_array = np.array(img) + + # Apply masking to the image array (zero out disabled channels) + for i, show_channel in enumerate(channels_mask): + if not show_channel and i < img_array.shape[-1]: + img_array[:, :, i] = 0 + + # Convert back to PIL image + img = Image.fromarray(img_array) + + return img diff --git a/anomaly_match/utils/numpy_to_byte_stream.py b/anomaly_match_ui/utils/image_utils.py similarity index 99% rename from anomaly_match/utils/numpy_to_byte_stream.py rename to anomaly_match_ui/utils/image_utils.py index d16f0ef..9cea43e 100644 --- a/anomaly_match/utils/numpy_to_byte_stream.py +++ b/anomaly_match_ui/utils/image_utils.py @@ -4,10 +4,11 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -from PIL import Image -import numpy as np import io +import numpy as np +from PIL import Image + def numpy_array_to_byte_stream(numpy_array, normalize=True): """Convert a numpy array to a byte stream. diff --git a/anomaly_match/ui/Widget.py b/anomaly_match_ui/widget.py similarity index 76% rename from anomaly_match/ui/Widget.py rename to anomaly_match_ui/widget.py index 86b3865..81396cb 100644 --- a/anomaly_match/ui/Widget.py +++ b/anomaly_match_ui/widget.py @@ -4,28 +4,29 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. +import datetime import os -import numpy as np -from ipywidgets import HBox +import time + import matplotlib.pyplot as plt -from PIL import Image +import numpy as np from IPython.display import display +from ipywidgets import HBox from loguru import logger -from sklearn.metrics import roc_curve -import time -import datetime - +from PIL import Image from skimage.util import img_as_ubyte +from sklearn.metrics import roc_curve from anomaly_match.data_io.load_images import load_and_process_single_wrapper -from anomaly_match.ui.preview_widget import PreviewWidget +from anomaly_match_ui.preview_widget import PreviewWidget # Import the newly created UI elements -from anomaly_match.ui.ui_elements import ( - create_ui_elements, - attach_click_listeners, +from anomaly_match_ui.ui_elements import ( HTML_setup, + attach_click_listeners, + create_ui_elements, ) +from anomaly_match_ui.utils.backend_interface import BackendInterface def shorten_filename(filename: str, max_length: int = 25) -> str: @@ -74,19 +75,13 @@ def shorten_filename(filename: str, max_length: int = 25) -> str: class Widget: """A widget-based user interface for interacting with the anomaly detection session.""" - def __init__(self, session): - """ - Initializes the UI widget with the given session. - - Args: - session (Session): The session to interact with. - """ - # Store session - self.session = session - self.cfg = session.cfg + def __init__(self): + """Initializes the UI widget using the BackendInterface.""" + # Get config from backend + self.cfg = BackendInterface.get_config() # Create the preview widget (handles image display and transforms) - self.preview = PreviewWidget(session) + self.preview = PreviewWidget() # Create all UI elements (moved from the big monolithic class) self.ui = create_ui_elements() @@ -96,8 +91,8 @@ def __init__(self, session): self.ui["filename_text"] = self.preview.filename_text # Rebuild center_row with preview widget's components - from ipywidgets import VBox import ipywidgets as widgets + from ipywidgets import VBox self.ui["center_row"] = VBox( [self.preview.filename_text, self.preview.image_widget], @@ -119,7 +114,7 @@ def __init__(self, session): self.preview.set_full_res_button(self.ui["transform_buttons"]["full_res"]) # Attach the output widget so the session logs go there - session.set_terminal_out(self.ui["out"]) + BackendInterface.set_terminal_output(self.ui["out"]) # Also attach it to the memory monitor self.ui["memory_monitor"].set_output_widget(self.ui["out"]) @@ -127,7 +122,7 @@ def __init__(self, session): self.ui["batch_size_slider"].value = self.cfg.N_to_load # Set initial normalization dropdown value from session cached value - self.ui["normalisation_dropdown"].value = session.cached_image_normalisation_enum + self.ui["normalisation_dropdown"].value = BackendInterface.get_cached_normalisation_method() # Attach all click listeners / slider observers attach_click_listeners(self) @@ -145,7 +140,8 @@ def __init__(self, session): self.update() # Only attempt sorting if we have scores - if self.session.scores is not None: + scores = BackendInterface.get_scores() + if scores is not None: self.sort_by_anomalous() # Start memory monitoring @@ -193,10 +189,8 @@ def _pack_layout(self, main_layout, side_display): def search_all_files(self): """Searches all files and displays the top N with their scores.""" with self.ui["out"]: - self.session.cfg.N_to_load = self.ui["batch_size_slider"].value - logger.debug( - f"Searching all files for anomalies with batch size: {self.session.cfg.N_to_load}" - ) + self.cfg.N_to_load = self.ui["batch_size_slider"].value + logger.debug(f"Searching all files for anomalies with batch size: {self.cfg.N_to_load}") # Set progress bar color to cyan for 'search_all_files' task self.ui["progress_bar"].style = {"bar_color": "cyan"} @@ -213,16 +207,22 @@ def update_progress( completed=False, total_time_str=None, final_speed=None, + results_updated=False, ): + # Update gallery when new results are loaded (after each chunk) + if results_updated: + self.display_top_files_scores() + return + # Always update progress bar if batch info provided if batch is not None and num_batches is not None: self.ui["progress_bar"].value = batch / num_batches # Handle completion message if completed and total_time_str: - self.ui["train_label"].value = ( - f"Search complete in {total_time_str} ({final_speed:.1f} img/sec)" - ) + self.ui[ + "train_label" + ].value = f"Search complete in {total_time_str} ({final_speed:.1f} img/sec)" return # Handle batch update with ETA information @@ -238,41 +238,41 @@ def update_progress( self.ui["train_label"].value = message else: # Early in the process when ETA isn't available yet - self.ui["train_label"].value = ( - f"Evaluating files... Batch: {batch}/{num_batches}" - ) + self.ui[ + "train_label" + ].value = f"Evaluating files... Batch: {batch}/{num_batches}" else: # Regular evaluation updates (not used in this function but keeping for completeness) if eta_str: - self.ui["train_label"].value = ( - f"Evaluating... {batch}/{num_batches} | ETA: {eta_str}" - ) + self.ui[ + "train_label" + ].value = f"Evaluating... {batch}/{num_batches} | ETA: {eta_str}" else: self.ui["train_label"].value = f"Evaluating... {batch}/{num_batches}" - self.session.evaluate_all_images( - top_N=self.cfg.top_N, progress_callback=update_progress - ) - # update models last_normalisation_method only after successful eval - if self.session.model.last_normalisation_method is None: - self.session.model.last_normalisation_method = ( - self.session.cfg.normalisation.normalisation_method - ) - elif ( - self.session.model.last_normalisation_method - != self.session.cfg.normalisation.normalisation_method - ): - logger.warning( - f"Evaluated with a new normalisation {self.session.cfg.normalisation.normalisation_method.name} method " - + f"not previously used with the model: {self.session.model.last_normalisation_method.name}" - ) - self.session.model.last_normalisation_method = ( - self.session.cfg.normalisation.normalisation_method + try: + BackendInterface.evaluate_all_images( + top_n=self.cfg.top_N, progress_callback=update_progress ) - - # Display will be updated by the callback when completed - self.display_top_files_scores() - self.ui["progress_bar"].style = {"bar_color": "green"} + # update models last_normalisation_method only after successful eval + model = BackendInterface.get_model() + if model.last_normalisation_method is None: + model.last_normalisation_method = self.cfg.normalisation.normalisation_method + elif model.last_normalisation_method != self.cfg.normalisation.normalisation_method: + logger.warning( + f"Evaluated with a new normalisation {self.cfg.normalisation.normalisation_method.name} method " + + f"not previously used with the model: {model.last_normalisation_method.name}" + ) + model.last_normalisation_method = self.cfg.normalisation.normalisation_method + except Exception as e: + logger.error(f"Error during evaluation: {e}") + finally: + # Always update display with whatever results are available, + # even if evaluation was interrupted or partially failed + scores = BackendInterface.get_scores() + if scores is not None and len(scores) > 0: + self.display_top_files_scores() + self.ui["progress_bar"].style = {"bar_color": "green"} def display_top_files_scores(self): """Displays the top files and their scores.""" @@ -281,10 +281,6 @@ def display_top_files_scores(self): self.ui["progress_bar"].style = {"bar_color": "green"} self.display_gallery() - def update_image_display(self): - """Updates the display of the current image.""" - self.preview.update_display() - def update_image_UI_label(self, filename=None, score=None): """Updates the UI label with the current image's filename, score, and label.""" self.preview.update_label_only() @@ -292,32 +288,32 @@ def update_image_UI_label(self, filename=None, score=None): # ======== Sorting Methods ======== def sort_by_anomalous(self): """Sorts the images by their anomalous scores and updates the display.""" - self.session.sort_by_anomalous() + BackendInterface.sort_by_anomalous() self.preview.set_index(0) self.preview.update_display() def sort_by_nominal(self): """Sorts the images by their nominal scores and updates the display.""" - self.session.sort_by_nominal() + BackendInterface.sort_by_nominal() self.preview.set_index(0) self.preview.update_display() def sort_by_mean(self): """Sorts the images by distance to mean score and updates the display.""" - self.session.sort_by_mean() + BackendInterface.sort_by_mean() self.preview.set_index(0) self.preview.update_display() def sort_by_median(self): """Sorts the images by distance to median score and updates the display.""" - self.session.sort_by_median() + BackendInterface.sort_by_median() self.preview.set_index(0) self.preview.update_display() # ======== Navigation ======== def next_image(self): """Displays the next image in the catalog.""" - new_index = min(len(self.session.img_catalog) - 1, self.preview.current_index + 1) + new_index = min(BackendInterface.get_image_count() - 1, self.preview.current_index + 1) self.preview.reset_full_resolution_mode() self.preview.set_index(new_index) self.preview.update_display() @@ -362,6 +358,9 @@ def display_gallery(self): self.ui["gallery"].clear_output(wait=True) try: if self.cfg.test_ratio > 0: + eval_performance = BackendInterface.get_eval_performance() + if eval_performance is None: + return # Show mispredicted images mispredicted_images = [] @@ -369,7 +368,7 @@ def display_gallery(self): # First collect all filenames we want to display display_files = [] - for filename, (pred, label) in self.session.eval_performance[ + for filename, (pred, label) in eval_performance[ "eval/predictions_and_labels" ].items(): pred, label = pred.item(), label.item() @@ -386,12 +385,11 @@ def display_gallery(self): try: img_array = load_and_process_single_wrapper( path, - self.session.cfg, + self.cfg, desc="widget loading image", show_progress=False, ) - img = Image.fromarray(img_array) mispredicted_images.append(img_array) display_name = shorten_filename(filename) @@ -406,7 +404,7 @@ def display_gallery(self): num_images = len(mispredicted_images) if num_images > 0: plt.figure(figsize=(12, 4), facecolor="black") - eval_perf = self.session.eval_performance + eval_perf = eval_performance plt.suptitle( f"Top {num_images} Mispredicted Test Images | " f"Acc: {eval_perf['eval/top-1-acc'] * 100:.1f}% | " @@ -434,7 +432,7 @@ def display_gallery(self): ax1 = plt.subplot(1, 2, 1) labels, probs = eval_perf["eval/roc_data"] fpr, tpr, _ = roc_curve(labels, probs) - ax1.plot(fpr, tpr, "b-", label=f'ROC (AUC={eval_perf["eval/auroc"]:.3f})') + ax1.plot(fpr, tpr, "b-", label=f"ROC (AUC={eval_perf['eval/auroc']:.3f})") ax1.plot([0, 1], [0, 1], "r--") ax1.set_title("ROC Curve", color="white") ax1.grid(True, alpha=0.3) @@ -451,7 +449,7 @@ def display_gallery(self): recall, precision, "g-", - label=f'PRC (AUC={eval_perf["eval/auprc"]:.3f})', + label=f"PRC (AUC={eval_perf['eval/auprc']:.3f})", ) ax2.set_title("Precision-Recall Curve", color="white") ax2.grid(True, alpha=0.3) @@ -467,7 +465,10 @@ def display_gallery(self): else: # Show top 5 anomalous & top 5 nominal - scores = self.session.scores + scores = BackendInterface.get_scores() + img_catalog = BackendInterface.get_session().img_catalog + filenames = BackendInterface.get_filenames() + indices = np.argsort(scores) num_images_to_display = min(5, len(scores) // 2) @@ -478,7 +479,7 @@ def display_gallery(self): image_text = [] for idx in top_anomalous_indices: - img_arr = self.session.img_catalog[idx] + img_arr = img_catalog[idx] img_arr = img_arr - np.min(img_arr) img_arr = img_arr / np.max(img_arr) img_arr = img_as_ubyte(img_arr) @@ -486,13 +487,13 @@ def display_gallery(self): img_arr = np.repeat(img_arr, 3, axis=-1) pil_img = Image.fromarray(img_arr) images.append(pil_img) - filename = self.session.filenames[idx] + filename = filenames[idx] display_name = shorten_filename(filename) score = scores[idx] image_text.append(f"{display_name}\nScore: {score:.4f}") for idx in top_nominal_indices: - img_arr = self.session.img_catalog[idx] + img_arr = img_catalog[idx] img_arr = img_arr - np.min(img_arr) img_arr = img_arr / np.max(img_arr) img_arr = img_as_ubyte(img_arr) @@ -500,7 +501,7 @@ def display_gallery(self): img_arr = np.repeat(img_arr, 3, axis=-1) pil_img = Image.fromarray(img_arr) images.append(pil_img) - filename = self.session.filenames[idx] + filename = filenames[idx] display_name = shorten_filename(filename) score = scores[idx] image_text.append(f"{display_name}\nScore: {score:.4f}") @@ -527,33 +528,33 @@ def display_gallery(self): logger.error(f"Error displaying gallery: {e}") def save_labels(self): - self.session.save_labels() + BackendInterface.save_labels() def remember_current_file(self, _): """Remembers the currently displayed file.""" - self.session.remember_current_file(self.session.filenames[self.preview.current_index]) + filenames = BackendInterface.get_filenames() + BackendInterface.remember_current_file(filenames[self.preview.current_index]) def save_model(self): """Saves the model using the session.""" - self.session.save_model() + BackendInterface.save_model() def load_model(self): """Loads the model using the session.""" with self.ui["out"]: logger.debug( - f"Loading model, cfg norm: {self.session.cfg.normalisation.normalisation_method}, " + f"Loading model, cfg norm: {self.cfg.normalisation.normalisation_method}, " ) - self.session.load_model() + BackendInterface.load_model() # Update the normalization dropdown to match the session's method - self.ui["normalisation_dropdown"].value = ( - self.session.cfg.normalisation.normalisation_method - ) + self.ui["normalisation_dropdown"].value = self.cfg.normalisation.normalisation_method with self.ui["out"]: + model = BackendInterface.get_model() logger.debug( - f"Loaded model, cfg norm: {self.session.cfg.normalisation.normalisation_method}," - + f" model norm: {self.session.model.last_normalisation_method}" + f"Loaded model, cfg norm: {self.cfg.normalisation.normalisation_method}," + + f" model norm: {model.last_normalisation_method}" ) self.update() @@ -562,8 +563,8 @@ def train(self): with self.ui["out"]: logger.debug( - f"Session norm: {self.session.cached_image_normalisation_enum}, " - f"selected norm: {self.session.cfg.normalisation.normalisation_method}" + f"Session norm: {BackendInterface.get_cached_normalisation_method()}, " + f"selected norm: {self.cfg.normalisation.normalisation_method}" ) with self.ui["out"]: @@ -606,57 +607,72 @@ def update_training_progress(iteration, total_iterations): eta_str = str(datetime.timedelta(seconds=int(eta_seconds))) # Update display with iteration count and ETA - self.ui["train_label"].value = ( - f"Training... Iteration {iteration}/{total_iterations} | " - f"ETA: {eta_str}" + self.ui[ + "train_label" + ].value = ( + f"Training... Iteration {iteration}/{total_iterations} | ETA: {eta_str}" ) else: # Early iterations - no reliable ETA yet - self.ui["train_label"].value = ( - f"Training... Iteration {iteration}/{total_iterations}" - ) + self.ui[ + "train_label" + ].value = f"Training... Iteration {iteration}/{total_iterations}" last_update_time = current_time last_iteration = iteration - self.session.train(self.cfg, progress_callback=update_training_progress) + BackendInterface.train(self.cfg, progress_callback=update_training_progress) # update models last_normalisation_method after successful training - if self.session.model.last_normalisation_method is None: - self.session.model.last_normalisation_method = ( - self.session.cfg.normalisation.normalisation_method - ) - elif ( - self.session.model.last_normalisation_method - != self.session.cfg.normalisation.normalisation_method - ): + model = BackendInterface.get_model() + if model.last_normalisation_method is None: + model.last_normalisation_method = self.cfg.normalisation.normalisation_method + elif model.last_normalisation_method != self.cfg.normalisation.normalisation_method: logger.warning( - f"Trained with a new normalisation {self.session.cfg.normalisation.normalisation_method.name} method " - + f"not previously used with the model: {self.session.model.last_normalisation_method.name}" - ) - self.session.model.last_normalisation_method = ( - self.session.cfg.normalisation.normalisation_method + f"Trained with a new normalisation {self.cfg.normalisation.normalisation_method.name} method " + + f"not previously used with the model: {model.last_normalisation_method.name}" ) + model.last_normalisation_method = self.cfg.normalisation.normalisation_method # Calculate total time taken total_time = time.time() - start_time time_str = str(datetime.timedelta(seconds=int(total_time))) - self.ui["progress_bar"].style = {"bar_color": "green"} - self.ui["train_label"].value = f"Training complete in {time_str}." + self.ui[ + "train_label" + ].value = f"Training complete in {time_str}. Updating predictions..." self.update() self.sort_by_anomalous() def update(self): """Updates the UI components and performs evaluation.""" + self.ui["progress_bar"].value = 0.0 self.ui["progress_bar"].style = {"bar_color": "cyan"} - self.session.update_predictions() + self.ui["train_label"].value = "Scoring unlabelled samples..." + + phase = {"name": "scoring"} + + def update_progress(current, total): + self.ui["progress_bar"].value = current / total + if phase["name"] == "scoring": + self.ui["train_label"].value = f"Scoring unlabelled samples... {current}/{total}" + if current >= total: + phase["name"] = "evaluating" + self.ui["train_label"].value = "Evaluating model..." + self.ui["progress_bar"].value = 0.0 + else: + self.ui["train_label"].value = f"Evaluating model... {current}/{total}" + + BackendInterface.update_predictions(progress_callback=update_progress) self.preview.set_index(0) if self.cfg.test_ratio > 0: - if self.session.eval_performance is not None: - self.ui["train_label"].value = ( - f"Training Complete. Eval Acc: {self.session.eval_performance['eval/top-1-acc'] * 100:.2f}%" + eval_performance = BackendInterface.get_eval_performance() + if eval_performance is not None: + self.ui[ + "train_label" + ].value = ( + f"Training Complete. Eval Acc: {eval_performance['eval/top-1-acc'] * 100:.2f}%" ) else: self.ui["train_label"].value = "Training Complete. No evaluation performed yet." @@ -671,7 +687,7 @@ def update(self): def update_batch_size(self, change): """Updates the batch size in the session config.""" - self.session.cfg.N_to_load = change["new"] + self.cfg.N_to_load = change["new"] def next_batch(self): """Loads the next batch and updates predictions.""" @@ -680,7 +696,7 @@ def next_batch(self): self.ui["progress_bar"].style = {"bar_color": "orange"} self.ui["train_label"].value = "Predicting next batch..." - self.session.load_next_batch() + BackendInterface.load_next_batch() self.ui["progress_bar"].style = {"bar_color": "green"} self.ui["train_label"].value = "Batch loading complete." @@ -690,14 +706,14 @@ def next_batch(self): def reset_model(self): """Resets the model in the session.""" with self.ui["out"]: - self.session.reset_model() + BackendInterface.reset_model() self.update() self.sort_by_anomalous() def load_top_files(self): """Loads the top files and updates the display.""" with self.ui["out"]: - self.session.load_top_files(self.cfg.top_N) + BackendInterface.load_top_files(self.cfg.top_N) self.display_top_files_scores() # Add channel toggle methods @@ -717,10 +733,10 @@ def select_normalisation(self, change): """Updates the normalization method when dropdown selection changes.""" new_value = change["new"] if new_value != self.cfg.normalisation.normalisation_method: - self.session.set_normalisation_method(new_value) + BackendInterface.set_normalisation_method(new_value) self.preview.update_display() def unlabel_current_image(self): """Removes the label from the currently displayed image.""" - self.session.unlabel_image(self.preview.current_index) + BackendInterface.unlabel_image(self.preview.current_index) self.preview.update_label_only() diff --git a/environment.yml b/environment.yml index f9cacce..94b3ce4 100644 --- a/environment.yml +++ b/environment.yml @@ -12,21 +12,22 @@ channels: dependencies: - astropy - dotmap - - efficientnet-pytorch - h5py - - imageio - ipykernel - ipywidgets - loguru - matplotlib - numpy - pandas<3 + - psutil + - pyarrow - python=3.11 - pytorch>=2.6 - pytorch-cuda=12.4 - pyturbojpeg - - scikit-learn - scikit-image + - scikit-learn + - scipy - toml - torchvision - tqdm @@ -34,8 +35,7 @@ dependencies: - pip - pip: - albumentations - - efficientnet_lite_pytorch - - efficientnet_lite0_pytorch_model - - opencv-python-headless - - fitsbolt>=0.1.6 - cutana>=0.2.1 + - fitsbolt>=0.2 + - opencv-python-headless + - timm diff --git a/environment_CI.yml b/environment_CI.yml index edb3021..d815bc1 100644 --- a/environment_CI.yml +++ b/environment_CI.yml @@ -11,7 +11,6 @@ channels: dependencies: - astropy - dotmap - - efficientnet-pytorch - h5py - imageio - ipykernel @@ -23,6 +22,7 @@ dependencies: - python=3.11 - pytorch>=2.6 - pytest + - pytest-cov - pytest-asyncio>=0.23.0 - pyturbojpeg - scikit-learn @@ -35,7 +35,6 @@ dependencies: - pip: - opencv-python-headless - albumentations - - efficientnet_lite_pytorch - - efficientnet_lite0_pytorch_model - - fitsbolt>=0.1.6 + - timm + - fitsbolt>=0.2 - cutana>=0.2.1 diff --git a/paper_scripts/create_results.py b/paper_scripts/create_results.py index 736d0ea..6bf2250 100644 --- a/paper_scripts/create_results.py +++ b/paper_scripts/create_results.py @@ -17,14 +17,15 @@ python create_results.py """ -import sys -import time +import argparse import datetime +import glob import subprocess +import sys +import time from pathlib import Path -import argparse + import pandas as pd -import glob # ========== CONFIGURATION ========== # Toggle which experiment sets to run diff --git a/paper_scripts/dataset_plot.py b/paper_scripts/dataset_plot.py index dac6293..2a18cee 100644 --- a/paper_scripts/dataset_plot.py +++ b/paper_scripts/dataset_plot.py @@ -12,12 +12,13 @@ """ import os +from pathlib import Path + +import matplotlib.gridspec as gridspec +import matplotlib.pyplot as plt import numpy as np import pandas as pd -import matplotlib.pyplot as plt -import matplotlib.gridspec as gridspec from matplotlib.patches import Rectangle -from pathlib import Path from PIL import Image # Figure settings for paper-quality output diff --git a/paper_scripts/galaxyzoo_multi_seed.py b/paper_scripts/galaxyzoo_multi_seed.py index bd9b0cc..3e60012 100644 --- a/paper_scripts/galaxyzoo_multi_seed.py +++ b/paper_scripts/galaxyzoo_multi_seed.py @@ -46,31 +46,32 @@ └── average_astronomaly_comparison_zoo_class1_n400_ratio0.015_iter{DEFAULT_TRAINING_RUNS}.pdf """ -import sys -import time +import argparse import datetime -import subprocess import pickle +import subprocess +import sys +import time from pathlib import Path -import argparse -import pandas as pd -import numpy as np +from typing import Dict, List + import matplotlib.pyplot as plt -from typing import List, Dict +import numpy as np +import pandas as pd # Import plotting constants from paper_plots sys.path.append(str(Path(__file__).parent)) try: from paper_plots import ( BLUE, + DEFAULT_DPI, GREEN, ORANGE, + PERFECT_LINE_STYLE, PURPLE, RED, REFERENCE_LINE_COLOR, REFERENCE_LINE_STYLE, - PERFECT_LINE_STYLE, - DEFAULT_DPI, ) except ImportError: # Fallback colors if import fails @@ -227,8 +228,8 @@ def collect_results_from_seeds(output_dir: Path, seeds: List[int]) -> Dict[str, config_key = f"zoo_class{cls}_n{n_samples}_ratio{anomaly_ratio:.3f}" config_results = [] sub_config_key = ( - f"galaxyzoo_anomaly{cls}_n{n_samples-int(n_samples*anomaly_ratio)}" - + f"_a{int(n_samples*anomaly_ratio)}" + f"galaxyzoo_anomaly{cls}_n{n_samples - int(n_samples * anomaly_ratio)}" + + f"_a{int(n_samples * anomaly_ratio)}" ) for seed in seeds: seed_dir = output_dir / f"seed_{seed}" / "galaxyzoo" / config_key / sub_config_key @@ -331,8 +332,8 @@ def load_plot_data_from_seeds( for n_samples, anomaly_ratio in GALAXYZOO_CONFIGS: config_key = f"zoo_class{cls}_n{n_samples}_ratio{anomaly_ratio:.3f}" sub_config_key = ( - f"galaxyzoo_anomaly{cls}_n{n_samples-int(n_samples*anomaly_ratio)}" - + f"_a{int(n_samples*anomaly_ratio)}" + f"galaxyzoo_anomaly{cls}_n{n_samples - int(n_samples * anomaly_ratio)}" + + f"_a{int(n_samples * anomaly_ratio)}" ) plot_data_list = [] @@ -594,9 +595,9 @@ def main(): # Run experiments for each seed for i, seed in enumerate(args.seeds): - print(f"\n{'='*60}") - print(f"Running experiments for seed {seed} ({i+1}/{len(args.seeds)})") - print(f"{'='*60}") + print(f"\n{'=' * 60}") + print(f"Running experiments for seed {seed} ({i + 1}/{len(args.seeds)})") + print(f"{'=' * 60}") try: seed_output_dir = run_single_seed_experiment( @@ -608,9 +609,9 @@ def main(): continue # Collect and analyze results - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("Collecting and analyzing results from all seeds") - print(f"{'='*60}") + print(f"{'=' * 60}") # Create summary directory summary_dir = output_dir / "summary" @@ -636,10 +637,10 @@ def main(): # Final summary total_time = time.time() - start_time - print(f"\n{'='*60}") + print(f"\n{'=' * 60}") print("EXPERIMENT SUMMARY") - print(f"{'='*60}") - print(f"Total execution time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)") + print(f"{'=' * 60}") + print(f"Total execution time: {total_time:.2f} seconds ({total_time / 60:.2f} minutes)") print(f"Seeds processed: {len(args.seeds)}") print(f"Configurations per seed: {len(GALAXYZOO_CONFIGS)}") print(f"Results saved to: {output_dir}") diff --git a/paper_scripts/get_example_images.py b/paper_scripts/get_example_images.py index 6938efe..9ec453d 100644 --- a/paper_scripts/get_example_images.py +++ b/paper_scripts/get_example_images.py @@ -14,24 +14,23 @@ import os import sys +from pathlib import Path + +import matplotlib.pyplot as plt import numpy as np import pandas as pd -import matplotlib.pyplot as plt +import torchvision.transforms as transforms from matplotlib.patches import Rectangle -from pathlib import Path from PIL import Image -import torchvision.transforms as transforms - sys.path.append("/media/home/AnomalyMatch") sys.path.append("../") from anomaly_match.image_processing.transforms import ( - get_weak_transforms, get_strong_transforms, + get_weak_transforms, ) - # Constants HOURGLASS_CLASS_IDX = 57 # Class ID for hourglass images (anomaly) NUM_EXAMPLES = 10 # Number of examples to generate for each class diff --git a/paper_scripts/get_hst_example_images.py b/paper_scripts/get_hst_example_images.py index f0c7055..6ee1c5a 100644 --- a/paper_scripts/get_hst_example_images.py +++ b/paper_scripts/get_hst_example_images.py @@ -13,14 +13,15 @@ """ import os +import random import sys -import numpy as np +from pathlib import Path + import matplotlib.pyplot as plt +import numpy as np +import torchvision.transforms as transforms from matplotlib.patches import Rectangle -from pathlib import Path from PIL import Image -import torchvision.transforms as transforms -import random sys.path.append("/media/home/AnomalyMatch") sys.path.append("../") diff --git a/paper_scripts/paper_benchmark.py b/paper_scripts/paper_benchmark.py index 318e60b..bb0b18a 100644 --- a/paper_scripts/paper_benchmark.py +++ b/paper_scripts/paper_benchmark.py @@ -235,7 +235,7 @@ def get_prediction_scores(session, labeled_filenames, hdf5_path, progress_bar=No if batch_idx % 5 == 0 or batch_idx == num_batches - 1: logger.info( - f"Batch {batch_idx+1}/{num_batches}: Processed {batch_size_actual} images " + f"Batch {batch_idx + 1}/{num_batches}: Processed {batch_size_actual} images " f"in {batch_time:.2f}s ({images_per_sec:.1f} img/s)" ) @@ -267,7 +267,7 @@ def get_prediction_scores(session, labeled_filenames, hdf5_path, progress_bar=No avg_time_per_image = total_time / processed_images if processed_images > 0 else 0 logger.info( f"Processed {processed_images} images in {total_time:.2f}s " - f"({processed_images/total_time:.1f} img/s, {avg_time_per_image*1000:.2f}ms/img)" + f"({processed_images / total_time:.1f} img/s, {avg_time_per_image * 1000:.2f}ms/img)" ) # Handle mismatched length between scores and filenames @@ -504,11 +504,11 @@ def run_benchmark(args): # Run training and evaluation loop for iteration in range(args.training_runs): # Create iteration-specific directory - iter_dir = os.path.join(run_dir, f"iteration_{iteration+1}") + iter_dir = os.path.join(run_dir, f"iteration_{iteration + 1}") iter_plots_dir = os.path.join(iter_dir, "plots") logger.info( - f"\n======= Starting training iteration {iteration+1}/{args.training_runs} =======" + f"\n======= Starting training iteration {iteration + 1}/{args.training_runs} =======" ) # Train model @@ -516,7 +516,7 @@ def run_benchmark(args): train_with_progress_bar(session, cfg) # Save model after training - model_save_path = os.path.join(model_dir, f"model_iter{iteration+1}.pth") + model_save_path = os.path.join(model_dir, f"model_iter{iteration + 1}.pth") iter_model_path = os.path.join(iter_dir, f"model.pth") # session.session_tracker = None session.cfg.model_path = model_save_path @@ -562,12 +562,12 @@ def run_benchmark(args): # Log top percentile metrics if "top_0.1pct_anomalies_found" in metrics: logger.info( - f"Iter {iteration+1} - Anomalies in top 0.1%: {metrics['top_0.1pct_anomalies_found']:.2f}%, " + f"Iter {iteration + 1} - Anomalies in top 0.1%: {metrics['top_0.1pct_anomalies_found']:.2f}%, " f"Precision: {metrics['top_0.1pct_precision']:.2f}%" ) if "top_1.0pct_anomalies_found" in metrics: logger.info( - f"Iter {iteration+1} - Anomalies in top 1.0%: {metrics['top_1.0pct_anomalies_found']:.2f}%, " + f"Iter {iteration + 1} - Anomalies in top 1.0%: {metrics['top_1.0pct_anomalies_found']:.2f}%, " f"Precision: {metrics['top_1.0pct_precision']:.2f}%" ) @@ -868,9 +868,9 @@ def run_multi_class_benchmark(args): ) for anomaly_class in args.anomaly_classes: - logger.info(f"\n\n{'='*50}") + logger.info(f"\n\n{'=' * 50}") logger.info(f"Starting benchmark for anomaly class {anomaly_class}") - logger.info(f"{'='*50}\n") + logger.info(f"{'=' * 50}\n") # Update args for this specific class args.anomaly_class = anomaly_class @@ -1045,11 +1045,11 @@ def run_multi_class_benchmark(args): # Run training and evaluation loop for iteration in range(args.training_runs): # Create iteration-specific directory - iter_dir = os.path.join(run_dir, f"iteration_{iteration+1}") + iter_dir = os.path.join(run_dir, f"iteration_{iteration + 1}") iter_plots_dir = os.path.join(iter_dir, "plots") logger.info( - f"\n======= Starting training iteration {iteration+1}/{args.training_runs} for anomaly class {anomaly_class} =======" + f"\n======= Starting training iteration {iteration + 1}/{args.training_runs} for anomaly class {anomaly_class} =======" ) # Train model with progress bar @@ -1057,7 +1057,7 @@ def run_multi_class_benchmark(args): train_with_progress_bar(session, cfg) # Save model after training - model_save_path = os.path.join(model_dir, f"model_iter{iteration+1}.pth") + model_save_path = os.path.join(model_dir, f"model_iter{iteration + 1}.pth") iter_model_path = os.path.join(iter_dir, f"model.pth") # session.session_tracker = None session.cfg.model_path = model_save_path @@ -1101,12 +1101,12 @@ def run_multi_class_benchmark(args): # Log top percentile metrics if "top_0.1pct_anomalies_found" in metrics: logger.info( - f"Iter {iteration+1} - Anomalies in top 0.1%: {metrics['top_0.1pct_anomalies_found']:.2f}%, " + f"Iter {iteration + 1} - Anomalies in top 0.1%: {metrics['top_0.1pct_anomalies_found']:.2f}%, " f"Precision: {metrics['top_0.1pct_precision']:.2f}%" ) if "top_1.0pct_anomalies_found" in metrics: logger.info( - f"Iter {iteration+1} - Anomalies in top 1.0%: {metrics['top_1.0pct_anomalies_found']:.2f}%, " + f"Iter {iteration + 1} - Anomalies in top 1.0%: {metrics['top_1.0pct_anomalies_found']:.2f}%, " f"Precision: {metrics['top_1.0pct_precision']:.2f}%" ) diff --git a/paper_scripts/paper_plots.py b/paper_scripts/paper_plots.py index 305c49b..1c8f756 100644 --- a/paper_scripts/paper_plots.py +++ b/paper_scripts/paper_plots.py @@ -12,38 +12,40 @@ """ import os + +import matplotlib.pyplot as plt import numpy as np import pandas as pd -import matplotlib.pyplot as plt import seaborn as sns -from sklearn.metrics import roc_curve from loguru import logger from paper_utils import save_plot_data +from sklearn.metrics import roc_curve + +from paper_scripts.create_results import GALAXYZOO_THRESHOLDS from paper_scripts.plot_colors import ( + ANOMALY_COLOR, BLUE, - RED, + COLORMAP_NAME, GREEN, + HIST_ALPHA, + HLINE_ALPHA, + HLINE_COLOR, + HLINE_STYLE, + LAST_ITER_COLOR, + NORMAL_COLOR, ORANGE, - PURPLE, + PERFECT_LINE_ALPHA, PERFECT_LINE_COLOR, PERFECT_LINE_STYLE, - PERFECT_LINE_ALPHA, + PURPLE, + RED, + REFERENCE_LINE_ALPHA, REFERENCE_LINE_COLOR, REFERENCE_LINE_STYLE, - REFERENCE_LINE_ALPHA, + VLINE_ALPHA, VLINE_COLOR, VLINE_STYLE, - VLINE_ALPHA, - HLINE_COLOR, - HLINE_STYLE, - HLINE_ALPHA, - COLORMAP_NAME, - LAST_ITER_COLOR, - NORMAL_COLOR, - ANOMALY_COLOR, - HIST_ALPHA, ) -from paper_scripts.create_results import GALAXYZOO_THRESHOLDS # Scaling factor for all font sizes (adjust this to make all text larger or smaller) FONT_SCALE = 1.75 @@ -199,7 +201,7 @@ def plot_roc_prc_curves(metrics, iteration, plots_dir): y_scores = np.concatenate([metrics["anomaly_scores"], metrics["normal_scores"]]) fpr, tpr, _ = roc_curve(y_true, y_scores) # 1. ROC Curve plt.figure(figsize=(8, 8)) - plt.plot(fpr, tpr, color=BLUE, linewidth=2, label=f'AUROC = {metrics["auroc"]:.3f}') + plt.plot(fpr, tpr, color=BLUE, linewidth=2, label=f"AUROC = {metrics['auroc']:.3f}") plt.plot( [0, 1], [0, 1], @@ -229,7 +231,7 @@ def plot_roc_prc_curves(metrics, iteration, plots_dir): metrics["precision"], color=RED, linewidth=2, - label=f'AUPRC = {metrics["auprc"]:.3f}', + label=f"AUPRC = {metrics['auprc']:.3f}", ) # Note: We're removing the baseline from the PR curve as requested @@ -251,7 +253,7 @@ def plot_roc_prc_curves(metrics, iteration, plots_dir): # 3. Combined figure (side by side) fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5)) # Plot ROC curve - ax1.plot(fpr, tpr, color=BLUE, linewidth=2, label=f'AUROC = {metrics["auroc"]:.3f}') + ax1.plot(fpr, tpr, color=BLUE, linewidth=2, label=f"AUROC = {metrics['auroc']:.3f}") ax1.plot( [0, 1], [0, 1], @@ -274,7 +276,7 @@ def plot_roc_prc_curves(metrics, iteration, plots_dir): metrics["precision"], color=RED, linewidth=2, - label=f'AUPRC = {metrics["auprc"]:.3f}', + label=f"AUPRC = {metrics['auprc']:.3f}", ) ax2.set_xlabel("Recall") ax2.set_ylabel("Precision") @@ -1418,9 +1420,9 @@ def create_grid_plot(merged_df, n_grid, data_dir, plots_dir, iteration, suffix, suffix: Suffix for the output filename fig_title: Optional title for the figure """ - from PIL import Image import matplotlib.gridspec as gridspec import matplotlib.patches as patches + from PIL import Image # Number of user score bins (fewer than ML score bins) n_user_grid = 6 @@ -1863,7 +1865,7 @@ def create_grid_plot(merged_df, n_grid, data_dir, plots_dir, iteration, suffix, # Add a legend for the visual indicators in the bottom-left corner ax_legend = plt.subplot(gs[n_user_grid + 1, 0]) - legend_text = "+/-n: AM ranks\n" "higher/lower\n" "score than GZ\n" + legend_text = "+/-n: AM ranks\nhigher/lower\nscore than GZ\n" ax_legend.text( 0.35, 0.5, diff --git a/paper_scripts/paper_utils.py b/paper_scripts/paper_utils.py index 9aabffc..de4ecba 100644 --- a/paper_scripts/paper_utils.py +++ b/paper_scripts/paper_utils.py @@ -6,16 +6,15 @@ # the terms contained in the file 'LICENCE.txt'. import argparse import os - +import sys from pathlib import Path + +import h5py import ipywidgets as widgets import numpy as np import pandas as pd -import h5py from loguru import logger -from sklearn.metrics import roc_auc_score, precision_recall_curve, auc - -import sys +from sklearn.metrics import auc, precision_recall_curve, roc_auc_score sys.path.append("/media/home/AnomalyMatch") sys.path.append("../") @@ -478,8 +477,9 @@ def train_with_progress_bar(session, cfg): session (am.Session): The AnomalyMatch session cfg (DotMap): Configuration for training """ - import time import datetime + import time + from tqdm import tqdm # Create a tqdm progress bar @@ -566,8 +566,8 @@ def save_plot_data(data, plot_type, iteration, output_dir): iteration: Iteration number (0 for baseline) output_dir: Directory to save data to """ - import pickle import os + import pickle # Create plot_data directory if it doesn't exist plot_data_dir = os.path.join(output_dir, "plot_data") diff --git a/paper_scripts/prepare_datasets.py b/paper_scripts/prepare_datasets.py index 8e23371..0a6e7a8 100644 --- a/paper_scripts/prepare_datasets.py +++ b/paper_scripts/prepare_datasets.py @@ -22,19 +22,20 @@ - GalaxyZoo images and training_solutions_am_2.csv in datasets/ folder """ -import os -import io import argparse +import concurrent.futures +import gc # Add garbage collection +import io +import os + +import h5py import numpy as np import pandas as pd -import h5py -from PIL import Image +import pyarrow.parquet as pq import torch -from tqdm import tqdm from loguru import logger -import concurrent.futures -import pyarrow.parquet as pq -import gc # Add garbage collection +from PIL import Image +from tqdm import tqdm # Configure basic logging logger.remove() diff --git a/paper_scripts/recreate_plots.py b/paper_scripts/recreate_plots.py index 1aae0c4..80f812c 100644 --- a/paper_scripts/recreate_plots.py +++ b/paper_scripts/recreate_plots.py @@ -18,42 +18,43 @@ Date: April 14, 2025 """ -import os -import sys import argparse import glob +import os import pickle +import sys +from pathlib import Path + +import matplotlib.pyplot as plt import numpy as np import pandas as pd -import matplotlib.pyplot as plt import seaborn as sns -from pathlib import Path from loguru import logger # Add parent directory to path to import paper_plots sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from paper_scripts.paper_plots import ( - plot_score_histogram, - plot_roc_prc_curves, - plot_top_n_anomaly_detection, - plot_metrics_over_time, + FONT_SCALE, + plot_astronomaly_comparison, plot_combined_anomaly_detection, - plot_top_n_with_thresholds, + plot_metrics_over_time, + plot_roc_prc_curves, plot_roc_with_thresholds, - plot_astronomaly_comparison, + plot_score_histogram, plot_score_vs_user_score_grid, - FONT_SCALE, + plot_top_n_anomaly_detection, + plot_top_n_with_thresholds, ) from paper_scripts.plot_colors import ( BLUE, ORANGE, - PURPLE, + PERFECT_LINE_ALPHA, PERFECT_LINE_COLOR, PERFECT_LINE_STYLE, - PERFECT_LINE_ALPHA, + PURPLE, + VLINE_ALPHA, VLINE_COLOR, VLINE_STYLE, - VLINE_ALPHA, ) # Set matplotlib parameters for publication-quality plots (similar to paper_plots.py) diff --git a/paper_scripts/results_analysis.py b/paper_scripts/results_analysis.py index b709a4f..113a6f5 100644 --- a/paper_scripts/results_analysis.py +++ b/paper_scripts/results_analysis.py @@ -4,9 +4,9 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import pandas as pd from pathlib import Path +import pandas as pd MINIIMAGENET_CLASS_NAMES = {48: "Guitar", 57: "Hourglass", 68: "Printer", 85: "Piano", 95: "Orange"} diff --git a/paper_scripts/test_plots.py b/paper_scripts/test_plots.py index 47e57a9..b22f3ee 100644 --- a/paper_scripts/test_plots.py +++ b/paper_scripts/test_plots.py @@ -13,22 +13,23 @@ python test_plots.py --reload # Test loading and recreating plots from saved data """ -import os import argparse +import os + import numpy as np import pandas as pd -from PIL import Image from paper_plots import ( - plot_score_histogram, + plot_astronomaly_comparison, + plot_combined_anomaly_detection, plot_metrics_over_time, plot_roc_prc_curves, - plot_top_n_anomaly_detection, - plot_combined_anomaly_detection, - plot_astronomaly_comparison, plot_roc_with_thresholds, - plot_top_n_with_thresholds, + plot_score_histogram, plot_score_vs_user_score_grid, + plot_top_n_anomaly_detection, + plot_top_n_with_thresholds, ) +from PIL import Image # Create output directory for test plots output_dir = "test_plots_output" @@ -304,9 +305,9 @@ def test_all_plots(): def test_reload_plots(plot_data_dir): """Test loading saved plot data and recreating plots from it.""" - import pickle - import os import glob + import os + import pickle # Directory for reloaded plots reload_dir = os.path.join(output_dir, "reloaded_plots") diff --git a/prediction_process.py b/prediction_process.py index 5efb448..28417c8 100644 --- a/prediction_process.py +++ b/prediction_process.py @@ -5,32 +5,27 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. import argparse -import os -import pickle -import sys +import time +from concurrent.futures import ThreadPoolExecutor -from dotmap import DotMap -import torch import numpy as np +import torch from loguru import logger -from concurrent.futures import ThreadPoolExecutor from tqdm import tqdm -import time from anomaly_match.data_io.load_images import ( load_and_process_single_wrapper, ) - +from anomaly_match.image_processing.transforms import ( + get_prediction_transforms, +) from prediction_utils import ( + clear_gpu_cache_if_needed, load_model, - save_results, + load_prediction_config, process_batch_predictions, - estimate_batch_size, - clear_gpu_cache_if_needed, -) - -from anomaly_match.image_processing.transforms import ( - get_prediction_transforms, + save_results, + setup_prediction_logging, ) @@ -55,6 +50,12 @@ def evaluate_files(file_list, cfg, top_n=1000, batch_size=1000, max_workers=1): """Evaluate files in batches and return top N scores. file list is a list of cfg.prediction_search_dir+filename """ + if not file_list: + raise FileNotFoundError( + f"No files to evaluate. The prediction search directory " + f"'{cfg.prediction_search_dir}' is empty or contains no supported image files." + ) + logger.trace(f"{len(file_list)} unlabeled images remain.") # Load model first - this loads the fitsbolt config from the checkpoint @@ -130,21 +131,7 @@ def main(): parser.add_argument("top_n", type=int, default=1000, help="Number of top scores to keep") args = parser.parse_args() - logger.info(f"Loading config from {args.config_path}") - # Load cfg from pkl - try: - with open(args.config_path, "rb") as f: - cfg = pickle.load(f) - cfg = DotMap(cfg) - except Exception as e: - logger.error(f"Failed to load config from {args.config_path}: {e}") - sys.exit(1) - - logger.info("Setting batch size") - batch_size = ( - estimate_batch_size(cfg) if cfg.N_batch_prediction is None else cfg.N_batch_prediction - ) - logger.info(f"Batch size set to: {batch_size}") + cfg, batch_size = load_prediction_config(args.config_path) logger.info(f"Loading file list from {args.file_list_path}") with open(args.file_list_path, "r") as f: @@ -167,14 +154,5 @@ def main(): if __name__ == "__main__": - # Configure logging - logs_dir = os.path.join(os.path.dirname(__file__), "logs") - os.makedirs(logs_dir, exist_ok=True) - logger.remove() - logger.add( - os.path.join(logs_dir, "prediction_thread_{time}.log"), - rotation="1 MB", - format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", - level="DEBUG", - ) + setup_prediction_logging("prediction_thread") main() diff --git a/prediction_process_cutana.py b/prediction_process_cutana.py index 7a661e2..b248721 100644 --- a/prediction_process_cutana.py +++ b/prediction_process_cutana.py @@ -5,70 +5,84 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. import argparse +import glob import os -import sys -import pickle +import time -from dotmap import DotMap -import torch +import cutana import numpy as np +import pandas as pd +import torch +from cutana.catalogue_preprocessor import extract_filter_name, parse_fits_file_paths +from dotmap import DotMap from loguru import logger -from concurrent.futures import ThreadPoolExecutor -import time from tqdm import tqdm -import cutana - -from anomaly_match.data_io.load_images import process_single_wrapper +from anomaly_match.image_processing.transforms import ( + get_prediction_transforms, +) from prediction_utils import ( + clear_gpu_cache_if_needed, + convert_cutana_cutout, + create_cutana_format_cfg, load_model, - save_results, + load_prediction_config, process_batch_predictions, - estimate_batch_size, - clear_gpu_cache_if_needed, -) - -from anomaly_match.image_processing.transforms import ( - get_prediction_transforms, + save_results, + setup_prediction_logging, ) -def read_and_preprocess_image_from_zarr(image_data, cfg): - """Read and preprocess image data from Zarr array using standardized functions.""" - try: - # Convert Zarr data to numpy array if it's not already - if not isinstance(image_data, np.ndarray): - image_data = np.array(image_data) - - # Check if we need to transpose based on the shape - # If last dimension is 3 (RGB channels), data is already in HWC format - # If first dimension is 3, data is in CHW format and needs transposing - if image_data.shape[0] == cfg.normalisation.n_output_channels: - # In CHW format, convert to HWC - image = image_data.transpose(1, 2, 0) - else: - # Assume HWC format if neither first nor last dimension is 3 - # This handles grayscale or other formats - image = image_data - - # Use the centralized processing function - this handles RGB conversion, - # normalization, and resizing efficiently without temporary files - processed_image = process_single_wrapper(image, cfg, desc="zarr") - return processed_image +def _resolve_filter_names_from_catalogue(catalogue_path, n_extensions): + """Resolve filter/band names from the first source in a cutana catalogue. - except Exception as e: - logger.error(f"Error processing image from Zarr: {e}") - raise + Cutana streaming predictions operate on large mosaic tiles (separate FITS + files per band) referenced by the catalogue's ``fits_file_paths`` column. + Filter names are extracted from file paths using Euclid naming conventions. + Raises: + ValueError: If filter names cannot be determined — this indicates the + catalogue uses non-Euclid file naming. In that case users must set + ``cfg.normalisation.fits_extension`` to explicit filter name strings. + """ + # catalogue_path may be a single file (buffer parquet) or a directory + if os.path.isfile(catalogue_path): + first_file = catalogue_path + else: + cat_files = sorted(glob.glob(os.path.join(catalogue_path, "*.parquet"))) + if not cat_files: + cat_files = sorted(glob.glob(os.path.join(catalogue_path, "*.csv"))) + if not cat_files: + raise FileNotFoundError(f"No catalogue files found in {catalogue_path}") + first_file = cat_files[0] + + # Read just the first row to get fits_file_paths + if first_file.endswith(".parquet"): + df = pd.read_parquet(first_file, columns=["fits_file_paths"]).head(1) + else: + df = pd.read_csv(first_file, usecols=["fits_file_paths"], nrows=1) + + fits_paths = parse_fits_file_paths(df["fits_file_paths"].iloc[0]) + if len(fits_paths) != n_extensions: + raise ValueError( + f"Catalogue has {len(fits_paths)} FITS files per source but " + f"cfg.normalisation.fits_extension specifies {n_extensions} extensions." + ) -def load_and_preprocess_zarr(args): - """Load and preprocess a single image from Zarr. + filter_names = [extract_filter_name(p) for p in fits_paths] + unknown = [p for p, name in zip(fits_paths, filter_names) if name == "UNKNOWN"] + if unknown: + raise ValueError( + f"Could not determine filter names from catalogue FITS file paths. " + f"Cutana streaming predictions currently only support Euclid data with " + f"standard file naming conventions (VIS, NIR-H, NIR-Y, NIR-J). " + f"Unrecognised files: {unknown}. " + f"If using non-Euclid data, set cfg.normalisation.fits_extension to " + f"explicit filter name strings instead of integer indices." + ) - Note: Returns numpy array, not tensor. Tensor conversion is done on main - thread to avoid CUDA context issues in ThreadPoolExecutor. - """ - image_data, cfg = args - return read_and_preprocess_image_from_zarr(image_data, cfg) + logger.info(f"Resolved filter names from catalogue: {filter_names}") + return filter_names def evaluate_images_from_cutana( @@ -81,37 +95,104 @@ def evaluate_images_from_cutana( cutana_config.target_resolution = cfg.normalisation.image_size[0] cutana_config.source_catalogue = cutana_sources_path - # Configure FITS extensions from AM config, default to PRIMARY if not specified - # fits_extension can be: None, str/int, list of str/int, or list of tuples (name, ext_type) + # Configure FITS extensions for cutana. + # + # AnomalyMatch's fits_extension uses integer HDU indices (for multi-extension + # cutout files loaded by fitsbolt). Cutana operates on large mosaic tiles + # referenced in the catalogue — each source has separate FITS files per band. + # Cutana identifies bands by filter name (e.g. "VIS", "NIR-H") extracted from + # the file paths, so we must resolve integer indices to filter names here. + # + # NOTE: filter name extraction currently relies on Euclid naming conventions + # (via cutana.catalogue_preprocessor.extract_filter_name). If your catalogue + # uses non-Euclid file naming, set cfg.normalisation.fits_extension to + # explicit filter name strings instead of integer indices. fits_ext = cfg.normalisation.fits_extension if fits_ext is None: fits_ext = ["PRIMARY"] elif isinstance(fits_ext, (str, int)): fits_ext = [fits_ext] - # Build selected_extensions - handle both simple names and (name, ext_type) tuples - selected_extensions = [] - extension_names = [] - for ext in fits_ext: - if isinstance(ext, tuple): - name, ext_type = ext - selected_extensions.append({"name": str(name), "ext": ext_type}) - extension_names.append(name) + # When fits_extension contains integers, resolve to filter names from the + # catalogue's fits_file_paths column. + has_integer_indices = any(isinstance(e, int) for e in fits_ext) + if has_integer_indices: + if len(fits_ext) > 1: + extension_names = _resolve_filter_names_from_catalogue( + cutana_sources_path, len(fits_ext) + ) else: - selected_extensions.append({"name": str(ext), "ext": "PrimaryHDU"}) - extension_names.append(ext) + # Single integer index (e.g. [0]) maps to the PRIMARY HDU + extension_names = ["PRIMARY"] + else: + extension_names = [str(e) for e in fits_ext] + + # Build selected_extensions for cutana. + # For multi-file catalogues (separate FITS per band), each file has only a + # PRIMARY HDU, so fits_extensions must be ["PRIMARY"]. The filter names go + # into channel_weights and selected_extensions for channel identification. + if has_integer_indices: + cutana_config.fits_extensions = ["PRIMARY"] + else: + cutana_config.fits_extensions = extension_names - cutana_config.fits_extensions = extension_names + selected_extensions = [] + for name in extension_names: + selected_extensions.append({"name": name, "ext": "PRIMARY"}) cutana_config.selected_extensions = selected_extensions - # Pass channel combination - required for multi-extension data + # Build channel_weights dict for cutana from AM's channel configuration. + # Cutana expects {"ext_name": [weight_per_output_channel, ...], ...}. + # Channel combination must happen BEFORE normalisation (cutana's pipeline + # ensures this) so that ZSCALE/ASINH see the same data shape as training. + n_out = cfg.normalisation.n_output_channels if cfg.normalisation.channel_combination is not None: - cutana_config.channel_weights = cfg.normalisation.channel_combination + # Multi-extension: convert numpy matrix (n_out x n_in) to cutana dict + combo = cfg.normalisation.channel_combination + channel_weights = {} + for j, ext_name in enumerate(extension_names): + channel_weights[str(ext_name)] = combo[:, j].tolist() + cutana_config.channel_weights = channel_weights elif len(fits_ext) > 1: raise ValueError( "cfg.normalisation.channel_combination must be set when using multiple FITS extensions. " "This defines how extensions are combined into RGB channels." ) + else: + # Single extension: replicate to n_output_channels (e.g. 1→3 for RGB) + cutana_config.channel_weights = {str(extension_names[0]): [1.0] * n_out} + + # Verify channel configuration consistency + if len(extension_names) > 1: + combo = cfg.normalisation.channel_combination + n_in = combo.shape[1] if hasattr(combo, "shape") else len(extension_names) + if len(extension_names) != n_in: + raise ValueError( + f"Number of resolved filter names ({len(extension_names)}) does not match " + f"channel_combination input dimension ({n_in}). " + f"Filter names: {extension_names}, matrix shape: {combo.shape}" + ) + if combo.shape[0] != n_out: + raise ValueError( + f"channel_combination output dimension ({combo.shape[0]}) does not match " + f"n_output_channels ({n_out})" + ) + # For non-diagonal matrices, verify all input channels contribute + # (a zero column means an extension is loaded but never used) + for j, ext_name in enumerate(extension_names): + col_sum = abs(combo[:, j]).sum() + if col_sum == 0: + logger.warning( + f"Extension '{ext_name}' (column {j}) has zero weight in " + f"channel_combination — this channel will be loaded but ignored" + ) + logger.info( + f"Channel configuration: {len(extension_names)} inputs -> {n_out} outputs, " + f"filter order: {extension_names}" + ) + + # Flux conversion: must match the training path setting + cutana_config.apply_flux_conversion = cfg.normalisation.apply_flux_conversion # Pass AnomalyMatch's fitsbolt_cfg directly to cutana for normalization # This ensures cutana uses the exact same normalization settings as training @@ -142,7 +223,7 @@ def evaluate_images_from_cutana( model = load_model(cfg) model.eval() - transform = get_prediction_transforms() + transform = get_prediction_transforms(num_channels=n_out) # Process images in batches scores_list = [] @@ -164,13 +245,15 @@ def evaluate_images_from_cutana( ) logger.debug("Using fitsbolt config loaded from model checkpoint") + # CONVERSION_ONLY config for format conversion (created once, reused per cutout) + format_cfg = create_cutana_format_cfg(cfg) + batches_count = cutana_orchestrator.get_batch_count() num_images = 0 filenames = [] for batch_idx in tqdm(range(batches_count), desc="Processing batches"): - loaded_batch = cutana_orchestrator.next_batch() batch_data = loaded_batch["cutouts"] @@ -194,12 +277,12 @@ def evaluate_images_from_cutana( batch_filenames = (source["source_id"] for source in loaded_batch["metadata"]) filenames.extend(batch_filenames) - # I/O and preprocessing in ThreadPool (returns numpy arrays) - # CUDA operations are kept on main thread to prevent memory fragmentation + # Cutana already normalised the cutouts via external_fitsbolt_cfg. + # Only format conversion is needed (CHW→HWC, dtype, channel replication). batch_process_start = time.time() - with ThreadPoolExecutor(max_workers=max_workers) as executor: - batch_args = [(batch_data[i], cfg) for i in range(batch_size_actual)] - numpy_images = list(executor.map(load_and_preprocess_zarr, batch_args)) + numpy_images = [ + convert_cutana_cutout(batch_data[i], format_cfg) for i in range(batch_size_actual) + ] # Tensor conversion on main thread (not in ThreadPool) to avoid CUDA context issues stack_start = time.time() @@ -261,36 +344,7 @@ def main(): parser.add_argument("top_n", type=int, default=1000, help="Number of top scores to keep") args = parser.parse_args() - logger.info(f"Loading config from {args.config_path}") - # Load cfg from pkl - try: - with open(args.config_path, "rb") as f: - cfg = pickle.load(f) - cfg = DotMap(cfg) - except Exception as e: - logger.error(f"Failed to load config from {args.config_path}: {e}") - sys.exit(1) - - logger.info("Setting batch size") - batch_size = ( - estimate_batch_size(cfg) if cfg.N_batch_prediction is None else cfg.N_batch_prediction - ) - logger.info(f"Batch size set to: {batch_size}") - - # Log key configuration parameters - logger.debug("Configuration loaded with parameters:") - logger.debug(f" Save file: {cfg.save_file}") - logger.debug(f" Save path: {cfg.save_path}") - logger.debug(f" Model path: {cfg.model_path}") - logger.debug(f" Output directory: {cfg.output_dir}") - logger.debug(f" Image size: {cfg.normalisation.image_size}") - - # Log full configuration - logger.debug("Full configuration:") - logger.debug(f"{cfg.toDict()}") - - # Create output directory if it doesn't exist - os.makedirs(cfg.output_dir, exist_ok=True) + cfg, batch_size = load_prediction_config(args.config_path) logger.info(f"Streaming from directory: {args.cutana_sources_path}") @@ -306,18 +360,5 @@ def main(): if __name__ == "__main__": - - # Configure logging - logs_dir = os.path.join(os.path.dirname(__file__), "logs") - os.makedirs(logs_dir, exist_ok=True) - - # Remove default handler and set up file logging - logger.remove() - script_logger_id = logger.add( - os.path.join(logs_dir, "prediction_cutana_{time}.log"), - rotation="1 MB", - format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", - level="DEBUG", - ) - logger.add(sys.stderr, level="INFO") + setup_prediction_logging("prediction_cutana") main() diff --git a/prediction_process_hdf5.py b/prediction_process_hdf5.py index de87f33..112c744 100644 --- a/prediction_process_hdf5.py +++ b/prediction_process_hdf5.py @@ -5,32 +5,27 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. import argparse -import os -import pickle -import sys +import time +from concurrent.futures import ThreadPoolExecutor -from dotmap import DotMap -import torch +import h5py import numpy as np +import torch from loguru import logger -from concurrent.futures import ThreadPoolExecutor -import time -import h5py from tqdm import tqdm from anomaly_match.data_io.load_images import process_single_wrapper - +from anomaly_match.image_processing.transforms import ( + get_prediction_transforms, +) from prediction_utils import ( - load_model, - save_results, - process_batch_predictions, clear_gpu_cache_if_needed, jpeg_decoder, - estimate_batch_size, -) - -from anomaly_match.image_processing.transforms import ( - get_prediction_transforms, + load_model, + load_prediction_config, + process_batch_predictions, + save_results, + setup_prediction_logging, ) @@ -48,9 +43,10 @@ def read_and_decode_image_from_hdf5(image_data, cfg): image = image[:, :, [2, 1, 0]] except Exception: # If TurboJPEG fails, fall back to PIL - from PIL import Image import io + from PIL import Image + image = np.array(Image.open(io.BytesIO(image_bytes))) processed_image = process_single_wrapper(image, cfg, desc="hdf5") @@ -172,32 +168,7 @@ def main(): parser.add_argument("top_n", type=int, default=1000, help="Number of top scores to keep") args = parser.parse_args() - logger.info(f"Loading config from {args.config_path}") - # Load cfg from pkl - try: - with open(args.config_path, "rb") as f: - cfg = pickle.load(f) - cfg = DotMap(cfg) - except Exception as e: - logger.error(f"Failed to load config from {args.config_path}: {e}") - sys.exit(1) - - logger.info("Setting batch size") - batch_size = ( - estimate_batch_size(cfg) if cfg.N_batch_prediction is None else cfg.N_batch_prediction - ) - logger.info(f"Batch size set to: {batch_size}") - - # Log key configuration parameters - logger.debug("Configuration loaded with parameters:") - logger.debug(f" Save file: {cfg.save_file}") - logger.debug(f" Save path: {cfg.save_path}") - logger.debug(f" Model path: {cfg.model_path}") - logger.debug(f" Output directory: {cfg.output_dir}") - logger.debug(f" Image size: {cfg.normalisation.image_size}") - - # Create output directory if it doesn't exist - os.makedirs(cfg.output_dir, exist_ok=True) + cfg, batch_size = load_prediction_config(args.config_path) logger.info(f"Processing HDF5 file: {args.hdf5_path}") @@ -211,18 +182,5 @@ def main(): if __name__ == "__main__": - - # Configure logging - logs_dir = os.path.join(os.path.dirname(__file__), "logs") - os.makedirs(logs_dir, exist_ok=True) - - # Remove default handler and set up file logging - logger.remove() - logger.add( - os.path.join(logs_dir, "prediction_thread_{time}.log"), - rotation="1 MB", - format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", - level="DEBUG", - ) - logger.add(sys.stderr, level="INFO") + setup_prediction_logging("prediction_hdf5") main() diff --git a/prediction_process_zarr.py b/prediction_process_zarr.py index e375b32..f4a1c26 100644 --- a/prediction_process_zarr.py +++ b/prediction_process_zarr.py @@ -5,74 +5,31 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. import argparse -import os -import sys -import pickle +import time +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path -from dotmap import DotMap -import torch import numpy as np -from loguru import logger -from concurrent.futures import ThreadPoolExecutor -import time -import zarr import pandas as pd -from pathlib import Path +import torch +import zarr +from loguru import logger from tqdm import tqdm -from anomaly_match.data_io.load_images import process_single_wrapper - +from anomaly_match.image_processing.transforms import ( + get_prediction_transforms, +) from prediction_utils import ( + clear_gpu_cache_if_needed, + load_and_preprocess_zarr, load_model, - save_results, + load_prediction_config, process_batch_predictions, - estimate_batch_size, - clear_gpu_cache_if_needed, -) - -from anomaly_match.image_processing.transforms import ( - get_prediction_transforms, + save_results, + setup_prediction_logging, ) -def read_and_preprocess_image_from_zarr(image_data, cfg): - """Read and preprocess image data from Zarr array using standardized functions.""" - try: - # Convert Zarr data to numpy array if it's not already - if not isinstance(image_data, np.ndarray): - image_data = np.array(image_data) - - # Check if we need to transpose based on the shape - # If last dimension is 3 (RGB channels), data is already in HWC format - # If first dimension is 3, data is in CHW format and needs transposing - if image_data.shape[0] == cfg.normalisation.n_output_channels: - # In CHW format, convert to HWC - image = image_data.transpose(1, 2, 0) - else: - # Assume HWC format if neither first nor last dimension is 3 - # This handles grayscale or other formats - image = image_data - - # Use the centralized processing function - this handles RGB conversion, - # normalization, and resizing efficiently without temporary files - processed_image = process_single_wrapper(image, cfg, desc="zarr") - return processed_image - - except Exception as e: - logger.error(f"Error processing image from Zarr: {e}") - raise - - -def load_and_preprocess_zarr(args): - """Load and preprocess a single image from Zarr. - - Note: Returns numpy array, not tensor. Tensor conversion is done on main - thread to avoid CUDA context issues in ThreadPoolExecutor. - """ - image_data, cfg = args - return read_and_preprocess_image_from_zarr(image_data, cfg) - - def evaluate_images_in_zarr(zarr_path, cfg, top_n=1000, batch_size=1000, max_workers=4): """Evaluate images inside a Zarr file and return top N scores.""" logger.info(f"Opening Zarr file {zarr_path}") @@ -244,36 +201,7 @@ def main(): parser.add_argument("top_n", type=int, default=1000, help="Number of top scores to keep") args = parser.parse_args() - logger.info(f"Loading config from {args.config_path}") - # Load cfg from pkl - try: - with open(args.config_path, "rb") as f: - cfg = pickle.load(f) - cfg = DotMap(cfg) - except Exception as e: - logger.error(f"Failed to load config from {args.config_path}: {e}") - sys.exit(1) - - logger.info("Setting batch size") - batch_size = ( - estimate_batch_size(cfg) if cfg.N_batch_prediction is None else cfg.N_batch_prediction - ) - logger.info(f"Batch size set to: {batch_size}") - - # Log key configuration parameters - logger.debug("Configuration loaded with parameters:") - logger.debug(f" Save file: {cfg.save_file}") - logger.debug(f" Save path: {cfg.save_path}") - logger.debug(f" Model path: {cfg.model_path}") - logger.debug(f" Output directory: {cfg.output_dir}") - logger.debug(f" Image size: {cfg.normalisation.image_size}") - - # Log full configuration - logger.debug("Full configuration:") - logger.debug(f"{cfg.toDict()}") - - # Create output directory if it doesn't exist - os.makedirs(cfg.output_dir, exist_ok=True) + cfg, batch_size = load_prediction_config(args.config_path) logger.info(f"Processing Zarr file: {args.zarr_path}") @@ -287,17 +215,5 @@ def main(): if __name__ == "__main__": - # Configure logging - logs_dir = os.path.join(os.path.dirname(__file__), "logs") - os.makedirs(logs_dir, exist_ok=True) - - # Remove default handler and set up file logging - logger.remove() - logger.add( - os.path.join(logs_dir, "prediction_zarr_{time}.log"), - rotation="1 MB", - format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", - level="DEBUG", - ) - logger.add(sys.stderr, level="INFO") + setup_prediction_logging("prediction_zarr") main() diff --git a/prediction_utils.py b/prediction_utils.py index 55226d0..6ce6348 100644 --- a/prediction_utils.py +++ b/prediction_utils.py @@ -13,13 +13,19 @@ """ import os -import torch +import pickle +import sys + import numpy as np import pandas as pd - +import torch +from dotmap import DotMap +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod from loguru import logger from turbojpeg import TurboJPEG +from anomaly_match.data_io.load_images import get_fitsbolt_config, process_single_wrapper +from anomaly_match.utils.get_default_cfg import get_default_cfg # Initialize TurboJPEG jpeg_decoder = TurboJPEG() @@ -124,13 +130,7 @@ def estimate_batch_size( # usable_vram - c = B * (a * S² + b) # B = (usable_vram - c) / (a * S² + b) S2 = cfg.normalisation.image_size[0] * cfg.normalisation.image_size[1] - # Use num_channels if set, otherwise fall back to normalisation.n_output_channels - num_channels = ( - cfg.num_channels - if isinstance(cfg.num_channels, int) - else cfg.normalisation.n_output_channels - ) - denominator = coef["a"] * S2 * num_channels + coef["b"] + denominator = coef["a"] * S2 * cfg.num_channels + coef["b"] if denominator <= 0: logger.warning("Invalid memory model parameters, returning minimum batch size") @@ -174,18 +174,12 @@ def load_model(cfg): from anomaly_match.utils.get_net_builder import get_net_builder - # Use num_channels if set, otherwise fall back to normalisation.n_output_channels - num_channels = ( - cfg.num_channels - if isinstance(cfg.num_channels, int) - else cfg.normalisation.n_output_channels - ) net_builder = get_net_builder( cfg.net, pretrained=cfg.pretrained, - in_channels=num_channels, + in_channels=cfg.num_channels, ) - model = net_builder(num_classes=2, in_channels=num_channels) + model = net_builder(num_classes=2, in_channels=cfg.num_channels) if torch.cuda.is_available(): gpu_device = getattr(cfg, "gpu", 0) # Default to 0 if not set @@ -449,6 +443,174 @@ def _ensure_consistent_image_format(images): return images +def setup_prediction_logging(log_name): + """Set up logging for prediction scripts. + + Configures file logging with rotation and stderr output. Also adds + session-specific logging if a config path is available in sys.argv. + + Args: + log_name: Name used for the log file (e.g. "prediction_thread", + "prediction_zarr", "prediction_cutana"). + """ + logs_dir = os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), "logs") + os.makedirs(logs_dir, exist_ok=True) + + logger.remove() + logger.add( + os.path.join(logs_dir, f"{log_name}_{{time}}.log"), + rotation="1 MB", + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", + level="DEBUG", + ) + logger.add(sys.stderr, level="INFO") + + # Also log to session output directory if available + if len(sys.argv) > 1: + try: + with open(sys.argv[1], "rb") as _f: + _pre_cfg = DotMap(pickle.load(_f)) + if _pre_cfg.output_dir: + os.makedirs(_pre_cfg.output_dir, exist_ok=True) + logger.add( + os.path.join(_pre_cfg.output_dir, "prediction.log"), + rotation="10 MB", + format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", + level="DEBUG", + ) + except Exception: + pass + + +def load_prediction_config(config_path): + """Load prediction config from pickle file and compute batch size. + + Args: + config_path: Path to the pickled config file. + + Returns: + tuple: (cfg, batch_size) where cfg is a DotMap config object + and batch_size is the computed or configured batch size. + """ + logger.info(f"Loading config from {config_path}") + try: + with open(config_path, "rb") as f: + cfg = pickle.load(f) + cfg = DotMap(cfg) + except Exception as e: + logger.error(f"Failed to load config from {config_path}: {e}") + sys.exit(1) + + logger.info("Setting batch size") + batch_size = ( + estimate_batch_size(cfg) if cfg.N_batch_prediction is None else cfg.N_batch_prediction + ) + logger.info(f"Batch size set to: {batch_size}") + + # Log key configuration parameters + logger.debug("Configuration loaded with parameters:") + logger.debug(f" Save file: {cfg.save_file}") + logger.debug(f" Save path: {cfg.save_path}") + logger.debug(f" Model path: {cfg.model_path}") + logger.debug(f" Output directory: {cfg.output_dir}") + logger.debug(f" Image size: {cfg.normalisation.image_size}") + + # Log full configuration + logger.debug("Full configuration:") + logger.debug(f"{cfg.toDict()}") + + # Create output directory if it doesn't exist + os.makedirs(cfg.output_dir, exist_ok=True) + + return cfg, batch_size + + +def create_cutana_format_cfg(cfg): + """Create a CONVERSION_ONLY fitsbolt config for cutana format conversion. + + This config is used by convert_cutana_cutout to handle dtype and channel + conversion without re-applying normalisation. Callers should create this + once and pass it to convert_cutana_cutout for each image. + """ + format_cfg = get_default_cfg() + format_cfg.normalisation.image_size = cfg.normalisation.image_size + format_cfg.normalisation.n_output_channels = cfg.normalisation.n_output_channels + format_cfg.normalisation.normalisation_method = NormalisationMethod.CONVERSION_ONLY + format_cfg.normalisation.norm_asinh_scale = cfg.normalisation.norm_asinh_scale + format_cfg.normalisation.norm_asinh_clip = cfg.normalisation.norm_asinh_clip + format_cfg.num_workers = 0 + return get_fitsbolt_config(format_cfg) + + +def convert_cutana_cutout(image_data, format_cfg): + """Convert a cutana-normalised cutout to the format expected by the model. + + Cutana already applies normalisation via external_fitsbolt_cfg, so this + function only performs format conversion (dtype, channel replication) using + fitsbolt's CONVERSION_ONLY mode — no normalisation stretch is applied. + + Args: + image_data: Cutana cutout array (already normalised). + format_cfg: CONVERSION_ONLY config from create_cutana_format_cfg. + + Returns: + np.ndarray: Image in HWC uint8 format ready for model inference. + """ + if not isinstance(image_data, np.ndarray): + image_data = np.array(image_data) + + # CHW → HWC (cutana may deliver CHW depending on config) + if image_data.ndim == 3 and image_data.shape[0] <= 4 and image_data.shape[2] > 4: + image_data = image_data.transpose(1, 2, 0) + + # Delegate dtype conversion and channel replication to fitsbolt via + # process_single_wrapper with a CONVERSION_ONLY config so the already- + # normalised pixel values are preserved (only dtype + channels change). + return process_single_wrapper(image_data, format_cfg, desc="cutana") + + +def read_and_preprocess_image_from_zarr(image_data, cfg): + """Read and preprocess raw image data from a Zarr array. + + Handles CHW/HWC format detection and uses fitsbolt for normalization. + Used by the zarr prediction script for data that has NOT been normalised yet. + """ + try: + # Convert Zarr data to numpy array if it's not already + if not isinstance(image_data, np.ndarray): + image_data = np.array(image_data) + + # Check if we need to transpose based on the shape + # If last dimension is 3 (RGB channels), data is already in HWC format + # If first dimension is 3, data is in CHW format and needs transposing + if image_data.shape[0] == cfg.normalisation.n_output_channels: + # In CHW format, convert to HWC + image = image_data.transpose(1, 2, 0) + else: + # Assume HWC format if neither first nor last dimension is 3 + # This handles grayscale or other formats + image = image_data + + # Use the centralized processing function - this handles RGB conversion, + # normalization, and resizing efficiently without temporary files + processed_image = process_single_wrapper(image, cfg, desc="zarr") + return processed_image + + except Exception as e: + logger.error(f"Error processing image from Zarr: {e}") + raise + + +def load_and_preprocess_zarr(args): + """Load and preprocess a single image from Zarr. + + Note: Returns numpy array, not tensor. Tensor conversion is done on main + thread to avoid CUDA context issues in ThreadPoolExecutor. + """ + image_data, cfg = args + return read_and_preprocess_image_from_zarr(image_data, cfg) + + def process_batch_predictions(model, images, original_images=None): """Process a batch of images through the model to get anomaly scores. diff --git a/pyproject.toml b/pyproject.toml index 031d175..3e25184 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "anomaly_match" -version = "1.2.0" +version = "1.3.0" description = "A tool for anomaly detection in images using semi-supervised and active learning with a GUI" readme = "README.md" license = { file = "LICENSE.txt" } @@ -32,51 +32,66 @@ classifiers = [ dependencies = [ "albumentations", "astropy", + "cutana>=0.2.1", "dotmap", - "efficientnet-pytorch", - "efficientnet_lite_pytorch", - "efficientnet_lite0_pytorch_model", + "fitsbolt>=0.2", "h5py", "ipykernel", "ipywidgets", - "imageio", "loguru", "matplotlib", "numpy", "opencv-python-headless", "pandas<3", + "psutil", + "pyarrow", "pyturbojpeg", - "scikit-learn", "scikit-image", - "fitsbolt>=0.1.6", + "scikit-learn", + "scipy", + "timm", "toml", - "torch>=2.6", + "torch", "torchvision", "tqdm", "zarr>=3.0.0b0", - "cutana>=0.2.1", ] [project.optional-dependencies] -dev = ["pytest", "pytest-cov", "black", "flake8", "mypy", "vulture>=2.10"] +dev = ["pytest", "pytest-asyncio", "pytest-cov", "ruff", "mypy", "vulture>=2.10"] -[tool.setuptools] -packages = ["anomaly_match"] +[tool.setuptools.packages.find] +include = ["anomaly_match*"] [tool.setuptools.package-dir] "" = "." -[tool.black] +[tool.ruff] line-length = 100 -target-version = ['py311'] +target-version = "py311" + +[tool.ruff.lint] +select = ["E", "F", "W", "I"] +ignore = ["E402", "E203", "E501"] + +[tool.ruff.lint.isort] +known-first-party = ["anomaly_match", "anomaly_match_ui"] [tool.pytest.ini_options] +minversion = "7.0" testpaths = ["tests"] +pythonpath = ["."] python_files = ["test_*.py"] python_classes = ["Test*"] python_functions = ["test_*"] +markers = [ + "slow: tests that train models or take >10s (skip with -m 'not slow')", + "ui: tests requiring ipywidgets/Jupyter", +] +asyncio_mode = "strict" +asyncio_default_fixture_loop_scope = "function" [tool.vulture] min_confidence = 60 -paths = ["anomaly_match"] +paths = ["anomaly_match", "anomaly_match_ui"] exclude = ["tests/", "docs/", "examples/", "paper_scripts/"] diff --git a/scripts/generate_4ch_test_data.py b/scripts/generate_4ch_test_data.py new file mode 100644 index 0000000..118df14 --- /dev/null +++ b/scripts/generate_4ch_test_data.py @@ -0,0 +1,128 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Generate 4-channel test images with distinct per-channel content. + +Each channel has a different geometric pattern so multispectral images +are visually distinguishable from grayscale: + Ch0: Concentric circles (radial gradient) + Ch1: Diagonal stripes + Ch2: Checkerboard pattern + Ch3: Gaussian blob (random position per image) +""" + +import os + +import numpy as np +import pandas as pd +import tifffile + + +def make_radial_gradient(size, center_x, center_y, scale=1.0): + """Concentric circles centered at (center_x, center_y).""" + y, x = np.mgrid[0:size, 0:size] + dist = np.sqrt((x - center_x) ** 2 + (y - center_y) ** 2) + return np.clip((np.cos(dist * scale * 0.1) * 0.5 + 0.5) * 255, 0, 255).astype(np.uint8) + + +def make_diagonal_stripes(size, frequency=0.05, angle=0.0): + """Diagonal stripe pattern.""" + y, x = np.mgrid[0:size, 0:size] + val = np.sin(2 * np.pi * frequency * (x * np.cos(angle) + y * np.sin(angle))) + return np.clip((val * 0.5 + 0.5) * 255, 0, 255).astype(np.uint8) + + +def make_checkerboard(size, block_size=16): + """Checkerboard pattern.""" + y, x = np.mgrid[0:size, 0:size] + checker = ((x // block_size) + (y // block_size)) % 2 + return (checker * 255).astype(np.uint8) + + +def make_gaussian_blob(size, cx, cy, sigma=20, intensity=200): + """Gaussian blob at (cx, cy).""" + y, x = np.mgrid[0:size, 0:size] + blob = intensity * np.exp(-((x - cx) ** 2 + (y - cy) ** 2) / (2 * sigma**2)) + return np.clip(blob, 0, 255).astype(np.uint8) + + +def main(): + script_dir = os.path.dirname(os.path.abspath(__file__)) + repo_root = os.path.dirname(script_dir) + ms_subdir = os.path.join(repo_root, "tests", "test_data", "multispectral_4ch") + os.makedirs(ms_subdir, exist_ok=True) + + # Clean old files + for f in os.listdir(ms_subdir): + os.remove(os.path.join(ms_subdir, f)) + + size = 64 + n_images = 10 + np.random.seed(42) + + source_names = [ + "Abell2390_VIS_2", + "Abell2390_VIS_3", + "Abell2390_VIS_5", + "Abell2390_VIS_6", + "Abell2390_VIS_2021", + "Abell2390_VIS_2172", + "Abell2390_VIS_4989", + "Abell2390_VIS_5260", + "Abell2390_VIS_5783", + "Abell2390_VIS_6212", + ] + + generated_filenames = [] + + for i, name in enumerate(source_names[:n_images]): + # Each image gets slightly different parameters for variety + cx = np.random.randint(size // 4, 3 * size // 4) + cy = np.random.randint(size // 4, 3 * size // 4) + angle = np.random.uniform(0, np.pi) + block = np.random.choice([8, 12, 16]) + + ch0 = make_radial_gradient(size, cx, cy, scale=1.0 + i * 0.3) + ch1 = make_diagonal_stripes(size, frequency=0.04 + i * 0.01, angle=angle) + ch2 = make_checkerboard(size, block_size=block) + ch3 = make_gaussian_blob(size, cx, cy, sigma=10 + i * 2) + + ms_img = np.stack([ch0, ch1, ch2, ch3], axis=-1) # (H, W, 4) + + filename = f"{name}_4ch.tiff" + tifffile.imwrite(os.path.join(ms_subdir, filename), ms_img) + generated_filenames.append(filename) + print(f" Created {filename} with shape {ms_img.shape}") + + # Create labeled_data.csv - label first 6, leave 4 unlabeled + labeled_filenames = generated_filenames[:6] + labels = ["anomaly"] * 3 + ["normal"] * 3 + df = pd.DataFrame({"filename": labeled_filenames, "label": labels}) + df.to_csv(os.path.join(ms_subdir, "labeled_data.csv"), index=False) + print(f"Created labeled_data.csv with {len(labeled_filenames)} labeled entries") + + # Create metadata.csv for all images + metadata_rows = [] + base_ra, base_dec = 328.4034, 17.6950 + for i, filename in enumerate(generated_filenames): + metadata_rows.append( + { + "filename": filename, + "sourceID": f"MS_{i:05d}", + "ra": base_ra + i * 0.002, + "dec": base_dec + i * 0.001, + "custom_metadata": f"4ch test source {i}", + } + ) + meta_df = pd.DataFrame(metadata_rows) + meta_df.to_csv(os.path.join(ms_subdir, "metadata.csv"), index=False) + print(f"Created metadata.csv with {len(metadata_rows)} entries") + + print(f"\nDone! Test data saved to: {ms_subdir}") + + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..43e2e91 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,76 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Shared pytest fixtures for AnomalyMatch tests.""" + +import os + +import pytest + + +@pytest.fixture(scope="function") +def multispectral_test_data(): + """Provide 4-channel test images from permanent test data. + + Uses pre-generated 4-channel TIFF images from tests/test_data/multispectral_4ch/ + + Yields: + tuple: (data_dir, image_filenames, labeled_data_path, num_channels) + - data_dir: Path to directory containing 4-channel test images + - image_filenames: List of image filenames + - labeled_data_path: Path to the labeled_data.csv file + - num_channels: Number of channels (4) + """ + # Get path to 4-channel test images + test_data_dir = os.path.join(os.path.dirname(__file__), "test_data", "multispectral_4ch") + + if not os.path.exists(test_data_dir): + pytest.skip( + "4-channel test data not found. Run 'python scripts/generate_4ch_test_data.py' first." + ) + + # Get all .tiff files (the format that supports 4+ channels) + tiff_images = [f for f in os.listdir(test_data_dir) if f.endswith(".tiff")] + + if not tiff_images: + pytest.skip("No .tiff test images found in tests/test_data/multispectral_4ch/") + + csv_path = os.path.join(test_data_dir, "labeled_data.csv") + if not os.path.exists(csv_path): + pytest.skip("labeled_data.csv not found in tests/test_data/multispectral_4ch/") + + yield test_data_dir, sorted(tiff_images), csv_path, 4 + + +@pytest.fixture(scope="function") +def multispectral_config(multispectral_test_data): + """Create a configuration for multispectral testing. + + Args: + multispectral_test_data: The multispectral test data fixture. + + Yields: + DotMap: Configuration object set up for 4-channel images. + """ + import anomaly_match as am + + data_dir, filenames, csv_path, num_channels = multispectral_test_data + + cfg = am.get_default_cfg() + cfg.data_dir = data_dir + cfg.normalisation.image_size = [64, 64] + cfg.net = "test-cnn" + cfg.pretrained = False + cfg.num_train_iter = 2 + cfg.num_workers = 0 + cfg.test_ratio = 0.3 + cfg.N_to_load = 10 + cfg.normalisation.fits_extension = None + cfg.label_file = csv_path + cfg.seed = 42 + # n_output_channels is auto-detected from images by AnomalyDetectionDataset + + yield cfg diff --git a/pytest.ini b/tests/e2e/conftest.py similarity index 85% rename from pytest.ini rename to tests/e2e/conftest.py index 53d01ae..5d45ca4 100644 --- a/pytest.ini +++ b/tests/e2e/conftest.py @@ -4,7 +4,4 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -[pytest] -minversion = 7.0 -testpaths = tests -pythonpath = . \ No newline at end of file +"""Shared fixtures for end-to-end tests.""" diff --git a/tests/e2e/test_normalisation_consistency.py b/tests/e2e/test_normalisation_consistency.py new file mode 100644 index 0000000..7d32171 --- /dev/null +++ b/tests/e2e/test_normalisation_consistency.py @@ -0,0 +1,248 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Tests verifying that cutana's prediction pipeline normalises images identically to AM's +training pipeline. A mismatch would cause silent model failure in production (data drift). + +The cutana prediction path (prediction_process_cutana.py) passes AM's fitsbolt_cfg to cutana +via external_fitsbolt_cfg so that cutana handles normalisation. This test verifies that +the output matches what AM's load_and_process_wrapper produces for the same raw data. + +Test approach: +1. Extract raw (unnormalised) cutouts from a test FITS tile using cutana's + do_only_cutout_extraction mode, then save as clean FITS files. +2. Prediction path: run cutana with external_fitsbolt_cfg + channel_weights + (expanding 1->n_output_channels before normalisation) -> convert_cutana_cutout +3. Training path: load_and_process_wrapper on the same raw FITS files +4. Compare pixel-for-pixel (+/-1 tolerance for WCS reprojection float rounding). +""" + +import csv +import os + +import cutana +import numpy as np +import pytest +from astropy.io import fits +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod + +from anomaly_match.data_io.load_images import ( + get_fitsbolt_config, + load_and_process_wrapper, +) +from anomaly_match.utils.get_default_cfg import get_default_cfg +from prediction_utils import convert_cutana_cutout, create_cutana_format_cfg + +# Paths to pre-generated test data +_TEST_DATA_DIR = os.path.join( + os.path.dirname(__file__), os.pardir, "test_data", "normalisation_consistency" +) +_FITS_TILE = os.path.join( + _TEST_DATA_DIR, + "EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits", +) +_CSV_CATALOGUE = os.path.join(_TEST_DATA_DIR, "mock_sources.csv") + +_TARGET_RESOLUTION = 150 +_MAX_SOURCES = 2 + + +def _rewrite_csv_with_absolute_paths(csv_in, csv_out, fits_tile_abs): + """Rewrite the catalogue CSV with absolute paths, limited to _MAX_SOURCES.""" + with open(csv_in) as f_in, open(csv_out, "w", newline="") as f_out: + reader = csv.DictReader(f_in) + writer = csv.DictWriter(f_out, fieldnames=reader.fieldnames) + writer.writeheader() + for i, row in enumerate(reader): + if i >= _MAX_SOURCES: + break + row["fits_file_paths"] = str([fits_tile_abs]) + writer.writerow(row) + + +def _make_cutana_config(csv_path, n_output_channels=3): + """Create a base cutana config for the test tile. + + Sets channel_weights so cutana expands single-extension data to + n_output_channels before normalisation — matching the training path's + channel_combine -> normalise order in fitsbolt's _process_image. + """ + config = cutana.get_default_config() + config.target_resolution = _TARGET_RESOLUTION + config.source_catalogue = csv_path + config.fits_extensions = ["PRIMARY"] + config.selected_extensions = [{"name": "PRIMARY", "ext": "PrimaryHDU"}] + config.apply_flux_conversion = False + config.channel_weights = {"PRIMARY": [1.0] * n_output_channels} + return config + + +def _run_cutana_normalised(csv_path, fitsbolt_cfg, n_output_channels=3): + """Run cutana with external_fitsbolt_cfg and return normalised cutout arrays.""" + config = _make_cutana_config(csv_path, n_output_channels) + config.external_fitsbolt_cfg = fitsbolt_cfg + + orchestrator = cutana.StreamingOrchestrator(config) + orchestrator.init_streaming(batch_size=10, write_to_disk=False, synchronised_loading=False) + + all_cutouts = [] + for _ in range(orchestrator.get_batch_count()): + batch = orchestrator.next_batch() + cutouts = batch["cutouts"] + if isinstance(cutouts, list) and len(cutouts) == 0: + continue + if isinstance(cutouts, list): + cutouts = np.array(cutouts) + for i in range(cutouts.shape[0]): + all_cutouts.append(np.array(cutouts[i])) + + orchestrator.cleanup() + return all_cutouts + + +def _extract_raw_cutouts(csv_path, output_dir): + """Extract raw (unnormalised) cutouts as FITS files using cutana.""" + config = _make_cutana_config(csv_path, n_output_channels=1) + config.do_only_cutout_extraction = True + config.output_format = "fits" + config.write_to_disk = True + config.output_dir = output_dir + + orchestrator = cutana.StreamingOrchestrator(config) + orchestrator.init_streaming(batch_size=10, write_to_disk=True, synchronised_loading=False) + for _ in range(orchestrator.get_batch_count()): + orchestrator.next_batch() + orchestrator.cleanup() + + # Raw cutout data is in HDU[1] of each file + return sorted( + os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith(".fits") + ) + + +@pytest.fixture(scope="session") +def cutana_test_data(tmp_path_factory): + """Extract raw cutouts from the test FITS tile using cutana. + + Session-scoped: raw cutout extraction is expensive and identical + across all normalisation methods. + + Returns: + tuple: (clean_fits_paths, rewritten_csv_path) + - clean_fits_paths: list of FITS file paths with raw cutout data in HDU[0] + - rewritten_csv_path: path to CSV with absolute FITS tile paths + """ + if not os.path.exists(_CSV_CATALOGUE) or not os.path.exists(_FITS_TILE): + pytest.skip("Normalisation consistency test data not found") + + tmp_path = tmp_path_factory.mktemp("normalisation_consistency") + + # Rewrite CSV with absolute paths for this environment + rewritten_csv = str(tmp_path / "sources.csv") + _rewrite_csv_with_absolute_paths(_CSV_CATALOGUE, rewritten_csv, os.path.abspath(_FITS_TILE)) + + # Extract raw cutouts (do_only_cutout_extraction writes FITS with data in HDU[1]) + raw_dir = str(tmp_path / "raw_cutouts") + os.makedirs(raw_dir, exist_ok=True) + raw_fits_paths = _extract_raw_cutouts(rewritten_csv, raw_dir) + + if not raw_fits_paths: + pytest.skip("Cutana did not produce any cutouts from the test tile") + + # Save as clean FITS files with data in HDU[0] for the training path + clean_fits_paths = [] + for i, raw_path in enumerate(raw_fits_paths): + with fits.open(raw_path) as hdul: + raw_data = hdul[1].data + clean_path = str(tmp_path / f"cutout_{i}.fits") + fits.PrimaryHDU(raw_data.astype(np.float32)).writeto(clean_path, overwrite=True) + clean_fits_paths.append(clean_path) + + return clean_fits_paths, rewritten_csv + + +def _make_cfg(normalisation_method): + """Create an AM config with fitsbolt_cfg for the given normalisation method.""" + cfg = get_default_cfg() + cfg.normalisation.image_size = [_TARGET_RESOLUTION, _TARGET_RESOLUTION] + cfg.normalisation.n_output_channels = 3 + cfg.normalisation.normalisation_method = normalisation_method + cfg.num_channels = 3 + cfg.num_workers = 0 + cfg = get_fitsbolt_config(cfg) + return cfg + + +_ALL_METHODS = [ + NormalisationMethod.CONVERSION_ONLY, + NormalisationMethod.LOG, + NormalisationMethod.ZSCALE, + NormalisationMethod.ASINH, +] + + +def test_cutana_vs_training_normalisation(cutana_test_data): + """Verify that cutana's normalisation matches AM's training normalisation. + + Tests all normalisation methods in a single function to avoid repeated + cutana orchestrator initialisation overhead (~3s per method). + + For each normalisation method: + 1. Training path: raw FITS file -> load_and_process_wrapper() -> normalised image + 2. Prediction path: FITS tile -> cutana with external_fitsbolt_cfg -> + convert_cutana_cutout -> normalised image + + Both paths must produce matching output for the same source data. + A tolerance of +/-1 (uint8) is allowed because the two separate cutana runs + (raw extraction vs normalised) may produce tiny float32 differences in WCS + reprojection, which non-linear stretches (ASINH) can amplify to +/-1/255. + """ + clean_fits_paths, rewritten_csv = cutana_test_data + failures = [] + + for method in _ALL_METHODS: + cfg = _make_cfg(method) + + # --- Prediction path: cutana normalisation + format conversion --- + cutana_normalised = _run_cutana_normalised( + rewritten_csv, cfg.fitsbolt_cfg, n_output_channels=3 + ) + if len(cutana_normalised) != len(clean_fits_paths): + failures.append( + f"{method.name}: cutana returned {len(cutana_normalised)} cutouts " + f"but {len(clean_fits_paths)} raw cutouts were extracted" + ) + continue + + format_cfg = create_cutana_format_cfg(cfg) + prediction_images = [convert_cutana_cutout(c, format_cfg) for c in cutana_normalised] + + # --- Training path: load raw FITS via fitsbolt --- + training_pairs = load_and_process_wrapper(clean_fits_paths, cfg, show_progress=False) + + # --- Compare --- + for i, (pred_img, (_, train_img)) in enumerate(zip(prediction_images, training_pairs)): + if pred_img.shape != train_img.shape: + failures.append( + f"{method.name} cutout {i}: shape mismatch — " + f"prediction {pred_img.shape} vs training {train_img.shape}" + ) + continue + if pred_img.dtype != train_img.dtype: + failures.append( + f"{method.name} cutout {i}: dtype mismatch — " + f"prediction {pred_img.dtype} vs training {train_img.dtype}" + ) + continue + + diff = np.abs(pred_img.astype(np.int16) - train_img.astype(np.int16)) + max_diff = int(diff.max()) + if max_diff > 1: + failures.append( + f"{method.name} cutout {i}: max abs diff = {max_diff} (tolerance: 1)" + ) + + assert not failures, "Normalisation mismatches:\n" + "\n".join(failures) diff --git a/tests/pipeline_test.py b/tests/e2e/test_pipeline.py similarity index 89% rename from tests/pipeline_test.py rename to tests/e2e/test_pipeline.py index e1137ca..41730d6 100644 --- a/tests/pipeline_test.py +++ b/tests/e2e/test_pipeline.py @@ -6,10 +6,13 @@ # the terms contained in the file 'LICENCE.txt'. """Testing a mininal pipeline""" -import pytest import ipywidgets as widgets +import pytest + import anomaly_match as am +pytestmark = pytest.mark.slow + @pytest.fixture(scope="module") def pipeline_config(): @@ -25,7 +28,10 @@ def pipeline_config(): cfg.data_dir = "tests/test_data/" cfg.normalisation.image_size = [64, 64] cfg.normalisation.n_output_channels = 3 - cfg.num_train_iter = 10 + cfg.net = "test-cnn" + cfg.pretrained = False + cfg.num_train_iter = 2 + cfg.num_workers = 0 return cfg, out diff --git a/tests/test_prediction_process.py b/tests/e2e/test_prediction_process.py similarity index 96% rename from tests/test_prediction_process.py rename to tests/e2e/test_prediction_process.py index 32897f9..799d366 100644 --- a/tests/test_prediction_process.py +++ b/tests/e2e/test_prediction_process.py @@ -4,33 +4,34 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import pytest import csv import os -import numpy as np -import h5py -import zarr import tempfile -from PIL import Image + +import h5py +import numpy as np import pandas as pd +import pytest import torch -from loguru import logger +import zarr + +pytestmark = pytest.mark.slow from astropy.io import fits -from astropy.wcs import WCS from astropy.table import Table +from astropy.wcs import WCS +from fitsbolt.cfg.create_config import create_config as fb_create_cfg +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod +from loguru import logger +from PIL import Image +from anomaly_match.utils.get_default_cfg import get_default_cfg from prediction_process import evaluate_files +from prediction_process_cutana import evaluate_images_from_cutana from prediction_process_hdf5 import evaluate_images_in_hdf5 from prediction_process_zarr import evaluate_images_in_zarr -from prediction_process_cutana import evaluate_images_from_cutana from prediction_utils import save_results -from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod -from fitsbolt.cfg.create_config import create_config as fb_create_cfg -from anomaly_match.utils.get_default_cfg import get_default_cfg - - @pytest.fixture def test_config(): cfg = get_default_cfg() @@ -49,6 +50,7 @@ def test_config(): cfg.test_ratio = 0.0 # Add test ratio cfg.save_dir = tempfile.mkdtemp() # Add save directory cfg.data_dir = "tests/test_data/" # Add data directory + cfg.num_workers = 0 # Use main process for data loading (avoids spawn overhead) # Create fb_cfg for fitsbolt cfg.fitsbolt_cfg = fb_create_cfg( @@ -58,7 +60,7 @@ def test_config(): interpolation_order=cfg.normalisation.interpolation_order, normalisation_method=cfg.normalisation.normalisation_method, channel_combination=cfg.normalisation.channel_combination, - num_workers=cfg.num_workers, + num_workers=max(cfg.num_workers, 1), norm_maximum_value=cfg.normalisation.norm_maximum_value, norm_minimum_value=cfg.normalisation.norm_minimum_value, norm_log_calculate_minimum_value=cfg.normalisation.norm_log_calculate_minimum_value, @@ -554,14 +556,19 @@ def test_prediction_file_type_cutana_malformed_header(test_config, test_cutana_m with pytest.warns( RuntimeWarning, - match=r"File .* did not pass cutana compatibility check and will be skipped \(.*\)", + match=r"File .* did not pass cutana column check \(.*\)", ): with pytest.raises(RuntimeError, match="All found files are not compatible with cutana"): session.evaluate_all_images() def test_prediction_file_type_cutana_missing_images(test_config, test_cutana_missing_images): - """Test for meaningful exception when streaming from cutana and images are missing.""" + """Test that catalogues with missing FITS images still pass validation. + + FITS existence is no longer checked during validation (only column schema + is verified). Errors from missing files surface later during cutana + processing. + """ from anomaly_match.pipeline.session import Session from anomaly_match.utils.get_default_cfg import get_default_cfg @@ -572,12 +579,11 @@ def test_prediction_file_type_cutana_missing_images(test_config, test_cutana_mis session = Session(cfg) - with pytest.warns( - RuntimeWarning, - match=r"File .* did not pass cutana compatibility check and will be skipped \(.*\)", - ): - with pytest.raises(RuntimeError, match="All found files are not compatible with cutana"): - session.evaluate_all_images() + # Validation passes (columns are valid), but the cutana subprocess will + # fail when it tries to open missing FITS files. The session logs a + # warning instead of raising, so we just verify no scores are loaded. + session.evaluate_all_images() + assert session.scores is None def test_stream_file_type_detection_csv_and_parquet(tmp_path): @@ -698,8 +704,8 @@ def mock_save_results(cfg, all_scores, all_imgs, all_filenames, top_n): def test_load_and_preprocess_multiple_formats(test_config, mixed_format_images): """Test the load_and_preprocess function can handle multiple formats.""" - from prediction_process import load_and_preprocess from anomaly_match.image_processing.transforms import get_prediction_transforms + from prediction_process import load_and_preprocess image_paths, _ = mixed_format_images transform = get_prediction_transforms() @@ -815,9 +821,9 @@ def mock_process_batch(model, images): f"Image array size ({len(final_images)}) doesn't match CSV size ({len(final_scores)}). " f"This indicates a bug in image accumulation logic." ) - assert ( - final_images.shape[0] == top_n - ), f"Expected {top_n} images, got {final_images.shape[0]}" + assert final_images.shape[0] == top_n, ( + f"Expected {top_n} images, got {final_images.shape[0]}" + ) def test_all_predictions_accumulation(test_config, monkeypatch): @@ -988,9 +994,9 @@ def test_top_images_preservation_across_batches(test_config, tmp_path): logger.info(f"All stored filenames: {all_stored_filenames}") # Should have 4 total predictions (2 from each batch) assert len(all_stored_scores) == 4, f"Expected 4 total scores, got {len(all_stored_scores)}" - assert ( - len(all_stored_filenames) == 4 - ), f"Expected 4 total filenames, got {len(all_stored_filenames)}" + assert len(all_stored_filenames) == 4, ( + f"Expected 4 total filenames, got {len(all_stored_filenames)}" + ) # Verify that all scores from both batches are present expected_all_scores = [0.9, 0.8, 0.6, 0.5] # batch1: 0.9, 0.8; batch2: 0.6, 0.5 @@ -1006,12 +1012,12 @@ def test_top_images_preservation_across_batches(test_config, tmp_path): sorted_scores = all_stored_scores[sorted_indices] sorted_filenames = all_stored_filenames[sorted_indices] - assert np.allclose( - sorted_scores, expected_all_scores - ), f"Expected all scores {expected_all_scores}, got {sorted_scores}" - assert ( - list(sorted_filenames) == expected_all_filenames - ), f"Expected all filenames {expected_all_filenames}, got {list(sorted_filenames)}" + assert np.allclose(sorted_scores, expected_all_scores), ( + f"Expected all scores {expected_all_scores}, got {sorted_scores}" + ) + assert list(sorted_filenames) == expected_all_filenames, ( + f"Expected all filenames {expected_all_filenames}, got {list(sorted_filenames)}" + ) logger.info("✅ All predictions file correctly contains data from both batches") @@ -1022,8 +1028,8 @@ def test_image_directory_processing(test_config, mixed_format_images): _, directory_path = mixed_format_images # Create a file list with all image paths in the directory - from pathlib import Path import tempfile + from pathlib import Path # Create a temporary file list with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as file_list: @@ -1065,9 +1071,10 @@ def test_image_directory_processing(test_config, mixed_format_images): def test_prediction_file_type_image(test_config, monkeypatch, mixed_format_images): """Test that the correct prediction process is called for the 'image' file type.""" import os + import subprocess # Create a directory with mixed format images import sys + from anomaly_match.pipeline.session import Session - import subprocess # Create a directory with mixed format images image_paths, directory_path = mixed_format_images @@ -1147,7 +1154,7 @@ def mock_run_pipeline(self, temp_config_path, input_path, top_N): with open(group_file, "r") as f: group_content = f.read().strip().split("\n") assert set(group_content) == set(image_paths), ( - f"Wrong content in group file: {group_content}, " f"expected {image_paths}" + f"Wrong content in group file: {group_content}, expected {image_paths}" ) # Verify that the file list exists and points to the group file @@ -1155,7 +1162,7 @@ def mock_run_pipeline(self, temp_config_path, input_path, top_N): with open(temp_file_list, "r") as f: content = f.read().strip() assert content == group_file, ( - f"Wrong content in file list: '{content}', " f"expected '{group_file}'" + f"Wrong content in file list: '{content}', expected '{group_file}'" ) finally: @@ -1208,14 +1215,14 @@ def test_image_channel_order_rgb(test_config, tmp_path): f"HDF5: Red channel not highest. R={np.mean(loaded_hdf5[:, :, 0])}, " f"G={np.mean(loaded_hdf5[:, :, 1])}, B={np.mean(loaded_hdf5[:, :, 2])}" ) - assert ( - np.mean(loaded_hdf5[:, :, 0]) > np.mean(loaded_hdf5[:, :, 2]) + tolerance - ), f"HDF5: Red channel not highest vs blue. R={np.mean(loaded_hdf5[:, :, 0])}, B={np.mean(loaded_hdf5[:, :, 2])}" + assert np.mean(loaded_hdf5[:, :, 0]) > np.mean(loaded_hdf5[:, :, 2]) + tolerance, ( + f"HDF5: Red channel not highest vs blue. R={np.mean(loaded_hdf5[:, :, 0])}, B={np.mean(loaded_hdf5[:, :, 2])}" + ) # Check green channel (should be middle) - assert ( - np.mean(loaded_hdf5[:, :, 1]) > np.mean(loaded_hdf5[:, :, 2]) + tolerance - ), f"HDF5: Green channel not higher than blue. G={np.mean(loaded_hdf5[:, :, 1])}, B={np.mean(loaded_hdf5[:, :, 2])}" + assert np.mean(loaded_hdf5[:, :, 1]) > np.mean(loaded_hdf5[:, :, 2]) + tolerance, ( + f"HDF5: Green channel not higher than blue. G={np.mean(loaded_hdf5[:, :, 1])}, B={np.mean(loaded_hdf5[:, :, 2])}" + ) logger.info( f"HDF5 RGB values: R={np.mean(loaded_hdf5[:, :, 0]):.1f}, " @@ -1226,9 +1233,10 @@ def test_image_channel_order_rgb(test_config, tmp_path): def test_prediction_file_type_zarr(test_config, monkeypatch, test_zarr): """Test that the correct prediction process is called for the 'zarr' file type.""" import os + import subprocess import sys + from anomaly_match.pipeline.session import Session - import subprocess # Create config based on default from anomaly_match.utils.get_default_cfg import get_default_cfg @@ -1289,9 +1297,9 @@ def mock_run_pipeline(self, temp_config_path, input_path, top_N): assert script_path == "prediction_process_zarr.py", f"Wrong script called: {script_path}" # Verify the zarr file path is passed correctly - assert ( - called_processes[0][3] == test_zarr - ), f"Wrong zarr file path: {called_processes[0][3]}" + assert called_processes[0][3] == test_zarr, ( + f"Wrong zarr file path: {called_processes[0][3]}" + ) finally: # Clean up any temporary files @@ -1301,7 +1309,7 @@ def mock_run_pipeline(self, temp_config_path, input_path, top_N): def test_zarr_image_processing_consistency(test_config, test_zarr): """Test that zarr image processing produces consistent results with standard methods.""" - from prediction_process_zarr import read_and_preprocess_image_from_zarr + from prediction_utils import read_and_preprocess_image_from_zarr # Open the zarr file and get a sample image root = zarr.open_group(test_zarr, mode="r") @@ -1566,9 +1574,9 @@ def test_zarr_fallback_filenames_have_prefix(tmp_path, test_config): for batch_idx, filenames in enumerate(batch_filenames): sample_filename = filenames[0] # Should have format: __image_000000 - assert ( - "__image_" in sample_filename - ), f"Batch {batch_idx} fallback filename doesn't have expected format. Got: {sample_filename}" + assert "__image_" in sample_filename, ( + f"Batch {batch_idx} fallback filename doesn't have expected format. Got: {sample_filename}" + ) # Verify no collision between batches set_0 = set(batch_filenames[0]) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..3a800e5 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,7 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Shared fixtures for integration tests.""" diff --git a/tests/test_fitsbolt_config_persistence.py b/tests/integration/test_fitsbolt_config_persistence.py similarity index 99% rename from tests/test_fitsbolt_config_persistence.py rename to tests/integration/test_fitsbolt_config_persistence.py index 179b54f..ee58584 100644 --- a/tests/test_fitsbolt_config_persistence.py +++ b/tests/integration/test_fitsbolt_config_persistence.py @@ -18,7 +18,8 @@ import numpy as np import torch from dotmap import DotMap -from fitsbolt.cfg.create_config import create_config as fb_create_cfg, validate_config +from fitsbolt.cfg.create_config import create_config as fb_create_cfg +from fitsbolt.cfg.create_config import validate_config from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod from anomaly_match.data_io.load_images import get_fitsbolt_config diff --git a/tests/test_model_io_integration.py b/tests/integration/test_model_io_integration.py similarity index 80% rename from tests/test_model_io_integration.py rename to tests/integration/test_model_io_integration.py index 6c7f6a6..eb8e0e3 100644 --- a/tests/test_model_io_integration.py +++ b/tests/integration/test_model_io_integration.py @@ -5,16 +5,19 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import tempfile import shutil +import tempfile from pathlib import Path + +import pytest import torch import torch.nn as nn from dotmap import DotMap +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod from anomaly_match.data_io.SessionIOHandler import SessionIOHandler from anomaly_match.pipeline.SessionTracker import SessionTracker -from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod +from anomaly_match.utils.get_net_builder import get_net_builder class MockModel(nn.Module): @@ -165,3 +168,36 @@ def test_load_model_checkpoint_nonexistent(self): checkpoint = self.session_io.load_model_checkpoint(str(self.temp_dir / "nonexistent.pkl")) assert checkpoint is None + + +TEST_MODEL_PATH = Path(__file__).parent.parent / "test_data" / "test_model.pth" + + +@pytest.mark.skipif(not TEST_MODEL_PATH.exists(), reason="test_model.pth not available") +class TestStoredModelLoading: + """Regression tests for loading the stored test_model.pth checkpoint. + + These tests verify that the checked-in test model remains compatible + with the current model architecture (timm-based EfficientNet). + """ + + def test_stored_model_has_expected_keys(self): + """Verify the stored checkpoint contains expected top-level keys.""" + checkpoint = torch.load(str(TEST_MODEL_PATH), weights_only=False, map_location="cpu") + + assert "eval_model" in checkpoint, ( + f"Checkpoint missing 'eval_model' key. Found: {list(checkpoint.keys())}" + ) + assert "train_model" in checkpoint + + def test_stored_model_loads_into_efficientnet_lite0(self): + """Verify stored model state_dict is compatible with the current architecture.""" + checkpoint = torch.load(str(TEST_MODEL_PATH), weights_only=False, map_location="cpu") + + net_builder = get_net_builder("efficientnet-lite0", pretrained=False, in_channels=3) + model = net_builder(num_classes=2, in_channels=3) + + # This will raise RuntimeError if keys don't match (the exact regression + # that would occur if the model was saved with a different architecture) + model.load_state_dict(checkpoint["eval_model"]) + model.load_state_dict(checkpoint["train_model"]) diff --git a/tests/integration/test_multispectral_training.py b/tests/integration/test_multispectral_training.py new file mode 100644 index 0000000..b5dba50 --- /dev/null +++ b/tests/integration/test_multispectral_training.py @@ -0,0 +1,65 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Integration tests for multispectral (N-channel) model training.""" + +import pytest + +from anomaly_match.datasets.SSL_Dataset import SSL_Dataset + + +@pytest.mark.slow +class TestMultispectralTraining: + """Tests for multispectral training pipeline.""" + + def test_model_training_4_channels(self, multispectral_config): + """Test that FixMatch can be set up with 4-channel data.""" + from anomaly_match.models.FixMatch import FixMatch + from anomaly_match.utils.get_net_builder import get_net_builder + + # Load datasets + ssl_dataset = SSL_Dataset(cfg=multispectral_config, train=True) + labeled_dset, unlabeled_dset = ssl_dataset.get_ssl_dset() + + # Verify datasets have correct channel count + assert labeled_dset.num_channels == 4 + assert unlabeled_dset.num_channels == 4 + + # Build network + net_builder = get_net_builder( + multispectral_config.net, + pretrained=multispectral_config.pretrained, + in_channels=4, + ) + + # Create model + model = FixMatch( + net_builder=net_builder, + num_classes=2, + in_channels=4, + ema_m=multispectral_config.ema_m, + T=multispectral_config.temperature, + p_cutoff=multispectral_config.p_cutoff, + lambda_u=multispectral_config.ulb_loss_ratio, + ) + + # Set up data loaders + model.set_data_loader( + cfg=multispectral_config, + lb_dset=labeled_dset, + ulb_dset=unlabeled_dset, + eval_dset=None, + ) + + # Verify model is set up correctly for 4-channel input + assert model.train_model is not None + assert model.eval_model is not None + # The first conv layer should accept 4 channels + # TestCNN uses _conv_stem, timm models use conv_stem + train_model = model.train_model + first_conv = getattr(train_model, "conv_stem", getattr(train_model, "_conv_stem", None)) + assert first_conv is not None + assert first_conv.in_channels == 4 diff --git a/tests/test_run_label_migration.py b/tests/integration/test_run_label_migration.py similarity index 99% rename from tests/test_run_label_migration.py rename to tests/integration/test_run_label_migration.py index dfc2fce..a484cf7 100644 --- a/tests/test_run_label_migration.py +++ b/tests/integration/test_run_label_migration.py @@ -5,12 +5,13 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import tempfile import os +import tempfile +from unittest.mock import Mock + +import pandas as pd import pytest import torch -import pandas as pd -from unittest.mock import Mock from anomaly_match.data_io.SessionIOHandler import SessionIOHandler from anomaly_match.pipeline.SessionTracker import SessionTracker diff --git a/tests/session_test.py b/tests/integration/test_session.py similarity index 98% rename from tests/session_test.py rename to tests/integration/test_session.py index 624a6cc..e267371 100644 --- a/tests/session_test.py +++ b/tests/integration/test_session.py @@ -4,12 +4,16 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import pytest import os + +import ipywidgets as widgets import numpy as np import pandas as pd +import pytest import torch -import ipywidgets as widgets + +pytestmark = pytest.mark.slow + import anomaly_match as am from anomaly_match.pipeline.session import Session @@ -27,7 +31,10 @@ def base_config(): cfg.data_dir = "tests/test_data/" cfg.normalisation.image_size = [64, 64] cfg.normalisation.n_output_channels = 3 + cfg.net = "test-cnn" + cfg.pretrained = False cfg.num_train_iter = 2 + cfg.num_workers = 0 cfg.test_ratio = 0.5 cfg.output_dir = "tests/test_output" return cfg, out @@ -212,9 +219,9 @@ def test_load_top_files(trained_session): # Verify that the transpose from CHW to HWC was applied correctly # Convert the original CHW back to HWC for comparison expected_images_hwc = test_images_chw.transpose(0, 2, 3, 1) - assert np.array_equal( - session.img_catalog, expected_images_hwc - ), "CHW to HWC conversion should be correct" + assert np.array_equal(session.img_catalog, expected_images_hwc), ( + "CHW to HWC conversion should be correct" + ) # Test loading with images already in HWC format test_images_hwc = np.random.randint(0, 255, (top_N, 64, 64, 3), dtype=np.uint8) @@ -235,9 +242,9 @@ def test_load_top_files(trained_session): # Verify that float images were converted to uint8 assert session.img_catalog.dtype == np.uint8, "Float images should be converted to uint8" expected_uint8 = (test_images_float * 255.0).clip(0, 255).astype(np.uint8) - assert np.array_equal( - session.img_catalog, expected_uint8 - ), "Float to uint8 conversion should be correct" + assert np.array_equal(session.img_catalog, expected_uint8), ( + "Float to uint8 conversion should be correct" + ) # Clean up test files if os.path.exists(output_csv_path): @@ -705,9 +712,9 @@ def test_iteration_scores_saved_after_training(base_config): # Verify score mapping: check a few samples match between CSV and session for idx, (filename, score) in enumerate(zip(session.filenames[:5], session.scores[:5])): csv_score = unlabelled_df[unlabelled_df["filename"] == filename]["score"].values[0] - assert ( - abs(csv_score - score) < 1e-6 - ), f"Score mismatch for {filename}: {csv_score} vs {score}" + assert abs(csv_score - score) < 1e-6, ( + f"Score mismatch for {filename}: {csv_score} vs {score}" + ) # If test_ratio > 0, verify test scores were also saved if cfg.test_ratio > 0: diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_2021_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_2021_4ch.tiff new file mode 100644 index 0000000..6699872 Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_2021_4ch.tiff differ diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_2172_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_2172_4ch.tiff new file mode 100644 index 0000000..39a3aff Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_2172_4ch.tiff differ diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_2_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_2_4ch.tiff new file mode 100644 index 0000000..e276931 Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_2_4ch.tiff differ diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_3_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_3_4ch.tiff new file mode 100644 index 0000000..742c851 Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_3_4ch.tiff differ diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_4989_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_4989_4ch.tiff new file mode 100644 index 0000000..6287179 Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_4989_4ch.tiff differ diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_5260_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_5260_4ch.tiff new file mode 100644 index 0000000..553cc0a Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_5260_4ch.tiff differ diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_5783_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_5783_4ch.tiff new file mode 100644 index 0000000..405afdd Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_5783_4ch.tiff differ diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_5_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_5_4ch.tiff new file mode 100644 index 0000000..a3b90a0 Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_5_4ch.tiff differ diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_6212_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_6212_4ch.tiff new file mode 100644 index 0000000..d5cd0da Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_6212_4ch.tiff differ diff --git a/tests/test_data/multispectral_4ch/Abell2390_VIS_6_4ch.tiff b/tests/test_data/multispectral_4ch/Abell2390_VIS_6_4ch.tiff new file mode 100644 index 0000000..d7a2d4f Binary files /dev/null and b/tests/test_data/multispectral_4ch/Abell2390_VIS_6_4ch.tiff differ diff --git a/tests/test_data/multispectral_4ch/labeled_data.csv b/tests/test_data/multispectral_4ch/labeled_data.csv new file mode 100644 index 0000000..ce6e7f2 --- /dev/null +++ b/tests/test_data/multispectral_4ch/labeled_data.csv @@ -0,0 +1,7 @@ +filename,label +Abell2390_VIS_2_4ch.tiff,anomaly +Abell2390_VIS_3_4ch.tiff,anomaly +Abell2390_VIS_5_4ch.tiff,anomaly +Abell2390_VIS_6_4ch.tiff,normal +Abell2390_VIS_2021_4ch.tiff,normal +Abell2390_VIS_2172_4ch.tiff,normal diff --git a/tests/test_data/multispectral_4ch/metadata.csv b/tests/test_data/multispectral_4ch/metadata.csv new file mode 100644 index 0000000..60d78e4 --- /dev/null +++ b/tests/test_data/multispectral_4ch/metadata.csv @@ -0,0 +1,11 @@ +filename,sourceID,ra,dec,custom_metadata +Abell2390_VIS_2_4ch.tiff,MS_00000,328.4034,17.695,4ch test source 0 +Abell2390_VIS_3_4ch.tiff,MS_00001,328.4054,17.696,4ch test source 1 +Abell2390_VIS_5_4ch.tiff,MS_00002,328.4074,17.697,4ch test source 2 +Abell2390_VIS_6_4ch.tiff,MS_00003,328.40939999999995,17.698,4ch test source 3 +Abell2390_VIS_2021_4ch.tiff,MS_00004,328.41139999999996,17.699,4ch test source 4 +Abell2390_VIS_2172_4ch.tiff,MS_00005,328.41339999999997,17.7,4ch test source 5 +Abell2390_VIS_4989_4ch.tiff,MS_00006,328.4154,17.701,4ch test source 6 +Abell2390_VIS_5260_4ch.tiff,MS_00007,328.4174,17.702,4ch test source 7 +Abell2390_VIS_5783_4ch.tiff,MS_00008,328.4194,17.703,4ch test source 8 +Abell2390_VIS_6212_4ch.tiff,MS_00009,328.42139999999995,17.704,4ch test source 9 diff --git a/tests/test_data/normalisation_consistency/EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits b/tests/test_data/normalisation_consistency/EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits new file mode 100644 index 0000000..79179f6 Binary files /dev/null and b/tests/test_data/normalisation_consistency/EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits differ diff --git a/tests/test_data/normalisation_consistency/mock_sources.csv b/tests/test_data/normalisation_consistency/mock_sources.csv new file mode 100644 index 0000000..2892a73 --- /dev/null +++ b/tests/test_data/normalisation_consistency/mock_sources.csv @@ -0,0 +1,6 @@ +SourceID,RA,Dec,diameter_pixel,fits_file_paths +MockSource_102018211000001,150.1406475476454,2.336233868960458,150,['C:\\Arbeit\\Code\\AnomalyMatch\\tests\\test_data\\normalisation_consistency\\EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits'] +MockSource_102018211000002,150.13846652223168,2.3426457870360125,150,['C:\\Arbeit\\Code\\AnomalyMatch\\tests\\test_data\\normalisation_consistency\\EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits'] +MockSource_102018211000003,150.13720611369422,2.3416953968595355,150,['C:\\Arbeit\\Code\\AnomalyMatch\\tests\\test_data\\normalisation_consistency\\EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits'] +MockSource_102018211000004,150.13954247363657,2.3378254467880653,150,['C:\\Arbeit\\Code\\AnomalyMatch\\tests\\test_data\\normalisation_consistency\\EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits'] +MockSource_102018211000005,150.13806143453291,2.3366348110534423,150,['C:\\Arbeit\\Code\\AnomalyMatch\\tests\\test_data\\normalisation_consistency\\EUC_MER_BGSUB-MOSAIC-VIS_TILE102018211-ACBD03_20251124T100053.096Z_00.00.fits'] diff --git a/tests/test_data/test_model.pth b/tests/test_data/test_model.pth index ffd70f0..1ffb3f4 100644 Binary files a/tests/test_data/test_model.pth and b/tests/test_data/test_model.pth differ diff --git a/tests/ui/conftest.py b/tests/ui/conftest.py new file mode 100644 index 0000000..93db83a --- /dev/null +++ b/tests/ui/conftest.py @@ -0,0 +1,7 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Shared fixtures for UI tests.""" diff --git a/tests/ui/test_memory_monitor.py b/tests/ui/test_memory_monitor.py index 800eacc..17ab3bb 100644 --- a/tests/ui/test_memory_monitor.py +++ b/tests/ui/test_memory_monitor.py @@ -4,9 +4,13 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import pytest import asyncio -from anomaly_match.ui.memory_monitor import MemoryMonitor + +import pytest + +from anomaly_match_ui.memory_monitor import MemoryMonitor + +pytestmark = pytest.mark.ui @pytest.mark.asyncio diff --git a/tests/ui_test.py b/tests/ui/test_widget.py similarity index 62% rename from tests/ui_test.py rename to tests/ui/test_widget.py index bbd3070..7a60843 100644 --- a/tests/ui_test.py +++ b/tests/ui/test_widget.py @@ -4,71 +4,23 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import pytest -import ipywidgets as widgets -from anomaly_match.ui.Widget import Widget, shorten_filename -from anomaly_match.pipeline.session import Session -import anomaly_match as am import os -import numpy as np from unittest.mock import patch + +import ipywidgets as widgets import matplotlib +import numpy as np +import pytest from PIL import Image -matplotlib.use("Agg") # Prevent matplotlib windows from opening +import anomaly_match as am +from anomaly_match.pipeline.session import Session +from anomaly_match_ui.utils.backend_interface import BackendInterface +from anomaly_match_ui.widget import Widget +pytestmark = [pytest.mark.ui, pytest.mark.slow] -class TestShortenFilename: - """Tests for the shorten_filename helper function.""" - - def test_short_filename_unchanged(self): - """Filenames within max length should remain unchanged.""" - assert shorten_filename("short.fits", max_length=25) == "short.fits" - assert shorten_filename("image.jpg", max_length=25) == "image.jpg" - - def test_long_filename_shortened(self): - """Long filenames should be shortened to max_length.""" - long_name = "very_long_filename_that_exceeds_limit.fits" - result = shorten_filename(long_name, max_length=25) - assert len(result) <= 25 - assert result.endswith(".fits") - assert "..." in result - - def test_filename_with_multiple_dots(self): - """Filenames with multiple dots should preserve only the extension.""" - name = "image.2024.01.15.observation.fits" - result = shorten_filename(name, max_length=25) - assert len(result) <= 25 - assert result.endswith(".fits") - assert "..." in result - - def test_filename_without_extension(self): - """Filenames without extension should still be shortened correctly.""" - name = "very_long_filename_without_any_extension" - result = shorten_filename(name, max_length=25) - assert len(result) <= 25 - assert "..." in result - - def test_exact_max_length(self): - """Filename exactly at max_length should be unchanged.""" - name = "exactly_25_chars_long.fit" - assert len(name) == 25 - assert shorten_filename(name, max_length=25) == name - - def test_very_short_max_length(self): - """Very short max_length should still produce valid output.""" - name = "some_filename.fits" - result = shorten_filename(name, max_length=10) - assert len(result) <= 10 - assert "..." in result - - def test_preserves_start_and_end(self): - """Shortened name should contain parts of the original start and end.""" - name = "START_middle_content_END.fits" - result = shorten_filename(name, max_length=20) - assert result.startswith("START") - # Should contain some part of the end before the extension - assert "END" in result or "..." in result +matplotlib.use("Agg") # Prevent matplotlib windows from opening @pytest.fixture(scope="session") @@ -84,7 +36,10 @@ def base_config(): cfg.data_dir = "tests/test_data/" cfg.normalisation.image_size = [64, 64] cfg.normalisation.n_output_channels = 3 + cfg.net = "test-cnn" + cfg.pretrained = False cfg.num_train_iter = 2 + cfg.num_workers = 0 cfg.test_ratio = 0.5 cfg.output_dir = "tests/test_output" cfg.prediction_search_dir = "tests/test_data/" # Set a default search directory @@ -105,7 +60,9 @@ def session_fixture(base_config): @pytest.fixture(scope="session") def ui_widget(session_fixture): with patch("IPython.display.display"): # Prevent actual display calls - widget = Widget(session_fixture) + # Set up the backend interface + BackendInterface.set_session(session_fixture) + widget = Widget() yield widget # Only close widgets in teardown if they still exist try: @@ -127,13 +84,13 @@ def setup_display_mocks(): # Test classes to organize related tests class TestUIInitialization: def test_ui_initialization(self, ui_widget): - assert ui_widget.session is not None + assert BackendInterface.get_session() is not None assert isinstance(ui_widget.ui["image_widget"], widgets.Image) assert isinstance(ui_widget.ui["filename_text"], widgets.HTML) def test_normalization_dropdown(self, ui_widget): # Get the initial normalization method - initial_method = ui_widget.session.cfg.normalisation.normalisation_method + initial_method = BackendInterface.get_config().normalisation.normalisation_method # Get the dropdown options and find a different method dropdown_options = ui_widget.ui["normalisation_dropdown"].options @@ -148,7 +105,7 @@ def test_normalization_dropdown(self, ui_widget): ui_widget.ui["normalisation_dropdown"].value = new_method # Assert that the session config was updated - assert ui_widget.session.cfg.normalisation.normalisation_method == new_method + assert BackendInterface.get_config().normalisation.normalisation_method == new_method class TestUINavigation: @@ -167,25 +124,25 @@ def test_previous_image(self, ui_widget): class TestUISorting: def test_sort_by_anomalous(self, ui_widget): ui_widget.sort_by_anomalous() - assert ui_widget.session.scores[0] >= ui_widget.session.scores[-1] + scores = BackendInterface.get_scores() + assert scores[0] >= scores[-1] def test_sort_by_nominal(self, ui_widget): ui_widget.sort_by_nominal() - assert ui_widget.session.scores[0] <= ui_widget.session.scores[-1] + scores = BackendInterface.get_scores() + assert scores[0] <= scores[-1] def test_sort_by_mean(self, ui_widget): ui_widget.sort_by_mean() - mean_score = ui_widget.session.scores.mean() - assert abs(ui_widget.session.scores[0] - mean_score) <= abs( - ui_widget.session.scores[-1] - mean_score - ) + scores = BackendInterface.get_scores() + mean_score = scores.mean() + assert abs(scores[0] - mean_score) <= abs(scores[-1] - mean_score) def test_sort_by_median(self, ui_widget): ui_widget.sort_by_median() - median_score = np.median(ui_widget.session.scores) - assert abs(ui_widget.session.scores[0] - median_score) <= abs( - ui_widget.session.scores[-1] - median_score - ) + scores = BackendInterface.get_scores() + median_score = np.median(scores) + assert abs(scores[0] - median_score) <= abs(scores[-1] - median_score) class TestUIImageProcessing: @@ -211,46 +168,49 @@ def test_adjust_brightness_contrast(self, ui_widget): class TestUIModelOperations: def test_save_load_model(self, ui_widget): ui_widget.save_model() - assert os.path.exists(ui_widget.session.cfg.model_path) + cfg = BackendInterface.get_config() + assert os.path.exists(cfg.model_path) ui_widget.load_model() - assert ui_widget.session.model is not None + assert BackendInterface.get_model() is not None def test_train(self, ui_widget): ui_widget.train() - assert ui_widget.session.eval_performance is not None + assert BackendInterface.get_eval_performance() is not None def test_reset_model(self, ui_widget): ui_widget.reset_model() - assert ui_widget.session.model is not None + assert BackendInterface.get_model() is not None class TestUIBatchOperations: def test_update_batch_size(self, ui_widget): initial_batch_size = ui_widget.ui["batch_size_slider"].value ui_widget.ui["batch_size_slider"].value = initial_batch_size + 500 - assert ui_widget.session.cfg.N_to_load == initial_batch_size + 500 + cfg = BackendInterface.get_config() + assert cfg.N_to_load == initial_batch_size + 500 def test_next_batch(self, ui_widget): ui_widget.next_batch() - assert ui_widget.session.img_catalog is not None + session = BackendInterface.get_session() + assert session.img_catalog is not None def test_search_all_files(self, ui_widget): - with patch("anomaly_match.ui.Widget.display"): + with patch("anomaly_match_ui.widget.display"): # Save model first so evaluate_all_images can find it ui_widget.save_model() + cfg = BackendInterface.get_config() # Ensure test data directory exists and has files - os.makedirs(ui_widget.session.cfg.prediction_search_dir, exist_ok=True) + os.makedirs(cfg.prediction_search_dir, exist_ok=True) # Create a test image if directory is empty - if not os.listdir(ui_widget.session.cfg.prediction_search_dir): + if not os.listdir(cfg.prediction_search_dir): test_img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) - test_img_path = os.path.join( - ui_widget.session.cfg.prediction_search_dir, "test.jpg" - ) + test_img_path = os.path.join(cfg.prediction_search_dir, "test.jpg") Image.fromarray(test_img).save(test_img_path) ui_widget.search_all_files() - assert len(ui_widget.session.img_catalog) > 0 + session = BackendInterface.get_session() + assert len(session.img_catalog) > 0 def test_cleanup(ui_widget): diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000..9ff1a91 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,35 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Shared fixtures for unit tests.""" + +import ipywidgets as widgets +import pytest + +import anomaly_match as am + + +@pytest.fixture(scope="module") +def base_config(): + """Base configuration for unit tests (no training, no output dir).""" + out = widgets.Output( + layout=widgets.Layout( + border="1px solid white", height="400px", background_color="black", overflow="auto" + ), + ) + + cfg = am.get_default_cfg() + am.set_log_level("debug", cfg) + cfg.data_dir = "tests/test_data/" + cfg.normalisation.image_size = [64, 64] + cfg.normalisation.n_output_channels = 3 + cfg.net = "test-cnn" + cfg.pretrained = False + cfg.num_train_iter = 2 + cfg.num_workers = 0 + cfg.test_ratio = 0.5 + cfg.output_dir = "tests/test_output" + return cfg, out diff --git a/tests/test_batch_size_estimation.py b/tests/unit/test_batch_size_estimation.py similarity index 98% rename from tests/test_batch_size_estimation.py rename to tests/unit/test_batch_size_estimation.py index 746f2f9..c411f66 100644 --- a/tests/test_batch_size_estimation.py +++ b/tests/unit/test_batch_size_estimation.py @@ -6,9 +6,9 @@ # the terms contained in the file 'LICENCE.txt'. import pytest -from prediction_utils import estimate_batch_size, MEMORY_COEFFICIENTS -from anomaly_match import get_default_cfg +from anomaly_match import get_default_cfg +from prediction_utils import MEMORY_COEFFICIENTS, estimate_batch_size FAKE_VRAM_BYTES = 16 * 1024**3 # 16GB @@ -152,7 +152,6 @@ def test_estimate_batch_size_invalid_memory(test_config): def test_estimate_batch_size_invalid_model_coefficients(test_config, monkeypatch): - import prediction_utils # Inject invalid coefficients diff --git a/tests/cfg_validation_test.py b/tests/unit/test_config_validation.py similarity index 99% rename from tests/cfg_validation_test.py rename to tests/unit/test_config_validation.py index f4b36da..03d03ef 100644 --- a/tests/cfg_validation_test.py +++ b/tests/unit/test_config_validation.py @@ -7,6 +7,7 @@ import pytest from loguru import logger + from anomaly_match.utils.get_default_cfg import get_default_cfg from anomaly_match.utils.validate_config import validate_config diff --git a/tests/unit/test_cutana_stream_utils.py b/tests/unit/test_cutana_stream_utils.py new file mode 100644 index 0000000..67c0c94 --- /dev/null +++ b/tests/unit/test_cutana_stream_utils.py @@ -0,0 +1,87 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Tests for cutana catalogue validation and source counting.""" + +import pandas as pd +import pytest + +from anomaly_match.utils.cutana_stream_utils import cutana_validate_files_and_count_sources + + +def _make_valid_df(n=50): + """Create a minimal valid cutana catalogue DataFrame.""" + return pd.DataFrame( + { + "SourceID": range(n), + "RA": [180.0] * n, + "Dec": [-30.0] * n, + "fits_file_paths": ["dummy.fits"] * n, + "diameter_pixel": [10] * n, + } + ) + + +@pytest.fixture +def valid_parquet(tmp_path): + """Write a valid parquet catalogue and return its path.""" + path = tmp_path / "valid.parquet" + _make_valid_df(100).to_parquet(path, index=False) + return path + + +@pytest.fixture +def invalid_parquet(tmp_path): + """Write a parquet file with wrong columns.""" + path = tmp_path / "invalid.parquet" + pd.DataFrame({"bad_col": [1, 2, 3]}).to_parquet(path, index=False) + return path + + +class TestCutanaValidateFilesAndCountSources: + def test_valid_parquet_accepted(self, valid_parquet): + files, total, chunks = cutana_validate_files_and_count_sources( + [valid_parquet], chunk_size=50 + ) + assert len(files) == 1 + assert total == 100 + assert chunks == 2 # 100 rows / 50 chunk_size + + def test_row_count_from_metadata(self, tmp_path): + """Row counts should match actual rows without scanning data.""" + paths = [] + for i, n in enumerate([200, 300]): + p = tmp_path / f"cat_{i}.parquet" + _make_valid_df(n).to_parquet(p, index=False) + paths.append(p) + + files, total, chunks = cutana_validate_files_and_count_sources(paths, chunk_size=100) + assert len(files) == 2 + assert total == 500 + assert chunks == 5 # 200/100 + 300/100 + + def test_invalid_columns_rejected(self, invalid_parquet): + files, total, chunks = cutana_validate_files_and_count_sources( + [invalid_parquet], chunk_size=100 + ) + assert len(files) == 0 + assert total == 0 + + def test_mixed_valid_invalid(self, valid_parquet, invalid_parquet): + """Invalid file should be skipped, valid file should be counted.""" + files, total, chunks = cutana_validate_files_and_count_sources( + [invalid_parquet, valid_parquet], chunk_size=100 + ) + assert len(files) == 1 + assert total == 100 + + def test_valid_csv_accepted(self, tmp_path): + path = tmp_path / "valid.csv" + _make_valid_df(75).to_csv(path, index=False) + files, total, chunks = cutana_validate_files_and_count_sources([path], chunk_size=50) + assert len(files) == 1 + assert total == 75 + assert chunks == 2 # 50 + 25 diff --git a/tests/unit/test_data_utils.py b/tests/unit/test_data_utils.py new file mode 100644 index 0000000..2aec1e1 --- /dev/null +++ b/tests/unit/test_data_utils.py @@ -0,0 +1,120 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Tests for data loading utilities.""" + +import numpy as np +import pytest +import torch +from torch.utils.data import DataLoader + +from anomaly_match.datasets.BasicDataset import BasicDataset +from anomaly_match.datasets.data_utils import get_data_loader, get_sampler_by_name +from anomaly_match.image_processing.transforms import get_prediction_transforms + + +@pytest.fixture +def simple_dataset(): + """Create a simple BasicDataset for testing.""" + imgs = np.random.randint(0, 255, (20, 64, 64, 3), dtype=np.uint8) + filenames = [f"img_{i}.jpg" for i in range(20)] + targets = [0] * 10 + [1] * 10 + transform = get_prediction_transforms(num_channels=3) + return BasicDataset(imgs, filenames, targets, num_classes=2, transform=transform) + + +class TestGetSamplerByName: + def test_random_sampler(self): + sampler = get_sampler_by_name("RandomSampler") + assert sampler is torch.utils.data.sampler.RandomSampler + + def test_sequential_sampler(self): + sampler = get_sampler_by_name("SequentialSampler") + assert sampler is torch.utils.data.sampler.SequentialSampler + + def test_invalid_sampler_raises(self): + with pytest.raises(AttributeError, match="not found"): + get_sampler_by_name("NonexistentSampler") + + +class TestGetDataLoader: + def test_requires_batch_size(self, simple_dataset): + with pytest.raises(AssertionError, match="Batch size must be specified"): + get_data_loader(simple_dataset, batch_size=None) + + def test_basic_dataloader(self, simple_dataset): + loader = get_data_loader(simple_dataset, batch_size=4, num_workers=0) + assert isinstance(loader, DataLoader) + batch = next(iter(loader)) + assert batch[0].shape[0] == 4 + + def test_shuffle_dataloader(self, simple_dataset): + loader = get_data_loader(simple_dataset, batch_size=4, shuffle=True, num_workers=0) + assert isinstance(loader, DataLoader) + + def test_weighted_sampler(self, simple_dataset): + loader = get_data_loader( + simple_dataset, + batch_size=4, + use_weighted_sampler=True, + num_workers=0, + ) + assert isinstance(loader, DataLoader) + + def test_weighted_sampler_with_num_iters(self, simple_dataset): + loader = get_data_loader( + simple_dataset, + batch_size=4, + use_weighted_sampler=True, + num_iters=10, + num_workers=0, + ) + assert isinstance(loader, DataLoader) + + def test_weighted_sampler_with_num_epochs(self, simple_dataset): + loader = get_data_loader( + simple_dataset, + batch_size=4, + use_weighted_sampler=True, + num_epochs=2, + num_workers=0, + ) + assert isinstance(loader, DataLoader) + + def test_random_sampler_by_name(self, simple_dataset): + loader = get_data_loader( + simple_dataset, + batch_size=4, + data_sampler="RandomSampler", + num_workers=0, + ) + assert isinstance(loader, DataLoader) + + def test_unsupported_sampler_raises(self, simple_dataset): + with pytest.raises(RuntimeError, match="not fully implemented"): + get_data_loader( + simple_dataset, + batch_size=4, + data_sampler="SequentialSampler", + num_workers=0, + ) + + +class TestWeightedSamplerSingleClass: + def test_single_class_uniform_weights(self): + """Weighted sampler with only one class should use uniform weights.""" + imgs = np.random.randint(0, 255, (10, 64, 64, 3), dtype=np.uint8) + filenames = [f"img_{i}.jpg" for i in range(10)] + targets = [0] * 10 # All same class + transform = get_prediction_transforms(num_channels=3) + dataset = BasicDataset(imgs, filenames, targets, num_classes=2, transform=transform) + loader = get_data_loader( + dataset, + batch_size=4, + use_weighted_sampler=True, + num_workers=0, + ) + assert isinstance(loader, DataLoader) diff --git a/tests/dataset_test.py b/tests/unit/test_dataset.py similarity index 98% rename from tests/dataset_test.py rename to tests/unit/test_dataset.py index 1a93bbc..05dbd2b 100644 --- a/tests/dataset_test.py +++ b/tests/unit/test_dataset.py @@ -4,22 +4,23 @@ # is part of this source code package. No part of the package, including # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import pytest +import os +import tempfile + import numpy as np import pandas as pd +import pytest import torch +from PIL import Image + import anomaly_match as am -from anomaly_match.datasets.Label import Label from anomaly_match.datasets.AnomalyDetectionDataset import AnomalyDetectionDataset from anomaly_match.datasets.BasicDataset import BasicDataset +from anomaly_match.datasets.Label import Label from anomaly_match.datasets.SSL_Dataset import SSL_Dataset -import os -import tempfile -from PIL import Image - from anomaly_match.image_processing.transforms import ( - get_weak_transforms, get_prediction_transforms, + get_weak_transforms, ) @@ -112,9 +113,9 @@ def test_multiple_file_extensions_support(multi_extension_dataset, test_config): dataset = AnomalyDetectionDataset(test_config) # Check if all images were found - assert len(dataset.all_filenames) == len( - extensions - ), "Not all images with different extensions were found" + assert len(dataset.all_filenames) == len(extensions), ( + "Not all images with different extensions were found" + ) # Verify that all expected files are included for filename in test_images: diff --git a/tests/unit/test_display_transforms.py b/tests/unit/test_display_transforms.py new file mode 100644 index 0000000..70bdff4 --- /dev/null +++ b/tests/unit/test_display_transforms.py @@ -0,0 +1,198 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Tests for display transform functions.""" + +import numpy as np +import pytest +from PIL import Image + +from anomaly_match_ui.utils.display_transforms import ( + apply_transforms_ui, + display_image_normalisation, + prepare_for_display, +) + + +class TestPrepareForDisplay: + def test_rgb_passthrough(self): + img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + result = prepare_for_display(img) + assert result.shape == (64, 64, 3) + assert np.array_equal(result, img) + + def test_grayscale_to_rgb(self): + img = np.random.randint(0, 255, (64, 64, 1), dtype=np.uint8) + result = prepare_for_display(img) + assert result.shape == (64, 64, 3) + # All channels should be the same + assert np.array_equal(result[:, :, 0], result[:, :, 1]) + assert np.array_equal(result[:, :, 1], result[:, :, 2]) + + def test_2d_grayscale(self): + img = np.random.randint(0, 255, (64, 64), dtype=np.uint8) + result = prepare_for_display(img) + assert result.shape == (64, 64, 3) + + def test_2_channel_to_rgb(self): + img = np.random.randint(0, 255, (64, 64, 2), dtype=np.uint8) + result = prepare_for_display(img) + assert result.shape == (64, 64, 3) + + def test_4_channel_default_mapping(self): + img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8) + result = prepare_for_display(img) + assert result.shape == (64, 64, 3) + # Default mapping uses first 3 channels + assert np.array_equal(result, img[:, :, :3]) + + def test_4_channel_custom_mapping(self): + img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8) + result = prepare_for_display(img, rgb_mapping=[1, 2, 3]) + assert result.shape == (64, 64, 3) + assert np.array_equal(result[:, :, 0], img[:, :, 1]) + assert np.array_equal(result[:, :, 1], img[:, :, 2]) + assert np.array_equal(result[:, :, 2], img[:, :, 3]) + + def test_invalid_rgb_mapping_length(self): + img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8) + with pytest.raises(ValueError, match="must have 3 elements"): + prepare_for_display(img, rgb_mapping=[0, 1]) + + def test_invalid_rgb_mapping_index(self): + img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8) + with pytest.raises(ValueError, match="exceed channel count"): + prepare_for_display(img, rgb_mapping=[0, 1, 5]) + + def test_float_to_uint8_conversion(self): + img = np.random.random((64, 64, 3)).astype(np.float32) + result = prepare_for_display(img) + assert result.dtype == np.uint8 + assert result.shape == (64, 64, 3) + + def test_pil_image_input(self): + pil_img = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)) + result = prepare_for_display(pil_img) + assert result.shape == (64, 64, 3) + assert result.dtype == np.uint8 + + def test_invalid_input_type(self): + with pytest.raises(ValueError, match="Expected numpy array or PIL Image"): + prepare_for_display("not_an_image") + + +class TestDisplayImageNormalisation: + def test_basic_normalisation(self): + img = np.random.random((64, 64, 3)).astype(np.float64) * 255 + result = display_image_normalisation(img) + assert isinstance(result, Image.Image) + + def test_constant_image(self): + img = np.full((64, 64, 3), 0.5, dtype=np.float64) + result = display_image_normalisation(img) + assert isinstance(result, Image.Image) + + def test_handles_nan(self): + img = np.random.random((64, 64, 3)) + img[10, 10, 0] = np.nan + result = display_image_normalisation(img) + assert isinstance(result, Image.Image) + + def test_handles_inf(self): + img = np.random.random((64, 64, 3)) + img[10, 10, 0] = np.inf + result = display_image_normalisation(img) + assert isinstance(result, Image.Image) + + +class TestApplyTransformsUI: + @pytest.fixture + def sample_pil_image(self): + return Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)) + + def test_no_transforms(self, sample_pil_image): + result = apply_transforms_ui( + sample_pil_image, + invert=False, + brightness=1.0, + contrast=1.0, + unsharp_mask_applied=False, + ) + assert isinstance(result, Image.Image) + + def test_invert(self, sample_pil_image): + result = apply_transforms_ui( + sample_pil_image, + invert=True, + brightness=1.0, + contrast=1.0, + unsharp_mask_applied=False, + ) + assert isinstance(result, Image.Image) + + def test_brightness(self, sample_pil_image): + result = apply_transforms_ui( + sample_pil_image, + invert=False, + brightness=1.5, + contrast=1.0, + unsharp_mask_applied=False, + ) + assert isinstance(result, Image.Image) + + def test_contrast(self, sample_pil_image): + result = apply_transforms_ui( + sample_pil_image, + invert=False, + brightness=1.0, + contrast=1.5, + unsharp_mask_applied=False, + ) + assert isinstance(result, Image.Image) + + def test_unsharp_mask(self, sample_pil_image): + result = apply_transforms_ui( + sample_pil_image, + invert=False, + brightness=1.0, + contrast=1.0, + unsharp_mask_applied=True, + ) + assert isinstance(result, Image.Image) + + def test_channel_toggling_hide_red(self, sample_pil_image): + result = apply_transforms_ui( + sample_pil_image, + invert=False, + brightness=1.0, + contrast=1.0, + unsharp_mask_applied=False, + show_r=False, + ) + result_array = np.array(result) + assert np.all(result_array[:, :, 0] == 0) + + def test_channel_visibility_list(self, sample_pil_image): + result = apply_transforms_ui( + sample_pil_image, + invert=False, + brightness=1.0, + contrast=1.0, + unsharp_mask_applied=False, + channel_visibility=[True, False, True], + ) + result_array = np.array(result) + assert np.all(result_array[:, :, 1] == 0) + + def test_all_transforms_combined(self, sample_pil_image): + result = apply_transforms_ui( + sample_pil_image, + invert=True, + brightness=1.2, + contrast=0.8, + unsharp_mask_applied=True, + ) + assert isinstance(result, Image.Image) diff --git a/tests/file_io_test.py b/tests/unit/test_file_io.py similarity index 90% rename from tests/file_io_test.py rename to tests/unit/test_file_io.py index d90df1d..c46a01a 100644 --- a/tests/file_io_test.py +++ b/tests/unit/test_file_io.py @@ -7,23 +7,25 @@ """ Tests for the image IO utility functions. """ + import os -import numpy as np -import pytest import shutil import tempfile -from PIL import Image + +import numpy as np +import pytest from astropy.io import fits +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod +from PIL import Image from anomaly_match.data_io.find_images_in_folder import ( get_image_names_from_folder, get_image_paths_from_folder, ) from anomaly_match.data_io.load_images import ( - load_and_process_wrapper, load_and_process_single_wrapper, + load_and_process_wrapper, ) -from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod def _load_image_with_fitsbolt(filepath, cfg): @@ -305,9 +307,9 @@ def test_load_and_process_images_rgba(self, test_config): transparent_img = _load_image_with_fitsbolt(self.transparent_path, cfg=test_config) assert transparent_img.shape[2] == 3 # Should still be RGB # The green channel should still be present even though the pixels were transparent - assert np.any( - transparent_img[:, :, 1] > 0 - ), "Green channel data lost in transparent image" # Test complex RGBA with gradient alpha + assert np.any(transparent_img[:, :, 1] > 0), ( + "Green channel data lost in transparent image" + ) # Test complex RGBA with gradient alpha complex_rgba_img = _load_image_with_fitsbolt(self.complex_rgba_path, cfg=test_config) assert complex_rgba_img.shape[2] == 3 # Should be RGB # Check that gradients are preserved in RGB channels @@ -352,24 +354,24 @@ def test_rgba_to_rgb_conversion_values(self, test_config): # Test that colors are preserved correctly according to alpha values # Colors with full alpha should be preserved exactly - assert np.all( - rgb_img[: height // 2, : width // 2, 0] == 255 - ), "Red with full alpha should be preserved" - assert np.all( - rgb_img[: height // 2, : width // 2, 1] == 0 - ), "Red channel should have no green" - assert np.all( - rgb_img[: height // 2, : width // 2, 2] == 0 - ), "Red channel should have no blue" + assert np.all(rgb_img[: height // 2, : width // 2, 0] == 255), ( + "Red with full alpha should be preserved" + ) + assert np.all(rgb_img[: height // 2, : width // 2, 1] == 0), ( + "Red channel should have no green" + ) + assert np.all(rgb_img[: height // 2, : width // 2, 2] == 0), ( + "Red channel should have no blue" + ) # Colors with partial alpha might be handled differently depending on implementation # We'll just check that they exist and aren't black - assert ( - np.mean(rgb_img[: height // 2, width // 2 :, 1]) > 0 - ), "Green with half alpha should be visible" - assert ( - np.mean(rgb_img[height // 2 :, : width // 2, 2]) > 0 - ), "Blue with quarter alpha should be visible" + assert np.mean(rgb_img[: height // 2, width // 2 :, 1]) > 0, ( + "Green with half alpha should be visible" + ) + assert np.mean(rgb_img[height // 2 :, : width // 2, 2]) > 0, ( + "Blue with quarter alpha should be visible" + ) # The behavior with fully transparent pixels can vary depending on the implementation # Some libraries preserve the color values even with zero alpha, others apply a background color # Instead of asserting they shouldn't be white, we'll just check that the image was loaded successfully @@ -467,9 +469,9 @@ def test_load_image_with_fitsbolt_fits(self, test_config): # Check if normalization preserved the pattern (bright area should be brighter than dark area) bright_area = extreme_img[50:80, 50:80] dark_area = extreme_img[10:40, 10:40] - assert np.mean(bright_area) > np.mean( - dark_area - ), "Normalization failed to preserve contrast" + assert np.mean(bright_area) > np.mean(dark_area), ( + "Normalization failed to preserve contrast" + ) def test_fits_extension_parameter(self, test_config): """Test the fits_extension parameter for FITS files.""" @@ -634,14 +636,14 @@ def test_fits_multiple_extensions(self, test_config): assert combined_img3.shape == (50, 50, 3), "Combined image should have shape (50, 50, 3)" # The images should be different due to different ordering - assert not np.array_equal( - combined_img1, combined_img3 - ), "Different extension order should give different results" + assert not np.array_equal(combined_img1, combined_img3), ( + "Different extension order should give different results" + ) # Int indices and string names with same order should give same result - assert np.array_equal( - combined_img1, combined_img2 - ), "Same extension order should give same results" + assert np.array_equal(combined_img1, combined_img2), ( + "Same extension order should give same results" + ) # Create another FITS file with extensions of different shapes to test error handling diff_shapes_path = os.path.join(self.test_dir, "diff_shapes.fits") @@ -668,9 +670,9 @@ def test_fits_multiple_extensions(self, test_config): # Validate the error message contains information about channel mismatch error_message = str(e_info.value) - assert ( - "channel" in error_message.lower() or "extension" in error_message.lower() - ), f"Error should mention channel or extension mismatch: {error_message}" + assert "channel" in error_message.lower() or "extension" in error_message.lower(), ( + f"Error should mention channel or extension mismatch: {error_message}" + ) # could expand test if needed with more extensions def test_load_and_process_images_fits_extension(self, test_config): @@ -759,12 +761,12 @@ def test_load_and_process_images_fits_extension(self, test_config): img_combined = results_combined[0][1] # The images should have different content because they used different extensions - assert not np.array_equal( - img_ext0, img_ext1 - ), "Different extensions should produce different images" - assert not np.array_equal( - img_ext0, img_combined - ), "Combined extensions should differ from single extension" + assert not np.array_equal(img_ext0, img_ext1), ( + "Different extensions should produce different images" + ) + assert not np.array_equal(img_ext0, img_combined), ( + "Combined extensions should differ from single extension" + ) # Also test with string extension names _update_config(test_config, fits_extension=["PRIMARY", "EXT1", "EXT2"]) @@ -774,9 +776,9 @@ def test_load_and_process_images_fits_extension(self, test_config): img_named = results_named[0][1] # Should be identical to using numeric indices [0, 1, 2] - assert np.array_equal( - img_combined, img_named - ), "String extension names should produce same result as numeric indices" + assert np.array_equal(img_combined, img_named), ( + "String extension names should produce same result as numeric indices" + ) def test_image_normalisation(self, test_config): """Test that different normalisation methods are correctly applied during image loading.""" @@ -797,9 +799,9 @@ def test_image_normalisation(self, test_config): # Test with no normalisation (default) _update_config(test_config, normalisation_method=NormalisationMethod.CONVERSION_ONLY) img_none = _load_image_with_fitsbolt(test_path, cfg=test_config) - assert np.array_equal( - img_none, test_values - ), "NONE normalisation should preserve original values" + assert np.array_equal(img_none, test_values), ( + "NONE normalisation should preserve original values" + ) # Test with LOG normalisation _update_config(test_config, normalisation_method=NormalisationMethod.LOG) @@ -815,9 +817,9 @@ def test_image_normalisation(self, test_config): # Test with ZSCALE normalisation _update_config(test_config, normalisation_method=NormalisationMethod.ZSCALE) img_zscale = _load_image_with_fitsbolt(test_path, cfg=test_config) - assert not np.array_equal( - img_zscale, test_values - ), "ZSCALE normalisation should modify values" + assert not np.array_equal(img_zscale, test_values), ( + "ZSCALE normalisation should modify values" + ) # ZScale should produce values with reasonable contrast assert np.min(img_zscale) < np.max(img_zscale), "ZSCALE should preserve contrast" @@ -829,9 +831,9 @@ def test_image_normalisation(self, test_config): # Test that all normalised outputs preserve image dimensions assert img_none.shape == test_values.shape, "NONE normalisation should preserve dimensions" assert img_log.shape == test_values.shape, "LOG normalisation should preserve dimensions" - assert ( - img_zscale.shape == test_values.shape - ), "ZSCALE normalisation should preserve dimensions" + assert img_zscale.shape == test_values.shape, ( + "ZSCALE normalisation should preserve dimensions" + ) def test_image_interpolation_orders(self, test_config): """Test different interpolation orders when resizing images. @@ -892,9 +894,9 @@ def test_image_interpolation_orders(self, test_config): 100, 3, ), f"Resized small image should be 100x100 with order {order}" - assert ( - resized_small.dtype == np.uint8 - ), f"Resized small image should be uint8 with order {order}" + assert resized_small.dtype == np.uint8, ( + f"Resized small image should be uint8 with order {order}" + ) # Resize large image (200x200 → 100x100) - downsampling resized_large = _load_image_with_fitsbolt(large_path, cfg=test_config) @@ -903,9 +905,9 @@ def test_image_interpolation_orders(self, test_config): 100, 3, ), f"Resized large image should be 100x100 with order {order}" - assert ( - resized_large.dtype == np.uint8 - ), f"Resized large image should be uint8 with order {order}" + assert resized_large.dtype == np.uint8, ( + f"Resized large image should be uint8 with order {order}" + ) # Check the center pixel of each quadrant for both resized images # Allow for some variation (±20%) in color values due to interpolation differences @@ -931,9 +933,9 @@ def test_image_interpolation_orders(self, test_config): ) else: # For zero values, small absolute threshold - assert ( - small_color[c] <= 50 - ), f"Small image order {order}, quadrant {idx}, channel {c}: expected ~0, got {small_color[c]}" + assert small_color[c] <= 50, ( + f"Small image order {order}, quadrant {idx}, channel {c}: expected ~0, got {small_color[c]}" + ) # Check large image downsampled large_color = resized_large[y, x] @@ -948,9 +950,9 @@ def test_image_interpolation_orders(self, test_config): ) else: # For zero values, small absolute threshold - assert ( - large_color[c] <= 50 - ), f"Large image order {order}, quadrant {idx}, channel {c}: expected ~0, got {large_color[c]}" + assert large_color[c] <= 50, ( + f"Large image order {order}, quadrant {idx}, channel {c}: expected ~0, got {large_color[c]}" + ) # Additional check for sharp transitions with order 0 (nearest neighbor) if order == 0: @@ -962,16 +964,16 @@ def test_image_interpolation_orders(self, test_config): # For small image left_of_boundary_small = resized_small[boundary_y, boundary_x - 1] right_of_boundary_small = resized_small[boundary_y, boundary_x + 1] - assert not np.array_equal( - left_of_boundary_small, right_of_boundary_small - ), "Small image order 0 should have sharp transitions at boundary" + assert not np.array_equal(left_of_boundary_small, right_of_boundary_small), ( + "Small image order 0 should have sharp transitions at boundary" + ) # For large image left_of_boundary_large = resized_large[boundary_y, boundary_x - 1] right_of_boundary_large = resized_large[boundary_y, boundary_x + 1] - assert not np.array_equal( - left_of_boundary_large, right_of_boundary_large - ), "Large image order 0 should have sharp transitions at boundary" + assert not np.array_equal(left_of_boundary_large, right_of_boundary_large), ( + "Large image order 0 should have sharp transitions at boundary" + ) # Higher order interpolation (order > 1) should lead to smoother transitions # This is difficult to quantify precisely, but we can check for values between the extremes for upscaling @@ -986,21 +988,21 @@ def test_image_interpolation_orders(self, test_config): unique_values_small = np.unique(boundary_region_small) # Higher order interpolation should have more unique values in the boundary region when upscaling - assert ( - len(unique_values_small) > 4 - ), f"Small image order {order} should have intermediate values at boundaries" + assert len(unique_values_small) > 4, ( + f"Small image order {order} should have intermediate values at boundaries" + ) # Compare results between different interpolation orders to verify they're not identical # We'll compare order 0 (nearest neighbor) with orders 1, 3, and 5 # These should produce visibly different results for i, upscaled_im in enumerate(upscaled_results): if i != 0: - assert not np.array_equal( - upscaled_results[0], upscaled_results[i] - ), "Order 0 and order {i} interpolation should produce different results" - assert not np.array_equal( - upscaled_results[i - 1], upscaled_results[i] - ), f"Order {i - 1} and order {i} interpolation should produce different results" + assert not np.array_equal(upscaled_results[0], upscaled_results[i]), ( + "Order 0 and order {i} interpolation should produce different results" + ) + assert not np.array_equal(upscaled_results[i - 1], upscaled_results[i]), ( + f"Order {i - 1} and order {i} interpolation should produce different results" + ) def test_fits_combination_configurations(self, test_config): """Test different configurations of the fits_combination dictionary.""" @@ -1063,3 +1065,35 @@ def test_fits_combination_configurations(self, test_config): assert np.any(img_two[:, :, 0] > 0) # Red channel should have data assert np.any(img_two[:, :, 1] > 0) # Green channel should have data assert np.any(img_two[:, :, 2] > 0) # Blue channel should have data + + @pytest.mark.parametrize( + "norm_method", + [ + NormalisationMethod.CONVERSION_ONLY, + NormalisationMethod.LINEAR, + NormalisationMethod.MIDTONES, + ], + ) + def test_load_and_process_wrapper_normalisation_methods(self, test_config, norm_method): + """Regression test: load_and_process_wrapper must work with all normalisation methods. + + The PIL resize optimization (size=None to fitsbolt) is only safe for + CONVERSION_ONLY. Other methods need fitsbolt's skimage resize to keep + images as float64 for their normalisation pipelines. + """ + target_size = [64, 64] + _update_config(test_config, size=target_size) + test_config.normalisation.normalisation_method = norm_method + test_config.num_workers = 1 + + # Use RGBA gradient images (non-constant channels needed for MIDTONES) + filepaths = [self.complex_rgba_path, self.complex_rgba_path] + results = _load_multiple_images_with_fitsbolt(filepaths, test_config) + + assert len(results) == 2 + for fp, img in results: + assert isinstance(img, np.ndarray) + assert img.shape[:2] == tuple(target_size) + assert img.shape[2] == 3 + assert img.dtype == np.uint8 + assert np.isfinite(img).all() diff --git a/tests/fixmatch_test.py b/tests/unit/test_fixmatch.py similarity index 99% rename from tests/fixmatch_test.py rename to tests/unit/test_fixmatch.py index c426da5..7237b92 100644 --- a/tests/fixmatch_test.py +++ b/tests/unit/test_fixmatch.py @@ -6,6 +6,7 @@ # the terms contained in the file 'LICENCE.txt'. import pytest import torch + from anomaly_match.models.FixMatch import FixMatch @@ -26,7 +27,6 @@ def __getitem__(self, idx): class TestFixMatch: - @pytest.fixture def net_builder(self): """Simple CNN network builder for testing.""" @@ -58,7 +58,6 @@ def fixmatch_model(self, net_builder): T=0.5, p_cutoff=0.95, lambda_u=1.0, - hard_label=True, ) # Set optimizer optimizer = torch.optim.SGD(model.train_model.parameters(), lr=0.01) diff --git a/tests/test_image_io.py b/tests/unit/test_image_io.py similarity index 88% rename from tests/test_image_io.py rename to tests/unit/test_image_io.py index 5b280fb..2092060 100644 --- a/tests/test_image_io.py +++ b/tests/unit/test_image_io.py @@ -5,19 +5,20 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import pytest +import copy import tempfile -import numpy as np -import h5py from pathlib import Path -from PIL import Image + +import h5py +import numpy as np +import pytest import torch -import copy +from PIL import Image from anomaly_match.data_io.load_images import ( + get_fitsbolt_config, load_and_process_single_wrapper, process_single_wrapper, - get_fitsbolt_config, ) from anomaly_match.utils.get_default_cfg import get_default_cfg from prediction_utils import save_results @@ -288,8 +289,8 @@ def test_image_pipeline_integration(self, test_image, test_config): def test_prediction_process_integration(self, test_image, test_config): """Test integration with prediction processes for different image formats.""" import tempfile - from pathlib import Path import time + from pathlib import Path with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -341,6 +342,7 @@ def test_prediction_process_integration(self, test_image, test_config): # Save config with pickle import pickle + from dotmap import DotMap with open(config_path, "wb") as f: @@ -352,47 +354,47 @@ def test_prediction_process_integration(self, test_image, test_config): reloaded_config = DotMap(reloaded_config) # Check that critical fields are present - assert hasattr( - reloaded_config, "normalisation" - ), "normalisation field missing from reloaded config" - assert hasattr( - reloaded_config.normalisation, "fits_extension" - ), "fits_extension field missing from reloaded config.normalisation" - assert hasattr( - reloaded_config.normalisation, "size" - ), "size field missing from reloaded config.normalisation" - assert hasattr( - reloaded_config.normalisation, "normalisation_method" - ), "normalisation_method field missing from reloaded config.normalisation" + assert hasattr(reloaded_config, "normalisation"), ( + "normalisation field missing from reloaded config" + ) + assert hasattr(reloaded_config.normalisation, "fits_extension"), ( + "fits_extension field missing from reloaded config.normalisation" + ) + assert hasattr(reloaded_config.normalisation, "size"), ( + "size field missing from reloaded config.normalisation" + ) + assert hasattr(reloaded_config.normalisation, "normalisation_method"), ( + "normalisation_method field missing from reloaded config.normalisation" + ) # Test image loading with reloaded config for test_file in test_files: try: loaded_image = _load_image_with_fitsbolt(test_file, reloaded_config) - assert ( - loaded_image is not None - ), f"Failed to load {test_file} with reloaded config" - assert isinstance( - loaded_image, np.ndarray - ), f"Loaded image from {test_file} is not a numpy array" - assert ( - loaded_image.ndim == 3 - ), f"Loaded image from {test_file} should be 3D (HWC)" - assert ( - loaded_image.shape[2] == 3 - ), f"Loaded image from {test_file} should have 3 channels" # Ensure no NaN/inf values - assert np.isfinite( - loaded_image - ).all(), f"Image from {test_file} contains NaN or inf values" + assert loaded_image is not None, ( + f"Failed to load {test_file} with reloaded config" + ) + assert isinstance(loaded_image, np.ndarray), ( + f"Loaded image from {test_file} is not a numpy array" + ) + assert loaded_image.ndim == 3, ( + f"Loaded image from {test_file} should be 3D (HWC)" + ) + assert loaded_image.shape[2] == 3, ( + f"Loaded image from {test_file} should have 3 channels" + ) # Ensure no NaN/inf values + assert np.isfinite(loaded_image).all(), ( + f"Image from {test_file} contains NaN or inf values" + ) except Exception as e: pytest.fail(f"Failed to load {test_file} with reloaded config: {e}") def test_image_formats_comprehensive(self, test_image, test_config): """Test comprehensive image format support.""" + import gc import tempfile - from pathlib import Path import time - import gc + from pathlib import Path with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) @@ -427,24 +429,24 @@ def test_image_formats_comprehensive(self, test_image, test_config): if should_succeed: assert loaded_image is not None, f"Failed to load {filename}" - assert isinstance( - loaded_image, np.ndarray - ), f"Loaded {filename} is not a numpy array" + assert isinstance(loaded_image, np.ndarray), ( + f"Loaded {filename} is not a numpy array" + ) # Check dimensions - assert ( - loaded_image.ndim >= 2 - ), f"Loaded {filename} has insufficient dimensions" + assert loaded_image.ndim >= 2, ( + f"Loaded {filename} has insufficient dimensions" + ) # Check data integrity - assert np.isfinite( - loaded_image - ).all(), f"Loaded {filename} contains NaN or inf values" + assert np.isfinite(loaded_image).all(), ( + f"Loaded {filename} contains NaN or inf values" + ) # Check data type - assert ( - loaded_image.dtype == np.uint8 - ), f"Loaded {filename} has wrong dtype: {loaded_image.dtype}" + assert loaded_image.dtype == np.uint8, ( + f"Loaded {filename} has wrong dtype: {loaded_image.dtype}" + ) except Exception as e: if should_succeed: @@ -455,7 +457,7 @@ def test_image_formats_comprehensive(self, test_image, test_config): def test_numpy_to_byte_stream_nan_inf_handling(self): """Test that numpy_to_byte_stream handles NaN and inf values properly.""" - from anomaly_match.utils.numpy_to_byte_stream import numpy_array_to_byte_stream + from anomaly_match_ui.utils.image_utils import numpy_array_to_byte_stream # Test array with NaN values array_with_nan = np.array([[1.0, 2.0, np.nan], [4.0, 5.0, 6.0]], dtype=np.float32) diff --git a/tests/import_test.py b/tests/unit/test_import.py similarity index 100% rename from tests/import_test.py rename to tests/unit/test_import.py diff --git a/tests/unit/test_label.py b/tests/unit/test_label.py new file mode 100644 index 0000000..5df1f34 --- /dev/null +++ b/tests/unit/test_label.py @@ -0,0 +1,33 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Tests for the Label enum.""" + +from anomaly_match.datasets.Label import Label + + +class TestLabel: + def test_label_values(self): + assert Label.UNKNOWN == -1 + assert Label.NORMAL == 0 + assert Label.ANOMALY == 1 + + def test_label_is_int(self): + assert isinstance(Label.NORMAL, int) + assert isinstance(Label.ANOMALY, int) + assert isinstance(Label.UNKNOWN, int) + + def test_label_from_value(self): + assert Label(-1) == Label.UNKNOWN + assert Label(0) == Label.NORMAL + assert Label(1) == Label.ANOMALY + + def test_label_members(self): + members = list(Label) + assert len(members) == 3 + assert Label.UNKNOWN in members + assert Label.NORMAL in members + assert Label.ANOMALY in members diff --git a/tests/unit/test_losses.py b/tests/unit/test_losses.py new file mode 100644 index 0000000..6b9343a --- /dev/null +++ b/tests/unit/test_losses.py @@ -0,0 +1,137 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Tests for accuracy, cross_entropy_loss, and consistency_loss.""" + +import pytest +import torch + +from anomaly_match.utils.accuracy import accuracy +from anomaly_match.utils.consistency_loss import consistency_loss +from anomaly_match.utils.cross_entropy_loss import cross_entropy_loss + + +class TestAccuracy: + def test_perfect_predictions(self): + output = torch.tensor([[10.0, -10.0], [-10.0, 10.0], [10.0, -10.0]]) + target = torch.tensor([0, 1, 0]) + [top1] = accuracy(output, target, topk=(1,)) + assert top1.item() == 100.0 + + def test_zero_accuracy(self): + output = torch.tensor([[10.0, -10.0], [-10.0, 10.0]]) + target = torch.tensor([1, 0]) + [top1] = accuracy(output, target, topk=(1,)) + assert top1.item() == 0.0 + + def test_partial_accuracy(self): + output = torch.tensor([[10.0, -10.0], [-10.0, 10.0], [10.0, -10.0], [-10.0, 10.0]]) + target = torch.tensor([0, 1, 1, 0]) + [top1] = accuracy(output, target, topk=(1,)) + assert top1.item() == 50.0 + + def test_batch_size_one(self): + output = torch.tensor([[5.0, 1.0]]) + target = torch.tensor([0]) + [top1] = accuracy(output, target, topk=(1,)) + assert top1.item() == 100.0 + + +class TestCrossEntropyLoss: + def test_hard_labels_shape(self): + logits = torch.randn(4, 2) + targets = torch.tensor([0, 1, 0, 1]) + loss = cross_entropy_loss(logits, targets, use_hard_labels=True, reduction="none") + assert loss.shape == (4,) + + def test_hard_labels_mean_reduction(self): + logits = torch.randn(4, 2) + targets = torch.tensor([0, 1, 0, 1]) + loss = cross_entropy_loss(logits, targets, use_hard_labels=True, reduction="mean") + assert loss.ndim == 0 # scalar + + def test_hard_labels_non_negative(self): + logits = torch.randn(4, 2) + targets = torch.tensor([0, 1, 0, 1]) + loss = cross_entropy_loss(logits, targets, use_hard_labels=True, reduction="none") + assert (loss >= 0).all() + + def test_soft_labels_shape(self): + logits = torch.randn(4, 2) + targets = torch.tensor([[0.9, 0.1], [0.1, 0.9], [0.7, 0.3], [0.3, 0.7]]) + loss = cross_entropy_loss(logits, targets, use_hard_labels=False) + assert loss.shape == (4,) + + def test_soft_labels_non_negative(self): + logits = torch.randn(4, 2) + targets = torch.softmax(torch.randn(4, 2), dim=-1) + loss = cross_entropy_loss(logits, targets, use_hard_labels=False) + assert (loss >= 0).all() + + def test_soft_labels_shape_mismatch_raises(self): + logits = torch.randn(4, 2) + targets = torch.randn(4, 3) + with pytest.raises(AssertionError): + cross_entropy_loss(logits, targets, use_hard_labels=False) + + +class TestConsistencyLoss: + def test_l2_loss(self): + logits_w = torch.randn(4, 2) + logits_s = torch.randn(4, 2) + loss = consistency_loss(logits_w, logits_s, name="L2") + assert loss.ndim == 0 # scalar + assert loss >= 0 + + def test_ce_loss_returns_tuple(self): + logits_w = torch.randn(4, 2) + logits_s = torch.randn(4, 2) + result = consistency_loss(logits_w, logits_s, name="ce", p_cutoff=0.0) + assert isinstance(result, tuple) + assert len(result) == 2 + masked_loss, mask_ratio = result + assert masked_loss.ndim == 0 + assert mask_ratio.ndim == 0 + + def test_ce_hard_labels(self): + logits_w = torch.randn(4, 2) + logits_s = torch.randn(4, 2) + masked_loss, mask_ratio = consistency_loss( + logits_w, logits_s, name="ce", use_hard_labels=True, p_cutoff=0.0 + ) + assert masked_loss >= 0 + assert mask_ratio == 1.0 # p_cutoff=0 means all samples pass + + def test_ce_soft_labels(self): + logits_w = torch.randn(4, 2) + logits_s = torch.randn(4, 2) + masked_loss, mask_ratio = consistency_loss( + logits_w, logits_s, name="ce", use_hard_labels=False, p_cutoff=0.0 + ) + assert masked_loss >= 0 + assert mask_ratio == 1.0 + + def test_ce_high_cutoff_masks_all(self): + # Uniform logits = 0.5 probability, cutoff at 0.99 should mask all + logits_w = torch.zeros(4, 2) + logits_s = torch.randn(4, 2) + masked_loss, mask_ratio = consistency_loss(logits_w, logits_s, name="ce", p_cutoff=0.99) + assert mask_ratio == 0.0 + + def test_ce_detaches_weak_logits(self): + logits_w = torch.randn(4, 2, requires_grad=True) + logits_s = torch.randn(4, 2, requires_grad=True) + masked_loss, _ = consistency_loss(logits_w, logits_s, name="ce", p_cutoff=0.0) + masked_loss.backward() + # Gradient should only flow through logits_s, not logits_w + assert logits_w.grad is None + assert logits_s.grad is not None + + def test_invalid_loss_name_raises(self): + logits_w = torch.randn(4, 2) + logits_s = torch.randn(4, 2) + with pytest.raises(AssertionError): + consistency_loss(logits_w, logits_s, name="invalid") diff --git a/tests/metadata_test.py b/tests/unit/test_metadata.py similarity index 99% rename from tests/metadata_test.py rename to tests/unit/test_metadata.py index f9c381f..0245a72 100644 --- a/tests/metadata_test.py +++ b/tests/unit/test_metadata.py @@ -8,11 +8,12 @@ import os import shutil import tempfile + +import numpy as np import pandas as pd import pytest -import numpy as np -from PIL import Image from dotmap import DotMap +from PIL import Image from anomaly_match.datasets.AnomalyDetectionDataset import AnomalyDetectionDataset from anomaly_match.pipeline.session import Session diff --git a/tests/metadata_handler_test.py b/tests/unit/test_metadata_handler.py similarity index 99% rename from tests/metadata_handler_test.py rename to tests/unit/test_metadata_handler.py index 9e201b8..b0f56ea 100644 --- a/tests/metadata_handler_test.py +++ b/tests/unit/test_metadata_handler.py @@ -7,6 +7,7 @@ import os import tempfile + import numpy as np from anomaly_match.data_io.metadata_handler import MetadataHandler diff --git a/tests/unit/test_multispectral.py b/tests/unit/test_multispectral.py new file mode 100644 index 0000000..cc295af --- /dev/null +++ b/tests/unit/test_multispectral.py @@ -0,0 +1,274 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Unit tests for multispectral (N-channel) image support.""" + +import numpy as np +import torch + +from anomaly_match.datasets.AnomalyDetectionDataset import AnomalyDetectionDataset +from anomaly_match.datasets.BasicDataset import BasicDataset +from anomaly_match.datasets.SSL_Dataset import SSL_Dataset +from anomaly_match.image_processing.transforms import ( + get_prediction_transforms, + get_strong_transforms, + get_weak_transforms, +) + + +class TestMultispectralDataset: + """Tests for multispectral dataset loading and handling.""" + + def test_dataset_loads_4_channel_images(self, multispectral_config): + """Test that AnomalyDetectionDataset correctly loads 4-channel images.""" + dataset = AnomalyDetectionDataset(multispectral_config) + + assert dataset is not None + assert dataset.num_channels == 4, f"Expected 4 channels, got {dataset.num_channels}" + assert dataset.size == multispectral_config.normalisation.image_size + + # Check that images in data_dict have correct shape + for filename, (image, label) in dataset.data_dict.items(): + assert image.shape[-1] == 4, f"Image {filename} has wrong channels: {image.shape}" + assert image.shape[:2] == tuple(dataset.size), ( + f"Image {filename} has wrong size: {image.shape}" + ) + + def test_channel_auto_detection(self, multispectral_config): + """Test that n_output_channels is auto-detected from 4-channel images.""" + # Config starts at default (3), should be updated after dataset load + assert multispectral_config.normalisation.n_output_channels == 3 + + dataset = AnomalyDetectionDataset(multispectral_config) + + # After loading, config should be updated to 4 + assert multispectral_config.normalisation.n_output_channels == 4 + assert dataset.num_channels == 4 + + def test_dataset_mean_std_4_channels(self, multispectral_config): + """Test that mean and std are computed correctly for 4 channels.""" + dataset = AnomalyDetectionDataset(multispectral_config) + + assert len(dataset.mean) == 4, f"Mean should have 4 values, got {len(dataset.mean)}" + assert len(dataset.std) == 4, f"Std should have 4 values, got {len(dataset.std)}" + + # Check that values are reasonable (between 0 and 1 for normalized) + assert all(0 <= m <= 1 for m in dataset.mean), f"Mean values out of range: {dataset.mean}" + assert all(0 <= s <= 1 for s in dataset.std), f"Std values out of range: {dataset.std}" + + def test_ssl_dataset_4_channels(self, multispectral_config): + """Test SSL_Dataset initialization with 4-channel data.""" + ssl_dataset = SSL_Dataset(cfg=multispectral_config, train=True) + assert ssl_dataset.num_classes == 2 + + # Auto-detection happens when dataset is loaded (lazy) + labeled_dataset, unlabeled_dataset = ssl_dataset.get_ssl_dset() + assert labeled_dataset is not None + assert unlabeled_dataset is not None + assert multispectral_config.normalisation.n_output_channels == 4 + + +class TestMultispectralAugmentation: + """Tests for multispectral augmentation compatibility.""" + + def test_basic_dataset_4_channel_no_transform(self, multispectral_config): + """Test BasicDataset works with 4-channel data without transforms.""" + # Create sample 4-channel data + imgs = np.random.randint(0, 255, (5, 64, 64, 4), dtype=np.uint8) + filenames = [f"img_{i}.npy" for i in range(5)] + targets = [0, 1, 0, 1, 0] + + transform = get_prediction_transforms(num_channels=4) + dataset = BasicDataset( + imgs, filenames, targets, num_classes=2, transform=transform, num_channels=4 + ) + + img, target, filename = dataset[0] + + assert isinstance(img, torch.Tensor) + assert img.shape[0] == 4, f"Expected 4 channels, got shape {img.shape}" + assert img.shape[1:] == (64, 64), f"Expected (64, 64) spatial, got {img.shape}" + + def test_basic_dataset_4_channel_weak_transform(self, multispectral_config): + """Test BasicDataset works with 4-channel data and weak transforms.""" + imgs = np.random.randint(0, 255, (5, 64, 64, 4), dtype=np.uint8) + filenames = [f"img_{i}.npy" for i in range(5)] + targets = [0, 1, 0, 1, 0] + + transform = get_weak_transforms(num_channels=4) + dataset = BasicDataset( + imgs, filenames, targets, num_classes=2, transform=transform, num_channels=4 + ) + + img, target, filename = dataset[0] + + assert isinstance(img, torch.Tensor) + assert img.shape[0] == 4, f"Expected 4 channels after weak transform, got {img.shape}" + + def test_basic_dataset_4_channel_strong_transform(self, multispectral_config): + """Test BasicDataset works with 4-channel data and strong transforms.""" + imgs = np.random.randint(0, 255, (5, 64, 64, 4), dtype=np.uint8) + filenames = [f"img_{i}.npy" for i in range(5)] + targets = [0, 1, 0, 1, 0] + + weak_transform = get_weak_transforms(num_channels=4) + strong_transform = get_strong_transforms(num_channels=4) + + dataset = BasicDataset( + imgs, + filenames, + targets, + num_classes=2, + transform=weak_transform, + use_strong_transform=True, + strong_transform=strong_transform, + num_channels=4, + ) + + weak_img, strong_img, target = dataset[0] + + assert isinstance(weak_img, torch.Tensor) + assert isinstance(strong_img, torch.Tensor) + assert weak_img.shape[0] == 4, f"Weak augmented should have 4 channels: {weak_img.shape}" + assert strong_img.shape[0] == 4, ( + f"Strong augmented should have 4 channels: {strong_img.shape}" + ) + + +class TestMultispectralModel: + """Tests for multispectral model architecture.""" + + def test_network_builder_4_channels(self): + """Test that network builder creates model accepting 4 channels.""" + from anomaly_match.utils.get_net_builder import get_net_builder + + net_builder = get_net_builder("efficientnet-lite0", pretrained=True, in_channels=4) + model = net_builder(num_classes=2, in_channels=4) + + # Set to eval mode to avoid batch norm issues with batch size 1 + model.eval() + + # Create dummy input with 4 channels + dummy_input = torch.randn(1, 4, 64, 64) + with torch.no_grad(): + output = model(dummy_input) + + assert output.shape == (1, 2), f"Expected output shape (1, 2), got {output.shape}" + + def test_network_builder_pretrained_weight_transfer(self): + """Test that pretrained weights are adapted for 4-channel input by timm.""" + from anomaly_match.utils.get_net_builder import get_net_builder + + # Get 4-channel pretrained model + net_builder_4ch = get_net_builder("efficientnet-lite0", pretrained=True, in_channels=4) + model_4ch = net_builder_4ch(num_classes=2, in_channels=4) + + # Verify conv_stem was adapted to 4 input channels + assert model_4ch.conv_stem.weight.data.shape[1] == 4, ( + "conv_stem should have 4 input channels" + ) + + # Verify weights are not all zeros (pretrained weights were adapted, not just zero-padded) + assert model_4ch.conv_stem.weight.data.abs().sum() > 0, "Adapted weights should be non-zero" + + +class TestMultispectralPrediction: + """Tests for multispectral prediction pipeline.""" + + def test_model_prediction_4_channels(self, multispectral_config): + """Test model prediction with 4-channel images.""" + from anomaly_match.models.FixMatch import FixMatch + from anomaly_match.utils.get_net_builder import get_net_builder + + # Build and initialize model + net_builder = get_net_builder( + multispectral_config.net, + pretrained=multispectral_config.pretrained, + in_channels=4, + ) + model = FixMatch( + net_builder=net_builder, + num_classes=2, + in_channels=4, + ema_m=multispectral_config.ema_m, + T=multispectral_config.temperature, + p_cutoff=multispectral_config.p_cutoff, + lambda_u=multispectral_config.ulb_loss_ratio, + ) + + # Create test batch + test_batch = torch.randn(4, 4, 64, 64) # (batch, channels, H, W) + + # Get predictions + model.eval_model.eval() + with torch.no_grad(): + predictions = model.eval_model(test_batch) + + assert predictions.shape == (4, 2), f"Expected (4, 2), got {predictions.shape}" + # Check predictions are valid probabilities after softmax + probs = torch.softmax(predictions, dim=1) + assert torch.allclose(probs.sum(dim=1), torch.ones(4)), "Probabilities should sum to 1" + + +class TestMultispectralDisplay: + """Tests for multispectral display functionality.""" + + def test_prepare_4_channel_for_display(self): + """Test converting 4-channel image to RGB for display.""" + from anomaly_match_ui.utils.display_transforms import prepare_for_display + + # Create 4-channel image + img_4ch = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8) + + # Convert to displayable RGB + rgb_img = prepare_for_display(img_4ch) + + assert rgb_img.shape == (64, 64, 3), f"Expected RGB output, got {rgb_img.shape}" + assert rgb_img.dtype == np.uint8 + + def test_prepare_4_channel_with_rgb_mapping(self): + """Test converting 4-channel image with custom RGB channel mapping.""" + from anomaly_match_ui.utils.display_transforms import prepare_for_display + + # Create 4-channel image + img_4ch = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8) + + # Use channels 0, 2, 3 as RGB + rgb_mapping = [0, 2, 3] + rgb_img = prepare_for_display(img_4ch, rgb_mapping=rgb_mapping) + + assert rgb_img.shape == (64, 64, 3) + # Verify mapping is correct + assert np.array_equal(rgb_img[:, :, 0], img_4ch[:, :, 0]) + assert np.array_equal(rgb_img[:, :, 1], img_4ch[:, :, 2]) + assert np.array_equal(rgb_img[:, :, 2], img_4ch[:, :, 3]) + + +class TestMultispectralHDF5: + """Tests for HDF5 save/load with multispectral data.""" + + def test_hdf5_save_load_4_channels(self, multispectral_config, tmp_path): + """Test saving and loading 4-channel dataset to/from HDF5.""" + dataset = AnomalyDetectionDataset(multispectral_config) + + # Save to HDF5 + hdf5_path = tmp_path / "test_4ch.hdf5" + dataset.save_as_hdf5(str(hdf5_path)) + + assert hdf5_path.exists() + + # Load back + new_dataset = AnomalyDetectionDataset(multispectral_config) + new_dataset.load_from_hdf5(str(hdf5_path)) + + # Verify data integrity + assert len(new_dataset.data_dict) == len(dataset.data_dict) + for filename in dataset.data_dict: + if filename in new_dataset.data_dict: + orig_img, _ = dataset.data_dict[filename] + loaded_img, _ = new_dataset.data_dict[filename] + assert orig_img.shape == loaded_img.shape + assert np.array_equal(orig_img, loaded_img) diff --git a/tests/unit/test_prediction_utils.py b/tests/unit/test_prediction_utils.py new file mode 100644 index 0000000..97f024a --- /dev/null +++ b/tests/unit/test_prediction_utils.py @@ -0,0 +1,117 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. + +import numpy as np +import pytest +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod + +from anomaly_match.utils.get_default_cfg import get_default_cfg +from prediction_utils import convert_cutana_cutout, create_cutana_format_cfg + + +@pytest.fixture +def format_cfg(): + """Create a CONVERSION_ONLY format config for testing.""" + cfg = get_default_cfg() + cfg.normalisation.image_size = [64, 64] + cfg.normalisation.n_output_channels = 3 + cfg.normalisation.normalisation_method = NormalisationMethod.CONVERSION_ONLY + cfg.num_workers = 0 + return create_cutana_format_cfg(cfg) + + +class TestCreateCutanaFormatCfg: + """Tests for create_cutana_format_cfg.""" + + def test_returns_config_with_conversion_only(self): + cfg = get_default_cfg() + cfg.normalisation.image_size = [64, 64] + cfg.normalisation.n_output_channels = 3 + cfg.num_workers = 0 + result = create_cutana_format_cfg(cfg) + + assert result.fitsbolt_cfg.normalisation_method == NormalisationMethod.CONVERSION_ONLY + + def test_preserves_image_size(self): + cfg = get_default_cfg() + cfg.normalisation.image_size = [128, 128] + cfg.normalisation.n_output_channels = 3 + cfg.num_workers = 0 + result = create_cutana_format_cfg(cfg) + + assert result.fitsbolt_cfg.size == [128, 128] + + def test_preserves_output_channels(self): + cfg = get_default_cfg() + cfg.normalisation.image_size = [64, 64] + cfg.normalisation.n_output_channels = 3 + cfg.num_workers = 0 + result = create_cutana_format_cfg(cfg) + + assert result.fitsbolt_cfg.n_output_channels == 3 + + +class TestConvertCutanaCutout: + """Tests for convert_cutana_cutout CHW/HWC handling and format conversion.""" + + def test_hwc_input_preserved(self, format_cfg): + """HWC input should pass through without transpose.""" + image = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + result = convert_cutana_cutout(image, format_cfg) + + assert result.shape[0] == 64 # H + assert result.shape[1] == 64 # W + + def test_chw_input_transposed_to_hwc(self, format_cfg): + """CHW input should be transposed to HWC before processing.""" + # Create a CHW image (3, 64, 64) with distinct channel values + image = np.zeros((3, 64, 64), dtype=np.uint8) + image[0] = 100 # R channel + image[1] = 150 # G channel + image[2] = 200 # B channel + + result = convert_cutana_cutout(image, format_cfg) + + # Result should be HWC + assert result.ndim == 3 + assert result.shape[2] == 3 # channels last + + def test_single_channel_chw_transposed(self, format_cfg): + """Single-channel CHW (1, H, W) should be detected and transposed.""" + image = np.random.randint(0, 255, (1, 64, 64), dtype=np.uint8) + result = convert_cutana_cutout(image, format_cfg) + + assert result.ndim == 3 + assert result.shape[2] == 3 # replicated to 3 channels by fitsbolt + + def test_list_input_converted_to_array(self, format_cfg): + """Non-ndarray input should be converted before processing.""" + image_list = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8).tolist() + result = convert_cutana_cutout(image_list, format_cfg) + + assert isinstance(result, np.ndarray) + + def test_float_input_converted_to_uint8(self, format_cfg): + """Float input (from cutana normalisation) should be converted to uint8.""" + image = np.random.random((64, 64, 3)).astype(np.float32) + result = convert_cutana_cutout(image, format_cfg) + + assert result.dtype == np.uint8 + + def test_output_matches_configured_size(self): + """Output should be resized to the configured image_size.""" + cfg = get_default_cfg() + cfg.normalisation.image_size = [32, 32] + cfg.normalisation.n_output_channels = 3 + cfg.num_workers = 0 + small_cfg = create_cutana_format_cfg(cfg) + + image = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + result = convert_cutana_cutout(image, small_cfg) + + assert result.shape[0] == 32 + assert result.shape[1] == 32 diff --git a/tests/test_session_io_handler.py b/tests/unit/test_session_io_handler.py similarity index 99% rename from tests/test_session_io_handler.py rename to tests/unit/test_session_io_handler.py index 0d6e47b..f0f4e1e 100644 --- a/tests/test_session_io_handler.py +++ b/tests/unit/test_session_io_handler.py @@ -5,15 +5,16 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import pytest -import tempfile -import shutil import json import pickle -import pandas as pd +import shutil +import tempfile from pathlib import Path from unittest.mock import patch +import pandas as pd +import pytest + from anomaly_match.data_io.SessionIOHandler import SessionIOHandler, print_session from anomaly_match.pipeline.SessionTracker import SessionTracker diff --git a/tests/test_session_tracker.py b/tests/unit/test_session_tracker.py similarity index 99% rename from tests/test_session_tracker.py rename to tests/unit/test_session_tracker.py index 7379e04..594c49c 100644 --- a/tests/test_session_tracker.py +++ b/tests/unit/test_session_tracker.py @@ -5,12 +5,13 @@ # this file, may be copied, modified, propagated, or distributed except according to # the terms contained in the file 'LICENCE.txt'. -import pytest import datetime from unittest.mock import patch + import pandas as pd +import pytest -from anomaly_match.pipeline.SessionTracker import SessionTracker, IterationInfo +from anomaly_match.pipeline.SessionTracker import IterationInfo, SessionTracker class TestIterationInfo: diff --git a/tests/unit/test_shorten_filename.py b/tests/unit/test_shorten_filename.py new file mode 100644 index 0000000..61ca24e --- /dev/null +++ b/tests/unit/test_shorten_filename.py @@ -0,0 +1,60 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +from anomaly_match_ui.widget import shorten_filename + + +class TestShortenFilename: + """Tests for the shorten_filename helper function.""" + + def test_short_filename_unchanged(self): + """Filenames within max length should remain unchanged.""" + assert shorten_filename("short.fits", max_length=25) == "short.fits" + assert shorten_filename("image.jpg", max_length=25) == "image.jpg" + + def test_long_filename_shortened(self): + """Long filenames should be shortened to max_length.""" + long_name = "very_long_filename_that_exceeds_limit.fits" + result = shorten_filename(long_name, max_length=25) + assert len(result) <= 25 + assert result.endswith(".fits") + assert "..." in result + + def test_filename_with_multiple_dots(self): + """Filenames with multiple dots should preserve only the extension.""" + name = "image.2024.01.15.observation.fits" + result = shorten_filename(name, max_length=25) + assert len(result) <= 25 + assert result.endswith(".fits") + assert "..." in result + + def test_filename_without_extension(self): + """Filenames without extension should still be shortened correctly.""" + name = "very_long_filename_without_any_extension" + result = shorten_filename(name, max_length=25) + assert len(result) <= 25 + assert "..." in result + + def test_exact_max_length(self): + """Filename exactly at max_length should be unchanged.""" + name = "exactly_25_chars_long.fit" + assert len(name) == 25 + assert shorten_filename(name, max_length=25) == name + + def test_very_short_max_length(self): + """Very short max_length should still produce valid output.""" + name = "some_filename.fits" + result = shorten_filename(name, max_length=10) + assert len(result) <= 10 + assert "..." in result + + def test_preserves_start_and_end(self): + """Shortened name should contain parts of the original start and end.""" + name = "START_middle_content_END.fits" + result = shorten_filename(name, max_length=20) + assert result.startswith("START") + # Should contain some part of the end before the extension + assert "END" in result or "..." in result diff --git a/tests/test_toml_config.py b/tests/unit/test_toml_config.py similarity index 97% rename from tests/test_toml_config.py rename to tests/unit/test_toml_config.py index b057a57..fe9a4d5 100644 --- a/tests/test_toml_config.py +++ b/tests/unit/test_toml_config.py @@ -6,15 +6,16 @@ # the terms contained in the file 'LICENCE.txt'. import tempfile -import toml from pathlib import Path + +import toml from dotmap import DotMap +from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod from anomaly_match.data_io.save_config import ( - save_config_toml, _convert_enum_to_string, + save_config_toml, ) -from fitsbolt.normalisation.NormalisationMethod import NormalisationMethod class TestTOMLConfigUtils: @@ -134,6 +135,6 @@ def test_config_types_and_structure(self): assert isinstance(config.name, str) # Verify image_size is NOT in default config (user must set it) - assert ( - "image_size" not in config.normalisation - ), "image_size should not have a default value" + assert "image_size" not in config.normalisation, ( + "image_size should not have a default value" + ) diff --git a/tests/unit/test_transforms.py b/tests/unit/test_transforms.py new file mode 100644 index 0000000..f1de876 --- /dev/null +++ b/tests/unit/test_transforms.py @@ -0,0 +1,139 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Tests for image transformation pipelines.""" + +import numpy as np +import torch + +from anomaly_match.image_processing.transforms import ( + NumpyRandomHorizontalFlip, + NumpyRandomTranslate, + NumpyToTensor, + get_prediction_transforms, + get_strong_transforms, + get_weak_transforms, +) + + +class TestNumpyToTensor: + def test_converts_hwc_to_chw(self): + img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + tensor = NumpyToTensor()(img) + assert isinstance(tensor, torch.Tensor) + assert tensor.shape == (3, 64, 64) + + def test_normalizes_to_float(self): + img = np.full((32, 32, 3), 255, dtype=np.uint8) + tensor = NumpyToTensor()(img) + assert tensor.dtype == torch.float32 + assert torch.allclose(tensor, torch.ones(3, 32, 32)) + + def test_handles_4_channels(self): + img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8) + tensor = NumpyToTensor()(img) + assert tensor.shape == (4, 64, 64) + + def test_passthrough_non_numpy(self): + tensor = torch.randn(3, 64, 64) + result = NumpyToTensor()(tensor) + assert torch.equal(result, tensor) + + +class TestNumpyRandomHorizontalFlip: + def test_output_shape_preserved(self): + img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + result = NumpyRandomHorizontalFlip(p=1.0)(img) + assert result.shape == img.shape + + def test_flip_with_p_one(self): + img = np.zeros((4, 4, 1), dtype=np.uint8) + img[0, 0, 0] = 255 + result = NumpyRandomHorizontalFlip(p=1.0)(img) + assert result[0, 3, 0] == 255 + + def test_no_flip_with_p_zero(self): + img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + result = NumpyRandomHorizontalFlip(p=0.0)(img) + assert np.array_equal(result, img) + + def test_handles_tensor(self): + tensor = torch.randn(3, 64, 64) + result = NumpyRandomHorizontalFlip(p=1.0)(tensor) + assert isinstance(result, torch.Tensor) + assert result.shape == tensor.shape + + +class TestNumpyRandomTranslate: + def test_output_shape_preserved(self): + img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + result = NumpyRandomTranslate()(img) + assert result.shape == img.shape + + def test_handles_tensor(self): + tensor = torch.randn(3, 64, 64) + result = NumpyRandomTranslate()(tensor) + assert isinstance(result, torch.Tensor) + assert result.shape == tensor.shape + + +class TestGetWeakTransforms: + def test_rgb_returns_compose(self): + transform = get_weak_transforms(num_channels=3) + assert transform is not None + + def test_rgb_output_tensor(self): + transform = get_weak_transforms(num_channels=3) + img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + result = transform(img) + assert isinstance(result, torch.Tensor) + assert result.shape[0] == 3 + + def test_multispectral_output_tensor(self): + transform = get_weak_transforms(num_channels=4) + img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8) + result = transform(img) + assert isinstance(result, torch.Tensor) + assert result.shape[0] == 4 + + +class TestGetPredictionTransforms: + def test_rgb_output_tensor(self): + transform = get_prediction_transforms(num_channels=3) + img = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) + result = transform(img) + assert isinstance(result, torch.Tensor) + assert result.shape == (3, 64, 64) + + def test_multispectral_output_tensor(self): + transform = get_prediction_transforms(num_channels=4) + img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8) + result = transform(img) + assert isinstance(result, torch.Tensor) + assert result.shape == (4, 64, 64) + + +class TestGetStrongTransforms: + def test_rgb_returns_compose(self): + transform = get_strong_transforms(num_channels=3) + assert transform is not None + + def test_rgb_output_tensor(self): + from PIL import Image + + transform = get_strong_transforms(num_channels=3) + # RandAugment for RGB expects PIL Image input + img = Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)) + result = transform(img) + assert isinstance(result, torch.Tensor) + assert result.shape[0] == 3 + + def test_multispectral_output_tensor(self): + transform = get_strong_transforms(num_channels=4) + img = np.random.randint(0, 255, (64, 64, 4), dtype=np.uint8) + result = transform(img) + assert isinstance(result, torch.Tensor) + assert result.shape[0] == 4 diff --git a/tests/utils_test.py b/tests/unit/test_utils.py similarity index 99% rename from tests/utils_test.py rename to tests/unit/test_utils.py index e80d8eb..ab7aac3 100644 --- a/tests/utils_test.py +++ b/tests/unit/test_utils.py @@ -6,9 +6,10 @@ # the terms contained in the file 'LICENCE.txt'. import torch import torch.optim as optim + from anomaly_match.utils.get_cosine_schedule_with_warmup import get_cosine_schedule_with_warmup -from anomaly_match.utils.get_optimizer import get_optimizer from anomaly_match.utils.get_net_builder import get_net_builder +from anomaly_match.utils.get_optimizer import get_optimizer from anomaly_match.utils.set_seeds import set_seeds diff --git a/tests/unit/test_validate_config.py b/tests/unit/test_validate_config.py new file mode 100644 index 0000000..0af0adc --- /dev/null +++ b/tests/unit/test_validate_config.py @@ -0,0 +1,285 @@ +# Copyright (c) European Space Agency, 2025. +# +# This file is subject to the terms and conditions defined in file 'LICENCE.txt', which +# is part of this source code package. No part of the package, including +# this file, may be copied, modified, propagated, or distributed except according to +# the terms contained in the file 'LICENCE.txt'. +"""Tests for configuration validation edge cases.""" + +import pytest +from loguru import logger + +from anomaly_match.utils.get_default_cfg import get_default_cfg +from anomaly_match.utils.validate_config import ( + _get_all_keys, + _get_nested_value, + validate_config, +) + + +@pytest.fixture +def caplog(caplog): + """Configure loguru to use the caplog handler.""" + handler_id = logger.add(caplog.handler) + yield caplog + logger.remove(handler_id) + + +@pytest.fixture +def valid_cfg(): + """Return a valid default config with image_size set.""" + cfg = get_default_cfg() + cfg.normalisation.image_size = [64, 64] + return cfg + + +class TestGetNestedValue: + def test_simple_key(self, valid_cfg): + assert _get_nested_value(valid_cfg, "seed") == 42 + + def test_nested_key(self, valid_cfg): + assert _get_nested_value(valid_cfg, "normalisation.image_size") == [64, 64] + + def test_missing_key_raises(self, valid_cfg): + with pytest.raises(ValueError, match="Missing key in config"): + _get_nested_value(valid_cfg, "nonexistent.key") + + +class TestGetAllKeys: + def test_returns_top_level_keys(self, valid_cfg): + keys = _get_all_keys(valid_cfg) + assert "seed" in keys + assert "batch_size" in keys + assert "net" in keys + + def test_returns_nested_keys(self, valid_cfg): + keys = _get_all_keys(valid_cfg) + assert "normalisation" in keys + assert "normalisation.image_size" in keys + assert "normalisation.n_output_channels" in keys + + +class TestValidateConfigRequired: + def test_missing_required_string(self, valid_cfg): + del valid_cfg["name"] + with pytest.raises(ValueError, match="Missing required parameter"): + validate_config(valid_cfg) + + def test_missing_required_integer(self, valid_cfg): + del valid_cfg["batch_size"] + with pytest.raises(ValueError, match="Missing required parameter"): + validate_config(valid_cfg) + + +class TestValidateConfigTypes: + def test_string_type_mismatch(self, valid_cfg): + valid_cfg.name = 123 + with pytest.raises(ValueError, match="must be a string"): + validate_config(valid_cfg) + + def test_int_type_mismatch(self, valid_cfg): + valid_cfg.batch_size = "not_an_int" + with pytest.raises(ValueError, match="must be an integer"): + validate_config(valid_cfg) + + def test_float_type_mismatch(self, valid_cfg): + valid_cfg.test_ratio = "not_a_float" + with pytest.raises(ValueError, match="must be a number"): + validate_config(valid_cfg) + + def test_bool_type_mismatch(self, valid_cfg): + valid_cfg.pin_memory = "not_a_bool" + with pytest.raises(ValueError, match="must be a boolean"): + validate_config(valid_cfg) + + +class TestValidateConfigRanges: + def test_int_below_minimum(self, valid_cfg): + valid_cfg.batch_size = 0 + with pytest.raises(ValueError, match="must be >= 1"): + validate_config(valid_cfg) + + def test_float_below_minimum(self, valid_cfg): + valid_cfg.test_ratio = -0.1 + with pytest.raises(ValueError, match="must be >= 0.0"): + validate_config(valid_cfg) + + def test_float_above_maximum(self, valid_cfg): + valid_cfg.test_ratio = 1.5 + with pytest.raises(ValueError, match="must be <= 1.0"): + validate_config(valid_cfg) + + def test_n_to_load_below_minimum(self, valid_cfg): + valid_cfg.N_to_load = 5 + with pytest.raises(ValueError, match="must be >= 10"): + validate_config(valid_cfg) + + +class TestValidateConfigAllowedValues: + def test_invalid_optimizer(self, valid_cfg): + valid_cfg.opt = "RMSProp" + with pytest.raises(ValueError, match="must be one of"): + validate_config(valid_cfg) + + def test_invalid_net(self, valid_cfg): + valid_cfg.net = "resnet50" + with pytest.raises(ValueError, match="must be one of"): + validate_config(valid_cfg) + + def test_valid_optimizer_sgd(self, valid_cfg): + valid_cfg.opt = "SGD" + validate_config(valid_cfg) + + def test_valid_optimizer_adam(self, valid_cfg): + valid_cfg.opt = "Adam" + validate_config(valid_cfg) + + +class TestValidateConfigSpecialTypes: + def test_invalid_image_size_not_list(self, valid_cfg): + valid_cfg.normalisation.image_size = 64 + with pytest.raises(ValueError, match="must be a list or tuple of length 2"): + validate_config(valid_cfg) + + def test_invalid_image_size_wrong_length(self, valid_cfg): + valid_cfg.normalisation.image_size = [64, 64, 64] + with pytest.raises(ValueError, match="must be a list or tuple of length 2"): + validate_config(valid_cfg) + + def test_invalid_eval_iter(self, valid_cfg): + valid_cfg.num_eval_iter = 0 + with pytest.raises(ValueError, match="must be an integer > 0 or -1"): + validate_config(valid_cfg) + + def test_valid_eval_iter_negative_one(self, valid_cfg): + valid_cfg.num_eval_iter = -1 + validate_config(valid_cfg) + + def test_valid_eval_iter_positive(self, valid_cfg): + valid_cfg.num_eval_iter = 10 + validate_config(valid_cfg) + + def test_normalisation_not_dotmap(self, valid_cfg): + valid_cfg.normalisation = "not_a_dotmap" + with pytest.raises(ValueError, match="must be a DotMap"): + validate_config(valid_cfg) + + +class TestValidateConfigPaths: + def test_skip_path_checks(self, valid_cfg): + valid_cfg.data_dir = "/nonexistent/path" + validate_config(valid_cfg, check_paths=False) + + def test_invalid_directory_path(self, valid_cfg): + valid_cfg.data_dir = "/definitely/nonexistent/path" + with pytest.raises(ValueError, match="directory does not exist"): + validate_config(valid_cfg, check_paths=True) + + def test_invalid_file_path(self, valid_cfg): + valid_cfg.label_file = "/nonexistent/file.csv" + with pytest.raises(ValueError, match="file does not exist"): + validate_config(valid_cfg, check_paths=True) + + +class TestValidateConfigOptional: + def test_optional_none_metadata_file(self, valid_cfg): + valid_cfg.metadata_file = None + validate_config(valid_cfg) + + def test_optional_none_prediction_search_dir(self, valid_cfg): + valid_cfg.prediction_search_dir = None + validate_config(valid_cfg) + + +class TestFitsExtensionChannelAutoAdjust: + """Regression tests for auto-adjusting n_output_channels from fits_extension.""" + + def test_fits_extension_4_auto_adjusts_n_output_channels(self, valid_cfg): + """fits_extension=[0,1,2,3] with default n_output_channels=3 should auto-adjust to 4.""" + import numpy as np + + valid_cfg.normalisation.fits_extension = [0, 1, 2, 3] + assert valid_cfg.normalisation.n_output_channels == 3 + validate_config(valid_cfg) + assert valid_cfg.normalisation.n_output_channels == 4 + assert valid_cfg.num_channels == 4 + np.testing.assert_array_equal(valid_cfg.normalisation.channel_combination, np.eye(4)) + + def test_fits_extension_matching_channels_creates_identity(self, valid_cfg): + """fits_extension=[0,1,2] with n_output_channels=3 should create identity matrix.""" + import numpy as np + + valid_cfg.normalisation.fits_extension = [0, 1, 2] + validate_config(valid_cfg) + assert valid_cfg.normalisation.n_output_channels == 3 + np.testing.assert_array_equal(valid_cfg.normalisation.channel_combination, np.eye(3)) + + def test_fits_extension_single_no_adjustment(self, valid_cfg): + """Single fits_extension should not trigger adjustment.""" + valid_cfg.normalisation.fits_extension = [0] + validate_config(valid_cfg) + assert valid_cfg.normalisation.n_output_channels == 3 + + def test_channel_combination_provided_no_adjustment(self, valid_cfg): + """Explicit channel_combination should prevent auto-adjustment.""" + import numpy as np + + valid_cfg.normalisation.fits_extension = [0, 1, 2, 3] + valid_cfg.normalisation.channel_combination = np.eye(3, 4) + validate_config(valid_cfg) + assert valid_cfg.normalisation.n_output_channels == 3 + + def test_asinh_params_extended_on_adjustment(self, valid_cfg): + """Per-channel asinh params should be extended when channels are added.""" + valid_cfg.normalisation.fits_extension = [0, 1, 2, 3] + valid_cfg.normalisation.norm_asinh_scale = [0.7, 0.7, 0.7] + valid_cfg.normalisation.norm_asinh_clip = [99.8, 99.8, 99.8] + validate_config(valid_cfg) + assert len(valid_cfg.normalisation.norm_asinh_scale) == 4 + assert len(valid_cfg.normalisation.norm_asinh_clip) == 4 + + def test_n_output_channels_inferred_from_channel_combination(self, valid_cfg): + """n_output_channels should be inferred from channel_combination shape.""" + import numpy as np + + valid_cfg.normalisation.fits_extension = [0, 1, 2, 3] + # User provides 1x4 matrix but doesn't set n_output_channels (defaults to 3) + valid_cfg.normalisation.channel_combination = np.array([[0.25, 0.25, 0.25, 0.25]]) + validate_config(valid_cfg) + assert valid_cfg.normalisation.n_output_channels == 1 + assert valid_cfg.num_channels == 1 + assert len(valid_cfg.normalisation.norm_asinh_scale) == 1 + assert len(valid_cfg.normalisation.norm_asinh_clip) == 1 + + def test_n_output_channels_inferred_3x4_matrix(self, valid_cfg): + """3x4 channel_combination with default n_output_channels=3 should not change.""" + import numpy as np + + valid_cfg.normalisation.fits_extension = [0, 1, 2, 3] + valid_cfg.normalisation.channel_combination = np.eye(3, 4) + validate_config(valid_cfg) + assert valid_cfg.normalisation.n_output_channels == 3 + + def test_n_output_channels_none_with_multi_ext_auto_infers(self, valid_cfg): + """n_output_channels=None with multiple fits_extension should auto-infer.""" + import numpy as np + + valid_cfg.normalisation.fits_extension = [0, 1, 2, 3] + valid_cfg.normalisation.n_output_channels = None + validate_config(valid_cfg) + assert valid_cfg.normalisation.n_output_channels == 4 + np.testing.assert_array_equal(valid_cfg.normalisation.channel_combination, np.eye(4)) + + def test_n_output_channels_none_single_ext_raises(self, valid_cfg): + """n_output_channels=None with single fits_extension should raise ValueError.""" + valid_cfg.normalisation.fits_extension = [0] + valid_cfg.normalisation.n_output_channels = None + with pytest.raises(ValueError, match="n_output_channels is None"): + validate_config(valid_cfg) + + +class TestValidateConfigUnexpectedKeys: + def test_warns_on_unexpected_keys(self, valid_cfg, caplog): + valid_cfg.unexpected_key = "some_value" + validate_config(valid_cfg) + assert "Found unexpected keys in config" in caplog.text