diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b205a8c --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +.DS_Store + +__pycache__/ +dist/ +*.egg-info/ + +datasets/eeg_train/ +datasets/eeg_test/ +datasets/*.pth +datasets/**/*.JPEG +exps/ +exp/ +results/ +wandb/ + +pretrains/models/v1-5-pruned.ckpt +pretrains/eeg-pretrain/*.pth +pretrains/generation/*.pth diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..09b66ad --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "code/TimeMAE"] + path = code/TimeMAE + url = git@github.com:Mingyue-Cheng/TimeMAE.git diff --git a/README.md b/README.md index 08c25a4..af9e02b 100644 --- a/README.md +++ b/README.md @@ -10,15 +10,14 @@ This document introduces the precesedures required for replicating the results i ## Abstract This paper introduces DreamDiffusion, a novel method for generating high-quality images directly from brain electroencephalogram (EEG) signals, without the need to translate thoughts into text. DreamDiffusion leverages pre-trained text-to-image models and employs temporal masked signal modeling to pre-train the EEG encoder for effective and robust EEG representations. Additionally, the method further leverages the CLIP image encoder to provide extra supervision to better align EEG, text, and image embeddings with limited EEG-image pairs. Overall, the proposed method overcomes the challenges of using EEG signals for image generation, such as noise, limited information, and individual differences, and achieves promising results. Quantitative and qualitative results demonstrate the effectiveness of the proposed method as a significant step towards portable and low-cost "thoughts-to-image", with potential applications in neuroscience and computer vision. - ## Overview ![pipeline](assets/eeg_pipeline.png) -The **datasets** folder and **pretrains** folder are not included in this repository. +The **datasets** folder is not included in this repository. Neither are the pretrained models' checkpoints. Please download them from [eeg](https://github.com/perceivelab/eeg_visual_classification) and put them in the root directory of this repository as shown below. We also provide a copy of the Imagenet subset [imagenet](https://drive.google.com/file/d/1y7I9bG1zKYqBM94odcox_eQjnP9HGo9-/view?usp=drive_link). -For Stable Diffusion, we just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). You want the file ["v1-5-pruned.ckpt"](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). +For Stable Diffusion, we just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5). You want the file ["v1-5-pruned.ckpt"](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5). File path | Description ``` @@ -29,10 +28,10 @@ File path | Description ┃ ┗ 📜 v1-5-pruned.ckpt ┣ 📂 generation -┃ ┗ 📜 checkpoint_best.pth +┃ ┗ 📜 checkpoint.pth ┣ 📂 eeg_pretain -┃ ┗ 📜 checkpoint.pth (pre-trained EEG encoder) +┃ ┗ 📜 checkpoint-eeg-500.pth (pre-trained EEG encoder) /datasets ┣ 📂 imageNet_images (subset of Imagenet) @@ -74,16 +73,24 @@ conda env create -f env.yaml conda activate dreamdiffusion ``` -## Download checkpoints +## Download checkpoint after pre-training on EEG data -We also checkpoints to run the finetuing and decoding directly. +No checkpoint provided +To make the code work (poorly), please find a checkpoint after pretraining the EEG Autoencoder in [this repository](https://github.com/alinvdu/reproduce-dream-diffusion/). They propose to use only the EEGs from the (EEG, Image) pairs but it makes poor results as this is few data. +And put it in ```pretrains/eeg-pretrain/checkpoint-eeg-500.pth``` ## Pre-training on EEG data +remember to move generated checkpoints from the ```results``` directory to the +appropriate location in the ```pretrains``` directory + You can download the dataset for pretraining from here [MOABB](https://github.com/NeuroTechX/moabb). +The datasets used to get the papers's results aren't specified, but you can put +any EEG dataset, put them in ```datasets/eeg_train/``` and ```dataset/eeg_test/``` in a `.npy` format. + To perform the pre-training from scratch with defaults parameters, run ```sh python3 code/stageA1_eeg_pretrain.py @@ -101,18 +108,20 @@ Multiple-GPU (DDP) training is supported, run with python -m torch.distributed.launch --nproc_per_node=NUM_GPUS code/stageA1_eeg_pretrain.py ``` +## Finetune the Stable Diffusion with Pre-trained EEG Encoder +remember to move generated checkpoints from the ```results``` directory to the +appropriate location in the ```pretrains``` directory -## Finetune the Stable Diffusion with Pre-trained EEG Encoder In this stage, the cross-attention heads and pre-trained EEG encoder will be jointly optimized with EEG-image pairs. ```sh -python3 code/eeg_ldm.py --dataset EEG --num_epoch 300 --batch_size 4 --pretrain_mbm_path ../dreamdiffusion/pretrains/eeg_pretrain/checkpoint.pth +python3 code/eeg_ldm.py --dataset EEG --num_epoch 300 --batch_size 4 --pretrain_mbm_path ./pretrains/eeg_pretrain/checkpoint-eeg-500.pth ``` - ## Generating Images with Trained Checkpoints Run this stage with our provided checkpoints: Here we provide a checkpoint [ckpt](https://drive.google.com/file/d/1Ygplxe1TB68-aYu082bjc89nD8Ngklnc/view?usp=drive_link), which you may want to try. + ```sh python3 code/gen_eval_eeg.py --dataset EEG --model_path pretrains/generation/checkpoint.pth ``` diff --git a/code/TimeMAE b/code/TimeMAE new file mode 160000 index 0000000..f4401b5 --- /dev/null +++ b/code/TimeMAE @@ -0,0 +1 @@ +Subproject commit f4401b5a3b80056a385078af0caa1a58a04115e4 diff --git a/code/config.py b/code/config.py index 493c884..f1d91b1 100644 --- a/code/config.py +++ b/code/config.py @@ -16,25 +16,24 @@ def __init__(self): self.min_lr = 0. self.weight_decay = 0.05 self.num_epoch = 500 - self.warmup_epochs = 40 - self.batch_size = 100 + self.warmup_epochs = 10 + self.batch_size = 128 self.clip_grad = 0.8 # Model Parameters - self.mask_ratio = 0.1 + self.mask_ratio = 0.5 self.patch_size = 4 # 1 self.embed_dim = 1024 #256 # has to be a multiple of num_heads self.decoder_embed_dim = 512 #128 self.depth = 24 self.num_heads = 16 self.decoder_num_heads = 16 - self.mlp_ratio = 1.0 + self.mlp_ratio = 0.8 # Project setting - self.root_path = '../dreamdiffusion/' - self.output_path = '../dreamdiffusion/exps/' + self.root_path = './' + self.output_path = './exps/' self.seed = 2022 - self.roi = 'VC' self.aug_times = 1 self.num_sub_limit = None self.include_hcp = True @@ -54,15 +53,15 @@ class Config_EEG_finetune(Config_MBM_finetune): def __init__(self): # Project setting - self.root_path = '../dreamdiffusion/' + self.root_path = './' # self.root_path = '.' - self.output_path = '../dreamdiffusion/exps/' + self.output_path = './exps/' self.eeg_signals_path = os.path.join(self.root_path, 'datasets/eeg_5_95_std.pth') self.splits_path = os.path.join(self.root_path, 'datasets/block_splits_by_image_all.pth') self.dataset = 'EEG' - self.pretrain_mbm_path = '../dreamdiffusion/pretrains/eeg_pretrain/checkpoint.pth' + self.pretrain_mbm_path = './pretrains/eeg_pretrain/checkpoint.pth' self.include_nonavg_test = True @@ -71,7 +70,7 @@ def __init__(self): self.lr = 5.3e-5 self.weight_decay = 0.05 self.num_epoch = 15 - self.batch_size = 16 if self.dataset == 'GOD' else 4 + self.batch_size = 32 if self.dataset == 'GOD' else 16 self.mask_ratio = 0.5 self.accum_iter = 1 self.clip_grad = 0.8 @@ -89,8 +88,8 @@ class Config_Generative_Model: def __init__(self): # project parameters self.seed = 2022 - self.root_path = '../dreamdiffusion/' - self.output_path = '../dreamdiffusion/exps/' + self.root_path = './' + self.output_path = './exps/' self.eeg_signals_path = os.path.join(self.root_path, 'datasets/eeg_5_95_std.pth') self.splits_path = os.path.join(self.root_path, 'datasets/block_splits_by_image_single.pth') @@ -111,7 +110,7 @@ def __init__(self): np.random.seed(self.seed) # finetune parameters - self.batch_size = 5 if self.dataset == 'GOD' else 25 + self.batch_size = 32 if self.dataset == 'GOD' else 16 self.lr = 5.3e-5 self.num_epoch = 500 @@ -139,8 +138,8 @@ class Config_Cls_Model: def __init__(self): # project parameters self.seed = 2022 - self.root_path = '../dreamdiffusion/' - self.output_path = '../dreamdiffusion/exps/' + self.root_path = './' + self.output_path = './exps/' # self.eeg_signals_path = os.path.join(self.root_path, 'datasets/eeg_5_95_std.pth') self.eeg_signals_path = os.path.join(self.root_path, 'datasets/eeg_14_70_std.pth') @@ -162,7 +161,7 @@ def __init__(self): np.random.seed(self.seed) # finetune parameters - self.batch_size = 5 if self.dataset == 'GOD' else 25 + self.batch_size = 32 if self.dataset == 'GOD' else 16 self.lr = 5.3e-5 self.num_epoch = 50 @@ -181,4 +180,4 @@ def __init__(self): self.HW = None # resume check util self.model_meta = None - self.checkpoint_path = None \ No newline at end of file + self.checkpoint_path = None diff --git a/code/dataset.py b/code/dataset.py index 3c8957b..b72665d 100644 --- a/code/dataset.py +++ b/code/dataset.py @@ -106,24 +106,25 @@ def is_npy_ext(fname: Union[str, Path]) -> bool: return f'{ext}' == 'npy'# type: ignore class eeg_pretrain_dataset(Dataset): - def __init__(self, path='../dreamdiffusion/datasets/mne_data/', roi='VC', patch_size=16, transform=identity, aug_times=2, - num_sub_limit=None, include_kam=False, include_hcp=True): + def __init__(self, path): super(eeg_pretrain_dataset, self).__init__() data = [] images = [] self.input_paths = [str(f) for f in sorted(Path(path).rglob('*')) if is_npy_ext(f) and os.path.isfile(f)] assert len(self.input_paths) != 0, 'No data found' + + self.inputs = [np.load(data_path) for data_path in self.input_paths] + self.real_input = np.concatenate(self.inputs, axis=0) + self.data_len = 512 self.data_chan = 128 def __len__(self): - return len(self.input_paths) + return len(self.real_input) def __getitem__(self, index): - data_path = self.input_paths[index] - - data = np.load(data_path) + data = self.real_input[index] if data.shape[-1] > self.data_len: idx = np.random.randint(0, int(data.shape[-1] - self.data_len)+1) @@ -240,7 +241,7 @@ class EEGDataset_r(Dataset): # Constructor def __init__(self, eeg_signals_path, image_transform=identity): - self.imagenet = '/apdcephfs/share_1290939/0_public_datasets/imageNet_2012/train/' + self.imagenet = './datasets/imageNet_2012/train/' self.image_transform = image_transform self.num_voxels = 440 self.data_len = 512 @@ -276,7 +277,7 @@ def __init__(self, eeg_signals_path, image_transform=identity): self.data = loaded['dataset'] self.labels = loaded["labels"] self.images = loaded["images"] - self.imagenet = '/apdcephfs/share_1290939/0_public_datasets/imageNet_2012/train/' + self.imagenet = './datasets/imageNet_images/' self.image_transform = image_transform self.num_voxels = 440 # Compute size @@ -316,7 +317,7 @@ def __init__(self, eeg_signals_path, image_transform=identity, subject = 4): self.data = loaded['dataset'] self.labels = loaded["labels"] self.images = loaded["images"] - self.imagenet = '/apdcephfs/share_1290939/0_public_datasets/imageNet_2012/train/' + self.imagenet = './datasets/imageNet_images' self.image_transform = image_transform self.num_voxels = 440 self.data_len = 512 @@ -386,12 +387,12 @@ def __getitem__(self, i): return self.dataset[self.split_idx[i]] -def create_EEG_dataset(eeg_signals_path='../dreamdiffusion/datasets/eeg_5_95_std.pth', - splits_path = '../dreamdiffusion/datasets/block_splits_by_image_single.pth', - # splits_path = '../dreamdiffusion/datasets/block_splits_by_image_all.pth', +def create_EEG_dataset(eeg_signals_path='./datasets/eeg_5_95_std.pth', + splits_path = './datasets/block_splits_by_image_single.pth', + # splits_path = './datasets/block_splits_by_image_all.pth', image_transform=identity, subject = 0): # if subject == 0: - # splits_path = '../dreamdiffusion/datasets/block_splits_by_image_all.pth' + # splits_path = './datasets/block_splits_by_image_all.pth' if isinstance(image_transform, list): dataset_train = EEGDataset(eeg_signals_path, image_transform[0], subject ) dataset_test = EEGDataset(eeg_signals_path, image_transform[1], subject) @@ -405,9 +406,9 @@ def create_EEG_dataset(eeg_signals_path='../dreamdiffusion/datasets/eeg_5_95_std -def create_EEG_dataset_r(eeg_signals_path='../dreamdiffusion/datasets/eeg_5_95_std.pth', - # splits_path = '../dreamdiffusion/datasets/block_splits_by_image_single.pth', - splits_path = '../dreamdiffusion/datasets/block_splits_by_image_all.pth', +def create_EEG_dataset_r(eeg_signals_path='./datasets/eeg_5_95_std.pth', + # splits_path = './datasets/block_splits_by_image_single.pth', + splits_path = './datasets/block_splits_by_image_all.pth', image_transform=identity): if isinstance(image_transform, list): dataset_train = EEGDataset_r(eeg_signals_path, image_transform[0]) diff --git a/code/eeg_ldm.py b/code/eeg_ldm.py index fddd98a..4aa1c30 100644 --- a/code/eeg_ldm.py +++ b/code/eeg_ldm.py @@ -12,7 +12,7 @@ import copy # own code -from config import Config_Generative_Model +from config import Config_Generative_Model, Config_MBM_EEG from dataset import create_EEG_dataset from dc_ldm.ldm_for_eeg import eLDM from eval_metrics import get_similarity_metric @@ -178,7 +178,7 @@ def get_args_parser(): parser = argparse.ArgumentParser('Double Conditioning LDM Finetuning', add_help=False) # project parameters parser.add_argument('--seed', type=int) - parser.add_argument('--root_path', type=str, default = '../dreamdiffusion/') + parser.add_argument('--root_path', type=str, default = './') parser.add_argument('--pretrain_mbm_path', type=str) parser.add_argument('--checkpoint_path', type=str) parser.add_argument('--crop_ratio', type=float) diff --git a/code/gen_eval_eeg.py b/code/gen_eval_eeg.py index a079f51..c9e1720 100644 --- a/code/gen_eval_eeg.py +++ b/code/gen_eval_eeg.py @@ -51,7 +51,7 @@ def __call__(self, img): def get_args_parser(): parser = argparse.ArgumentParser('Double Conditioning LDM Finetuning', add_help=False) # project parameters - parser.add_argument('--root', type=str, default='../dreamdiffusion/') + parser.add_argument('--root', type=str, default='./') parser.add_argument('--dataset', type=str, default='GOD') parser.add_argument('--model_path', type=str) @@ -68,9 +68,10 @@ def get_args_parser(): config = sd['config'] # update paths config.root_path = root - config.pretrain_mbm_path = '../dreamdiffusion/results/eeg_pretrain/19-02-2023-08-48-17/checkpoints/checkpoint.pth' - config.pretrain_gm_path = '../dreamdiffusion/pretrains/' - print(config.__dict__) + config.eeg_signals_path = "./datasets/eeg_5_95_std.pth" + config.splits_path = "./datasets/block_splits_by_image_single.pth" + config.pretrain_mbm_path = "./pretrains/eeg-pretrain/checkpoint-eeg-500.pth" + config.pretrain_gm_path = './pretrains/' output_path = os.path.join(config.root_path, 'results', 'eval', '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))) @@ -91,7 +92,7 @@ def get_args_parser(): ]) - splits_path = "../dreamdiffusion/datasets/block_splits_by_image_single.pth" + splits_path = "./datasets/block_splits_by_image_single.pth" dataset_train, dataset_test = create_EEG_dataset(eeg_signals_path = config.eeg_signals_path, splits_path = splits_path, image_transform=[img_transform_train, img_transform_test], subject = 4) num_voxels = dataset_test.dataset.data_len diff --git a/code/sc_mbm/mae_for_eeg.py b/code/sc_mbm/mae_for_eeg.py index 90ed837..d4b14ab 100644 --- a/code/sc_mbm/mae_for_eeg.py +++ b/code/sc_mbm/mae_for_eeg.py @@ -1,5 +1,5 @@ import sys -sys.path.append('../dreamdiffusion/code/') +sys.path.append('./code/') # print(sys.path) import sc_mbm.utils as ut import torch diff --git a/code/stageA1_eeg_pretrain.py b/code/stageA1_eeg_pretrain_dd.py similarity index 96% rename from code/stageA1_eeg_pretrain.py rename to code/stageA1_eeg_pretrain_dd.py index 7c1d8e6..4913774 100644 --- a/code/stageA1_eeg_pretrain.py +++ b/code/stageA1_eeg_pretrain_dd.py @@ -74,7 +74,6 @@ def get_args_parser(): # Project setting parser.add_argument('--root_path', type=str) parser.add_argument('--seed', type=str) - parser.add_argument('--roi', type=str) parser.add_argument('--aug_times', type=int) parser.add_argument('--num_sub_limit', type=int) @@ -109,8 +108,8 @@ def main(config): torch.distributed.init_process_group(backend='nccl') output_path = os.path.join(config.root_path, 'results', 'eeg_pretrain', '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))) config.output_path = output_path - # logger = wandb_logger(config) if config.local_rank == 0 else None - logger = None + logger = wandb_logger(config) if config.local_rank == 0 else None + # logger = None if config.local_rank == 0: os.makedirs(output_path, exist_ok=True) @@ -121,9 +120,7 @@ def main(config): np.random.seed(config.seed) # create dataset and dataloader - dataset_pretrain = eeg_pretrain_dataset(path='../dreamdiffusion/datasets/mne_data/', roi=config.roi, patch_size=config.patch_size, - transform=fmri_transform, aug_times=config.aug_times, num_sub_limit=config.num_sub_limit, - include_kam=config.include_kam, include_hcp=config.include_hcp) + dataset_pretrain = eeg_pretrain_dataset(path='./datasets/eeg_train/') print(f'Dataset size: {len(dataset_pretrain)}\n Time len: {dataset_pretrain.data_len}') sampler = torch.utils.data.DistributedSampler(dataset_pretrain, rank=config.local_rank) if torch.cuda.device_count() > 1 else None @@ -287,4 +284,4 @@ def update_config(args, config): config = Config_MBM_EEG() config = update_config(args, config) main(config) - \ No newline at end of file + diff --git a/code/stageA1_eeg_pretrain_timeMAE.py b/code/stageA1_eeg_pretrain_timeMAE.py new file mode 100644 index 0000000..d83a26e --- /dev/null +++ b/code/stageA1_eeg_pretrain_timeMAE.py @@ -0,0 +1,300 @@ +import os, sys +import numpy as np +import torch +from torch.utils.data import DataLoader +from torch.nn.parallel import DistributedDataParallel +import argparse +import time +import timm.optim.optim_factory as optim_factory +import datetime +import matplotlib.pyplot as plt +import wandb +import copy + +from config import Config_MBM_EEG +from dataset import eeg_pretrain_dataset +from sc_mbm.mae_for_eeg import MAEforEEG +from sc_mbm.trainer import train_one_epoch +from sc_mbm.trainer import NativeScalerWithGradNormCount as NativeScaler +from sc_mbm.utils import save_model + +from TimeMAE.model.TimeMAE import TimeMAE + +os.environ["WANDB_START_METHOD"] = "thread" +os.environ['WANDB_DIR'] = "." + +class wandb_logger: + def __init__(self, config): + wandb.init( + project="dreamdiffusion", + anonymous="allow", + group='stageA_sc-mbm', + config=config, + reinit=True) + + self.config = config + self.step = None + + def log(self, name, data, step=None): + if step is None: + wandb.log({name: data}) + else: + wandb.log({name: data}, step=step) + self.step = step + + def watch_model(self, *args, **kwargs): + wandb.watch(*args, **kwargs) + + def log_image(self, name, fig): + if self.step is None: + wandb.log({name: wandb.Image(fig)}) + else: + wandb.log({name: wandb.Image(fig)}, step=self.step) + + def finish(self): + wandb.finish(quiet=True) + +def get_args_parser(): + parser = argparse.ArgumentParser('MBM pre-training for fMRI', add_help=False) + + # Training Parameters + parser.add_argument('--lr', type=float) + parser.add_argument('--weight_decay', type=float) + parser.add_argument('--num_epoch', type=int) + parser.add_argument('--batch_size', type=int) + parser.add_argument('--momentum', type=float) + + # Model Parameters + parser.add_argument('--mask_ratio', type=float) + parser.add_argument('--patch_size', type=int) + parser.add_argument('--embed_dim', type=int) + parser.add_argument('--decoder_embed_dim', type=int) + parser.add_argument('--depth', type=int) + parser.add_argument('--num_heads', type=int) + parser.add_argument('--decoder_num_heads', type=int) + parser.add_argument('--mlp_ratio', type=float) + parser.add_argument('--wave_length', type=int) + parser.add_argument('--attn_heads', type=int) + parser.add_argument('--layers', type=int) + parser.add_argument('--dropout', type=float) + parser.add_argument('--enable_res_parameter', type=int) + parser.add_argument('--vocab_size', type=int) + parser.add_argument('--reg_layers', type=int) + parser.add_argument('--num_class', type=int) + + # Project setting + parser.add_argument('--root_path', type=str) + parser.add_argument('--seed', type=str) + parser.add_argument('--roi', type=str) + parser.add_argument('--aug_times', type=int) + parser.add_argument('--num_sub_limit', type=int) + + parser.add_argument('--include_hcp', type=bool) + parser.add_argument('--include_kam', type=bool) + + parser.add_argument('--use_nature_img_loss', type=bool) + parser.add_argument('--img_recon_weight', type=float) + + # distributed training parameters + parser.add_argument('--local_rank', type=int) + + return parser + +def create_readme(config, path): + print(config.__dict__) + with open(os.path.join(path, 'README.md'), 'w+') as f: + print(config.__dict__, file=f) + +def fmri_transform(x, sparse_rate=0.2): + # x: 1, num_voxels + x_aug = copy.deepcopy(x) + idx = np.random.choice(x.shape[0], int(x.shape[0]*sparse_rate), replace=False) + x_aug[idx] = 0 + return torch.FloatTensor(x_aug) + +def main(config): + # print('num of gpu:') + # print(torch.cuda.device_count()) + if torch.cuda.device_count() > 1: + torch.cuda.set_device(config.local_rank) + torch.distributed.init_process_group(backend='nccl') + output_path = os.path.join(config.root_path, 'results', 'eeg_pretrain', '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))) + config.output_path = output_path + logger = wandb_logger(config) if config.local_rank == 0 else None + # logger = None + + if config.local_rank == 0: + os.makedirs(output_path, exist_ok=True) + create_readme(config, output_path) + + device = torch.device(f'cuda:{config.local_rank}') if torch.cuda.is_available() else torch.device('cpu') + torch.manual_seed(config.seed) + np.random.seed(config.seed) + + # create dataset and dataloader + dataset_pretrain = eeg_pretrain_dataset(path='./datasets/eeg_test/') + + print(f'Dataset size: {len(dataset_pretrain)}\n Time len: {dataset_pretrain.data_len}') + sampler = torch.utils.data.DistributedSampler(dataset_pretrain, rank=config.local_rank) if torch.cuda.device_count() > 1 else None + + dataloader_eeg = DataLoader(dataset_pretrain, batch_size=config.batch_size, sampler=sampler, + shuffle=(sampler is None), pin_memory=True) + + # create model + config.time_len = dataset_pretrain.data_len + config.d_model = config.embed_dim + config.device = device + config.data_shape = dataset_pretrain.__getitem__(0)['eeg'].transpose(0,1).shape + print(f"Input data shape: {config.data_shape}") + + model = TimeMAE(config) + + model.to(device) + model_without_ddp = model + if torch.cuda.device_count() > 1: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = DistributedDataParallel(model, device_ids=[config.local_rank], output_device=config.local_rank, find_unused_parameters=config.use_nature_img_loss) + + param_groups = optim_factory.add_weight_decay(model, config.weight_decay) + optimizer = torch.optim.AdamW(param_groups, lr=config.lr, betas=(0.9, 0.95)) + print(optimizer) + loss_scaler = NativeScaler() + + if logger is not None: + logger.watch_model(model,log='all', log_freq=1000) + + cor_list = [] + start_time = time.time() + print('Start Training the EEG MAE ... ...') + img_feature_extractor = None + preprocess = None + if config.use_nature_img_loss: + from torchvision.models import resnet50, ResNet50_Weights + from torchvision.models.feature_extraction import create_feature_extractor + weights = ResNet50_Weights.DEFAULT + preprocess = weights.transforms() + m = resnet50(weights=weights) + img_feature_extractor = create_feature_extractor(m, return_nodes={f'layer2': 'layer2'}).to(device).eval() + for param in img_feature_extractor.parameters(): + param.requires_grad = False + + for ep in range(config.num_epoch): + + if torch.cuda.device_count() > 1: + sampler.set_epoch(ep) # to shuffle the data at every epoch + cor = train_one_epoch(model, dataloader_eeg, optimizer, device, ep, loss_scaler, logger, config, start_time, model_without_ddp, + img_feature_extractor, preprocess) + cor_list.append(cor) + if (ep % 20 == 0 or ep + 1 == config.num_epoch) and config.local_rank == 0: #and ep != 0 + # save models + # if True: + save_model(config, ep, model_without_ddp, optimizer, loss_scaler, os.path.join(output_path,'checkpoints')) + # plot figures + # plot_recon_figures(model, device, dataset_pretrain, output_path, 5, config, logger, model_without_ddp) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + if logger is not None: + logger.log('max cor', np.max(cor_list), step=config.num_epoch-1) + logger.finish() + return + +@torch.no_grad() +def plot_recon_figures(model, device, dataset, output_path, num_figures = 5, config=None, logger=None, model_without_ddp=None): + dataloader = DataLoader(dataset, batch_size=1, shuffle=True) + model.eval() + fig, axs = plt.subplots(num_figures, 3, figsize=(30,15)) + fig.tight_layout() + axs[0,0].set_title('Ground-truth') + axs[0,1].set_title('Masked Ground-truth') + axs[0,2].set_title('Reconstruction') + + for ax in axs: + sample = next(iter(dataloader))['eeg'] + sample = sample.to(device) + _, pred, mask = model(sample, mask_ratio=config.mask_ratio) + # sample_with_mask = model_without_ddp.patchify(sample.transpose(1,2))[0].to('cpu').numpy().reshape(-1, model_without_ddp.patch_size) + sample_with_mask = sample.to('cpu').squeeze(0)[0].numpy().reshape(-1, model_without_ddp.patch_size) + # pred = model_without_ddp.unpatchify(pred.transpose(1,2)).to('cpu').squeeze(0)[0].unsqueeze(0).numpy() + # sample = sample.to('cpu').squeeze(0)[0].unsqueeze(0).numpy() + pred = model_without_ddp.unpatchify(pred).to('cpu').squeeze(0)[0].numpy() + # pred = model_without_ddp.unpatchify(model_without_ddp.patchify(sample.transpose(1,2))).to('cpu').squeeze(0)[0].numpy() + sample = sample.to('cpu').squeeze(0)[0].numpy() + mask = mask.to('cpu').numpy().reshape(-1) + + cor = np.corrcoef([pred, sample])[0,1] + + x_axis = np.arange(0, sample.shape[-1]) + # groundtruth + ax[0].plot(x_axis, sample) + # groundtruth with mask + s = 0 + for x, m in zip(sample_with_mask,mask): + if m == 0: + ax[1].plot(x_axis[s:s+len(x)], x, color='#1f77b4') + s += len(x) + # pred + ax[2].plot(x_axis, pred) + ax[2].set_ylabel('cor: %.4f'%cor, weight = 'bold') + ax[2].yaxis.set_label_position("right") + + fig_name = 'reconst-%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")) + fig.savefig(os.path.join(output_path, f'{fig_name}.png')) + if logger is not None: + logger.log_image('reconst', fig) + plt.close(fig) + + +@torch.no_grad() +def plot_recon_figures2(model, device, dataset, output_path, num_figures = 5, config=None, logger=None, model_without_ddp=None): + dataloader = DataLoader(dataset, batch_size=1, shuffle=True) + model.eval() + fig, axs = plt.subplots(num_figures, 2, figsize=(20,15)) + fig.tight_layout() + axs[0,0].set_title('Ground-truth') + # axs[0,1].set_title('Masked Ground-truth') + axs[0,1].set_title('Reconstruction') + + for ax in axs: + sample = next(iter(dataloader))['eeg'] + sample = sample.to(device) + _, pred, mask = model(sample, mask_ratio=config.mask_ratio) + # sample_with_mask = model_without_ddp.patchify(sample.transpose(1,2))[0].to('cpu').numpy().reshape(-1, model_without_ddp.patch_size) + sample_with_mask = sample.to('cpu').squeeze(0)[0].numpy().reshape(-1, model_without_ddp.patch_size) + # pred = model_without_ddp.unpatchify(pred.transpose(1,2)).to('cpu').squeeze(0)[0].unsqueeze(0).numpy() + # sample = sample.to('cpu').squeeze(0)[0].unsqueeze(0).numpy() + pred = model_without_ddp.unpatchify(pred).to('cpu').squeeze(0)[0].numpy() + # pred = model_without_ddp.unpatchify(model_without_ddp.patchify(sample.transpose(1,2))).to('cpu').squeeze(0)[0].numpy() + sample = sample.to('cpu').squeeze(0)[0].numpy() + cor = np.corrcoef([pred, sample])[0,1] + + x_axis = np.arange(0, sample.shape[-1]) + # groundtruth + ax[0].plot(x_axis, sample) + + ax[1].plot(x_axis, pred) + ax[1].set_ylabel('cor: %.4f'%cor, weight = 'bold') + ax[1].yaxis.set_label_position("right") + + fig_name = 'reconst-%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")) + fig.savefig(os.path.join(output_path, f'{fig_name}.png')) + if logger is not None: + logger.log_image('reconst', fig) + plt.close(fig) + + +def update_config(args, config): + for attr in config.__dict__: + if hasattr(args, attr): + if getattr(args, attr) != None: + setattr(config, attr, getattr(args, attr)) + return config + + +if __name__ == '__main__': + args = get_args_parser() + args = args.parse_args() + config = Config_MBM_EEG() + config = update_config(args, config) + main(config) \ No newline at end of file diff --git a/datasets/dl_moabb_datasets.py b/datasets/dl_moabb_datasets.py new file mode 100644 index 0000000..e79c081 --- /dev/null +++ b/datasets/dl_moabb_datasets.py @@ -0,0 +1,72 @@ +import warnings +from urllib3.exceptions import InsecureRequestWarning +from mne.io import BaseRaw +import os + +def find_raw_objects(data, types=(BaseRaw), path=""): + found = [] + + if isinstance(data, dict): + for key, value in data.items(): + new_path = f"{path}.{key}" if path else key + found.extend(find_raw_objects(value, types, new_path)) + + elif isinstance(data, list): + for idx, item in enumerate(data): + new_path = f"{path}[{idx}]" + found.extend(find_raw_objects(item, types, new_path)) + + elif isinstance(data, types): + found.append((path, data)) + + return found + +print() +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=InsecureRequestWarning) +print("WARNING : All 'FutureWarning' and 'InsecureRequestWarning' have been disabled") + +import numpy as np +from moabb.datasets import * + +datasets = [Schirrmeister2017(), GrosseWentrup2009()] +sample_size = 490 + +def dl_dataset(dataset, subjects, outpath): + try: + print("\n#####") + print(dataset) + print("#####\n") + data = dataset.get_data(subjects=subjects) + for raw_path, raw_data in find_raw_objects(data): + print() + npdata = raw_data.pick_types(eeg=True).get_data() + print(f"Shape of data: {npdata.shape}") + + print(f"Splitting into samples of size {sample_size}") + length = npdata.shape[1] + samples = [] + i=0 + while (i + 1) * sample_size < length: + sample = npdata[:, i * sample_size : (i+1) * sample_size] + samples.append(sample) + i += 1 + + res = np.array(samples) + print(f"New shape: {res.shape}") + + + whole_path = outpath+'/'+dataset.code+"/"+raw_path+'.npy' + print(f"Saving in {whole_path}") + + os.makedirs(os.path.dirname(whole_path), exist_ok=True) + np.save(whole_path, res) + print() + + except Exception as e: + print(e) + +for dataset in datasets: + subjects = dataset.subject_list + dl_dataset(dataset, subjects[:1], "eeg_train/") + dl_dataset(dataset, subjects[1:2], "eeg_test/") diff --git a/datasets/min_size_sample.py b/datasets/min_size_sample.py new file mode 100644 index 0000000..11a0950 --- /dev/null +++ b/datasets/min_size_sample.py @@ -0,0 +1,12 @@ +import torch +import numpy as np + +data = torch.load("eeg_5_95_std.pth") + +print(f"data: dict with keys {data.keys()}") +print(f"data['dataset']: list of length {len(data['dataset'])}") +print(f"data['dataset'][0]['eeg']: {data['dataset'][0]['eeg'].shape}") + +res1 = [sample['eeg'] for sample in data['dataset']] +print() +print(f"Smallest sample has size {min(arr.shape[1] for arr in res1)}") diff --git a/env.yaml b/env.yaml index ebfecab..fd123ce 100644 --- a/env.yaml +++ b/env.yaml @@ -4,9 +4,10 @@ channels: - defaults dependencies: - python=3.8.5 - - pip>=20.3 + - pip>=20.3,<24.1 - pip: - - numpy + - numpy==1.19.5 + - transformers==4.36.2 - matplotlib - natsort - kornia diff --git a/env_moabb.yaml b/env_moabb.yaml new file mode 100644 index 0000000..5971711 --- /dev/null +++ b/env_moabb.yaml @@ -0,0 +1,18 @@ +name: dreamdiffusion-moabb +channels: + - conda-forge + - defaults +dependencies: + - python=3.10 + - pip + - numpy + - scipy + - scikit-learn + - matplotlib + - pandas + - seaborn + - mne + - joblib + - pyxdf + - pip: + - moabb