This package provides elastic deformation methods, based on non-linear image registration. Everything runs on the GPU, so that it can be used in near real-time for your PyTorch Lightning model. The setup aims to be as minimal as possible, you only need to use our custom BrainDataModule which is based on pytoch-lightning's DataModule class.
At the moment, only FSL-style FNIRT warp coefficient files are supported.
You can
- apply linear/nonlinear transformations to MNI space
- apply affine transformations for data augmentation
- apply partial subject-to-subject warps for data augmentation
- apply partial subject-to-subject intensity transformation for data augmentation
- implement contrastive/semisupervised learning
Here is a minimal pytorch-lightning example. Here is an example of subject-to-subject cross-registration.
First install pytorch and pytorch_lightning, then install brain-deform via
pip -v install git+https://github.com/brain-tools/brain-deform.gitor, for development, clone and install editable
git clone https://github.com/brain-tools/brain-deform.git
pip -v install -e brain-deform/CUDA toolkit is necessary to compile, GPU is necessary to run - no CPU implementation is available (yet).
You can easily add the library to your existing project by using the provided data module BrainDataModule. It will take care of all the heavy lifting, so that your training_step is provided with an already augmented batch.
import torch
import pytorch_lightning as pl
from pytorch_lightning.cli import LightningCLI
from brain_deform.lightning import BrainDataModule
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
# ...
def training_step(self, batch, batch_idx):
# Batch contains the registered image, the registered and augmented image, and the label
(x, x_augmented, y), _ = batch
y_hat = self.forward(x)
loss = self.criterion(y_hat, y)
self.log("train_loss", loss)
return loss
# ...
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
cli = LightningCLI(LitModel, BrainDataModule)To configure what the data module does under the hood, you need to provide a configuration file. See example. The config file contains the usual pytorch-lightning configuration plus our data module configuration under the section data.
data:
# data_table and split file paths
data_table_path: data.csv
split_path: split.json
# column names in data_table file
index_column: eid
t1_column: t1
coefs_to_mni_column: coefs_to_mni
target_column: target
# provide additional image from axilliary set
auxiliary_image: False
# provide images in linear or nonlinear MNI space
registration: "linear"
# affine transformations for data augmentation
translation: 0 # max voxels
rotation: 0 # max degrees
scale: 0 # max growth/shrink factor (0-1)
flip: False # flip hemispheres
# advanced augmentations
cross_warp: 0 # max subject-to-subject warp factor (0-1)
random_warp: 0 # max random warp factor (0-1)
cross_intensity: 0 # max subject-to-subject intensity transformation factor (0-1)
# ...The arguments data_table_path and split_path define which data you plan on using and how it will be split into training, validation and test sets.
The table referred to by data_table_path is a csv file containing at least an index (EID), a T1 FOV image, a B-spline coefficient file, and a target label.
The split referred to by split_path is a json file containing arrays of EIDs for the corresponding set (train, val, test, augmentation - targets for subject-to-subject augmentation methods, auxiliary - for the optional second "auxiliary" image used for contrastive/semisupervised learning).
For each batch, the module returns (batch_main, batch_aux), with each batch consisting of (registered_image, augmented_registered_image, label). If auxiliary_image=False, batch_aux is simply (None, None, None)
Click here for more information on how to use the LightningCLI. For a more concrete minimal working example, see the example folder.
First, confirm visually that images look fine. First column shows the registered_image, followed by the augmented_registered_image for each modality.
python tools/plot_batch.py -c config.yaml -o batch.pngThat's it! Now you can train your model.
python main.py fit -c config.yamlTests rely on the external testdata. The following will download HCP Unrelated 100 Subset and NFBS Skull-Stripped repository.
cd testdata
./make_hcp.sh
./make_nfbs.sh
cd ..pytest tests