Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
.DS_Store

__pycache__/
dist/
*.egg-info/

datasets/eeg_train/
datasets/eeg_test/
datasets/*.pth
datasets/**/*.JPEG
exps/
exp/
results/
wandb/

pretrains/models/v1-5-pruned.ckpt
pretrains/eeg-pretrain/*.pth
pretrains/generation/*.pth
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "code/TimeMAE"]
path = code/TimeMAE
url = git@github.com:Mingyue-Cheng/TimeMAE.git
29 changes: 19 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@ This document introduces the precesedures required for replicating the results i
## Abstract
This paper introduces DreamDiffusion, a novel method for generating high-quality images directly from brain electroencephalogram (EEG) signals, without the need to translate thoughts into text. DreamDiffusion leverages pre-trained text-to-image models and employs temporal masked signal modeling to pre-train the EEG encoder for effective and robust EEG representations. Additionally, the method further leverages the CLIP image encoder to provide extra supervision to better align EEG, text, and image embeddings with limited EEG-image pairs. Overall, the proposed method overcomes the challenges of using EEG signals for image generation, such as noise, limited information, and individual differences, and achieves promising results. Quantitative and qualitative results demonstrate the effectiveness of the proposed method as a significant step towards portable and low-cost "thoughts-to-image", with potential applications in neuroscience and computer vision.


## Overview
![pipeline](assets/eeg_pipeline.png)


The **datasets** folder and **pretrains** folder are not included in this repository.
The **datasets** folder is not included in this repository. Neither are the pretrained models' checkpoints.
Please download them from [eeg](https://github.com/perceivelab/eeg_visual_classification) and put them in the root directory of this repository as shown below. We also provide a copy of the Imagenet subset [imagenet](https://drive.google.com/file/d/1y7I9bG1zKYqBM94odcox_eQjnP9HGo9-/view?usp=drive_link).

For Stable Diffusion, we just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). You want the file ["v1-5-pruned.ckpt"](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main).
For Stable Diffusion, we just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5). You want the file ["v1-5-pruned.ckpt"](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5).

File path | Description
```
Expand All @@ -29,10 +28,10 @@ File path | Description
┃ ┗ 📜 v1-5-pruned.ckpt

┣ 📂 generation
┃ ┗ 📜 checkpoint_best.pth
┃ ┗ 📜 checkpoint.pth

┣ 📂 eeg_pretain
┃ ┗ 📜 checkpoint.pth (pre-trained EEG encoder)
┃ ┗ 📜 checkpoint-eeg-500.pth (pre-trained EEG encoder)

/datasets
┣ 📂 imageNet_images (subset of Imagenet)
Expand Down Expand Up @@ -74,16 +73,24 @@ conda env create -f env.yaml
conda activate dreamdiffusion
```

## Download checkpoints
## Download checkpoint after pre-training on EEG data

We also checkpoints to run the finetuing and decoding directly.
No checkpoint provided

To make the code work (poorly), please find a checkpoint after pretraining the EEG Autoencoder in [this repository](https://github.com/alinvdu/reproduce-dream-diffusion/). They propose to use only the EEGs from the (EEG, Image) pairs but it makes poor results as this is few data.

And put it in ```pretrains/eeg-pretrain/checkpoint-eeg-500.pth```

## Pre-training on EEG data

remember to move generated checkpoints from the ```results``` directory to the
appropriate location in the ```pretrains``` directory

You can download the dataset for pretraining from here [MOABB](https://github.com/NeuroTechX/moabb).

The datasets used to get the papers's results aren't specified, but you can put
any EEG dataset, put them in ```datasets/eeg_train/``` and ```dataset/eeg_test/``` in a `.npy` format.

To perform the pre-training from scratch with defaults parameters, run
```sh
python3 code/stageA1_eeg_pretrain.py
Expand All @@ -101,18 +108,20 @@ Multiple-GPU (DDP) training is supported, run with
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS code/stageA1_eeg_pretrain.py
```

## Finetune the Stable Diffusion with Pre-trained EEG Encoder

remember to move generated checkpoints from the ```results``` directory to the
appropriate location in the ```pretrains``` directory

## Finetune the Stable Diffusion with Pre-trained EEG Encoder
In this stage, the cross-attention heads and pre-trained EEG encoder will be jointly optimized with EEG-image pairs.

```sh
python3 code/eeg_ldm.py --dataset EEG --num_epoch 300 --batch_size 4 --pretrain_mbm_path ../dreamdiffusion/pretrains/eeg_pretrain/checkpoint.pth
python3 code/eeg_ldm.py --dataset EEG --num_epoch 300 --batch_size 4 --pretrain_mbm_path ./pretrains/eeg_pretrain/checkpoint-eeg-500.pth
```


## Generating Images with Trained Checkpoints
Run this stage with our provided checkpoints: Here we provide a checkpoint [ckpt](https://drive.google.com/file/d/1Ygplxe1TB68-aYu082bjc89nD8Ngklnc/view?usp=drive_link), which you may want to try.

```sh
python3 code/gen_eval_eeg.py --dataset EEG --model_path pretrains/generation/checkpoint.pth
```
Expand Down
1 change: 1 addition & 0 deletions code/TimeMAE
Submodule TimeMAE added at f4401b
35 changes: 17 additions & 18 deletions code/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,24 @@ def __init__(self):
self.min_lr = 0.
self.weight_decay = 0.05
self.num_epoch = 500
self.warmup_epochs = 40
self.batch_size = 100
self.warmup_epochs = 10
self.batch_size = 128
self.clip_grad = 0.8

# Model Parameters
self.mask_ratio = 0.1
self.mask_ratio = 0.5
self.patch_size = 4 # 1
self.embed_dim = 1024 #256 # has to be a multiple of num_heads
self.decoder_embed_dim = 512 #128
self.depth = 24
self.num_heads = 16
self.decoder_num_heads = 16
self.mlp_ratio = 1.0
self.mlp_ratio = 0.8

# Project setting
self.root_path = '../dreamdiffusion/'
self.output_path = '../dreamdiffusion/exps/'
self.root_path = './'
self.output_path = './exps/'
self.seed = 2022
self.roi = 'VC'
self.aug_times = 1
self.num_sub_limit = None
self.include_hcp = True
Expand All @@ -54,15 +53,15 @@ class Config_EEG_finetune(Config_MBM_finetune):
def __init__(self):

# Project setting
self.root_path = '../dreamdiffusion/'
self.root_path = './'
# self.root_path = '.'
self.output_path = '../dreamdiffusion/exps/'
self.output_path = './exps/'

self.eeg_signals_path = os.path.join(self.root_path, 'datasets/eeg_5_95_std.pth')
self.splits_path = os.path.join(self.root_path, 'datasets/block_splits_by_image_all.pth')

self.dataset = 'EEG'
self.pretrain_mbm_path = '../dreamdiffusion/pretrains/eeg_pretrain/checkpoint.pth'
self.pretrain_mbm_path = './pretrains/eeg_pretrain/checkpoint.pth'

self.include_nonavg_test = True

Expand All @@ -71,7 +70,7 @@ def __init__(self):
self.lr = 5.3e-5
self.weight_decay = 0.05
self.num_epoch = 15
self.batch_size = 16 if self.dataset == 'GOD' else 4
self.batch_size = 32 if self.dataset == 'GOD' else 16
self.mask_ratio = 0.5
self.accum_iter = 1
self.clip_grad = 0.8
Expand All @@ -89,8 +88,8 @@ class Config_Generative_Model:
def __init__(self):
# project parameters
self.seed = 2022
self.root_path = '../dreamdiffusion/'
self.output_path = '../dreamdiffusion/exps/'
self.root_path = './'
self.output_path = './exps/'

self.eeg_signals_path = os.path.join(self.root_path, 'datasets/eeg_5_95_std.pth')
self.splits_path = os.path.join(self.root_path, 'datasets/block_splits_by_image_single.pth')
Expand All @@ -111,7 +110,7 @@ def __init__(self):

np.random.seed(self.seed)
# finetune parameters
self.batch_size = 5 if self.dataset == 'GOD' else 25
self.batch_size = 32 if self.dataset == 'GOD' else 16
self.lr = 5.3e-5
self.num_epoch = 500

Expand Down Expand Up @@ -139,8 +138,8 @@ class Config_Cls_Model:
def __init__(self):
# project parameters
self.seed = 2022
self.root_path = '../dreamdiffusion/'
self.output_path = '../dreamdiffusion/exps/'
self.root_path = './'
self.output_path = './exps/'

# self.eeg_signals_path = os.path.join(self.root_path, 'datasets/eeg_5_95_std.pth')
self.eeg_signals_path = os.path.join(self.root_path, 'datasets/eeg_14_70_std.pth')
Expand All @@ -162,7 +161,7 @@ def __init__(self):

np.random.seed(self.seed)
# finetune parameters
self.batch_size = 5 if self.dataset == 'GOD' else 25
self.batch_size = 32 if self.dataset == 'GOD' else 16
self.lr = 5.3e-5
self.num_epoch = 50

Expand All @@ -181,4 +180,4 @@ def __init__(self):
self.HW = None
# resume check util
self.model_meta = None
self.checkpoint_path = None
self.checkpoint_path = None
33 changes: 17 additions & 16 deletions code/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,24 +106,25 @@ def is_npy_ext(fname: Union[str, Path]) -> bool:
return f'{ext}' == 'npy'# type: ignore

class eeg_pretrain_dataset(Dataset):
def __init__(self, path='../dreamdiffusion/datasets/mne_data/', roi='VC', patch_size=16, transform=identity, aug_times=2,
num_sub_limit=None, include_kam=False, include_hcp=True):
def __init__(self, path):
super(eeg_pretrain_dataset, self).__init__()
data = []
images = []
self.input_paths = [str(f) for f in sorted(Path(path).rglob('*')) if is_npy_ext(f) and os.path.isfile(f)]

assert len(self.input_paths) != 0, 'No data found'

self.inputs = [np.load(data_path) for data_path in self.input_paths]
self.real_input = np.concatenate(self.inputs, axis=0)

self.data_len = 512
self.data_chan = 128

def __len__(self):
return len(self.input_paths)
return len(self.real_input)

def __getitem__(self, index):
data_path = self.input_paths[index]

data = np.load(data_path)
data = self.real_input[index]

if data.shape[-1] > self.data_len:
idx = np.random.randint(0, int(data.shape[-1] - self.data_len)+1)
Expand Down Expand Up @@ -240,7 +241,7 @@ class EEGDataset_r(Dataset):
# Constructor
def __init__(self, eeg_signals_path, image_transform=identity):

self.imagenet = '/apdcephfs/share_1290939/0_public_datasets/imageNet_2012/train/'
self.imagenet = './datasets/imageNet_2012/train/'
self.image_transform = image_transform
self.num_voxels = 440
self.data_len = 512
Expand Down Expand Up @@ -276,7 +277,7 @@ def __init__(self, eeg_signals_path, image_transform=identity):
self.data = loaded['dataset']
self.labels = loaded["labels"]
self.images = loaded["images"]
self.imagenet = '/apdcephfs/share_1290939/0_public_datasets/imageNet_2012/train/'
self.imagenet = './datasets/imageNet_images/'
self.image_transform = image_transform
self.num_voxels = 440
# Compute size
Expand Down Expand Up @@ -316,7 +317,7 @@ def __init__(self, eeg_signals_path, image_transform=identity, subject = 4):
self.data = loaded['dataset']
self.labels = loaded["labels"]
self.images = loaded["images"]
self.imagenet = '/apdcephfs/share_1290939/0_public_datasets/imageNet_2012/train/'
self.imagenet = './datasets/imageNet_images'
self.image_transform = image_transform
self.num_voxels = 440
self.data_len = 512
Expand Down Expand Up @@ -386,12 +387,12 @@ def __getitem__(self, i):
return self.dataset[self.split_idx[i]]


def create_EEG_dataset(eeg_signals_path='../dreamdiffusion/datasets/eeg_5_95_std.pth',
splits_path = '../dreamdiffusion/datasets/block_splits_by_image_single.pth',
# splits_path = '../dreamdiffusion/datasets/block_splits_by_image_all.pth',
def create_EEG_dataset(eeg_signals_path='./datasets/eeg_5_95_std.pth',
splits_path = './datasets/block_splits_by_image_single.pth',
# splits_path = './datasets/block_splits_by_image_all.pth',
image_transform=identity, subject = 0):
# if subject == 0:
# splits_path = '../dreamdiffusion/datasets/block_splits_by_image_all.pth'
# splits_path = './datasets/block_splits_by_image_all.pth'
if isinstance(image_transform, list):
dataset_train = EEGDataset(eeg_signals_path, image_transform[0], subject )
dataset_test = EEGDataset(eeg_signals_path, image_transform[1], subject)
Expand All @@ -405,9 +406,9 @@ def create_EEG_dataset(eeg_signals_path='../dreamdiffusion/datasets/eeg_5_95_std



def create_EEG_dataset_r(eeg_signals_path='../dreamdiffusion/datasets/eeg_5_95_std.pth',
# splits_path = '../dreamdiffusion/datasets/block_splits_by_image_single.pth',
splits_path = '../dreamdiffusion/datasets/block_splits_by_image_all.pth',
def create_EEG_dataset_r(eeg_signals_path='./datasets/eeg_5_95_std.pth',
# splits_path = './datasets/block_splits_by_image_single.pth',
splits_path = './datasets/block_splits_by_image_all.pth',
image_transform=identity):
if isinstance(image_transform, list):
dataset_train = EEGDataset_r(eeg_signals_path, image_transform[0])
Expand Down
4 changes: 2 additions & 2 deletions code/eeg_ldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import copy

# own code
from config import Config_Generative_Model
from config import Config_Generative_Model, Config_MBM_EEG
from dataset import create_EEG_dataset
from dc_ldm.ldm_for_eeg import eLDM
from eval_metrics import get_similarity_metric
Expand Down Expand Up @@ -178,7 +178,7 @@ def get_args_parser():
parser = argparse.ArgumentParser('Double Conditioning LDM Finetuning', add_help=False)
# project parameters
parser.add_argument('--seed', type=int)
parser.add_argument('--root_path', type=str, default = '../dreamdiffusion/')
parser.add_argument('--root_path', type=str, default = './')
parser.add_argument('--pretrain_mbm_path', type=str)
parser.add_argument('--checkpoint_path', type=str)
parser.add_argument('--crop_ratio', type=float)
Expand Down
11 changes: 6 additions & 5 deletions code/gen_eval_eeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __call__(self, img):
def get_args_parser():
parser = argparse.ArgumentParser('Double Conditioning LDM Finetuning', add_help=False)
# project parameters
parser.add_argument('--root', type=str, default='../dreamdiffusion/')
parser.add_argument('--root', type=str, default='./')
parser.add_argument('--dataset', type=str, default='GOD')
parser.add_argument('--model_path', type=str)

Expand All @@ -68,9 +68,10 @@ def get_args_parser():
config = sd['config']
# update paths
config.root_path = root
config.pretrain_mbm_path = '../dreamdiffusion/results/eeg_pretrain/19-02-2023-08-48-17/checkpoints/checkpoint.pth'
config.pretrain_gm_path = '../dreamdiffusion/pretrains/'
print(config.__dict__)
config.eeg_signals_path = "./datasets/eeg_5_95_std.pth"
config.splits_path = "./datasets/block_splits_by_image_single.pth"
config.pretrain_mbm_path = "./pretrains/eeg-pretrain/checkpoint-eeg-500.pth"
config.pretrain_gm_path = './pretrains/'

output_path = os.path.join(config.root_path, 'results', 'eval',
'%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")))
Expand All @@ -91,7 +92,7 @@ def get_args_parser():
])


splits_path = "../dreamdiffusion/datasets/block_splits_by_image_single.pth"
splits_path = "./datasets/block_splits_by_image_single.pth"
dataset_train, dataset_test = create_EEG_dataset(eeg_signals_path = config.eeg_signals_path, splits_path = splits_path,
image_transform=[img_transform_train, img_transform_test], subject = 4)
num_voxels = dataset_test.dataset.data_len
Expand Down
2 changes: 1 addition & 1 deletion code/sc_mbm/mae_for_eeg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
sys.path.append('../dreamdiffusion/code/')
sys.path.append('./code/')
# print(sys.path)
import sc_mbm.utils as ut
import torch
Expand Down
11 changes: 4 additions & 7 deletions code/stageA1_eeg_pretrain.py → code/stageA1_eeg_pretrain_dd.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def get_args_parser():
# Project setting
parser.add_argument('--root_path', type=str)
parser.add_argument('--seed', type=str)
parser.add_argument('--roi', type=str)
parser.add_argument('--aug_times', type=int)
parser.add_argument('--num_sub_limit', type=int)

Expand Down Expand Up @@ -109,8 +108,8 @@ def main(config):
torch.distributed.init_process_group(backend='nccl')
output_path = os.path.join(config.root_path, 'results', 'eeg_pretrain', '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")))
config.output_path = output_path
# logger = wandb_logger(config) if config.local_rank == 0 else None
logger = None
logger = wandb_logger(config) if config.local_rank == 0 else None
# logger = None

if config.local_rank == 0:
os.makedirs(output_path, exist_ok=True)
Expand All @@ -121,9 +120,7 @@ def main(config):
np.random.seed(config.seed)

# create dataset and dataloader
dataset_pretrain = eeg_pretrain_dataset(path='../dreamdiffusion/datasets/mne_data/', roi=config.roi, patch_size=config.patch_size,
transform=fmri_transform, aug_times=config.aug_times, num_sub_limit=config.num_sub_limit,
include_kam=config.include_kam, include_hcp=config.include_hcp)
dataset_pretrain = eeg_pretrain_dataset(path='./datasets/eeg_train/')

print(f'Dataset size: {len(dataset_pretrain)}\n Time len: {dataset_pretrain.data_len}')
sampler = torch.utils.data.DistributedSampler(dataset_pretrain, rank=config.local_rank) if torch.cuda.device_count() > 1 else None
Expand Down Expand Up @@ -287,4 +284,4 @@ def update_config(args, config):
config = Config_MBM_EEG()
config = update_config(args, config)
main(config)


Loading