This folder contains refactored versions of the provided training (3_*.py) and inference (4_*.py) scripts.
- Single GPU by default (no
torch.distributed.launch, nomp.spawn). - Optional DDP via
torchrun. - Faster dataloading:
pin_memory,persistent_workers,prefetch_factor, and a robustdict_collate. - TensorBoard fixed: step + epoch scalars are written and
flush()is called. - Early stopping, periodic checkpoints (
latest,best, andepochXXXX). - Validation each epoch with a tqdm progress bar (stdout). Logs go to timestamped log files.
- Metrics:
- 3D: MAE, PSNR (averaged over modalities)
- 2D slice-wise (added): SSIM, PSNR, MAE (averaged over slices, configurable stride)
- Optional LPIPS proxy (MONAI PerceptualLoss in fake-3D mode) in the supervised UNet script.
Supervised UNet baseline:
python 3_train_ours.py --config configs/config_unet3d_rcg.yamlRCG (needs pretrained RDM ckpt):
python 3_train_rcg_optimized_metrics_final.py --config configs/config_unet3d_rcg.yaml --rdm-ckpt /path/to/autoencoder.ptInference:
python 4_inference_ours.py --config configs/config_unet3d_rcg.yaml --ckpt /path/to/ckpt_best.pt- Single GPU default: use
python ... - Multi-GPU optional: use
torchrun --standalone --nproc_per_node=2 ... - Fast smoke test: add
--max-train-batches 10 --max-val-batches 10(runs only 10 batches per epoch / validation)
OURS (UNet raw 3D)
python 3_train_ours.py --config configs/config_unet3d_rcg.yaml
python 3_train_ours.py --config configs/config_unet3d_rcg.yaml --max-train-batches 10 --max-val-batches 10
torchrun --standalone --nproc_per_node=2 3_train_ours.py --config configs/config_unet3d_rcg.yamlUNEST
python 3_train_unest.py --config configs/config_unest.yaml
python 3_train_unest.py --config configs/config_unest.yaml --max-train-batches 10 --max-val-batches 10
torchrun --standalone --nproc_per_node=2 3_train_unest.py --config configs/config_unest.yamlMedSyn
python 3_train_medsyn.py --config configs/config_medsyn.yaml
python 3_train_medsyn.py --config configs/config_medsyn.yaml --max-train-batches 10 --max-val-batches 10
torchrun --standalone --nproc_per_node=2 3_train_medsyn.py --config configs/config_medsyn.yamlI2I-Mamba
python 3_train_i2imamba.py --config configs/config_i2imamba.yaml
python 3_train_i2imamba.py --config configs/config_i2imamba.yaml --max-train-batches 10 --max-val-batches 10
torchrun --standalone --nproc_per_node=2 3_train_i2imamba.py --config configs/config_i2imamba.yamlVanilla diffusion
python 3_train_vanilla_diffusion.py --config configs/config_unet3d_rcg.yaml
python 3_train_vanilla_diffusion.py --config configs/config_unet3d_rcg.yaml --max-train-batches 10 --max-val-batches 10
torchrun --standalone --nproc_per_node=2 3_train_vanilla_diffusion.py --config configs/config_unet3d_rcg.yamlRCG (requires pretrained RDM checkpoint)
python 3_train_rcg_optimized_metrics_final.py --config configs/config_unet3d_rcg.yaml --rdm-ckpt /path/to/autoencoder.pt
python 3_train_rcg_optimized_metrics_final.py --config configs/config_unet3d_rcg.yaml --rdm-ckpt /path/to/autoencoder.pt --max-train-batches 10 --max-val-batches 10
torchrun --standalone --nproc_per_node=2 3_train_rcg_optimized_metrics_final.py --config configs/config_unet3d_rcg.yaml --rdm-ckpt /path/to/autoencoder.ptAll inference scripts expect --ckpt pointing to your trained checkpoint.
python 4_inference_ours.py --config configs/config_unet3d_rcg.yaml --ckpt /path/to/ckpt_best.pt
python 4_inference_unest.py --config configs/config_unest.yaml --ckpt /path/to/ckpt_best.pt
python 4_inference_medsyn.py --config configs/config_medsyn.yaml --ckpt /path/to/ckpt_best.pt
python 4_inference_i2imamba.py --config configs/config_i2imamba.yaml --ckpt /path/to/ckpt_best.pt
python 4_inference_cyclegan.py --config configs/config_cyclegan.yaml --ckpt /path/to/ckpt_best.pt
python 4_inference_medsyn_updated.py \
--config configs/config_medsyn.yaml \
--ckpt logs/medsyn_ep200_noaug/checkpoints/ckpt_best.pt \
--save-dir ./predictions_dump \
--save-limit 20 \
--save-every 1python -m utils.check_nifti_integrity --max 0 --workers 16