Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ def create_training_transforms(
def create_validation_transforms(
skeleton_targets: Optional[List[str]] = None,
skeleton_ignore_values: Optional[Dict[str, int]] = None,
cache_dir: Optional[str] = None,
enable_disk_cache: bool = False,
) -> Optional[ComposeTransforms]:
"""
Create minimal transforms for validation.
Expand All @@ -489,17 +491,5 @@ def create_validation_transforms(
Optional[ComposeTransforms]
The composed validation transforms, or None if no transforms needed.
"""
if not skeleton_targets:
return None

from vesuvius.models.augmentation.transforms.utils.skeleton_transform import MedialSurfaceTransform

transforms = [
MedialSurfaceTransform(
do_tube=False,
target_keys=skeleton_targets,
ignore_values=skeleton_ignore_values or None,
)
]

return ComposeTransforms(transforms)
# Deterministic skeleton targets are generated in the dataset before augmentation.
return None
Comment on lines +494 to +495
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore validation skeleton transform for non-Zarr datasets

create_validation_transforms now always returns None, but MutexAffinityDataset._initialize_transforms still invokes this helper when skeleton losses are configured. In that path, validation no longer produces *_skel tensors, and BaseTrainer._compute_loss_value will call skeleton losses without the required skel argument, which raises for DC_SkelREC_and_CE_loss/SoftSkeletonRecallLoss. This breaks mutex-affinity validation whenever skeleton-supervised losses are enabled.

Useful? React with 👍 / 👎.

Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import hashlib
import json
from collections import OrderedDict
from pathlib import Path
from typing import Optional, Sequence

import torch
Expand All @@ -13,7 +17,10 @@ def __init__(self,
do_open: bool = False,
do_close: bool = True,
target_keys: Optional[Sequence[str]] = None,
ignore_values: Optional[dict] = None,):
ignore_values: Optional[dict] = None,
cache_dir: Optional[str] = None,
enable_disk_cache: bool = False,
memory_cache_size: int = 128):
"""
Calculates the medial surface skeleton of the segmentation (plus an optional 2 px tube around it)
and adds it to the dict with the key "skel"
Expand All @@ -24,6 +31,130 @@ def __init__(self,
self.do_close = do_close
self.target_keys = tuple(target_keys) if target_keys else None
self.ignore_values = dict(ignore_values or {})
self.cache_dir = Path(cache_dir) if cache_dir else None
self.enable_disk_cache = bool(enable_disk_cache and self.cache_dir is not None)
self.memory_cache_size = max(1, int(memory_cache_size))
self._cache: OrderedDict[str, tuple[tuple[tuple[int, int], ...], np.ndarray]] = OrderedDict()

@staticmethod
def _bbox_slices(mask: np.ndarray, margin: int = 0):
if not np.any(mask):
return None
coords = np.where(mask)
slices = []
for axis, axis_coords in enumerate(coords):
start = max(int(axis_coords.min()) - margin, 0)
stop = min(int(axis_coords.max()) + margin + 1, mask.shape[axis])
slices.append(slice(start, stop))
return tuple(slices)

@staticmethod
def _roi_tuple_from_slices(roi_slices):
return tuple((int(slc.start), int(slc.stop)) for slc in roi_slices)

@staticmethod
def _roi_slices_from_tuple(roi_tuple):
return tuple(slice(start, stop) for start, stop in roi_tuple)

def _cache_key(self, patch_info, target_key: str, ignore_value):
if not patch_info:
return None
volume_name = patch_info.get("volume_name")
position = patch_info.get("position")
patch_size = patch_info.get("patch_size")
scale = patch_info.get("scale")
if volume_name is None or position is None or patch_size is None:
return None
payload = {
"version": "v1",
"target": target_key,
"volume_name": volume_name,
"position": list(position),
"patch_size": list(patch_size),
"scale": scale,
"ignore_value": repr(ignore_value),
"do_tube": self.do_tube,
"do_open": self.do_open,
"do_close": self.do_close,
}
return hashlib.md5(json.dumps(payload, sort_keys=True).encode()).hexdigest()

def _cache_get(self, cache_key: Optional[str]):
if cache_key is None:
return None
cached = self._cache.get(cache_key)
if cached is not None:
self._cache.move_to_end(cache_key)
return cached
if not self.enable_disk_cache or self.cache_dir is None:
return None

cache_path = self.cache_dir / cache_key[:2] / f"{cache_key}.npz"
if not cache_path.exists():
return None
try:
with np.load(cache_path, allow_pickle=False) as payload:
roi_tuple = tuple(tuple(int(v) for v in pair) for pair in payload["roi"].tolist())
roi_values = np.ascontiguousarray(payload["values"])
except Exception:
return None

self._cache_put(cache_key, roi_tuple, roi_values)
return roi_tuple, roi_values

def _cache_put(self, cache_key: Optional[str], roi_tuple, roi_values: np.ndarray):
if cache_key is None or roi_tuple is None:
return
self._cache[cache_key] = (roi_tuple, np.ascontiguousarray(roi_values))
self._cache.move_to_end(cache_key)
while len(self._cache) > self.memory_cache_size:
self._cache.popitem(last=False)

if not self.enable_disk_cache or self.cache_dir is None:
return

cache_path = self.cache_dir / cache_key[:2] / f"{cache_key}.npz"
cache_path.parent.mkdir(parents=True, exist_ok=True)
if cache_path.exists():
return
np.savez(cache_path, roi=np.asarray(roi_tuple, dtype=np.int64), values=np.ascontiguousarray(roi_values))

def _compute_skeleton(self, seg_processed: np.ndarray) -> np.ndarray:
bin_seg = seg_processed > 0
seg_all_skel = np.zeros_like(seg_processed, dtype=np.float32)
margin = 2 if (self.do_tube or self.do_open or self.do_close) else 0

for c in range(bin_seg.shape[0]):
seg_c = bin_seg[c]
if seg_c.sum() == 0:
continue

roi_slices = self._bbox_slices(seg_c, margin=margin)
if roi_slices is None:
continue
seg_roi = seg_c[roi_slices]

if seg_roi.ndim == 3:
skel = np.zeros_like(seg_roi, dtype=bool)
for z in range(seg_roi.shape[0]):
skel[z] |= skeletonize(seg_roi[z])
elif seg_roi.ndim == 2:
skel = skeletonize(seg_roi)
else:
raise ValueError(f"Unsupported segmentation dimensionality {seg_roi.ndim} for skeletonization")

if self.do_tube:
skel = dilation(dilation(skel))
if self.do_open:
skel = opening(skel)
if self.do_close:
skel = closing(skel)

seg_all_skel[(c, *roi_slices)] = (
skel.astype(np.float32) * seg_processed[(c, *roi_slices)].astype(np.float32)
)

return seg_all_skel

def apply(self, data_dict, **params):
# Collect regression keys to avoid processing continuous aux targets
Expand All @@ -42,6 +173,7 @@ def apply(self, data_dict, **params):
target_keys = candidate_keys

# Process each target
patch_info = data_dict.get("patch_info", {}) or {}
for target_key in target_keys:
t = data_dict[target_key]
orig_device = t.device
Expand All @@ -53,31 +185,21 @@ def apply(self, data_dict, **params):
else:
seg_processed = seg_all

bin_seg = seg_processed > 0
seg_all_skel = np.zeros_like(seg_processed, dtype=np.float32)

for c in range(bin_seg.shape[0]):
seg_c = bin_seg[c]
if seg_c.sum() == 0:
continue

if seg_c.ndim == 3:
skel = np.zeros_like(seg_c, dtype=bool)
for z in range(seg_c.shape[0]):
skel[z] |= skeletonize(seg_c[z])
elif seg_c.ndim == 2:
skel = skeletonize(seg_c)
else:
raise ValueError(f"Unsupported segmentation dimensionality {seg_c.ndim} for skeletonization")

if self.do_tube:
skel = dilation(dilation(skel))
if self.do_open:
skel = opening(skel)
if self.do_close:
skel = closing(skel)

seg_all_skel[c] = (skel.astype(np.float32) * seg_processed[c].astype(np.float32))
cache_key = self._cache_key(patch_info, target_key, ignore_value)
cached = self._cache_get(cache_key)
if cached is not None:
Comment on lines +188 to +190
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Disable skeleton cache for augmented training samples

The new cache key is based only on static patch_info (volume/position/patch size), but this transform is appended after stochastic augmentations in the training pipeline. That means repeated patches can reuse a cached skeleton computed for a different augmented variant, so {target}_skel can diverge from the current target tensor and silently corrupt skeleton-supervised training on ZarrDataset runs.

Useful? React with 👍 / 👎.

roi_tuple, roi_values = cached
seg_all_skel = np.zeros_like(seg_processed, dtype=np.float32)
seg_all_skel[(slice(None), *self._roi_slices_from_tuple(roi_tuple))] = roi_values
else:
seg_all_skel = self._compute_skeleton(seg_processed)
roi_slices = self._bbox_slices(np.any(seg_all_skel != 0, axis=0), margin=0)
if roi_slices is not None:
self._cache_put(
cache_key,
self._roi_tuple_from_slices(roi_slices),
seg_all_skel[(slice(None), *roi_slices)],
)

data_dict[f"{target_key}_skel"] = torch.from_numpy(seg_all_skel).to(orig_device)

Expand Down
59 changes: 16 additions & 43 deletions vesuvius/src/vesuvius/models/configuration/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ def load_config(self, config_path):
attr_name = "lejepa_lambda" if key == "lambda" else key
setattr(self, attr_name, value)

# Load optional EMA config used by the base trainer.
self.ema_config = deepcopy(config.get("ema", {}) or {})

self._init_attributes()

if self.auxiliary_tasks and self.targets:
Expand All @@ -95,7 +92,6 @@ def load_config(self, config_path):
return config

def _resolve_config_relative_path(self, raw_path):
"""Resolve a config path relative to the config file location."""
if raw_path in (None, ""):
return None

Expand Down Expand Up @@ -129,14 +125,6 @@ def _make_unique_volume_id(self, base_name, used_names):
next_index += 1

def get_explicit_volume_specs(self):
"""
Normalize dataset_config.volumes into explicit image/label path specs.

