From 53d8019ed220780d575815a85907d4516c217480 Mon Sep 17 00:00:00 2001 From: Giorgio Angelotti <76100950+giorgioangel@users.noreply.github.com> Date: Sat, 28 Mar 2026 09:26:50 +0100 Subject: [PATCH 1/4] Add multiscale surface training configs and BCE+Dice ignore-label support --- .../models/configuration/config_manager.py | 10 + .../surface_resenc_s0_ps128_bs28_bcedice.yaml | 44 +++ .../surface_resenc_s0_ps128_bs28_msr.yaml | 42 +++ .../surface_resenc_s0_ps256_bs3_bcedice.yaml | 44 +++ .../surface_resenc_s0_ps256_bs3_msr.yaml | 42 +++ .../surface_resenc_s2_ps128_bs28_bcedice.yaml | 43 +++ .../surface_resenc_s2_ps128_bs28_msr.yaml | 42 +++ .../surface_resenc_s2_ps256_bs3_bcedice.yaml | 44 +++ .../surface_resenc_s2_ps256_bs3_msr.yaml | 42 +++ .../vesuvius/models/datasets/zarr_dataset.py | 34 ++- .../models/preprocessing/patches/cache.py | 10 +- .../models/preprocessing/patches/generate.py | 20 +- .../vesuvius/models/training/loss/losses.py | 65 ++++- .../src/vesuvius/scripts/probe_surface_fit.py | 269 ++++++++++++++++++ .../test_surface_multiscale_training.py | 91 ++++++ 15 files changed, 837 insertions(+), 5 deletions(-) create mode 100644 vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps128_bs28_bcedice.yaml create mode 100644 vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps128_bs28_msr.yaml create mode 100644 vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps256_bs3_bcedice.yaml create mode 100644 vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps256_bs3_msr.yaml create mode 100644 vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps128_bs28_bcedice.yaml create mode 100644 vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps128_bs28_msr.yaml create mode 100644 vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps256_bs3_bcedice.yaml create mode 100644 vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps256_bs3_msr.yaml create mode 100644 vesuvius/src/vesuvius/scripts/probe_surface_fit.py create mode 100644 vesuvius/tests/models/test_surface_multiscale_training.py diff --git a/vesuvius/src/vesuvius/models/configuration/config_manager.py b/vesuvius/src/vesuvius/models/configuration/config_manager.py index e724e825c..b1c70588d 100755 --- a/vesuvius/src/vesuvius/models/configuration/config_manager.py +++ b/vesuvius/src/vesuvius/models/configuration/config_manager.py @@ -273,6 +273,16 @@ def _init_attributes(self): # Chunk-slicing worker configuration self.valid_patch_find_resolution = int(self.dataset_config.get("valid_patch_find_resolution", 1)) + self.ome_zarr_resolution = int(self.dataset_config.get("ome_zarr_resolution", 0)) + if self.ome_zarr_resolution < 0: + raise ValueError( + f"dataset_config.ome_zarr_resolution must be >= 0, got {self.ome_zarr_resolution}" + ) + if self.valid_patch_find_resolution < self.ome_zarr_resolution: + raise ValueError( + "dataset_config.valid_patch_find_resolution must be >= dataset_config.ome_zarr_resolution " + f"(got {self.valid_patch_find_resolution} < {self.ome_zarr_resolution})" + ) self.num_workers = int(self.dataset_config.get("num_workers", 8)) # Worker configuration for image→Zarr pipeline diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps128_bs28_bcedice.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps128_bs28_bcedice.yaml new file mode 100644 index 000000000..b3c2f0e00 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps128_bs28_bcedice.yaml @@ -0,0 +1,44 @@ +tr_setup: + wandb_project: "srf_2um" + wandb_entity: "vesuvius-challenge" + model_name: "surface_resenc_s0_ps128_bs28_bcedice" + tr_val_split: 0.95 + autoconfigure: false + +tr_config: + patch_size: [128, 128, 128] + batch_size: 28 + num_dataloader_workers: 14 + +model_config: + basic_encoder_block: "BasicBlockD" + bottleneck_block: "BasicBlockD" + basic_decoder_block: "ConvBlock" + norm_op: "nn.InstanceNorm3d" + nonlin: "nn.LeakyReLU" + features_per_stage: [32, 64, 128, 256, 320, 320] + n_stages: 6 + n_blocks_per_stage: [1, 3, 4, 6, 6, 6] + n_conv_per_stage_decoder: [1, 1, 1, 1, 1] + kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]] + strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] + separate_decoders: true + +dataset_config: + data_path: "/ephemeral/datasets" + ome_zarr_resolution: 0 + min_labeled_ratio: 0.001 + min_bbox_percent: 0.35 + valid_patch_find_resolution: 3 + targets: + surface: + out_channels: 1 + valid_patch_value: 1 + activation: "sigmoid" + ignore_label: 2 + losses: + - name: "BinaryBCEAndDiceLoss" + weight: 1.0 + weight_bce: 1.0 + weight_dice: 1.0 + diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps128_bs28_msr.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps128_bs28_msr.yaml new file mode 100644 index 000000000..c5d91d654 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps128_bs28_msr.yaml @@ -0,0 +1,42 @@ +tr_setup: + wandb_project: "srf_2um" + wandb_entity: "vesuvius-challenge" + model_name: "surface_resenc_s0_ps128_bs28_msr" + tr_val_split: 0.95 + autoconfigure: false + +tr_config: + patch_size: [128, 128, 128] + batch_size: 28 + num_dataloader_workers: 14 + +model_config: + basic_encoder_block: "BasicBlockD" + bottleneck_block: "BasicBlockD" + basic_decoder_block: "ConvBlock" + norm_op: "nn.InstanceNorm3d" + nonlin: "nn.LeakyReLU" + features_per_stage: [32, 64, 128, 256, 320, 320] + n_stages: 6 + n_blocks_per_stage: [1, 3, 4, 6, 6, 6] + n_conv_per_stage_decoder: [1, 1, 1, 1, 1] + kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]] + strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] + separate_decoders: true + +dataset_config: + data_path: "/ephemeral/datasets" + ome_zarr_resolution: 0 + min_labeled_ratio: 0.001 + min_bbox_percent: 0.35 + valid_patch_find_resolution: 3 + targets: + surface: + out_channels: 2 + valid_patch_value: 1 + activation: "none" + ignore_label: 2 + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 + diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps256_bs3_bcedice.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps256_bs3_bcedice.yaml new file mode 100644 index 000000000..9500bdadb --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps256_bs3_bcedice.yaml @@ -0,0 +1,44 @@ +tr_setup: + wandb_project: "srf_2um" + wandb_entity: "vesuvius-challenge" + model_name: "surface_resenc_s0_ps256_bs3_bcedice" + tr_val_split: 0.95 + autoconfigure: false + +tr_config: + patch_size: [256, 256, 256] + batch_size: 3 + num_dataloader_workers: 14 + +model_config: + basic_encoder_block: "BasicBlockD" + bottleneck_block: "BasicBlockD" + basic_decoder_block: "ConvBlock" + norm_op: "nn.InstanceNorm3d" + nonlin: "nn.LeakyReLU" + features_per_stage: [32, 64, 128, 256, 320, 320] + n_stages: 6 + n_blocks_per_stage: [1, 3, 4, 6, 6, 6] + n_conv_per_stage_decoder: [1, 1, 1, 1, 1] + kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]] + strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] + separate_decoders: true + +dataset_config: + data_path: "/ephemeral/datasets" + ome_zarr_resolution: 0 + min_labeled_ratio: 0.001 + min_bbox_percent: 0.35 + valid_patch_find_resolution: 3 + targets: + surface: + out_channels: 1 + valid_patch_value: 1 + activation: "sigmoid" + ignore_label: 2 + losses: + - name: "BinaryBCEAndDiceLoss" + weight: 1.0 + weight_bce: 1.0 + weight_dice: 1.0 + diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps256_bs3_msr.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps256_bs3_msr.yaml new file mode 100644 index 000000000..ae81ccf02 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps256_bs3_msr.yaml @@ -0,0 +1,42 @@ +tr_setup: + wandb_project: "srf_2um" + wandb_entity: "vesuvius-challenge" + model_name: "surface_resenc_s0_ps256_bs3_msr" + tr_val_split: 0.95 + autoconfigure: false + +tr_config: + patch_size: [256, 256, 256] + batch_size: 3 + num_dataloader_workers: 14 + +model_config: + basic_encoder_block: "BasicBlockD" + bottleneck_block: "BasicBlockD" + basic_decoder_block: "ConvBlock" + norm_op: "nn.InstanceNorm3d" + nonlin: "nn.LeakyReLU" + features_per_stage: [32, 64, 128, 256, 320, 320] + n_stages: 6 + n_blocks_per_stage: [1, 3, 4, 6, 6, 6] + n_conv_per_stage_decoder: [1, 1, 1, 1, 1] + kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]] + strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] + separate_decoders: true + +dataset_config: + data_path: "/ephemeral/datasets" + ome_zarr_resolution: 0 + min_labeled_ratio: 0.001 + min_bbox_percent: 0.35 + valid_patch_find_resolution: 3 + targets: + surface: + out_channels: 2 + valid_patch_value: 1 + activation: "none" + ignore_label: 2 + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 + diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps128_bs28_bcedice.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps128_bs28_bcedice.yaml new file mode 100644 index 000000000..1a43d1381 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps128_bs28_bcedice.yaml @@ -0,0 +1,43 @@ +tr_setup: + wandb_project: "srf_2um" + wandb_entity: "vesuvius-challenge" + model_name: "surface_resenc_s2_ps128_bs28_bcedice" + tr_val_split: 0.95 + autoconfigure: false + +tr_config: + patch_size: [128, 128, 128] + batch_size: 28 + num_dataloader_workers: 14 + +model_config: + basic_encoder_block: "BasicBlockD" + bottleneck_block: "BasicBlockD" + basic_decoder_block: "ConvBlock" + norm_op: "nn.InstanceNorm3d" + nonlin: "nn.LeakyReLU" + features_per_stage: [32, 64, 128, 256, 320, 320] + n_stages: 6 + n_blocks_per_stage: [1, 3, 4, 6, 6, 6] + n_conv_per_stage_decoder: [1, 1, 1, 1, 1] + kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]] + strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] + separate_decoders: true + +dataset_config: + data_path: "/ephemeral/datasets" + ome_zarr_resolution: 2 + min_labeled_ratio: 0.001 + min_bbox_percent: 0.35 + valid_patch_find_resolution: 3 + targets: + surface: + out_channels: 1 + valid_patch_value: 1 + activation: "sigmoid" + ignore_label: 2 + losses: + - name: "BinaryBCEAndDiceLoss" + weight: 1.0 + weight_bce: 1.0 + weight_dice: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps128_bs28_msr.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps128_bs28_msr.yaml new file mode 100644 index 000000000..22ab9abf1 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps128_bs28_msr.yaml @@ -0,0 +1,42 @@ +tr_setup: + wandb_project: "srf_2um" + wandb_entity: "vesuvius-challenge" + model_name: "surface_resenc_s2_ps128_bs28_msr" + tr_val_split: 0.95 + autoconfigure: false + +tr_config: + patch_size: [128, 128, 128] + batch_size: 28 + num_dataloader_workers: 14 + +model_config: + basic_encoder_block: "BasicBlockD" + bottleneck_block: "BasicBlockD" + basic_decoder_block: "ConvBlock" + norm_op: "nn.InstanceNorm3d" + nonlin: "nn.LeakyReLU" + features_per_stage: [32, 64, 128, 256, 320, 320] + n_stages: 6 + n_blocks_per_stage: [1, 3, 4, 6, 6, 6] + n_conv_per_stage_decoder: [1, 1, 1, 1, 1] + kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]] + strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] + separate_decoders: true + +dataset_config: + data_path: "/ephemeral/datasets" + ome_zarr_resolution: 2 + min_labeled_ratio: 0.001 + min_bbox_percent: 0.35 + valid_patch_find_resolution: 3 + targets: + surface: + out_channels: 2 + valid_patch_value: 1 + activation: "none" + ignore_label: 2 + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 + diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps256_bs3_bcedice.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps256_bs3_bcedice.yaml new file mode 100644 index 000000000..3f00feae6 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps256_bs3_bcedice.yaml @@ -0,0 +1,44 @@ +tr_setup: + wandb_project: "srf_2um" + wandb_entity: "vesuvius-challenge" + model_name: "surface_resenc_s2_ps256_bs3_bcedice" + tr_val_split: 0.95 + autoconfigure: false + +tr_config: + patch_size: [256, 256, 256] + batch_size: 3 + num_dataloader_workers: 14 + +model_config: + basic_encoder_block: "BasicBlockD" + bottleneck_block: "BasicBlockD" + basic_decoder_block: "ConvBlock" + norm_op: "nn.InstanceNorm3d" + nonlin: "nn.LeakyReLU" + features_per_stage: [32, 64, 128, 256, 320, 320] + n_stages: 6 + n_blocks_per_stage: [1, 3, 4, 6, 6, 6] + n_conv_per_stage_decoder: [1, 1, 1, 1, 1] + kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]] + strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] + separate_decoders: true + +dataset_config: + data_path: "/ephemeral/datasets" + ome_zarr_resolution: 2 + min_labeled_ratio: 0.001 + min_bbox_percent: 0.35 + valid_patch_find_resolution: 3 + targets: + surface: + out_channels: 1 + valid_patch_value: 1 + activation: "sigmoid" + ignore_label: 2 + losses: + - name: "BinaryBCEAndDiceLoss" + weight: 1.0 + weight_bce: 1.0 + weight_dice: 1.0 + diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps256_bs3_msr.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps256_bs3_msr.yaml new file mode 100644 index 000000000..41717ff89 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps256_bs3_msr.yaml @@ -0,0 +1,42 @@ +tr_setup: + wandb_project: "srf_2um" + wandb_entity: "vesuvius-challenge" + model_name: "surface_resenc_s2_ps256_bs3_msr" + tr_val_split: 0.95 + autoconfigure: false + +tr_config: + patch_size: [256, 256, 256] + batch_size: 3 + num_dataloader_workers: 14 + +model_config: + basic_encoder_block: "BasicBlockD" + bottleneck_block: "BasicBlockD" + basic_decoder_block: "ConvBlock" + norm_op: "nn.InstanceNorm3d" + nonlin: "nn.LeakyReLU" + features_per_stage: [32, 64, 128, 256, 320, 320] + n_stages: 6 + n_blocks_per_stage: [1, 3, 4, 6, 6, 6] + n_conv_per_stage_decoder: [1, 1, 1, 1, 1] + kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]] + strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] + separate_decoders: true + +dataset_config: + data_path: "/ephemeral/datasets" + ome_zarr_resolution: 2 + min_labeled_ratio: 0.001 + min_bbox_percent: 0.35 + valid_patch_find_resolution: 3 + targets: + surface: + out_channels: 2 + valid_patch_value: 1 + activation: "none" + ignore_label: 2 + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 + diff --git a/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py b/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py index 99167546a..d164ee684 100644 --- a/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py +++ b/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py @@ -120,6 +120,7 @@ def __init__( # OME-Zarr parameters self.ome_zarr_resolution = getattr(mgr, 'ome_zarr_resolution', 0) + self.ome_zarr_scale_factor = 2 ** int(self.ome_zarr_resolution) self.valid_patch_find_resolution = getattr(mgr, 'valid_patch_find_resolution', 1) # Semi-supervised parameters @@ -321,6 +322,26 @@ def _build_patch_index(self) -> None: self._n_labeled_fg, self._n_unlabeled_fg, len(self._patches) ) + @staticmethod + def _cached_position_to_training_level( + position: Tuple[int, ...], + ome_zarr_resolution: int, + ) -> Tuple[int, ...]: + """Convert cached full-resolution coordinates into the selected training level.""" + scale_factor = 2 ** int(ome_zarr_resolution) + if scale_factor == 1: + return tuple(int(v) for v in position) + + converted = [] + for coord in position: + coord = int(coord) + if coord % scale_factor != 0: + raise ValueError( + f"Cached full-resolution position {position} is not divisible by scale factor {scale_factor}" + ) + converted.append(coord // scale_factor) + return tuple(converted) + def _load_from_mapping(self, mapping_file: Path) -> None: """ Load valid patches from samples_mapping.json (packed sparse zarr format). @@ -551,6 +572,7 @@ def _try_load_cache(self): min_labeled_ratio=self.min_labeled_ratio, bbox_threshold=self.min_bbox_percent, valid_patch_find_resolution=self.valid_patch_find_resolution, + ome_zarr_resolution=self.ome_zarr_resolution, valid_patch_value=valid_patch_value, unlabeled_fg_enabled=self.unlabeled_fg_enabled, unlabeled_fg_threshold=self.unlabeled_fg_threshold, @@ -565,10 +587,14 @@ def _load_from_cache(self, cache_data) -> None: # Add FG patches for entry in cache_data.fg_patches: vol_idx = volume_name_to_idx.get(entry.volume_name, entry.volume_idx) + position = self._cached_position_to_training_level( + entry.position, + self.ome_zarr_resolution, + ) self._patches.append(PatchInfo( volume_index=vol_idx, volume_name=entry.volume_name, - position=entry.position, + position=position, patch_size=self.patch_size, is_unlabeled_fg=False, )) @@ -577,10 +603,14 @@ def _load_from_cache(self, cache_data) -> None: # Add unlabeled FG patches for entry in cache_data.unlabeled_fg_patches: vol_idx = volume_name_to_idx.get(entry.volume_name, entry.volume_idx) + position = self._cached_position_to_training_level( + entry.position, + self.ome_zarr_resolution, + ) self._patches.append(PatchInfo( volume_index=vol_idx, volume_name=entry.volume_name, - position=entry.position, + position=position, patch_size=self.patch_size, is_unlabeled_fg=True, )) diff --git a/vesuvius/src/vesuvius/models/preprocessing/patches/cache.py b/vesuvius/src/vesuvius/models/preprocessing/patches/cache.py index e91ed67d5..eb27e851d 100644 --- a/vesuvius/src/vesuvius/models/preprocessing/patches/cache.py +++ b/vesuvius/src/vesuvius/models/preprocessing/patches/cache.py @@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple -SCHEMA_VERSION = 3 +SCHEMA_VERSION = 4 @dataclass(frozen=True) @@ -30,6 +30,7 @@ class PatchCacheParams: min_labeled_ratio: float bbox_threshold: float valid_patch_find_resolution: int + ome_zarr_resolution: int = 0 valid_patch_value: Optional[float] = None unlabeled_fg_enabled: bool = True unlabeled_fg_threshold: float = 0.05 @@ -45,6 +46,7 @@ def to_dict(self) -> Dict[str, Any]: "min_labeled_ratio": float(self.min_labeled_ratio), "bbox_threshold": float(self.bbox_threshold), "valid_patch_find_resolution": int(self.valid_patch_find_resolution), + "ome_zarr_resolution": int(self.ome_zarr_resolution), "valid_patch_value": self.valid_patch_value, "unlabeled_fg_enabled": bool(self.unlabeled_fg_enabled), "unlabeled_fg_threshold": float(self.unlabeled_fg_threshold), @@ -78,6 +80,7 @@ def build_cache_params( min_labeled_ratio: float, bbox_threshold: float, valid_patch_find_resolution: int, + ome_zarr_resolution: int = 0, valid_patch_value: Optional[float] = None, unlabeled_fg_enabled: bool = True, unlabeled_fg_threshold: float = 0.05, @@ -91,6 +94,7 @@ def build_cache_params( min_labeled_ratio=float(min_labeled_ratio), bbox_threshold=float(bbox_threshold), valid_patch_find_resolution=int(valid_patch_find_resolution), + ome_zarr_resolution=int(ome_zarr_resolution), valid_patch_value=valid_patch_value, unlabeled_fg_enabled=bool(unlabeled_fg_enabled), unlabeled_fg_threshold=float(unlabeled_fg_threshold), @@ -223,6 +227,7 @@ def try_load_patch_cache( min_labeled_ratio: float, bbox_threshold: float, valid_patch_find_resolution: int, + ome_zarr_resolution: int = 0, valid_patch_value: Optional[float] = None, unlabeled_fg_enabled: bool = True, unlabeled_fg_threshold: float = 0.05, @@ -247,6 +252,8 @@ def try_load_patch_cache( Minimum bounding box coverage. valid_patch_find_resolution : int Multi-resolution level for patch finding. + ome_zarr_resolution : int + Multi-resolution level used for training data reads. valid_patch_value : Optional[float] Specific label value to match. unlabeled_fg_enabled : bool @@ -268,6 +275,7 @@ def try_load_patch_cache( min_labeled_ratio=min_labeled_ratio, bbox_threshold=bbox_threshold, valid_patch_find_resolution=valid_patch_find_resolution, + ome_zarr_resolution=ome_zarr_resolution, valid_patch_value=valid_patch_value, unlabeled_fg_enabled=unlabeled_fg_enabled, unlabeled_fg_threshold=unlabeled_fg_threshold, diff --git a/vesuvius/src/vesuvius/models/preprocessing/patches/generate.py b/vesuvius/src/vesuvius/models/preprocessing/patches/generate.py index e0c3cebd6..a43b105ff 100644 --- a/vesuvius/src/vesuvius/models/preprocessing/patches/generate.py +++ b/vesuvius/src/vesuvius/models/preprocessing/patches/generate.py @@ -49,6 +49,15 @@ class PatchCacheResult: total_unlabeled_fg_patches: int +def _full_resolution_patch_size( + train_patch_size: Tuple[int, ...], + ome_zarr_resolution: int, +) -> Tuple[int, ...]: + """Convert a training-resolution patch size into full-resolution coordinates.""" + scale_factor = 2 ** int(ome_zarr_resolution) + return tuple(int(v) * scale_factor for v in train_patch_size) + + def generate_patch_caches( config_path: Path, *, @@ -78,6 +87,14 @@ def generate_patch_caches( data_path = Path(mgr.data_path) patch_size = tuple(int(v) for v in mgr.train_patch_size) + ome_zarr_resolution = int(getattr(mgr, "ome_zarr_resolution", 0)) + full_res_patch_size = _full_resolution_patch_size(patch_size, ome_zarr_resolution) + logger.info( + "Training resolution level %d uses patch size %s; generating cache with full-resolution patch size %s", + ome_zarr_resolution, + patch_size, + full_res_patch_size, + ) # Resolve target names target_names = _resolve_target_names(mgr) @@ -106,6 +123,7 @@ def generate_patch_caches( min_labeled_ratio=float(getattr(mgr, "min_labeled_ratio", 0.10)), bbox_threshold=float(getattr(mgr, "min_bbox_percent", 0.95)), valid_patch_find_resolution=int(getattr(mgr, "valid_patch_find_resolution", 1)), + ome_zarr_resolution=ome_zarr_resolution, valid_patch_value=_resolve_valid_patch_value(target_names, mgr), unlabeled_fg_enabled=bool(getattr(mgr, "unlabeled_foreground_enabled", False)), unlabeled_fg_threshold=float(getattr(mgr, "unlabeled_foreground_threshold", 0.05)), @@ -171,7 +189,7 @@ def generate_patch_caches( result = find_valid_patches( label_arrays=label_arrays, label_names=label_names, - patch_size=patch_size, + patch_size=full_res_patch_size, bbox_threshold=cache_params.bbox_threshold, label_threshold=cache_params.min_labeled_ratio, valid_patch_find_resolution=cache_params.valid_patch_find_resolution, diff --git a/vesuvius/src/vesuvius/models/training/loss/losses.py b/vesuvius/src/vesuvius/models/training/loss/losses.py index 746cb2899..cca4403df 100644 --- a/vesuvius/src/vesuvius/models/training/loss/losses.py +++ b/vesuvius/src/vesuvius/models/training/loss/losses.py @@ -53,6 +53,57 @@ def forward(self, input, target): return masked_loss.sum() / num_valid +class BinaryBCEAndDiceLoss(nn.Module): + """ + Binary BCE + soft dice loss for single-channel logits with optional ignore label. + """ + + def __init__( + self, + bce_kwargs=None, + weight_bce: float = 1.0, + weight_dice: float = 1.0, + smooth: float = 1e-5, + ignore_label: int | None = None, + ): + super().__init__() + self.bce = nn.BCEWithLogitsLoss(reduction="none", **(bce_kwargs or {})) + self.weight_bce = float(weight_bce) + self.weight_dice = float(weight_dice) + self.smooth = float(smooth) + self.ignore_label = ignore_label + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + if input.shape[1] != 1: + raise ValueError( + f"BinaryBCEAndDiceLoss expects a single-channel prediction, got shape {tuple(input.shape)}" + ) + if target.ndim != input.ndim: + target = target.view((target.shape[0], 1, *target.shape[1:])) + + target = target.float() + if self.ignore_label is not None: + valid_mask = target != float(self.ignore_label) + target = torch.where(valid_mask, target, torch.zeros_like(target)) + else: + valid_mask = torch.ones_like(target, dtype=torch.bool) + + bce_map = self.bce(input, target) + valid_mask_f = valid_mask.float() + valid_voxels = valid_mask_f.sum().clamp(min=1.0) + bce_loss = (bce_map * valid_mask_f).sum() / valid_voxels + + probs = torch.sigmoid(input) * valid_mask_f + target_masked = target * valid_mask_f + axes = tuple(range(2, probs.ndim)) + intersection = (probs * target_masked).sum(axes) + pred_sum = probs.sum(axes) + target_sum = target_masked.sum(axes) + dice = (2.0 * intersection + self.smooth) / (pred_sum + target_sum + self.smooth) + dice_loss = 1.0 - dice.mean() + return self.weight_bce * bce_loss + self.weight_dice * dice_loss + + class SkipLastTargetChannelWrapper(nn.Module): """ Loss wrapper which removes additional target channel @@ -1142,7 +1193,19 @@ def _create_loss(name, loss_config, weight, ignore_index, pos_weight, mgr=None): use_ignore_label=use_ignore_label, dice_class=MemoryEfficientSoftDiceLoss ) - + + elif name == 'BinaryBCEAndDiceLoss': + bce_kwargs = dict(loss_config.get('bce_kwargs', {})) + if pos_weight is not None and 'pos_weight' not in bce_kwargs: + bce_kwargs['pos_weight'] = pos_weight + base_loss = BinaryBCEAndDiceLoss( + bce_kwargs=bce_kwargs, + weight_bce=loss_config.get('weight_bce', loss_config.get('weight_ce', 1)), + weight_dice=loss_config.get('weight_dice', 1), + smooth=loss_config.get('smooth', 1e-5), + ignore_label=ignore_index, + ) + elif name == 'MemoryEfficientSoftDiceLoss': # Standalone memory efficient dice loss base_loss = MemoryEfficientSoftDiceLoss( diff --git a/vesuvius/src/vesuvius/scripts/probe_surface_fit.py b/vesuvius/src/vesuvius/scripts/probe_surface_fit.py new file mode 100644 index 000000000..936163c49 --- /dev/null +++ b/vesuvius/src/vesuvius/scripts/probe_surface_fit.py @@ -0,0 +1,269 @@ +from __future__ import annotations + +import argparse +import json +import subprocess +import sys +from pathlib import Path +from types import SimpleNamespace +from typing import Iterable + +import torch + +from vesuvius.models.build.build_network_from_config import NetworkFromConfig + + +FIXED_TARGETS = { + "surface": { + "out_channels": 2, + "activation": "none", + "ignore_label": 2, + } +} + +FIXED_MODEL_CONFIG = { + "features_per_stage": [32, 64, 128, 256, 320, 320], + "n_stages": 6, + "n_blocks_per_stage": [1, 3, 4, 6, 6, 6], + "n_conv_per_stage_decoder": [1, 1, 1, 1, 1], + "kernel_sizes": [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], + "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], + "separate_decoders": True, +} + + +def _flatten_tensors(value) -> Iterable[torch.Tensor]: + if isinstance(value, torch.Tensor): + yield value + return + if isinstance(value, dict): + for child in value.values(): + yield from _flatten_tensors(child) + return + if isinstance(value, (list, tuple)): + for child in value: + yield from _flatten_tensors(child) + + +def _build_manager(patch_size: tuple[int, int, int], batch_size: int) -> SimpleNamespace: + return SimpleNamespace( + targets=FIXED_TARGETS, + model_name="surface-fit-probe", + train_patch_size=patch_size, + train_batch_size=batch_size, + in_channels=1, + spacing=(1.0, 1.0, 1.0), + autoconfigure=False, + model_config=FIXED_MODEL_CONFIG, + enable_deep_supervision=True, + op_dims=3, + ) + + +def _run_single_trial(batch_size: int, patch_size: tuple[int, int, int]) -> int: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for probe_surface_fit") + + device = torch.device("cuda") + torch.backends.cudnn.benchmark = False + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + + model = None + optimizer = None + inputs = None + loss = None + outputs = None + try: + mgr = _build_manager(patch_size, batch_size) + model = NetworkFromConfig(mgr).to(device) + model.train() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.99) + inputs = torch.randn((batch_size, 1, *patch_size), device=device) + + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.float16): + outputs = model(inputs) + loss = None + for tensor in _flatten_tensors(outputs): + term = tensor.float().mean() + loss = term if loss is None else loss + term + if loss is None: + raise RuntimeError("Network produced no tensors during fit probe") + loss.backward() + optimizer.step() + torch.cuda.synchronize(device) + + payload = { + "success": True, + "batch_size": int(batch_size), + "patch_size": list(patch_size), + "peak_memory_allocated": int(torch.cuda.max_memory_allocated(device)), + "total_memory": int(torch.cuda.get_device_properties(device).total_memory), + } + print(json.dumps(payload, sort_keys=True)) + return 0 + except RuntimeError as exc: + message = str(exc) + payload = { + "success": False, + "batch_size": int(batch_size), + "patch_size": list(patch_size), + "error": message, + } + print(json.dumps(payload, sort_keys=True)) + if "out of memory" in message.lower(): + return 3 + return 4 + finally: + del outputs + del loss + del inputs + del optimizer + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def _invoke_trial(batch_size: int, patch_size: tuple[int, int, int]) -> dict: + cmd = [ + sys.executable, + __file__, + "--mode", + "single-trial", + "--batch-size", + str(batch_size), + "--patch-size", + ",".join(str(v) for v in patch_size), + ] + proc = subprocess.run( + cmd, + capture_output=True, + text=True, + check=False, + ) + payload = None + lines = [line for line in proc.stdout.splitlines() if line.strip()] + for line in reversed(lines): + try: + payload = json.loads(line) + break + except json.JSONDecodeError: + continue + if payload is None: + raise RuntimeError( + "Trial produced no JSON payload.\n" + f"stdout:\n{proc.stdout}\n" + f"stderr:\n{proc.stderr}" + ) + payload["returncode"] = proc.returncode + return payload + + +def _candidate_edges(min_edge: int, max_edge: int, step: int) -> list[int]: + if min_edge > max_edge: + raise ValueError("min_edge must be <= max_edge") + if step <= 0: + raise ValueError("step must be > 0") + return list(range(min_edge, max_edge + 1, step)) + + +def _binary_search_max_success(candidates: list[int], evaluator) -> tuple[int, list[dict]]: + attempts: list[dict] = [] + lo = 0 + hi = len(candidates) - 1 + best_idx = None + while lo <= hi: + mid = (lo + hi) // 2 + candidate = candidates[mid] + payload = evaluator(candidate) + attempts.append(payload) + if payload.get("success"): + best_idx = mid + lo = mid + 1 + else: + hi = mid - 1 + if best_idx is None: + raise RuntimeError("No successful fit found in the provided search range") + return candidates[best_idx], attempts + + +def _load_existing(path: Path) -> dict: + if not path.exists(): + return {} + with path.open("r", encoding="utf-8") as handle: + return json.load(handle) + + +def _write_result(path: Path, key: str, payload: dict) -> None: + existing = _load_existing(path) + existing[key] = payload + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + json.dump(existing, handle, indent=2, sort_keys=True) + handle.write("\n") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Probe H100 fit limits for surface training.") + parser.add_argument("--mode", required=True, choices=["iso-bs2", "ps128-maxbs", "single-trial"]) + parser.add_argument("--min-edge", type=int, default=128) + parser.add_argument("--max-edge", type=int, default=512) + parser.add_argument("--step", type=int, default=32) + parser.add_argument("--min-batch", type=int, default=2) + parser.add_argument("--max-batch", type=int, default=64) + parser.add_argument("--patch-size", type=str, default=None) + parser.add_argument("--batch-size", type=int, default=None) + parser.add_argument("--out", type=Path, default=None) + args = parser.parse_args() + + if args.mode == "single-trial": + if args.patch_size is None or args.batch_size is None: + raise SystemExit("--mode single-trial requires --patch-size and --batch-size") + patch_size = tuple(int(v) for v in args.patch_size.split(",")) + if len(patch_size) != 3: + raise SystemExit("--patch-size must have exactly 3 comma-separated integers") + raise SystemExit(_run_single_trial(args.batch_size, patch_size)) + + if args.out is None: + raise SystemExit("--out is required for aggregate probe modes") + + if args.mode == "iso-bs2": + candidates = _candidate_edges(args.min_edge, args.max_edge, args.step) + + def evaluator(edge: int) -> dict: + payload = _invoke_trial(2, (edge, edge, edge)) + payload["candidate_edge"] = edge + return payload + + max_edge, attempts = _binary_search_max_success(candidates, evaluator) + result = { + "batch_size": 2, + "max_edge": max_edge, + "patch_size": [max_edge, max_edge, max_edge], + "attempts": attempts, + } + _write_result(args.out, "iso_bs2", result) + print(json.dumps(result, sort_keys=True)) + return + + candidates = list(range(args.min_batch, args.max_batch + 1)) + + def evaluator(batch_size: int) -> dict: + payload = _invoke_trial(batch_size, (128, 128, 128)) + payload["candidate_batch_size"] = batch_size + return payload + + max_batch, attempts = _binary_search_max_success(candidates, evaluator) + result = { + "batch_size": max_batch, + "max_batch": max_batch, + "patch_size": [128, 128, 128], + "attempts": attempts, + } + _write_result(args.out, "ps128_maxbs", result) + print(json.dumps(result, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/vesuvius/tests/models/test_surface_multiscale_training.py b/vesuvius/tests/models/test_surface_multiscale_training.py new file mode 100644 index 000000000..c017aaf79 --- /dev/null +++ b/vesuvius/tests/models/test_surface_multiscale_training.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from pathlib import Path + +import yaml +import pytest +import torch + +from vesuvius.models.configuration.config_manager import ConfigManager +from vesuvius.models.datasets.zarr_dataset import ZarrDataset +from vesuvius.models.preprocessing.patches.cache import build_cache_params, cache_filename +from vesuvius.models.preprocessing.patches.generate import _full_resolution_patch_size +from vesuvius.models.training.loss.losses import BinaryBCEAndDiceLoss + + +def _write_config( + tmp_path: Path, + *, + ome_zarr_resolution: int, + valid_patch_find_resolution: int, +) -> Path: + config = { + "tr_setup": { + "model_name": "surface-test", + }, + "tr_config": { + "patch_size": [128, 128, 128], + "batch_size": 2, + }, + "dataset_config": { + "data_path": str(tmp_path), + "ome_zarr_resolution": ome_zarr_resolution, + "valid_patch_find_resolution": valid_patch_find_resolution, + "targets": { + "surface": { + "activation": "none", + "ignore_label": 2, + } + }, + }, + } + config_path = tmp_path / "config.yaml" + config_path.write_text(yaml.safe_dump(config), encoding="utf-8") + return config_path + + +def test_config_manager_loads_ome_zarr_resolution(tmp_path: Path) -> None: + mgr = ConfigManager(verbose=False) + mgr.load_config(_write_config(tmp_path, ome_zarr_resolution=2, valid_patch_find_resolution=3)) + assert mgr.ome_zarr_resolution == 2 + assert mgr.valid_patch_find_resolution == 3 + + +def test_config_manager_rejects_patch_find_resolution_below_training_resolution(tmp_path: Path) -> None: + mgr = ConfigManager(verbose=False) + with pytest.raises(ValueError, match="valid_patch_find_resolution"): + mgr.load_config(_write_config(tmp_path, ome_zarr_resolution=2, valid_patch_find_resolution=1)) + + +def test_patch_cache_filename_varies_by_training_resolution(tmp_path: Path) -> None: + common_kwargs = { + "data_path": tmp_path, + "volume_ids": ["sample"], + "patch_size": [128, 128, 128], + "min_labeled_ratio": 0.001, + "bbox_threshold": 0.35, + "valid_patch_find_resolution": 3, + } + scale0 = build_cache_params(ome_zarr_resolution=0, **common_kwargs) + scale2 = build_cache_params(ome_zarr_resolution=2, **common_kwargs) + assert cache_filename(scale0) != cache_filename(scale2) + + +def test_cached_positions_scale_to_training_level() -> None: + assert ZarrDataset._cached_position_to_training_level((256, 128, 64), 2) == (64, 32, 16) + with pytest.raises(ValueError, match="not divisible"): + ZarrDataset._cached_position_to_training_level((258, 128, 64), 2) + + +def test_full_resolution_patch_size_uses_training_scale() -> None: + assert _full_resolution_patch_size((128, 128, 128), 0) == (128, 128, 128) + assert _full_resolution_patch_size((128, 128, 128), 2) == (512, 512, 512) + + +def test_binary_bce_and_dice_loss_ignores_ignore_label() -> None: + loss_fn = BinaryBCEAndDiceLoss(ignore_label=2) + logits = torch.tensor([[[[[0.0, 1.0], [2.0, -1.0]]]]], dtype=torch.float32) + target = torch.tensor([[[[[0.0, 1.0], [2.0, 0.0]]]]], dtype=torch.float32) + loss = loss_fn(logits, target) + assert torch.isfinite(loss) + assert loss.item() >= 0.0 From 08371f9b884ad901db9cee34f92b8847bd73f95c Mon Sep 17 00:00:00 2001 From: Giorgio Angelotti <76100950+giorgioangel@users.noreply.github.com> Date: Sat, 28 Mar 2026 10:12:10 +0100 Subject: [PATCH 2/4] Fix BCE validation activation handling --- .../models/build/build_network_from_config.py | 8 +++-- .../surface_resenc_s0_ps128_bs28_bcedice.yaml | 3 +- .../surface_resenc_s0_ps256_bs3_bcedice.yaml | 3 +- .../surface_resenc_s2_ps128_bs28_bcedice.yaml | 2 +- .../surface_resenc_s2_ps256_bs3_bcedice.yaml | 3 +- .../vesuvius/models/evaluation/base_metric.py | 36 ++++++++++++++++++- .../models/evaluation/connected_components.py | 17 ++------- .../vesuvius/models/evaluation/iou_dice.py | 22 ++---------- .../src/vesuvius/models/evaluation/voi.py | 15 ++------ .../src/vesuvius/models/training/train.py | 8 +++-- vesuvius/src/vesuvius/utils/plotting.py | 27 +++++++++++--- .../test_surface_multiscale_training.py | 19 ++++++++++ 12 files changed, 98 insertions(+), 65 deletions(-) diff --git a/vesuvius/src/vesuvius/models/build/build_network_from_config.py b/vesuvius/src/vesuvius/models/build/build_network_from_config.py index d3f5974d3..5e0988566 100644 --- a/vesuvius/src/vesuvius/models/build/build_network_from_config.py +++ b/vesuvius/src/vesuvius/models/build/build_network_from_config.py @@ -849,9 +849,11 @@ def check_input_channels(self, x): return False return True - def forward(self, x, return_mae_mask=False): + def forward(self, x, return_mae_mask=False, apply_activation=None): # Check input channels and warn if mismatch self.check_input_channels(x) + if apply_activation is None: + apply_activation = not self.training # Get features from encoder (works for both U-Net and Primus) # For MAE training with Primus, we need to get the mask @@ -897,7 +899,7 @@ def forward(self, x, return_mae_mask=False): logits = logits[0] logits = self._apply_z_projection(task_name, logits) activation_fn = self.task_activations[task_name] if task_name in self.task_activations else None - if activation_fn is not None and not self.training: + if activation_fn is not None and apply_activation: if isinstance(logits, (list, tuple)): logits = type(logits)(activation_fn(l) for l in logits) else: @@ -912,7 +914,7 @@ def forward(self, x, return_mae_mask=False): logits = head(shared_features) logits = self._apply_z_projection(task_name, logits) activation_fn = self.task_activations[task_name] if task_name in self.task_activations else None - if activation_fn is not None and not self.training: + if activation_fn is not None and apply_activation: if isinstance(logits, (list, tuple)): logits = type(logits)(activation_fn(l) for l in logits) else: diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps128_bs28_bcedice.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps128_bs28_bcedice.yaml index b3c2f0e00..d87a38a23 100644 --- a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps128_bs28_bcedice.yaml +++ b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps128_bs28_bcedice.yaml @@ -34,11 +34,10 @@ dataset_config: surface: out_channels: 1 valid_patch_value: 1 - activation: "sigmoid" + activation: "none" ignore_label: 2 losses: - name: "BinaryBCEAndDiceLoss" weight: 1.0 weight_bce: 1.0 weight_dice: 1.0 - diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps256_bs3_bcedice.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps256_bs3_bcedice.yaml index 9500bdadb..26fbe3f86 100644 --- a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps256_bs3_bcedice.yaml +++ b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s0_ps256_bs3_bcedice.yaml @@ -34,11 +34,10 @@ dataset_config: surface: out_channels: 1 valid_patch_value: 1 - activation: "sigmoid" + activation: "none" ignore_label: 2 losses: - name: "BinaryBCEAndDiceLoss" weight: 1.0 weight_bce: 1.0 weight_dice: 1.0 - diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps128_bs28_bcedice.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps128_bs28_bcedice.yaml index 1a43d1381..8381512bc 100644 --- a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps128_bs28_bcedice.yaml +++ b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps128_bs28_bcedice.yaml @@ -34,7 +34,7 @@ dataset_config: surface: out_channels: 1 valid_patch_value: 1 - activation: "sigmoid" + activation: "none" ignore_label: 2 losses: - name: "BinaryBCEAndDiceLoss" diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps256_bs3_bcedice.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps256_bs3_bcedice.yaml index 3f00feae6..87d77b083 100644 --- a/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps256_bs3_bcedice.yaml +++ b/vesuvius/src/vesuvius/models/configuration/single_task/surface_resenc_s2_ps256_bs3_bcedice.yaml @@ -34,11 +34,10 @@ dataset_config: surface: out_channels: 1 valid_patch_value: 1 - activation: "sigmoid" + activation: "none" ignore_label: 2 losses: - name: "BinaryBCEAndDiceLoss" weight: 1.0 weight_bce: 1.0 weight_dice: 1.0 - diff --git a/vesuvius/src/vesuvius/models/evaluation/base_metric.py b/vesuvius/src/vesuvius/models/evaluation/base_metric.py index a2335fbbb..97abf77c0 100644 --- a/vesuvius/src/vesuvius/models/evaluation/base_metric.py +++ b/vesuvius/src/vesuvius/models/evaluation/base_metric.py @@ -4,6 +4,40 @@ import numpy as np +def _sigmoid_if_needed(array: np.ndarray) -> np.ndarray: + arr = array.astype(np.float32, copy=False) + if arr.size == 0: + return arr + if np.nanmin(arr) < 0.0 or np.nanmax(arr) > 1.0: + return 1.0 / (1.0 + np.exp(-arr)) + return arr + + +def prediction_to_discrete_labels(pred_np: np.ndarray) -> np.ndarray: + """Convert model outputs into discrete label maps for evaluation metrics.""" + if pred_np.ndim == 5: + if pred_np.shape[1] > 1: + return np.argmax(pred_np, axis=1).astype(np.int32) + probs = _sigmoid_if_needed(np.squeeze(pred_np, axis=1)) + return (probs >= 0.5).astype(np.int32) + + if pred_np.ndim == 4: + if pred_np.shape[1] <= 10: + if pred_np.shape[1] > 1: + return np.argmax(pred_np, axis=1).astype(np.int32) + probs = _sigmoid_if_needed(np.squeeze(pred_np, axis=1)) + return (probs >= 0.5).astype(np.int32) + return pred_np + + if pred_np.ndim == 3 and pred_np.shape[0] <= 10: + if pred_np.shape[0] > 1: + return np.argmax(pred_np, axis=0).astype(np.int32) + probs = _sigmoid_if_needed(np.squeeze(pred_np, axis=0)) + return (probs >= 0.5).astype(np.int32) + + return pred_np + + class BaseMetric(ABC): def __init__(self, name: str): self.name = name @@ -35,4 +69,4 @@ def aggregate(self) -> Dict[str, float]: return aggregated def reset(self): - self.results = [] \ No newline at end of file + self.results = [] diff --git a/vesuvius/src/vesuvius/models/evaluation/connected_components.py b/vesuvius/src/vesuvius/models/evaluation/connected_components.py index ac430000f..2fde39083 100644 --- a/vesuvius/src/vesuvius/models/evaluation/connected_components.py +++ b/vesuvius/src/vesuvius/models/evaluation/connected_components.py @@ -2,7 +2,7 @@ import cc3d import torch from typing import Dict, Optional -from .base_metric import BaseMetric +from .base_metric import BaseMetric, prediction_to_discrete_labels class ConnectedComponentsMetric(BaseMetric): @@ -64,20 +64,7 @@ def _normalize_mask(mask_arr: np.ndarray, target_shape: tuple) -> np.ndarray: else: mask_np = np.asarray(mask) - # Handle different input shapes for predictions - if pred_np.ndim == 5: # (batch, channels, depth, height, width) - if pred_np.shape[1] > 1: # Multi-channel, need argmax - pred_np = np.argmax(pred_np, axis=1) - else: # Single channel, just squeeze - pred_np = pred_np.squeeze(1) - elif pred_np.ndim == 4: # Could be (batch, depth, height, width) or (batch, channels, height, width) - # Check if second dimension is channels (usually small) or spatial dimension - if pred_np.shape[1] <= 10: # Likely channels dimension - if pred_np.shape[1] > 1: - pred_np = np.argmax(pred_np, axis=1) - else: - pred_np = pred_np.squeeze(1) - # Otherwise assume it's already (batch, depth, height, width) + pred_np = prediction_to_discrete_labels(pred_np) # Handle different input shapes for ground truth (and align mask) if gt_np.ndim == 3: # (depth, height, width) diff --git a/vesuvius/src/vesuvius/models/evaluation/iou_dice.py b/vesuvius/src/vesuvius/models/evaluation/iou_dice.py index f3ffdeea7..41b418542 100644 --- a/vesuvius/src/vesuvius/models/evaluation/iou_dice.py +++ b/vesuvius/src/vesuvius/models/evaluation/iou_dice.py @@ -1,7 +1,7 @@ import torch import numpy as np from typing import Dict, Optional -from .base_metric import BaseMetric +from .base_metric import BaseMetric, prediction_to_discrete_labels class IOUDiceMetric(BaseMetric): @@ -64,25 +64,7 @@ def _normalize_mask(mask_arr: np.ndarray, target_shape: tuple) -> np.ndarray: else: mask_np = np.asarray(mask) - # Handle different input shapes for predictions - if pred_np.ndim == 5: # (batch, channels, depth, height, width) - if pred_np.shape[1] > 1: # Multi-channel, need argmax - pred_np = np.argmax(pred_np, axis=1) - else: # Single channel, just squeeze - pred_np = pred_np.squeeze(1) - elif pred_np.ndim == 4: # Could be (batch, depth, height, width) or (batch, channels, height, width) - # Check if second dimension is channels (usually small) or spatial dimension - if pred_np.shape[1] <= 10: # Likely channels dimension - if pred_np.shape[1] > 1: - pred_np = np.argmax(pred_np, axis=1) - else: - pred_np = pred_np.squeeze(1) - # Otherwise assume it's already (batch, depth, height, width) or (batch, height, width) - elif pred_np.ndim == 3 and pred_np.shape[0] <= 10: # (channels, height, width) - if pred_np.shape[0] > 1: - pred_np = np.argmax(pred_np, axis=0) - else: - pred_np = pred_np.squeeze(0) + pred_np = prediction_to_discrete_labels(pred_np) # Handle different input shapes for ground truth (and align mask if provided) if gt_np.ndim == 3: # (depth, height, width) or (height, width) diff --git a/vesuvius/src/vesuvius/models/evaluation/voi.py b/vesuvius/src/vesuvius/models/evaluation/voi.py index 564479a8d..bd7ccfe50 100644 --- a/vesuvius/src/vesuvius/models/evaluation/voi.py +++ b/vesuvius/src/vesuvius/models/evaluation/voi.py @@ -4,7 +4,7 @@ from typing import Dict, Optional, Tuple from skimage.metrics import variation_of_information -from .base_metric import BaseMetric +from .base_metric import BaseMetric, prediction_to_discrete_labels class VOIMetric(BaseMetric): @@ -172,18 +172,7 @@ def compute_voi( else: mask_np = np.asarray(mask).astype(bool) - # Handle different input shapes for predictions - if pred_np.ndim == 5: # (batch, channels, depth, height, width) - if pred_np.shape[1] > 1: # Multi-channel, need argmax - pred_np = np.argmax(pred_np, axis=1) - else: # Single channel, just squeeze - pred_np = pred_np.squeeze(1) - elif pred_np.ndim == 4: # Could be (batch, depth, height, width) or (batch, channels, height, width) - if pred_np.shape[1] <= 10: # Likely channels dimension - if pred_np.shape[1] > 1: - pred_np = np.argmax(pred_np, axis=1) - else: - pred_np = pred_np.squeeze(1) + pred_np = prediction_to_discrete_labels(pred_np) # Handle different input shapes for ground truth if gt_np.ndim == 3: # (depth, height, width) diff --git a/vesuvius/src/vesuvius/models/training/train.py b/vesuvius/src/vesuvius/models/training/train.py index 6c1b36995..371a537fc 100644 --- a/vesuvius/src/vesuvius/models/training/train.py +++ b/vesuvius/src/vesuvius/models/training/train.py @@ -1,5 +1,6 @@ from pathlib import Path from copy import deepcopy +import inspect import os from datetime import datetime from tqdm import tqdm @@ -1125,8 +1126,11 @@ def _extract_targets(self, data_dict): def _get_model_outputs(self, model, data_dict): inputs = data_dict["image"].to(self.device) targets_dict = self._extract_targets(data_dict) - - outputs = model(inputs) + + if "apply_activation" in inspect.signature(model.forward).parameters: + outputs = model(inputs, apply_activation=False) + else: + outputs = model(inputs) # If deep supervision is enabled, prepare lists of downsampled targets if getattr(self.mgr, 'enable_deep_supervision', False): diff --git a/vesuvius/src/vesuvius/utils/plotting.py b/vesuvius/src/vesuvius/utils/plotting.py index 235e30e6b..a1e2e6292 100644 --- a/vesuvius/src/vesuvius/utils/plotting.py +++ b/vesuvius/src/vesuvius/utils/plotting.py @@ -222,6 +222,25 @@ def _apply_activation(array_np: np.ndarray, activation: Optional[str], *, is_sur return array_np +def _resolve_visualization_activation(task_cfg: Dict | None) -> Optional[str]: + task_cfg = task_cfg or {} + override = task_cfg.get("visualization_activation") + if override is not None: + return override + + activation = task_cfg.get("activation", None) + activation_l = str(activation).lower() if activation is not None else "none" + if activation_l not in {"none", "identity", ""}: + return activation + + losses = task_cfg.get("losses") or [] + if task_cfg.get("out_channels") == 1: + for loss_cfg in losses: + if (loss_cfg or {}).get("name") == "BinaryBCEAndDiceLoss": + return "sigmoid" + return activation + + def _vector_to_bgr(vector_3ch): """Map a 3×H×W vector field to a BGR image using directional colouring.""" if vector_3ch.shape[0] != 3: @@ -393,7 +412,7 @@ def save_debug( arr_np = arr_np[0] task_cfg = tasks_dict.get(t_name, {}) if tasks_dict else {} - activation = task_cfg.get("activation", None) + activation = _resolve_visualization_activation(task_cfg) is_surface_frame = arr_np.shape[0] == 9 or t_name.endswith("surface_frame") if apply_activation: @@ -436,7 +455,7 @@ def save_debug( arr_np = arr_np[0] task_cfg = tasks_dict.get(t_name, {}) if tasks_dict else {} - activation = task_cfg.get("activation", None) + activation = _resolve_visualization_activation(task_cfg) is_surface_frame = arr_np.shape[0] == 9 or t_name.endswith("surface_frame") if apply_activation: @@ -466,7 +485,7 @@ def save_debug( arr_np = arr_np[0] task_cfg = tasks_dict.get(t_name, {}) if tasks_dict else {} - activation = task_cfg.get("activation", None) + activation = _resolve_visualization_activation(task_cfg) is_surface_frame = arr_np.shape[0] == 9 or t_name.endswith("surface_frame") if apply_activation: @@ -483,7 +502,7 @@ def save_debug( arr_np = arr_np[0] task_cfg = tasks_dict.get(t_name, {}) if tasks_dict else {} - activation = task_cfg.get("activation", None) + activation = _resolve_visualization_activation(task_cfg) is_surface_frame = arr_np.shape[0] == 9 or t_name.endswith("surface_frame") if apply_activation: diff --git a/vesuvius/tests/models/test_surface_multiscale_training.py b/vesuvius/tests/models/test_surface_multiscale_training.py index c017aaf79..c44c3e863 100644 --- a/vesuvius/tests/models/test_surface_multiscale_training.py +++ b/vesuvius/tests/models/test_surface_multiscale_training.py @@ -5,12 +5,15 @@ import yaml import pytest import torch +import numpy as np from vesuvius.models.configuration.config_manager import ConfigManager from vesuvius.models.datasets.zarr_dataset import ZarrDataset +from vesuvius.models.evaluation.base_metric import prediction_to_discrete_labels from vesuvius.models.preprocessing.patches.cache import build_cache_params, cache_filename from vesuvius.models.preprocessing.patches.generate import _full_resolution_patch_size from vesuvius.models.training.loss.losses import BinaryBCEAndDiceLoss +from vesuvius.utils.plotting import _resolve_visualization_activation def _write_config( @@ -89,3 +92,19 @@ def test_binary_bce_and_dice_loss_ignores_ignore_label() -> None: loss = loss_fn(logits, target) assert torch.isfinite(loss) assert loss.item() >= 0.0 + + +def test_prediction_to_discrete_labels_thresholds_single_channel_logits() -> None: + logits = np.array([[[[[-2.0, 2.0], [0.1, -0.1]]]]], dtype=np.float32) + labels = prediction_to_discrete_labels(logits) + assert labels.shape == (1, 1, 2, 2) + assert labels.tolist() == [[[[0, 1], [1, 0]]]] + + +def test_visualization_activation_uses_sigmoid_for_binary_bce_dice() -> None: + task_cfg = { + "out_channels": 1, + "activation": "none", + "losses": [{"name": "BinaryBCEAndDiceLoss"}], + } + assert _resolve_visualization_activation(task_cfg) == "sigmoid" From bb5cf4edff0da77110f06cf98cd52aac8fd350f1 Mon Sep 17 00:00:00 2001 From: Giorgio Angelotti <76100950+giorgioangel@users.noreply.github.com> Date: Sat, 28 Mar 2026 14:12:21 +0100 Subject: [PATCH 3/4] Honor ignore label in patch cache lookup --- .../vesuvius/models/datasets/zarr_dataset.py | 16 ++++++++++- .../models/preprocessing/patches/cache.py | 10 ++++++- .../models/preprocessing/patches/generate.py | 27 +++++++++++++++++++ .../test_surface_multiscale_training.py | 14 ++++++++++ 4 files changed, 65 insertions(+), 2 deletions(-) diff --git a/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py b/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py index d164ee684..edd01fbd3 100644 --- a/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py +++ b/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py @@ -553,14 +553,27 @@ def _try_load_cache(self): # Get valid_patch_value from target config, with dataset-level fallback valid_patch_value = None + ignore_label = None for target_name in self.target_names: info = self.targets.get(target_name, {}) + if ignore_label is None: + ignore_label = info.get('ignore_label') + if ignore_label is None: + ignore_label = info.get('ignore_index') + if ignore_label is None: + ignore_label = info.get('ignore_value') if 'valid_patch_value' in info: valid_patch_value = info['valid_patch_value'] break + dataset_cfg = getattr(self.mgr, "dataset_config", {}) or {} if valid_patch_value is None: - dataset_cfg = getattr(self.mgr, "dataset_config", {}) or {} valid_patch_value = dataset_cfg.get("valid_patch_value") + if ignore_label is None: + ignore_label = dataset_cfg.get("ignore_label") + if ignore_label is None: + ignore_label = dataset_cfg.get("ignore_index") + if ignore_label is None: + ignore_label = dataset_cfg.get("ignore_value") volume_ids = [vol.volume_id for vol in self._volumes] @@ -573,6 +586,7 @@ def _try_load_cache(self): bbox_threshold=self.min_bbox_percent, valid_patch_find_resolution=self.valid_patch_find_resolution, ome_zarr_resolution=self.ome_zarr_resolution, + ignore_label=ignore_label, valid_patch_value=valid_patch_value, unlabeled_fg_enabled=self.unlabeled_fg_enabled, unlabeled_fg_threshold=self.unlabeled_fg_threshold, diff --git a/vesuvius/src/vesuvius/models/preprocessing/patches/cache.py b/vesuvius/src/vesuvius/models/preprocessing/patches/cache.py index eb27e851d..d912b865a 100644 --- a/vesuvius/src/vesuvius/models/preprocessing/patches/cache.py +++ b/vesuvius/src/vesuvius/models/preprocessing/patches/cache.py @@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple -SCHEMA_VERSION = 4 +SCHEMA_VERSION = 5 @dataclass(frozen=True) @@ -31,6 +31,7 @@ class PatchCacheParams: bbox_threshold: float valid_patch_find_resolution: int ome_zarr_resolution: int = 0 + ignore_label: Optional[float] = None valid_patch_value: Optional[float] = None unlabeled_fg_enabled: bool = True unlabeled_fg_threshold: float = 0.05 @@ -47,6 +48,7 @@ def to_dict(self) -> Dict[str, Any]: "bbox_threshold": float(self.bbox_threshold), "valid_patch_find_resolution": int(self.valid_patch_find_resolution), "ome_zarr_resolution": int(self.ome_zarr_resolution), + "ignore_label": self.ignore_label, "valid_patch_value": self.valid_patch_value, "unlabeled_fg_enabled": bool(self.unlabeled_fg_enabled), "unlabeled_fg_threshold": float(self.unlabeled_fg_threshold), @@ -81,6 +83,7 @@ def build_cache_params( bbox_threshold: float, valid_patch_find_resolution: int, ome_zarr_resolution: int = 0, + ignore_label: Optional[float] = None, valid_patch_value: Optional[float] = None, unlabeled_fg_enabled: bool = True, unlabeled_fg_threshold: float = 0.05, @@ -95,6 +98,7 @@ def build_cache_params( bbox_threshold=float(bbox_threshold), valid_patch_find_resolution=int(valid_patch_find_resolution), ome_zarr_resolution=int(ome_zarr_resolution), + ignore_label=ignore_label, valid_patch_value=valid_patch_value, unlabeled_fg_enabled=bool(unlabeled_fg_enabled), unlabeled_fg_threshold=float(unlabeled_fg_threshold), @@ -228,6 +232,7 @@ def try_load_patch_cache( bbox_threshold: float, valid_patch_find_resolution: int, ome_zarr_resolution: int = 0, + ignore_label: Optional[float] = None, valid_patch_value: Optional[float] = None, unlabeled_fg_enabled: bool = True, unlabeled_fg_threshold: float = 0.05, @@ -254,6 +259,8 @@ def try_load_patch_cache( Multi-resolution level for patch finding. ome_zarr_resolution : int Multi-resolution level used for training data reads. + ignore_label : Optional[float] + Label value that should be treated as ignored/background for cache generation. valid_patch_value : Optional[float] Specific label value to match. unlabeled_fg_enabled : bool @@ -276,6 +283,7 @@ def try_load_patch_cache( bbox_threshold=bbox_threshold, valid_patch_find_resolution=valid_patch_find_resolution, ome_zarr_resolution=ome_zarr_resolution, + ignore_label=ignore_label, valid_patch_value=valid_patch_value, unlabeled_fg_enabled=unlabeled_fg_enabled, unlabeled_fg_threshold=unlabeled_fg_threshold, diff --git a/vesuvius/src/vesuvius/models/preprocessing/patches/generate.py b/vesuvius/src/vesuvius/models/preprocessing/patches/generate.py index a43b105ff..22d3b1182 100644 --- a/vesuvius/src/vesuvius/models/preprocessing/patches/generate.py +++ b/vesuvius/src/vesuvius/models/preprocessing/patches/generate.py @@ -124,6 +124,7 @@ def generate_patch_caches( bbox_threshold=float(getattr(mgr, "min_bbox_percent", 0.95)), valid_patch_find_resolution=int(getattr(mgr, "valid_patch_find_resolution", 1)), ome_zarr_resolution=ome_zarr_resolution, + ignore_label=_resolve_ignore_label(target_names, mgr), valid_patch_value=_resolve_valid_patch_value(target_names, mgr), unlabeled_fg_enabled=bool(getattr(mgr, "unlabeled_foreground_enabled", False)), unlabeled_fg_threshold=float(getattr(mgr, "unlabeled_foreground_threshold", 0.05)), @@ -193,6 +194,11 @@ def generate_patch_caches( bbox_threshold=cache_params.bbox_threshold, label_threshold=cache_params.min_labeled_ratio, valid_patch_find_resolution=cache_params.valid_patch_find_resolution, + ignore_labels=( + [cache_params.ignore_label] * len(label_arrays) + if cache_params.ignore_label is not None + else None + ), valid_patch_values=( [cache_params.valid_patch_value] * len(label_arrays) if cache_params.valid_patch_value is not None @@ -283,6 +289,27 @@ def _resolve_valid_patch_value( return None +def _resolve_ignore_label( + target_names: List[str], + mgr, +) -> Optional[Union[int, float]]: + """Extract ignore label from target config, with dataset-level fallbacks.""" + targets = getattr(mgr, "targets", {}) + fallback_keys = ("ignore_label", "ignore_index", "ignore_value") + for target in target_names: + info = targets.get(target) or {} + for key in fallback_keys: + value = info.get(key) + if value is not None: + return value + dataset_cfg = getattr(mgr, "dataset_config", {}) or {} + for key in fallback_keys: + value = dataset_cfg.get(key) + if value is not None: + return value + return None + + def _discover_volumes( data_path: Path, target_names: List[str], diff --git a/vesuvius/tests/models/test_surface_multiscale_training.py b/vesuvius/tests/models/test_surface_multiscale_training.py index c44c3e863..9ba5e9799 100644 --- a/vesuvius/tests/models/test_surface_multiscale_training.py +++ b/vesuvius/tests/models/test_surface_multiscale_training.py @@ -74,6 +74,20 @@ def test_patch_cache_filename_varies_by_training_resolution(tmp_path: Path) -> N assert cache_filename(scale0) != cache_filename(scale2) +def test_patch_cache_filename_varies_by_ignore_label(tmp_path: Path) -> None: + common_kwargs = { + "data_path": tmp_path, + "volume_ids": ["sample"], + "patch_size": [128, 128, 128], + "min_labeled_ratio": 0.001, + "bbox_threshold": 0.35, + "valid_patch_find_resolution": 3, + } + ignore0 = build_cache_params(ignore_label=0, **common_kwargs) + ignore2 = build_cache_params(ignore_label=2, **common_kwargs) + assert cache_filename(ignore0) != cache_filename(ignore2) + + def test_cached_positions_scale_to_training_level() -> None: assert ZarrDataset._cached_position_to_training_level((256, 128, 64), 2) == (64, 32, 16) with pytest.raises(ValueError, match="not divisible"): From 8167dfba3079805c0fd7850dd8f5ea242948826d Mon Sep 17 00:00:00 2001 From: Giorgio Angelotti <76100950+giorgioangel@users.noreply.github.com> Date: Sat, 28 Mar 2026 14:34:53 +0100 Subject: [PATCH 4/4] Filter patches by ignore-label z-range --- .../models/datasets/find_valid_patches.py | 119 ++++++++++++++++++ .../vesuvius/models/datasets/zarr_dataset.py | 97 ++++++++++++++ .../models/preprocessing/patches/cache.py | 2 +- .../test_surface_multiscale_training.py | 55 ++++++++ 4 files changed, 272 insertions(+), 1 deletion(-) mode change 100755 => 100644 vesuvius/src/vesuvius/models/datasets/find_valid_patches.py diff --git a/vesuvius/src/vesuvius/models/datasets/find_valid_patches.py b/vesuvius/src/vesuvius/models/datasets/find_valid_patches.py old mode 100755 new mode 100644 index 21b068d5a..b00c1f0e8 --- a/vesuvius/src/vesuvius/models/datasets/find_valid_patches.py +++ b/vesuvius/src/vesuvius/models/datasets/find_valid_patches.py @@ -1,5 +1,6 @@ import logging import time +from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple, Union import numpy as np @@ -169,6 +170,98 @@ def zero_ignore_labels(array: np.ndarray, ignore_label: Union[int, float]) -> np return result +def _filter_z_starts_by_ignore_bounds( + z_starts: Sequence[int], + *, + patch_depth: int, + ignore_bounds: Optional[Tuple[int, int]], +) -> List[int]: + """Keep only patch starts whose z-interval intersects the ignore-label z-support.""" + if ignore_bounds is None: + return list(z_starts) + ignore_min_z, ignore_max_z = ignore_bounds + return [ + int(z_start) + for z_start in z_starts + if (int(z_start) + int(patch_depth) - 1) >= ignore_min_z and int(z_start) <= ignore_max_z + ] + + +def _find_ignore_z_bounds_from_stored_chunks( + array_obj, + *, + ignore_label: Union[int, float], +) -> Optional[Tuple[int, int]]: + """ + Find the exact z-range containing ignore labels using only stored chunks. + + Missing chunks are intentionally ignored here. This prevents fill-value based + ignore regions from expanding the usable z-range to the full volume. + """ + store = getattr(array_obj, "store", None) + store_path = getattr(store, "path", None) + if store_path is None: + return None + + chunk_shape = getattr(array_obj, "chunks", None) + spatial_shape = getattr(array_obj, "shape", None) + if chunk_shape is None or spatial_shape is None or len(chunk_shape) < 3 or len(spatial_shape) < 3: + return None + + chunk_dir = Path(store_path) + array_path = getattr(array_obj, "path", "") + if array_path: + chunk_dir = chunk_dir / str(array_path) + if not chunk_dir.exists(): + return None + + z_groups: Dict[int, List[Tuple[int, int, int]]] = {} + for chunk_file in chunk_dir.iterdir(): + if not chunk_file.is_file() or chunk_file.name.startswith("."): + continue + parts = chunk_file.name.split(".") + if len(parts) < 3: + continue + try: + chunk_idx = (int(parts[0]), int(parts[1]), int(parts[2])) + except ValueError: + continue + z_groups.setdefault(chunk_idx[0], []).append(chunk_idx) + + if not z_groups: + return None + + def _chunk_has_ignore(chunk: np.ndarray) -> np.ndarray: + if isinstance(ignore_label, float) and np.isnan(ignore_label): + return np.any(np.isnan(chunk), axis=(1, 2)) + return np.any(chunk == ignore_label, axis=(1, 2)) + + def _scan(sorted_z_indices: Sequence[int], *, from_start: bool) -> Optional[int]: + for z_chunk_idx in sorted_z_indices: + z0 = z_chunk_idx * int(chunk_shape[0]) + z_len = min(int(chunk_shape[0]), int(spatial_shape[0]) - z0) + local_hits = np.zeros(z_len, dtype=bool) + for chunk_idx in z_groups[z_chunk_idx]: + slices = tuple( + slice(idx * size, min((idx + 1) * size, shape)) + for idx, size, shape in zip(chunk_idx, chunk_shape[:3], spatial_shape[:3]) + ) + chunk = np.asarray(array_obj[slices]) + local_hits |= _chunk_has_ignore(chunk) + hit_indices = np.flatnonzero(local_hits) + if hit_indices.size: + local_idx = int(hit_indices[0] if from_start else hit_indices[-1]) + return z0 + local_idx + return None + + sorted_z = sorted(z_groups) + first = _scan(sorted_z, from_start=True) + last = _scan(sorted_z[::-1], from_start=False) + if first is None or last is None: + return None + return (first, last) + + def check_patch_chunk( chunk, sheet_label, @@ -858,6 +951,32 @@ def _resolve_resolution(array_obj, level_key): y_starts = list(range(vol_min_y, max(vol_min_y, vol_max_y - dpY + 1), dpY)) x_starts = list(range(vol_min_x, max(vol_min_x, vol_max_x - dpX + 1), dpX)) + ignore_z_bounds = None + if ignore_label is not None: + ignore_z_bounds = _find_ignore_z_bounds_from_stored_chunks( + downsampled_array, + ignore_label=ignore_label, + ) + if ignore_z_bounds is not None: + original_z_count = len(z_starts) + z_starts = _filter_z_starts_by_ignore_bounds( + z_starts, + patch_depth=dpZ, + ignore_bounds=ignore_z_bounds, + ) + logger.info( + "Volume '%s': restrict patch starts to ignore-label z-range %s -> %d/%d z positions", + label_name, + ignore_z_bounds, + len(z_starts), + original_z_count, + ) + else: + logger.warning( + "Volume '%s': could not determine stored ignore-label z-range; leaving z candidates unfiltered", + label_name, + ) + generate_elapsed = time.perf_counter() - position_gen_start candidate_count = ( len(y_starts) * len(x_starts) if is_2d else len(z_starts) * len(y_starts) * len(x_starts) diff --git a/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py b/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py index edd01fbd3..c330bd863 100644 --- a/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py +++ b/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py @@ -18,6 +18,7 @@ from typing import Dict, List, Optional, Tuple import numpy as np +from scipy import ndimage as ndi import torch import zarr from torch.utils.data import Dataset @@ -146,6 +147,14 @@ def __init__( # Transforms (initialized after normalization) self.transforms = None + self._zero_cc_to_ignore_enabled = bool( + getattr(mgr, 'zero_components_without_fg_to_ignore', True) + ) + self._zero_cc_to_ignore_structure = ndi.generate_binary_structure( + len(self.patch_size), + len(self.patch_size), + ) + self._target_cleanup_rules = self._build_target_cleanup_rules() # Initialize self._discover_and_load_volumes() @@ -292,6 +301,85 @@ def _build_target_volumes_for_intensity(self) -> Dict: }) return {first_target: volumes_list} + def _build_target_cleanup_rules(self) -> Dict[str, Tuple[Optional[float], Optional[float]]]: + """Collect per-target foreground/ignore values used for patch-local cleanup.""" + rules: Dict[str, Tuple[Optional[float], Optional[float]]] = {} + dataset_cfg = getattr(self.mgr, "dataset_config", {}) or {} + for target_name in self.target_names: + info = self.targets.get(target_name, {}) or {} + valid_patch_value = info.get("valid_patch_value") + if valid_patch_value is None: + valid_patch_value = dataset_cfg.get("valid_patch_value") + + ignore_label = info.get("ignore_label") + if ignore_label is None: + ignore_label = info.get("ignore_index") + if ignore_label is None: + ignore_label = info.get("ignore_value") + if ignore_label is None: + ignore_label = dataset_cfg.get("ignore_label") + if ignore_label is None: + ignore_label = dataset_cfg.get("ignore_index") + if ignore_label is None: + ignore_label = dataset_cfg.get("ignore_value") + + rules[target_name] = (valid_patch_value, ignore_label) + return rules + + def _zero_components_without_foreground_to_ignore( + self, + label_data: np.ndarray, + *, + valid_patch_value: Optional[float], + ignore_label: Optional[float], + ) -> np.ndarray: + """ + Relabel patch-local zero components that do not touch foreground as ignore. + + This is intentionally patch-local. It keeps zero regions connected to any + foreground voxel unchanged and turns the rest into the configured + ignore label. The propagation is implemented with scipy ndimage to keep + the hot path in compiled code. + """ + if ( + not self._zero_cc_to_ignore_enabled + or valid_patch_value is None + or ignore_label is None + ): + return label_data + + zero_mask = label_data == 0 + if not np.any(zero_mask): + return label_data + + fg_mask = label_data == valid_patch_value + if not np.any(fg_mask): + cleaned = label_data.copy() + cleaned[zero_mask] = ignore_label + return cleaned + + touching_zero = zero_mask & ndi.binary_dilation( + fg_mask, + structure=self._zero_cc_to_ignore_structure, + iterations=1, + ) + if np.any(touching_zero): + zero_connected_to_fg = ndi.binary_propagation( + touching_zero, + structure=self._zero_cc_to_ignore_structure, + mask=zero_mask, + ) + else: + zero_connected_to_fg = np.zeros_like(zero_mask, dtype=bool) + + isolated_zero = zero_mask & ~zero_connected_to_fg + if not np.any(isolated_zero): + return label_data + + cleaned = label_data.copy() + cleaned[isolated_zero] = ignore_label + return cleaned + # ------------------------------------------------------------------------- # Patch Index Building # ------------------------------------------------------------------------- @@ -799,6 +887,15 @@ def load_array( for target_name in self.target_names: label_arr = vol.label_arrays.get(target_name) label_data = load_array(label_arr) + valid_patch_value, ignore_label = self._target_cleanup_rules.get( + target_name, + (None, None), + ) + label_data = self._zero_components_without_foreground_to_ignore( + label_data, + valid_patch_value=valid_patch_value, + ignore_label=ignore_label, + ) if label_arr is not None and np.count_nonzero(label_data) > 0: is_unlabeled = False result[target_name] = torch.from_numpy(label_data[np.newaxis, ...]) diff --git a/vesuvius/src/vesuvius/models/preprocessing/patches/cache.py b/vesuvius/src/vesuvius/models/preprocessing/patches/cache.py index d912b865a..34c47820a 100644 --- a/vesuvius/src/vesuvius/models/preprocessing/patches/cache.py +++ b/vesuvius/src/vesuvius/models/preprocessing/patches/cache.py @@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple -SCHEMA_VERSION = 5 +SCHEMA_VERSION = 6 @dataclass(frozen=True) diff --git a/vesuvius/tests/models/test_surface_multiscale_training.py b/vesuvius/tests/models/test_surface_multiscale_training.py index 9ba5e9799..4ce68c2d7 100644 --- a/vesuvius/tests/models/test_surface_multiscale_training.py +++ b/vesuvius/tests/models/test_surface_multiscale_training.py @@ -6,8 +6,10 @@ import pytest import torch import numpy as np +from scipy import ndimage as ndi from vesuvius.models.configuration.config_manager import ConfigManager +from vesuvius.models.datasets.find_valid_patches import _filter_z_starts_by_ignore_bounds from vesuvius.models.datasets.zarr_dataset import ZarrDataset from vesuvius.models.evaluation.base_metric import prediction_to_discrete_labels from vesuvius.models.preprocessing.patches.cache import build_cache_params, cache_filename @@ -94,6 +96,59 @@ def test_cached_positions_scale_to_training_level() -> None: ZarrDataset._cached_position_to_training_level((258, 128, 64), 2) +def test_zero_components_without_foreground_become_ignore() -> None: + ds = ZarrDataset.__new__(ZarrDataset) + ds._zero_cc_to_ignore_enabled = True + ds._zero_cc_to_ignore_structure = ndi.generate_binary_structure(2, 2) + label = np.array( + [ + [0, 0, 2, 2], + [0, 1, 2, 0], + [0, 0, 2, 0], + [2, 2, 2, 0], + ], + dtype=np.float32, + ) + cleaned = ds._zero_components_without_foreground_to_ignore( + label, + valid_patch_value=1, + ignore_label=2, + ) + expected = np.array( + [ + [0, 0, 2, 2], + [0, 1, 2, 2], + [0, 0, 2, 2], + [2, 2, 2, 2], + ], + dtype=np.float32, + ) + np.testing.assert_array_equal(cleaned, expected) + + +def test_all_zero_patch_becomes_ignore_when_no_foreground_present() -> None: + ds = ZarrDataset.__new__(ZarrDataset) + ds._zero_cc_to_ignore_enabled = True + ds._zero_cc_to_ignore_structure = ndi.generate_binary_structure(3, 3) + label = np.zeros((2, 2, 2), dtype=np.float32) + cleaned = ds._zero_components_without_foreground_to_ignore( + label, + valid_patch_value=1, + ignore_label=2, + ) + np.testing.assert_array_equal(cleaned, np.full_like(label, 2.0)) + + +def test_filter_z_starts_by_ignore_bounds_keeps_intersecting_patches() -> None: + z_starts = [0, 64, 128, 192] + filtered = _filter_z_starts_by_ignore_bounds( + z_starts, + patch_depth=64, + ignore_bounds=(96, 159), + ) + assert filtered == [64, 128] + + def test_full_resolution_patch_size_uses_training_scale() -> None: assert _full_resolution_patch_size((128, 128, 128), 0) == (128, 128, 128) assert _full_resolution_patch_size((128, 128, 128), 2) == (512, 512, 512)