Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion docs/user_guide/03_cropmodels.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,114 @@ While that approach is certainly valid, there are a few key benefits to using Cr
- **Simpler and Extendable**: CropModels decouple detection and classification workflows, allowing separate handling of challenges like class imbalance and incomplete labels, without reducing the quality of the detections. Two-stage object detection models can be finicky with similar classes and often require expertise in managing learning rates.
- **New Data and Multi-sensor Learning**: In many applications, the data needed for detection and classification may differ. The CropModel concept provides an extendable piece that allows for advanced pipelines.

(spatial-temporal-metadata)=
## Spatial-Temporal Metadata

In biodiversity monitoring, species distributions vary by location and season. A bird common in Florida may be rare in Alaska, and migratory species shift seasonally. The CropModel supports an optional spatial-temporal metadata embedding that provides location and date context alongside image features to improve classification.

The metadata signal is intentionally "gentle" — it contributes only ~1.5% of the feature vector (32 dimensions vs. 2048 image features). This means the model still classifies primarily from visual appearance but can use location/season as a soft prior. When metadata is not provided at inference time, the model gracefully degrades to image-only classification.

### How It Works

When `use_metadata=True`, the CropModel:

1. Encodes `(lat, lon, day_of_year)` using sinusoidal features (smooth, periodic representation)
2. Projects the 6 sinusoidal features through a small MLP to a 32-dim embedding
3. Concatenates this with the 2048-dim ResNet image features
4. Classifies from the combined 2080-dim vector

### Inference with Metadata

Pass a `metadata` dict to `predict_tile`:

```python
from deepforest import main
from deepforest.model import CropModel

m = main.deepforest()
m.create_trainer()

crop_model = CropModel(config_args={"use_metadata": True})
crop_model.load_from_disk(train_dir="path/to/train", val_dir="path/to/val",
metadata_csv="metadata.csv")
crop_model.create_trainer(max_epochs=10)
crop_model.trainer.fit(crop_model)

result = m.predict_tile(
path="image.tif",
crop_model=crop_model,
metadata={"lat": 35.2, "lon": -120.4, "date": "2024-06-15"}
)
```

All detected crops in the tile share the same metadata. If `metadata` is omitted, the model falls back to image-only classification.

### Training with Metadata

Training requires a CSV sidecar file that maps each crop image filename to its spatial-temporal metadata:

```text
filename,lat,lon,date
bird_001.png,35.2,-120.4,2024-06-15
bird_002.png,35.2,-120.4,2024-06-15
mammal_001.png,40.1,-105.3,2024-07-20
```

- `filename` matches the image basename inside the ImageFolder class directories
- `date` is an ISO format string, converted to day-of-year internally
- One CSV covers both train and val sets (filenames are unique)

The existing ImageFolder directory structure is unchanged:

```
train/
Bird/
bird_001.png
bird_002.png
Mammal/
mammal_001.png
```

Pass the CSV when loading data:

```python
from deepforest.model import CropModel

crop_model = CropModel(config_args={"use_metadata": True})
crop_model.load_from_disk(
train_dir="path/to/train",
val_dir="path/to/val",
metadata_csv="metadata.csv"
)
crop_model.create_trainer(max_epochs=10)
crop_model.trainer.fit(crop_model)
```

### Configuration

The metadata embedding is controlled by three config parameters:

```python
crop_model = CropModel(config_args={
"use_metadata": True, # Enable metadata fusion (default: False)
"metadata_dim": 32, # Embedding dimension (default: 32)
"metadata_dropout": 0.5, # Dropout on metadata path (default: 0.5)
})
```

Or in `config.yaml`:

```yaml
cropmodel:
use_metadata: True
metadata_dim: 32
metadata_dropout: 0.5
```

## Considerations