Supported entries include:
- {image: /path/to/image.zarr, label: /path/to/label.zarr, scale: 2}
- {image: ..., labels: {ink: /path/to/ink.zarr, surface: /path/to/surface.zarr}}
- /path/to/image_only_volume.zarr
"""
volumes_cfg = self.dataset_config.get("volumes")
if not volumes_cfg:
return []
Expand Down Expand Up @@ -195,25 +183,15 @@ def get_explicit_volume_specs(self):

scale_raw = spec.get("scale", spec.get("ome_zarr_resolution"))
if scale_raw not in (None, ""):
try:
scale = int(scale_raw)
except (TypeError, ValueError) as exc:
raise ValueError(
f"Explicit volume '{volume_id}' has invalid scale {scale_raw!r}"
) from exc
scale = int(scale_raw)
if scale < 0:
raise ValueError(
f"Explicit volume '{volume_id}' scale must be >= 0"
)

dilate_raw = spec.get("dilate")
if dilate_raw not in (None, ""):
try:
dilate = float(dilate_raw)
except (TypeError, ValueError) as exc:
raise ValueError(
f"Explicit volume '{volume_id}' has invalid dilate value {dilate_raw!r}"
) from exc
dilate = float(dilate_raw)
if dilate < 0:
raise ValueError(
f"Explicit volume '{volume_id}' dilate must be >= 0"
Expand Down Expand Up @@ -302,30 +280,26 @@ def _init_attributes(self):
self.max_steps_per_epoch = int(self.tr_configs.get("max_steps_per_epoch", 250))
self.max_val_steps_per_epoch = int(self.tr_configs.get("max_val_steps_per_epoch", 50))
self.train_num_dataloader_workers = int(self.tr_configs.get("num_dataloader_workers", 8))
self.val_num_dataloader_workers = int(
self.tr_configs.get("val_num_dataloader_workers", self.train_num_dataloader_workers)
)
self.train_prefetch_factor = int(self.tr_configs.get("prefetch_factor", 2))
self.val_prefetch_factor = int(self.tr_configs.get("val_prefetch_factor", self.train_prefetch_factor))
self.persistent_workers = bool(self.tr_configs.get("persistent_workers", True))
self.val_persistent_workers = bool(
self.tr_configs.get("val_persistent_workers", self.persistent_workers)
)
self.log_every_n_steps = max(1, int(self.tr_configs.get("log_every_n_steps", 10)))
self.max_epoch = int(self.tr_configs.get("max_epoch", 5000))
self.val_every_n = int(self.tr_configs.get("val_every_n", 1))
self.early_stopping_patience = int(self.tr_configs.get("early_stopping_patience", 0))
self.optimizer = self.tr_configs.get("optimizer", "SGD")
self.initial_lr = float(self.tr_configs.get("initial_lr", 0.01))
self.weight_decay = float(self.tr_configs.get("weight_decay", 0.00003))

ema_cfg = deepcopy(getattr(self, "ema_config", {}) or {})
self.ema_enabled = bool(ema_cfg.get("enabled", False))
self.ema_decay = float(ema_cfg.get("decay", 0.999))
self.ema_start_step = int(ema_cfg.get("start_step", 0))
self.ema_update_every_steps = max(1, int(ema_cfg.get("update_every_steps", 1)))
self.ema_validate = bool(ema_cfg.get("validate", self.ema_enabled))
self.ema_save_in_checkpoint = bool(
ema_cfg.get("save_in_checkpoint", self.ema_enabled)
)
self.ema_config = {
"enabled": self.ema_enabled,
"decay": self.ema_decay,
"start_step": self.ema_start_step,
"update_every_steps": self.ema_update_every_steps,
"validate": self.ema_validate,
"save_in_checkpoint": self.ema_save_in_checkpoint,
}
self.numa_pin = str(self.tr_configs.get("numa_pin", "auto")).lower()
self.debug_visualization_every_n = int(self.tr_configs.get("debug_visualization_every_n", 0))
self.validation_preview_pool_size = int(self.tr_configs.get("validation_preview_pool_size", 32))
self.log_validation_preview = bool(self.tr_configs.get("log_validation_preview", True))

### Dataset config ###
self.min_labeled_ratio = float(self.dataset_config.get("min_labeled_ratio", 0.10))
Expand Down Expand Up @@ -1007,7 +981,6 @@ def convert_to_dict(self):
"model_config": model_config,
"dataset_config": dataset_config,
"inference_config": inference_config,
"ema": deepcopy(getattr(self, "ema_config", {})),
}

return combined_config
Expand Down
Loading
Loading