Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name = "deepforest"
version = "2.0.1dev0"
description = "Platform for individual detection from airborne remote sensing including trees, birds, and livestock. Supports multiple detection models, adding models for species classification, and easy fine tuning to particular ecosystems."
readme = "README.md"
requires-python = ">=3.10,<3.13"
requires-python = ">=3.10,<3.14"
license = {text = "MIT"}
keywords = ["deep-learning", "forest", "ecology", "computer-vision"]
classifiers = [
Expand Down
206 changes: 155 additions & 51 deletions src/deepforest/datasets/prediction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings

import numpy as np
import pandas as pd
Expand All @@ -14,6 +15,80 @@
from deepforest.utilities import format_geometry, read_file


def _load_image_array(
image_path: str | None = None, image: np.ndarray | Image.Image | None = None
) -> np.ndarray:
"""Load image from path or array; converts to RGB when loading from
path."""
if image is None:
if image_path is None:
raise ValueError("Either image_path or image must be provided")
return np.asarray(Image.open(image_path).convert("RGB"))

return image if isinstance(image, np.ndarray) else np.asarray(image)


def _ensure_rgb_chw(image: np.ndarray) -> np.ndarray:
"""Return 3-channel RGB in CHW order (no normalization).

Raises if grayscale or wrong shape.
"""
if image.ndim == 2:
raise ValueError("Grayscale images are not supported (expected 3-channel RGB)")
if image.ndim != 3:
raise ValueError(f"Expected 3D image array, got shape {image.shape}")

# Ensure channels-first (C, H, W)
if image.shape[0] == 3:
chw = image
elif image.shape[-1] == 3:
chw = np.moveaxis(image, -1, 0)
else:
raise ValueError(f"Expected image with 3 channels, got shape {image.shape}")

return np.ascontiguousarray(chw)


def _ensure_rgb_chw_float32(image: np.ndarray) -> np.ndarray:
"""Normalize to RGB CHW float32 in [0, 1].

Accepts HWC/CHW uint8 or float. Raises if invalid. For float images,
uses full-image heuristic (max > 1.0 -> divide by 255).
"""
chw = _ensure_rgb_chw(image)

# Normalize based primarily on dtype
if chw.dtype == np.uint8:
chw = chw.astype(np.float32)
chw /= 255.0
elif np.issubdtype(chw.dtype, np.floating):
if chw.dtype != np.float32:
chw = chw.astype(np.float32)

# Allow already-normalized float images.
# If values look like 0-255 floats, normalize.
max_val = float(chw.max())
min_val = float(chw.min())
if min_val < 0:
raise ValueError(
f"Expected float image in [0, 1] or [0, 255], got min {min_val}"
)
if max_val > 1.0:
if max_val <= 255.0:
if np.shares_memory(chw, image):
chw = chw.copy()
chw /= 255.0
else:
raise ValueError(
f"Expected float image in [0, 1] or [0, 255], got max {max_val}"
)
else:
# Integers other than uint8 are ambiguous; be explicit.
raise ValueError(f"Unsupported image dtype {chw.dtype}. Expected uint8 or float.")

return np.ascontiguousarray(chw)


# Base prediction class
class PredictionDataset(Dataset):
"""Base class for prediction datasets. Defines the common interface and
Expand Down Expand Up @@ -48,32 +123,9 @@ def __init__(
def load_and_preprocess_image(
self, image_path: str = None, image: np.ndarray | Image.Image = None
):
if image is None:
if image_path is None:
raise ValueError("Either image_path or image must be provided")
image = np.array(Image.open(image_path).convert("RGB"))
else:
image = np.array(image)
# If dtype is not float32, convert to float32
if image.dtype != "float32":
image = image.astype("float32")

# If image is not normalized, normalize to [0, 1]
if image.max() > 1 or image.min() < 0:
image = image / 255.0

# If image is not in CHW format, convert to CHW
if image.shape[0] != 3:
if image.shape[-1] != 3:
raise ValueError(
f"Expected 3 channel image, got image shape {image.shape}"
)
else:
image = np.rollaxis(image, 2, 0)

image = torch.from_numpy(image)

return image
image_arr = _load_image_array(image_path=image_path, image=image)
image_arr = _ensure_rgb_chw_float32(image_arr)
return torch.from_numpy(image_arr)

def prepare_items(self):
"""Prepare the items for the dataset.
Expand Down Expand Up @@ -169,9 +221,13 @@ def __init__(self, path=None, image=None, patch_size=400, patch_overlap=0):
)

