Skip to content

vesuvius: add guided Dinovol modes, pixelshuffle pretrained decoder, and MedNeXt architectures#806

Draft
giorgioangel wants to merge 59 commits intomainfrom
pr/vesuvius-guided-dinovol-mednext-clean
Draft

vesuvius: add guided Dinovol modes, pixelshuffle pretrained decoder, and MedNeXt architectures#806
giorgioangel wants to merge 59 commits intomainfrom
pr/vesuvius-guided-dinovol-mednext-clean

Conversation

@giorgioangel
Copy link
Copy Markdown
Member

Summary

Adds guided volumetric Dinovol segmentation features, a frozen-backbone PixelShuffle decoder, and MedNeXt v1/v2 architectures to vesuvius, while keeping training and inference routed through the existing CLI, NetworkFromConfig, and checkpoint loader paths.

Included

  • Guided Dinovol input gating, encoder-feature gating, skip concatenation, and direct TokenBook segmentation
  • Guided compile/runtime/debug improvements, including compile policy controls and W&B/debug preview support
  • Frozen pretrained_backbone protection from global InitWeights_He
  • model_config.pretrained_decoder_type: pixelshuffle_conv
  • model_config.architecture_type: mednext_v1 and mednext_v2
  • MedNeXt benchmark coverage
  • Config-plumbed AdamW epsilon support

Final retained behavior

  • The PixelShuffle pretrained-backbone decoder does not use the later experimental input skip
  • The retained PixelShuffle structure is:
    • per-stage Conv -> PixelShuffle -> Conv -> GroupNorm -> GELU
    • final head 3x3x3 -> GroupNorm -> GELU -> 3x3x3 -> 1x1x1 logits
  • mednext_v2 is implemented as a paper-derived extension over vendored MedNeXt v1, with explicit preset selection via mednext_model_id

Validation

  • uv run --extra models --extra tests pytest tests/models/build/test_guided_network.py tests/models/build/test_mednext_shapes.py tests/models/build/test_primus_shapes.py tests/models/training/test_guided_trainer.py tests/models/training/test_mednext_trainer.py tests/models/training/test_base_trainer.py tests/models/configuration/test_config_manager.py tests/models/configuration/test_ps256_config_compat.py -q
  • Result on the clean PR branch: 139 passed

Benchmark snapshots

Guided Dinovol benchmark on the clean PR branch:

  • 32^3: baseline train step 61.21 ms; direct segmentation 11.57 ms; feature encoder 14.52 ms; skip concat 23.18 ms; input gating 62.65 ms
  • 64^3: baseline train step 16.35 ms; direct segmentation 9.68 ms; feature encoder 16.92 ms; skip concat 24.12 ms; input gating 19.32 ms

MedNeXt benchmark on the clean PR branch:

  • 128^3: UNet train step 118.62 ms; mednext_v1 B 248.90 ms; mednext_v2 L 1160.38 ms
  • 128^3: mednext_v2 L width2 forward runs but train-step OOMs; mednext_v2 B startup OOMs on the local RTX 4090
  • 192^3: only the UNet baseline remains trainable locally; the current MedNeXt variants OOM in this setup

Caveats

  • Do not treat mednext_v2 as upstream-official nnUNet code; it is a paper-derived extension over vendored MedNeXt v1
  • Remote MedNeXt training recipes are still exploratory; VRAM/stability findings are not presented as solved
  • Local working files notes.md and implementation.md were used to reconcile retained behavior vs reverted experiments, but they are outside the villa git repo and are not part of this PR

@vercel
Copy link
Copy Markdown

vercel bot commented Apr 4, 2026

The latest updates on your projects. Learn more about Vercel for GitHub.

1 Skipped Deployment
Project Deployment Actions Updated (UTC)
scrollprize-org Ignored Ignored Preview Apr 5, 2026 2:23pm

Request Review

Copy link
Copy Markdown
Member Author

Follow-up fix pushed in 941daa8ed:

  • preserve resolved per-target MedNeXt decoder layout in final_config so mixed shared/separate decoder checkpoints rebuild exactly
  • rebuild train.py checkpoints with enable_deep_supervision preserved in the inference loader
  • wrap DS-enabled train.py models for plain inference outputs so strict checkpoint load still works while inference receives highest-resolution logits
  • add regressions for:
    • DS-enabled MedNeXt checkpoint reload through Inferer
    • mixed shared/separate MedNeXt decoder checkpoint reload

Validation rerun on the clean PR branch:

  • uv run --extra models --extra tests pytest tests/models/build/test_guided_network.py tests/models/build/test_mednext_shapes.py tests/models/build/test_primus_shapes.py tests/models/training/test_guided_trainer.py tests/models/training/test_mednext_trainer.py tests/models/training/test_base_trainer.py tests/models/configuration/test_config_manager.py tests/models/configuration/test_ps256_config_compat.py -q
  • Result: 141 passed

@giorgioangel
Copy link
Copy Markdown
Member Author

@codex review

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 941daa8ed2

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +980 to +984
if self.guide_freeze:
with torch.inference_mode():
frozen_features = self.guide_backbone(x)[0]
return frozen_features.clone()
return self.guide_backbone(x)[0]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Fix non-frozen guide feature path

When guide_freeze is false, this function skips the if block so frozen_features is never assigned, but return frozen_features.clone() still runs unconditionally. That raises UnboundLocalError on the first guided forward pass and makes trainable guidance unusable. The following return self.guide_backbone(x)[0] is unreachable, so the intended non-frozen branch currently never executes.

Useful? React with 👍 / 👎.

Comment on lines 1908 to +1911
scaler.step(optimizer)
scaler.update()
optimizer_stepped = True
if should_time_optimizer:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Update EMA weights after optimizer step

EMA support is wired into config/loading, but after scaler.step(optimizer) there is no call to _update_ema_model(model) (and no other call site in BaseTrainer). In runs with ema_enabled: true, the EMA copy stays at initialization and never tracks training weights, so any EMA validation/checkpoint flow silently uses stale parameters.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant