Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions GUI/controllers/AnnotationClusteringController.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ def _extract_from_image(self, img: dict, idx: int) -> list[Annotation]:
if umap is None or logits is None or fname is None:
return out

feats, coords = self._generator.generate_annotations(uncertainty_map=umap, logits=logits)
for c, f in zip(coords, feats):
result = self._generator.generate_annotations(uncertainty_map=umap, logits=logits)
if len(result) == 2:
feats, coords = result
masks = [None] * len(coords)
else:
feats, coords, masks = result
for c, f, m in zip(coords, feats, masks):
if not f.any():
continue
out.append(
Expand All @@ -84,6 +89,8 @@ def _extract_from_image(self, img: dict, idx: int) -> list[Annotation]:
uncertainty=float(umap[tuple(c)]),
cluster_id=None,
model_prediction=CLASS_COMPONENTS.get(int(np.argmax(f)), "None"),
mask_rle=m,
mask_shape=umap.shape[:2],
)
)
return out
Expand Down
5 changes: 4 additions & 1 deletion GUI/controllers/ImageProcessingController.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,20 @@ def _display_processed_annotations(self, annotations_data: List[dict]):
anno = data['annotation']
np_image = data['processed_crop']
coord_pos = data['coord_pos']
mask_crop = data.get('mask_crop')

if np_image is None or coord_pos is None:
logging.warning(f"Missing image or coords for annotation: {anno}")
continue

q_pixmap = self._numpy_to_qpixmap(np_image)
mask_pix = self._numpy_to_qpixmap(mask_crop) if mask_crop is not None else None

sampled_crops.append({
'annotation': anno,
'processed_crop': q_pixmap,
'coord_pos': coord_pos
'coord_pos': coord_pos,
'mask_pixmap': mask_pix,
})

self.crops_ready.emit(sampled_crops)
Expand Down
5 changes: 5 additions & 0 deletions GUI/controllers/MainController.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
LocalMaximaPointAnnotationGenerator,
EquidistantPointAnnotationGenerator,
CenterPointAnnotationGenerator,
LocalMaximaPointAnnotationGenerator,
)
from GUI.models.SuperpixelAnnotationGenerator import SLICSuperpixelGenerator
from GUI.models.UncertaintyPropagator import propagate_for_annotations
from GUI.models.export.Options import ExportOptions
from GUI.models.export.Usecase import ExportAnnotationsUseCase
Expand Down Expand Up @@ -231,6 +233,9 @@ def on_label_generator_method_changed(self, method: str):
elif method == "Image Centre":
self.annotation_generator = CenterPointAnnotationGenerator()
self._use_greedy_nav = False
elif method == "Superpixel Masks":
self.annotation_generator = SLICSuperpixelGenerator()
self._use_greedy_nav = True
else:
self.annotation_generator = LocalMaximaPointAnnotationGenerator(
filter_size=48, gaussian_sigma=4.0, use_gaussian=False
Expand Down
32 changes: 32 additions & 0 deletions GUI/models/Annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class Annotation:
cluster_id: Optional[int] = None
model_prediction: Optional[str] = None
adjusted_uncertainty: Optional[Union[float, np.ndarray]] = None
mask_rle: Optional[list] = None
mask_shape: Optional[Tuple[int, int]] = None

# ---------- core ---------------------------------------------------------------
def __setattr__(self, name, value):
Expand All @@ -28,6 +30,28 @@ def __post_init__(self):
if self.adjusted_uncertainty is None:
self.adjusted_uncertainty = self.uncertainty

# ---------- mask helpers -------------------------------------------------
@staticmethod
def encode_mask(mask: np.ndarray) -> list:
"""Return run-length encoding for ``mask``."""
pixels = mask.astype(np.uint8).flatten(order="F")
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return runs.tolist()

@staticmethod
def decode_mask(rle: list, shape: Tuple[int, int]) -> np.ndarray:
"""Decode RLE ``rle`` back to a binary mask of ``shape``."""
rle = np.asarray(rle, dtype=int)
starts = rle[0::2] - 1
lengths = rle[1::2]
ends = starts + lengths
img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
for s, e in zip(starts, ends):
img[s:e] = 1
return img.reshape(shape, order="F")

# ---------- public helpers -----------------------------------------------------
def reset_uncertainty(self) -> None:
"""Restore posterior to prior."""
Expand Down Expand Up @@ -56,6 +80,8 @@ def to_dict(self) -> dict:
int(self.cluster_id) if self.cluster_id is not None else None
),
"model_prediction": self.model_prediction,
"mask_rle": self.mask_rle,
"mask_shape": list(self.mask_shape) if self.mask_shape else None,
}