- **Efficiency**: Using a CropModel will be slower, as for each detection, the sensor data needs to be cropped and passed to the detector. This is less efficient than using a combined classification/detection system like multi-class detection models. While modern GPUs mitigate this to some extent, it is still something to be mindful of.
- **Lack of Spatial Awareness**: The model knows only about the pixels inside the crop and cannot use features outside the bounding box. This lack of spatial awareness can be a major limitation. It is possible, but untested, that multi-class detection models might perform better in such tasks. A box attention mechanism, like in [this paper](https://arxiv.org/abs/2111.13087), could be a better approach.
- **Lack of Spatial Awareness**: The model knows only about the pixels inside the crop and cannot use features outside the bounding box. This lack of spatial awareness can be a major limitation. It is possible, but untested, that multi-class detection models might perform better in such tasks. A box attention mechanism, like in [this paper](https://arxiv.org/abs/2111.13087), could be a better approach. See the {ref}`spatial-temporal-metadata` section for an optional way to incorporate location and season information.

## Single Crop Model

Expand Down
12 changes: 12 additions & 0 deletions docs/user_guide/09_configuration_file.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,15 @@ crop_model = CropModel()
# Or use custom resize dimensions
crop_model = CropModel(config_args={"resize": [300, 300]})
```

### use_metadata

Boolean flag to enable spatial-temporal metadata fusion. When `True`, the model accepts `(lat, lon, date)` alongside image crops and learns a small embedding that is concatenated with image features. Default is `False`. See {ref}`spatial-temporal-metadata` for usage details.

### metadata_dim

Dimension of the metadata embedding vector. A smaller value makes the metadata signal more gentle relative to the 2048-dim image features. Default is `32`.

### metadata_dropout

Dropout rate applied to the metadata embedding path. Higher values reduce the model's reliance on location/date information. Default is `0.5`.
5 changes: 5 additions & 0 deletions src/deepforest/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,8 @@ cropmodel:
- 224
# Number of pixels to expand bbox crop windows for better prediction context.
expand: 0
# Spatial-temporal metadata fusion (optional).
# When True, the model accepts (lat, lon, date) alongside image crops.
use_metadata: False
metadata_dim: 32
metadata_dropout: 0.5
3 changes: 3 additions & 0 deletions src/deepforest/conf/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ class CropModelConfig:
balance_classes: bool = False
resize: list[int] = field(default_factory=lambda: [224, 224])
expand: int = 0
use_metadata: bool = False
metadata_dim: int = 32
metadata_dropout: float = 0.5


@dataclass
Expand Down
11 changes: 11 additions & 0 deletions src/deepforest/datasets/cropmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
import rasterio as rio
import torch
from rasterio.windows import Window
from torch.utils.data import Dataset
from torchvision import transforms
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
augmentations=None,
resize=None,
expand: int = 0,
metadata=None,
):
self.df = df

Expand All @@ -80,6 +82,10 @@ def __init__(
raise ValueError("expand must be >= 0")
self.expand = int(expand)

# Optional spatial-temporal metadata per crop.
# Dict mapping crop index to (lat, lon, day_of_year).
self.metadata = metadata

unique_image = self.df["image_path"].unique()
assert len(unique_image) == 1, (
"There should be only one unique image for this class object"
Expand Down Expand Up @@ -129,4 +135,9 @@ def __getitem__(self, idx):
else:
image = box

if self.metadata is not None:
lat, lon, doy = self.metadata[idx]
meta_tensor = torch.tensor([lat, lon, doy], dtype=torch.float32)
return image, meta_tensor

return image
64 changes: 64 additions & 0 deletions src/deepforest/datasets/training.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Dataset model for object detection tasks."""

import datetime
import math
import os

import kornia.augmentation as K
import numpy as np
import pandas as pd
import shapely
import torch
from PIL import Image
Expand Down Expand Up @@ -346,3 +348,65 @@ def _classes_in(root):
)

return train_ds, val_ds


class MetadataImageFolder(Dataset):
"""Wrapper that adds spatial-temporal metadata to an ImageFolder dataset.

Expects a CSV sidecar file with columns: filename, lat, lon, date.
The date column should be an ISO format string (e.g., "2024-06-15")
and will be converted to day_of_year internally.

Args:
image_folder: A FixedClassImageFolder (or ImageFolder) dataset.
metadata_csv: Path to CSV with columns filename, lat, lon, date.

Returns per sample:
(image, label, metadata_tensor) where metadata_tensor is shape (3,)
containing [lat, lon, day_of_year].
"""

def __init__(self, image_folder, metadata_csv):
self.image_folder = image_folder
metadata_df = pd.read_csv(metadata_csv)
self._meta_lookup = {}
for _, row in metadata_df.iterrows():
date = datetime.datetime.strptime(str(row["date"]), "%Y-%m-%d")
doy = float(date.timetuple().tm_yday)
self._meta_lookup[row["filename"]] = (
float(row["lat"]),
float(row["lon"]),
doy,
)

def __len__(self):
return len(self.image_folder)

def __getitem__(self, idx):
image, label = self.image_folder[idx]
filepath = self.image_folder.samples[idx][0]
filename = os.path.basename(filepath)

if filename in self._meta_lookup:
lat, lon, doy = self._meta_lookup[filename]
else:
lat, lon, doy = 0.0, 0.0, 1.0

metadata = torch.tensor([lat, lon, doy], dtype=torch.float32)
return image, label, metadata

@property
def targets(self):
return self.image_folder.targets

@property
def class_to_idx(self):
return self.image_folder.class_to_idx

@property
def samples(self):
return self.image_folder.samples

@property
def imgs(self):
return self.image_folder.imgs
33 changes: 32 additions & 1 deletion src/deepforest/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# entry point for deepforest model
import datetime
import importlib
import os
import warnings
Expand Down Expand Up @@ -523,6 +524,7 @@ def predict_tile(
iou_threshold=0.15,
dataloader_strategy="single",
crop_model=None,
metadata=None,
):
"""For images too large to input into the model, predict_tile cuts the
image into overlapping windows, predicts trees on each window and
Expand All @@ -539,6 +541,10 @@ def predict_tile(
- "batch" loads the entire image into GPU memory and creates views of an image as batch, requires in the entire tile to fit into GPU memory. CPU parallelization is possible for loading images.
- "window" loads only the desired window of the image from the raster dataset. Most memory efficient option, but cannot parallelize across windows.
crop_model: a deepforest.model.CropModel object to predict on crops
metadata: Optional dict with keys "lat", "lon", "date" for
spatial-temporal context. "date" should be an ISO format
string (e.g., "2024-06-15"). Used by CropModel when
use_metadata=True in config.

Returns:
pd.DataFrame or tuple: Predictions dataframe or (predictions, crops) tuple
Expand Down Expand Up @@ -664,6 +670,20 @@ def predict_tile(
root_dir = None

if crop_model is not None:
# Build per-crop metadata from image-level metadata dict
if metadata is not None:
date_str = metadata.get("date", None)
if date_str is not None:
doy = float(
datetime.datetime.strptime(str(date_str), "%Y-%m-%d")
.timetuple()
.tm_yday
)
else:
doy = 1.0
lat = float(metadata.get("lat", 0.0))
lon = float(metadata.get("lon", 0.0))

cropmodel_results = []
for path in paths:
image_result = mosaic_results[
Expand All @@ -672,8 +692,19 @@ def predict_tile(
if image_result.empty:
continue
image_result.root_dir = os.path.dirname(path)

# Create per-crop metadata dict if metadata was provided
per_crop_metadata = None
if metadata is not None:
per_crop_metadata = dict.fromkeys(
range(len(image_result)), (lat, lon, doy)
)

cropmodel_result = predict._crop_models_wrapper_(
crop_model, self.trainer, image_result
crop_model,
self.trainer,
image_result,
metadata=per_crop_metadata,
)
cropmodel_results.append(cropmodel_result)
cropmodel_results = pd.concat(cropmodel_results)
Expand Down
Loading
Loading