Skip to content

brain-tools/brain-deform

Repository files navigation

Brain-Deform Library

This package provides elastic deformation methods, based on non-linear image registration. Everything runs on the GPU, so that it can be used in near real-time for your PyTorch Lightning model. The setup aims to be as minimal as possible, you only need to use our custom BrainDataModule which is based on pytoch-lightning's DataModule class.

At the moment, only FSL-style FNIRT warp coefficient files are supported.

You can

  • apply linear/nonlinear transformations to MNI space
  • apply affine transformations for data augmentation
  • apply partial subject-to-subject warps for data augmentation
  • apply partial subject-to-subject intensity transformation for data augmentation
  • implement contrastive/semisupervised learning

Here is a minimal pytorch-lightning example. Here is an example of subject-to-subject cross-registration.

Installation

First install pytorch and pytorch_lightning, then install brain-deform via

pip -v install git+https://github.com/brain-tools/brain-deform.git

or, for development, clone and install editable

git clone https://github.com/brain-tools/brain-deform.git
pip -v install -e brain-deform/

CUDA toolkit is necessary to compile, GPU is necessary to run - no CPU implementation is available (yet).

Usage

Step 1: Define your model

You can easily add the library to your existing project by using the provided data module BrainDataModule. It will take care of all the heavy lifting, so that your training_step is provided with an already augmented batch.

import torch
import pytorch_lightning as pl
from pytorch_lightning.cli import LightningCLI
from brain_deform.lightning import BrainDataModule

class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
    
    # ...

    def training_step(self, batch, batch_idx):
        # Batch contains the registered image, the registered and augmented image, and the label
        (x, x_augmented, y), _ = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        self.log("train_loss", loss)
        return loss
    
    # ...

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

cli = LightningCLI(LitModel, BrainDataModule)

Step 2: Write a configuration file

To configure what the data module does under the hood, you need to provide a configuration file. See example. The config file contains the usual pytorch-lightning configuration plus our data module configuration under the section data.

data:
    # data_table and split file paths
    data_table_path: data.csv
    split_path: split.json

    # column names in data_table file
    index_column: eid
    t1_column: t1
    coefs_to_mni_column: coefs_to_mni
    target_column: target

    # provide additional image from axilliary set
    auxiliary_image: False

    # provide images in linear or nonlinear MNI space
    registration: "linear" 

    # affine transformations for data augmentation
    translation: 0 # max voxels
    rotation:  0 # max degrees
    scale:  0 # max growth/shrink factor (0-1)
    flip: False # flip hemispheres
    
    # advanced augmentations
    cross_warp:  0 # max subject-to-subject warp factor (0-1)
    random_warp:  0 # max random warp factor (0-1)
    cross_intensity:  0 # max subject-to-subject intensity transformation factor (0-1)

  # ...

The arguments data_table_path and split_path define which data you plan on using and how it will be split into training, validation and test sets. The table referred to by data_table_path is a csv file containing at least an index (EID), a T1 FOV image, a B-spline coefficient file, and a target label. The split referred to by split_path is a json file containing arrays of EIDs for the corresponding set (train, val, test, augmentation - targets for subject-to-subject augmentation methods, auxiliary - for the optional second "auxiliary" image used for contrastive/semisupervised learning).

For each batch, the module returns (batch_main, batch_aux), with each batch consisting of (registered_image, augmented_registered_image, label). If auxiliary_image=False, batch_aux is simply (None, None, None)

Click here for more information on how to use the LightningCLI. For a more concrete minimal working example, see the example folder.

Step 3: Train

First, confirm visually that images look fine. First column shows the registered_image, followed by the augmented_registered_image for each modality.

python tools/plot_batch.py -c config.yaml -o batch.png

That's it! Now you can train your model.

python main.py fit -c config.yaml

Running tests (WIP)

Tests rely on the external testdata. The following will download HCP Unrelated 100 Subset and NFBS Skull-Stripped repository.

cd testdata
./make_hcp.sh
./make_nfbs.sh
cd ..
pytest tests

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Packages

No packages published

Contributors 3

  •  
  •  
  •