diff --git a/src/deepforest/conf/schema.py b/src/deepforest/conf/schema.py index 4e66e6394..50f9e6ddd 100644 --- a/src/deepforest/conf/schema.py +++ b/src/deepforest/conf/schema.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import Any from omegaconf import DictConfig, OmegaConf @@ -67,7 +68,7 @@ class TrainConfig: epochs: int = 1 fast_dev_run: bool = False preload_images: bool = False - augmentations: list[str] | None = field(default_factory=lambda: ["HorizontalFlip"]) + augmentations: list[Any] | None = field(default_factory=lambda: ["HorizontalFlip"]) @dataclass @@ -86,7 +87,7 @@ class ValidationConfig: iou_threshold: float = 0.4 val_accuracy_interval: int = 20 lr_plateau_target: str = "val_loss" - augmentations: list[str] | None = field(default_factory=lambda: []) + augmentations: list[Any] | None = field(default_factory=lambda: []) @dataclass diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index 3e5ae86fa..f1a9ebf52 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -41,6 +41,34 @@ def test_load_dataset_without_augmentations(): image, target, path = next(iter(train_ds)) assert len(train_ds.dataset.transform) == 0 +def test_augmentation_schema_validation(): + """ + Test that the schema accepts a mixed list of dictionaries and strings for augmentations, + and that they are correctly applied to the dataset pipeline. + """ + augmentations = [ + {"RandomResizedCrop": {"size": (800, 800), "scale": (0.5, 1.0), "p": 0.3}}, + "HorizontalFlip" + ] + + m = main.deepforest(config_args={"train": {"augmentations": augmentations}}) + + # Verify Schema stored it correctly + assert m.config.train.augmentations == augmentations + + csv_file = get_data("example.csv") + root_dir = os.path.dirname(csv_file) + + train_ds = m.load_dataset(csv_file, root_dir=root_dir, augmentations=augmentations) + + transforms = train_ds.dataset.transform + + has_resized_crop = any(isinstance(t, K.RandomResizedCrop) for t in transforms) + has_hflip = any(isinstance(t, K.RandomHorizontalFlip) for t in transforms) + + assert has_resized_crop + assert has_hflip + """ Augmentation parsing tests: """ @@ -347,6 +375,5 @@ def test_geometric_augmentation_filters_boxes(): # Labels should match box count assert len(labels) == len(boxes), f"Label count mismatch: {len(labels)} labels, {len(boxes)} boxes" - if __name__ == "__main__": pytest.main([__file__])