diff --git a/docs/user_guide/03_cropmodels.md b/docs/user_guide/03_cropmodels.md index 947854c7f..d06e06723 100644 --- a/docs/user_guide/03_cropmodels.md +++ b/docs/user_guide/03_cropmodels.md @@ -18,10 +18,114 @@ While that approach is certainly valid, there are a few key benefits to using Cr - **Simpler and Extendable**: CropModels decouple detection and classification workflows, allowing separate handling of challenges like class imbalance and incomplete labels, without reducing the quality of the detections. Two-stage object detection models can be finicky with similar classes and often require expertise in managing learning rates. - **New Data and Multi-sensor Learning**: In many applications, the data needed for detection and classification may differ. The CropModel concept provides an extendable piece that allows for advanced pipelines. +(spatial-temporal-metadata)= +## Spatial-Temporal Metadata + +In biodiversity monitoring, species distributions vary by location and season. A bird common in Florida may be rare in Alaska, and migratory species shift seasonally. The CropModel supports an optional spatial-temporal metadata embedding that provides location and date context alongside image features to improve classification. + +The metadata signal is intentionally "gentle" — it contributes only ~1.5% of the feature vector (32 dimensions vs. 2048 image features). This means the model still classifies primarily from visual appearance but can use location/season as a soft prior. When metadata is not provided at inference time, the model gracefully degrades to image-only classification. + +### How It Works + +When `use_metadata=True`, the CropModel: + +1. Encodes `(lat, lon, day_of_year)` using sinusoidal features (smooth, periodic representation) +2. Projects the 6 sinusoidal features through a small MLP to a 32-dim embedding +3. Concatenates this with the 2048-dim ResNet image features +4. Classifies from the combined 2080-dim vector + +### Inference with Metadata + +Pass a `metadata` dict to `predict_tile`: + +```python +from deepforest import main +from deepforest.model import CropModel + +m = main.deepforest() +m.create_trainer() + +crop_model = CropModel(config_args={"use_metadata": True}) +crop_model.load_from_disk(train_dir="path/to/train", val_dir="path/to/val", + metadata_csv="metadata.csv") +crop_model.create_trainer(max_epochs=10) +crop_model.trainer.fit(crop_model) + +result = m.predict_tile( + path="image.tif", + crop_model=crop_model, + metadata={"lat": 35.2, "lon": -120.4, "date": "2024-06-15"} +) +``` + +All detected crops in the tile share the same metadata. If `metadata` is omitted, the model falls back to image-only classification. + +### Training with Metadata + +Training requires a CSV sidecar file that maps each crop image filename to its spatial-temporal metadata: + +```text +filename,lat,lon,date +bird_001.png,35.2,-120.4,2024-06-15 +bird_002.png,35.2,-120.4,2024-06-15 +mammal_001.png,40.1,-105.3,2024-07-20 +``` + +- `filename` matches the image basename inside the ImageFolder class directories +- `date` is an ISO format string, converted to day-of-year internally +- One CSV covers both train and val sets (filenames are unique) + +The existing ImageFolder directory structure is unchanged: + +``` +train/ + Bird/ + bird_001.png + bird_002.png + Mammal/ + mammal_001.png +``` + +Pass the CSV when loading data: + +```python +from deepforest.model import CropModel + +crop_model = CropModel(config_args={"use_metadata": True}) +crop_model.load_from_disk( + train_dir="path/to/train", + val_dir="path/to/val", + metadata_csv="metadata.csv" +) +crop_model.create_trainer(max_epochs=10) +crop_model.trainer.fit(crop_model) +``` + +### Configuration + +The metadata embedding is controlled by three config parameters: + +```python +crop_model = CropModel(config_args={ + "use_metadata": True, # Enable metadata fusion (default: False) + "metadata_dim": 32, # Embedding dimension (default: 32) + "metadata_dropout": 0.5, # Dropout on metadata path (default: 0.5) +}) +``` + +Or in `config.yaml`: + +```yaml +cropmodel: + use_metadata: True + metadata_dim: 32 + metadata_dropout: 0.5 +``` + ## Considerations - **Efficiency**: Using a CropModel will be slower, as for each detection, the sensor data needs to be cropped and passed to the detector. This is less efficient than using a combined classification/detection system like multi-class detection models. While modern GPUs mitigate this to some extent, it is still something to be mindful of. -- **Lack of Spatial Awareness**: The model knows only about the pixels inside the crop and cannot use features outside the bounding box. This lack of spatial awareness can be a major limitation. It is possible, but untested, that multi-class detection models might perform better in such tasks. A box attention mechanism, like in [this paper](https://arxiv.org/abs/2111.13087), could be a better approach. +- **Lack of Spatial Awareness**: The model knows only about the pixels inside the crop and cannot use features outside the bounding box. This lack of spatial awareness can be a major limitation. It is possible, but untested, that multi-class detection models might perform better in such tasks. A box attention mechanism, like in [this paper](https://arxiv.org/abs/2111.13087), could be a better approach. See the {ref}`spatial-temporal-metadata` section for an optional way to incorporate location and season information. ## Single Crop Model diff --git a/docs/user_guide/09_configuration_file.md b/docs/user_guide/09_configuration_file.md index bda0de7ba..c655b5f77 100644 --- a/docs/user_guide/09_configuration_file.md +++ b/docs/user_guide/09_configuration_file.md @@ -319,3 +319,15 @@ crop_model = CropModel() # Or use custom resize dimensions crop_model = CropModel(config_args={"resize": [300, 300]}) ``` + +### use_metadata + +Boolean flag to enable spatial-temporal metadata fusion. When `True`, the model accepts `(lat, lon, date)` alongside image crops and learns a small embedding that is concatenated with image features. Default is `False`. See {ref}`spatial-temporal-metadata` for usage details. + +### metadata_dim + +Dimension of the metadata embedding vector. A smaller value makes the metadata signal more gentle relative to the 2048-dim image features. Default is `32`. + +### metadata_dropout + +Dropout rate applied to the metadata embedding path. Higher values reduce the model's reliance on location/date information. Default is `0.5`. diff --git a/src/deepforest/conf/config.yaml b/src/deepforest/conf/config.yaml index 8655f9281..834141093 100644 --- a/src/deepforest/conf/config.yaml +++ b/src/deepforest/conf/config.yaml @@ -119,3 +119,8 @@ cropmodel: - 224 # Number of pixels to expand bbox crop windows for better prediction context. expand: 0 + # Spatial-temporal metadata fusion (optional). + # When True, the model accepts (lat, lon, date) alongside image crops. + use_metadata: False + metadata_dim: 32 + metadata_dropout: 0.5 diff --git a/src/deepforest/conf/schema.py b/src/deepforest/conf/schema.py index d0aa84806..6dd40b42c 100644 --- a/src/deepforest/conf/schema.py +++ b/src/deepforest/conf/schema.py @@ -108,6 +108,9 @@ class CropModelConfig: balance_classes: bool = False resize: list[int] = field(default_factory=lambda: [224, 224]) expand: int = 0 + use_metadata: bool = False + metadata_dim: int = 32 + metadata_dropout: float = 0.5 @dataclass diff --git a/src/deepforest/datasets/cropmodel.py b/src/deepforest/datasets/cropmodel.py index 7ae0e0399..d5cb9f1cb 100644 --- a/src/deepforest/datasets/cropmodel.py +++ b/src/deepforest/datasets/cropmodel.py @@ -8,6 +8,7 @@ import numpy as np import rasterio as rio +import torch from rasterio.windows import Window from torch.utils.data import Dataset from torchvision import transforms @@ -65,6 +66,7 @@ def __init__( augmentations=None, resize=None, expand: int = 0, + metadata=None, ): self.df = df @@ -80,6 +82,10 @@ def __init__( raise ValueError("expand must be >= 0") self.expand = int(expand) + # Optional spatial-temporal metadata per crop. + # Dict mapping crop index to (lat, lon, day_of_year). + self.metadata = metadata + unique_image = self.df["image_path"].unique() assert len(unique_image) == 1, ( "There should be only one unique image for this class object" @@ -129,4 +135,9 @@ def __getitem__(self, idx): else: image = box + if self.metadata is not None: + lat, lon, doy = self.metadata[idx] + meta_tensor = torch.tensor([lat, lon, doy], dtype=torch.float32) + return image, meta_tensor + return image diff --git a/src/deepforest/datasets/training.py b/src/deepforest/datasets/training.py index b3bdf085e..c4c998d32 100644 --- a/src/deepforest/datasets/training.py +++ b/src/deepforest/datasets/training.py @@ -1,10 +1,12 @@ """Dataset model for object detection tasks.""" +import datetime import math import os import kornia.augmentation as K import numpy as np +import pandas as pd import shapely import torch from PIL import Image @@ -346,3 +348,65 @@ def _classes_in(root): ) return train_ds, val_ds + + +class MetadataImageFolder(Dataset): + """Wrapper that adds spatial-temporal metadata to an ImageFolder dataset. + + Expects a CSV sidecar file with columns: filename, lat, lon, date. + The date column should be an ISO format string (e.g., "2024-06-15") + and will be converted to day_of_year internally. + + Args: + image_folder: A FixedClassImageFolder (or ImageFolder) dataset. + metadata_csv: Path to CSV with columns filename, lat, lon, date. + + Returns per sample: + (image, label, metadata_tensor) where metadata_tensor is shape (3,) + containing [lat, lon, day_of_year]. + """ + + def __init__(self, image_folder, metadata_csv): + self.image_folder = image_folder + metadata_df = pd.read_csv(metadata_csv) + self._meta_lookup = {} + for _, row in metadata_df.iterrows(): + date = datetime.datetime.strptime(str(row["date"]), "%Y-%m-%d") + doy = float(date.timetuple().tm_yday) + self._meta_lookup[row["filename"]] = ( + float(row["lat"]), + float(row["lon"]), + doy, + ) + + def __len__(self): + return len(self.image_folder) + + def __getitem__(self, idx): + image, label = self.image_folder[idx] + filepath = self.image_folder.samples[idx][0] + filename = os.path.basename(filepath) + + if filename in self._meta_lookup: + lat, lon, doy = self._meta_lookup[filename] + else: + lat, lon, doy = 0.0, 0.0, 1.0 + + metadata = torch.tensor([lat, lon, doy], dtype=torch.float32) + return image, label, metadata + + @property + def targets(self): + return self.image_folder.targets + + @property + def class_to_idx(self): + return self.image_folder.class_to_idx + + @property + def samples(self): + return self.image_folder.samples + + @property + def imgs(self): + return self.image_folder.imgs diff --git a/src/deepforest/main.py b/src/deepforest/main.py index f74c24228..73933892a 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -1,4 +1,5 @@ # entry point for deepforest model +import datetime import importlib import os import warnings @@ -523,6 +524,7 @@ def predict_tile( iou_threshold=0.15, dataloader_strategy="single", crop_model=None, + metadata=None, ): """For images too large to input into the model, predict_tile cuts the image into overlapping windows, predicts trees on each window and @@ -539,6 +541,10 @@ def predict_tile( - "batch" loads the entire image into GPU memory and creates views of an image as batch, requires in the entire tile to fit into GPU memory. CPU parallelization is possible for loading images. - "window" loads only the desired window of the image from the raster dataset. Most memory efficient option, but cannot parallelize across windows. crop_model: a deepforest.model.CropModel object to predict on crops + metadata: Optional dict with keys "lat", "lon", "date" for + spatial-temporal context. "date" should be an ISO format + string (e.g., "2024-06-15"). Used by CropModel when + use_metadata=True in config. Returns: pd.DataFrame or tuple: Predictions dataframe or (predictions, crops) tuple @@ -664,6 +670,20 @@ def predict_tile( root_dir = None if crop_model is not None: + # Build per-crop metadata from image-level metadata dict + if metadata is not None: + date_str = metadata.get("date", None) + if date_str is not None: + doy = float( + datetime.datetime.strptime(str(date_str), "%Y-%m-%d") + .timetuple() + .tm_yday + ) + else: + doy = 1.0 + lat = float(metadata.get("lat", 0.0)) + lon = float(metadata.get("lon", 0.0)) + cropmodel_results = [] for path in paths: image_result = mosaic_results[ @@ -672,8 +692,19 @@ def predict_tile( if image_result.empty: continue image_result.root_dir = os.path.dirname(path) + + # Create per-crop metadata dict if metadata was provided + per_crop_metadata = None + if metadata is not None: + per_crop_metadata = dict.fromkeys( + range(len(image_result)), (lat, lon, doy) + ) + cropmodel_result = predict._crop_models_wrapper_( - crop_model, self.trainer, image_result + crop_model, + self.trainer, + image_result, + metadata=per_crop_metadata, ) cropmodel_results.append(cropmodel_result) cropmodel_results = pd.concat(cropmodel_results) diff --git a/src/deepforest/model.py b/src/deepforest/model.py index 9637fab55..d2184a9cd 100644 --- a/src/deepforest/model.py +++ b/src/deepforest/model.py @@ -1,5 +1,6 @@ # Model - common class import json +import math import os import numpy as np @@ -81,6 +82,69 @@ def simple_resnet_50(num_classes: int = 2) -> torch.nn.Module: return m +def resnet50_backbone(): + """Create a ResNet-50 backbone that outputs 2048-dim feature vectors. + + Returns: + tuple: (backbone, feature_dim) where backbone is the model and + feature_dim is the output dimension (2048). + """ + m = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) + feature_dim = m.fc.in_features + m.fc = torch.nn.Identity() + return m, feature_dim + + +class SpatialTemporalEncoder(torch.nn.Module): + """Encode (lat, lon, day_of_year) into a fixed-size embedding. + + Uses sinusoidal features for smooth, periodic representation of + geographic coordinates and seasonality, followed by a small MLP. + + Args: + embed_dim: Output embedding dimension. Default 32. + dropout: Dropout rate on the embedding. Default 0.5. + + Input: + metadata: tensor of shape (batch, 3) with [lat, lon, day_of_year]. + lat in [-90, 90], lon in [-180, 180], day_of_year in [1, 366]. + + Output: + tensor of shape (batch, embed_dim). + """ + + def __init__(self, embed_dim: int = 32, dropout: float = 0.5): + super().__init__() + self.mlp = torch.nn.Sequential( + torch.nn.Linear(6, embed_dim), + torch.nn.ReLU(), + torch.nn.Dropout(dropout), + ) + + def forward(self, metadata): + lat = metadata[:, 0:1] + lon = metadata[:, 1:2] + doy = metadata[:, 2:3] + + lat_norm = lat / 90.0 + lon_norm = lon / 180.0 + doy_norm = (doy - 1) / 365.0 + + features = torch.cat( + [ + torch.sin(math.pi * lat_norm), + torch.cos(math.pi * lat_norm), + torch.sin(math.pi * lon_norm), + torch.cos(math.pi * lon_norm), + torch.sin(2 * math.pi * doy_norm), + torch.cos(2 * math.pi * doy_norm), + ], + dim=1, + ) + + return self.mlp(features) + + class CropModel(LightningModule, PyTorchModelHubMixin): """A PyTorch Lightning module for classifying image crops from object detection models. @@ -112,6 +176,9 @@ def __init__( super().__init__() self.model = model + self.backbone = None + self.metadata_encoder = None + self.classifier = None # Set the argument as the self.config, this way when reloading the checkpoint, self.config exists and is not overwritten. self.config = config if self.config is None: @@ -162,18 +229,39 @@ def create_model(self, num_classes): } ) - self.model = simple_resnet_50(num_classes=num_classes) + use_metadata = self.config["cropmodel"].get("use_metadata", False) + + if use_metadata: + metadata_dim = self.config["cropmodel"].get("metadata_dim", 32) + metadata_dropout = self.config["cropmodel"].get("metadata_dropout", 0.5) + + backbone, feature_dim = resnet50_backbone() + self.backbone = backbone + self.metadata_encoder = SpatialTemporalEncoder( + embed_dim=metadata_dim, dropout=metadata_dropout + ) + self.classifier = torch.nn.Linear(feature_dim + metadata_dim, num_classes) + self.model = None + else: + self.backbone = None + self.metadata_encoder = None + self.classifier = None + self.model = simple_resnet_50(num_classes=num_classes) def create_trainer(self, **kwargs): """Create a pytorch lightning trainer object.""" self.trainer = Trainer(**kwargs) - def load_from_disk(self, train_dir, val_dir): + def load_from_disk(self, train_dir, val_dir, metadata_csv=None): """Load the training and validation datasets from disk. Args: train_dir (str): The directory containing the training dataset. val_dir (str): The directory containing the validation dataset. + metadata_csv (str, optional): Path to a CSV file mapping image + filenames to spatial-temporal metadata. The CSV should have + columns: filename, lat, lon, date. Required when + use_metadata=True in config. Defaults to None. Returns: None @@ -184,6 +272,15 @@ def load_from_disk(self, train_dir, val_dir): transform_train=self.get_transform(augmentations=["HorizontalFlip"]), transform_val=self.get_transform(augmentations=None), ) + + if metadata_csv is not None and self.config["cropmodel"].get( + "use_metadata", False + ): + from deepforest.datasets.training import MetadataImageFolder + + self.train_ds = MetadataImageFolder(self.train_ds, metadata_csv) + self.val_ds = MetadataImageFolder(self.val_ds, metadata_csv) + self.label_dict = self.train_ds.class_to_idx # Create a reverse mapping from numeric indices to class labels @@ -191,7 +288,7 @@ def load_from_disk(self, train_dir, val_dir): self.num_classes = len(self.label_dict) - if self.model is None: + if self.model is None and self.backbone is None: self.create_model(self.num_classes) def get_transform(self, augmentations): @@ -309,14 +406,24 @@ def write_crops(self, root_dir, images, boxes, labels, savedir): def normalize(self): return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) - def forward(self, x): - if self.model is None: + def forward(self, x, metadata=None): + if self.backbone is not None: + image_features = self.backbone(x) + if metadata is not None: + meta_features = self.metadata_encoder(metadata) + else: + meta_dim = self.classifier.in_features - image_features.shape[1] + meta_features = torch.zeros( + x.shape[0], meta_dim, device=x.device, dtype=x.dtype + ) + combined = torch.cat([image_features, meta_features], dim=1) + return self.classifier(combined) + elif self.model is not None: + return self.model(x) + else: raise AttributeError( "CropModel is not initialized. Provide 'num_classes' or load from a checkpoint." ) - output = self.model(x) - - return output def train_dataloader(self): """Train data loader.""" @@ -371,20 +478,27 @@ def val_dataloader(self): return val_loader def training_step(self, batch, batch_idx): - x, y = batch - outputs = self.forward(x) + if len(batch) == 3: + x, y, metadata = batch + else: + x, y = batch + metadata = None + outputs = self.forward(x, metadata=metadata) loss = F.cross_entropy(outputs, y) self.log("train_loss", loss) return loss def predict_step(self, batch, batch_idx): - # Check if batch is a tuple for validation_dataloader - if isinstance(batch, list): - x, y = batch + # Inference: batch may be (images, metadata), (images, labels, metadata), or a single images tensor. + if isinstance(batch, (list, tuple)) and len(batch) == 3: + images, _labels, metadata = batch + elif isinstance(batch, (list, tuple)) and len(batch) == 2: + images, metadata = batch else: - x = batch - outputs = self.forward(x) + images = batch + metadata = None + outputs = self.forward(images, metadata=metadata) yhat = F.softmax(outputs, 1) return yhat @@ -398,8 +512,12 @@ def postprocess_predictions(self, predictions): return label, score def validation_step(self, batch, batch_idx): - x, y = batch - outputs = self(x) + if len(batch) == 3: + x, y, metadata = batch + else: + x, y = batch + metadata = None + outputs = self(x, metadata=metadata) loss = F.cross_entropy(outputs, y) self.log("val_loss", loss) diff --git a/src/deepforest/predict.py b/src/deepforest/predict.py index 1be09dae2..90378eca6 100644 --- a/src/deepforest/predict.py +++ b/src/deepforest/predict.py @@ -254,6 +254,7 @@ def _predict_crop_model_( augmentations=None, model_index=0, is_single_model=False, + metadata=None, ): """Predicts crop model on a raster file. @@ -290,6 +291,7 @@ def _predict_crop_model_( augmentations=augmentations, resize=resize, expand=expand, + metadata=metadata, ) # Create dataloader @@ -325,7 +327,7 @@ def _predict_crop_model_( def _crop_models_wrapper_( - crop_models, trainer, results, transform=None, augmentations=None + crop_models, trainer, results, transform=None, augmentations=None, metadata=None ): if crop_models is not None and not isinstance(crop_models, list): crop_models = [crop_models] @@ -348,6 +350,7 @@ def _crop_models_wrapper_( transform=transform, augmentations=augmentations, is_single_model=is_single_model, + metadata=metadata, ) crop_results.append(crop_result) diff --git a/tests/test_metadata_cropmodel.py b/tests/test_metadata_cropmodel.py new file mode 100644 index 000000000..af431b5a8 --- /dev/null +++ b/tests/test_metadata_cropmodel.py @@ -0,0 +1,366 @@ +"""Tests for spatial-temporal metadata embeddings in CropModel.""" + +import os + +import numpy as np +import pandas as pd +import pytest +import torch + +from deepforest import get_data, model +from deepforest.datasets.cropmodel import BoundingBoxDataset +from deepforest.model import CropModel, SpatialTemporalEncoder + + +# --- SpatialTemporalEncoder unit tests --- + + +def test_spatial_temporal_encoder_output_shape(): + enc = SpatialTemporalEncoder(embed_dim=32, dropout=0.0) + meta = torch.tensor([[35.0, -120.0, 145.0], [0.0, 0.0, 1.0]]) + out = enc(meta) + assert out.shape == (2, 32) + + +def test_spatial_temporal_encoder_custom_dim(): + enc = SpatialTemporalEncoder(embed_dim=64, dropout=0.0) + meta = torch.tensor([[35.0, -120.0, 145.0]]) + out = enc(meta) + assert out.shape == (1, 64) + + +def test_spatial_temporal_encoder_zeros(): + enc = SpatialTemporalEncoder(embed_dim=32, dropout=0.0) + meta = torch.zeros(3, 3) + out = enc(meta) + assert out.shape == (3, 32) + + +def test_spatial_temporal_encoder_boundary_values(): + """Test extreme lat/lon/doy values.""" + enc = SpatialTemporalEncoder(embed_dim=32, dropout=0.0) + meta = torch.tensor([ + [90.0, 180.0, 366.0], # Max values + [-90.0, -180.0, 1.0], # Min values + ]) + out = enc(meta) + assert out.shape == (2, 32) + assert torch.isfinite(out).all() + + +# --- CropModel with metadata: forward pass --- + + +def test_crop_model_metadata_forward(): + cm = CropModel(config_args={"use_metadata": True, "metadata_dim": 32}) + cm.create_model(num_classes=5) + x = torch.rand(4, 3, 224, 224) + meta = torch.tensor([[35.0, -120.0, 145.0]] * 4) + out = cm.forward(x, metadata=meta) + assert out.shape == (4, 5) + + +def test_crop_model_metadata_none_graceful_degradation(): + """When use_metadata=True but metadata=None, model should still predict.""" + cm = CropModel(config_args={"use_metadata": True}) + cm.create_model(num_classes=5) + x = torch.rand(4, 3, 224, 224) + out = cm.forward(x, metadata=None) + assert out.shape == (4, 5) + + +def test_crop_model_metadata_custom_dim(): + cm = CropModel(config_args={"use_metadata": True, "metadata_dim": 16}) + cm.create_model(num_classes=3) + x = torch.rand(2, 3, 224, 224) + meta = torch.tensor([[40.0, -105.0, 200.0]] * 2) + out = cm.forward(x, metadata=meta) + assert out.shape == (2, 3) + + +# --- Backward compatibility --- + + +def test_crop_model_no_metadata_backward_compat(): + cm = CropModel() + cm.create_model(num_classes=2) + x = torch.rand(4, 3, 224, 224) + out = cm.forward(x) + assert out.shape == (4, 2) + assert cm.backbone is None + assert cm.metadata_encoder is None + assert cm.classifier is None + + +# --- Training/validation/predict steps --- + + +def test_training_step_with_metadata(): + cm = CropModel(config_args={"use_metadata": True}) + cm.create_model(num_classes=3) + x = torch.rand(4, 3, 224, 224) + y = torch.tensor([0, 1, 2, 0]) + meta = torch.rand(4, 3) + batch = (x, y, meta) + loss = cm.training_step(batch, 0) + assert isinstance(loss, torch.Tensor) + assert loss.ndim == 0 + + +def test_training_step_without_metadata(): + cm = CropModel() + cm.create_model(num_classes=2) + x = torch.rand(4, 3, 224, 224) + y = torch.tensor([0, 1, 0, 1]) + batch = (x, y) + loss = cm.training_step(batch, 0) + assert isinstance(loss, torch.Tensor) + + +def test_validation_step_with_metadata(): + cm = CropModel(config_args={"use_metadata": True}) + cm.create_model(num_classes=3) + cm.label_dict = {"A": 0, "B": 1, "C": 2} + cm.numeric_to_label_dict = {0: "A", 1: "B", 2: "C"} + x = torch.rand(4, 3, 224, 224) + y = torch.tensor([0, 1, 2, 0]) + meta = torch.rand(4, 3) + batch = (x, y, meta) + loss = cm.validation_step(batch, 0) + assert isinstance(loss, torch.Tensor) + + +def test_predict_step_with_metadata(): + cm = CropModel(config_args={"use_metadata": True}) + cm.create_model(num_classes=3) + x = torch.rand(4, 3, 224, 224) + meta = torch.rand(4, 3) + batch = (x, meta) + yhat = cm.predict_step(batch, 0) + assert yhat.shape == (4, 3) + # Softmax output should sum to ~1 + assert torch.allclose(yhat.sum(dim=1), torch.ones(4), atol=1e-5) + + +def test_predict_step_image_only(): + """BoundingBoxDataset returns just image tensor when no metadata.""" + cm = CropModel() + cm.create_model(num_classes=2) + x = torch.rand(4, 3, 224, 224) + yhat = cm.predict_step(x, 0) + assert yhat.shape == (4, 2) + + +# --- BoundingBoxDataset with metadata --- + + +@pytest.fixture() +def bbox_df(): + """Create a simple DataFrame with bounding boxes for testing.""" + df = pd.read_csv(get_data("testfile_multi.csv")) + # Get boxes for a single image + single_image = df.image_path.unique()[0] + return df[df.image_path == single_image].reset_index(drop=True) + + +def test_bounding_box_dataset_with_metadata(bbox_df): + root_dir = os.path.dirname(get_data("SOAP_061.png")) + n = len(bbox_df) + metadata = {i: (35.0, -120.0, 145.0) for i in range(n)} + ds = BoundingBoxDataset(bbox_df, root_dir=root_dir, metadata=metadata) + item = ds[0] + assert isinstance(item, tuple) + assert len(item) == 2 + assert item[0].shape[0] == 3 # channels + assert item[1].shape == (3,) + assert item[1][0] == 35.0 # lat + assert item[1][1] == -120.0 # lon + assert item[1][2] == 145.0 # doy + + +def test_bounding_box_dataset_no_metadata(bbox_df): + root_dir = os.path.dirname(get_data("SOAP_061.png")) + ds = BoundingBoxDataset(bbox_df, root_dir=root_dir) + item = ds[0] + assert isinstance(item, torch.Tensor) + assert item.shape[0] == 3 # channels + + +# --- MetadataImageFolder --- + + +def test_metadata_image_folder(tmp_path): + """Test MetadataImageFolder wrapping an ImageFolder.""" + from deepforest.datasets.training import MetadataImageFolder + from torchvision.datasets import ImageFolder + from PIL import Image + + # Create ImageFolder structure + for cls in ["A", "B"]: + cls_dir = tmp_path / cls + cls_dir.mkdir() + for i in range(3): + img = Image.fromarray(np.random.randint(0, 255, (10, 10, 3), dtype=np.uint8)) + img.save(cls_dir / f"{cls}_{i}.png") + + # Create metadata CSV + rows = [] + for cls in ["A", "B"]: + for i in range(3): + rows.append({ + "filename": f"{cls}_{i}.png", + "lat": 35.0 + i, + "lon": -120.0 + i, + "date": "2024-06-15", + }) + metadata_csv = tmp_path / "metadata.csv" + pd.DataFrame(rows).to_csv(metadata_csv, index=False) + + # Create dataset + base_ds = ImageFolder(str(tmp_path)) + meta_ds = MetadataImageFolder(base_ds, str(metadata_csv)) + + assert len(meta_ds) == 6 + image, label, metadata = meta_ds[0] + assert isinstance(image, (torch.Tensor, np.ndarray, Image.Image)) + assert isinstance(label, int) + assert metadata.shape == (3,) + + # Check day_of_year was computed correctly (June 15 = day 167) + # Find an entry where we know the doy + found = False + for i in range(len(meta_ds)): + _, _, meta = meta_ds[i] + if meta[2].item() == 167.0: + found = True + break + assert found, "day_of_year should be 167 for 2024-06-15" + + +def test_metadata_image_folder_missing_file(tmp_path): + """Files not in the metadata CSV get default zeros.""" + from deepforest.datasets.training import MetadataImageFolder + from torchvision.datasets import ImageFolder + from PIL import Image + + cls_dir = tmp_path / "A" + cls_dir.mkdir() + img = Image.fromarray(np.random.randint(0, 255, (10, 10, 3), dtype=np.uint8)) + img.save(cls_dir / "missing.png") + + # Empty metadata CSV (no matching filenames) + metadata_csv = tmp_path / "metadata.csv" + pd.DataFrame({"filename": [], "lat": [], "lon": [], "date": []}).to_csv( + metadata_csv, index=False + ) + + base_ds = ImageFolder(str(tmp_path)) + meta_ds = MetadataImageFolder(base_ds, str(metadata_csv)) + _, _, metadata = meta_ds[0] + assert metadata[0] == 0.0 # lat fallback + assert metadata[1] == 0.0 # lon fallback + assert metadata[2] == 1.0 # doy fallback + + +# --- Integration: full train cycle with metadata --- + + +def test_full_metadata_training_cycle(tmp_path): + """Integration test: create crop data, train with metadata.""" + from PIL import Image + + # Create crop data in ImageFolder structure + for cls in ["Bird", "Mammal"]: + cls_dir = tmp_path / cls + cls_dir.mkdir() + for i in range(4): + img = Image.fromarray( + np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + ) + img.save(cls_dir / f"{cls}_{i}.png") + + # Create metadata CSV + rows = [] + for cls in ["Bird", "Mammal"]: + for i in range(4): + rows.append({ + "filename": f"{cls}_{i}.png", + "lat": 35.0, + "lon": -120.0, + "date": "2024-06-15", + }) + metadata_csv = tmp_path / "metadata.csv" + pd.DataFrame(rows).to_csv(metadata_csv, index=False) + + # Create and train model + cm = CropModel(config_args={"use_metadata": True, "metadata_dim": 16}) + cm.load_from_disk( + train_dir=str(tmp_path), + val_dir=str(tmp_path), + metadata_csv=str(metadata_csv), + ) + cm.create_trainer(fast_dev_run=True, default_root_dir=str(tmp_path)) + cm.trainer.fit(cm) + + +# --- Checkpoint save/load with metadata --- + + +def test_checkpoint_save_load_metadata(tmp_path): + """Test that metadata models can be saved and loaded from checkpoint.""" + from PIL import Image + + # Create minimal data + for cls in ["A", "B"]: + cls_dir = tmp_path / "data" / cls + cls_dir.mkdir(parents=True) + for i in range(3): + img = Image.fromarray( + np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + ) + img.save(cls_dir / f"{cls}_{i}.png") + + rows = [] + for cls in ["A", "B"]: + for i in range(3): + rows.append({ + "filename": f"{cls}_{i}.png", + "lat": 40.0, + "lon": -100.0, + "date": "2024-01-15", + }) + metadata_csv = tmp_path / "metadata.csv" + pd.DataFrame(rows).to_csv(metadata_csv, index=False) + + data_dir = str(tmp_path / "data") + + # Train and save + cm = CropModel(config_args={"use_metadata": True, "metadata_dim": 16}) + cm.create_trainer( + fast_dev_run=False, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + default_root_dir=str(tmp_path / "logs"), + ) + cm.load_from_disk( + train_dir=data_dir, val_dir=data_dir, metadata_csv=str(metadata_csv) + ) + cm.create_model(num_classes=len(cm.label_dict)) + cm.trainer.fit(cm) + + checkpoint_path = str(tmp_path / "test.ckpt") + cm.trainer.save_checkpoint(checkpoint_path) + + # Load from checkpoint + loaded = CropModel.load_from_checkpoint(checkpoint_path) + assert loaded.backbone is not None + assert loaded.metadata_encoder is not None + assert loaded.classifier is not None + assert loaded.label_dict == cm.label_dict + + # Forward pass should work + x = torch.rand(2, 3, 224, 224) + meta = torch.tensor([[40.0, -100.0, 15.0]] * 2) + out = loaded(x, metadata=meta) + assert out.shape == (2, len(cm.label_dict))