def prepare_items(self):
self.image = self.load_and_preprocess_image(self.path, image=self.image)
image_arr = _load_image_array(image_path=self.path, image=self.image)
# Normalize full image once so all crops share consistent treatment
# (uniform across dtype, float [0,1] vs [0,255], and dark vs bright regions).
image_norm = _ensure_rgb_chw_float32(image_arr)
self.image = torch.from_numpy(image_norm)
self.windows = preprocess.compute_windows(
self.image, self.patch_size, self.patch_overlap
image_norm, self.patch_size, self.patch_overlap
)

def __len__(self):
Expand All @@ -182,8 +238,9 @@ def window_list(self):

def get_crop(self, idx):
crop = self.image[self.windows[idx].indices()]

return crop
# Clone to avoid in-place modification corrupting self.image when crop
# is a view (overlapping windows).
return crop.clone()

def get_image_basename(self, idx):
if self.path is not None:
Expand Down Expand Up @@ -433,33 +490,47 @@ class TiledRaster(PredictionDataset):
def __init__(self, path, patch_size, patch_overlap):
if path is None:
raise ValueError("path is required for a memory raster dataset")
self._src = None
super().__init__(path=path, patch_size=patch_size, patch_overlap=patch_overlap)

def prepare_items(self):
# Get raster shape without keeping file open
with rio.open(self.path) as src:
width = src.shape[0]
height = src.shape[1]

# Check is tiled
if not src.is_tiled:
raise ValueError(
"Out-of-memory dataset is selected, but raster is not tiled, "
"leading to entire raster being read into memory and defeating "
"the purpose of an out-of-memory dataset. "
"\nPlease run: "
"\ngdal_translate -of GTiff -co TILED=YES <input> <output> "
"to create a tiled raster"
)
# Open once; workers=0 is enforced by caller for this dataset.
self._src = rio.open(self.path)
height = self._src.height
width = self._src.width

# Warn on non-tiled rasters: window reads may still be efficient (strip-based),
# but performance can degrade depending on driver/strip layout.
if not self._src.is_tiled:
warnings.warn(
"dataloader_strategy='window' is selected, but raster is not tiled. "
"Windowed reads may be slower depending on file layout. If needed, "
"create a tiled GeoTIFF with: "
"gdal_translate -of GTiff -co TILED=YES <input> <output>",
stacklevel=2,
)

