From e38d70b9cd117acdb84ad6166f203dea1b3fc209 Mon Sep 17 00:00:00 2001 From: Henry Senyondo Date: Fri, 13 Feb 2026 01:00:02 -0500 Subject: [PATCH 1/7] Improve predict_tile memory use and windowed raster inference - Add helper to convert inputs to RGB CHW float32 in [0, 1] - Validate ndim/channel placement and reject grayscale early - Normalize based on dtype (uint8 -> /255), not max/min heuristics - Ensure contiguous arrays before torch conversion - Refactor SingleImage to keep full image in CHW without forcing full float32 copy; convert/normalize per-window crops in get_crop - Improve TiledRaster/window strategy: reuse single rasterio handle, warn when untiled, close datasets promptly after predict - Simplify predict_step return --- src/deepforest/datasets/prediction.py | 162 ++++++++++++++++++-------- src/deepforest/main.py | 7 +- 2 files changed, 119 insertions(+), 50 deletions(-) diff --git a/src/deepforest/datasets/prediction.py b/src/deepforest/datasets/prediction.py index bb5b2d3b5..bed0695de 100644 --- a/src/deepforest/datasets/prediction.py +++ b/src/deepforest/datasets/prediction.py @@ -1,4 +1,5 @@ import os +import warnings import numpy as np import pandas as pd @@ -14,6 +15,70 @@ from deepforest.utilities import format_geometry, read_file +def _load_image_array( + image_path: str | None = None, image: np.ndarray | Image.Image | None = None +) -> np.ndarray: + """Load image from path or array; converts to RGB when loading from path.""" + if image is None: + if image_path is None: + raise ValueError("Either image_path or image must be provided") + return np.asarray(Image.open(image_path).convert("RGB")) + + return image if isinstance(image, np.ndarray) else np.asarray(image) + + +def _ensure_rgb_chw(image: np.ndarray) -> np.ndarray: + """Return 3-channel RGB in CHW order (no normalization). Raises if grayscale or wrong shape.""" + if image.ndim == 2: + raise ValueError("Grayscale images are not supported (expected 3-channel RGB)") + if image.ndim != 3: + raise ValueError(f"Expected 3D image array, got shape {image.shape}") + + # Ensure channels-first (C, H, W) + if image.shape[0] == 3: + chw = image + elif image.shape[-1] == 3: + chw = np.moveaxis(image, -1, 0) + else: + raise ValueError(f"Expected image with 3 channels, got shape {image.shape}") + + return np.ascontiguousarray(chw) + + +def _ensure_rgb_chw_float32(image: np.ndarray) -> np.ndarray: + """Normalize to RGB CHW float32 in [0, 1]. Accepts HWC/CHW uint8 or float. Raises if invalid.""" + chw = _ensure_rgb_chw(image) + + # Normalize based primarily on dtype + if chw.dtype == np.uint8: + chw = chw.astype(np.float32) + chw /= 255.0 + elif np.issubdtype(chw.dtype, np.floating): + if chw.dtype != np.float32: + chw = chw.astype(np.float32) + + # Allow already-normalized float images. + # If values look like 0-255 floats, normalize. + max_val = float(chw.max()) + min_val = float(chw.min()) + if min_val < 0: + raise ValueError( + f"Expected float image in [0, 1] or [0, 255], got min {min_val}" + ) + if max_val > 1.0: + if max_val <= 255.0: + chw /= 255.0 + else: + raise ValueError( + f"Expected float image in [0, 1] or [0, 255], got max {max_val}" + ) + else: + # Integers other than uint8 are ambiguous; be explicit. + raise ValueError(f"Unsupported image dtype {chw.dtype}. Expected uint8 or float.") + + return np.ascontiguousarray(chw) + + # Base prediction class class PredictionDataset(Dataset): """Base class for prediction datasets. Defines the common interface and @@ -48,32 +113,9 @@ def __init__( def load_and_preprocess_image( self, image_path: str = None, image: np.ndarray | Image.Image = None ): - if image is None: - if image_path is None: - raise ValueError("Either image_path or image must be provided") - image = np.array(Image.open(image_path).convert("RGB")) - else: - image = np.array(image) - # If dtype is not float32, convert to float32 - if image.dtype != "float32": - image = image.astype("float32") - - # If image is not normalized, normalize to [0, 1] - if image.max() > 1 or image.min() < 0: - image = image / 255.0 - - # If image is not in CHW format, convert to CHW - if image.shape[0] != 3: - if image.shape[-1] != 3: - raise ValueError( - f"Expected 3 channel image, got image shape {image.shape}" - ) - else: - image = np.rollaxis(image, 2, 0) - - image = torch.from_numpy(image) - - return image + image_arr = _load_image_array(image_path=image_path, image=image) + image_arr = _ensure_rgb_chw_float32(image_arr) + return torch.from_numpy(image_arr) def prepare_items(self): """Prepare the items for the dataset. @@ -169,7 +211,11 @@ def __init__(self, path=None, image=None, patch_size=400, patch_overlap=0): ) def prepare_items(self): - self.image = self.load_and_preprocess_image(self.path, image=self.image) + image_arr = _load_image_array(image_path=self.path, image=self.image) + image = _ensure_rgb_chw(image_arr) + + # Keep as uint8/float in CHW; normalize per-crop to avoid full-image float copy + self.image = image self.windows = preprocess.compute_windows( self.image, self.patch_size, self.patch_overlap ) @@ -182,8 +228,11 @@ def window_list(self): def get_crop(self, idx): crop = self.image[self.windows[idx].indices()] - - return crop + if crop.dtype != "float32": + crop = crop.astype("float32") + if crop.max() > 1 or crop.min() < 0: + crop /= 255.0 + return torch.from_numpy(crop) def get_image_basename(self, idx): if self.path is not None: @@ -433,24 +482,25 @@ class TiledRaster(PredictionDataset): def __init__(self, path, patch_size, patch_overlap): if path is None: raise ValueError("path is required for a memory raster dataset") + self._src = None super().__init__(path=path, patch_size=patch_size, patch_overlap=patch_overlap) def prepare_items(self): - # Get raster shape without keeping file open - with rio.open(self.path) as src: - width = src.shape[0] - height = src.shape[1] - - # Check is tiled - if not src.is_tiled: - raise ValueError( - "Out-of-memory dataset is selected, but raster is not tiled, " - "leading to entire raster being read into memory and defeating " - "the purpose of an out-of-memory dataset. " - "\nPlease run: " - "\ngdal_translate -of GTiff -co TILED=YES " - "to create a tiled raster" - ) + # Open once; workers=0 is enforced by caller for this dataset. + self._src = rio.open(self.path) + height = self._src.height + width = self._src.width + + # Warn on non-tiled rasters: window reads may still be efficient (strip-based), + # but performance can degrade depending on driver/strip layout. + if not self._src.is_tiled: + warnings.warn( + "dataloader_strategy='window' is selected, but raster is not tiled. " + "Windowed reads may be slower depending on file layout. If needed, " + "create a tiled GeoTIFF with: " + "gdal_translate -of GTiff -co TILED=YES ", + stacklevel=2, + ) # Generate sliding windows self.windows = slidingwindow.generateForSize( @@ -469,12 +519,15 @@ def window_list(self): def get_crop(self, idx): window = self.windows[idx] - with rio.open(self.path) as src: - window_data = src.read(window=Window(window.x, window.y, window.w, window.h)) + assert self._src is not None, "Raster dataset is not open" + window_data = self._src.read( + window=Window(window.x, window.y, window.w, window.h) + ) # Rasterio already returns (C, H, W), just normalize and convert - window_data = window_data.astype("float32") / 255.0 - window_data = torch.from_numpy(window_data).float() + window_data = window_data.astype(np.float32) + window_data /= 255.0 + window_data = torch.from_numpy(window_data) if window_data.shape[0] != 3: raise ValueError( f"Expected 3 channel image, got {window_data.shape[0]} channels" @@ -487,3 +540,16 @@ def get_image_basename(self, idx): def get_crop_bounds(self, idx): return self.window_list()[idx] + + def close(self) -> None: + """Close the underlying raster dataset.""" + if self._src is not None: + self._src.close() + self._src = None + + def __del__(self): + # Best-effort cleanup + try: + self.close() + except Exception: + pass diff --git a/src/deepforest/main.py b/src/deepforest/main.py index af331eb78..6bee743dc 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -600,6 +600,10 @@ def predict_tile( image_results.append(formatted_result) global_window_idx += 1 + # Ensure raster datasets are closed promptly + if hasattr(ds, "close"): + ds.close() + if not image_results: results = pd.DataFrame() else: @@ -895,8 +899,7 @@ def predict_step(self, batch, batch_idx): self.model.eval() with torch.no_grad(): - preds = self.model.forward(images) - return preds + return self.model.forward(images) def predict_batch(self, images, preprocess_fn=None): """Predict a batch of images with the deepforest model. From 863951fb9d97c03bfa25eee265a1d7c1c2149941 Mon Sep 17 00:00:00 2001 From: Henry Senyondo Date: Fri, 13 Feb 2026 01:02:07 -0500 Subject: [PATCH 2/7] Add GPU tests and CUDA fallback warning - Warn when CUDA is available but trainer falls back to CPU - Add test_gpu_inference_uses_cuda: regression test for GPU inference - Add hpc_multi_gpu_train: DDP smoke test for multi-GPU training - Relax requires-python to allow 3.13 --- pyproject.toml | 2 +- src/deepforest/main.py | 20 ++++++++ tests/hpc_multi_gpu_train.py | 74 +++++++++++++++++++++++++++ tests/test_gpu_inference_uses_cuda.py | 39 ++++++++++++++ 4 files changed, 134 insertions(+), 1 deletion(-) create mode 100644 tests/hpc_multi_gpu_train.py create mode 100644 tests/test_gpu_inference_uses_cuda.py diff --git a/pyproject.toml b/pyproject.toml index 9f92fb9c2..714fe6c3b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "deepforest" version = "2.0.1dev0" description = "Platform for individual detection from airborne remote sensing including trees, birds, and livestock. Supports multiple detection models, adding models for species classification, and easy fine tuning to particular ecosystems." readme = "README.md" -requires-python = ">=3.10,<3.13" +requires-python = ">=3.10,<3.14" license = {text = "MIT"} keywords = ["deep-learning", "forest", "ecology", "computer-vision"] classifiers = [ diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 6bee743dc..0968b1ece 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -241,6 +241,26 @@ def create_trainer(self, logger=None, callbacks=None, **kwargs): self.trainer = pl.Trainer(**trainer_args) + # Helpful warning: CUDA visible but trainer not using it. + # This commonly happens if accelerator/devices were overridden to CPU, or + # if the trainer wasn't recreated after changing config. + try: + accel_name = type(self.trainer.accelerator).__name__.lower() + except Exception: + accel_name = "" + + requested_accel = str(trainer_args.get("accelerator", "")).lower() + if torch.cuda.is_available() and requested_accel in {"auto", "gpu", "cuda"}: + if "cuda" not in accel_name and "gpu" not in accel_name: + warnings.warn( + "CUDA appears to be available, but the Lightning trainer is not " + f"using a GPU accelerator (accelerator={trainer_args.get('accelerator')}, " + f"devices={trainer_args.get('devices')}). " + "To force GPU inference, call create_trainer(accelerator='gpu', devices=1) " + "or set config.accelerator='gpu' and config.devices=1, then recreate the trainer.", + stacklevel=2, + ) + def on_fit_start(self): if self.config.train.csv_file is None: raise AttributeError( diff --git a/tests/hpc_multi_gpu_train.py b/tests/hpc_multi_gpu_train.py new file mode 100644 index 000000000..9ea1307b0 --- /dev/null +++ b/tests/hpc_multi_gpu_train.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python +"""HPC-only multi-GPU training smoke test (DDP). + +Run with: + torchrun --nproc_per_node=2 tests/hpc_multi_gpu_train.py +""" +from __future__ import annotations + +import os +import sys + +import torch + +from deepforest import get_data +from deepforest.main import deepforest + + +def _require_hpc() -> None: + if os.environ.get("GITHUB_ACTIONS") or os.environ.get("CI"): + raise SystemExit("CI environment detected; skip HPC-only test.") + if not os.environ.get("HIPERGATOR") and not os.environ.get("SLURM_JOB_ID"): + raise SystemExit( + "This script is intended for HPC use only. " + "Set HIPERGATOR=1 or run under SLURM." + ) + + +def _require_ddp() -> None: + if "LOCAL_RANK" not in os.environ and "RANK" not in os.environ: + raise SystemExit( + "DDP environment not detected. Run with:\n" + " torchrun --nproc_per_node=2 tests/hpc_multi_gpu_train.py" + ) + + +def main() -> int: + _require_hpc() + _require_ddp() + + if torch.cuda.device_count() < 2: + raise SystemExit("Need at least 2 GPUs for this test.") + + m = deepforest() + m.config.workers = 0 + m.config.batch_size = 1 + m.config.num_classes = 1 + m.config.label_dict = {"Tree": 0} + train_csv = get_data("example.csv") + m.config.train.csv_file = train_csv + m.config.train.root_dir = os.path.dirname(train_csv) + m.config.validation.csv_file = train_csv + m.config.validation.root_dir = os.path.dirname(train_csv) + m.create_model(initialize_model=True) + + # Keep this fast but avoid fast_dev_run's zero-length warning in DDP. + m.create_trainer( + accelerator="gpu", + devices=2, + strategy="ddp", + fast_dev_run=False, + max_epochs=1, + limit_train_batches=2, + limit_val_batches=2, + log_every_n_steps=1, + ) + m.trainer.fit(m) + + # Multi-GPU evaluation pass (uses same example.csv) + m.trainer.validate(m) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/test_gpu_inference_uses_cuda.py b/tests/test_gpu_inference_uses_cuda.py new file mode 100644 index 000000000..5cbfc03fa --- /dev/null +++ b/tests/test_gpu_inference_uses_cuda.py @@ -0,0 +1,39 @@ +import os + +import pytest +import torch + +from deepforest import get_data +from deepforest.main import deepforest + + +@pytest.mark.skipif( + not os.environ.get("HIPERGATOR"), + reason="Only run on HIPERGATOR (requires GPU + model downloads).", +) +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available in this test environment.", +) +def test_predict_tile_uses_cuda_when_requested(): + """Ensure predict_tile runs on CUDA when accelerator/devices request GPU. + + This is a regression test to catch silent CPU fallbacks on GPU nodes. + """ + m = deepforest(config_args={"accelerator": "gpu", "devices": 1, "workers": 0}) + m.load_model(model_name="weecology/deepforest-tree", revision="main") + m.create_trainer(accelerator="gpu", devices=1) + + results = m.predict_tile( + path=get_data("OSBS_029.png"), + patch_size=400, + patch_overlap=0.0, + iou_threshold=0.15, + dataloader_strategy="single", + ) + assert results is not None and not results.empty + + # Assert trainer is actually using a GPU accelerator (no silent CPU fallback). + assert m.trainer is not None + accel_name = type(m.trainer.accelerator).__name__.lower() + assert "cuda" in accel_name or "gpu" in accel_name From cbdce347fa13959c39dba07d97676d933890f01a Mon Sep 17 00:00:00 2001 From: henrykironde Date: Sat, 14 Feb 2026 23:38:41 -0500 Subject: [PATCH 3/7] Fix in-place crop normalization corrupts overlapping windows --- src/deepforest/datasets/prediction.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/deepforest/datasets/prediction.py b/src/deepforest/datasets/prediction.py index bed0695de..534da2686 100644 --- a/src/deepforest/datasets/prediction.py +++ b/src/deepforest/datasets/prediction.py @@ -18,7 +18,8 @@ def _load_image_array( image_path: str | None = None, image: np.ndarray | Image.Image | None = None ) -> np.ndarray: - """Load image from path or array; converts to RGB when loading from path.""" + """Load image from path or array; converts to RGB when loading from + path.""" if image is None: if image_path is None: raise ValueError("Either image_path or image must be provided") @@ -28,7 +29,10 @@ def _load_image_array( def _ensure_rgb_chw(image: np.ndarray) -> np.ndarray: - """Return 3-channel RGB in CHW order (no normalization). Raises if grayscale or wrong shape.""" + """Return 3-channel RGB in CHW order (no normalization). + + Raises if grayscale or wrong shape. + """ if image.ndim == 2: raise ValueError("Grayscale images are not supported (expected 3-channel RGB)") if image.ndim != 3: @@ -46,7 +50,10 @@ def _ensure_rgb_chw(image: np.ndarray) -> np.ndarray: def _ensure_rgb_chw_float32(image: np.ndarray) -> np.ndarray: - """Normalize to RGB CHW float32 in [0, 1]. Accepts HWC/CHW uint8 or float. Raises if invalid.""" + """Normalize to RGB CHW float32 in [0, 1]. + + Accepts HWC/CHW uint8 or float. Raises if invalid. + """ chw = _ensure_rgb_chw(image) # Normalize based primarily on dtype @@ -228,10 +235,11 @@ def window_list(self): def get_crop(self, idx): crop = self.image[self.windows[idx].indices()] - if crop.dtype != "float32": - crop = crop.astype("float32") - if crop.max() > 1 or crop.min() < 0: - crop /= 255.0 + # Copy to avoid in-place modification corrupting self.image when crop is a + # view (e.g. overlapping windows or float32 input). Reuse dtype-based + # normalization to avoid heuristic edge cases (e.g. uint8 all-0/1 crops). + crop = np.array(crop, copy=True) + crop = _ensure_rgb_chw_float32(crop) return torch.from_numpy(crop) def get_image_basename(self, idx): From e552a5ab9a887146c335ebeb21b71e40dad29bf7 Mon Sep 17 00:00:00 2001 From: henrykironde Date: Sun, 15 Feb 2026 01:40:53 -0500 Subject: [PATCH 4/7] Add TiledRaster close on failure and pickle for multiprocessing --- src/deepforest/datasets/prediction.py | 13 ++++++++++ src/deepforest/main.py | 35 ++++++++++++++------------- 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/deepforest/datasets/prediction.py b/src/deepforest/datasets/prediction.py index 534da2686..46c1d3a40 100644 --- a/src/deepforest/datasets/prediction.py +++ b/src/deepforest/datasets/prediction.py @@ -555,6 +555,19 @@ def close(self) -> None: self._src.close() self._src = None + def __getstate__(self) -> dict: + """Make picklable for multiprocessing; rasterio handles are not + serializable.""" + state = self.__dict__.copy() + state["_src"] = None # Exclude handle; __setstate__ will reopen in new process + return state + + def __setstate__(self, state: dict) -> None: + """Restore after unpickle; reopen raster since handle was excluded.""" + self.__dict__.update(state) + if self._src is None and self.path is not None: + self._src = rio.open(self.path) + def __del__(self): # Best-effort cleanup try: diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 0968b1ece..d942b0251 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -606,23 +606,24 @@ def predict_tile( patch_size=patch_size, ) - dataloader = self.predict_dataloader(ds) - batched_results = self.trainer.predict(self, dataloader) - - # Flatten list from batched prediction - # Track global window index across batches - global_window_idx = 0 - for _idx, batch in enumerate(batched_results): - for _window_idx, window_result in enumerate(batch): - formatted_result = ds.postprocess( - window_result, global_window_idx - ) - image_results.append(formatted_result) - global_window_idx += 1 - - # Ensure raster datasets are closed promptly - if hasattr(ds, "close"): - ds.close() + try: + dataloader = self.predict_dataloader(ds) + batched_results = self.trainer.predict(self, dataloader) + + # Flatten list from batched prediction + # Track global window index across batches + global_window_idx = 0 + for _idx, batch in enumerate(batched_results): + for _window_idx, window_result in enumerate(batch): + formatted_result = ds.postprocess( + window_result, global_window_idx + ) + image_results.append(formatted_result) + global_window_idx += 1 + finally: + # Ensure raster datasets are closed even if predict/postprocess raises + if hasattr(ds, "close"): + ds.close() if not image_results: results = pd.DataFrame() From 65a42a33f0b55a7dec456c80aa6fe6388fd038d1 Mon Sep 17 00:00:00 2001 From: henrykironde Date: Sun, 15 Feb 2026 03:10:48 -0500 Subject: [PATCH 5/7] fix: avoid in-place mutation of caller array in _ensure_rgb_chw_float32 Copy only when chw may alias the input (np.shares_memory), then use in-place division. Preserves caller's data without memory spike from regular division creating temporary arrays. Co-authored-by: Cursor --- src/deepforest/datasets/prediction.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/deepforest/datasets/prediction.py b/src/deepforest/datasets/prediction.py index 46c1d3a40..8551343ca 100644 --- a/src/deepforest/datasets/prediction.py +++ b/src/deepforest/datasets/prediction.py @@ -74,6 +74,8 @@ def _ensure_rgb_chw_float32(image: np.ndarray) -> np.ndarray: ) if max_val > 1.0: if max_val <= 255.0: + if np.shares_memory(chw, image): + chw = chw.copy() chw /= 255.0 else: raise ValueError( From 525ef0398ae2e8de347416dc02aa6673203d6d6d Mon Sep 17 00:00:00 2001 From: henrykironde Date: Mon, 16 Feb 2026 00:32:57 -0500 Subject: [PATCH 6/7] Avoid ZeroDivisionError edge case --- src/deepforest/datasets/prediction.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/deepforest/datasets/prediction.py b/src/deepforest/datasets/prediction.py index 8551343ca..f773b9c0b 100644 --- a/src/deepforest/datasets/prediction.py +++ b/src/deepforest/datasets/prediction.py @@ -513,13 +513,26 @@ def prepare_items(self): ) # Generate sliding windows - self.windows = slidingwindow.generateForSize( + all_windows = slidingwindow.generateForSize( height, width, dimOrder=slidingwindow.DimOrder.ChannelHeightWidth, maxWindowSize=self.patch_size, overlapPercent=self.patch_overlap, ) + # Filter out invalid windows: zero-size or extending past raster bounds. + # Rasterio returns (C,0,W) or (C,H,0) for out-of-bounds reads, which breaks RetinaNet. + self.windows = [ + w + for w in all_windows + if w.w > 0 and w.h > 0 and w.x + w.w <= width and w.y + w.h <= height + ] + n_filtered = len(all_windows) - len(self.windows) + if n_filtered > 0: + warnings.warn( + f"Filtered {n_filtered} window(s) extending past raster bounds or zero-size.", + stacklevel=2, + ) def __len__(self): return len(self.windows) @@ -534,7 +547,11 @@ def get_crop(self, idx): window=Window(window.x, window.y, window.w, window.h) ) - # Rasterio already returns (C, H, W), just normalize and convert + if window_data.shape[1] == 0 or window_data.shape[2] == 0: + raise ValueError( + f"Window {idx} returned zero-size array (shape={window_data.shape}). " + "RetinaNet cannot process images with zero height or width." + ) window_data = window_data.astype(np.float32) window_data /= 255.0 window_data = torch.from_numpy(window_data) From d847701d5c380e8d785bd63beda8a4accec9ad7c Mon Sep 17 00:00:00 2001 From: Henry Senyondo Date: Mon, 16 Feb 2026 01:56:00 -0500 Subject: [PATCH 7/7] fix: normalize full image in SingleImage.prepare_items for consistent crops Apply normalization to the entire image once before prediction. Co-authored-by: Cursor --- src/deepforest/datasets/prediction.py | 22 ++++++++++----------- tests/test_datasets_prediction.py | 28 +++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/src/deepforest/datasets/prediction.py b/src/deepforest/datasets/prediction.py index f773b9c0b..0f866cc7d 100644 --- a/src/deepforest/datasets/prediction.py +++ b/src/deepforest/datasets/prediction.py @@ -52,7 +52,8 @@ def _ensure_rgb_chw(image: np.ndarray) -> np.ndarray: def _ensure_rgb_chw_float32(image: np.ndarray) -> np.ndarray: """Normalize to RGB CHW float32 in [0, 1]. - Accepts HWC/CHW uint8 or float. Raises if invalid. + Accepts HWC/CHW uint8 or float. Raises if invalid. For float images, + uses full-image heuristic (max > 1.0 -> divide by 255). """ chw = _ensure_rgb_chw(image) @@ -221,12 +222,12 @@ def __init__(self, path=None, image=None, patch_size=400, patch_overlap=0): def prepare_items(self): image_arr = _load_image_array(image_path=self.path, image=self.image) - image = _ensure_rgb_chw(image_arr) - - # Keep as uint8/float in CHW; normalize per-crop to avoid full-image float copy - self.image = image + # Normalize full image once so all crops share consistent treatment + # (uniform across dtype, float [0,1] vs [0,255], and dark vs bright regions). + image_norm = _ensure_rgb_chw_float32(image_arr) + self.image = torch.from_numpy(image_norm) self.windows = preprocess.compute_windows( - self.image, self.patch_size, self.patch_overlap + image_norm, self.patch_size, self.patch_overlap ) def __len__(self): @@ -237,12 +238,9 @@ def window_list(self): def get_crop(self, idx): crop = self.image[self.windows[idx].indices()] - # Copy to avoid in-place modification corrupting self.image when crop is a - # view (e.g. overlapping windows or float32 input). Reuse dtype-based - # normalization to avoid heuristic edge cases (e.g. uint8 all-0/1 crops). - crop = np.array(crop, copy=True) - crop = _ensure_rgb_chw_float32(crop) - return torch.from_numpy(crop) + # Clone to avoid in-place modification corrupting self.image when crop + # is a view (overlapping windows). + return crop.clone() def get_image_basename(self, idx): if self.path is not None: diff --git a/tests/test_datasets_prediction.py b/tests/test_datasets_prediction.py index f98ab3f72..642b64f83 100644 --- a/tests/test_datasets_prediction.py +++ b/tests/test_datasets_prediction.py @@ -47,6 +47,34 @@ def test_valid_array(): test_data = np.random.randint(0, 256, (300,300,3)).astype(np.uint8) SingleImage(image=test_data) + +def test_single_image_float32_0_255_consistent_normalization(): + """Float32 [0, 255] crops must be normalized uniformly from full-image decision. + + A dark crop (all pixels <= 1.0) would be misclassified as [0, 1] by the + per-crop heuristic; with the fix, we use the full-image max to decide once. + """ + # Image: left half dark (0.5), right half bright (128). Full max > 1. + h, w = 200, 400 + img = np.zeros((h, w, 3), dtype=np.float32) + img[:, : w // 2, :] = 0.5 # Dark region + img[:, w // 2 :, :] = 128.0 # Bright region + # CHW for preprocess.compute_windows + img = np.moveaxis(img, -1, 0) + + ds = SingleImage(image=img, patch_size=100, patch_overlap=0) + assert len(ds) >= 2 + + # First crop(s) from dark region: max=0.5. Should be divided by 255 -> max ~0.002 + dark_crop = ds[0] + assert dark_crop.shape == (3, 100, 100) + assert dark_crop.max().item() < 0.01, "Dark crop should be /255, not left as [0,1]" + + # Crop from bright region: max=128. Should be divided by 255 -> max ~0.5 + bright_idx = len(ds) - 1 + bright_crop = ds[bright_idx] + assert bright_crop.max().item() > 0.4 + def test_MultiImage(): ds = MultiImage(paths=[get_data("OSBS_029.png"), get_data("OSBS_029.png")], patch_size=300,