Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion GUI/controllers/AnnotationClusteringController.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ def _extract_from_image(self, img: dict, idx: int) -> list[AnnotationBase]:
if umap is None or logits is None or fname is None:
return out

generated = self._generator.generate_annotations(uncertainty_map=umap, logits=logits)
image = img.get("image")
generated = self._generator.generate_annotations(
uncertainty_map=umap,
logits=logits,
image=image,
)
for ann in generated:
if not ann.logit_features.any():
continue
Expand Down
4 changes: 3 additions & 1 deletion GUI/controllers/ImageProcessingController.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def _display_processed_annotations(self, annotations_data: List[dict]):
anno = data['annotation']
np_image = data['processed_crop']
coord_pos = data['coord_pos']
mask_patch = data.get('mask_patch')

if np_image is None or coord_pos is None:
logging.warning(f"Missing image or coords for annotation: {anno}")
Expand All @@ -135,7 +136,8 @@ def _display_processed_annotations(self, annotations_data: List[dict]):
sampled_crops.append({
'annotation': anno,
'processed_crop': q_pixmap,
'coord_pos': coord_pos
'coord_pos': coord_pos,
'mask_patch': mask_patch,
})

self.crops_ready.emit(sampled_crops)
Expand Down
6 changes: 6 additions & 0 deletions GUI/controllers/MainController.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
LocalMaximaPointAnnotationGenerator,
EquidistantPointAnnotationGenerator,
CenterPointAnnotationGenerator,
SLICSuperpixelAnnotationGenerator,
)
from GUI.models.UncertaintyPropagator import propagate_for_annotations
from GUI.models.export.Options import ExportOptions
Expand Down Expand Up @@ -231,6 +232,11 @@ def on_label_generator_method_changed(self, method: str):
elif method == "Image Centre":
self.annotation_generator = CenterPointAnnotationGenerator()
self._use_greedy_nav = False
elif method == "Superpixels":
self.annotation_generator = SLICSuperpixelAnnotationGenerator(
n_segments=200, compactness=10.0
)
self._use_greedy_nav = True
else:
self.annotation_generator = LocalMaximaPointAnnotationGenerator(
filter_size=48, gaussian_sigma=4.0, use_gaussian=False
Expand Down
20 changes: 15 additions & 5 deletions GUI/models/ImageProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,9 @@ def extract_crop_data(
image: np.ndarray,
coord: Tuple[int, int],
crop_size: int = 256,
zoom_factor: int = 2
) -> Tuple[np.ndarray, Tuple[int, int]]:
zoom_factor: int = 2,
mask: np.ndarray | None = None,
) -> Tuple[np.ndarray, Tuple[int, int], np.ndarray | None]:
"""
Extracts a zoomed-in crop from the RGB image at the specified coordinate without padding.
Handles image data in float (0-1) or uint8 (0-255) formats.
Expand All @@ -248,8 +249,9 @@ def extract_crop_data(
:param coord: A tuple (row, column) indicating the center of the crop.
:param crop_size: Desired size of the crop in pixels (crop_size x crop_size).
:param zoom_factor: Factor by which to zoom the crop.
:return: A tuple containing the processed image as a NumPy array
and the (x, y) position of the coordinate within the zoomed crop.
:param mask: Optional segmentation mask aligned with ``image``.
:return: ``(zoomed_crop, coord_pos, zoomed_mask)`` where ``zoomed_mask``
is ``None`` when *mask* is ``None``.
"""
# Generate a cache key using a hash of the image and other parameters
cache_key = self._generate_cache_key(image, coord, crop_size, zoom_factor)
Expand Down Expand Up @@ -283,6 +285,9 @@ def extract_crop_data(
height_crop = int(height_crop)

crop = image[y_start:y_start + height_crop, x_start:x_start + width_crop]
mask_crop = None
if mask is not None:
mask_crop = mask[y_start:y_start + height_crop, x_start:x_start + width_crop]

# Ensure crop is uint8 for further processing
if np.issubdtype(crop.dtype, np.floating):
Expand All @@ -292,14 +297,19 @@ def extract_crop_data(
new_size = (crop.shape[1] * zoom_factor, crop.shape[0] * zoom_factor)
zoomed_pil = pil_image.resize(new_size, Image.BICUBIC)
zoomed_crop = np.array(zoomed_pil)
zoomed_mask = None
if mask_crop is not None:
mask_pil = Image.fromarray(mask_crop.astype(np.uint8) * 255)
zoomed_mask = mask_pil.resize(new_size, Image.NEAREST)
zoomed_mask = (np.array(zoomed_mask) > 127).astype(np.uint8)

# Calculate the position of the original coordinate within the zoomed crop
arrow_rel_x = col - x_start
arrow_rel_y = row - y_start
pos_x_zoomed = arrow_rel_x * zoom_factor
pos_y_zoomed = arrow_rel_y * zoom_factor

result = (zoomed_crop, (int(pos_x_zoomed), int(pos_y_zoomed)))
result = (zoomed_crop, (int(pos_x_zoomed), int(pos_y_zoomed)), zoomed_mask)

# Store the result in the cache
self.cache.set(cache_key, result)
Expand Down
101 changes: 98 additions & 3 deletions GUI/models/PointAnnotationGenerator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import List, Tuple

from .annotations import AnnotationBase, PointAnnotation
from .annotations import AnnotationBase, PointAnnotation, MaskAnnotation

import numpy as np
from scipy.ndimage import gaussian_filter, maximum_filter
Expand Down Expand Up @@ -64,9 +64,23 @@ def _extract_logit_features(

# --------------------- public dispatcher -------------------------------- #
def generate_annotations(
self, uncertainty_map: np.ndarray, logits: np.ndarray
self,
uncertainty_map: np.ndarray,
logits: np.ndarray,
image: np.ndarray | None = None,
) -> List[AnnotationBase]:
"""Generate :class:`AnnotationBase` objects for the given inputs."""
"""Generate :class:`AnnotationBase` objects for the given inputs.

Parameters
----------
uncertainty_map
Uncertainty heatmap for the image.
logits
Per-pixel class logits ``H×W×C``.
image
Optional RGB image. Subclasses that require it may override the
method and expect a non-``None`` value.
"""
map2d = self._prepare_uncertainty_map(uncertainty_map)
coords = self._generate_coords(map2d)

Expand Down Expand Up @@ -211,3 +225,84 @@ def _generate_coords(self, uncertainty_map: np.ndarray) -> List[Tuple[int, int]]
centre = (r // 2, c // 2)
logger.debug("Centre point at %s", centre)
return [centre]


# -----------------------------------------------------------------------------
#
# SUPERPIXEL IMPLEMENTATION
# -----------------------------------------------------------------------------
class SLICSuperpixelAnnotationGenerator(BasePointAnnotationGenerator):
"""Generate mask annotations from image superpixels around high-uncertainty areas."""

def __init__(self, n_segments: int = 100, compactness: float = 10.0, edge_buffer: int = 64):
super().__init__(edge_buffer=edge_buffer)
if n_segments <= 0:
raise ValueError("n_segments must be positive")
if compactness <= 0:
raise ValueError("compactness must be positive")
self.n_segments = int(n_segments)
self.compactness = float(compactness)
logger.info(
"SLICSuperpixelAnnotationGenerator(segments=%d, compactness=%.1f)",
self.n_segments,
self.compactness,
)

# ------------------------------------------------------------------ #
def generate_annotations(
self,
uncertainty_map: np.ndarray,
logits: np.ndarray,
image: np.ndarray | None = None,
) -> List[AnnotationBase]:
"""Return mask annotations for superpixels covering local maxima."""
if image is None:
raise ValueError("image must be provided for superpixel annotations")

from skimage.segmentation import slic

map2d = self._prepare_uncertainty_map(uncertainty_map)

lm_gen = LocalMaximaPointAnnotationGenerator(
filter_size=48,
gaussian_sigma=4.0,
edge_buffer=self.edge_buffer,
use_gaussian=False,
)
maxima = lm_gen._generate_coords(map2d)

segments = slic(
image,
n_segments=self.n_segments,
compactness=self.compactness,
start_label=0,
channel_axis=-1 if image.ndim == 3 else None,
)

annos: List[AnnotationBase] = []
seen: set[int] = set()
for r, c in maxima:
label = int(segments[r, c])
if label in seen:
continue
seen.add(label)
mask = segments == label
feats = logits[mask].mean(axis=0)
uncert = float(map2d[mask].mean())
annos.append(
MaskAnnotation(
image_index=-1,
filename="",
coord=(int(r), int(c)),
logit_features=feats,
uncertainty=uncert,
mask=mask.astype(np.uint8),
)
)

logger.info(
"%s produced %d annotations.",
self.__class__.__name__,
len(annos),
)
return annos
21 changes: 20 additions & 1 deletion GUI/models/export/ExportService.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,31 @@
from typing import Dict, Iterable, List, Tuple

from GUI.models.annotations import AnnotationBase, PointAnnotation, MaskAnnotation
import numpy as np
from .Options import ExportOptions

__all__ = ["build_grouped_annotations"]

Grouped = Dict[str, List[dict]]


def _rle_encode(mask: np.ndarray) -> List[int]:
"""Return run-length encoding for a binary mask."""
flat = mask.astype(np.uint8).ravel()
counts: List[int] = []
prev = flat[0]
length = 1
for val in flat[1:]:
if val == prev:
length += 1
else:
counts.append(length)
length = 1
prev = val
counts.append(length)
return counts


def _should_include(anno: AnnotationBase, opts: ExportOptions) -> bool:
"""Return *True* when *anno* must be part of the export."""
if anno.class_id in {None, -1, -2}: # unlabeled or unsure
Expand Down Expand Up @@ -53,7 +71,8 @@ def build_grouped_annotations(
"cluster_id": int(cluster_id),
}
if isinstance(anno, MaskAnnotation) and anno.mask is not None:
entry["mask"] = anno.mask.tolist()
entry["mask_rle"] = _rle_encode(anno.mask)
entry["mask_shape"] = list(anno.mask.shape)
entry["coord"] = [int(c) for c in anno.coord]
else:
entry["coord"] = [int(c) for c in anno.coord]
Expand Down
21 changes: 18 additions & 3 deletions GUI/unittests/test_clickable_pixmapitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from PyQt5.QtGui import QPixmap, QMouseEvent
from PyQt5.QtCore import Qt, QPoint, QPointF

from GUI.views.ClickablePixmapItem import ClickablePixmapItem
from GUI.views.ClickablePixmapItem import (
PointClickablePixmapItem,
MaskClickablePixmapItem,
)
from GUI.models.annotations import PointAnnotation
from GUI.configuration.configuration import CLASS_COMPONENTS

Expand All @@ -27,7 +30,7 @@ def qapp():
return app


def make_item(qapp):
def make_item(qapp, cls=PointClickablePixmapItem, **kwargs):
scene = QGraphicsScene()
parent = QWidget()
QGraphicsView(scene, parent)
Expand All @@ -39,7 +42,10 @@ def make_item(qapp):
uncertainty=0.5,
)
pixmap = QPixmap(10, 10)
item = ClickablePixmapItem(annotation=ann, pixmap=pixmap, coord_pos=(2, 3))
if cls is MaskClickablePixmapItem:
item = cls(annotation=ann, pixmap=pixmap, mask_patch=kwargs.get("mask_patch"))
else:
item = cls(annotation=ann, pixmap=pixmap, coord_pos=(2, 3))
scene.addItem(item)
return item, ann, scene, parent

Expand Down Expand Up @@ -174,6 +180,15 @@ def test_paint_draws_overlays_by_default(qapp):
assert sum(1 for c in painter.calls if c[0] == "drawLine") == 4


def test_paint_draws_mask_edges(qapp):
mask = np.zeros((10, 10), dtype=np.uint8)
mask[2:8, 2:8] = 1
item, ann, scene, _ = make_item(qapp, cls=MaskClickablePixmapItem, mask_patch=mask)
painter = DummyPainter()
item.paint(painter, None, None)
assert sum(1 for c in painter.calls if c[0] == "drawPixmap") == 2


def test_paint_skips_overlays_when_hidden(qapp):
item, _, scene, _ = make_item(qapp)
scene.overlays_visible = False
Expand Down
15 changes: 15 additions & 0 deletions GUI/unittests/test_superpixel_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import numpy as np
from GUI.models.PointAnnotationGenerator import SLICSuperpixelAnnotationGenerator
from GUI.models.annotations import MaskAnnotation


def test_superpixel_generator_outputs_masks():
gen = SLICSuperpixelAnnotationGenerator(n_segments=4, compactness=0.5, edge_buffer=0)
uncertainty = np.zeros((10, 10), dtype=np.float32)
uncertainty[2:7, 2:7] = 1.0
logits = np.random.rand(10, 10, 2).astype(np.float32)
rgb = (np.random.rand(10, 10, 3) * 255).astype(np.uint8)
annos = gen.generate_annotations(uncertainty, logits, image=rgb)
assert annos
assert isinstance(annos[0], MaskAnnotation)
assert annos[0].mask.shape == (10, 10)
1 change: 1 addition & 0 deletions GUI/views/AppMenuBar.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _build_actions_menu(self) -> None:
"Local Uncertainty Maxima",
"Equidistant Spots",
"Image Centre",
"Superpixels",
]:
act = QAction(label, self, checkable=True)
grp.addAction(act)
Expand Down
Loading