From 0712b92127c57893b2d1aeec0c0c1c3e53af02b3 Mon Sep 17 00:00:00 2001 From: Giorgio Angelotti <76100950+giorgioangel@users.noreply.github.com> Date: Tue, 31 Mar 2026 12:04:11 +0200 Subject: [PATCH 1/3] Optimize Vesuvius training runtime and validation --- .../pipelines/training_transforms.py | 18 +- .../transforms/utils/skeleton_transform.py | 174 +++++- .../models/configuration/config_manager.py | 59 +- .../vesuvius/models/datasets/zarr_dataset.py | 262 +++++++-- vesuvius/src/vesuvius/models/training/cli.py | 54 +- .../models/training/save_checkpoint.py | 12 +- .../src/vesuvius/models/training/train.py | 535 ++++++++++-------- .../vesuvius/models/utilities/cli_utils.py | 64 ++- vesuvius/src/vesuvius/utils/plotting.py | 100 +++- 9 files changed, 861 insertions(+), 417 deletions(-) diff --git a/vesuvius/src/vesuvius/models/augmentation/pipelines/training_transforms.py b/vesuvius/src/vesuvius/models/augmentation/pipelines/training_transforms.py index 7a2ab2e82..f12260f95 100644 --- a/vesuvius/src/vesuvius/models/augmentation/pipelines/training_transforms.py +++ b/vesuvius/src/vesuvius/models/augmentation/pipelines/training_transforms.py @@ -470,6 +470,8 @@ def create_training_transforms( def create_validation_transforms( skeleton_targets: Optional[List[str]] = None, skeleton_ignore_values: Optional[Dict[str, int]] = None, + cache_dir: Optional[str] = None, + enable_disk_cache: bool = False, ) -> Optional[ComposeTransforms]: """ Create minimal transforms for validation. @@ -489,17 +491,5 @@ def create_validation_transforms( Optional[ComposeTransforms] The composed validation transforms, or None if no transforms needed. """ - if not skeleton_targets: - return None - - from vesuvius.models.augmentation.transforms.utils.skeleton_transform import MedialSurfaceTransform - - transforms = [ - MedialSurfaceTransform( - do_tube=False, - target_keys=skeleton_targets, - ignore_values=skeleton_ignore_values or None, - ) - ] - - return ComposeTransforms(transforms) + # Deterministic skeleton targets are generated in the dataset before augmentation. + return None diff --git a/vesuvius/src/vesuvius/models/augmentation/transforms/utils/skeleton_transform.py b/vesuvius/src/vesuvius/models/augmentation/transforms/utils/skeleton_transform.py index 4f413cc97..d8aaa77d5 100644 --- a/vesuvius/src/vesuvius/models/augmentation/transforms/utils/skeleton_transform.py +++ b/vesuvius/src/vesuvius/models/augmentation/transforms/utils/skeleton_transform.py @@ -1,3 +1,7 @@ +import hashlib +import json +from collections import OrderedDict +from pathlib import Path from typing import Optional, Sequence import torch @@ -13,7 +17,10 @@ def __init__(self, do_open: bool = False, do_close: bool = True, target_keys: Optional[Sequence[str]] = None, - ignore_values: Optional[dict] = None,): + ignore_values: Optional[dict] = None, + cache_dir: Optional[str] = None, + enable_disk_cache: bool = False, + memory_cache_size: int = 128): """ Calculates the medial surface skeleton of the segmentation (plus an optional 2 px tube around it) and adds it to the dict with the key "skel" @@ -24,6 +31,130 @@ def __init__(self, self.do_close = do_close self.target_keys = tuple(target_keys) if target_keys else None self.ignore_values = dict(ignore_values or {}) + self.cache_dir = Path(cache_dir) if cache_dir else None + self.enable_disk_cache = bool(enable_disk_cache and self.cache_dir is not None) + self.memory_cache_size = max(1, int(memory_cache_size)) + self._cache: OrderedDict[str, tuple[tuple[tuple[int, int], ...], np.ndarray]] = OrderedDict() + + @staticmethod + def _bbox_slices(mask: np.ndarray, margin: int = 0): + if not np.any(mask): + return None + coords = np.where(mask) + slices = [] + for axis, axis_coords in enumerate(coords): + start = max(int(axis_coords.min()) - margin, 0) + stop = min(int(axis_coords.max()) + margin + 1, mask.shape[axis]) + slices.append(slice(start, stop)) + return tuple(slices) + + @staticmethod + def _roi_tuple_from_slices(roi_slices): + return tuple((int(slc.start), int(slc.stop)) for slc in roi_slices) + + @staticmethod + def _roi_slices_from_tuple(roi_tuple): + return tuple(slice(start, stop) for start, stop in roi_tuple) + + def _cache_key(self, patch_info, target_key: str, ignore_value): + if not patch_info: + return None + volume_name = patch_info.get("volume_name") + position = patch_info.get("position") + patch_size = patch_info.get("patch_size") + scale = patch_info.get("scale") + if volume_name is None or position is None or patch_size is None: + return None + payload = { + "version": "v1", + "target": target_key, + "volume_name": volume_name, + "position": list(position), + "patch_size": list(patch_size), + "scale": scale, + "ignore_value": repr(ignore_value), + "do_tube": self.do_tube, + "do_open": self.do_open, + "do_close": self.do_close, + } + return hashlib.md5(json.dumps(payload, sort_keys=True).encode()).hexdigest() + + def _cache_get(self, cache_key: Optional[str]): + if cache_key is None: + return None + cached = self._cache.get(cache_key) + if cached is not None: + self._cache.move_to_end(cache_key) + return cached + if not self.enable_disk_cache or self.cache_dir is None: + return None + + cache_path = self.cache_dir / cache_key[:2] / f"{cache_key}.npz" + if not cache_path.exists(): + return None + try: + with np.load(cache_path, allow_pickle=False) as payload: + roi_tuple = tuple(tuple(int(v) for v in pair) for pair in payload["roi"].tolist()) + roi_values = np.ascontiguousarray(payload["values"]) + except Exception: + return None + + self._cache_put(cache_key, roi_tuple, roi_values) + return roi_tuple, roi_values + + def _cache_put(self, cache_key: Optional[str], roi_tuple, roi_values: np.ndarray): + if cache_key is None or roi_tuple is None: + return + self._cache[cache_key] = (roi_tuple, np.ascontiguousarray(roi_values)) + self._cache.move_to_end(cache_key) + while len(self._cache) > self.memory_cache_size: + self._cache.popitem(last=False) + + if not self.enable_disk_cache or self.cache_dir is None: + return + + cache_path = self.cache_dir / cache_key[:2] / f"{cache_key}.npz" + cache_path.parent.mkdir(parents=True, exist_ok=True) + if cache_path.exists(): + return + np.savez(cache_path, roi=np.asarray(roi_tuple, dtype=np.int64), values=np.ascontiguousarray(roi_values)) + + def _compute_skeleton(self, seg_processed: np.ndarray) -> np.ndarray: + bin_seg = seg_processed > 0 + seg_all_skel = np.zeros_like(seg_processed, dtype=np.float32) + margin = 2 if (self.do_tube or self.do_open or self.do_close) else 0 + + for c in range(bin_seg.shape[0]): + seg_c = bin_seg[c] + if seg_c.sum() == 0: + continue + + roi_slices = self._bbox_slices(seg_c, margin=margin) + if roi_slices is None: + continue + seg_roi = seg_c[roi_slices] + + if seg_roi.ndim == 3: + skel = np.zeros_like(seg_roi, dtype=bool) + for z in range(seg_roi.shape[0]): + skel[z] |= skeletonize(seg_roi[z]) + elif seg_roi.ndim == 2: + skel = skeletonize(seg_roi) + else: + raise ValueError(f"Unsupported segmentation dimensionality {seg_roi.ndim} for skeletonization") + + if self.do_tube: + skel = dilation(dilation(skel)) + if self.do_open: + skel = opening(skel) + if self.do_close: + skel = closing(skel) + + seg_all_skel[(c, *roi_slices)] = ( + skel.astype(np.float32) * seg_processed[(c, *roi_slices)].astype(np.float32) + ) + + return seg_all_skel def apply(self, data_dict, **params): # Collect regression keys to avoid processing continuous aux targets @@ -42,6 +173,7 @@ def apply(self, data_dict, **params): target_keys = candidate_keys # Process each target + patch_info = data_dict.get("patch_info", {}) or {} for target_key in target_keys: t = data_dict[target_key] orig_device = t.device @@ -53,31 +185,21 @@ def apply(self, data_dict, **params): else: seg_processed = seg_all - bin_seg = seg_processed > 0 - seg_all_skel = np.zeros_like(seg_processed, dtype=np.float32) - - for c in range(bin_seg.shape[0]): - seg_c = bin_seg[c] - if seg_c.sum() == 0: - continue - - if seg_c.ndim == 3: - skel = np.zeros_like(seg_c, dtype=bool) - for z in range(seg_c.shape[0]): - skel[z] |= skeletonize(seg_c[z]) - elif seg_c.ndim == 2: - skel = skeletonize(seg_c) - else: - raise ValueError(f"Unsupported segmentation dimensionality {seg_c.ndim} for skeletonization") - - if self.do_tube: - skel = dilation(dilation(skel)) - if self.do_open: - skel = opening(skel) - if self.do_close: - skel = closing(skel) - - seg_all_skel[c] = (skel.astype(np.float32) * seg_processed[c].astype(np.float32)) + cache_key = self._cache_key(patch_info, target_key, ignore_value) + cached = self._cache_get(cache_key) + if cached is not None: + roi_tuple, roi_values = cached + seg_all_skel = np.zeros_like(seg_processed, dtype=np.float32) + seg_all_skel[(slice(None), *self._roi_slices_from_tuple(roi_tuple))] = roi_values + else: + seg_all_skel = self._compute_skeleton(seg_processed) + roi_slices = self._bbox_slices(np.any(seg_all_skel != 0, axis=0), margin=0) + if roi_slices is not None: + self._cache_put( + cache_key, + self._roi_tuple_from_slices(roi_slices), + seg_all_skel[(slice(None), *roi_slices)], + ) data_dict[f"{target_key}_skel"] = torch.from_numpy(seg_all_skel).to(orig_device) diff --git a/vesuvius/src/vesuvius/models/configuration/config_manager.py b/vesuvius/src/vesuvius/models/configuration/config_manager.py index 473422eaf..3da5f10de 100755 --- a/vesuvius/src/vesuvius/models/configuration/config_manager.py +++ b/vesuvius/src/vesuvius/models/configuration/config_manager.py @@ -84,9 +84,6 @@ def load_config(self, config_path): attr_name = "lejepa_lambda" if key == "lambda" else key setattr(self, attr_name, value) - # Load optional EMA config used by the base trainer. - self.ema_config = deepcopy(config.get("ema", {}) or {}) - self._init_attributes() if self.auxiliary_tasks and self.targets: @@ -95,7 +92,6 @@ def load_config(self, config_path): return config def _resolve_config_relative_path(self, raw_path): - """Resolve a config path relative to the config file location.""" if raw_path in (None, ""): return None @@ -129,14 +125,6 @@ def _make_unique_volume_id(self, base_name, used_names): next_index += 1 def get_explicit_volume_specs(self): - """ - Normalize dataset_config.volumes into explicit image/label path specs. - - Supported entries include: - - {image: /path/to/image.zarr, label: /path/to/label.zarr, scale: 2} - - {image: ..., labels: {ink: /path/to/ink.zarr, surface: /path/to/surface.zarr}} - - /path/to/image_only_volume.zarr - """ volumes_cfg = self.dataset_config.get("volumes") if not volumes_cfg: return [] @@ -195,12 +183,7 @@ def get_explicit_volume_specs(self): scale_raw = spec.get("scale", spec.get("ome_zarr_resolution")) if scale_raw not in (None, ""): - try: - scale = int(scale_raw) - except (TypeError, ValueError) as exc: - raise ValueError( - f"Explicit volume '{volume_id}' has invalid scale {scale_raw!r}" - ) from exc + scale = int(scale_raw) if scale < 0: raise ValueError( f"Explicit volume '{volume_id}' scale must be >= 0" @@ -208,12 +191,7 @@ def get_explicit_volume_specs(self): dilate_raw = spec.get("dilate") if dilate_raw not in (None, ""): - try: - dilate = float(dilate_raw) - except (TypeError, ValueError) as exc: - raise ValueError( - f"Explicit volume '{volume_id}' has invalid dilate value {dilate_raw!r}" - ) from exc + dilate = float(dilate_raw) if dilate < 0: raise ValueError( f"Explicit volume '{volume_id}' dilate must be >= 0" @@ -302,30 +280,26 @@ def _init_attributes(self): self.max_steps_per_epoch = int(self.tr_configs.get("max_steps_per_epoch", 250)) self.max_val_steps_per_epoch = int(self.tr_configs.get("max_val_steps_per_epoch", 50)) self.train_num_dataloader_workers = int(self.tr_configs.get("num_dataloader_workers", 8)) + self.val_num_dataloader_workers = int( + self.tr_configs.get("val_num_dataloader_workers", self.train_num_dataloader_workers) + ) + self.train_prefetch_factor = int(self.tr_configs.get("prefetch_factor", 2)) + self.val_prefetch_factor = int(self.tr_configs.get("val_prefetch_factor", self.train_prefetch_factor)) + self.persistent_workers = bool(self.tr_configs.get("persistent_workers", True)) + self.val_persistent_workers = bool( + self.tr_configs.get("val_persistent_workers", self.persistent_workers) + ) + self.log_every_n_steps = max(1, int(self.tr_configs.get("log_every_n_steps", 10))) self.max_epoch = int(self.tr_configs.get("max_epoch", 5000)) self.val_every_n = int(self.tr_configs.get("val_every_n", 1)) self.early_stopping_patience = int(self.tr_configs.get("early_stopping_patience", 0)) self.optimizer = self.tr_configs.get("optimizer", "SGD") self.initial_lr = float(self.tr_configs.get("initial_lr", 0.01)) self.weight_decay = float(self.tr_configs.get("weight_decay", 0.00003)) - - ema_cfg = deepcopy(getattr(self, "ema_config", {}) or {}) - self.ema_enabled = bool(ema_cfg.get("enabled", False)) - self.ema_decay = float(ema_cfg.get("decay", 0.999)) - self.ema_start_step = int(ema_cfg.get("start_step", 0)) - self.ema_update_every_steps = max(1, int(ema_cfg.get("update_every_steps", 1))) - self.ema_validate = bool(ema_cfg.get("validate", self.ema_enabled)) - self.ema_save_in_checkpoint = bool( - ema_cfg.get("save_in_checkpoint", self.ema_enabled) - ) - self.ema_config = { - "enabled": self.ema_enabled, - "decay": self.ema_decay, - "start_step": self.ema_start_step, - "update_every_steps": self.ema_update_every_steps, - "validate": self.ema_validate, - "save_in_checkpoint": self.ema_save_in_checkpoint, - } + self.numa_pin = str(self.tr_configs.get("numa_pin", "auto")).lower() + self.debug_visualization_every_n = int(self.tr_configs.get("debug_visualization_every_n", 0)) + self.validation_preview_pool_size = int(self.tr_configs.get("validation_preview_pool_size", 32)) + self.log_validation_preview = bool(self.tr_configs.get("log_validation_preview", True)) ### Dataset config ### self.min_labeled_ratio = float(self.dataset_config.get("min_labeled_ratio", 0.10)) @@ -1007,7 +981,6 @@ def convert_to_dict(self): "model_config": model_config, "dataset_config": dataset_config, "inference_config": inference_config, - "ema": deepcopy(getattr(self, "ema_config", {})), } return combined_config diff --git a/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py b/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py index a3053c86e..221426af3 100644 --- a/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py +++ b/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py @@ -12,7 +12,9 @@ from __future__ import annotations import json +import hashlib import logging +from collections import OrderedDict from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Optional, Tuple @@ -28,10 +30,18 @@ from ..training.normalization import get_normalization from ..augmentation.pipelines import create_training_transforms from ..augmentation.transforms.utils.perf import collect_augmentation_names +from ..augmentation.transforms.utils.skeleton_transform import MedialSurfaceTransform + +try: + import edt as fast_edt +except Exception: + fast_edt = None logger = logging.getLogger(__name__) +DERIVED_CACHE_VERSION = "v1" + @dataclass class VolumeInfo: @@ -113,6 +123,10 @@ def __init__( name: self._resolve_ignore_label(self.targets.get(name) or {}) for name in self.target_names } + self._binary_fast_path_targets = { + name for name in self.target_names + if str(name).lower() == "surface" + } # Determine 2D vs 3D self.is_2d = len(self.patch_size) == 2 @@ -139,6 +153,17 @@ def __init__( # Caching self.cache_enabled = getattr(mgr, 'cache_valid_patches', True) self.cache_dir = self.data_path / '.patches_cache' + derived_cache_root = getattr(mgr, 'derived_patch_cache_dir', None) + if derived_cache_root: + self.derived_cache_dir = Path(derived_cache_root) + elif Path('/ephemeral').exists(): + self.derived_cache_dir = Path('/ephemeral/vesuvius_patch_cache') + else: + self.derived_cache_dir = self.data_path / '.derived_patch_cache' + self._derived_cache_disk_enabled = not self.is_training + self._derived_cache_limit = 128 if self.is_training else 256 + self._derived_cache: OrderedDict[str, tuple[tuple[tuple[int, int], ...], np.ndarray]] = OrderedDict() + self._precomputed_skeleton_transform = None # Initialize storage self._volumes: List[VolumeInfo] = [] @@ -366,38 +391,69 @@ def _ignore_mask(values: np.ndarray, ignore_label: Optional[float]) -> np.ndarra except TypeError: return np.zeros(values.shape, dtype=bool) + @staticmethod + def _bbox_slices(mask: np.ndarray, margin: int = 0) -> Optional[Tuple[slice, ...]]: + if not np.any(mask): + return None + coords = np.where(mask) + slices = [] + for axis, axis_coords in enumerate(coords): + start = max(int(axis_coords.min()) - margin, 0) + stop = min(int(axis_coords.max()) + margin + 1, mask.shape[axis]) + slices.append(slice(start, stop)) + return tuple(slices) + @classmethod def _dilate_label_region( cls, values: np.ndarray, distance: float, ignore_label: Optional[float], - ) -> np.ndarray: + *, + binary_fast_path: bool = False, + ) -> tuple[np.ndarray, Optional[Tuple[slice, ...]]]: if distance <= 0: - return values + return values, None arr = np.asarray(values) ignore_mask = cls._ignore_mask(arr, ignore_label) source_mask = (arr != 0) & ~ignore_mask if not np.any(source_mask): - return arr + return arr, None - fill_mask = (arr == 0) - if not np.any(fill_mask): - return arr + roi_slices = cls._bbox_slices(source_mask, margin=int(np.ceil(distance))) + if roi_slices is None: + return arr, None - distances, nearest_indices = distance_transform_edt( - ~source_mask, - return_indices=True, - ) - fill_mask &= distances <= float(distance) + roi = arr[roi_slices] + roi_ignore_mask = cls._ignore_mask(roi, ignore_label) + roi_source_mask = (roi != 0) & ~roi_ignore_mask + fill_mask = roi == 0 if not np.any(fill_mask): - return arr + return arr, roi_slices - result = arr.copy() - nearest_values = arr[tuple(nearest_indices[axis][fill_mask] for axis in range(arr.ndim))] - result[fill_mask] = nearest_values - return result + result = np.array(arr, copy=True) + roi_result = np.array(roi, copy=True) + + if binary_fast_path and fast_edt is not None: + distances = fast_edt.edt((~roi_source_mask).astype(np.uint8), parallel=1) + fill_mask &= distances <= float(distance) + if not np.any(fill_mask): + return arr, roi_slices + roi_result[fill_mask] = 1 + else: + distances, nearest_indices = distance_transform_edt( + ~roi_source_mask, + return_indices=True, + ) + fill_mask &= distances <= float(distance) + if not np.any(fill_mask): + return arr, roi_slices + nearest_values = roi[tuple(nearest_indices[axis][fill_mask] for axis in range(roi.ndim))] + roi_result[fill_mask] = nearest_values + + result[roi_slices] = roi_result + return result, roi_slices @classmethod def _dilate_label_patch( @@ -406,21 +462,31 @@ def _dilate_label_patch( distance: Optional[float], ignore_label: Optional[float], original_shape: Tuple[int, ...], - ) -> np.ndarray: + *, + binary_fast_path: bool = False, + ) -> tuple[np.ndarray, Optional[Tuple[slice, ...]]]: if distance in (None, 0): - return values + return values, None valid_slices = tuple(slice(0, max(0, min(int(o), int(s)))) for o, s in zip(original_shape, values.shape)) if any(slc.stop == 0 for slc in valid_slices): - return values + return values, None result = np.array(values, copy=True) - result[valid_slices] = cls._dilate_label_region( + dilated_region, roi_slices = cls._dilate_label_region( result[valid_slices], float(distance), ignore_label, + binary_fast_path=binary_fast_path, ) - return result + result[valid_slices] = dilated_region + if roi_slices is None: + return result, None + absolute_roi = tuple( + slice(valid_slices[axis].start + roi_slice.start, valid_slices[axis].start + roi_slice.stop) + for axis, roi_slice in enumerate(roi_slices) + ) + return result, absolute_roi # ------------------------------------------------------------------------- # Normalization @@ -458,6 +524,97 @@ def _build_target_volumes_for_intensity(self) -> Dict: }) return {first_target: volumes_list} + @staticmethod + def _roi_tuple_from_slices(roi_slices: Tuple[slice, ...]) -> tuple[tuple[int, int], ...]: + return tuple((int(slc.start), int(slc.stop)) for slc in roi_slices) + + @staticmethod + def _roi_slices_from_tuple(roi_tuple: tuple[tuple[int, int], ...]) -> Tuple[slice, ...]: + return tuple(slice(start, stop) for start, stop in roi_tuple) + + def _derived_cache_key( + self, + *, + derivation: str, + patch: PatchInfo, + target_name: str, + extra: Dict[str, object], + ) -> str: + vol = self._volumes[patch.volume_index] + payload = { + "version": DERIVED_CACHE_VERSION, + "derivation": derivation, + "volume_name": patch.volume_name, + "position": list(patch.position), + "patch_size": list(patch.patch_size), + "position_scale_factor": int(getattr(patch, "position_scale_factor", 1) or 1), + "scale": getattr(vol, "scale", self.ome_zarr_resolution), + "target_name": target_name, + "extra": {key: repr(value) for key, value in extra.items()}, + } + return hashlib.md5(json.dumps(payload, sort_keys=True).encode()).hexdigest() + + def _derived_cache_path(self, derivation: str, cache_key: str) -> Path: + return self.derived_cache_dir / derivation / cache_key[:2] / f"{cache_key}.npz" + + def _derived_cache_get(self, cache_key: str): + entry = self._derived_cache.get(cache_key) + if entry is not None: + self._derived_cache.move_to_end(cache_key) + return entry + return None + + def _derived_cache_put(self, cache_key: str, roi_tuple, roi_values: np.ndarray) -> None: + self._derived_cache[cache_key] = (roi_tuple, np.ascontiguousarray(roi_values)) + self._derived_cache.move_to_end(cache_key) + while len(self._derived_cache) > self._derived_cache_limit: + self._derived_cache.popitem(last=False) + + def _load_cached_roi(self, derivation: str, cache_key: str): + cached = self._derived_cache_get(cache_key) + if cached is not None: + return cached + if not self._derived_cache_disk_enabled: + return None + + cache_path = self._derived_cache_path(derivation, cache_key) + if not cache_path.exists(): + return None + try: + with np.load(cache_path, allow_pickle=False) as payload: + roi_tuple = tuple(tuple(int(v) for v in pair) for pair in payload["roi"].tolist()) + roi_values = np.ascontiguousarray(payload["values"]) + except Exception: + return None + + self._derived_cache_put(cache_key, roi_tuple, roi_values) + return roi_tuple, roi_values + + def _store_cached_roi(self, derivation: str, cache_key: str, roi_slices, values: np.ndarray) -> None: + if roi_slices is None: + return + + roi_tuple = self._roi_tuple_from_slices(roi_slices) + roi_values = np.ascontiguousarray(values[roi_slices]) + self._derived_cache_put(cache_key, roi_tuple, roi_values) + + if not self._derived_cache_disk_enabled: + return + + cache_path = self._derived_cache_path(derivation, cache_key) + cache_path.parent.mkdir(parents=True, exist_ok=True) + if cache_path.exists(): + return + np.savez(cache_path, roi=np.asarray(roi_tuple, dtype=np.int64), values=roi_values) + + def _apply_cached_roi(self, base_values: np.ndarray, cached_entry): + if cached_entry is None: + return base_values + roi_tuple, roi_values = cached_entry + result = np.array(base_values, copy=True) + result[self._roi_slices_from_tuple(roi_tuple)] = roi_values + return result + # ------------------------------------------------------------------------- # Patch Index Building # ------------------------------------------------------------------------- @@ -831,6 +988,7 @@ def _get_skeleton_targets(self) -> tuple[List[str], Dict[str, int]]: def _initialize_transforms(self) -> None: """Initialize augmentation transforms.""" skeleton_targets, skeleton_ignore_values = self._get_skeleton_targets() + self._precomputed_skeleton_transform = None if self.is_training: no_spatial = getattr(self.mgr, 'no_spatial_augmentation', False) @@ -846,14 +1004,17 @@ def _initialize_transforms(self) -> None: if self._profile_augmentations: self._augmentation_names = collect_augmentation_names(self.transforms) elif skeleton_targets: - # Validation: only apply skeleton generation (no augmentation) - from vesuvius.models.augmentation.pipelines.training_transforms import create_validation_transforms - self.transforms = create_validation_transforms( - skeleton_targets=skeleton_targets, - skeleton_ignore_values=skeleton_ignore_values if skeleton_ignore_values else None, + self._precomputed_skeleton_transform = MedialSurfaceTransform( + do_tube=False, + target_keys=skeleton_targets, + ignore_values=skeleton_ignore_values or None, + cache_dir=str(self.derived_cache_dir / "medial_surface"), + enable_disk_cache=self._derived_cache_disk_enabled, + memory_cache_size=512, ) - if self._profile_augmentations: - self._augmentation_names = collect_augmentation_names(self.transforms) + self.transforms = None + else: + self.transforms = None # ------------------------------------------------------------------------- # Dataset Interface @@ -920,25 +1081,53 @@ def load_array( 'patch_info': { 'volume_name': vol.volume_id, 'position': patch.position, + 'patch_size': patch.patch_size, + 'scale': getattr(vol, "scale", self.ome_zarr_resolution), }, } - is_unlabeled = True + can_assume_labeled = ( + not self.allow_unlabeled_data + and not patch.is_unlabeled_fg + and all(vol.label_arrays.get(target_name) is not None for target_name in self.target_names) + ) + is_unlabeled = False if can_assume_labeled else True for target_name in self.target_names: label_arr = vol.label_arrays.get(target_name) label_data, label_shape = load_array(label_arr, return_original_shape=True) - label_data = self._dilate_label_patch( - label_data, - vol.dilate, - self.target_ignore_labels.get(target_name), - label_shape, - ) - if label_arr is not None and np.count_nonzero(label_data) > 0: + dilate_distance = getattr(vol, "dilate", None) + if dilate_distance not in (None, 0): + cache_key = self._derived_cache_key( + derivation="dilated_label", + patch=patch, + target_name=target_name, + extra={ + "dilate": dilate_distance, + "ignore_label": self.target_ignore_labels.get(target_name), + "shape": label_shape, + }, + ) + cached = self._load_cached_roi("dilated_label", cache_key) + if cached is not None: + label_data = self._apply_cached_roi(label_data, cached) + else: + label_data, roi_slices = self._dilate_label_patch( + label_data, + dilate_distance, + self.target_ignore_labels.get(target_name), + label_shape, + binary_fast_path=(target_name in self._binary_fast_path_targets), + ) + self._store_cached_roi("dilated_label", cache_key, roi_slices, label_data) + if not can_assume_labeled and label_arr is not None and np.count_nonzero(label_data) > 0: is_unlabeled = False result[target_name] = torch.from_numpy(label_data[np.newaxis, ...]) result['is_unlabeled'] = is_unlabeled or patch.is_unlabeled_fg + if self._precomputed_skeleton_transform is not None and not self.is_training: + result = self._precomputed_skeleton_transform.apply(result) + if self.transforms is not None: if self._profile_augmentations and self._augmentation_names: result['_aug_perf'] = {name: 0.0 for name in self._augmentation_names} @@ -947,7 +1136,6 @@ def load_array( return result def _scale_patch_position_for_array(self, patch: PatchInfo) -> Tuple[int, ...]: - """Convert cached full-resolution coordinates into the active array scale.""" factor = int(getattr(patch, "position_scale_factor", 1) or 1) if factor <= 1: return patch.position diff --git a/vesuvius/src/vesuvius/models/training/cli.py b/vesuvius/src/vesuvius/models/training/cli.py index db7a237bc..3175b6d18 100644 --- a/vesuvius/src/vesuvius/models/training/cli.py +++ b/vesuvius/src/vesuvius/models/training/cli.py @@ -35,11 +35,6 @@ def main(argv=None): description="Train Vesuvius neural networks for ink detection and segmentation", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.set_defaults( - ema_enabled=None, - ema_validate=None, - ema_save_in_checkpoint=None, - ) grp_required = parser.add_argument_group("Required") grp_paths = parser.add_argument_group("Paths & Format") @@ -109,20 +104,40 @@ def main(argv=None): help="Type of pooling in encoder ('conv' = strided conv)") # Training Control - grp_train.add_argument("--max-epoch", type=int, default=None, + grp_train.add_argument("--max-epoch", type=int, default=1000, help="Maximum number of epochs") - grp_train.add_argument("--max-steps-per-epoch", type=int, default=None, + grp_train.add_argument("--max-steps-per-epoch", type=int, default=250, help="Max training steps per epoch (use all data if unset)") - grp_train.add_argument("--max-val-steps-per-epoch", type=int, default=None, + grp_train.add_argument("--max-val-steps-per-epoch", type=int, default=50, help="Max validation steps per epoch (use all data if unset)") grp_train.add_argument("--full-epoch", action="store_true", help="Iterate over entire train/val set per epoch (overrides max-steps)") - grp_train.add_argument("--early-stopping-patience", type=int, default=None, + grp_train.add_argument("--early-stopping-patience", type=int, default=0, help="Epochs to wait for val loss improvement (0 disables)") grp_train.add_argument("--ddp", action="store_true", help="Enable DistributedDataParallel (use with torchrun)") - grp_train.add_argument("--val-every-n", dest="val_every_n", type=int, default=None, + grp_train.add_argument("--val-every-n", dest="val_every_n", type=int, default=1, help="Perform validation every N epochs (1=every epoch)") + grp_train.add_argument("--num-dataloader-workers", type=int, + help="Training dataloader worker count") + grp_train.add_argument("--val-num-dataloader-workers", type=int, + help="Validation dataloader worker count") + grp_train.add_argument("--persistent-workers", dest="persistent_workers", action="store_true", + help="Enable persistent dataloader workers") + grp_train.add_argument("--no-persistent-workers", dest="persistent_workers", action="store_false", + help="Disable persistent dataloader workers") + grp_train.add_argument("--prefetch-factor", type=int, + help="Training dataloader prefetch factor") + grp_train.add_argument("--val-prefetch-factor", type=int, + help="Validation dataloader prefetch factor") + grp_train.add_argument("--debug-visualization-every-n", type=int, + help="Save debug GIF/PNG media every N validation epochs (0 disables media saves)") + grp_train.add_argument("--validation-preview-pool-size", type=int, + help="Number of globally ordered validation patches to rotate through for W&B previews") + grp_train.add_argument("--log-every-n-steps", type=int, + help="Log training metrics to W&B every N optimizer steps") + grp_train.add_argument("--numa-pin", type=str, choices=["auto", "off"], + help="NUMA affinity mode for CUDA DDP workers") grp_train.add_argument("--gpus", type=str, default=None, help="Comma-separated GPU device IDs to use, e.g. '0,1,3'. With DDP, length must equal WORLD_SIZE") grp_train.add_argument("--nproc-per-node", type=int, default=None, @@ -131,6 +146,7 @@ def main(argv=None): help="Master address for DDP when spawning without torchrun") grp_train.add_argument("--master-port", type=int, default=None, help="Master port for DDP when spawning without torchrun (default: auto)") + grp_train.set_defaults(persistent_workers=None) # Optimization grp_optim.add_argument("--optimizer", type=str, @@ -143,24 +159,6 @@ def main(argv=None): help="Autocast dtype when AMP is enabled (float16 uses GradScaler; bfloat16 skips scaling)") grp_optim.add_argument("--no-amp", action="store_true", help="Disable Automatic Mixed Precision (AMP)") - grp_optim.add_argument("--ema", dest="ema_enabled", action="store_true", - help="Enable EMA weights tracking for the base trainer") - grp_optim.add_argument("--no-ema", dest="ema_enabled", action="store_false", - help="Disable EMA weights tracking") - grp_optim.add_argument("--ema-decay", type=float, - help="EMA decay factor") - grp_optim.add_argument("--ema-start-step", type=int, - help="Optimizer step at which EMA updates begin") - grp_optim.add_argument("--ema-update-every-steps", type=int, - help="Update EMA weights every N optimizer steps") - grp_optim.add_argument("--ema-validate", dest="ema_validate", action="store_true", - help="Use the EMA model for validation when EMA is enabled") - grp_optim.add_argument("--no-ema-validate", dest="ema_validate", action="store_false", - help="Validate with the student model even when EMA is enabled") - grp_optim.add_argument("--ema-save-in-checkpoint", dest="ema_save_in_checkpoint", action="store_true", - help="Save EMA weights in checkpoints") - grp_optim.add_argument("--no-ema-save-in-checkpoint", dest="ema_save_in_checkpoint", action="store_false", - help="Do not save EMA weights in checkpoints") # Scheduler grp_sched.add_argument("--scheduler", type=str, diff --git a/vesuvius/src/vesuvius/models/training/save_checkpoint.py b/vesuvius/src/vesuvius/models/training/save_checkpoint.py index 1293fceac..6bab679c3 100644 --- a/vesuvius/src/vesuvius/models/training/save_checkpoint.py +++ b/vesuvius/src/vesuvius/models/training/save_checkpoint.py @@ -113,7 +113,9 @@ def manage_checkpoint_history(checkpoint_history, best_checkpoints, epoch, checkpoint_path = Path(checkpoint_path) checkpoint_dir = Path(checkpoint_dir) - checkpoint_history.append((epoch, str(checkpoint_path))) + checkpoint_record = (epoch, str(checkpoint_path)) + if checkpoint_record not in checkpoint_history: + checkpoint_history.append(checkpoint_record) if epoch in [e for e, _ in checkpoint_history]: ckpt_path = next(p for e, p in checkpoint_history if e == epoch) @@ -242,9 +244,8 @@ def cleanup_old_configs(model_ckpt_dir, model_name, keep_latest=1): def save_final_checkpoint(model, optimizer, scheduler, max_epoch, - model_ckpt_dir, model_name, - model_config=None, train_dataset=None, - additional_data=None): + model_ckpt_dir, model_name, + model_config=None, train_dataset=None): """ Save the final model checkpoint at the end of training. @@ -281,8 +282,7 @@ def save_final_checkpoint(model, optimizer, scheduler, max_epoch, epoch=max_epoch - 1, checkpoint_path=final_model_path, model_config=model_config, - train_dataset=train_dataset, - additional_data=additional_data, + train_dataset=train_dataset ) print(f"Final model saved to {final_model_path}") diff --git a/vesuvius/src/vesuvius/models/training/train.py b/vesuvius/src/vesuvius/models/training/train.py index f1251f9ef..826071cef 100644 --- a/vesuvius/src/vesuvius/models/training/train.py +++ b/vesuvius/src/vesuvius/models/training/train.py @@ -9,7 +9,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP import torch.nn.functional as F from vesuvius.models.training.lr_schedulers import get_scheduler -from torch.utils.data import DataLoader, SubsetRandomSampler, Subset, WeightedRandomSampler +from torch.utils.data import DataLoader, Sampler, SubsetRandomSampler, Subset, WeightedRandomSampler from torch.utils.data.distributed import DistributedSampler from vesuvius.models.utils import InitWeights_He from vesuvius.models.datasets import ZarrDataset @@ -37,6 +37,27 @@ from vesuvius.models.evaluation.voi import VOIMetric from contextlib import nullcontext from collections import deque, defaultdict +import gc + +from vesuvius.models.training.numa import apply_numa_affinity + + +class DistributedValidationSampler(Sampler[int]): + def __init__(self, dataset, num_replicas: int, rank: int, max_samples: int | None = None): + self.dataset = dataset + self.num_replicas = max(1, int(num_replicas)) + self.rank = int(rank) + total_size = len(dataset) + if max_samples is not None and max_samples > 0: + total_size = min(total_size, int(max_samples)) + self.global_indices = list(range(total_size)) + self.indices = self.global_indices[self.rank::self.num_replicas] + + def __iter__(self): + return iter(self.indices) + + def __len__(self): + return len(self.indices) @@ -117,6 +138,11 @@ def __init__(self, f"In DDP, number of GPUs in --gpus ({len(self.gpu_ids)}) must equal WORLD_SIZE ({self.world_size})." ) + self.numa_affinity = apply_numa_affinity( + getattr(self.mgr, 'numa_pin', 'auto'), + self.assigned_gpu_id, + ) + # Friendly prints if self.is_distributed and (not self.rank or self.rank == 0): if torch.cuda.is_available(): @@ -130,6 +156,12 @@ def __init__(self, else: print(f"Using GPU {self.gpu_ids[0]}") + if self.numa_affinity is not None: + print( + f"Rank {self.rank} pinned to {self.numa_affinity['cpu_count']} CPUs " + f"for GPU {self.numa_affinity['gpu_id']}" + ) + # Default AMP dtype; resolved during training initialization self.amp_dtype = torch.float16 self.amp_dtype_str = 'float16' @@ -137,29 +169,6 @@ def __init__(self, self._augmentation_names = None self._epoch_aug_time = None self._epoch_aug_count = None - ema_cfg = getattr(self.mgr, 'ema_config', {}) or {} - self.ema_enabled = bool(getattr(self.mgr, 'ema_enabled', ema_cfg.get('enabled', False))) - self.ema_decay = float(getattr(self.mgr, 'ema_decay', ema_cfg.get('decay', 0.999))) - self.ema_start_step = int(getattr(self.mgr, 'ema_start_step', ema_cfg.get('start_step', 0))) - self.ema_update_every_steps = max( - 1, - int(getattr(self.mgr, 'ema_update_every_steps', ema_cfg.get('update_every_steps', 1))) - ) - self.ema_validate = bool( - getattr(self.mgr, 'ema_validate', ema_cfg.get('validate', self.ema_enabled)) - ) - self.ema_save_in_checkpoint = bool( - getattr( - self.mgr, - 'ema_save_in_checkpoint', - ema_cfg.get('save_in_checkpoint', self.ema_enabled), - ) - ) - self.ema_model = None - self._ema_optimizer_step = 0 - self._checkpoint_ema_state = None - self._checkpoint_ema_optimizer_step = None - self._printed_ema_validation_mode = False # --- build model --- # def _build_model(self): @@ -183,99 +192,8 @@ def _get_additional_checkpoint_data(self): Subclasses can override this to save extra state (e.g., EMA model). Returns a dict that will be merged into the checkpoint. """ - if self.ema_model is not None and self.ema_save_in_checkpoint: - return { - 'ema_model': self.ema_model.state_dict(), - 'ema_optimizer_step': int(self._ema_optimizer_step), - } return {} - def _unwrap_model(self, model): - if hasattr(model, 'module'): - model = model.module - if hasattr(model, '_orig_mod'): - try: - model = model._orig_mod - except Exception: - pass - return model - - def _wrap_model_for_distributed_training(self, model): - if not self.is_distributed: - return model - - ddp_kwargs = {"find_unused_parameters": True} - if self.device.type == 'cuda': - ddp_kwargs.update( - device_ids=[self.assigned_gpu_id], - output_device=self.assigned_gpu_id, - ) - return DDP(model, **ddp_kwargs) - - def _create_ema_model(self, model): - ema_model = deepcopy(self._unwrap_model(model)) - ema_model = ema_model.to(self.device) - ema_model.eval() - for parameter in ema_model.parameters(): - parameter.requires_grad_(False) - return ema_model - - def _initialize_ema_model(self, model): - if not self.ema_enabled: - self.ema_model = None - return None - - self.ema_model = self._create_ema_model(model) - if self._checkpoint_ema_state is not None: - try: - self.ema_model.load_state_dict(self._checkpoint_ema_state) - self._ema_optimizer_step = int(self._checkpoint_ema_optimizer_step or 0) - print( - "Restored EMA model from checkpoint " - f"(optimizer_step={self._ema_optimizer_step})" - ) - except Exception as exc: - print(f"Warning: Failed to restore EMA model from checkpoint: {exc}") - print("Using freshly initialized EMA model") - self._ema_optimizer_step = 0 - finally: - self._checkpoint_ema_state = None - self._checkpoint_ema_optimizer_step = None - else: - print( - "Created EMA model " - f"(decay={self.ema_decay}, start_step={self.ema_start_step}, " - f"update_every_steps={self.ema_update_every_steps})" - ) - return self.ema_model - - def _update_ema_model(self, model): - if self.ema_model is None: - return - - self._ema_optimizer_step += 1 - if self._ema_optimizer_step < self.ema_start_step: - return - if ((self._ema_optimizer_step - self.ema_start_step) % self.ema_update_every_steps) != 0: - return - - ema_state = self.ema_model.state_dict() - for name, model_value in self._unwrap_model(model).state_dict().items(): - ema_value = ema_state[name] - model_value = model_value.detach() - if torch.is_floating_point(ema_value): - ema_value.lerp_(model_value.to(dtype=ema_value.dtype), 1.0 - self.ema_decay) - else: - ema_value.copy_(model_value) - - def _get_validation_model(self, model): - if self.ema_model is not None and self.ema_validate: - if not self._printed_ema_validation_mode: - print("Validation will use the EMA model") - self._printed_ema_validation_mode = True - return self.ema_model - return model - # --- configure dataset --- # def _configure_dataset(self, is_training=True): dataset = self._build_dataset_for_mgr(self.mgr, is_training=is_training) @@ -789,10 +707,14 @@ def _initialize_evaluation_metrics(self): if target_ignore_value is not None: task_metrics.append(IOUDiceMetric(num_classes=num_classes, ignore_index=target_ignore_value)) - task_metrics.append(VOIMetric(ignore_index=target_ignore_value)) + # task_metrics.append(VOIMetric(ignore_index=target_ignore_value)) + # VOI is intentionally disabled for distributed online validation because + # it is too expensive and not needed for runtime monitoring. else: task_metrics.append(IOUDiceMetric(num_classes=num_classes)) - task_metrics.append(VOIMetric()) + # task_metrics.append(VOIMetric()) + # VOI is intentionally disabled for distributed online validation because + # it is too expensive and not needed for runtime monitoring. # task_metrics.append(SkeletonBranchPointsMetric(num_classes=num_classes)) # task_metrics.append(HausdorffDistanceMetric(num_classes=num_classes)) metrics[task_name] = task_metrics @@ -836,6 +758,87 @@ def _autocast_context(self, use_amp: bool): return torch.amp.autocast(self.device.type, dtype=self.amp_dtype) return torch.amp.autocast(self.device.type) + def _should_save_debug_media(self, epoch: int) -> bool: + every_n = int(getattr(self.mgr, 'debug_visualization_every_n', 0)) + return every_n > 0 and ((epoch + 1) % every_n == 0) + + def _validation_preview_target(self, epoch: int): + if not getattr(self.mgr, 'log_validation_preview', True): + return None + preview_pool_size = max(0, int(getattr(self.mgr, 'validation_preview_pool_size', 32))) + global_indices = tuple(getattr(self, '_val_global_indices', ())) + if preview_pool_size == 0 or not global_indices: + return None + preview_pool = global_indices[:min(preview_pool_size, len(global_indices))] + if not preview_pool: + return None + return int(preview_pool[epoch % len(preview_pool)]) + + def _reduce_validation_task_losses(self, local_sums, local_counts): + task_names = list(self.mgr.targets.keys()) + reduce_device = self.device if self.device.type == 'cuda' else torch.device('cpu') + sum_tensor = torch.tensor( + [float(local_sums.get(task_name, 0.0)) for task_name in task_names], + device=reduce_device, + dtype=torch.float64, + ) + count_tensor = torch.tensor( + [float(local_counts.get(task_name, 0.0)) for task_name in task_names], + device=reduce_device, + dtype=torch.float64, + ) + if self.is_distributed: + dist.all_reduce(sum_tensor, op=dist.ReduceOp.SUM) + dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM) + + reduced = {} + for idx, task_name in enumerate(task_names): + count = float(count_tensor[idx].item()) + reduced[task_name] = float(sum_tensor[idx].item() / count) if count > 0 else 0.0 + return reduced + + def _merge_metric_payloads(self, gathered_payloads): + merged = {} + for payload in gathered_payloads: + if not payload: + continue + for task_name, metrics in payload.items(): + task_accumulator = merged.setdefault(task_name, {}) + for metric_name, metric_payload in metrics.items(): + count = int(metric_payload.get("count", 0)) + if count <= 0: + continue + metric_accumulator = task_accumulator.setdefault(metric_name, {}) + for key, value in metric_payload.get("values", {}).items(): + slot = metric_accumulator.setdefault( + key, + {"weighted_sum": 0.0, "count": 0}, + ) + slot["weighted_sum"] += float(value) * count + slot["count"] += count + + reduced = {} + for task_name, metrics in merged.items(): + reduced[task_name] = {} + for metric_name, metric_values in metrics.items(): + reduced[task_name][metric_name] = {} + for key, accumulator in metric_values.items(): + count = max(1, accumulator["count"]) + reduced[task_name][metric_name][key] = accumulator["weighted_sum"] / count + return reduced + + def _gather_preview_image(self, preview_image): + if not self.is_distributed: + return preview_image + gathered = [None for _ in range(self.world_size)] + dist.all_gather_object(gathered, preview_image) + if self.rank != 0: + return None + for image in gathered: + if image is not None: + return image + return None + # --- dataloaders --- # def _configure_dataloaders(self, train_dataset, val_dataset=None): @@ -923,13 +926,15 @@ def _configure_dataloaders(self, train_dataset, val_dataset=None): val_base = val_dataset if val_dataset is not None else train_dataset train_subset = Subset(train_base, train_indices) val_subset = Subset(val_base, val_indices) + if getattr(self.mgr, 'max_val_steps_per_epoch', None) is not None and self.mgr.max_val_steps_per_epoch > 0: + global_val_budget = min(len(val_subset), int(self.mgr.max_val_steps_per_epoch)) + else: + global_val_budget = len(val_subset) if self.is_distributed: train_sampler = DistributedSampler( train_subset, num_replicas=self.world_size, rank=self.rank, shuffle=True, drop_last=False ) - # For validation we only run on rank 0; sampler unused there, but keep a sequential sampler for completeness - val_sampler = None else: if hasattr(train_base, 'patch_weights') and isinstance(getattr(train_base, 'patch_weights', None), list): if train_base.patch_weights and len(train_base.patch_weights) >= len(train_base): @@ -968,12 +973,25 @@ def _configure_dataloaders(self, train_dataset, val_dataset=None): train_sampler = SubsetRandomSampler(list(range(len(train_subset)))) else: train_sampler = SubsetRandomSampler(list(range(len(train_subset)))) - val_sampler = SubsetRandomSampler(list(range(len(val_subset)))) + + val_sampler = DistributedValidationSampler( + val_subset, + num_replicas=self.world_size if self.is_distributed else 1, + rank=self.rank if self.is_distributed else 0, + max_samples=global_val_budget, + ) + self._val_global_indices = tuple(val_sampler.global_indices) pin_mem = True if self.device.type == 'cuda' else False - dl_kwargs = {} + train_dl_kwargs = {} if self.mgr.train_num_dataloader_workers and self.mgr.train_num_dataloader_workers > 0: - dl_kwargs['prefetch_factor'] = 2 + train_dl_kwargs['prefetch_factor'] = max(1, int(getattr(self.mgr, 'train_prefetch_factor', 2))) + train_dl_kwargs['persistent_workers'] = bool(getattr(self.mgr, 'persistent_workers', True)) + + val_dl_kwargs = {} + if self.mgr.val_num_dataloader_workers and self.mgr.val_num_dataloader_workers > 0: + val_dl_kwargs['prefetch_factor'] = max(1, int(getattr(self.mgr, 'val_prefetch_factor', 2))) + val_dl_kwargs['persistent_workers'] = bool(getattr(self.mgr, 'val_persistent_workers', True)) train_dataloader = DataLoader( train_subset, @@ -982,18 +1000,17 @@ def _configure_dataloaders(self, train_dataset, val_dataset=None): shuffle=False, pin_memory=pin_mem, num_workers=self.mgr.train_num_dataloader_workers, - **dl_kwargs + **train_dl_kwargs ) - # Validation dataloader will only be iterated on rank 0 in DDP val_dataloader = DataLoader( val_subset, batch_size=1, sampler=val_sampler, shuffle=False, pin_memory=pin_mem, - num_workers=self.mgr.train_num_dataloader_workers, - **dl_kwargs + num_workers=self.mgr.val_num_dataloader_workers, + **val_dl_kwargs ) return train_dataloader, val_dataloader, train_indices, val_indices @@ -1091,7 +1108,11 @@ def _initialize_training(self): val_dataset) # Wrap model with DDP if distributed - model = self._wrap_model_for_distributed_training(model) + if self.is_distributed: + if self.device.type == 'cuda': + model = DDP(model, device_ids=[self.assigned_gpu_id], output_device=self.assigned_gpu_id, find_unused_parameters=True) + else: + model = DDP(model, find_unused_parameters=True) os.makedirs(self.mgr.ckpt_out_base, exist_ok=True) model_ckpt_dir = os.path.join(self.mgr.ckpt_out_base, self.mgr.model_name) os.makedirs(model_ckpt_dir, exist_ok=True) @@ -1099,7 +1120,7 @@ def _initialize_training(self): now = datetime.now() date_str = now.strftime('%m%d%y') time_str = now.strftime('%H%M') - ckpt_dir = os.path.join('checkpoints', f"{self.mgr.model_name}_{date_str}{time_str}") + ckpt_dir = os.path.join(self.mgr.ckpt_out_base, f"{self.mgr.model_name}_{date_str}{time_str}") os.makedirs(ckpt_dir, exist_ok=True) loss_overrides = self._capture_loss_overrides() @@ -1124,7 +1145,6 @@ def _initialize_training(self): ckpt = torch.load(self.mgr.checkpoint_path, map_location=self.device, weights_only=False) if isinstance(ckpt, dict) and 'ema_model' in ckpt: self._checkpoint_ema_state = ckpt['ema_model'] - self._checkpoint_ema_optimizer_step = int(ckpt.get('ema_optimizer_step', 0)) print("Found EMA model state in checkpoint") del ckpt except Exception: @@ -1143,7 +1163,6 @@ def _initialize_training(self): self._ds_scales = None self._ds_weights = None loss_fns = self._build_loss() - self._initialize_ema_model(model) if self.device.type == 'cuda': try: @@ -1182,6 +1201,9 @@ def _initialize_wandb(self, train_dataset, val_dataset, train_indices, val_indic train_val_splits = save_train_val_filenames(self, train_dataset, val_dataset, train_indices, val_indices) save_dir = ckpt_dir if ckpt_dir else os.getcwd() + wandb_dir = os.path.join(str(self.mgr.ckpt_out_base), "wandb") + os.makedirs(wandb_dir, exist_ok=True) + os.environ["WANDB_DIR"] = wandb_dir timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") splits_filename = f"train_val_splits_{self.mgr.model_name}_{timestamp}.json" @@ -1202,6 +1224,7 @@ def _initialize_wandb(self, train_dataset, val_dataset, train_indices, val_indic "project": self.mgr.wandb_project, "group": self.mgr.model_name, "config": mgr_config, + "dir": wandb_dir, } wandb_resume = getattr(self.mgr, "wandb_resume", None) if wandb_resume: @@ -1426,9 +1449,9 @@ def _on_epoch_end(self, epoch, model, optimizer, scheduler, train_dataset, additional_data=self._get_additional_checkpoint_data() ) - checkpoint_history.append((epoch, ckpt_path)) - del checkpoint_data + if self.device.type == 'cuda': + torch.cuda.empty_cache() # Manage checkpoint history checkpoint_history, best_checkpoints = manage_checkpoint_history( @@ -1653,10 +1676,7 @@ def train(self): train_sample_outputs[t_name] = p_val[b_idx: b_idx + 1] if optimizer_stepped and is_per_iteration_scheduler: - self._update_ema_model(model) scheduler.step() - elif optimizer_stepped: - self._update_ema_model(model) if pbar is not None: loss_str = " | ".join([f"{t}: {np.mean(epoch_losses[t][-100:]):.4f}" @@ -1666,7 +1686,16 @@ def train(self): current_lr = optimizer.param_groups[0]['lr'] - if self.mgr.wandb_project and (not self.is_distributed or self.rank == 0): + should_log_train_metrics = ( + self.mgr.wandb_project + and (not self.is_distributed or self.rank == 0) + and ( + global_step == 1 + or (global_step % max(1, int(getattr(self.mgr, 'log_every_n_steps', 10)))) == 0 + or i == (num_iters - 1) + ) + ) + if should_log_train_metrics: metrics = self._prepare_metrics_for_logging( epoch=epoch, step=global_step, @@ -1686,6 +1715,10 @@ def train(self): if not is_per_iteration_scheduler and not step_scheduler_at_epoch_begin: scheduler.step() + gc.collect() + if self.device.type == 'cuda': + torch.cuda.empty_cache() + # Report the effective learning rate(s) after all scheduler updates for this epoch. current_lrs = [group['lr'] for group in optimizer.param_groups] @@ -1708,30 +1741,30 @@ def train(self): # ---- validation ----- # val_every_n = int(getattr(self.mgr, 'val_every_n', 1)) do_validate = ((epoch + 1) % max(1, val_every_n) == 0) - if do_validate and (not self.is_distributed or self.rank == 0): - validation_model = self._get_validation_model(model) + stop_training_early = False + if do_validate: # For MAE training, don't set to eval mode to keep patch dropping active if not hasattr(self, '_is_mae_training'): model.eval() - if validation_model is not model or not hasattr(self, '_is_mae_training'): - validation_model.eval() with torch.no_grad(): - val_losses = {t_name: [] for t_name in self.mgr.targets} - debug_preview_image = None - + local_val_sums = {t_name: 0.0 for t_name in self.mgr.targets} + local_val_counts = {t_name: 0 for t_name in self.mgr.targets} + local_preview_image = None + local_metric_payload = {} + should_save_debug_media = self._should_save_debug_media(epoch) and (not self.is_distributed or self.rank == 0) + # Initialize evaluation metrics evaluation_metrics = self._initialize_evaluation_metrics() val_dataloader_iter = iter(val_dataloader) + num_val_iters = len(val_dataloader) + val_sampler = getattr(val_dataloader, 'sampler', None) + local_val_indices = list(getattr(val_sampler, 'indices', range(num_val_iters))) + preview_target = self._validation_preview_target(epoch) - if hasattr(self.mgr, 'max_val_steps_per_epoch') and self.mgr.max_val_steps_per_epoch is not None and self.mgr.max_val_steps_per_epoch > 0: - num_val_iters = min(len(val_indices), self.mgr.max_val_steps_per_epoch) - else: - num_val_iters = len(val_indices) - - val_pbar = tqdm(range(num_val_iters), desc=f'Validation {epoch + 1}') + val_pbar = tqdm(range(num_val_iters), desc=f'Validation {epoch + 1}') if (not self.is_distributed or self.rank == 0) else None - for i in val_pbar: + for i in range(num_val_iters): try: data_dict = next(val_dataloader_iter) except StopIteration: @@ -1739,7 +1772,7 @@ def train(self): data_dict = next(val_dataloader_iter) task_losses, inputs, targets_dict, outputs = self._validation_step( - model=validation_model, + model=model, data_dict=data_dict, loss_fns=loss_fns, use_amp=use_amp @@ -1747,9 +1780,11 @@ def train(self): for t_name, loss_value in task_losses.items(): # Ensure we have a slot for dynamically introduced tasks (e.g., 'mae') - if t_name not in val_losses: - val_losses[t_name] = [] - val_losses[t_name].append(loss_value) + if t_name not in local_val_sums: + local_val_sums[t_name] = 0.0 + local_val_counts[t_name] = 0 + local_val_sums[t_name] += float(loss_value) + local_val_counts[t_name] += 1 # Compute evaluation metrics for each task (handle deep supervision lists) for t_name in self.mgr.targets: @@ -1769,7 +1804,8 @@ def train(self): continue metric.update(pred=pred_val, gt=gt_val, mask=mask_tensor) - if i == 0: + current_val_index = local_val_indices[i] if i < len(local_val_indices) else None + if preview_target is not None and current_val_index == preview_target: # Find first non-zero sample for debug visualization, but save even if all zeros b_idx = 0 found_non_zero = False @@ -1846,66 +1882,95 @@ def train(self): train_outputs_dict=train_sample_outputs, skeleton_dict=skeleton_dict, train_skeleton_dict=train_skeleton_dict, + save_media=should_save_debug_media, unlabeled_input=unlabeled_input, unlabeled_pseudo_dict=unlabeled_pseudo, unlabeled_outputs_dict=unlabeled_pred ) - debug_gif_history.append((epoch, debug_img_path)) - - loss_str = " | ".join([f"{t}: {np.mean(val_losses[t]):.4f}" - for t in self.mgr.targets if len(val_losses[t]) > 0]) - val_pbar.set_postfix_str(loss_str) + local_preview_image = debug_preview_image + if should_save_debug_media: + debug_gif_history.append((epoch, debug_img_path)) + + if val_pbar is not None: + loss_str = " | ".join([ + f"{t}: {local_val_sums[t] / max(1, local_val_counts[t]):.4f}" + for t in self.mgr.targets if local_val_counts.get(t, 0) > 0 + ]) + val_pbar.set_postfix_str(loss_str) + val_pbar.update(1) del outputs, inputs, targets_dict - print(f"\n[Validation] Epoch {epoch + 1} summary:") - total_val_loss = 0.0 - for t_name in self.mgr.targets: - val_avg = np.mean(val_losses[t_name]) if val_losses[t_name] else 0 - print(f" Task '{t_name}': Avg validation loss = {val_avg:.4f}") - total_val_loss += val_avg + if val_pbar is not None: + val_pbar.close() + + reduced_val_losses = self._reduce_validation_task_losses(local_val_sums, local_val_counts) - avg_val_loss = total_val_loss / len(self.mgr.targets) if self.mgr.targets else 0 - val_loss_history[epoch] = avg_val_loss - - print("\n[Validation Metrics]") - metric_results = {} for t_name in self.mgr.targets: - if t_name in evaluation_metrics: - print(f" Task '{t_name}':") - for metric in evaluation_metrics[t_name]: - aggregated = metric.aggregate() - for metric_name, value in aggregated.items(): - full_metric_name = f"{t_name}_{metric_name}" - metric_results[full_metric_name] = value - display_name = f"{metric.name}_{metric_name}" - print(f" {display_name}: {value:.4f}") - - if self.mgr.wandb_project: - val_metrics = {"epoch": epoch, "step": global_step} - for t_name in self.mgr.targets: - if t_name in val_losses and len(val_losses[t_name]) > 0: - val_metrics[f"val_loss_{t_name}"] = np.mean(val_losses[t_name]) - val_metrics["val_loss_total"] = avg_val_loss - - # Add evaluation metrics to wandb - for metric_name, value in metric_results.items(): - val_metrics[f"val_{metric_name}"] = value + metric_group = {} + for metric in evaluation_metrics.get(t_name, []): + aggregated = metric.aggregate() + if aggregated: + metric_group[metric.name] = { + "count": len(metric.results), + "values": aggregated, + } + if metric_group: + local_metric_payload[t_name] = metric_group + + if self.is_distributed: + gathered_metric_payloads = [None for _ in range(self.world_size)] + dist.all_gather_object(gathered_metric_payloads, local_metric_payload) + else: + gathered_metric_payloads = [local_metric_payload] - import wandb + debug_preview_image = self._gather_preview_image(local_preview_image) - if debug_preview_image is not None: - preview_to_log = debug_preview_image - if preview_to_log.ndim == 3 and preview_to_log.shape[2] == 3: - # Convert BGR (OpenCV) to RGB for wandb - preview_to_log = preview_to_log[..., ::-1] - preview_to_log = np.ascontiguousarray(preview_to_log) - val_metrics["debug_image"] = wandb.Image(preview_to_log) + if not self.is_distributed or self.rank == 0: + print(f"\n[Validation] Epoch {epoch + 1} summary:") + total_val_loss = 0.0 + for t_name in self.mgr.targets: + val_avg = reduced_val_losses.get(t_name, 0.0) + print(f" Task '{t_name}': Avg validation loss = {val_avg:.4f}") + total_val_loss += val_avg - wandb.log(val_metrics) + avg_val_loss = total_val_loss / len(self.mgr.targets) if self.mgr.targets else 0 + val_loss_history[epoch] = avg_val_loss + + print("\n[Validation Metrics]") + metric_results = {} + merged_metrics = self._merge_metric_payloads(gathered_metric_payloads) + for t_name in self.mgr.targets: + if t_name in merged_metrics: + print(f" Task '{t_name}':") + for metric_name, aggregated in merged_metrics[t_name].items(): + for metric_key, value in aggregated.items(): + full_metric_name = f"{t_name}_{metric_key}" + metric_results[full_metric_name] = value + print(f" {metric_name}_{metric_key}: {value:.4f}") + + if self.mgr.wandb_project: + val_metrics = {"epoch": epoch, "step": global_step} + for t_name in self.mgr.targets: + val_metrics[f"val_loss_{t_name}"] = reduced_val_losses.get(t_name, 0.0) + val_metrics["val_loss_total"] = avg_val_loss + + for metric_name, value in metric_results.items(): + val_metrics[f"val_{metric_name}"] = value + + import wandb + + if debug_preview_image is not None: + preview_to_log = debug_preview_image + if preview_to_log.ndim == 3 and preview_to_log.shape[2] == 3: + preview_to_log = preview_to_log[..., ::-1] + preview_to_log = np.ascontiguousarray(preview_to_log) + val_metrics["debug_image"] = wandb.Image(preview_to_log) + + wandb.log(val_metrics) # Early stopping check - if early_stopping_patience > 0: + if (not self.is_distributed or self.rank == 0) and early_stopping_patience > 0: if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss patience_counter = 0 @@ -1918,36 +1983,51 @@ def train(self): print(f"\n[Early Stopping] Validation loss did not improve for {early_stopping_patience} epochs.") print(f"Best validation loss: {best_val_loss:.4f}") print("Stopping training early.") - break + stop_training_early = True # Handle epoch end operations (checkpointing, cleanup) - checkpoint_history, best_checkpoints, ckpt_path = self._on_epoch_end( - epoch=epoch, - model=model, - optimizer=optimizer, - scheduler=scheduler, - train_dataset=train_dataset, - ckpt_dir=ckpt_dir, - model_ckpt_dir=model_ckpt_dir, - checkpoint_history=checkpoint_history, - best_checkpoints=best_checkpoints, - avg_val_loss=avg_val_loss - ) - - # Manage debug videos - if epoch in [e for e, _ in debug_gif_history]: - debug_gif_history, best_debug_gifs = manage_debug_gifs( - debug_gif_history=debug_gif_history, - best_debug_gifs=best_debug_gifs, + if (not self.is_distributed or self.rank == 0) and not stop_training_early: + checkpoint_history, best_checkpoints, ckpt_path = self._on_epoch_end( epoch=epoch, - gif_path=next(p for e, p in debug_gif_history if e == epoch), - validation_loss=avg_val_loss, - checkpoint_dir=ckpt_dir, - model_name=self.mgr.model_name, - max_recent=3, - max_best=2 + model=model, + optimizer=optimizer, + scheduler=scheduler, + train_dataset=train_dataset, + ckpt_dir=ckpt_dir, + model_ckpt_dir=model_ckpt_dir, + checkpoint_history=checkpoint_history, + best_checkpoints=best_checkpoints, + avg_val_loss=avg_val_loss ) + if epoch in [e for e, _ in debug_gif_history]: + debug_gif_history, best_debug_gifs = manage_debug_gifs( + debug_gif_history=debug_gif_history, + best_debug_gifs=best_debug_gifs, + epoch=epoch, + gif_path=next(p for e, p in debug_gif_history if e == epoch), + validation_loss=avg_val_loss, + checkpoint_dir=ckpt_dir, + model_name=self.mgr.model_name, + max_recent=3, + max_best=2 + ) + + if self.is_distributed: + dist.barrier() + + if self.is_distributed: + stop_tensor = torch.tensor( + 1 if stop_training_early else 0, + device=self.device if self.device.type == 'cuda' else torch.device('cpu'), + dtype=torch.int32, + ) + dist.all_reduce(stop_tensor, op=dist.ReduceOp.MAX) + stop_training_early = bool(stop_tensor.item()) + + if stop_training_early: + break + # Synchronize all ranks before finalization if self.is_distributed: dist.barrier() @@ -1963,8 +2043,7 @@ def train(self): model_ckpt_dir=model_ckpt_dir, model_name=self.mgr.model_name, model_config=getattr(model, 'final_config', None), - train_dataset=train_dataset, - additional_data=self._get_additional_checkpoint_data(), + train_dataset=train_dataset ) # Clean up DDP process group diff --git a/vesuvius/src/vesuvius/models/utilities/cli_utils.py b/vesuvius/src/vesuvius/models/utilities/cli_utils.py index f921b9686..811f7ec35 100644 --- a/vesuvius/src/vesuvius/models/utilities/cli_utils.py +++ b/vesuvius/src/vesuvius/models/utilities/cli_utils.py @@ -82,6 +82,44 @@ def update_config_from_args(mgr, args): mgr.max_val_steps_per_epoch = args.max_val_steps_per_epoch mgr.tr_configs["max_val_steps_per_epoch"] = args.max_val_steps_per_epoch + if getattr(args, 'num_dataloader_workers', None) is not None: + mgr.train_num_dataloader_workers = int(args.num_dataloader_workers) + mgr.tr_configs["num_dataloader_workers"] = int(args.num_dataloader_workers) + + if getattr(args, 'val_num_dataloader_workers', None) is not None: + mgr.val_num_dataloader_workers = int(args.val_num_dataloader_workers) + mgr.tr_configs["val_num_dataloader_workers"] = int(args.val_num_dataloader_workers) + + if getattr(args, 'persistent_workers', None) is not None: + mgr.persistent_workers = bool(args.persistent_workers) + mgr.val_persistent_workers = bool(args.persistent_workers) + mgr.tr_configs["persistent_workers"] = bool(args.persistent_workers) + mgr.tr_configs["val_persistent_workers"] = bool(args.persistent_workers) + + if getattr(args, 'prefetch_factor', None) is not None: + mgr.train_prefetch_factor = int(args.prefetch_factor) + mgr.tr_configs["prefetch_factor"] = int(args.prefetch_factor) + + if getattr(args, 'val_prefetch_factor', None) is not None: + mgr.val_prefetch_factor = int(args.val_prefetch_factor) + mgr.tr_configs["val_prefetch_factor"] = int(args.val_prefetch_factor) + + if getattr(args, 'debug_visualization_every_n', None) is not None: + mgr.debug_visualization_every_n = int(args.debug_visualization_every_n) + mgr.tr_configs["debug_visualization_every_n"] = int(args.debug_visualization_every_n) + + if getattr(args, 'validation_preview_pool_size', None) is not None: + mgr.validation_preview_pool_size = int(args.validation_preview_pool_size) + mgr.tr_configs["validation_preview_pool_size"] = int(args.validation_preview_pool_size) + + if getattr(args, 'log_every_n_steps', None) is not None: + mgr.log_every_n_steps = max(1, int(args.log_every_n_steps)) + mgr.tr_configs["log_every_n_steps"] = mgr.log_every_n_steps + + if getattr(args, 'numa_pin', None) is not None: + mgr.numa_pin = str(args.numa_pin).lower() + mgr.tr_configs["numa_pin"] = mgr.numa_pin + if getattr(args, 'profile_augmentations', False): mgr.profile_augmentations = True mgr.tr_configs["profile_augmentations"] = True @@ -232,32 +270,6 @@ def update_config_from_args(mgr, args): if mgr.verbose: print(f"Set gradient clipping: {mgr.gradient_clip}") - ema_updates = {} - if getattr(args, 'ema_enabled', None) is not None: - mgr.ema_enabled = bool(args.ema_enabled) - ema_updates["enabled"] = mgr.ema_enabled - if getattr(args, 'ema_decay', None) is not None: - mgr.ema_decay = float(args.ema_decay) - ema_updates["decay"] = mgr.ema_decay - if getattr(args, 'ema_start_step', None) is not None: - mgr.ema_start_step = int(args.ema_start_step) - ema_updates["start_step"] = mgr.ema_start_step - if getattr(args, 'ema_update_every_steps', None) is not None: - mgr.ema_update_every_steps = max(1, int(args.ema_update_every_steps)) - ema_updates["update_every_steps"] = mgr.ema_update_every_steps - if getattr(args, 'ema_validate', None) is not None: - mgr.ema_validate = bool(args.ema_validate) - ema_updates["validate"] = mgr.ema_validate - if getattr(args, 'ema_save_in_checkpoint', None) is not None: - mgr.ema_save_in_checkpoint = bool(args.ema_save_in_checkpoint) - ema_updates["save_in_checkpoint"] = mgr.ema_save_in_checkpoint - if ema_updates: - if not hasattr(mgr, 'ema_config') or mgr.ema_config is None: - mgr.ema_config = {} - mgr.ema_config.update(ema_updates) - if mgr.verbose: - print(f"Updated EMA config: {mgr.ema_config}") - # Gradient accumulation steps if hasattr(args, 'gradient_accumulation') and args.gradient_accumulation is not None: if args.gradient_accumulation < 1: diff --git a/vesuvius/src/vesuvius/utils/plotting.py b/vesuvius/src/vesuvius/utils/plotting.py index 235e30e6b..40d652824 100644 --- a/vesuvius/src/vesuvius/utils/plotting.py +++ b/vesuvius/src/vesuvius/utils/plotting.py @@ -323,6 +323,71 @@ def convert_slice_to_bgr( raise ValueError(f"Expected 2D or 3D array, got shape {slice_2d_or_3d.shape}") +def _target_volume_for_preview(arr_np: np.ndarray, is_2d_run: bool) -> Optional[np.ndarray]: + if is_2d_run: + return None + if arr_np.ndim == 4: + if arr_np.shape[0] == 1: + return arr_np[0] + if arr_np.shape[0] == 2: + return arr_np[1] + return np.argmax(arr_np, axis=0) + if arr_np.ndim == 3: + return arr_np + return None + + +def _prediction_volume_for_preview(arr_np: np.ndarray, is_2d_run: bool) -> Optional[np.ndarray]: + if is_2d_run: + return None + if arr_np.ndim == 4: + if arr_np.shape[0] == 1: + return arr_np[0] + if arr_np.shape[0] >= 2: + return arr_np[1:].sum(axis=0) + if arr_np.ndim == 3: + return arr_np + return None + + +def _choose_preview_slice_index( + input_array: np.ndarray, + targets_np: Dict[str, np.ndarray], + preds_np: Dict[str, np.ndarray], + *, + is_2d_run: bool, +) -> Optional[int]: + if is_2d_run: + return None + + if input_array.ndim == 3: + z_dim = input_array.shape[0] + elif input_array.ndim == 4: + z_dim = input_array.shape[1] + else: + return None + + gt_scores = np.zeros(z_dim, dtype=np.float64) + for arr_np in targets_np.values(): + volume = _target_volume_for_preview(arr_np, is_2d_run) + if volume is None or volume.ndim != 3: + continue + gt_scores += np.count_nonzero(volume > 0, axis=(1, 2)) + if np.any(gt_scores > 0): + return int(gt_scores.argmax()) + + pred_scores = np.zeros(z_dim, dtype=np.float64) + for arr_np in preds_np.values(): + volume = _prediction_volume_for_preview(arr_np, is_2d_run) + if volume is None or volume.ndim != 3: + continue + pred_scores += volume.astype(np.float32, copy=False).sum(axis=(1, 2)) + if np.any(pred_scores > 0): + return int(pred_scores.argmax()) + + return int(max(z_dim // 2, 0)) + + def save_debug( input_volume: torch.Tensor, # shape [1, C, Z, H, W] for 3D or [1, C, H, W] for 2D targets_dict: dict, # e.g. {"sheet": tensor([1, Z, H, W]), "normals": tensor([3, Z, H, W])} @@ -338,6 +403,8 @@ def save_debug( skeleton_dict: dict = None, # Optional skeleton data for visualization train_skeleton_dict: dict = None, # Optional train skeleton data apply_activation: bool = True, # Whether to apply activation functions + save_media: bool = True, # Whether to write GIF/PNG media to disk + preview_slice_index: Optional[int] = None, # Preferred 3D slice for W&B preview # Unlabeled sample visualization for semi-supervised training unlabeled_input: torch.Tensor = None, # Optional unlabeled sample input unlabeled_pseudo_dict: dict = None, # Teacher predictions (pseudo-labels) @@ -704,11 +771,11 @@ def _pad_rows_to_uniform_width(rows_list: list[np.ndarray]) -> list[np.ndarray]: # Stack rows and save rows = _pad_rows_to_uniform_width(rows) final_img = np.vstack(rows) - out_dir = Path(save_path).parent - out_dir.mkdir(parents=True, exist_ok=True) - print(f"[Epoch {epoch}] Saving PNG to: {save_path}") - # Use PIL for saving - Image.fromarray(final_img).save(save_path) + if save_media: + out_dir = Path(save_path).parent + out_dir.mkdir(parents=True, exist_ok=True) + print(f"[Epoch {epoch}] Saving PNG to: {save_path}") + Image.fromarray(final_img).save(save_path) preview_img = np.ascontiguousarray(final_img, dtype=np.uint8) return None, preview_img @@ -718,9 +785,19 @@ def _pad_rows_to_uniform_width(rows_list: list[np.ndarray]) -> list[np.ndarray]: frames = [] preview_frame = None z_dim = inp_np.shape[0] if inp_np.ndim == 3 else inp_np.shape[1] - mid_z_idx = max(z_dim // 2, 0) + if preview_slice_index is None: + preview_slice_index = _choose_preview_slice_index( + inp_np, + targets_np, + preds_np, + is_2d_run=is_2d, + ) + if preview_slice_index is None: + preview_slice_index = max(z_dim // 2, 0) + preview_slice_index = int(max(0, min(preview_slice_index, z_dim - 1))) + z_indices = range(z_dim) if save_media else [preview_slice_index] - for z_idx in range(z_dim): + for z_idx in z_indices: rows = [] # Get slices @@ -882,9 +959,14 @@ def _pad_rows_to_uniform_width(rows_list: list[np.ndarray]) -> list[np.ndarray]: frame = np.ascontiguousarray(frame, dtype=np.uint8) frames.append(frame) - if z_idx == mid_z_idx: + if z_idx == preview_slice_index: preview_frame = frame.copy() - + + if not save_media: + if preview_frame is None and frames: + preview_frame = frames[len(frames) // 2].copy() + return None, preview_frame + # Save GIF in a subprocess to avoid crashing main training process on encoder segfaults out_dir = Path(save_path).parent out_dir.mkdir(parents=True, exist_ok=True) From ca224a3bf2c82b07b6496eda3ca0c851ed1cf3a8 Mon Sep 17 00:00:00 2001 From: Giorgio Angelotti <76100950+giorgioangel@users.noreply.github.com> Date: Tue, 31 Mar 2026 12:04:43 +0200 Subject: [PATCH 2/3] Add NUMA and regression tests for Vesuvius runtime changes --- vesuvius/src/vesuvius/models/training/numa.py | 82 +++++++++++++++++++ .../tests/models/test_validation_preview.py | 38 +++++++++ .../models/test_zarr_dataset_dilation.py | 56 +++++++++++++ 3 files changed, 176 insertions(+) create mode 100644 vesuvius/src/vesuvius/models/training/numa.py create mode 100644 vesuvius/tests/models/test_validation_preview.py create mode 100644 vesuvius/tests/models/test_zarr_dataset_dilation.py diff --git a/vesuvius/src/vesuvius/models/training/numa.py b/vesuvius/src/vesuvius/models/training/numa.py new file mode 100644 index 000000000..48236aef2 --- /dev/null +++ b/vesuvius/src/vesuvius/models/training/numa.py @@ -0,0 +1,82 @@ +import os +import re +import subprocess +from functools import lru_cache +from typing import Optional + + +_ANSI_ESCAPE_RE = re.compile(r"\x1b\[[0-9;]*m") +_GPU_ROW_RE = re.compile( + r"^GPU(?P\d+)\s+.*?\s+(?P[0-9,\-]+)\s+(?P\S+)\s+(?P\S+)\s*$" +) + + +def _strip_ansi(text: str) -> str: + return _ANSI_ESCAPE_RE.sub("", text) + + +def _parse_cpu_affinity(spec: str) -> tuple[int, ...]: + cpus: set[int] = set() + for token in spec.split(","): + token = token.strip() + if not token: + continue + if "-" in token: + start_str, end_str = token.split("-", 1) + start = int(start_str) + end = int(end_str) + cpus.update(range(start, end + 1)) + else: + cpus.add(int(token)) + return tuple(sorted(cpus)) + + +@lru_cache(maxsize=1) +def get_gpu_cpu_affinity_map() -> dict[int, tuple[int, ...]]: + try: + topo = subprocess.run( + ["nvidia-smi", "topo", "-m"], + check=True, + capture_output=True, + text=True, + ).stdout + except Exception: + return {} + + mapping: dict[int, tuple[int, ...]] = {} + for raw_line in topo.splitlines(): + line = _strip_ansi(raw_line).strip() + if not line.startswith("GPU"): + continue + match = _GPU_ROW_RE.match(line) + if not match: + continue + gpu_id = int(match.group("gpu")) + cpu_affinity = _parse_cpu_affinity(match.group("cpu")) + if cpu_affinity: + mapping[gpu_id] = cpu_affinity + return mapping + + +def apply_numa_affinity( + mode: str, + assigned_gpu_id: Optional[int], +) -> Optional[dict[str, object]]: + if str(mode).lower() != "auto": + return None + if assigned_gpu_id is None: + return None + if not hasattr(os, "sched_setaffinity"): + return None + + affinity_map = get_gpu_cpu_affinity_map() + cpu_affinity = affinity_map.get(int(assigned_gpu_id)) + if not cpu_affinity: + return None + + os.sched_setaffinity(0, set(cpu_affinity)) + return { + "gpu_id": int(assigned_gpu_id), + "cpu_count": len(cpu_affinity), + "cpu_affinity": cpu_affinity, + } diff --git a/vesuvius/tests/models/test_validation_preview.py b/vesuvius/tests/models/test_validation_preview.py new file mode 100644 index 000000000..f52beb8e7 --- /dev/null +++ b/vesuvius/tests/models/test_validation_preview.py @@ -0,0 +1,38 @@ +import numpy as np + +from vesuvius.utils.plotting import _choose_preview_slice_index + + +def test_choose_preview_slice_prefers_ground_truth_foreground(): + input_volume = np.zeros((4, 8, 8), dtype=np.float32) + targets = {"surface": np.zeros((1, 4, 8, 8), dtype=np.float32)} + preds = {"surface": np.zeros((2, 4, 8, 8), dtype=np.float32)} + + targets["surface"][0, 2, 2:6, 2:6] = 1.0 + preds["surface"][1, 1, 1:7, 1:7] = 0.75 + + preview_idx = _choose_preview_slice_index( + input_volume, + targets, + preds, + is_2d_run=False, + ) + + assert preview_idx == 2 + + +def test_choose_preview_slice_falls_back_to_prediction_mass(): + input_volume = np.zeros((5, 6, 6), dtype=np.float32) + targets = {"surface": np.zeros((1, 5, 6, 6), dtype=np.float32)} + preds = {"surface": np.zeros((2, 5, 6, 6), dtype=np.float32)} + + preds["surface"][1, 4, :, :] = 0.5 + + preview_idx = _choose_preview_slice_index( + input_volume, + targets, + preds, + is_2d_run=False, + ) + + assert preview_idx == 4 diff --git a/vesuvius/tests/models/test_zarr_dataset_dilation.py b/vesuvius/tests/models/test_zarr_dataset_dilation.py new file mode 100644 index 000000000..e8ca973f3 --- /dev/null +++ b/vesuvius/tests/models/test_zarr_dataset_dilation.py @@ -0,0 +1,56 @@ +import numpy as np +from scipy.ndimage import distance_transform_edt + +from vesuvius.models.datasets.zarr_dataset import ZarrDataset + + +def _reference_dilate(values: np.ndarray, distance: float, ignore_label=None) -> np.ndarray: + arr = np.array(values, copy=True) + ignore_mask = np.zeros(arr.shape, dtype=bool) if ignore_label is None else (arr == ignore_label) + source_mask = (arr != 0) & ~ignore_mask + fill_mask = arr == 0 + if not np.any(source_mask) or not np.any(fill_mask): + return arr + + distances, nearest_indices = distance_transform_edt( + ~source_mask, + return_indices=True, + ) + fill_mask &= distances <= float(distance) + arr[fill_mask] = arr[tuple(nearest_indices[axis][fill_mask] for axis in range(arr.ndim))] + return arr + + +def test_roi_dilate_matches_full_patch_result(): + values = np.zeros((12, 12, 12), dtype=np.float32) + values[4:6, 4:6, 4:6] = 1.0 + values[5, 5, 7] = 2.0 + + expected = _reference_dilate(values, distance=2.0) + result, roi_slices = ZarrDataset._dilate_label_patch( + values, + distance=2.0, + ignore_label=None, + original_shape=values.shape, + ) + + np.testing.assert_array_equal(result, expected) + assert roi_slices is not None + + +def test_binary_fast_path_matches_full_patch_result(): + values = np.zeros((16, 16, 16), dtype=np.float32) + values[6:9, 6:9, 6:9] = 1.0 + values[7, 7, 10] = 2.0 + + expected = _reference_dilate(values, distance=2.0, ignore_label=2.0) + result, roi_slices = ZarrDataset._dilate_label_patch( + values, + distance=2.0, + ignore_label=2.0, + original_shape=values.shape, + binary_fast_path=True, + ) + + np.testing.assert_array_equal(result, expected) + assert roi_slices is not None From e9f706a32fedc07e7d3af9811eccdbfb580b9455 Mon Sep 17 00:00:00 2001 From: Giorgio Angelotti <76100950+giorgioangel@users.noreply.github.com> Date: Tue, 31 Mar 2026 12:14:32 +0200 Subject: [PATCH 3/3] Honor config epoch and step defaults in Vesuvius CLI --- vesuvius/src/vesuvius/models/training/cli.py | 414 ------------------- 1 file changed, 414 deletions(-) diff --git a/vesuvius/src/vesuvius/models/training/cli.py b/vesuvius/src/vesuvius/models/training/cli.py index 3175b6d18..e69de29bb 100644 --- a/vesuvius/src/vesuvius/models/training/cli.py +++ b/vesuvius/src/vesuvius/models/training/cli.py @@ -1,414 +0,0 @@ -import argparse -import multiprocessing -import os -import socket -import subprocess -import sys -from pathlib import Path - -import torch - -from vesuvius.models.configuration.config_manager import ConfigManager -from vesuvius.models.datasets.intensity_properties import load_intensity_props_formatted -from vesuvius.models.utilities.cli_utils import update_config_from_args - - -def _maybe_set_spawn_start_method(argv): - # s3fs/fsspec can misbehave with fork; force spawn if s3/config is present. - if not argv: - return - if any('s3://' in str(arg) for arg in argv) or '--config-path' in argv or '--config' in argv: - try: - multiprocessing.set_start_method('spawn', force=True) - except RuntimeError: - pass - - -def main(argv=None): - """Main entry point for the training script.""" - if argv is None: - argv = sys.argv[1:] - - _maybe_set_spawn_start_method(argv) - - parser = argparse.ArgumentParser( - description="Train Vesuvius neural networks for ink detection and segmentation", - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - grp_required = parser.add_argument_group("Required") - grp_paths = parser.add_argument_group("Paths & Format") - grp_data = parser.add_argument_group("Data & Splits") - grp_model = parser.add_argument_group("Model") - grp_train = parser.add_argument_group("Training Control") - grp_optim = parser.add_argument_group("Optimization") - grp_sched = parser.add_argument_group("Scheduler") - grp_trainer = parser.add_argument_group("Trainer Selection") - grp_logging = parser.add_argument_group("Logging & Tracking") - - # Required - grp_required.add_argument("-i", "--input", - help="Input directory containing images/ and labels/ subdirectories.") - grp_required.add_argument("--config", "--config-path", dest="config_path", type=str, required=True, - help="Path to configuration YAML file") - - # Paths & Format - grp_paths.add_argument("-o", "--output", default="checkpoints", - help="Output directory for saving checkpoints and configs") - grp_paths.add_argument("--format", choices=["image", "zarr", "napari"], - help="Data format (auto-detected if omitted)") - grp_paths.add_argument("--val-dir", type=str, - help="Optional validation directory with images/ and labels/") - grp_paths.add_argument("--checkpoint", "--checkpoint-path", dest="checkpoint_path", type=str, - help="Path to checkpoint (.pt/.pth) or weights-only state_dict file") - grp_paths.add_argument("--load-weights-only", action="store_true", - help="Load only model weights from checkpoint; ignore optimizer/scheduler and allow partial load") - grp_paths.add_argument("--rebuild-from-ckpt-config", action="store_true", - help="Rebuild model from checkpoint's model_config before loading weights") - grp_paths.add_argument("--intensity-properties-json", type=str, default=None, - help="nnU-Net style intensity properties JSON for CT normalization") - grp_paths.add_argument("--skip-image-checks", action="store_true", - help="Skip expensive image/zarr existence checks and conversions; assumes images.zarr/labels.zarr already exist") - - # Data & Splits - grp_data.add_argument("--batch-size", type=int, - help="Training batch size") - grp_data.add_argument("--patch-size", type=str, - help="Patch size CSV, e.g. '192,192,192' (3D) or '256,256' (2D)") - grp_data.add_argument("--loss", type=str, - help="Loss functions, e.g. '[SoftDiceLoss, BCEWithLogitsLoss]' or CSV") - grp_data.add_argument("--train-split", type=float, - help="Training/validation split ratio in [0,1]") - grp_data.add_argument("--seed", type=int, default=42, - help="Random seed for split/initialization") - grp_data.add_argument("--skip-intensity-sampling", dest="skip_intensity_sampling", - action="store_true", default=True, - help="Skip intensity sampling during dataset init") - grp_data.add_argument("--no-skip-intensity-sampling", dest="skip_intensity_sampling", - action="store_false", - help="Enable intensity sampling during dataset init") - grp_data.add_argument("--no-spatial", action="store_true", - help="Disable spatial/geometric augmentations") - grp_data.add_argument("--rotation-axes", type=str, - help="Comma-separated axes (subset of x,y,z / width,height,depth) that may be rotated; e.g. 'z' keeps the depth axis upright") - - # Model - grp_model.add_argument("--model-name", type=str, - help="Model name for checkpoints and logging") - grp_model.add_argument("--nonlin", type=str, choices=["LeakyReLU", "ReLU", "SwiGLU", "swiglu", "GLU", "glu"], - help="Activation function") - grp_model.add_argument("--se", action="store_true", help="Enable squeeze and excitation modules in the encoder") - grp_model.add_argument("--se-reduction-ratio", type=float, default=0.0625, - help="Squeeze excitation reduction ratio") - grp_model.add_argument("--pool-type", type=str, choices=["avg", "max", "conv"], - help="Type of pooling in encoder ('conv' = strided conv)") - - # Training Control - grp_train.add_argument("--max-epoch", type=int, default=1000, - help="Maximum number of epochs") - grp_train.add_argument("--max-steps-per-epoch", type=int, default=250, - help="Max training steps per epoch (use all data if unset)") - grp_train.add_argument("--max-val-steps-per-epoch", type=int, default=50, - help="Max validation steps per epoch (use all data if unset)") - grp_train.add_argument("--full-epoch", action="store_true", - help="Iterate over entire train/val set per epoch (overrides max-steps)") - grp_train.add_argument("--early-stopping-patience", type=int, default=0, - help="Epochs to wait for val loss improvement (0 disables)") - grp_train.add_argument("--ddp", action="store_true", - help="Enable DistributedDataParallel (use with torchrun)") - grp_train.add_argument("--val-every-n", dest="val_every_n", type=int, default=1, - help="Perform validation every N epochs (1=every epoch)") - grp_train.add_argument("--num-dataloader-workers", type=int, - help="Training dataloader worker count") - grp_train.add_argument("--val-num-dataloader-workers", type=int, - help="Validation dataloader worker count") - grp_train.add_argument("--persistent-workers", dest="persistent_workers", action="store_true", - help="Enable persistent dataloader workers") - grp_train.add_argument("--no-persistent-workers", dest="persistent_workers", action="store_false", - help="Disable persistent dataloader workers") - grp_train.add_argument("--prefetch-factor", type=int, - help="Training dataloader prefetch factor") - grp_train.add_argument("--val-prefetch-factor", type=int, - help="Validation dataloader prefetch factor") - grp_train.add_argument("--debug-visualization-every-n", type=int, - help="Save debug GIF/PNG media every N validation epochs (0 disables media saves)") - grp_train.add_argument("--validation-preview-pool-size", type=int, - help="Number of globally ordered validation patches to rotate through for W&B previews") - grp_train.add_argument("--log-every-n-steps", type=int, - help="Log training metrics to W&B every N optimizer steps") - grp_train.add_argument("--numa-pin", type=str, choices=["auto", "off"], - help="NUMA affinity mode for CUDA DDP workers") - grp_train.add_argument("--gpus", type=str, default=None, - help="Comma-separated GPU device IDs to use, e.g. '0,1,3'. With DDP, length must equal WORLD_SIZE") - grp_train.add_argument("--nproc-per-node", type=int, default=None, - help="Number of processes to spawn locally for DDP (use instead of torchrun)") - grp_train.add_argument("--master-addr", type=str, default="127.0.0.1", - help="Master address for DDP when spawning without torchrun") - grp_train.add_argument("--master-port", type=int, default=None, - help="Master port for DDP when spawning without torchrun (default: auto)") - grp_train.set_defaults(persistent_workers=None) - - # Optimization - grp_optim.add_argument("--optimizer", type=str, - help="Optimizer (see models/optimizers.py)") - grp_optim.add_argument("--grad-accum", "--gradient-accumulation", dest="gradient_accumulation", type=int, default=None, - help="Number of steps to accumulate gradients before optimizer.step()") - grp_optim.add_argument("--grad-clip", type=float, default=12.0, - help="Gradient clipping value") - grp_optim.add_argument("--amp-dtype", type=str, choices=["float16", "bfloat16"], default="float16", - help="Autocast dtype when AMP is enabled (float16 uses GradScaler; bfloat16 skips scaling)") - grp_optim.add_argument("--no-amp", action="store_true", - help="Disable Automatic Mixed Precision (AMP)") - - # Scheduler - grp_sched.add_argument("--scheduler", type=str, - help="Learning rate scheduler (default: from config or 'poly')") - grp_sched.add_argument("--warmup-steps", type=int, - help="Number of warmup steps for cosine_warmup scheduler") - - # Trainer Selection - grp_trainer.add_argument("--trainer", "--tr", type=str, default="base", - help="Trainer: base, surface_frame, mean_teacher, uncertainty_aware_mean_teacher, primus_mae, unet_mae, finetune_mae_unet") - grp_trainer.add_argument("--ssl-warmup", type=int, default=None, - help="Semi-supervised: epochs to ignore EMA consistency loss (0 disables)") - # Semi-supervised sampling controls (used by mean_teacher/uncertainty_aware_mean_teacher) - grp_trainer.add_argument("--labeled-ratio", type=float, default=None, - help="Fraction of labeled patches to use (0-1). If set, overrides trainer default") - grp_trainer.add_argument("--num-labeled", type=int, default=None, - help="Absolute number of labeled patches to use (overrides --labeled-ratio if provided)") - grp_trainer.add_argument("--labeled-batch-size", type=int, default=None, - help="Number of labeled patches per batch (rest are unlabeled) for two-stream sampler") - - # Only valid for finetune_mae_unet: path to the pretrained MAE checkpoint to initialize from - grp_trainer.add_argument("--pretrained_checkpoint", type=str, default=None, - help="Pretrained MAE checkpoint path (required when --trainer finetune_mae_unet). Invalid for other trainers.") - - # Logging & Tracking - grp_logging.add_argument("--wandb-project", type=str, default=None, - help="Weights & Biases project (omit to disable wandb)") - grp_logging.add_argument("--wandb-entity", type=str, default=None, - help="Weights & Biases team/username") - grp_logging.add_argument("--wandb-run-name", type=str, default=None, - help="Optional custom name for the Weights & Biases run") - grp_logging.add_argument("--wandb-resume", nargs='?', const='allow', default=None, - help="Weights & Biases resume mode or run id. Provide a resume policy ('allow', 'auto', 'must', 'never') or a run id (defaults to 'allow' if flag used without value).") - grp_logging.add_argument("--profile-augmentations", action="store_true", - help="Collect per-augmentation timing and report per-epoch totals") - grp_logging.add_argument("--verbose", action="store_true", - help="Enable verbose debug output") - - args = parser.parse_args(argv) - - mgr = ConfigManager(verbose=args.verbose) - - if not Path(args.config_path).exists(): - print(f"\nError: Config file does not exist: {args.config_path}") - print("\nPlease provide a valid configuration file.") - print("\nExample usage:") - print(" vesuvius.train --config path/to/config.yaml --input path/to/data --output path/to/output") - print("\nFor more options, use: vesuvius.train --help") - sys.exit(1) - - mgr.load_config(args.config_path) - print(f"Loaded configuration from: {args.config_path}") - - # Resolve input path: CLI arg takes precedence, else use config's data_path - if args.input is not None: - input_path = Path(args.input) - elif mgr.data_path is not None: - input_path = mgr.data_path - args.input = str(input_path) # Update args so downstream code sees it - else: - print("\nError: No input directory specified.") - print("Provide --input on the command line OR set data_path in your YAML config.") - sys.exit(1) - - if not input_path.exists(): - raise ValueError(f"Input directory does not exist: {input_path}") - - if args.val_dir is not None and not Path(args.val_dir).exists(): - raise ValueError(f"Validation directory does not exist: {args.val_dir}") - - Path(args.output).mkdir(parents=True, exist_ok=True) - - update_config_from_args(mgr, args) - - # Validation frequency - if hasattr(args, 'val_every_n') and args.val_every_n is not None: - if int(args.val_every_n) < 1: - raise ValueError(f"--val-every-n must be >= 1, got {args.val_every_n}") - setattr(mgr, 'val_every_n', int(args.val_every_n)) - mgr.tr_configs["val_every_n"] = int(args.val_every_n) - if args.verbose: - print(f"Validate every {args.val_every_n} epoch(s)") - - # Enable DDP if requested or if torchrun sets WORLD_SIZE>1 - if getattr(args, 'ddp', False) or int(os.environ.get('WORLD_SIZE', '1')) > 1: - setattr(mgr, 'use_ddp', True) - # In DDP, --batch-size is per-GPU; no extra adjustment needed. - - # Parse GPUs selection if provided - if getattr(args, 'gpus', None): - try: - gpu_ids = [int(x) for x in str(args.gpus).split(',') if x.strip() != ''] - except ValueError as exc: - raise ValueError("--gpus must be a comma-separated list of integers, e.g. '0,1,3'") from exc - setattr(mgr, 'gpu_ids', gpu_ids) - - if args.val_dir is not None: - mgr.val_data_path = Path(args.val_dir) - - # If user supplies intensity properties JSON, load and inject into config for CT normalization - if args.intensity_properties_json is not None: - ip_path = Path(args.intensity_properties_json) - if not ip_path.exists(): - raise ValueError(f"Intensity properties JSON not found: {ip_path}") - props = load_intensity_props_formatted(ip_path, channel=0) - if not props: - raise ValueError(f"Failed to parse intensity properties JSON: {ip_path}") - if hasattr(mgr, 'update_config'): - mgr.update_config(normalization_scheme='ct', intensity_properties=props) - else: - mgr.dataset_config = getattr(mgr, 'dataset_config', {}) - mgr.dataset_config['normalization_scheme'] = 'ct' - mgr.dataset_config['intensity_properties'] = props - setattr(mgr, 'skip_intensity_sampling', True) - print("Using provided intensity properties for CT normalization. Sampling disabled.") - - # If DDP is requested but not launched with torchrun, optionally self-spawn processes - if getattr(mgr, 'use_ddp', False) and int(os.environ.get('WORLD_SIZE', '1')) == 1: - # Determine process count - nproc = args.nproc_per_node - if nproc is None: - # Default to number of requested GPUs, else CUDA device count, else 1 - gpu_ids = getattr(mgr, 'gpu_ids', None) - if gpu_ids: - nproc = len(gpu_ids) - elif torch.cuda.is_available(): - try: - nproc = torch.cuda.device_count() - except Exception: - nproc = 1 - else: - nproc = 1 - - if nproc > 1: - # Validate GPU mapping length if provided - gpu_ids = getattr(mgr, 'gpu_ids', None) - if gpu_ids and len(gpu_ids) != nproc: - raise ValueError(f"--gpus specifies {len(gpu_ids)} GPUs but --nproc-per-node is {nproc}. They must match.") - - # Find a free port if not provided - master_port = args.master_port - if master_port is None: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind((args.master_addr, 0)) - master_port = s.getsockname()[1] - - print(f"Spawning {nproc} DDP processes (master {args.master_addr}:{master_port}) without torchrun...") - - # Rebuild argv without the spawn-only flags; children don't need them - skip_next = False - child_argv = [] - for a in argv: - if skip_next: - skip_next = False - continue - if a in ("--nproc-per-node", "--master-addr", "--master-port"): - skip_next = True - continue - child_argv.append(a) - - procs = [] - for rank in range(nproc): - env = os.environ.copy() - env.update({ - 'RANK': str(rank), - 'LOCAL_RANK': str(rank), - 'WORLD_SIZE': str(nproc), - 'MASTER_ADDR': args.master_addr, - 'MASTER_PORT': str(master_port), - }) - cmd = [sys.executable, sys.argv[0], *child_argv] - # Use unbuffered -u for timely logs on Windows/Unix - if '-u' not in cmd: - cmd.insert(1, '-u') - procs.append(subprocess.Popen(cmd, env=env)) - - exit_code = 0 - for p in procs: - ret = p.wait() - if ret != 0: - exit_code = ret - sys.exit(exit_code) - else: - print("DDP requested but only one process determined; proceeding single-process.") - - trainer_name = args.trainer.lower() - mgr.trainer_class = trainer_name - - # Enforce usage of --pretrained_checkpoint only for the MAE finetune trainer, and require it there - if getattr(args, 'pretrained_checkpoint', None): - if trainer_name != "finetune_mae_unet": - raise ValueError("--pretrained_checkpoint is only valid when using --trainer finetune_mae_unet") - # Stash onto mgr so the finetune trainer can load it - setattr(mgr, 'pretrained_mae_checkpoint', args.pretrained_checkpoint) - mgr.tr_info["pretrained_mae_checkpoint"] = args.pretrained_checkpoint - elif trainer_name == "finetune_mae_unet": - # For finetune trainer the pretrained checkpoint is mandatory - raise ValueError("Missing --pretrained_checkpoint: required for --trainer finetune_mae_unet") - - if trainer_name == "uncertainty_aware_mean_teacher": - mgr.allow_unlabeled_data = True - from vesuvius.models.training.trainers.semi_supervised.train_uncertainty_aware_mean_teacher import TrainUncertaintyAwareMeanTeacher - trainer = TrainUncertaintyAwareMeanTeacher(mgr=mgr, verbose=args.verbose) - print("Using Uncertainty-Aware Mean Teacher Trainer for semi-supervised 3D training") - elif trainer_name == "mean_teacher": - mgr.allow_unlabeled_data = True - from vesuvius.models.training.trainers.semi_supervised.train_mean_teacher import TrainMeanTeacher - trainer = TrainMeanTeacher(mgr=mgr, verbose=args.verbose) - print("Using Regular Mean Teacher Trainer for semi-supervised training") - elif trainer_name == "primus_mae": - mgr.allow_unlabeled_data = True - from vesuvius.models.training.trainers.self_supervised.train_eva_mae import TrainEVAMAE - trainer = TrainEVAMAE(mgr=mgr, verbose=args.verbose) - print("Using EVA (Primus) Architecture for MAE Pretraining") - elif trainer_name == "unet_mae": - mgr.allow_unlabeled_data = True - from vesuvius.models.training.trainers.self_supervised.train_unet_mae import TrainUNetMAE - trainer = TrainUNetMAE(mgr=mgr, verbose=args.verbose) - print("Using UNet-style MAE Trainer (NetworkFromConfig)") - elif trainer_name == "finetune_mae_unet": - from vesuvius.models.training.trainers.self_supervised.train_finetune_mae_unet import TrainFineTuneMAEUNet - trainer = TrainFineTuneMAEUNet(mgr=mgr, verbose=args.verbose) - print("Using Fine-Tune MAE->UNet Trainer (NetworkFromConfig)") - elif trainer_name == "lejepa": - mgr.allow_unlabeled_data = True - from vesuvius.models.training.trainers.self_supervised.train_lejepa import TrainLeJEPA - trainer = TrainLeJEPA(mgr=mgr, verbose=args.verbose) - print("Using LeJEPA Trainer (Primus + SIGReg) for unsupervised pretraining") - elif trainer_name == "mutex_affinity": - from vesuvius.models.training.trainers.mutex_affinity_trainer import MutexAffinityTrainer - trainer = MutexAffinityTrainer(mgr=mgr, verbose=args.verbose) - print("Using Mutex Affinity Trainer") - elif trainer_name == "base": - from vesuvius.models.training.train import BaseTrainer - trainer = BaseTrainer(mgr=mgr, verbose=args.verbose) - print("Using Base Trainer for supervised training") - elif trainer_name == "surface_frame": - from vesuvius.models.training.trainers.surface_frame_trainer import SurfaceFrameTrainer - trainer = SurfaceFrameTrainer(mgr=mgr, verbose=args.verbose) - print("Using Surface Frame Trainer") - else: - raise ValueError( - "Unknown trainer: {trainer}. Available options: base, surface_frame, mutex_affinity, mean_teacher, " - "uncertainty_aware_mean_teacher, primus_mae, unet_mae, finetune_mae_unet, lejepa".format(trainer=trainer_name) - ) - - print("Starting training...") - trainer.train() - print("Training completed!") - - -if __name__ == '__main__': - main()