diff --git a/docs/figures/bird_prediction_example_1.png b/docs/figures/bird_prediction_example_1.png new file mode 100644 index 000000000..69382eb56 Binary files /dev/null and b/docs/figures/bird_prediction_example_1.png differ diff --git a/docs/user_guide/02_prebuilt.md b/docs/user_guide/02_prebuilt.md index a3c97182d..3f423944b 100644 --- a/docs/user_guide/02_prebuilt.md +++ b/docs/user_guide/02_prebuilt.md @@ -34,6 +34,13 @@ The model was initially described in [Ecological Applications](https://esajourna Using over 250,000 annotations from 13 projects from around the world, we develop a general bird detection model that achieves over 65% recall and 50% precision on novel aerial data without any local training despite differences in species, habitat, and imaging methodology. Fine-tuning this model with only 1000 local annotations increases these values to an average of 84% recall and 69% precision by building on the general features learned from other data sources. > +The bird detection model has been updated and retrained from the original `weecology/deepforest-bird` model. The updated model was fine-tuned starting from the tree detection model (`weecology/deepforest-tree`) and trained on data from both Weinstein et al. 2022 as well as new additional bird detection data from multiple sources including https://lila.science/. The result is a dataset with over a million bird detections from around the world. Training details and metrics can be viewed on the [Comet dashboard](https://www.comet.com/bw4sz/bird-detector/6181df1ab7ac40f291b863a2a9b86024?&prevPath=%2Fbw4sz%2Fbird-detector%2Fview%2Fnew%2Fexperiments). + +### Example Predictions + +The following examples show predictions from the updated bird detection model: + +![Bird Prediction Example 1](../figures/bird_prediction_example_1.png) ### Citation > Weinstein, B.G., Garner, L., Saccomanno, V.R., Steinkraus, A., Ortega, A., Brush, K., Yenni, G., McKellar, A.E., Converse, R., Lippitt, C.D., Wegmann, A., Holmes, N.D., Edney, A.J., Hart, T., Jessopp, M.J., Clarke, R.H., Marchowski, D., Senyondo, H., Dotson, R., White, E.P., Frederick, P. and Ernest, S.K.M. (2022), A general deep learning model for bird detection in high resolution airborne imagery. Ecological Applications. Accepted Author Manuscript e2694. https://doi-org.lp.hscl.ufl.edu/10.1002/eap.2694 diff --git a/docs/user_guide/07_scaling.md b/docs/user_guide/07_scaling.md index 067ddfa0e..0a9374820 100644 --- a/docs/user_guide/07_scaling.md +++ b/docs/user_guide/07_scaling.md @@ -19,6 +19,68 @@ For example on a SLURM cluster, we use the following line to get 5 gpus on a sin m.create_trainer(logger=comet_logger, accelerator="gpu", strategy="ddp", num_nodes=1, devices=devices) ``` +### Complete SLURM Example + +Here's a complete example for training on 2 GPUs with SLURM that has been tested and works correctly: + +**SLURM submission script (`submit_train.sh`):** +```bash +#!/bin/bash +#SBATCH --job-name=train_model +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gpus=2 +#SBATCH --cpus-per-task=10 +#SBATCH --mem=200GB +#SBATCH --time=48:00:00 +#SBATCH --output=train_%j.out +#SBATCH --error=train_%j.err + +# Use srun without explicit task/gpu flags - SLURM handles process spawning +# The --ntasks-per-node and --gpus directives above control resource allocation +srun uv run python train_script.py \ + --data_dir /path/to/data \ + --batch_size 24 \ + --workers 10 \ + --epochs 30 +``` + +**Python training script:** +```python +import torch +from deepforest import main + +# Initialize model +m = main.deepforest() + +# Configure training parameters +m.config["train"]["csv_file"] = "train.csv" +m.config["train"]["root_dir"] = "/path/to/data" +m.config["batch_size"] = 24 +m.config["workers"] = 10 + +# Set devices to total GPU count - PyTorch Lightning DDP will handle +# device assignment per process when used with SLURM process spawning +devices = torch.cuda.device_count() if torch.cuda.is_available() else 0 + +m.create_trainer( + logger=comet_logger, + devices=devices, + strategy="ddp", # Distributed Data Parallel + fast_dev_run=False, +) + +# Train the model +m.trainer.fit(m) +``` + +**Key points:** +- `--ntasks-per-node=2` in SBATCH directives spawns 2 processes (one per GPU) +- `--gpus=2` in SBATCH directives allocates 2 GPUs total +- Use plain `srun` without `--ntasks` or `--gpus-per-task` flags - let SLURM handle process spawning +- Set `devices=torch.cuda.device_count()` in Python (not `devices=1`) - PyTorch Lightning DDP coordinates with SLURM's process management +- Batch size is per-GPU: with `batch_size=24` and 2 GPUs, you get 48 images per forward pass total + While we rarely use multi-node GPU's, pytorch lightning has functionality at very large scales. We welcome users to share what configurations worked best. A few notes that can trip up those less used to multi-gpu training. These are for the default configurations and may vary on a specific system. We use a large University SLURM cluster with 'ddp' distributed data parallel. @@ -27,9 +89,7 @@ A few notes that can trip up those less used to multi-gpu training. These are fo 2. Each device gets its own portion of the dataset. This means that they do not interact during forward passes. -3. Make sure to use srun when combining with SLURM! This is an easy one to miss and will cause training to hang without error. Documented here - -https://lightning.ai/docs/pytorch/latest/clouds/cluster_advanced.html#troubleshooting. +3. Make sure to use `srun` when combining with SLURM! This is critical for proper process spawning and will cause training to hang without error if omitted. Use `--ntasks-per-node` in SBATCH directives (not in srun) to control the number of processes. Documented [here](https://lightning.ai/docs/pytorch/latest/clouds/cluster_advanced.html#troubleshooting). ## Prediction diff --git a/src/deepforest/callbacks.py b/src/deepforest/callbacks.py index d747aae39..226ba8323 100644 --- a/src/deepforest/callbacks.py +++ b/src/deepforest/callbacks.py @@ -39,7 +39,7 @@ def __init__( prediction_samples=2, dataset_samples=5, every_n_epochs=5, - select_random=False, + select_random=True, color=None, thickness=2, ): @@ -58,6 +58,10 @@ def on_train_start(self, trainer, pl_module): if trainer.fast_dev_run: return + # Only run on rank 0 to avoid file I/O synchronization issues in DDP + if trainer.global_rank != 0: + return + self.trainer = trainer self.pl_module = pl_module @@ -77,6 +81,10 @@ def on_validation_end(self, trainer, pl_module): if trainer.sanity_checking or trainer.fast_dev_run: return + # Only run on rank 0 to avoid file I/O synchronization issues in DDP + if trainer.global_rank != 0: + return + if (trainer.current_epoch + 1) % self.every_n_epochs == 0: pl_module.print("Logging prediction samples") self._log_last_predictions(trainer, pl_module) @@ -141,6 +149,10 @@ def _log_last_predictions(self, trainer, pl_module): else: df = pd.DataFrame() + # Skip logging if there are no predictions + if df.empty or "image_path" not in df.columns: + return + out_dir = os.path.join(self.savedir, "predictions") os.makedirs(out_dir, exist_ok=True) @@ -151,19 +163,21 @@ def _log_last_predictions(self, trainer, pl_module): df["root_dir"] = dataset.root_dir # Limit to n images, potentially randomly selected + unique_images = df.image_path.unique() + n_samples = min(self.prediction_samples, len(unique_images)) + if n_samples == 0: + return if self.select_random: - selected_images = np.random.choice( - df.image_path.unique(), self.prediction_samples - ) + selected_images = np.random.choice(unique_images, n_samples, replace=False) else: - selected_images = df.image_path.unique()[: self.prediction_samples] + selected_images = unique_images[:n_samples] - # Ensure color is correctly assigned - if self.color is None: - num_classes = len(df["label"].unique()) - results_color = sv.ColorPalette.from_matplotlib("viridis", num_classes) - else: - results_color = self.color + # Ensure color is correctly assigned + if self.color is None: + num_classes = len(df["label"].unique()) + results_color = sv.ColorPalette.from_matplotlib("viridis", num_classes) + else: + results_color = self.color for image_name in selected_images: pred_df = df[df.image_path == image_name] diff --git a/src/deepforest/conf/config.yaml b/src/deepforest/conf/config.yaml index 6f7c1bb2c..c70bb69b5 100644 --- a/src/deepforest/conf/config.yaml +++ b/src/deepforest/conf/config.yaml @@ -2,7 +2,7 @@ # Cpu workers for data loaders # Dataloaders -workers: 0 +workers: 5 devices: auto accelerator: auto batch_size: 1 diff --git a/src/deepforest/datasets/training.py b/src/deepforest/datasets/training.py index b3bdf085e..735469f39 100644 --- a/src/deepforest/datasets/training.py +++ b/src/deepforest/datasets/training.py @@ -78,7 +78,7 @@ def __init__( self.preload_images = preload_images self._validate_labels() - self._validate_coordinates() + #self._validate_coordinates() # Pin data to memory if desired if self.preload_images: @@ -109,14 +109,21 @@ def _validate_coordinates(self): ValueError: If any bounding box coordinate occurs outside the image """ errors = [] + # Cache image dimensions to avoid opening the same image multiple times + image_dims = {} for _idx, row in self.annotations.iterrows(): img_path = os.path.join(self.root_dir, row["image_path"]) - try: - with Image.open(img_path) as img: - width, height = img.size - except Exception as e: - errors.append(f"Failed to open image {img_path}: {e}") - continue + + # Get image dimensions (cached per unique image) + if img_path not in image_dims: + try: + with Image.open(img_path) as img: + image_dims[img_path] = img.size + except Exception as e: + errors.append(f"Failed to open image {img_path}: {e}") + continue + + width, height = image_dims[img_path] # Extract bounding box geom = row["geometry"] diff --git a/src/deepforest/main.py b/src/deepforest/main.py index c2e3f44e9..475433a0b 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -843,21 +843,26 @@ def on_validation_epoch_end(self): self.log_epoch_metrics() if (self.current_epoch + 1) % self.config.validation.val_accuracy_interval == 0: - if len(self.predictions) > 0: - predictions = pd.concat(self.predictions) - else: - predictions = pd.DataFrame() + # Only run evaluate on rank 0 to avoid file I/O synchronization issues in DDP + if self.trainer.global_rank == 0: + if len(self.predictions) > 0: + predictions = pd.concat(self.predictions) + else: + predictions = pd.DataFrame() - results = self.evaluate( - self.config.validation.csv_file, - root_dir=self.config.validation.root_dir, - size=self.config.validation.size, - predictions=predictions, - ) + results = self.evaluate( + self.config.validation.csv_file, + root_dir=self.config.validation.root_dir, + size=self.config.validation.size, + predictions=predictions, + ) - self.__evaluation_logs__(results) + self.__evaluation_logs__(results) - return results + return results + else: + # Other ranks return None - metrics are already logged via log_epoch_metrics + return None def predict_step(self, batch, batch_idx): """Predict a batch of images with the deepforest model. If batch is a diff --git a/src/deepforest/preprocess.py b/src/deepforest/preprocess.py index d8bb51692..d5fa42d49 100644 --- a/src/deepforest/preprocess.py +++ b/src/deepforest/preprocess.py @@ -199,15 +199,15 @@ def split_raster( ) # Convert from channels-last (H x W x C) to channels-first (C x H x W) - if numpy_image.shape[2] in [3, 4]: + if len(numpy_image.shape) == 3 and numpy_image.shape[2] in [3, 4]: print( f"Image shape is {numpy_image.shape[2]}, assuming this is channels last, " "converting to channels first" ) numpy_image = numpy_image.transpose(2, 0, 1) - # Check that it's 3 bands - bands = numpy_image.shape[2] + # Check that it's 3 bands (after transpose, shape is (C, H, W), so bands is shape[0]) + bands = numpy_image.shape[0] if not bands == 3: warnings.warn( f"Input image had non-3 band shape of {numpy_image.shape}, selecting first 3 bands", diff --git a/src/deepforest/scripts/compare_bird_models.py b/src/deepforest/scripts/compare_bird_models.py new file mode 100644 index 000000000..3ffd021e1 --- /dev/null +++ b/src/deepforest/scripts/compare_bird_models.py @@ -0,0 +1,517 @@ +"""Compare retrained bird model checkpoint with pretrained weecology/deepforest-bird model. + +This script evaluates both models on the same test dataset and prints a comparison +of performance metrics. It also evaluates the checkpoint model at multiple score +thresholds and generates a precision-recall curve. + +Example usage: + python compare_bird_models.py --checkpoint_path /path/to/checkpoint.ckpt --data_dir /path/to/data +""" + +import argparse +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from deepforest import main + + +def compare_models(checkpoint_path, data_dir, iou_threshold=0.4): + """Compare checkpoint model with pretrained weecology/deepforest-bird model. + + Args: + checkpoint_path: Path to the checkpoint file + data_dir: Directory containing test.csv and images + iou_threshold: IoU threshold for evaluation (default: 0.4) + + Returns: + dict: Dictionary containing results for both models + """ + test_csv = os.path.join(data_dir, "test.csv") + + # Read test set and make a tiny subset of 100 images + test_df = pd.read_csv(test_csv) + test_df = test_df[test_df.image_path.str.startswith("BDA")].head(10) + test_csv = os.path.join(data_dir, "test_subset.csv") + test_df.to_csv(test_csv, index=False) + + print("=" * 80) + print("Bird Detection Model Comparison") + print("=" * 80) + print(f"\nTest dataset: {test_csv}") + print(f"IoU threshold: {iou_threshold}\n") + + results = {} + + # Evaluate checkpoint model + print("-" * 80) + print("Evaluating retrained checkpoint model...") + print(f"Checkpoint: {checkpoint_path}") + print("-" * 80) + checkpoint_model = main.deepforest.load_from_checkpoint(checkpoint_path) + checkpoint_model.config.score_thresh = 0.25 + checkpoint_model.model.score_thresh = 0.25 + + # Set up validation configuration + checkpoint_model.config.validation.csv_file = test_csv + checkpoint_model.config.validation.root_dir = data_dir + checkpoint_model.config.validation.iou_threshold = iou_threshold + checkpoint_model.config.validation.val_accuracy_interval = 1 + checkpoint_model.create_trainer() + + # Evaluate using trainer.validate() + print("\n1. trainer.validate() results:") + validation_results = checkpoint_model.trainer.validate(checkpoint_model) + checkpoint_validate = validation_results[0] if validation_results else {} + results["checkpoint_validate"] = checkpoint_validate + + checkpoint_precision_validate = checkpoint_validate.get('box_precision') + checkpoint_recall_validate = checkpoint_validate.get('box_recall') + if checkpoint_precision_validate is not None: + print(f" Box Precision: {checkpoint_precision_validate:.4f}") + else: + print(" Box Precision: N/A") + if checkpoint_recall_validate is not None: + print(f" Box Recall: {checkpoint_recall_validate:.4f}") + else: + print(" Box Recall: N/A") + print(f" Empty Frame Accuracy: {checkpoint_validate.get('empty_frame_accuracy', 'N/A')}") + + # Evaluate using main.evaluate() + print("\n2. main.evaluate() results:") + checkpoint_evaluate = checkpoint_model.evaluate( + csv_file=test_csv, + root_dir=data_dir, + iou_threshold=iou_threshold, + ) + results["checkpoint_evaluate"] = checkpoint_evaluate + + checkpoint_precision_evaluate = checkpoint_evaluate.get('box_precision') + checkpoint_recall_evaluate = checkpoint_evaluate.get('box_recall') + if checkpoint_precision_evaluate is not None: + print(f" Box Precision: {checkpoint_precision_evaluate:.4f}") + else: + print(" Box Precision: N/A") + if checkpoint_recall_evaluate is not None: + print(f" Box Recall: {checkpoint_recall_evaluate:.4f}") + else: + print(" Box Recall: N/A") + print(f" Empty Frame Accuracy: {checkpoint_evaluate.get('empty_frame_accuracy', 'N/A')}") + + # Store both for backward compatibility + results["checkpoint"] = checkpoint_validate + + # Evaluate pretrained model + print("\n" + "-" * 80) + print("Evaluating pretrained weecology/deepforest-bird model...") + print("-" * 80) + pretrained_model = main.deepforest() + pretrained_model.load_model("weecology/deepforest-bird") + pretrained_model.config.score_thresh = 0.25 + pretrained_model.model.score_thresh = 0.25 + + # Set label dictionaries to match + pretrained_model.label_dict = {"Bird": 0} + pretrained_model.numeric_to_label_dict = {0: "Bird"} + pretrained_model.config.label_dict = {"Bird": 0} + pretrained_model.config.num_classes = 1 + + # Set up validation configuration + pretrained_model.config.validation.csv_file = test_csv + pretrained_model.config.validation.root_dir = data_dir + pretrained_model.config.validation.iou_threshold = iou_threshold + pretrained_model.config.validation.val_accuracy_interval = 1 + pretrained_model.create_trainer() + + # Evaluate using trainer.validate() + print("\n1. trainer.validate() results:") + validation_results = pretrained_model.trainer.validate(pretrained_model) + pretrained_validate = validation_results[0] if validation_results else {} + results["pretrained_validate"] = pretrained_validate + + pretrained_precision_validate = pretrained_validate.get('box_precision') + pretrained_recall_validate = pretrained_validate.get('box_recall') + if pretrained_precision_validate is not None: + print(f" Box Precision: {pretrained_precision_validate:.4f}") + else: + print(" Box Precision: N/A") + if pretrained_recall_validate is not None: + print(f" Box Recall: {pretrained_recall_validate:.4f}") + else: + print(" Box Recall: N/A") + print(f" Empty Frame Accuracy: {pretrained_validate.get('empty_frame_accuracy', 'N/A')}") + + # Evaluate using main.evaluate() + print("\n2. main.evaluate() results:") + pretrained_evaluate = pretrained_model.evaluate( + csv_file=test_csv, + root_dir=data_dir, + iou_threshold=iou_threshold, + ) + results["pretrained_evaluate"] = pretrained_evaluate + + pretrained_precision_evaluate = pretrained_evaluate.get('box_precision') + pretrained_recall_evaluate = pretrained_evaluate.get('box_recall') + if pretrained_precision_evaluate is not None: + print(f" Box Precision: {pretrained_precision_evaluate:.4f}") + else: + print(" Box Precision: N/A") + if pretrained_recall_evaluate is not None: + print(f" Box Recall: {pretrained_recall_evaluate:.4f}") + else: + print(" Box Recall: N/A") + print(f" Empty Frame Accuracy: {pretrained_evaluate.get('empty_frame_accuracy', 'N/A')}") + + # Store both for backward compatibility + results["pretrained"] = pretrained_validate + + # Print comparison + print("\n" + "=" * 80) + print("COMPARISON SUMMARY") + print("=" * 80) + + # Comparison using trainer.validate() results + print("\n" + "-" * 80) + print("Using trainer.validate() results:") + print("-" * 80) + + checkpoint_precision_validate = checkpoint_validate.get("box_precision") + checkpoint_recall_validate = checkpoint_validate.get("box_recall") + pretrained_precision_validate = pretrained_validate.get("box_precision") + pretrained_recall_validate = pretrained_validate.get("box_recall") + + if checkpoint_precision_validate is not None and pretrained_precision_validate is not None: + precision_diff = checkpoint_precision_validate - pretrained_precision_validate + print(f"\nBox Precision:") + print(f" Checkpoint: {checkpoint_precision_validate:.4f}") + print(f" Pretrained: {pretrained_precision_validate:.4f}") + if pretrained_precision_validate != 0: + print(f" Difference: {precision_diff:+.4f} ({precision_diff/pretrained_precision_validate*100:+.2f}%)") + else: + print(f" Difference: {precision_diff:+.4f} (N/A%)") + else: + print(f"\nBox Precision: Unable to compute (missing values)") + + if checkpoint_recall_validate is not None and pretrained_recall_validate is not None: + recall_diff = checkpoint_recall_validate - pretrained_recall_validate + print(f"\nBox Recall:") + print(f" Checkpoint: {checkpoint_recall_validate:.4f}") + print(f" Pretrained: {pretrained_recall_validate:.4f}") + if pretrained_recall_validate != 0: + print(f" Difference: {recall_diff:+.4f} ({recall_diff/pretrained_recall_validate*100:+.2f}%)") + else: + print(f" Difference: {recall_diff:+.4f} (N/A%)") + else: + print(f"\nBox Recall: Unable to compute (missing values)") + + if "empty_frame_accuracy" in checkpoint_validate and "empty_frame_accuracy" in pretrained_validate: + checkpoint_empty = checkpoint_validate["empty_frame_accuracy"] + pretrained_empty = pretrained_validate["empty_frame_accuracy"] + if checkpoint_empty is not None and pretrained_empty is not None: + empty_diff = checkpoint_empty - pretrained_empty + print(f"\nEmpty Frame Accuracy:") + print(f" Checkpoint: {checkpoint_empty:.4f}") + print(f" Pretrained: {pretrained_empty:.4f}") + print(f" Difference: {empty_diff:+.4f}") + else: + print(f"\nEmpty Frame Accuracy: Unable to compute (missing values)") + print(f" Checkpoint: {checkpoint_empty}") + print(f" Pretrained: {pretrained_empty}") + + # Comparison using main.evaluate() results + print("\n" + "-" * 80) + print("Using main.evaluate() results:") + print("-" * 80) + + checkpoint_precision_evaluate = checkpoint_evaluate.get("box_precision") + checkpoint_recall_evaluate = checkpoint_evaluate.get("box_recall") + pretrained_precision_evaluate = pretrained_evaluate.get("box_precision") + pretrained_recall_evaluate = pretrained_evaluate.get("box_recall") + + if checkpoint_precision_evaluate is not None and pretrained_precision_evaluate is not None: + precision_diff = checkpoint_precision_evaluate - pretrained_precision_evaluate + print(f"\nBox Precision:") + print(f" Checkpoint: {checkpoint_precision_evaluate:.4f}") + print(f" Pretrained: {pretrained_precision_evaluate:.4f}") + if pretrained_precision_evaluate != 0: + print(f" Difference: {precision_diff:+.4f} ({precision_diff/pretrained_precision_evaluate*100:+.2f}%)") + else: + print(f" Difference: {precision_diff:+.4f} (N/A%)") + else: + print(f"\nBox Precision: Unable to compute (missing values)") + + if checkpoint_recall_evaluate is not None and pretrained_recall_evaluate is not None: + recall_diff = checkpoint_recall_evaluate - pretrained_recall_evaluate + print(f"\nBox Recall:") + print(f" Checkpoint: {checkpoint_recall_evaluate:.4f}") + print(f" Pretrained: {pretrained_recall_evaluate:.4f}") + if pretrained_recall_evaluate != 0: + print(f" Difference: {recall_diff:+.4f} ({recall_diff/pretrained_recall_evaluate*100:+.2f}%)") + else: + print(f" Difference: {recall_diff:+.4f} (N/A%)") + else: + print(f"\nBox Recall: Unable to compute (missing values)") + + if "empty_frame_accuracy" in checkpoint_evaluate and "empty_frame_accuracy" in pretrained_evaluate: + checkpoint_empty = checkpoint_evaluate["empty_frame_accuracy"] + pretrained_empty = pretrained_evaluate["empty_frame_accuracy"] + if checkpoint_empty is not None and pretrained_empty is not None: + empty_diff = checkpoint_empty - pretrained_empty + print(f"\nEmpty Frame Accuracy:") + print(f" Checkpoint: {checkpoint_empty:.4f}") + print(f" Pretrained: {pretrained_empty:.4f}") + print(f" Difference: {empty_diff:+.4f}") + else: + print(f"\nEmpty Frame Accuracy: Unable to compute (missing values)") + print(f" Checkpoint: {checkpoint_empty}") + print(f" Pretrained: {pretrained_empty}") + + print("\n" + "=" * 80) + + return results + + +def evaluate_multiple_thresholds( + checkpoint_path, data_dir, iou_threshold=0.4, thresholds=None, output_path=None +): + """Evaluate checkpoint model at multiple score thresholds. + + Args: + checkpoint_path: Path to the checkpoint file + data_dir: Directory containing test.csv and images + iou_threshold: IoU threshold for evaluation (default: 0.4) + thresholds: List of score thresholds to evaluate (default: 0.1 to 0.5 in 0.05 steps) + output_path: Path to save the plot (default: data_dir/precision_recall_curve.png) + + Returns: + dict: Dictionary with thresholds, precision, and recall arrays + """ + if thresholds is None: + thresholds = np.arange(0.1, 0.55, 0.05).round(2).tolist() + + test_csv = os.path.join(data_dir, "test.csv") + + if not os.path.exists(test_csv): + raise FileNotFoundError(f"Test CSV not found: {test_csv}") + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + print("=" * 80) + print("Evaluating Checkpoint Model at Multiple Score Thresholds") + print("=" * 80) + print(f"\nTest dataset: {test_csv}") + print(f"IoU threshold: {iou_threshold}") + print(f"Score thresholds: {thresholds}\n") + + # Load model once + print("Loading checkpoint model...") + model = main.deepforest.load_from_checkpoint(checkpoint_path) + + precision_scores_validate = [] + recall_scores_validate = [] + precision_scores_evaluate = [] + recall_scores_evaluate = [] + + # Set up validation configuration once + model.config.validation.csv_file = test_csv + model.config.validation.root_dir = data_dir + model.config.validation.iou_threshold = iou_threshold + model.config.validation.val_accuracy_interval = 1 + + print("\nEvaluating at each threshold:") + print("-" * 80) + for i, threshold in enumerate(thresholds): + print(f"\n[{i+1}/{len(thresholds)}] Evaluating at score threshold: {threshold:.2f}") + model.config.score_thresh = threshold + model.model.score_thresh = threshold + + # Evaluate using trainer.validate() + model.create_trainer() + validation_results = model.trainer.validate(model) + validate_results = validation_results[0] if validation_results else {} + + precision_validate = validate_results.get("box_precision", 0.0) + recall_validate = validate_results.get("box_recall", 0.0) + + precision_scores_validate.append(precision_validate) + recall_scores_validate.append(recall_validate) + + print(f" trainer.validate() - Precision: {precision_validate:.4f}, Recall: {recall_validate:.4f}") + + # Evaluate using main.evaluate() + evaluate_results = model.evaluate( + csv_file=test_csv, + root_dir=data_dir, + iou_threshold=iou_threshold, + ) + + precision_evaluate = evaluate_results.get("box_precision", 0.0) + recall_evaluate = evaluate_results.get("box_recall", 0.0) + + precision_scores_evaluate.append(precision_evaluate) + recall_scores_evaluate.append(recall_evaluate) + + print(f" main.evaluate() - Precision: {precision_evaluate:.4f}, Recall: {recall_evaluate:.4f}") + + # Create results dictionary + threshold_results = { + "thresholds": thresholds, + "precision_validate": precision_scores_validate, + "recall_validate": recall_scores_validate, + "precision_evaluate": precision_scores_evaluate, + "recall_evaluate": recall_scores_evaluate, + } + + # Print summary table + print("\n" + "=" * 80) + print("SUMMARY TABLE - trainer.validate()") + print("=" * 80) + print(f"\n{'Threshold':<12} {'Precision':<12} {'Recall':<12}") + print("-" * 40) + for thresh, prec, rec in zip(thresholds, precision_scores_validate, recall_scores_validate): + print(f"{thresh:<12.2f} {prec:<12.4f} {rec:<12.4f}") + + print("\n" + "=" * 80) + print("SUMMARY TABLE - main.evaluate()") + print("=" * 80) + print(f"\n{'Threshold':<12} {'Precision':<12} {'Recall':<12}") + print("-" * 40) + for thresh, prec, rec in zip(thresholds, precision_scores_evaluate, recall_scores_evaluate): + print(f"{thresh:<12.2f} {prec:<12.4f} {rec:<12.4f}") + + # Generate plot + if output_path is None: + output_path = os.path.join(data_dir, "precision_recall_curve.png") + + print(f"\nGenerating plot: {output_path}") + plt.figure(figsize=(14, 8)) + + # Plot trainer.validate() results + plt.plot(thresholds, precision_scores_validate, "o-", label="Precision (trainer.validate())", linewidth=2, markersize=8, color='blue') + plt.plot(thresholds, recall_scores_validate, "s-", label="Recall (trainer.validate())", linewidth=2, markersize=8, color='blue', linestyle='--') + + # Plot main.evaluate() results + plt.plot(thresholds, precision_scores_evaluate, "o-", label="Precision (main.evaluate())", linewidth=2, markersize=8, color='red') + plt.plot(thresholds, recall_scores_evaluate, "s-", label="Recall (main.evaluate())", linewidth=2, markersize=8, color='red', linestyle='--') + + plt.xlabel("Score Threshold", fontsize=12) + plt.ylabel("Score", fontsize=12) + plt.title("Precision and Recall vs Score Threshold\n(Retrained Bird Detection Model - Both Methods)", fontsize=14) + plt.legend(fontsize=11) + plt.grid(True, alpha=0.3) + plt.xlim(min(thresholds) - 0.02, max(thresholds) + 0.02) + max_score = max( + max(precision_scores_validate) if precision_scores_validate else 0, + max(recall_scores_validate) if recall_scores_validate else 0, + max(precision_scores_evaluate) if precision_scores_evaluate else 0, + max(recall_scores_evaluate) if recall_scores_evaluate else 0, + ) + plt.ylim(0, max_score * 1.1 if max_score > 0 else 1.0) + + # Add value labels on points for trainer.validate() + for thresh, prec, rec in zip(thresholds, precision_scores_validate, recall_scores_validate): + plt.annotate( + f"{prec:.3f}", + (thresh, prec), + textcoords="offset points", + xytext=(0, 10), + ha="center", + fontsize=7, + color='blue', + ) + plt.annotate( + f"{rec:.3f}", + (thresh, rec), + textcoords="offset points", + xytext=(0, -15), + ha="center", + fontsize=7, + color='blue', + ) + + # Add value labels on points for main.evaluate() + for thresh, prec, rec in zip(thresholds, precision_scores_evaluate, recall_scores_evaluate): + plt.annotate( + f"{prec:.3f}", + (thresh, prec), + textcoords="offset points", + xytext=(0, 20), + ha="center", + fontsize=7, + color='red', + ) + plt.annotate( + f"{rec:.3f}", + (thresh, rec), + textcoords="offset points", + xytext=(0, -25), + ha="center", + fontsize=7, + color='red', + ) + + plt.tight_layout() + plt.savefig(output_path, dpi=300, bbox_inches="tight") + print(f"Plot saved to: {output_path}") + + return threshold_results + + +def run(): + """Main function.""" + parser = argparse.ArgumentParser( + description="Compare retrained bird model with pretrained model" + ) + parser.add_argument( + "--checkpoint_path", + type=str, + required=True, + help="Path to the checkpoint file", + ) + parser.add_argument( + "--data_dir", + type=str, + required=True, + help="Directory containing test.csv and images", + ) + parser.add_argument( + "--iou_threshold", + type=float, + default=0.4, + help="IoU threshold for evaluation (default: 0.4)", + ) + parser.add_argument( + "--evaluate_thresholds", + action="store_true", + help="Evaluate checkpoint model at multiple score thresholds (0.1-0.5) and generate plot", + ) + parser.add_argument( + "--plot_output", + type=str, + default=None, + help="Path to save the precision-recall plot (default: data_dir/precision_recall_curve.png)", + ) + + args = parser.parse_args() + + # Run comparison + compare_models( + checkpoint_path=args.checkpoint_path, + data_dir=args.data_dir, + iou_threshold=args.iou_threshold, + ) + + # Evaluate at multiple thresholds if requested + if args.evaluate_thresholds: + evaluate_multiple_thresholds( + checkpoint_path=args.checkpoint_path, + data_dir=args.data_dir, + iou_threshold=args.iou_threshold, + output_path=args.plot_output, + ) + + +if __name__ == "__main__": + run() + diff --git a/src/deepforest/scripts/evaluate_deepwater_horizon.py b/src/deepforest/scripts/evaluate_deepwater_horizon.py new file mode 100644 index 000000000..392bab2e8 --- /dev/null +++ b/src/deepforest/scripts/evaluate_deepwater_horizon.py @@ -0,0 +1,478 @@ +"""Evaluate bird detection models on DeepWater Horizon imagery. + +This script: +1. Loads shapefiles from the DeepWater Horizon monitoring program +2. Creates a test.csv file +3. Evaluates both the old and new bird detection models +4. Generates visualization comparisons +""" + +import os +import glob + +import pandas as pd +from deepforest import main as df_main +from deepforest.preprocess import split_raster +from deepforest.utilities import read_file +from deepforest.visualize import plot_results +import geopandas as gpd + + +def load_shapefiles_and_create_test_csv(data_dir, output_csv="test.csv", output_dir=None): + """Load all shapefiles and create a test.csv file. + + Args: + data_dir: Directory containing shapefiles and images + output_csv: Name of output CSV file + output_dir: Directory to write CSV file (default: tries data_dir, falls back to current directory) + + Returns: + Path to the created CSV file + """ + output_path = os.path.join(data_dir, output_csv) + + # Check if CSV already exists + if os.path.exists(output_path): + print(f"Test CSV already exists at {output_path}, skipping creation.") + # Verify it's readable + try: + existing_df = pd.read_csv(output_path) + print(f"Found existing {output_csv} with {len(existing_df)} annotations from {len(existing_df['image_path'].unique())} images") + except Exception as e: + print(f"Warning: Could not read existing CSV: {e}. Recreating...") + else: + return output_path + + # Find all shapefiles + shapefiles = glob.glob(os.path.join(data_dir, "*_annotated.shp")) + print(f"Found {len(shapefiles)} shapefiles") + + all_annotations = [] + + for shp_path in shapefiles: + # Extract base name to find corresponding image + base_name = os.path.basename(shp_path).replace("_annotated.shp", "") + + # Find corresponding image file + image_files = glob.glob(os.path.join(data_dir, f"{base_name}*.jpg")) + if not image_files: + print(f"Warning: No image found for {base_name}") + continue + + image_path = image_files[0] + image_filename = os.path.basename(image_path) + + print(f"Processing {base_name}: {image_filename}") + + # Read shapefile directly (coordinates are already in image space) + gdf = gpd.read_file(shp_path) + gdf.geometry = gdf.geometry.scale(xfact=1, yfact=-1, origin=(0, 0)) + + # Set image_path + gdf["image_path"] = image_filename + gdf.crs = None + gdf["label"] = "Bird" + gdf = gdf[gdf.geometry.notna()] + gdf = read_file(gdf, root_dir=data_dir) + + all_annotations.append(gdf) + + # Combine all annotations + combined_df = pd.concat(all_annotations, ignore_index=True) + + combined_df.to_csv(output_path, index=False) + print(f"\nCreated {output_csv} with {len(combined_df)} annotations from {len(combined_df['image_path'].unique())} images") + print(f"Saved to: {output_path}") + + return output_path + + +def split_test_images_for_evaluation(test_csv, data_dir, patch_size=800, patch_overlap=0, output_dir=None, split_csv_name=None): + """Split test images into smaller patches for evaluation using split_raster. + + Args: + test_csv: Path to test CSV file with full image annotations + data_dir: Directory containing test images + patch_size: Size of patches for splitting (default: 800) + patch_overlap: Overlap between patches (default: 0) + output_dir: Directory to save split images (default: test_splits subdirectory of test_csv location) + split_csv_name: Name of the output CSV file (default: test_split.csv) + + Returns: + Tuple of (split_csv_path, split_dir) where split_csv_path is the path to + the CSV file with split image annotations and split_dir is the directory + containing the split images + """ + # Create output directory for split images + if output_dir is None: + output_dir = os.path.join(os.path.dirname(test_csv), "test_splits") + os.makedirs(output_dir, exist_ok=True) + + # Set default CSV name if not provided + if split_csv_name is None: + split_csv_name = "test_split.csv" + + # Read the test CSV + test_df = read_file(test_csv) + unique_images = test_df["image_path"].unique() + + print(f"\nSplitting {len(unique_images)} test images into {patch_size}-pixel patches...") + print(f"Output directory: {output_dir}") + + all_split_annotations = [] + + for image_name in unique_images: + image_path = os.path.join(data_dir, image_name) + if not os.path.exists(image_path): + print(f"Warning: Image not found: {image_path}") + continue + + print(f"Processing {image_name}...") + + # Get annotations for this image + image_annotations = test_df[test_df["image_path"] == image_name].copy() + + # Create temporary CSV file for this image's annotations + temp_annotations_file = os.path.join(output_dir, f"temp_{image_name}_annotations.csv") + image_annotations.to_csv(temp_annotations_file, index=False) + + # Use split_raster to create crops + split_df = split_raster( + annotations_file=temp_annotations_file, + path_to_raster=image_path, + root_dir=os.path.dirname(temp_annotations_file), + patch_size=patch_size, + patch_overlap=patch_overlap, + allow_empty=False, + save_dir=output_dir, + ) + + if not split_df.empty: + all_split_annotations.append(split_df) + print(f" Created {len(split_df['image_path'].unique())} patches with {len(split_df)} annotations") + else: + print(f" Warning: No patches created for {image_name}") + + # Clean up temporary annotations file + if os.path.exists(temp_annotations_file): + os.remove(temp_annotations_file) + + # Combine all split annotations + if all_split_annotations: + combined_split_df = pd.concat(all_split_annotations, ignore_index=True) + + # Save split CSV + split_csv_path = os.path.join(output_dir, split_csv_name) + combined_split_df.to_csv(split_csv_path, index=False) + + print(f"\nCreated split test CSV with {len(combined_split_df)} annotations from {len(combined_split_df['image_path'].unique())} patches") + print(f"Saved to: {split_csv_path}") + + return split_csv_path, output_dir + else: + raise ValueError("No split annotations were created. Check that images exist and contain valid annotations.") + + +def evaluate_models(checkpoint_path, data_dir, test_csv, split_dir, iou_threshold=0.4): + """Evaluate both old and new bird detection models. + + Args: + checkpoint_path: Path to the new checkpoint model + data_dir: Directory containing original test data (not used for evaluation) + test_csv: Path to split test CSV file (for evaluation) + split_dir: Directory containing split images (for evaluation) + iou_threshold: IoU threshold for evaluation + + Returns: + Dictionary with evaluation results for both models + """ + results = {} + + # Evaluate new checkpoint model + print("\n" + "=" * 80) + print("Evaluating NEW checkpoint model...") + print("=" * 80) + checkpoint_model = df_main.deepforest.load_from_checkpoint(checkpoint_path) + checkpoint_model.config.score_thresh = 0.25 + checkpoint_model.model.score_thresh = 0.25 + + # Set up validation configuration using split CSV and split directory + checkpoint_model.config.validation.csv_file = test_csv + checkpoint_model.config.validation.root_dir = split_dir + checkpoint_model.config.validation.iou_threshold = iou_threshold + checkpoint_model.config.validation.val_accuracy_interval = 1 + checkpoint_model.config.workers = 0 + checkpoint_model.create_trainer() + + validation_results = checkpoint_model.trainer.validate(checkpoint_model) + checkpoint_validate = validation_results[0] if validation_results else {} + results["checkpoint"] = checkpoint_validate + + print(f"Box Precision: {checkpoint_validate.get('box_precision', 'N/A')}") + print(f"Box Recall: {checkpoint_validate.get('box_recall', 'N/A')}") + print(f"Empty Frame Accuracy: {checkpoint_validate.get('empty_frame_accuracy', 'N/A')}") + + # Evaluate old pretrained model + print("\n" + "=" * 80) + print("Evaluating OLD pretrained model (weecology/deepforest-bird)...") + print("=" * 80) + pretrained_model = df_main.deepforest() + pretrained_model.load_model("weecology/deepforest-bird") + pretrained_model.config.score_thresh = 0.25 + pretrained_model.model.score_thresh = 0.25 + + # Set label dictionaries to match + pretrained_model.label_dict = {"Bird": 0} + pretrained_model.numeric_to_label_dict = {0: "Bird"} + pretrained_model.config.label_dict = {"Bird": 0} + pretrained_model.config.num_classes = 1 + + # Set up validation configuration using split CSV and split directory + pretrained_model.config.validation.csv_file = test_csv + pretrained_model.config.validation.root_dir = split_dir + pretrained_model.config.validation.iou_threshold = iou_threshold + pretrained_model.config.validation.val_accuracy_interval = 1 + pretrained_model.config.workers = 0 + pretrained_model.create_trainer() + + validation_results = pretrained_model.trainer.validate(pretrained_model) + pretrained_validate = validation_results[0] if validation_results else {} + results["pretrained"] = pretrained_validate + + print(f"Box Precision: {pretrained_validate.get('box_precision', 'N/A')}") + print(f"Box Recall: {pretrained_validate.get('box_recall', 'N/A')}") + print(f"Empty Frame Accuracy: {pretrained_validate.get('empty_frame_accuracy', 'N/A')}") + + return results, checkpoint_model, pretrained_model + + +def generate_visualizations( + checkpoint_model, + pretrained_model, + data_dir, + test_csv, + output_dir, + num_images=2, +): + """Generate side-by-side visualizations comparing old and new models. + + Args: + checkpoint_model: New checkpoint model + pretrained_model: Old pretrained model + data_dir: Directory containing test data + test_csv: Path to test CSV file + output_dir: Directory to save visualizations + num_images: Number of images to visualize + """ + import matplotlib.pyplot as plt + from deepforest.utilities import read_file + + os.makedirs(output_dir, exist_ok=True) + + # Read test CSV to get image list + test_df = read_file(test_csv) + unique_images = test_df["image_path"].unique()[:num_images] + + print(f"\nGenerating visualizations for {len(unique_images)} images...") + + for image_name in unique_images: + image_path = os.path.join(data_dir, image_name) + if not os.path.exists(image_path): + print(f"Warning: Image not found: {image_path}") + continue + + print(f"Processing {image_name}...") + + # Get ground truth + ground_truth = test_df[test_df["image_path"] == image_name].copy() + + # Predict with new model + checkpoint_predictions = checkpoint_model.predict_tile(path=image_path, patch_size=800, patch_overlap=0) + + # Predict with old model + pretrained_predictions = pretrained_model.predict_tile(path=image_path, patch_size=800, patch_overlap=0) + + # Create side-by-side comparison using savedir approach + # Save individual plots first, then combine + base_name = os.path.splitext(image_name)[0] + plots_dir = "/blue/ewhite/b.weinstein/bird_detector_retrain/zero_shot/avian_images_annotated/plots" + os.makedirs(plots_dir, exist_ok=True) + + # Plot new model + if len(checkpoint_predictions) > 0: + plot_results( + checkpoint_predictions, + ground_truth=ground_truth, + image=image_path, + savedir=plots_dir, + basename=f"{base_name}_new", + show=False, + ) + else: + # Create empty plot + fig, ax = plt.subplots(figsize=(10, 10)) + ax.text(0.5, 0.5, "No predictions", ha="center", va="center", fontsize=16) + ax.set_title("New Retrained Model - No Predictions", fontsize=14) + plt.savefig(os.path.join(plots_dir, f"{base_name}_new.png"), dpi=300, bbox_inches="tight") + plt.close(fig) + + # Plot old model + if len(pretrained_predictions) > 0: + plot_results( + pretrained_predictions, + ground_truth=ground_truth, + image=image_path, + savedir=plots_dir, + basename=f"{base_name}_old", + show=False, + ) + else: + # Create empty plot + fig, ax = plt.subplots(figsize=(10, 10)) + ax.text(0.5, 0.5, "No predictions", ha="center", va="center", fontsize=16) + ax.set_title("Original Pretrained Model - No Predictions", fontsize=14) + plt.savefig(os.path.join(plots_dir, f"{base_name}_old.png"), dpi=300, bbox_inches="tight") + plt.close(fig) + + # Combine the two images side by side + from PIL import Image as PILImage + img1 = PILImage.open(os.path.join(plots_dir, f"{base_name}_new.png")) + img2 = PILImage.open(os.path.join(plots_dir, f"{base_name}_old.png")) + + # Resize to same height + height = max(img1.height, img2.height) + img1 = img1.resize((int(img1.width * height / img1.height), height), PILImage.Resampling.LANCZOS) + img2 = img2.resize((int(img2.width * height / img2.height), height), PILImage.Resampling.LANCZOS) + + # Combine + combined = PILImage.new('RGB', (img1.width + img2.width, height)) + combined.paste(img1, (0, 0)) + combined.paste(img2, (img1.width, 0)) + + # Save + output_path = os.path.join(plots_dir, f"{base_name}_comparison.png") + combined.save(output_path, dpi=(300, 300)) + + print(f"Saved: {output_path}") + + +def main(): + """Main function.""" + import argparse + + parser = argparse.ArgumentParser( + description="Evaluate bird detection models on DeepWater Horizon imagery" + ) + parser.add_argument( + "--data_dir", + type=str, + default="/blue/ewhite/b.weinstein/bird_detector_retrain/zero_shot/avian_images_annotated", + help="Directory containing shapefiles and images", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + default="/blue/ewhite/b.weinstein/bird_detector_retrain/2022paper/checkpoints/f92a9384135f4481b7372b85d1da5b5f.ckpt", + help="Path to checkpoint file", + ) + parser.add_argument( + "--iou_threshold", + type=float, + default=0.4, + help="IoU threshold for evaluation", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Directory to save visualizations (default: data_dir/visualizations)", + ) + parser.add_argument( + "--num_images", + type=int, + default=2, + help="Number of images to visualize", + ) + parser.add_argument( + "--patch_size", + type=int, + default=800, + help="Patch size for splitting images during evaluation (default: 800)", + ) + parser.add_argument( + "--patch_overlap", + type=float, + default=0.0, + help="Patch overlap for splitting images during evaluation (default: 0.0)", + ) + + args = parser.parse_args() + + # Set default output directory (use current directory to avoid permission issues) + if args.output_dir is None: + args.output_dir = os.path.join(os.getcwd(), "visualizations") + + # Step 1: Load shapefiles and create test.csv + print("=" * 80) + print("Step 1: Loading shapefiles and creating test.csv") + print("=" * 80) + test_csv = load_shapefiles_and_create_test_csv(args.data_dir) + + # Step 2: Split test images for evaluation + print("\n" + "=" * 80) + print("Step 2: Splitting test images for evaluation") + print("=" * 80) + split_csv, split_dir = split_test_images_for_evaluation( + test_csv=test_csv, + data_dir=args.data_dir, + patch_size=args.patch_size, + patch_overlap=args.patch_overlap, + ) + + # Step 3: Evaluate models using split images + print("\n" + "=" * 80) + print("Step 3: Evaluating models") + print("=" * 80) + results, checkpoint_model, pretrained_model = evaluate_models( + checkpoint_path=args.checkpoint_path, + data_dir=args.data_dir, + test_csv=split_csv, + split_dir=split_dir, + iou_threshold=args.iou_threshold, + ) + + # Step 4: Generate visualizations using full images + print("\n" + "=" * 80) + print("Step 4: Generating visualizations") + print("=" * 80) + generate_visualizations( + checkpoint_model=checkpoint_model, + pretrained_model=pretrained_model, + data_dir=args.data_dir, + test_csv=test_csv, + output_dir=args.output_dir, + num_images=args.num_images, + ) + + # Print summary + print("\n" + "=" * 80) + print("EVALUATION SUMMARY") + print("=" * 80) + print("\nNew Checkpoint Model:") + print(f" Box Precision: {results['checkpoint'].get('box_precision', 'N/A')}") + print(f" Box Recall: {results['checkpoint'].get('box_recall', 'N/A')}") + print(f" Empty Frame Accuracy: {results['checkpoint'].get('empty_frame_accuracy', 'N/A')}") + + print("\nOriginal Pretrained Model:") + print(f" Box Precision: {results['pretrained'].get('box_precision', 'N/A')}") + print(f" Box Recall: {results['pretrained'].get('box_recall', 'N/A')}") + print(f" Empty Frame Accuracy: {results['pretrained'].get('empty_frame_accuracy', 'N/A')}") + + print(f"\nVisualizations saved to: {args.output_dir}") + print(f"Original test CSV saved to: {test_csv}") + print(f"Split test CSV saved to: {split_csv}") + print(f"Split images directory: {split_dir}") + + +if __name__ == "__main__": + main() diff --git a/src/deepforest/scripts/evaluate_patch_size_sensitivity.py b/src/deepforest/scripts/evaluate_patch_size_sensitivity.py new file mode 100644 index 000000000..0dd694e7a --- /dev/null +++ b/src/deepforest/scripts/evaluate_patch_size_sensitivity.py @@ -0,0 +1,276 @@ +"""Evaluate sensitivity of box_recall and box_precision to patch_size. + +This script wraps evaluate_deepwater_horizon.py to evaluate multiple patch sizes +for both checkpoint and pretrained models, and generate a sensitivity plot showing +how metrics vary with patch size for comparison. +""" + +import os +import importlib.util +import argparse +import matplotlib.pyplot as plt +import pandas as pd + +# Import from evaluate_deepwater_horizon in the same directory +_script_dir = os.path.dirname(os.path.abspath(__file__)) +_eval_module_path = os.path.join(_script_dir, "evaluate_deepwater_horizon.py") +spec = importlib.util.spec_from_file_location("evaluate_deepwater_horizon", _eval_module_path) +eval_module = importlib.util.module_from_spec(spec) +spec.loader.exec_module(eval_module) + +load_shapefiles_and_create_test_csv = eval_module.load_shapefiles_and_create_test_csv +split_test_images_for_evaluation = eval_module.split_test_images_for_evaluation +evaluate_models = eval_module.evaluate_models + + +def evaluate_patch_size_sensitivity( + data_dir, + checkpoint_path, + patch_sizes, + iou_threshold=0.4, + patch_overlap=0.0, +): + """Evaluate both checkpoint and pretrained models across multiple patch sizes and collect results. + + Args: + data_dir: Directory containing shapefiles and images + checkpoint_path: Path to checkpoint file + patch_sizes: List of patch sizes to evaluate + iou_threshold: IoU threshold for evaluation + patch_overlap: Overlap between patches + + Returns: + DataFrame with patch_size and metrics for both checkpoint and pretrained models + """ + # Step 1: Load shapefiles and create test.csv (only once) + print("=" * 80) + print("Step 1: Loading shapefiles and creating test.csv") + print("=" * 80) + test_csv = load_shapefiles_and_create_test_csv(data_dir) + + results = [] + + for patch_size in patch_sizes: + print("\n" + "=" * 80) + print(f"Evaluating patch size: {patch_size}") + print("=" * 80) + + # Create patch-size-specific output directory + base_output_dir = os.path.join(os.path.dirname(test_csv), "test_splits") + patch_output_dir = os.path.join(base_output_dir, f"patch_{patch_size}") + split_csv_name = f"test_split_patch_{patch_size}.csv" + + # Check if split CSV already exists + split_csv_path = os.path.join(patch_output_dir, split_csv_name) + if os.path.exists(split_csv_path): + print(f"Found existing split CSV at {split_csv_path}, skipping splitting...") + split_dir = patch_output_dir + else: + # Step 2: Split test images for this patch size + print(f"\nSplitting test images for patch size {patch_size}...") + split_csv_path, split_dir = split_test_images_for_evaluation( + test_csv=test_csv, + data_dir=data_dir, + patch_size=patch_size, + patch_overlap=patch_overlap, + output_dir=patch_output_dir, + split_csv_name=split_csv_name, + ) + + # Step 3: Evaluate both models + print(f"\nEvaluating both models for patch size {patch_size}...") + eval_results, _, _ = evaluate_models( + checkpoint_path=checkpoint_path, + data_dir=data_dir, + test_csv=split_csv_path, + split_dir=split_dir, + iou_threshold=iou_threshold, + ) + + # Extract results for both models + checkpoint_results = eval_results.get("checkpoint", {}) + pretrained_results = eval_results.get("pretrained", {}) + + checkpoint_precision = checkpoint_results.get("box_precision", None) + checkpoint_recall = checkpoint_results.get("box_recall", None) + pretrained_precision = pretrained_results.get("box_precision", None) + pretrained_recall = pretrained_results.get("box_recall", None) + + results.append( + { + "patch_size": patch_size, + "checkpoint_precision": checkpoint_precision, + "checkpoint_recall": checkpoint_recall, + "pretrained_precision": pretrained_precision, + "pretrained_recall": pretrained_recall, + } + ) + + print(f"Patch size {patch_size}:") + print(f" Checkpoint - Precision={checkpoint_precision}, Recall={checkpoint_recall}") + print(f" Pretrained - Precision={pretrained_precision}, Recall={pretrained_recall}") + + return pd.DataFrame(results) + + +def plot_sensitivity(results_df, output_path): + """Create a plot showing sensitivity of metrics to patch size for both models. + + Args: + results_df: DataFrame with patch_size and metrics for both checkpoint and pretrained models + output_path: Path to save the plot + """ + fig, ax = plt.subplots(figsize=(12, 7)) + + # Plot checkpoint model (solid lines) + ax.plot( + results_df["patch_size"], + results_df["checkpoint_precision"], + marker="o", + label="Checkpoint Precision", + linewidth=2, + markersize=8, + linestyle="-", + color="C0", + ) + ax.plot( + results_df["patch_size"], + results_df["checkpoint_recall"], + marker="s", + label="Checkpoint Recall", + linewidth=2, + markersize=8, + linestyle="-", + color="C1", + ) + + # Plot pretrained model (dashed lines) + ax.plot( + results_df["patch_size"], + results_df["pretrained_precision"], + marker="o", + label="Pretrained Precision", + linewidth=2, + markersize=8, + linestyle="--", + color="C0", + alpha=0.7, + ) + ax.plot( + results_df["patch_size"], + results_df["pretrained_recall"], + marker="s", + label="Pretrained Recall", + linewidth=2, + markersize=8, + linestyle="--", + color="C1", + alpha=0.7, + ) + + ax.set_xlabel("Patch Size (pixels)", fontsize=12) + ax.set_ylabel("Metric Value", fontsize=12) + ax.set_title("Sensitivity of Box Precision and Recall to Patch Size\n(Checkpoint vs Pretrained Model)", fontsize=14) + ax.legend(fontsize=10, loc="best") + ax.grid(True, alpha=0.3) + ax.set_xlim(left=0) + + # Ensure y-axis shows full range + all_metrics = pd.concat([ + results_df["checkpoint_precision"], + results_df["checkpoint_recall"], + results_df["pretrained_precision"], + results_df["pretrained_recall"], + ]) + y_min = all_metrics.min() + y_max = all_metrics.max() + y_range = y_max - y_min + ax.set_ylim( + max(0, y_min - 0.1 * y_range), + min(1.0, y_max + 0.1 * y_range), + ) + + plt.tight_layout() + plt.savefig(output_path, dpi=300, bbox_inches="tight") + print(f"\nSaved sensitivity plot to: {output_path}") + plt.close(fig) + + +def main(): + """Main function.""" + parser = argparse.ArgumentParser( + description="Evaluate sensitivity of metrics to patch size" + ) + parser.add_argument( + "--data_dir", + type=str, + default="/blue/ewhite/b.weinstein/bird_detector_retrain/zero_shot/avian_images_annotated", + help="Directory containing shapefiles and images", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + default="/blue/ewhite/b.weinstein/bird_detector_retrain/data/checkpoints/6181df1ab7ac40f291b863a2a9b86024.ckpt", + help="Path to checkpoint file", + ) + parser.add_argument( + "--iou_threshold", + type=float, + default=0.4, + help="IoU threshold for evaluation", + ) + parser.add_argument( + "--patch_overlap", + type=float, + default=0.0, + help="Patch overlap for splitting images (default: 0.0)", + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Directory to save plots (default: data_dir/plots)", + ) + parser.add_argument( + "--patch_sizes", + type=int, + nargs="+", + default=[200, 400, 600, 800, 1000, 1500, 2000], + help="List of patch sizes to evaluate (default: 200 400 600 800 1000 1500 2000)", + ) + + args = parser.parse_args() + + # Set default output directory + if args.output_dir is None: + args.output_dir = os.path.join(args.data_dir, "plots") + os.makedirs(args.output_dir, exist_ok=True) + + # Evaluate across patch sizes (both models) + results_df = evaluate_patch_size_sensitivity( + data_dir=args.data_dir, + checkpoint_path=args.checkpoint_path, + patch_sizes=args.patch_sizes, + iou_threshold=args.iou_threshold, + patch_overlap=args.patch_overlap, + ) + + # Save results to CSV + results_csv = os.path.join(args.output_dir, "patch_size_sensitivity_results.csv") + results_df.to_csv(results_csv, index=False) + print(f"\nSaved results to: {results_csv}") + + # Create and save plot + plot_path = os.path.join(args.output_dir, "patch_size_sensitivity.png") + plot_sensitivity(results_df, plot_path) + + # Print summary + print("\n" + "=" * 80) + print("SENSITIVITY ANALYSIS SUMMARY") + print("=" * 80) + print(results_df.to_string(index=False)) + print(f"\nPlot saved to: {plot_path}") + + +if __name__ == "__main__": + main() diff --git a/src/deepforest/scripts/evaluate_thresholds.py b/src/deepforest/scripts/evaluate_thresholds.py new file mode 100644 index 000000000..aa11c136c --- /dev/null +++ b/src/deepforest/scripts/evaluate_thresholds.py @@ -0,0 +1,180 @@ +"""Evaluate bird detection model at multiple score thresholds. + +This script evaluates a checkpoint model at multiple score thresholds and +generates a precision-recall curve. + +Example usage: + python evaluate_thresholds.py --checkpoint_path /path/to/checkpoint.ckpt --data_dir /path/to/data +""" + +import argparse +import os + +import matplotlib.pyplot as plt +import numpy as np + +from deepforest import main + + +def evaluate_thresholds( + checkpoint_path, data_dir, iou_threshold=0.4, thresholds=None, output_path=None +): + """Evaluate checkpoint model at multiple score thresholds. + + Args: + checkpoint_path: Path to the checkpoint file + data_dir: Directory containing test.csv and images + iou_threshold: IoU threshold for evaluation (default: 0.4) + thresholds: List of score thresholds to evaluate (default: 0.1 to 0.5 in 0.05 steps) + output_path: Path to save the plot (default: data_dir/precision_recall_curve.png) + + Returns: + dict: Dictionary with thresholds, precision, and recall arrays + """ + if thresholds is None: + thresholds = np.arange(0.1, 0.55, 0.05).round(2).tolist() + + test_csv = os.path.join(data_dir, "test.csv") + + if not os.path.exists(test_csv): + raise FileNotFoundError(f"Test CSV not found: {test_csv}") + if not os.path.exists(checkpoint_path): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + print("=" * 80) + print("Evaluating Checkpoint Model at Multiple Score Thresholds") + print("=" * 80) + print(f"\nTest dataset: {test_csv}") + print(f"IoU threshold: {iou_threshold}") + print(f"Score thresholds: {thresholds}\n") + + # Load model once + print("Loading checkpoint model...") + model = main.deepforest.load_from_checkpoint(checkpoint_path) + + precision_scores = [] + recall_scores = [] + + print("\nEvaluating at each threshold:") + print("-" * 80) + for i, threshold in enumerate(thresholds): + print(f"\n[{i+1}/{len(thresholds)}] Evaluating at score threshold: {threshold:.2f}") + model.config.score_thresh = threshold + model.model.score_thresh = threshold + + results = model.evaluate( + csv_file=test_csv, + root_dir=data_dir, + iou_threshold=iou_threshold, + ) + + precision = results["box_precision"] + recall = results["box_recall"] + + precision_scores.append(precision) + recall_scores.append(recall) + + print(f" Precision: {precision:.4f}") + print(f" Recall: {recall:.4f}") + + # Create results dictionary + threshold_results = { + "thresholds": thresholds, + "precision": precision_scores, + "recall": recall_scores, + } + + # Print summary table + print("\n" + "=" * 80) + print("SUMMARY TABLE") + print("=" * 80) + print(f"\n{'Threshold':<12} {'Precision':<12} {'Recall':<12}") + print("-" * 40) + for thresh, prec, rec in zip(thresholds, precision_scores, recall_scores): + print(f"{thresh:<12.2f} {prec:<12.4f} {rec:<12.4f}") + + # Generate plot + if output_path is None: + output_path = os.path.join(data_dir, "precision_recall_curve.png") + + print(f"\nGenerating plot: {output_path}") + plt.figure(figsize=(10, 6)) + plt.plot(thresholds, precision_scores, "o-", label="Precision", linewidth=2, markersize=8) + plt.plot(thresholds, recall_scores, "s-", label="Recall", linewidth=2, markersize=8) + plt.xlabel("Score Threshold", fontsize=12) + plt.ylabel("Score", fontsize=12) + plt.title("Precision and Recall vs Score Threshold\n(Retrained Bird Detection Model)", fontsize=14) + plt.legend(fontsize=11) + plt.grid(True, alpha=0.3) + plt.xlim(min(thresholds) - 0.02, max(thresholds) + 0.02) + plt.ylim(0, max(max(precision_scores), max(recall_scores)) * 1.1) + + # Add value labels on points + for thresh, prec, rec in zip(thresholds, precision_scores, recall_scores): + plt.annotate( + f"{prec:.3f}", + (thresh, prec), + textcoords="offset points", + xytext=(0, 10), + ha="center", + fontsize=8, + ) + plt.annotate( + f"{rec:.3f}", + (thresh, rec), + textcoords="offset points", + xytext=(0, -15), + ha="center", + fontsize=8, + ) + + plt.tight_layout() + plt.savefig(output_path, dpi=300, bbox_inches="tight") + print(f"Plot saved to: {output_path}") + + return threshold_results + + +def run(): + """Main function.""" + parser = argparse.ArgumentParser( + description="Evaluate checkpoint model at multiple score thresholds" + ) + parser.add_argument( + "--checkpoint_path", + type=str, + required=True, + help="Path to the checkpoint file", + ) + parser.add_argument( + "--data_dir", + type=str, + required=True, + help="Directory containing test.csv and images", + ) + parser.add_argument( + "--iou_threshold", + type=float, + default=0.4, + help="IoU threshold for evaluation (default: 0.4)", + ) + parser.add_argument( + "--plot_output", + type=str, + default=None, + help="Path to save the precision-recall plot (default: data_dir/precision_recall_curve.png)", + ) + + args = parser.parse_args() + + evaluate_thresholds( + checkpoint_path=args.checkpoint_path, + data_dir=args.data_dir, + iou_threshold=args.iou_threshold, + output_path=args.plot_output, + ) + + +if __name__ == "__main__": + run() + diff --git a/src/deepforest/scripts/prepare_birds.py b/src/deepforest/scripts/prepare_birds.py new file mode 100644 index 000000000..7edb49ba0 --- /dev/null +++ b/src/deepforest/scripts/prepare_birds.py @@ -0,0 +1,666 @@ +"""Prepare bird detection training data from multiple sources. + +This script collects annotations from multiple data sources, maps labels to "Bird", +creates symlinks to a single output directory, and generates train/test splits. +This is a documentation/example script - users should adapt paths to their own data. + +Example paths are hardcoded below (actual data not publicly available). +""" + +import argparse +import json +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from PIL import Image +from sklearn.model_selection import train_test_split + +from deepforest.preprocess import split_raster +from deepforest.utilities import read_file + + +# Data source file paths (adapt these to your own data locations) +DATA_SOURCES = [ + "/orange/ewhite/b.weinstein/Drones_for_Ducks/uas-imagery-of-migratory-waterfowl/crowdsourced/20240220_dronesforducks_zooniverse_refined.json", + "/orange/ewhite/b.weinstein/izembek-lagoon-waterfowl/izembek-lagoon-birds-metadata.json", + "/orange/ewhite/b.weinstein/bird_detector/generalization/crops/training_annotations.csv", + "/blue/ewhite/b.weinstein/BOEM/UBFAI Images with Detection Data/crops/train.csv", +] + +# Path to existing test dataset (if provided, skip train/test split and put all new data in train) +EXISTING_TEST = None # e.g., "/path/to/existing/test.csv" + +# Nuisance labels to exclude (will be filtered out) +NUISANCE_LABELS = {"buoy", "buoys", "trash", "Trash",'boat','sargassum'} + +def load_coco_with_bboxes(json_file): + """Load COCO format JSON file with bounding boxes (bbox) instead of segmentation. + + Args: + json_file: Path to COCO JSON file with bbox annotations + + Returns: + DataFrame with image_path, xmin, ymin, xmax, ymax, label columns + """ + with open(json_file) as f: + coco_data = json.load(f) + + # Create mapping from image_id to file_name + image_ids = {image["id"]: image["file_name"] for image in coco_data["images"]} + + # Create mapping from category_id to category name (if available) + category_ids = {} + if "categories" in coco_data: + category_ids = {cat["id"]: cat.get("name", f"category_{cat['id']}") for cat in coco_data["categories"]} + + annotations = [] + for annotation in coco_data["annotations"]: + # Skip if image_id doesn't exist in images + image_id = annotation["image_id"] + if image_id not in image_ids: + continue + + # COCO bbox format: [x, y, width, height] where (x, y) is top-left corner + try: + bbox = annotation["bbox"] + except KeyError: + continue + + x = bbox[0] + y = bbox[1] + width = bbox[2] + height = bbox[3] + + # Convert to DeepForest format: xmin, ymin, xmax, ymax + xmin = x + ymin = y + xmax = x + width + ymax = y + height + + # Get category label + category_id = annotation.get("category_id", 1) + label = category_ids.get(category_id, "Bird") + + annotations.append({ + "image_path": image_ids[image_id], + "xmin": xmin, + "ymin": ymin, + "xmax": xmax, + "ymax": ymax, + "label": label, + }) + + return pd.DataFrame(annotations) + + +def load_annotations_from_source(source_path): + """Load annotations from a data source file. + + Args: + source_path: Path to annotation file (CSV or JSON) + + Returns: + DataFrame with annotations and root_dir attribute + """ + if not os.path.exists(source_path): + raise FileNotFoundError(f"Source file does not exist: {source_path}") + + if source_path.endswith(".csv"): + df = read_file(source_path) + elif source_path.endswith(".json"): + df = load_coco_with_bboxes(source_path) + else: + raise ValueError(f"Unsupported file type: {source_path}") + + # Add root_dir attribute (directory containing the annotation file) + df.root_dir = os.path.dirname(source_path) + + return df + + +def map_labels_to_bird(df): + """Map all labels to "Bird" except nuisance labels which are filtered out. + + Args: + df: DataFrame with label column + + Returns: + DataFrame with labels mapped to "Bird" and nuisance labels removed + """ + # Filter out nuisance labels + if "label" in df.columns: + mask = ~df["label"].str.lower().isin([n.lower() for n in NUISANCE_LABELS]) + df = df[mask].copy() + + # Map all remaining labels to "Bird" + df["label"] = "Bird" + + return df + + +def create_blank_images(output_dir, num_images=100, image_size=(400, 400)): + """Create blank white images with empty annotations. + + Args: + output_dir: Directory to save images and annotations + num_images: Number of blank images to create + image_size: Tuple of (width, height) for images + + Returns: + DataFrame with empty annotations for blank images + """ + blank_annotations = [] + + for i in range(num_images): + # Create blank white image + blank_image = Image.new("RGB", image_size, color="white") + image_filename = f"blank_image_{i:03d}.png" + image_path = os.path.join(output_dir, image_filename) + blank_image.save(image_path) + + # Create empty annotation (0,0,0,0 coordinates indicate empty frame) + blank_annotations.append( + { + "image_path": image_filename, + "xmin": 0, + "ymin": 0, + "xmax": 0, + "ymax": 0, + "label": "Bird", + } + ) + + return pd.DataFrame(blank_annotations) + + +def create_symlink(source, target): + """Create a symlink, handling existing files. + + Args: + source: Source file path + target: Target symlink path + """ + # Remove target if it exists + if os.path.exists(target) or os.path.islink(target): + os.remove(target) + + # Create parent directory if needed + os.makedirs(os.path.dirname(target), exist_ok=True) + + # Create symlink + os.symlink(source, target) + + +def check_negative_coordinates(df): + """Check for negative bounding box coordinates. + + Args: + df: DataFrame with xmin, ymin, xmax, ymax columns + + Returns: + DataFrame with rows that have negative coordinates + """ + required_cols = ["xmin", "ymin", "xmax", "ymax"] + for col in required_cols: + if col not in df.columns: + return pd.DataFrame() + + # Find rows with any negative coordinates + negative_mask = ( + (df["xmin"] < 0) | (df["ymin"] < 0) | (df["xmax"] < 0) | (df["ymax"] < 0) + ) + return df[negative_mask].copy() + + +def clip_boxes_to_image_bounds(df, image_dir): + """Clip bounding box coordinates to image boundaries. + + Clips negative coordinates to 0 and coordinates beyond image dimensions + to the image edges. Ensures boxes remain valid (xmax > xmin, ymax > ymin). + + Args: + df: DataFrame with image_path, xmin, ymin, xmax, ymax columns + image_dir: Directory containing the images + + Returns: + DataFrame with clipped coordinates + """ + df = df.copy() + required_cols = ["image_path", "xmin", "ymin", "xmax", "ymax"] + for col in required_cols: + if col not in df.columns: + return df + + # Track how many boxes were clipped + clipped_count = 0 + invalid_count = 0 + + # Process each unique image + unique_images = df["image_path"].unique() + for img_path in unique_images: + # Get full image path + full_img_path = os.path.join(image_dir, img_path) + + if not os.path.exists(full_img_path): + continue + + try: + # Load image to get dimensions + img = Image.open(full_img_path) + img_width, img_height = img.size + + # Get annotations for this image + img_mask = df["image_path"] == img_path + img_indices = df[img_mask].index + + for idx in img_indices: + original_xmin = df.at[idx, "xmin"] + original_ymin = df.at[idx, "ymin"] + original_xmax = df.at[idx, "xmax"] + original_ymax = df.at[idx, "ymax"] + + # Clip coordinates to image boundaries + xmin = max(0, min(original_xmin, img_width - 1)) + ymin = max(0, min(original_ymin, img_height - 1)) + xmax = max(xmin + 1, min(original_xmax, img_width)) + ymax = max(ymin + 1, min(original_ymax, img_height)) + + # Check if clipping occurred + if ( + xmin != original_xmin + or ymin != original_ymin + or xmax != original_xmax + or ymax != original_ymax + ): + clipped_count += 1 + df.at[idx, "xmin"] = xmin + df.at[idx, "ymin"] = ymin + df.at[idx, "xmax"] = xmax + df.at[idx, "ymax"] = ymax + + # Check if box is still valid + if xmax <= xmin or ymax <= ymin: + invalid_count += 1 + + except Exception as e: + print(f" Warning: Error processing image {img_path}: {e}") + continue + + if clipped_count > 0: + print(f" Clipped {clipped_count} bounding boxes to image boundaries") + if invalid_count > 0: + print(f" Warning: {invalid_count} boxes became invalid after clipping") + + return df + + +def process_izembek_with_splitting(df, root_dir, output_dir, image_files_map): + """Process Izembek dataset by splitting images into 800-pixel crops. + + Args: + df: DataFrame with annotations + root_dir: Root directory for images + output_dir: Output directory for crops + image_files_map: Map from original path to symlink name + + Returns: + DataFrame with crop annotations + """ + import tempfile + + # Create temporary directory for crops + crops_dir = os.path.join(output_dir, "izembek_crops") + os.makedirs(crops_dir, exist_ok=True) + + crop_annotations_list = [] + unique_images = df["image_path"].unique() + + print(f" Splitting {len(unique_images)} images into 2000-pixel crops...") + + for img_path in unique_images: + # Construct full source path + if os.path.isabs(img_path): + source_img_path = img_path + else: + source_img_path = os.path.join(root_dir, img_path) + + if not os.path.exists(source_img_path): + # Try alternative locations + alt_paths = [ + os.path.join(root_dir, os.path.basename(img_path)), + ] + found = False + for alt_path in alt_paths: + if os.path.exists(alt_path): + source_img_path = alt_path + found = True + break + if not found: + print(f" Warning: Image not found: {source_img_path}") + continue + + # Get image basename for matching with annotations + image_basename = os.path.basename(source_img_path) + + # Filter annotations for this image and update image_path to basename + img_annotations = df[df["image_path"] == img_path].copy() + if img_annotations.empty: + continue + + # Update image_path to basename for split_raster matching + img_annotations["image_path"] = image_basename + + # Save temporary annotations file for this image + temp_annotations_file = os.path.join(crops_dir, f"temp_{image_basename}_annotations.csv") + img_annotations.to_csv(temp_annotations_file, index=False) + + try: + # Use split_raster to create crops + crop_df = split_raster( + annotations_file=temp_annotations_file, + path_to_raster=source_img_path, + root_dir=os.path.dirname(temp_annotations_file), + patch_size=2000, + patch_overlap=0, + allow_empty=False, + save_dir=crops_dir, + ) + + # Process each crop + for crop_img_path in crop_df["image_path"].unique(): + crop_full_path = os.path.join(crops_dir, crop_img_path) + + if not os.path.exists(crop_full_path): + continue + + # Create unique symlink name + crop_basename = crop_img_path + symlink_name = crop_basename + counter = 1 + while symlink_name in image_files_map.values(): + name, ext = os.path.splitext(crop_basename) + symlink_name = f"{name}_{counter}{ext}" + counter += 1 + + # Create symlink to crop + target_path = os.path.join(output_dir, symlink_name) + try: + create_symlink(crop_full_path, target_path) + image_files_map[crop_img_path] = symlink_name + except Exception as e: + print(f" Warning: Failed to create symlink for {crop_img_path}: {e}") + continue + + # Update image paths in crop dataframe to use symlink name + crop_df.loc[crop_df["image_path"] == crop_img_path, "image_path"] = symlink_name + + crop_annotations_list.append(crop_df) + + except Exception as e: + print(f" Warning: Failed to split image {img_path}: {e}") + continue + finally: + # Clean up temporary annotations file + if os.path.exists(temp_annotations_file): + os.remove(temp_annotations_file) + + if crop_annotations_list: + return pd.concat(crop_annotations_list, ignore_index=True) + else: + return pd.DataFrame() + + +def filter_small_boxes(df, min_area=1, epsilon=1e-6): + """Filter out bounding boxes with zero or single-pixel area. + + Args: + df: DataFrame with xmin, ymin, xmax, ymax columns + min_area: Minimum area (in pixels) for a box to be kept (default: 1) + epsilon: Small value for floating point comparison (default: 1e-6) + + Returns: + DataFrame with small boxes removed + """ + df = df.copy() + required_cols = ["xmin", "ymin", "xmax", "ymax"] + for col in required_cols: + if col not in df.columns: + return df + + # Calculate width, height, and area + width = df["xmax"] - df["xmin"] + height = df["ymax"] - df["ymin"] + area = width * height + + # Round area to handle floating point precision issues + # Single-pixel boxes (width=1, height=1) should have area=1.0 + area_rounded = np.round(area, decimals=6) + + # Filter out boxes with invalid dimensions or area <= min_area + # Filter if: width <= 0, height <= 0, or rounded area <= min_area + # This catches single-pixel boxes (width=1, height=1, area=1) + valid_mask = (width > epsilon) & (height > epsilon) & (area_rounded > min_area) + + removed_count = (~valid_mask).sum() + if removed_count > 0: + print(f" Removed {removed_count} bounding boxes with area <= {min_area} pixel(s) or single-pixel dimensions") + + return df[valid_mask].copy() + + + + +def main(): + """Main function to prepare bird detection training data.""" + parser = argparse.ArgumentParser( + description="Prepare bird detection training data from multiple sources" + ) + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Output directory for prepared data (images and CSV files)", + ) + parser.add_argument( + "--test_size", + type=float, + default=0.1, + help="Fraction of data to use for testing (default: 0.2)", + ) + parser.add_argument( + "--num_blank_images", + type=int, + default=100, + help="Number of blank white images to generate (default: 100)", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for train/test split (default: 42)", + ) + + args = parser.parse_args() + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + print("Loading annotations from multiple sources...") + all_annotations = [] + image_files_map = {} # Map from original path to symlink name + + # Load annotations from all sources + for source_path in DATA_SOURCES: + print(f"\nProcessing source: {source_path}") + df = load_annotations_from_source(source_path) + df = map_labels_to_bird(df) + + if df.empty: + print(f" No annotations after filtering for {source_path}") + continue + + # Get root directory for images + root_dir = df.root_dir if hasattr(df, 'root_dir') else os.path.dirname(source_path) + + # Special case: Drones for Ducks images are in /images subdirectory + if "drones_for_ducks" in source_path.lower(): + root_dir = os.path.join(root_dir, "images") + + # Ensure required columns exist + required_cols = ["image_path", "xmin", "ymin", "xmax", "ymax", "label"] + missing_cols = [col for col in required_cols if col not in df.columns] + if missing_cols: + print(f" Warning: Missing columns {missing_cols}, skipping...") + continue + + # Special case: Izembek dataset - split into 2000-pixel crops + if "izembek" in source_path.lower(): + print(" Using split_raster to create 2000-pixel crops with allow_empty=False") + df = process_izembek_with_splitting(df, root_dir, args.output_dir, image_files_map) + if df.empty: + print(f" No crop annotations generated for {source_path}") + continue + all_annotations.append(df) + print(f" Loaded {len(df)} crop annotations from {df['image_path'].nunique()} crop images") + continue + + # Handle image paths - create symlinks + unique_images = df["image_path"].unique() + for img_path in unique_images: + # Construct full source path + if os.path.isabs(img_path): + source_img_path = img_path + else: + source_img_path = os.path.join(root_dir, img_path) + + if not os.path.exists(source_img_path): + # Try alternative locations + alt_paths = [ + os.path.join(root_dir, os.path.basename(img_path)), + os.path.join(os.path.dirname(source_path), img_path), + ] + found = False + for alt_path in alt_paths: + if os.path.exists(alt_path): + source_img_path = alt_path + found = True + break + if not found: + print(f" Warning: Image not found: {source_img_path}") + continue + + # Create unique symlink name + img_basename = os.path.basename(img_path) + symlink_name = img_basename + counter = 1 + while symlink_name in image_files_map.values(): + name, ext = os.path.splitext(img_basename) + symlink_name = f"{name}_{counter}{ext}" + counter += 1 + + # Create symlink + target_path = os.path.join(args.output_dir, symlink_name) + try: + create_symlink(source_img_path, target_path) + image_files_map[img_path] = symlink_name + except Exception as e: + print(f" Warning: Failed to create symlink for {img_path}: {e}") + continue + + # Update image paths in dataframe to use symlink name + df.loc[df["image_path"] == img_path, "image_path"] = symlink_name + + all_annotations.append(df) + print(f" Loaded {len(df)} annotations from {len(unique_images)} images") + + if not all_annotations: + raise ValueError("No annotations were loaded from any source!") + + # Combine all annotations + combined_df = pd.concat(all_annotations, ignore_index=True) + + # Check for negative coordinates before clipping + print("\nChecking for negative bounding box coordinates...") + negative_coords_df = check_negative_coordinates(combined_df) + if not negative_coords_df.empty: + print(f" Found {len(negative_coords_df)} annotations with negative coordinates") + print(f" Affected images: {negative_coords_df['image_path'].nunique()}") + print("\n Summary of negative coordinates:") + print(f" xmin < 0: {(combined_df['xmin'] < 0).sum()}") + print(f" ymin < 0: {(combined_df['ymin'] < 0).sum()}") + print(f" xmax < 0: {(combined_df['xmax'] < 0).sum()}") + print(f" ymax < 0: {(combined_df['ymax'] < 0).sum()}") + + # Clip boxes to image boundaries + print("\nClipping bounding boxes to image boundaries...") + combined_df = clip_boxes_to_image_bounds(combined_df, args.output_dir) + + # Verify clipping worked + negative_after = check_negative_coordinates(combined_df) + if negative_after.empty: + print(" All negative coordinates have been clipped.") + else: + print(f" Warning: {len(negative_after)} annotations still have negative coordinates after clipping") + else: + print(" No negative coordinates found.") + + # Filter out boxes with zero or single-pixel area + print("\nFiltering out boxes with zero or single-pixel area...") + initial_count = len(combined_df) + combined_df = filter_small_boxes(combined_df, min_area=1) + removed_count = initial_count - len(combined_df) + if removed_count > 0: + print(f" Removed {removed_count} boxes (kept {len(combined_df)} boxes)") + + # Add blank images + print(f"\nGenerating {args.num_blank_images} blank white images...") + blank_df = create_blank_images(args.output_dir, args.num_blank_images) + combined_df = pd.concat([combined_df, blank_df], ignore_index=True) + + print(f"\nTotal annotations: {len(combined_df)}") + print(f"Total unique images: {combined_df['image_path'].nunique()}") + + # Save CSV files + train_csv = os.path.join(args.output_dir, "train.csv") + test_csv = os.path.join(args.output_dir, "test.csv") + + # Ensure required columns are present and in correct order + required_cols = ["image_path", "xmin", "ymin", "xmax", "ymax", "label"] + + # If existing test dataset is provided, skip split and put all new data in train + if EXISTING_TEST and os.path.exists(EXISTING_TEST): + print(f"\nUsing existing test dataset: {EXISTING_TEST}") + print("Putting all new data in training set...") + train_df = combined_df[required_cols].copy() + train_df.to_csv(train_csv, index=False) + print(f"\nSaved training annotations: {train_csv} ({len(train_df)} annotations, {train_df['image_path'].nunique()} images)") + print(f"Using existing test dataset: {EXISTING_TEST}") + else: + # Split into train/test by image_path (to avoid data leakage) + print(f"\nSplitting into train/test ({1-args.test_size:.0%}/{args.test_size:.0%})...") + unique_images = combined_df["image_path"].unique() + train_images, test_images = train_test_split( + unique_images, test_size=args.test_size, random_state=args.seed + ) + + train_df = combined_df[combined_df["image_path"].isin(train_images)].copy() + test_df = combined_df[combined_df["image_path"].isin(test_images)].copy() + + train_df = train_df[required_cols] + test_df = test_df[required_cols] + + train_df.to_csv(train_csv, index=False) + test_df.to_csv(test_csv, index=False) + + print(f"\nSaved training annotations: {train_csv} ({len(train_df)} annotations, {len(train_images)} images)") + print(f"Saved test annotations: {test_csv} ({len(test_df)} annotations, {len(test_images)} images)") + + print(f"\nOutput directory: {args.output_dir}") + print("\nData preparation complete!") + + +if __name__ == "__main__": + main() + diff --git a/src/deepforest/scripts/push_bird_model_to_hf.py b/src/deepforest/scripts/push_bird_model_to_hf.py new file mode 100644 index 000000000..b7c243923 --- /dev/null +++ b/src/deepforest/scripts/push_bird_model_to_hf.py @@ -0,0 +1,87 @@ +"""Push trained bird detection model to HuggingFace Hub via PR. + +This script loads a trained model checkpoint and creates a pull request +on HuggingFace Hub to update the weecology/deepforest-bird model. + +Example usage: + python push_bird_model_to_hf.py --checkpoint path/to/checkpoint.ckpt +""" + +import argparse +import os +from pathlib import Path + +from dotenv import load_dotenv +from huggingface_hub import login + +from deepforest import main + + +def run(): + """Main function to push model to HuggingFace via PR.""" + parser = argparse.ArgumentParser( + description="Push trained bird model to HuggingFace Hub via PR" + ) + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the model checkpoint file (.ckpt)", + ) + parser.add_argument( + "--repo-id", + type=str, + default="weecology/deepforest-bird", + help="HuggingFace repository ID (default: weecology/deepforest-bird)", + ) + parser.add_argument( + "--commit-message", + type=str, + default="Update model weights", + help="Commit message for the PR", + ) + + args = parser.parse_args() + + # Load HF token from .env + load_dotenv() + hf_token = os.getenv("HF_TOKEN") + if not hf_token: + raise ValueError( + "HF_TOKEN not found in .env file. Please add your HuggingFace token to .env" + ) + + # Login to HuggingFace + login(token=hf_token) + + # Verify checkpoint exists + checkpoint_path = Path(args.checkpoint) + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") + + print(f"Loading model from checkpoint: {checkpoint_path}") + # Load model from checkpoint + model = main.deepforest.load_from_checkpoint(str(checkpoint_path)) + + # Ensure label_dict is set (should be loaded from checkpoint, but verify) + if not hasattr(model, "label_dict") or model.label_dict is None: + print("Warning: label_dict not found in checkpoint, setting default Bird label") + model.label_dict = {"Bird": 0} + model.numeric_to_label_dict = {0: "Bird"} + + print(f"Model loaded with label_dict: {model.label_dict}") + + # Push to HuggingFace Hub - this will automatically create a PR + print(f"Pushing model to {args.repo_id} and creating PR...") + model.model.push_to_hub( + args.repo_id, + commit_message=args.commit_message, + create_pr=True, + ) + + print(f"\nSuccessfully created PR to update {args.repo_id}!") + + +if __name__ == "__main__": + run() + diff --git a/src/deepforest/scripts/submit_train_birds.sh b/src/deepforest/scripts/submit_train_birds.sh new file mode 100755 index 000000000..fa1e729f4 --- /dev/null +++ b/src/deepforest/scripts/submit_train_birds.sh @@ -0,0 +1,24 @@ +#!/bin/bash +#SBATCH --job-name=train_birds # Job name +#SBATCH --mail-type=END # Mail events +#SBATCH --mail-user=benweinstein2010@gmail.com # Where to send mail +#SBATCH --account=ewhite +#SBATCH --nodes=1 # Number of MPI ran +#SBATCH --cpus-per-task=10 +#SBATCH --mem=200GB +#SBATCH --time=48:00:00 #Time limit hrs:min:sec +#SBATCH --output=/home/b.weinstein/logs/train_birds%j.out # Standard output and error log +#SBATCH --error=/home/b.weinstein/logs/train_birds%j.err +#SBATCH --partition=hpg-b200 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus=1 + +# Example usage: +# First prepare the data: +#uv run python src/deepforest/scripts/prepare_birds.py --output_dir /blue/ewhite/b.weinstein/bird_detector_retrain/data/ + +srun uv run python src/deepforest/scripts/train_birds.py \ + --data_dir /blue/ewhite/b.weinstein/bird_detector_retrain/data/ \ + --batch_size 32 \ + --workers 10 \ + --epochs 40 \ No newline at end of file diff --git a/src/deepforest/scripts/train_birds.py b/src/deepforest/scripts/train_birds.py new file mode 100644 index 000000000..30ba79839 --- /dev/null +++ b/src/deepforest/scripts/train_birds.py @@ -0,0 +1,248 @@ +"""Train DeepForest bird detection model. + +This script trains a bird detection model using the weecology/deepforest-bird +pretrained model as a starting point. + +Example usage: + python train_birds.py --data_dir /path/to/prepared/data --batch_size 12 --workers 5 +""" + +import argparse +import os + +import torch +from pytorch_lightning.loggers import CometLogger +from omegaconf import OmegaConf + +from deepforest import main, callbacks +import pandas as pd + + +def run(): + """Main training function.""" + parser = argparse.ArgumentParser(description="Train DeepForest bird detection model") + parser.add_argument( + "--data_dir", + type=str, + required=True, + help="Directory containing train.csv, test.csv, and images", + ) + parser.add_argument( + "--batch_size", + type=int, + default=12, + help="Batch size for training (default: 12)", + ) + parser.add_argument( + "--workers", + type=int, + default=5, + help="Number of workers for data loading (default: 5)", + ) + parser.add_argument( + "--epochs", + type=int, + default=12, + help="Number of training epochs (default: 12)", + ) + parser.add_argument( + "--lr", + type=float, + default=0.001, + help="Learning rate (default: 0.001)", + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + default=None, + help="Directory to save model checkpoints (default: data_dir/checkpoints)", + ) + parser.add_argument( + "--fast_dev_run", + action="store_true", + help="Run a fast development run with a single batch", + ) + + args = parser.parse_args() + + # Set matmul precision to high for faster training on Tensor Core GPUs + if torch.cuda.is_available(): + torch.set_float32_matmul_precision('high') + print("Set torch.float32_matmul_precision to 'high' for faster training") + + # Set up paths + train_csv = os.path.join(args.data_dir, "train.csv") + test_csv = os.path.join(args.data_dir, "test.csv") + + if not os.path.exists(train_csv): + raise FileNotFoundError(f"Training CSV not found: {train_csv}") + if not os.path.exists(test_csv): + raise FileNotFoundError(f"Test CSV not found: {test_csv}") + + if args.checkpoint_dir is None: + checkpoint_dir = os.path.join(args.data_dir, "checkpoints") + else: + checkpoint_dir = args.checkpoint_dir + os.makedirs(checkpoint_dir, exist_ok=True) + + print("Initializing DeepForest model...") + # Initialize DeepForest model + m = main.deepforest() + + # Load the pretrained tree model as a starting point + #print("Loading pretrained tree model: weecology/deepforest-tree") + m.load_model("weecology/deepforest-tree") + + # Set label dictionaries for single "Bird" class + m.label_dict = {"Bird": 0} + m.numeric_to_label_dict = {0: "Bird"} + m.config.label_dict = {"Bird": 0} + m.config.num_classes = 1 + + m.config.score_thresh = 0.25 + m.model.score_thresh = 0.25 + + # Configure training data paths + m.config["train"]["csv_file"] = train_csv + m.config["train"]["root_dir"] = args.data_dir + m.config["train"]["fast_dev_run"] = args.fast_dev_run + m.config["train"]["epochs"] = args.epochs + m.config["train"]["lr"] = args.lr + m.config["train"]["scheduler"]["params"]["patience"] = 3 + + # Configure validation data paths + m.config["validation"]["csv_file"] = test_csv + m.config["validation"]["root_dir"] = args.data_dir + m.config["validation"]["val_accuracy_interval"] = 1 + m.config["validation"]["size"] = 800 + + # Configure data loading + m.config["batch_size"] = args.batch_size + m.config["workers"] = args.workers + + # Configure augmentations with modern options + # Using zoom augmentations (RandomResizedCrop), rotations, and other augmentations + # Use OmegaConf.update to bypass strict type validation + augmentations_config = OmegaConf.create({ + "train": { + "augmentations": [ + {"RandomResizedCrop": {"size": (800, 800), "scale": (0.3, 1.0), "p": 0.5}}, + {"Rotate": {"degrees": 15, "p": 0.5}}, + {"HorizontalFlip": {"p": 0.5}}, + {"VerticalFlip": {"p": 0.3}}, + {"PadIfNeeded": {"size": (1000, 1000)}} + #{"RandomBrightnessContrast": {"brightness": 0.2, "contrast": 0.2, "p": 0.5}}, + #{"HueSaturationValue": {"hue": 0.1, "saturation": 0.1, "p": 0.3}}, + #{"ZoomBlur": {"max_factor": (1.0, 1.03), "step_factor": (0.01, 0.02), "p": 0.3}}, + ] + } + }) + OmegaConf.set_struct(m.config, False) + m.config = OmegaConf.merge(m.config, augmentations_config) + OmegaConf.set_struct(m.config, True) + + # Configure scheduler (similar to BOEM script) + m.config["train"]["scheduler"]["params"]["eps"] = 0 + + # Set up Comet logger (optional, will skip if not configured) + comet_logger = None + try: + comet_logger = CometLogger() + comet_logger.experiment.add_tag("bird-detection") + + # Log training and test set sizes + + train_df = pd.read_csv(train_csv) + test_df = pd.read_csv(test_csv) + comet_logger.experiment.log_table("train.csv", train_df) + comet_logger.experiment.log_table("test.csv", test_df) + + # Log training parameters + devices = torch.cuda.device_count() if torch.cuda.is_available() else 0 + comet_logger.experiment.log_parameter("devices", devices) + comet_logger.experiment.log_parameter("workers", m.config["workers"]) + comet_logger.experiment.log_parameter("batch_size", m.config["batch_size"]) + comet_logger.experiment.log_parameter("train_size", len(train_df)) + comet_logger.experiment.log_parameter("test_size", len(test_df)) + comet_logger.experiment.log_parameter("epochs", args.epochs) + comet_logger.experiment.log_parameter("learning_rate", args.lr) + + print(f"Comet logging enabled: {comet_logger.experiment.get_key()}") + except Exception as e: + print(f"Warning: Could not initialize Comet logger: {e}") + print("Continuing without Comet logging...") + comet_logger = None + + # Set up image callback for validation visualization + images_dir = os.path.join(checkpoint_dir, "images") + os.makedirs(images_dir, exist_ok=True) + im_callback = callbacks.ImagesCallback( + save_dir=images_dir, + prediction_samples=20, # Number of validation images to log + dataset_samples=20, # Number of dataset samples to log at start + every_n_epochs=1, # Log predictions every epoch + ) + + # Create trainer with GPU support + print("Creating trainer...") + # For DDP, each process uses 1 device. PyTorch Lightning will handle + + m.create_trainer( + logger=comet_logger, + callbacks=[im_callback], + devices=devices, + strategy="ddp", + precision="16-mixed", # Use mixed precision training for faster performance + fast_dev_run=args.fast_dev_run, + enable_progress_bar=True, + ) + + # Train the model + print("\nStarting training...") + m.trainer.fit(m) + m.trainer.validate(m) + + # Save the model checkpoint + checkpoint_path = os.path.join(checkpoint_dir, f"{comet_logger.experiment.id}.ckpt") + print(f"\nSaving checkpoint to: {checkpoint_path}") + m.trainer.save_checkpoint(checkpoint_path) + + # Evaluate on zero-shot dataset + print("\n" + "=" * 80) + print("Evaluating on zero-shot dataset (DeepWater Horizon)") + print("=" * 80) + + # Update validation config for zero-shot dataset + m.config.validation.csv_file = "/blue/ewhite/b.weinstein/bird_detector_retrain/zero_shot/avian_images_annotated/test_splits/test_split_patch_600.csv" + m.config.validation.root_dir = "/blue/ewhite/b.weinstein/bird_detector_retrain/zero_shot/avian_images_annotated/test_splits/patch_600" + m.config.validation.iou_threshold = 0.4 + + # Create new trainer for zero-shot evaluation + m.create_trainer() + + # Evaluate on zero-shot dataset + zero_shot_results = m.trainer.validate(m) + zero_shot_metrics = zero_shot_results[0] if zero_shot_results else {} + + print("\nZero-shot evaluation results:") + print(f" Box Precision: {zero_shot_metrics.get('box_precision', 'N/A')}") + print(f" Box Recall: {zero_shot_metrics.get('box_recall', 'N/A')}") + print(f" Empty Frame Accuracy: {zero_shot_metrics.get('empty_frame_accuracy', 'N/A')}") + + # log the zero-shot evaluation results to the comet logger + if comet_logger: + comet_logger.experiment.log_metric("zero_shot_box_precision", zero_shot_metrics.get('box_precision', 'N/A')) + comet_logger.experiment.log_metric("zero_shot_box_recall", zero_shot_metrics.get('box_recall', 'N/A')) + comet_logger.experiment.log_metric("zero_shot_empty_frame_accuracy", zero_shot_metrics.get('empty_frame_accuracy', 'N/A')) + + if comet_logger: + # Log global steps + global_steps = torch.tensor(m.trainer.global_step, dtype=torch.int32, device=m.device) + comet_logger.experiment.log_metric("global_steps", global_steps.item()) + + print("\nTraining complete!") + + +if __name__ == "__main__": + run() +