diff --git a/GUI/controllers/AnnotationClusteringController.py b/GUI/controllers/AnnotationClusteringController.py index 9f04cfc..1aca360 100644 --- a/GUI/controllers/AnnotationClusteringController.py +++ b/GUI/controllers/AnnotationClusteringController.py @@ -70,8 +70,13 @@ def _extract_from_image(self, img: dict, idx: int) -> list[Annotation]: if umap is None or logits is None or fname is None: return out - feats, coords = self._generator.generate_annotations(uncertainty_map=umap, logits=logits) - for c, f in zip(coords, feats): + result = self._generator.generate_annotations(uncertainty_map=umap, logits=logits) + if len(result) == 2: + feats, coords = result + masks = [None] * len(coords) + else: + feats, coords, masks = result + for c, f, m in zip(coords, feats, masks): if not f.any(): continue out.append( @@ -84,6 +89,8 @@ def _extract_from_image(self, img: dict, idx: int) -> list[Annotation]: uncertainty=float(umap[tuple(c)]), cluster_id=None, model_prediction=CLASS_COMPONENTS.get(int(np.argmax(f)), "None"), + mask_rle=m, + mask_shape=umap.shape[:2], ) ) return out diff --git a/GUI/controllers/ImageProcessingController.py b/GUI/controllers/ImageProcessingController.py index d92cc05..19bc334 100644 --- a/GUI/controllers/ImageProcessingController.py +++ b/GUI/controllers/ImageProcessingController.py @@ -125,17 +125,20 @@ def _display_processed_annotations(self, annotations_data: List[dict]): anno = data['annotation'] np_image = data['processed_crop'] coord_pos = data['coord_pos'] + mask_crop = data.get('mask_crop') if np_image is None or coord_pos is None: logging.warning(f"Missing image or coords for annotation: {anno}") continue q_pixmap = self._numpy_to_qpixmap(np_image) + mask_pix = self._numpy_to_qpixmap(mask_crop) if mask_crop is not None else None sampled_crops.append({ 'annotation': anno, 'processed_crop': q_pixmap, - 'coord_pos': coord_pos + 'coord_pos': coord_pos, + 'mask_pixmap': mask_pix, }) self.crops_ready.emit(sampled_crops) diff --git a/GUI/controllers/MainController.py b/GUI/controllers/MainController.py index 61abac3..633a4d0 100644 --- a/GUI/controllers/MainController.py +++ b/GUI/controllers/MainController.py @@ -17,7 +17,9 @@ LocalMaximaPointAnnotationGenerator, EquidistantPointAnnotationGenerator, CenterPointAnnotationGenerator, + LocalMaximaPointAnnotationGenerator, ) +from GUI.models.SuperpixelAnnotationGenerator import SLICSuperpixelGenerator from GUI.models.UncertaintyPropagator import propagate_for_annotations from GUI.models.export.Options import ExportOptions from GUI.models.export.Usecase import ExportAnnotationsUseCase @@ -231,6 +233,9 @@ def on_label_generator_method_changed(self, method: str): elif method == "Image Centre": self.annotation_generator = CenterPointAnnotationGenerator() self._use_greedy_nav = False + elif method == "Superpixel Masks": + self.annotation_generator = SLICSuperpixelGenerator() + self._use_greedy_nav = True else: self.annotation_generator = LocalMaximaPointAnnotationGenerator( filter_size=48, gaussian_sigma=4.0, use_gaussian=False diff --git a/GUI/models/Annotation.py b/GUI/models/Annotation.py index 7594d96..6da7f29 100644 --- a/GUI/models/Annotation.py +++ b/GUI/models/Annotation.py @@ -17,6 +17,8 @@ class Annotation: cluster_id: Optional[int] = None model_prediction: Optional[str] = None adjusted_uncertainty: Optional[Union[float, np.ndarray]] = None + mask_rle: Optional[list] = None + mask_shape: Optional[Tuple[int, int]] = None # ---------- core --------------------------------------------------------------- def __setattr__(self, name, value): @@ -28,6 +30,28 @@ def __post_init__(self): if self.adjusted_uncertainty is None: self.adjusted_uncertainty = self.uncertainty + # ---------- mask helpers ------------------------------------------------- + @staticmethod + def encode_mask(mask: np.ndarray) -> list: + """Return run-length encoding for ``mask``.""" + pixels = mask.astype(np.uint8).flatten(order="F") + pixels = np.concatenate([[0], pixels, [0]]) + runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 + runs[1::2] -= runs[::2] + return runs.tolist() + + @staticmethod + def decode_mask(rle: list, shape: Tuple[int, int]) -> np.ndarray: + """Decode RLE ``rle`` back to a binary mask of ``shape``.""" + rle = np.asarray(rle, dtype=int) + starts = rle[0::2] - 1 + lengths = rle[1::2] + ends = starts + lengths + img = np.zeros(shape[0] * shape[1], dtype=np.uint8) + for s, e in zip(starts, ends): + img[s:e] = 1 + return img.reshape(shape, order="F") + # ---------- public helpers ----------------------------------------------------- def reset_uncertainty(self) -> None: """Restore posterior to prior.""" @@ -56,6 +80,8 @@ def to_dict(self) -> dict: int(self.cluster_id) if self.cluster_id is not None else None ), "model_prediction": self.model_prediction, + "mask_rle": self.mask_rle, + "mask_shape": list(self.mask_shape) if self.mask_shape else None, } @staticmethod @@ -81,6 +107,10 @@ def from_dict(data: dict) -> "Annotation": cluster_id_val = int(cluster_id_raw) if cluster_id_raw is not None else None is_manual_val = data.get("is_manual", False) + mask_rle_val = data.get("mask_rle") + mask_shape_val = ( + tuple(data.get("mask_shape")) if data.get("mask_shape") else None + ) # --- construct ------------------------------------------------------------- return Annotation( @@ -94,4 +124,6 @@ def from_dict(data: dict) -> "Annotation": class_id=int(data.get("class_id", -1)), cluster_id=cluster_id_val, model_prediction=data.get("model_prediction", None), + mask_rle=mask_rle_val, + mask_shape=mask_shape_val, ) diff --git a/GUI/models/ImageProcessor.py b/GUI/models/ImageProcessor.py index 64e5720..4c031c4 100644 --- a/GUI/models/ImageProcessor.py +++ b/GUI/models/ImageProcessor.py @@ -154,6 +154,11 @@ def create_annotation_overlay( if ann.class_id == -1: continue colour = self.class_color_map.get(ann.class_id, (255, 255, 255)) + if ann.mask_rle and ann.mask_shape: + mask = Annotation.decode_mask(ann.mask_rle, ann.mask_shape) + mimg = Image.fromarray(mask * 255).convert("L") + overlay = Image.new("RGBA", pil.size, colour + (100,)) + pil.paste(overlay, mask=mimg) y, x = map(int, ann.coord) bbox = [x - radius, y - radius, x + radius, y + radius] draw.ellipse(bbox, fill=colour, outline=colour) diff --git a/GUI/models/SuperpixelAnnotationGenerator.py b/GUI/models/SuperpixelAnnotationGenerator.py new file mode 100644 index 0000000..9c1b726 --- /dev/null +++ b/GUI/models/SuperpixelAnnotationGenerator.py @@ -0,0 +1,57 @@ +import logging +from typing import List, Tuple + +import numpy as np +from skimage.segmentation import slic +from skimage.measure import regionprops + +from .PointAnnotationGenerator import BasePointAnnotationGenerator +from .Annotation import Annotation + +logger = logging.getLogger(__name__) + + +class SLICSuperpixelGenerator(BasePointAnnotationGenerator): + """Generate superpixel regions ranked by uncertainty.""" + + def __init__(self, n_segments: int = 250, compactness: float = 0.1, edge_buffer: int = 64): + super().__init__(edge_buffer=edge_buffer) + self.n_segments = n_segments + self.compactness = compactness + logger.info( + "SLICSuperpixelGenerator(n_segments=%d, compactness=%.2f)", + n_segments, + compactness, + ) + + def generate_annotations( + self, uncertainty_map: np.ndarray, logits: np.ndarray + ) -> Tuple[np.ndarray, List[Tuple[int, int]], List[list]]: + map2d = self._prepare_uncertainty_map(uncertainty_map) + segments = slic(map2d, n_segments=self.n_segments, compactness=self.compactness, start_label=1) + props = regionprops(segments, intensity_image=map2d) + props = sorted(props, key=lambda p: p.mean_intensity, reverse=True) + + coords: List[Tuple[int, int]] = [] + feats: List[np.ndarray] = [] + masks: List[list] = [] + for p in props: + cy, cx = map(int, p.centroid) + if ( + cy < self.edge_buffer + or cy >= map2d.shape[0] - self.edge_buffer + or cx < self.edge_buffer + or cx >= map2d.shape[1] - self.edge_buffer + ): + continue + mask = segments == p.label + coords.append((cy, cx)) + feats.append(logits[mask].mean(axis=0)) + masks.append(Annotation.encode_mask(mask)) + + if not coords: + logger.warning("SLICSuperpixelGenerator produced no annotations") + return np.empty((0, logits.shape[-1]), dtype=np.float32), [], [] + + return np.stack(feats).astype(np.float32), coords, masks + diff --git a/GUI/models/export/ExportService.py b/GUI/models/export/ExportService.py index f998ca2..a09308c 100644 --- a/GUI/models/export/ExportService.py +++ b/GUI/models/export/ExportService.py @@ -53,6 +53,8 @@ def build_grouped_annotations( "coord": [int(c) for c in anno.coord], "class_id": int(anno.class_id), "cluster_id": int(cluster_id), + "mask_rle": anno.mask_rle, + "mask_shape": list(anno.mask_shape) if anno.mask_shape else None, } ) diff --git a/GUI/models/io/Persistence.py b/GUI/models/io/Persistence.py index e167c9d..fdcdeea 100644 --- a/GUI/models/io/Persistence.py +++ b/GUI/models/io/Persistence.py @@ -53,6 +53,8 @@ class AnnotationJSON(BaseModel): class_id: int cluster_id: Optional[int] = None model_prediction: Optional[str] = None + mask_rle: Optional[List[int]] = None + mask_shape: Optional[List[int]] = None class ProjectState(BaseModel): diff --git a/GUI/unittests/test_annotation_overlay.py b/GUI/unittests/test_annotation_overlay.py index 382db7f..a3230da 100644 --- a/GUI/unittests/test_annotation_overlay.py +++ b/GUI/unittests/test_annotation_overlay.py @@ -16,6 +16,8 @@ def test_create_annotation_overlay_draws_crosshair(): logit_features=np.array([0.0]), uncertainty=0.5, class_id=1, + mask_rle=None, + mask_shape=None, ) out = proc.create_annotation_overlay(img, [ann], radius=2, show_labels=False) arr = np.array(out.convert("RGB")) @@ -36,6 +38,8 @@ def test_create_annotation_overlay_draws_label_box(): logit_features=np.array([0.0]), uncertainty=0.5, class_id=1, + mask_rle=None, + mask_shape=None, ) out = proc.create_annotation_overlay( img, [ann], radius=3, crosshair=False, show_labels=True diff --git a/GUI/unittests/test_clickable_pixmapitem.py b/GUI/unittests/test_clickable_pixmapitem.py index 1b7792c..8a1b86f 100644 --- a/GUI/unittests/test_clickable_pixmapitem.py +++ b/GUI/unittests/test_clickable_pixmapitem.py @@ -37,6 +37,8 @@ def make_item(qapp): coord=(0, 0), logit_features=np.array([], dtype=np.float32), uncertainty=0.5, + mask_rle=None, + mask_shape=None, ) pixmap = QPixmap(10, 10) item = ClickablePixmapItem(annotation=ann, pixmap=pixmap, coord_pos=(2, 3)) diff --git a/GUI/unittests/test_main_controller.py b/GUI/unittests/test_main_controller.py index b9b0072..990972b 100644 --- a/GUI/unittests/test_main_controller.py +++ b/GUI/unittests/test_main_controller.py @@ -166,8 +166,8 @@ def test_visible_crops_complete(): from GUI.models.Annotation import Annotation view = DummyView() - anno1 = Annotation(0, "a", (0, 0), [], 0.5, class_id=1) - anno2 = Annotation(0, "b", (1, 1), [], 0.5, class_id=-1) + anno1 = Annotation(0, "a", (0, 0), [], 0.5, class_id=1, mask_rle=None, mask_shape=None) + anno2 = Annotation(0, "b", (1, 1), [], 0.5, class_id=-1, mask_rle=None, mask_shape=None) view.selected_crops = [{"annotation": anno1}] ctrl = build_controller(view) assert ctrl._visible_crops_complete() @@ -180,7 +180,7 @@ def test_propagate_labeling_changes(monkeypatch): view = DummyView() ctrl = build_controller(view) - ann = Annotation(0, "a", (0, 0), [], 0.5) + ann = Annotation(0, "a", (0, 0), [], 0.5, mask_rle=None, mask_shape=None) ctrl.clustering_controller.get_clusters = lambda: {1: [ann]} called = {} diff --git a/GUI/unittests/test_persistence.py b/GUI/unittests/test_persistence.py index eb65815..01a182b 100644 --- a/GUI/unittests/test_persistence.py +++ b/GUI/unittests/test_persistence.py @@ -26,6 +26,8 @@ def example_state() -> ProjectState: "class_id": -1, "cluster_id": 1, "model_prediction": None, + "mask_rle": None, + "mask_shape": None, } return ProjectState( schema_version=LATEST_SCHEMA_VERSION, diff --git a/GUI/unittests/test_uncertainty_propagation.py b/GUI/unittests/test_uncertainty_propagation.py index f0be1fa..c28fc3e 100644 --- a/GUI/unittests/test_uncertainty_propagation.py +++ b/GUI/unittests/test_uncertainty_propagation.py @@ -14,6 +14,8 @@ def make_annotation(x, uncertainty=1.0, class_id=-1): logit_features=np.array([x], dtype=np.float32), uncertainty=uncertainty, class_id=class_id, + mask_rle=None, + mask_shape=None, ) diff --git a/GUI/views/AppMenuBar.py b/GUI/views/AppMenuBar.py index 4030179..fbbd733 100644 --- a/GUI/views/AppMenuBar.py +++ b/GUI/views/AppMenuBar.py @@ -76,6 +76,7 @@ def _build_actions_menu(self) -> None: "Local Uncertainty Maxima", "Equidistant Spots", "Image Centre", + "Superpixel Masks", ]: act = QAction(label, self, checkable=True) grp.addAction(act) diff --git a/GUI/views/ClickablePixmapItem.py b/GUI/views/ClickablePixmapItem.py index 224dfb1..c82e2c3 100644 --- a/GUI/views/ClickablePixmapItem.py +++ b/GUI/views/ClickablePixmapItem.py @@ -16,11 +16,20 @@ class ClickablePixmapItem(QGraphicsObject): """ class_label_changed = pyqtSignal(dict, int) - def __init__(self, annotation: Annotation, pixmap: QPixmap, coord_pos: Tuple[int, int], *args, **kwargs): + def __init__( + self, + annotation: Annotation, + pixmap: QPixmap, + coord_pos: Tuple[int, int], + mask_pixmap: QPixmap | None = None, + *args, + **kwargs, + ): super().__init__(*args, **kwargs) self.annotation = annotation self.pixmap = pixmap self.coord_pos = coord_pos + self.mask_pixmap = mask_pixmap self.class_id = annotation.class_id self.model_prediction = annotation.model_prediction @@ -65,6 +74,10 @@ def paint(self, painter: QPainter, option, widget): painter.save() painter.scale(self.scale_factor, self.scale_factor) painter.drawPixmap(0, 0, self.pixmap) + if self.mask_pixmap: + painter.setOpacity(0.4) + painter.drawPixmap(0, 0, self.mask_pixmap) + painter.setOpacity(1.0) painter.restore() # 2) Compute scaled dimensions diff --git a/GUI/views/ClusteredCropsView.py b/GUI/views/ClusteredCropsView.py index d963d8e..17ba7fe 100644 --- a/GUI/views/ClusteredCropsView.py +++ b/GUI/views/ClusteredCropsView.py @@ -669,7 +669,12 @@ def arrange_crops(self): logging.warning(f"Invalid QPixmap for image index {annotation.image_index}. Skipping.") continue - pixmap_item = ClickablePixmapItem(annotation=annotation, pixmap=pixmap, coord_pos=crop_data['coord_pos']) + pixmap_item = ClickablePixmapItem( + annotation=annotation, + pixmap=pixmap, + coord_pos=crop_data['coord_pos'], + mask_pixmap=crop_data.get('mask_pixmap'), + ) pixmap_item.setFlag(QGraphicsItem.ItemIsSelectable, True) pixmap_item.class_label_changed.connect(self.crop_label_changed.emit) diff --git a/GUI/workers/ImageProcessingWorker.py b/GUI/workers/ImageProcessingWorker.py index af1af7f..354b934 100644 --- a/GUI/workers/ImageProcessingWorker.py +++ b/GUI/workers/ImageProcessingWorker.py @@ -3,6 +3,9 @@ import logging from typing import List, Dict, Any, Optional +import numpy as np +from PIL import Image + from PyQt5.QtCore import QRunnable, QObject, pyqtSignal from GUI.models.Annotation import Annotation @@ -98,6 +101,7 @@ def _process_single_annotation(self, annotation: Annotation) -> Optional[Dict[st cached_result = self.cache.get(cache_key) if cached_result: processed_crop, coord_pos = cached_result + mask_crop = None else: image_data = self.image_data_model.get_image_data(annotation.image_index) image_array = image_data.get('image') @@ -108,10 +112,30 @@ def _process_single_annotation(self, annotation: Annotation) -> Optional[Dict[st processed_crop, coord_pos = self.image_processor.extract_crop_data( image_array, coord, crop_size=self.crop_size, zoom_factor=self.zoom_factor ) + mask_crop = None + if annotation.mask_rle and annotation.mask_shape: + mask = Annotation.decode_mask(annotation.mask_rle, annotation.mask_shape) + row, col = map(int, coord) + original_height, original_width = mask.shape + half_crop = self.crop_size // 2 + x_start = max(0, col - half_crop) + y_start = max(0, row - half_crop) + if x_start + self.crop_size > original_width: + x_start = original_width - self.crop_size + if y_start + self.crop_size > original_height: + y_start = original_height - self.crop_size + width_crop = min(self.crop_size, original_width - x_start) + height_crop = min(self.crop_size, original_height - y_start) + mask_crop = mask[y_start:y_start + height_crop, x_start:x_start + width_crop] + pil_mask = Image.fromarray(mask_crop * 255) + new_size = (mask_crop.shape[1] * self.zoom_factor, mask_crop.shape[0] * self.zoom_factor) + mask_crop = np.array(pil_mask.resize(new_size, Image.NEAREST)) + self.cache.set(cache_key, (processed_crop, coord_pos)) return { 'annotation': annotation, 'processed_crop': processed_crop, - 'coord_pos': coord_pos + 'coord_pos': coord_pos, + 'mask_crop': mask_crop, } diff --git a/environment.yml b/environment.yml index b59bdf0..727236a 100644 --- a/environment.yml +++ b/environment.yml @@ -27,3 +27,4 @@ dependencies: - pydantic - pytest - pyinstaller + - scikit-image