From da65138a6c852c987075e152ab735727642973d6 Mon Sep 17 00:00:00 2001 From: Muhammad Saqlain <2mesaqlain@gmail.com> Date: Fri, 26 Dec 2025 01:09:49 +0500 Subject: [PATCH 1/5] Fix: Relax augmentation schema to allow list of dicts --- src/deepforest/conf/schema.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/deepforest/conf/schema.py b/src/deepforest/conf/schema.py index 4e66e6394..1bdfff274 100644 --- a/src/deepforest/conf/schema.py +++ b/src/deepforest/conf/schema.py @@ -1,7 +1,6 @@ from dataclasses import dataclass, field - from omegaconf import DictConfig, OmegaConf - +from typing import Any @dataclass class ModelConfig: @@ -67,7 +66,7 @@ class TrainConfig: epochs: int = 1 fast_dev_run: bool = False preload_images: bool = False - augmentations: list[str] | None = field(default_factory=lambda: ["HorizontalFlip"]) + augmentations: list[Any] | None = field(default_factory=lambda: ["HorizontalFlip"]) @dataclass @@ -86,7 +85,7 @@ class ValidationConfig: iou_threshold: float = 0.4 val_accuracy_interval: int = 20 lr_plateau_target: str = "val_loss" - augmentations: list[str] | None = field(default_factory=lambda: []) + augmentations: list[Any] | None = field(default_factory=lambda: []) @dataclass From 40d05da272a81d41757fcda0e84aef10ae5a8a63 Mon Sep 17 00:00:00 2001 From: Muhammad Saqlain <2mesaqlain@gmail.com> Date: Fri, 26 Dec 2025 01:25:33 +0500 Subject: [PATCH 2/5] Style: Apply automated formatting fixes --- src/deepforest/conf/schema.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/deepforest/conf/schema.py b/src/deepforest/conf/schema.py index 1bdfff274..50f9e6ddd 100644 --- a/src/deepforest/conf/schema.py +++ b/src/deepforest/conf/schema.py @@ -1,7 +1,9 @@ from dataclasses import dataclass, field -from omegaconf import DictConfig, OmegaConf from typing import Any +from omegaconf import DictConfig, OmegaConf + + @dataclass class ModelConfig: """Model configuration that defines the repository ID on HuggingFace and From f768cdec8d3dba7940ab6a2ceff1e481a0e99f2b Mon Sep 17 00:00:00 2001 From: Muhammad Saqlain <2mesaqlain@gmail.com> Date: Thu, 8 Jan 2026 05:09:16 +0500 Subject: [PATCH 3/5] Test: Add verification for dictionary-based augmentation config --- tests/test_augmentations.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index 3e5ae86fa..47df66527 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -347,6 +347,22 @@ def test_geometric_augmentation_filters_boxes(): # Labels should match box count assert len(labels) == len(boxes), f"Label count mismatch: {len(labels)} labels, {len(boxes)} boxes" +def test_augmentation_schema_validation(m): + """ + Test that the schema accepts a list of dictionaries for augmentations. + """ + # Define the complex dictionary + complex_augmentations = [ + {"RandomResizedCrop": {"size": (800, 800), "scale": (0.5, 1.0), "p": 0.3}}, + {"Rotate": {"degrees": 15, "p": 0.5}}, + {"HorizontalFlip": {"p": 0.5}} + ] + + # assigning it to the config + m.config.train.augmentations = complex_augmentations + + # Verifying that it was stored correctly + assert m.config.train.augmentations == complex_augmentations if __name__ == "__main__": pytest.main([__file__]) From 1448db0714b3b0ff5c435d9119c176b9b281ef87 Mon Sep 17 00:00:00 2001 From: Muhammad Saqlain <2mesaqlain@gmail.com> Date: Thu, 8 Jan 2026 05:39:12 +0500 Subject: [PATCH 4/5] Test: Update schema validation to use config_args during init --- tests/test_augmentations.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index 47df66527..eb1c5a137 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -347,7 +347,7 @@ def test_geometric_augmentation_filters_boxes(): # Labels should match box count assert len(labels) == len(boxes), f"Label count mismatch: {len(labels)} labels, {len(boxes)} boxes" -def test_augmentation_schema_validation(m): +def test_augmentation_schema_validation(): """ Test that the schema accepts a list of dictionaries for augmentations. """ @@ -358,10 +358,10 @@ def test_augmentation_schema_validation(m): {"HorizontalFlip": {"p": 0.5}} ] - # assigning it to the config - m.config.train.augmentations = complex_augmentations + # Pass it directly to the model constructor via config_args + m = main.deepforest(config_args={"train": {"augmentations": complex_augmentations}}) - # Verifying that it was stored correctly + # Verify it was stored correctly assert m.config.train.augmentations == complex_augmentations if __name__ == "__main__": From d375f0db4c34bcb5aa7d5af72004b4951e10ec10 Mon Sep 17 00:00:00 2001 From: Muhammad Saqlain <2mesaqlain@gmail.com> Date: Sat, 10 Jan 2026 00:15:12 +0500 Subject: [PATCH 5/5] Refactor: Robust mixed-type testing and revert schema to List[Any] due to OmegaConf limitation --- tests/test_augmentations.py | 45 +++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index eb1c5a137..f1a9ebf52 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -41,6 +41,34 @@ def test_load_dataset_without_augmentations(): image, target, path = next(iter(train_ds)) assert len(train_ds.dataset.transform) == 0 +def test_augmentation_schema_validation(): + """ + Test that the schema accepts a mixed list of dictionaries and strings for augmentations, + and that they are correctly applied to the dataset pipeline. + """ + augmentations = [ + {"RandomResizedCrop": {"size": (800, 800), "scale": (0.5, 1.0), "p": 0.3}}, + "HorizontalFlip" + ] + + m = main.deepforest(config_args={"train": {"augmentations": augmentations}}) + + # Verify Schema stored it correctly + assert m.config.train.augmentations == augmentations + + csv_file = get_data("example.csv") + root_dir = os.path.dirname(csv_file) + + train_ds = m.load_dataset(csv_file, root_dir=root_dir, augmentations=augmentations) + + transforms = train_ds.dataset.transform + + has_resized_crop = any(isinstance(t, K.RandomResizedCrop) for t in transforms) + has_hflip = any(isinstance(t, K.RandomHorizontalFlip) for t in transforms) + + assert has_resized_crop + assert has_hflip + """ Augmentation parsing tests: """ @@ -347,22 +375,5 @@ def test_geometric_augmentation_filters_boxes(): # Labels should match box count assert len(labels) == len(boxes), f"Label count mismatch: {len(labels)} labels, {len(boxes)} boxes" -def test_augmentation_schema_validation(): - """ - Test that the schema accepts a list of dictionaries for augmentations. - """ - # Define the complex dictionary - complex_augmentations = [ - {"RandomResizedCrop": {"size": (800, 800), "scale": (0.5, 1.0), "p": 0.3}}, - {"Rotate": {"degrees": 15, "p": 0.5}}, - {"HorizontalFlip": {"p": 0.5}} - ] - - # Pass it directly to the model constructor via config_args - m = main.deepforest(config_args={"train": {"augmentations": complex_augmentations}}) - - # Verify it was stored correctly - assert m.config.train.augmentations == complex_augmentations - if __name__ == "__main__": pytest.main([__file__])