Let's see the paper.
This repository contains code to preprocess, train, and evaluate a 3D U-Net model on the BraTS (Brain Tumor Segmentation) dataset. The main script, 3dunet.py, implements data loading, preprocessing (cropping, normalization), model definition, training loop with Dice loss, validation, learning rate scheduling, checkpointing, and uploading to Hugging Face.
- Python 3.10 or higher
- NVIDIA GPU with CUDA (for training)
Install dependencies via pip:
pip install -r requirements.txtrequirements.txt should include:
torch>=1.8.0
numpy
pandas
scikit-learn
nibabel
opencv-python
matplotlib
tqdm
huggingface_hub
kagglehub
kaggle-secrets
tensorboard
-
Kaggle authentication: Ensure you have a Kaggle API token (
kaggle.json) in~/.kaggle/or set up viakaggle-secrets:export KAGGLE_USERNAME=your_username export KAGGLE_KEY=your_key
-
Hugging Face: To upload checkpoints, set your token:
export HUGGING_FACE_TOKEN=hf_xxx
Open preprocessing.py in your editor or notebook environment. Adjust configuration variables at the top of the file:
data_dir = 'data/BraTS20_Training'
batch_size = 2
num_epochs = 300
slice_range = (40, 190)
modalities = ['flair', 't1', 't1ce', 't2']
learning_rate = 4e-5
weight_decay = 5e-6
use_amp = True # Mixed-precision training
checkpoint_dir = 'models/'Then run:
python 3dunet.pyThis will:
- Download the BraTS dataset from Kaggle (via
kagglehub). - Preprocess each patient scan (crop, z-score normalization).
- Create PyTorch
DatasetandDataLoaderobjects. - Train a 3D U-Net with Combined Dice + Cross-Entropy loss.
- Validate on hold-out set and apply early stopping.
- Save model checkpoints every 5 epochs and upload to Hugging Face.
The script defines the BrainTumorDataset class:
class BrainTumorDataset(Dataset):
def __init__(self, data_dir, modalities, slice_range, is_train=True):
# data_dir: directory with patient subfolders
# modalities: list of ['flair','t1','t1ce','t2']
# slice_range: tuple(start, end) along axial axis
# is_train: whether to load segmentation masks- Loading: Reads
.niifiles vianibabel. - Cropping: Selects slices in
slice_rangealong z-axis. - Normalization: Z-score per scan.
- Mask processing: Converts raw segmentation masks into three channels (NCR/Net, ED, ET).
- Visualization:
visualize_sample(idx)plots each modality and overlayed masks.
Example:
from torch.utils.data import DataLoader
from preprocessing import BrainTumorDataset
dataset = BrainTumorDataset(
data_dir='data/BraTS20_Training',
modalities=['flair','t1','t1ce','t2'],
slice_range=(40,190),
is_train=True
)
loader = DataLoader(dataset, batch_size=2, shuffle=True)
for imgs, masks in loader:
# imgs: (batch, 4, H, W, D)
# masks: (batch, 3, H, W, D)
break- Model:
ImprovedUNet3Ddefined at the top of the script. - Loss:
CombinedLoss(Dice + CrossEntropy). - Optimizer: AdamW with scheduler (
ReduceLROnPlateau). - Metrics: Dice coefficient per class.
- EarlyStopping: stops after 10 epochs without improvement.
- Checkpointing: Saves every 5 epochs to
models/and uploads to Hugging Face.
To visualize training metrics:
tensorboard --logdir logs/This project is licensed under the MIT License. See LICENSE for details.
- BraTS Challenge for the dataset.
- U-Net 3D implementations.
- Kaggleforum and PyTorch community for examples and support.