Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions ultralytics/models/sam/amg.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,16 +181,32 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if not small_regions:
sizes = stats[1:, -1] # Row 0 is background label

# Use np.flatnonzero for efficiency
small_regions_idx = np.flatnonzero(sizes < area_thresh)
if small_regions_idx.size == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
small_regions = small_regions_idx + 1

if correct_holes:
fill_labels = np.concatenate(([0], small_regions))
mask_out = np.isin(regions, fill_labels)
return mask_out, True
else:
# If every region is below threshold, keep largest
fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
fill_labels_set = set(range(n_labels)) - set(np.concatenate(([0], small_regions)))
if not fill_labels_set:
# All regions below threshold, keep largest
fill_labels = [int(np.argmax(sizes)) + 1]
else:
fill_labels = list(fill_labels_set)
# Use an efficient lookup if fill_labels is large
if len(fill_labels) == 1:
mask_out = regions == fill_labels[0]
else:
mask_out = np.isin(regions, fill_labels)
return mask_out, True


def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
Expand Down