Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/deepforest/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ def get_transform(
aug_transform = _create_augmentation(aug_name, aug_params)
transforms_list.append(aug_transform)

# PadIfNeeded must be the last entry, if present:
for i, t in enumerate(transforms_list):
if isinstance(t, K.PadTo):
transforms_list.append(transforms_list.pop(i))
break

# Create a sequential container for all transforms
return K.AugmentationSequential(
*transforms_list, data_keys=[DataKey.IMAGE, DataKey.BBOX_XYXY]
Expand Down
24 changes: 16 additions & 8 deletions src/deepforest/datasets/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,25 +151,25 @@ def _validate_coordinates(self):
if errors:
raise ValueError("\n".join(errors))

def filter_boxes(self, boxes, labels, image_shape, min_size=1):
def filter_boxes(self, boxes, labels, width, height, min_size=1):
"""Clamp boxes to image bounds and filter by minimum dimension.

Args:
boxes (torch.Tensor): Bounding boxes of shape (N, 4) in xyxy format.
labels (torch.Tensor): Labels of shape (N,).
image_shape (tuple): Image shape as (C, H, W).
width (int): Image width in pixels.
height (int): Image height in pixels.
min_size (int): Minimum box width/height in pixels. Defaults to 1.

Returns:
tuple: A tuple of (filtered_boxes, filtered_labels)
"""
_, H, W = image_shape

# Clamp boxes to image bounds
boxes[:, 0] = torch.clamp(boxes[:, 0], min=0, max=W) # x1
boxes[:, 1] = torch.clamp(boxes[:, 1], min=0, max=H) # y1
boxes[:, 2] = torch.clamp(boxes[:, 2], min=0, max=W) # x2
boxes[:, 3] = torch.clamp(boxes[:, 3], min=0, max=H) # y2
boxes[:, 0] = torch.clamp(boxes[:, 0], min=0, max=width) # x1
boxes[:, 1] = torch.clamp(boxes[:, 1], min=0, max=height) # y1
boxes[:, 2] = torch.clamp(boxes[:, 2], min=0, max=width) # x2
boxes[:, 3] = torch.clamp(boxes[:, 3], min=0, max=height) # y2

# Filter boxes with minimum size
width = boxes[:, 2] - boxes[:, 0]
Expand Down Expand Up @@ -264,7 +264,15 @@ def __getitem__(self, idx):
labels = torch.from_numpy(targets["labels"].astype(np.int64))

# Filter invalid boxes after augmentation
boxes, labels = self.filter_boxes(boxes, labels, image.shape)
# Since the augmentation operation may change image size, we take the smallest
# of the source and transformed dimensions assuming that padding is always the
# last operation in the pipeline.
boxes, labels = self.filter_boxes(
boxes,
labels,
width=min(image_tensor.shape[3], image.shape[2]),
height=min(image_tensor.shape[2], image.shape[1]),
)

# Edge case if all labels were augmented away, keep the image
if len(boxes) == 0:
Expand Down
3 changes: 1 addition & 2 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,9 @@ def test_filter_boxes():
[0., 0., 0., 10.], # Too small
])
labels = torch.tensor([0, 1, 2, 3, 4, 5])
image_shape = (3, 200, 200)

dataset = BoxDataset.__new__(BoxDataset)
filtered_boxes, filtered_labels = dataset.filter_boxes(boxes, labels, image_shape)
filtered_boxes, filtered_labels = dataset.filter_boxes(boxes, labels, width=200, height=200, min_size=1)

assert filtered_boxes.shape[0] == 3
assert torch.equal(filtered_labels, torch.tensor([2, 3, 4]))
Expand Down