@staticmethod
Expand All @@ -81,6 +107,10 @@ def from_dict(data: dict) -> "Annotation":
cluster_id_val = int(cluster_id_raw) if cluster_id_raw is not None else None

is_manual_val = data.get("is_manual", False)
mask_rle_val = data.get("mask_rle")
mask_shape_val = (
tuple(data.get("mask_shape")) if data.get("mask_shape") else None
)

# --- construct -------------------------------------------------------------
return Annotation(
Expand All @@ -94,4 +124,6 @@ def from_dict(data: dict) -> "Annotation":
class_id=int(data.get("class_id", -1)),
cluster_id=cluster_id_val,
model_prediction=data.get("model_prediction", None),
mask_rle=mask_rle_val,
mask_shape=mask_shape_val,
)
5 changes: 5 additions & 0 deletions GUI/models/ImageProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ def create_annotation_overlay(
if ann.class_id == -1:
continue
colour = self.class_color_map.get(ann.class_id, (255, 255, 255))
if ann.mask_rle and ann.mask_shape:
mask = Annotation.decode_mask(ann.mask_rle, ann.mask_shape)
mimg = Image.fromarray(mask * 255).convert("L")
overlay = Image.new("RGBA", pil.size, colour + (100,))
pil.paste(overlay, mask=mimg)
y, x = map(int, ann.coord)
bbox = [x - radius, y - radius, x + radius, y + radius]
draw.ellipse(bbox, fill=colour, outline=colour)
Expand Down
57 changes: 57 additions & 0 deletions GUI/models/SuperpixelAnnotationGenerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
from typing import List, Tuple

import numpy as np
from skimage.segmentation import slic
from skimage.measure import regionprops

from .PointAnnotationGenerator import BasePointAnnotationGenerator
from .Annotation import Annotation

logger = logging.getLogger(__name__)


class SLICSuperpixelGenerator(BasePointAnnotationGenerator):
"""Generate superpixel regions ranked by uncertainty."""

def __init__(self, n_segments: int = 250, compactness: float = 0.1, edge_buffer: int = 64):
super().__init__(edge_buffer=edge_buffer)
self.n_segments = n_segments
self.compactness = compactness
logger.info(
"SLICSuperpixelGenerator(n_segments=%d, compactness=%.2f)",
n_segments,
compactness,
)

def generate_annotations(
self, uncertainty_map: np.ndarray, logits: np.ndarray
) -> Tuple[np.ndarray, List[Tuple[int, int]], List[list]]:
map2d = self._prepare_uncertainty_map(uncertainty_map)
segments = slic(map2d, n_segments=self.n_segments, compactness=self.compactness, start_label=1)
props = regionprops(segments, intensity_image=map2d)
props = sorted(props, key=lambda p: p.mean_intensity, reverse=True)

coords: List[Tuple[int, int]] = []
feats: List[np.ndarray] = []
masks: List[list] = []
for p in props:
cy, cx = map(int, p.centroid)
if (
cy < self.edge_buffer
or cy >= map2d.shape[0] - self.edge_buffer
or cx < self.edge_buffer
or cx >= map2d.shape[1] - self.edge_buffer
):
continue
mask = segments == p.label
coords.append((cy, cx))
feats.append(logits[mask].mean(axis=0))
masks.append(Annotation.encode_mask(mask))

if not coords:
logger.warning("SLICSuperpixelGenerator produced no annotations")
return np.empty((0, logits.shape[-1]), dtype=np.float32), [], []

return np.stack(feats).astype(np.float32), coords, masks

2 changes: 2 additions & 0 deletions GUI/models/export/ExportService.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def build_grouped_annotations(
"coord": [int(c) for c in anno.coord],
"class_id": int(anno.class_id),
"cluster_id": int(cluster_id),
"mask_rle": anno.mask_rle,
"mask_shape": list(anno.mask_shape) if anno.mask_shape else None,
}
)

