Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/deepforest/conf/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from typing import Any

from omegaconf import DictConfig, OmegaConf

Expand Down Expand Up @@ -67,7 +68,7 @@ class TrainConfig:
epochs: int = 1
fast_dev_run: bool = False
preload_images: bool = False
augmentations: list[str] | None = field(default_factory=lambda: ["HorizontalFlip"])
augmentations: list[Any] | None = field(default_factory=lambda: ["HorizontalFlip"])


@dataclass
Expand All @@ -86,7 +87,7 @@ class ValidationConfig:
iou_threshold: float = 0.4
val_accuracy_interval: int = 20
lr_plateau_target: str = "val_loss"
augmentations: list[str] | None = field(default_factory=lambda: [])
augmentations: list[Any] | None = field(default_factory=lambda: [])


@dataclass
Expand Down
29 changes: 28 additions & 1 deletion tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,34 @@ def test_load_dataset_without_augmentations():
image, target, path = next(iter(train_ds))
assert len(train_ds.dataset.transform) == 0

def test_augmentation_schema_validation():
"""
Test that the schema accepts a mixed list of dictionaries and strings for augmentations,
and that they are correctly applied to the dataset pipeline.
"""
augmentations = [
{"RandomResizedCrop": {"size": (800, 800), "scale": (0.5, 1.0), "p": 0.3}},
"HorizontalFlip"
]

m = main.deepforest(config_args={"train": {"augmentations": augmentations}})

# Verify Schema stored it correctly
assert m.config.train.augmentations == augmentations

csv_file = get_data("example.csv")
root_dir = os.path.dirname(csv_file)

train_ds = m.load_dataset(csv_file, root_dir=root_dir, augmentations=augmentations)

transforms = train_ds.dataset.transform

has_resized_crop = any(isinstance(t, K.RandomResizedCrop) for t in transforms)
has_hflip = any(isinstance(t, K.RandomHorizontalFlip) for t in transforms)

assert has_resized_crop
assert has_hflip

"""
Augmentation parsing tests:
"""
Expand Down Expand Up @@ -347,6 +375,5 @@ def test_geometric_augmentation_filters_boxes():
# Labels should match box count
assert len(labels) == len(boxes), f"Label count mismatch: {len(labels)} labels, {len(boxes)} boxes"


if __name__ == "__main__":
pytest.main([__file__])