diff --git a/src/deepforest/augmentations.py b/src/deepforest/augmentations.py index a6577068b..92c23624a 100644 --- a/src/deepforest/augmentations.py +++ b/src/deepforest/augmentations.py @@ -244,6 +244,12 @@ def get_transform( aug_transform = _create_augmentation(aug_name, aug_params) transforms_list.append(aug_transform) + # PadIfNeeded must be the last entry, if present: + for i, t in enumerate(transforms_list): + if isinstance(t, K.PadTo): + transforms_list.append(transforms_list.pop(i)) + break + # Create a sequential container for all transforms return K.AugmentationSequential( *transforms_list, data_keys=[DataKey.IMAGE, DataKey.BBOX_XYXY] diff --git a/src/deepforest/datasets/training.py b/src/deepforest/datasets/training.py index b3bdf085e..d1c448b49 100644 --- a/src/deepforest/datasets/training.py +++ b/src/deepforest/datasets/training.py @@ -151,25 +151,25 @@ def _validate_coordinates(self): if errors: raise ValueError("\n".join(errors)) - def filter_boxes(self, boxes, labels, image_shape, min_size=1): + def filter_boxes(self, boxes, labels, width, height, min_size=1): """Clamp boxes to image bounds and filter by minimum dimension. Args: boxes (torch.Tensor): Bounding boxes of shape (N, 4) in xyxy format. labels (torch.Tensor): Labels of shape (N,). - image_shape (tuple): Image shape as (C, H, W). + width (int): Image width in pixels. + height (int): Image height in pixels. min_size (int): Minimum box width/height in pixels. Defaults to 1. Returns: tuple: A tuple of (filtered_boxes, filtered_labels) """ - _, H, W = image_shape # Clamp boxes to image bounds - boxes[:, 0] = torch.clamp(boxes[:, 0], min=0, max=W) # x1 - boxes[:, 1] = torch.clamp(boxes[:, 1], min=0, max=H) # y1 - boxes[:, 2] = torch.clamp(boxes[:, 2], min=0, max=W) # x2 - boxes[:, 3] = torch.clamp(boxes[:, 3], min=0, max=H) # y2 + boxes[:, 0] = torch.clamp(boxes[:, 0], min=0, max=width) # x1 + boxes[:, 1] = torch.clamp(boxes[:, 1], min=0, max=height) # y1 + boxes[:, 2] = torch.clamp(boxes[:, 2], min=0, max=width) # x2 + boxes[:, 3] = torch.clamp(boxes[:, 3], min=0, max=height) # y2 # Filter boxes with minimum size width = boxes[:, 2] - boxes[:, 0] @@ -264,7 +264,15 @@ def __getitem__(self, idx): labels = torch.from_numpy(targets["labels"].astype(np.int64)) # Filter invalid boxes after augmentation - boxes, labels = self.filter_boxes(boxes, labels, image.shape) + # Since the augmentation operation may change image size, we take the smallest + # of the source and transformed dimensions assuming that padding is always the + # last operation in the pipeline. + boxes, labels = self.filter_boxes( + boxes, + labels, + width=min(image_tensor.shape[3], image.shape[2]), + height=min(image_tensor.shape[2], image.shape[1]), + ) # Edge case if all labels were augmented away, keep the image if len(boxes) == 0: diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index 3e5ae86fa..91db2fd82 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -295,10 +295,9 @@ def test_filter_boxes(): [0., 0., 0., 10.], # Too small ]) labels = torch.tensor([0, 1, 2, 3, 4, 5]) - image_shape = (3, 200, 200) dataset = BoxDataset.__new__(BoxDataset) - filtered_boxes, filtered_labels = dataset.filter_boxes(boxes, labels, image_shape) + filtered_boxes, filtered_labels = dataset.filter_boxes(boxes, labels, width=200, height=200, min_size=1) assert filtered_boxes.shape[0] == 3 assert torch.equal(filtered_labels, torch.tensor([2, 3, 4]))