Expand Down
2 changes: 2 additions & 0 deletions GUI/models/io/Persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class AnnotationJSON(BaseModel):
class_id: int
cluster_id: Optional[int] = None
model_prediction: Optional[str] = None
mask_rle: Optional[List[int]] = None
mask_shape: Optional[List[int]] = None


class ProjectState(BaseModel):
Expand Down
4 changes: 4 additions & 0 deletions GUI/unittests/test_annotation_overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def test_create_annotation_overlay_draws_crosshair():
logit_features=np.array([0.0]),
uncertainty=0.5,
class_id=1,
mask_rle=None,
mask_shape=None,
)
out = proc.create_annotation_overlay(img, [ann], radius=2, show_labels=False)
arr = np.array(out.convert("RGB"))
Expand All @@ -36,6 +38,8 @@ def test_create_annotation_overlay_draws_label_box():
logit_features=np.array([0.0]),
uncertainty=0.5,
class_id=1,
mask_rle=None,
mask_shape=None,
)
out = proc.create_annotation_overlay(
img, [ann], radius=3, crosshair=False, show_labels=True
Expand Down
2 changes: 2 additions & 0 deletions GUI/unittests/test_clickable_pixmapitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def make_item(qapp):
coord=(0, 0),
logit_features=np.array([], dtype=np.float32),
uncertainty=0.5,
mask_rle=None,
mask_shape=None,
)
pixmap = QPixmap(10, 10)
item = ClickablePixmapItem(annotation=ann, pixmap=pixmap, coord_pos=(2, 3))
Expand Down
6 changes: 3 additions & 3 deletions GUI/unittests/test_main_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def test_visible_crops_complete():
from GUI.models.Annotation import Annotation

view = DummyView()
anno1 = Annotation(0, "a", (0, 0), [], 0.5, class_id=1)
anno2 = Annotation(0, "b", (1, 1), [], 0.5, class_id=-1)
anno1 = Annotation(0, "a", (0, 0), [], 0.5, class_id=1, mask_rle=None, mask_shape=None)
anno2 = Annotation(0, "b", (1, 1), [], 0.5, class_id=-1, mask_rle=None, mask_shape=None)
view.selected_crops = [{"annotation": anno1}]
ctrl = build_controller(view)
assert ctrl._visible_crops_complete()
Expand All @@ -180,7 +180,7 @@ def test_propagate_labeling_changes(monkeypatch):

view = DummyView()
ctrl = build_controller(view)
ann = Annotation(0, "a", (0, 0), [], 0.5)
ann = Annotation(0, "a", (0, 0), [], 0.5, mask_rle=None, mask_shape=None)
ctrl.clustering_controller.get_clusters = lambda: {1: [ann]}
called = {}

Expand Down
2 changes: 2 additions & 0 deletions GUI/unittests/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def example_state() -> ProjectState:
"class_id": -1,
"cluster_id": 1,
"model_prediction": None,
"mask_rle": None,
"mask_shape": None,
}
return ProjectState(
schema_version=LATEST_SCHEMA_VERSION,
Expand Down
2 changes: 2 additions & 0 deletions GUI/unittests/test_uncertainty_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def make_annotation(x, uncertainty=1.0, class_id=-1):
logit_features=np.array([x], dtype=np.float32),
uncertainty=uncertainty,
class_id=class_id,
mask_rle=None,
mask_shape=None,
)


Expand Down
1 change: 1 addition & 0 deletions GUI/views/AppMenuBar.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _build_actions_menu(self) -> None:
"Local Uncertainty Maxima",
"Equidistant Spots",
"Image Centre",
"Superpixel Masks",
]:
act = QAction(label, self, checkable=True)
grp.addAction(act)
Expand Down
15 changes: 14 additions & 1 deletion GUI/views/ClickablePixmapItem.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,20 @@ class ClickablePixmapItem(QGraphicsObject):
"""
class_label_changed = pyqtSignal(dict, int)

def __init__(self, annotation: Annotation, pixmap: QPixmap, coord_pos: Tuple[int, int], *args, **kwargs):
def __init__(
self,
annotation: Annotation,
pixmap: QPixmap,
coord_pos: Tuple[int, int],
mask_pixmap: QPixmap | None = None,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.annotation = annotation
self.pixmap = pixmap
self.coord_pos = coord_pos
self.mask_pixmap = mask_pixmap
self.class_id = annotation.class_id
self.model_prediction = annotation.model_prediction

Expand Down Expand Up @@ -65,6 +74,10 @@ def paint(self, painter: QPainter, option, widget):
painter.save()
painter.scale(self.scale_factor, self.scale_factor)
painter.drawPixmap(0, 0, self.pixmap)
if self.mask_pixmap:
painter.setOpacity(0.4)
painter.drawPixmap(0, 0, self.mask_pixmap)
painter.setOpacity(1.0)
painter.restore()

# 2) Compute scaled dimensions
Expand Down
7 changes: 6 additions & 1 deletion GUI/views/ClusteredCropsView.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,12 @@ def arrange_crops(self):
logging.warning(f"Invalid QPixmap for image index {annotation.image_index}. Skipping.")
continue

pixmap_item = ClickablePixmapItem(annotation=annotation, pixmap=pixmap, coord_pos=crop_data['coord_pos'])
pixmap_item = ClickablePixmapItem(
annotation=annotation,
pixmap=pixmap,
coord_pos=crop_data['coord_pos'],
mask_pixmap=crop_data.get('mask_pixmap'),
)
pixmap_item.setFlag(QGraphicsItem.ItemIsSelectable, True)
pixmap_item.class_label_changed.connect(self.crop_label_changed.emit)

Expand Down
26 changes: 25 additions & 1 deletion GUI/workers/ImageProcessingWorker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import logging
from typing import List, Dict, Any, Optional

import numpy as np
from PIL import Image

from PyQt5.QtCore import QRunnable, QObject, pyqtSignal

from GUI.models.Annotation import Annotation
Expand Down Expand Up @@ -98,6 +101,7 @@ def _process_single_annotation(self, annotation: Annotation) -> Optional[Dict[st
cached_result = self.cache.get(cache_key)
if cached_result:
processed_crop, coord_pos = cached_result
mask_crop = None
else:
image_data = self.image_data_model.get_image_data(annotation.image_index)
image_array = image_data.get('image')
Expand All @@ -108,10 +112,30 @@ def _process_single_annotation(self, annotation: Annotation) -> Optional[Dict[st
processed_crop, coord_pos = self.image_processor.extract_crop_data(
image_array, coord, crop_size=self.crop_size, zoom_factor=self.zoom_factor
)
mask_crop = None
if annotation.mask_rle and annotation.mask_shape:
mask = Annotation.decode_mask(annotation.mask_rle, annotation.mask_shape)
row, col = map(int, coord)
original_height, original_width = mask.shape
half_crop = self.crop_size // 2
x_start = max(0, col - half_crop)
y_start = max(0, row - half_crop)
if x_start + self.crop_size > original_width:
x_start = original_width - self.crop_size
if y_start + self.crop_size > original_height:
y_start = original_height - self.crop_size
width_crop = min(self.crop_size, original_width - x_start)
height_crop = min(self.crop_size, original_height - y_start)
mask_crop = mask[y_start:y_start + height_crop, x_start:x_start + width_crop]
pil_mask = Image.fromarray(mask_crop * 255)
new_size = (mask_crop.shape[1] * self.zoom_factor, mask_crop.shape[0] * self.zoom_factor)
mask_crop = np.array(pil_mask.resize(new_size, Image.NEAREST))

self.cache.set(cache_key, (processed_crop, coord_pos))

return {
'annotation': annotation,
'processed_crop': processed_crop,
'coord_pos': coord_pos
'coord_pos': coord_pos,
'mask_crop': mask_crop,
}
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ dependencies:
- pydantic
- pytest
- pyinstaller
- scikit-image