diff --git a/ultralytics/models/sam/amg.py b/ultralytics/models/sam/amg.py index e5c577c0bd2..e7d0b79f696 100644 --- a/ultralytics/models/sam/amg.py +++ b/ultralytics/models/sam/amg.py @@ -181,16 +181,32 @@ def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> Tup correct_holes = mode == "holes" working_mask = (correct_holes ^ mask).astype(np.uint8) n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) - sizes = stats[:, -1][1:] # Row 0 is background label - small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] - if not small_regions: + sizes = stats[1:, -1] # Row 0 is background label + + # Use np.flatnonzero for efficiency + small_regions_idx = np.flatnonzero(sizes < area_thresh) + if small_regions_idx.size == 0: return mask, False - fill_labels = [0] + small_regions - if not correct_holes: + small_regions = small_regions_idx + 1 + + if correct_holes: + fill_labels = np.concatenate(([0], small_regions)) + mask_out = np.isin(regions, fill_labels) + return mask_out, True + else: # If every region is below threshold, keep largest - fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1] - mask = np.isin(regions, fill_labels) - return mask, True + fill_labels_set = set(range(n_labels)) - set(np.concatenate(([0], small_regions))) + if not fill_labels_set: + # All regions below threshold, keep largest + fill_labels = [int(np.argmax(sizes)) + 1] + else: + fill_labels = list(fill_labels_set) + # Use an efficient lookup if fill_labels is large + if len(fill_labels) == 1: + mask_out = regions == fill_labels[0] + else: + mask_out = np.isin(regions, fill_labels) + return mask_out, True def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: