diff --git a/vesuvius/docs/guided_dinovol.md b/vesuvius/docs/guided_dinovol.md new file mode 100644 index 000000000..ad749b1d7 --- /dev/null +++ b/vesuvius/docs/guided_dinovol.md @@ -0,0 +1,101 @@ +# Guided Dinovol Training + +## Environment + +From `villa/vesuvius`: + +```bash +uv sync --extra models --extra tests +``` + +## Example Guided Training + +Use the example config: + +```bash +uv run --extra models python -m vesuvius.models.training.train \ + --config src/vesuvius/models/configuration/single_task/ps128_guided_dinovol_ink.yaml \ + --input /path/to/data +``` + +The guided config uses the local volumetric DINO checkpoint at: + +```text +/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt +``` + +It also enables: + +```yaml +model_config: + guide_tokenbook_tokens: 256 +``` + +This is an opt-in speed setting for the example config only. The model default remains full-grid TokenBook prototypes when the key is omitted. + +Ps256 guided configs are also available: + +```text +src/vesuvius/models/configuration/single_task/ps256_guided_medial.yaml +src/vesuvius/models/configuration/single_task/ps256_guided_dicece.yaml +``` + +For large guided runs, generate patch caches before training: + +```bash +uv run --extra models vesuvius.find_patches \ + --config src/vesuvius/models/configuration/single_task/ps256_guided_medial.yaml +``` + +## Tests + +Run the guided coverage: + +```bash +uv run --extra models --extra tests python -m pytest \ + tests/models/configuration/test_ps256_config_compat.py \ + tests/models/build/test_dinovol_local_backbone.py \ + tests/models/build/test_guided_network.py \ + tests/models/training/test_guided_trainer.py -q +``` + +## Benchmark + +Profile unguided vs guided input gating on the local RTX 4090: + +```bash +uv run --extra models python -m vesuvius.models.benchmarks.benchmark_guided_dinovol \ + --guide-checkpoint /home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt \ + --patch-size 64,64,64 \ + --device cuda +``` + +For faster prototype-count sweeps without compile variants: + +```bash +uv run --extra models python -m vesuvius.models.benchmarks.benchmark_guided_dinovol \ + --guide-checkpoint /home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt \ + --patch-size 64,64,64 \ + --device cuda \ + --guide-tokenbook-tokens 256 \ + --skip-compile-variants \ + --skip-stage-breakdown +``` + +## Current Operational Guidance + +- Large guided runs should precompute patch caches with `vesuvius.find_patches` before training. +- Guided training now supports `tr_config.compile_policy`: + - `auto`: guided models compile the inner module, unguided models keep the legacy DDP-wrapper compile path + - `module`: compile the inner module before DDP wrapping + - `ddp_wrapper`: preserve the legacy `torch.compile(DDP(model))` path + - `off`: eager mode +- The guided backbone path is explicitly excluded from compiler capture because full-token guided `ps256` hit an Inductor `BackendCompilerFailed` crash during the first compiled train step. +- Do not expose or rely on a public `channels_last_3d` toggle; the measured gain was negligible relative to plain compile. +- Prefer capped TokenBook prototypes for large training patches; the example config uses `256` for `128^3`. +- The trainer now logs the guide validation visualization twice when guidance is enabled: + - embedded inside the composite `debug_image` + - separately as `debug_guide_image` +- Set `tr_config.startup_timing: true` when debugging slow startup or first-step stalls. This logs dataset init, model build, compile, first batch fetch, first forward/backward, and first optimizer step timings. +- On the local RTX 4090, full `256^3` forward inference for both unguided and guided ps256 configs hit OOM, so practical timing comparisons should use smaller patches or larger-memory GPUs. +- The guide panel/render overhead is small relative to model runtime, so it is reasonable to keep both composite and separate guide-image logging enabled. diff --git a/vesuvius/pyproject.toml b/vesuvius/pyproject.toml index 4b29348e8..993c166cb 100644 --- a/vesuvius/pyproject.toml +++ b/vesuvius/pyproject.toml @@ -70,6 +70,7 @@ models = [ "typed-argument-parser>=1.11.0", "wandb[media]>=0.22.0", "psutil>=7.1.0", + "nest-asyncio>=1.6.0", # Data pipeline / I/O and compression "aiohttp>=3.12.15", diff --git a/vesuvius/src/vesuvius/models/benchmarks/__init__.py b/vesuvius/src/vesuvius/models/benchmarks/__init__.py new file mode 100644 index 000000000..6748f90d2 --- /dev/null +++ b/vesuvius/src/vesuvius/models/benchmarks/__init__.py @@ -0,0 +1 @@ +"""Benchmark helpers for model profiling.""" diff --git a/vesuvius/src/vesuvius/models/benchmarks/benchmark_guided_dinovol.py b/vesuvius/src/vesuvius/models/benchmarks/benchmark_guided_dinovol.py new file mode 100644 index 000000000..e643af52a --- /dev/null +++ b/vesuvius/src/vesuvius/models/benchmarks/benchmark_guided_dinovol.py @@ -0,0 +1,672 @@ +from __future__ import annotations + +import argparse +import json +import time +from collections import defaultdict +from contextlib import nullcontext +from pathlib import Path +from types import SimpleNamespace + +import torch +import torch.nn.functional as F + +from vesuvius.models.build.build_network_from_config import NetworkFromConfig + + +def _parse_patch_size(text: str) -> tuple[int, int, int]: + values = tuple(int(v) for v in text.split(",")) + if len(values) != 3: + raise ValueError(f"Expected 3 comma-separated values, got {text!r}") + return values + + +def _sync_if_needed(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.synchronize(device) + + +def _to_device(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: + non_blocking = device.type == "cuda" + return tensor.to(device=device, non_blocking=non_blocking) + + +def _autocast_context(device: torch.device): + if device.type == "cuda": + return torch.amp.autocast("cuda") + return nullcontext() + + +def _make_mgr( + *, + patch_size: tuple[int, int, int], + guide_checkpoint: str | None, + guide_tokenbook_tokens: int | None = None, + guide_fusion_stage: str | None = None, + guide_tokenbook_prototype_weighting: str = "mean", + guide_tokenbook_weight_mlp_hidden: int | None = None, +) -> SimpleNamespace: + benchmark_features_per_stage = [16, 32, 64, 128, 192] + guided_config = {} + if guide_checkpoint is not None: + guided_config = { + "guide_backbone": str(guide_checkpoint), + "guide_freeze": True, + } + if guide_fusion_stage is not None: + guided_config["guide_fusion_stage"] = str(guide_fusion_stage) + if guide_fusion_stage != "feature_skip_concat": + guided_config["guide_tokenbook_sample_rate"] = 1.0 + guided_config["guide_tokenbook_prototype_weighting"] = str(guide_tokenbook_prototype_weighting) + if guide_tokenbook_tokens is not None: + guided_config["guide_tokenbook_tokens"] = int(guide_tokenbook_tokens) + if guide_tokenbook_weight_mlp_hidden is not None: + guided_config["guide_tokenbook_weight_mlp_hidden"] = int(guide_tokenbook_weight_mlp_hidden) + else: + guided_config["guide_skip_concat_projector_channels"] = [max(1, int(ch) // 4) for ch in benchmark_features_per_stage] + + return SimpleNamespace( + targets={"ink": {"out_channels": 2, "activation": "none"}}, + train_patch_size=patch_size, + train_batch_size=1, + in_channels=1, + autoconfigure=False, + model_name="guided_benchmark", + enable_deep_supervision=False, + spacing=(1.0, 1.0, 1.0), + model_config={ + "features_per_stage": benchmark_features_per_stage, + "n_stages": 5, + "n_blocks_per_stage": [1, 1, 1, 1, 1], + "n_conv_per_stage_decoder": [1, 1, 1, 1], + "kernel_sizes": [[3, 3, 3]] * 5, + "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], + "basic_encoder_block": "ConvBlock", + "basic_decoder_block": "ConvBlock", + "bottleneck_block": "BasicBlockD", + "separate_decoders": True, + "input_shape": list(patch_size), + **guided_config, + }, + ) + + +def _label_name(tokenbook_tokens: int | None) -> str: + return "full" if tokenbook_tokens is None else f"k{int(tokenbook_tokens)}" + + +def _build_model( + *, + patch_size: tuple[int, int, int], + device: torch.device, + guide_checkpoint: str | None, + guide_tokenbook_tokens: int | None, + guide_fusion_stage: str | None, + guide_tokenbook_prototype_weighting: str, + guide_tokenbook_weight_mlp_hidden: int | None, + compile_model: bool, + channels_last_3d: bool, +): + mgr = _make_mgr( + patch_size=patch_size, + guide_checkpoint=guide_checkpoint, + guide_tokenbook_tokens=guide_tokenbook_tokens, + guide_fusion_stage=guide_fusion_stage, + guide_tokenbook_prototype_weighting=guide_tokenbook_prototype_weighting, + guide_tokenbook_weight_mlp_hidden=guide_tokenbook_weight_mlp_hidden, + ) + model = NetworkFromConfig(mgr).to(device) + if channels_last_3d: + model = model.to(memory_format=torch.channels_last_3d) + if compile_model: + model = torch.compile(model, mode="default") + return model + + +def _make_cpu_batch( + patch_size: tuple[int, int, int], + *, + pin_memory: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + image = torch.randn(1, 1, *patch_size) + labels = torch.randint(0, 2, (1, *patch_size)) + if pin_memory: + image = image.pin_memory() + labels = labels.pin_memory() + return image, labels + + +def _compute_loss(logits_dict: dict[str, torch.Tensor], aux: dict[str, torch.Tensor], labels: torch.Tensor, *, guide_fusion_stage: str | None) -> torch.Tensor: + logits = logits_dict["ink"].float() + loss = F.cross_entropy(logits, labels) + if "guide_mask" in aux and guide_fusion_stage in {"input", "input_gating"}: + guide_target = (labels > 0).float().unsqueeze(1) + guide_target = F.interpolate(guide_target, size=aux["guide_mask"].shape[2:], mode="nearest") + loss = loss + 0.5 * F.binary_cross_entropy( + aux["guide_mask"].float().clamp(1e-6, 1.0 - 1e-6), + guide_target, + ) + return loss + + +def _timed_call(fn, *, device: torch.device): + _sync_if_needed(device) + start = time.perf_counter() + result = fn() + _sync_if_needed(device) + return result, (time.perf_counter() - start) * 1000.0 + + +def _run_variant_benchmark( + model, + *, + device: torch.device, + cpu_inputs: torch.Tensor, + cpu_labels: torch.Tensor, + iterations: int, +) -> dict[str, float | None]: + if device.type == "cuda": + torch.cuda.reset_peak_memory_stats(device) + + def _first_forward(): + inputs = _to_device(cpu_inputs, device) + with _autocast_context(device): + outputs = model(inputs) + logits = outputs["ink"].to(dtype=torch.float16) + _ = logits.cpu() + return outputs + + try: + _result, startup_ms = _timed_call(_first_forward, device=device) + except Exception as exc: # pragma: no cover - benchmark fallback path + return { + "startup_forward_ms": None, + "steady_state_forward_ms_mean": None, + "steady_state_train_step_ms_mean": None, + "max_memory_mb": ( + torch.cuda.max_memory_allocated(device) / (1024 ** 2) + if device.type == "cuda" + else None + ), + "error": f"{type(exc).__name__}: {str(exc).splitlines()[0]}", + } + + forward_times = [] + train_step_times = [] + optimizer = torch.optim.AdamW([param for param in model.parameters() if param.requires_grad], lr=1e-4) + model.train() + for _ in range(iterations): + inputs = _to_device(cpu_inputs, device) + labels = _to_device(cpu_labels, device) + + optimizer.zero_grad(set_to_none=True) + def _forward_only(): + with _autocast_context(device): + return model(inputs) + _, forward_ms = _timed_call( + _forward_only, + device=device, + ) + forward_times.append(forward_ms) + + def _train_step(): + with _autocast_context(device): + outputs = model(inputs, return_aux=True) if getattr(model, "guide_enabled", False) else model(inputs) + if getattr(model, "guide_enabled", False): + logits_dict, aux = outputs + else: + logits_dict, aux = outputs, {} + loss = _compute_loss( + logits_dict, + aux, + labels, + guide_fusion_stage=getattr(model, "guide_fusion_stage", None), + ) + loss.backward() + optimizer.step() + + try: + _, train_ms = _timed_call(_train_step, device=device) + train_step_times.append(train_ms) + except Exception as exc: # pragma: no cover - benchmark fallback path + train_step_times = [] + train_step_error = f"{type(exc).__name__}: {str(exc).splitlines()[0]}" + break + else: + train_step_error = None + + model.eval() + result = { + "startup_forward_ms": startup_ms, + "steady_state_forward_ms_mean": sum(forward_times) / len(forward_times), + "steady_state_train_step_ms_mean": ( + sum(train_step_times) / len(train_step_times) if train_step_times else None + ), + "max_memory_mb": ( + torch.cuda.max_memory_allocated(device) / (1024 ** 2) + if device.type == "cuda" + else None + ), + } + if train_step_error is not None: + result["train_step_error"] = train_step_error + return result + + +def _profile_guided_stages( + model: NetworkFromConfig, + *, + device: torch.device, + cpu_inputs: torch.Tensor, + iterations: int, +) -> dict[str, float]: + if not model.guide_enabled: + return {} + + timings = defaultdict(list) + model.eval() + + for _ in range(iterations): + inputs, h2d_ms = _timed_call(lambda: _to_device(cpu_inputs, device), device=device) + timings["host_to_device_ms"].append(h2d_ms) + + def _guide_backbone(): + with torch.no_grad(): + with _autocast_context(device): + return model.guide_backbone(inputs)[0] + + guide_features, guide_backbone_ms = _timed_call(_guide_backbone, device=device) + timings["guide_backbone_ms"].append(guide_backbone_ms) + + token_mask = model._sample_guide_token_mask(guide_features) + if model.guide_fusion_stage in {"input", "input_gating"}: + guide_mask, tokenbook_ms = _timed_call( + lambda: model.guide_tokenbook(guide_features, token_mask=token_mask), + device=device, + ) + timings["tokenbook_ms"].append(tokenbook_ms) + + def _upsample(): + with _autocast_context(device): + return F.interpolate( + guide_mask, + size=inputs.shape[2:], + mode="trilinear", + align_corners=False, + ) + + guide_for_input, upsample_ms = _timed_call(_upsample, device=device) + timings["guide_upsample_ms"].append(upsample_ms) + + def _segmentation(): + with _autocast_context(device): + guided_input = inputs * guide_for_input + features = model.shared_encoder(guided_input) + logits = model.task_decoders["ink"](features) + return logits + + elif model.guide_fusion_stage == "feature_encoder": + stage_shapes = model._get_encoder_stage_shapes(inputs.shape[2:]) + + def _feature_gates(): + with _autocast_context(device): + feature_gates = {} + for stage_idx, stage_key in enumerate(model.guide_stage_keys or []): + stage_guide = model.guide_stage_tokenbooks[stage_key]( + guide_features, + token_mask=token_mask, + ) + target_shape = stage_shapes[stage_idx] + if stage_guide.shape[2:] != target_shape: + stage_guide = F.interpolate( + stage_guide, + size=target_shape, + mode="trilinear", + align_corners=False, + ) + feature_gates[stage_key] = stage_guide + return feature_gates + + feature_gates, tokenbook_ms = _timed_call(_feature_gates, device=device) + timings["encoder_guide_heads_ms"].append(tokenbook_ms) + + def _segmentation(): + with _autocast_context(device): + features = model.shared_encoder(inputs, feature_gates=feature_gates) + logits = model.task_decoders["ink"](features) + return logits + + elif model.guide_fusion_stage == "feature_skip_concat": + stage_shapes = model._get_encoder_stage_shapes(inputs.shape[2:]) + + def _projected_features(): + with _autocast_context(device): + projected_features = [] + for stage_idx, stage_key in enumerate(model.guide_stage_keys or []): + projected = model.guide_skip_projectors[stage_key](guide_features) + if projected.shape[2:] != stage_shapes[stage_idx]: + projected = F.interpolate( + projected, + size=stage_shapes[stage_idx], + mode="trilinear", + align_corners=False, + ) + projected_features.append(projected) + return projected_features + + projected_features, projector_ms = _timed_call(_projected_features, device=device) + timings["skip_projectors_ms"].append(projector_ms) + + def _segmentation(): + with _autocast_context(device): + features = model.shared_encoder(inputs) + features = [ + torch.cat([stage_feature, projected_feature], dim=1) + for stage_feature, projected_feature in zip(features, projected_features) + ] + logits = model.task_decoders["ink"](features) + return logits + + elif model.guide_fusion_stage == "direct_segmentation": + def _guide_mask(): + with _autocast_context(device): + return model.guide_tokenbook(guide_features, token_mask=token_mask) + + guide_mask, tokenbook_ms = _timed_call(_guide_mask, device=device) + timings["tokenbook_ms"].append(tokenbook_ms) + + def _segmentation(): + with _autocast_context(device): + guide_probability = F.interpolate( + guide_mask, + size=inputs.shape[2:], + mode="trilinear", + align_corners=False, + ).clamp_(1e-6, 1.0 - 1e-6) + fg_logit = torch.logit(guide_probability) + bg_logit = torch.zeros_like(fg_logit) + return torch.cat([bg_logit, fg_logit], dim=1) + + else: # pragma: no cover - defensive path for future guide modes + raise ValueError(f"Unsupported guided benchmark mode: {model.guide_fusion_stage!r}") + + logits, segmentation_ms = _timed_call(_segmentation, device=device) + timings["segmentation_trunk_ms"].append(segmentation_ms) + + def _device_to_host(): + return logits.to(dtype=torch.float16).cpu() + + _host_logits, d2h_ms = _timed_call(_device_to_host, device=device) + timings["device_to_host_ms"].append(d2h_ms) + + return {name: sum(values) / len(values) for name, values in timings.items()} + + +def main() -> None: + parser = argparse.ArgumentParser(description="Benchmark baseline and guided volumetric Dinovol modes.") + parser.add_argument( + "--guide-checkpoint", + type=str, + default="/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt", + help="Local volumetric DINO checkpoint path.", + ) + parser.add_argument( + "--patch-size", + type=str, + action="append", + default=None, + help="3D patch size as z,y,x. Repeat to benchmark multiple sizes.", + ) + parser.add_argument("--warmup", type=int, default=1, help="Reserved for compatibility; startup is measured explicitly.") + parser.add_argument("--iterations", type=int, default=5) + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--guide-tokenbook-tokens", type=str, default="full", help="'full' or integer prototype count.") + parser.add_argument( + "--guide-tokenbook-prototype-weighting", + type=str, + default="mean", + choices=["mean", "token_mlp"], + help="Prototype aggregation mode for guided variants.", + ) + parser.add_argument( + "--guide-tokenbook-weight-mlp-hidden", + type=int, + default=None, + help="Hidden width for token-conditioned prototype weighting MLP.", + ) + parser.add_argument("--skip-compile-variants", action="store_true") + parser.add_argument("--skip-stage-breakdown", action="store_true") + parser.add_argument("--output", type=str, default=None) + args = parser.parse_args() + + device = torch.device(args.device) + patch_sizes = args.patch_size or ["64,64,64"] + patch_sizes = [_parse_patch_size(text) for text in patch_sizes] + + if args.guide_tokenbook_tokens.strip().lower() == "full": + guide_tokenbook_tokens = None + else: + guide_tokenbook_tokens = int(args.guide_tokenbook_tokens) + if guide_tokenbook_tokens <= 0: + raise ValueError("--guide-tokenbook-tokens must be positive or 'full'") + + guide_checkpoint = Path(args.guide_checkpoint) + if not guide_checkpoint.exists(): + raise FileNotFoundError(f"Guide checkpoint not found: {guide_checkpoint}") + + summary: dict[str, object] = { + "device": str(device), + "warmup": args.warmup, + "iterations": args.iterations, + "guide_tokenbook_tokens": _label_name(guide_tokenbook_tokens), + "guide_tokenbook_prototype_weighting": args.guide_tokenbook_prototype_weighting, + "guide_tokenbook_weight_mlp_hidden": args.guide_tokenbook_weight_mlp_hidden, + "results": {}, + } + + pin_memory = device.type == "cuda" + for patch_size in patch_sizes: + patch_label = "x".join(str(v) for v in patch_size) + cpu_inputs, cpu_labels = _make_cpu_batch(patch_size, pin_memory=pin_memory) + guided_eager_model = _build_model( + patch_size=patch_size, + device=device, + guide_checkpoint=str(guide_checkpoint), + guide_tokenbook_tokens=guide_tokenbook_tokens, + guide_fusion_stage="input", + guide_tokenbook_prototype_weighting=args.guide_tokenbook_prototype_weighting, + guide_tokenbook_weight_mlp_hidden=args.guide_tokenbook_weight_mlp_hidden, + compile_model=False, + channels_last_3d=False, + ) + feature_encoder_eager_model = _build_model( + patch_size=patch_size, + device=device, + guide_checkpoint=str(guide_checkpoint), + guide_tokenbook_tokens=guide_tokenbook_tokens, + guide_fusion_stage="feature_encoder", + guide_tokenbook_prototype_weighting=args.guide_tokenbook_prototype_weighting, + guide_tokenbook_weight_mlp_hidden=args.guide_tokenbook_weight_mlp_hidden, + compile_model=False, + channels_last_3d=False, + ) + feature_skip_concat_eager_model = _build_model( + patch_size=patch_size, + device=device, + guide_checkpoint=str(guide_checkpoint), + guide_tokenbook_tokens=guide_tokenbook_tokens, + guide_fusion_stage="feature_skip_concat", + guide_tokenbook_prototype_weighting=args.guide_tokenbook_prototype_weighting, + guide_tokenbook_weight_mlp_hidden=args.guide_tokenbook_weight_mlp_hidden, + compile_model=False, + channels_last_3d=False, + ) + direct_segmentation_eager_model = _build_model( + patch_size=patch_size, + device=device, + guide_checkpoint=str(guide_checkpoint), + guide_tokenbook_tokens=guide_tokenbook_tokens, + guide_fusion_stage="direct_segmentation", + guide_tokenbook_prototype_weighting=args.guide_tokenbook_prototype_weighting, + guide_tokenbook_weight_mlp_hidden=args.guide_tokenbook_weight_mlp_hidden, + compile_model=False, + channels_last_3d=False, + ) + + patch_summary: dict[str, object] = {} + variants = { + "baseline_eager": _build_model( + patch_size=patch_size, + device=device, + guide_checkpoint=None, + guide_tokenbook_tokens=None, + guide_fusion_stage=None, + guide_tokenbook_prototype_weighting="mean", + guide_tokenbook_weight_mlp_hidden=None, + compile_model=False, + channels_last_3d=False, + ), + "guided_input_eager": guided_eager_model, + "guided_feature_encoder_eager": feature_encoder_eager_model, + "guided_feature_skip_concat_eager": feature_skip_concat_eager_model, + "guided_direct_segmentation_eager": direct_segmentation_eager_model, + } + if not args.skip_compile_variants: + variants["guided_input_compile"] = _build_model( + patch_size=patch_size, + device=device, + guide_checkpoint=str(guide_checkpoint), + guide_tokenbook_tokens=guide_tokenbook_tokens, + guide_fusion_stage="input", + guide_tokenbook_prototype_weighting=args.guide_tokenbook_prototype_weighting, + guide_tokenbook_weight_mlp_hidden=args.guide_tokenbook_weight_mlp_hidden, + compile_model=True, + channels_last_3d=False, + ) + variants["guided_input_compile_channels_last_3d"] = _build_model( + patch_size=patch_size, + device=device, + guide_checkpoint=str(guide_checkpoint), + guide_tokenbook_tokens=guide_tokenbook_tokens, + guide_fusion_stage="input", + guide_tokenbook_prototype_weighting=args.guide_tokenbook_prototype_weighting, + guide_tokenbook_weight_mlp_hidden=args.guide_tokenbook_weight_mlp_hidden, + compile_model=True, + channels_last_3d=True, + ) + variants["guided_feature_encoder_compile"] = _build_model( + patch_size=patch_size, + device=device, + guide_checkpoint=str(guide_checkpoint), + guide_tokenbook_tokens=guide_tokenbook_tokens, + guide_fusion_stage="feature_encoder", + guide_tokenbook_prototype_weighting=args.guide_tokenbook_prototype_weighting, + guide_tokenbook_weight_mlp_hidden=args.guide_tokenbook_weight_mlp_hidden, + compile_model=True, + channels_last_3d=False, + ) + variants["guided_feature_encoder_compile_channels_last_3d"] = _build_model( + patch_size=patch_size, + device=device, + guide_checkpoint=str(guide_checkpoint), + guide_tokenbook_tokens=guide_tokenbook_tokens, + guide_fusion_stage="feature_encoder", + guide_tokenbook_prototype_weighting=args.guide_tokenbook_prototype_weighting, + guide_tokenbook_weight_mlp_hidden=args.guide_tokenbook_weight_mlp_hidden, + compile_model=True, + channels_last_3d=True, + ) + variants["guided_feature_skip_concat_compile"] = _build_model( + patch_size=patch_size, + device=device, + guide_checkpoint=str(guide_checkpoint), + guide_tokenbook_tokens=guide_tokenbook_tokens, + guide_fusion_stage="feature_skip_concat", + guide_tokenbook_prototype_weighting=args.guide_tokenbook_prototype_weighting, + guide_tokenbook_weight_mlp_hidden=args.guide_tokenbook_weight_mlp_hidden, + compile_model=True, + channels_last_3d=False, + ) + variants["guided_feature_skip_concat_compile_channels_last_3d"] = _build_model( + patch_size=patch_size, + device=device, + guide_checkpoint=str(guide_checkpoint), + guide_tokenbook_tokens=guide_tokenbook_tokens, + guide_fusion_stage="feature_skip_concat", + guide_tokenbook_prototype_weighting=args.guide_tokenbook_prototype_weighting, + guide_tokenbook_weight_mlp_hidden=args.guide_tokenbook_weight_mlp_hidden, + compile_model=True, + channels_last_3d=True, + ) + variants["guided_direct_segmentation_compile"] = _build_model( + patch_size=patch_size, + device=device, + guide_checkpoint=str(guide_checkpoint), + guide_tokenbook_tokens=guide_tokenbook_tokens, + guide_fusion_stage="direct_segmentation", + guide_tokenbook_prototype_weighting=args.guide_tokenbook_prototype_weighting, + guide_tokenbook_weight_mlp_hidden=args.guide_tokenbook_weight_mlp_hidden, + compile_model=True, + channels_last_3d=False, + ) + variants["guided_direct_segmentation_compile_channels_last_3d"] = _build_model( + patch_size=patch_size, + device=device, + guide_checkpoint=str(guide_checkpoint), + guide_tokenbook_tokens=guide_tokenbook_tokens, + guide_fusion_stage="direct_segmentation", + guide_tokenbook_prototype_weighting=args.guide_tokenbook_prototype_weighting, + guide_tokenbook_weight_mlp_hidden=args.guide_tokenbook_weight_mlp_hidden, + compile_model=True, + channels_last_3d=True, + ) + + if not args.skip_stage_breakdown: + patch_summary["guided_input_eager_stage_breakdown_ms"] = _profile_guided_stages( + guided_eager_model, + device=device, + cpu_inputs=cpu_inputs.detach().clone(), + iterations=args.iterations, + ) + patch_summary["guided_feature_encoder_eager_stage_breakdown_ms"] = _profile_guided_stages( + feature_encoder_eager_model, + device=device, + cpu_inputs=cpu_inputs.detach().clone(), + iterations=args.iterations, + ) + patch_summary["guided_feature_skip_concat_eager_stage_breakdown_ms"] = _profile_guided_stages( + feature_skip_concat_eager_model, + device=device, + cpu_inputs=cpu_inputs.detach().clone(), + iterations=args.iterations, + ) + patch_summary["guided_direct_segmentation_eager_stage_breakdown_ms"] = _profile_guided_stages( + direct_segmentation_eager_model, + device=device, + cpu_inputs=cpu_inputs.detach().clone(), + iterations=args.iterations, + ) + + for variant_name, model in variants.items(): + variant_inputs = cpu_inputs + if variant_name.endswith("channels_last_3d"): + variant_inputs = cpu_inputs.contiguous(memory_format=torch.channels_last_3d) + patch_summary[variant_name] = _run_variant_benchmark( + model, + device=device, + cpu_inputs=variant_inputs, + cpu_labels=cpu_labels, + iterations=args.iterations, + ) + summary["results"][patch_label] = patch_summary + + payload = json.dumps(summary, indent=2, sort_keys=True) + if args.output is not None: + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(payload + "\n", encoding="utf-8") + print(payload) + + +if __name__ == "__main__": + main() diff --git a/vesuvius/src/vesuvius/models/benchmarks/benchmark_mednext.py b/vesuvius/src/vesuvius/models/benchmarks/benchmark_mednext.py new file mode 100644 index 000000000..e67257777 --- /dev/null +++ b/vesuvius/src/vesuvius/models/benchmarks/benchmark_mednext.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +import argparse +import io +import json +import subprocess +import sys +import time +from contextlib import nullcontext, redirect_stdout +from pathlib import Path +from types import SimpleNamespace + +import torch +import torch.nn.functional as F + +from vesuvius.models.build.build_network_from_config import NetworkFromConfig + + +def _parse_patch_sizes(values: list[str]) -> list[tuple[int, int, int]]: + patch_sizes = [] + for value in values: + dims = tuple(int(v) for v in value.split(",")) + if len(dims) != 3: + raise ValueError(f"Expected 3 comma-separated values, got {value!r}") + patch_sizes.append(dims) + return patch_sizes + + +def _sync(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.synchronize(device) + + +def _autocast(device: torch.device): + if device.type == "cuda": + return torch.amp.autocast("cuda") + return nullcontext() + + +def _make_unet_mgr(patch_size: tuple[int, int, int]) -> SimpleNamespace: + return SimpleNamespace( + targets={"surface": {"out_channels": 2, "activation": "none"}}, + train_patch_size=patch_size, + train_batch_size=1, + in_channels=1, + autoconfigure=False, + model_name="mednext_benchmark_unet", + enable_deep_supervision=False, + spacing=(1.0, 1.0, 1.0), + model_config={ + "features_per_stage": [32, 64, 128, 256, 320], + "n_stages": 5, + "n_blocks_per_stage": [1, 2, 2, 2, 2], + "n_conv_per_stage_decoder": [1, 1, 1, 1], + "kernel_sizes": [[3, 3, 3]] * 5, + "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]], + "basic_encoder_block": "ConvBlock", + "basic_decoder_block": "ConvBlock", + "bottleneck_block": "BasicBlockD", + "separate_decoders": True, + }, + ) + + +def _make_mednext_mgr( + patch_size: tuple[int, int, int], + *, + architecture_type: str, + model_id: str, + width_factor: int = 1, +) -> SimpleNamespace: + return SimpleNamespace( + targets={"surface": {"out_channels": 2, "activation": "none"}}, + train_patch_size=patch_size, + train_batch_size=1, + in_channels=1, + autoconfigure=False, + model_name=f"mednext_benchmark_{architecture_type}_{model_id.lower()}_w{width_factor}", + enable_deep_supervision=False, + spacing=(1.0, 1.0, 1.0), + model_config={ + "architecture_type": architecture_type, + "mednext_model_id": model_id, + "mednext_kernel_size": 3, + "mednext_width_factor": width_factor, + }, + ) + + +def _build_model( + variant: str, + patch_size: tuple[int, int, int], + model_id: str, + width_factor: int, + device: torch.device, +) -> torch.nn.Module: + if variant == "unet": + mgr = _make_unet_mgr(patch_size) + elif variant in {"mednext_v1", "mednext_v2"}: + mgr = _make_mednext_mgr( + patch_size, + architecture_type=variant, + model_id=model_id, + width_factor=width_factor, + ) + else: + raise ValueError(f"Unknown benchmark variant {variant!r}") + return NetworkFromConfig(mgr).to(device) + + +def _timed_call(fn, *, device: torch.device): + _sync(device) + start = time.perf_counter() + result = fn() + _sync(device) + return result, (time.perf_counter() - start) * 1000.0 + + +def _run_variant( + variant: str, + *, + patch_size: tuple[int, int, int], + model_id: str, + width_factor: int, + device: torch.device, + iterations: int, +) -> dict[str, float | str | None]: + model = _build_model(variant, patch_size, model_id, width_factor, device) + inputs = torch.randn(1, 1, *patch_size, device=device) + labels = torch.randint(0, 2, (1, *patch_size), device=device) + optimizer = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=1e-4) + + if device.type == "cuda": + torch.cuda.reset_peak_memory_stats(device) + + try: + _, startup_ms = _timed_call(lambda: model(inputs), device=device) + except Exception as exc: # pragma: no cover + return {"startup_forward_ms": None, "steady_state_forward_ms_mean": None, "steady_state_train_step_ms_mean": None, "max_memory_mb": None, "error": f"{type(exc).__name__}: {str(exc).splitlines()[0]}"} + + forward_times: list[float] = [] + train_times: list[float] = [] + train_step_error: str | None = None + + model.train() + for _ in range(iterations): + def _forward_only(): + with _autocast(device): + return model(inputs) + + _, forward_ms = _timed_call(_forward_only, device=device) + forward_times.append(forward_ms) + + def _train_step(): + optimizer.zero_grad(set_to_none=True) + with _autocast(device): + logits = model(inputs)["surface"] + loss = F.cross_entropy(logits.float(), labels) + loss.backward() + optimizer.step() + + try: + _, train_ms = _timed_call(_train_step, device=device) + train_times.append(train_ms) + except torch.OutOfMemoryError as exc: # pragma: no cover + train_step_error = f"{type(exc).__name__}: {str(exc).splitlines()[0]}" + break + + result = { + "startup_forward_ms": startup_ms, + "steady_state_forward_ms_mean": sum(forward_times) / len(forward_times) if forward_times else None, + "steady_state_train_step_ms_mean": sum(train_times) / len(train_times) if train_times else None, + "max_memory_mb": torch.cuda.max_memory_allocated(device) / (1024 ** 2) if device.type == "cuda" else None, + } + if train_step_error is not None: + result["train_step_error"] = train_step_error + return result + + +def main() -> None: + parser = argparse.ArgumentParser(description="Benchmark UNet vs MedNeXt whole-model performance") + parser.add_argument("--patch-size", action="append", default=["128,128,128", "192,192,192"], help="3 comma-separated ints; can be passed multiple times") + parser.add_argument("--iterations", type=int, default=3) + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--mednext-model-id", default="B") + parser.add_argument("--output-json", type=Path, default=None) + parser.add_argument("--single-variant", default=None) + parser.add_argument("--single-patch", default=None) + parser.add_argument("--single-width-factor", type=int, default=1) + args = parser.parse_args() + + device = torch.device(args.device) + + if args.single_variant is not None: + if args.single_patch is None: + raise ValueError("--single-patch is required when --single-variant is set") + patch_size = _parse_patch_sizes([args.single_patch])[0] + capture = io.StringIO() + with redirect_stdout(capture): + result = _run_variant( + args.single_variant, + patch_size=patch_size, + model_id=args.mednext_model_id, + width_factor=args.single_width_factor, + device=device, + iterations=args.iterations, + ) + print(json.dumps(result)) + return + + patch_sizes = _parse_patch_sizes(args.patch_size) + variants = [ + ("unet", "B", 1, "unet"), + ("mednext_v1", "B", 1, "mednext_v1_b"), + ("mednext_v2", "L", 1, "mednext_v2_l"), + ("mednext_v2", "L", 2, "mednext_v2_l_width2"), + ("mednext_v2", "B", 1, "mednext_v2_b"), + ] + + results = {} + for patch_size in patch_sizes: + patch_key = "x".join(str(v) for v in patch_size) + patch_arg = ",".join(str(v) for v in patch_size) + patch_results = {} + for variant, model_id, width_factor, label in variants: + cmd = [ + sys.executable, + "-m", + "vesuvius.models.benchmarks.benchmark_mednext", + "--device", + str(device), + "--iterations", + str(args.iterations), + "--single-variant", + variant, + "--single-patch", + patch_arg, + "--mednext-model-id", + model_id, + "--single-width-factor", + str(width_factor), + ] + completed = subprocess.run(cmd, check=True, capture_output=True, text=True) + patch_results[label] = json.loads(completed.stdout) + results[patch_key] = patch_results + + payload = { + "device": str(device), + "results": results, + } + print(json.dumps(payload, indent=2)) + if args.output_json is not None: + args.output_json.write_text(json.dumps(payload, indent=2), encoding="utf-8") + + +if __name__ == "__main__": + main() diff --git a/vesuvius/src/vesuvius/models/build/build_network_from_config.py b/vesuvius/src/vesuvius/models/build/build_network_from_config.py index d3f5974d3..b6232756d 100644 --- a/vesuvius/src/vesuvius/models/build/build_network_from_config.py +++ b/vesuvius/src/vesuvius/models/build/build_network_from_config.py @@ -43,13 +43,21 @@ https://github.com/MIC-DKFZ/nnUNet """ +import math + import torch import torch.nn as nn -from ..utilities.utils import get_pool_and_conv_props, get_n_blocks_per_stage +import torch.nn.functional as F +from ..utilities.utils import get_pool_and_conv_props, get_n_blocks_per_stage, pad_shape from .encoder import Encoder from .decoder import Decoder from .activations import SwiGLUBlock, GLUBlock +from .simple_conv_blocks import ConvDropoutNormReLU +from .guidance import TokenBook3D from .primus_wrapper import PrimusEncoder, PrimusDecoder +from .mednext.factory import get_mednext_v1_config +from .mednext_wrapper import MedNeXtEncoder, MedNeXtDecoder +from .pretrained_backbones.dinov2 import build_dinov2_backbone, build_dinov2_decoder class LearnedMLPZProjection(nn.Module): @@ -243,9 +251,78 @@ def __init__(self, mgr): self.save_config = False self.architecture_type = model_config.get("architecture_type", "unet") + self.pretrained_backbone = model_config.get("pretrained_backbone") + self.guide_backbone_name = model_config.get("guide_backbone") + self.guide_backbone_config_path = model_config.get("guide_backbone_config_path") + self.guide_fusion_stage = str(model_config.get("guide_fusion_stage", "input")).strip().lower() + self.guide_enabled = bool(self.guide_backbone_name) + self.guide_backbone = None + self.guide_tokenbook = None + self.guide_stage_tokenbooks = nn.ModuleDict() + self.guide_skip_projectors = nn.ModuleDict() + self.guide_stage_keys = None + self.guide_patch_grid = None + self.guide_skip_concat_projector_channels = None + self.guide_feature_gate_alpha = 1.0 + self.guide_direct_output_mode = None + self.direct_segmentation_target_name = None + self.freeze_encoder = bool(model_config.get("freeze_encoder", False)) + self.guide_freeze = bool(model_config.get("guide_freeze", True)) + guide_compile_policy = str(model_config.get("guide_compile_policy", "off")).strip().lower() + if guide_compile_policy not in {"off", "backbone_only", "tokenbook_only", "all_guidance"}: + raise ValueError( + "guide_compile_policy must be one of " + "{'off', 'backbone_only', 'tokenbook_only', 'all_guidance'}" + ) + self.guide_compile_policy = guide_compile_policy # Determine if deep supervision is requested ds_enabled = bool(getattr(mgr, 'enable_deep_supervision', False)) + if self.guide_enabled: + if self.op_dims != 3: + raise ValueError("guide_backbone is only supported for 3D models in v1") + if self.pretrained_backbone: + raise ValueError("guide_backbone cannot be combined with pretrained_backbone in v1") + if self.architecture_type.lower().startswith("primus"): + raise ValueError("guide_backbone is not supported with primus architectures in v1") + if self.architecture_type.lower().startswith("mednext"): + raise ValueError("guide_backbone is not supported with mednext architectures in v1") + if self.guide_fusion_stage not in {"input", "input_gating", "feature_encoder", "feature_skip_concat", "direct_segmentation"}: + raise ValueError( + "Unsupported guide_fusion_stage. Expected one of " + "{'input', 'input_gating', 'feature_encoder', 'feature_skip_concat', 'direct_segmentation'}." + ) + if self.guide_fusion_stage == "feature_encoder": + gate_alpha = float(model_config.get("guide_feature_gate_alpha", 1.0)) + if not (0.0 <= gate_alpha <= 1.0): + raise ValueError( + f"guide_feature_gate_alpha must be in [0, 1], got {gate_alpha}" + ) + self.guide_feature_gate_alpha = gate_alpha + if self.guide_fusion_stage == "direct_segmentation" and ds_enabled: + raise ValueError("direct_segmentation does not support deep supervision") + + if self.architecture_type.lower().startswith("mednext") and self.pretrained_backbone: + raise ValueError("pretrained_backbone cannot be combined with mednext architectures in v1") + + if self.pretrained_backbone: + if ds_enabled: + print( + "Warning: Deep supervision is enabled but the selected pretrained backbone path does not " + "support multi-scale logits. Disabling deep supervision for this run." + ) + setattr(mgr, "enable_deep_supervision", False) + self._init_pretrained_backbone(mgr, model_config) + return + + if self._guide_uses_direct_segmentation(): + self._init_direct_segmentation(model_config) + return + + if self.architecture_type.lower().startswith("mednext"): + self._init_mednext(mgr, model_config) + return + # Primus decoders do not emit multi-scale logits; block DS to avoid silent misconfiguration if self.architecture_type.lower().startswith("primus"): if ds_enabled: @@ -553,6 +630,8 @@ def __init__(self, mgr): squeeze_excitation_add_maxpool=model_config.get("squeeze_excitation_add_maxpool", False), pool_type=model_config.get("pool_type", "conv") ) + if self._guide_uses_skip_feature_concat(): + self.guide_skip_concat_projector_channels = self._resolve_skip_concat_projector_channels(model_config) self.task_decoders = nn.ModuleDict() self.task_activations = nn.ModuleDict() self.task_heads = nn.ModuleDict() @@ -593,11 +672,13 @@ def __init__(self, mgr): # If at least one task uses shared, build a single shared decoder trunk (features-only) if len(tasks_using_shared) > 0: + shared_decoder_skip_channels = self._decoder_skip_channels() self.shared_decoder = Decoder( encoder=self.shared_encoder, basic_block=model_config.get("basic_decoder_block", "ConvBlock"), num_classes=None, # features-only mode n_conv_per_stage=model_config.get("n_conv_per_stage_decoder", [1] * (self.num_stages - 1)), + skip_channels=shared_decoder_skip_channels, deep_supervision=False ) # Heads map from decoder feature channels at highest resolution to task outputs @@ -613,16 +694,20 @@ def __init__(self, mgr): for target_name in sorted(tasks_using_separate): out_channels = self.targets[target_name]["out_channels"] activation_str = self.targets[target_name].get("activation", "none") + decoder_skip_channels = self._decoder_skip_channels() self.task_decoders[target_name] = Decoder( encoder=self.shared_encoder, basic_block=model_config.get("basic_decoder_block", "ConvBlock"), num_classes=out_channels, n_conv_per_stage=model_config.get("n_conv_per_stage_decoder", [1] * (self.num_stages - 1)), + skip_channels=decoder_skip_channels, deep_supervision=False ) self.task_activations[target_name] = get_activation_module(activation_str) print(f"Task '{target_name}' configured with separate decoder ({out_channels} channels)") + self._init_guidance(model_config) + # -------------------------------------------------------------------- # Build final configuration snapshot. # -------------------------------------------------------------------- @@ -666,13 +751,567 @@ def __init__(self, mgr): "target_z_projection": self.task_z_projection_cfg, "separate_decoders": len(tasks_using_separate) > 0, "num_pool_per_axis": getattr(self, 'num_pool_per_axis', None), - "must_be_divisible_by": getattr(self, 'must_be_divisible_by', None) + "must_be_divisible_by": getattr(self, 'must_be_divisible_by', None), + "guide_backbone": self.guide_backbone_name, + "guide_backbone_config_path": self.guide_backbone_config_path, + "guide_freeze": self.guide_freeze, + "guide_patch_grid": self.guide_patch_grid, + "guide_stage_keys": list(self.guide_stage_keys) if self.guide_stage_keys is not None else None, + "guide_skip_concat_projector_channels": ( + list(self.guide_skip_concat_projector_channels) + if self._guide_uses_skip_feature_concat() and self.guide_skip_concat_projector_channels is not None + else None + ), + "guide_feature_gate_alpha": self.guide_feature_gate_alpha if self._guide_uses_encoder_feature_gating() else None, + "guide_tokenbook_tokens": getattr(self, "guide_tokenbook_tokens", None) if not self._guide_uses_skip_feature_concat() else None, + "guide_tokenbook_prototype_weighting": getattr(self, "guide_tokenbook_prototype_weighting", "mean") if not self._guide_uses_skip_feature_concat() else None, + "guide_tokenbook_weight_mlp_hidden": getattr(self, "guide_tokenbook_weight_mlp_hidden", None) if not self._guide_uses_skip_feature_concat() else None, + "guide_compile_policy": self.guide_compile_policy if self.guide_enabled else "off", + "guide_fusion_stage": self.guide_fusion_stage if self.guide_enabled else None, } print("NetworkFromConfig initialized with final configuration:") for k, v in self.final_config.items(): print(f" {k}: {v}") + + def _guide_uses_input_gating(self) -> bool: + return self.guide_enabled and self.guide_fusion_stage in {"input", "input_gating"} + + def _guide_uses_encoder_feature_gating(self) -> bool: + return self.guide_enabled and self.guide_fusion_stage == "feature_encoder" + + def _guide_uses_skip_feature_concat(self) -> bool: + return self.guide_enabled and self.guide_fusion_stage == "feature_skip_concat" + + def _guide_uses_direct_segmentation(self) -> bool: + return self.guide_enabled and self.guide_fusion_stage == "direct_segmentation" + + def _resolve_skip_concat_projector_channels(self, model_config): + native_skip_channels = [int(ch) for ch in self.shared_encoder.output_channels] + configured = model_config.get("guide_skip_concat_projector_channels") + if configured is None: + return native_skip_channels + + projector_channels = [int(ch) for ch in configured] + if len(projector_channels) != len(native_skip_channels): + raise ValueError( + "guide_skip_concat_projector_channels must have as many entries as encoder stages, " + f"got {len(projector_channels)} and {len(native_skip_channels)}." + ) + if any(ch <= 0 for ch in projector_channels): + raise ValueError( + "guide_skip_concat_projector_channels entries must be > 0, " + f"got {projector_channels}." + ) + return projector_channels + + def _decoder_skip_channels(self): + if self._guide_uses_skip_feature_concat(): + projector_channels = self.guide_skip_concat_projector_channels + if projector_channels is None: + projector_channels = [int(ch) for ch in self.shared_encoder.output_channels] + return [ + int(native_ch) + int(projected_ch) + for native_ch, projected_ch in zip(self.shared_encoder.output_channels, projector_channels) + ] + return list(self.shared_encoder.output_channels) + + def _init_guidance(self, model_config): + if not self.guide_enabled: + return + + input_shape = tuple(model_config.get("input_shape", self.patch_size)) + self.guide_backbone = build_dinov2_backbone( + self.guide_backbone_name, + input_channels=self.in_channels, + input_shape=input_shape, + config_path=self.guide_backbone_config_path, + ) + if getattr(self.guide_backbone, "ndim", None) != 3: + raise ValueError("guide_backbone must resolve to a 3D Dinovol backbone") + + patch_embed_size = tuple(int(v) for v in self.guide_backbone.patch_embed_size) + if any(size % patch != 0 for size, patch in zip(input_shape, patch_embed_size)): + raise ValueError( + f"Configured input_shape {input_shape} must be divisible by guide patch size {patch_embed_size}" + ) + + self.guide_patch_grid = tuple(size // patch for size, patch in zip(input_shape, patch_embed_size)) + self.guide_tokenbook_tokens = None + self.guide_tokenbook_prototype_weighting = "mean" + self.guide_tokenbook_weight_mlp_hidden = None + if self._guide_uses_input_gating() or self._guide_uses_direct_segmentation(): + default_token_count = int(math.prod(self.guide_patch_grid)) + configured_token_count = model_config.get("guide_tokenbook_tokens") + if configured_token_count is None: + token_count = default_token_count + else: + token_count = int(configured_token_count) + if token_count <= 0: + raise ValueError(f"guide_tokenbook_tokens must be > 0, got {token_count}") + self.guide_tokenbook_tokens = token_count + prototype_weighting = str( + model_config.get("guide_tokenbook_prototype_weighting", "mean") + ).strip().lower() + self.guide_tokenbook_prototype_weighting = prototype_weighting + self.guide_tokenbook_weight_mlp_hidden = model_config.get("guide_tokenbook_weight_mlp_hidden") + tokenbook_kwargs = { + "n_tokens": token_count, + "embed_dim": int(self.guide_backbone.embed_dim), + "dropout": float(model_config.get("guide_tokenbook_dropout", 0.0)), + "ema_decay": model_config.get("guide_tokenbook_ema_decay"), + "use_ema": bool(model_config.get("guide_tokenbook_use_ema", False)), + "prototype_weighting": prototype_weighting, + "weight_mlp_hidden": self.guide_tokenbook_weight_mlp_hidden, + } + self.guide_tokenbook = TokenBook3D(**tokenbook_kwargs) + self.guide_stage_keys = None + elif self._guide_uses_encoder_feature_gating(): + default_token_count = int(math.prod(self.guide_patch_grid)) + configured_token_count = model_config.get("guide_tokenbook_tokens") + if configured_token_count is None: + token_count = default_token_count + else: + token_count = int(configured_token_count) + if token_count <= 0: + raise ValueError(f"guide_tokenbook_tokens must be > 0, got {token_count}") + self.guide_tokenbook_tokens = token_count + prototype_weighting = str( + model_config.get("guide_tokenbook_prototype_weighting", "mean") + ).strip().lower() + self.guide_tokenbook_prototype_weighting = prototype_weighting + self.guide_tokenbook_weight_mlp_hidden = model_config.get("guide_tokenbook_weight_mlp_hidden") + tokenbook_kwargs = { + "n_tokens": token_count, + "embed_dim": int(self.guide_backbone.embed_dim), + "dropout": float(model_config.get("guide_tokenbook_dropout", 0.0)), + "ema_decay": model_config.get("guide_tokenbook_ema_decay"), + "use_ema": bool(model_config.get("guide_tokenbook_use_ema", False)), + "prototype_weighting": prototype_weighting, + "weight_mlp_hidden": self.guide_tokenbook_weight_mlp_hidden, + } + self.guide_stage_keys = [f"enc_{stage_idx}" for stage_idx in range(self.num_stages)] + self.guide_stage_tokenbooks = nn.ModuleDict( + { + stage_key: TokenBook3D(**tokenbook_kwargs) + for stage_key in self.guide_stage_keys + } + ) + elif self._guide_uses_skip_feature_concat(): + self.guide_stage_keys = [f"enc_{stage_idx}" for stage_idx in range(self.num_stages)] + self.guide_skip_projectors = nn.ModuleDict( + { + stage_key: ConvDropoutNormReLU( + conv_op=self.conv_op, + input_channels=int(self.guide_backbone.embed_dim), + output_channels=int(self.guide_skip_concat_projector_channels[stage_idx]), + kernel_size=1, + stride=1, + conv_bias=False, + norm_op=self.norm_op, + norm_op_kwargs=self.norm_op_kwargs, + dropout_op=None, + dropout_op_kwargs=None, + nonlin=self.nonlin, + nonlin_kwargs=self.nonlin_kwargs, + ) + for stage_idx, stage_key in enumerate(self.guide_stage_keys) + } + ) + self.guide_tokenbook_sample_rate = float(model_config.get("guide_tokenbook_sample_rate", 1.0)) + + if self.guide_freeze: + for parameter in self.guide_backbone.parameters(): + parameter.requires_grad_(False) + self.guide_backbone.eval() + + def _sample_guide_token_mask(self, guide_features): + if self.guide_tokenbook_sample_rate >= 1.0: + return None + + batch_size = guide_features.shape[0] + token_count = int(guide_features.shape[2] * guide_features.shape[3] * guide_features.shape[4]) + mask = torch.rand(batch_size, token_count, device=guide_features.device) < self.guide_tokenbook_sample_rate + if mask.sum(dim=1).min().item() == 0: + random_indices = torch.randint(0, token_count, (batch_size,), device=guide_features.device) + mask[torch.arange(batch_size, device=guide_features.device), random_indices] = True + return mask + + @staticmethod + def _compile_module_in_place(module: nn.Module) -> nn.Module: + compile_method = getattr(module, "compile", None) + if callable(compile_method): + compile_method() + return module + return torch.compile(module) + + def _compile_guidance_submodules(self, *, device_type: str) -> list[str]: + if not self.guide_enabled or device_type != "cuda": + return [] + + compiled_modules: list[str] = [] + compile_backbone = self.guide_compile_policy in {"backbone_only", "all_guidance"} + compile_tokenbooks = self.guide_compile_policy in {"tokenbook_only", "all_guidance"} + + if compile_backbone and self.guide_backbone is not None: + self.guide_backbone = self._compile_module_in_place(self.guide_backbone) + compiled_modules.append("guide_backbone") + + if compile_tokenbooks and self.guide_tokenbook is not None: + self.guide_tokenbook = self._compile_module_in_place(self.guide_tokenbook) + compiled_modules.append("guide_tokenbook") + + if compile_tokenbooks and len(self.guide_stage_tokenbooks) > 0: + for stage_key in list(self.guide_stage_tokenbooks.keys()): + self.guide_stage_tokenbooks[stage_key] = self._compile_module_in_place( + self.guide_stage_tokenbooks[stage_key] + ) + compiled_modules.append(f"guide_tokenbook:{stage_key}") + if compile_tokenbooks and len(self.guide_skip_projectors) > 0: + for stage_key in list(self.guide_skip_projectors.keys()): + self.guide_skip_projectors[stage_key] = self._compile_module_in_place( + self.guide_skip_projectors[stage_key] + ) + compiled_modules.append(f"guide_projector:{stage_key}") + return compiled_modules + + @torch.compiler.disable(reason="guided backbone compile failure") + def _compute_guide_features(self, x): + if self.guide_freeze: + with torch.inference_mode(): + frozen_features = self.guide_backbone(x)[0] + return frozen_features.clone() + return self.guide_backbone(x)[0] + + def _build_direct_segmentation_outputs(self, x): + if not self._guide_uses_direct_segmentation(): + return {}, {} + + guide_features = self._compute_guide_features(x) + token_mask = self._sample_guide_token_mask(guide_features) + guide_mask = self.guide_tokenbook(guide_features, token_mask=token_mask) + guide_probability = F.interpolate( + guide_mask, + size=x.shape[2:], + mode="trilinear", + align_corners=False, + ).clamp_(1e-6, 1.0 - 1e-6) + fg_logit = torch.logit(guide_probability) + bg_logit = torch.zeros_like(fg_logit) + logits = torch.cat([bg_logit, fg_logit], dim=1) + logits = self._apply_z_projection(self.direct_segmentation_target_name, logits) + + results = {self.direct_segmentation_target_name: logits} + if self.direct_segmentation_target_name in self.task_activations and not self.training: + activation_fn = self.task_activations[self.direct_segmentation_target_name] + if activation_fn is not None: + results[self.direct_segmentation_target_name] = activation_fn(logits) + return results, {"guide_mask": guide_mask} + + @torch.compiler.disable(reason="guided backbone compile failure") + def _apply_input_guidance(self, x): + if not self.guide_enabled: + return x, {} + + guide_features = self._compute_guide_features(x) + token_mask = self._sample_guide_token_mask(guide_features) + guide_mask = self.guide_tokenbook(guide_features, token_mask=token_mask) + guide_for_input = F.interpolate( + guide_mask, + size=x.shape[2:], + mode="trilinear", + align_corners=False, + ) + return x * guide_for_input, {"guide_mask": guide_mask} + + def _get_encoder_stage_shapes(self, input_spatial_shape): + current_shape = [int(v) for v in input_spatial_shape] + stage_shapes = [] + for stride in self.shared_encoder.strides: + current_shape = [ + max(1, math.ceil(size / int(step))) + for size, step in zip(current_shape, stride) + ] + stage_shapes.append(tuple(current_shape)) + return stage_shapes + + @torch.compiler.disable(reason="guided backbone compile failure") + def _build_encoder_feature_guidance(self, x): + if not self._guide_uses_encoder_feature_gating(): + return {}, {} + + guide_features = self._compute_guide_features(x) + token_mask = self._sample_guide_token_mask(guide_features) + stage_shapes = self._get_encoder_stage_shapes(x.shape[2:]) + + feature_gates = {} + aux_outputs = {} + for stage_idx, stage_key in enumerate(self.guide_stage_keys or []): + stage_guide = self.guide_stage_tokenbooks[stage_key](guide_features, token_mask=token_mask) + target_shape = stage_shapes[stage_idx] + if stage_guide.shape[2:] != target_shape: + stage_guide = F.interpolate( + stage_guide, + size=target_shape, + mode="trilinear", + align_corners=False, + ) + feature_gates[stage_key] = stage_guide + aux_outputs[stage_key] = stage_guide + return feature_gates, aux_outputs + + @torch.compiler.disable(reason="guided backbone compile failure") + def _build_skip_concat_features(self, x, encoder_features): + if not self._guide_uses_skip_feature_concat(): + return encoder_features, {} + + guide_features = self._compute_guide_features(x) + augmented_features = [] + for stage_idx, (stage_key, stage_feature) in enumerate(zip(self.guide_stage_keys or [], encoder_features)): + projected = self.guide_skip_projectors[stage_key](guide_features) + if projected.shape[2:] != stage_feature.shape[2:]: + projected = F.interpolate( + projected, + size=stage_feature.shape[2:], + mode="trilinear", + align_corners=False, + ) + augmented_features.append(torch.cat([stage_feature, projected], dim=1)) + return augmented_features, {} + def _init_pretrained_backbone(self, mgr, model_config): + print(f"--- Initializing pretrained backbone '{self.pretrained_backbone}' ---") + input_shape = tuple(model_config.get("input_shape", self.patch_size)) + decoder_type = model_config.get("pretrained_decoder_type", "primus_patch_decode") + config_path = model_config.get("pretrained_backbone_config_path") + + self.shared_encoder = build_dinov2_backbone( + self.pretrained_backbone, + input_channels=self.in_channels, + input_shape=input_shape, + config_path=config_path, + ) + if self.freeze_encoder: + for parameter in self.shared_encoder.parameters(): + parameter.requires_grad_(False) + self.shared_encoder.eval() + print("Encoder frozen — weights will not be updated during training") + self.task_decoders = nn.ModuleDict() + self.task_activations = nn.ModuleDict() + self.task_heads = nn.ModuleDict() + + separate_decoders_default = model_config.get("separate_decoders", True) + decoder_head_channels = model_config.get("decoder_head_channels", 32) + tasks_using_shared, tasks_using_separate = set(), set() + + for target_name, target_info in self.targets.items(): + if 'out_channels' in target_info: + out_channels = target_info['out_channels'] + elif 'channels' in target_info: + out_channels = target_info['channels'] + else: + out_channels = self.in_channels + print(f"No channel specification found for task '{target_name}', defaulting to {out_channels} channels") + target_info["out_channels"] = out_channels + self._register_target_projection(target_name, target_info, model_config) + + use_separate = target_info.get("separate_decoder", separate_decoders_default) + if use_separate: + tasks_using_separate.add(target_name) + else: + tasks_using_shared.add(target_name) + + if len(tasks_using_shared) > 0: + self.shared_decoder = build_dinov2_decoder(decoder_type, self.shared_encoder, decoder_head_channels) + head_conv = nn.Conv2d if self.shared_encoder.ndim == 2 else nn.Conv3d + for target_name in sorted(tasks_using_shared): + out_ch = self.targets[target_name]["out_channels"] + self.task_heads[target_name] = head_conv( + decoder_head_channels, out_ch, kernel_size=1, stride=1, padding=0, bias=True + ) + activation_str = self.targets[target_name].get("activation", "none") + self.task_activations[target_name] = get_activation_module(activation_str) + print( + f"Pretrained task '{target_name}' configured with shared {decoder_type} decoder + head ({out_ch} channels)" + ) + + for target_name in sorted(tasks_using_separate): + out_channels = self.targets[target_name]["out_channels"] + activation_str = self.targets[target_name].get("activation", "none") + self.task_decoders[target_name] = build_dinov2_decoder( + decoder_type, + self.shared_encoder, + out_channels, + ) + self.task_activations[target_name] = get_activation_module(activation_str) + print(f"Pretrained task '{target_name}' configured with separate {decoder_type} decoder ({out_channels} channels)") + + self.final_config = { + "model_name": self.mgr.model_name, + "architecture_type": self.architecture_type, + "pretrained_backbone": self.pretrained_backbone, + "pretrained_decoder_type": decoder_type, + "freeze_encoder": self.freeze_encoder, + "input_shape": input_shape, + "in_channels": self.in_channels, + "targets": self.targets, + "target_z_projection": self.task_z_projection_cfg, + "separate_decoders": len(tasks_using_separate) > 0, + "decoder_head_channels": decoder_head_channels, + } + + print("Pretrained backbone network initialized with configuration:") + for k, v in self.final_config.items(): + print(f" {k}: {v}") + + def _init_mednext(self, mgr, model_config): + if self.op_dims != 3: + raise ValueError("MedNeXt support is currently limited to 3D models in vesuvius") + + architecture_name = str(self.architecture_type).strip().lower() + if architecture_name not in {"mednext_v1", "mednext_v2"}: + raise ValueError( + f"Unsupported mednext architecture_type {self.architecture_type!r}. " + "Expected one of {'mednext_v1', 'mednext_v2'}." + ) + + model_id = model_config.get("mednext_model_id") + if model_id in (None, ""): + raise ValueError("mednext_model_id must be explicitly set for mednext architectures") + + base_cfg = get_mednext_v1_config(model_id) + width_factor = int(model_config.get("mednext_width_factor", 1)) + if width_factor not in {1, 2}: + raise ValueError(f"mednext_width_factor must be 1 or 2, got {width_factor}") + if architecture_name == "mednext_v1" and width_factor != 1: + raise ValueError("mednext_width_factor is only supported for mednext_v2") + + if architecture_name == "mednext_v2": + if "mednext_kernel_size" in model_config and int(model_config["mednext_kernel_size"]) != 3: + raise ValueError("mednext_v2 fixes the MedNeXt kernel size to 3") + kernel_size = 3 + norm_type = "instance" + use_grn = True + else: + kernel_size = int(model_config.get("mednext_kernel_size", 3)) + norm_type = "group" + use_grn = False + checkpoint_style = model_config.get("mednext_checkpoint_style", base_cfg["checkpoint_style"]) + if checkpoint_style == "": + checkpoint_style = None + + resolved_cfg = { + "model_id": base_cfg["model_id"], + "base_channels": int(base_cfg["n_channels"]) * width_factor, + "exp_r": list(base_cfg["exp_r"]) if isinstance(base_cfg["exp_r"], (list, tuple)) else [int(base_cfg["exp_r"])] * 9, + "block_counts": list(base_cfg["block_counts"]), + "kernel_size": kernel_size, + "checkpoint_style": checkpoint_style, + "norm_type": norm_type, + "grn": use_grn, + "width_factor": width_factor, + } + + self.shared_encoder = MedNeXtEncoder( + input_channels=self.in_channels, + n_channels=resolved_cfg["base_channels"], + exp_r=resolved_cfg["exp_r"], + block_counts=resolved_cfg["block_counts"], + kernel_size=resolved_cfg["kernel_size"], + checkpoint_style=resolved_cfg["checkpoint_style"], + norm_type=resolved_cfg["norm_type"], + grn=resolved_cfg["grn"], + do_res=True, + do_res_up_down=True, + ) + self.task_decoders = nn.ModuleDict() + self.task_activations = nn.ModuleDict() + self.task_heads = nn.ModuleDict() + + ds_enabled = bool(getattr(mgr, "enable_deep_supervision", False)) + separate_decoders_default = model_config.get("separate_decoders", ds_enabled) + tasks_using_separate = set() + tasks_using_shared = set() + for target_name, target_info in self.targets.items(): + if "out_channels" in target_info: + out_channels = target_info["out_channels"] + elif "channels" in target_info: + out_channels = target_info["channels"] + else: + out_channels = self.in_channels + print(f"No channel specification found for task '{target_name}', defaulting to {out_channels} channels") + target_info["out_channels"] = out_channels + self._register_target_projection(target_name, target_info, model_config) + use_separate = target_info.get("separate_decoder", separate_decoders_default) + # Persist the resolved per-target decoder layout so strict checkpoint + # reloads can rebuild mixed shared/separate MedNeXt models exactly. + target_info["separate_decoder"] = bool(use_separate) + if use_separate: + tasks_using_separate.add(target_name) + else: + tasks_using_shared.add(target_name) + + if ds_enabled and len(tasks_using_shared) > 0: + print( + "Deep supervision enabled: switching shared-decoder tasks to separate decoders for DS support:", + ", ".join(sorted(tasks_using_shared)), + ) + tasks_using_separate.update(tasks_using_shared) + tasks_using_shared.clear() + + if len(tasks_using_shared) > 0: + self.shared_decoder = MedNeXtDecoder( + encoder=self.shared_encoder, + num_classes=None, + deep_supervision=False, + ) + head_in_ch = self.shared_encoder.output_channels[0] + for target_name in sorted(tasks_using_shared): + out_ch = self.targets[target_name]["out_channels"] + self.task_heads[target_name] = nn.Conv3d(head_in_ch, out_ch, kernel_size=1, stride=1, padding=0, bias=True) + activation_str = self.targets[target_name].get("activation", "none") + self.task_activations[target_name] = get_activation_module(activation_str) + print(f"MedNeXt task '{target_name}' configured with shared decoder + head ({out_ch} channels)") + else: + self.shared_decoder = None + + for target_name in sorted(tasks_using_separate): + out_channels = self.targets[target_name]["out_channels"] + activation_str = self.targets[target_name].get("activation", "none") + self.task_decoders[target_name] = MedNeXtDecoder( + encoder=self.shared_encoder, + num_classes=out_channels, + deep_supervision=ds_enabled, + ) + self.task_activations[target_name] = get_activation_module(activation_str) + print(f"MedNeXt task '{target_name}' configured with separate decoder ({out_channels} channels)") + + self.final_config = { + "model_name": self.mgr.model_name, + "architecture_type": self.architecture_type, + "mednext_model_id": resolved_cfg["model_id"], + "mednext_base_channels": resolved_cfg["base_channels"], + "mednext_exp_r": resolved_cfg["exp_r"], + "mednext_block_counts": resolved_cfg["block_counts"], + "mednext_kernel_size": resolved_cfg["kernel_size"], + "mednext_checkpoint_style": resolved_cfg["checkpoint_style"], + "mednext_norm_type": resolved_cfg["norm_type"], + "mednext_grn": resolved_cfg["grn"], + "mednext_width_factor": resolved_cfg["width_factor"], + "patch_size": self.patch_size, + "batch_size": self.batch_size, + "in_channels": self.in_channels, + "autoconfigure": self.autoconfigure, + "targets": self.targets, + "target_z_projection": self.task_z_projection_cfg, + "separate_decoders": len(tasks_using_separate) > 0, + "enable_deep_supervision": ds_enabled, + } + + print("MedNeXt network initialized with configuration:") + for k, v in self.final_config.items(): + print(f" {k}: {v}") + def _init_primus(self, mgr, model_config): """ Initialize Primus transformer architecture. @@ -849,9 +1488,32 @@ def check_input_channels(self, x): return False return True - def forward(self, x, return_mae_mask=False): + def train(self, mode: bool = True): + super().train(mode) + if self.pretrained_backbone and self.freeze_encoder and self.shared_encoder is not None: + self.shared_encoder.eval() + if self.guide_enabled and self.guide_freeze and self.guide_backbone is not None: + self.guide_backbone.eval() + return self + + def forward(self, x, return_mae_mask=False, return_aux=False): # Check input channels and warn if mismatch self.check_input_channels(x) + aux_outputs = {} + encoder_feature_gates = None + + if self._guide_uses_direct_segmentation(): + if return_mae_mask: + raise RuntimeError("direct_segmentation does not support return_mae_mask") + results, aux_outputs = self._build_direct_segmentation_outputs(x) + if return_aux: + return results, aux_outputs + return results + + if self._guide_uses_input_gating(): + x, aux_outputs = self._apply_input_guidance(x) + elif self._guide_uses_encoder_feature_gating(): + encoder_feature_gates, aux_outputs = self._build_encoder_feature_guidance(x) # Get features from encoder (works for both U-Net and Primus) # For MAE training with Primus, we need to get the mask @@ -881,7 +1543,16 @@ def forward(self, x, return_mae_mask=False): ) else: # Standard forward pass - features = self.shared_encoder(x) + if encoder_feature_gates is not None: + features = self.shared_encoder( + x, + feature_gates=encoder_feature_gates, + feature_gate_alpha=self.guide_feature_gate_alpha, + ) + else: + features = self.shared_encoder(x) + if self._guide_uses_skip_feature_concat(): + features, aux_outputs = self._build_skip_concat_features(x, features) restoration_mask = None results = {} @@ -921,5 +1592,71 @@ def forward(self, x, return_mae_mask=False): # Return MAE mask if requested (for MAE training) if return_mae_mask: + if return_aux: + return results, restoration_mask, aux_outputs return results, restoration_mask + if return_aux: + return results, aux_outputs return results + + def _init_direct_segmentation(self, model_config): + primary_targets = [ + name for name, info in self.targets.items() + if not (info or {}).get("auxiliary_task", False) + ] + if len(primary_targets) != 1: + raise ValueError( + "direct_segmentation supports exactly one non-auxiliary target in v1, " + f"got {primary_targets}." + ) + + target_name = primary_targets[0] + target_info = self.targets[target_name] + configured_out_channels = target_info.get("out_channels", target_info.get("channels", 2)) + if int(configured_out_channels) != 2: + raise ValueError( + "direct_segmentation expects a binary target configured with out_channels=2, " + f"got {configured_out_channels} for target '{target_name}'." + ) + target_info["out_channels"] = 2 + self._register_target_projection(target_name, target_info, model_config) + activation_str = target_info.get("activation", "none") + + self.shared_encoder = None + self.shared_decoder = None + self.task_decoders = nn.ModuleDict() + self.task_heads = nn.ModuleDict() + self.task_activations = nn.ModuleDict({target_name: get_activation_module(activation_str)}) + self.direct_segmentation_target_name = target_name + self.guide_direct_output_mode = "two_channel_logits" + + self._init_guidance(model_config) + + self.final_config = { + "model_name": self.mgr.model_name, + "architecture_type": "guide_only", + "patch_size": self.patch_size, + "batch_size": self.batch_size, + "in_channels": self.in_channels, + "autoconfigure": self.autoconfigure, + "targets": self.targets, + "target_z_projection": self.task_z_projection_cfg, + "guide_backbone": self.guide_backbone_name, + "guide_backbone_config_path": self.guide_backbone_config_path, + "guide_freeze": self.guide_freeze, + "guide_patch_grid": self.guide_patch_grid, + "guide_stage_keys": None, + "guide_skip_concat_projector_channels": None, + "guide_feature_gate_alpha": None, + "guide_tokenbook_tokens": getattr(self, "guide_tokenbook_tokens", None), + "guide_tokenbook_prototype_weighting": getattr(self, "guide_tokenbook_prototype_weighting", "mean"), + "guide_tokenbook_weight_mlp_hidden": getattr(self, "guide_tokenbook_weight_mlp_hidden", None), + "guide_compile_policy": self.guide_compile_policy, + "guide_fusion_stage": self.guide_fusion_stage, + "guide_direct_output_mode": self.guide_direct_output_mode, + "direct_segmentation_target_name": self.direct_segmentation_target_name, + } + + print("NetworkFromConfig initialized with final configuration:") + for k, v in self.final_config.items(): + print(f" {k}: {v}") diff --git a/vesuvius/src/vesuvius/models/build/decoder.py b/vesuvius/src/vesuvius/models/build/decoder.py index 68d4633c0..ae4d87a2c 100644 --- a/vesuvius/src/vesuvius/models/build/decoder.py +++ b/vesuvius/src/vesuvius/models/build/decoder.py @@ -20,6 +20,7 @@ def __init__(self, num_classes: Union[int, None], n_conv_per_stage: Union[int, Tuple[int, ...], List[int]], deep_supervision, + skip_channels: Union[None, Tuple[int, ...], List[int]] = None, nonlin_first: bool = False, norm_op: Union[None, Type[nn.Module]] = None, norm_op_kwargs: dict = None, @@ -53,6 +54,21 @@ def __init__(self, object.__setattr__(self, '_encoder_ref', encoder) self.num_classes = num_classes if num_classes is not None else 0 n_stages_encoder = len(encoder.output_channels) + if skip_channels is None: + skip_channels = list(encoder.output_channels) + elif isinstance(skip_channels, tuple): + skip_channels = list(skip_channels) + else: + skip_channels = list(skip_channels) + assert len(skip_channels) == n_stages_encoder, ( + "skip_channels must have as many entries as encoder.output_channels, " + f"got {len(skip_channels)} and {n_stages_encoder}" + ) + self.skip_channels = [int(v) for v in skip_channels] + self.native_encoder_channels = [int(v) for v in encoder.output_channels] + self.bottleneck_input_channels = int(self.skip_channels[-1]) + self.decoder_stage_input_channels = [] + self.decoder_stage_output_channels = [] if isinstance(n_conv_per_stage, int): n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1) assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \ @@ -74,19 +90,20 @@ def __init__(self, transpconvs = [] seg_layers = [] for s in range(1, n_stages_encoder): - input_features_below = encoder.output_channels[-s] - input_features_skip = encoder.output_channels[-(s + 1)] + input_features_below = self.skip_channels[-1] if s == 1 else encoder.output_channels[-s] + input_features_skip_native = encoder.output_channels[-(s + 1)] + input_features_skip = self.skip_channels[-(s + 1)] stride_for_transpconv = encoder.strides[-s] transpconvs.append(transpconv_op( - input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv, + input_features_below, input_features_skip_native, stride_for_transpconv, stride_for_transpconv, bias=encoder.conv_bias )) - # input features to conv is 2x input_features_skip (concat input_features_skip with transpconv output) + conv_input_channels = input_features_skip_native + input_features_skip stages.append(StackedResidualBlocks( n_blocks=n_conv_per_stage[s - 1], conv_op=encoder.conv_op, - input_channels=2 * input_features_skip, - output_channels=input_features_skip, + input_channels=conv_input_channels, + output_channels=input_features_skip_native, kernel_size=encoder.kernel_sizes[-(s + 1)], initial_stride=1, conv_bias=conv_bias, @@ -97,29 +114,32 @@ def __init__(self, nonlin=nonlin, nonlin_kwargs=nonlin_kwargs, )) + self.decoder_stage_input_channels.append(int(conv_input_channels)) + self.decoder_stage_output_channels.append(int(input_features_skip_native)) # Only build segmentation layers if we are not in features-only mode if not self.return_features_only: # we always build the deep supervision outputs so that we can always load parameters. If we don't do this # then a model trained with deep_supervision=True could not easily be loaded at inference time where # deep supervision is not needed. It's just a convenience thing - seg_layers.append(encoder.conv_op(input_features_skip, self.num_classes, 1, 1, 0, bias=True)) + seg_layers.append(encoder.conv_op(input_features_skip_native, self.num_classes, 1, 1, 0, bias=True)) if basic_block == 'ConvBlock': stages = [] transpconvs = [] seg_layers = [] for s in range(1, n_stages_encoder): - input_features_below = encoder.output_channels[-s] - input_features_skip = encoder.output_channels[-(s + 1)] + input_features_below = self.skip_channels[-1] if s == 1 else encoder.output_channels[-s] + input_features_skip_native = encoder.output_channels[-(s + 1)] + input_features_skip = self.skip_channels[-(s + 1)] stride_for_transpconv = encoder.strides[-s] transpconvs.append(transpconv_op( - input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv, + input_features_below, input_features_skip_native, stride_for_transpconv, stride_for_transpconv, bias=conv_bias )) - # input features to conv is 2x input_features_skip (concat input_features_skip with transpconv output) + conv_input_channels = input_features_skip_native + input_features_skip stages.append(StackedConvBlocks( - n_conv_per_stage[s - 1], encoder.conv_op, 2 * input_features_skip, input_features_skip, + n_conv_per_stage[s - 1], encoder.conv_op, conv_input_channels, input_features_skip_native, encoder.kernel_sizes[-(s + 1)], 1, conv_bias, norm_op, @@ -130,13 +150,15 @@ def __init__(self, nonlin_kwargs, nonlin_first )) + self.decoder_stage_input_channels.append(int(conv_input_channels)) + self.decoder_stage_output_channels.append(int(input_features_skip_native)) # Only build segmentation layers if we are not in features-only mode if not self.return_features_only: # we always build the deep supervision outputs so that we can always load parameters. If we don't do this # then a model trained with deep_supervision=True could not easily be loaded at inference time where # deep supervision is not needed. It's just a convenience thing - seg_layers.append(encoder.conv_op(input_features_skip, self.num_classes, 1, 1, 0, bias=True)) + seg_layers.append(encoder.conv_op(input_features_skip_native, self.num_classes, 1, 1, 0, bias=True)) self.stages = nn.ModuleList(stages) self.transpconvs = nn.ModuleList(transpconvs) diff --git a/vesuvius/src/vesuvius/models/build/encoder.py b/vesuvius/src/vesuvius/models/build/encoder.py index 76b785250..a5e1ac088 100644 --- a/vesuvius/src/vesuvius/models/build/encoder.py +++ b/vesuvius/src/vesuvius/models/build/encoder.py @@ -152,13 +152,20 @@ def __init__(self, self.conv_bias = conv_bias self.kernel_sizes = kernel_sizes - def forward(self, x): + def forward(self, x, feature_gates=None, feature_gate_alpha: float = 1.0): if self.stem is not None: x = self.stem(x) ret = [] - for s in self.stages: + alpha = float(feature_gate_alpha) + for stage_idx, s in enumerate(self.stages): x = s(x) - ret.append(x) + skip_x = x + if feature_gates is not None: + stage_key = f"enc_{stage_idx}" + gate = feature_gates.get(stage_key) + if gate is not None: + skip_x = x * ((1.0 - alpha) + alpha * gate) + ret.append(skip_x) if self.return_skips: return ret else: @@ -175,5 +182,3 @@ def compute_conv_feature_map_size(self, input_size): input_size = [i // j for i, j in zip(input_size, self.strides[s])] return output - - diff --git a/vesuvius/src/vesuvius/models/build/guidance.py b/vesuvius/src/vesuvius/models/build/guidance.py new file mode 100644 index 000000000..fcbd8b070 --- /dev/null +++ b/vesuvius/src/vesuvius/models/build/guidance.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TokenBook3D(nn.Module): + """Convert a 3D feature grid into a scalar spatial guide mask.""" + + def __init__( + self, + *, + n_tokens: int, + embed_dim: int, + dropout: float = 0.0, + ema_decay: float | None = None, + use_ema: bool = False, + prototype_weighting: str = "mean", + weight_mlp_hidden: int | None = None, + ) -> None: + super().__init__() + self.input_conv = nn.Conv3d(embed_dim, embed_dim, kernel_size=1, bias=True) + self.book = nn.Parameter(torch.randn(int(n_tokens), int(embed_dim))) + nn.init.normal_(self.book, mean=0.0, std=float(embed_dim) ** -0.5) + self.dropout = float(dropout) + self.ema_decay = ema_decay + self.use_ema = bool(use_ema) + if ema_decay is not None: + self.register_buffer("book_ema", self.book.detach().clone()) + else: + self.book_ema = None + + self.prototype_weighting = str(prototype_weighting).strip().lower() + if self.prototype_weighting not in {"mean", "token_mlp"}: + raise ValueError( + "prototype_weighting must be one of {'mean', 'token_mlp'}" + ) + self.prototype_weight_mlp = None + self.weight_mlp_hidden = None + if self.prototype_weighting == "token_mlp": + hidden = int(weight_mlp_hidden or embed_dim) + if hidden <= 0: + raise ValueError(f"weight_mlp_hidden must be > 0, got {hidden}") + self.weight_mlp_hidden = hidden + self.prototype_weight_mlp = nn.Sequential( + nn.Linear(int(embed_dim), hidden), + nn.GELU(), + nn.Linear(hidden, int(n_tokens)), + ) + + @torch.no_grad() + def _ema_update(self) -> None: + if self.ema_decay is None or self.book_ema is None: + return + decay = float(self.ema_decay) + self.book_ema.mul_(decay).add_(self.book.detach(), alpha=1.0 - decay) + + def forward( + self, + x: torch.Tensor, + *, + token_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + if x.ndim != 5: + raise ValueError(f"TokenBook3D expects [B, C, D, H, W], got {tuple(x.shape)}") + + if self.training: + self._ema_update() + + batch_size, _channels, depth, height, width = x.shape + book = self.book_ema if self.use_ema and self.book_ema is not None else self.book + + tokens = self.input_conv(x).flatten(2).transpose(1, 2) + tokens = F.normalize(tokens, dim=-1) + prototypes = F.normalize(book, dim=-1).unsqueeze(0).expand(batch_size, -1, -1) + + similarities = torch.matmul(tokens, prototypes.transpose(1, 2)) + if self.dropout > 0.0: + similarities = F.dropout(similarities, p=self.dropout, training=self.training) + + if self.prototype_weighting == "token_mlp": + prototype_weights = F.softmax(self.prototype_weight_mlp(tokens), dim=-1) + similarities = (similarities * prototype_weights).sum(dim=-1) + else: + similarities = similarities.mean(dim=-1) + + if token_mask is not None: + token_mask = token_mask.to(device=similarities.device, dtype=similarities.dtype) + similarities = similarities * token_mask + denom = token_mask.sum(dim=1, keepdim=True).clamp_min(1.0) + similarities = similarities / denom + + guide = similarities.reshape(batch_size, 1, depth, height, width) + guide = guide * 0.5 + 0.5 + return guide.clamp_(0.0, 1.0) diff --git a/vesuvius/src/vesuvius/models/build/mednext/__init__.py b/vesuvius/src/vesuvius/models/build/mednext/__init__.py new file mode 100644 index 000000000..c2abc3ba1 --- /dev/null +++ b/vesuvius/src/vesuvius/models/build/mednext/__init__.py @@ -0,0 +1,17 @@ +""" +Minimal vendored MedNeXt implementation for Vesuvius. + +Source adapted from the official MIC-DKFZ MedNeXt repository: +https://github.com/MIC-DKFZ/MedNeXt + +The initial landing intentionally scopes this package to 3D usage only. +""" + +from .factory import create_mednext_v1, get_mednext_v1_config +from .v1 import MedNeXtV1 + +__all__ = [ + "MedNeXtV1", + "create_mednext_v1", + "get_mednext_v1_config", +] diff --git a/vesuvius/src/vesuvius/models/build/mednext/blocks.py b/vesuvius/src/vesuvius/models/build/mednext/blocks.py new file mode 100644 index 000000000..c98af8e2e --- /dev/null +++ b/vesuvius/src/vesuvius/models/build/mednext/blocks.py @@ -0,0 +1,205 @@ +""" +3D MedNeXt blocks adapted from the official MIC-DKFZ MedNeXt repository. + +This file intentionally keeps the stage-1 implementation close to upstream +MedNeXt v1 semantics while restricting support to 3D. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LayerNorm3D(nn.Module): + """channels_first LayerNorm variant used by upstream MedNeXt blocks.""" + + def __init__(self, normalized_shape: int, eps: float = 1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.ndim != 5: + raise ValueError(f"LayerNorm3D expects [B, C, D, H, W], got {tuple(x.shape)}") + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + return self.weight[:, None, None, None] * x + self.bias[:, None, None, None] + + +def _make_norm(norm_type: str, num_channels: int) -> nn.Module: + norm_type = str(norm_type).strip().lower() + if norm_type == "group": + return nn.GroupNorm(num_groups=num_channels, num_channels=num_channels) + if norm_type == "layer": + return LayerNorm3D(num_channels) + if norm_type == "instance": + return nn.InstanceNorm3d(num_channels, affine=True) + raise ValueError(f"Unsupported MedNeXt norm_type {norm_type!r}") + + +class MedNeXtBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + exp_r: int = 4, + kernel_size: int = 7, + do_res: bool = True, + norm_type: str = "group", + n_groups: int | None = None, + grn: bool = False, + ): + super().__init__() + self.do_res = bool(do_res) + self.grn = bool(grn) + + self.conv1 = nn.Conv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + groups=in_channels if n_groups is None else int(n_groups), + ) + self.norm = _make_norm(norm_type, in_channels) + self.conv2 = nn.Conv3d( + in_channels=in_channels, + out_channels=exp_r * in_channels, + kernel_size=1, + stride=1, + padding=0, + ) + self.act = nn.GELU() + self.conv3 = nn.Conv3d( + in_channels=exp_r * in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + ) + + if self.grn: + expanded = exp_r * in_channels + self.grn_beta = nn.Parameter(torch.zeros(1, expanded, 1, 1, 1), requires_grad=True) + self.grn_gamma = nn.Parameter(torch.zeros(1, expanded, 1, 1, 1), requires_grad=True) + + def forward(self, x: torch.Tensor, dummy_tensor: torch.Tensor | None = None) -> torch.Tensor: + del dummy_tensor + residual = x + x = self.conv1(x) + x = self.act(self.conv2(self.norm(x))) + if self.grn: + gx = torch.norm(x, p=2, dim=(-3, -2, -1), keepdim=True) + nx = gx / (gx.mean(dim=1, keepdim=True) + 1e-6) + x = self.grn_gamma * (x * nx) + self.grn_beta + x + x = self.conv3(x) + if self.do_res: + x = residual + x + return x + + +class MedNeXtDownBlock3D(MedNeXtBlock3D): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + exp_r: int = 4, + kernel_size: int = 7, + do_res: bool = False, + norm_type: str = "group", + grn: bool = False, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + exp_r=exp_r, + kernel_size=kernel_size, + do_res=False, + norm_type=norm_type, + grn=grn, + ) + self.resample_do_res = bool(do_res) + if self.resample_do_res: + self.res_conv = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=2, + ) + self.conv1 = nn.Conv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=2, + padding=kernel_size // 2, + groups=in_channels, + ) + + def forward(self, x: torch.Tensor, dummy_tensor: torch.Tensor | None = None) -> torch.Tensor: + y = super().forward(x, dummy_tensor=dummy_tensor) + if self.resample_do_res: + y = y + self.res_conv(x) + return y + + +class MedNeXtUpBlock3D(MedNeXtBlock3D): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + exp_r: int = 4, + kernel_size: int = 7, + do_res: bool = False, + norm_type: str = "group", + grn: bool = False, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + exp_r=exp_r, + kernel_size=kernel_size, + do_res=False, + norm_type=norm_type, + grn=grn, + ) + self.resample_do_res = bool(do_res) + if self.resample_do_res: + self.res_conv = nn.ConvTranspose3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=2, + ) + self.conv1 = nn.ConvTranspose3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=2, + padding=kernel_size // 2, + groups=in_channels, + ) + + def forward(self, x: torch.Tensor, dummy_tensor: torch.Tensor | None = None) -> torch.Tensor: + y = super().forward(x, dummy_tensor=dummy_tensor) + y = F.pad(y, (1, 0, 1, 0, 1, 0)) + if self.resample_do_res: + res = F.pad(self.res_conv(x), (1, 0, 1, 0, 1, 0)) + y = y + res + return y + + +class OutBlock3D(nn.Module): + def __init__(self, in_channels: int, num_classes: int): + super().__init__() + self.conv_out = nn.ConvTranspose3d(in_channels, num_classes, kernel_size=1) + + def forward(self, x: torch.Tensor, dummy_tensor: torch.Tensor | None = None) -> torch.Tensor: + del dummy_tensor + return self.conv_out(x) diff --git a/vesuvius/src/vesuvius/models/build/mednext/factory.py b/vesuvius/src/vesuvius/models/build/mednext/factory.py new file mode 100644 index 000000000..b807324c2 --- /dev/null +++ b/vesuvius/src/vesuvius/models/build/mednext/factory.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from copy import deepcopy + +from .v1 import MedNeXtV1 + + +_MEDNEXT_V1_PRESETS: dict[str, dict] = { + "S": { + "n_channels": 32, + "exp_r": 2, + "block_counts": [2, 2, 2, 2, 2, 2, 2, 2, 2], + "checkpoint_style": None, + }, + "B": { + "n_channels": 32, + "exp_r": [2, 3, 4, 4, 4, 4, 4, 3, 2], + "block_counts": [2, 2, 2, 2, 2, 2, 2, 2, 2], + "checkpoint_style": None, + }, + "M": { + "n_channels": 32, + "exp_r": [2, 3, 4, 4, 4, 4, 4, 3, 2], + "block_counts": [3, 4, 4, 4, 4, 4, 4, 4, 3], + "checkpoint_style": "outside_block", + }, + "L": { + "n_channels": 32, + "exp_r": [3, 4, 8, 8, 8, 8, 8, 4, 3], + "block_counts": [3, 4, 8, 8, 8, 8, 8, 4, 3], + "checkpoint_style": "outside_block", + }, +} + + +def get_mednext_v1_config(model_id: str) -> dict: + key = str(model_id).strip().upper() + if key not in _MEDNEXT_V1_PRESETS: + raise ValueError( + f"Unknown mednext_model_id {model_id!r}. Expected one of {sorted(_MEDNEXT_V1_PRESETS)}" + ) + cfg = deepcopy(_MEDNEXT_V1_PRESETS[key]) + cfg["model_id"] = key + return cfg + + +def create_mednext_v1( + num_input_channels: int, + num_classes: int, + model_id: str, + *, + kernel_size: int = 3, + deep_supervision: bool = False, + checkpoint_style: str | None = None, +) -> MedNeXtV1: + cfg = get_mednext_v1_config(model_id) + if checkpoint_style is None: + checkpoint_style = cfg["checkpoint_style"] + return MedNeXtV1( + in_channels=num_input_channels, + n_channels=cfg["n_channels"], + n_classes=num_classes, + exp_r=cfg["exp_r"], + kernel_size=int(kernel_size), + deep_supervision=bool(deep_supervision), + do_res=True, + do_res_up_down=True, + checkpoint_style=checkpoint_style, + block_counts=cfg["block_counts"], + norm_type="group", + grn=False, + ) diff --git a/vesuvius/src/vesuvius/models/build/mednext/v1.py b/vesuvius/src/vesuvius/models/build/mednext/v1.py new file mode 100644 index 000000000..24a052611 --- /dev/null +++ b/vesuvius/src/vesuvius/models/build/mednext/v1.py @@ -0,0 +1,311 @@ +""" +3D MedNeXt v1 architecture adapted from the official MIC-DKFZ MedNeXt repository. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from .blocks import MedNeXtBlock3D, MedNeXtDownBlock3D, MedNeXtUpBlock3D, OutBlock3D + + +class MedNeXtV1(nn.Module): + def __init__( + self, + in_channels: int, + n_channels: int, + n_classes: int, + *, + exp_r: int | list[int] = 4, + kernel_size: int = 7, + enc_kernel_size: int | None = None, + dec_kernel_size: int | None = None, + deep_supervision: bool = False, + do_res: bool = False, + do_res_up_down: bool = False, + checkpoint_style: str | None = None, + block_counts: list[int] | tuple[int, ...] = (2, 2, 2, 2, 2, 2, 2, 2, 2), + norm_type: str = "group", + grn: bool = False, + ): + super().__init__() + + self.do_ds = bool(deep_supervision) + if checkpoint_style not in {None, "outside_block"}: + raise ValueError( + f"Unsupported mednext_checkpoint_style {checkpoint_style!r}. " + "Expected None or 'outside_block'." + ) + self.outside_block_checkpointing = checkpoint_style == "outside_block" + + if kernel_size is not None: + enc_kernel_size = int(kernel_size) + dec_kernel_size = int(kernel_size) + enc_kernel_size = int(enc_kernel_size or 3) + dec_kernel_size = int(dec_kernel_size or 3) + block_counts = list(int(v) for v in block_counts) + if len(block_counts) != 9: + raise ValueError(f"MedNeXt expects 9 block_counts entries, got {block_counts}") + + if isinstance(exp_r, int): + exp_r = [int(exp_r)] * len(block_counts) + else: + exp_r = [int(v) for v in exp_r] + if len(exp_r) != len(block_counts): + raise ValueError(f"MedNeXt expects 9 exp_r entries, got {exp_r}") + + self.stem = nn.Conv3d(in_channels, n_channels, kernel_size=1) + + self.enc_block_0 = nn.Sequential(*[ + MedNeXtBlock3D( + in_channels=n_channels, + out_channels=n_channels, + exp_r=exp_r[0], + kernel_size=enc_kernel_size, + do_res=do_res, + norm_type=norm_type, + grn=grn, + ) for _ in range(block_counts[0]) + ]) + self.down_0 = MedNeXtDownBlock3D( + in_channels=n_channels, + out_channels=2 * n_channels, + exp_r=exp_r[1], + kernel_size=enc_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + grn=grn, + ) + + self.enc_block_1 = nn.Sequential(*[ + MedNeXtBlock3D( + in_channels=2 * n_channels, + out_channels=2 * n_channels, + exp_r=exp_r[1], + kernel_size=enc_kernel_size, + do_res=do_res, + norm_type=norm_type, + grn=grn, + ) for _ in range(block_counts[1]) + ]) + self.down_1 = MedNeXtDownBlock3D( + in_channels=2 * n_channels, + out_channels=4 * n_channels, + exp_r=exp_r[2], + kernel_size=enc_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + grn=grn, + ) + + self.enc_block_2 = nn.Sequential(*[ + MedNeXtBlock3D( + in_channels=4 * n_channels, + out_channels=4 * n_channels, + exp_r=exp_r[2], + kernel_size=enc_kernel_size, + do_res=do_res, + norm_type=norm_type, + grn=grn, + ) for _ in range(block_counts[2]) + ]) + self.down_2 = MedNeXtDownBlock3D( + in_channels=4 * n_channels, + out_channels=8 * n_channels, + exp_r=exp_r[3], + kernel_size=enc_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + grn=grn, + ) + + self.enc_block_3 = nn.Sequential(*[ + MedNeXtBlock3D( + in_channels=8 * n_channels, + out_channels=8 * n_channels, + exp_r=exp_r[3], + kernel_size=enc_kernel_size, + do_res=do_res, + norm_type=norm_type, + grn=grn, + ) for _ in range(block_counts[3]) + ]) + self.down_3 = MedNeXtDownBlock3D( + in_channels=8 * n_channels, + out_channels=16 * n_channels, + exp_r=exp_r[4], + kernel_size=enc_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + grn=grn, + ) + + self.bottleneck = nn.Sequential(*[ + MedNeXtBlock3D( + in_channels=16 * n_channels, + out_channels=16 * n_channels, + exp_r=exp_r[4], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + grn=grn, + ) for _ in range(block_counts[4]) + ]) + + self.up_3 = MedNeXtUpBlock3D( + in_channels=16 * n_channels, + out_channels=8 * n_channels, + exp_r=exp_r[5], + kernel_size=dec_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + grn=grn, + ) + self.dec_block_3 = nn.Sequential(*[ + MedNeXtBlock3D( + in_channels=8 * n_channels, + out_channels=8 * n_channels, + exp_r=exp_r[5], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + grn=grn, + ) for _ in range(block_counts[5]) + ]) + + self.up_2 = MedNeXtUpBlock3D( + in_channels=8 * n_channels, + out_channels=4 * n_channels, + exp_r=exp_r[6], + kernel_size=dec_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + grn=grn, + ) + self.dec_block_2 = nn.Sequential(*[ + MedNeXtBlock3D( + in_channels=4 * n_channels, + out_channels=4 * n_channels, + exp_r=exp_r[6], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + grn=grn, + ) for _ in range(block_counts[6]) + ]) + + self.up_1 = MedNeXtUpBlock3D( + in_channels=4 * n_channels, + out_channels=2 * n_channels, + exp_r=exp_r[7], + kernel_size=dec_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + grn=grn, + ) + self.dec_block_1 = nn.Sequential(*[ + MedNeXtBlock3D( + in_channels=2 * n_channels, + out_channels=2 * n_channels, + exp_r=exp_r[7], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + grn=grn, + ) for _ in range(block_counts[7]) + ]) + + self.up_0 = MedNeXtUpBlock3D( + in_channels=2 * n_channels, + out_channels=n_channels, + exp_r=exp_r[8], + kernel_size=dec_kernel_size, + do_res=do_res_up_down, + norm_type=norm_type, + grn=grn, + ) + self.dec_block_0 = nn.Sequential(*[ + MedNeXtBlock3D( + in_channels=n_channels, + out_channels=n_channels, + exp_r=exp_r[8], + kernel_size=dec_kernel_size, + do_res=do_res, + norm_type=norm_type, + grn=grn, + ) for _ in range(block_counts[8]) + ]) + + self.out_0 = OutBlock3D(n_channels, n_classes) + self.dummy_tensor = nn.Parameter(torch.tensor([1.0]), requires_grad=True) + if self.do_ds: + self.out_1 = OutBlock3D(2 * n_channels, n_classes) + self.out_2 = OutBlock3D(4 * n_channels, n_classes) + self.out_3 = OutBlock3D(8 * n_channels, n_classes) + self.out_4 = OutBlock3D(16 * n_channels, n_classes) + + self.block_counts = block_counts + + def iterative_checkpoint(self, sequential_block: nn.Sequential, x: torch.Tensor) -> torch.Tensor: + for layer in sequential_block: + x = checkpoint.checkpoint(layer, x, self.dummy_tensor) + return x + + def forward(self, x: torch.Tensor): + x = self.stem(x) + if self.outside_block_checkpointing: + x_res_0 = self.iterative_checkpoint(self.enc_block_0, x) + x = checkpoint.checkpoint(self.down_0, x_res_0, self.dummy_tensor) + x_res_1 = self.iterative_checkpoint(self.enc_block_1, x) + x = checkpoint.checkpoint(self.down_1, x_res_1, self.dummy_tensor) + x_res_2 = self.iterative_checkpoint(self.enc_block_2, x) + x = checkpoint.checkpoint(self.down_2, x_res_2, self.dummy_tensor) + x_res_3 = self.iterative_checkpoint(self.enc_block_3, x) + x = checkpoint.checkpoint(self.down_3, x_res_3, self.dummy_tensor) + x = self.iterative_checkpoint(self.bottleneck, x) + if self.do_ds: + x_ds_4 = checkpoint.checkpoint(self.out_4, x, self.dummy_tensor) + x_up_3 = checkpoint.checkpoint(self.up_3, x, self.dummy_tensor) + x = self.iterative_checkpoint(self.dec_block_3, x_res_3 + x_up_3) + if self.do_ds: + x_ds_3 = checkpoint.checkpoint(self.out_3, x, self.dummy_tensor) + x_up_2 = checkpoint.checkpoint(self.up_2, x, self.dummy_tensor) + x = self.iterative_checkpoint(self.dec_block_2, x_res_2 + x_up_2) + if self.do_ds: + x_ds_2 = checkpoint.checkpoint(self.out_2, x, self.dummy_tensor) + x_up_1 = checkpoint.checkpoint(self.up_1, x, self.dummy_tensor) + x = self.iterative_checkpoint(self.dec_block_1, x_res_1 + x_up_1) + if self.do_ds: + x_ds_1 = checkpoint.checkpoint(self.out_1, x, self.dummy_tensor) + x_up_0 = checkpoint.checkpoint(self.up_0, x, self.dummy_tensor) + x = self.iterative_checkpoint(self.dec_block_0, x_res_0 + x_up_0) + x = checkpoint.checkpoint(self.out_0, x, self.dummy_tensor) + else: + x_res_0 = self.enc_block_0(x) + x = self.down_0(x_res_0) + x_res_1 = self.enc_block_1(x) + x = self.down_1(x_res_1) + x_res_2 = self.enc_block_2(x) + x = self.down_2(x_res_2) + x_res_3 = self.enc_block_3(x) + x = self.down_3(x_res_3) + x = self.bottleneck(x) + if self.do_ds: + x_ds_4 = self.out_4(x) + x = self.dec_block_3(x_res_3 + self.up_3(x)) + if self.do_ds: + x_ds_3 = self.out_3(x) + x = self.dec_block_2(x_res_2 + self.up_2(x)) + if self.do_ds: + x_ds_2 = self.out_2(x) + x = self.dec_block_1(x_res_1 + self.up_1(x)) + if self.do_ds: + x_ds_1 = self.out_1(x) + x = self.dec_block_0(x_res_0 + self.up_0(x)) + x = self.out_0(x) + + if self.do_ds: + return [x, x_ds_1, x_ds_2, x_ds_3, x_ds_4] + return x diff --git a/vesuvius/src/vesuvius/models/build/mednext_wrapper.py b/vesuvius/src/vesuvius/models/build/mednext_wrapper.py new file mode 100644 index 000000000..53c110b9d --- /dev/null +++ b/vesuvius/src/vesuvius/models/build/mednext_wrapper.py @@ -0,0 +1,230 @@ +""" +Wrapper classes to integrate MedNeXt with NetworkFromConfig. + +The wrappers keep the shared-encoder / task-decoder contract used by the rest +of vesuvius while preserving MedNeXt's additive skip behavior. +""" + +from __future__ import annotations + +from typing import Sequence + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from .mednext.blocks import MedNeXtBlock3D, MedNeXtDownBlock3D, MedNeXtUpBlock3D, OutBlock3D + + +class MedNeXtEncoder(nn.Module): + def __init__( + self, + *, + input_channels: int, + n_channels: int, + exp_r: Sequence[int], + block_counts: Sequence[int], + kernel_size: int, + checkpoint_style: str | None, + norm_type: str, + grn: bool, + do_res: bool = True, + do_res_up_down: bool = True, + ): + super().__init__() + if checkpoint_style not in {None, "outside_block"}: + raise ValueError( + f"Unsupported mednext_checkpoint_style {checkpoint_style!r}. " + "Expected None or 'outside_block'." + ) + self.outside_block_checkpointing = checkpoint_style == "outside_block" + self.n_channels = int(n_channels) + self.exp_r = [int(v) for v in exp_r] + self.block_counts = [int(v) for v in block_counts] + self.kernel_size = int(kernel_size) + self.checkpoint_style = checkpoint_style + self.norm_type = str(norm_type) + self.grn = bool(grn) + self.do_res = bool(do_res) + self.do_res_up_down = bool(do_res_up_down) + self.ndim = 3 + self.output_channels = [ + self.n_channels, + 2 * self.n_channels, + 4 * self.n_channels, + 8 * self.n_channels, + 16 * self.n_channels, + ] + self.strides = [ + (1, 1, 1), + (2, 2, 2), + (4, 4, 4), + (8, 8, 8), + (16, 16, 16), + ] + self.conv_op = nn.Conv3d + self.norm_op = nn.InstanceNorm3d if self.norm_type == "instance" else None + self.norm_op_kwargs = {"affine": True, "eps": 1e-5} + self.nonlin = nn.GELU + self.nonlin_kwargs = {} + self.dropout_op = None + self.dropout_op_kwargs = None + self.conv_bias = True + self.kernel_sizes = [[int(self.kernel_size)] * 3 for _ in self.output_channels] + + c = self.n_channels + e = self.exp_r + b = self.block_counts + k = self.kernel_size + self.stem = nn.Conv3d(input_channels, c, kernel_size=1) + self.enc_block_0 = nn.Sequential(*[ + MedNeXtBlock3D(c, c, exp_r=e[0], kernel_size=k, do_res=self.do_res, norm_type=self.norm_type, grn=self.grn) + for _ in range(b[0]) + ]) + self.down_0 = MedNeXtDownBlock3D(c, 2 * c, exp_r=e[1], kernel_size=k, do_res=self.do_res_up_down, norm_type=self.norm_type, grn=self.grn) + self.enc_block_1 = nn.Sequential(*[ + MedNeXtBlock3D(2 * c, 2 * c, exp_r=e[1], kernel_size=k, do_res=self.do_res, norm_type=self.norm_type, grn=self.grn) + for _ in range(b[1]) + ]) + self.down_1 = MedNeXtDownBlock3D(2 * c, 4 * c, exp_r=e[2], kernel_size=k, do_res=self.do_res_up_down, norm_type=self.norm_type, grn=self.grn) + self.enc_block_2 = nn.Sequential(*[ + MedNeXtBlock3D(4 * c, 4 * c, exp_r=e[2], kernel_size=k, do_res=self.do_res, norm_type=self.norm_type, grn=self.grn) + for _ in range(b[2]) + ]) + self.down_2 = MedNeXtDownBlock3D(4 * c, 8 * c, exp_r=e[3], kernel_size=k, do_res=self.do_res_up_down, norm_type=self.norm_type, grn=self.grn) + self.enc_block_3 = nn.Sequential(*[ + MedNeXtBlock3D(8 * c, 8 * c, exp_r=e[3], kernel_size=k, do_res=self.do_res, norm_type=self.norm_type, grn=self.grn) + for _ in range(b[3]) + ]) + self.down_3 = MedNeXtDownBlock3D(8 * c, 16 * c, exp_r=e[4], kernel_size=k, do_res=self.do_res_up_down, norm_type=self.norm_type, grn=self.grn) + self.bottleneck = nn.Sequential(*[ + MedNeXtBlock3D(16 * c, 16 * c, exp_r=e[4], kernel_size=k, do_res=self.do_res, norm_type=self.norm_type, grn=self.grn) + for _ in range(b[4]) + ]) + self.dummy_tensor = nn.Parameter(torch.tensor([1.0]), requires_grad=True) + + def _run_sequential(self, block: nn.Sequential, x: torch.Tensor) -> torch.Tensor: + if not self.outside_block_checkpointing: + return block(x) + for layer in block: + x = checkpoint.checkpoint(layer, x, self.dummy_tensor) + return x + + def _run_module(self, module: nn.Module, x: torch.Tensor) -> torch.Tensor: + if not self.outside_block_checkpointing: + return module(x) + return checkpoint.checkpoint(module, x, self.dummy_tensor) + + def forward(self, x: torch.Tensor): + x = self.stem(x) + x_res_0 = self._run_sequential(self.enc_block_0, x) + x = self._run_module(self.down_0, x_res_0) + x_res_1 = self._run_sequential(self.enc_block_1, x) + x = self._run_module(self.down_1, x_res_1) + x_res_2 = self._run_sequential(self.enc_block_2, x) + x = self._run_module(self.down_2, x_res_2) + x_res_3 = self._run_sequential(self.enc_block_3, x) + x = self._run_module(self.down_3, x_res_3) + x = self._run_sequential(self.bottleneck, x) + return [x_res_0, x_res_1, x_res_2, x_res_3, x] + + +class MedNeXtDecoder(nn.Module): + def __init__( + self, + *, + encoder: MedNeXtEncoder, + num_classes: int | None, + deep_supervision: bool, + ): + super().__init__() + if num_classes is None and deep_supervision: + raise ValueError("MedNeXt shared decoder cannot emit deep supervision outputs without logits heads") + object.__setattr__(self, "_encoder_ref", encoder) + self.num_classes = num_classes + self.do_ds = bool(deep_supervision and num_classes is not None) + c = encoder.n_channels + e = encoder.exp_r + b = encoder.block_counts + k = encoder.kernel_size + self.outside_block_checkpointing = encoder.outside_block_checkpointing + self.dummy_tensor = nn.Parameter(torch.tensor([1.0]), requires_grad=True) + + self.up_3 = MedNeXtUpBlock3D(16 * c, 8 * c, exp_r=e[5], kernel_size=k, do_res=encoder.do_res_up_down, norm_type=encoder.norm_type, grn=encoder.grn) + self.dec_block_3 = nn.Sequential(*[ + MedNeXtBlock3D(8 * c, 8 * c, exp_r=e[5], kernel_size=k, do_res=encoder.do_res, norm_type=encoder.norm_type, grn=encoder.grn) + for _ in range(b[5]) + ]) + self.up_2 = MedNeXtUpBlock3D(8 * c, 4 * c, exp_r=e[6], kernel_size=k, do_res=encoder.do_res_up_down, norm_type=encoder.norm_type, grn=encoder.grn) + self.dec_block_2 = nn.Sequential(*[ + MedNeXtBlock3D(4 * c, 4 * c, exp_r=e[6], kernel_size=k, do_res=encoder.do_res, norm_type=encoder.norm_type, grn=encoder.grn) + for _ in range(b[6]) + ]) + self.up_1 = MedNeXtUpBlock3D(4 * c, 2 * c, exp_r=e[7], kernel_size=k, do_res=encoder.do_res_up_down, norm_type=encoder.norm_type, grn=encoder.grn) + self.dec_block_1 = nn.Sequential(*[ + MedNeXtBlock3D(2 * c, 2 * c, exp_r=e[7], kernel_size=k, do_res=encoder.do_res, norm_type=encoder.norm_type, grn=encoder.grn) + for _ in range(b[7]) + ]) + self.up_0 = MedNeXtUpBlock3D(2 * c, c, exp_r=e[8], kernel_size=k, do_res=encoder.do_res_up_down, norm_type=encoder.norm_type, grn=encoder.grn) + self.dec_block_0 = nn.Sequential(*[ + MedNeXtBlock3D(c, c, exp_r=e[8], kernel_size=k, do_res=encoder.do_res, norm_type=encoder.norm_type, grn=encoder.grn) + for _ in range(b[8]) + ]) + + if self.num_classes is None: + self.out_0 = None + self.out_1 = None + self.out_2 = None + self.out_3 = None + self.out_4 = None + else: + self.out_0 = OutBlock3D(c, self.num_classes) + if self.do_ds: + self.out_1 = OutBlock3D(2 * c, self.num_classes) + self.out_2 = OutBlock3D(4 * c, self.num_classes) + self.out_3 = OutBlock3D(8 * c, self.num_classes) + self.out_4 = OutBlock3D(16 * c, self.num_classes) + else: + self.out_1 = self.out_2 = self.out_3 = self.out_4 = None + + @property + def encoder(self) -> MedNeXtEncoder: + return self._encoder_ref + + def _run_sequential(self, block: nn.Sequential, x: torch.Tensor) -> torch.Tensor: + if not self.outside_block_checkpointing: + return block(x) + for layer in block: + x = checkpoint.checkpoint(layer, x, self.dummy_tensor) + return x + + def _run_module(self, module: nn.Module, x: torch.Tensor) -> torch.Tensor: + if not self.outside_block_checkpointing: + return module(x) + return checkpoint.checkpoint(module, x, self.dummy_tensor) + + def forward(self, features): + x_res_0, x_res_1, x_res_2, x_res_3, x = features + if self.do_ds: + x_ds_4 = self._run_module(self.out_4, x) + + x = self._run_sequential(self.dec_block_3, x_res_3 + self._run_module(self.up_3, x)) + if self.do_ds: + x_ds_3 = self._run_module(self.out_3, x) + + x = self._run_sequential(self.dec_block_2, x_res_2 + self._run_module(self.up_2, x)) + if self.do_ds: + x_ds_2 = self._run_module(self.out_2, x) + + x = self._run_sequential(self.dec_block_1, x_res_1 + self._run_module(self.up_1, x)) + if self.do_ds: + x_ds_1 = self._run_module(self.out_1, x) + + x = self._run_sequential(self.dec_block_0, x_res_0 + self._run_module(self.up_0, x)) + if self.num_classes is None: + return x + + x = self._run_module(self.out_0, x) + if self.do_ds: + return [x, x_ds_1, x_ds_2, x_ds_3, x_ds_4] + return x diff --git a/vesuvius/src/vesuvius/models/build/pretrained_backbones/__init__.py b/vesuvius/src/vesuvius/models/build/pretrained_backbones/__init__.py new file mode 100644 index 000000000..76955f141 --- /dev/null +++ b/vesuvius/src/vesuvius/models/build/pretrained_backbones/__init__.py @@ -0,0 +1,3 @@ +from .dinov2 import build_dinov2_backbone, build_dinov2_decoder + +__all__ = ["build_dinov2_backbone", "build_dinov2_decoder"] diff --git a/vesuvius/src/vesuvius/models/build/pretrained_backbones/dinov2.py b/vesuvius/src/vesuvius/models/build/pretrained_backbones/dinov2.py new file mode 100644 index 000000000..8e7897a12 --- /dev/null +++ b/vesuvius/src/vesuvius/models/build/pretrained_backbones/dinov2.py @@ -0,0 +1,324 @@ +import json +import math +from pathlib import Path +from typing import Optional + +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download + +from vesuvius.models.build.transformers.patch_encode_decode import ( + LayerNormNd, + PatchDecode, + PixelShuffle3D, + _compute_patch_decode_plan, +) + +from .dinovol_2_builder import build_dinovol_2_backbone + + +_DINO_V2_PRETRAINED = { + "dinov2_ps8": { + "repo_id": "scrollprize/dinov2_ps8", + "checkpoint": "dinov2_ps8.pt", + "config": "config.json", + }, + "dinov2_ps16_crop256": { + "repo_id": "scrollprize/dinov2_ps16_crop256", + "checkpoint": "dinov2_ps16_crop256.pt", + "config": "config_ps256.json", + }, +} + + +class Dinov2Backbone(nn.Module): + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + self.patch_embed_size = tuple(int(v) for v in backbone.patch_size) + self.ndim = len(self.patch_embed_size) + self.embed_dim = int(backbone.embed_dim) + self.output_channels = [self.embed_dim] + self.strides = [self.patch_embed_size] + + def forward(self, x): + spatial = tuple(int(v) for v in x.shape[2:]) + if len(spatial) != self.ndim: + raise ValueError( + f"DINOv2 backbone expected {self.ndim} spatial dims, got input shape {tuple(x.shape)}" + ) + if any(size % patch != 0 for size, patch in zip(spatial, self.patch_embed_size)): + raise ValueError( + f"Input shape {spatial} must be divisible by patch size {self.patch_embed_size}" + ) + tokens = self.backbone(x)["x_norm_patchtokens"] + grid = tuple(size // patch for size, patch in zip(spatial, self.patch_embed_size)) + features = tokens.transpose(1, 2).reshape(tokens.shape[0], tokens.shape[-1], *grid).contiguous() + return [features] + + +class MinimalDinov2Decoder(nn.Module): + def __init__(self, encoder, num_classes): + super().__init__() + conv = nn.Conv2d if encoder.ndim == 2 else nn.Conv3d + conv_t = nn.ConvTranspose2d if encoder.ndim == 2 else nn.ConvTranspose3d + self.project = conv(encoder.embed_dim, num_classes, kernel_size=1, stride=1, padding=0, bias=True) + self.upsample = None + if any(step > 1 for step in encoder.patch_embed_size): + self.upsample = conv_t( + num_classes, + num_classes, + kernel_size=encoder.patch_embed_size, + stride=encoder.patch_embed_size, + bias=True, + ) + + def forward(self, features): + x = features[0] if isinstance(features, list) else features + x = self.project(x) + if self.upsample is not None: + x = self.upsample(x) + return x + + +class PrimusPatchDecodeDinov2Decoder(nn.Module): + def __init__(self, encoder, num_classes): + super().__init__() + self.patch_decoder = PatchDecode( + patch_size=encoder.patch_embed_size, + embed_dim=encoder.embed_dim, + out_channels=num_classes, + norm=LayerNormNd, + activation=nn.GELU, + ) + + def forward(self, features): + x = features[0] if isinstance(features, list) else features + return self.patch_decoder(x) + + +class PixelShuffleConvDinov2Decoder(nn.Module): + def __init__(self, encoder, num_classes): + super().__init__() + if encoder.ndim != 3: + raise ValueError( + f"pixelshuffle_conv currently only supports 3D pretrained backbones, got ndim={encoder.ndim}" + ) + + num_stages, strides, channels = _compute_patch_decode_plan( + encoder.patch_embed_size, + embed_dim=encoder.embed_dim, + out_channels=num_classes, + ) + conv = nn.Conv3d + final_hidden_channels = max( + 8, + channels[-2] if num_stages > 1 else min(64, max(8, encoder.embed_dim // 8)), + ) + stages = [] + for stage_idx in range(num_stages): + next_channels = channels[stage_idx + 1] if stage_idx < num_stages - 1 else final_hidden_channels + scale_factors = tuple(int(v) for v in strides[stage_idx]) + expansion_channels = int(next_channels) * int(torch.tensor(scale_factors).prod().item()) + stage_ops = [ + conv(channels[stage_idx], expansion_channels, kernel_size=1, stride=1, padding=0, bias=True), + PixelShuffle3D(scale_factors), + conv(next_channels, next_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.GroupNorm(_resolve_group_norm_groups(next_channels), next_channels), + nn.GELU(), + ] + stages.append(nn.Sequential(*stage_ops)) + self.decode = nn.Sequential(*stages) + self.final_refine = nn.Sequential( + conv(final_hidden_channels, final_hidden_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.GroupNorm(_resolve_group_norm_groups(final_hidden_channels), final_hidden_channels), + nn.GELU(), + conv(final_hidden_channels, final_hidden_channels, kernel_size=3, stride=1, padding=1, bias=True), + conv(final_hidden_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True), + ) + + def forward(self, features): + x = features[0] if isinstance(features, list) else features + x = self.decode(x) + return self.final_refine(x) + + +class PixelShuffleConvHeadBigDinov2Decoder(nn.Module): + def __init__(self, encoder, num_classes): + super().__init__() + if encoder.ndim != 3: + raise ValueError( + f"pixelshuffle_convhead_big currently only supports 3D pretrained backbones, got ndim={encoder.ndim}" + ) + + num_stages, strides, channels = _compute_patch_decode_plan( + encoder.patch_embed_size, + embed_dim=encoder.embed_dim, + out_channels=num_classes, + ) + conv = nn.Conv3d + final_hidden_channels = max( + 8, + channels[-2] if num_stages > 1 else min(64, max(8, encoder.embed_dim // 8)), + ) + stages = [] + for stage_idx in range(num_stages): + next_channels = channels[stage_idx + 1] if stage_idx < num_stages - 1 else final_hidden_channels + scale_factors = tuple(int(v) for v in strides[stage_idx]) + expansion_channels = int(next_channels) * int(torch.tensor(scale_factors).prod().item()) + stage_ops = [ + conv(channels[stage_idx], expansion_channels, kernel_size=1, stride=1, padding=0, bias=True), + PixelShuffle3D(scale_factors), + conv(next_channels, next_channels, kernel_size=3, stride=1, padding=1, bias=False), + nn.GroupNorm(_resolve_group_norm_groups(next_channels), next_channels), + nn.GELU(), + ] + stages.append(nn.Sequential(*stage_ops)) + self.decode = nn.Sequential(*stages) + self.final_refine = nn.Sequential( + conv(final_hidden_channels, final_hidden_channels, kernel_size=7, stride=1, padding=3, bias=False), + nn.GroupNorm(_resolve_group_norm_groups(final_hidden_channels), final_hidden_channels), + nn.GELU(), + conv(final_hidden_channels, final_hidden_channels, kernel_size=3, stride=1, padding=1, bias=True), + conv(final_hidden_channels, num_classes, kernel_size=1, stride=1, padding=0, bias=True), + ) + + def forward(self, features): + x = features[0] if isinstance(features, list) else features + x = self.decode(x) + return self.final_refine(x) + + +def _resolve_group_norm_groups(num_channels: int, max_groups: int = 8) -> int: + if num_channels < 1: + raise ValueError(f"GroupNorm requires num_channels >= 1, got {num_channels}") + return max(1, math.gcd(int(num_channels), int(max_groups))) + + +def _load_model_config(spec, input_channels): + checkpoint_path = Path(hf_hub_download(repo_id=spec["repo_id"], filename=spec["checkpoint"])) + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + config = _resolve_checkpoint_config(spec, checkpoint_path, checkpoint) + model_config = dict(config.get("model") or {}) + dataset_config = config.get("dataset") or {} + if "global_crops_size" not in model_config and "global_crop_size" in dataset_config: + model_config["global_crops_size"] = dataset_config["global_crop_size"] + if "local_crops_size" not in model_config and "local_crop_size" in dataset_config: + model_config["local_crops_size"] = dataset_config["local_crop_size"] + model_config["input_channels"] = int(input_channels) + return model_config, checkpoint_path, checkpoint + + +def _load_json_config(path): + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + +def _resolve_local_checkpoint_config( + checkpoint_path: Path, + checkpoint, + *, + config_path: Optional[str | Path] = None, +): + embedded_config = checkpoint.get("config") + if isinstance(embedded_config, dict) and isinstance(embedded_config.get("model"), dict): + return embedded_config + + candidate_paths: list[Path] = [] + if config_path is not None: + candidate_paths.append(Path(config_path).expanduser()) + candidate_paths.append(checkpoint_path.with_name("config.json")) + candidate_paths.append(checkpoint_path.with_suffix(".json")) + + seen = set() + for candidate in candidate_paths: + resolved = candidate.resolve() if candidate.exists() else candidate + key = str(resolved) + if key in seen: + continue + seen.add(key) + if candidate.exists(): + return _load_json_config(candidate) + + raise ValueError( + "Local Dinovol checkpoint does not contain an embedded config and no JSON config was found. " + f"checkpoint={checkpoint_path}" + ) + + +def _resolve_checkpoint_config(spec, checkpoint_path, checkpoint): + embedded_config = checkpoint.get("config") + if isinstance(embedded_config, dict) and isinstance(embedded_config.get("model"), dict): + return embedded_config + + sidecar_path = checkpoint_path.with_name("config.json") + if sidecar_path.exists(): + return _load_json_config(sidecar_path) + + config_path = hf_hub_download(repo_id=spec["repo_id"], filename=spec["config"]) + return _load_json_config(config_path) + + +def _load_teacher_backbone_state(checkpoint): + if isinstance(checkpoint, (str, Path)): + checkpoint = torch.load(checkpoint, map_location="cpu", weights_only=False) + for branch_name in ("teacher", "student"): + branch_state = checkpoint.get(branch_name) + if not isinstance(branch_state, dict): + continue + backbone_state = { + key.replace("backbone.", "", 1): value + for key, value in branch_state.items() + if key.startswith("backbone.") + } + if backbone_state: + return backbone_state + raise ValueError("Checkpoint does not contain teacher.backbone or student.backbone weights") + + +def _load_local_model_config(checkpoint_path, input_channels, *, config_path=None): + checkpoint_path = Path(checkpoint_path).expanduser() + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Local Dinovol checkpoint not found: {checkpoint_path}") + + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + config = _resolve_local_checkpoint_config(checkpoint_path, checkpoint, config_path=config_path) + model_config = dict(config.get("model") or {}) + dataset_config = config.get("dataset") or {} + if "global_crops_size" not in model_config and "global_crop_size" in dataset_config: + model_config["global_crops_size"] = dataset_config["global_crop_size"] + if "local_crops_size" not in model_config and "local_crop_size" in dataset_config: + model_config["local_crops_size"] = dataset_config["local_crop_size"] + model_config["input_channels"] = int(input_channels) + return model_config, checkpoint_path, checkpoint + + +def build_dinov2_backbone(name, input_channels, input_shape, *, config_path=None): + del input_shape + if name in _DINO_V2_PRETRAINED: + spec = _DINO_V2_PRETRAINED[name] + model_config, checkpoint_path, checkpoint = _load_model_config(spec, input_channels) + else: + model_config, checkpoint_path, checkpoint = _load_local_model_config( + name, + input_channels, + config_path=config_path, + ) + backbone = build_dinovol_2_backbone(model_config) + backbone.load_pretrained_weights(_load_teacher_backbone_state(checkpoint), unchunk=True) + return Dinov2Backbone(backbone) + + +def build_dinov2_decoder(name, encoder, num_classes): + if name == "minimal": + return MinimalDinov2Decoder(encoder, num_classes) + if name == "primus_patch_decode": + return PrimusPatchDecodeDinov2Decoder(encoder, num_classes) + if name == "pixelshuffle_conv": + return PixelShuffleConvDinov2Decoder(encoder, num_classes) + if name == "pixelshuffle_convhead_big": + return PixelShuffleConvHeadBigDinov2Decoder(encoder, num_classes) + raise ValueError( + "Unknown pretrained_decoder_type " + f"{name!r}. Expected one of: minimal, primus_patch_decode, pixelshuffle_conv, pixelshuffle_convhead_big" + ) diff --git a/vesuvius/src/vesuvius/models/build/pretrained_backbones/dinovol_2_builder.py b/vesuvius/src/vesuvius/models/build/pretrained_backbones/dinovol_2_builder.py new file mode 100644 index 000000000..d23619e76 --- /dev/null +++ b/vesuvius/src/vesuvius/models/build/pretrained_backbones/dinovol_2_builder.py @@ -0,0 +1,223 @@ +from typing import Any, Mapping, Optional, Tuple + +import torch +from torch import nn + +from .dinovol_2_eva import Eva, EvaWithChunking +from .rope import MixedRopePositionEmbedding, RopePositionEmbedding + + +_BACKBONE_DEFAULTS = { + "input_channels": 1, + "global_crops_size": (256, 256, 256), + "local_crops_size": None, + "embed_dim": 864, + "patch_size": (16, 16, 16), + "embedding_type": "default", + "deeper_embed_patch_chunk_size": None, + "deeper_embed_batch_chunk_size": 1, + "depth": 24, + "num_heads": 16, + "qkv_bias": True, + "qkv_fused": False, + "mlp_ratio": 8.0 / 3.0, + "swiglu_mlp": True, + "scale_mlp": True, + "scale_attn_inner": False, + "proj_drop_rate": 0.0, + "attn_drop_rate": 0.0, + "drop_path_rate": 0.0, + "drop_path_uniform": False, + "init_values": None, + "use_abs_pos_emb": False, + "use_rot_pos_emb": True, + "num_reg_tokens": 4, + "grad_checkpointing": False, + "block_chunks": 0, +} +_BACKBONE_DEFAULTS_V2 = { + "input_channels": 1, + "global_crops_size": (256, 256, 256), + "local_crops_size": None, + "embed_dim": 864, + "patch_size": (16, 16, 16), + "embedding_type": "default", + "deeper_embed_patch_chunk_size": None, + "deeper_embed_batch_chunk_size": 1, + "depth": 24, + "num_heads": 16, + "qkv_bias": True, + "qkv_fused": True, + "mlp_ratio": 8.0 / 3.0, + "swiglu_mlp": True, + "scale_mlp": True, + "scale_attn_inner": False, + "proj_drop_rate": 0.0, + "attn_drop_rate": 0.0, + "drop_path_rate": 0.3, + "drop_path_uniform": False, + "init_values": 1e-5, + "use_abs_pos_emb": False, + "use_rot_pos_emb": True, + "num_reg_tokens": 4, + "grad_checkpointing": False, + "block_chunks": 0, +} +_ROPE_KWARG_CONFIG_KEYS = { + "base": "rope_base", + "min_period": "rope_min_period", + "max_period": "rope_max_period", + "normalize_coords": "rope_normalize_coords", + "shift_coords": "rope_shift_coords", + "jitter_coords": "rope_jitter_coords", + "rescale_coords": "rope_rescale_coords", + "dtype": "rope_dtype", +} +_ROPE_KWARG_DEFAULTS = { + "base": 100.0, + "normalize_coords": "separate", + "rescale_coords": 2.0, +} +_ROPE_KWARG_FALLBACK_KEYS = { + "base": "pos_embed_rope_base", + "min_period": "pos_embed_rope_min_period", + "max_period": "pos_embed_rope_max_period", + "normalize_coords": "pos_embed_rope_normalize_coords", + "shift_coords": "pos_embed_rope_shift_coords", + "jitter_coords": "pos_embed_rope_jitter_coords", + "rescale_coords": "pos_embed_rope_rescale_coords", + "dtype": "pos_embed_rope_dtype", +} +_ROPE_DTYPE_ALIASES = { + "fp32": torch.float32, + "float32": torch.float32, + "fp16": torch.float16, + "float16": torch.float16, + "bf16": torch.bfloat16, + "bfloat16": torch.bfloat16, +} +_ROPE_IMPL_ALIASES = { + "axial": RopePositionEmbedding, + "default": RopePositionEmbedding, + "mixed": MixedRopePositionEmbedding, + "mixed_learnable": MixedRopePositionEmbedding, +} + + +def _resolve_backbone_defaults(config: Mapping[str, Any]) -> Mapping[str, Any]: + model_type = config.get("model_type") + if isinstance(model_type, str) and model_type.strip().lower() == "v2": + return _BACKBONE_DEFAULTS_V2 + return _BACKBONE_DEFAULTS + + +def _resolve_default_rope_type(config: Mapping[str, Any]) -> str: + model_type = config.get("model_type") + if isinstance(model_type, str) and model_type.strip().lower() == "v2": + return "mixed" + return "axial" + + +def _as_3tuple(value: int | Tuple[int, int, int]) -> Tuple[int, int, int]: + if isinstance(value, int): + return (value, value, value) + result = tuple(int(v) for v in value) + if len(result) != 3: + raise ValueError(f"expected 3 values and got {len(result)}: {result}") + return result + + +def _config_value( + config: Mapping[str, Any], + key: str, + default: Any, + *, + fallback_key: Optional[str] = None, +) -> Any: + if key in config: + return config[key] + if fallback_key is not None and fallback_key in config: + return config[fallback_key] + return default + + +def _resolve_rope_dtype(value: Any) -> Any: + if value is None or isinstance(value, torch.dtype): + return value + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in _ROPE_DTYPE_ALIASES: + return _ROPE_DTYPE_ALIASES[normalized] + raise ValueError(f"unsupported rope dtype value: {value!r}") + + +def _resolve_rope_kwargs(config: Mapping[str, Any]) -> dict[str, Any]: + rope_kwargs = dict(config.get("rope_kwargs") or {}) + for rope_key, config_key in _ROPE_KWARG_CONFIG_KEYS.items(): + if rope_key in rope_kwargs: + continue + fallback_key = _ROPE_KWARG_FALLBACK_KEYS[rope_key] + if config_key in config or fallback_key in config: + rope_kwargs[rope_key] = _config_value(config, config_key, None, fallback_key=fallback_key) + + for rope_key, default_value in _ROPE_KWARG_DEFAULTS.items(): + rope_kwargs.setdefault(rope_key, default_value) + + if "dtype" in rope_kwargs: + rope_kwargs["dtype"] = _resolve_rope_dtype(rope_kwargs["dtype"]) + + return rope_kwargs + + +def _resolve_rope_impl( + config: Mapping[str, Any], + rope_kwargs: Mapping[str, Any], +) -> tuple[type[nn.Module], dict[str, Any]]: + resolved_kwargs = dict(rope_kwargs) + rope_type = _config_value(config, "rope_type", None, fallback_key="pos_embed_rope_type") + if rope_type is None: + rope_type = resolved_kwargs.pop("type", resolved_kwargs.pop("rope_type", _resolve_default_rope_type(config))) + + if isinstance(rope_type, str): + normalized = rope_type.strip().lower() + if normalized not in _ROPE_IMPL_ALIASES: + raise ValueError( + f"unsupported rope_type={rope_type!r}; expected one of {sorted(_ROPE_IMPL_ALIASES)}" + ) + return _ROPE_IMPL_ALIASES[normalized], resolved_kwargs + + if isinstance(rope_type, type) and issubclass(rope_type, nn.Module): + return rope_type, resolved_kwargs + + raise ValueError(f"unsupported rope_type value: {rope_type!r}") + + +def build_dinovol_2_backbone(config: Mapping[str, Any]) -> Eva: + backbone_defaults = _resolve_backbone_defaults(config) + backbone_config = {key: config.get(key, default) for key, default in backbone_defaults.items()} + if "num_reg_tokens" not in config and "num_register_tokens" in config: + backbone_config["num_reg_tokens"] = int(config["num_register_tokens"]) + + rope_kwargs = _resolve_rope_kwargs(config) + rope_impl, rope_kwargs = _resolve_rope_impl(config, rope_kwargs) + + global_crops_size = _as_3tuple(backbone_config["global_crops_size"]) + local_crop_value = backbone_config["local_crops_size"] or global_crops_size + local_crops_size = _as_3tuple(local_crop_value) + + block_chunks = int(backbone_config["block_chunks"]) + backbone_cls = EvaWithChunking if block_chunks > 0 else Eva + kwargs = dict(backbone_config) + kwargs.update( + { + "global_crops_size": global_crops_size, + "local_crops_size": local_crops_size, + "patch_size": _as_3tuple(backbone_config["patch_size"]), + "rope_impl": rope_impl, + "rope_kwargs": dict(rope_kwargs), + } + ) + kwargs.pop("block_chunks") + if backbone_cls is EvaWithChunking: + kwargs["block_chunks"] = block_chunks + return backbone_cls(**kwargs) diff --git a/vesuvius/src/vesuvius/models/build/pretrained_backbones/dinovol_2_eva.py b/vesuvius/src/vesuvius/models/build/pretrained_backbones/dinovol_2_eva.py new file mode 100644 index 000000000..9128f9062 --- /dev/null +++ b/vesuvius/src/vesuvius/models/build/pretrained_backbones/dinovol_2_eva.py @@ -0,0 +1,919 @@ +from collections import OrderedDict +import math +from typing import Callable, Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from timm.layers import trunc_normal_, use_fused_attn, DropPath, SwiGLU, GluMlp, Mlp +from torch import nn +from torch.nn import LayerNorm +from torch.utils.checkpoint import checkpoint +from einops import rearrange + +from .patch_encode_decode import PatchEmbed, PatchEmbedDeeper +from .rope import MixedRopePositionEmbedding, RopeEmbedding, RopePositionEmbedding, apply_rotary_embedding + + +class InitWeights_He(object): + def __init__(self, neg_slope: float = 1e-2): + self.neg_slope = neg_slope + + def __call__(self, module): + if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): + module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) + if module.bias is not None: + module.bias = nn.init.constant_(module.bias, 0) + + +class EvaAttention(nn.Module): + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + qkv_fused: bool = True, + num_prefix_tokens: int = 1, + qkv_bias_separate: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + attn_head_dim: Optional[int] = None, + norm_layer: Optional[Callable] = None, + ): + """ + + Args: + dim: + num_heads: + qkv_bias: + qkv_fused: + attn_drop: + proj_drop: + attn_head_dim: + norm_layer: + """ + super().__init__() + self.num_heads = num_heads + if attn_head_dim is None and dim % num_heads != 0: + raise ValueError(f"dim must be divisible by num_heads, got dim={dim}, num_heads={num_heads}") + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = head_dim ** -0.5 + self.num_prefix_tokens = num_prefix_tokens + self.fused_attn = use_fused_attn() + self.qkv_bias_separate = qkv_bias_separate + + if qkv_fused: + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + self.q_proj = self.k_proj = self.v_proj = None + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = self.k_bias = self.v_bias = None + else: + self.q_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias) + self.k_proj = nn.Linear(dim, all_head_dim, bias=False) + self.v_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias) + self.qkv = None + self.q_bias = self.k_bias = self.v_bias = None + + self.attn_drop = nn.Dropout(attn_drop) + self.norm = norm_layer(all_head_dim) if norm_layer is not None else nn.Identity() + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward( + self, + x, + rope: Optional[RopeEmbedding] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + B, N, C = x.shape + + if self.qkv is not None: + if self.q_bias is None: + qkv = self.qkv(x) + else: + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) + if self.qkv_bias_separate: + qkv = self.qkv(x) + qkv += qkv_bias + else: + qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim + else: + q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C + k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) + v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) + + if rope is not None: + q = apply_rotary_embedding(q, rope, prefix_tokens=self.num_prefix_tokens).type_as(v) + k = apply_rotary_embedding(k, rope, prefix_tokens=self.num_prefix_tokens).type_as(v) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + raise RuntimeError("Fused attention should be used.") + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + if attn_mask is not None: + attn_mask = attn_mask.to(torch.bool) + attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf")) + attn = attn.softmax(dim=-1) + + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.norm(x) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class EvaBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + qkv_bias: bool = True, + qkv_fused: bool = True, + mlp_ratio: float = 4., + swiglu_mlp: bool = False, + scale_mlp: bool = False, + scale_attn_inner: bool = False, + num_prefix_tokens: int = 1, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + init_values: Optional[float] = None, + act_layer: Callable = nn.GELU, + norm_layer: Callable = LayerNorm, + attn_head_dim: Optional[int] = None, + drop_path_scale: bool = True, + rope_impl=None, + rope_kwargs=None, + ndim: Optional[int] = None, + ): + """ + + Args: + dim: + num_heads: + qkv_bias: + qkv_fused: + mlp_ratio: + swiglu_mlp: + scale_mlp: + scale_attn_inner: + proj_drop: + attn_drop: + drop_path: + init_values: + act_layer: + norm_layer: + attn_head_dim: + """ + super().__init__() + if rope_kwargs is None: + rope_kwargs = {} + self.norm1 = norm_layer(dim) + self.attn = EvaAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qkv_fused=qkv_fused, + num_prefix_tokens=num_prefix_tokens, + attn_drop=attn_drop, + proj_drop=proj_drop, + attn_head_dim=attn_head_dim, + norm_layer=norm_layer if scale_attn_inner else None, + ) + self.rope_embed = None + if rope_impl is not None: + if attn_head_dim is None and dim % num_heads != 0: + raise ValueError(f"dim must be divisible by num_heads, got dim={dim}, num_heads={num_heads}") + head_dim = attn_head_dim if attn_head_dim is not None else dim // num_heads + rope_kwargs_local = dict(rope_kwargs) + self.rope_embed = rope_impl( + head_dim, + ndim=ndim, + num_heads=num_heads, + **rope_kwargs_local, + ) + self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None + self.drop_path1 = DropPath(drop_path, drop_path_scale) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + hidden_features = int(dim * mlp_ratio) + if swiglu_mlp: + if scale_mlp: + # when norm in SwiGLU used, an impl with separate fc for gate & x is used + self.mlp = SwiGLU( + in_features=dim, + hidden_features=hidden_features, + norm_layer=norm_layer if scale_mlp else None, + drop=proj_drop, + ) + else: + # w/o any extra norm, an impl with packed weights is used, matches existing GluMLP + self.mlp = GluMlp( + in_features=dim, + hidden_features=hidden_features * 2, + norm_layer=norm_layer if scale_mlp else None, + act_layer=nn.SiLU, + gate_last=False, + drop=proj_drop, + ) + else: + self.mlp = Mlp( + in_features=dim, + hidden_features=hidden_features, + act_layer=act_layer, + norm_layer=norm_layer if scale_mlp else None, + drop=proj_drop, + ) + self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None + self.drop_path2 = DropPath(drop_path, drop_path_scale) if drop_path > 0. else nn.Identity() + + def forward( + self, + x, + rope: Optional[RopeEmbedding] = None, + attn_mask: Optional[torch.Tensor] = None, + rope_shape: Optional[Tuple[int, ...]] = None, + ): + if rope is None and self.rope_embed is not None: + if rope_shape is None: + raise ValueError("rope_shape must be provided when using per-block RoPE") + rope = self.rope_embed.get_embed(rope_shape) + if self.gamma_1 is None: + x = x + self.drop_path1(self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask)) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask)) + x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class Eva(nn.Module): + """ Eva Vision Transformer w/ Abs & Rotary Pos Embed + + This class implements the EVA and EVA02 models that were based on the BEiT ViT variant + * EVA - abs pos embed, global avg pool + * EVA02 - abs + rope pos embed, global avg pool, SwiGLU, scale Norm in MLP (ala normformer) + + + """ + + @staticmethod + def _assert_patch_aligned( + spatial_shape: Tuple[int, ...], + patch_size: Tuple[int, ...], + *, + context: str, + ) -> None: + remainders = [int(size) % int(patch) for size, patch in zip(spatial_shape, patch_size)] + if any(remainders): + raise AssertionError( + f"{context} must be divisible by patch_size for PatchEmbedDeeper, " + f"got spatial_shape={tuple(spatial_shape)} and patch_size={tuple(patch_size)}" + ) + + @staticmethod + def _normalize_optional_chunk_shape( + value: Optional[int | Tuple[int, ...]], + *, + ndim: int, + name: str, + ) -> Optional[Tuple[int, ...]]: + if value is None: + return None + if isinstance(value, int): + if value <= 0: + raise ValueError(f"{name} must be positive, got {value}") + return tuple([int(value)] * ndim) + result = tuple(int(v) for v in value) + if len(result) != ndim: + raise ValueError(f"{name} must provide {ndim} values, got {result}") + if any(v <= 0 for v in result): + raise ValueError(f"{name} must be positive, got {result}") + return result + + @staticmethod + def _normalize_optional_positive_int( + value: Optional[int], + *, + name: str, + ) -> Optional[int]: + if value is None: + return None + value = int(value) + if value <= 0: + raise ValueError(f"{name} must be positive, got {value}") + return value + + def __init__( + self, + input_channels: int = 1, + global_crops_size: Tuple[int, ...] = None, + local_crops_size: Tuple[int, ...] = None, + embed_dim: int = 864, + patch_size: Tuple[int, ...] = (8, 8, 8), + embedding_type: str = "default", + depth: int = 24, + num_heads: int = 12, + qkv_bias: bool = True, + qkv_fused: bool = False, + mlp_ratio: float = 4 * 2 / 3, + swiglu_mlp: bool = True, + scale_mlp: bool = True, + scale_attn_inner: bool = False, + pos_drop_rate: float = 0., + proj_drop_rate: float = 0., + # drops out things related to the projection. That is in the MLP and at the end of EVA attention + attn_drop_rate: float = 0., + # drops attention, meaning connections between patches may bebroken up at random + drop_path_rate: float = 0., + # drops computations (multihead attention, mlp), Implementation of scaling might be useless here because this is not batch normed + drop_path_uniform: bool = False, + norm_layer: Callable = LayerNorm, + init_values: Optional[float] = None, + class_token: bool = True, + use_abs_pos_emb: bool = False, + use_rot_pos_emb: bool = True, + dynamic_img_size: bool = False, + num_reg_tokens: int = 0, + drop_path_scale: bool = True, + rope_impl=RopePositionEmbedding, + rope_kwargs=None, + grad_checkpointing=False, + deeper_embed_patch_chunk_size: Optional[int | Tuple[int, ...]] = None, + deeper_embed_batch_chunk_size: Optional[int] = None, + ): + """ + Diff to timm implementation + + - removed patch embedding, we expect embeded patches + - removed classification token, we use features at the end + - removed head + - dynamic image size is not supported, but left in for future stuff + - self.cls_token removed + - removed postnorm block support + """ + super().__init__() + + self.input_channels = input_channels + self.patch_size = [patch_size] * 3 if isinstance(patch_size, int) else patch_size + self.ndim = len(self.patch_size) + self.embedding_type = str(embedding_type).lower() + self.global_crops_size = [global_crops_size] * 3 if isinstance(global_crops_size, int) else global_crops_size + self.local_crops_size = [local_crops_size] * 3 if isinstance(local_crops_size, int) else local_crops_size + self.global_input_size = tuple(int(size) for size in self.global_crops_size) + self.local_input_size = tuple(int(size) for size in self.local_crops_size) + self.deeper_embed_patch_halo = tuple(0 for _ in range(self.ndim)) + self.deeper_embed_patch_chunk_size = self._normalize_optional_chunk_shape( + deeper_embed_patch_chunk_size, + ndim=self.ndim, + name="deeper_embed_patch_chunk_size", + ) + self.deeper_embed_batch_chunk_size = self._normalize_optional_positive_int( + deeper_embed_batch_chunk_size, + name="deeper_embed_batch_chunk_size", + ) + + if self.embedding_type == "deeper": + self._assert_patch_aligned( + tuple(self.global_crops_size), + tuple(self.patch_size), + context="global_crops_size", + ) + self._assert_patch_aligned( + tuple(self.local_crops_size), + tuple(self.patch_size), + context="local_crops_size", + ) + + self.global_ref_feat_shape = tuple([i // ds for i, ds in zip(self.global_crops_size, self.patch_size)]) + self.local_ref_feat_shape = tuple([i // ds for i, ds in zip(self.local_crops_size, self.patch_size)]) + + if self.embedding_type == "default": + self.down_projection = PatchEmbed( + patch_size=tuple(self.patch_size), + input_channels=input_channels, + embed_dim=embed_dim, + ) + elif self.embedding_type == "deeper": + self.down_projection = PatchEmbedDeeper( + patch_size=tuple(self.patch_size), + input_channels=input_channels, + embed_dim=embed_dim, + ) + else: + raise ValueError( + f"unsupported embedding_type={embedding_type!r}; expected 'default' or 'deeper'" + ) + + if rope_kwargs is None: + rope_kwargs = {} + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.dynamic_img_size = dynamic_img_size + self.grad_checkpointing = grad_checkpointing + + self.num_reg_tokens = num_reg_tokens + self.num_class_tokens = (1 if class_token else 0) + self.num_prefix_tokens = self.num_class_tokens + self.num_reg_tokens + + num_patches = np.prod(self.global_ref_feat_shape) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + self.reg_token = nn.Parameter(torch.zeros(1, num_reg_tokens, embed_dim)) if num_reg_tokens else None + self.cls_embed = class_token and self.reg_token is None + + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_class_tokens, embed_dim)) if use_abs_pos_emb else None + self.pos_drop = nn.Dropout(p=pos_drop_rate) + self.use_per_block_rope = bool(use_rot_pos_emb and rope_impl is MixedRopePositionEmbedding) + if use_rot_pos_emb and embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim must be divisible by num_heads, got embed_dim={embed_dim}, num_heads={num_heads}" + ) + head_dim = embed_dim // num_heads + if use_rot_pos_emb and head_dim % (2 * self.ndim) != 0: + raise ValueError( + f"RoPE requires head_dim divisible by 2 * ndim, got head_dim={head_dim}, ndim={self.ndim}" + ) + if use_rot_pos_emb and not self.use_per_block_rope: + self.rope_embed = rope_impl( + head_dim, + ndim=self.ndim, + num_heads=num_heads, + **rope_kwargs, + ) + else: + self.rope_embed = None + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + block_fn = EvaBlock + self.blocks = nn.ModuleList([ + block_fn( + dim=embed_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qkv_fused=qkv_fused, + mlp_ratio=mlp_ratio, + swiglu_mlp=swiglu_mlp, + scale_mlp=scale_mlp, + scale_attn_inner=scale_attn_inner, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + init_values=init_values, + num_prefix_tokens=self.num_prefix_tokens, + drop_path_scale=drop_path_scale, + rope_impl=rope_impl if self.use_per_block_rope else None, + rope_kwargs=rope_kwargs if self.use_per_block_rope else None, + ndim=self.ndim, + ) + for i in range(depth)]) + + self.norm = norm_layer(embed_dim) + self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + self._init_weights() + + def _init_weights(self): + def init_fn(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + self.apply(init_fn) + self.down_projection.apply(InitWeights_He(1e-2)) + + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + if self.cls_token is not None: + trunc_normal_(self.cls_token, std=.02) + if self.reg_token is not None: + trunc_normal_(self.reg_token, std=.02) + if self.mask_token is not None: + trunc_normal_(self.mask_token, std=.02) + + # Inline fix_init_weight + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + if hasattr(layer.attn.proj, 'weight'): + rescale(layer.attn.proj.weight.data, layer_id + 1) + if hasattr(layer.mlp.fc2, 'weight'): + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + @torch.jit.ignore + def no_weight_decay(self): + nwd = {'pos_embed', 'cls_token'} + return nwd + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^cls_token|pos_embed|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))], + ) + return matcher + + def _pos_embed(self, x, *spatial) -> Tuple[torch.Tensor, Optional[RopeEmbedding]]: + """ + Computes positional embeddings with interpolation if needed. + + Args: + x (torch.Tensor): Input tensor after patch embedding, shape (B, N, C). + spatial: Spatial dimensions of the original image before patch embedding. + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: Positionally encoded input. + """ + pos_embed = self.pos_embed + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + + source_size = tuple(self.global_ref_feat_shape) + target_size = tuple(dim // patch for dim, patch in zip(spatial, self.patch_size)) + + # If needed, interpolate only patch embeddings + if source_size != target_size: + if pos_embed is not None: + pos_embed = self.interpolate_pos_encoding_nd( + pos_embed, + source_size=source_size, + target_size=target_size, + num_prefix_tokens=self.num_class_tokens + ) + rot_pos_embed = self._get_rot_pos_embed(target_size) + + # Add interpolated positional embeddings + if pos_embed is not None: + x = x + pos_embed + + if self.reg_token is not None: + reg_tokens = self.reg_token.expand(x.shape[0], -1, -1) + if self.cls_token is not None: + x = torch.cat((x[:, :1], reg_tokens, x[:, 1:]), dim=1) + else: + x = torch.cat((reg_tokens, x), dim=1) + + x = self.pos_drop(x) + + return x, rot_pos_embed + + def _get_rot_pos_embed(self, target_size: Tuple[int, ...]) -> Optional[RopeEmbedding]: + if self.rope_embed is None: + return None + return self.rope_embed.get_embed(target_size) + + def interpolate_pos_encoding_nd( + self, + pos_embed: torch.Tensor, + source_size: tuple, + target_size: tuple, + num_prefix_tokens: int = 1, + ) -> torch.Tensor: + """ + Interpolates positional embeddings to match a new spatial size. + + Args: + pos_embed (torch.Tensor): Positional embeddings (1, N, D). + source_size: Original source grid size. + target_size: New target grid size. + num_prefix_tokens (int): Number of special tokens (e.g., CLS, registers). + + Returns: + torch.Tensor: Rescaled positional embeddings (1, N_new, D). + """ + _, N, C = pos_embed.shape + N = N - num_prefix_tokens # Remove prefix tokens + + previous_dtype = pos_embed.dtype + pos_embed = pos_embed.float() + + if num_prefix_tokens > 0: + pos_prefix, pos_embed = pos_embed[:, :num_prefix_tokens], pos_embed[:, num_prefix_tokens:] + else: + pos_prefix = None + + ndim = len(source_size) + if ndim not in (2, 3): + raise ValueError(f"Only 2D and 3D positional interpolation are supported, got ndim={ndim}") + + interpolation_mode = "bilinear" if ndim == 2 else "trilinear" + + # Reshape from (1, N, C) -> (1, C, *source_size) + pos_embed = pos_embed.reshape(1, *source_size, C) + permute_order = (0, ndim + 1, *range(1, ndim + 1)) + pos_embed = pos_embed.permute(*permute_order) + + # Interpolate to the new spatial grid. + pos_embed = F.interpolate(pos_embed, size=target_size, mode=interpolation_mode, align_corners=False) + + # Reshape back to (1, N, C) + inverse_permute = (0, *range(2, ndim + 2), 1) + pos_embed = pos_embed.permute(*inverse_permute).reshape(1, -1, C) + + # Reattach prefix tokens + if pos_prefix is not None: + pos_embed = torch.cat([pos_prefix, pos_embed], dim=1) + + pos_embed = pos_embed.to(previous_dtype) + + return pos_embed + + def _resolve_target_spatial_shape(self, spatial: tuple[int, ...], *, view_kind: str) -> tuple[int, ...]: + spatial = tuple(int(dim) for dim in spatial) + if view_kind == "global": + target = tuple(int(dim) for dim in self.global_crops_size) + input_size = tuple(int(dim) for dim in self.global_input_size) + elif view_kind == "local": + target = tuple(int(dim) for dim in self.local_crops_size) + input_size = tuple(int(dim) for dim in self.local_input_size) + else: + raise ValueError(f"unknown view_kind={view_kind!r}") + + if spatial == input_size or spatial == target: + return target + raise ValueError( + f"unexpected input shape {spatial} for embedding_type={self.embedding_type!r} and view_kind={view_kind!r}; " + f"expected {input_size} or {target}" + ) + + def _crop_embedded_grid(self, x: torch.Tensor, target_spatial: tuple[int, ...]) -> torch.Tensor: + target_patch_shape = tuple(int(size) // int(patch) for size, patch in zip(target_spatial, self.patch_size)) + current_patch_shape = tuple(int(dim) for dim in x.shape[2:]) + if current_patch_shape == target_patch_shape: + return x + + starts = [] + for current, target in zip(current_patch_shape, target_patch_shape): + delta = current - target + if delta < 0 or delta % 2 != 0: + raise ValueError( + f"cannot center-crop embedded grid from {current_patch_shape} to {target_patch_shape}" + ) + starts.append(delta // 2) + slices = tuple(slice(start, start + size) for start, size in zip(starts, target_patch_shape)) + return x[(slice(None), slice(None), *slices)] + + def prepare_tokens_with_masks(self, x, masks=None, *, view_kind: str = "global"): + spatial = tuple(x.shape[2:]) + self._assert_patch_aligned(spatial, tuple(self.patch_size), context="input shape") + if self.embedding_type == "deeper": + target_spatial = self._resolve_target_spatial_shape(spatial, view_kind=view_kind) + target_patch_shape = tuple(int(size) // int(patch) for size, patch in zip(target_spatial, self.patch_size)) + if ( + self.deeper_embed_patch_chunk_size is not None + or self.deeper_embed_batch_chunk_size is not None + ): + x = self.down_projection.forward_tiled( + x, + target_patch_shape=target_patch_shape, + patch_chunk_size=self.deeper_embed_patch_chunk_size, + batch_chunk_size=self.deeper_embed_batch_chunk_size, + ) + else: + x = self.down_projection(x) + x = self._crop_embedded_grid(x, target_spatial) + else: + x = self.down_projection(x) + target_spatial = spatial + if self.ndim == 2: + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + else: + x = rearrange(x, 'b c d h w -> b (d h w) c').contiguous() + + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype), x) + + target_patch_shape = tuple(dim // patch for dim, patch in zip(target_spatial, self.patch_size)) + x, rot_pos_embed = self._pos_embed(x, *target_spatial) + + return x, rot_pos_embed, target_patch_shape + + def forward_features_list(self, x_list, masks_list, *, view_kind: str = "global"): + if not isinstance(x_list, list): + return self.forward_features(x_list, masks_list, view_kind=view_kind) + output = [] + for x, masks in zip(x_list, masks_list): + x_out = self.forward_features(x, masks, view_kind=view_kind) + output.append(x_out) + return output + + def forward_features(self, x, masks=None, *, view_kind: str = "global"): + x, rot_pos_embed, rope_shape = self.prepare_tokens_with_masks(x, masks, view_kind=view_kind) + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, rope=rot_pos_embed, rope_shape=rope_shape) + else: + x = blk(x, rope=rot_pos_embed, rope_shape=rope_shape) + x = self.norm(x) + outputs = { + "x_norm_clstoken": x[:, 0] if self.num_class_tokens > 0 else None, + "x_norm_regtokens": x[:, self.num_class_tokens:self.num_prefix_tokens], + "x_norm_patchtokens": x[:, self.num_prefix_tokens:], + "x_prenorm": x, + "masks": masks, + } + return outputs + + def forward(self, x, masks=None, is_training=True, *, view_kind: str = "global"): + return self.forward_features_list(x, masks, view_kind=view_kind) + + def load_pretrained_weights(self, state_dict, backbone_only=False, unchunk=False): + if isinstance(state_dict, str): + state_dict = torch.load(state_dict, map_location="cpu", weights_only=False)['teacher'] + new_state_dict = {} + for k, v in state_dict.items(): + if not k.startswith("backbone."): + continue + new_key = k.replace("backbone.", "", 1) + new_state_dict[new_key] = v + state_dict = new_state_dict + + if unchunk: + state_dict = self.unchunk_state_dict(state_dict) + state_dict = dict(state_dict) + if self.rope_embed is None: + for key in [key for key in state_dict if key.startswith("rope_embed.")]: + state_dict.pop(key) + load_result = self.load_state_dict(state_dict, strict=False) + + missing_keys = list(load_result.missing_keys) + unexpected_keys = list(load_result.unexpected_keys) + allowed_missing_keys = set() + + if self.rope_embed is not None and hasattr(self.rope_embed, "reset_mixed_frequencies_to_axial"): + allowed_missing_keys.add("rope_embed.mix_frequencies") + if "rope_embed.mix_frequencies" in missing_keys: + self.rope_embed.reset_mixed_frequencies_to_axial() + + for module_name, module in self.named_modules(): + if not module_name or not hasattr(module, "rope_embed") or module.rope_embed is None: + continue + prefix = f"{module_name}.rope_embed" + if hasattr(module.rope_embed, "reset_mixed_frequencies_to_axial"): + allowed_missing_keys.add(f"{prefix}.mix_frequencies") + if f"{prefix}.mix_frequencies" in missing_keys: + module.rope_embed.reset_mixed_frequencies_to_axial() + allowed_missing_keys.add(f"{prefix}.periods") + + disallowed_missing = [key for key in missing_keys if key not in allowed_missing_keys] + if disallowed_missing or unexpected_keys: + details = [] + if disallowed_missing: + details.append(f"missing keys: {disallowed_missing}") + if unexpected_keys: + details.append(f"unexpected keys: {unexpected_keys}") + raise RuntimeError("failed to load pretrained backbone weights cleanly: " + "; ".join(details)) + + return load_result + + def unchunk_state_dict(self, state_dict): + """ + Convert a state_dict from EvaWithChunking (nested blocks) + to Eva (flat blocks). + """ + if not any([key.startswith("blocks.0.0") for key in state_dict.keys()]): + return state_dict + + new_state_dict = OrderedDict() + for key, val in state_dict.items(): + if key.startswith("blocks."): + parts = key.split(".") + # e.g. "blocks.0.1.attn.qkv.weight" + # parts[1] = chunk idx, parts[2] = inner idx + if parts[2].isdigit(): + chunk_idx = int(parts[1]) + inner_idx = int(parts[2]) + # compute new flat index + flat_idx = chunk_idx * 9999 + inner_idx # temporary large stride + # rewrite key + new_key = ".".join(["blocks", str(flat_idx)] + parts[3:]) + new_state_dict[new_key] = val + else: + # already a normal block key (no extra index) + new_state_dict[key] = val + else: + new_state_dict[key] = val + + # Fix flat indices back to consecutive 0..N + # because above we used a stride + mapping = {old: new for new, old in enumerate(sorted(set( + int(k.split(".")[1]) for k in new_state_dict if k.startswith("blocks.") + )))} + final_state_dict = OrderedDict() + for key, val in new_state_dict.items(): + if key.startswith("blocks."): + parts = key.split(".") + parts[1] = str(mapping[int(parts[1])]) + final_state_dict[".".join(parts)] = val + else: + final_state_dict[key] = val + + return final_state_dict + + +class BlockChunk(nn.ModuleList): + def forward(self, x, rope=None, attn_mask=None, rope_shape=None): + for blk in self: + x = blk(x, rope=rope, attn_mask=attn_mask, rope_shape=rope_shape) + return x + + +class EvaWithChunking(Eva): + def __init__(self, *args, block_chunks: int = 1, **kwargs): + super().__init__(*args, **kwargs) + + self.block_chunks = block_chunks + self.chunked_blocks = block_chunks > 0 and block_chunks < len(self.blocks) + + if self.chunked_blocks: + self._apply_block_chunking() + + def _apply_block_chunking(self): + depth = len(self.blocks) + chunksize = depth // self.block_chunks + chunks = [] + for i in range(0, depth, chunksize): + block_chunk = BlockChunk(self.blocks[i: i + chunksize]) + chunks.append(block_chunk) + self.blocks = nn.ModuleList(chunks) + + def forward_features(self, x, masks=None, *, view_kind: str = "global"): + x, rot_pos_embed, rope_shape = self.prepare_tokens_with_masks(x, masks, view_kind=view_kind) + + if self.chunked_blocks: + for chunk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(chunk, x, rope=rot_pos_embed, rope_shape=rope_shape) + else: + x = chunk(x, rope=rot_pos_embed, rope_shape=rope_shape) + else: + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, rope=rot_pos_embed, rope_shape=rope_shape) + else: + x = blk(x, rope=rot_pos_embed, rope_shape=rope_shape) + + x = self.norm(x) + outputs = { + "x_norm_clstoken": x[:, 0] if self.num_class_tokens > 0 else None, + "x_norm_regtokens": x[:, self.num_class_tokens:self.num_prefix_tokens], + "x_norm_patchtokens": x[:, self.num_prefix_tokens:], + "x_prenorm": x, + "masks": masks, + } + return outputs + + def forward(self, x, masks=None, is_training=True, *, view_kind: str = "global"): + return self.forward_features_list(x, masks, view_kind=view_kind) + + +class Dinov2PrimusEncL(Eva): + def __init__(self, + input_channels, + input_shape): + super().__init__( + input_channels=input_channels, + global_crops_size=96, + local_crops_size=input_shape, + embed_dim=864, + patch_size=(8, 8, 8), + depth=24, + num_heads=16, + mlp_ratio=2.66666666, + attn_drop_rate=0.2, + drop_path_rate=0.2 + ) diff --git a/vesuvius/src/vesuvius/models/build/pretrained_backbones/patch_encode_decode.py b/vesuvius/src/vesuvius/models/build/pretrained_backbones/patch_encode_decode.py new file mode 100644 index 000000000..305852580 --- /dev/null +++ b/vesuvius/src/vesuvius/models/build/pretrained_backbones/patch_encode_decode.py @@ -0,0 +1,988 @@ +from itertools import product +from typing import Literal, Optional, Tuple, List, Union, Type + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd +from torch.nn.modules.dropout import _DropoutNd + +block_type = Literal["basic", "bottleneck"] +block_style = Literal["residual", "conv"] + + +def convert_dim_to_conv_op(dimension: int) -> Type[_ConvNd]: + if dimension == 1: + return nn.Conv1d + if dimension == 2: + return nn.Conv2d + if dimension == 3: + return nn.Conv3d + raise ValueError("Unknown dimension. Only 1, 2 and 3 are supported") + + +def maybe_convert_scalar_to_list(conv_op: Type[_ConvNd], scalar): + if not isinstance(scalar, (tuple, list, np.ndarray)): + if conv_op == nn.Conv1d: + return [scalar] + if conv_op == nn.Conv2d: + return [scalar] * 2 + if conv_op == nn.Conv3d: + return [scalar] * 3 + raise RuntimeError(f"Invalid conv op: {conv_op}") + return scalar + + +def get_matching_pool_op(conv_op: Type[_ConvNd], *, pool_type: str = "avg") -> Type[nn.Module]: + if pool_type != "avg": + raise ValueError(f"unsupported pool_type={pool_type!r}") + mapping = { + nn.Conv1d: nn.AvgPool1d, + nn.Conv2d: nn.AvgPool2d, + nn.Conv3d: nn.AvgPool3d, + } + return mapping[conv_op] + + +class LayerNormNd(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + idx = (None, slice(None), *([None] * (x.ndim - 2))) + return self.weight[idx] * x + self.bias[idx] + + +def _is_power_of_two(value: int) -> bool: + return value > 0 and (value & (value - 1)) == 0 + + +class ConvDropoutNormReLU(nn.Module): + def __init__( + self, + conv_op: Type[_ConvNd], + input_channels: int, + output_channels: int, + kernel_size: Union[int, List[int], Tuple[int, ...]], + stride: Union[int, List[int], Tuple[int, ...]], + padding_mode: str = "zeros", + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: Optional[dict] = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: Optional[dict] = None, + nonlin: Union[None, Type[nn.Module]] = None, + nonlin_kwargs: Optional[dict] = None, + nonlin_first: bool = False, + ) -> None: + super().__init__() + self.input_channels = input_channels + self.output_channels = output_channels + stride = maybe_convert_scalar_to_list(conv_op, stride) + self.stride = stride + kernel_size = maybe_convert_scalar_to_list(conv_op, kernel_size) + if norm_op_kwargs is None: + norm_op_kwargs = {} + if dropout_op_kwargs is None: + dropout_op_kwargs = {} + if nonlin_kwargs is None: + nonlin_kwargs = {} + + ops = [] + self.conv = conv_op( + input_channels, + output_channels, + kernel_size, + stride, + padding=[(k - 1) // 2 for k in kernel_size], + dilation=1, + padding_mode=padding_mode, + bias=conv_bias, + ) + ops.append(self.conv) + if dropout_op is not None: + self.dropout = dropout_op(**dropout_op_kwargs) + ops.append(self.dropout) + if norm_op is not None: + self.norm = norm_op(output_channels, **norm_op_kwargs) + ops.append(self.norm) + if nonlin is not None: + self.nonlin = nonlin(**nonlin_kwargs) + ops.append(self.nonlin) + if nonlin_first and norm_op is not None and nonlin is not None: + ops[-1], ops[-2] = ops[-2], ops[-1] + self.all_modules = nn.Sequential(*ops) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.all_modules(x) + + +class StackedConvBlocks(nn.Module): + def __init__( + self, + num_convs: int, + conv_op: Type[_ConvNd], + input_channels: int, + output_channels: Union[int, List[int], Tuple[int, ...]], + kernel_size: Union[int, List[int], Tuple[int, ...]], + initial_stride: Union[int, List[int], Tuple[int, ...]], + padding_mode: str = "zeros", + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: Optional[dict] = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: Optional[dict] = None, + nonlin: Union[None, Type[nn.Module]] = None, + nonlin_kwargs: Optional[dict] = None, + nonlin_first: bool = False, + ) -> None: + super().__init__() + if not isinstance(output_channels, (tuple, list)): + output_channels = [output_channels] * num_convs + self.convs = nn.Sequential( + ConvDropoutNormReLU( + conv_op, + input_channels, + output_channels[0], + kernel_size, + initial_stride, + padding_mode, + conv_bias, + norm_op, + norm_op_kwargs, + dropout_op, + dropout_op_kwargs, + nonlin, + nonlin_kwargs, + nonlin_first, + ), + *[ + ConvDropoutNormReLU( + conv_op, + output_channels[i - 1], + output_channels[i], + kernel_size, + 1, + padding_mode, + conv_bias, + norm_op, + norm_op_kwargs, + dropout_op, + dropout_op_kwargs, + nonlin, + nonlin_kwargs, + nonlin_first, + ) + for i in range(1, num_convs) + ], + ) + self.output_channels = output_channels[-1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.convs(x) + + +class BasicBlockD(nn.Module): + def __init__( + self, + conv_op: Type[_ConvNd], + input_channels: int, + output_channels: int, + kernel_size: Union[int, List[int], Tuple[int, ...]], + stride: Union[int, List[int], Tuple[int, ...]], + padding_mode: str = "zeros", + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: Optional[dict] = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: Optional[dict] = None, + nonlin: Union[None, Type[nn.Module]] = None, + nonlin_kwargs: Optional[dict] = None, + ) -> None: + super().__init__() + stride = maybe_convert_scalar_to_list(conv_op, stride) + if norm_op_kwargs is None: + norm_op_kwargs = {} + if nonlin_kwargs is None: + nonlin_kwargs = {} + + self.conv1 = ConvDropoutNormReLU( + conv_op, + input_channels, + output_channels, + kernel_size, + stride, + padding_mode, + conv_bias, + norm_op, + norm_op_kwargs, + dropout_op, + dropout_op_kwargs, + nonlin, + nonlin_kwargs, + ) + self.conv2 = ConvDropoutNormReLU( + conv_op, + output_channels, + output_channels, + kernel_size, + 1, + padding_mode, + conv_bias, + norm_op, + norm_op_kwargs, + None, + None, + None, + None, + ) + self.nonlin2 = nonlin(**nonlin_kwargs) if nonlin is not None else nn.Identity() + + has_stride = any(step != 1 for step in stride) + requires_projection = input_channels != output_channels + if has_stride or requires_projection: + ops = [] + if has_stride: + pool_op = get_matching_pool_op(conv_op, pool_type="avg") + ops.append(pool_op(stride, stride)) + if requires_projection: + ops.append( + ConvDropoutNormReLU( + conv_op, + input_channels, + output_channels, + 1, + 1, + "zeros", + False, + norm_op, + norm_op_kwargs, + None, + None, + None, + None, + ) + ) + self.skip = nn.Sequential(*ops) + else: + self.skip = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = self.skip(x) + out = self.conv2(self.conv1(x)) + out = out + residual + return self.nonlin2(out) + + +class BottleneckD(nn.Module): + def __init__( + self, + conv_op: Type[_ConvNd], + input_channels: int, + bottleneck_channels: int, + output_channels: int, + kernel_size: Union[int, List[int], Tuple[int, ...]], + stride: Union[int, List[int], Tuple[int, ...]], + padding_mode: str = "zeros", + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: Optional[dict] = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: Optional[dict] = None, + nonlin: Union[None, Type[nn.Module]] = None, + nonlin_kwargs: Optional[dict] = None, + ) -> None: + super().__init__() + stride = maybe_convert_scalar_to_list(conv_op, stride) + if norm_op_kwargs is None: + norm_op_kwargs = {} + if nonlin_kwargs is None: + nonlin_kwargs = {} + self.conv1 = ConvDropoutNormReLU( + conv_op, + input_channels, + bottleneck_channels, + 1, + 1, + "zeros", + conv_bias, + norm_op, + norm_op_kwargs, + None, + None, + nonlin, + nonlin_kwargs, + ) + self.conv2 = ConvDropoutNormReLU( + conv_op, + bottleneck_channels, + bottleneck_channels, + kernel_size, + stride, + padding_mode, + conv_bias, + norm_op, + norm_op_kwargs, + dropout_op, + dropout_op_kwargs, + nonlin, + nonlin_kwargs, + ) + self.conv3 = ConvDropoutNormReLU( + conv_op, + bottleneck_channels, + output_channels, + 1, + 1, + "zeros", + conv_bias, + norm_op, + norm_op_kwargs, + None, + None, + None, + None, + ) + self.nonlin3 = nonlin(**nonlin_kwargs) if nonlin is not None else nn.Identity() + + has_stride = any(step != 1 for step in stride) + requires_projection = input_channels != output_channels + if has_stride or requires_projection: + ops = [] + if has_stride: + pool_op = get_matching_pool_op(conv_op, pool_type="avg") + ops.append(pool_op(stride, stride)) + if requires_projection: + ops.append( + ConvDropoutNormReLU( + conv_op, + input_channels, + output_channels, + 1, + 1, + "zeros", + False, + norm_op, + norm_op_kwargs, + None, + None, + None, + None, + ) + ) + self.skip = nn.Sequential(*ops) + else: + self.skip = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = self.skip(x) + out = self.conv3(self.conv2(self.conv1(x))) + out = out + residual + return self.nonlin3(out) + + +class StackedResidualBlocks(nn.Module): + def __init__( + self, + n_blocks: int, + conv_op: Type[_ConvNd], + input_channels: int, + output_channels: Union[int, List[int], Tuple[int, ...]], + kernel_size: Union[int, List[int], Tuple[int, ...]], + initial_stride: Union[int, List[int], Tuple[int, ...]], + padding_mode: str = "zeros", + conv_bias: bool = False, + norm_op: Union[None, Type[nn.Module]] = None, + norm_op_kwargs: Optional[dict] = None, + dropout_op: Union[None, Type[_DropoutNd]] = None, + dropout_op_kwargs: Optional[dict] = None, + nonlin: Union[None, Type[nn.Module]] = None, + nonlin_kwargs: Optional[dict] = None, + block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD, + bottleneck_channels: Union[int, List[int], Tuple[int, ...], None] = None, + ) -> None: + super().__init__() + assert n_blocks > 0 + assert block in [BasicBlockD, BottleneckD] + if not isinstance(output_channels, (tuple, list)): + output_channels = [output_channels] * n_blocks + if not isinstance(bottleneck_channels, (tuple, list)): + bottleneck_channels = [bottleneck_channels] * n_blocks + + if block == BasicBlockD: + blocks = nn.Sequential( + block( + conv_op, + input_channels, + output_channels[0], + kernel_size, + initial_stride, + padding_mode, + conv_bias, + norm_op, + norm_op_kwargs, + dropout_op, + dropout_op_kwargs, + nonlin, + nonlin_kwargs, + ), + *[ + block( + conv_op, + output_channels[n - 1], + output_channels[n], + kernel_size, + 1, + padding_mode, + conv_bias, + norm_op, + norm_op_kwargs, + dropout_op, + dropout_op_kwargs, + nonlin, + nonlin_kwargs, + ) + for n in range(1, n_blocks) + ], + ) + else: + blocks = nn.Sequential( + block( + conv_op, + input_channels, + bottleneck_channels[0], + output_channels[0], + kernel_size, + initial_stride, + padding_mode, + conv_bias, + norm_op, + norm_op_kwargs, + dropout_op, + dropout_op_kwargs, + nonlin, + nonlin_kwargs, + ), + *[ + block( + conv_op, + output_channels[n - 1], + bottleneck_channels[n], + output_channels[n], + kernel_size, + 1, + padding_mode, + conv_bias, + norm_op, + norm_op_kwargs, + dropout_op, + dropout_op_kwargs, + nonlin, + nonlin_kwargs, + ) + for n in range(1, n_blocks) + ], + ) + self.blocks = blocks + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.blocks(x) + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + Loosely inspired by https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L364 + + """ + + def __init__( + self, + patch_size: Tuple[int, ...] = (16, 16, 16), + input_channels: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + patch_size (Tuple): patch size. + padding (Tuple): padding size of the projection layer. + input_channels (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = convert_dim_to_conv_op(len(patch_size))( + input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, padding=0 + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Returns a patch grid of shape (B, embed_dim, D, H, W) for 3D or + (B, embed_dim, H, W) for 2D, where (D, H, W) = input_shape / patch_size. + This output will need to be rearranged to whatever your transformer expects. + """ + x = self.proj(x) + return x + + +class PatchEmbedDeeper(nn.Module): + """Dynamic-network-architectures-style patch embed with patch-size-derived stage strides.""" + + def __init__( + self, + patch_size: Tuple[int, ...] = (16, 16, 16), + input_channels: int = 3, + embed_dim: int = 768, + base_features: int = 32, + depth_per_level: Optional[Tuple[int, ...]] = None, + embed_proj_3x3x3: bool = False, + embed_block_type: block_type = "basic", + embed_block_style: block_style = "residual", + ) -> None: + super().__init__() + self.patch_size = tuple(int(p) for p in patch_size) + self.ndim = len(self.patch_size) + if self.ndim not in (2, 3): + raise ValueError(f"PatchEmbedDeeper only supports 2D or 3D, got ndim={self.ndim}") + if any(not _is_power_of_two(size) for size in self.patch_size): + raise ValueError( + f"PatchEmbedDeeper requires power-of-two patch sizes, got patch_size={self.patch_size}" + ) + + max_patch = max(self.patch_size) + num_stages = int(np.log2(max_patch)) if max_patch > 1 else 0 + if depth_per_level is None: + depth_per_level = (1,) * num_stages + if len(depth_per_level) != num_stages: + raise ValueError( + f"depth_per_level must match the number of downsampling stages ({num_stages}), got {depth_per_level}" + ) + + conv_op = convert_dim_to_conv_op(self.ndim) + norm_op = LayerNormNd + kernel_size = [3] * self.ndim + block = BottleneckD if embed_block_type == "bottleneck" else BasicBlockD + nonlin = nn.LeakyReLU if embed_block_type == "bottleneck" else nn.ReLU + norm_op_kwargs = {"eps": 1e-5} + nonlin_kwargs = {"inplace": True} + + if embed_block_style == "residual": + self.stem = StackedResidualBlocks( + 1, + conv_op, + input_channels, + base_features, + kernel_size, + 1, + padding_mode="reflect", + conv_bias=True, + norm_op=norm_op, + norm_op_kwargs=norm_op_kwargs, + dropout_op=None, + dropout_op_kwargs=None, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + block=block, + ) + elif embed_block_style == "conv": + self.stem = StackedConvBlocks( + 1, + conv_op, + input_channels, + base_features, + kernel_size, + 1, + padding_mode="reflect", + conv_bias=True, + norm_op=norm_op, + norm_op_kwargs=norm_op_kwargs, + dropout_op=None, + dropout_op_kwargs=None, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + ) + else: + raise ValueError( + f"Unknown embed_block_style: {embed_block_style}. Must be 'residual' or 'conv'." + ) + + stage_strides = [ + tuple(2 if (patch / (2 ** stage_index)) % 2 == 0 else 1 for patch in self.patch_size) + for stage_index in range(num_stages) + ] + + self.stages = nn.ModuleList() + stage_in_channels = base_features + for stage_index, (depth, stride) in enumerate(zip(depth_per_level, stage_strides)): + stage_out_channels = base_features * (2 ** stage_index) + bottleneck_channels = stage_out_channels // 4 if embed_block_type == "bottleneck" else None + if embed_block_style == "residual": + stage = StackedResidualBlocks( + n_blocks=depth, + conv_op=conv_op, + input_channels=stage_in_channels, + output_channels=stage_out_channels, + kernel_size=kernel_size, + initial_stride=stride, + padding_mode="reflect", + conv_bias=False, + norm_op=norm_op, + norm_op_kwargs=norm_op_kwargs, + dropout_op=None, + dropout_op_kwargs=None, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + block=block, + bottleneck_channels=bottleneck_channels, + ) + else: + stage = StackedConvBlocks( + num_convs=depth, + conv_op=conv_op, + input_channels=stage_in_channels, + output_channels=stage_out_channels, + kernel_size=kernel_size, + initial_stride=stride, + padding_mode="reflect", + conv_bias=False, + norm_op=norm_op, + norm_op_kwargs=norm_op_kwargs, + dropout_op=None, + dropout_op_kwargs=None, + nonlin=nonlin, + nonlin_kwargs=nonlin_kwargs, + ) + self.stages.append(stage) + stage_in_channels = stage_out_channels + + final_proj_kernel = [3] * self.ndim if embed_proj_3x3x3 else [1] * self.ndim + final_pad = [1] * self.ndim if embed_proj_3x3x3 else [0] * self.ndim + self.final_proj = conv_op( + stage_in_channels, + embed_dim, + kernel_size=final_proj_kernel, + stride=[1] * self.ndim, + padding=final_pad, + padding_mode="reflect", + ) + self._patch_halo: tuple[int, ...] | None = None + + def _normalize_patch_chunk_size( + self, + patch_chunk_size: Optional[Union[int, tuple[int, ...], list[int]]], + *, + full_patch_shape: tuple[int, ...], + ) -> tuple[int, ...]: + if patch_chunk_size is None: + return full_patch_shape + if isinstance(patch_chunk_size, int): + if patch_chunk_size <= 0: + raise ValueError(f"patch_chunk_size must be positive, got {patch_chunk_size}") + normalized = (int(patch_chunk_size),) * self.ndim + else: + normalized = tuple(int(v) for v in patch_chunk_size) + if len(normalized) != self.ndim: + raise ValueError( + f"patch_chunk_size must provide {self.ndim} values, got {normalized}" + ) + if any(v <= 0 for v in normalized): + raise ValueError(f"patch_chunk_size must be positive, got {normalized}") + return tuple(min(chunk, full) for chunk, full in zip(normalized, full_patch_shape)) + + def _avg_pool_support( + self, + support: torch.Tensor, + *, + kernel_size: Union[int, tuple[int, ...], list[int]], + stride: Union[int, tuple[int, ...], list[int]], + padding: Union[int, tuple[int, ...], list[int]], + ) -> torch.Tensor: + pool = F.avg_pool2d if self.ndim == 2 else F.avg_pool3d + return pool( + support, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=False, + count_include_pad=True, + ) + + def _support_from_conv(self, support: torch.Tensor, conv: _ConvNd) -> torch.Tensor: + return self._avg_pool_support( + support, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + ) + + def _support_from_module(self, module: nn.Module, support: torch.Tensor) -> torch.Tensor: + if isinstance(module, ConvDropoutNormReLU): + return self._support_from_conv(support, module.conv) + if isinstance(module, StackedConvBlocks): + for block in module.convs: + support = self._support_from_module(block, support) + return support + if isinstance(module, BasicBlockD): + residual = self._support_from_module(module.skip, support) + main = self._support_from_module(module.conv1, support) + main = self._support_from_module(module.conv2, main) + return 0.5 * (main + residual) + if isinstance(module, BottleneckD): + residual = self._support_from_module(module.skip, support) + main = self._support_from_module(module.conv1, support) + main = self._support_from_module(module.conv2, main) + main = self._support_from_module(module.conv3, main) + return 0.5 * (main + residual) + if isinstance(module, StackedResidualBlocks): + for block in module.blocks: + support = self._support_from_module(block, support) + return support + if isinstance(module, (nn.AvgPool2d, nn.AvgPool3d)): + return self._avg_pool_support( + support, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + ) + if isinstance(module, (nn.Identity, LayerNormNd, nn.ReLU, nn.LeakyReLU)): + return support + if isinstance(module, nn.Sequential): + for submodule in module: + support = self._support_from_module(submodule, support) + return support + raise TypeError(f"Unsupported support propagation module: {type(module).__name__}") + + @torch.no_grad() + def patch_halo(self) -> tuple[int, ...]: + if self._patch_halo is None: + probe_patch_shape = tuple(max(13, 2 * (len(self.stages) + 2) + 1) for _ in range(self.ndim)) + probe_spatial_shape = tuple( + patch_count * patch_size for patch_count, patch_size in zip(probe_patch_shape, self.patch_size) + ) + support = torch.ones((1, 1, *probe_spatial_shape), dtype=torch.float32) + support = self._support_from_module(self.stem, support) + for stage in self.stages: + support = self._support_from_module(stage, support) + support = self._support_from_conv(support, self.final_proj) + support = support / support.amax().clamp(min=torch.finfo(support.dtype).eps) + full_support = support[0, 0] >= (1.0 - 1e-6) + indices = torch.nonzero(full_support, as_tuple=False) + if indices.numel() == 0: + raise RuntimeError("PatchEmbedDeeper could not determine a full-support interior patch region.") + + halo: list[int] = [] + for axis, axis_size in enumerate(full_support.shape): + axis_min = int(indices[:, axis].min().item()) + axis_max = int(indices[:, axis].max().item()) + halo.append(min(axis_min, axis_size - axis_max - 1)) + self._patch_halo = tuple(halo) + return self._patch_halo + + def forward_tiled( + self, + x: torch.Tensor, + *, + target_patch_shape: Optional[Union[int, tuple[int, ...], list[int]]] = None, + patch_chunk_size: Optional[Union[int, tuple[int, ...], list[int]]] = None, + batch_chunk_size: Optional[int] = None, + ) -> torch.Tensor: + full_patch_shape = tuple(int(dim) // int(patch) for dim, patch in zip(x.shape[2:], self.patch_size)) + if any(int(dim) % int(patch) != 0 for dim, patch in zip(x.shape[2:], self.patch_size)): + raise ValueError( + f"PatchEmbedDeeper input shape must be divisible by patch_size, got shape={tuple(x.shape[2:])} " + f"and patch_size={self.patch_size}" + ) + + if batch_chunk_size is None: + batch_chunk_size = x.shape[0] + batch_chunk_size = int(batch_chunk_size) + if batch_chunk_size <= 0: + raise ValueError(f"batch_chunk_size must be positive, got {batch_chunk_size}") + + if target_patch_shape is None: + target_patch_shape_normalized = full_patch_shape + elif isinstance(target_patch_shape, int): + if target_patch_shape <= 0: + raise ValueError(f"target_patch_shape must be positive, got {target_patch_shape}") + target_patch_shape_normalized = (int(target_patch_shape),) * self.ndim + else: + target_patch_shape_normalized = tuple(int(v) for v in target_patch_shape) + if len(target_patch_shape_normalized) != self.ndim: + raise ValueError( + f"target_patch_shape must provide {self.ndim} values, got {target_patch_shape_normalized}" + ) + if any(v <= 0 for v in target_patch_shape_normalized): + raise ValueError( + f"target_patch_shape must be positive, got {target_patch_shape_normalized}" + ) + if any(target > full for target, full in zip(target_patch_shape_normalized, full_patch_shape)): + raise ValueError( + f"target_patch_shape={target_patch_shape_normalized} cannot exceed full patch shape {full_patch_shape}" + ) + + patch_chunk_shape = self._normalize_patch_chunk_size( + patch_chunk_size, + full_patch_shape=target_patch_shape_normalized, + ) + if ( + batch_chunk_size >= x.shape[0] + and patch_chunk_shape == target_patch_shape_normalized + and target_patch_shape_normalized == full_patch_shape + ): + return self.forward(x) + + crop_starts = [] + for full, target in zip(full_patch_shape, target_patch_shape_normalized): + delta = full - target + if delta < 0 or delta % 2 != 0: + raise ValueError( + f"cannot center-crop patch grid from {full_patch_shape} to {target_patch_shape_normalized}" + ) + crop_starts.append(delta // 2) + + crop_starts = tuple(crop_starts) + halo = self.patch_halo() + if patch_chunk_size is not None and target_patch_shape_normalized == full_patch_shape: + raise ValueError( + "PatchEmbedDeeper spatial chunking requires a smaller centered target_patch_shape; " + "set patch_chunk_size=None to use batch-only chunking." + ) + if patch_chunk_size is not None: + required_halo = tuple(min(start, full - (start + target)) for start, target, full in zip( + crop_starts, + target_patch_shape_normalized, + full_patch_shape, + )) + if any(available < needed for available, needed in zip(required_halo, halo)): + raise ValueError( + "PatchEmbedDeeper spatial chunking requires centered overscan at least as large as patch_halo; " + f"got available={required_halo} and required={halo}" + ) + + if patch_chunk_size is None: + outputs = [] + crop_stops = tuple(start + size for start, size in zip(crop_starts, target_patch_shape_normalized)) + output_slices = tuple(slice(start, stop) for start, stop in zip(crop_starts, crop_stops)) + for batch_start in range(0, x.shape[0], batch_chunk_size): + batch_end = min(batch_start + batch_chunk_size, x.shape[0]) + batch_y = self.forward(x[batch_start:batch_end]) + outputs.append(batch_y[(slice(None), slice(None), *output_slices)]) + return torch.cat(outputs, dim=0) + + outputs: list[torch.Tensor] = [] + tile_starts = [range(0, size, chunk) for size, chunk in zip(target_patch_shape_normalized, patch_chunk_shape)] + for batch_start in range(0, x.shape[0], batch_chunk_size): + batch_end = min(batch_start + batch_chunk_size, x.shape[0]) + batch_x = x[batch_start:batch_end] + batch_output: torch.Tensor | None = None + + for starts in product(*tile_starts): + tile_shape = tuple( + min(chunk, full - start) + for start, chunk, full in zip(starts, patch_chunk_shape, target_patch_shape_normalized) + ) + target_starts = tuple(crop_start + start for crop_start, start in zip(crop_starts, starts)) + context_starts = tuple(target_start - axis_halo for target_start, axis_halo in zip(target_starts, halo)) + context_stops = tuple( + target_start + size + axis_halo + for target_start, size, axis_halo in zip(target_starts, tile_shape, halo) + ) + input_slices = tuple( + slice(start * patch, stop * patch) + for start, stop, patch in zip(context_starts, context_stops, self.patch_size) + ) + tile_x = batch_x[(slice(None), slice(None), *input_slices)] + tile_y = self.forward(tile_x) + + output_slices = tuple( + slice(target_start - context_start, target_start - context_start + size) + for target_start, context_start, size in zip(target_starts, context_starts, tile_shape) + ) + tile_output = tile_y[(slice(None), slice(None), *output_slices)] + if batch_output is None: + batch_output = torch.empty( + (batch_x.shape[0], tile_output.shape[1], *target_patch_shape_normalized), + device=tile_output.device, + dtype=tile_output.dtype, + ) + destination_slices = tuple(slice(start, start + size) for start, size in zip(starts, tile_shape)) + batch_output[(slice(None), slice(None), *destination_slices)] = tile_output + + if batch_output is None: + raise RuntimeError("PatchEmbedDeeper.forward_tiled produced no output tiles.") + outputs.append(batch_output) + + return torch.cat(outputs, dim=0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.stem(x) + for stage in self.stages: + x = stage(x) + x = self.final_proj(x) + return x + + +class PatchDecode(nn.Module): + """ + Loosely inspired by SAM decoder + https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py#L53 + + Supports both 2D and 3D inputs based on patch_size dimensionality. + """ + + def __init__( + self, + patch_size, + embed_dim: int, + out_channels: int, + norm=LayerNormNd, + activation=nn.GELU, + ): + """ + patch size must be 2^x, so 2, 4, 8, 16, 32, etc. Otherwise we die + + Args: + patch_size: Tuple of (H, W) for 2D or (D, H, W) for 3D + """ + super().__init__() + + # Determine dimensionality from patch_size + self.ndim = len(patch_size) + conv_transpose_op = nn.ConvTranspose2d if self.ndim == 2 else nn.ConvTranspose3d + + def _round_to_8(inp): + return int(max(8, np.round((inp + 1e-6) / 8) * 8)) + + num_stages = int(np.log(max(patch_size)) / np.log(2)) + strides = [[2 if (p / 2**n) % 2 == 0 else 1 for p in patch_size] for n in range(num_stages)][::-1] + dim_red = (embed_dim / (2 * out_channels)) ** (1 / num_stages) + + # don't question me + channels = [embed_dim] + [_round_to_8(embed_dim / dim_red ** (x + 1)) for x in range(num_stages)] + channels[-1] = out_channels + + stages = [] + for s in range(num_stages - 1): + stages.append( + nn.Sequential( + conv_transpose_op(channels[s], channels[s + 1], kernel_size=strides[s], stride=strides[s]), + norm(channels[s + 1]), + activation(), + ) + ) + stages.append(conv_transpose_op(channels[-2], channels[-1], kernel_size=strides[-1], stride=strides[-1])) + self.decode = nn.Sequential(*stages) + + def forward(self, x): + """ + Expects input of shape (B, embed_dim, D, H, W) for 3D or (B, embed_dim, H, W) for 2D. + This will require you to reshape the output of your transformer. + """ + return self.decode(x) diff --git a/vesuvius/src/vesuvius/models/build/pretrained_backbones/rope.py b/vesuvius/src/vesuvius/models/build/pretrained_backbones/rope.py new file mode 100644 index 000000000..31371478b --- /dev/null +++ b/vesuvius/src/vesuvius/models/build/pretrained_backbones/rope.py @@ -0,0 +1,331 @@ +import math +from typing import Literal, Sequence + +import torch +from torch import Tensor, nn + + +RopeEmbedding = tuple[Tensor, Tensor] + + +def rope_rotate_half(x: Tensor) -> Tensor: + x_first, x_second = x.chunk(2, dim=-1) + return torch.cat((-x_second, x_first), dim=-1) + + +def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: + return (x * cos) + (rope_rotate_half(x) * sin) + + +def apply_rotary_embedding(x: Tensor, rope: RopeEmbedding, *, prefix_tokens: int = 0) -> Tensor: + sin, cos = rope + if prefix_tokens < 0: + raise ValueError(f"prefix_tokens must be non-negative, got {prefix_tokens}") + if x.shape[-2] - prefix_tokens != sin.shape[-2]: + raise ValueError( + "rope token count does not match the rotated suffix: " + f"x.shape[-2]={x.shape[-2]}, prefix_tokens={prefix_tokens}, rope_tokens={sin.shape[-2]}" + ) + + x_dtype = x.dtype + rope_dtype = sin.dtype + x_prefix = x[..., :prefix_tokens, :] + x_suffix = x[..., prefix_tokens:, :].to(dtype=rope_dtype) + x_suffix = rope_apply(x_suffix, sin, cos).to(dtype=x_dtype) + if prefix_tokens == 0: + return x_suffix + return torch.cat((x_prefix, x_suffix), dim=-2) + + +class _BaseRopePositionEmbedding(nn.Module): + def __init__( + self, + head_dim: int, + *, + ndim: int, + base: float | None = 100.0, + min_period: float | None = None, + max_period: float | None = None, + normalize_coords: Literal["min", "max", "separate"] = "separate", + shift_coords: float | None = None, + jitter_coords: float | None = None, + rescale_coords: float | None = 2.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ) -> None: + super().__init__() + if ndim not in (2, 3): + raise ValueError(f"RopePositionEmbedding only supports 2D or 3D inputs, got ndim={ndim}") + + both_periods = min_period is not None and max_period is not None + if (base is None and not both_periods) or (base is not None and both_periods): + raise ValueError("Either `base` or `min_period`+`max_period` must be provided.") + + self.head_dim = int(head_dim) + self.ndim = int(ndim) + self.base = base + self.min_period = min_period + self.max_period = max_period + self.normalize_coords = normalize_coords + self.shift_coords = shift_coords + self.jitter_coords = jitter_coords + self.rescale_coords = rescale_coords + self.dtype = dtype or torch.float32 + + def _normalized_axes(self, shape: tuple[int, ...]) -> list[Tensor]: + device = self.periods.device + dtype = self.periods.dtype + if self.normalize_coords == "max": + denominator = max(shape) + axes = [torch.arange(0.5, size, device=device, dtype=dtype) / denominator for size in shape] + elif self.normalize_coords == "min": + denominator = min(shape) + axes = [torch.arange(0.5, size, device=device, dtype=dtype) / denominator for size in shape] + elif self.normalize_coords == "separate": + axes = [torch.arange(0.5, size, device=device, dtype=dtype) / size for size in shape] + else: + raise ValueError(f"Unknown normalize_coords={self.normalize_coords!r}") + return axes + + def _apply_coord_augmentations(self, coords: Tensor) -> Tensor: + if self.training and self.shift_coords is not None: + shift = torch.empty(self.ndim, device=coords.device, dtype=coords.dtype).uniform_( + -self.shift_coords, self.shift_coords + ) + coords = coords + shift[None, :] + + if self.training and self.jitter_coords is not None: + jitter_log = math.log(self.jitter_coords) + jitter = torch.empty(self.ndim, device=coords.device, dtype=coords.dtype).uniform_( + -jitter_log, jitter_log + ).exp() + coords = coords * jitter[None, :] + + if self.training and self.rescale_coords is not None: + rescale_log = math.log(self.rescale_coords) + rescale = torch.empty(1, device=coords.device, dtype=coords.dtype).uniform_( + -rescale_log, rescale_log + ).exp() + coords = coords * rescale + + return coords + + def _build_periods(self, count: int, denominator_dim: int, *, device: torch.device | None = None) -> Tensor: + device = device or torch.device("cpu") + if self.base is not None: + return self.base ** ( + 2 * torch.arange(count, device=device, dtype=self.dtype) / denominator_dim + ) + + assert self.min_period is not None + assert self.max_period is not None + base = self.max_period / self.min_period + exponents = torch.linspace(0, 1, count, device=device, dtype=self.dtype) + periods = base**exponents + periods = periods / base + periods = periods * self.max_period + return periods + + def _get_coords(self, shape: Sequence[int]) -> Tensor: + spatial_shape = tuple(int(size) for size in shape) + if len(spatial_shape) != self.ndim: + raise ValueError(f"expected a {self.ndim}D shape and got {spatial_shape}") + if any(size <= 0 for size in spatial_shape): + raise ValueError(f"shape dimensions must be positive, got {spatial_shape}") + + axes = self._normalized_axes(spatial_shape) + coords = torch.stack(torch.meshgrid(*axes, indexing="ij"), dim=-1).reshape(-1, self.ndim) + coords = (2.0 * coords) - 1.0 + coords = self._apply_coord_augmentations(coords) + return coords + + +class RopePositionEmbedding(_BaseRopePositionEmbedding): + def __init__( + self, + head_dim: int, + *, + ndim: int, + num_heads: int | None = None, + base: float | None = 100.0, + min_period: float | None = None, + max_period: float | None = None, + normalize_coords: Literal["min", "max", "separate"] = "separate", + shift_coords: float | None = None, + jitter_coords: float | None = None, + rescale_coords: float | None = 2.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ) -> None: + if head_dim % (2 * ndim) != 0: + raise ValueError( + f"head_dim must be divisible by 2 * ndim for RoPE, got head_dim={head_dim}, ndim={ndim}" + ) + super().__init__( + head_dim, + ndim=ndim, + base=base, + min_period=min_period, + max_period=max_period, + normalize_coords=normalize_coords, + shift_coords=shift_coords, + jitter_coords=jitter_coords, + rescale_coords=rescale_coords, + dtype=dtype, + device=device, + ) + + self.freqs_per_axis = self.head_dim // (2 * self.ndim) + self.axis_dim = self.head_dim // self.ndim + self.register_buffer( + "periods", + torch.empty(self.freqs_per_axis, device=device, dtype=self.dtype), + persistent=True, + ) + self._init_weights() + + def _init_weights(self) -> None: + periods = self._build_periods( + self.freqs_per_axis, + self.axis_dim, + device=self.periods.device, + ) + self.periods.data.copy_(periods) + + def get_embed(self, shape: Sequence[int]) -> RopeEmbedding: + coords = self._get_coords(shape) + angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] + angles = angles.flatten(1, 2).tile(2) + cos = torch.cos(angles) + sin = torch.sin(angles) + return sin, cos + + def forward(self, shape: Sequence[int]) -> RopeEmbedding: + return self.get_embed(shape) + + +class MixedRopePositionEmbedding(_BaseRopePositionEmbedding): + def __init__( + self, + head_dim: int, + *, + ndim: int, + num_heads: int, + base: float | None = 100.0, + min_period: float | None = None, + max_period: float | None = None, + normalize_coords: Literal["min", "max", "separate"] = "separate", + shift_coords: float | None = None, + jitter_coords: float | None = None, + rescale_coords: float | None = 2.0, + dtype: torch.dtype | None = None, + device: torch.device | None = None, + ) -> None: + if head_dim % (2 * ndim) != 0: + raise ValueError( + f"mixed RoPE warm-start requires head_dim divisible by 2 * ndim, got head_dim={head_dim}, ndim={ndim}" + ) + if num_heads <= 0: + raise ValueError(f"num_heads must be positive for mixed RoPE, got {num_heads}") + super().__init__( + head_dim, + ndim=ndim, + base=base, + min_period=min_period, + max_period=max_period, + normalize_coords=normalize_coords, + shift_coords=shift_coords, + jitter_coords=jitter_coords, + rescale_coords=rescale_coords, + dtype=dtype, + device=device, + ) + + self.num_heads = int(num_heads) + self.freqs_per_axis = self.head_dim // (2 * self.ndim) + self.axis_dim = self.head_dim // self.ndim + self.num_pairs = self.head_dim // 2 + self.register_buffer( + "periods", + torch.empty(self.freqs_per_axis, device=device, dtype=self.dtype), + persistent=True, + ) + self.mix_frequencies = nn.Parameter( + torch.empty(self.num_heads, self.num_pairs, self.ndim, device=device, dtype=self.dtype) + ) + self._init_weights() + + def _init_weights(self) -> None: + periods = self._build_periods( + self.freqs_per_axis, + self.axis_dim, + device=self.periods.device, + ) + self.periods.data.copy_(periods) + self.reset_mixed_frequencies_to_random_oriented() + + @torch.no_grad() + def reset_mixed_frequencies_to_axial(self) -> None: + mix_frequencies = torch.zeros_like(self.mix_frequencies) + inv_periods = self.periods.reciprocal() + for axis in range(self.ndim): + start = axis * self.freqs_per_axis + end = start + self.freqs_per_axis + mix_frequencies[:, start:end, axis] = inv_periods + self.mix_frequencies.copy_(mix_frequencies) + + @torch.no_grad() + def reset_mixed_frequencies_to_random_oriented(self) -> None: + mix_frequencies = torch.empty_like(self.mix_frequencies) + inv_periods = self.periods.reciprocal() + + if self.ndim == 2: + angles = torch.empty(self.num_heads, device=self.periods.device, dtype=self.periods.dtype).uniform_( + 0.0, 2.0 * math.pi + ) + cos = torch.cos(angles) + sin = torch.sin(angles) + basis = torch.stack( + ( + torch.stack((cos, sin), dim=-1), + torch.stack((-sin, cos), dim=-1), + ), + dim=-1, + ) + else: + basis = torch.randn( + self.num_heads, + self.ndim, + self.ndim, + device=self.periods.device, + dtype=self.periods.dtype, + ) + basis, r = torch.linalg.qr(basis) + diag = torch.diagonal(r, dim1=-2, dim2=-1) + signs = torch.where(diag < 0, -torch.ones_like(diag), torch.ones_like(diag)) + basis = basis * signs.unsqueeze(-2) + negative_det = torch.linalg.det(basis) < 0 + if negative_det.any(): + basis[negative_det, :, 0] *= -1 + + for axis in range(self.ndim): + start = axis * self.freqs_per_axis + end = start + self.freqs_per_axis + mix_frequencies[:, start:end, :] = inv_periods[None, :, None] * basis[:, None, :, axis] + + self.mix_frequencies.copy_(mix_frequencies) + + @torch.jit.ignore + def no_weight_decay(self) -> set[str]: + return {"mix_frequencies"} + + def get_embed(self, shape: Sequence[int]) -> RopeEmbedding: + coords = self._get_coords(shape) + angles = 2 * math.pi * torch.einsum("td,hpd->htp", coords, self.mix_frequencies) + angles = angles.tile(2) + cos = torch.cos(angles) + sin = torch.sin(angles) + return sin, cos + + def forward(self, shape: Sequence[int]) -> RopeEmbedding: + return self.get_embed(shape) diff --git a/vesuvius/src/vesuvius/models/build/transformers/patch_encode_decode.py b/vesuvius/src/vesuvius/models/build/transformers/patch_encode_decode.py index 41641cf61..90c70784e 100644 --- a/vesuvius/src/vesuvius/models/build/transformers/patch_encode_decode.py +++ b/vesuvius/src/vesuvius/models/build/transformers/patch_encode_decode.py @@ -66,6 +66,58 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +def _compute_patch_decode_plan( + patch_size, + *, + embed_dim: int, + out_channels: int, +): + patch_size = tuple(int(p) for p in patch_size) + if len(patch_size) not in (2, 3): + raise ValueError(f"Patch decoding only supports 2D or 3D patch sizes, got {patch_size}") + if max(patch_size) < 1: + raise ValueError(f"patch_size entries must be >= 1, got {patch_size}") + + def _round_to_8(inp): + return int(max(8, np.round((inp + 1e-6) / 8) * 8)) + + num_stages = int(np.log(max(patch_size)) / np.log(2)) + strides = [[2 if (p / 2**n) % 2 == 0 else 1 for p in patch_size] for n in range(num_stages)][::-1] + dim_red = (embed_dim / (2 * out_channels)) ** (1 / num_stages) + channels = [embed_dim] + [_round_to_8(embed_dim / dim_red ** (x + 1)) for x in range(num_stages)] + channels[-1] = out_channels + return num_stages, strides, channels + + +class PixelShuffle3D(nn.Module): + def __init__(self, upscale_factor): + super().__init__() + if isinstance(upscale_factor, int): + upscale_factor = (upscale_factor, upscale_factor, upscale_factor) + self.upscale_factor = tuple(int(v) for v in upscale_factor) + if len(self.upscale_factor) != 3: + raise ValueError(f"PixelShuffle3D expects 3 scale factors, got {self.upscale_factor}") + if any(v < 1 for v in self.upscale_factor): + raise ValueError(f"PixelShuffle3D scale factors must be >= 1, got {self.upscale_factor}") + self.scale_volume = int(np.prod(self.upscale_factor)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.ndim != 5: + raise ValueError(f"PixelShuffle3D expects [B, C, D, H, W], got {tuple(x.shape)}") + + b, c, d, h, w = x.shape + rz, ry, rx = self.upscale_factor + if c % self.scale_volume != 0: + raise ValueError( + f"PixelShuffle3D input channels {c} must be divisible by product of scale factors {self.scale_volume}" + ) + + out_channels = c // self.scale_volume + x = x.reshape(b, out_channels, rz, ry, rx, d, h, w) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + return x.reshape(b, out_channels, d * rz, h * ry, w * rx) + + class PatchDecode(nn.Module): """ Loosely inspired by SAM decoder @@ -93,17 +145,11 @@ def __init__( # Determine dimensionality from patch_size self.ndim = len(patch_size) conv_transpose_op = nn.ConvTranspose2d if self.ndim == 2 else nn.ConvTranspose3d - - def _round_to_8(inp): - return int(max(8, np.round((inp + 1e-6) / 8) * 8)) - - num_stages = int(np.log(max(patch_size)) / np.log(2)) - strides = [[2 if (p / 2**n) % 2 == 0 else 1 for p in patch_size] for n in range(num_stages)][::-1] - dim_red = (embed_dim / (2 * out_channels)) ** (1 / num_stages) - - # don't question me - channels = [embed_dim] + [_round_to_8(embed_dim / dim_red ** (x + 1)) for x in range(num_stages)] - channels[-1] = out_channels + num_stages, strides, channels = _compute_patch_decode_plan( + patch_size, + embed_dim=embed_dim, + out_channels=out_channels, + ) stages = [] for s in range(num_stages - 1): diff --git a/vesuvius/src/vesuvius/models/configuration/config_manager.py b/vesuvius/src/vesuvius/models/configuration/config_manager.py index e724e825c..554656435 100755 --- a/vesuvius/src/vesuvius/models/configuration/config_manager.py +++ b/vesuvius/src/vesuvius/models/configuration/config_manager.py @@ -117,6 +117,16 @@ def _explicit_volumes_require_unlabeled(self) -> bool: labels = spec.get("labels") if labels is None: labels = spec.get("label_paths") + if labels is None and "label" in spec: + singular_label = spec.get("label") + if singular_label in (None, ""): + labels = {} + elif len(target_names) == 1: + labels = {target_names[0]: singular_label} + else: + raise ValueError( + "Explicit volume entries that use singular 'label' require exactly one configured target" + ) if labels is None or labels == {}: return True if not isinstance(labels, dict): @@ -164,9 +174,65 @@ def _init_attributes(self): self.max_val_steps_per_epoch = int(self.tr_configs.get("max_val_steps_per_epoch", 50)) self.train_num_dataloader_workers = int(self.tr_configs.get("num_dataloader_workers", 8)) self.max_epoch = int(self.tr_configs.get("max_epoch", 5000)) + self.val_every_n = int(self.tr_configs.get("val_every_n", 1)) + self.early_stopping_patience = int(self.tr_configs.get("early_stopping_patience", 0)) + self.save_gifs = bool(self.tr_configs.get("save_gifs", True)) + compile_policy = str(self.tr_configs.get("compile_policy", "auto")).strip().lower() + if compile_policy not in {"auto", "off", "module", "ddp_wrapper"}: + raise ValueError( + "tr_config.compile_policy must be one of " + "{'auto', 'off', 'module', 'ddp_wrapper'}" + ) + self.compile_policy = compile_policy + self.startup_timing = bool(self.tr_configs.get("startup_timing", False)) + ddp_find_unused_parameters = self.tr_configs.get("ddp_find_unused_parameters", "auto") + if isinstance(ddp_find_unused_parameters, str): + ddp_find_unused_parameters_norm = ddp_find_unused_parameters.strip().lower() + if ddp_find_unused_parameters_norm not in {"auto", "true", "false"}: + raise ValueError( + "tr_config.ddp_find_unused_parameters must be one of " + "{'auto', true, false}" + ) + self.ddp_find_unused_parameters = ddp_find_unused_parameters_norm + else: + self.ddp_find_unused_parameters = bool(ddp_find_unused_parameters) + ddp_static_graph = self.tr_configs.get("ddp_static_graph", "auto") + if isinstance(ddp_static_graph, str): + ddp_static_graph_norm = ddp_static_graph.strip().lower() + if ddp_static_graph_norm not in {"auto", "true", "false"}: + raise ValueError( + "tr_config.ddp_static_graph must be one of " + "{'auto', true, false}" + ) + self.ddp_static_graph = ddp_static_graph_norm + else: + self.ddp_static_graph = bool(ddp_static_graph) + self.ddp_gradient_as_bucket_view = bool( + self.tr_configs.get("ddp_gradient_as_bucket_view", False) + ) self.optimizer = self.tr_configs.get("optimizer", "SGD") self.initial_lr = float(self.tr_configs.get("initial_lr", 0.01)) self.weight_decay = float(self.tr_configs.get("weight_decay", 0.00003)) + self.guide_loss_weight = float(self.tr_configs.get("guide_loss_weight", 0.0)) + self.guide_supervision_target = self.tr_configs.get("guide_supervision_target", None) + + ema_cfg = deepcopy(getattr(self, "ema_config", {}) or {}) + self.ema_enabled = bool(ema_cfg.get("enabled", False)) + self.ema_decay = float(ema_cfg.get("decay", 0.999)) + self.ema_start_step = int(ema_cfg.get("start_step", 0)) + self.ema_update_every_steps = max(1, int(ema_cfg.get("update_every_steps", 1))) + self.ema_validate = bool(ema_cfg.get("validate", self.ema_enabled)) + self.ema_save_in_checkpoint = bool( + ema_cfg.get("save_in_checkpoint", self.ema_enabled) + ) + self.ema_config = { + "enabled": self.ema_enabled, + "decay": self.ema_decay, + "start_step": self.ema_start_step, + "update_every_steps": self.ema_update_every_steps, + "validate": self.ema_validate, + "save_in_checkpoint": self.ema_save_in_checkpoint, + } ### Dataset config ### self.min_labeled_ratio = float(self.dataset_config.get("min_labeled_ratio", 0.10)) diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps128_guided_dinovol_ink.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_guided_dinovol_ink.yaml new file mode 100644 index 000000000..75503b2a9 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_guided_dinovol_ink.yaml @@ -0,0 +1,47 @@ +tr_setup: + model_name: "ps128_guided_dinovol_ink" + +tr_config: + patch_size: [128, 128, 128] + batch_size: 1 + enable_deep_supervision: false + guide_loss_weight: 0.25 + guide_supervision_target: "ink" + +model_config: + features_per_stage: [32, 64, 128, 256, 320] + n_stages: 5 + n_blocks_per_stage: [1, 1, 1, 1, 1] + n_conv_per_stage_decoder: [1, 1, 1, 1] + kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]] + strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] + basic_encoder_block: "ConvBlock" + basic_decoder_block: "ConvBlock" + bottleneck_block: "BasicBlockD" + separate_decoders: true + input_shape: [128, 128, 128] + guide_backbone: "/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt" + guide_freeze: true + guide_tokenbook_tokens: 256 + guide_tokenbook_dropout: 0.0 + guide_tokenbook_sample_rate: 1.0 + guide_tokenbook_ema_decay: null + guide_tokenbook_use_ema: false + +dataset_config: + min_labeled_ratio: 0.01 + min_bbox_percent: 0.15 + skip_patch_validation: false + normalization_scheme: "zscore" + ome_zarr_resolution: 0 + bg_sampling_enabled: true + bg_to_fg_ratio: 0.10 + + targets: + ink: + out_channels: 2 + activation: "none" + ignore_label: 2 + losses: + - name: "nnUNet_DC_and_CE_loss" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps128_guided_direct_segmentation_ink.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_guided_direct_segmentation_ink.yaml new file mode 100644 index 000000000..2eece8a8c --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_guided_direct_segmentation_ink.yaml @@ -0,0 +1,33 @@ +tr_setup: + model_name: "ps128_guided_direct_segmentation_ink" + +tr_config: + patch_size: [128, 128, 128] + batch_size: 1 + enable_deep_supervision: false + guide_loss_weight: 0.0 + +model_config: + guide_backbone: "/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt" + guide_fusion_stage: "direct_segmentation" + guide_compile_policy: "all_guidance" + guide_freeze: true + input_shape: [128, 128, 128] + +dataset_config: + min_labeled_ratio: 0.01 + min_bbox_percent: 0.15 + skip_patch_validation: false + normalization_scheme: "zscore" + ome_zarr_resolution: 0 + bg_sampling_enabled: true + bg_to_fg_ratio: 0.10 + + targets: + ink: + out_channels: 2 + activation: "none" + ignore_label: 2 + losses: + - name: "nnUNet_DC_and_CE_loss" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps128_guided_feature_encoder_ink.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_guided_feature_encoder_ink.yaml new file mode 100644 index 000000000..f2d7403c4 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_guided_feature_encoder_ink.yaml @@ -0,0 +1,49 @@ +tr_setup: + model_name: "ps128_guided_feature_encoder_ink" + +tr_config: + patch_size: [128, 128, 128] + batch_size: 1 + enable_deep_supervision: false + guide_loss_weight: 0.0 + +model_config: + features_per_stage: [32, 64, 128, 256, 320] + n_stages: 5 + n_blocks_per_stage: [1, 1, 1, 1, 1] + n_conv_per_stage_decoder: [1, 1, 1, 1] + kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]] + strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] + basic_encoder_block: "ConvBlock" + basic_decoder_block: "ConvBlock" + bottleneck_block: "BasicBlockD" + separate_decoders: true + input_shape: [128, 128, 128] + guide_backbone: "/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt" + guide_fusion_stage: "feature_encoder" + guide_feature_gate_alpha: 0.9 + guide_freeze: true + guide_tokenbook_tokens: 256 + guide_tokenbook_prototype_weighting: "token_mlp" + guide_tokenbook_dropout: 0.0 + guide_tokenbook_sample_rate: 1.0 + guide_tokenbook_ema_decay: null + guide_tokenbook_use_ema: false + +dataset_config: + min_labeled_ratio: 0.01 + min_bbox_percent: 0.15 + skip_patch_validation: false + normalization_scheme: "zscore" + ome_zarr_resolution: 0 + bg_sampling_enabled: true + bg_to_fg_ratio: 0.10 + + targets: + ink: + out_channels: 2 + activation: "none" + ignore_label: 2 + losses: + - name: "nnUNet_DC_and_CE_loss" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps128_guided_feature_skip_concat_ink.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_guided_feature_skip_concat_ink.yaml new file mode 100644 index 000000000..4593e67f1 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_guided_feature_skip_concat_ink.yaml @@ -0,0 +1,44 @@ +tr_setup: + model_name: "ps128_guided_feature_skip_concat_ink" + +tr_config: + patch_size: [128, 128, 128] + batch_size: 1 + enable_deep_supervision: false + guide_loss_weight: 0.0 + +model_config: + features_per_stage: [32, 64, 128, 256, 320] + n_stages: 5 + n_blocks_per_stage: [1, 1, 1, 1, 1] + n_conv_per_stage_decoder: [1, 1, 1, 1] + kernel_sizes: [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]] + strides: [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]] + basic_encoder_block: "ConvBlock" + basic_decoder_block: "ConvBlock" + bottleneck_block: "BasicBlockD" + separate_decoders: true + input_shape: [128, 128, 128] + guide_backbone: "/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt" + guide_fusion_stage: "feature_skip_concat" + guide_compile_policy: "all_guidance" + guide_skip_concat_projector_channels: [8, 16, 32, 64, 80] + guide_freeze: true + +dataset_config: + min_labeled_ratio: 0.01 + min_bbox_percent: 0.15 + skip_patch_validation: false + normalization_scheme: "zscore" + ome_zarr_resolution: 0 + bg_sampling_enabled: true + bg_to_fg_ratio: 0.10 + + targets: + ink: + out_channels: 2 + activation: "none" + ignore_label: 2 + losses: + - name: "nnUNet_DC_and_CE_loss" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps128_mednext_v1_medial.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_mednext_v1_medial.yaml new file mode 100644 index 000000000..0353dac58 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_mednext_v1_medial.yaml @@ -0,0 +1,27 @@ +tr_config: + patch_size: [128,128,128] + enable_deep_supervision: false + +model_config: + architecture_type: "mednext_v1" + mednext_model_id: "B" + mednext_kernel_size: 3 + +dataset_config: + valid_patch_find_resolution: 2 + valid_patch_value: 1 + min_labeled_ratio: 0.01 + min_bbox_percent: 0.15 + normalization_scheme: "zscore" + skip_patch_validation: false + ome_zarr_resolution: 0 + bg_sampling_enabled: true + bg_to_fg_ratio: 0.10 + + targets: + surface: + activation: "none" + ignore_label: 2 + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps128_mednext_v2_b_medial.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_mednext_v2_b_medial.yaml new file mode 100644 index 000000000..32d6e2721 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_mednext_v2_b_medial.yaml @@ -0,0 +1,26 @@ +tr_config: + patch_size: [128,128,128] + enable_deep_supervision: false + +model_config: + architecture_type: "mednext_v2" + mednext_model_id: "B" + +dataset_config: + valid_patch_find_resolution: 2 + valid_patch_value: 1 + min_labeled_ratio: 0.01 + min_bbox_percent: 0.15 + normalization_scheme: "zscore" + skip_patch_validation: false + ome_zarr_resolution: 0 + bg_sampling_enabled: true + bg_to_fg_ratio: 0.10 + + targets: + surface: + activation: "none" + ignore_label: 2 + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps128_mednext_v2_l_medial.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_mednext_v2_l_medial.yaml new file mode 100644 index 000000000..a4d87b016 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_mednext_v2_l_medial.yaml @@ -0,0 +1,26 @@ +tr_config: + patch_size: [128,128,128] + enable_deep_supervision: false + +model_config: + architecture_type: "mednext_v2" + mednext_model_id: "L" + +dataset_config: + valid_patch_find_resolution: 2 + valid_patch_value: 1 + min_labeled_ratio: 0.01 + min_bbox_percent: 0.15 + normalization_scheme: "zscore" + skip_patch_validation: false + ome_zarr_resolution: 0 + bg_sampling_enabled: true + bg_to_fg_ratio: 0.10 + + targets: + surface: + activation: "none" + ignore_label: 2 + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps128_mednext_v2_l_width2_medial.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_mednext_v2_l_width2_medial.yaml new file mode 100644 index 000000000..de77166d1 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_mednext_v2_l_width2_medial.yaml @@ -0,0 +1,27 @@ +tr_config: + patch_size: [128,128,128] + enable_deep_supervision: false + +model_config: + architecture_type: "mednext_v2" + mednext_model_id: "L" + mednext_width_factor: 2 + +dataset_config: + valid_patch_find_resolution: 2 + valid_patch_value: 1 + min_labeled_ratio: 0.01 + min_bbox_percent: 0.15 + normalization_scheme: "zscore" + skip_patch_validation: false + ome_zarr_resolution: 0 + bg_sampling_enabled: true + bg_to_fg_ratio: 0.10 + + targets: + surface: + activation: "none" + ignore_label: 2 + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps128_pretrained_dino_pixelshuffle_medial.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_pretrained_dino_pixelshuffle_medial.yaml new file mode 100644 index 000000000..6cf6eb39d --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps128_pretrained_dino_pixelshuffle_medial.yaml @@ -0,0 +1,23 @@ +tr_config: + patch_size: [128, 128, 128] + batch_size: 1 + enable_deep_supervision: false + +model_config: + pretrained_backbone: "/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt" + pretrained_decoder_type: "pixelshuffle_conv" + freeze_encoder: true + input_shape: [128, 128, 128] + +dataset_config: + min_labeled_ratio: 0 + min_bbox_percent: 0 + normalization_scheme: "zscore" + skip_patch_validation: true + targets: + surface: + out_channels: 2 + activation: "none" + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps192_mednext_v2_l_medial.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps192_mednext_v2_l_medial.yaml new file mode 100644 index 000000000..ed5461cc6 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps192_mednext_v2_l_medial.yaml @@ -0,0 +1,19 @@ +tr_config: + patch_size: [192,192,192] + enable_deep_supervision: false + +model_config: + architecture_type: "mednext_v2" + mednext_model_id: "L" + +dataset_config: + min_labeled_ratio: 0 + min_bbox_percent: 0 + normalization_scheme: "zscore" + skip_patch_validation: true + targets: + surface: + activation: "none" + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_dicece.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_dicece.yaml new file mode 100644 index 000000000..4099499e3 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_dicece.yaml @@ -0,0 +1,27 @@ +tr_config: + patch_size: [256, 256, 256] + batch_size: 1 + guide_loss_weight: 0.25 + guide_supervision_target: "surface" + +model_config: + guide_backbone: "/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt" + guide_freeze: true + guide_tokenbook_tokens: 256 + guide_tokenbook_dropout: 0.0 + guide_tokenbook_sample_rate: 1.0 + guide_tokenbook_ema_decay: null + guide_tokenbook_use_ema: false + input_shape: [256, 256, 256] + +dataset_config: + min_labeled_ratio: 0 + min_bbox_percent: 0 + normalization_scheme: "zscore" + skip_patch_validation: true + targets: + surface: + activation: "none" + losses: + - name: "nnUNet_DC_and_CE_loss" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_direct_segmentation_dicece.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_direct_segmentation_dicece.yaml new file mode 100644 index 000000000..45a0e1812 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_direct_segmentation_dicece.yaml @@ -0,0 +1,25 @@ +tr_config: + patch_size: [256, 256, 256] + batch_size: 1 + enable_deep_supervision: false + guide_loss_weight: 0.0 + +model_config: + guide_backbone: "/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt" + guide_fusion_stage: "direct_segmentation" + guide_compile_policy: "all_guidance" + guide_freeze: true + input_shape: [256, 256, 256] + +dataset_config: + min_labeled_ratio: 0 + min_bbox_percent: 0 + normalization_scheme: "zscore" + skip_patch_validation: true + targets: + surface: + out_channels: 2 + activation: "none" + losses: + - name: "nnUNet_DC_and_CE_loss" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_direct_segmentation_medial.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_direct_segmentation_medial.yaml new file mode 100644 index 000000000..cb6111bee --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_direct_segmentation_medial.yaml @@ -0,0 +1,25 @@ +tr_config: + patch_size: [256, 256, 256] + batch_size: 1 + enable_deep_supervision: false + guide_loss_weight: 0.0 + +model_config: + guide_backbone: "/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt" + guide_fusion_stage: "direct_segmentation" + guide_compile_policy: "all_guidance" + guide_freeze: true + input_shape: [256, 256, 256] + +dataset_config: + min_labeled_ratio: 0 + min_bbox_percent: 0 + normalization_scheme: "zscore" + skip_patch_validation: true + targets: + surface: + out_channels: 2 + activation: "none" + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_feature_encoder_dicece.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_feature_encoder_dicece.yaml new file mode 100644 index 000000000..1f66cf940 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_feature_encoder_dicece.yaml @@ -0,0 +1,29 @@ +tr_config: + patch_size: [256, 256, 256] + batch_size: 1 + guide_loss_weight: 0.0 + +model_config: + guide_backbone: "/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt" + guide_fusion_stage: "feature_encoder" + guide_feature_gate_alpha: 0.9 + guide_freeze: true + guide_tokenbook_tokens: 256 + guide_tokenbook_prototype_weighting: "token_mlp" + guide_tokenbook_dropout: 0.0 + guide_tokenbook_sample_rate: 1.0 + guide_tokenbook_ema_decay: null + guide_tokenbook_use_ema: false + input_shape: [256, 256, 256] + +dataset_config: + min_labeled_ratio: 0 + min_bbox_percent: 0 + normalization_scheme: "zscore" + skip_patch_validation: true + targets: + surface: + activation: "none" + losses: + - name: "nnUNet_DC_and_CE_loss" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_feature_encoder_medial.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_feature_encoder_medial.yaml new file mode 100644 index 000000000..7df6ba5af --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_feature_encoder_medial.yaml @@ -0,0 +1,29 @@ +tr_config: + patch_size: [256, 256, 256] + batch_size: 1 + guide_loss_weight: 0.0 + +model_config: + guide_backbone: "/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt" + guide_fusion_stage: "feature_encoder" + guide_feature_gate_alpha: 0.9 + guide_freeze: true + guide_tokenbook_tokens: 256 + guide_tokenbook_prototype_weighting: "token_mlp" + guide_tokenbook_dropout: 0.0 + guide_tokenbook_sample_rate: 1.0 + guide_tokenbook_ema_decay: null + guide_tokenbook_use_ema: false + input_shape: [256, 256, 256] + +dataset_config: + min_labeled_ratio: 0 + min_bbox_percent: 0 + normalization_scheme: "zscore" + skip_patch_validation: true + targets: + surface: + activation: "none" + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_feature_skip_concat_dicece.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_feature_skip_concat_dicece.yaml new file mode 100644 index 000000000..ff2d8eb49 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_feature_skip_concat_dicece.yaml @@ -0,0 +1,24 @@ +tr_config: + patch_size: [256, 256, 256] + batch_size: 1 + guide_loss_weight: 0.0 + +model_config: + guide_backbone: "/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt" + guide_fusion_stage: "feature_skip_concat" + guide_compile_policy: "all_guidance" + guide_skip_concat_projector_channels: [8, 16, 32, 64, 80, 80, 80] + guide_freeze: true + input_shape: [256, 256, 256] + +dataset_config: + min_labeled_ratio: 0 + min_bbox_percent: 0 + normalization_scheme: "zscore" + skip_patch_validation: true + targets: + surface: + activation: "none" + losses: + - name: "nnUNet_DC_and_CE_loss" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_feature_skip_concat_medial.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_feature_skip_concat_medial.yaml new file mode 100644 index 000000000..830844c8e --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_feature_skip_concat_medial.yaml @@ -0,0 +1,24 @@ +tr_config: + patch_size: [256, 256, 256] + batch_size: 1 + guide_loss_weight: 0.0 + +model_config: + guide_backbone: "/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt" + guide_fusion_stage: "feature_skip_concat" + guide_compile_policy: "all_guidance" + guide_skip_concat_projector_channels: [8, 16, 32, 64, 80, 80, 80] + guide_freeze: true + input_shape: [256, 256, 256] + +dataset_config: + min_labeled_ratio: 0 + min_bbox_percent: 0 + normalization_scheme: "zscore" + skip_patch_validation: true + targets: + surface: + activation: "none" + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_medial.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_medial.yaml new file mode 100644 index 000000000..fb13fafe1 --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_guided_medial.yaml @@ -0,0 +1,27 @@ +tr_config: + patch_size: [256, 256, 256] + batch_size: 1 + guide_loss_weight: 0.25 + guide_supervision_target: "surface" + +model_config: + guide_backbone: "/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt" + guide_freeze: true + guide_tokenbook_tokens: 256 + guide_tokenbook_dropout: 0.0 + guide_tokenbook_sample_rate: 1.0 + guide_tokenbook_ema_decay: null + guide_tokenbook_use_ema: false + input_shape: [256, 256, 256] + +dataset_config: + min_labeled_ratio: 0 + min_bbox_percent: 0 + normalization_scheme: "zscore" + skip_patch_validation: true + targets: + surface: + activation: "none" + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps256_mednext_v1_medial.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_mednext_v1_medial.yaml new file mode 100644 index 000000000..6006a692a --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_mednext_v1_medial.yaml @@ -0,0 +1,20 @@ +tr_config: + patch_size: [256,256,256] + enable_deep_supervision: false + +model_config: + architecture_type: "mednext_v1" + mednext_model_id: "B" + mednext_kernel_size: 3 + +dataset_config: + min_labeled_ratio: 0 + min_bbox_percent: 0 + normalization_scheme: "zscore" + skip_patch_validation: true + targets: + surface: + activation: "none" + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/configuration/single_task/ps256_pretrained_dino_pixelshuffle_medial.yaml b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_pretrained_dino_pixelshuffle_medial.yaml new file mode 100644 index 000000000..73e59579a --- /dev/null +++ b/vesuvius/src/vesuvius/models/configuration/single_task/ps256_pretrained_dino_pixelshuffle_medial.yaml @@ -0,0 +1,23 @@ +tr_config: + patch_size: [256, 256, 256] + batch_size: 1 + enable_deep_supervision: false + +model_config: + pretrained_backbone: "/home/giorgio/Projects/dino-vesuvius/dino-checkpoints/checkpoint_step_342500.pt" + pretrained_decoder_type: "pixelshuffle_conv" + freeze_encoder: true + input_shape: [256, 256, 256] + +dataset_config: + min_labeled_ratio: 0 + min_bbox_percent: 0 + normalization_scheme: "zscore" + skip_patch_validation: true + targets: + surface: + out_channels: 2 + activation: "none" + losses: + - name: "MedialSurfaceRecall" + weight: 1.0 diff --git a/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py b/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py index 99167546a..39abff859 100644 --- a/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py +++ b/vesuvius/src/vesuvius/models/datasets/zarr_dataset.py @@ -166,6 +166,11 @@ def __init__( def _discover_and_load_volumes(self) -> None: """Discover and load all OME-Zarr volumes.""" + volumes_cfg = getattr(self.mgr, "dataset_config", {}).get("volumes") + if volumes_cfg: + self._load_explicit_volumes(volumes_cfg) + return + images_dir = self.data_path / "images" labels_dir = self.data_path / "labels" @@ -222,6 +227,128 @@ def _discover_and_load_volumes(self) -> None: volume_id, spatial_shape, has_any_label ) + def _load_explicit_volumes(self, volumes_cfg) -> None: + """Load explicit volume specs from dataset_config.volumes.""" + if isinstance(volumes_cfg, dict): + volume_specs = list(volumes_cfg.values()) + elif isinstance(volumes_cfg, (list, tuple)): + volume_specs = list(volumes_cfg) + else: + raise ValueError("dataset_config.volumes must be a list or mapping") + + image_id_counts: Dict[str, int] = {} + logger.info("Loading %d explicit volume specs", len(volume_specs)) + + for volume_idx, spec in enumerate(volume_specs): + if isinstance(spec, (str, Path)): + spec = {"image": spec} + if not isinstance(spec, dict): + raise ValueError("Each explicit volume entry must be a mapping or path string") + + image_path = self._resolve_explicit_volume_path( + spec.get("image") or spec.get("image_path") + ) + if image_path is None: + raise ValueError("Explicit volume entries must define 'image' or 'image_path'") + if not image_path.exists(): + raise FileNotFoundError(f"Explicit image volume not found: {image_path}") + + label_mapping = self._resolve_explicit_label_mapping(spec) + label_paths: Dict[str, Optional[Path]] = {} + label_arrays: Dict[str, Optional[zarr.Array]] = {} + has_any_label = False + for target in self.target_names: + label_path = self._resolve_explicit_volume_path(label_mapping.get(target)) + if label_path is not None: + if not label_path.exists(): + raise FileNotFoundError(f"Explicit label volume not found: {label_path}") + label_paths[target] = label_path + label_arrays[target] = self._open_zarr(label_path) + has_any_label = True + else: + if not self.allow_unlabeled_data: + raise FileNotFoundError( + f"Missing explicit label for target '{target}' in volume spec #{volume_idx}" + ) + label_paths[target] = None + label_arrays[target] = None + + image_array = self._open_zarr(image_path) + spatial_shape = self._get_spatial_shape(image_array) + volume_id = self._derive_explicit_volume_id( + spec=spec, + image_path=image_path, + label_paths=label_paths, + image_id_counts=image_id_counts, + ) + + self._volumes.append( + VolumeInfo( + volume_id=volume_id, + image_path=image_path, + image_array=image_array, + label_paths=label_paths, + label_arrays=label_arrays, + spatial_shape=spatial_shape, + has_labels=has_any_label, + ) + ) + + logger.info( + "Loaded explicit volume '%s': image=%s shape=%s has_labels=%s", + volume_id, + image_path, + spatial_shape, + has_any_label, + ) + + def _resolve_explicit_label_mapping(self, spec: dict) -> Dict[str, Optional[str]]: + labels = spec.get("labels") + if labels is None: + labels = spec.get("label_paths") + if labels is not None: + if not isinstance(labels, dict): + raise ValueError("Explicit volume labels must be a mapping of target -> path") + return {str(k): v for k, v in labels.items()} + + singular_label = spec.get("label") + if singular_label in (None, ""): + return {} + if len(self.target_names) != 1: + raise ValueError( + "Explicit volume entries that use singular 'label' require exactly one active target" + ) + return {self.target_names[0]: singular_label} + + def _resolve_explicit_volume_path(self, value) -> Optional[Path]: + if value in (None, ""): + return None + path = Path(value) + if path.is_absolute(): + return path + config_path = getattr(self.mgr, "_config_path", None) + base_dir = Path(config_path).parent.resolve() if config_path is not None else self.data_path + return (base_dir / path).resolve() + + def _derive_explicit_volume_id( + self, + *, + spec: dict, + image_path: Path, + label_paths: Dict[str, Optional[Path]], + image_id_counts: Dict[str, int], + ) -> str: + configured_id = spec.get("volume_id") or spec.get("name") + if configured_id not in (None, ""): + return str(configured_id) + + base_id = image_path.stem + duplicate_idx = image_id_counts.get(base_id, 0) + image_id_counts[base_id] = duplicate_idx + 1 + if duplicate_idx == 0: + return base_id + return f"{base_id}_{duplicate_idx}" + def _open_zarr(self, path: Path) -> zarr.Array: """Open a zarr array at the configured resolution level.""" store = zarr.open(path, mode="r") @@ -312,7 +439,7 @@ def _build_patch_index(self) -> None: # No cache found - enumerate all patches without validation logger.warning( "No patch cache found. Enumerating all patches without validation. " - "Run `vesuvius.build_patch_cache --config ` to generate cache." + "Run `vesuvius.find_patches --config ` to generate cache." ) self._enumerate_all_patches() diff --git a/vesuvius/src/vesuvius/models/run/inference.py b/vesuvius/src/vesuvius/models/run/inference.py index b8a8d331a..5d52e7b76 100644 --- a/vesuvius/src/vesuvius/models/run/inference.py +++ b/vesuvius/src/vesuvius/models/run/inference.py @@ -21,6 +21,27 @@ from vesuvius.utils.k8s import get_tqdm_kwargs +class _InferenceDeepSupervisionWrapper(torch.nn.Module): + """Collapse multi-scale train-time outputs to highest-resolution inference outputs.""" + + def __init__(self, network: torch.nn.Module): + super().__init__() + self.network = network + self.final_config = getattr(network, "final_config", None) + + def _collapse(self, output): + if isinstance(output, dict): + return {key: self._collapse(value) for key, value in output.items()} + if isinstance(output, (list, tuple)): + if not output: + return output + return self._collapse(output[0]) + return output + + def forward(self, *args, **kwargs): + return self._collapse(self.network(*args, **kwargs)) + + class Inferer(): def __init__(self, model_path: str = None, @@ -328,6 +349,7 @@ def __init__(self, model_config): self.train_batch_size = model_config.get('train_batch_size', model_config.get('batch_size', 2)) self.in_channels = model_config.get('in_channels', 1) self.autoconfigure = model_config.get('autoconfigure', False) + self.enable_deep_supervision = bool(model_config.get('enable_deep_supervision', False)) self.model_name = model_config.get('model_name', 'Model') # Set spacing based on patch size dimensions @@ -364,6 +386,11 @@ def strip_key(k: str) -> str: if self.verbose: print("Model weights loaded successfully") + if getattr(mgr, "enable_deep_supervision", False): + model = _InferenceDeepSupervisionWrapper(model) + if self.verbose: + print("Wrapped deep-supervision model for plain inference outputs") + # Compile model for CUDA inference (provides 10-30% speedup via kernel fusion) # Note: 'reduce-overhead' mode uses CUDA graphs which can cause tensor reuse issues # when outputs are accessed after subsequent runs. Using 'default' mode instead. @@ -671,19 +698,22 @@ def write_patch(write_index, patch_data): except Exception as e: raise RuntimeError(f"Failed to write patch at index {write_index}: {str(e)}") + def to_device(tensor): + return tensor.to(self.device, non_blocking=(self.device.type == 'cuda')) + with tqdm(total=self.num_total_patches, desc=f"Inferring Part {self.part_id}", **get_tqdm_kwargs()) as pbar: with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] for batch_data in self.dataloader: if isinstance(batch_data, dict): - input_batch = batch_data['data'].to(self.device) + input_batch = to_device(batch_data['data']) is_empty_flags = batch_data.get('is_empty', [False] * input_batch.shape[0]) elif isinstance(batch_data, (list, tuple)): - input_batch = batch_data[0].to(self.device) + input_batch = to_device(batch_data[0]) is_empty_flags = [False] * input_batch.shape[0] else: - input_batch = batch_data.to(self.device) + input_batch = to_device(batch_data) is_empty_flags = [False] * input_batch.shape[0] # the case that the batch is empty is valid, e.g. when the input volume is smaller than the patch size @@ -694,7 +724,8 @@ def write_patch(write_index, patch_data): batch_size = input_batch.shape[0] output_shape = (batch_size, self.num_classes, *self.patch_size) - output_batch = torch.zeros(output_shape, device=self.device, dtype=input_batch.dtype) + output_dtype = torch.float16 if self.device.type == 'cuda' else input_batch.dtype + output_batch = torch.zeros(output_shape, device=self.device, dtype=output_dtype) # Find non-empty patches that need model inference non_empty_indices = [i for i, is_empty in enumerate(is_empty_flags) if not is_empty] @@ -703,7 +734,7 @@ def write_patch(write_index, patch_data): if non_empty_indices: non_empty_input = input_batch[non_empty_indices] - with torch.no_grad(), torch.amp.autocast('cuda'): + with torch.inference_mode(), torch.amp.autocast('cuda'): if self.do_tta: non_empty_output = infer_with_tta( self.model, @@ -716,6 +747,7 @@ def write_patch(write_index, patch_data): non_empty_output = self.model(non_empty_input) if self.is_multi_task: non_empty_output = self._concat_multi_task_outputs(non_empty_output) + non_empty_output = non_empty_output.to(dtype=output_batch.dtype) # Place non-empty patch outputs in the correct positions for idx, original_idx in enumerate(non_empty_indices): @@ -725,7 +757,7 @@ def write_patch(write_index, patch_data): if self.verbose: print("Batch contains only empty patches, skipping model inference") - output_np = output_batch.cpu().numpy().astype(np.float16) + output_np = self._finalize_output_batch(output_batch) current_batch_size = output_np.shape[0] patch_indices = batch_data.get('index', list(range(current_batch_size))) @@ -763,6 +795,11 @@ def write_patch(write_index, patch_data): if self.current_patch_write_index != self.num_total_patches: print(f"Warning: Expected {self.num_total_patches} patches, but wrote {self.current_patch_write_index}.") + def _finalize_output_batch(self, output_batch): + if output_batch.dtype != torch.float16: + output_batch = output_batch.to(dtype=torch.float16) + return output_batch.cpu().numpy() + def _run_inference(self): if self.verbose: print("Loading model...") self.model = self._load_model() diff --git a/vesuvius/src/vesuvius/models/training/optimizers.py b/vesuvius/src/vesuvius/models/training/optimizers.py index a20909f43..d0f5710bd 100644 --- a/vesuvius/src/vesuvius/models/training/optimizers.py +++ b/vesuvius/src/vesuvius/models/training/optimizers.py @@ -8,6 +8,7 @@ def create_optimizer(optimizer_config, model): optim_name = optimizer_config.get('name', 'Adam') learning_rate = optimizer_config.get('learning_rate', 1e-3) weight_decay = optimizer_config.get('weight_decay', 0) + eps = optimizer_config.get('eps', None) optim_name = optim_name.lower() if optim_name == 'adadelta': @@ -20,15 +21,23 @@ def create_optimizer(optimizer_config, model): weight_decay=weight_decay) elif optim_name == 'adamw': betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) - optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=betas, - weight_decay=weight_decay) + adamw_kwargs = { + 'lr': learning_rate, + 'betas': betas, + 'weight_decay': weight_decay, + } + if eps is not None: + adamw_kwargs['eps'] = eps + optimizer = optim.AdamW(model.parameters(), **adamw_kwargs) elif optim_name == 'sparseadam': betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) optimizer = optim.SparseAdam(model.parameters(), lr=learning_rate, betas=betas) elif optim_name == 'adamax': betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) - optimizer = optim.Adamax(model.parameters(), lr=learning_rate, betas=betas, - weight_decay=weight_decay) + adamax_kwargs = {'lr': learning_rate, 'betas': betas, 'weight_decay': weight_decay} + if eps is not None: + adamax_kwargs['eps'] = eps + optimizer = optim.Adamax(model.parameters(), **adamax_kwargs) elif optim_name == 'asgd': lambd = optimizer_config.get('lambd', 0.0001) alpha = optimizer_config.get('alpha', 0.75) @@ -52,8 +61,10 @@ def create_optimizer(optimizer_config, model): weight_decay=weight_decay) elif optim_name == 'radam': betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) - optimizer = optim.RAdam(model.parameters(), lr=learning_rate, betas=betas, - weight_decay=weight_decay) + radam_kwargs = {'lr': learning_rate, 'betas': betas, 'weight_decay': weight_decay} + if eps is not None: + radam_kwargs['eps'] = eps + optimizer = optim.RAdam(model.parameters(), **radam_kwargs) elif optim_name == 'rmsprop': alpha = optimizer_config.get('alpha', 0.99) optimizer = optim.RMSprop(model.parameters(), lr=learning_rate, alpha=alpha, @@ -87,7 +98,9 @@ def create_optimizer(optimizer_config, model): else: # Adam is default betas = tuple(optimizer_config.get('betas', (0.9, 0.999))) - optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=betas, - weight_decay=weight_decay) + adam_kwargs = {'lr': learning_rate, 'betas': betas, 'weight_decay': weight_decay} + if eps is not None: + adam_kwargs['eps'] = eps + optimizer = optim.Adam(model.parameters(), **adam_kwargs) return optimizer diff --git a/vesuvius/src/vesuvius/models/training/train.py b/vesuvius/src/vesuvius/models/training/train.py index 6c1b36995..409256888 100644 --- a/vesuvius/src/vesuvius/models/training/train.py +++ b/vesuvius/src/vesuvius/models/training/train.py @@ -1,7 +1,9 @@ from pathlib import Path from copy import deepcopy import os +from time import perf_counter from datetime import datetime +from PIL import Image from tqdm import tqdm import numpy as np import torch @@ -13,7 +15,7 @@ from torch.utils.data.distributed import DistributedSampler from vesuvius.models.utils import InitWeights_He from vesuvius.models.datasets import ZarrDataset -from vesuvius.utils.plotting import save_debug +from vesuvius.utils.plotting import save_debug, convert_slice_to_bgr, _compute_display_value_range, add_text_label from vesuvius.models.build.build_network_from_config import NetworkFromConfig from vesuvius.models.training.loss.losses import _create_loss @@ -138,6 +140,46 @@ def __init__(self, self._augmentation_names = None self._epoch_aug_time = None self._epoch_aug_count = None + self._on_device_transforms = None + ema_cfg = getattr(self.mgr, 'ema_config', {}) or {} + self.ema_enabled = bool(getattr(self.mgr, 'ema_enabled', ema_cfg.get('enabled', False))) + self.ema_decay = float(getattr(self.mgr, 'ema_decay', ema_cfg.get('decay', 0.999))) + self.ema_start_step = int(getattr(self.mgr, 'ema_start_step', ema_cfg.get('start_step', 0))) + self.ema_update_every_steps = max( + 1, + int(getattr(self.mgr, 'ema_update_every_steps', ema_cfg.get('update_every_steps', 1))) + ) + self.ema_validate = bool( + getattr(self.mgr, 'ema_validate', ema_cfg.get('validate', self.ema_enabled)) + ) + self.ema_save_in_checkpoint = bool( + getattr( + self.mgr, + 'ema_save_in_checkpoint', + ema_cfg.get('save_in_checkpoint', self.ema_enabled), + ) + ) + self.ema_model = None + self._ema_optimizer_step = 0 + self._checkpoint_ema_state = None + self._checkpoint_ema_optimizer_step = None + self._printed_ema_validation_mode = False + self._volume_dilate_by_name = {} + self.guide_loss_weight = float(getattr(self.mgr, "guide_loss_weight", 0.0)) + self.guide_supervision_target = getattr(self.mgr, "guide_supervision_target", None) + model_config = getattr(self.mgr, "model_config", {}) or {} + self.guide_fusion_stage = str(model_config.get("guide_fusion_stage", "input")).strip().lower() + if self.guide_fusion_stage in {"feature_encoder", "feature_skip_concat", "direct_segmentation"} and self.guide_loss_weight > 0.0: + raise ValueError( + "guide_loss_weight must be 0.0 when model_config.guide_fusion_stage is " + "'feature_encoder', 'feature_skip_concat', or 'direct_segmentation'" + ) + self._current_aux_outputs = {} + self.compile_policy = str(getattr(self.mgr, "compile_policy", "auto")).strip().lower() + self.startup_timing = bool(getattr(self.mgr, "startup_timing", False)) + self._startup_timing_records = [] + self._startup_timing_logged = False + self._first_optimizer_timing_recorded = False # --- build model --- # def _build_model(self): @@ -163,6 +205,184 @@ def _get_additional_checkpoint_data(self): """ return {} + def _unwrap_model(self, model): + if hasattr(model, 'module'): + model = model.module + if hasattr(model, '_orig_mod'): + try: + model = model._orig_mod + except Exception: + pass + return model + + def _wrap_model_for_distributed_training(self, model): + if not self.is_distributed: + return model + + raw_model = self._unwrap_model(model) + guide_enabled = bool(getattr(raw_model, "guide_enabled", False)) + + find_unused_cfg = getattr(self.mgr, "ddp_find_unused_parameters", "auto") + if isinstance(find_unused_cfg, str): + find_unused_norm = find_unused_cfg.strip().lower() + if find_unused_norm == "auto": + find_unused_parameters = not guide_enabled + else: + find_unused_parameters = find_unused_norm == "true" + else: + find_unused_parameters = bool(find_unused_cfg) + + static_graph_cfg = getattr(self.mgr, "ddp_static_graph", "auto") + if isinstance(static_graph_cfg, str): + static_graph_norm = static_graph_cfg.strip().lower() + if static_graph_norm == "auto": + static_graph = guide_enabled + else: + static_graph = static_graph_norm == "true" + else: + static_graph = bool(static_graph_cfg) + + ddp_kwargs = { + "find_unused_parameters": find_unused_parameters, + "static_graph": static_graph, + "gradient_as_bucket_view": bool(getattr(self.mgr, "ddp_gradient_as_bucket_view", False)), + } + if self.device.type == 'cuda': + ddp_kwargs.update( + device_ids=[self.assigned_gpu_id], + output_device=self.assigned_gpu_id, + ) + if not self.is_distributed or self.rank == 0: + print( + "DDP kwargs: " + f"find_unused_parameters={ddp_kwargs['find_unused_parameters']}, " + f"static_graph={ddp_kwargs['static_graph']}, " + f"gradient_as_bucket_view={ddp_kwargs['gradient_as_bucket_view']}" + ) + return DDP(model, **ddp_kwargs) + + def _record_startup_timing(self, label, duration_seconds): + if not self.startup_timing: + return + label = str(label) + duration_seconds = float(duration_seconds) + self._startup_timing_records.append((label, duration_seconds)) + if not self.is_distributed or self.rank == 0: + print(f"[Timing] {label}: {duration_seconds:.3f}s") + + def _flush_startup_timing(self, prefix="startup"): + if not self.startup_timing or self._startup_timing_logged: + return + self._startup_timing_logged = True + + def _resolve_compile_policy(self, model): + policy = self.compile_policy + if policy == "auto": + raw_model = self._unwrap_model(model) + return "module" if getattr(raw_model, "guide_enabled", False) else "ddp_wrapper" + return policy + + def _compile_module_in_place(self, model): + compile_method = getattr(model, "compile", None) + if callable(compile_method): + compile_method() + return model + return torch.compile(model) + + def _maybe_compile_model(self, model): + if self.device.type != "cuda": + return model + + compile_policy = self._resolve_compile_policy(model) + if compile_policy == "off": + if not self.is_distributed or self.rank == 0: + print("Compile policy set to 'off'; using eager mode.") + return model + + compile_start = perf_counter() + if not self.is_distributed or self.rank == 0: + print(f"Compiling model with policy '{compile_policy}'") + + try: + if compile_policy == "module": + model = self._compile_module_in_place(model) + elif compile_policy == "ddp_wrapper": + model = torch.compile(model) + else: + raise ValueError(f"Unsupported compile policy: {compile_policy}") + except Exception as e: + if not self.is_distributed or self.rank == 0: + print(f"Compile policy '{compile_policy}' failed; continuing without compile. Reason: {e}") + self._record_startup_timing("compile_failed", perf_counter() - compile_start) + return model + + self._record_startup_timing("compile", perf_counter() - compile_start) + return model + + def _create_ema_model(self, model): + ema_model = deepcopy(self._unwrap_model(model)) + ema_model = ema_model.to(self.device) + ema_model.eval() + for parameter in ema_model.parameters(): + parameter.requires_grad_(False) + return ema_model + + def _initialize_ema_model(self, model): + if not self.ema_enabled: + self.ema_model = None + return None + + self.ema_model = self._create_ema_model(model) + if self._checkpoint_ema_state is not None: + try: + self.ema_model.load_state_dict(self._checkpoint_ema_state) + self._ema_optimizer_step = int(self._checkpoint_ema_optimizer_step or 0) + print( + "Restored EMA model from checkpoint " + f"(optimizer_step={self._ema_optimizer_step})" + ) + except Exception as exc: + print(f"Warning: Failed to restore EMA model from checkpoint: {exc}") + print("Using freshly initialized EMA model") + self._ema_optimizer_step = 0 + finally: + self._checkpoint_ema_state = None + self._checkpoint_ema_optimizer_step = None + else: + print( + "Created EMA model " + f"(decay={self.ema_decay}, start_step={self.ema_start_step}, " + f"update_every_steps={self.ema_update_every_steps})" + ) + return self.ema_model + + def _update_ema_model(self, model): + if self.ema_model is None: + return + + self._ema_optimizer_step += 1 + if self._ema_optimizer_step < self.ema_start_step: + return + if ((self._ema_optimizer_step - self.ema_start_step) % self.ema_update_every_steps) != 0: + return + + ema_state = self.ema_model.state_dict() + for name, model_value in self._unwrap_model(model).state_dict().items(): + ema_value = ema_state[name] + model_value = model_value.detach() + if torch.is_floating_point(ema_value): + ema_value.lerp_(model_value.to(dtype=ema_value.dtype), 1.0 - self.ema_decay) + else: + ema_value.copy_(model_value) + + def _get_validation_model(self, model): + if self.ema_model is not None and self.ema_validate: + if not self._printed_ema_validation_mode: + print("Validation will use the EMA model") + self._printed_ema_validation_mode = True + return self.ema_model + return model + # --- configure dataset --- # def _configure_dataset(self, is_training=True): dataset = self._build_dataset_for_mgr(self.mgr, is_training=is_training) @@ -272,10 +492,401 @@ def _compute_loss_value( base_loss = getattr(loss_fn, 'loss', loss_fn) skeleton_losses = {'DC_SkelREC_and_CE_loss', 'SoftSkeletonRecallLoss'} + extra_loss_kwargs = {} + if self.guide_fusion_stage == "direct_segmentation": + prediction, ground_truth, extra_loss_kwargs = self._adapt_direct_segmentation_loss_inputs( + base_loss, + prediction, + ground_truth, + target_name=target_name, + ) + if skeleton_data is not None and base_loss.__class__.__name__ in skeleton_losses: - return loss_fn(prediction, ground_truth, skeleton_data) + return loss_fn(prediction, ground_truth, skeleton_data, **extra_loss_kwargs) + + return loss_fn(prediction, ground_truth, **extra_loss_kwargs) + + def _resolve_target_ignore_value(self, target_name: str): + target_info = self.mgr.targets.get(target_name, {}) or {} + for alias in ("ignore_index", "ignore_label", "ignore_value"): + value = target_info.get(alias) + if value is not None: + return value + return None + + @staticmethod + def _make_ignore_mask(target: torch.Tensor, ignore_label): + if ignore_label is None: + return None + if isinstance(ignore_label, float) and np.isnan(ignore_label): + return torch.isnan(target) + return target == float(ignore_label) + + def _adapt_direct_segmentation_loss_inputs(self, base_loss, prediction, ground_truth, *, target_name: str): + loss_name = base_loss.__class__.__name__ + ce_style_losses = {"DC_and_CE_loss", "DC_SkelREC_and_CE_loss"} + bce_region_losses = {"DC_and_BCE_loss", "LabelSmoothedDCAndBCELoss"} + bce_single_channel_losses = {"BCEWithLogitsLoss", "MemoryEfficientSoftDiceLoss", "SoftDiceLoss"} + supported_losses = ce_style_losses | bce_region_losses | bce_single_channel_losses + if loss_name not in supported_losses: + raise ValueError( + f"Loss '{loss_name}' is not supported with guide_fusion_stage='direct_segmentation'." + ) + + if not torch.is_tensor(ground_truth): + raise TypeError("direct_segmentation requires tensor targets") + + target = ground_truth + if target.ndim == 4: + target = target.unsqueeze(1) + if target.ndim != 5: + raise ValueError( + "direct_segmentation expects 5D targets [B, C, D, H, W] or 4D label maps, " + f"got shape {tuple(ground_truth.shape)}" + ) + + ignore_label = self._resolve_target_ignore_value(target_name) + ignore_mask = self._make_ignore_mask(target, ignore_label) + + if loss_name in ce_style_losses: + if target.shape[1] > 1: + target = torch.argmax(target, dim=1, keepdim=True) + return prediction, target, {} + + if prediction.ndim != 5 or prediction.shape[1] < 2: + raise ValueError( + "direct_segmentation expects synthesized 2-channel logits for BCE-style losses, " + f"got shape {tuple(prediction.shape)}" + ) + fg_prediction = prediction[:, 1:2] + + if target.shape[1] > 1: + foreground = target[:, -1:].float() + ignore_mask = None + else: + foreground = (target > 0).float() + if ignore_mask is not None: + foreground = foreground.masked_fill(ignore_mask, 0.0) + + if loss_name in bce_region_losses: + ignore_channel = ignore_mask.float() if ignore_mask is not None else torch.zeros_like(foreground) + region_target = torch.cat([foreground, ignore_channel], dim=1) + return fg_prediction, region_target, {} + + if loss_name == "BCEWithLogitsLoss": + if ignore_mask is not None: + bce_target = target.float().clone() + bce_target = torch.where(ignore_mask, bce_target, foreground) + else: + bce_target = foreground + return fg_prediction, bce_target, {} + + extra_loss_kwargs = {} + if ignore_mask is not None: + extra_loss_kwargs["loss_mask"] = (~ignore_mask).float() + return fg_prediction, foreground, extra_loss_kwargs + + def _resolve_guide_supervision_target(self, targets_dict): + if self.guide_supervision_target is not None: + return self.guide_supervision_target + for target_name, target_info in self.mgr.targets.items(): + if not target_info.get("auxiliary_task", False) and target_name in targets_dict: + return target_name + return None + + def _compute_guide_alignment_loss(self, targets_dict): + if self.guide_loss_weight <= 0.0: + return None + + guide_mask = self._current_aux_outputs.get("guide_mask") + if guide_mask is None: + return None + + target_name = self._resolve_guide_supervision_target(targets_dict) + if target_name is None or target_name not in targets_dict: + return None + + guide_target = targets_dict[target_name] + if isinstance(guide_target, (list, tuple)): + guide_target = guide_target[0] + if not torch.is_tensor(guide_target): + return None + + ignore_label = None + target_info = self.mgr.targets.get(target_name, {}) or {} + for key in ("ignore_label", "ignore_index", "ignore_value"): + value = target_info.get(key) + if value is not None: + ignore_label = value + break + + guide_target = guide_target.float() + if guide_target.ndim != 5: + return None + + if guide_target.shape[1] == 1: + ignore_mask = torch.zeros_like(guide_target, dtype=torch.bool) + if ignore_label is not None: + if isinstance(ignore_label, float) and np.isnan(ignore_label): + ignore_mask = torch.isnan(guide_target) + else: + ignore_mask = guide_target == float(ignore_label) + foreground = (guide_target > 0).float() + if ignore_label is not None: + foreground = foreground.masked_fill(ignore_mask, 0.0) + valid_mask = (~ignore_mask).float() + else: + foreground = (guide_target > 0).any(dim=1, keepdim=True).float() + valid_mask = torch.ones_like(foreground) + + target_resize = F.interpolate(foreground, size=guide_mask.shape[2:], mode="nearest") + valid_resize = F.interpolate(valid_mask, size=guide_mask.shape[2:], mode="nearest") + device_type = guide_mask.device.type + autocast_enabled = device_type in {"cuda", "cpu"} + with torch.amp.autocast(device_type=device_type, enabled=False) if autocast_enabled else nullcontext(): + guide_mask_fp32 = guide_mask.float().clamp(1e-6, 1.0 - 1e-6) + target_resize_fp32 = target_resize.float() + loss_map = F.binary_cross_entropy(guide_mask_fp32, target_resize_fp32, reduction="none") + weighted_loss = loss_map * valid_resize + denom = valid_resize.sum().clamp_min(1.0) + return weighted_loss.sum() / denom + + def _resolve_guide_supervision_target_name(self) -> str: + if self.guide_supervision_target is not None: + return str(self.guide_supervision_target) + for target_name, target_info in self.mgr.targets.items(): + if not target_info.get("auxiliary_task", False): + return str(target_name) + return "target" + + def _slice_aux_outputs_for_debug(self, aux_outputs, batch_index: int): + if not aux_outputs: + return {} + sliced = {} + for name, value in aux_outputs.items(): + if isinstance(value, torch.Tensor) and value.ndim > 0 and value.shape[0] > batch_index: + sliced[name] = value[batch_index: batch_index + 1] + return sliced + + @staticmethod + def _select_debug_sample_index_from_targets(targets_dict): + if not targets_dict: + return 0, False + first_target_any = next(iter(targets_dict.values())) + first_target = first_target_any[0] if isinstance(first_target_any, (list, tuple)) else first_target_any + if not torch.is_tensor(first_target) or first_target.ndim == 0: + return 0, False + for b in range(first_target.shape[0]): + if torch.any(first_target[b] != 0): + return b, True + return 0, False + + @staticmethod + def _detach_debug_tensors(data_dict): + detached = {} + for name, value in data_dict.items(): + if isinstance(value, torch.Tensor): + detached[name] = value.detach().cpu() + return detached + + def _capture_validation_debug_snapshot(self, inputs, targets_dict, outputs, aux_outputs, batch_index: int): + inputs_first = inputs[batch_index: batch_index + 1].detach().cpu() + + targets_dict_first_all = {} + for t_name, t_val in targets_dict.items(): + if isinstance(t_val, (list, tuple)): + targets_dict_first_all[t_name] = t_val[0][batch_index: batch_index + 1].detach().cpu() + else: + targets_dict_first_all[t_name] = t_val[batch_index: batch_index + 1].detach().cpu() + + outputs_dict_first = {} + for t_name, p_val in outputs.items(): + if isinstance(p_val, (list, tuple)): + outputs_dict_first[t_name] = p_val[0][batch_index: batch_index + 1].detach().cpu() + else: + outputs_dict_first[t_name] = p_val[batch_index: batch_index + 1].detach().cpu() + + raw_aux_outputs_dict_first = self._detach_debug_tensors( + self._slice_aux_outputs_for_debug(aux_outputs, batch_index) + ) + + return { + "input": inputs_first, + "targets_all": targets_dict_first_all, + "outputs": outputs_dict_first, + "raw_aux_outputs": raw_aux_outputs_dict_first, + } + + @staticmethod + def _select_validation_debug_snapshot_for_epoch(candidates, fallback, epoch: int): + if candidates: + return candidates[int(epoch) % len(candidates)] + return fallback + + def _prepare_aux_debug_outputs(self, aux_outputs, reference_input): + if not aux_outputs or not isinstance(reference_input, torch.Tensor): + return {} + + prepared = {} + guide_mask = aux_outputs.get("guide_mask") + if guide_mask is not None: + upsampled = F.interpolate( + guide_mask.float(), + size=reference_input.shape[2:], + mode="trilinear", + align_corners=False, + ).clamp_(0.0, 1.0) + prepared[f"guide_{self._resolve_guide_supervision_target_name()}"] = upsampled.detach() + return prepared + + @staticmethod + def _is_feature_encoder_aux_output(name: str) -> bool: + raw = str(name).strip().lower() + if not raw.startswith("enc_"): + return False + try: + int(raw.split("_", 1)[1]) + except (IndexError, ValueError): + return False + return True + + @staticmethod + def _stack_preview_rows(rows): + valid_rows = [row for row in rows if row is not None] + if not valid_rows: + return None + max_width = max(row.shape[1] for row in valid_rows) + padded_rows = [] + for row in valid_rows: + if row.shape[1] < max_width: + row = np.pad( + row, + ((0, 0), (0, max_width - row.shape[1]), (0, 0)), + mode="constant", + constant_values=0, + ) + padded_rows.append(np.ascontiguousarray(row, dtype=np.uint8)) + return np.ascontiguousarray(np.vstack(padded_rows), dtype=np.uint8) + + def _make_feature_encoder_guide_preview_row(self, aux_outputs_dict, *, row_prefix: str | None = None): + if not aux_outputs_dict: + return None + + stage_items = [ + (name, tensor) + for name, tensor in aux_outputs_dict.items() + if self._is_feature_encoder_aux_output(name) and isinstance(tensor, torch.Tensor) + ] + if not stage_items: + return None + + stage_items.sort(key=lambda item: int(str(item[0]).split("_", 1)[1])) + rendered_panels = [] + max_height = 0 + max_width = 0 + + for aux_name, aux_tensor in stage_items: + aux_tensor = aux_tensor.float() if aux_tensor.dtype == torch.bfloat16 else aux_tensor + aux_np = aux_tensor.detach().cpu().numpy() + while aux_np.ndim > 4: + aux_np = aux_np[0] + if aux_np.ndim == 4: + aux_np = aux_np[0] + + if aux_np.ndim == 2: + preview_slice = aux_np + is_2d = True + else: + preview_slice = aux_np[max(aux_np.shape[0] // 2, 0)] + is_2d = False + + value_range = _compute_display_value_range( + aux_np, + is_2d_run=is_2d, + task_name=aux_name, + task_cfg={}, + ) + preview = convert_slice_to_bgr(preview_slice, value_range=value_range) + rendered_panels.append((aux_name, np.ascontiguousarray(preview, dtype=np.uint8))) + max_height = max(max_height, preview.shape[0]) + max_width = max(max_width, preview.shape[1]) + + labeled_panels = [] + for aux_name, preview in rendered_panels: + if preview.shape[0] != max_height or preview.shape[1] != max_width: + preview = np.array( + Image.fromarray(preview).resize((max_width, max_height), resample=Image.Resampling.BILINEAR), + dtype=np.uint8, + ) + label = aux_name if not row_prefix else f"{row_prefix} {aux_name}" + labeled_panels.append(add_text_label(preview, label)) + + return np.ascontiguousarray(np.hstack(labeled_panels), dtype=np.uint8) - return loss_fn(prediction, ground_truth) + def _make_feature_encoder_guide_preview_image(self, aux_outputs_dict, *, train_aux_outputs_dict=None): + val_row = self._make_feature_encoder_guide_preview_row(aux_outputs_dict, row_prefix="Val") + train_row = self._make_feature_encoder_guide_preview_row(train_aux_outputs_dict, row_prefix="Train") + return self._stack_preview_rows([val_row, train_row]) + + @staticmethod + def _save_preview_image(preview_image, save_path): + if preview_image is None or save_path is None: + return + out_path = Path(save_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + Image.fromarray(np.ascontiguousarray(preview_image, dtype=np.uint8)).save(out_path) + + def _make_aux_preview_image(self, aux_outputs_dict): + if not aux_outputs_dict: + return None + + aux_name = sorted(aux_outputs_dict.keys())[0] + aux_tensor = aux_outputs_dict[aux_name] + if not isinstance(aux_tensor, torch.Tensor): + return None + if aux_tensor.dtype == torch.bfloat16: + aux_tensor = aux_tensor.float() + + aux_np = aux_tensor.detach().cpu().numpy() + while aux_np.ndim > 4: + aux_np = aux_np[0] + if aux_np.ndim == 4: + aux_np = aux_np[0] + + if aux_np.ndim == 2: + preview_slice = aux_np + is_2d = True + else: + preview_slice = aux_np[max(aux_np.shape[0] // 2, 0)] + is_2d = False + + value_range = _compute_display_value_range( + aux_np, + is_2d_run=is_2d, + task_name=aux_name, + task_cfg={}, + ) + preview = convert_slice_to_bgr(preview_slice, value_range=value_range) + return np.ascontiguousarray(preview, dtype=np.uint8) + + def _build_debug_media_payload(self, debug_preview_image=None, guide_preview_image=None): + payload = {} + for key, image in (("debug_image", debug_preview_image), ("debug_guide_image", guide_preview_image)): + if image is None: + continue + image_to_log = image + if image_to_log.ndim == 3 and image_to_log.shape[2] == 3: + image_to_log = image_to_log[..., ::-1] + payload[key] = np.ascontiguousarray(image_to_log) + return payload + + @staticmethod + def _attach_debug_media_to_wandb_metrics(metrics, media_payload, wandb_module): + if "debug_image" in media_payload: + metrics["debug_image"] = wandb_module.Image(media_payload["debug_image"]) + if "debug_guide_image" in media_payload: + metrics["debug_guide_image"] = wandb_module.Image(media_payload["debug_guide_image"]) + return metrics # --- losses ---- # def _build_loss(self): @@ -626,6 +1237,17 @@ def _get_optimizer(self, model): 'learning_rate': self.mgr.initial_lr, 'weight_decay': self.mgr.weight_decay } + tr_configs = getattr(self.mgr, 'tr_configs', {}) or {} + if 'betas' in tr_configs: + optimizer_config['betas'] = tuple(tr_configs['betas']) + if 'eps' in tr_configs: + optimizer_config['eps'] = float(tr_configs['eps']) + if 'momentum' in tr_configs: + optimizer_config['momentum'] = float(tr_configs['momentum']) + if 'dampening' in tr_configs: + optimizer_config['dampening'] = float(tr_configs['dampening']) + if 'nesterov' in tr_configs: + optimizer_config['nesterov'] = bool(tr_configs['nesterov']) return create_optimizer(optimizer_config, model) @@ -890,11 +1512,13 @@ def _initialize_training(self): print("\nDetected S3 paths in configuration") setup_multiprocessing_for_s3() - # the is_training flag forces the dataset to perform augmentations - # we put augmentations in the dataset class so we can use the __getitem__ method - # for free multi processing of augmentations , alternatively you can perform on-device within the train loop + # By default the training dataset applies augmentations in __getitem__ so + # workers can parallelize them. When augment_on_device=True we preserve the + # same pipeline here and disable dataset-side augmentation to avoid applying + # transforms twice. + stage_start = perf_counter() train_dataset = self._configure_dataset(is_training=True) - # Keep a handle to the training dataset for on-device augmentation + self._record_startup_timing("train_dataset_init", perf_counter() - stage_start) self._train_dataset = train_dataset if self._profile_augmentations: self._augmentation_names = getattr(train_dataset, '_augmentation_names', None) @@ -911,7 +1535,9 @@ def _initialize_training(self): raise ValueError(f"Could not determine data format for validation directory: {val_mgr.data_path}") val_mgr.data_format = detected_val_fmt + stage_start = perf_counter() val_dataset = self._build_dataset_for_mgr(val_mgr, is_training=False) + self._record_startup_timing("val_dataset_init", perf_counter() - stage_start) print(f"Using {val_mgr.data_format} dataset format (validation from --val-dir)") else: # Reuse same source for validation without re-running expensive image checks @@ -919,15 +1545,21 @@ def _initialize_training(self): val_mgr = deepcopy(self.mgr) setattr(val_mgr, 'skip_image_checks', True) + stage_start = perf_counter() val_dataset = self._build_dataset_for_mgr(val_mgr, is_training=False) + self._record_startup_timing("val_dataset_init", perf_counter() - stage_start) - + stage_start = perf_counter() autodetect_sample = self._prepare_sample(train_dataset[0], is_training=True) if len(train_dataset) > 0 else None + self._record_startup_timing("sample_autodetect", perf_counter() - stage_start) self.mgr.auto_detect_channels(dataset=train_dataset, sample=autodetect_sample) + stage_start = perf_counter() model = self._build_model() + self._record_startup_timing("model_build", perf_counter() - stage_start) self._ds_scales = None self._ds_weights = None + stage_start = perf_counter() optimizer = self._get_optimizer(model) scheduler, is_per_iteration_scheduler = self._get_scheduler(optimizer) self._is_per_iteration_scheduler = is_per_iteration_scheduler # Store for later use @@ -974,23 +1606,21 @@ def _initialize_training(self): print("Using CUDA AMP with bfloat16 (GradScaler disabled)") scaler = self._get_scaler(self.device.type, use_amp=use_amp, amp_dtype=self.amp_dtype) + self._record_startup_timing("optimizer_scheduler_scaler_init", perf_counter() - stage_start) + stage_start = perf_counter() train_dataloader, val_dataloader, train_indices, val_indices = self._configure_dataloaders(train_dataset, val_dataset) + self._record_startup_timing("dataloader_build", perf_counter() - stage_start) - # Wrap model with DDP if distributed - if self.is_distributed: - if self.device.type == 'cuda': - model = DDP(model, device_ids=[self.assigned_gpu_id], output_device=self.assigned_gpu_id, find_unused_parameters=False) - else: - model = DDP(model) - os.makedirs(self.mgr.ckpt_out_base, exist_ok=True) - model_ckpt_dir = os.path.join(self.mgr.ckpt_out_base, self.mgr.model_name) + ckpt_out_base = str(self.mgr.ckpt_out_base) + os.makedirs(ckpt_out_base, exist_ok=True) + model_ckpt_dir = os.path.join(ckpt_out_base, self.mgr.model_name) os.makedirs(model_ckpt_dir, exist_ok=True) now = datetime.now() date_str = now.strftime('%m%d%y') time_str = now.strftime('%H%M') - ckpt_dir = os.path.join('checkpoints', f"{self.mgr.model_name}_{date_str}{time_str}") + ckpt_dir = os.path.join(ckpt_out_base, f"{self.mgr.model_name}_{date_str}{time_str}") os.makedirs(ckpt_dir, exist_ok=True) loss_overrides = self._capture_loss_overrides() @@ -1033,12 +1663,24 @@ def _initialize_training(self): self._ds_scales = None self._ds_weights = None loss_fns = self._build_loss() + self._initialize_ema_model(model) + raw_model = self._unwrap_model(model) + if hasattr(raw_model, "_compile_guidance_submodules"): + stage_start = perf_counter() + compiled_guide_modules = raw_model._compile_guidance_submodules(device_type=self.device.type) + if compiled_guide_modules: + self._record_startup_timing("guide_submodule_compile", perf_counter() - stage_start) + if not self.is_distributed or self.rank == 0: + print("Compiled guide submodules: " + ", ".join(compiled_guide_modules)) - if self.device.type == 'cuda': - try: - model = torch.compile(model) - except Exception as e: - print(f"torch.compile failed; continuing without compile. Reason: {e}") + compile_policy = self._resolve_compile_policy(model) + if compile_policy == "module": + model = self._maybe_compile_model(model) + stage_start = perf_counter() + model = self._wrap_model_for_distributed_training(model) + self._record_startup_timing("ddp_wrap", perf_counter() - stage_start) + if compile_policy == "ddp_wrapper": + model = self._maybe_compile_model(model) return { 'model': model, @@ -1115,18 +1757,45 @@ def _initialize_wandb(self, train_dataset, val_dataset, train_indices, val_indic def _extract_targets(self, data_dict): # Only include tensor targets; skip metadata and lists (e.g., 'regression_keys') + non_blocking = self.device.type == 'cuda' return { - k: v.to(self.device) + k: v.to(self.device, non_blocking=non_blocking) for k, v in data_dict.items() if k not in ["image", "patch_info", "is_unlabeled", "regression_keys"] and hasattr(v, "to") } def _get_model_outputs(self, model, data_dict): - inputs = data_dict["image"].to(self.device) + non_blocking = self.device.type == 'cuda' + inputs = data_dict["image"].to(self.device, non_blocking=non_blocking) targets_dict = self._extract_targets(data_dict) + if self.device.type == 'cuda' and self._volume_dilate_by_name: + padding_mask = data_dict["padding_mask"].to(self.device, non_blocking=True) + volume_names = data_dict["patch_info"]["volume_name"] + for i, volume_name in enumerate(volume_names): + distance = self._volume_dilate_by_name.get(volume_name) + if distance in (None, 0): + continue + for target_name in self.mgr.targets: + if target_name not in targets_dict: + continue + target_tensor = targets_dict[target_name] + target_info = self.mgr.targets.get(target_name, {}) + ignore_label = target_info.get("ignore_label", target_info.get("ignore_index", target_info.get("ignore_value"))) + targets_dict[target_name][i:i + 1] = dilate_label_batch_with_cucim( + target_tensor[i:i + 1], + padding_mask[i:i + 1], + distance, + ignore_label, + ) - outputs = model(inputs) + raw_model = self._unwrap_model(model) + if getattr(raw_model, "guide_enabled", False): + outputs, aux_outputs = model(inputs, return_aux=True) + self._current_aux_outputs = aux_outputs + else: + outputs = model(inputs) + self._current_aux_outputs = {} # If deep supervision is enabled, prepare lists of downsampled targets if getattr(self.mgr, 'enable_deep_supervision', False): @@ -1193,7 +1862,7 @@ def _train_step(self, model, data_dict, loss_fns, use_amp, autocast_ctx, epoch, else: dd = {} for k, v in data_dict.items(): - dd[k] = v.to(self.device) if isinstance(v, torch.Tensor) else v + dd[k] = v.to(self.device, non_blocking=(self.device.type == 'cuda')) if isinstance(v, torch.Tensor) else v # Apply transforms per-sample (transforms expect unbatched (C, ...) tensors) try: @@ -1205,25 +1874,43 @@ def _train_step(self, model, data_dict, loss_fns, use_amp, autocast_ctx, epoch, self._accumulate_aug_timers(data_for_forward) + should_time_first_step = self.startup_timing and epoch == 0 and step == 0 + if should_time_first_step: + step_stage_start = perf_counter() with autocast_ctx: inputs, targets_dict, outputs = self._get_model_outputs(model, data_for_forward) + if should_time_first_step: + self._record_startup_timing("first_forward", perf_counter() - step_stage_start) + step_stage_start = perf_counter() total_loss, task_losses = self._compute_train_loss(outputs, targets_dict, loss_fns) + if should_time_first_step: + self._record_startup_timing("first_loss_compute", perf_counter() - step_stage_start) # Handle gradient accumulation, clipping, and optimizer step # Scale loss by accumulation steps to maintain same effective batch size scaled_loss = total_loss / grad_accumulate_n # backward + if should_time_first_step: + step_stage_start = perf_counter() scaler.scale(scaled_loss).backward() + if should_time_first_step: + self._record_startup_timing("first_backward", perf_counter() - step_stage_start) optimizer_stepped = False if (step + 1) % grad_accumulate_n == 0 or (step + 1) == num_iters: + should_time_optimizer = self.startup_timing and not self._first_optimizer_timing_recorded + if should_time_optimizer: + step_stage_start = perf_counter() scaler.unscale_(optimizer) grad_clip = getattr(self.mgr, 'gradient_clip', 12.0) torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) scaler.step(optimizer) scaler.update() optimizer_stepped = True + if should_time_optimizer: + self._record_startup_timing("first_optimizer_step", perf_counter() - step_stage_start) + self._first_optimizer_timing_recorded = True return total_loss, task_losses, inputs, targets_dict, outputs, optimizer_stepped @@ -1276,6 +1963,13 @@ def _compute_task_losses(self, outputs, targets_dict, loss_fns, *, apply_task_we total_loss = total_loss + weighted_loss.to(total_loss.dtype) task_losses[t_name] = weighted_loss.detach().cpu().item() + if self.guide_loss_weight > 0.0: + guide_loss = self._compute_guide_alignment_loss(targets_dict) + if guide_loss is not None: + weighted_guide_loss = self.guide_loss_weight * guide_loss + total_loss = total_loss + weighted_guide_loss.to(total_loss.dtype) + task_losses["guide_mask"] = weighted_guide_loss.detach().cpu().item() + return total_loss, task_losses def _compute_train_loss(self, outputs, targets_dict, loss_fns): @@ -1349,11 +2043,10 @@ def _prepare_metrics_for_logging(self, epoch, step, epoch_losses, current_lr=Non metrics = {"epoch": epoch, "step": step} - # Add training losses - for t_name in self.mgr.targets: - if t_name in epoch_losses and len(epoch_losses[t_name]) > 0: - # Use recent average for training losses - metrics[f"train_loss_{t_name}"] = np.mean(epoch_losses[t_name][-100:]) + # Add training losses, including auxiliary entries such as guide_mask + for t_name, losses in epoch_losses.items(): + if len(losses) > 0: + metrics[f"train_loss_{t_name}"] = np.mean(losses[-100:]) # Add total training loss if epoch_losses: @@ -1365,18 +2058,20 @@ def _prepare_metrics_for_logging(self, epoch, step, epoch_losses, current_lr=Non if current_lr is not None: metrics["learning_rate"] = current_lr - # Add validation losses if provided + # Add validation losses if provided, including auxiliary entries such as guide_mask if val_losses is not None: total_val_loss = 0.0 - for t_name in self.mgr.targets: - if t_name in val_losses and len(val_losses[t_name]) > 0: - val_avg = np.mean(val_losses[t_name]) + num_val_entries = 0 + for t_name, losses in val_losses.items(): + if len(losses) > 0: + val_avg = np.mean(losses) metrics[f"val_loss_{t_name}"] = val_avg total_val_loss += val_avg + num_val_entries += 1 # Add total validation loss - if self.mgr.targets: - metrics["val_loss_total"] = total_val_loss / len(self.mgr.targets) + if num_val_entries > 0: + metrics["val_loss_total"] = total_val_loss / num_val_entries return metrics @@ -1462,6 +2157,8 @@ def train(self): train_sample_input = None train_sample_targets = None train_sample_outputs = None + train_sample_aux_outputs = None + train_sample_raw_aux_outputs = None print(f"Using optimizer : {optimizer.__class__.__name__}") print(f"Using scheduler : {scheduler.__class__.__name__} (per-iteration: {is_per_iteration_scheduler})") @@ -1471,7 +2168,10 @@ def train(self): if i % grad_accumulate_n == 0: optimizer.zero_grad(set_to_none=True) + fetch_start = perf_counter() if (self.startup_timing and epoch == 0 and i == 0) else None data_dict = next(train_iter) + if fetch_start is not None: + self._record_startup_timing("first_batch_fetch", perf_counter() - fetch_start) global_step += 1 # Setup autocast context (dtype resolved based on CLI/config) @@ -1542,10 +2242,21 @@ def train(self): train_sample_outputs[t_name] = p_val[0][b_idx: b_idx + 1] else: train_sample_outputs[t_name] = p_val[b_idx: b_idx + 1] + train_sample_raw_aux_outputs = self._slice_aux_outputs_for_debug( + getattr(self, "_current_aux_outputs", {}), + b_idx, + ) + train_sample_aux_outputs = self._prepare_aux_debug_outputs( + train_sample_raw_aux_outputs, + train_sample_input, + ) if optimizer_stepped and is_per_iteration_scheduler: scheduler.step() + if self.startup_timing and self._first_optimizer_timing_recorded and not self._startup_timing_logged: + self._flush_startup_timing(prefix="startup_and_first_step") + if pbar is not None: loss_str = " | ".join([f"{t}: {np.mean(epoch_losses[t][-100:]):.4f}" for t in epoch_losses.keys() if len(epoch_losses[t]) > 0]) @@ -1607,6 +2318,9 @@ def train(self): with torch.no_grad(): val_losses = {t_name: [] for t_name in self.mgr.targets} debug_preview_image = None + debug_guide_preview_image = None + debug_preview_fallback = None + debug_preview_candidates = [] # Initialize evaluation metrics evaluation_metrics = self._initialize_evaluation_metrics() @@ -1658,88 +2372,21 @@ def train(self): continue metric.update(pred=pred_val, gt=gt_val, mask=mask_tensor) - if i == 0: - # Find first non-zero sample for debug visualization, but save even if all zeros - b_idx = 0 - found_non_zero = False - - first_target_any = next(iter(targets_dict.values())) - first_target = first_target_any[0] if isinstance(first_target_any, (list, tuple)) else first_target_any - if torch.any(first_target[0] != 0): - found_non_zero = True - else: - # Look for a non-zero sample - for b in range(first_target.shape[0]): - if torch.any(first_target[b] != 0): - b_idx = b - found_non_zero = True - break - - if True: # Was: if found_non_zero: - # Slicing shape: [1, c, z, y, x ] - inputs_first = inputs[b_idx: b_idx + 1] - - targets_dict_first_all = {} - for t_name, t_val in targets_dict.items(): - if isinstance(t_val, (list, tuple)): - targets_dict_first_all[t_name] = t_val[0][b_idx: b_idx + 1] - else: - targets_dict_first_all[t_name] = t_val[b_idx: b_idx + 1] - - outputs_dict_first = {} - for t_name, p_val in outputs.items(): - if isinstance(p_val, (list, tuple)): - outputs_dict_first[t_name] = p_val[0][b_idx: b_idx + 1] - else: - outputs_dict_first[t_name] = p_val[b_idx: b_idx + 1] - - # Use human-friendly 1-based epoch numbering in debug image filenames - debug_img_path = f"{ckpt_dir}/{self.mgr.model_name}_debug_epoch{epoch + 1}.gif" - - # handle skel data from skeleton-based losses - skeleton_dict = None - train_skeleton_dict = None - if 'skel' in targets_dict_first_all: - skeleton_dict = {'segmentation': targets_dict_first_all.get('skel')} - # Check if train_sample_targets_all exists (from earlier training step) - if 'train_sample_targets_all' in locals() and train_sample_targets_all and 'skel' in train_sample_targets_all: - train_skeleton_dict = {'segmentation': train_sample_targets_all.get('skel')} - - targets_dict_first = {} - for t_name, t_tensor in targets_dict_first_all.items(): - if t_name not in ['skel', 'is_unlabeled']: - targets_dict_first[t_name] = t_tensor - - # Check for custom debug visualization (e.g., self-supervised trainers) - custom_debug_method = getattr(self, '_save_lejepa_debug', None) - if custom_debug_method is not None: - saved_path = custom_debug_method(debug_img_path, epoch) - if saved_path: - debug_gif_history.append((epoch, saved_path)) - else: - # Get unlabeled debug samples if available (for semi-supervised trainers) - unlabeled_input = getattr(self, '_debug_unlabeled_input', None) - unlabeled_pseudo = getattr(self, '_debug_unlabeled_pseudo_label', None) - unlabeled_pred = getattr(self, '_debug_unlabeled_student_pred', None) - - _, debug_preview_image = save_debug( - input_volume=inputs_first, - targets_dict=targets_dict_first, - outputs_dict=outputs_dict_first, - tasks_dict=self.mgr.targets, - # dictionary, e.g. {"sheet": {"activation":"sigmoid"}, "normals": {"activation":"none"}} - epoch=epoch, - save_path=debug_img_path, - train_input=train_sample_input, - train_targets_dict=train_sample_targets, - train_outputs_dict=train_sample_outputs, - skeleton_dict=skeleton_dict, - train_skeleton_dict=train_skeleton_dict, - unlabeled_input=unlabeled_input, - unlabeled_pseudo_dict=unlabeled_pseudo, - unlabeled_outputs_dict=unlabeled_pred - ) - debug_gif_history.append((epoch, debug_img_path)) + if debug_preview_image is None: + b_idx, found_non_zero = self._select_debug_sample_index_from_targets(targets_dict) + snapshot = self._capture_validation_debug_snapshot( + inputs, + targets_dict, + outputs, + getattr(self, "_current_aux_outputs", {}), + b_idx, + ) + + if i == 0: + debug_preview_fallback = snapshot + + if found_non_zero: + debug_preview_candidates.append(snapshot) loss_str = " | ".join([f"{t}: {np.mean(val_losses[t]):.4f}" for t in self.mgr.targets if len(val_losses[t]) > 0]) @@ -1747,6 +2394,65 @@ def train(self): del outputs, inputs, targets_dict + selected_debug_snapshot = self._select_validation_debug_snapshot_for_epoch( + debug_preview_candidates, + debug_preview_fallback, + epoch, + ) + + if debug_preview_image is None and selected_debug_snapshot is not None: + inputs_first = selected_debug_snapshot["input"] + targets_dict_first_all = selected_debug_snapshot["targets_all"] + outputs_dict_first = selected_debug_snapshot["outputs"] + raw_aux_outputs_dict_first = selected_debug_snapshot["raw_aux_outputs"] + aux_outputs_dict_first = self._prepare_aux_debug_outputs( + raw_aux_outputs_dict_first, + inputs_first, + ) + debug_img_path = f"{ckpt_dir}/{self.mgr.model_name}_debug_epoch{epoch + 1}.gif" + guide_debug_img_path = f"{ckpt_dir}/{self.mgr.model_name}_guide_epoch{epoch + 1}.png" + skeleton_dict = None + train_skeleton_dict = None + if 'skel' in targets_dict_first_all: + skeleton_dict = {'segmentation': targets_dict_first_all.get('skel')} + if 'train_sample_targets_all' in locals() and train_sample_targets_all and 'skel' in train_sample_targets_all: + train_skeleton_dict = {'segmentation': train_sample_targets_all.get('skel')} + targets_dict_first = {} + for t_name, t_tensor in targets_dict_first_all.items(): + if t_name not in ['skel', 'is_unlabeled']: + targets_dict_first[t_name] = t_tensor + save_debug_media = bool(getattr(self.mgr, 'save_gifs', True)) + unlabeled_input = getattr(self, '_debug_unlabeled_input', None) + unlabeled_pseudo = getattr(self, '_debug_unlabeled_pseudo_label', None) + unlabeled_pred = getattr(self, '_debug_unlabeled_student_pred', None) + _, debug_preview_image = save_debug( + input_volume=inputs_first, + targets_dict=targets_dict_first, + outputs_dict=outputs_dict_first, + aux_outputs_dict=aux_outputs_dict_first, + tasks_dict=self.mgr.targets, + epoch=epoch, + save_path=debug_img_path, + train_input=train_sample_input, + train_targets_dict=train_sample_targets, + train_outputs_dict=train_sample_outputs, + train_aux_outputs_dict=train_sample_aux_outputs, + skeleton_dict=skeleton_dict, + train_skeleton_dict=train_skeleton_dict, + unlabeled_input=unlabeled_input, + unlabeled_pseudo_dict=unlabeled_pseudo, + unlabeled_outputs_dict=unlabeled_pred, + save_media=save_debug_media, + ) + debug_guide_preview_image = self._make_feature_encoder_guide_preview_image( + raw_aux_outputs_dict_first, + train_aux_outputs_dict=train_sample_raw_aux_outputs, + ) + if save_debug_media and debug_guide_preview_image is not None: + self._save_preview_image(debug_guide_preview_image, guide_debug_img_path) + if save_debug_media: + debug_gif_history.append((epoch, debug_img_path)) + print(f"\n[Validation] Epoch {epoch + 1} summary:") total_val_loss = 0.0 for t_name in self.mgr.targets: @@ -1783,13 +2489,15 @@ def train(self): import wandb - if debug_preview_image is not None: - preview_to_log = debug_preview_image - if preview_to_log.ndim == 3 and preview_to_log.shape[2] == 3: - # Convert BGR (OpenCV) to RGB for wandb - preview_to_log = preview_to_log[..., ::-1] - preview_to_log = np.ascontiguousarray(preview_to_log) - val_metrics["debug_image"] = wandb.Image(preview_to_log) + media_payload = self._build_debug_media_payload( + debug_preview_image=debug_preview_image, + guide_preview_image=debug_guide_preview_image, + ) + val_metrics = self._attach_debug_media_to_wandb_metrics( + val_metrics, + media_payload, + wandb, + ) wandb.log(val_metrics) diff --git a/vesuvius/src/vesuvius/models/utils.py b/vesuvius/src/vesuvius/models/utils.py index 773568483..0eb8f6385 100644 --- a/vesuvius/src/vesuvius/models/utils.py +++ b/vesuvius/src/vesuvius/models/utils.py @@ -35,6 +35,8 @@ def __init__(self, neg_slope: float = 1e-2): def __call__(self, module): if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): + if hasattr(module, "weight") and isinstance(module.weight, torch.Tensor) and not module.weight.requires_grad: + return module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) if module.bias is not None: module.bias = nn.init.constant_(module.bias, 0) @@ -46,6 +48,8 @@ def __init__(self, gain: int = 1): def __call__(self, module): if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): + if hasattr(module, "weight") and isinstance(module.weight, torch.Tensor) and not module.weight.requires_grad: + return module.weight = nn.init.xavier_uniform_(module.weight, self.gain) if module.bias is not None: module.bias = nn.init.constant_(module.bias, 0) diff --git a/vesuvius/src/vesuvius/utils/plotting.py b/vesuvius/src/vesuvius/utils/plotting.py index 235e30e6b..b0dba03e7 100644 --- a/vesuvius/src/vesuvius/utils/plotting.py +++ b/vesuvius/src/vesuvius/utils/plotting.py @@ -205,6 +205,14 @@ def add_text_label(img, text): return np.array(pil_img, dtype=np.uint8) +def _format_aux_panel_label(name: str) -> str: + raw = str(name).strip() + if raw.startswith("guide_"): + target_name = raw[len("guide_"):].replace("_", " ") + return f"Guide {target_name}" + return raw.replace("_", " ") + + def _apply_activation(array_np: np.ndarray, activation: Optional[str], *, is_surface: bool) -> np.ndarray: if activation is None or str(activation).lower() in {"none", "identity"} or is_surface: return array_np @@ -323,6 +331,77 @@ def convert_slice_to_bgr( raise ValueError(f"Expected 2D or 3D array, got shape {slice_2d_or_3d.shape}") +def _scale_with_fixed_value_range_to_8bit( + arr_np: np.ndarray, + value_range: Optional[tuple[float, float]] = None, +) -> np.ndarray: + """Scale an array to uint8 using a fixed global range when provided.""" + arr = arr_np.astype(np.float32, copy=False) + if value_range is None: + return minmax_scale_to_8bit(arr) + + lower, upper = value_range + lower = float(lower) + upper = float(upper) + if not np.isfinite(lower) or not np.isfinite(upper) or upper <= lower: + return minmax_scale_to_8bit(arr) + + arr = np.clip(arr, lower, upper) + arr = (arr - lower) / (upper - lower) * 255.0 + return np.clip(arr, 0, 255).astype(np.uint8) + + +def _gray_volume_to_bgr(volume_zhw: np.ndarray, value_range: Optional[tuple[float, float]]) -> np.ndarray: + gray_8u = _scale_with_fixed_value_range_to_8bit(volume_zhw, value_range) + return np.repeat(gray_8u[..., None], 3, axis=-1) + + +def _render_3d_volume_to_bgr( + arr_np: np.ndarray, + *, + task_name: str | None = None, + task_cfg: Dict | None = None, + value_range: Optional[tuple[float, float]] = None, +) -> np.ndarray: + """Render a full 3D tensor/volume to a Z-stacked uint8 BGR panel volume.""" + task_cfg = task_cfg or {} + task_type = task_cfg.get("visualization") or task_cfg.get("type") + is_surface = ( + (task_name is not None and task_name.endswith("surface_frame")) + or task_type == "surface_frame" + or (arr_np.ndim >= 4 and arr_np.shape[0] == 9) + ) + + if arr_np.ndim == 3: + return _gray_volume_to_bgr(arr_np, value_range) + + if arr_np.ndim != 4: + raise ValueError(f"Expected 3D volume or 4D channel-first tensor, got shape {arr_np.shape}") + + if is_surface: + return np.stack( + [_surface_frame_to_bgr(arr_np[:, z_idx, :, :]) for z_idx in range(arr_np.shape[1])], + axis=0, + ) + + if arr_np.shape[0] == 1: + return _gray_volume_to_bgr(arr_np[0], value_range) + + if arr_np.shape[0] == 3: + rgb_zhwc = np.transpose(arr_np, (1, 2, 3, 0)) + return _scale_with_fixed_value_range_to_8bit(rgb_zhwc, value_range) + + if arr_np.shape[0] == 2: + return _gray_volume_to_bgr(arr_np[1], value_range) + + is_affinity = ( + task_type == "affinity" + or (task_name is not None and "affinity" in task_name.lower()) + ) + reduced = arr_np.mean(axis=0) if is_affinity else arr_np[0] + return _gray_volume_to_bgr(reduced, value_range) + + def save_debug( input_volume: torch.Tensor, # shape [1, C, Z, H, W] for 3D or [1, C, H, W] for 2D targets_dict: dict, # e.g. {"sheet": tensor([1, Z, H, W]), "normals": tensor([3, Z, H, W])} @@ -335,9 +414,12 @@ def save_debug( train_input: torch.Tensor = None, # Optional train sample input train_targets_dict: dict = None, # Optional train sample targets train_outputs_dict: dict = None, # Optional train sample outputs + aux_outputs_dict: dict = None, # Optional display-only aux outputs + train_aux_outputs_dict: dict = None,# Optional display-only train aux outputs skeleton_dict: dict = None, # Optional skeleton data for visualization train_skeleton_dict: dict = None, # Optional train skeleton data apply_activation: bool = True, # Whether to apply activation functions + save_media: bool = True, # Whether to write the debug image/GIF to disk # Unlabeled sample visualization for semi-supervised training unlabeled_input: torch.Tensor = None, # Optional unlabeled sample input unlabeled_pseudo_dict: dict = None, # Teacher predictions (pseudo-labels) @@ -401,10 +483,20 @@ def save_debug( preds_np[t_name] = arr_np + aux_outputs_np = {} + for aux_name, aux_tensor in (aux_outputs_dict or {}).items(): + if aux_tensor.dtype == torch.bfloat16: + aux_tensor = aux_tensor.float() + arr_np = aux_tensor.cpu().numpy() + while arr_np.ndim > (3 if is_2d else 4): + arr_np = arr_np[0] + aux_outputs_np[aux_name] = arr_np + # Process train data if provided train_inp_np = None train_targets_np = {} train_preds_np = {} + train_aux_outputs_np = {} if train_input is not None and train_targets_dict is not None and train_outputs_dict is not None: # Convert BFloat16 to Float32 before numpy conversion @@ -444,6 +536,14 @@ def save_debug( train_preds_np[t_name] = arr_np + for aux_name, aux_tensor in (train_aux_outputs_dict or {}).items(): + if aux_tensor.dtype == torch.bfloat16: + aux_tensor = aux_tensor.float() + arr_np = aux_tensor.cpu().numpy() + while arr_np.ndim > (3 if is_2d else 4): + arr_np = arr_np[0] + train_aux_outputs_np[aux_name] = arr_np + # Process unlabeled data if provided (for semi-supervised training visualization) unlabeled_inp_np = None unlabeled_pseudo_np = {} @@ -515,6 +615,15 @@ def save_debug( ) for t_name, p_arr in preds_np.items() } + aux_ranges = { + aux_name: _compute_display_value_range( + aux_arr, + is_2d_run=is_2d, + task_name=aux_name, + task_cfg={}, + ) + for aux_name, aux_arr in aux_outputs_np.items() + } train_input_range = ( _compute_display_value_range(train_inp_np, is_2d_run=is_2d) if train_inp_np is not None else None @@ -537,6 +646,15 @@ def save_debug( ) for t_name, p_arr in train_preds_np.items() } + train_aux_ranges = { + aux_name: _compute_display_value_range( + aux_arr, + is_2d_run=is_2d, + task_name=aux_name, + task_cfg={}, + ) + for aux_name, aux_arr in train_aux_outputs_np.items() + } unlabeled_input_range = ( _compute_display_value_range(unlabeled_inp_np, is_2d_run=is_2d) if unlabeled_inp_np is not None else None @@ -586,6 +704,18 @@ def _pad_rows_to_uniform_width(rows_list: list[np.ndarray]) -> list[np.ndarray]: # Val row: input, targets (including skels), preds val_imgs = [add_text_label(convert_slice_to_bgr(inp_np, value_range=val_input_range), "Val Input")] + for aux_name in sorted(aux_outputs_np.keys()): + aux_arr = aux_outputs_np[aux_name] + aux_slice = aux_arr[0] if aux_arr.ndim == 3 and aux_arr.shape[0] == 1 else aux_arr + val_imgs.append( + add_text_label( + convert_slice_to_bgr( + aux_slice, + value_range=aux_ranges.get(aux_name), + ), + _format_aux_panel_label(aux_name), + ) + ) # Show all targets (including skeleton data) for t_name in sorted(targets_np.keys()): @@ -625,6 +755,18 @@ def _pad_rows_to_uniform_width(rows_list: list[np.ndarray]) -> list[np.ndarray]: # Train row if available if train_inp_np is not None: train_imgs = [add_text_label(convert_slice_to_bgr(train_inp_np, value_range=train_input_range), "Train Input")] + for aux_name in sorted(train_aux_outputs_np.keys()): + aux_arr = train_aux_outputs_np[aux_name] + aux_slice = aux_arr[0] if aux_arr.ndim == 3 and aux_arr.shape[0] == 1 else aux_arr + train_imgs.append( + add_text_label( + convert_slice_to_bgr( + aux_slice, + value_range=train_aux_ranges.get(aux_name), + ), + _format_aux_panel_label(aux_name), + ) + ) # Show all train targets (including skeleton data) for t_name in sorted(train_targets_np.keys()): @@ -704,187 +846,145 @@ def _pad_rows_to_uniform_width(rows_list: list[np.ndarray]) -> list[np.ndarray]: # Stack rows and save rows = _pad_rows_to_uniform_width(rows) final_img = np.vstack(rows) - out_dir = Path(save_path).parent - out_dir.mkdir(parents=True, exist_ok=True) - print(f"[Epoch {epoch}] Saving PNG to: {save_path}") - # Use PIL for saving - Image.fromarray(final_img).save(save_path) - preview_img = np.ascontiguousarray(final_img, dtype=np.uint8) + if save_media: + out_dir = Path(save_path).parent + out_dir.mkdir(parents=True, exist_ok=True) + print(f"[Epoch {epoch}] Saving PNG to: {save_path}") + # Use PIL for saving + Image.fromarray(final_img).save(save_path) return None, preview_img else: # Build frames for 3D GIF frames = [] - preview_frame = None z_dim = inp_np.shape[0] if inp_np.ndim == 3 else inp_np.shape[1] mid_z_idx = max(z_dim // 2, 0) - for z_idx in range(z_dim): + val_input_bgr = _render_3d_volume_to_bgr(inp_np, value_range=val_input_range) + target_panel_volumes = { + t_name: _render_3d_volume_to_bgr( + t_arr, + task_name=t_name, + task_cfg=tasks_dict.get(t_name, {}), + value_range=target_ranges.get(t_name), + ) + for t_name, t_arr in targets_np.items() + } + pred_panel_volumes = { + t_name: _render_3d_volume_to_bgr( + p_arr, + task_name=t_name, + task_cfg=tasks_dict.get(t_name, {}), + value_range=pred_ranges.get(t_name), + ) + for t_name, p_arr in preds_np.items() + } + aux_panel_volumes = { + aux_name: _render_3d_volume_to_bgr( + aux_arr, + value_range=aux_ranges.get(aux_name), + ) + for aux_name, aux_arr in aux_outputs_np.items() + } + + train_input_bgr = ( + _render_3d_volume_to_bgr(train_inp_np, value_range=train_input_range) + if train_inp_np is not None else None + ) + train_target_panel_volumes = { + t_name: _render_3d_volume_to_bgr( + t_arr, + task_name=t_name, + task_cfg=tasks_dict.get(t_name, {}), + value_range=train_target_ranges.get(t_name), + ) + for t_name, t_arr in train_targets_np.items() + } + train_pred_panel_volumes = { + t_name: _render_3d_volume_to_bgr( + p_arr, + task_name=t_name, + task_cfg=tasks_dict.get(t_name, {}), + value_range=train_pred_ranges.get(t_name), + ) + for t_name, p_arr in train_preds_np.items() + } + train_aux_panel_volumes = { + aux_name: _render_3d_volume_to_bgr( + aux_arr, + value_range=train_aux_ranges.get(aux_name), + ) + for aux_name, aux_arr in train_aux_outputs_np.items() + } + + unlabeled_input_bgr = ( + _render_3d_volume_to_bgr(unlabeled_inp_np, value_range=unlabeled_input_range) + if unlabeled_inp_np is not None else None + ) + unlabeled_pseudo_panel_volumes = { + t_name: _render_3d_volume_to_bgr( + p_arr, + task_name=t_name, + task_cfg=tasks_dict.get(t_name, {}), + value_range=unlabeled_pseudo_ranges.get(t_name), + ) + for t_name, p_arr in unlabeled_pseudo_np.items() + } + unlabeled_pred_panel_volumes = { + t_name: _render_3d_volume_to_bgr( + p_arr, + task_name=t_name, + task_cfg=tasks_dict.get(t_name, {}), + value_range=unlabeled_pred_ranges.get(t_name), + ) + for t_name, p_arr in unlabeled_preds_np.items() + } + + def _build_3d_frame(z_idx: int) -> np.ndarray: rows = [] - - # Get slices - inp_slice = inp_np[z_idx] if inp_np.ndim == 3 else inp_np[:, z_idx, :, :] - - # Val row - val_imgs = [add_text_label(convert_slice_to_bgr(inp_slice, value_range=val_input_range), "Val Input")] - - # Show all targets (including skeleton data) + + val_imgs = [add_text_label(val_input_bgr[z_idx], "Val Input")] + for aux_name in sorted(aux_outputs_np.keys()): + val_imgs.append(add_text_label(aux_panel_volumes[aux_name][z_idx], _format_aux_panel_label(aux_name))) for t_name in sorted(targets_np.keys()): - gt = targets_np[t_name] - if gt.shape[0] == 1: - gt_slice = gt[0, z_idx, :, :] - else: - gt_slice = gt[:, z_idx, :, :] label = f"Skel {t_name.replace('_skel', '')}" if t_name.endswith('_skel') else f"GT {t_name}" - val_imgs.append( - add_text_label( - convert_slice_to_bgr( - gt_slice, - task_name=t_name, - task_cfg=tasks_dict.get(t_name, {}), - value_range=target_ranges.get(t_name), - ), - label, - ) - ) - - # Show predictions (only for actual model outputs) + val_imgs.append(add_text_label(target_panel_volumes[t_name][z_idx], label)) for t_name in pred_task_names: - pred = preds_np[t_name] - if pred.ndim == 4: - if pred.shape[0] == 1: - pred_slice = pred[0, z_idx, :, :] - else: - pred_slice = pred[:, z_idx, :, :] - else: - pred_slice = pred[z_idx, :, :] - val_imgs.append( - add_text_label( - convert_slice_to_bgr( - pred_slice, - task_name=t_name, - task_cfg=tasks_dict.get(t_name, {}), - value_range=pred_ranges.get(t_name), - ), - f"Pred {t_name}", - ) - ) - + val_imgs.append(add_text_label(pred_panel_volumes[t_name][z_idx], f"Pred {t_name}")) rows.append(np.hstack(val_imgs)) - - # Train row if available - if train_inp_np is not None: - train_slice = train_inp_np[z_idx] if train_inp_np.ndim == 3 else train_inp_np[:, z_idx, :, :] - train_imgs = [add_text_label(convert_slice_to_bgr(train_slice, value_range=train_input_range), "Train Input")] - - # Show all train targets (including skeleton data) + + if train_input_bgr is not None: + train_imgs = [add_text_label(train_input_bgr[z_idx], "Train Input")] + for aux_name in sorted(train_aux_outputs_np.keys()): + train_imgs.append(add_text_label(train_aux_panel_volumes[aux_name][z_idx], _format_aux_panel_label(aux_name))) for t_name in sorted(train_targets_np.keys()): - gt = train_targets_np[t_name] - if gt.shape[0] == 1: - gt_slice = gt[0, z_idx, :, :] - else: - gt_slice = gt[:, z_idx, :, :] label = f"Skel {t_name.replace('_skel', '')}" if t_name.endswith('_skel') else f"GT {t_name}" - train_imgs.append( - add_text_label( - convert_slice_to_bgr( - gt_slice, - task_name=t_name, - task_cfg=tasks_dict.get(t_name, {}), - value_range=train_target_ranges.get(t_name), - ), - label, - ) - ) - - # Show train predictions (only for actual model outputs) + train_imgs.append(add_text_label(train_target_panel_volumes[t_name][z_idx], label)) for t_name in pred_task_names: - if t_name in train_preds_np: - pred = train_preds_np[t_name] - if pred.ndim == 4: - if pred.shape[0] == 1: - pred_slice = pred[0, z_idx, :, :] - else: - pred_slice = pred[:, z_idx, :, :] - else: - pred_slice = pred[z_idx, :, :] - train_imgs.append( - add_text_label( - convert_slice_to_bgr( - pred_slice, - task_name=t_name, - task_cfg=tasks_dict.get(t_name, {}), - value_range=train_pred_ranges.get(t_name), - ), - f"Pred {t_name}", - ) - ) - + if t_name in train_pred_panel_volumes: + train_imgs.append(add_text_label(train_pred_panel_volumes[t_name][z_idx], f"Pred {t_name}")) rows.append(np.hstack(train_imgs)) - # Unlabeled row if available (for semi-supervised training) - if unlabeled_inp_np is not None: - unlabeled_slice = unlabeled_inp_np[z_idx] if unlabeled_inp_np.ndim == 3 else unlabeled_inp_np[:, z_idx, :, :] - unlabeled_imgs = [add_text_label(convert_slice_to_bgr(unlabeled_slice, value_range=unlabeled_input_range), "Unlabeled")] - - # Show pseudo-labels (teacher predictions) + if unlabeled_input_bgr is not None: + unlabeled_imgs = [add_text_label(unlabeled_input_bgr[z_idx], "Unlabeled")] for t_name in sorted(unlabeled_pseudo_np.keys()): - pseudo = unlabeled_pseudo_np[t_name] - if pseudo.ndim == 4: - if pseudo.shape[0] == 1: - pseudo_slice = pseudo[0, z_idx, :, :] - else: - pseudo_slice = pseudo[:, z_idx, :, :] - else: - pseudo_slice = pseudo[z_idx, :, :] - unlabeled_imgs.append( - add_text_label( - convert_slice_to_bgr( - pseudo_slice, - task_name=t_name, - task_cfg=tasks_dict.get(t_name, {}), - value_range=unlabeled_pseudo_ranges.get(t_name), - ), - f"Pseudo {t_name}", - ) - ) - - # Show student predictions on unlabeled data + unlabeled_imgs.append(add_text_label(unlabeled_pseudo_panel_volumes[t_name][z_idx], f"Pseudo {t_name}")) for t_name in pred_task_names: - if t_name in unlabeled_preds_np: - pred = unlabeled_preds_np[t_name] - if pred.ndim == 4: - if pred.shape[0] == 1: - pred_slice = pred[0, z_idx, :, :] - else: - pred_slice = pred[:, z_idx, :, :] - else: - pred_slice = pred[z_idx, :, :] - unlabeled_imgs.append( - add_text_label( - convert_slice_to_bgr( - pred_slice, - task_name=t_name, - task_cfg=tasks_dict.get(t_name, {}), - value_range=unlabeled_pred_ranges.get(t_name), - ), - f"Pred {t_name}", - ) - ) - + if t_name in unlabeled_pred_panel_volumes: + unlabeled_imgs.append(add_text_label(unlabeled_pred_panel_volumes[t_name][z_idx], f"Pred {t_name}")) rows.append(np.hstack(unlabeled_imgs)) - # Stack rows for this frame rows = _pad_rows_to_uniform_width(rows) - frame = np.vstack(rows) - # Ensure frame is uint8 and contiguous - frame = np.ascontiguousarray(frame, dtype=np.uint8) + return np.ascontiguousarray(np.vstack(rows), dtype=np.uint8) + + if not save_media: + return None, _build_3d_frame(mid_z_idx) + + for z_idx in range(z_dim): + frame = _build_3d_frame(z_idx) frames.append(frame) - if z_idx == mid_z_idx: - preview_frame = frame.copy() - # Save GIF in a subprocess to avoid crashing main training process on encoder segfaults out_dir = Path(save_path).parent out_dir.mkdir(parents=True, exist_ok=True) @@ -899,19 +999,16 @@ def _pad_rows_to_uniform_width(rows_list: list[np.ndarray]) -> list[np.ndarray]: if proc.is_alive(): proc.terminate() print("Warning: GIF save timed out; skipping debug visualization") - if preview_frame is None and frames: - preview_frame = frames[len(frames) // 2].copy() + preview_frame = frames[mid_z_idx].copy() if frames else None return None, preview_frame if proc.exitcode == 0: print(f"Successfully saved GIF to: {save_path}") - if preview_frame is None and frames: - preview_frame = frames[len(frames) // 2].copy() + preview_frame = frames[mid_z_idx].copy() if frames else None return frames, preview_frame else: print(f"Warning: GIF save failed in subprocess (exit code {proc.exitcode}); skipping") - if preview_frame is None and frames: - preview_frame = frames[len(frames) // 2].copy() + preview_frame = frames[mid_z_idx].copy() if frames else None return None, preview_frame diff --git a/vesuvius/tests/models/build/test_dinovol_local_backbone.py b/vesuvius/tests/models/build/test_dinovol_local_backbone.py new file mode 100644 index 000000000..ad7ddee18 --- /dev/null +++ b/vesuvius/tests/models/build/test_dinovol_local_backbone.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import json +from pathlib import Path + +import pytest +import torch + +from vesuvius.models.build.pretrained_backbones.dinov2 import build_dinov2_backbone +from vesuvius.models.build.pretrained_backbones.dinovol_2_builder import build_dinovol_2_backbone + + +def _tiny_dinovol_model_config() -> dict: + return { + "model_type": "v2", + "input_channels": 1, + "global_crops_size": [16, 16, 16], + "local_crops_size": [16, 16, 16], + "patch_size": [8, 8, 8], + "embed_dim": 48, + "depth": 2, + "num_heads": 4, + "num_reg_tokens": 2, + "mlp_ratio": 2.0, + "drop_path_rate": 0.0, + "qkv_fused": True, + } + + +def _write_checkpoint(path: Path, *, include_config: bool, config: dict | None = None) -> None: + backbone = build_dinovol_2_backbone(_tiny_dinovol_model_config()) + teacher_state = {f"backbone.{key}": value.cpu() for key, value in backbone.state_dict().items()} + checkpoint = {"teacher": teacher_state} + if include_config: + checkpoint["config"] = config + torch.save(checkpoint, path) + + +def test_build_dinov2_backbone_from_local_checkpoint_with_embedded_config(tmp_path: Path): + checkpoint_path = tmp_path / "tiny_dinovol.pt" + config = { + "model": _tiny_dinovol_model_config(), + "dataset": { + "global_crop_size": [16, 16, 16], + "local_crop_size": [16, 16, 16], + }, + } + _write_checkpoint(checkpoint_path, include_config=True, config=config) + + backbone = build_dinov2_backbone( + str(checkpoint_path), + input_channels=1, + input_shape=(16, 16, 16), + ) + + x = torch.randn(1, 1, 16, 16, 16) + features = backbone(x) + + assert backbone.patch_embed_size == (8, 8, 8) + assert len(features) == 1 + assert features[0].shape == (1, 48, 2, 2, 2) + + +def test_build_dinov2_backbone_from_local_checkpoint_with_sidecar_config(tmp_path: Path): + checkpoint_path = tmp_path / "tiny_dinovol_no_config.pt" + sidecar_path = tmp_path / "guide_backbone.json" + config = { + "model": _tiny_dinovol_model_config(), + "dataset": { + "global_crop_size": [16, 16, 16], + "local_crop_size": [16, 16, 16], + }, + } + _write_checkpoint(checkpoint_path, include_config=False) + sidecar_path.write_text(json.dumps(config), encoding="utf-8") + + backbone = build_dinov2_backbone( + str(checkpoint_path), + input_channels=1, + input_shape=(16, 16, 16), + config_path=str(sidecar_path), + ) + + x = torch.randn(1, 1, 16, 16, 16) + features = backbone(x) + + assert features[0].shape == (1, 48, 2, 2, 2) + + +def test_build_dinov2_backbone_local_checkpoint_requires_config(tmp_path: Path): + checkpoint_path = tmp_path / "missing_config.pt" + _write_checkpoint(checkpoint_path, include_config=False) + + with pytest.raises(ValueError, match="does not contain an embedded config"): + build_dinov2_backbone( + str(checkpoint_path), + input_channels=1, + input_shape=(16, 16, 16), + ) diff --git a/vesuvius/tests/models/build/test_guided_network.py b/vesuvius/tests/models/build/test_guided_network.py new file mode 100644 index 000000000..8c1e447dd --- /dev/null +++ b/vesuvius/tests/models/build/test_guided_network.py @@ -0,0 +1,960 @@ +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import pytest +import torch +import torch.nn.functional as F + +from vesuvius.models.build.build_network_from_config import NetworkFromConfig +from vesuvius.models.build.guidance import TokenBook3D +from vesuvius.models.build.pretrained_backbones.dinovol_2_builder import build_dinovol_2_backbone +from vesuvius.models.build.pretrained_backbones.dinov2 import ( + PixelShuffleConvHeadBigDinov2Decoder, + PixelShuffleConvDinov2Decoder, + build_dinov2_backbone, + build_dinov2_decoder, +) +from vesuvius.models.build.simple_conv_blocks import ConvDropoutNormReLU +from vesuvius.models.build.transformers.patch_encode_decode import PixelShuffle3D +from vesuvius.models.utils import InitWeights_He + + +def _tiny_dinovol_model_config() -> dict: + return { + "model_type": "v2", + "input_channels": 1, + "global_crops_size": [16, 16, 16], + "local_crops_size": [16, 16, 16], + "patch_size": [8, 8, 8], + "embed_dim": 48, + "depth": 2, + "num_heads": 4, + "num_reg_tokens": 2, + "mlp_ratio": 2.0, + "drop_path_rate": 0.0, + "qkv_fused": True, + } + + +def _write_local_guide_checkpoint(path: Path) -> None: + backbone = build_dinovol_2_backbone(_tiny_dinovol_model_config()) + teacher_state = {f"backbone.{key}": value.cpu() for key, value in backbone.state_dict().items()} + checkpoint = { + "config": { + "model": _tiny_dinovol_model_config(), + "dataset": { + "global_crop_size": [16, 16, 16], + "local_crop_size": [16, 16, 16], + }, + }, + "teacher": teacher_state, + } + torch.save(checkpoint, path) + + +def _make_mgr( + checkpoint_path: Path, + *, + basic_encoder_block: str, + guide_tokenbook_tokens: int | None = None, + guide_compile_policy: str | None = None, + guide_fusion_stage: str | None = None, + guide_tokenbook_prototype_weighting: str | None = None, + guide_feature_gate_alpha: float | None = None, + guide_skip_concat_projector_channels: list[int] | None = None, + target_out_channels: int = 1, +) -> SimpleNamespace: + guided_config = {} + if guide_tokenbook_tokens is not None: + guided_config["guide_tokenbook_tokens"] = int(guide_tokenbook_tokens) + if guide_compile_policy is not None: + guided_config["guide_compile_policy"] = str(guide_compile_policy) + if guide_fusion_stage is not None: + guided_config["guide_fusion_stage"] = str(guide_fusion_stage) + if guide_tokenbook_prototype_weighting is not None: + guided_config["guide_tokenbook_prototype_weighting"] = str(guide_tokenbook_prototype_weighting) + if guide_feature_gate_alpha is not None: + guided_config["guide_feature_gate_alpha"] = float(guide_feature_gate_alpha) + if guide_skip_concat_projector_channels is not None: + guided_config["guide_skip_concat_projector_channels"] = [int(v) for v in guide_skip_concat_projector_channels] + return SimpleNamespace( + targets={"ink": {"out_channels": int(target_out_channels), "activation": "none"}}, + train_patch_size=(16, 16, 16), + train_batch_size=2, + in_channels=1, + autoconfigure=False, + model_name="guided_test", + enable_deep_supervision=False, + spacing=(1.0, 1.0, 1.0), + model_config={ + "features_per_stage": [8, 16, 32], + "n_stages": 3, + "n_blocks_per_stage": [1, 1, 1], + "n_conv_per_stage_decoder": [1, 1], + "kernel_sizes": [[3, 3, 3], [3, 3, 3], [3, 3, 3]], + "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2]], + "basic_encoder_block": basic_encoder_block, + "basic_decoder_block": "ConvBlock", + "bottleneck_block": "BasicBlockD", + "separate_decoders": True, + "guide_backbone": str(checkpoint_path), + "guide_freeze": True, + "guide_tokenbook_sample_rate": 1.0, + "input_shape": [16, 16, 16], + **guided_config, + }, + ) + + +def _make_pretrained_backbone_mgr( + checkpoint_path: Path, + *, + freeze_encoder: bool, + decoder_type: str = "primus_patch_decode", +) -> SimpleNamespace: + return SimpleNamespace( + targets={"surface": {"out_channels": 2, "activation": "none"}}, + train_patch_size=(16, 16, 16), + train_batch_size=2, + in_channels=1, + autoconfigure=False, + model_name="pretrained_backbone_test", + enable_deep_supervision=False, + spacing=(1.0, 1.0, 1.0), + model_config={ + "pretrained_backbone": str(checkpoint_path), + "pretrained_decoder_type": str(decoder_type), + "freeze_encoder": bool(freeze_encoder), + "input_shape": [16, 16, 16], + }, + ) + + +def test_tokenbook3d_returns_unit_interval_mask(): + module = TokenBook3D(n_tokens=8, embed_dim=16) + x = torch.randn(2, 16, 2, 2, 2) + + guide = module(x) + + assert guide.shape == (2, 1, 2, 2, 2) + assert float(guide.min().detach()) >= 0.0 + assert float(guide.max().detach()) <= 1.0 + + +def test_tokenbook3d_token_mlp_returns_unit_interval_mask(): + module = TokenBook3D( + n_tokens=8, + embed_dim=16, + prototype_weighting="token_mlp", + weight_mlp_hidden=12, + ) + x = torch.randn(2, 16, 2, 2, 2) + + guide = module(x) + + assert guide.shape == (2, 1, 2, 2, 2) + assert module.prototype_weight_mlp is not None + assert float(guide.min().detach()) >= 0.0 + assert float(guide.max().detach()) <= 1.0 + + +def test_pixelshuffle3d_upsamples_tuple_factor_shape(): + module = PixelShuffle3D((2, 2, 2)) + x = torch.randn(2, 32, 3, 4, 5) + + y = module(x) + + assert y.shape == (2, 4, 6, 8, 10) + + +@pytest.mark.parametrize("basic_encoder_block", ["ConvBlock", "BasicBlockD"]) +def test_guided_network_forward_shapes_and_aux_outputs(tmp_path: Path, basic_encoder_block: str): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr(checkpoint_path, basic_encoder_block=basic_encoder_block) + model = NetworkFromConfig(mgr) + x = torch.randn(2, 1, 16, 16, 16) + + outputs = model(x) + outputs_with_aux, aux = model(x, return_aux=True) + + assert set(outputs.keys()) == {"ink"} + assert outputs["ink"].shape == (2, 1, 16, 16, 16) + assert set(outputs_with_aux.keys()) == {"ink"} + assert aux["guide_mask"].shape == (2, 1, 2, 2, 2) + + +@pytest.mark.parametrize("basic_encoder_block", ["ConvBlock", "BasicBlockD"]) +def test_feature_encoder_guidance_returns_encoder_stage_aux_outputs(tmp_path: Path, basic_encoder_block: str): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block=basic_encoder_block, + guide_fusion_stage="feature_encoder", + ) + model = NetworkFromConfig(mgr) + + outputs, aux = model(torch.randn(2, 1, 16, 16, 16), return_aux=True) + + assert outputs["ink"].shape == (2, 1, 16, 16, 16) + assert list(aux.keys()) == ["enc_0", "enc_1", "enc_2"] + assert aux["enc_0"].shape == (2, 1, 16, 16, 16) + assert aux["enc_1"].shape == (2, 1, 8, 8, 8) + assert aux["enc_2"].shape == (2, 1, 4, 4, 4) + assert model.final_config["guide_fusion_stage"] == "feature_encoder" + assert model.final_config["guide_stage_keys"] == ["enc_0", "enc_1", "enc_2"] + assert model.final_config["guide_feature_gate_alpha"] == pytest.approx(1.0) + assert len(model.guide_stage_tokenbooks) == 3 + + +@pytest.mark.parametrize("basic_encoder_block", ["ConvBlock", "BasicBlockD"]) +def test_feature_skip_concat_forward_shapes_and_empty_aux_outputs(tmp_path: Path, basic_encoder_block: str): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block=basic_encoder_block, + guide_fusion_stage="feature_skip_concat", + ) + model = NetworkFromConfig(mgr) + + outputs, aux = model(torch.randn(2, 1, 16, 16, 16), return_aux=True) + + assert outputs["ink"].shape == (2, 1, 16, 16, 16) + assert aux == {} + assert list(model.guide_skip_projectors.keys()) == ["enc_0", "enc_1", "enc_2"] + assert all(isinstance(projector, ConvDropoutNormReLU) for projector in model.guide_skip_projectors.values()) + assert all(projector.conv.bias is None for projector in model.guide_skip_projectors.values()) + assert all(isinstance(projector.norm, model.norm_op) for projector in model.guide_skip_projectors.values()) + assert model.final_config["guide_fusion_stage"] == "feature_skip_concat" + assert model.final_config["guide_feature_gate_alpha"] is None + assert model.final_config["guide_skip_concat_projector_channels"] == [8, 16, 32] + assert model.final_config["guide_tokenbook_tokens"] is None + assert model.final_config["guide_tokenbook_prototype_weighting"] is None + assert model.final_config["guide_tokenbook_weight_mlp_hidden"] is None + + +def test_direct_segmentation_returns_two_channel_logits_and_low_res_guide_mask(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_fusion_stage="direct_segmentation", + guide_tokenbook_prototype_weighting="token_mlp", + target_out_channels=2, + ) + model = NetworkFromConfig(mgr) + + outputs, aux = model(torch.randn(2, 1, 16, 16, 16), return_aux=True) + probabilities = torch.softmax(outputs["ink"], dim=1)[:, 1:2] + expected = F.interpolate( + aux["guide_mask"], + size=probabilities.shape[2:], + mode="trilinear", + align_corners=False, + ) + + assert model.shared_encoder is None + assert model.final_config["guide_fusion_stage"] == "direct_segmentation" + assert model.final_config["guide_direct_output_mode"] == "two_channel_logits" + assert outputs["ink"].shape == (2, 2, 16, 16, 16) + assert aux["guide_mask"].shape == (2, 1, 2, 2, 2) + assert torch.allclose(probabilities, expected, atol=1e-5, rtol=1e-5) + + +def test_pretrained_backbone_pixelshuffle_conv_decoder_returns_input_resolution(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_pretrained_backbone_mgr( + checkpoint_path, + freeze_encoder=True, + decoder_type="pixelshuffle_conv", + ) + model = NetworkFromConfig(mgr) + + outputs = model(torch.randn(2, 1, 16, 16, 16)) + + assert outputs["surface"].shape == (2, 2, 16, 16, 16) + assert model.final_config["pretrained_decoder_type"] == "pixelshuffle_conv" + + +def test_pretrained_backbone_pixelshuffle_convhead_big_returns_input_resolution(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_pretrained_backbone_mgr( + checkpoint_path, + freeze_encoder=True, + decoder_type="pixelshuffle_convhead_big", + ) + model = NetworkFromConfig(mgr) + + outputs = model(torch.randn(2, 1, 16, 16, 16)) + + assert outputs["surface"].shape == (2, 2, 16, 16, 16) + assert model.final_config["pretrained_decoder_type"] == "pixelshuffle_convhead_big" + + +def test_encoder_skip_only_feature_gating_uses_residual_alpha_formula(): + class _TestStage(torch.nn.Module): + def __init__(self, value: float): + super().__init__() + self.value = value + self.last_input = None + + def forward(self, x): + self.last_input = x.detach().clone() + return torch.full_like(x, self.value) + + from vesuvius.models.build.encoder import Encoder + + test_encoder = Encoder.__new__(Encoder) + torch.nn.Module.__init__(test_encoder) + stage0 = _TestStage(2.0) + stage1 = _TestStage(3.0) + test_encoder.stem = None + test_encoder.stages = torch.nn.Sequential(stage0, stage1) + test_encoder.output_channels = [1, 1] + test_encoder.strides = [[1, 1, 1], [1, 1, 1]] + test_encoder.return_skips = True + test_encoder.conv_op = torch.nn.Conv3d + test_encoder.norm_op = None + test_encoder.norm_op_kwargs = {} + test_encoder.nonlin = None + test_encoder.nonlin_kwargs = {} + test_encoder.dropout_op = None + test_encoder.dropout_op_kwargs = {} + test_encoder.conv_bias = True + test_encoder.kernel_sizes = [[3, 3, 3], [3, 3, 3]] + + x = torch.ones((1, 1, 2, 2, 2), dtype=torch.float32) + feature_gates = { + "enc_0": torch.full((1, 1, 2, 2, 2), 0.5, dtype=torch.float32), + "enc_1": torch.full((1, 1, 2, 2, 2), 0.25, dtype=torch.float32), + } + + skips = test_encoder(x, feature_gates=feature_gates, feature_gate_alpha=0.9) + + expected_stage0_skip = 2.0 * (0.1 + 0.9 * 0.5) + expected_stage1_skip = 3.0 * (0.1 + 0.9 * 0.25) + assert torch.allclose(skips[0], torch.full_like(skips[0], expected_stage0_skip)) + assert torch.allclose(skips[1], torch.full_like(skips[1], expected_stage1_skip)) + assert torch.allclose(stage1.last_input, torch.full_like(stage1.last_input, 2.0)) + + +def test_feature_skip_concat_decoder_contract_uses_augmented_skip_channels(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_fusion_stage="feature_skip_concat", + ) + model = NetworkFromConfig(mgr) + decoder = model.task_decoders["ink"] + + assert decoder.skip_channels == [16, 32, 64] + assert decoder.bottleneck_input_channels == 64 + assert decoder.decoder_stage_input_channels == [16 + 32, 8 + 16] + assert decoder.decoder_stage_output_channels == [16, 8] + assert model.guide_stage_keys == ["enc_0", "enc_1", "enc_2"] + assert all( + projector.output_channels == native_width + for projector, native_width in zip(model.guide_skip_projectors.values(), [8, 16, 32]) + ) + + +def test_feature_skip_concat_decoder_contract_uses_configured_projector_widths(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_fusion_stage="feature_skip_concat", + guide_skip_concat_projector_channels=[2, 4, 8], + ) + model = NetworkFromConfig(mgr) + decoder = model.task_decoders["ink"] + + assert decoder.skip_channels == [10, 20, 40] + assert decoder.bottleneck_input_channels == 40 + assert decoder.decoder_stage_input_channels == [16 + 20, 8 + 10] + assert decoder.decoder_stage_output_channels == [16, 8] + assert model.final_config["guide_skip_concat_projector_channels"] == [2, 4, 8] + assert [projector.output_channels for projector in model.guide_skip_projectors.values()] == [2, 4, 8] + assert all(projector.conv.bias is None for projector in model.guide_skip_projectors.values()) + assert all(isinstance(projector.norm, model.norm_op) for projector in model.guide_skip_projectors.values()) + + +def test_feature_skip_concat_projects_low_res_guide_features_before_upsampling(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_fusion_stage="feature_skip_concat", + ) + model = NetworkFromConfig(mgr) + + class _RecordingProjector(torch.nn.Module): + def __init__(self, out_channels: int): + super().__init__() + self.out_channels = out_channels + self.last_input_shape = None + + def forward(self, x): + self.last_input_shape = tuple(x.shape) + return torch.ones( + x.shape[0], + self.out_channels, + x.shape[2], + x.shape[3], + x.shape[4], + device=x.device, + dtype=x.dtype, + ) + + model.guide_skip_projectors = torch.nn.ModuleDict( + { + "enc_0": _RecordingProjector(8), + "enc_1": _RecordingProjector(16), + "enc_2": _RecordingProjector(32), + } + ) + low_res_guide = torch.randn(2, 48, 2, 2, 2) + model._compute_guide_features = lambda x: low_res_guide + + encoder_features = [ + torch.randn(2, 8, 16, 16, 16), + torch.randn(2, 16, 8, 8, 8), + torch.randn(2, 32, 4, 4, 4), + ] + + augmented_features, aux = model._build_skip_concat_features( + torch.randn(2, 1, 16, 16, 16), + encoder_features, + ) + + assert aux == {} + assert [feature.shape for feature in augmented_features] == [ + (2, 16, 16, 16, 16), + (2, 32, 8, 8, 8), + (2, 64, 4, 4, 4), + ] + for projector in model.guide_skip_projectors.values(): + assert projector.last_input_shape == (2, 48, 2, 2, 2) + + +def test_guided_network_backprop_updates_tokenbook_but_not_frozen_guide_backbone(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr(checkpoint_path, basic_encoder_block="ConvBlock") + model = NetworkFromConfig(mgr) + model.train() + + x = torch.randn(2, 1, 16, 16, 16) + target = torch.randn(2, 1, 16, 16, 16) + outputs, aux = model(x, return_aux=True) + loss = torch.nn.functional.mse_loss(outputs["ink"], target) + aux["guide_mask"].mean() + loss.backward() + + assert any(parameter.grad is not None for parameter in model.guide_tokenbook.parameters()) + assert all(parameter.grad is None for parameter in model.guide_backbone.parameters()) + + +def test_direct_segmentation_backprop_updates_tokenbook_but_not_frozen_guide_backbone(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_fusion_stage="direct_segmentation", + target_out_channels=2, + ) + model = NetworkFromConfig(mgr) + model.train() + + x = torch.randn(2, 1, 16, 16, 16) + target = torch.randint(0, 2, (2, 16, 16, 16)) + outputs, aux = model(x, return_aux=True) + loss = torch.nn.functional.cross_entropy(outputs["ink"], target) + aux["guide_mask"].mean() + loss.backward() + + assert any(parameter.grad is not None for parameter in model.guide_tokenbook.parameters()) + assert all(parameter.grad is None for parameter in model.guide_backbone.parameters()) + + +def test_init_weights_he_preserves_frozen_guide_backbone_weights(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_fusion_stage="direct_segmentation", + target_out_channels=2, + ) + model = NetworkFromConfig(mgr) + before = { + key: value.detach().clone() + for key, value in model.guide_backbone.state_dict().items() + } + + model.apply(InitWeights_He(neg_slope=0.2)) + + after = model.guide_backbone.state_dict() + assert all(torch.equal(before[key], after[key]) for key in before) + + +def test_feature_skip_concat_requires_projector_channels_to_match_num_stages(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_fusion_stage="feature_skip_concat", + guide_skip_concat_projector_channels=[8, 16], + ) + + with pytest.raises(ValueError, match="guide_skip_concat_projector_channels must have as many entries"): + NetworkFromConfig(mgr) + + +def test_direct_segmentation_requires_single_binary_target(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = SimpleNamespace( + targets={ + "ink": {"out_channels": 2, "activation": "none"}, + "surface": {"out_channels": 2, "activation": "none"}, + }, + train_patch_size=(16, 16, 16), + train_batch_size=2, + in_channels=1, + autoconfigure=False, + model_name="guided_test", + enable_deep_supervision=False, + spacing=(1.0, 1.0, 1.0), + model_config={ + "guide_backbone": str(checkpoint_path), + "guide_freeze": True, + "guide_fusion_stage": "direct_segmentation", + "input_shape": [16, 16, 16], + }, + ) + + with pytest.raises(ValueError, match="exactly one non-auxiliary target"): + NetworkFromConfig(mgr) + + +def test_pretrained_backbone_freeze_encoder_disables_grads_and_preserves_weights_under_init(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_pretrained_backbone_mgr(checkpoint_path, freeze_encoder=True) + model = NetworkFromConfig(mgr) + + assert model.final_config["freeze_encoder"] is True + assert all(not parameter.requires_grad for parameter in model.shared_encoder.parameters()) + + before = { + key: value.detach().clone() + for key, value in model.shared_encoder.state_dict().items() + } + model.apply(InitWeights_He(neg_slope=0.2)) + after = model.shared_encoder.state_dict() + + assert all(torch.equal(before[key], after[key]) for key in before) + model.train() + assert model.shared_encoder.training is False + + +def test_pretrained_backbone_without_freeze_encoder_still_reinitializes_conv_stem(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_pretrained_backbone_mgr(checkpoint_path, freeze_encoder=False) + model = NetworkFromConfig(mgr) + + before = { + key: value.detach().clone() + for key, value in model.shared_encoder.state_dict().items() + } + model.apply(InitWeights_He(neg_slope=0.2)) + after = model.shared_encoder.state_dict() + + changed = [key for key in before if not torch.equal(before[key], after[key])] + assert "backbone.down_projection.proj.weight" in changed + + +def test_build_dinov2_decoder_accepts_pixelshuffle_conv(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + encoder = build_dinov2_backbone(str(checkpoint_path), input_channels=1, input_shape=(16, 16, 16)) + decoder = build_dinov2_decoder("pixelshuffle_conv", encoder, num_classes=2) + + features = torch.randn(1, encoder.embed_dim, 2, 2, 2) + output = decoder(features) + + assert isinstance(decoder, PixelShuffleConvDinov2Decoder) + assert output.shape == (1, 2, 16, 16, 16) + assert all(isinstance(stage[3], torch.nn.GroupNorm) for stage in decoder.decode) + assert all(isinstance(stage[4], torch.nn.GELU) for stage in decoder.decode) + assert all(isinstance(stage[2], torch.nn.Conv3d) for stage in decoder.decode) + assert all(stage[2].kernel_size == (3, 3, 3) for stage in decoder.decode) + assert all(stage[2].padding == (1, 1, 1) for stage in decoder.decode) + assert all(stage[2].bias is None for stage in decoder.decode) + assert decoder.decode[-1][2].out_channels > 2 + assert decoder.final_refine[0].in_channels == decoder.decode[-1][2].out_channels + assert decoder.final_refine[0].bias is None + assert isinstance(decoder.final_refine[1], torch.nn.GroupNorm) + assert isinstance(decoder.final_refine[2], torch.nn.GELU) + assert len(decoder.final_refine) == 5 + assert decoder.final_refine[-1].kernel_size == (1, 1, 1) + assert decoder.final_refine[0].kernel_size == (3, 3, 3) + assert decoder.final_refine[3].kernel_size == (3, 3, 3) + + +def test_build_dinov2_decoder_accepts_pixelshuffle_convhead_big(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + encoder = build_dinov2_backbone(str(checkpoint_path), input_channels=1, input_shape=(16, 16, 16)) + decoder = build_dinov2_decoder("pixelshuffle_convhead_big", encoder, num_classes=2) + + features = torch.randn(1, encoder.embed_dim, 2, 2, 2) + output = decoder(features) + + assert isinstance(decoder, PixelShuffleConvHeadBigDinov2Decoder) + assert output.shape == (1, 2, 16, 16, 16) + assert not hasattr(decoder, "pre_refine") + assert all(isinstance(stage[3], torch.nn.GroupNorm) for stage in decoder.decode) + assert all(isinstance(stage[4], torch.nn.GELU) for stage in decoder.decode) + assert all(isinstance(stage[2], torch.nn.Conv3d) for stage in decoder.decode) + assert all(stage[2].kernel_size == (3, 3, 3) for stage in decoder.decode) + assert all(stage[2].padding == (1, 1, 1) for stage in decoder.decode) + assert all(stage[2].bias is None for stage in decoder.decode) + assert decoder.decode[-1][2].out_channels > 2 + assert decoder.final_refine[0].in_channels == decoder.decode[-1][2].out_channels + assert decoder.final_refine[0].bias is None + assert isinstance(decoder.final_refine[1], torch.nn.GroupNorm) + assert isinstance(decoder.final_refine[2], torch.nn.GELU) + assert len(decoder.final_refine) == 5 + assert decoder.final_refine[-1].kernel_size == (1, 1, 1) + assert decoder.final_refine[0].kernel_size == (7, 7, 7) + assert decoder.final_refine[0].padding == (3, 3, 3) + assert decoder.final_refine[3].kernel_size == (3, 3, 3) + assert decoder.final_refine[3].padding == (1, 1, 1) + + +def test_feature_encoder_guidance_backprop_updates_all_stage_tokenbooks(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_fusion_stage="feature_encoder", + ) + model = NetworkFromConfig(mgr) + model.train() + + x = torch.randn(2, 1, 16, 16, 16) + target = torch.randn(2, 1, 16, 16, 16) + outputs, aux = model(x, return_aux=True) + loss = torch.nn.functional.mse_loss(outputs["ink"], target) + loss = loss + sum(stage_aux.mean() for stage_aux in aux.values()) + loss.backward() + + assert set(aux.keys()) == {"enc_0", "enc_1", "enc_2"} + assert all( + any(parameter.grad is not None for parameter in tokenbook.parameters()) + for tokenbook in model.guide_stage_tokenbooks.values() + ) + assert all(parameter.grad is None for parameter in model.guide_backbone.parameters()) + + +def test_feature_encoder_token_mlp_backprop_updates_stage_weight_mlps(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_fusion_stage="feature_encoder", + guide_tokenbook_prototype_weighting="token_mlp", + ) + model = NetworkFromConfig(mgr) + model.train() + + outputs, aux = model(torch.randn(2, 1, 16, 16, 16), return_aux=True) + loss = outputs["ink"].square().mean() + sum(stage_aux.mean() for stage_aux in aux.values()) + loss.backward() + + assert model.final_config["guide_tokenbook_prototype_weighting"] == "token_mlp" + assert all( + tokenbook.prototype_weight_mlp is not None + for tokenbook in model.guide_stage_tokenbooks.values() + ) + assert all( + any(parameter.grad is not None for parameter in tokenbook.prototype_weight_mlp.parameters()) + for tokenbook in model.guide_stage_tokenbooks.values() + ) + + +def test_guided_network_supports_reduced_tokenbook_prototype_count(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_tokenbook_tokens=3, + ) + model = NetworkFromConfig(mgr) + outputs, aux = model(torch.randn(1, 1, 16, 16, 16), return_aux=True) + + assert model.guide_tokenbook.book.shape[0] == 3 + assert model.final_config["guide_tokenbook_tokens"] == 3 + assert outputs["ink"].shape == (1, 1, 16, 16, 16) + assert aux["guide_mask"].shape == (1, 1, 2, 2, 2) + + +def test_feature_encoder_guidance_supports_reduced_tokenbook_prototype_count(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_tokenbook_tokens=3, + guide_fusion_stage="feature_encoder", + ) + model = NetworkFromConfig(mgr) + outputs, aux = model(torch.randn(1, 1, 16, 16, 16), return_aux=True) + + assert model.final_config["guide_tokenbook_tokens"] == 3 + assert all(tokenbook.book.shape[0] == 3 for tokenbook in model.guide_stage_tokenbooks.values()) + assert outputs["ink"].shape == (1, 1, 16, 16, 16) + assert set(aux.keys()) == {"enc_0", "enc_1", "enc_2"} + + +@pytest.mark.parametrize("guide_compile_policy", ["off", "backbone_only", "tokenbook_only", "all_guidance"]) +def test_guided_network_records_guide_compile_policy(tmp_path: Path, guide_compile_policy: str): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_compile_policy=guide_compile_policy, + ) + model = NetworkFromConfig(mgr) + + assert model.guide_compile_policy == guide_compile_policy + assert model.final_config["guide_compile_policy"] == guide_compile_policy + + +def test_guided_network_all_guidance_compile_policy_compiles_backbone_and_tokenbook(tmp_path: Path, monkeypatch): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_compile_policy="all_guidance", + ) + model = NetworkFromConfig(mgr) + compiled_targets: list[str] = [] + + def _record_compile(module): + compiled_targets.append(type(module).__name__) + return module + + monkeypatch.setattr(model, "_compile_module_in_place", _record_compile) + + compiled_modules = model._compile_guidance_submodules(device_type="cuda") + + assert compiled_modules == ["guide_backbone", "guide_tokenbook"] + assert compiled_targets == ["Dinov2Backbone", "TokenBook3D"] + + +def test_feature_encoder_all_guidance_compile_policy_compiles_backbone_and_stage_tokenbooks(tmp_path: Path, monkeypatch): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_compile_policy="all_guidance", + guide_fusion_stage="feature_encoder", + ) + model = NetworkFromConfig(mgr) + compiled_targets: list[str] = [] + + def _record_compile(module): + compiled_targets.append(type(module).__name__) + return module + + monkeypatch.setattr(model, "_compile_module_in_place", _record_compile) + + compiled_modules = model._compile_guidance_submodules(device_type="cuda") + + assert compiled_modules == [ + "guide_backbone", + "guide_tokenbook:enc_0", + "guide_tokenbook:enc_1", + "guide_tokenbook:enc_2", + ] + assert compiled_targets == ["Dinov2Backbone", "TokenBook3D", "TokenBook3D", "TokenBook3D"] + + +def test_feature_encoder_guidance_records_exact_fusion_stage(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_fusion_stage="feature_encoder", + ) + model = NetworkFromConfig(mgr) + + assert model.guide_fusion_stage == "feature_encoder" + assert model.final_config["guide_fusion_stage"] == "feature_encoder" + assert model.guide_feature_gate_alpha == pytest.approx(1.0) + + +def test_feature_skip_concat_records_exact_fusion_stage_and_no_alpha(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_fusion_stage="feature_skip_concat", + ) + model = NetworkFromConfig(mgr) + + assert model.guide_fusion_stage == "feature_skip_concat" + assert model.final_config["guide_fusion_stage"] == "feature_skip_concat" + assert model.final_config["guide_feature_gate_alpha"] is None + assert model.final_config["guide_tokenbook_tokens"] is None + assert model.final_config["guide_tokenbook_prototype_weighting"] is None + assert model.final_config["guide_tokenbook_weight_mlp_hidden"] is None + + +def test_feature_skip_concat_ignores_stale_tokenbook_and_alpha_keys(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_fusion_stage="feature_skip_concat", + guide_tokenbook_tokens=17, + guide_tokenbook_prototype_weighting="token_mlp", + guide_feature_gate_alpha=0.9, + ) + model = NetworkFromConfig(mgr) + + assert model.guide_fusion_stage == "feature_skip_concat" + assert model.final_config["guide_tokenbook_tokens"] is None + assert model.final_config["guide_tokenbook_prototype_weighting"] is None + assert model.final_config["guide_tokenbook_weight_mlp_hidden"] is None + assert model.final_config["guide_feature_gate_alpha"] is None + + +def test_feature_encoder_guidance_records_configured_gate_alpha(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_fusion_stage="feature_encoder", + guide_feature_gate_alpha=0.9, + ) + model = NetworkFromConfig(mgr) + + assert model.guide_feature_gate_alpha == pytest.approx(0.9) + assert model.final_config["guide_feature_gate_alpha"] == pytest.approx(0.9) + + +def test_feature_encoder_guidance_rejects_invalid_gate_alpha(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_fusion_stage="feature_encoder", + guide_feature_gate_alpha=1.5, + ) + + with pytest.raises(ValueError, match="guide_feature_gate_alpha"): + NetworkFromConfig(mgr) + + +def test_feature_encoder_guidance_records_token_mlp_weighting(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_fusion_stage="feature_encoder", + guide_tokenbook_prototype_weighting="token_mlp", + ) + model = NetworkFromConfig(mgr) + + assert model.guide_tokenbook_prototype_weighting == "token_mlp" + assert model.final_config["guide_tokenbook_prototype_weighting"] == "token_mlp" + + +def test_guided_network_rejects_invalid_guide_compile_policy(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_compile_policy="bad_policy", + ) + + with pytest.raises(ValueError, match="guide_compile_policy"): + NetworkFromConfig(mgr) + + +def test_feature_skip_concat_all_guidance_compile_policy_compiles_backbone_and_projectors(tmp_path: Path, monkeypatch): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_mgr( + checkpoint_path, + basic_encoder_block="ConvBlock", + guide_compile_policy="all_guidance", + guide_fusion_stage="feature_skip_concat", + ) + model = NetworkFromConfig(mgr) + compiled_targets: list[str] = [] + + def _record_compile(module): + compiled_targets.append(type(module).__name__) + return module + + monkeypatch.setattr(model, "_compile_module_in_place", _record_compile) + + compiled_modules = model._compile_guidance_submodules(device_type="cuda") + + assert compiled_modules == ["guide_backbone", "guide_projector:enc_0", "guide_projector:enc_1", "guide_projector:enc_2"] + assert compiled_targets == ["Dinov2Backbone", "ConvDropoutNormReLU", "ConvDropoutNormReLU", "ConvDropoutNormReLU"] + + +def test_guided_network_rejects_primus_architecture(tmp_path: Path): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + + mgr = SimpleNamespace( + targets={"ink": {"out_channels": 1, "activation": "none"}}, + train_patch_size=(16, 16, 16), + train_batch_size=1, + in_channels=1, + autoconfigure=False, + model_name="guided_test", + enable_deep_supervision=False, + spacing=(1.0, 1.0, 1.0), + model_config={ + "architecture_type": "primus_s", + "input_shape": [16, 16, 16], + "patch_embed_size": [8, 8, 8], + "guide_backbone": str(checkpoint_path), + }, + ) + + with pytest.raises(ValueError, match="not supported with primus"): + NetworkFromConfig(mgr) diff --git a/vesuvius/tests/models/build/test_mednext_shapes.py b/vesuvius/tests/models/build/test_mednext_shapes.py new file mode 100644 index 000000000..be216b6c6 --- /dev/null +++ b/vesuvius/tests/models/build/test_mednext_shapes.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +import torch + +from vesuvius.models.build.build_network_from_config import NetworkFromConfig +from vesuvius.models.build.mednext import MedNeXtV1, create_mednext_v1, get_mednext_v1_config + + +def _make_mgr( + patch_size=(32, 32, 32), + *, + targets=None, + separate_decoders=True, + enable_deep_supervision=False, + architecture_type="mednext_v1", + mednext_model_id="B", +): + return SimpleNamespace( + targets=targets or {"surface": {"out_channels": 2, "activation": "none"}}, + train_patch_size=patch_size, + train_batch_size=1, + in_channels=1, + autoconfigure=False, + model_name="mednext_test", + enable_deep_supervision=enable_deep_supervision, + spacing=(1.0, 1.0, 1.0), + model_config={ + "architecture_type": architecture_type, + "mednext_model_id": mednext_model_id, + "mednext_kernel_size": 3, + "separate_decoders": separate_decoders, + }, + ) + + +def test_create_mednext_v1_small_forward_shape_3d(): + model = create_mednext_v1(1, 3, "S", kernel_size=3, deep_supervision=False) + out = model(torch.randn(2, 1, 32, 32, 32)) + + assert isinstance(out, torch.Tensor) + assert out.shape == (2, 3, 32, 32, 32) + + +def test_mednext_v1_deep_supervision_returns_multiscale_outputs(): + model = create_mednext_v1(1, 2, "B", kernel_size=3, deep_supervision=True) + out = model(torch.randn(1, 1, 32, 32, 32)) + + assert isinstance(out, list) + assert len(out) == 5 + assert out[0].shape == (1, 2, 32, 32, 32) + assert out[1].shape == (1, 2, 16, 16, 16) + assert out[2].shape == (1, 2, 8, 8, 8) + assert out[3].shape == (1, 2, 4, 4, 4) + assert out[4].shape == (1, 2, 2, 2, 2) + + +def test_mednext_v1_outside_block_checkpointing_smoke_forward(): + cfg = get_mednext_v1_config("M") + model = MedNeXtV1( + in_channels=1, + n_channels=cfg["n_channels"], + n_classes=2, + exp_r=cfg["exp_r"], + kernel_size=3, + deep_supervision=False, + do_res=True, + do_res_up_down=True, + checkpoint_style="outside_block", + block_counts=cfg["block_counts"], + norm_type="group", + grn=False, + ) + out = model(torch.randn(1, 1, 32, 32, 32)) + + assert out.shape == (1, 2, 32, 32, 32) + + +def test_network_from_config_mednext_v1_forward_shape(): + model = NetworkFromConfig(_make_mgr()) + out = model(torch.randn(2, 1, 32, 32, 32)) + + assert set(out.keys()) == {"surface"} + assert out["surface"].shape == (2, 2, 32, 32, 32) + assert model.final_config["architecture_type"] == "mednext_v1" + assert model.final_config["mednext_model_id"] == "B" + + +def test_network_from_config_mednext_v1_shared_decoder_multi_target(): + mgr = _make_mgr( + targets={ + "surface": {"out_channels": 2, "activation": "none"}, + "mask": {"out_channels": 1, "activation": "none"}, + }, + separate_decoders=False, + ) + mgr.targets = { + "surface": {"out_channels": 2, "activation": "none"}, + "tissue": {"out_channels": 1, "activation": "none"}, + } + model = NetworkFromConfig(mgr) + out = model(torch.randn(1, 1, 32, 32, 32)) + + assert set(out.keys()) == {"surface", "tissue"} + assert out["surface"].shape == (1, 2, 32, 32, 32) + assert out["tissue"].shape == (1, 1, 32, 32, 32) + + +def test_network_from_config_mednext_v1_deep_supervision_returns_multiscale_logits(): + model = NetworkFromConfig(_make_mgr(enable_deep_supervision=True)) + out = model(torch.randn(1, 1, 32, 32, 32)) + + assert isinstance(out["surface"], list) + assert len(out["surface"]) == 5 + assert out["surface"][0].shape == (1, 2, 32, 32, 32) + + +def test_network_from_config_mednext_requires_explicit_model_id(): + mgr = _make_mgr(mednext_model_id="") + del mgr.model_config["mednext_model_id"] + + with pytest.raises(ValueError, match="mednext_model_id"): + NetworkFromConfig(mgr) + + +def test_network_from_config_mednext_v2_uses_instance_norm_and_grn(): + mgr = _make_mgr(architecture_type="mednext_v2", mednext_model_id="L") + model = NetworkFromConfig(mgr) + + assert model.final_config["architecture_type"] == "mednext_v2" + assert model.final_config["mednext_norm_type"] == "instance" + assert model.final_config["mednext_grn"] is True + first_block = model.shared_encoder.enc_block_0[0] + assert isinstance(first_block.norm, torch.nn.InstanceNorm3d) + assert hasattr(first_block, "grn_gamma") + assert hasattr(first_block, "grn_beta") + + +def test_network_from_config_mednext_v2_width_factor_doubles_base_channels(): + model_base = NetworkFromConfig(_make_mgr(architecture_type="mednext_v2", mednext_model_id="B")) + model_wide = NetworkFromConfig( + SimpleNamespace( + **{ + **_make_mgr(architecture_type="mednext_v2", mednext_model_id="B").__dict__, + "model_config": { + **_make_mgr(architecture_type="mednext_v2", mednext_model_id="B").model_config, + "mednext_width_factor": 2, + }, + } + ) + ) + + assert model_base.final_config["mednext_width_factor"] == 1 + assert model_wide.final_config["mednext_width_factor"] == 2 + assert model_wide.final_config["mednext_base_channels"] == 2 * model_base.final_config["mednext_base_channels"] diff --git a/vesuvius/tests/models/build/test_primus_shapes.py b/vesuvius/tests/models/build/test_primus_shapes.py index 8711f9682..efeaa8aa7 100644 --- a/vesuvius/tests/models/build/test_primus_shapes.py +++ b/vesuvius/tests/models/build/test_primus_shapes.py @@ -88,7 +88,10 @@ def test_primus_single_train_step_backward(): loss.backward() assert model.shared_encoder.patch_embed.stem.blocks[0].conv1.conv.weight.grad is not None - assert model.task_heads["ink"].weight.grad is not None + if "ink" in model.task_heads: + assert model.task_heads["ink"].weight.grad is not None + else: + assert model.task_decoders["ink"].patch_decoder.decode[-1].weight.grad is not None optim.step() diff --git a/vesuvius/tests/models/configuration/test_config_manager.py b/vesuvius/tests/models/configuration/test_config_manager.py index 604a71550..ed5592062 100644 --- a/vesuvius/tests/models/configuration/test_config_manager.py +++ b/vesuvius/tests/models/configuration/test_config_manager.py @@ -102,3 +102,109 @@ def test_multiple_reserved_names(self): # Should raise for the first reserved name encountered with pytest.raises(ValueError, match="is reserved"): mgr.validate_target_names(names_with_reserved) + + def test_compile_policy_and_startup_timing_load_from_config(self): + config = { + "tr_config": { + "compile_policy": "module", + "startup_timing": True, + "ddp_find_unused_parameters": False, + "ddp_static_graph": True, + "ddp_gradient_as_bucket_view": True, + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config, f) + temp_path = f.name + + try: + mgr = ConfigManager(verbose=False) + mgr.load_config(temp_path) + assert mgr.compile_policy == "module" + assert mgr.startup_timing is True + assert mgr.ddp_find_unused_parameters is False + assert mgr.ddp_static_graph is True + assert mgr.ddp_gradient_as_bucket_view is True + finally: + Path(temp_path).unlink() + + def test_invalid_compile_policy_raises_value_error(self): + config = { + "tr_config": { + "compile_policy": "not_a_policy", + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config, f) + temp_path = f.name + + try: + mgr = ConfigManager(verbose=False) + with pytest.raises(ValueError, match="compile_policy"): + mgr.load_config(temp_path) + finally: + Path(temp_path).unlink() + + def test_invalid_ddp_find_unused_parameters_raises_value_error(self): + config = { + "tr_config": { + "ddp_find_unused_parameters": "sometimes", + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config, f) + temp_path = f.name + + try: + mgr = ConfigManager(verbose=False) + with pytest.raises(ValueError, match="ddp_find_unused_parameters"): + mgr.load_config(temp_path) + finally: + Path(temp_path).unlink() + + def test_invalid_ddp_static_graph_raises_value_error(self): + config = { + "tr_config": { + "ddp_static_graph": "sometimes", + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config, f) + temp_path = f.name + + try: + mgr = ConfigManager(verbose=False) + with pytest.raises(ValueError, match="ddp_static_graph"): + mgr.load_config(temp_path) + finally: + Path(temp_path).unlink() + + def test_explicit_single_label_volume_does_not_force_allow_unlabeled(self): + config = { + "dataset_config": { + "targets": { + "surface": {"out_channels": 2} + }, + "volumes": [ + { + "image": "/tmp/image.zarr", + "label": "/tmp/label_surface.zarr", + } + ], + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + yaml.dump(config, f) + temp_path = f.name + + try: + mgr = ConfigManager(verbose=False) + mgr.load_config(temp_path) + assert mgr.allow_unlabeled_data is False + finally: + Path(temp_path).unlink() diff --git a/vesuvius/tests/models/configuration/test_ps256_config_compat.py b/vesuvius/tests/models/configuration/test_ps256_config_compat.py new file mode 100644 index 000000000..cb68865f6 --- /dev/null +++ b/vesuvius/tests/models/configuration/test_ps256_config_compat.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +from pathlib import Path + +from vesuvius.models.build.build_network_from_config import NetworkFromConfig +from vesuvius.models.configuration.config_manager import ConfigManager + + +def _load_mgr(rel_path: str) -> ConfigManager: + cfg = Path(rel_path) + mgr = ConfigManager(verbose=False) + mgr.load_config(str(cfg)) + mgr.spacing = (1.0, 1.0, 1.0) + mgr.autoconfigure = False + mgr.train_batch_size = 1 + return mgr + + +def test_ps256_medial_config_loads_and_builds_without_guide_path(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps256_medial.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (256, 256, 256) + assert list(mgr.targets.keys()) == ["surface"] + assert model.guide_enabled is False + assert model.final_config["guide_backbone"] is None + + +def test_ps256_dicece_config_loads_and_builds_without_guide_path(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps256_dicece.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (256, 256, 256) + assert list(mgr.targets.keys()) == ["surface"] + assert model.guide_enabled is False + assert model.final_config["guide_backbone"] is None + + +def test_ps256_guided_medial_config_loads_and_builds_with_guide_path(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps256_guided_medial.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (256, 256, 256) + assert list(mgr.targets.keys()) == ["surface"] + assert model.guide_enabled is True + assert model.final_config["guide_backbone"] is not None + assert model.final_config["guide_tokenbook_tokens"] == 256 + + +def test_ps256_guided_dicece_config_loads_and_builds_with_guide_path(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps256_guided_dicece.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (256, 256, 256) + assert list(mgr.targets.keys()) == ["surface"] + assert model.guide_enabled is True + assert model.final_config["guide_backbone"] is not None + assert model.final_config["guide_tokenbook_tokens"] == 256 + + +def test_ps128_guided_feature_encoder_config_loads_and_builds_with_encoder_guidance(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps128_guided_feature_encoder_ink.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (128, 128, 128) + assert list(mgr.targets.keys()) == ["ink"] + assert model.guide_enabled is True + assert model.final_config["guide_fusion_stage"] == "feature_encoder" + assert model.final_config["guide_stage_keys"] == ["enc_0", "enc_1", "enc_2", "enc_3", "enc_4"] + assert model.final_config["guide_feature_gate_alpha"] == 0.9 + assert model.final_config["guide_tokenbook_prototype_weighting"] == "token_mlp" + + +def test_ps256_guided_feature_encoder_medial_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps256_guided_feature_encoder_medial.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (256, 256, 256) + assert list(mgr.targets.keys()) == ["surface"] + assert model.guide_enabled is True + assert model.final_config["guide_fusion_stage"] == "feature_encoder" + assert model.final_config["guide_tokenbook_tokens"] == 256 + assert model.final_config["guide_feature_gate_alpha"] == 0.9 + assert model.final_config["guide_tokenbook_prototype_weighting"] == "token_mlp" + + +def test_ps256_guided_feature_encoder_dicece_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps256_guided_feature_encoder_dicece.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (256, 256, 256) + assert list(mgr.targets.keys()) == ["surface"] + assert model.guide_enabled is True + assert model.final_config["guide_fusion_stage"] == "feature_encoder" + assert model.final_config["guide_tokenbook_tokens"] == 256 + assert model.final_config["guide_feature_gate_alpha"] == 0.9 + assert model.final_config["guide_tokenbook_prototype_weighting"] == "token_mlp" + + +def test_ps128_guided_feature_skip_concat_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps128_guided_feature_skip_concat_ink.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (128, 128, 128) + assert list(mgr.targets.keys()) == ["ink"] + assert model.guide_enabled is True + assert model.final_config["guide_fusion_stage"] == "feature_skip_concat" + assert model.final_config["guide_stage_keys"] == ["enc_0", "enc_1", "enc_2", "enc_3", "enc_4"] + assert model.final_config["guide_skip_concat_projector_channels"] == [8, 16, 32, 64, 80] + assert model.final_config["guide_feature_gate_alpha"] is None + assert model.final_config["guide_tokenbook_tokens"] is None + assert model.final_config["guide_tokenbook_prototype_weighting"] is None + + +def test_ps256_guided_feature_skip_concat_medial_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps256_guided_feature_skip_concat_medial.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (256, 256, 256) + assert list(mgr.targets.keys()) == ["surface"] + assert model.guide_enabled is True + assert model.final_config["guide_fusion_stage"] == "feature_skip_concat" + assert model.final_config["guide_skip_concat_projector_channels"] == [8, 16, 32, 64, 80, 80, 80] + assert model.final_config["guide_feature_gate_alpha"] is None + assert model.final_config["guide_tokenbook_tokens"] is None + assert model.final_config["guide_tokenbook_prototype_weighting"] is None + + +def test_ps256_guided_feature_skip_concat_dicece_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps256_guided_feature_skip_concat_dicece.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (256, 256, 256) + assert list(mgr.targets.keys()) == ["surface"] + assert model.guide_enabled is True + assert model.final_config["guide_fusion_stage"] == "feature_skip_concat" + assert model.final_config["guide_skip_concat_projector_channels"] == [8, 16, 32, 64, 80, 80, 80] + assert model.final_config["guide_feature_gate_alpha"] is None + assert model.final_config["guide_tokenbook_tokens"] is None + assert model.final_config["guide_tokenbook_prototype_weighting"] is None + + +def test_ps128_guided_direct_segmentation_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps128_guided_direct_segmentation_ink.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (128, 128, 128) + assert list(mgr.targets.keys()) == ["ink"] + assert model.guide_enabled is True + assert model.final_config["guide_fusion_stage"] == "direct_segmentation" + assert model.final_config["guide_direct_output_mode"] == "two_channel_logits" + + +def test_ps256_guided_direct_segmentation_medial_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps256_guided_direct_segmentation_medial.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (256, 256, 256) + assert list(mgr.targets.keys()) == ["surface"] + assert model.guide_enabled is True + assert model.final_config["guide_fusion_stage"] == "direct_segmentation" + assert model.final_config["guide_direct_output_mode"] == "two_channel_logits" + + +def test_ps256_guided_direct_segmentation_dicece_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps256_guided_direct_segmentation_dicece.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (256, 256, 256) + assert list(mgr.targets.keys()) == ["surface"] + assert model.guide_enabled is True + assert model.final_config["guide_fusion_stage"] == "direct_segmentation" + assert model.final_config["guide_direct_output_mode"] == "two_channel_logits" + + +def test_ps128_pretrained_dino_pixelshuffle_medial_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps128_pretrained_dino_pixelshuffle_medial.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (128, 128, 128) + assert list(mgr.targets.keys()) == ["surface"] + assert model.final_config["pretrained_backbone"] is not None + assert model.final_config["pretrained_decoder_type"] == "pixelshuffle_conv" + assert model.final_config["freeze_encoder"] is True + + +def test_ps256_pretrained_dino_pixelshuffle_medial_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps256_pretrained_dino_pixelshuffle_medial.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (256, 256, 256) + assert list(mgr.targets.keys()) == ["surface"] + assert model.final_config["pretrained_backbone"] is not None + assert model.final_config["pretrained_decoder_type"] == "pixelshuffle_conv" + assert model.final_config["freeze_encoder"] is True + + +def test_ps128_mednext_v1_medial_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps128_mednext_v1_medial.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (128, 128, 128) + assert list(mgr.targets.keys()) == ["surface"] + assert model.final_config["architecture_type"] == "mednext_v1" + assert model.final_config["mednext_model_id"] == "B" + + +def test_ps256_mednext_v1_medial_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps256_mednext_v1_medial.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (256, 256, 256) + assert list(mgr.targets.keys()) == ["surface"] + assert model.final_config["architecture_type"] == "mednext_v1" + assert model.final_config["mednext_model_id"] == "B" + + +def test_ps128_mednext_v2_l_medial_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps128_mednext_v2_l_medial.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (128, 128, 128) + assert list(mgr.targets.keys()) == ["surface"] + assert model.final_config["architecture_type"] == "mednext_v2" + assert model.final_config["mednext_model_id"] == "L" + assert model.final_config["mednext_grn"] is True + + +def test_ps192_mednext_v2_l_medial_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps192_mednext_v2_l_medial.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (192, 192, 192) + assert list(mgr.targets.keys()) == ["surface"] + assert model.final_config["architecture_type"] == "mednext_v2" + assert model.final_config["mednext_model_id"] == "L" + + +def test_ps128_mednext_v2_b_medial_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps128_mednext_v2_b_medial.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (128, 128, 128) + assert list(mgr.targets.keys()) == ["surface"] + assert model.final_config["architecture_type"] == "mednext_v2" + assert model.final_config["mednext_model_id"] == "B" + + +def test_ps128_mednext_v2_l_width2_medial_config_loads_and_builds(): + mgr = _load_mgr("src/vesuvius/models/configuration/single_task/ps128_mednext_v2_l_width2_medial.yaml") + model = NetworkFromConfig(mgr) + + assert mgr.train_patch_size == (128, 128, 128) + assert list(mgr.targets.keys()) == ["surface"] + assert model.final_config["architecture_type"] == "mednext_v2" + assert model.final_config["mednext_model_id"] == "L" + assert model.final_config["mednext_width_factor"] == 2 diff --git a/vesuvius/tests/models/training/test_base_trainer.py b/vesuvius/tests/models/training/test_base_trainer.py index 8a1fd3d4c..5e09c1498 100644 --- a/vesuvius/tests/models/training/test_base_trainer.py +++ b/vesuvius/tests/models/training/test_base_trainer.py @@ -1,6 +1,7 @@ import torch from types import SimpleNamespace +from vesuvius.models.training.optimizers import create_optimizer from vesuvius.models.training.train import BaseTrainer @@ -49,3 +50,19 @@ def test_base_trainer_should_include_target_in_loss(): assert trainer._should_include_target_in_loss("ink") is True assert trainer._should_include_target_in_loss("ink_skel") is False assert trainer._should_include_target_in_loss("is_unlabeled") is False + + +def test_create_optimizer_adamw_honors_eps_override(): + model = torch.nn.Linear(4, 2) + optimizer = create_optimizer( + { + "name": "adamw", + "learning_rate": 1e-3, + "weight_decay": 3e-5, + "eps": 1e-4, + }, + model, + ) + + assert isinstance(optimizer, torch.optim.AdamW) + assert optimizer.defaults["eps"] == 1e-4 diff --git a/vesuvius/tests/models/training/test_guided_trainer.py b/vesuvius/tests/models/training/test_guided_trainer.py new file mode 100644 index 000000000..03b6572a7 --- /dev/null +++ b/vesuvius/tests/models/training/test_guided_trainer.py @@ -0,0 +1,1294 @@ +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import pytest +import torch +import zarr +from torch.utils.data import DataLoader + +from vesuvius.models.build.pretrained_backbones.dinovol_2_builder import build_dinovol_2_backbone +from vesuvius.models.run.inference import Inferer +from vesuvius.models.training.train import BaseTrainer +from vesuvius.utils.plotting import save_debug + + +def _tiny_dinovol_model_config() -> dict: + return { + "model_type": "v2", + "input_channels": 1, + "global_crops_size": [16, 16, 16], + "local_crops_size": [16, 16, 16], + "patch_size": [8, 8, 8], + "embed_dim": 48, + "depth": 2, + "num_heads": 4, + "num_reg_tokens": 2, + "mlp_ratio": 2.0, + "drop_path_rate": 0.0, + "qkv_fused": True, + } + + +def _write_local_guide_checkpoint(path: Path) -> None: + backbone = build_dinovol_2_backbone(_tiny_dinovol_model_config()) + teacher_state = {f"backbone.{key}": value.cpu() for key, value in backbone.state_dict().items()} + checkpoint = { + "config": { + "model": _tiny_dinovol_model_config(), + "dataset": { + "global_crop_size": [16, 16, 16], + "local_crop_size": [16, 16, 16], + }, + }, + "teacher": teacher_state, + } + torch.save(checkpoint, path) + + +def _make_synthetic_dataset(root: Path) -> Path: + data_root = root / "data" + image_root = data_root / "images" / "volume1.zarr" + label_root_ink = data_root / "labels" / "volume1_ink.zarr" + label_root_surface = data_root / "labels" / "volume1_surface.zarr" + + image_group = zarr.open_group(str(image_root), mode="w") + label_group_ink = zarr.open_group(str(label_root_ink), mode="w") + label_group_surface = zarr.open_group(str(label_root_surface), mode="w") + image_array = image_group.create_dataset("0", shape=(16, 16, 16), chunks=(16, 16, 16), dtype="float32") + label_array_ink = label_group_ink.create_dataset("0", shape=(16, 16, 16), chunks=(16, 16, 16), dtype="uint8") + label_array_surface = label_group_surface.create_dataset("0", shape=(16, 16, 16), chunks=(16, 16, 16), dtype="uint8") + + coords = np.linspace(-1.0, 1.0, 16, dtype=np.float32) + z = coords[:, None, None] + y = coords[None, :, None] + x = coords[None, None, :] + image = np.exp(-(x * x + y * y + z * z) * 4.0).astype(np.float32) + label = ((x * x + y * y + z * z) < 0.35).astype(np.uint8) + + image_array[:] = image + label_array_ink[:] = label + label_array_surface[:] = label + return data_root + + +def _make_mgr( + data_root: Path, + guide_checkpoint: Path, + *, + guide_fusion_stage: str | None = None, + guide_loss_weight: float = 0.5, + guide_tokenbook_prototype_weighting: str | None = None, + losses: list[dict] | None = None, +) -> SimpleNamespace: + guided_model_config = { + "features_per_stage": [8, 16, 32], + "n_stages": 3, + "n_blocks_per_stage": [1, 1, 1], + "n_conv_per_stage_decoder": [1, 1], + "kernel_sizes": [[3, 3, 3], [3, 3, 3], [3, 3, 3]], + "strides": [[1, 1, 1], [2, 2, 2], [2, 2, 2]], + "basic_encoder_block": "ConvBlock", + "basic_decoder_block": "ConvBlock", + "bottleneck_block": "BasicBlockD", + "separate_decoders": True, + "guide_backbone": str(guide_checkpoint), + "guide_freeze": True, + "guide_tokenbook_sample_rate": 1.0, + "input_shape": [16, 16, 16], + } + if guide_fusion_stage is not None: + guided_model_config["guide_fusion_stage"] = str(guide_fusion_stage) + if guide_tokenbook_prototype_weighting is not None: + guided_model_config["guide_tokenbook_prototype_weighting"] = str(guide_tokenbook_prototype_weighting) + + return SimpleNamespace( + gpu_ids=None, + use_ddp=False, + targets={ + "ink": { + "out_channels": 2, + "activation": "none", + "losses": losses if losses is not None else [{"name": "nnUNet_DC_and_CE_loss", "weight": 1.0}], + } + }, + tr_configs={}, + model_name="guided_smoke", + train_patch_size=(16, 16, 16), + train_batch_size=1, + in_channels=1, + autoconfigure=False, + enable_deep_supervision=False, + spacing=(1.0, 1.0, 1.0), + data_path=data_root, + min_labeled_ratio=0.0, + min_bbox_percent=0.0, + skip_patch_validation=True, + allow_unlabeled_data=False, + normalization_scheme="zscore", + intensity_properties={}, + skip_intensity_sampling=True, + no_spatial_augmentation=True, + no_scaling_augmentation=True, + cache_valid_patches=False, + dataset_config={}, + ome_zarr_resolution=0, + valid_patch_find_resolution=0, + bg_sampling_enabled=False, + bg_to_fg_ratio=0.5, + unlabeled_foreground_enabled=False, + unlabeled_foreground_threshold=0.05, + unlabeled_foreground_bbox_threshold=0.15, + unlabeled_foreground_volumes=None, + profile_augmentations=False, + guide_loss_weight=guide_loss_weight, + guide_supervision_target="ink", + train_num_dataloader_workers=0, + model_config=guided_model_config, + ) + + +def _make_pretrained_backbone_mgr( + data_root: Path, + checkpoint_path: Path, + *, + pretrained_decoder_type: str = "pixelshuffle_conv", + losses: list[dict] | None = None, +) -> SimpleNamespace: + return SimpleNamespace( + gpu_ids=None, + use_ddp=False, + targets={ + "surface": { + "out_channels": 2, + "activation": "none", + "losses": losses if losses is not None else [{"name": "nnUNet_DC_and_CE_loss", "weight": 1.0}], + } + }, + tr_configs={}, + model_name="pretrained_backbone_smoke", + train_patch_size=(16, 16, 16), + train_batch_size=1, + in_channels=1, + autoconfigure=False, + enable_deep_supervision=False, + spacing=(1.0, 1.0, 1.0), + data_path=data_root, + min_labeled_ratio=0.0, + min_bbox_percent=0.0, + skip_patch_validation=True, + allow_unlabeled_data=False, + normalization_scheme="zscore", + intensity_properties={}, + skip_intensity_sampling=True, + no_spatial_augmentation=True, + no_scaling_augmentation=True, + cache_valid_patches=False, + dataset_config={}, + ome_zarr_resolution=0, + valid_patch_find_resolution=0, + bg_sampling_enabled=False, + bg_to_fg_ratio=0.5, + unlabeled_foreground_enabled=False, + unlabeled_foreground_threshold=0.05, + unlabeled_foreground_bbox_threshold=0.15, + unlabeled_foreground_volumes=None, + profile_augmentations=False, + guide_loss_weight=0.0, + guide_supervision_target=None, + train_num_dataloader_workers=0, + model_config={ + "pretrained_backbone": str(checkpoint_path), + "pretrained_decoder_type": str(pretrained_decoder_type), + "freeze_encoder": True, + "input_shape": [16, 16, 16], + }, + ) + + +def test_guided_base_trainer_smoke_runs_two_training_steps(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr(data_root, guide_checkpoint) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model().to(trainer.device) + loss_fns = trainer._build_loss() + dataset = trainer._configure_dataset(is_training=True) + batch = next(iter(DataLoader(dataset, batch_size=1, shuffle=False))) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + + for _step in range(2): + optimizer.zero_grad(set_to_none=True) + _inputs, targets_dict, outputs = trainer._get_model_outputs(model, batch) + loss, task_losses = trainer._compute_train_loss(outputs, targets_dict, loss_fns) + assert torch.isfinite(loss) + assert "guide_mask" in task_losses + assert task_losses["guide_mask"] > 0.0 + loss.backward() + optimizer.step() + + +def test_guided_checkpoint_roundtrip_preserves_plain_inference_forward(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr(data_root, guide_checkpoint) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model() + checkpoint_path = tmp_path / "guided_model.pth" + torch.save({"model_config": model.final_config, "model": model.state_dict()}, checkpoint_path) + + output_dir = tmp_path / "inference_out" + output_dir.mkdir() + inferer = Inferer( + model_path=str(checkpoint_path), + input_dir=str(data_root / "images" / "volume1.zarr"), + output_dir=str(output_dir), + input_format="zarr", + do_tta=False, + device="cpu", + num_dataloader_workers=0, + model_type="train_py", + ) + + model_info = inferer._load_train_py_model(checkpoint_path) + output = model_info["network"](torch.randn(1, 1, 16, 16, 16)) + + assert isinstance(output, dict) + assert set(output.keys()) == {"ink"} + + +def test_feature_encoder_checkpoint_roundtrip_preserves_plain_inference_forward(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr( + data_root, + guide_checkpoint, + guide_fusion_stage="feature_encoder", + guide_loss_weight=0.0, + guide_tokenbook_prototype_weighting="token_mlp", + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model() + checkpoint_path = tmp_path / "feature_encoder_model.pth" + torch.save({"model_config": model.final_config, "model": model.state_dict()}, checkpoint_path) + + output_dir = tmp_path / "inference_out_feature_encoder" + output_dir.mkdir() + inferer = Inferer( + model_path=str(checkpoint_path), + input_dir=str(data_root / "images" / "volume1.zarr"), + output_dir=str(output_dir), + input_format="zarr", + do_tta=False, + device="cpu", + num_dataloader_workers=0, + model_type="train_py", + ) + + model_info = inferer._load_train_py_model(checkpoint_path) + output = model_info["network"](torch.randn(1, 1, 16, 16, 16)) + + assert isinstance(output, dict) + assert set(output.keys()) == {"ink"} + assert model_info["network"].final_config["guide_fusion_stage"] == "feature_encoder" + + +def test_feature_skip_concat_checkpoint_roundtrip_preserves_plain_inference_forward(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr( + data_root, + guide_checkpoint, + guide_fusion_stage="feature_skip_concat", + guide_loss_weight=0.0, + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model() + checkpoint_path = tmp_path / "feature_skip_concat_model.pth" + torch.save({"model_config": model.final_config, "model": model.state_dict()}, checkpoint_path) + + output_dir = tmp_path / "inference_out_feature_skip_concat" + output_dir.mkdir() + inferer = Inferer( + model_path=str(checkpoint_path), + input_dir=str(data_root / "images" / "volume1.zarr"), + output_dir=str(output_dir), + input_format="zarr", + do_tta=False, + device="cpu", + num_dataloader_workers=0, + model_type="train_py", + ) + + model_info = inferer._load_train_py_model(checkpoint_path) + output = model_info["network"](torch.randn(1, 1, 16, 16, 16)) + + assert isinstance(output, dict) + assert set(output.keys()) == {"ink"} + assert model_info["network"].final_config["guide_fusion_stage"] == "feature_skip_concat" + + +def test_pretrained_backbone_pixelshuffle_checkpoint_roundtrip_preserves_plain_inference_forward(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_pretrained_backbone_mgr( + data_root, + guide_checkpoint, + pretrained_decoder_type="pixelshuffle_conv", + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model() + checkpoint_path = tmp_path / "pretrained_pixelshuffle_model.pth" + torch.save({"model_config": model.final_config, "model": model.state_dict()}, checkpoint_path) + + output_dir = tmp_path / "inference_out_pretrained_pixelshuffle" + output_dir.mkdir() + inferer = Inferer( + model_path=str(checkpoint_path), + input_dir=str(data_root / "images" / "volume1.zarr"), + output_dir=str(output_dir), + input_format="zarr", + do_tta=False, + device="cpu", + num_dataloader_workers=0, + model_type="train_py", + ) + + model_info = inferer._load_train_py_model(checkpoint_path) + output = model_info["network"](torch.randn(1, 1, 16, 16, 16)) + + assert isinstance(output, dict) + assert set(output.keys()) == {"surface"} + assert output["surface"].shape == (1, 2, 16, 16, 16) + assert model_info["network"].final_config["pretrained_decoder_type"] == "pixelshuffle_conv" + + +def test_pretrained_backbone_pixelshuffle_convhead_big_checkpoint_roundtrip_preserves_plain_inference_forward(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_pretrained_backbone_mgr( + data_root, + guide_checkpoint, + pretrained_decoder_type="pixelshuffle_convhead_big", + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model() + checkpoint_path = tmp_path / "pretrained_pixelshuffle_convhead_big_model.pth" + torch.save({"model_config": model.final_config, "model": model.state_dict()}, checkpoint_path) + + output_dir = tmp_path / "inference_out_pretrained_pixelshuffle_convhead_big" + output_dir.mkdir() + inferer = Inferer( + model_path=str(checkpoint_path), + input_dir=str(data_root / "images" / "volume1.zarr"), + output_dir=str(output_dir), + input_format="zarr", + do_tta=False, + device="cpu", + num_dataloader_workers=0, + model_type="train_py", + ) + + model_info = inferer._load_train_py_model(checkpoint_path) + output = model_info["network"](torch.randn(1, 1, 16, 16, 16)) + + assert isinstance(output, dict) + assert set(output.keys()) == {"surface"} + assert output["surface"].shape == (1, 2, 16, 16, 16) + assert model_info["network"].final_config["pretrained_decoder_type"] == "pixelshuffle_convhead_big" + + +def test_pretrained_backbone_explicit_volumes_dataset_loads_from_volume_specs(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_pretrained_backbone_mgr( + data_root, + guide_checkpoint, + pretrained_decoder_type="pixelshuffle_convhead_big", + ) + + runtime_cfg_dir = tmp_path / "runtime_configs" + runtime_cfg_dir.mkdir() + mgr.data_path = runtime_cfg_dir + mgr.dataset_config = { + "volumes": [ + { + "image": str(data_root / "images" / "volume1.zarr"), + "label": str(data_root / "labels" / "volume1_surface.zarr"), + } + ], + "skip_patch_validation": True, + } + mgr.skip_patch_validation = True + + trainer = BaseTrainer(mgr=mgr, verbose=False) + dataset = trainer._configure_dataset(is_training=True) + + assert len(dataset._volumes) == 1 + assert dataset._volumes[0].image_path == data_root / "images" / "volume1.zarr" + assert dataset._volumes[0].label_paths["surface"] == data_root / "labels" / "volume1_surface.zarr" + assert dataset._volumes[0].has_labels is True + assert len(dataset) > 0 + + +def test_pretrained_backbone_explicit_volumes_duplicate_image_ids_match_cache_naming(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_pretrained_backbone_mgr( + data_root, + guide_checkpoint, + pretrained_decoder_type="pixelshuffle_convhead_big", + ) + + runtime_cfg_dir = tmp_path / "runtime_configs" + runtime_cfg_dir.mkdir() + image_path = data_root / "images" / "volume1.zarr" + label_path = data_root / "labels" / "volume1_surface.zarr" + mgr.data_path = runtime_cfg_dir + mgr.dataset_config = { + "volumes": [ + {"image": str(image_path), "label": str(label_path)}, + {"image": str(image_path), "label": str(label_path)}, + {"image": str(image_path), "label": str(label_path)}, + ], + "skip_patch_validation": True, + } + mgr.skip_patch_validation = True + + trainer = BaseTrainer(mgr=mgr, verbose=False) + dataset = trainer._configure_dataset(is_training=True) + + assert [vol.volume_id for vol in dataset._volumes] == ["volume1", "volume1_1", "volume1_2"] + + +def test_direct_segmentation_checkpoint_roundtrip_preserves_plain_inference_forward(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr( + data_root, + guide_checkpoint, + guide_fusion_stage="direct_segmentation", + guide_loss_weight=0.0, + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model() + checkpoint_path = tmp_path / "direct_segmentation_model.pth" + torch.save({"model_config": model.final_config, "model": model.state_dict()}, checkpoint_path) + + output_dir = tmp_path / "inference_out_direct_segmentation" + output_dir.mkdir() + inferer = Inferer( + model_path=str(checkpoint_path), + input_dir=str(data_root / "images" / "volume1.zarr"), + output_dir=str(output_dir), + input_format="zarr", + do_tta=False, + device="cpu", + num_dataloader_workers=0, + model_type="train_py", + ) + + model_info = inferer._load_train_py_model(checkpoint_path) + output = model_info["network"](torch.randn(1, 1, 16, 16, 16)) + + assert isinstance(output, dict) + assert set(output.keys()) == {"ink"} + assert output["ink"].shape == (1, 2, 16, 16, 16) + assert model_info["network"].final_config["guide_fusion_stage"] == "direct_segmentation" + + +def test_inferer_finalize_output_batch_returns_float16_numpy(): + inferer = object.__new__(Inferer) + inferer.device = torch.device("cpu") + + output_batch = torch.randn(1, 2, 4, 4, 4, dtype=torch.float32) + output_np = inferer._finalize_output_batch(output_batch) + + assert output_np.dtype == np.float16 + assert output_np.shape == (1, 2, 4, 4, 4) + + +def test_save_debug_adds_aux_panel_to_preview_image(): + input_volume = torch.zeros((1, 1, 4, 8, 8), dtype=torch.float32) + targets_dict = {"surface": torch.zeros((1, 2, 4, 8, 8), dtype=torch.float32)} + outputs_dict = {"surface": torch.zeros((1, 2, 4, 8, 8), dtype=torch.float32)} + aux_outputs_dict = {"guide_surface": torch.ones((1, 1, 4, 8, 8), dtype=torch.float32)} + tasks_dict = {"surface": {"activation": "none"}} + + _, preview_without_aux = save_debug( + input_volume=input_volume, + targets_dict=targets_dict, + outputs_dict=outputs_dict, + tasks_dict=tasks_dict, + epoch=0, + save_media=False, + ) + _, preview_with_aux = save_debug( + input_volume=input_volume, + targets_dict=targets_dict, + outputs_dict=outputs_dict, + aux_outputs_dict=aux_outputs_dict, + tasks_dict=tasks_dict, + epoch=0, + save_media=False, + ) + + assert preview_with_aux is not None + assert preview_without_aux is not None + assert preview_with_aux.shape[1] > preview_without_aux.shape[1] + + +def test_trainer_builds_guide_preview_and_media_payload(): + mgr = _make_mgr(Path("/tmp/guide_backbone.pt"), Path("/tmp/guide_backbone.pt")) + trainer = BaseTrainer(mgr=mgr, verbose=False) + + aux_outputs_dict = { + "guide_surface": torch.zeros((1, 1, 8, 8, 8), dtype=torch.float32) + } + aux_outputs_dict["guide_surface"][:, :, 4] = 1.0 + + guide_preview = trainer._make_aux_preview_image(aux_outputs_dict) + payload = trainer._build_debug_media_payload( + debug_preview_image=np.zeros((8, 8, 3), dtype=np.uint8), + guide_preview_image=guide_preview, + ) + + assert guide_preview is not None + assert guide_preview.ndim == 3 + assert guide_preview.shape[2] == 3 + assert set(payload.keys()) == {"debug_image", "debug_guide_image"} + + +def test_trainer_builds_only_standard_debug_payload_when_no_guide_preview(): + mgr = _make_mgr(Path("/tmp/guide_backbone.pt"), Path("/tmp/guide_backbone.pt")) + trainer = BaseTrainer(mgr=mgr, verbose=False) + + payload = trainer._build_debug_media_payload( + debug_preview_image=np.zeros((8, 8, 3), dtype=np.uint8), + guide_preview_image=None, + ) + + assert set(payload.keys()) == {"debug_image"} + + +def test_attach_debug_media_to_wandb_metrics_includes_guide_image(): + class _FakeWandb: + @staticmethod + def Image(value): + return ("image", value.shape) + + metrics = {"epoch": 0, "step": 1} + payload = { + "debug_image": np.zeros((8, 8, 3), dtype=np.uint8), + "debug_guide_image": np.zeros((6, 10, 3), dtype=np.uint8), + } + + result = BaseTrainer._attach_debug_media_to_wandb_metrics(metrics, payload, _FakeWandb) + + assert result["debug_image"] == ("image", (8, 8, 3)) + assert result["debug_guide_image"] == ("image", (6, 10, 3)) + + +def test_guided_base_trainer_omits_guide_loss_when_weight_is_zero(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr(data_root, guide_checkpoint, guide_loss_weight=0.0) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model().to(trainer.device) + loss_fns = trainer._build_loss() + dataset = trainer._configure_dataset(is_training=True) + batch = next(iter(DataLoader(dataset, batch_size=1, shuffle=False))) + + _inputs, targets_dict, outputs = trainer._get_model_outputs(model, batch) + loss, task_losses = trainer._compute_train_loss(outputs, targets_dict, loss_fns) + + assert torch.isfinite(loss) + assert "guide_mask" not in task_losses + + +def test_feature_encoder_base_trainer_smoke_runs_two_training_steps_without_aux_guide_loss(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr( + data_root, + guide_checkpoint, + guide_fusion_stage="feature_encoder", + guide_loss_weight=0.0, + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model().to(trainer.device) + loss_fns = trainer._build_loss() + dataset = trainer._configure_dataset(is_training=True) + batch = next(iter(DataLoader(dataset, batch_size=1, shuffle=False))) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + + for _step in range(2): + optimizer.zero_grad(set_to_none=True) + _inputs, targets_dict, outputs = trainer._get_model_outputs(model, batch) + loss, task_losses = trainer._compute_train_loss(outputs, targets_dict, loss_fns) + assert torch.isfinite(loss) + assert "guide_mask" not in task_losses + assert set(trainer._current_aux_outputs.keys()) == {"enc_0", "enc_1", "enc_2"} + loss.backward() + optimizer.step() + + +def test_feature_skip_concat_base_trainer_smoke_runs_two_training_steps_without_aux_outputs(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr( + data_root, + guide_checkpoint, + guide_fusion_stage="feature_skip_concat", + guide_loss_weight=0.0, + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model().to(trainer.device) + loss_fns = trainer._build_loss() + dataset = trainer._configure_dataset(is_training=True) + batch = next(iter(DataLoader(dataset, batch_size=1, shuffle=False))) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + + for _step in range(2): + optimizer.zero_grad(set_to_none=True) + _inputs, targets_dict, outputs = trainer._get_model_outputs(model, batch) + loss, task_losses = trainer._compute_train_loss(outputs, targets_dict, loss_fns) + assert torch.isfinite(loss) + assert "guide_mask" not in task_losses + assert trainer._current_aux_outputs == {} + loss.backward() + optimizer.step() + + +def test_pretrained_backbone_pixelshuffle_base_trainer_smoke_runs_two_training_steps(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_pretrained_backbone_mgr( + data_root, + guide_checkpoint, + pretrained_decoder_type="pixelshuffle_conv", + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model().to(trainer.device) + loss_fns = trainer._build_loss() + dataset = trainer._configure_dataset(is_training=True) + batch = next(iter(DataLoader(dataset, batch_size=1, shuffle=False))) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + + for _step in range(2): + optimizer.zero_grad(set_to_none=True) + _inputs, targets_dict, outputs = trainer._get_model_outputs(model, batch) + loss, task_losses = trainer._compute_train_loss(outputs, targets_dict, loss_fns) + assert torch.isfinite(loss) + assert outputs["surface"].shape == (1, 2, 16, 16, 16) + loss.backward() + optimizer.step() + + +def test_direct_segmentation_base_trainer_smoke_runs_two_training_steps_without_aux_guide_loss(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr( + data_root, + guide_checkpoint, + guide_fusion_stage="direct_segmentation", + guide_loss_weight=0.0, + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model().to(trainer.device) + loss_fns = trainer._build_loss() + dataset = trainer._configure_dataset(is_training=True) + batch = next(iter(DataLoader(dataset, batch_size=1, shuffle=False))) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + + for _step in range(2): + optimizer.zero_grad(set_to_none=True) + _inputs, targets_dict, outputs = trainer._get_model_outputs(model, batch) + loss, task_losses = trainer._compute_train_loss(outputs, targets_dict, loss_fns) + assert torch.isfinite(loss) + assert outputs["ink"].shape == (1, 2, 16, 16, 16) + assert "guide_mask" not in task_losses + assert set(trainer._current_aux_outputs.keys()) == {"guide_mask"} + loss.backward() + optimizer.step() + + +def test_direct_segmentation_bce_style_loss_smoke_runs(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr( + data_root, + guide_checkpoint, + guide_fusion_stage="direct_segmentation", + guide_loss_weight=0.0, + losses=[{"name": "nnUNet_DC_and_BCE_loss", "weight": 1.0}], + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model().to(trainer.device) + loss_fns = trainer._build_loss() + dataset = trainer._configure_dataset(is_training=True) + batch = next(iter(DataLoader(dataset, batch_size=1, shuffle=False))) + + _inputs, targets_dict, outputs = trainer._get_model_outputs(model, batch) + loss, task_losses = trainer._compute_train_loss(outputs, targets_dict, loss_fns) + + assert torch.isfinite(loss) + assert "guide_mask" not in task_losses + + +def test_direct_segmentation_plain_bce_with_logits_loss_smoke_runs(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr( + data_root, + guide_checkpoint, + guide_fusion_stage="direct_segmentation", + guide_loss_weight=0.0, + losses=[{"name": "BCEWithLogitsLoss", "weight": 1.0}], + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model().to(trainer.device) + loss_fns = trainer._build_loss() + dataset = trainer._configure_dataset(is_training=True) + batch = next(iter(DataLoader(dataset, batch_size=1, shuffle=False))) + + _inputs, targets_dict, outputs = trainer._get_model_outputs(model, batch) + loss, task_losses = trainer._compute_train_loss(outputs, targets_dict, loss_fns) + + assert torch.isfinite(loss) + assert "guide_mask" not in task_losses + + +def test_direct_segmentation_medial_surface_recall_smoke_runs(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr( + data_root, + guide_checkpoint, + guide_fusion_stage="direct_segmentation", + guide_loss_weight=0.0, + losses=[{"name": "MedialSurfaceRecall", "weight": 1.0}], + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model().to(trainer.device) + loss_fns = trainer._build_loss() + dataset = trainer._configure_dataset(is_training=True) + batch = next(iter(DataLoader(dataset, batch_size=1, shuffle=False))) + + _inputs, targets_dict, outputs = trainer._get_model_outputs(model, batch) + loss, task_losses = trainer._compute_train_loss(outputs, targets_dict, loss_fns) + + assert torch.isfinite(loss) + assert "guide_mask" not in task_losses + + +def test_feature_encoder_rejects_nonzero_guide_loss_weight(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr( + data_root, + guide_checkpoint, + guide_fusion_stage="feature_encoder", + guide_loss_weight=0.25, + guide_tokenbook_prototype_weighting="token_mlp", + ) + + with pytest.raises(ValueError, match="guide_loss_weight"): + BaseTrainer(mgr=mgr, verbose=False) + + +def test_feature_skip_concat_rejects_nonzero_guide_loss_weight(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr( + data_root, + guide_checkpoint, + guide_fusion_stage="feature_skip_concat", + guide_loss_weight=0.25, + ) + + with pytest.raises(ValueError, match="guide_loss_weight"): + BaseTrainer(mgr=mgr, verbose=False) + + +def test_direct_segmentation_rejects_nonzero_guide_loss_weight(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr( + data_root, + guide_checkpoint, + guide_fusion_stage="direct_segmentation", + guide_loss_weight=0.25, + ) + + with pytest.raises(ValueError, match="guide_loss_weight"): + BaseTrainer(mgr=mgr, verbose=False) + + +def test_direct_segmentation_rejects_unsupported_loss_name(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + guide_checkpoint = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(guide_checkpoint) + mgr = _make_mgr( + data_root, + guide_checkpoint, + guide_fusion_stage="direct_segmentation", + guide_loss_weight=0.0, + losses=[{"name": "MSELoss", "weight": 1.0}], + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model().to(trainer.device) + loss_fns = trainer._build_loss() + dataset = trainer._configure_dataset(is_training=True) + batch = next(iter(DataLoader(dataset, batch_size=1, shuffle=False))) + + _inputs, targets_dict, outputs = trainer._get_model_outputs(model, batch) + + with pytest.raises(ValueError, match="not supported with guide_fusion_stage='direct_segmentation'"): + trainer._compute_train_loss(outputs, targets_dict, loss_fns) + + +def test_prepare_metrics_for_logging_includes_guide_loss_entries(): + mgr = _make_mgr(Path("/tmp/guide_backbone.pt"), Path("/tmp/guide_backbone.pt")) + trainer = BaseTrainer(mgr=mgr, verbose=False) + + metrics = trainer._prepare_metrics_for_logging( + epoch=0, + step=3, + epoch_losses={"ink": [0.4, 0.2], "guide_mask": [0.1, 0.3]}, + current_lr=1e-3, + val_losses={"ink": [0.5], "guide_mask": [0.25]}, + ) + + assert metrics["train_loss_ink"] == pytest.approx(0.3) + assert metrics["train_loss_guide_mask"] == pytest.approx(0.2) + assert metrics["val_loss_ink"] == pytest.approx(0.5) + assert metrics["val_loss_guide_mask"] == pytest.approx(0.25) + assert metrics["val_loss_total"] == pytest.approx(0.375) + + +def test_feature_encoder_aux_outputs_stay_out_of_debug_composite(): + mgr = _make_mgr( + Path("/tmp/guide_backbone.pt"), + Path("/tmp/guide_backbone.pt"), + guide_fusion_stage="feature_encoder", + guide_loss_weight=0.0, + guide_tokenbook_prototype_weighting="token_mlp", + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + aux_outputs_dict = { + "enc_0": torch.ones((1, 1, 8, 8, 8), dtype=torch.float32), + "enc_1": torch.ones((1, 1, 4, 4, 4), dtype=torch.float32), + } + + prepared = trainer._prepare_aux_debug_outputs( + aux_outputs_dict, + torch.zeros((1, 1, 8, 8, 8), dtype=torch.float32), + ) + + assert prepared == {} + + +def test_feature_encoder_trainer_builds_stage_montage_and_media_payload(tmp_path: Path): + mgr = _make_mgr( + tmp_path, + tmp_path / "guide_backbone.pt", + guide_fusion_stage="feature_encoder", + guide_loss_weight=0.0, + guide_tokenbook_prototype_weighting="token_mlp", + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + aux_outputs_dict = { + "enc_0": torch.zeros((1, 1, 8, 8, 8), dtype=torch.float32), + "enc_1": torch.zeros((1, 1, 4, 4, 4), dtype=torch.float32), + "enc_2": torch.zeros((1, 1, 2, 2, 2), dtype=torch.float32), + } + aux_outputs_dict["enc_0"][:, :, 4] = 1.0 + aux_outputs_dict["enc_1"][:, :, 2] = 1.0 + aux_outputs_dict["enc_2"][:, :, 1] = 1.0 + + guide_preview = trainer._make_feature_encoder_guide_preview_image(aux_outputs_dict) + save_path = tmp_path / "guide_preview.png" + trainer._save_preview_image(guide_preview, save_path) + payload = trainer._build_debug_media_payload( + debug_preview_image=np.zeros((8, 8, 3), dtype=np.uint8), + guide_preview_image=guide_preview, + ) + + assert guide_preview is not None + assert guide_preview.ndim == 3 + assert guide_preview.shape[2] == 3 + assert guide_preview.shape[:2] == (38, 24) + assert save_path.exists() + assert set(payload.keys()) == {"debug_image", "debug_guide_image"} + + +def test_feature_encoder_trainer_builds_train_and_val_stage_montage(tmp_path: Path): + mgr = _make_mgr( + tmp_path, + tmp_path / "guide_backbone.pt", + guide_fusion_stage="feature_encoder", + guide_loss_weight=0.0, + guide_tokenbook_prototype_weighting="token_mlp", + ) + trainer = BaseTrainer(mgr=mgr, verbose=False) + val_aux_outputs_dict = { + "enc_0": torch.zeros((1, 1, 8, 8, 8), dtype=torch.float32), + "enc_1": torch.zeros((1, 1, 4, 4, 4), dtype=torch.float32), + "enc_2": torch.zeros((1, 1, 2, 2, 2), dtype=torch.float32), + } + train_aux_outputs_dict = { + "enc_0": torch.ones((1, 1, 8, 8, 8), dtype=torch.float32), + "enc_1": torch.ones((1, 1, 4, 4, 4), dtype=torch.float32), + "enc_2": torch.ones((1, 1, 2, 2, 2), dtype=torch.float32), + } + + guide_preview = trainer._make_feature_encoder_guide_preview_image( + val_aux_outputs_dict, + train_aux_outputs_dict=train_aux_outputs_dict, + ) + + assert guide_preview is not None + assert guide_preview.ndim == 3 + assert guide_preview.shape[2] == 3 + assert guide_preview.shape[:2] == (76, 24) + + +def test_direct_segmentation_aux_outputs_render_only_standard_debug_payload(): + mgr = _make_mgr(Path("/tmp/guide_backbone.pt"), Path("/tmp/guide_backbone.pt"), guide_fusion_stage="direct_segmentation", guide_loss_weight=0.0) + trainer = BaseTrainer(mgr=mgr, verbose=False) + aux_outputs_dict = { + "guide_mask": torch.ones((1, 1, 4, 4, 4), dtype=torch.float32), + } + + prepared = trainer._prepare_aux_debug_outputs( + aux_outputs_dict, + torch.zeros((1, 1, 8, 8, 8), dtype=torch.float32), + ) + guide_preview = trainer._make_feature_encoder_guide_preview_image(aux_outputs_dict) + payload = trainer._build_debug_media_payload( + debug_preview_image=np.zeros((8, 8, 3), dtype=np.uint8), + guide_preview_image=guide_preview, + ) + + assert set(prepared.keys()) == {"guide_ink"} + assert guide_preview is None + assert set(payload.keys()) == {"debug_image"} + + +def test_select_debug_sample_index_from_targets_prefers_first_nonzero_sample(): + targets_dict = { + "surface": torch.tensor( + [ + [[[[0.0, 0.0], [0.0, 0.0]]]], + [[[[0.0, 1.0], [0.0, 0.0]]]], + [[[[1.0, 1.0], [0.0, 0.0]]]], + ], + dtype=torch.float32, + ) + } + + index, found = BaseTrainer._select_debug_sample_index_from_targets(targets_dict) + + assert found is True + assert index == 1 + + +class _DummyDataset: + def __len__(self): + return 1 + + def __getitem__(self, index): + return {"image": torch.zeros((1, 4, 4, 4), dtype=torch.float32)} + + +class _DummyModel(torch.nn.Module): + def __init__(self, *, guide_enabled: bool): + super().__init__() + self.guide_enabled = guide_enabled + self.weight = torch.nn.Parameter(torch.tensor(1.0)) + + def forward(self, x): + return x * self.weight + + +def _make_compile_mgr(tmp_path: Path, *, compile_policy: str) -> SimpleNamespace: + return SimpleNamespace( + gpu_ids=None, + use_ddp=False, + targets={"ink": {"out_channels": 1, "activation": "none"}}, + tr_configs={}, + model_name="compile_policy_test", + train_patch_size=(4, 4, 4), + train_batch_size=1, + in_channels=1, + autoconfigure=False, + enable_deep_supervision=False, + spacing=(1.0, 1.0, 1.0), + data_path=tmp_path, + ckpt_out_base=tmp_path / "ckpts", + checkpoint_path=None, + load_weights_only=False, + guide_loss_weight=0.0, + guide_supervision_target=None, + train_num_dataloader_workers=0, + compile_policy=compile_policy, + startup_timing=False, + ddp_find_unused_parameters="auto", + ddp_static_graph="auto", + ddp_gradient_as_bucket_view=False, + no_amp=True, + amp_dtype="float16", + verbose=False, + auto_detect_channels=lambda dataset, sample: None, + ) + + +def _make_pretrained_compile_mgr(tmp_path: Path, checkpoint_path: Path, *, compile_policy: str) -> SimpleNamespace: + mgr = _make_compile_mgr(tmp_path, compile_policy=compile_policy) + mgr.targets = {"surface": {"out_channels": 2, "activation": "none"}} + mgr.model_name = "compile_policy_pretrained_pixelshuffle" + mgr.train_patch_size = (16, 16, 16) + mgr.train_batch_size = 1 + mgr.model_config = { + "pretrained_backbone": str(checkpoint_path), + "pretrained_decoder_type": "pixelshuffle_conv", + "freeze_encoder": True, + "input_shape": [16, 16, 16], + } + return mgr + + +def _configure_compile_test_trainer( + trainer: BaseTrainer, + monkeypatch, + *, + guide_enabled: bool, + call_order: list[str], + compile_guidance: bool = False, +): + dummy_dataset = _DummyDataset() + dummy_model = _DummyModel(guide_enabled=guide_enabled) + if compile_guidance: + dummy_model._compile_guidance_submodules = lambda device_type: call_order.append("guide_submodule_compile") or ["guide_backbone"] + + monkeypatch.setattr(trainer, "_configure_dataset", lambda is_training: dummy_dataset) + monkeypatch.setattr(trainer, "_build_dataset_for_mgr", lambda mgr, is_training: dummy_dataset) + monkeypatch.setattr(trainer, "_prepare_sample", lambda sample, is_training: sample) + monkeypatch.setattr(trainer, "_build_model", lambda: dummy_model) + monkeypatch.setattr(trainer, "_get_optimizer", lambda model: torch.optim.SGD(model.parameters(), lr=0.1)) + monkeypatch.setattr( + trainer, + "_get_scheduler", + lambda optimizer: (torch.optim.lr_scheduler.StepLR(optimizer, step_size=1), False), + ) + monkeypatch.setattr(trainer, "_get_scaler", lambda *args, **kwargs: object()) + monkeypatch.setattr( + trainer, + "_configure_dataloaders", + lambda train_dataset, val_dataset: ([], [], [0], [0]), + ) + monkeypatch.setattr(trainer, "_build_loss", lambda: {}) + monkeypatch.setattr(trainer, "_initialize_ema_model", lambda model: None) + monkeypatch.setattr( + trainer, + "_wrap_model_for_distributed_training", + lambda model: call_order.append("wrap") or model, + ) + monkeypatch.setattr( + trainer, + "_maybe_compile_model", + lambda model: call_order.append("compile") or model, + ) + + +def test_initialize_training_compiles_guided_model_before_ddp_wrap(tmp_path: Path, monkeypatch): + mgr = _make_compile_mgr(tmp_path, compile_policy="auto") + trainer = BaseTrainer(mgr=mgr, verbose=False) + call_order: list[str] = [] + _configure_compile_test_trainer(trainer, monkeypatch, guide_enabled=True, call_order=call_order) + + trainer._initialize_training() + + assert call_order == ["compile", "wrap"] + + +def test_initialize_training_compiles_guidance_submodule_before_model_compile(tmp_path: Path, monkeypatch): + mgr = _make_compile_mgr(tmp_path, compile_policy="auto") + trainer = BaseTrainer(mgr=mgr, verbose=False) + call_order: list[str] = [] + _configure_compile_test_trainer( + trainer, + monkeypatch, + guide_enabled=True, + call_order=call_order, + compile_guidance=True, + ) + + trainer._initialize_training() + + assert call_order == ["guide_submodule_compile", "compile", "wrap"] + + +def test_initialize_training_compiles_unguided_ddp_wrapper_by_default(tmp_path: Path, monkeypatch): + mgr = _make_compile_mgr(tmp_path, compile_policy="auto") + trainer = BaseTrainer(mgr=mgr, verbose=False) + call_order: list[str] = [] + _configure_compile_test_trainer(trainer, monkeypatch, guide_enabled=False, call_order=call_order) + + trainer._initialize_training() + + assert call_order == ["wrap", "compile"] + + +def test_initialize_training_skips_compile_when_policy_is_off(tmp_path: Path, monkeypatch): + mgr = _make_compile_mgr(tmp_path, compile_policy="off") + trainer = BaseTrainer(mgr=mgr, verbose=False) + call_order: list[str] = [] + _configure_compile_test_trainer(trainer, monkeypatch, guide_enabled=True, call_order=call_order) + + trainer._initialize_training() + + assert call_order == ["wrap"] + + +def test_initialize_training_places_run_checkpoint_dir_under_ckpt_out_base(tmp_path: Path, monkeypatch): + mgr = _make_compile_mgr(tmp_path, compile_policy="off") + trainer = BaseTrainer(mgr=mgr, verbose=False) + call_order: list[str] = [] + _configure_compile_test_trainer(trainer, monkeypatch, guide_enabled=True, call_order=call_order) + + state = trainer._initialize_training() + + assert str(state["ckpt_dir"]).startswith(str(mgr.ckpt_out_base)) + assert str(state["model_ckpt_dir"]).startswith(str(mgr.ckpt_out_base)) + + +def test_initialize_training_compiles_pretrained_pixelshuffle_model_before_ddp_wrap(tmp_path: Path, monkeypatch): + checkpoint_path = tmp_path / "guide_backbone.pt" + _write_local_guide_checkpoint(checkpoint_path) + mgr = _make_pretrained_compile_mgr(tmp_path, checkpoint_path, compile_policy="module") + trainer = BaseTrainer(mgr=mgr, verbose=False) + call_order: list[str] = [] + dummy_dataset = _DummyDataset() + + monkeypatch.setattr(trainer, "_configure_dataset", lambda is_training: dummy_dataset) + monkeypatch.setattr(trainer, "_build_dataset_for_mgr", lambda mgr, is_training: dummy_dataset) + monkeypatch.setattr(trainer, "_prepare_sample", lambda sample, is_training: sample) + monkeypatch.setattr(trainer, "_get_optimizer", lambda model: torch.optim.SGD([p for p in model.parameters() if p.requires_grad], lr=0.1)) + monkeypatch.setattr( + trainer, + "_get_scheduler", + lambda optimizer: (torch.optim.lr_scheduler.StepLR(optimizer, step_size=1), False), + ) + monkeypatch.setattr(trainer, "_get_scaler", lambda *args, **kwargs: object()) + monkeypatch.setattr( + trainer, + "_configure_dataloaders", + lambda train_dataset, val_dataset: ([], [], [0], [0]), + ) + monkeypatch.setattr(trainer, "_build_loss", lambda: {}) + monkeypatch.setattr(trainer, "_initialize_ema_model", lambda model: None) + monkeypatch.setattr( + trainer, + "_wrap_model_for_distributed_training", + lambda model: call_order.append("wrap") or model, + ) + monkeypatch.setattr( + trainer, + "_maybe_compile_model", + lambda model: call_order.append("compile") or model, + ) + + trainer._initialize_training() + + assert call_order == ["compile", "wrap"] + + +def test_wrap_model_uses_guided_ddp_defaults(tmp_path: Path): + mgr = _make_compile_mgr(tmp_path, compile_policy="module") + trainer = BaseTrainer(mgr=mgr, verbose=False) + trainer.is_distributed = True + trainer.device = torch.device("cpu") + trainer.rank = 0 + captured = {} + + class _FakeDDP: + def __init__(self, model, **kwargs): + captured["model"] = model + captured["kwargs"] = kwargs + + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr("vesuvius.models.training.train.DDP", _FakeDDP) + try: + trainer._wrap_model_for_distributed_training(_DummyModel(guide_enabled=True)) + finally: + monkeypatch.undo() + + assert captured["kwargs"]["find_unused_parameters"] is False + assert captured["kwargs"]["static_graph"] is True + + +def test_wrap_model_uses_unguided_ddp_defaults(tmp_path: Path): + mgr = _make_compile_mgr(tmp_path, compile_policy="module") + trainer = BaseTrainer(mgr=mgr, verbose=False) + trainer.is_distributed = True + trainer.device = torch.device("cpu") + trainer.rank = 0 + captured = {} + + class _FakeDDP: + def __init__(self, model, **kwargs): + captured["model"] = model + captured["kwargs"] = kwargs + + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr("vesuvius.models.training.train.DDP", _FakeDDP) + try: + trainer._wrap_model_for_distributed_training(_DummyModel(guide_enabled=False)) + finally: + monkeypatch.undo() + + assert captured["kwargs"]["find_unused_parameters"] is True + assert captured["kwargs"]["static_graph"] is False + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for autocast safety test") +def test_guide_alignment_loss_is_amp_safe_on_cuda(): + mgr = _make_mgr(Path("/tmp/guide_backbone.pt"), Path("/tmp/guide_backbone.pt")) + trainer = BaseTrainer(mgr=mgr, verbose=False) + trainer.device = torch.device("cuda") + trainer._current_aux_outputs = { + "guide_mask": torch.full((1, 1, 2, 2, 2), 0.5, device=trainer.device, dtype=torch.float16) + } + targets_dict = { + "ink": torch.ones((1, 1, 2, 2, 2), device=trainer.device, dtype=torch.float16) + } + + with torch.amp.autocast("cuda"): + loss = trainer._compute_guide_alignment_loss(targets_dict) + + assert loss is not None + assert torch.isfinite(loss) diff --git a/vesuvius/tests/models/training/test_mednext_trainer.py b/vesuvius/tests/models/training/test_mednext_trainer.py new file mode 100644 index 000000000..d0cb81315 --- /dev/null +++ b/vesuvius/tests/models/training/test_mednext_trainer.py @@ -0,0 +1,340 @@ +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace + +import pytest +import torch +import zarr +import numpy as np +from torch.utils.data import DataLoader + +from vesuvius.models.build.build_network_from_config import NetworkFromConfig +from vesuvius.models.run.inference import Inferer +from vesuvius.models.training.train import BaseTrainer + + +def _make_synthetic_dataset(root: Path) -> Path: + data_root = root / "data" + image_root = data_root / "images" / "volume1.zarr" + label_root = data_root / "labels" / "volume1_surface.zarr" + + image_group = zarr.open_group(str(image_root), mode="w") + label_group = zarr.open_group(str(label_root), mode="w") + image_array = image_group.create_dataset("0", shape=(32, 32, 32), chunks=(32, 32, 32), dtype="float32") + label_array = label_group.create_dataset("0", shape=(32, 32, 32), chunks=(32, 32, 32), dtype="uint8") + + coords = np.linspace(-1.0, 1.0, 32, dtype=np.float32) + z = coords[:, None, None] + y = coords[None, :, None] + x = coords[None, None, :] + image = np.exp(-(x * x + y * y + z * z) * 4.0).astype(np.float32) + label = ((x * x + y * y + z * z) < 0.35).astype(np.uint8) + + image_array[:] = image + label_array[:] = label + return data_root + + +def _make_mgr( + data_root: Path, + *, + architecture_type: str = "mednext_v1", + mednext_model_id: str = "S", + compile_policy: str = "off", + targets: dict | None = None, + enable_deep_supervision: bool = False, + separate_decoders: bool | None = None, +) -> SimpleNamespace: + model_config = { + "architecture_type": architecture_type, + "mednext_model_id": mednext_model_id, + "mednext_kernel_size": 3, + } + if separate_decoders is not None: + model_config["separate_decoders"] = separate_decoders + + return SimpleNamespace( + gpu_ids=None, + use_ddp=False, + targets=targets or { + "surface": { + "out_channels": 2, + "activation": "none", + "losses": [{"name": "nnUNet_DC_and_CE_loss", "weight": 1.0}], + } + }, + tr_configs={}, + model_name="mednext_smoke", + train_patch_size=(32, 32, 32), + train_batch_size=1, + in_channels=1, + autoconfigure=False, + enable_deep_supervision=enable_deep_supervision, + spacing=(1.0, 1.0, 1.0), + data_path=data_root, + min_labeled_ratio=0.0, + min_bbox_percent=0.0, + skip_patch_validation=True, + allow_unlabeled_data=False, + normalization_scheme="zscore", + intensity_properties={}, + skip_intensity_sampling=True, + no_spatial_augmentation=True, + no_scaling_augmentation=True, + cache_valid_patches=False, + dataset_config={}, + ome_zarr_resolution=0, + valid_patch_find_resolution=0, + bg_sampling_enabled=False, + bg_to_fg_ratio=0.5, + unlabeled_foreground_enabled=False, + unlabeled_foreground_threshold=0.05, + unlabeled_foreground_bbox_threshold=0.15, + unlabeled_foreground_volumes=None, + profile_augmentations=False, + guide_loss_weight=0.0, + guide_supervision_target=None, + train_num_dataloader_workers=0, + compile_policy=compile_policy, + startup_timing=False, + ddp_find_unused_parameters="auto", + ddp_static_graph="auto", + ddp_gradient_as_bucket_view=False, + no_amp=True, + amp_dtype="float16", + verbose=False, + auto_detect_channels=lambda dataset, sample: None, + ckpt_out_base=data_root / "ckpts", + checkpoint_path=None, + load_weights_only=False, + model_config=model_config, + ) + + +def test_mednext_v1_base_trainer_smoke_runs_two_training_steps(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + mgr = _make_mgr(data_root) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model().to(trainer.device) + loss_fns = trainer._build_loss() + dataset = trainer._configure_dataset(is_training=True) + batch = next(iter(DataLoader(dataset, batch_size=1, shuffle=False))) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + + for _ in range(2): + optimizer.zero_grad(set_to_none=True) + _inputs, targets_dict, outputs = trainer._get_model_outputs(model, batch) + loss, task_losses = trainer._compute_train_loss(outputs, targets_dict, loss_fns) + assert torch.isfinite(loss) + assert "surface" in task_losses + loss.backward() + optimizer.step() + + +def test_mednext_v1_checkpoint_roundtrip_preserves_plain_inference_forward(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + mgr = _make_mgr(data_root) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model() + checkpoint_path = tmp_path / "mednext_model.pth" + torch.save({"model_config": model.final_config, "model": model.state_dict()}, checkpoint_path) + + output_dir = tmp_path / "inference_out" + output_dir.mkdir() + inferer = Inferer( + model_path=str(checkpoint_path), + input_dir=str(data_root / "images" / "volume1.zarr"), + output_dir=str(output_dir), + input_format="zarr", + do_tta=False, + device="cpu", + num_dataloader_workers=0, + model_type="train_py", + ) + + model_info = inferer._load_train_py_model(checkpoint_path) + output = model_info["network"](torch.randn(1, 1, 32, 32, 32)) + + assert isinstance(output, dict) + assert set(output.keys()) == {"surface"} + assert output["surface"].shape == (1, 2, 32, 32, 32) + assert model_info["network"].final_config["architecture_type"] == "mednext_v1" + assert model_info["network"].final_config["mednext_model_id"] == "S" + + +class _DummyDataset: + def __len__(self): + return 1 + + def __getitem__(self, index): + return {"image": torch.zeros((1, 4, 4, 4), dtype=torch.float32)} + + +def test_initialize_training_compiles_mednext_model_before_ddp_wrap(tmp_path: Path, monkeypatch): + mgr = _make_mgr(tmp_path, compile_policy="module") + trainer = BaseTrainer(mgr=mgr, verbose=False) + call_order: list[str] = [] + dummy_dataset = _DummyDataset() + + monkeypatch.setattr(trainer, "_configure_dataset", lambda is_training: dummy_dataset) + monkeypatch.setattr(trainer, "_build_dataset_for_mgr", lambda mgr, is_training: dummy_dataset) + monkeypatch.setattr(trainer, "_prepare_sample", lambda sample, is_training: sample) + monkeypatch.setattr(trainer, "_get_optimizer", lambda model: torch.optim.SGD([p for p in model.parameters() if p.requires_grad], lr=0.1)) + monkeypatch.setattr( + trainer, + "_get_scheduler", + lambda optimizer: (torch.optim.lr_scheduler.StepLR(optimizer, step_size=1), False), + ) + monkeypatch.setattr(trainer, "_get_scaler", lambda *args, **kwargs: object()) + monkeypatch.setattr(trainer, "_configure_dataloaders", lambda train_dataset, val_dataset: ([], [], [0], [0])) + monkeypatch.setattr(trainer, "_build_loss", lambda: {}) + monkeypatch.setattr(trainer, "_initialize_ema_model", lambda model: None) + monkeypatch.setattr(trainer, "_wrap_model_for_distributed_training", lambda model: call_order.append("wrap") or model) + monkeypatch.setattr(trainer, "_maybe_compile_model", lambda model: call_order.append("compile") or model) + + trainer._initialize_training() + + assert call_order == ["compile", "wrap"] + + +def test_mednext_v2_checkpoint_roundtrip_preserves_plain_inference_forward(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + mgr = _make_mgr(data_root, architecture_type="mednext_v2", mednext_model_id="L") + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model() + checkpoint_path = tmp_path / "mednext_v2_model.pth" + torch.save({"model_config": model.final_config, "model": model.state_dict()}, checkpoint_path) + + output_dir = tmp_path / "inference_out_v2" + output_dir.mkdir() + inferer = Inferer( + model_path=str(checkpoint_path), + input_dir=str(data_root / "images" / "volume1.zarr"), + output_dir=str(output_dir), + input_format="zarr", + do_tta=False, + device="cpu", + num_dataloader_workers=0, + model_type="train_py", + ) + + model_info = inferer._load_train_py_model(checkpoint_path) + output = model_info["network"](torch.randn(1, 1, 32, 32, 32)) + + assert isinstance(output, dict) + assert set(output.keys()) == {"surface"} + assert output["surface"].shape == (1, 2, 32, 32, 32) + assert model_info["network"].final_config["architecture_type"] == "mednext_v2" + assert model_info["network"].final_config["mednext_model_id"] == "L" + + +def test_mednext_v1_deep_supervision_checkpoint_roundtrip_loads_plain_inference_output(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + mgr = _make_mgr(data_root, enable_deep_supervision=True) + trainer = BaseTrainer(mgr=mgr, verbose=False) + model = trainer._build_model() + checkpoint_path = tmp_path / "mednext_ds_model.pth" + torch.save({"model_config": model.final_config, "model": model.state_dict()}, checkpoint_path) + + output_dir = tmp_path / "inference_out_ds" + output_dir.mkdir() + inferer = Inferer( + model_path=str(checkpoint_path), + input_dir=str(data_root / "images" / "volume1.zarr"), + output_dir=str(output_dir), + input_format="zarr", + do_tta=False, + device="cpu", + num_dataloader_workers=0, + model_type="train_py", + ) + + model_info = inferer._load_train_py_model(checkpoint_path) + output = model_info["network"](torch.randn(1, 1, 32, 32, 32)) + + assert isinstance(output, dict) + assert isinstance(output["surface"], torch.Tensor) + assert output["surface"].shape == (1, 2, 32, 32, 32) + assert model_info["network"].final_config["enable_deep_supervision"] is True + + +def test_mednext_mixed_decoder_checkpoint_roundtrip_preserves_layout(tmp_path: Path): + data_root = _make_synthetic_dataset(tmp_path) + targets = { + "surface": { + "out_channels": 2, + "activation": "none", + "losses": [{"name": "nnUNet_DC_and_CE_loss", "weight": 1.0}], + "separate_decoder": True, + }, + "tissue": { + "out_channels": 1, + "activation": "none", + "losses": [{"name": "nnUNet_DC_and_CE_loss", "weight": 1.0}], + }, + } + mgr = _make_mgr( + data_root, + targets=targets, + separate_decoders=False, + ) + model = NetworkFromConfig(mgr) + assert model.shared_decoder is not None + assert sorted(model.task_decoders.keys()) == ["surface"] + assert sorted(model.task_heads.keys()) == ["tissue"] + + checkpoint_path = tmp_path / "mednext_mixed_model.pth" + torch.save({"model_config": model.final_config, "model": model.state_dict()}, checkpoint_path) + + output_dir = tmp_path / "inference_out_mixed" + output_dir.mkdir() + inferer = Inferer( + model_path=str(checkpoint_path), + input_dir=str(data_root / "images" / "volume1.zarr"), + output_dir=str(output_dir), + input_format="zarr", + do_tta=False, + device="cpu", + num_dataloader_workers=0, + model_type="train_py", + ) + + model_info = inferer._load_train_py_model(checkpoint_path) + loaded = model_info["network"] + output = loaded(torch.randn(1, 1, 32, 32, 32)) + + assert isinstance(output, dict) + assert set(output.keys()) == {"surface", "tissue"} + assert loaded.shared_decoder is not None + assert sorted(loaded.task_decoders.keys()) == ["surface"] + assert sorted(loaded.task_heads.keys()) == ["tissue"] + assert loaded.final_config["targets"]["surface"]["separate_decoder"] is True + assert loaded.final_config["targets"]["tissue"]["separate_decoder"] is False + + +def test_initialize_training_compiles_mednext_v2_model_before_ddp_wrap(tmp_path: Path, monkeypatch): + mgr = _make_mgr(tmp_path, architecture_type="mednext_v2", mednext_model_id="B", compile_policy="module") + trainer = BaseTrainer(mgr=mgr, verbose=False) + call_order: list[str] = [] + dummy_dataset = _DummyDataset() + + monkeypatch.setattr(trainer, "_configure_dataset", lambda is_training: dummy_dataset) + monkeypatch.setattr(trainer, "_build_dataset_for_mgr", lambda mgr, is_training: dummy_dataset) + monkeypatch.setattr(trainer, "_prepare_sample", lambda sample, is_training: sample) + monkeypatch.setattr(trainer, "_get_optimizer", lambda model: torch.optim.SGD([p for p in model.parameters() if p.requires_grad], lr=0.1)) + monkeypatch.setattr( + trainer, + "_get_scheduler", + lambda optimizer: (torch.optim.lr_scheduler.StepLR(optimizer, step_size=1), False), + ) + monkeypatch.setattr(trainer, "_get_scaler", lambda *args, **kwargs: object()) + monkeypatch.setattr(trainer, "_configure_dataloaders", lambda train_dataset, val_dataset: ([], [], [0], [0])) + monkeypatch.setattr(trainer, "_build_loss", lambda: {}) + monkeypatch.setattr(trainer, "_initialize_ema_model", lambda model: None) + monkeypatch.setattr(trainer, "_wrap_model_for_distributed_training", lambda model: call_order.append("wrap") or model) + monkeypatch.setattr(trainer, "_maybe_compile_model", lambda model: call_order.append("compile") or model) + + trainer._initialize_training() + + assert call_order == ["compile", "wrap"] diff --git a/vesuvius/uv.lock b/vesuvius/uv.lock index b2cdec62e..70575a489 100644 --- a/vesuvius/uv.lock +++ b/vesuvius/uv.lock @@ -1758,21 +1758,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e1/b2/ad5c38b500a6bc12762c0a16529a8123abfa12b23fb937e74357ffad9bca/imagecodecs-2025.8.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d588e51fe0ff35d759c310e209c12c76f5db35893de3e95c8e7fa69d78244ca5", size = 10168537, upload-time = "2025-08-03T06:06:40.788Z" }, { url = "https://files.pythonhosted.org/packages/3c/8e/58a58f851043325f955bfc30eef14a23bc15dd02c343f7bf4028c27fd005/imagecodecs-2025.8.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e13144b6af43bf75024a92f1522af6d8fd0ef61c6e0fcd01ae028961df93963a", size = 25522552, upload-time = "2025-08-03T06:06:44.527Z" }, { url = "https://files.pythonhosted.org/packages/d6/7e/4eb0dc145b7abfbff4139251326946b44d375ae0ab29c775a58adee63a0f/imagecodecs-2025.8.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8e993b0b1601d19ee1aadb10fceb93524d88b874ee728f82d3f9b40bddccd223", size = 26369822, upload-time = "2025-08-03T06:06:48.308Z" }, - { url = "https://files.pythonhosted.org/packages/ee/1b/05baebc43671f7f6d1d169841438cd689a12e257a021eab6fa5ebc00b5e1/imagecodecs-2025.8.2-cp311-cp311-win32.whl", hash = "sha256:a1ec2006cbc25bf273616c90478338d44f37aafa50935dcbc18767775b416722", size = 18567263, upload-time = "2025-08-03T06:06:51.692Z" }, { url = "https://files.pythonhosted.org/packages/10/27/0254f64a1d759f22f2c3add691ef915bba016aa7b3630192bbaa9ce1d65d/imagecodecs-2025.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:369a2e0784787ec2ea56676de5fdb0bdaff3202679f45b27e51261f0960923e1", size = 22686127, upload-time = "2025-08-03T06:06:55.182Z" }, { url = "https://files.pythonhosted.org/packages/72/89/3a43d23e27da6e145148b96886b9f1b32b864cc4dc5f5356e585e2eef863/imagecodecs-2025.8.2-cp311-cp311-win_arm64.whl", hash = "sha256:5fa688d3521edeb192179c9808182f00064fb123af11c989a5d509e36a178ebe", size = 18003465, upload-time = "2025-08-03T06:06:58.387Z" }, { url = "https://files.pythonhosted.org/packages/3b/2e/b129b2bf95a4332b86452dd93829cb6ab1e203d86a488fc8639e9f08db7d/imagecodecs-2025.8.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:32ce9707c85cf90e5fb4d9e076ec78288f0f1c084954376f798397743bc4d4e5", size = 12480416, upload-time = "2025-08-03T06:07:02.511Z" }, { url = "https://files.pythonhosted.org/packages/dc/36/b1d3117c34f5eebfd7a92314042338dce608654173c71b1fb707e416aef5/imagecodecs-2025.8.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7b7cd0c41ea83f249db2f4ee04c7c3f23bd6fbef89a58bdaa7d2bf11421ffc15", size = 10127356, upload-time = "2025-08-03T06:07:04.771Z" }, { url = "https://files.pythonhosted.org/packages/e6/41/9ee8f13c4e3425c1aed0b9c5bca5547b968da20cb60af094fafcb52be05f/imagecodecs-2025.8.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:10f421048206ffb6763b1eb28f2205535801794e9f2c612bbcc219d14b9e44bc", size = 25765528, upload-time = "2025-08-03T06:07:08.044Z" }, { url = "https://files.pythonhosted.org/packages/d6/1a/f12736807a179a4bd5a26387f2edf00cc3f972139b61f2833f6fbe9b9685/imagecodecs-2025.8.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f14a5eb8151e32316028c09212eed1d14334460e8adbbef819c1d8ec6ab7e633", size = 26718026, upload-time = "2025-08-03T06:07:13.54Z" }, - { url = "https://files.pythonhosted.org/packages/9c/cd/b0367017443af901279952817900e7f496998fb11de24d91cccdb4fa3dc3/imagecodecs-2025.8.2-cp312-cp312-win32.whl", hash = "sha256:e5cb491fa698d055643d3519bab36a3da8d4144ac759430464e2f75f945f3325", size = 18546766, upload-time = "2025-08-03T06:07:16.836Z" }, { url = "https://files.pythonhosted.org/packages/c4/97/5eb48564f5fdbe3ef83a07ba45d301d45bf85af382a051de34e0a3d6af85/imagecodecs-2025.8.2-cp312-cp312-win_amd64.whl", hash = "sha256:7371169cff1b98f07a983b28fffd2d334cf40a61fce3c876dda930149081bdeb", size = 22674367, upload-time = "2025-08-03T06:07:20.208Z" }, { url = "https://files.pythonhosted.org/packages/2b/b9/a1a70b5d2250937cd747ceff94ddb9cb70d6f1b9fef9f4164d88c63dd3cb/imagecodecs-2025.8.2-cp312-cp312-win_arm64.whl", hash = "sha256:bf0a97cd58810d5d7ecd983332f1f85a0df991080a774f7caadb1add172165d0", size = 17981844, upload-time = "2025-08-03T06:07:23.219Z" }, { url = "https://files.pythonhosted.org/packages/f2/69/77a5c020030dbba1a566264b9201f091534b4b10e9ac5725b7bd40895a8b/imagecodecs-2025.8.2-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:06a1b54847cbf595ed58879936eabdab990f1780b80d03a8d4edb57e61c768b0", size = 12454348, upload-time = "2025-08-03T06:07:26.135Z" }, { url = "https://files.pythonhosted.org/packages/f9/5d/c5dd7f0706dc48d38deefe4cba05ac78236b662a8ef86f9c0cc569b9db96/imagecodecs-2025.8.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:dbf3abd281034e29034b773407c27eff35dd94832a9b6e3c97490db079e3c0a8", size = 10103114, upload-time = "2025-08-03T06:07:28.657Z" }, { url = "https://files.pythonhosted.org/packages/60/40/8b656577120f758ce4b177917b57c76f15e695ff0e63584f641db2063bbe/imagecodecs-2025.8.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:248bb2ea3e43690196acdb1dae5aa40213dbd00c47255295d57da80770ceeaa7", size = 25602557, upload-time = "2025-08-03T06:07:32.399Z" }, { url = "https://files.pythonhosted.org/packages/8a/de/6c1cf78cc0ecc45d98a0eb0d8920df7b90719f8643c7ed9b1bb700f95890/imagecodecs-2025.8.2-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:52230fcd0c331b0167ccd14d7ce764bb780006b65bf69761d8bde6863419fdbf", size = 26544468, upload-time = "2025-08-03T06:07:36.096Z" }, - { url = "https://files.pythonhosted.org/packages/51/14/bda4974256e31eca65e33446026c91d54d1124aa59ce042fc326836597ec/imagecodecs-2025.8.2-cp313-cp313-win32.whl", hash = "sha256:151f9e0879ed8a025b19fcffbf4785dacf29002d5fec94d318e84ba154ddd54c", size = 18534480, upload-time = "2025-08-03T06:07:39.901Z" }, { url = "https://files.pythonhosted.org/packages/a6/bb/ada8851ee56ab835562fb96f764f055e16a5d43a81ebc95207f6b2f3c1d4/imagecodecs-2025.8.2-cp313-cp313-win_amd64.whl", hash = "sha256:d167c265caaf36f9f090e55b68a7c0ec4e5b436923cdc047358743f450534154", size = 22662134, upload-time = "2025-08-03T06:07:43.845Z" }, { url = "https://files.pythonhosted.org/packages/7c/b6/687ad4e137fe637213540cf71bf6a3957cc48667ce6d96d6d9dcb8409305/imagecodecs-2025.8.2-cp313-cp313-win_arm64.whl", hash = "sha256:37199e19d61c7a6c0cf901859d7a81c4449d18047a15ca9d2ebe17176c8d1b69", size = 17968181, upload-time = "2025-08-03T06:07:47.254Z" }, ] @@ -6439,6 +6436,7 @@ models = [ { name = "huggingface-hub" }, { name = "lxml" }, { name = "monai" }, + { name = "nest-asyncio" }, { name = "nnunetv2" }, { name = "numba" }, { name = "pandas" }, @@ -6541,13 +6539,14 @@ requires-dist = [ { name = "napari", marker = "extra == 'gui'", specifier = ">=0.5.6" }, { name = "nest-asyncio", marker = "extra == 'all'", specifier = ">=1.6.0" }, { name = "nest-asyncio", marker = "extra == 'blending'", specifier = ">=1.6.0" }, + { name = "nest-asyncio", marker = "extra == 'models'", specifier = ">=1.6.0" }, { name = "nest-asyncio", marker = "extra == 'notebooks'", specifier = ">=1.6.0" }, { name = "nnunetv2", marker = "extra == 'all'", directory = "../segmentation/models/arch/nnunet" }, { name = "nnunetv2", marker = "extra == 'models'", directory = "../segmentation/models/arch/nnunet" }, { name = "numba", marker = "extra == 'all'", specifier = ">=0.60.0" }, { name = "numba", marker = "extra == 'models'", specifier = ">=0.60.0" }, - { name = "numcodecs", specifier = ">=0.12.1" }, - { name = "numcodecs", marker = "extra == 'volume-only'", specifier = ">=0.12.1" }, + { name = "numcodecs", specifier = ">=0.12.1,<0.16" }, + { name = "numcodecs", marker = "extra == 'volume-only'", specifier = ">=0.12.1,<0.16" }, { name = "numpy", specifier = ">=2.0.2" }, { name = "numpy", marker = "extra == 'volume-only'", specifier = ">=2.0.2" }, { name = "open3d", marker = "extra == 'all'", specifier = ">=0.18,<0.19" },