# Generate sliding windows
self.windows = slidingwindow.generateForSize(
all_windows = slidingwindow.generateForSize(
height,
width,
dimOrder=slidingwindow.DimOrder.ChannelHeightWidth,
maxWindowSize=self.patch_size,
overlapPercent=self.patch_overlap,
)
# Filter out invalid windows: zero-size or extending past raster bounds.
# Rasterio returns (C,0,W) or (C,H,0) for out-of-bounds reads, which breaks RetinaNet.
self.windows = [
w
for w in all_windows
if w.w > 0 and w.h > 0 and w.x + w.w <= width and w.y + w.h <= height
]
n_filtered = len(all_windows) - len(self.windows)
if n_filtered > 0:
warnings.warn(
f"Filtered {n_filtered} window(s) extending past raster bounds or zero-size.",
stacklevel=2,
)

def __len__(self):
return len(self.windows)
Expand All @@ -469,12 +540,19 @@ def window_list(self):

def get_crop(self, idx):
window = self.windows[idx]
with rio.open(self.path) as src:
window_data = src.read(window=Window(window.x, window.y, window.w, window.h))
assert self._src is not None, "Raster dataset is not open"
window_data = self._src.read(
window=Window(window.x, window.y, window.w, window.h)
)

# Rasterio already returns (C, H, W), just normalize and convert
window_data = window_data.astype("float32") / 255.0
window_data = torch.from_numpy(window_data).float()
if window_data.shape[1] == 0 or window_data.shape[2] == 0:
raise ValueError(
f"Window {idx} returned zero-size array (shape={window_data.shape}). "
"RetinaNet cannot process images with zero height or width."
)
window_data = window_data.astype(np.float32)
window_data /= 255.0
window_data = torch.from_numpy(window_data)
if window_data.shape[0] != 3:
raise ValueError(
f"Expected 3 channel image, got {window_data.shape[0]} channels"
Expand All @@ -487,3 +565,29 @@ def get_image_basename(self, idx):

def get_crop_bounds(self, idx):
return self.window_list()[idx]

def close(self) -> None:
"""Close the underlying raster dataset."""
if self._src is not None:
self._src.close()
self._src = None

def __getstate__(self) -> dict:
"""Make picklable for multiprocessing; rasterio handles are not
serializable."""
state = self.__dict__.copy()
state["_src"] = None # Exclude handle; __setstate__ will reopen in new process
return state

def __setstate__(self, state: dict) -> None:
"""Restore after unpickle; reopen raster since handle was excluded."""
self.__dict__.update(state)
if self._src is None and self.path is not None:
self._src = rio.open(self.path)

def __del__(self):
# Best-effort cleanup
try:
self.close()
except Exception:
pass
54 changes: 39 additions & 15 deletions src/deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,26 @@ def create_trainer(self, logger=None, callbacks=None, **kwargs):

self.trainer = pl.Trainer(**trainer_args)

# Helpful warning: CUDA visible but trainer not using it.
# This commonly happens if accelerator/devices were overridden to CPU, or
# if the trainer wasn't recreated after changing config.
try:
accel_name = type(self.trainer.accelerator).__name__.lower()
except Exception:
accel_name = ""

requested_accel = str(trainer_args.get("accelerator", "")).lower()
if torch.cuda.is_available() and requested_accel in {"auto", "gpu", "cuda"}:
if "cuda" not in accel_name and "gpu" not in accel_name:
warnings.warn(
"CUDA appears to be available, but the Lightning trainer is not "
f"using a GPU accelerator (accelerator={trainer_args.get('accelerator')}, "
f"devices={trainer_args.get('devices')}). "
"To force GPU inference, call create_trainer(accelerator='gpu', devices=1) "
"or set config.accelerator='gpu' and config.devices=1, then recreate the trainer.",
stacklevel=2,
)

def on_fit_start(self):
if self.config.train.csv_file is None:
raise AttributeError(
Expand Down Expand Up @@ -586,19 +606,24 @@ def predict_tile(
patch_size=patch_size,
)

dataloader = self.predict_dataloader(ds)
batched_results = self.trainer.predict(self, dataloader)

# Flatten list from batched prediction
# Track global window index across batches
global_window_idx = 0
for _idx, batch in enumerate(batched_results):
for _window_idx, window_result in enumerate(batch):
formatted_result = ds.postprocess(
window_result, global_window_idx
)
image_results.append(formatted_result)
global_window_idx += 1
try:
dataloader = self.predict_dataloader(ds)
batched_results = self.trainer.predict(self, dataloader)

# Flatten list from batched prediction
# Track global window index across batches
global_window_idx = 0
for _idx, batch in enumerate(batched_results):
for _window_idx, window_result in enumerate(batch):
formatted_result = ds.postprocess(
window_result, global_window_idx
)
image_results.append(formatted_result)
global_window_idx += 1
finally:
# Ensure raster datasets are closed even if predict/postprocess raises
if hasattr(ds, "close"):
ds.close()

if not image_results:
results = pd.DataFrame()
Expand Down Expand Up @@ -895,8 +920,7 @@ def predict_step(self, batch, batch_idx):

self.model.eval()
with torch.no_grad():
preds = self.model.forward(images)
return preds
return self.model.forward(images)

def predict_batch(self, images, preprocess_fn=None):
"""Predict a batch of images with the deepforest model.
Expand Down
Loading