Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 77 additions & 122 deletions src/deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
import os
import warnings

import geopandas as gpd
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities.exceptions import MisconfigurationException
from omegaconf import DictConfig, OmegaConf
from PIL import Image
from pytorch_lightning.callbacks import LearningRateMonitor
Expand All @@ -19,6 +17,7 @@
from deepforest import evaluate as evaluate_iou
from deepforest import predict, utilities
from deepforest.datasets import prediction, training
from deepforest.metrics import RecallPrecision


class deepforest(pl.LightningModule):
Expand Down Expand Up @@ -70,24 +69,15 @@ def __init__(
self.existing_train_dataloader = existing_train_dataloader
self.existing_val_dataloader = existing_val_dataloader

# Metrics
self.iou_metric = IntersectionOverUnion(
class_metrics=True, iou_threshold=self.config.validation.iou_threshold
)
self.mAP_metric = MeanAveragePrecision(backend="faster_coco_eval")

# Empty frame accuracy
self.empty_frame_accuracy = BinaryAccuracy()

# Create a default trainer.
self.create_trainer()

self.model = model
self.original_batch_structure = []

if self.model is None:
self.create_model()

# Create a default trainer.
self.create_trainer()

# Add user supplied transforms
if transforms is None:
self.transforms = None
Expand All @@ -98,6 +88,25 @@ def __init__(
{"config": OmegaConf.to_container(self.config, resolve=True)}
)

def setup_metrics(self):
# Guard against initialization before a validation csv_file is set
if not self.config.validation.csv_file:
return

# Metrics
self.iou_metric = IntersectionOverUnion(
class_metrics=True, iou_threshold=self.config.validation.iou_threshold
)
self.mAP_metric = MeanAveragePrecision(backend="faster_coco_eval")

# Empty frame accuracy
self.empty_frame_accuracy = BinaryAccuracy()

self.precision_recall_metric = RecallPrecision(
csv_file=self.config.validation.csv_file,
label_dict=self.label_dict,
)

def load_model(self, model_name=None, revision=None):
"""Loads a model that has already been pretrained for a specific task,
like tree crown detection.
Expand Down Expand Up @@ -190,6 +199,10 @@ def create_trainer(self, logger=None, callbacks=None, **kwargs):
callbacks: Optional list of callbacks
**kwargs: Additional trainer arguments
"""

# Setup metrics which may have changed if the config was modified
self.setup_metrics()

if callbacks is None:
callbacks = []
# If val data is passed, monitor learning rate and setup classification metrics
Expand Down Expand Up @@ -704,15 +717,10 @@ def validation_step(self, batch, batch_idx):
losses = sum(loss_dict.values())

# Log losses
try:
for key, value in loss_dict.items():
self.log(
f"val_{key}", value.detach(), on_epoch=True, batch_size=len(images)
)
for key, value in loss_dict.items():
self.log(f"val_{key}", value.detach(), on_epoch=True, batch_size=len(images))

self.log("val_loss", losses.detach(), on_epoch=True, batch_size=len(images))
except MisconfigurationException:
pass
self.log("val_loss", losses.detach(), on_epoch=True, batch_size=len(images))

# In eval model, return predictions to calculate prediction metrics
self.model.eval()
Expand All @@ -723,13 +731,29 @@ def validation_step(self, batch, batch_idx):
# Remove empty targets and corresponding predictions
filtered_preds = []
filtered_targets = []

for i, target in enumerate(targets):
if target["boxes"].shape[0] > 0:
# Empty frame accuracy
is_empty_frame = target["boxes"].numel() == 0 or torch.all(
target["boxes"] == 0
)
if is_empty_frame:
# 0 indicates empty frame or predication
device = target["boxes"].device
self.empty_frame_accuracy.update(
torch.tensor([min(len(preds[i]["boxes"]), 1)], device=device),
torch.tensor([0.0], device=device),
)
else:
# Non-empty frames go to all metrics
filtered_preds.append(preds[i])
filtered_targets.append(target)

# IoU and mAP metrics need preds/targets to exist
self.iou_metric.update(filtered_preds, filtered_targets)
self.mAP_metric.update(filtered_preds, filtered_targets)
# Precision recall metric can handle empty frames internally
self.precision_recall_metric.update(preds, image_names)

# Log the predictions if you want to use them for evaluation logs
for i, result in enumerate(preds):
Expand Down Expand Up @@ -799,49 +823,38 @@ def calculate_empty_frame_accuracy(self, ground_df, predictions_df):
# Calculate accuracy using metric
self.empty_frame_accuracy.update(predictions, gt)
empty_accuracy = self.empty_frame_accuracy.compute()
self.empty_frame_accuracy.reset()

# Log empty frame accuracy
try:
self.log("empty_frame_accuracy", empty_accuracy)
except MisconfigurationException:
pass
self.log("empty_frame_accuracy", empty_accuracy)

return empty_accuracy

def log_epoch_metrics(self):
def _compute_epoch_metrics(self) -> dict:
"""Compute metrics and returns a Lightning-loggable dictionary.
This function is called automatically at the end of validation.
"""
metrics = {}

# IoU and mAP
if len(self.iou_metric.groundtruth_labels) > 0:
output = self.iou_metric.compute()
metrics.update(self.iou_metric.compute())
# Lightning bug: claims this is a warning but it's not. See issue #16218 in Lightning-AI/pytorch-lightning
try:
self.log_dict(output)
except Exception:
pass

self.iou_metric.reset()
output = self.mAP_metric.compute()

# Keep only overall mAP; drop extra map_* and classes clutter
if isinstance(output, dict):
# Remove classes entry if present
if "classes" in output:
output.pop("classes", None)
# Reduce to only overall 'map' and map_50 if available
output = {k: v for k, v in output.items() if k in ["map", "map_50"]}
try:
self.log_dict(output)
except MisconfigurationException:
pass
self.mAP_metric.reset()

# Log empty frame accuracy if it has been updated
if self.empty_frame_accuracy._update_called:
empty_accuracy = self.empty_frame_accuracy.compute()
# Remove classes from output dict
output = {key: value for key, value in output.items() if not key == "classes"}
metrics.update(output)

# Box recall/precision
metrics.update(self.precision_recall_metric.compute())

# Empty frame accuracy
if self.empty_frame_accuracy.update_called:
metrics["empty_frame_accuracy"] = self.empty_frame_accuracy.compute()

# Log empty frame accuracy
try:
self.log("empty_frame_accuracy", empty_accuracy)
except MisconfigurationException:
pass
return metrics

def on_validation_epoch_end(self):
"""Compute metrics and predictions at the end of the validation
Expand All @@ -850,23 +863,16 @@ def on_validation_epoch_end(self):
return

# Log epoch metrics
self.log_epoch_metrics()

if (self.current_epoch + 1) % self.config.validation.val_accuracy_interval == 0:
if len(self.predictions) > 0:
predictions = pd.concat(self.predictions)
else:
predictions = pd.DataFrame()

results = self.__evaluate__(
self.config.validation.csv_file,
root_dir=self.config.validation.root_dir,
predictions=predictions,
)
metrics = self._compute_epoch_metrics()
self.log_dict(metrics)

self.__evaluation_logs__(results)

return results
# Manual reset. Lightning does not do this automatically
# unless we log the metric objects directly
self.precision_recall_metric.reset()
self.iou_metric.reset()
self.mAP_metric.reset()
self.empty_frame_accuracy.reset()

def predict_step(self, batch, batch_idx):
"""Predict a batch of images with the deepforest model. If batch is a
Expand Down Expand Up @@ -1040,8 +1046,6 @@ def __evaluate__(
empty_accuracy = self.calculate_empty_frame_accuracy(ground_df, predictions)
results["empty_frame_accuracy"] = empty_accuracy

self.__evaluation_logs__(results)

return results

def evaluate(
Expand Down Expand Up @@ -1079,52 +1083,3 @@ def evaluate(
root_dir=root_dir,
predictions=predictions,
)

def __evaluation_logs__(self, results):
"""Log metrics from evaluation results."""
# Log metrics
for key, value in results.items():
if type(value) in [
pd.DataFrame,
gpd.GeoDataFrame,
utilities.DeepForest_DataFrame,
]:
pass
elif value is None:
pass
else:
try:
self.log(key, value)
except MisconfigurationException:
pass

# Log each key value pair of the results dict
if results["class_recall"] is not None and self.config.num_classes > 1:
for key, value in results.items():
if key in ["class_recall"]:
for _, row in value.iterrows():
try:
self.log(
"{}_Recall".format(
self.numeric_to_label_dict[row["label"]]
),
row["recall"],
)
self.log(
"{}_Precision".format(
self.numeric_to_label_dict[row["label"]]
),
row["precision"],
)
except MisconfigurationException:
pass
elif key in ["predictions", "results", "ground_df"]:
# Don't log dataframes of predictions or IoU results per epoch
pass
elif value is None:
pass
else:
try:
self.log(key, value)
except MisconfigurationException:
pass
Loading