diff --git a/pyproject.toml b/pyproject.toml
index 9f92fb9c2..714fe6c3b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,7 +3,7 @@ name = "deepforest"
version = "2.0.1dev0"
description = "Platform for individual detection from airborne remote sensing including trees, birds, and livestock. Supports multiple detection models, adding models for species classification, and easy fine tuning to particular ecosystems."
readme = "README.md"
-requires-python = ">=3.10,<3.13"
+requires-python = ">=3.10,<3.14"
license = {text = "MIT"}
keywords = ["deep-learning", "forest", "ecology", "computer-vision"]
classifiers = [
diff --git a/src/deepforest/datasets/prediction.py b/src/deepforest/datasets/prediction.py
index bb5b2d3b5..0f866cc7d 100644
--- a/src/deepforest/datasets/prediction.py
+++ b/src/deepforest/datasets/prediction.py
@@ -1,4 +1,5 @@
import os
+import warnings
import numpy as np
import pandas as pd
@@ -14,6 +15,80 @@
from deepforest.utilities import format_geometry, read_file
+def _load_image_array(
+ image_path: str | None = None, image: np.ndarray | Image.Image | None = None
+) -> np.ndarray:
+ """Load image from path or array; converts to RGB when loading from
+ path."""
+ if image is None:
+ if image_path is None:
+ raise ValueError("Either image_path or image must be provided")
+ return np.asarray(Image.open(image_path).convert("RGB"))
+
+ return image if isinstance(image, np.ndarray) else np.asarray(image)
+
+
+def _ensure_rgb_chw(image: np.ndarray) -> np.ndarray:
+ """Return 3-channel RGB in CHW order (no normalization).
+
+ Raises if grayscale or wrong shape.
+ """
+ if image.ndim == 2:
+ raise ValueError("Grayscale images are not supported (expected 3-channel RGB)")
+ if image.ndim != 3:
+ raise ValueError(f"Expected 3D image array, got shape {image.shape}")
+
+ # Ensure channels-first (C, H, W)
+ if image.shape[0] == 3:
+ chw = image
+ elif image.shape[-1] == 3:
+ chw = np.moveaxis(image, -1, 0)
+ else:
+ raise ValueError(f"Expected image with 3 channels, got shape {image.shape}")
+
+ return np.ascontiguousarray(chw)
+
+
+def _ensure_rgb_chw_float32(image: np.ndarray) -> np.ndarray:
+ """Normalize to RGB CHW float32 in [0, 1].
+
+ Accepts HWC/CHW uint8 or float. Raises if invalid. For float images,
+ uses full-image heuristic (max > 1.0 -> divide by 255).
+ """
+ chw = _ensure_rgb_chw(image)
+
+ # Normalize based primarily on dtype
+ if chw.dtype == np.uint8:
+ chw = chw.astype(np.float32)
+ chw /= 255.0
+ elif np.issubdtype(chw.dtype, np.floating):
+ if chw.dtype != np.float32:
+ chw = chw.astype(np.float32)
+
+ # Allow already-normalized float images.
+ # If values look like 0-255 floats, normalize.
+ max_val = float(chw.max())
+ min_val = float(chw.min())
+ if min_val < 0:
+ raise ValueError(
+ f"Expected float image in [0, 1] or [0, 255], got min {min_val}"
+ )
+ if max_val > 1.0:
+ if max_val <= 255.0:
+ if np.shares_memory(chw, image):
+ chw = chw.copy()
+ chw /= 255.0
+ else:
+ raise ValueError(
+ f"Expected float image in [0, 1] or [0, 255], got max {max_val}"
+ )
+ else:
+ # Integers other than uint8 are ambiguous; be explicit.
+ raise ValueError(f"Unsupported image dtype {chw.dtype}. Expected uint8 or float.")
+
+ return np.ascontiguousarray(chw)
+
+
# Base prediction class
class PredictionDataset(Dataset):
"""Base class for prediction datasets. Defines the common interface and
@@ -48,32 +123,9 @@ def __init__(
def load_and_preprocess_image(
self, image_path: str = None, image: np.ndarray | Image.Image = None
):
- if image is None:
- if image_path is None:
- raise ValueError("Either image_path or image must be provided")
- image = np.array(Image.open(image_path).convert("RGB"))
- else:
- image = np.array(image)
- # If dtype is not float32, convert to float32
- if image.dtype != "float32":
- image = image.astype("float32")
-
- # If image is not normalized, normalize to [0, 1]
- if image.max() > 1 or image.min() < 0:
- image = image / 255.0
-
- # If image is not in CHW format, convert to CHW
- if image.shape[0] != 3:
- if image.shape[-1] != 3:
- raise ValueError(
- f"Expected 3 channel image, got image shape {image.shape}"
- )
- else:
- image = np.rollaxis(image, 2, 0)
-
- image = torch.from_numpy(image)
-
- return image
+ image_arr = _load_image_array(image_path=image_path, image=image)
+ image_arr = _ensure_rgb_chw_float32(image_arr)
+ return torch.from_numpy(image_arr)
def prepare_items(self):
"""Prepare the items for the dataset.
@@ -169,9 +221,13 @@ def __init__(self, path=None, image=None, patch_size=400, patch_overlap=0):
)
def prepare_items(self):
- self.image = self.load_and_preprocess_image(self.path, image=self.image)
+ image_arr = _load_image_array(image_path=self.path, image=self.image)
+ # Normalize full image once so all crops share consistent treatment
+ # (uniform across dtype, float [0,1] vs [0,255], and dark vs bright regions).
+ image_norm = _ensure_rgb_chw_float32(image_arr)
+ self.image = torch.from_numpy(image_norm)
self.windows = preprocess.compute_windows(
- self.image, self.patch_size, self.patch_overlap
+ image_norm, self.patch_size, self.patch_overlap
)
def __len__(self):
@@ -182,8 +238,9 @@ def window_list(self):
def get_crop(self, idx):
crop = self.image[self.windows[idx].indices()]
-
- return crop
+ # Clone to avoid in-place modification corrupting self.image when crop
+ # is a view (overlapping windows).
+ return crop.clone()
def get_image_basename(self, idx):
if self.path is not None:
@@ -433,33 +490,47 @@ class TiledRaster(PredictionDataset):
def __init__(self, path, patch_size, patch_overlap):
if path is None:
raise ValueError("path is required for a memory raster dataset")
+ self._src = None
super().__init__(path=path, patch_size=patch_size, patch_overlap=patch_overlap)
def prepare_items(self):
- # Get raster shape without keeping file open
- with rio.open(self.path) as src:
- width = src.shape[0]
- height = src.shape[1]
-
- # Check is tiled
- if not src.is_tiled:
- raise ValueError(
- "Out-of-memory dataset is selected, but raster is not tiled, "
- "leading to entire raster being read into memory and defeating "
- "the purpose of an out-of-memory dataset. "
- "\nPlease run: "
- "\ngdal_translate -of GTiff -co TILED=YES