Skip to content

clisazo/free_prototype_brain_age

Repository files navigation

Learning age-conditioned brain atlases via free-prototype modeling for brain age prediction

A PyTorch Lightning framework for interpretable brain age (BA) estimation. Unlike traditional prototype models that select reference samples from the training data, this framework directly learns age-conditioned prototypes in latent space and decodes them into anatomically grounded brain atlases.

Overview

This project implements a prototype-based regression model designed to estimate biological brain age from 3D structural MRI images. The key innovation is the use of freely learned prototypes - representative feature vectors corresponding to specific ages - which enable interpretable predictions (ages are predicted as weighted averages of prototype ages)

Requirements

Install from requirements file:

pip install -r requirements.txt

Project Structure

free_prototype_brain_age/
├── README.md                          # This file
├── config.py                          # Configuration/parameter definitions
├── train.py                           # Training script
├── inference.py                       # Inference script
├── models.py                          # Model architectures
├── loss_functions.py                  # Custom loss functions
├── data_module.py                     # PyTorch data module
├── push_prototypes.py                 # Prototype pushing logic
├── utils.py                           # Utility functions (preprocessing, augmentation)
├── data/
    └── dataset.csv                    # Sample dataset CSV

Data Preparation

Dataset Format

Your dataset should be organized in a CSV file with the following columns:

participant_id,session,split,age,T1aff_path,synthseg_path
P001,ses-01,train,45.2,/path/to/mri_t1.nii.gz,/path/to/segmentation.nii.gz
P002,ses-01,val,67.8,/path/to/mri_t1.nii.gz,/path/to/segmentation.nii.gz
...

Column Descriptions:

  • participant_id: Unique subject identifier
  • session: Session/timepoint identifier
  • split: Dataset split (train, val, test, push)
  • age: Target age label (years)
  • T1aff_path: Path to affine-registered T1-weighted MRI (NIfTI format)
  • synthseg_path: Path to segmentation (from Synthseg or similar tool)

Training

python train.py \
    --save_basepath /path/to/output \
    --experiment_name BrainAge_ProtoModel \
    --run_name Run1

Inference

Run Inference on Test Set

python inference.py \
    --csv_path /path/to/dataset.csv \
    --ckpt_path /path/to/best_model.ckpt \
    --save_basepath /path/to/output \
    --split_name test

Inference on Custom Data

python inference.py \
    --csv_path /path/to/custom_dataset.csv \
    --ckpt_path /path/to/trained_model.ckpt \
    --save_basepath /path/to/output \
    --split_name custom

Output Format

Predictions are saved to results/test_predictions.csv:

participant_id,groundtruth_age,predicted_age,session,path
P001,68.5,67.2,ses-01,/path/to/image.nii.gz
P002,45.3,46.1,ses-01,/path/to/image.nii.gz

Contact

For questions or issues, please contact: clara.lisazo@udg.edu


Last Updated: March 2026

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages