Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/figures/bird_prediction_example_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions docs/user_guide/02_prebuilt.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ The model was initially described in [Ecological Applications](https://esajourna
Using over 250,000 annotations from 13 projects from around the world, we develop a general bird detection model that achieves over 65% recall and 50% precision on novel aerial data without any local training despite differences in species, habitat, and imaging methodology. Fine-tuning this model with only 1000 local annotations increases these values to an average of 84% recall and 69% precision by building on the general features learned from other data sources.
>

The bird detection model has been updated and retrained from the original `weecology/deepforest-bird` model. The updated model was fine-tuned starting from the tree detection model (`weecology/deepforest-tree`) and trained on data from both Weinstein et al. 2022 as well as new additional bird detection data from multiple sources including https://lila.science/. The result is a dataset with over a million bird detections from around the world. Training details and metrics can be viewed on the [Comet dashboard](https://www.comet.com/bw4sz/bird-detector/6181df1ab7ac40f291b863a2a9b86024?&prevPath=%2Fbw4sz%2Fbird-detector%2Fview%2Fnew%2Fexperiments).

### Example Predictions

The following examples show predictions from the updated bird detection model:

![Bird Prediction Example 1](../figures/bird_prediction_example_1.png)

### Citation
> Weinstein, B.G., Garner, L., Saccomanno, V.R., Steinkraus, A., Ortega, A., Brush, K., Yenni, G., McKellar, A.E., Converse, R., Lippitt, C.D., Wegmann, A., Holmes, N.D., Edney, A.J., Hart, T., Jessopp, M.J., Clarke, R.H., Marchowski, D., Senyondo, H., Dotson, R., White, E.P., Frederick, P. and Ernest, S.K.M. (2022), A general deep learning model for bird detection in high resolution airborne imagery. Ecological Applications. Accepted Author Manuscript e2694. https://doi-org.lp.hscl.ufl.edu/10.1002/eap.2694
Expand Down
66 changes: 63 additions & 3 deletions docs/user_guide/07_scaling.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,68 @@ For example on a SLURM cluster, we use the following line to get 5 gpus on a sin
m.create_trainer(logger=comet_logger, accelerator="gpu", strategy="ddp", num_nodes=1, devices=devices)
```

### Complete SLURM Example

Here's a complete example for training on 2 GPUs with SLURM that has been tested and works correctly:

**SLURM submission script (`submit_train.sh`):**
```bash
#!/bin/bash
#SBATCH --job-name=train_model
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=2
#SBATCH --gpus=2
#SBATCH --cpus-per-task=10
#SBATCH --mem=200GB
#SBATCH --time=48:00:00
#SBATCH --output=train_%j.out
#SBATCH --error=train_%j.err

# Use srun without explicit task/gpu flags - SLURM handles process spawning
# The --ntasks-per-node and --gpus directives above control resource allocation
srun uv run python train_script.py \
--data_dir /path/to/data \
--batch_size 24 \
--workers 10 \
--epochs 30
```

**Python training script:**
```python
import torch
from deepforest import main

# Initialize model
m = main.deepforest()

# Configure training parameters
m.config["train"]["csv_file"] = "train.csv"
m.config["train"]["root_dir"] = "/path/to/data"
m.config["batch_size"] = 24
m.config["workers"] = 10

# Set devices to total GPU count - PyTorch Lightning DDP will handle
# device assignment per process when used with SLURM process spawning
devices = torch.cuda.device_count() if torch.cuda.is_available() else 0

m.create_trainer(
logger=comet_logger,
devices=devices,
strategy="ddp", # Distributed Data Parallel
fast_dev_run=False,
)

# Train the model
m.trainer.fit(m)
```

**Key points:**
- `--ntasks-per-node=2` in SBATCH directives spawns 2 processes (one per GPU)
- `--gpus=2` in SBATCH directives allocates 2 GPUs total
- Use plain `srun` without `--ntasks` or `--gpus-per-task` flags - let SLURM handle process spawning
- Set `devices=torch.cuda.device_count()` in Python (not `devices=1`) - PyTorch Lightning DDP coordinates with SLURM's process management
- Batch size is per-GPU: with `batch_size=24` and 2 GPUs, you get 48 images per forward pass total

While we rarely use multi-node GPU's, pytorch lightning has functionality at very large scales. We welcome users to share what configurations worked best.

A few notes that can trip up those less used to multi-gpu training. These are for the default configurations and may vary on a specific system. We use a large University SLURM cluster with 'ddp' distributed data parallel.
Expand All @@ -27,9 +89,7 @@ A few notes that can trip up those less used to multi-gpu training. These are fo

2. Each device gets its own portion of the dataset. This means that they do not interact during forward passes.

3. Make sure to use srun when combining with SLURM! This is an easy one to miss and will cause training to hang without error. Documented here

https://lightning.ai/docs/pytorch/latest/clouds/cluster_advanced.html#troubleshooting.
3. Make sure to use `srun` when combining with SLURM! This is critical for proper process spawning and will cause training to hang without error if omitted. Use `--ntasks-per-node` in SBATCH directives (not in srun) to control the number of processes. Documented [here](https://lightning.ai/docs/pytorch/latest/clouds/cluster_advanced.html#troubleshooting).


## Prediction
Expand Down
36 changes: 25 additions & 11 deletions src/deepforest/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
prediction_samples=2,
dataset_samples=5,
every_n_epochs=5,
select_random=False,
select_random=True,
color=None,
thickness=2,
):
Expand All @@ -58,6 +58,10 @@ def on_train_start(self, trainer, pl_module):
if trainer.fast_dev_run:
return

# Only run on rank 0 to avoid file I/O synchronization issues in DDP
if trainer.global_rank != 0:
return

self.trainer = trainer
self.pl_module = pl_module

Expand All @@ -77,6 +81,10 @@ def on_validation_end(self, trainer, pl_module):
if trainer.sanity_checking or trainer.fast_dev_run:
return

# Only run on rank 0 to avoid file I/O synchronization issues in DDP
if trainer.global_rank != 0:
return

if (trainer.current_epoch + 1) % self.every_n_epochs == 0:
pl_module.print("Logging prediction samples")
self._log_last_predictions(trainer, pl_module)
Expand Down Expand Up @@ -141,6 +149,10 @@ def _log_last_predictions(self, trainer, pl_module):
else:
df = pd.DataFrame()

# Skip logging if there are no predictions
if df.empty or "image_path" not in df.columns:
return

out_dir = os.path.join(self.savedir, "predictions")
os.makedirs(out_dir, exist_ok=True)

Expand All @@ -151,19 +163,21 @@ def _log_last_predictions(self, trainer, pl_module):
df["root_dir"] = dataset.root_dir

# Limit to n images, potentially randomly selected
unique_images = df.image_path.unique()
n_samples = min(self.prediction_samples, len(unique_images))
if n_samples == 0:
return
if self.select_random:
selected_images = np.random.choice(
df.image_path.unique(), self.prediction_samples
)
selected_images = np.random.choice(unique_images, n_samples, replace=False)
else:
selected_images = df.image_path.unique()[: self.prediction_samples]
selected_images = unique_images[:n_samples]

# Ensure color is correctly assigned
if self.color is None:
num_classes = len(df["label"].unique())
results_color = sv.ColorPalette.from_matplotlib("viridis", num_classes)
else:
results_color = self.color
# Ensure color is correctly assigned
if self.color is None:
num_classes = len(df["label"].unique())
results_color = sv.ColorPalette.from_matplotlib("viridis", num_classes)
else:
results_color = self.color

for image_name in selected_images:
pred_df = df[df.image_path == image_name]
Expand Down
2 changes: 1 addition & 1 deletion src/deepforest/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Cpu workers for data loaders
# Dataloaders
workers: 0
workers: 5
devices: auto
accelerator: auto
batch_size: 1
Expand Down
21 changes: 14 additions & 7 deletions src/deepforest/datasets/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(
self.preload_images = preload_images

self._validate_labels()
self._validate_coordinates()
#self._validate_coordinates()

# Pin data to memory if desired
if self.preload_images:
Expand Down Expand Up @@ -109,14 +109,21 @@ def _validate_coordinates(self):
ValueError: If any bounding box coordinate occurs outside the image
"""
errors = []
# Cache image dimensions to avoid opening the same image multiple times
image_dims = {}
for _idx, row in self.annotations.iterrows():
img_path = os.path.join(self.root_dir, row["image_path"])
try:
with Image.open(img_path) as img:
width, height = img.size
except Exception as e:
errors.append(f"Failed to open image {img_path}: {e}")
continue

# Get image dimensions (cached per unique image)
if img_path not in image_dims:
try:
with Image.open(img_path) as img:
image_dims[img_path] = img.size
except Exception as e:
errors.append(f"Failed to open image {img_path}: {e}")
continue

width, height = image_dims[img_path]

# Extract bounding box
geom = row["geometry"]
Expand Down
29 changes: 17 additions & 12 deletions src/deepforest/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,21 +843,26 @@ def on_validation_epoch_end(self):
self.log_epoch_metrics()

if (self.current_epoch + 1) % self.config.validation.val_accuracy_interval == 0:
if len(self.predictions) > 0:
predictions = pd.concat(self.predictions)
else:
predictions = pd.DataFrame()
# Only run evaluate on rank 0 to avoid file I/O synchronization issues in DDP
if self.trainer.global_rank == 0:
if len(self.predictions) > 0:
predictions = pd.concat(self.predictions)
else:
predictions = pd.DataFrame()

results = self.evaluate(
self.config.validation.csv_file,
root_dir=self.config.validation.root_dir,
size=self.config.validation.size,
predictions=predictions,
)
results = self.evaluate(
self.config.validation.csv_file,
root_dir=self.config.validation.root_dir,
size=self.config.validation.size,
predictions=predictions,
)

self.__evaluation_logs__(results)
self.__evaluation_logs__(results)

return results
return results
else:
# Other ranks return None - metrics are already logged via log_epoch_metrics
return None

def predict_step(self, batch, batch_idx):
"""Predict a batch of images with the deepforest model. If batch is a
Expand Down
6 changes: 3 additions & 3 deletions src/deepforest/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,15 @@ def split_raster(
)

# Convert from channels-last (H x W x C) to channels-first (C x H x W)
if numpy_image.shape[2] in [3, 4]:
if len(numpy_image.shape) == 3 and numpy_image.shape[2] in [3, 4]:
print(
f"Image shape is {numpy_image.shape[2]}, assuming this is channels last, "
"converting to channels first"
)
numpy_image = numpy_image.transpose(2, 0, 1)

# Check that it's 3 bands
bands = numpy_image.shape[2]
# Check that it's 3 bands (after transpose, shape is (C, H, W), so bands is shape[0])
bands = numpy_image.shape[0]
if not bands == 3:
warnings.warn(
f"Input image had non-3 band shape of {numpy_image.shape}, selecting first 3 bands",
Expand Down
Loading
Loading